0
# ImageNet Integration
1
2
Utilities for working with ImageNet classifications, including label decoding and prediction interpretation for models trained on ImageNet dataset.
3
4
## Capabilities
5
6
### ImageNetLabels
7
8
Utility class for decoding ImageNet predictions to human-readable labels. Downloads and caches ImageNet class labels from a remote JSON file and provides methods for interpreting model predictions.
9
10
```java { .api }
11
/**
12
* Helper class with methods for returning ImageNet label descriptions and
13
* decoding prediction arrays to human-readable format.
14
*/
15
class ImageNetLabels {
16
/**
17
* Creates ImageNetLabels instance and loads label data
18
* Downloads class labels from remote JSON if not already cached
19
*/
20
ImageNetLabels();
21
22
/**
23
* Returns the description of the nth class in the 1000 ImageNet classes
24
* @param n Class index (0-999)
25
* @return String description of the ImageNet class
26
*/
27
String getLabel(int n);
28
29
/**
30
* Decodes prediction array to top 5 matches with probabilities
31
* Given predictions from trained model, returns formatted string
32
* listing the top five matches and their respective probabilities
33
* @param predictions INDArray containing model predictions
34
* @return Formatted string with top 5 predictions and probabilities
35
*/
36
String decodePredictions(INDArray predictions);
37
}
38
```
39
40
**Usage Examples:**
41
42
```java
43
// Create ImageNet labels helper
44
ImageNetLabels imageNetLabels = new ImageNetLabels();
45
46
// Get specific class label
47
String label283 = imageNetLabels.getLabel(283); // "Persian cat"
48
String label285 = imageNetLabels.getLabel(285); // "Egyptian cat"
49
50
// Decode model predictions
51
VGG16 vgg16 = new VGG16(1000, 42, 1);
52
Model model = vgg16.initPretrained(PretrainedType.IMAGENET);
53
54
// Assuming you have preprocessed image data as INDArray input
55
INDArray predictions = model.output(imageInput);
56
57
// Decode to human-readable format
58
String topPredictions = imageNetLabels.decodePredictions(predictions);
59
// Output format:
60
// Predictions for batch :
61
// 85.234%, Egyptian cat
62
// 12.456%, Persian cat
63
// 1.789%, tabby cat
64
// 0.345%, tiger cat
65
// 0.176%, lynx
66
```
67
68
**Complete Image Classification Example:**
69
70
```java
71
import org.deeplearning4j.zoo.model.VGG16;
72
import org.deeplearning4j.zoo.util.imagenet.ImageNetLabels;
73
import org.deeplearning4j.zoo.PretrainedType;
74
import org.nd4j.linalg.api.ndarray.INDArray;
75
76
// 1. Load pre-trained ImageNet model
77
VGG16 vgg16 = new VGG16(1000, 42, 1);
78
Model model = vgg16.initPretrained(PretrainedType.IMAGENET);
79
80
// 2. Create ImageNet labels decoder
81
ImageNetLabels imageNetLabels = new ImageNetLabels();
82
83
// 3. Preprocess your image (not shown - requires image loading and preprocessing)
84
// INDArray imageInput = preprocessImage(imagePath);
85
86
// 4. Get model predictions
87
INDArray predictions = model.output(imageInput);
88
89
// 5. Decode predictions to readable format
90
String results = imageNetLabels.decodePredictions(predictions);
91
System.out.println(results);
92
93
// 6. Get individual class labels
94
int topClass = predictions.argMax(1).getInt(0);
95
String topClassName = imageNetLabels.getLabel(topClass);
96
System.out.println("Top prediction: " + topClassName);
97
```
98
99
**Batch Processing Example:**
100
101
```java
102
ImageNetLabels labels = new ImageNetLabels();
103
104
// Process multiple images in batch
105
INDArray batchPredictions = model.output(batchInput); // Shape: [batchSize, 1000]
106
107
// Decode entire batch - shows results for each image in batch
108
String batchResults = labels.decodePredictions(batchPredictions);
109
System.out.println(batchResults);
110
// Output format for batch:
111
// Predictions for batch 0 :
112
// 85.234%, Egyptian cat
113
// ...
114
// Predictions for batch 1 :
115
// 67.891%, golden retriever
116
// ...
117
```
118
119
**Label Index Reference:**
120
121
```java
122
ImageNetLabels labels = new ImageNetLabels();
123
124
// Common ImageNet classes examples:
125
String label0 = labels.getLabel(0); // "tench"
126
String label1 = labels.getLabel(1); // "goldfish"
127
String label151 = labels.getLabel(151); // "Chihuahua"
128
String label285 = labels.getLabel(285); // "Egyptian cat"
129
String label945 = labels.getLabel(945); // "bell pepper"
130
131
// ImageNet has 1000 classes (indices 0-999)
132
for (int i = 0; i < 1000; i++) {
133
String className = labels.getLabel(i);
134
System.out.println("Class " + i + ": " + className);
135
}
136
```
137
138
**Error Handling:**
139
140
```java
141
ImageNetLabels labels = new ImageNetLabels();
142
143
try {
144
// Get label for valid index
145
String validLabel = labels.getLabel(500); // Returns valid class name
146
147
// Note: Invalid indices may cause exceptions
148
// Always ensure index is within 0-999 range for ImageNet
149
150
} catch (Exception e) {
151
System.err.println("Error accessing ImageNet labels: " + e.getMessage());
152
}
153
```
154
155
**Label Data Source:**
156
157
The `ImageNetLabels` class automatically downloads the ImageNet class index from:
158
- **URL**: `http://blob.deeplearning4j.org/utils/imagenet_class_index.json`
159
- **Format**: JSON mapping class indices to [internal_id, human_readable_name]
160
- **Caching**: Labels are cached in memory after first download
161
- **Total Classes**: 1000 classes (indices 0-999)
162
163
**Integration with Zoo Models:**
164
165
```java
166
// Works with any ImageNet pre-trained model from the zoo
167
ImageNetLabels labels = new ImageNetLabels();
168
169
// VGG16 with ImageNet weights
170
VGG16 vgg16 = new VGG16(1000, 42, 1);
171
Model vgg16Model = vgg16.initPretrained(PretrainedType.IMAGENET);
172
173
// AlexNet (if it had ImageNet weights available)
174
AlexNet alexNet = new AlexNet(1000, 42, 1);
175
if (alexNet.pretrainedAvailable(PretrainedType.IMAGENET)) {
176
Model alexNetModel = alexNet.initPretrained(PretrainedType.IMAGENET);
177
INDArray predictions = alexNetModel.output(input);
178
String results = labels.decodePredictions(predictions);
179
}
180
181
// ResNet50 with ImageNet weights
182
ResNet50 resNet50 = new ResNet50(1000, 42, 1);
183
if (resNet50.pretrainedAvailable(PretrainedType.IMAGENET)) {
184
Model resNetModel = resNet50.initPretrained(PretrainedType.IMAGENET);
185
INDArray predictions = resNetModel.output(input);
186
String results = labels.decodePredictions(predictions);
187
}
188
```