0
# Model Selection and Utilities
1
2
Tools for programmatically selecting, instantiating, and working with multiple zoo models, including helper classes for building custom architectures.
3
4
## Capabilities
5
6
### ModelSelector
7
8
Utility class for selecting and instantiating multiple zoo models based on type. Provides various overloaded methods for different configuration needs.
9
10
```java { .api }
11
/**
12
* Helper class for selecting multiple models from the zoo.
13
*/
14
class ModelSelector {
15
/**
16
* Select models by type with default configuration
17
* @param zooType Type of models to select
18
* @return Map of ZooType to ZooModel instances
19
*/
20
static Map<ZooType, ZooModel> select(ZooType zooType);
21
22
/**
23
* Select models by type with custom label count
24
* @param zooType Type of models to select
25
* @param numLabels Number of output classes
26
* @return Map of ZooType to ZooModel instances
27
*/
28
static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels);
29
30
/**
31
* Select models by type with workspace mode
32
* @param zooType Type of models to select
33
* @param numLabels Number of output classes
34
* @param workspaceMode Memory workspace configuration
35
* @return Map of ZooType to ZooModel instances
36
*/
37
static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels, WorkspaceMode workspaceMode);
38
39
/**
40
* Select models by type with training parameters
41
* @param zooType Type of models to select
42
* @param numLabels Number of output classes
43
* @param seed Random seed for reproducibility
44
* @param iterations Number of training iterations
45
* @return Map of ZooType to ZooModel instances
46
*/
47
static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels, int seed, int iterations);
48
49
/**
50
* Select models by type with full parameter control
51
* @param zooType Type of models to select
52
* @param numLabels Number of output classes
53
* @param seed Random seed for reproducibility
54
* @param iterations Number of training iterations
55
* @param workspaceMode Memory workspace configuration
56
* @return Map of ZooType to ZooModel instances
57
*/
58
static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels, int seed, int iterations, WorkspaceMode workspaceMode);
59
60
/**
61
* Select specific model types with workspace mode
62
* @param workspaceMode Memory workspace configuration
63
* @param zooTypes Specific model types to select
64
* @return Map of ZooType to ZooModel instances
65
*/
66
static Map<ZooType, ZooModel> select(WorkspaceMode workspaceMode, ZooType... zooTypes);
67
68
/**
69
* Select specific model types with default configuration
70
* @param zooTypes Specific model types to select
71
* @return Map of ZooType to ZooModel instances
72
*/
73
static Map<ZooType, ZooModel> select(ZooType... zooTypes);
74
75
/**
76
* Select specific model types with full parameter control
77
* @param numLabels Number of output classes
78
* @param seed Random seed for reproducibility
79
* @param iterations Number of training iterations
80
* @param workspaceMode Memory workspace configuration
81
* @param zooTypes Specific model types to select
82
* @return Map of ZooType to ZooModel instances
83
*/
84
static Map<ZooType, ZooModel> select(int numLabels, int seed, int iterations, WorkspaceMode workspaceMode, ZooType... zooTypes);
85
}
86
```
87
88
**Usage Examples:**
89
90
```java
91
// Select all CNN models with default settings
92
Map<ZooType, ZooModel> cnnModels = ModelSelector.select(ZooType.CNN);
93
// Returns: AlexNet, VGG16, VGG19, ResNet50, GoogLeNet, LeNet, SimpleCNN
94
95
// Select all models (CNN + RNN)
96
Map<ZooType, ZooModel> allModels = ModelSelector.select(ZooType.ALL);
97
98
// Select specific models
99
Map<ZooType, ZooModel> specificModels = ModelSelector.select(
100
ZooType.ALEXNET,
101
ZooType.VGG16,
102
ZooType.RESNET50
103
);
104
105
// Select with custom configuration
106
Map<ZooType, ZooModel> customModels = ModelSelector.select(
107
ZooType.CNN,
108
10, // 10 classes
109
42, // seed
110
100, // iterations
111
WorkspaceMode.SINGLE
112
);
113
114
// Iterate through selected models
115
for (Map.Entry<ZooType, ZooModel> entry : cnnModels.entrySet()) {
116
ZooType type = entry.getKey();
117
ZooModel model = entry.getValue();
118
119
System.out.println("Model: " + type);
120
Model initializedModel = model.init();
121
ModelMetaData metadata = model.metaData();
122
System.out.println("Input shape: " + Arrays.deepToString(metadata.getInputShape()));
123
}
124
```
125
126
### ZooType Enumeration
127
128
Classification system for different model types and categories.
129
130
```java { .api }
131
/**
132
* Enumerator for choosing different models, and different types of models.
133
*/
134
enum ZooType {
135
/** All available models */
136
ALL,
137
138
/** All CNN models */
139
CNN,
140
141
/** Simple CNN architecture */
142
SIMPLECNN,
143
144
/** AlexNet architecture */
145
ALEXNET,
146
147
/** LeNet architecture */
148
LENET,
149
150
/** GoogLeNet/Inception architecture */
151
GOOGLENET,
152
153
/** VGG16 architecture */
154
VGG16,
155
156
/** VGG19 architecture */
157
VGG19,
158
159
/** ResNet50 architecture */
160
RESNET50,
161
162
/** InceptionResNetV1 architecture */
163
INCEPTIONRESNETV1,
164
165
/** FaceNet NN4 Small2 architecture */
166
FACENETNN4SMALL2,
167
168
/** All RNN models */
169
RNN,
170
171
/** Text generation LSTM */
172
TEXTGENLSTM
173
}
174
```
175
176
**Model Type Hierarchies:**
177
178
```java
179
// CNN models include:
180
ModelSelector.select(ZooType.CNN); // Returns all CNN architectures
181
// - SIMPLECNN, ALEXNET, LENET, GOOGLENET, VGG16, VGG19, RESNET50
182
183
// RNN models include:
184
ModelSelector.select(ZooType.RNN); // Returns all RNN architectures
185
// - TEXTGENLSTM
186
187
// ALL includes both CNN and RNN:
188
ModelSelector.select(ZooType.ALL); // Returns all available models
189
```
190
191
### PretrainedType Enumeration
192
193
Types of pre-trained model weights available for supported models.
194
195
```java { .api }
196
/**
197
* Enumerator for choosing different pre-trained weight types.
198
*/
199
enum PretrainedType {
200
/** ImageNet dataset pre-trained weights (1000 classes) */
201
IMAGENET,
202
203
/** MNIST dataset pre-trained weights (10 digit classes) */
204
MNIST,
205
206
/** CIFAR-10 dataset pre-trained weights (10 object classes) */
207
CIFAR10,
208
209
/** VGGFace dataset pre-trained weights (face recognition) */
210
VGGFACE
211
}
212
```
213
214
**Pre-trained Weight Availability:**
215
216
```java
217
VGG16 vgg16 = new VGG16(1000, 42, 1);
218
219
// Check which pre-trained weights are available
220
boolean hasImageNet = vgg16.pretrainedAvailable(PretrainedType.IMAGENET); // true
221
boolean hasCIFAR10 = vgg16.pretrainedAvailable(PretrainedType.CIFAR10); // true
222
boolean hasVGGFace = vgg16.pretrainedAvailable(PretrainedType.VGGFACE); // true
223
boolean hasMNIST = vgg16.pretrainedAvailable(PretrainedType.MNIST); // false
224
225
// Load specific pre-trained weights
226
Model imageNetModel = vgg16.initPretrained(PretrainedType.IMAGENET);
227
Model cifar10Model = vgg16.initPretrained(PretrainedType.CIFAR10);
228
```
229
230
### Helper Classes
231
232
#### FaceNetHelper
233
234
Utility class for building Inception-style layers used in FaceNet and other advanced architectures.
235
236
```java { .api }
237
/**
238
* Helper class for building Inception-style modules used in FaceNet models.
239
* Provides pre-configured layers and graph building utilities.
240
*/
241
class FaceNetHelper {
242
/**
243
* Returns base module name for inception layers
244
* @return "inception"
245
*/
246
static String getModuleName();
247
248
/**
249
* Returns namespaced module name
250
* @param layerName Name of the specific layer
251
* @return Formatted module name
252
*/
253
static String getModuleName(String layerName);
254
255
/**
256
* Creates 1x1 convolution layer
257
* @param in Number of input channels
258
* @param out Number of output channels
259
* @param bias Bias initialization value
260
* @return ConvolutionLayer configured as 1x1 convolution
261
*/
262
static ConvolutionLayer conv1x1(int in, int out, double bias);
263
264
/**
265
* Creates 3x3 convolution layer
266
* @param in Number of input channels
267
* @param out Number of output channels
268
* @param bias Bias initialization value
269
* @return ConvolutionLayer configured as 3x3 convolution
270
*/
271
static ConvolutionLayer conv3x3(int in, int out, double bias);
272
273
/**
274
* Creates 5x5 convolution layer
275
* @param in Number of input channels
276
* @param out Number of output channels
277
* @param bias Bias initialization value
278
* @return ConvolutionLayer configured as 5x5 convolution
279
*/
280
static ConvolutionLayer conv5x5(int in, int out, double bias);
281
282
/**
283
* Creates 7x7 convolution layer
284
* @param in Number of input channels
285
* @param out Number of output channels
286
* @param bias Bias initialization value
287
* @return ConvolutionLayer configured as 7x7 convolution
288
*/
289
static ConvolutionLayer conv7x7(int in, int out, double bias);
290
291
/**
292
* Creates average pooling layer
293
* @param size Pool size (NxN)
294
* @param stride Stride for pooling
295
* @return SubsamplingLayer configured for average pooling
296
*/
297
static SubsamplingLayer avgPoolNxN(int size, int stride);
298
299
/**
300
* Creates max pooling layer
301
* @param size Pool size (NxN)
302
* @param stride Stride for pooling
303
* @return SubsamplingLayer configured for max pooling
304
*/
305
static SubsamplingLayer maxPoolNxN(int size, int stride);
306
307
/**
308
* Creates p-norm pooling layer
309
* @param pNorm P-norm value
310
* @param size Pool size (NxN)
311
* @param stride Stride for pooling
312
* @return SubsamplingLayer configured for p-norm pooling
313
*/
314
static SubsamplingLayer pNormNxN(int pNorm, int size, int stride);
315
316
/**
317
* Creates fully connected (dense) layer
318
* @param in Number of input units
319
* @param out Number of output units
320
* @param dropOut Dropout rate
321
* @return DenseLayer with specified configuration
322
*/
323
static DenseLayer fullyConnected(int in, int out, double dropOut);
324
325
/**
326
* Creates batch normalization layer
327
* @param in Number of input channels
328
* @param out Number of output channels
329
* @return BatchNormalization layer
330
*/
331
static BatchNormalization batchNorm(int in, int out);
332
333
/**
334
* Appends complete Inception module to a computation graph with default parameters
335
* @param graph Existing graph builder
336
* @param moduleLayerName Name for this inception module
337
* @param inputSize Number of input channels
338
* @param kernelSize Array of kernel sizes for different paths
339
* @param kernelStride Array of strides for different paths
340
* @param outputSize Array of output sizes for different paths
341
* @param reduceSize Array of reduction sizes for different paths
342
* @param poolingType Type of pooling to use
343
* @param transferFunction Activation function
344
* @param inputLayer Name of input layer to connect to
345
* @return Updated GraphBuilder with inception module added
346
*/
347
static ComputationGraphConfiguration.GraphBuilder appendGraph(
348
ComputationGraphConfiguration.GraphBuilder graph,
349
String moduleLayerName,
350
int inputSize,
351
int[] kernelSize,
352
int[] kernelStride,
353
int[] outputSize,
354
int[] reduceSize,
355
SubsamplingLayer.PoolingType poolingType,
356
Activation transferFunction,
357
String inputLayer
358
);
359
360
/**
361
* Appends complete Inception module to a computation graph with p-norm pooling
362
* @param graph Existing graph builder
363
* @param moduleLayerName Name for this inception module
364
* @param inputSize Number of input channels
365
* @param kernelSize Array of kernel sizes for different paths
366
* @param kernelStride Array of strides for different paths
367
* @param outputSize Array of output sizes for different paths
368
* @param reduceSize Array of reduction sizes for different paths
369
* @param poolingType Type of pooling to use
370
* @param pNorm P-norm value (if using p-norm pooling)
371
* @param transferFunction Activation function
372
* @param inputLayer Name of input layer to connect to
373
* @return Updated GraphBuilder with inception module added
374
*/
375
static ComputationGraphConfiguration.GraphBuilder appendGraph(
376
ComputationGraphConfiguration.GraphBuilder graph,
377
String moduleLayerName,
378
int inputSize,
379
int[] kernelSize,
380
int[] kernelStride,
381
int[] outputSize,
382
int[] reduceSize,
383
SubsamplingLayer.PoolingType poolingType,
384
int pNorm,
385
Activation transferFunction,
386
String inputLayer
387
);
388
389
/**
390
* Appends complete Inception module to a computation graph with custom pooling parameters
391
* @param graph Existing graph builder
392
* @param moduleLayerName Name for this inception module
393
* @param inputSize Number of input channels
394
* @param kernelSize Array of kernel sizes for different paths
395
* @param kernelStride Array of strides for different paths
396
* @param outputSize Array of output sizes for different paths
397
* @param reduceSize Array of reduction sizes for different paths
398
* @param poolingType Type of pooling to use
399
* @param poolSize Size of pooling window
400
* @param poolStride Stride for pooling
401
* @param transferFunction Activation function
402
* @param inputLayer Name of input layer to connect to
403
* @return Updated GraphBuilder with inception module added
404
*/
405
static ComputationGraphConfiguration.GraphBuilder appendGraph(
406
ComputationGraphConfiguration.GraphBuilder graph,
407
String moduleLayerName,
408
int inputSize,
409
int[] kernelSize,
410
int[] kernelStride,
411
int[] outputSize,
412
int[] reduceSize,
413
SubsamplingLayer.PoolingType poolingType,
414
int poolSize,
415
int poolStride,
416
Activation transferFunction,
417
String inputLayer
418
);
419
420
/**
421
* Appends complete Inception module to a computation graph with full parameter control
422
* @param graph Existing graph builder
423
* @param moduleLayerName Name for this inception module
424
* @param inputSize Number of input channels
425
* @param kernelSize Array of kernel sizes for different paths
426
* @param kernelStride Array of strides for different paths
427
* @param outputSize Array of output sizes for different paths
428
* @param reduceSize Array of reduction sizes for different paths
429
* @param poolingType Type of pooling to use
430
* @param pNorm P-norm value (if using p-norm pooling)
431
* @param poolSize Size of pooling window
432
* @param poolStride Stride for pooling
433
* @param transferFunction Activation function
434
* @param inputLayer Name of input layer to connect to
435
* @return Updated GraphBuilder with inception module added
436
*/
437
static ComputationGraphConfiguration.GraphBuilder appendGraph(
438
ComputationGraphConfiguration.GraphBuilder graph,
439
String moduleLayerName,
440
int inputSize,
441
int[] kernelSize,
442
int[] kernelStride,
443
int[] outputSize,
444
int[] reduceSize,
445
SubsamplingLayer.PoolingType poolingType,
446
int pNorm,
447
int poolSize,
448
int poolStride,
449
Activation transferFunction,
450
String inputLayer
451
);
452
}
453
```
454
455
**Usage Example:**
456
457
```java
458
// Building custom architecture with Inception modules
459
ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder()
460
.graphBuilder()
461
.addInputs("input");
462
463
// Add custom Inception module
464
graph = FaceNetHelper.appendGraph(
465
graph,
466
"inception_1", // module name
467
64, // input channels
468
new int[]{3, 5}, // kernel sizes
469
new int[]{1, 1}, // strides
470
new int[]{128, 64}, // output sizes
471
new int[]{32, 16, 8}, // reduction sizes
472
SubsamplingLayer.PoolingType.MAX,
473
0, // p-norm (not used for MAX pooling)
474
3, // pool size
475
1, // pool stride
476
Activation.RELU, // activation
477
"input" // input layer name
478
);
479
```
480
481
#### InceptionResNetHelper
482
483
Helper class for building Inception-ResNet architectures that combine Inception modules with residual connections.
484
485
```java { .api }
486
/**
487
* Helper class for building Inception-ResNet modules that combine residual shortcuts
488
* with Inception-style networks. Based on the Inception-ResNet paper.
489
*/
490
class InceptionResNetHelper {
491
/**
492
* Creates layer name with block and iteration naming
493
* @param blockName Name of the inception block
494
* @param layerName Name of the specific layer
495
* @param i Iteration/block number
496
* @return Formatted layer name
497
*/
498
static String nameLayer(String blockName, String layerName, int i);
499
500
/**
501
* Appends Inception-ResNet A blocks to a computation graph
502
* @param graph Existing graph builder
503
* @param blockName Name for this inception block
504
* @param scale Number of blocks to add
505
* @param activationScale Scaling factor for activations
506
* @param input Name of input layer to connect to
507
* @return Updated GraphBuilder with Inception-ResNet A blocks added
508
*/
509
static ComputationGraphConfiguration.GraphBuilder inceptionV1ResA(
510
ComputationGraphConfiguration.GraphBuilder graph,
511
String blockName,
512
int scale,
513
double activationScale,
514
String input
515
);
516
517
/**
518
* Appends Inception-ResNet B blocks to a computation graph
519
* @param graph Existing graph builder
520
* @param blockName Name for this inception block
521
* @param scale Number of blocks to add
522
* @param activationScale Scaling factor for activations
523
* @param input Name of input layer to connect to
524
* @return Updated GraphBuilder with Inception-ResNet B blocks added
525
*/
526
static ComputationGraphConfiguration.GraphBuilder inceptionV1ResB(
527
ComputationGraphConfiguration.GraphBuilder graph,
528
String blockName,
529
int scale,
530
double activationScale,
531
String input
532
);
533
534
/**
535
* Appends Inception-ResNet C blocks to a computation graph
536
* @param graph Existing graph builder
537
* @param blockName Name for this inception block
538
* @param scale Number of blocks to add
539
* @param activationScale Scaling factor for activations
540
* @param input Name of input layer to connect to
541
* @return Updated GraphBuilder with Inception-ResNet C blocks added
542
*/
543
static ComputationGraphConfiguration.GraphBuilder inceptionV1ResC(
544
ComputationGraphConfiguration.GraphBuilder graph,
545
String blockName,
546
int scale,
547
double activationScale,
548
String input
549
);
550
}
551
```
552
553
**Usage Example:**
554
555
```java
556
// Building InceptionResNet architecture
557
ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder()
558
.graphBuilder()
559
.addInputs("input");
560
561
// Add Inception-ResNet A blocks
562
graph = InceptionResNetHelper.inceptionV1ResA(
563
graph,
564
"resnet_a", // block name
565
3, // number of blocks
566
0.1, // activation scaling
567
"input" // input layer
568
);
569
```