0
# Image Models
1
2
Comprehensive computer vision models for image classification, object detection, and image segmentation tasks. Keras Hub provides implementations of popular architectures like ResNet, Vision Transformer (ViT), EfficientNet, and specialized models for various visual understanding tasks.
3
4
## Capabilities
5
6
### Base Classes
7
8
Foundation classes that define the interface for different types of image models.
9
10
```python { .api }
11
class ImageClassifier(Task):
12
"""Base class for image classification models."""
13
def __init__(
14
self,
15
backbone: Backbone,
16
num_classes: int,
17
preprocessor: Preprocessor = None,
18
**kwargs
19
): ...
20
21
class ObjectDetector(Task):
22
"""Base class for object detection models."""
23
def __init__(
24
self,
25
backbone: Backbone,
26
num_classes: int,
27
preprocessor: Preprocessor = None,
28
**kwargs
29
): ...
30
31
class ImageSegmenter(Task):
32
"""Base class for image segmentation models."""
33
def __init__(
34
self,
35
backbone: Backbone,
36
num_classes: int,
37
preprocessor: Preprocessor = None,
38
**kwargs
39
): ...
40
41
# Aliases
42
ImageObjectDetector = ObjectDetector
43
```
44
45
### ResNet (Residual Networks)
46
47
Deep residual networks for image classification with skip connections to enable training of very deep networks.
48
49
```python { .api }
50
class ResNetBackbone(Backbone):
51
"""ResNet backbone architecture."""
52
def __init__(
53
self,
54
stackwise_num_filters: list,
55
stackwise_num_blocks: list,
56
stackwise_num_strides: list,
57
block_type: str = "basic_block",
58
use_pre_activation: bool = False,
59
image_shape: tuple = (224, 224, 3),
60
**kwargs
61
): ...
62
63
class ResNetImageClassifier(ImageClassifier):
64
"""ResNet model for image classification."""
65
def __init__(
66
self,
67
backbone: ResNetBackbone,
68
num_classes: int,
69
preprocessor: Preprocessor = None,
70
**kwargs
71
): ...
72
73
class ResNetImageClassifierPreprocessor:
74
"""Preprocessor for ResNet image classification."""
75
def __init__(
76
self,
77
image_converter: ImageConverter,
78
**kwargs
79
): ...
80
81
class ResNetImageConverter:
82
"""Image converter for ResNet models."""
83
def __init__(
84
self,
85
height: int = 224,
86
width: int = 224,
87
crop_to_aspect_ratio: bool = True,
88
interpolation: str = "bilinear",
89
data_format: str = None,
90
**kwargs
91
): ...
92
```
93
94
### Vision Transformer (ViT)
95
96
Transformer architecture applied to image classification by treating image patches as sequences.
97
98
```python { .api }
99
class ViTBackbone(Backbone):
100
"""Vision Transformer backbone."""
101
def __init__(
102
self,
103
image_shape: tuple = (224, 224, 3),
104
patch_size: int = 16,
105
num_layers: int = 12,
106
num_heads: int = 12,
107
hidden_dim: int = 768,
108
mlp_dim: int = 3072,
109
dropout: float = 0.1,
110
**kwargs
111
): ...
112
113
class ViTImageClassifier(ImageClassifier):
114
"""Vision Transformer for image classification."""
115
def __init__(
116
self,
117
backbone: ViTBackbone,
118
num_classes: int,
119
preprocessor: Preprocessor = None,
120
**kwargs
121
): ...
122
123
class ViTImageClassifierPreprocessor:
124
"""Preprocessor for ViT image classification."""
125
def __init__(
126
self,
127
image_converter: ImageConverter,
128
**kwargs
129
): ...
130
131
class ViTImageConverter:
132
"""Image converter for ViT models."""
133
def __init__(
134
self,
135
height: int = 224,
136
width: int = 224,
137
crop_to_aspect_ratio: bool = True,
138
interpolation: str = "bilinear",
139
**kwargs
140
): ...
141
```
142
143
### EfficientNet
144
145
Scalable convolutional neural network architecture optimized for efficiency.
146
147
```python { .api }
148
class EfficientNetBackbone(Backbone):
149
"""EfficientNet backbone architecture."""
150
def __init__(
151
self,
152
stackwise_kernel_sizes: list,
153
stackwise_num_repeats: list,
154
stackwise_input_filters: list,
155
stackwise_output_filters: list,
156
stackwise_expand_ratios: list,
157
stackwise_strides: list,
158
width_coefficient: float = 1.0,
159
depth_coefficient: float = 1.0,
160
image_shape: tuple = (224, 224, 3),
161
**kwargs
162
): ...
163
164
class EfficientNetImageClassifier(ImageClassifier):
165
"""EfficientNet model for image classification."""
166
def __init__(
167
self,
168
backbone: EfficientNetBackbone,
169
num_classes: int,
170
preprocessor: Preprocessor = None,
171
**kwargs
172
): ...
173
174
class EfficientNetImageClassifierPreprocessor:
175
"""Preprocessor for EfficientNet image classification."""
176
def __init__(
177
self,
178
image_converter: ImageConverter,
179
**kwargs
180
): ...
181
182
class EfficientNetImageConverter:
183
"""Image converter for EfficientNet models."""
184
def __init__(
185
self,
186
height: int = 224,
187
width: int = 224,
188
crop_to_aspect_ratio: bool = True,
189
interpolation: str = "bilinear",
190
**kwargs
191
): ...
192
```
193
194
### Object Detection Models
195
196
Models specialized for detecting and localizing objects in images.
197
198
```python { .api }
199
class RetinaNetBackbone(Backbone):
200
"""RetinaNet backbone for object detection."""
201
def __init__(
202
self,
203
stackwise_num_filters: list,
204
stackwise_num_blocks: list,
205
stackwise_num_strides: list,
206
image_shape: tuple = (512, 512, 3),
207
**kwargs
208
): ...
209
210
class RetinaNetObjectDetector(ObjectDetector):
211
"""RetinaNet model for object detection."""
212
def __init__(
213
self,
214
backbone: RetinaNetBackbone,
215
num_classes: int,
216
preprocessor: Preprocessor = None,
217
**kwargs
218
): ...
219
220
class RetinaNetObjectDetectorPreprocessor:
221
"""Preprocessor for RetinaNet object detection."""
222
def __init__(
223
self,
224
image_converter: ImageConverter,
225
**kwargs
226
): ...
227
228
class RetinaNetImageConverter:
229
"""Image converter for RetinaNet models."""
230
def __init__(
231
self,
232
height: int = 512,
233
width: int = 512,
234
crop_to_aspect_ratio: bool = True,
235
interpolation: str = "bilinear",
236
**kwargs
237
): ...
238
239
class ViTDetBackbone(Backbone):
240
"""Vision Transformer backbone for object detection."""
241
def __init__(
242
self,
243
image_shape: tuple = (1024, 1024, 3),
244
patch_size: int = 16,
245
num_layers: int = 12,
246
num_heads: int = 12,
247
hidden_dim: int = 768,
248
mlp_dim: int = 3072,
249
**kwargs
250
): ...
251
```
252
253
### Image Segmentation Models
254
255
Models for pixel-level classification and semantic segmentation.
256
257
```python { .api }
258
class DeepLabV3Backbone(Backbone):
259
"""DeepLab V3 backbone for semantic segmentation."""
260
def __init__(
261
self,
262
image_shape: tuple = (512, 512, 3),
263
low_level_feature_key: str = "P2",
264
spatial_pyramid_pooling_key: str = "P5",
265
**kwargs
266
): ...
267
268
class DeepLabV3ImageSegmenter(ImageSegmenter):
269
"""DeepLab V3 model for image segmentation."""
270
def __init__(
271
self,
272
backbone: DeepLabV3Backbone,
273
num_classes: int,
274
preprocessor: Preprocessor = None,
275
**kwargs
276
): ...
277
278
class DeepLabV3ImageSegmenterPreprocessor:
279
"""Preprocessor for DeepLab V3 segmentation."""
280
def __init__(
281
self,
282
image_converter: ImageConverter,
283
**kwargs
284
): ...
285
286
class DeepLabV3ImageConverter:
287
"""Image converter for DeepLab V3 models."""
288
def __init__(
289
self,
290
height: int = 512,
291
width: int = 512,
292
crop_to_aspect_ratio: bool = True,
293
interpolation: str = "bilinear",
294
**kwargs
295
): ...
296
297
class BASNetBackbone(Backbone):
298
"""BASNet backbone for boundary-aware salient object detection."""
299
def __init__(
300
self,
301
image_shape: tuple = (224, 224, 3),
302
**kwargs
303
): ...
304
305
class BASNetImageSegmenter(ImageSegmenter):
306
"""BASNet model for image segmentation."""
307
def __init__(
308
self,
309
backbone: BASNetBackbone,
310
preprocessor: Preprocessor = None,
311
**kwargs
312
): ...
313
314
class BASNetPreprocessor:
315
"""Preprocessor for BASNet segmentation."""
316
def __init__(
317
self,
318
image_converter: ImageConverter,
319
**kwargs
320
): ...
321
322
class BASNetImageConverter:
323
"""Image converter for BASNet models."""
324
def __init__(
325
self,
326
height: int = 224,
327
width: int = 224,
328
crop_to_aspect_ratio: bool = True,
329
interpolation: str = "bilinear",
330
**kwargs
331
): ...
332
333
class SegFormerBackbone(Backbone):
334
"""SegFormer backbone for semantic segmentation."""
335
def __init__(
336
self,
337
image_shape: tuple = (512, 512, 3),
338
num_layers: list = [2, 2, 2, 2],
339
hidden_dims: list = [32, 64, 160, 256],
340
**kwargs
341
): ...
342
343
class SegFormerImageSegmenter(ImageSegmenter):
344
"""SegFormer model for image segmentation."""
345
def __init__(
346
self,
347
backbone: SegFormerBackbone,
348
num_classes: int,
349
preprocessor: Preprocessor = None,
350
**kwargs
351
): ...
352
353
class SegFormerImageSegmenterPreprocessor:
354
"""Preprocessor for SegFormer segmentation."""
355
def __init__(
356
self,
357
image_converter: ImageConverter,
358
**kwargs
359
): ...
360
361
class SegFormerImageConverter:
362
"""Image converter for SegFormer models."""
363
def __init__(
364
self,
365
height: int = 512,
366
width: int = 512,
367
crop_to_aspect_ratio: bool = True,
368
interpolation: str = "bilinear",
369
**kwargs
370
): ...
371
372
class SAMBackbone(Backbone):
373
"""Segment Anything Model backbone."""
374
def __init__(
375
self,
376
image_shape: tuple = (1024, 1024, 3),
377
patch_size: int = 16,
378
num_layers: int = 12,
379
num_heads: int = 12,
380
hidden_dim: int = 768,
381
**kwargs
382
): ...
383
384
class SAMImageSegmenter(ImageSegmenter):
385
"""Segment Anything Model for image segmentation."""
386
def __init__(
387
self,
388
backbone: SAMBackbone,
389
preprocessor: Preprocessor = None,
390
**kwargs
391
): ...
392
393
class SAMImageSegmenterPreprocessor:
394
"""Preprocessor for SAM segmentation."""
395
def __init__(
396
self,
397
image_converter: ImageConverter,
398
**kwargs
399
): ...
400
401
class SAMImageConverter:
402
"""Image converter for SAM models."""
403
def __init__(
404
self,
405
height: int = 1024,
406
width: int = 1024,
407
crop_to_aspect_ratio: bool = True,
408
interpolation: str = "bilinear",
409
**kwargs
410
): ...
411
```
412
413
### Additional Image Classification Models
414
415
Other popular architectures for image classification tasks.
416
417
```python { .api }
418
# DenseNet (Densely Connected Networks)
419
class DenseNetBackbone(Backbone): ...
420
class DenseNetImageClassifier(ImageClassifier): ...
421
class DenseNetImageClassifierPreprocessor: ...
422
class DenseNetImageConverter: ...
423
424
# MobileNet (Efficient Mobile Networks)
425
class MobileNetBackbone(Backbone): ...
426
class MobileNetImageClassifier(ImageClassifier): ...
427
class MobileNetImageClassifierPreprocessor: ...
428
class MobileNetImageConverter: ...
429
430
# VGG (Visual Geometry Group)
431
class VGGBackbone(Backbone): ...
432
class VGGImageClassifier(ImageClassifier): ...
433
class VGGImageClassifierPreprocessor: ...
434
class VGGImageConverter: ...
435
436
# Xception
437
class XceptionBackbone(Backbone): ...
438
class XceptionImageClassifier(ImageClassifier): ...
439
class XceptionImageClassifierPreprocessor: ...
440
class XceptionImageConverter: ...
441
442
# DeiT (Data-efficient Image Transformer)
443
class DeiTBackbone(Backbone): ...
444
class DeiTImageClassifier(ImageClassifier): ...
445
class DeiTImageClassifierPreprocessor: ...
446
class DeiTImageConverter: ...
447
448
# CSPNet (Cross Stage Partial Network)
449
class CSPNetBackbone(Backbone): ...
450
class CSPNetImageClassifier(ImageClassifier): ...
451
class CSPNetImageClassifierPreprocessor: ...
452
class CSPNetImageConverter: ...
453
454
# HGNet V2 (High Performance GPU Network V2)
455
class HGNetV2Backbone(Backbone): ...
456
class HGNetV2ImageClassifier(ImageClassifier): ...
457
class HGNetV2ImageClassifierPreprocessor: ...
458
class HGNetV2ImageConverter: ...
459
460
# MiT (Mix Transformer)
461
class MiTBackbone(Backbone): ...
462
class MiTImageClassifier(ImageClassifier): ...
463
class MiTImageClassifierPreprocessor: ...
464
class MiTImageConverter: ...
465
466
# DINOV2 (Self-Supervised Vision Transformer)
467
class DINOV2Backbone(Backbone): ...
468
class DINOV2ImageConverter: ...
469
```
470
471
### Utility Backbones
472
473
Specialized backbone architectures for various computer vision tasks.
474
475
```python { .api }
476
class FeaturePyramidBackbone(Backbone):
477
"""Feature Pyramid Network backbone."""
478
def __init__(
479
self,
480
backbone: Backbone,
481
feature_size: int = 256,
482
**kwargs
483
): ...
484
```
485
486
### Preprocessor Base Classes
487
488
Base classes for image preprocessing.
489
490
```python { .api }
491
class ImageClassifierPreprocessor(Preprocessor):
492
"""Base preprocessor for image classification."""
493
def __init__(
494
self,
495
image_converter: ImageConverter,
496
**kwargs
497
): ...
498
499
class ImageSegmenterPreprocessor(Preprocessor):
500
"""Base preprocessor for image segmentation."""
501
def __init__(
502
self,
503
image_converter: ImageConverter,
504
**kwargs
505
): ...
506
507
class ObjectDetectorPreprocessor(Preprocessor):
508
"""Base preprocessor for object detection."""
509
def __init__(
510
self,
511
image_converter: ImageConverter,
512
**kwargs
513
): ...
514
515
# Alias
516
ImageObjectDetectorPreprocessor = ObjectDetectorPreprocessor
517
```
518
519
## Usage Examples
520
521
### Image Classification with ResNet
522
523
```python
524
import keras_hub
525
import numpy as np
526
527
# Load pretrained ResNet classifier
528
classifier = keras_hub.models.ResNetImageClassifier.from_preset("resnet50_imagenet")
529
530
# Load and preprocess an image
531
# Image should be a numpy array of shape (height, width, channels)
532
image = np.random.random((224, 224, 3)) # Example random image
533
images = np.expand_dims(image, axis=0) # Add batch dimension
534
535
# Predict
536
predictions = classifier.predict(images)
537
print(f"Predictions shape: {predictions.shape}")
538
539
# Get top prediction
540
predicted_class = np.argmax(predictions[0])
541
print(f"Predicted class: {predicted_class}")
542
```
543
544
### Custom Image Classification
545
546
```python
547
import keras_hub
548
549
# Create custom ResNet for binary classification
550
backbone = keras_hub.models.ResNetBackbone.from_preset("resnet50_imagenet")
551
552
classifier = keras_hub.models.ResNetImageClassifier(
553
backbone=backbone,
554
num_classes=2, # Binary classification
555
)
556
557
# Compile model
558
classifier.compile(
559
optimizer="adam",
560
loss="sparse_categorical_crossentropy",
561
metrics=["accuracy"]
562
)
563
564
# Train with your data
565
# classifier.fit(train_images, train_labels, epochs=10)
566
```
567
568
### Object Detection with RetinaNet
569
570
```python
571
import keras_hub
572
573
# Load pretrained RetinaNet detector
574
detector = keras_hub.models.RetinaNetObjectDetector.from_preset("retinanet_resnet50_pascalvoc")
575
576
# Detect objects in image
577
detections = detector.predict(images)
578
579
# Process detections
580
# detections contains bounding boxes, class predictions, and confidence scores
581
print("Detections:", detections)
582
```
583
584
### Image Segmentation with DeepLab V3
585
586
```python
587
import keras_hub
588
589
# Load pretrained segmentation model
590
segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset("deeplabv3_resnet50_pascalvoc")
591
592
# Segment image
593
segmentation_mask = segmenter.predict(images)
594
595
# The output is a segmentation mask with class predictions for each pixel
596
print(f"Segmentation mask shape: {segmentation_mask.shape}")
597
```
598
599
### Using Vision Transformer
600
601
```python
602
import keras_hub
603
604
# Load pretrained ViT
605
vit_classifier = keras_hub.models.ViTImageClassifier.from_preset("vit_base_patch16_224")
606
607
# Classify images
608
predictions = vit_classifier.predict(images)
609
print("ViT predictions:", predictions)
610
```