or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

audio-models.mdevaluation-metrics.mdgenerative-models.mdimage-models.mdindex.mdlayers-components.mdmultimodal-models.mdtext-generation-sampling.mdtext-models.mdtokenizers.mdutilities-helpers.md

image-models.mddocs/

0

# Image Models

1

2

Comprehensive computer vision models for image classification, object detection, and image segmentation tasks. Keras Hub provides implementations of popular architectures like ResNet, Vision Transformer (ViT), EfficientNet, and specialized models for various visual understanding tasks.

3

4

## Capabilities

5

6

### Base Classes

7

8

Foundation classes that define the interface for different types of image models.

9

10

```python { .api }

11

class ImageClassifier(Task):

12

"""Base class for image classification models."""

13

def __init__(

14

self,

15

backbone: Backbone,

16

num_classes: int,

17

preprocessor: Preprocessor = None,

18

**kwargs

19

): ...

20

21

class ObjectDetector(Task):

22

"""Base class for object detection models."""

23

def __init__(

24

self,

25

backbone: Backbone,

26

num_classes: int,

27

preprocessor: Preprocessor = None,

28

**kwargs

29

): ...

30

31

class ImageSegmenter(Task):

32

"""Base class for image segmentation models."""

33

def __init__(

34

self,

35

backbone: Backbone,

36

num_classes: int,

37

preprocessor: Preprocessor = None,

38

**kwargs

39

): ...

40

41

# Aliases

42

ImageObjectDetector = ObjectDetector

43

```

44

45

### ResNet (Residual Networks)

46

47

Deep residual networks for image classification with skip connections to enable training of very deep networks.

48

49

```python { .api }

50

class ResNetBackbone(Backbone):

51

"""ResNet backbone architecture."""

52

def __init__(

53

self,

54

stackwise_num_filters: list,

55

stackwise_num_blocks: list,

56

stackwise_num_strides: list,

57

block_type: str = "basic_block",

58

use_pre_activation: bool = False,

59

image_shape: tuple = (224, 224, 3),

60

**kwargs

61

): ...

62

63

class ResNetImageClassifier(ImageClassifier):

64

"""ResNet model for image classification."""

65

def __init__(

66

self,

67

backbone: ResNetBackbone,

68

num_classes: int,

69

preprocessor: Preprocessor = None,

70

**kwargs

71

): ...

72

73

class ResNetImageClassifierPreprocessor:

74

"""Preprocessor for ResNet image classification."""

75

def __init__(

76

self,

77

image_converter: ImageConverter,

78

**kwargs

79

): ...

80

81

class ResNetImageConverter:

82

"""Image converter for ResNet models."""

83

def __init__(

84

self,

85

height: int = 224,

86

width: int = 224,

87

crop_to_aspect_ratio: bool = True,

88

interpolation: str = "bilinear",

89

data_format: str = None,

90

**kwargs

91

): ...

92

```

93

94

### Vision Transformer (ViT)

95

96

Transformer architecture applied to image classification by treating image patches as sequences.

97

98

```python { .api }

99

class ViTBackbone(Backbone):

100

"""Vision Transformer backbone."""

101

def __init__(

102

self,

103

image_shape: tuple = (224, 224, 3),

104

patch_size: int = 16,

105

num_layers: int = 12,

106

num_heads: int = 12,

107

hidden_dim: int = 768,

108

mlp_dim: int = 3072,

109

dropout: float = 0.1,

110

**kwargs

111

): ...

112

113

class ViTImageClassifier(ImageClassifier):

114

"""Vision Transformer for image classification."""

115

def __init__(

116

self,

117

backbone: ViTBackbone,

118

num_classes: int,

119

preprocessor: Preprocessor = None,

120

**kwargs

121

): ...

122

123

class ViTImageClassifierPreprocessor:

124

"""Preprocessor for ViT image classification."""

125

def __init__(

126

self,

127

image_converter: ImageConverter,

128

**kwargs

129

): ...

130

131

class ViTImageConverter:

132

"""Image converter for ViT models."""

133

def __init__(

134

self,

135

height: int = 224,

136

width: int = 224,

137

crop_to_aspect_ratio: bool = True,

138

interpolation: str = "bilinear",

139

**kwargs

140

): ...

141

```

142

143

### EfficientNet

144

145

Scalable convolutional neural network architecture optimized for efficiency.

146

147

```python { .api }

148

class EfficientNetBackbone(Backbone):

149

"""EfficientNet backbone architecture."""

150

def __init__(

151

self,

152

stackwise_kernel_sizes: list,

153

stackwise_num_repeats: list,

154

stackwise_input_filters: list,

155

stackwise_output_filters: list,

156

stackwise_expand_ratios: list,

157

stackwise_strides: list,

158

width_coefficient: float = 1.0,

159

depth_coefficient: float = 1.0,

160

image_shape: tuple = (224, 224, 3),

161

**kwargs

162

): ...

163

164

class EfficientNetImageClassifier(ImageClassifier):

165

"""EfficientNet model for image classification."""

166

def __init__(

167

self,

168

backbone: EfficientNetBackbone,

169

num_classes: int,

170

preprocessor: Preprocessor = None,

171

**kwargs

172

): ...

173

174

class EfficientNetImageClassifierPreprocessor:

175

"""Preprocessor for EfficientNet image classification."""

176

def __init__(

177

self,

178

image_converter: ImageConverter,

179

**kwargs

180

): ...

181

182

class EfficientNetImageConverter:

183

"""Image converter for EfficientNet models."""

184

def __init__(

185

self,

186

height: int = 224,

187

width: int = 224,

188

crop_to_aspect_ratio: bool = True,

189

interpolation: str = "bilinear",

190

**kwargs

191

): ...

192

```

193

194

### Object Detection Models

195

196

Models specialized for detecting and localizing objects in images.

197

198

```python { .api }

199

class RetinaNetBackbone(Backbone):

200

"""RetinaNet backbone for object detection."""

201

def __init__(

202

self,

203

stackwise_num_filters: list,

204

stackwise_num_blocks: list,

205

stackwise_num_strides: list,

206

image_shape: tuple = (512, 512, 3),

207

**kwargs

208

): ...

209

210

class RetinaNetObjectDetector(ObjectDetector):

211

"""RetinaNet model for object detection."""

212

def __init__(

213

self,

214

backbone: RetinaNetBackbone,

215

num_classes: int,

216

preprocessor: Preprocessor = None,

217

**kwargs

218

): ...

219

220

class RetinaNetObjectDetectorPreprocessor:

221

"""Preprocessor for RetinaNet object detection."""

222

def __init__(

223

self,

224

image_converter: ImageConverter,

225

**kwargs

226

): ...

227

228

class RetinaNetImageConverter:

229

"""Image converter for RetinaNet models."""

230

def __init__(

231

self,

232

height: int = 512,

233

width: int = 512,

234

crop_to_aspect_ratio: bool = True,

235

interpolation: str = "bilinear",

236

**kwargs

237

): ...

238

239

class ViTDetBackbone(Backbone):

240

"""Vision Transformer backbone for object detection."""

241

def __init__(

242

self,

243

image_shape: tuple = (1024, 1024, 3),

244

patch_size: int = 16,

245

num_layers: int = 12,

246

num_heads: int = 12,

247

hidden_dim: int = 768,

248

mlp_dim: int = 3072,

249

**kwargs

250

): ...

251

```

252

253

### Image Segmentation Models

254

255

Models for pixel-level classification and semantic segmentation.

256

257

```python { .api }

258

class DeepLabV3Backbone(Backbone):

259

"""DeepLab V3 backbone for semantic segmentation."""

260

def __init__(

261

self,

262

image_shape: tuple = (512, 512, 3),

263

low_level_feature_key: str = "P2",

264

spatial_pyramid_pooling_key: str = "P5",

265

**kwargs

266

): ...

267

268

class DeepLabV3ImageSegmenter(ImageSegmenter):

269

"""DeepLab V3 model for image segmentation."""

270

def __init__(

271

self,

272

backbone: DeepLabV3Backbone,

273

num_classes: int,

274

preprocessor: Preprocessor = None,

275

**kwargs

276

): ...

277

278

class DeepLabV3ImageSegmenterPreprocessor:

279

"""Preprocessor for DeepLab V3 segmentation."""

280

def __init__(

281

self,

282

image_converter: ImageConverter,

283

**kwargs

284

): ...

285

286

class DeepLabV3ImageConverter:

287

"""Image converter for DeepLab V3 models."""

288

def __init__(

289

self,

290

height: int = 512,

291

width: int = 512,

292

crop_to_aspect_ratio: bool = True,

293

interpolation: str = "bilinear",

294

**kwargs

295

): ...

296

297

class BASNetBackbone(Backbone):

298

"""BASNet backbone for boundary-aware salient object detection."""

299

def __init__(

300

self,

301

image_shape: tuple = (224, 224, 3),

302

**kwargs

303

): ...

304

305

class BASNetImageSegmenter(ImageSegmenter):

306

"""BASNet model for image segmentation."""

307

def __init__(

308

self,

309

backbone: BASNetBackbone,

310

preprocessor: Preprocessor = None,

311

**kwargs

312

): ...

313

314

class BASNetPreprocessor:

315

"""Preprocessor for BASNet segmentation."""

316

def __init__(

317

self,

318

image_converter: ImageConverter,

319

**kwargs

320

): ...

321

322

class BASNetImageConverter:

323

"""Image converter for BASNet models."""

324

def __init__(

325

self,

326

height: int = 224,

327

width: int = 224,

328

crop_to_aspect_ratio: bool = True,

329

interpolation: str = "bilinear",

330

**kwargs

331

): ...

332

333

class SegFormerBackbone(Backbone):

334

"""SegFormer backbone for semantic segmentation."""

335

def __init__(

336

self,

337

image_shape: tuple = (512, 512, 3),

338

num_layers: list = [2, 2, 2, 2],

339

hidden_dims: list = [32, 64, 160, 256],

340

**kwargs

341

): ...

342

343

class SegFormerImageSegmenter(ImageSegmenter):

344

"""SegFormer model for image segmentation."""

345

def __init__(

346

self,

347

backbone: SegFormerBackbone,

348

num_classes: int,

349

preprocessor: Preprocessor = None,

350

**kwargs

351

): ...

352

353

class SegFormerImageSegmenterPreprocessor:

354

"""Preprocessor for SegFormer segmentation."""

355

def __init__(

356

self,

357

image_converter: ImageConverter,

358

**kwargs

359

): ...

360

361

class SegFormerImageConverter:

362

"""Image converter for SegFormer models."""

363

def __init__(

364

self,

365

height: int = 512,

366

width: int = 512,

367

crop_to_aspect_ratio: bool = True,

368

interpolation: str = "bilinear",

369

**kwargs

370

): ...

371

372

class SAMBackbone(Backbone):

373

"""Segment Anything Model backbone."""

374

def __init__(

375

self,

376

image_shape: tuple = (1024, 1024, 3),

377

patch_size: int = 16,

378

num_layers: int = 12,

379

num_heads: int = 12,

380

hidden_dim: int = 768,

381

**kwargs

382

): ...

383

384

class SAMImageSegmenter(ImageSegmenter):

385

"""Segment Anything Model for image segmentation."""

386

def __init__(

387

self,

388

backbone: SAMBackbone,

389

preprocessor: Preprocessor = None,

390

**kwargs

391

): ...

392

393

class SAMImageSegmenterPreprocessor:

394

"""Preprocessor for SAM segmentation."""

395

def __init__(

396

self,

397

image_converter: ImageConverter,

398

**kwargs

399

): ...

400

401

class SAMImageConverter:

402

"""Image converter for SAM models."""

403

def __init__(

404

self,

405

height: int = 1024,

406

width: int = 1024,

407

crop_to_aspect_ratio: bool = True,

408

interpolation: str = "bilinear",

409

**kwargs

410

): ...

411

```

412

413

### Additional Image Classification Models

414

415

Other popular architectures for image classification tasks.

416

417

```python { .api }

418

# DenseNet (Densely Connected Networks)

419

class DenseNetBackbone(Backbone): ...

420

class DenseNetImageClassifier(ImageClassifier): ...

421

class DenseNetImageClassifierPreprocessor: ...

422

class DenseNetImageConverter: ...

423

424

# MobileNet (Efficient Mobile Networks)

425

class MobileNetBackbone(Backbone): ...

426

class MobileNetImageClassifier(ImageClassifier): ...

427

class MobileNetImageClassifierPreprocessor: ...

428

class MobileNetImageConverter: ...

429

430

# VGG (Visual Geometry Group)

431

class VGGBackbone(Backbone): ...

432

class VGGImageClassifier(ImageClassifier): ...

433

class VGGImageClassifierPreprocessor: ...

434

class VGGImageConverter: ...

435

436

# Xception

437

class XceptionBackbone(Backbone): ...

438

class XceptionImageClassifier(ImageClassifier): ...

439

class XceptionImageClassifierPreprocessor: ...

440

class XceptionImageConverter: ...

441

442

# DeiT (Data-efficient Image Transformer)

443

class DeiTBackbone(Backbone): ...

444

class DeiTImageClassifier(ImageClassifier): ...

445

class DeiTImageClassifierPreprocessor: ...

446

class DeiTImageConverter: ...

447

448

# CSPNet (Cross Stage Partial Network)

449

class CSPNetBackbone(Backbone): ...

450

class CSPNetImageClassifier(ImageClassifier): ...

451

class CSPNetImageClassifierPreprocessor: ...

452

class CSPNetImageConverter: ...

453

454

# HGNet V2 (High Performance GPU Network V2)

455

class HGNetV2Backbone(Backbone): ...

456

class HGNetV2ImageClassifier(ImageClassifier): ...

457

class HGNetV2ImageClassifierPreprocessor: ...

458

class HGNetV2ImageConverter: ...

459

460

# MiT (Mix Transformer)

461

class MiTBackbone(Backbone): ...

462

class MiTImageClassifier(ImageClassifier): ...

463

class MiTImageClassifierPreprocessor: ...

464

class MiTImageConverter: ...

465

466

# DINOV2 (Self-Supervised Vision Transformer)

467

class DINOV2Backbone(Backbone): ...

468

class DINOV2ImageConverter: ...

469

```

470

471

### Utility Backbones

472

473

Specialized backbone architectures for various computer vision tasks.

474

475

```python { .api }

476

class FeaturePyramidBackbone(Backbone):

477

"""Feature Pyramid Network backbone."""

478

def __init__(

479

self,

480

backbone: Backbone,

481

feature_size: int = 256,

482

**kwargs

483

): ...

484

```

485

486

### Preprocessor Base Classes

487

488

Base classes for image preprocessing.

489

490

```python { .api }

491

class ImageClassifierPreprocessor(Preprocessor):

492

"""Base preprocessor for image classification."""

493

def __init__(

494

self,

495

image_converter: ImageConverter,

496

**kwargs

497

): ...

498

499

class ImageSegmenterPreprocessor(Preprocessor):

500

"""Base preprocessor for image segmentation."""

501

def __init__(

502

self,

503

image_converter: ImageConverter,

504

**kwargs

505

): ...

506

507

class ObjectDetectorPreprocessor(Preprocessor):

508

"""Base preprocessor for object detection."""

509

def __init__(

510

self,

511

image_converter: ImageConverter,

512

**kwargs

513

): ...

514

515

# Alias

516

ImageObjectDetectorPreprocessor = ObjectDetectorPreprocessor

517

```

518

519

## Usage Examples

520

521

### Image Classification with ResNet

522

523

```python

524

import keras_hub

525

import numpy as np

526

527

# Load pretrained ResNet classifier

528

classifier = keras_hub.models.ResNetImageClassifier.from_preset("resnet50_imagenet")

529

530

# Load and preprocess an image

531

# Image should be a numpy array of shape (height, width, channels)

532

image = np.random.random((224, 224, 3)) # Example random image

533

images = np.expand_dims(image, axis=0) # Add batch dimension

534

535

# Predict

536

predictions = classifier.predict(images)

537

print(f"Predictions shape: {predictions.shape}")

538

539

# Get top prediction

540

predicted_class = np.argmax(predictions[0])

541

print(f"Predicted class: {predicted_class}")

542

```

543

544

### Custom Image Classification

545

546

```python

547

import keras_hub

548

549

# Create custom ResNet for binary classification

550

backbone = keras_hub.models.ResNetBackbone.from_preset("resnet50_imagenet")

551

552

classifier = keras_hub.models.ResNetImageClassifier(

553

backbone=backbone,

554

num_classes=2, # Binary classification

555

)

556

557

# Compile model

558

classifier.compile(

559

optimizer="adam",

560

loss="sparse_categorical_crossentropy",

561

metrics=["accuracy"]

562

)

563

564

# Train with your data

565

# classifier.fit(train_images, train_labels, epochs=10)

566

```

567

568

### Object Detection with RetinaNet

569

570

```python

571

import keras_hub

572

573

# Load pretrained RetinaNet detector

574

detector = keras_hub.models.RetinaNetObjectDetector.from_preset("retinanet_resnet50_pascalvoc")

575

576

# Detect objects in image

577

detections = detector.predict(images)

578

579

# Process detections

580

# detections contains bounding boxes, class predictions, and confidence scores

581

print("Detections:", detections)

582

```

583

584

### Image Segmentation with DeepLab V3

585

586

```python

587

import keras_hub

588

589

# Load pretrained segmentation model

590

segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset("deeplabv3_resnet50_pascalvoc")

591

592

# Segment image

593

segmentation_mask = segmenter.predict(images)

594

595

# The output is a segmentation mask with class predictions for each pixel

596

print(f"Segmentation mask shape: {segmentation_mask.shape}")

597

```

598

599

### Using Vision Transformer

600

601

```python

602

import keras_hub

603

604

# Load pretrained ViT

605

vit_classifier = keras_hub.models.ViTImageClassifier.from_preset("vit_base_patch16_224")

606

607

# Classify images

608

predictions = vit_classifier.predict(images)

609

print("ViT predictions:", predictions)

610

```