CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-keras-hub

Pretrained models for Keras with multi-framework compatibility.

Pending
Overview
Eval results
Files

image-models.mddocs/

Image Models

Comprehensive computer vision models for image classification, object detection, and image segmentation tasks. Keras Hub provides implementations of popular architectures like ResNet, Vision Transformer (ViT), EfficientNet, and specialized models for various visual understanding tasks.

Capabilities

Base Classes

Foundation classes that define the interface for different types of image models.

class ImageClassifier(Task):
    """Base class for image classification models."""
    def __init__(
        self,
        backbone: Backbone,
        num_classes: int,
        preprocessor: Preprocessor = None,
        **kwargs
    ): ...

class ObjectDetector(Task):
    """Base class for object detection models."""
    def __init__(
        self,
        backbone: Backbone,
        num_classes: int,
        preprocessor: Preprocessor = None,
        **kwargs
    ): ...

class ImageSegmenter(Task):
    """Base class for image segmentation models."""
    def __init__(
        self,
        backbone: Backbone,
        num_classes: int,
        preprocessor: Preprocessor = None,
        **kwargs
    ): ...

# Aliases
ImageObjectDetector = ObjectDetector

ResNet (Residual Networks)

Deep residual networks for image classification with skip connections to enable training of very deep networks.

class ResNetBackbone(Backbone):
    """ResNet backbone architecture."""
    def __init__(
        self,
        stackwise_num_filters: list,
        stackwise_num_blocks: list,
        stackwise_num_strides: list,
        block_type: str = "basic_block",
        use_pre_activation: bool = False,
        image_shape: tuple = (224, 224, 3),
        **kwargs
    ): ...

class ResNetImageClassifier(ImageClassifier):
    """ResNet model for image classification."""
    def __init__(
        self,
        backbone: ResNetBackbone,
        num_classes: int,
        preprocessor: Preprocessor = None,
        **kwargs
    ): ...

class ResNetImageClassifierPreprocessor:
    """Preprocessor for ResNet image classification."""
    def __init__(
        self,
        image_converter: ImageConverter,
        **kwargs
    ): ...

class ResNetImageConverter:
    """Image converter for ResNet models."""
    def __init__(
        self,
        height: int = 224,
        width: int = 224,
        crop_to_aspect_ratio: bool = True,
        interpolation: str = "bilinear",
        data_format: str = None,
        **kwargs
    ): ...

Vision Transformer (ViT)

Transformer architecture applied to image classification by treating image patches as sequences.

class ViTBackbone(Backbone):
    """Vision Transformer backbone."""
    def __init__(
        self,
        image_shape: tuple = (224, 224, 3),
        patch_size: int = 16,
        num_layers: int = 12,
        num_heads: int = 12,
        hidden_dim: int = 768,
        mlp_dim: int = 3072,
        dropout: float = 0.1,
        **kwargs
    ): ...

class ViTImageClassifier(ImageClassifier):
    """Vision Transformer for image classification."""
    def __init__(
        self,
        backbone: ViTBackbone,
        num_classes: int,
        preprocessor: Preprocessor = None,
        **kwargs
    ): ...

class ViTImageClassifierPreprocessor:
    """Preprocessor for ViT image classification."""
    def __init__(
        self,
        image_converter: ImageConverter,
        **kwargs
    ): ...

class ViTImageConverter:
    """Image converter for ViT models."""
    def __init__(
        self,
        height: int = 224,
        width: int = 224,
        crop_to_aspect_ratio: bool = True,
        interpolation: str = "bilinear",
        **kwargs
    ): ...

EfficientNet

Scalable convolutional neural network architecture optimized for efficiency.

class EfficientNetBackbone(Backbone):
    """EfficientNet backbone architecture."""
    def __init__(
        self,
        stackwise_kernel_sizes: list,
        stackwise_num_repeats: list,
        stackwise_input_filters: list,
        stackwise_output_filters: list,
        stackwise_expand_ratios: list,
        stackwise_strides: list,
        width_coefficient: float = 1.0,
        depth_coefficient: float = 1.0,
        image_shape: tuple = (224, 224, 3),
        **kwargs
    ): ...

class EfficientNetImageClassifier(ImageClassifier):
    """EfficientNet model for image classification."""
    def __init__(
        self,
        backbone: EfficientNetBackbone,
        num_classes: int,
        preprocessor: Preprocessor = None,
        **kwargs
    ): ...

class EfficientNetImageClassifierPreprocessor:
    """Preprocessor for EfficientNet image classification."""
    def __init__(
        self,
        image_converter: ImageConverter,
        **kwargs
    ): ...

class EfficientNetImageConverter:
    """Image converter for EfficientNet models."""
    def __init__(
        self,
        height: int = 224,
        width: int = 224,
        crop_to_aspect_ratio: bool = True,
        interpolation: str = "bilinear",
        **kwargs
    ): ...

Object Detection Models

Models specialized for detecting and localizing objects in images.

class RetinaNetBackbone(Backbone):
    """RetinaNet backbone for object detection."""
    def __init__(
        self,
        stackwise_num_filters: list,
        stackwise_num_blocks: list,
        stackwise_num_strides: list,
        image_shape: tuple = (512, 512, 3),
        **kwargs
    ): ...

class RetinaNetObjectDetector(ObjectDetector):
    """RetinaNet model for object detection."""
    def __init__(
        self,
        backbone: RetinaNetBackbone,
        num_classes: int,
        preprocessor: Preprocessor = None,
        **kwargs
    ): ...

class RetinaNetObjectDetectorPreprocessor:
    """Preprocessor for RetinaNet object detection."""
    def __init__(
        self,
        image_converter: ImageConverter,
        **kwargs
    ): ...

class RetinaNetImageConverter:
    """Image converter for RetinaNet models."""
    def __init__(
        self,
        height: int = 512,
        width: int = 512,
        crop_to_aspect_ratio: bool = True,
        interpolation: str = "bilinear",
        **kwargs
    ): ...

class ViTDetBackbone(Backbone):
    """Vision Transformer backbone for object detection."""
    def __init__(
        self,
        image_shape: tuple = (1024, 1024, 3),
        patch_size: int = 16,
        num_layers: int = 12,
        num_heads: int = 12,
        hidden_dim: int = 768,
        mlp_dim: int = 3072,
        **kwargs
    ): ...

Image Segmentation Models

Models for pixel-level classification and semantic segmentation.

class DeepLabV3Backbone(Backbone):
    """DeepLab V3 backbone for semantic segmentation."""
    def __init__(
        self,
        image_shape: tuple = (512, 512, 3),
        low_level_feature_key: str = "P2",
        spatial_pyramid_pooling_key: str = "P5",
        **kwargs
    ): ...

class DeepLabV3ImageSegmenter(ImageSegmenter):
    """DeepLab V3 model for image segmentation."""
    def __init__(
        self,
        backbone: DeepLabV3Backbone,
        num_classes: int,
        preprocessor: Preprocessor = None,
        **kwargs
    ): ...

class DeepLabV3ImageSegmenterPreprocessor:
    """Preprocessor for DeepLab V3 segmentation."""
    def __init__(
        self,
        image_converter: ImageConverter,
        **kwargs
    ): ...

class DeepLabV3ImageConverter:
    """Image converter for DeepLab V3 models."""
    def __init__(
        self,
        height: int = 512,
        width: int = 512,
        crop_to_aspect_ratio: bool = True,
        interpolation: str = "bilinear",
        **kwargs
    ): ...

class BASNetBackbone(Backbone):
    """BASNet backbone for boundary-aware salient object detection."""
    def __init__(
        self,
        image_shape: tuple = (224, 224, 3),
        **kwargs
    ): ...

class BASNetImageSegmenter(ImageSegmenter):
    """BASNet model for image segmentation."""
    def __init__(
        self,
        backbone: BASNetBackbone,
        preprocessor: Preprocessor = None,
        **kwargs
    ): ...

class BASNetPreprocessor:
    """Preprocessor for BASNet segmentation."""
    def __init__(
        self,
        image_converter: ImageConverter,
        **kwargs
    ): ...

class BASNetImageConverter:
    """Image converter for BASNet models."""
    def __init__(
        self,
        height: int = 224,
        width: int = 224,
        crop_to_aspect_ratio: bool = True,
        interpolation: str = "bilinear",
        **kwargs
    ): ...

class SegFormerBackbone(Backbone):
    """SegFormer backbone for semantic segmentation."""
    def __init__(
        self,
        image_shape: tuple = (512, 512, 3),
        num_layers: list = [2, 2, 2, 2],
        hidden_dims: list = [32, 64, 160, 256],
        **kwargs
    ): ...

class SegFormerImageSegmenter(ImageSegmenter):
    """SegFormer model for image segmentation."""
    def __init__(
        self,
        backbone: SegFormerBackbone,
        num_classes: int,
        preprocessor: Preprocessor = None,
        **kwargs
    ): ...

class SegFormerImageSegmenterPreprocessor:
    """Preprocessor for SegFormer segmentation."""
    def __init__(
        self,
        image_converter: ImageConverter,
        **kwargs
    ): ...

class SegFormerImageConverter:
    """Image converter for SegFormer models."""
    def __init__(
        self,
        height: int = 512,
        width: int = 512,
        crop_to_aspect_ratio: bool = True,
        interpolation: str = "bilinear",
        **kwargs
    ): ...

class SAMBackbone(Backbone):
    """Segment Anything Model backbone."""
    def __init__(
        self,
        image_shape: tuple = (1024, 1024, 3),
        patch_size: int = 16,
        num_layers: int = 12,
        num_heads: int = 12,
        hidden_dim: int = 768,
        **kwargs
    ): ...

class SAMImageSegmenter(ImageSegmenter):
    """Segment Anything Model for image segmentation."""
    def __init__(
        self,
        backbone: SAMBackbone,
        preprocessor: Preprocessor = None,
        **kwargs
    ): ...

class SAMImageSegmenterPreprocessor:
    """Preprocessor for SAM segmentation."""
    def __init__(
        self,
        image_converter: ImageConverter,
        **kwargs
    ): ...

class SAMImageConverter:
    """Image converter for SAM models."""
    def __init__(
        self,
        height: int = 1024,
        width: int = 1024,
        crop_to_aspect_ratio: bool = True,
        interpolation: str = "bilinear",
        **kwargs
    ): ...

Additional Image Classification Models

Other popular architectures for image classification tasks.

# DenseNet (Densely Connected Networks)
class DenseNetBackbone(Backbone): ...
class DenseNetImageClassifier(ImageClassifier): ...
class DenseNetImageClassifierPreprocessor: ...
class DenseNetImageConverter: ...

# MobileNet (Efficient Mobile Networks)
class MobileNetBackbone(Backbone): ...
class MobileNetImageClassifier(ImageClassifier): ...
class MobileNetImageClassifierPreprocessor: ...
class MobileNetImageConverter: ...

# VGG (Visual Geometry Group)
class VGGBackbone(Backbone): ...
class VGGImageClassifier(ImageClassifier): ...
class VGGImageClassifierPreprocessor: ...
class VGGImageConverter: ...

# Xception
class XceptionBackbone(Backbone): ...
class XceptionImageClassifier(ImageClassifier): ...
class XceptionImageClassifierPreprocessor: ...
class XceptionImageConverter: ...

# DeiT (Data-efficient Image Transformer)
class DeiTBackbone(Backbone): ...
class DeiTImageClassifier(ImageClassifier): ...
class DeiTImageClassifierPreprocessor: ...
class DeiTImageConverter: ...

# CSPNet (Cross Stage Partial Network)
class CSPNetBackbone(Backbone): ...
class CSPNetImageClassifier(ImageClassifier): ...
class CSPNetImageClassifierPreprocessor: ...
class CSPNetImageConverter: ...

# HGNet V2 (High Performance GPU Network V2)
class HGNetV2Backbone(Backbone): ...
class HGNetV2ImageClassifier(ImageClassifier): ...
class HGNetV2ImageClassifierPreprocessor: ...
class HGNetV2ImageConverter: ...

# MiT (Mix Transformer)
class MiTBackbone(Backbone): ...
class MiTImageClassifier(ImageClassifier): ...
class MiTImageClassifierPreprocessor: ...
class MiTImageConverter: ...

# DINOV2 (Self-Supervised Vision Transformer)
class DINOV2Backbone(Backbone): ...
class DINOV2ImageConverter: ...

Utility Backbones

Specialized backbone architectures for various computer vision tasks.

class FeaturePyramidBackbone(Backbone):
    """Feature Pyramid Network backbone."""
    def __init__(
        self,
        backbone: Backbone,
        feature_size: int = 256,
        **kwargs
    ): ...

Preprocessor Base Classes

Base classes for image preprocessing.

class ImageClassifierPreprocessor(Preprocessor):
    """Base preprocessor for image classification."""
    def __init__(
        self,
        image_converter: ImageConverter,
        **kwargs
    ): ...

class ImageSegmenterPreprocessor(Preprocessor):
    """Base preprocessor for image segmentation."""
    def __init__(
        self,
        image_converter: ImageConverter,
        **kwargs
    ): ...

class ObjectDetectorPreprocessor(Preprocessor):
    """Base preprocessor for object detection."""
    def __init__(
        self,
        image_converter: ImageConverter,
        **kwargs
    ): ...

# Alias
ImageObjectDetectorPreprocessor = ObjectDetectorPreprocessor

Usage Examples

Image Classification with ResNet

import keras_hub
import numpy as np

# Load pretrained ResNet classifier
classifier = keras_hub.models.ResNetImageClassifier.from_preset("resnet50_imagenet")

# Load and preprocess an image
# Image should be a numpy array of shape (height, width, channels)
image = np.random.random((224, 224, 3))  # Example random image
images = np.expand_dims(image, axis=0)  # Add batch dimension

# Predict
predictions = classifier.predict(images)
print(f"Predictions shape: {predictions.shape}")

# Get top prediction
predicted_class = np.argmax(predictions[0])
print(f"Predicted class: {predicted_class}")

Custom Image Classification

import keras_hub

# Create custom ResNet for binary classification
backbone = keras_hub.models.ResNetBackbone.from_preset("resnet50_imagenet")

classifier = keras_hub.models.ResNetImageClassifier(
    backbone=backbone,
    num_classes=2,  # Binary classification
)

# Compile model
classifier.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

# Train with your data
# classifier.fit(train_images, train_labels, epochs=10)

Object Detection with RetinaNet

import keras_hub

# Load pretrained RetinaNet detector
detector = keras_hub.models.RetinaNetObjectDetector.from_preset("retinanet_resnet50_pascalvoc")

# Detect objects in image
detections = detector.predict(images)

# Process detections
# detections contains bounding boxes, class predictions, and confidence scores
print("Detections:", detections)

Image Segmentation with DeepLab V3

import keras_hub

# Load pretrained segmentation model
segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset("deeplabv3_resnet50_pascalvoc")

# Segment image
segmentation_mask = segmenter.predict(images)

# The output is a segmentation mask with class predictions for each pixel
print(f"Segmentation mask shape: {segmentation_mask.shape}")

Using Vision Transformer

import keras_hub

# Load pretrained ViT
vit_classifier = keras_hub.models.ViTImageClassifier.from_preset("vit_base_patch16_224")

# Classify images
predictions = vit_classifier.predict(images)
print("ViT predictions:", predictions)

Install with Tessl CLI

npx tessl i tessl/pypi-keras-hub

docs

audio-models.md

evaluation-metrics.md

generative-models.md

image-models.md

index.md

layers-components.md

multimodal-models.md

text-generation-sampling.md

text-models.md

tokenizers.md

utilities-helpers.md

tile.json