CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-mmengine

Engine of OpenMMLab projects for training deep learning models based on PyTorch with large-scale training frameworks, configuration management, and monitoring capabilities

Pending
Overview
Eval results
Files

models.mddocs/

Models and Hooks

Comprehensive model management system with base classes, weight initialization, model wrappers for distributed training, and extensive hook system for customizing training behaviors. The system provides foundation classes and utilities for building robust training pipelines.

Capabilities

Base Model Classes

Foundation classes for all models in MMEngine with standardized interfaces for training, validation, and testing.

class BaseModel:
    def __init__(self, init_cfg: dict = None, data_preprocessor: dict = None):
        """
        Base class for all models.
        
        Parameters:
        - init_cfg: Weight initialization configuration
        - data_preprocessor: Data preprocessor configuration
        """

    def forward(self, *args, **kwargs):
        """
        Forward pass implementation.
        
        Parameters:
        - *args: Input arguments
        - **kwargs: Input keyword arguments
        
        Returns:
        Model outputs
        """

    def train_step(self, data, optim_wrapper):
        """
        Training step implementation.
        
        Parameters:
        - data: Input data batch
        - optim_wrapper: Optimizer wrapper
        
        Returns:
        Dictionary containing loss and log variables
        """

    def val_step(self, data):
        """
        Validation step implementation.
        
        Parameters:
        - data: Input data batch
        
        Returns:
        Validation outputs
        """

    def test_step(self, data):
        """
        Test step implementation.
        
        Parameters:
        - data: Input data batch
        
        Returns:
        Test outputs
        """

    def init_weights(self):
        """Initialize model weights."""

    @property
    def device(self):
        """Get model device."""

    def cuda(self, device=None):
        """Move model to CUDA device."""

    def cpu(self):
        """Move model to CPU."""

    def train(self, mode: bool = True):
        """Set training mode."""

    def eval(self):
        """Set evaluation mode."""

Data Preprocessors

Classes for preprocessing input data before feeding to models.

class BaseDataPreprocessor:
    def __init__(self, mean: list = None, std: list = None, pad_size_divisor: int = 1, pad_value: float = 0, bgr_to_rgb: bool = False, rgb_to_bgr: bool = False, non_blocking: bool = False):
        """
        Base data preprocessor.
        
        Parameters:
        - mean: Mean values for normalization
        - std: Standard deviation values for normalization
        - pad_size_divisor: Padding size divisor
        - pad_value: Padding value
        - bgr_to_rgb: Whether to convert BGR to RGB
        - rgb_to_bgr: Whether to convert RGB to BGR
        - non_blocking: Whether to use non-blocking data movement
        """

    def forward(self, data: dict, training: bool = False) -> dict:
        """
        Forward pass for data preprocessing.
        
        Parameters:
        - data: Input data dictionary
        - training: Whether in training mode
        
        Returns:
        Preprocessed data
        """

    def cast_data(self, data):
        """
        Cast data to appropriate types and devices.
        
        Parameters:
        - data: Input data
        
        Returns:
        Casted data
        """

class ImgDataPreprocessor(BaseDataPreprocessor):
    def __init__(self, mean: list = None, std: list = None, pad_size_divisor: int = 1, pad_value: float = 0, bgr_to_rgb: bool = False, rgb_to_bgr: bool = False, batch_augments: list = None):
        """
        Image data preprocessor.
        
        Parameters:
        - mean: RGB mean values for normalization
        - std: RGB std values for normalization
        - pad_size_divisor: Padding size divisor
        - pad_value: Padding value
        - bgr_to_rgb: Whether to convert BGR to RGB
        - rgb_to_bgr: Whether to convert RGB to BGR
        - batch_augments: Batch augmentation transforms
        """

Base Module Classes

Enhanced PyTorch module classes with initialization and utility features.

class BaseModule:
    def __init__(self, init_cfg: dict = None):
        """
        Base module with weight initialization support.
        
        Parameters:
        - init_cfg: Initialization configuration
        """

    def init_weights(self):
        """Initialize module weights."""

class ModuleDict:
    def __init__(self, modules: dict = None):
        """
        Module dictionary container.
        
        Parameters:
        - modules: Dictionary of modules
        """

    def __getitem__(self, key: str):
        """Get module by key."""

    def __setitem__(self, key: str, module):
        """Set module by key."""

    def __delitem__(self, key: str):
        """Delete module by key."""

    def __len__(self) -> int:
        """Get number of modules."""

    def __iter__(self):
        """Iterate over module keys."""

    def keys(self):
        """Get module keys."""

    def values(self):
        """Get module values."""

    def items(self):
        """Get module items."""

class ModuleList:
    def __init__(self, modules: list = None):
        """
        Module list container.
        
        Parameters:
        - modules: List of modules
        """

    def __getitem__(self, idx: int):
        """Get module by index."""

    def __setitem__(self, idx: int, module):
        """Set module by index."""

    def __delitem__(self, idx: int):
        """Delete module by index."""

    def __len__(self) -> int:
        """Get number of modules."""

    def __iter__(self):
        """Iterate over modules."""

    def append(self, module):
        """Append module to list."""

    def extend(self, modules: list):
        """Extend list with modules."""

    def insert(self, index: int, module):
        """Insert module at index."""

class Sequential:
    def __init__(self, *args):
        """
        Sequential module container.
        
        Parameters:
        - *args: Modules to add sequentially
        """

    def forward(self, input):
        """Sequential forward pass."""

Hook System

Comprehensive hook system for customizing training behaviors at different stages.

class Hook:
    priority = 'NORMAL'  # Hook priority level
    
    def before_run(self, runner):
        """Called before training starts."""

    def after_run(self, runner):
        """Called after training ends."""

    def before_train(self, runner):
        """Called before training loop."""

    def after_train(self, runner):
        """Called after training loop."""

    def before_train_epoch(self, runner):
        """Called before each training epoch."""

    def after_train_epoch(self, runner):
        """Called after each training epoch."""

    def before_train_iter(self, runner):
        """Called before each training iteration."""

    def after_train_iter(self, runner):
        """Called after each training iteration."""

    def before_val(self, runner):
        """Called before validation."""

    def after_val(self, runner):
        """Called after validation."""

    def before_val_epoch(self, runner):
        """Called before validation epoch."""

    def after_val_epoch(self, runner):
        """Called after validation epoch."""

    def before_val_iter(self, runner):
        """Called before validation iteration."""

    def after_val_iter(self, runner):
        """Called after validation iteration."""

    def before_save_checkpoint(self, runner, checkpoint: dict):
        """Called before saving checkpoint."""

    def after_load_checkpoint(self, runner, checkpoint: dict):
        """Called after loading checkpoint."""

    def before_test(self, runner):
        """Called before testing."""

    def after_test(self, runner):
        """Called after testing."""

Built-in Hooks

Collection of commonly used hooks for various training scenarios.

class CheckpointHook(Hook):
    def __init__(self, interval: int = -1, by_epoch: bool = True, save_optimizer: bool = True, save_param_scheduler: bool = True, out_dir: str = None, max_keep_ckpts: int = -1, save_last: bool = True, save_best: str = 'auto', rule: str = 'greater', greater_keys: list = None, less_keys: list = None, file_client_args: dict = None, published_keys: list = None):
        """
        Hook for saving checkpoints.
        
        Parameters:
        - interval: Save interval
        - by_epoch: Whether to save by epoch
        - save_optimizer: Whether to save optimizer state
        - save_param_scheduler: Whether to save scheduler state
        - out_dir: Output directory
        - max_keep_ckpts: Maximum checkpoints to keep
        - save_last: Whether to save last checkpoint
        - save_best: Best checkpoint strategy
        - rule: Comparison rule for best checkpoint
        - greater_keys: Keys that should be greater for best
        - less_keys: Keys that should be less for best
        - file_client_args: File client arguments
        - published_keys: Keys to publish in checkpoint
        """

class LoggerHook(Hook):
    def __init__(self, interval: int = 10, ignore_last: bool = True, reset_flag: bool = False, by_epoch: bool = True):
        """
        Hook for logging training information.
        
        Parameters:
        - interval: Logging interval
        - ignore_last: Whether to ignore last incomplete interval
        - reset_flag: Whether to reset log flag
        - by_epoch: Whether to log by epoch
        """

class IterTimerHook(Hook):
    def __init__(self):
        """Hook for timing training iterations."""

class DistSamplerSeedHook(Hook):
    def __init__(self):
        """Hook for setting distributed sampler seed."""

class ParamSchedulerHook(Hook):
    def __init__(self):
        """Hook for parameter scheduling."""

class EMAHook(Hook):
    def __init__(self, ema_type: str = 'ExponentialMovingAverage', momentum: float = 0.0002, update_buffers: bool = False, priority: int = 49):
        """
        Hook for exponential moving average.
        
        Parameters:
        - ema_type: Type of EMA ('ExponentialMovingAverage', 'MomentumAnnealingEMA')
        - momentum: EMA momentum
        - update_buffers: Whether to update buffers
        - priority: Hook priority
        """

class EmptyCacheHook(Hook):
    def __init__(self, before_epoch: bool = False, after_epoch: bool = True, after_iter: bool = False):
        """
        Hook for emptying CUDA cache.
        
        Parameters:
        - before_epoch: Whether to empty before epoch
        - after_epoch: Whether to empty after epoch
        - after_iter: Whether to empty after iteration
        """

class SyncBuffersHook(Hook):
    def __init__(self):
        """Hook for synchronizing model buffers in distributed training."""

class RuntimeInfoHook(Hook):
    def __init__(self, enable_tensorboard: bool = True):
        """
        Hook for collecting runtime information.
        
        Parameters:
        - enable_tensorboard: Whether to enable tensorboard logging
        """

class EarlyStoppingHook(Hook):
    def __init__(self, monitor: str, min_delta: float = 0, patience: int = 5, verbose: bool = False, mode: str = 'min', baseline: float = None, restore_best_weights: bool = False):
        """
        Hook for early stopping.
        
        Parameters:
        - monitor: Metric to monitor
        - min_delta: Minimum change to qualify as improvement
        - patience: Number of epochs with no improvement after which training stops
        - verbose: Whether to print early stopping messages
        - mode: 'min' or 'max' mode
        - baseline: Baseline value for the monitored quantity
        - restore_best_weights: Whether to restore best weights
        """

class ProfilerHook(Hook):
    def __init__(self, by_epoch: bool = True, profile_iters: int = 1, activities: list = None, schedule: dict = None, on_trace_ready: callable = None, record_shapes: bool = False, profile_memory: bool = False, with_stack: bool = False, with_flops: bool = False, json_trace_path: str = None):
        """
        Hook for profiling training performance.
        
        Parameters:
        - by_epoch: Whether to profile by epoch
        - profile_iters: Number of iterations to profile
        - activities: List of activities to profile
        - schedule: Profiling schedule
        - on_trace_ready: Callback for trace ready
        - record_shapes: Whether to record tensor shapes
        - profile_memory: Whether to profile memory
        - with_stack: Whether to record stack traces
        - with_flops: Whether to record FLOPs
        - json_trace_path: Path to save JSON trace
        """

Model Utilities

Utility functions for model operations and management.

def stack_batch(tensors: list, pad_size_divisor: int = 0, pad_value: float = 0) -> torch.Tensor:
    """
    Stack list of tensors into batch tensor.
    
    Parameters:
    - tensors: List of tensors to stack
    - pad_size_divisor: Padding size divisor
    - pad_value: Padding value
    
    Returns:
    Stacked batch tensor
    """

def merge_dict(*dicts: dict) -> dict:
    """
    Merge multiple dictionaries.
    
    Parameters:
    - *dicts: Dictionaries to merge
    
    Returns:
    Merged dictionary
    """

def detect_anomalous_params(loss: torch.Tensor, model: torch.nn.Module) -> dict:
    """
    Detect anomalous parameters (NaN or Inf).
    
    Parameters:
    - loss: Loss tensor
    - model: Model to check
    
    Returns:
    Dictionary of anomalous parameters
    """

def convert_sync_batchnorm(model: torch.nn.Module, process_group=None) -> torch.nn.Module:
    """
    Convert BatchNorm to SyncBatchNorm for distributed training.
    
    Parameters:
    - model: Model to convert
    - process_group: Process group for synchronization
    
    Returns:
    Model with SyncBatchNorm
    """

def revert_sync_batchnorm(model: torch.nn.Module) -> torch.nn.Module:
    """
    Revert SyncBatchNorm back to BatchNorm.
    
    Parameters:
    - model: Model to revert
    
    Returns:
    Model with BatchNorm
    """

Model Wrappers

Wrappers for models to handle distributed training and other special scenarios.

def is_model_wrapper(model) -> bool:
    """
    Check if model is wrapped.
    
    Parameters:
    - model: Model to check
    
    Returns:
    True if model is wrapped
    """

Test-Time Augmentation

Base class for test-time augmentation models.

class BaseTTAModel:
    def __init__(self, module, tta_cfg: dict = None):
        """
        Base test-time augmentation model.
        
        Parameters:
        - module: Base model module
        - tta_cfg: TTA configuration
        """

    def test_step(self, data):
        """
        Test step with augmentation.
        
        Parameters:
        - data: Input data
        
        Returns:
        Augmented test results
        """

    def merge_preds(self, data_samples_list: list):
        """
        Merge predictions from different augmentations.
        
        Parameters:
        - data_samples_list: List of predictions
        
        Returns:
        Merged predictions
        """

Usage Examples

Basic Model Implementation

from mmengine.model import BaseModel
import torch.nn as nn

class MyModel(BaseModel):
    def __init__(self, num_classes=10, init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )
        self.head = nn.Linear(64, num_classes)
    
    def forward(self, inputs):
        x = self.backbone(inputs)
        x = x.flatten(1)
        return self.head(x)
    
    def train_step(self, data, optim_wrapper):
        inputs = data['inputs']
        labels = data['labels']
        
        logits = self(inputs)
        loss = nn.CrossEntropyLoss()(logits, labels)
        
        parsed_loss, log_vars = self.parse_losses({'loss': loss})
        optim_wrapper.update_params(parsed_loss)
        
        return {'loss': parsed_loss, 'log_vars': log_vars}

Custom Hook Implementation

from mmengine.hooks import Hook

class CustomValidationHook(Hook):
    def __init__(self, val_interval=1):
        self.val_interval = val_interval
    
    def after_train_epoch(self, runner):
        if (runner.epoch + 1) % self.val_interval == 0:
            runner.val()
            
            # Custom validation logic
            val_metrics = runner.message_hub.get_scalar('val_acc')
            if val_metrics.current > 0.95:
                runner.logger.info("High accuracy achieved!")

# Register and use hook
runner.register_hook(CustomValidationHook(val_interval=5))

Model with Data Preprocessor

from mmengine.model import BaseModel, ImgDataPreprocessor

model = BaseModel(
    data_preprocessor=dict(
        type='ImgDataPreprocessor',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        bgr_to_rgb=True,
        pad_size_divisor=32
    )
)

Using Built-in Hooks

from mmengine import Runner
from mmengine.hooks import CheckpointHook, LoggerHook, EMAHook

# Configure hooks
default_hooks = dict(
    timer=dict(type='IterTimerHook'),
    logger=dict(type='LoggerHook', interval=100),
    param_scheduler=dict(type='ParamSchedulerHook'),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    checkpoint=dict(
        type='CheckpointHook',
        interval=1,
        save_best='auto',
        max_keep_ckpts=3
    )
)

custom_hooks = [
    dict(type='EMAHook', momentum=0.0002, priority=49)
]

runner = Runner(
    model=model,
    default_hooks=default_hooks,
    custom_hooks=custom_hooks
)

Model Utilities Usage

from mmengine.model import convert_sync_batchnorm, detect_anomalous_params

# Convert model for distributed training
model = convert_sync_batchnorm(model)

# Check for anomalous parameters during training
def training_step(model, data, optimizer):
    loss = model(data)
    
    # Check for anomalies
    anomalous = detect_anomalous_params(loss, model)
    if anomalous:
        print(f"Anomalous parameters detected: {anomalous}")
    
    loss.backward()
    optimizer.step()

Priority-based Hook Ordering

from mmengine.hooks import Hook
from mmengine.runner import get_priority

class HighPriorityHook(Hook):
    priority = 'HIGH'  # or get_priority('HIGH')
    
    def before_train_iter(self, runner):
        # This runs before normal priority hooks
        pass

class LowPriorityHook(Hook):
    priority = 'LOW'
    
    def after_train_iter(self, runner):
        # This runs after normal priority hooks
        pass

Install with Tessl CLI

npx tessl i tessl/pypi-mmengine

docs

configuration.md

dataset.md

distributed.md

fileio.md

index.md

logging.md

models.md

optimization.md

registry.md

training.md

visualization.md

tile.json