0
# Pre-trained Models (Deprecated)
1
2
**Note**: This functionality is deprecated. For new projects, use the `deeplearning4j-zoo` module which provides a more comprehensive and maintained set of pre-trained models.
3
4
Legacy support for popular pre-trained image classification models, specifically VGG16 variants with ImageNet weights.
5
6
## TrainedModels Enum
7
8
The `TrainedModels` enum provides access to pre-trained models with automatic downloading and setup.
9
10
```java { .api }
11
public enum TrainedModels {
12
VGG16, // VGG16 with ImageNet weights and classification head
13
VGG16NOTOP; // VGG16 with ImageNet weights, no classification head
14
15
// Get the complete model as ComputationGraph
16
public ComputationGraph getComputationGraph() throws IOException;
17
18
// Get appropriate preprocessor for the model
19
public DataSetPreProcessor getDataSetPreProcessor();
20
21
// Get ImageNet class labels
22
public ArrayList<String> getLabels();
23
24
// Get expected input shape
25
public int[] getInputShape();
26
}
27
```
28
29
## Available Models
30
31
### VGG16
32
33
Complete VGG16 model with ImageNet weights and classification head.
34
35
```java { .api }
36
TrainedModels.VGG16
37
```
38
39
**Specifications:**
40
- **Input Shape**: 224 × 224 × 3 (RGB images)
41
- **Output**: 1000 classes (ImageNet categories)
42
- **Weights**: Pre-trained on ImageNet dataset
43
- **Architecture**: 16-layer VGG network with classification head
44
45
### VGG16NOTOP
46
47
VGG16 model without the final classification layers, useful for feature extraction and transfer learning.
48
49
```java { .api }
50
TrainedModels.VGG16NOTOP
51
```
52
53
**Specifications:**
54
- **Input Shape**: 224 × 224 × 3 (RGB images)
55
- **Output**: Feature vectors (7 × 7 × 512)
56
- **Weights**: Pre-trained on ImageNet dataset
57
- **Architecture**: VGG16 without final dense layers
58
59
## Usage Examples
60
61
### Basic Image Classification
62
63
```java
64
import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels;
65
import org.deeplearning4j.nn.graph.ComputationGraph;
66
import org.nd4j.linalg.dataset.api.preprocessor.DataSetPreProcessor;
67
68
// Load VGG16 model
69
ComputationGraph vgg16 = TrainedModels.VGG16.getComputationGraph();
70
71
// Get appropriate preprocessor
72
DataSetPreProcessor preprocessor = TrainedModels.VGG16.getDataSetPreProcessor();
73
74
// Load and preprocess image
75
BufferedImage image = ImageIO.read(new File("image.jpg"));
76
INDArray imageArray = convertImageToINDArray(image); // Custom conversion method
77
preprocessor.preProcess(new DataSet(imageArray, null));
78
79
// Make prediction
80
INDArray output = vgg16.outputSingle(imageArray);
81
82
// Get predicted class
83
int predictedClass = Nd4j.argMax(output, 1).getInt(0);
84
String predictedLabel = TrainedModels.VGG16.getLabels().get(predictedClass);
85
86
System.out.println("Predicted: " + predictedLabel);
87
```
88
89
### Feature Extraction
90
91
```java
92
import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels;
93
import org.deeplearning4j.nn.graph.ComputationGraph;
94
95
// Load VGG16 without top layers for feature extraction
96
ComputationGraph featureExtractor = TrainedModels.VGG16NOTOP.getComputationGraph();
97
98
// Extract features from image
99
INDArray imageArray = preprocessedImage; // Your preprocessed image
100
INDArray features = featureExtractor.outputSingle(imageArray);
101
102
// Features shape: [1, 7, 7, 512] - can be flattened for use in other models
103
INDArray flatFeatures = features.reshape(1, 7 * 7 * 512);
104
```
105
106
### Transfer Learning Setup
107
108
```java
109
// Load feature extractor
110
ComputationGraph baseModel = TrainedModels.VGG16NOTOP.getComputationGraph();
111
112
// Create new model with custom classification head
113
ComputationGraphConfiguration.GraphBuilder confBuilder = new NeuralNetConfiguration.Builder()
114
.graphBuilder();
115
116
// Add VGG16 layers (frozen)
117
// ... copy layers from baseModel ...
118
119
// Add custom classification layers
120
confBuilder.addLayer("custom_dense", new DenseLayer.Builder()
121
.nIn(7 * 7 * 512)
122
.nOut(128)
123
.activation(Activation.RELU)
124
.build(), "vgg16_features");
125
126
confBuilder.addLayer("output", new OutputLayer.Builder()
127
.nIn(128)
128
.nOut(numCustomClasses)
129
.activation(Activation.SOFTMAX)
130
.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
131
.build(), "custom_dense");
132
133
// Build and initialize custom model
134
ComputationGraph customModel = new ComputationGraph(confBuilder.build());
135
customModel.init();
136
```
137
138
## ImageNet Labels
139
140
Access to ImageNet class labels for interpreting model predictions.
141
142
### ImageNetLabels Utility Class
143
144
```java { .api }
145
public class ImageNetLabels {
146
// Get all 1000 ImageNet class labels
147
public static ArrayList<String> getLabels();
148
149
// Get specific class label by index (0-999)
150
public static String getLabel(int n);
151
}
152
```
153
154
### Usage Examples
155
156
```java
157
import org.deeplearning4j.nn.modelimport.keras.trainedmodels.Utils.ImageNetLabels;
158
159
// Get all labels
160
ArrayList<String> allLabels = ImageNetLabels.getLabels();
161
System.out.println("Total classes: " + allLabels.size());
162
163
// Get specific label
164
String label = ImageNetLabels.getLabel(281); // "tabby, tabby cat"
165
System.out.println("Class 281: " + label);
166
167
// Top-k predictions
168
INDArray predictions = model.outputSingle(input);
169
int[] topK = getTopKIndices(predictions, 5); // Custom method to get top-5
170
171
System.out.println("Top 5 predictions:");
172
for (int i = 0; i < topK.length; i++) {
173
double confidence = predictions.getDouble(topK[i]);
174
String className = ImageNetLabels.getLabel(topK[i]);
175
System.out.println((i+1) + ". " + className + " (" + confidence + ")");
176
}
177
```
178
179
## Data Preprocessing
180
181
### VGG16ImagePreProcessor
182
183
The VGG16 models require specific preprocessing to match training conditions.
184
185
```java { .api }
186
// VGG16 preprocessing (from nd4j-dataset-api)
187
public class VGG16ImagePreProcessor implements DataSetPreProcessor {
188
// Applies VGG16-specific preprocessing:
189
// - Converts RGB to BGR
190
// - Subtracts ImageNet mean values
191
// - Scales appropriately
192
}
193
```
194
195
### Custom Preprocessing Pipeline
196
197
```java
198
import org.nd4j.linalg.dataset.api.preprocessor.VGG16ImagePreProcessor;
199
import org.datavec.image.loader.NativeImageLoader;
200
201
// Image loading and preprocessing pipeline
202
NativeImageLoader loader = new NativeImageLoader(224, 224, 3);
203
VGG16ImagePreProcessor preprocessor = new VGG16ImagePreProcessor();
204
205
// Load image
206
BufferedImage image = ImageIO.read(new File("input.jpg"));
207
INDArray imageArray = loader.asMatrix(image);
208
209
// Apply preprocessing
210
DataSet ds = new DataSet(imageArray, null);
211
preprocessor.preProcess(ds);
212
INDArray processedImage = ds.getFeatures();
213
214
// Now ready for VGG16 inference
215
INDArray prediction = vgg16Model.outputSingle(processedImage);
216
```
217
218
## Model Caching and Downloads
219
220
### Automatic Model Management
221
222
The `TrainedModels` enum automatically handles:
223
- Model downloading from remote URLs
224
- Local caching in `~/.dl4j/trainedmodels/` directory
225
- Version management and updates
226
227
### Cache Structure
228
229
```
230
~/.dl4j/trainedmodels/
231
├── vgg16/
232
│ ├── vgg16.json # Model architecture
233
│ └── vgg16_weights.h5 # Pre-trained weights
234
└── vgg16notop/
235
├── vgg16notop.json # Model architecture
236
└── vgg16notop_weights.h5 # Pre-trained weights
237
```
238
239
### Manual Cache Management
240
241
```java
242
// Models are automatically downloaded on first use
243
// Cache location: System.getProperty("user.home") + "/.dl4j/trainedmodels/"
244
245
// To clear cache, delete the directory manually
246
File cacheDir = new File(System.getProperty("user.home"), ".dl4j/trainedmodels");
247
if (cacheDir.exists()) {
248
// Delete cache directory if needed
249
}
250
```
251
252
## Migration to DL4J Zoo
253
254
### Recommended Migration Path
255
256
Instead of using the deprecated `TrainedModels` enum, use the `deeplearning4j-zoo` module:
257
258
```xml
259
<dependency>
260
<groupId>org.deeplearning4j</groupId>
261
<artifactId>deeplearning4j-zoo</artifactId>
262
<version>0.9.1</version>
263
</dependency>
264
```
265
266
```java
267
// New approach with DL4J Zoo
268
import org.deeplearning4j.zoo.model.VGG16;
269
import org.deeplearning4j.zoo.PretrainedType;
270
271
// Load VGG16 from zoo
272
VGG16 vgg16 = VGG16.builder().build();
273
ComputationGraph model = (ComputationGraph) vgg16.initPretrained(PretrainedType.IMAGENET);
274
```
275
276
### Benefits of DL4J Zoo
277
278
- **More Models**: ResNet, AlexNet, LeNet, etc.
279
- **Better Maintenance**: Active development and updates
280
- **Improved APIs**: More consistent and feature-rich
281
- **Better Documentation**: Comprehensive examples and guides
282
- **Performance**: Optimized implementations
283
284
## Limitations
285
286
### Deprecated Status
287
- No new features or models will be added
288
- Limited bug fixes and maintenance
289
- May be removed in future versions
290
291
### Model Limitations
292
- Only VGG16 variants available
293
- Fixed input sizes (224×224 for VGG16)
294
- ImageNet-specific preprocessing requirements
295
- Limited customization options
296
297
### Compatibility Issues
298
- May not work with newer Keras/TensorFlow versions
299
- HDF5 format dependencies
300
- Network connectivity required for initial download