or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

cnn-models.mdcore-interface.mdimagenet-integration.mdindex.mdmodel-selection.mdrnn-models.md

model-selection.mddocs/

0

# Model Selection and Utilities

1

2

Tools for programmatically selecting, instantiating, and working with multiple zoo models, including helper classes for building custom architectures.

3

4

## Capabilities

5

6

### ModelSelector

7

8

Utility class for selecting and instantiating multiple zoo models based on type. Provides various overloaded methods for different configuration needs.

9

10

```java { .api }

11

/**

12

* Helper class for selecting multiple models from the zoo.

13

*/

14

class ModelSelector {

15

/**

16

* Select models by type with default configuration

17

* @param zooType Type of models to select

18

* @return Map of ZooType to ZooModel instances

19

*/

20

static Map<ZooType, ZooModel> select(ZooType zooType);

21

22

/**

23

* Select models by type with custom label count

24

* @param zooType Type of models to select

25

* @param numLabels Number of output classes

26

* @return Map of ZooType to ZooModel instances

27

*/

28

static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels);

29

30

/**

31

* Select models by type with workspace mode

32

* @param zooType Type of models to select

33

* @param numLabels Number of output classes

34

* @param workspaceMode Memory workspace configuration

35

* @return Map of ZooType to ZooModel instances

36

*/

37

static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels, WorkspaceMode workspaceMode);

38

39

/**

40

* Select models by type with training parameters

41

* @param zooType Type of models to select

42

* @param numLabels Number of output classes

43

* @param seed Random seed for reproducibility

44

* @param iterations Number of training iterations

45

* @return Map of ZooType to ZooModel instances

46

*/

47

static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels, int seed, int iterations);

48

49

/**

50

* Select models by type with full parameter control

51

* @param zooType Type of models to select

52

* @param numLabels Number of output classes

53

* @param seed Random seed for reproducibility

54

* @param iterations Number of training iterations

55

* @param workspaceMode Memory workspace configuration

56

* @return Map of ZooType to ZooModel instances

57

*/

58

static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels, int seed, int iterations, WorkspaceMode workspaceMode);

59

60

/**

61

* Select specific model types with workspace mode

62

* @param workspaceMode Memory workspace configuration

63

* @param zooTypes Specific model types to select

64

* @return Map of ZooType to ZooModel instances

65

*/

66

static Map<ZooType, ZooModel> select(WorkspaceMode workspaceMode, ZooType... zooTypes);

67

68

/**

69

* Select specific model types with default configuration

70

* @param zooTypes Specific model types to select

71

* @return Map of ZooType to ZooModel instances

72

*/

73

static Map<ZooType, ZooModel> select(ZooType... zooTypes);

74

75

/**

76

* Select specific model types with full parameter control

77

* @param numLabels Number of output classes

78

* @param seed Random seed for reproducibility

79

* @param iterations Number of training iterations

80

* @param workspaceMode Memory workspace configuration

81

* @param zooTypes Specific model types to select

82

* @return Map of ZooType to ZooModel instances

83

*/

84

static Map<ZooType, ZooModel> select(int numLabels, int seed, int iterations, WorkspaceMode workspaceMode, ZooType... zooTypes);

85

}

86

```

87

88

**Usage Examples:**

89

90

```java

91

// Select all CNN models with default settings

92

Map<ZooType, ZooModel> cnnModels = ModelSelector.select(ZooType.CNN);

93

// Returns: AlexNet, VGG16, VGG19, ResNet50, GoogLeNet, LeNet, SimpleCNN

94

95

// Select all models (CNN + RNN)

96

Map<ZooType, ZooModel> allModels = ModelSelector.select(ZooType.ALL);

97

98

// Select specific models

99

Map<ZooType, ZooModel> specificModels = ModelSelector.select(

100

ZooType.ALEXNET,

101

ZooType.VGG16,

102

ZooType.RESNET50

103

);

104

105

// Select with custom configuration

106

Map<ZooType, ZooModel> customModels = ModelSelector.select(

107

ZooType.CNN,

108

10, // 10 classes

109

42, // seed

110

100, // iterations

111

WorkspaceMode.SINGLE

112

);

113

114

// Iterate through selected models

115

for (Map.Entry<ZooType, ZooModel> entry : cnnModels.entrySet()) {

116

ZooType type = entry.getKey();

117

ZooModel model = entry.getValue();

118

119

System.out.println("Model: " + type);

120

Model initializedModel = model.init();

121

ModelMetaData metadata = model.metaData();

122

System.out.println("Input shape: " + Arrays.deepToString(metadata.getInputShape()));

123

}

124

```

125

126

### ZooType Enumeration

127

128

Classification system for different model types and categories.

129

130

```java { .api }

131

/**

132

* Enumerator for choosing different models, and different types of models.

133

*/

134

enum ZooType {

135

/** All available models */

136

ALL,

137

138

/** All CNN models */

139

CNN,

140

141

/** Simple CNN architecture */

142

SIMPLECNN,

143

144

/** AlexNet architecture */

145

ALEXNET,

146

147

/** LeNet architecture */

148

LENET,

149

150

/** GoogLeNet/Inception architecture */

151

GOOGLENET,

152

153

/** VGG16 architecture */

154

VGG16,

155

156

/** VGG19 architecture */

157

VGG19,

158

159

/** ResNet50 architecture */

160

RESNET50,

161

162

/** InceptionResNetV1 architecture */

163

INCEPTIONRESNETV1,

164

165

/** FaceNet NN4 Small2 architecture */

166

FACENETNN4SMALL2,

167

168

/** All RNN models */

169

RNN,

170

171

/** Text generation LSTM */

172

TEXTGENLSTM

173

}

174

```

175

176

**Model Type Hierarchies:**

177

178

```java

179

// CNN models include:

180

ModelSelector.select(ZooType.CNN); // Returns all CNN architectures

181

// - SIMPLECNN, ALEXNET, LENET, GOOGLENET, VGG16, VGG19, RESNET50

182

183

// RNN models include:

184

ModelSelector.select(ZooType.RNN); // Returns all RNN architectures

185

// - TEXTGENLSTM

186

187

// ALL includes both CNN and RNN:

188

ModelSelector.select(ZooType.ALL); // Returns all available models

189

```

190

191

### PretrainedType Enumeration

192

193

Types of pre-trained model weights available for supported models.

194

195

```java { .api }

196

/**

197

* Enumerator for choosing different pre-trained weight types.

198

*/

199

enum PretrainedType {

200

/** ImageNet dataset pre-trained weights (1000 classes) */

201

IMAGENET,

202

203

/** MNIST dataset pre-trained weights (10 digit classes) */

204

MNIST,

205

206

/** CIFAR-10 dataset pre-trained weights (10 object classes) */

207

CIFAR10,

208

209

/** VGGFace dataset pre-trained weights (face recognition) */

210

VGGFACE

211

}

212

```

213

214

**Pre-trained Weight Availability:**

215

216

```java

217

VGG16 vgg16 = new VGG16(1000, 42, 1);

218

219

// Check which pre-trained weights are available

220

boolean hasImageNet = vgg16.pretrainedAvailable(PretrainedType.IMAGENET); // true

221

boolean hasCIFAR10 = vgg16.pretrainedAvailable(PretrainedType.CIFAR10); // true

222

boolean hasVGGFace = vgg16.pretrainedAvailable(PretrainedType.VGGFACE); // true

223

boolean hasMNIST = vgg16.pretrainedAvailable(PretrainedType.MNIST); // false

224

225

// Load specific pre-trained weights

226

Model imageNetModel = vgg16.initPretrained(PretrainedType.IMAGENET);

227

Model cifar10Model = vgg16.initPretrained(PretrainedType.CIFAR10);

228

```

229

230

### Helper Classes

231

232

#### FaceNetHelper

233

234

Utility class for building Inception-style layers used in FaceNet and other advanced architectures.

235

236

```java { .api }

237

/**

238

* Helper class for building Inception-style modules used in FaceNet models.

239

* Provides pre-configured layers and graph building utilities.

240

*/

241

class FaceNetHelper {

242

/**

243

* Returns base module name for inception layers

244

* @return "inception"

245

*/

246

static String getModuleName();

247

248

/**

249

* Returns namespaced module name

250

* @param layerName Name of the specific layer

251

* @return Formatted module name

252

*/

253

static String getModuleName(String layerName);

254

255

/**

256

* Creates 1x1 convolution layer

257

* @param in Number of input channels

258

* @param out Number of output channels

259

* @param bias Bias initialization value

260

* @return ConvolutionLayer configured as 1x1 convolution

261

*/

262

static ConvolutionLayer conv1x1(int in, int out, double bias);

263

264

/**

265

* Creates 3x3 convolution layer

266

* @param in Number of input channels

267

* @param out Number of output channels

268

* @param bias Bias initialization value

269

* @return ConvolutionLayer configured as 3x3 convolution

270

*/

271

static ConvolutionLayer conv3x3(int in, int out, double bias);

272

273

/**

274

* Creates 5x5 convolution layer

275

* @param in Number of input channels

276

* @param out Number of output channels

277

* @param bias Bias initialization value

278

* @return ConvolutionLayer configured as 5x5 convolution

279

*/

280

static ConvolutionLayer conv5x5(int in, int out, double bias);

281

282

/**

283

* Creates 7x7 convolution layer

284

* @param in Number of input channels

285

* @param out Number of output channels

286

* @param bias Bias initialization value

287

* @return ConvolutionLayer configured as 7x7 convolution

288

*/

289

static ConvolutionLayer conv7x7(int in, int out, double bias);

290

291

/**

292

* Creates average pooling layer

293

* @param size Pool size (NxN)

294

* @param stride Stride for pooling

295

* @return SubsamplingLayer configured for average pooling

296

*/

297

static SubsamplingLayer avgPoolNxN(int size, int stride);

298

299

/**

300

* Creates max pooling layer

301

* @param size Pool size (NxN)

302

* @param stride Stride for pooling

303

* @return SubsamplingLayer configured for max pooling

304

*/

305

static SubsamplingLayer maxPoolNxN(int size, int stride);

306

307

/**

308

* Creates p-norm pooling layer

309

* @param pNorm P-norm value

310

* @param size Pool size (NxN)

311

* @param stride Stride for pooling

312

* @return SubsamplingLayer configured for p-norm pooling

313

*/

314

static SubsamplingLayer pNormNxN(int pNorm, int size, int stride);

315

316

/**

317

* Creates fully connected (dense) layer

318

* @param in Number of input units

319

* @param out Number of output units

320

* @param dropOut Dropout rate

321

* @return DenseLayer with specified configuration

322

*/

323

static DenseLayer fullyConnected(int in, int out, double dropOut);

324

325

/**

326

* Creates batch normalization layer

327

* @param in Number of input channels

328

* @param out Number of output channels

329

* @return BatchNormalization layer

330

*/

331

static BatchNormalization batchNorm(int in, int out);

332

333

/**

334

* Appends complete Inception module to a computation graph with default parameters

335

* @param graph Existing graph builder

336

* @param moduleLayerName Name for this inception module

337

* @param inputSize Number of input channels

338

* @param kernelSize Array of kernel sizes for different paths

339

* @param kernelStride Array of strides for different paths

340

* @param outputSize Array of output sizes for different paths

341

* @param reduceSize Array of reduction sizes for different paths

342

* @param poolingType Type of pooling to use

343

* @param transferFunction Activation function

344

* @param inputLayer Name of input layer to connect to

345

* @return Updated GraphBuilder with inception module added

346

*/

347

static ComputationGraphConfiguration.GraphBuilder appendGraph(

348

ComputationGraphConfiguration.GraphBuilder graph,

349

String moduleLayerName,

350

int inputSize,

351

int[] kernelSize,

352

int[] kernelStride,

353

int[] outputSize,

354

int[] reduceSize,

355

SubsamplingLayer.PoolingType poolingType,

356

Activation transferFunction,

357

String inputLayer

358

);

359

360

/**

361

* Appends complete Inception module to a computation graph with p-norm pooling

362

* @param graph Existing graph builder

363

* @param moduleLayerName Name for this inception module

364

* @param inputSize Number of input channels

365

* @param kernelSize Array of kernel sizes for different paths

366

* @param kernelStride Array of strides for different paths

367

* @param outputSize Array of output sizes for different paths

368

* @param reduceSize Array of reduction sizes for different paths

369

* @param poolingType Type of pooling to use

370

* @param pNorm P-norm value (if using p-norm pooling)

371

* @param transferFunction Activation function

372

* @param inputLayer Name of input layer to connect to

373

* @return Updated GraphBuilder with inception module added

374

*/

375

static ComputationGraphConfiguration.GraphBuilder appendGraph(

376

ComputationGraphConfiguration.GraphBuilder graph,

377

String moduleLayerName,

378

int inputSize,

379

int[] kernelSize,

380

int[] kernelStride,

381

int[] outputSize,

382

int[] reduceSize,

383

SubsamplingLayer.PoolingType poolingType,

384

int pNorm,

385

Activation transferFunction,

386

String inputLayer

387

);

388

389

/**

390

* Appends complete Inception module to a computation graph with custom pooling parameters

391

* @param graph Existing graph builder

392

* @param moduleLayerName Name for this inception module

393

* @param inputSize Number of input channels

394

* @param kernelSize Array of kernel sizes for different paths

395

* @param kernelStride Array of strides for different paths

396

* @param outputSize Array of output sizes for different paths

397

* @param reduceSize Array of reduction sizes for different paths

398

* @param poolingType Type of pooling to use

399

* @param poolSize Size of pooling window

400

* @param poolStride Stride for pooling

401

* @param transferFunction Activation function

402

* @param inputLayer Name of input layer to connect to

403

* @return Updated GraphBuilder with inception module added

404

*/

405

static ComputationGraphConfiguration.GraphBuilder appendGraph(

406

ComputationGraphConfiguration.GraphBuilder graph,

407

String moduleLayerName,

408

int inputSize,

409

int[] kernelSize,

410

int[] kernelStride,

411

int[] outputSize,

412

int[] reduceSize,

413

SubsamplingLayer.PoolingType poolingType,

414

int poolSize,

415

int poolStride,

416

Activation transferFunction,

417

String inputLayer

418

);

419

420

/**

421

* Appends complete Inception module to a computation graph with full parameter control

422

* @param graph Existing graph builder

423

* @param moduleLayerName Name for this inception module

424

* @param inputSize Number of input channels

425

* @param kernelSize Array of kernel sizes for different paths

426

* @param kernelStride Array of strides for different paths

427

* @param outputSize Array of output sizes for different paths

428

* @param reduceSize Array of reduction sizes for different paths

429

* @param poolingType Type of pooling to use

430

* @param pNorm P-norm value (if using p-norm pooling)

431

* @param poolSize Size of pooling window

432

* @param poolStride Stride for pooling

433

* @param transferFunction Activation function

434

* @param inputLayer Name of input layer to connect to

435

* @return Updated GraphBuilder with inception module added

436

*/

437

static ComputationGraphConfiguration.GraphBuilder appendGraph(

438

ComputationGraphConfiguration.GraphBuilder graph,

439

String moduleLayerName,

440

int inputSize,

441

int[] kernelSize,

442

int[] kernelStride,

443

int[] outputSize,

444

int[] reduceSize,

445

SubsamplingLayer.PoolingType poolingType,

446

int pNorm,

447

int poolSize,

448

int poolStride,

449

Activation transferFunction,

450

String inputLayer

451

);

452

}

453

```

454

455

**Usage Example:**

456

457

```java

458

// Building custom architecture with Inception modules

459

ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder()

460

.graphBuilder()

461

.addInputs("input");

462

463

// Add custom Inception module

464

graph = FaceNetHelper.appendGraph(

465

graph,

466

"inception_1", // module name

467

64, // input channels

468

new int[]{3, 5}, // kernel sizes

469

new int[]{1, 1}, // strides

470

new int[]{128, 64}, // output sizes

471

new int[]{32, 16, 8}, // reduction sizes

472

SubsamplingLayer.PoolingType.MAX,

473

0, // p-norm (not used for MAX pooling)

474

3, // pool size

475

1, // pool stride

476

Activation.RELU, // activation

477

"input" // input layer name

478

);

479

```

480

481

#### InceptionResNetHelper

482

483

Helper class for building Inception-ResNet architectures that combine Inception modules with residual connections.

484

485

```java { .api }

486

/**

487

* Helper class for building Inception-ResNet modules that combine residual shortcuts

488

* with Inception-style networks. Based on the Inception-ResNet paper.

489

*/

490

class InceptionResNetHelper {

491

/**

492

* Creates layer name with block and iteration naming

493

* @param blockName Name of the inception block

494

* @param layerName Name of the specific layer

495

* @param i Iteration/block number

496

* @return Formatted layer name

497

*/

498

static String nameLayer(String blockName, String layerName, int i);

499

500

/**

501

* Appends Inception-ResNet A blocks to a computation graph

502

* @param graph Existing graph builder

503

* @param blockName Name for this inception block

504

* @param scale Number of blocks to add

505

* @param activationScale Scaling factor for activations

506

* @param input Name of input layer to connect to

507

* @return Updated GraphBuilder with Inception-ResNet A blocks added

508

*/

509

static ComputationGraphConfiguration.GraphBuilder inceptionV1ResA(

510

ComputationGraphConfiguration.GraphBuilder graph,

511

String blockName,

512

int scale,

513

double activationScale,

514

String input

515

);

516

517

/**

518

* Appends Inception-ResNet B blocks to a computation graph

519

* @param graph Existing graph builder

520

* @param blockName Name for this inception block

521

* @param scale Number of blocks to add

522

* @param activationScale Scaling factor for activations

523

* @param input Name of input layer to connect to

524

* @return Updated GraphBuilder with Inception-ResNet B blocks added

525

*/

526

static ComputationGraphConfiguration.GraphBuilder inceptionV1ResB(

527

ComputationGraphConfiguration.GraphBuilder graph,

528

String blockName,

529

int scale,

530

double activationScale,

531

String input

532

);

533

534

/**

535

* Appends Inception-ResNet C blocks to a computation graph

536

* @param graph Existing graph builder

537

* @param blockName Name for this inception block

538

* @param scale Number of blocks to add

539

* @param activationScale Scaling factor for activations

540

* @param input Name of input layer to connect to

541

* @return Updated GraphBuilder with Inception-ResNet C blocks added

542

*/

543

static ComputationGraphConfiguration.GraphBuilder inceptionV1ResC(

544

ComputationGraphConfiguration.GraphBuilder graph,

545

String blockName,

546

int scale,

547

double activationScale,

548

String input

549

);

550

}

551

```

552

553

**Usage Example:**

554

555

```java

556

// Building InceptionResNet architecture

557

ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder()

558

.graphBuilder()

559

.addInputs("input");

560

561

// Add Inception-ResNet A blocks

562

graph = InceptionResNetHelper.inceptionV1ResA(

563

graph,

564

"resnet_a", // block name

565

3, // number of blocks

566

0.1, // activation scaling

567

"input" // input layer

568

);

569

```