or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

builder-pattern.mdconfiguration-import.mdindex.mdlayer-support.mdmodel-import.mdpretrained-models.mdseparate-files-import.md

pretrained-models.mddocs/

0

# Pre-trained Models (Deprecated)

1

2

**Note**: This functionality is deprecated. For new projects, use the `deeplearning4j-zoo` module which provides a more comprehensive and maintained set of pre-trained models.

3

4

Legacy support for popular pre-trained image classification models, specifically VGG16 variants with ImageNet weights.

5

6

## TrainedModels Enum

7

8

The `TrainedModels` enum provides access to pre-trained models with automatic downloading and setup.

9

10

```java { .api }

11

public enum TrainedModels {

12

VGG16, // VGG16 with ImageNet weights and classification head

13

VGG16NOTOP; // VGG16 with ImageNet weights, no classification head

14

15

// Get the complete model as ComputationGraph

16

public ComputationGraph getComputationGraph() throws IOException;

17

18

// Get appropriate preprocessor for the model

19

public DataSetPreProcessor getDataSetPreProcessor();

20

21

// Get ImageNet class labels

22

public ArrayList<String> getLabels();

23

24

// Get expected input shape

25

public int[] getInputShape();

26

}

27

```

28

29

## Available Models

30

31

### VGG16

32

33

Complete VGG16 model with ImageNet weights and classification head.

34

35

```java { .api }

36

TrainedModels.VGG16

37

```

38

39

**Specifications:**

40

- **Input Shape**: 224 × 224 × 3 (RGB images)

41

- **Output**: 1000 classes (ImageNet categories)

42

- **Weights**: Pre-trained on ImageNet dataset

43

- **Architecture**: 16-layer VGG network with classification head

44

45

### VGG16NOTOP

46

47

VGG16 model without the final classification layers, useful for feature extraction and transfer learning.

48

49

```java { .api }

50

TrainedModels.VGG16NOTOP

51

```

52

53

**Specifications:**

54

- **Input Shape**: 224 × 224 × 3 (RGB images)

55

- **Output**: Feature vectors (7 × 7 × 512)

56

- **Weights**: Pre-trained on ImageNet dataset

57

- **Architecture**: VGG16 without final dense layers

58

59

## Usage Examples

60

61

### Basic Image Classification

62

63

```java

64

import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels;

65

import org.deeplearning4j.nn.graph.ComputationGraph;

66

import org.nd4j.linalg.dataset.api.preprocessor.DataSetPreProcessor;

67

68

// Load VGG16 model

69

ComputationGraph vgg16 = TrainedModels.VGG16.getComputationGraph();

70

71

// Get appropriate preprocessor

72

DataSetPreProcessor preprocessor = TrainedModels.VGG16.getDataSetPreProcessor();

73

74

// Load and preprocess image

75

BufferedImage image = ImageIO.read(new File("image.jpg"));

76

INDArray imageArray = convertImageToINDArray(image); // Custom conversion method

77

preprocessor.preProcess(new DataSet(imageArray, null));

78

79

// Make prediction

80

INDArray output = vgg16.outputSingle(imageArray);

81

82

// Get predicted class

83

int predictedClass = Nd4j.argMax(output, 1).getInt(0);

84

String predictedLabel = TrainedModels.VGG16.getLabels().get(predictedClass);

85

86

System.out.println("Predicted: " + predictedLabel);

87

```

88

89

### Feature Extraction

90

91

```java

92

import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels;

93

import org.deeplearning4j.nn.graph.ComputationGraph;

94

95

// Load VGG16 without top layers for feature extraction

96

ComputationGraph featureExtractor = TrainedModels.VGG16NOTOP.getComputationGraph();

97

98

// Extract features from image

99

INDArray imageArray = preprocessedImage; // Your preprocessed image

100

INDArray features = featureExtractor.outputSingle(imageArray);

101

102

// Features shape: [1, 7, 7, 512] - can be flattened for use in other models

103

INDArray flatFeatures = features.reshape(1, 7 * 7 * 512);

104

```

105

106

### Transfer Learning Setup

107

108

```java

109

// Load feature extractor

110

ComputationGraph baseModel = TrainedModels.VGG16NOTOP.getComputationGraph();

111

112

// Create new model with custom classification head

113

ComputationGraphConfiguration.GraphBuilder confBuilder = new NeuralNetConfiguration.Builder()

114

.graphBuilder();

115

116

// Add VGG16 layers (frozen)

117

// ... copy layers from baseModel ...

118

119

// Add custom classification layers

120

confBuilder.addLayer("custom_dense", new DenseLayer.Builder()

121

.nIn(7 * 7 * 512)

122

.nOut(128)

123

.activation(Activation.RELU)

124

.build(), "vgg16_features");

125

126

confBuilder.addLayer("output", new OutputLayer.Builder()

127

.nIn(128)

128

.nOut(numCustomClasses)

129

.activation(Activation.SOFTMAX)

130

.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)

131

.build(), "custom_dense");

132

133

// Build and initialize custom model

134

ComputationGraph customModel = new ComputationGraph(confBuilder.build());

135

customModel.init();

136

```

137

138

## ImageNet Labels

139

140

Access to ImageNet class labels for interpreting model predictions.

141

142

### ImageNetLabels Utility Class

143

144

```java { .api }

145

public class ImageNetLabels {

146

// Get all 1000 ImageNet class labels

147

public static ArrayList<String> getLabels();

148

149

// Get specific class label by index (0-999)

150

public static String getLabel(int n);

151

}

152

```

153

154

### Usage Examples

155

156

```java

157

import org.deeplearning4j.nn.modelimport.keras.trainedmodels.Utils.ImageNetLabels;

158

159

// Get all labels

160

ArrayList<String> allLabels = ImageNetLabels.getLabels();

161

System.out.println("Total classes: " + allLabels.size());

162

163

// Get specific label

164

String label = ImageNetLabels.getLabel(281); // "tabby, tabby cat"

165

System.out.println("Class 281: " + label);

166

167

// Top-k predictions

168

INDArray predictions = model.outputSingle(input);

169

int[] topK = getTopKIndices(predictions, 5); // Custom method to get top-5

170

171

System.out.println("Top 5 predictions:");

172

for (int i = 0; i < topK.length; i++) {

173

double confidence = predictions.getDouble(topK[i]);

174

String className = ImageNetLabels.getLabel(topK[i]);

175

System.out.println((i+1) + ". " + className + " (" + confidence + ")");

176

}

177

```

178

179

## Data Preprocessing

180

181

### VGG16ImagePreProcessor

182

183

The VGG16 models require specific preprocessing to match training conditions.

184

185

```java { .api }

186

// VGG16 preprocessing (from nd4j-dataset-api)

187

public class VGG16ImagePreProcessor implements DataSetPreProcessor {

188

// Applies VGG16-specific preprocessing:

189

// - Converts RGB to BGR

190

// - Subtracts ImageNet mean values

191

// - Scales appropriately

192

}

193

```

194

195

### Custom Preprocessing Pipeline

196

197

```java

198

import org.nd4j.linalg.dataset.api.preprocessor.VGG16ImagePreProcessor;

199

import org.datavec.image.loader.NativeImageLoader;

200

201

// Image loading and preprocessing pipeline

202

NativeImageLoader loader = new NativeImageLoader(224, 224, 3);

203

VGG16ImagePreProcessor preprocessor = new VGG16ImagePreProcessor();

204

205

// Load image

206

BufferedImage image = ImageIO.read(new File("input.jpg"));

207

INDArray imageArray = loader.asMatrix(image);

208

209

// Apply preprocessing

210

DataSet ds = new DataSet(imageArray, null);

211

preprocessor.preProcess(ds);

212

INDArray processedImage = ds.getFeatures();

213

214

// Now ready for VGG16 inference

215

INDArray prediction = vgg16Model.outputSingle(processedImage);

216

```

217

218

## Model Caching and Downloads

219

220

### Automatic Model Management

221

222

The `TrainedModels` enum automatically handles:

223

- Model downloading from remote URLs

224

- Local caching in `~/.dl4j/trainedmodels/` directory

225

- Version management and updates

226

227

### Cache Structure

228

229

```

230

~/.dl4j/trainedmodels/

231

├── vgg16/

232

│ ├── vgg16.json # Model architecture

233

│ └── vgg16_weights.h5 # Pre-trained weights

234

└── vgg16notop/

235

├── vgg16notop.json # Model architecture

236

└── vgg16notop_weights.h5 # Pre-trained weights

237

```

238

239

### Manual Cache Management

240

241

```java

242

// Models are automatically downloaded on first use

243

// Cache location: System.getProperty("user.home") + "/.dl4j/trainedmodels/"

244

245

// To clear cache, delete the directory manually

246

File cacheDir = new File(System.getProperty("user.home"), ".dl4j/trainedmodels");

247

if (cacheDir.exists()) {

248

// Delete cache directory if needed

249

}

250

```

251

252

## Migration to DL4J Zoo

253

254

### Recommended Migration Path

255

256

Instead of using the deprecated `TrainedModels` enum, use the `deeplearning4j-zoo` module:

257

258

```xml

259

<dependency>

260

<groupId>org.deeplearning4j</groupId>

261

<artifactId>deeplearning4j-zoo</artifactId>

262

<version>0.9.1</version>

263

</dependency>

264

```

265

266

```java

267

// New approach with DL4J Zoo

268

import org.deeplearning4j.zoo.model.VGG16;

269

import org.deeplearning4j.zoo.PretrainedType;

270

271

// Load VGG16 from zoo

272

VGG16 vgg16 = VGG16.builder().build();

273

ComputationGraph model = (ComputationGraph) vgg16.initPretrained(PretrainedType.IMAGENET);

274

```

275

276

### Benefits of DL4J Zoo

277

278

- **More Models**: ResNet, AlexNet, LeNet, etc.

279

- **Better Maintenance**: Active development and updates

280

- **Improved APIs**: More consistent and feature-rich

281

- **Better Documentation**: Comprehensive examples and guides

282

- **Performance**: Optimized implementations

283

284

## Limitations

285

286

### Deprecated Status

287

- No new features or models will be added

288

- Limited bug fixes and maintenance

289

- May be removed in future versions

290

291

### Model Limitations

292

- Only VGG16 variants available

293

- Fixed input sizes (224×224 for VGG16)

294

- ImageNet-specific preprocessing requirements

295

- Limited customization options

296

297

### Compatibility Issues

298

- May not work with newer Keras/TensorFlow versions

299

- HDF5 format dependencies

300

- Network connectivity required for initial download