or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

cnn-models.mdcore-interface.mdimagenet-integration.mdindex.mdmodel-selection.mdrnn-models.md

imagenet-integration.mddocs/

0

# ImageNet Integration

1

2

Utilities for working with ImageNet classifications, including label decoding and prediction interpretation for models trained on ImageNet dataset.

3

4

## Capabilities

5

6

### ImageNetLabels

7

8

Utility class for decoding ImageNet predictions to human-readable labels. Downloads and caches ImageNet class labels from a remote JSON file and provides methods for interpreting model predictions.

9

10

```java { .api }

11

/**

12

* Helper class with methods for returning ImageNet label descriptions and

13

* decoding prediction arrays to human-readable format.

14

*/

15

class ImageNetLabels {

16

/**

17

* Creates ImageNetLabels instance and loads label data

18

* Downloads class labels from remote JSON if not already cached

19

*/

20

ImageNetLabels();

21

22

/**

23

* Returns the description of the nth class in the 1000 ImageNet classes

24

* @param n Class index (0-999)

25

* @return String description of the ImageNet class

26

*/

27

String getLabel(int n);

28

29

/**

30

* Decodes prediction array to top 5 matches with probabilities

31

* Given predictions from trained model, returns formatted string

32

* listing the top five matches and their respective probabilities

33

* @param predictions INDArray containing model predictions

34

* @return Formatted string with top 5 predictions and probabilities

35

*/

36

String decodePredictions(INDArray predictions);

37

}

38

```

39

40

**Usage Examples:**

41

42

```java

43

// Create ImageNet labels helper

44

ImageNetLabels imageNetLabels = new ImageNetLabels();

45

46

// Get specific class label

47

String label283 = imageNetLabels.getLabel(283); // "Persian cat"

48

String label285 = imageNetLabels.getLabel(285); // "Egyptian cat"

49

50

// Decode model predictions

51

VGG16 vgg16 = new VGG16(1000, 42, 1);

52

Model model = vgg16.initPretrained(PretrainedType.IMAGENET);

53

54

// Assuming you have preprocessed image data as INDArray input

55

INDArray predictions = model.output(imageInput);

56

57

// Decode to human-readable format

58

String topPredictions = imageNetLabels.decodePredictions(predictions);

59

// Output format:

60

// Predictions for batch :

61

// 85.234%, Egyptian cat

62

// 12.456%, Persian cat

63

// 1.789%, tabby cat

64

// 0.345%, tiger cat

65

// 0.176%, lynx

66

```

67

68

**Complete Image Classification Example:**

69

70

```java

71

import org.deeplearning4j.zoo.model.VGG16;

72

import org.deeplearning4j.zoo.util.imagenet.ImageNetLabels;

73

import org.deeplearning4j.zoo.PretrainedType;

74

import org.nd4j.linalg.api.ndarray.INDArray;

75

76

// 1. Load pre-trained ImageNet model

77

VGG16 vgg16 = new VGG16(1000, 42, 1);

78

Model model = vgg16.initPretrained(PretrainedType.IMAGENET);

79

80

// 2. Create ImageNet labels decoder

81

ImageNetLabels imageNetLabels = new ImageNetLabels();

82

83

// 3. Preprocess your image (not shown - requires image loading and preprocessing)

84

// INDArray imageInput = preprocessImage(imagePath);

85

86

// 4. Get model predictions

87

INDArray predictions = model.output(imageInput);

88

89

// 5. Decode predictions to readable format

90

String results = imageNetLabels.decodePredictions(predictions);

91

System.out.println(results);

92

93

// 6. Get individual class labels

94

int topClass = predictions.argMax(1).getInt(0);

95

String topClassName = imageNetLabels.getLabel(topClass);

96

System.out.println("Top prediction: " + topClassName);

97

```

98

99

**Batch Processing Example:**

100

101

```java

102

ImageNetLabels labels = new ImageNetLabels();

103

104

// Process multiple images in batch

105

INDArray batchPredictions = model.output(batchInput); // Shape: [batchSize, 1000]

106

107

// Decode entire batch - shows results for each image in batch

108

String batchResults = labels.decodePredictions(batchPredictions);

109

System.out.println(batchResults);

110

// Output format for batch:

111

// Predictions for batch 0 :

112

// 85.234%, Egyptian cat

113

// ...

114

// Predictions for batch 1 :

115

// 67.891%, golden retriever

116

// ...

117

```

118

119

**Label Index Reference:**

120

121

```java

122

ImageNetLabels labels = new ImageNetLabels();

123

124

// Common ImageNet classes examples:

125

String label0 = labels.getLabel(0); // "tench"

126

String label1 = labels.getLabel(1); // "goldfish"

127

String label151 = labels.getLabel(151); // "Chihuahua"

128

String label285 = labels.getLabel(285); // "Egyptian cat"

129

String label945 = labels.getLabel(945); // "bell pepper"

130

131

// ImageNet has 1000 classes (indices 0-999)

132

for (int i = 0; i < 1000; i++) {

133

String className = labels.getLabel(i);

134

System.out.println("Class " + i + ": " + className);

135

}

136

```

137

138

**Error Handling:**

139

140

```java

141

ImageNetLabels labels = new ImageNetLabels();

142

143

try {

144

// Get label for valid index

145

String validLabel = labels.getLabel(500); // Returns valid class name

146

147

// Note: Invalid indices may cause exceptions

148

// Always ensure index is within 0-999 range for ImageNet

149

150

} catch (Exception e) {

151

System.err.println("Error accessing ImageNet labels: " + e.getMessage());

152

}

153

```

154

155

**Label Data Source:**

156

157

The `ImageNetLabels` class automatically downloads the ImageNet class index from:

158

- **URL**: `http://blob.deeplearning4j.org/utils/imagenet_class_index.json`

159

- **Format**: JSON mapping class indices to [internal_id, human_readable_name]

160

- **Caching**: Labels are cached in memory after first download

161

- **Total Classes**: 1000 classes (indices 0-999)

162

163

**Integration with Zoo Models:**

164

165

```java

166

// Works with any ImageNet pre-trained model from the zoo

167

ImageNetLabels labels = new ImageNetLabels();

168

169

// VGG16 with ImageNet weights

170

VGG16 vgg16 = new VGG16(1000, 42, 1);

171

Model vgg16Model = vgg16.initPretrained(PretrainedType.IMAGENET);

172

173

// AlexNet (if it had ImageNet weights available)

174

AlexNet alexNet = new AlexNet(1000, 42, 1);

175

if (alexNet.pretrainedAvailable(PretrainedType.IMAGENET)) {

176

Model alexNetModel = alexNet.initPretrained(PretrainedType.IMAGENET);

177

INDArray predictions = alexNetModel.output(input);

178

String results = labels.decodePredictions(predictions);

179

}

180

181

// ResNet50 with ImageNet weights

182

ResNet50 resNet50 = new ResNet50(1000, 42, 1);

183

if (resNet50.pretrainedAvailable(PretrainedType.IMAGENET)) {

184

Model resNetModel = resNet50.initPretrained(PretrainedType.IMAGENET);

185

INDArray predictions = resNetModel.output(input);

186

String results = labels.decodePredictions(predictions);

187

}

188

```