or run

tessl search
Log in

Version

Files

docs

ai-registry.mdclarify.mddata-io.mddebugger.mdevaluation.mdexperiments.mdexplainer-config.mdindex.mdjumpstart.mdlineage.mdmlops.mdmonitoring.mdprocessing.mdremote-functions.mdresources.mds3-utilities.mdserving.mdtraining.mdworkflow-primitives.md
tile.json

training.mddocs/

Training

Model training capabilities including unified ModelTrainer, hyperparameter tuning, fine-tuning for foundation models, and distributed training support.

Package Information

Important: The sagemaker.train module uses lazy loading for some classes. Import paths:

# Available via lazy loading from sagemaker.train
from sagemaker.train import ModelTrainer, Session, get_execution_role
from sagemaker.train import BenchMarkEvaluator, CustomScorerEvaluator, LLMAsJudgeEvaluator
from sagemaker.train import get_benchmarks, get_builtin_metrics

# Requires full module path (not lazy-loaded)
from sagemaker.train.sft_trainer import SFTTrainer
from sagemaker.train.dpo_trainer import DPOTrainer
from sagemaker.train.rlaif_trainer import RLAIFTrainer
from sagemaker.train.rlvr_trainer import RLVRTrainer
from sagemaker.train.tuner import HyperparameterTuner, WarmStartTypes

Capabilities

ModelTrainer

Main class for training models using AWS SageMaker, replacing V2 Estimator and framework-specific classes.

class ModelTrainer:
    """
    Train models on SageMaker or locally.

    Parameters:
        training_mode: Mode - LOCAL_CONTAINER or SAGEMAKER_TRAINING_JOB (default: SAGEMAKER_TRAINING_JOB)
        sagemaker_session: Optional[Session] - SageMaker session (default: creates new session)
        role: Optional[str] - IAM role ARN (required for SageMaker training)
        base_job_name: Optional[str] - Base name for training jobs (default: auto-generated)
        source_code: Optional[SourceCode] - Source code configuration for custom training scripts
        distributed: Optional[Union[Torchrun, MPI]] - Distributed training configuration
        compute: Optional[Compute] - Compute resource configuration (required)
        networking: Optional[Networking] - Network configuration for VPC deployments
        stopping_condition: Optional[StoppingCondition] - Stopping criteria (default: 1 day max runtime)
        algorithm_name: Optional[str] - Built-in algorithm name (e.g., "xgboost", "kmeans")
        training_image: Optional[str] - Training container image URI (required unless algorithm_name provided)
        training_image_config: Optional[TrainingImageConfig] - Image configuration for private registries
        output_data_config: Optional[OutputDataConfig] - Output configuration (default: default S3 bucket)
        input_data_config: Optional[List[Union[Channel, InputData]]] - Input data channels
        checkpoint_config: Optional[CheckpointConfig] - Checkpoint configuration for resumable training
        training_input_mode: Optional[str] - Input mode: "File" or "Pipe" (default: "File")
        environment: Optional[Dict[str, str]] - Environment variables passed to container
        hyperparameters: Optional[Union[Dict[str, Any], str]] - Training hyperparameters (values converted to strings)
        tags: Optional[List[Tag]] - Resource tags for cost tracking
        local_container_root: Optional[str] - Local container root directory (for LOCAL_CONTAINER mode)

    Methods:
        train(input_data_config=None, wait=True, logs=True) -> None
            Train the model.
            
            Parameters:
                input_data_config: Optional[List[InputData]] - Input data (overrides constructor value)
                wait: bool - Block until training completes (default: True)
                logs: bool - Show CloudWatch logs (default: True)
            
            Raises:
                ValueError: If required parameters missing or invalid
                ClientError: If AWS API call fails
                RuntimeError: If training job fails
        
        create_input_data_channel(channel_name, data_source, key_prefix=None, ignore_patterns=None) -> InputData
            Create input data channel from local files or S3.
            
            Parameters:
                channel_name: str - Channel name (e.g., "training", "validation")
                data_source: str - Local path or S3 URI
                key_prefix: Optional[str] - S3 key prefix for local uploads
                ignore_patterns: Optional[List[str]] - Patterns to ignore for local uploads
            
            Returns:
                InputData: Configured input data channel
        
        with_tensorboard_output_config(tensorboard_output_config) -> ModelTrainer
            Configure TensorBoard logging.
            
            Parameters:
                tensorboard_output_config: TensorBoardOutputConfig - TensorBoard configuration
            
            Returns:
                ModelTrainer: Self for method chaining
        
        with_retry_strategy(retry_strategy) -> ModelTrainer
            Set retry strategy for training job.
            
            Parameters:
                retry_strategy: RetryStrategy - Retry configuration
            
            Returns:
                ModelTrainer: Self for method chaining
        
        with_checkpoint_config(checkpoint_config) -> ModelTrainer
            Set checkpoint configuration.
            
            Parameters:
                checkpoint_config: CheckpointConfig - Checkpoint configuration
            
            Returns:
                ModelTrainer: Self for method chaining
        
        with_metric_definitions(metric_definitions) -> ModelTrainer
            Set metric definitions for extracting metrics from logs.
            
            Parameters:
                metric_definitions: List[MetricDefinition] - Metric extraction patterns
            
            Returns:
                ModelTrainer: Self for method chaining

    Class Methods:
        from_recipe(training_recipe, compute, ...) -> ModelTrainer
            Create from training recipe configuration.
            
        from_jumpstart_config(jumpstart_config, ...) -> ModelTrainer
            Create from JumpStart configuration.

    Attributes:
        _latest_training_job: TrainingJob - Most recent training job resource
        
    Raises:
        ValueError: Invalid configuration parameters
        ClientError: AWS API errors (permission, quota, validation)
        RuntimeError: Training execution errors
    """

Usage:

from sagemaker.train import ModelTrainer, Session
from sagemaker.train.configs import InputData, Compute, SourceCode

# Create trainer with basic configuration
trainer = ModelTrainer(
    training_image="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0-gpu-py310",
    role="arn:aws:iam::123456789012:role/SageMakerRole",
    compute=Compute(
        instance_type="ml.p3.2xlarge",
        instance_count=1,
        volume_size_in_gb=50
    ),
    source_code=SourceCode(
        source_dir="./code",
        entry_script="train.py",
        requirements="requirements.txt"
    ),
    hyperparameters={
        "epochs": 10,
        "learning_rate": 0.001,
        "batch_size": 32
    }
)

# Configure training data
train_data = InputData(
    channel_name="training",
    data_source="s3://my-bucket/train",
    content_type="application/json"
)

val_data = InputData(
    channel_name="validation",
    data_source="s3://my-bucket/val",
    content_type="application/json"
)

# Train model with error handling
try:
    trainer.train(
        input_data_config=[train_data, val_data],
        wait=True,
        logs=True
    )
    
    # Access training job details
    job = trainer._latest_training_job
    print(f"Training job: {job.training_job_name}")
    print(f"Model artifacts: {job.model_artifacts}")
    
except ClientError as e:
    error_code = e.response['Error']['Code']
    if error_code == 'ResourceLimitExceeded':
        print("Exceeded instance limit. Request quota increase.")
    elif error_code == 'ValidationException':
        print(f"Invalid configuration: {e}")
    else:
        raise
except RuntimeError as e:
    print(f"Training failed: {e}")
    # Check CloudWatch logs for details

Advanced Configuration:

from sagemaker.train.configs import (
    StoppingCondition, CheckpointConfig, 
    Networking, RetryStrategy, MetricDefinition
)

# Configure with all options
trainer = ModelTrainer(
    training_image="my-training-image",
    role=role_arn,
    compute=Compute(
        instance_type="ml.p3.8xlarge",
        instance_count=4,
        volume_size_in_gb=100,
        enable_managed_spot_training=True,  # Cost savings
        keep_alive_period_in_seconds=300  # Warm pool for repeated training
    ),
    networking=Networking(
        subnets=["subnet-xxx", "subnet-yyy"],
        security_group_ids=["sg-xxx"],
        enable_inter_container_traffic_encryption=True
    ),
    stopping_condition=StoppingCondition(
        max_runtime_in_seconds=86400,  # 24 hours
        max_wait_time_in_seconds=172800  # 48 hours for spot
    ),
    checkpoint_config=CheckpointConfig(
        s3_uri="s3://my-bucket/checkpoints",
        local_path="/opt/ml/checkpoints"
    ),
    environment={
        "MY_ENV_VAR": "value",
        "LOG_LEVEL": "INFO"
    },
    tags=[
        {"Key": "Project", "Value": "MyProject"},
        {"Key": "Environment", "Value": "Production"}
    ]
)

# Add metric definitions for custom metrics
trainer.with_metric_definitions([
    MetricDefinition(
        name="train:loss",
        regex="Train Loss: ([0-9\\.]+)"
    ),
    MetricDefinition(
        name="val:accuracy",
        regex="Validation Accuracy: ([0-9\\.]+)"
    )
])

# Add retry strategy for resilience
trainer.with_retry_strategy(
    RetryStrategy(maximum_retry_attempts=2)
)

# Train with full configuration
trainer.train(input_data_config=[train_data, val_data])

Hyperparameter Tuning

Manages hyperparameter tuning jobs for model optimization with support for Bayesian, Random, Grid, and Hyperband strategies.

class HyperparameterTuner:
    """
    Hyperparameter tuning for model optimization.

    Parameters:
        model_trainer: ModelTrainer - The trainer instance to tune (required)
        objective_metric_name: Union[str, PipelineVariable] - Metric to optimize (required)
        hyperparameter_ranges: Dict[str, ParameterRange] - Parameter ranges to search (required unless autotune=True)
        metric_definitions: Optional[List[Dict]] - Metric definitions with 'Name' and 'Regex' keys
        strategy: Union[str, PipelineVariable] - Search strategy: "Bayesian", "Random", "Hyperband", "Grid" (default: "Bayesian")
        objective_type: Union[str, PipelineVariable] - "Maximize" or "Minimize" (default: "Maximize")
        max_jobs: Union[int, PipelineVariable] - Maximum number of training jobs (required)
        max_parallel_jobs: Union[int, PipelineVariable] - Maximum parallel jobs (default: 1)
        max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] - Maximum runtime for entire tuning job (default: unlimited)
        tags: Optional[Tags] - Resource tags
        base_tuning_job_name: Optional[str] - Base tuning job name (default: auto-generated from timestamp)
        warm_start_config: Optional[HyperParameterTuningJobWarmStartConfig] - Warm start config for transfer learning
        strategy_config: Optional[HyperParameterTuningJobStrategyConfig] - Strategy-specific configuration (required for Hyperband)
        completion_criteria_config: Optional[TuningJobCompletionCriteria] - Auto-stop completion criteria
        early_stopping_type: Union[str, PipelineVariable] - Early stopping: "Off", "Auto" (default: "Off")
        model_trainer_name: Optional[str] - Name for multi-trainer configurations
        random_seed: Optional[int] - Random seed for reproducible hyperparameter search (range: 0-2147483647)
        autotune: bool - Enable automatic parameter range selection (default: False)
        hyperparameters_to_keep_static: Optional[List[str]] - Hyperparameters to exclude from auto-tuning when autotune=True

    Methods:
        tune(inputs=None, job_name=None, model_trainer_kwargs=None, wait=True, **kwargs) -> None
            Start hyperparameter tuning job.
            
            Parameters:
                inputs: Optional - Input data for training (overrides model_trainer configuration)
                job_name: Optional[str] - Tuning job name (default: auto-generated)
                model_trainer_kwargs: Optional[Dict] - Additional ModelTrainer kwargs
                wait: bool - Block until tuning completes (default: True)
                **kwargs: Additional tuning job parameters
            
            Raises:
                ValueError: Invalid configuration (e.g., Hyperband with early_stopping_type != "Off")
                ClientError: AWS API errors
        
        stop_tuning_job() -> None
            Stop the latest tuning job immediately.
            
            Raises:
                ClientError: If job cannot be stopped
        
        describe() -> dict
            Get full tuning job details from DescribeHyperParameterTuningJob API.
            
            Returns:
                dict: Complete job description including all training jobs
        
        wait() -> None
            Wait for tuning job to complete.
            
            Raises:
                WaiterError: If job fails or times out
        
        best_training_job() -> str
            Get name of best training job from latest tuning job.
            
            Returns:
                str: Training job name with best objective metric value
            
            Raises:
                RuntimeError: If tuning job not yet started or failed
        
        analytics() -> HyperparameterTuningJobAnalytics
            Get analytics instance with pandas DataFrame of results.
            
            Returns:
                HyperparameterTuningJobAnalytics: Analytics object with dataframe() method
        
        transfer_learning_tuner(additional_parents=None, model_trainer=None) -> HyperparameterTuner
            Create transfer learning tuner from this tuner.
            
            Parameters:
                additional_parents: Optional[Set[str]] - Additional parent tuning job names
                model_trainer: Optional[ModelTrainer] - New trainer (default: reuse current)
            
            Returns:
                HyperparameterTuner: New tuner configured for transfer learning
        
        override_resource_config(instance_configs) -> None
            Override instance configuration for model trainers.
            
            Parameters:
                instance_configs: List[HyperParameterTuningInstanceConfig] - Instance configs

    Class Methods:
        create(model_trainer_dict, objective_metric_name_dict, hyperparameter_ranges_dict, ...) -> HyperparameterTuner
            Create multi-model-trainer tuner.
            
            Parameters:
                model_trainer_dict: Dict[str, ModelTrainer] - Named trainers
                objective_metric_name_dict: Dict[str, str] - Objective metrics per trainer
                hyperparameter_ranges_dict: Dict[str, Dict[str, ParameterRange]] - Ranges per trainer
            
            Returns:
                HyperparameterTuner: Configured multi-trainer tuner

    Attributes:
        latest_tuning_job: HyperParameterTuningJob - Most recent tuning job resource
        
    Raises:
        ValueError: Invalid strategy, parameters, or configuration
        ClientError: AWS API errors
        RuntimeError: Tuning execution errors
    
    Notes:
        - Bayesian strategy: Most efficient for continuous parameters, learns from previous jobs
        - Random strategy: Good baseline, works well for all parameter types
        - Grid strategy: Exhaustive search, best for small search spaces with discrete parameters
        - Hyperband strategy: Multi-fidelity based, dynamically allocates resources, requires strategy_config
        - Early stopping "Auto" may stop underperforming jobs but is not guaranteed
        - For Hyperband, early_stopping_type must be "Off" (has internal early stopping)
        - Warm start types: IdenticalDataAndAlgorithm (same data/image) or TransferLearning (different data/config)
    """

Tuning Strategies:

  • Bayesian (default): Uses Bayesian optimization for intelligent hyperparameter search
  • Random: Random search optimization strategy
  • Hyperband: Multi-fidelity based tuning with dynamic resource allocation (requires strategy_config)
  • Grid: Grid search across all hyperparameter combinations

Early Stopping:

  • Off (default): No early stopping attempted
  • Auto: SageMaker may stop underperforming training jobs (not guaranteed)
  • Note: For Hyperband strategy, early stopping must be "Off" (Hyperband has internal early stopping)

Usage:

from sagemaker.train.tuner import HyperparameterTuner, WarmStartTypes
from sagemaker.core.parameter import ContinuousParameter, IntegerParameter, CategoricalParameter

# Define parameter ranges with scaling
hyperparameter_ranges = {
    "learning_rate": ContinuousParameter(
        min_value=0.001,
        max_value=0.1,
        scaling_type="Logarithmic"  # Better for learning rates
    ),
    "batch_size": IntegerParameter(
        min_value=32,
        max_value=256,
        scaling_type="Linear"
    ),
    "optimizer": CategoricalParameter(
        values=["adam", "sgd", "rmsprop"]
    ),
    "dropout": ContinuousParameter(
        min_value=0.0,
        max_value=0.5,
        scaling_type="Linear"
    )
}

# Create tuner with Bayesian strategy
tuner = HyperparameterTuner(
    model_trainer=trainer,
    objective_metric_name="validation:accuracy",
    hyperparameter_ranges=hyperparameter_ranges,
    metric_definitions=[
        {"Name": "validation:accuracy", "Regex": "Val Acc: ([0-9\\.]+)"}
    ],
    strategy="Bayesian",
    objective_type="Maximize",
    max_jobs=20,
    max_parallel_jobs=3,
    early_stopping_type="Auto"
)

# Start tuning with error handling
try:
    tuner.tune(wait=True)
    
    # Get best training job
    best_job = tuner.best_training_job()
    print(f"Best training job: {best_job}")
    
    # Access analytics
    analytics = tuner.analytics()
    df = analytics.dataframe()  # Pandas DataFrame with all results
    print(df.sort_values('FinalObjectiveValue', ascending=False).head())
    
except ValueError as e:
    print(f"Invalid configuration: {e}")
except ClientError as e:
    print(f"AWS API error: {e}")

Hyperband Strategy:

from sagemaker.core.shapes import HyperbandStrategyConfig, HyperParameterTuningJobStrategyConfig

# Configure Hyperband strategy
hyperband_config = HyperbandStrategyConfig(
    min_resource=1,  # Minimum epochs/resources for training jobs
    max_resource=27,  # Maximum epochs/resources
    reduction_factor=3,  # Factor for reducing configurations (typically 2-4)
    number_of_brackets=4  # Number of bracket levels (more brackets = more exploration)
)

strategy_config = HyperParameterTuningJobStrategyConfig(
    hyperband_strategy_config=hyperband_config
)

# Create tuner with Hyperband
tuner = HyperparameterTuner(
    model_trainer=trainer,
    objective_metric_name="validation:accuracy",
    hyperparameter_ranges=hyperparameter_ranges,
    strategy="Hyperband",
    strategy_config=strategy_config,
    early_stopping_type="Off",  # Required for Hyperband
    max_jobs=100,
    max_parallel_jobs=5
)

tuner.tune()

Warm Start (Transfer Learning):

from sagemaker.core.shapes import (
    HyperParameterTuningJobWarmStartConfig,
    ParentHyperParameterTuningJob
)
from sagemaker.train.tuner import WarmStartTypes

# Configure warm start from parent tuning job
warm_start_config = HyperParameterTuningJobWarmStartConfig(
    parent_hyper_parameter_tuning_jobs=[
        ParentHyperParameterTuningJob(
            hyper_parameter_tuning_job_name="my-previous-tuning-job"
        )
    ],
    warm_start_type=WarmStartTypes.TRANSFER_LEARNING.value
)

# Create tuner with warm start
tuner = HyperparameterTuner(
    model_trainer=trainer,
    objective_metric_name="validation:accuracy",
    hyperparameter_ranges=hyperparameter_ranges,
    warm_start_config=warm_start_config,
    max_jobs=20,
    max_parallel_jobs=3
)

# Or use transfer_learning_tuner helper
new_tuner = tuner.transfer_learning_tuner(
    additional_parents={"parent-job-1", "parent-job-2"}
)
new_tuner.tune()

Warm Start Types Enum:

class WarmStartTypes(Enum):
    """
    Enum for hyperparameter tuning warm start types.

    Values:
        IDENTICAL_DATA_AND_ALGORITHM: "IdenticalDataAndAlgorithm"
            - Same input data and training image required
            - Can change: hyperparameter ranges, max training jobs
            - Cannot change: algorithm version, input data
            
        TRANSFER_LEARNING: "TransferLearning"
            - Different data, ranges, or training configuration allowed
            - Can change: input data, hyperparameter ranges, training image version
            - Use when continuing tuning with different dataset or updated model
    """
    IDENTICAL_DATA_AND_ALGORITHM = "IdenticalDataAndAlgorithm"
    TRANSFER_LEARNING = "TransferLearning"

Completion Criteria:

from sagemaker.core.shapes import (
    TuningJobCompletionCriteria,
    BestObjectiveNotImproving,
    ConvergenceDetected
)

# Auto-stop when target metric reached or convergence detected
completion_criteria = TuningJobCompletionCriteria(
    target_objective_metric_value=0.95,  # Stop when metric reaches this value
    best_objective_not_improving=BestObjectiveNotImproving(
        max_number_of_training_jobs_not_improving=10  # Stop if no 1%+ improvement
    ),
    convergence_detected=ConvergenceDetected(
        complete_on_convergence="Enabled"  # Stop if AMT detects convergence
    )
)

tuner = HyperparameterTuner(
    model_trainer=trainer,
    objective_metric_name="validation:accuracy",
    hyperparameter_ranges=hyperparameter_ranges,
    completion_criteria_config=completion_criteria,
    max_jobs=100  # May complete with fewer jobs due to criteria
)

Multi-Model-Trainer Tuning:

# Create tuner for multiple trainers
tuner = HyperparameterTuner.create(
    model_trainer_dict={
        "trainer1": trainer1,
        "trainer2": trainer2
    },
    objective_metric_name_dict={
        "trainer1": "val:accuracy",
        "trainer2": "val:f1"
    },
    hyperparameter_ranges_dict={
        "trainer1": {
            "learning_rate": ContinuousParameter(0.001, 0.1)
        },
        "trainer2": {
            "learning_rate": ContinuousParameter(0.0001, 0.01)
        }
    },
    strategy="Bayesian",
    max_jobs=20,
    max_parallel_jobs=4
)

AutoTune Feature:

# Enable automatic parameter range selection
tuner = HyperparameterTuner(
    model_trainer=trainer,
    objective_metric_name="validation:accuracy",
    hyperparameter_ranges={},  # Empty when autotune=True
    autotune=True,
    hyperparameters_to_keep_static=["epochs", "batch_size"],  # Don't tune these
    max_jobs=50,
    max_parallel_jobs=5
)

# SageMaker will automatically determine optimal ranges for tunable hyperparameters
tuner.tune()

HyperparameterTuningJobAnalytics

Analytics class for accessing tuning job results as pandas DataFrame.

class HyperparameterTuningJobAnalytics:
    """
    Analytics for hyperparameter tuning jobs.

    Parameters:
        tuning_job_name: str - Name of completed tuning job
        sagemaker_session: Optional[Session] - SageMaker session

    Attributes:
        name: str - Name of tuning job
        tuning_ranges: Dict[str, ParameterRange] - Dictionary of parameter ranges used

    Methods:
        description() -> dict
            Get full job description from DescribeHyperParameterTuningJob API.
            
            Returns:
                dict: Complete tuning job metadata and configuration
        
        training_job_summaries() -> List[Dict]
            Get all training job summaries from tuning job.
            
            Returns:
                List[Dict]: List of training job summaries with parameters and metrics
        
        dataframe() -> pandas.DataFrame
            Pandas DataFrame with all tuning results for analysis.
            
            Returns:
                pandas.DataFrame: Results with columns for parameters, metrics, and metadata

    DataFrame Columns:
        - Tuned hyperparameter values (one column per hyperparameter)
        - TrainingJobName: Name of training job
        - TrainingJobStatus: Status (Completed, Failed, Stopped)
        - FinalObjectiveValue: Final metric value
        - TrainingStartTime: Training start timestamp
        - TrainingEndTime: Training end timestamp
        - TrainingElapsedTimeSeconds: Training duration in seconds
        - TrainingJobDefinitionName: Training job definition name
    
    Usage:
        Analyze tuning results, identify best parameters, visualize search progress
    """

Usage:

from sagemaker.train.tuner import HyperparameterTuningJobAnalytics
import pandas as pd
import matplotlib.pyplot as plt

# Create analytics instance
analytics = HyperparameterTuningJobAnalytics(
    tuning_job_name="my-tuning-job-2024-01-01-12-00-00"
)

# Get results as DataFrame
df = analytics.dataframe()

# Analyze results
print("Best 5 configurations:")
print(df.sort_values('FinalObjectiveValue', ascending=False).head())

# Analyze parameter impact
print("\nLearning rate vs Accuracy:")
print(df.groupby('learning_rate')['FinalObjectiveValue'].mean())

# Visualize search progress
plt.figure(figsize=(10, 6))
plt.plot(range(len(df)), df['FinalObjectiveValue'].cummax())
plt.xlabel('Training Job Number')
plt.ylabel('Best Objective Value')
plt.title('Hyperparameter Tuning Progress')
plt.show()

# Export best configuration
best_job = df.loc[df['FinalObjectiveValue'].idxmax()]
best_params = {
    col: best_job[col] 
    for col in df.columns 
    if col in analytics.tuning_ranges.keys()
}
print(f"\nBest hyperparameters: {best_params}")

Fine-Tuning for Foundation Models

Fine-tuning trainers for foundation models with support for Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and Reinforcement Learning techniques.

SFTTrainer

Supervised Fine-Tuning trainer for foundation models using instruction-following or task-specific datasets.

class SFTTrainer:
    """
    Supervised Fine-Tuning trainer for foundation models.

    Parameters:
        model: Union[str, ModelPackage] - Foundation model to fine-tune
            - Model name/ID (e.g., "meta-llama/Llama-2-7b-hf")
            - Model ARN (e.g., "arn:aws:sagemaker:...")
            - ModelPackage object
            (required)
        
        training_type: Union[TrainingType, str] - Fine-tuning approach (default: TrainingType.LORA)
            - TrainingType.LORA: Parameter-efficient, faster, lower cost
            - TrainingType.FULL: All parameters updated, higher quality for complex tasks
        
        model_package_group: Optional[Union[str, ModelPackageGroup]] - Model package group for output
            - Group name (e.g., "my-fine-tuned-models")
            - Group ARN
            - ModelPackageGroup object
        
        mlflow_resource_arn: Optional[str] - MLflow tracking server ARN for experiment tracking
        mlflow_experiment_name: Optional[str] - MLflow experiment name
        mlflow_run_name: Optional[str] - MLflow run name
        
        training_dataset: Optional[Union[str, DataSet]] - Training dataset (required at train() time if not provided)
            - S3 URI (e.g., "s3://bucket/train.jsonl")
            - Dataset ARN
            - DataSet object from AI Registry
        
        validation_dataset: Optional[Union[str, DataSet]] - Validation dataset
            - Same formats as training_dataset
        
        s3_output_path: Optional[str] - S3 path for outputs
            - Defaults to: s3://sagemaker-<region>-<account>/output
        
        kms_key_id: Optional[str] - KMS key ID for output encryption
        
        networking: Optional[VpcConfig] - VPC configuration for secure training
        
        accept_eula: Optional[bool] - Accept EULA for gated models (default: False)
            - Required as True for models with usage agreements

    Methods:
        train(training_dataset=None, validation_dataset=None, wait=True) -> TrainingJob
            Execute supervised fine-tuning training job.
            
            Parameters:
                training_dataset: Optional[Union[str, DataSet]] - Override constructor dataset
                validation_dataset: Optional[Union[str, DataSet]] - Override constructor dataset
                wait: bool - Block until training completes (default: True)
            
            Returns:
                TrainingJob: Completed training job resource
            
            Raises:
                ValueError: If training_dataset not provided either in constructor or train()
                ValueError: If accept_eula=False for gated model
                ClientError: AWS API errors

    Attributes:
        hyperparameters: FineTuningOptions - Dynamic hyperparameters with validation
            - Access via: trainer.hyperparameters.parameter_name
            - Common parameters:
                - epochs: int (default: 3)
                - learning_rate: float (default: 2e-4)
                - per_device_train_batch_size: int (default: 4)
                - per_device_eval_batch_size: int (default: 8)
                - gradient_accumulation_steps: int (default: 1)
                - lora_r: int (default: 8, only for LORA)
                - lora_alpha: int (default: 16, only for LORA)
                - lora_dropout: float (default: 0.05, only for LORA)
        
        _latest_training_job: TrainingJob - Most recent training job
    
    Raises:
        ValueError: Invalid configuration or missing required parameters
        ClientError: AWS API errors (permissions, quotas)
        RuntimeError: Training execution errors
    
    Notes:
        - Training dataset format: JSONL with "prompt" and "completion" fields
        - Validation dataset optional but recommended for monitoring
        - LORA typically 5-10x faster and cheaper than FULL fine-tuning
        - Use FULL training for significant domain shift or when quality critical
        - Gated models (Llama, etc.) require accept_eula=True
        - Output model registered to Model Registry if model_package_group specified
    """

Dataset Format:

Training datasets must be in JSONL format with specific fields:

{"prompt": "Question: What is machine learning?", "completion": "Machine learning is..."}
{"prompt": "Question: Explain neural networks.", "completion": "Neural networks are..."}

Usage:

from sagemaker.train.sft_trainer import SFTTrainer
from sagemaker.train.common import TrainingType

# Create SFT trainer for Llama model
trainer = SFTTrainer(
    model="meta-llama/Llama-2-7b-hf",
    training_type=TrainingType.LORA,
    model_package_group="my-llm-models",
    training_dataset="s3://my-bucket/train.jsonl",
    validation_dataset="s3://my-bucket/val.jsonl",
    s3_output_path="s3://my-bucket/output",
    accept_eula=True  # Required for Llama models
)

# Configure hyperparameters
trainer.hyperparameters.epochs = 3
trainer.hyperparameters.learning_rate = 2e-4
trainer.hyperparameters.per_device_train_batch_size = 4
trainer.hyperparameters.gradient_accumulation_steps = 4  # Effective batch size = 4*4 = 16

# LoRA-specific parameters
trainer.hyperparameters.lora_r = 8  # LoRA rank
trainer.hyperparameters.lora_alpha = 16  # LoRA alpha
trainer.hyperparameters.lora_dropout = 0.05
trainer.hyperparameters.target_modules = ["q_proj", "v_proj"]  # Which layers to adapt

# Train model with error handling
try:
    job = trainer.train(wait=True)
    print(f"Training completed: {job.training_job_name}")
    print(f"Model artifacts: {job.model_artifacts}")
except ValueError as e:
    print(f"Configuration error: {e}")
except ClientError as e:
    print(f"AWS API error: {e}")

Full Fine-Tuning:

# Use FULL training for maximum quality
trainer = SFTTrainer(
    model="meta-llama/Llama-2-7b-hf",
    training_type=TrainingType.FULL,  # Updates all model parameters
    training_dataset="s3://my-bucket/large-train-dataset.jsonl",
    validation_dataset="s3://my-bucket/val.jsonl",
    accept_eula=True
)

# Adjust hyperparameters for full training
trainer.hyperparameters.epochs = 1  # Typically fewer epochs needed
trainer.hyperparameters.learning_rate = 1e-5  # Lower learning rate
trainer.hyperparameters.per_device_train_batch_size = 1  # Smaller batch due to memory
trainer.hyperparameters.gradient_accumulation_steps = 16  # Larger effective batch

job = trainer.train()

DPOTrainer

Direct Preference Optimization trainer for aligning models with human preferences using preference pair datasets.

class DPOTrainer:
    """
    Direct Preference Optimization trainer for foundation models.

    Parameters:
        model: Union[str, ModelPackage] - Foundation model to fine-tune (required)
            - Model name/ID, ARN, or ModelPackage object
        
        training_type: Union[TrainingType, str] - Fine-tuning approach (default: TrainingType.LORA)
            - TrainingType.LORA or TrainingType.FULL
        
        model_package_group: Optional[Union[str, ModelPackageGroup]] - Model package group for output
        
        mlflow_resource_arn: Optional[str] - MLflow tracking server ARN
        mlflow_experiment_name: Optional[str] - MLflow experiment name
        mlflow_run_name: Optional[str] - MLflow run name
        
        training_dataset: Optional[Union[str, DataSet]] - Training dataset with preference pairs
            - Must contain "prompt", "chosen", "rejected" fields
        
        validation_dataset: Optional[Union[str, DataSet]] - Validation dataset
        
        s3_output_path: Optional[str] - S3 path for outputs
        
        kms_key_id: Optional[str] - KMS key ID for encryption
        
        networking: Optional[VpcConfig] - VPC configuration
        
        accept_eula: bool - Accept EULA for gated models (default: False)

    Methods:
        train(training_dataset=None, validation_dataset=None, wait=True) -> TrainingJob
            Execute DPO training.
            
            Returns:
                TrainingJob: Completed training job
            
            Raises:
                ValueError: Invalid configuration or dataset format
                ClientError: AWS API errors

    Attributes:
        hyperparameters: FineTuningOptions - Dynamic hyperparameters
            - epochs, learning_rate, batch_size, etc.
            - beta: float - DPO temperature parameter (default: 0.1)
        
        _latest_training_job: TrainingJob - Most recent training job
    
    Notes:
        - Dataset format: JSONL with "prompt", "chosen", "rejected" fields
        - DPO directly optimizes for preference without reward model
        - More stable and efficient than RLHF
        - Requires high-quality preference data
    """

Dataset Format for DPO:

{"prompt": "Explain quantum computing", "chosen": "Quantum computing uses...", "rejected": "It's magic..."}
{"prompt": "Write a poem", "chosen": "Roses are red...", "rejected": "asdfghjkl"}

Usage:

from sagemaker.train.dpo_trainer import DPOTrainer

# Create DPO trainer with preference pairs dataset
trainer = DPOTrainer(
    model="meta-llama/Llama-2-7b-hf",
    model_package_group="aligned-models",
    training_dataset="s3://my-bucket/preferences.jsonl",
    accept_eula=True
)

# Configure DPO-specific parameters
trainer.hyperparameters.beta = 0.1  # Temperature parameter
trainer.hyperparameters.learning_rate = 1e-5
trainer.hyperparameters.epochs = 1

# Train with preference optimization
try:
    job = trainer.train(wait=True)
    print(f"DPO training completed: {job.training_job_name}")
except ValueError as e:
    print(f"Dataset format error - check 'chosen' and 'rejected' fields: {e}")

RLAIFTrainer

Reinforcement Learning from AI Feedback trainer using LLM-based reward signals.

class RLAIFTrainer:
    """
    Reinforcement Learning from AI Feedback trainer.

    Parameters:
        model: Union[str, ModelPackage] - Foundation model to fine-tune (required)
        
        training_type: Union[TrainingType, str] - Fine-tuning approach (default: TrainingType.LORA)
        
        model_package_group: Optional[Union[str, ModelPackageGroup]] - Model package group
        
        reward_model_id: str - Bedrock model ID for generating AI feedback (required)
            Allowed values:
            - "openai.gpt-oss-120b-1:0" (us-west-2, us-east-1, ap-northeast-1, eu-west-1)
            - "openai.gpt-oss-20b-1:0" (us-west-2, us-east-1, ap-northeast-1, eu-west-1)
            - "qwen.qwen3-32b-v1:0" (us-west-2, us-east-1, ap-northeast-1, eu-west-1)
            - "qwen.qwen3-coder-30b-a3b-v1:0" (us-west-2, us-east-1, ap-northeast-1, eu-west-1)
            - "qwen.qwen3-coder-480b-a35b-v1:0" (us-west-2, ap-northeast-1)
            - "qwen.qwen3-235b-a22b-2507-v1:0" (us-west-2, ap-northeast-1)
        
        reward_prompt: Union[str, Evaluator] - Reward prompt or evaluator (required)
            Built-in metrics:
            - "Builtin.Helpfulness" - Response helpfulness
            - "Builtin.Harmlessness" - Safety and harmlessness
            - "Builtin.Honesty" - Truthfulness and accuracy
            - "Builtin.Conciseness" - Response conciseness
            Or custom Evaluator object
        
        mlflow_resource_arn: Optional[Union[str, MlflowTrackingServer]] - MLflow tracking
        mlflow_experiment_name: Optional[str] - MLflow experiment name
        mlflow_run_name: Optional[str] - MLflow run name
        
        training_dataset: Optional[Union[str, DataSet]] - Training dataset
        validation_dataset: Optional[Union[str, DataSet]] - Validation dataset
        
        s3_output_path: Optional[str] - S3 path for outputs
        kms_key_id: Optional[str] - KMS key ID for encryption
        networking: Optional[VpcConfig] - VPC configuration
        accept_eula: bool - Accept EULA for gated models (default: False)

    Methods:
        train(training_dataset=None, validation_dataset=None, wait=True) -> TrainingJob
            Execute RLAIF training with AI feedback.
            
            Returns:
                TrainingJob: Completed training job
            
            Raises:
                ValueError: Invalid reward_model_id or reward_prompt
                ClientError: AWS API errors

    Attributes:
        reward_model_id: str - Bedrock reward model identifier
        reward_prompt: Union[str, Evaluator] - Reward prompt configuration
        hyperparameters: FineTuningOptions - Dynamic hyperparameters
        _latest_training_job: TrainingJob - Most recent training job
        _evaluator_arn: str - Evaluator ARN (if using custom evaluator)
    
    Notes:
        - Reward model must be available in your AWS region
        - Built-in metrics provide standardized evaluation
        - Custom evaluators allow domain-specific reward signals
        - Training dataset: prompts for model to generate responses
        - Reward model evaluates generated responses during training
    """

Usage:

from sagemaker.train.rlaif_trainer import RLAIFTrainer

# Create RLAIF trainer with built-in metric
trainer = RLAIFTrainer(
    model="meta-llama/Llama-2-7b-hf",
    model_package_group="rl-models",
    reward_model_id="openai.gpt-oss-120b-1:0",
    reward_prompt="Builtin.Harmlessness",  # Built-in safety metric
    training_dataset="s3://my-bucket/prompts.jsonl",
    accept_eula=True
)

# Configure training
trainer.hyperparameters.epochs = 3
trainer.hyperparameters.learning_rate = 1e-5

# Train with AI feedback
try:
    job = trainer.train(wait=True)
    print(f"RLAIF training completed: {job.training_job_name}")
except ValueError as e:
    if "reward_model_id" in str(e):
        print("Invalid reward model ID or not available in your region")
    raise

Custom Reward Evaluator:

from sagemaker.ai_registry.evaluator import Evaluator

# Create custom evaluator
evaluator = Evaluator.create(
    name="code-quality-evaluator",
    type="RewardFunction",
    source="arn:aws:lambda:us-west-2:123:function:evaluate-code-quality",
    wait=True
)

# Use custom evaluator for domain-specific rewards
trainer = RLAIFTrainer(
    model="codellama/CodeLlama-7b-hf",
    reward_model_id="openai.gpt-oss-120b-1:0",
    reward_prompt=evaluator,  # Custom evaluator
    training_dataset="s3://my-bucket/code-prompts.jsonl",
    accept_eula=True
)

job = trainer.train()

RLVRTrainer

Reinforcement Learning from Verifiable Rewards trainer using custom reward functions.

class RLVRTrainer:
    """
    Reinforcement Learning from Verifiable Rewards trainer.

    Parameters:
        model: Union[str, ModelPackage] - Foundation model to fine-tune (required)
        
        training_type: Union[TrainingType, str] - Fine-tuning approach (default: TrainingType.LORA)
        
        model_package_group: Optional[Union[str, ModelPackageGroup]] - Model package group
        
        custom_reward_function: Optional[Union[str, Evaluator]] - Custom reward function evaluator (required)
            - Evaluator ARN
            - Evaluator object from AI Registry
            - Lambda function ARN
            - Must return numeric reward value
        
        mlflow_resource_arn: Optional[Union[str, MlflowTrackingServer]] - MLflow tracking
        mlflow_experiment_name: Optional[str] - MLflow experiment name
        mlflow_run_name: Optional[str] - MLflow run name
        
        training_dataset: Optional[Union[str, DataSet]] - Training dataset
        validation_dataset: Optional[Union[str, DataSet]] - Validation dataset
        
        s3_output_path: Optional[str] - S3 path for outputs
        kms_key_id: Optional[str] - KMS key ID for encryption
        networking: Optional[VpcConfig] - VPC configuration
        accept_eula: bool - Accept EULA for gated models (default: False)

    Methods:
        train(training_dataset=None, validation_dataset=None, wait=True) -> TrainingJob
            Execute RLVR training with custom rewards.
            
            Returns:
                TrainingJob: Completed training job
            
            Raises:
                ValueError: If custom_reward_function not provided
                ClientError: AWS API errors

    Attributes:
        custom_reward_function: Union[str, Evaluator] - Custom reward evaluator
        hyperparameters: FineTuningOptions - Dynamic hyperparameters
        _latest_training_job: TrainingJob - Most recent training job
    
    Notes:
        - Reward function must return numeric score for model outputs
        - Use for objective, verifiable metrics (code correctness, math accuracy, etc.)
        - Reward function invoked during training to evaluate generated responses
        - Training dataset contains prompts for model to generate responses
    """

Usage:

from sagemaker.train.rlvr_trainer import RLVRTrainer
from sagemaker.ai_registry.evaluator import Evaluator

# Create custom evaluator for code correctness
evaluator = Evaluator.create(
    name="code-correctness-evaluator",
    type="RewardFunction",
    source="./reward_function.py",  # Local file uploaded to Lambda
    wait=True
)

# Create RLVR trainer with custom rewards
trainer = RLVRTrainer(
    model="codellama/CodeLlama-7b-hf",
    model_package_group="rl-code-models",
    custom_reward_function=evaluator,
    training_dataset="s3://my-bucket/coding-tasks.jsonl",
    accept_eula=True
)

# Configure training
trainer.hyperparameters.epochs = 5
trainer.hyperparameters.learning_rate = 1e-5

# Train with verifiable rewards
try:
    job = trainer.train(wait=True)
    print(f"RLVR training completed: {job.training_job_name}")
except ValueError as e:
    print(f"Reward function error: {e}")

Example Reward Function (reward_function.py):

import json

def lambda_handler(event, context):
    """
    Reward function for code correctness.
    
    Args:
        event: Dict with 'prompt' and 'response' fields
    
    Returns:
        Dict with 'reward' field (float between 0 and 1)
    """
    prompt = event['prompt']
    response = event['response']
    
    # Extract test cases from prompt
    test_cases = extract_test_cases(prompt)
    
    # Execute generated code with test cases
    try:
        passed = 0
        for test_input, expected_output in test_cases:
            actual_output = execute_code(response, test_input)
            if actual_output == expected_output:
                passed += 1
        
        reward = passed / len(test_cases)
        
    except Exception as e:
        # Syntax errors or runtime failures get 0 reward
        reward = 0.0
    
    return {
        'reward': reward
    }

FineTuningOptions

Dynamic hyperparameters class with validation for fine-tuning trainers.

class FineTuningOptions:
    """
    Dynamic fine-tuning hyperparameters with validation.

    Parameters:
        options_dict: Dict[str, Any] - Dictionary of options with specifications
            Structure: {param_name: {type, default, min, max, enum}}

    Methods:
        to_dict() -> Dict[str, Any]
            Convert to dictionary with string values for API calls.
            
            Returns:
                Dict[str, Any]: Hyperparameters as string-valued dictionary
        
        get_info(param_name=None) -> None
            Display parameter information and valid ranges.
            
            Parameters:
                param_name: Optional[str] - Specific parameter (shows all if None)

    Attributes:
        _specs: Dict - Parameter specifications (type, default, min, max, enum)
        _initialized: bool - Initialization flag for attribute access control

    Notes:
        - Hyperparameters become dynamic attributes based on model recipe
        - Validates types (float, integer, string) when setting values
        - Validates ranges (min/max) when setting numeric values
        - Validates enums when setting categorical values
        - Access via: trainer.hyperparameters.parameter_name
        - Setting invalid value raises ValueError with helpful message
    
    Example:
        trainer.hyperparameters.epochs = 5  # Valid
        trainer.hyperparameters.learning_rate = 0.001  # Valid
        trainer.hyperparameters.epochs = -1  # Raises ValueError
    """

Common Fine-Tuning Hyperparameters:

# View available hyperparameters
trainer.hyperparameters.get_info()

# Training configuration
trainer.hyperparameters.epochs = 3  # Number of training epochs
trainer.hyperparameters.learning_rate = 2e-4  # Learning rate
trainer.hyperparameters.warmup_ratio = 0.1  # Warmup ratio
trainer.hyperparameters.weight_decay = 0.01  # Weight decay for regularization

# Batch size configuration
trainer.hyperparameters.per_device_train_batch_size = 4  # Batch size per device
trainer.hyperparameters.per_device_eval_batch_size = 8  # Eval batch size
trainer.hyperparameters.gradient_accumulation_steps = 4  # Effective batch = 4*4

# LoRA-specific (only for TrainingType.LORA)
trainer.hyperparameters.lora_r = 8  # LoRA rank (higher = more capacity)
trainer.hyperparameters.lora_alpha = 16  # LoRA scaling (typically 2*lora_r)
trainer.hyperparameters.lora_dropout = 0.05  # LoRA dropout
trainer.hyperparameters.target_modules = ["q_proj", "v_proj"]  # Layers to adapt

# Optimization
trainer.hyperparameters.optimizer = "adamw"  # Optimizer type
trainer.hyperparameters.max_grad_norm = 1.0  # Gradient clipping
trainer.hyperparameters.lr_scheduler_type = "cosine"  # LR schedule

# Logging and saving
trainer.hyperparameters.logging_steps = 10  # Log every N steps
trainer.hyperparameters.save_steps = 500  # Save checkpoint every N steps
trainer.hyperparameters.eval_steps = 100  # Evaluate every N steps

# Check parameter constraints
try:
    trainer.hyperparameters.epochs = 100  # May exceed max
except ValueError as e:
    print(f"Invalid value: {e}")
    trainer.hyperparameters.get_info("epochs")  # Show valid range

Distributed Training

Configuration classes for distributed training with torchrun, MPI, and SageMaker Model Parallelism.

Torchrun

class Torchrun:
    """
    Configuration for torchrun-based distributed training.

    Parameters:
        process_count_per_node: Optional[int] - Number of processes per node
            - Typically set to number of GPUs per instance
            - Default: auto-detect from instance type
        
        smp: Optional[SMP] - SageMaker Model Parallelism v2 configuration
            - Required for model parallelism or tensor parallelism

    Attributes:
        driver_dir: str - Driver directory path (default: "/opt/ml/code")
        driver_script: str - Driver script name

    Notes:
        - Use for PyTorch distributed data parallel training
        - Automatically configures NCCL for GPU communication
        - Works with PyTorch DistributedDataParallel (DDP)
    """

MPI

class MPI:
    """
    Configuration for MPI-based distributed training.

    Parameters:
        process_count_per_node: Optional[int] - Number of processes per node
            - Typically equals number of GPUs per instance
        
        mpi_additional_options: Optional[List[str]] - Additional MPI options
            - Example: ["-x", "NCCL_DEBUG=INFO"]

    Attributes:
        driver_dir: str - Driver directory path
        driver_script: str - Driver script name

    Notes:
        - Use for MPI-based frameworks (Horovod, etc.)
        - Provides more control over process placement
        - Useful for hybrid CPU/GPU workloads
    """

SMP (SageMaker Model Parallelism v2)

class SMP:
    """
    SageMaker Model Parallelism v2 configuration.

    Parameters:
        hybrid_shard_degree: Optional[int] - Hybrid sharding degree
            - Number of model shards
            - Must evenly divide total GPU count
            - Default: 1 (no sharding)
        
        sm_activation_offloading: Optional[bool] - Enable activation offloading
            - Offload activations to CPU to save GPU memory
            - Trades compute for memory
            - Default: False
        
        activation_loading_horizon: Optional[int] - Activation loading horizon
            - Prefetch window for activation loading
            - Default: 4
        
        fsdp_cache_flush_warnings: Optional[bool] - FSDP cache flush warnings
            - Enable warnings for FSDP cache issues
            - Default: False
        
        allow_empty_shards: Optional[bool] - Allow empty shards
            - Permit shards with no parameters
            - Default: False
        
        tensor_parallel_degree: Optional[int] - Tensor parallelism degree
            - Number of ways to split tensors
            - Must evenly divide GPU count
            - Default: 1 (no tensor parallelism)
        
        context_parallel_degree: Optional[int] - Context parallelism degree
            - Split sequence dimension for long context
            - Default: 1
        
        expert_parallel_degree: Optional[int] - Expert parallelism degree
            - For Mixture of Experts (MoE) models
            - Default: 1
        
        random_seed: Optional[int] - Random seed for reproducibility
            - Range: 0-2147483647
            - Default: 12345

    Notes:
        - Use for training large models that don't fit in single GPU memory
        - hybrid_shard_degree: splits model weights across GPUs
        - tensor_parallel_degree: splits individual tensors across GPUs
        - Combine parallelism types for very large models
        - Total parallelism: hybrid_shard × tensor_parallel × context_parallel
        - Must evenly divide total GPU count
    """

Usage:

from sagemaker.train import ModelTrainer
from sagemaker.train.distributed import Torchrun, SMP
from sagemaker.train.configs import Compute

# Configure distributed training with model parallelism
distributed_config = Torchrun(
    process_count_per_node=8,  # 8 GPUs per node
    smp=SMP(
        hybrid_shard_degree=4,  # Shard model across 4 GPUs
        tensor_parallel_degree=2,  # Additional tensor parallelism
        sm_activation_offloading=True,  # Save memory
        random_seed=42
    )
)

# Create trainer with distributed config
trainer = ModelTrainer(
    training_image="pytorch-training-image",
    distributed=distributed_config,
    compute=Compute(
        instance_type="ml.p4d.24xlarge",  # 8 A100 GPUs
        instance_count=2,  # 16 total GPUs
        volume_size_in_gb=500
    ),
    role=role
)

# Distributed configuration automatically applied
trainer.train(input_data_config=[train_data])

Data Parallel Only:

# Simple data parallel (no model parallelism)
distributed_config = Torchrun(
    process_count_per_node=8  # 8-way data parallelism
)

trainer = ModelTrainer(
    training_image="pytorch-image",
    distributed=distributed_config,
    compute=Compute(
        instance_type="ml.p3.16xlarge",
        instance_count=4  # 32 total GPUs for data parallelism
    )
)

MPI-Based Training:

from sagemaker.train.distributed import MPI

# Configure MPI for Horovod
mpi_config = MPI(
    process_count_per_node=8,
    mpi_additional_options=[
        "-x", "NCCL_DEBUG=INFO",
        "-x", "HOROVOD_TIMELINE=timeline.json"
    ]
)

trainer = ModelTrainer(
    training_image="horovod-training-image",
    distributed=mpi_config,
    compute=Compute(
        instance_type="ml.p3.16xlarge",
        instance_count=2
    )
)

Configuration Classes

SourceCode

class SourceCode:
    """
    Source code configuration for training.

    Fields:
        source_dir: Optional[str] - Directory containing source code
            - Local directory path
            - Contents uploaded to S3 automatically
            - Extracted to /opt/ml/code in container
        
        entry_script: Optional[str] - Entry point script
            - Relative to source_dir
            - Example: "train.py"
            - Script receives hyperparameters as arguments
        
        command: Optional[str] - Command to execute
            - Overrides entry_script if provided
            - Example: "python train.py --custom-args"
        
        requirements: Optional[str] - Requirements file path
            - Relative to source_dir
            - Example: "requirements.txt"
            - Installed via pip before training
        
        ignore_patterns: Optional[List[str]] - Patterns to ignore during upload
            - Glob patterns
            - Example: ["*.pyc", "__pycache__", ".git"]

    Notes:
        - Either entry_script or command must be provided
        - requirements.txt automatically installed if present in source_dir
        - Large directories may slow uploads - use ignore_patterns
        - Source code uploaded to S3 and cached for subsequent runs
    """

Compute

class Compute:
    """
    Compute resource configuration.

    Fields:
        instance_type: str - EC2 instance type (required)
            - Training: ml.m5.xlarge, ml.p3.2xlarge, ml.p4d.24xlarge, etc.
            - Check SageMaker documentation for full list
        
        instance_count: int - Number of instances (default: 1)
            - Range: 1-100
            - Use >1 for distributed training
        
        volume_size_in_gb: int - EBS volume size (default: 30)
            - Range: 1-16384 GB
            - Must be large enough for training data + model
        
        volume_kms_key_id: Optional[str] - KMS key for volume encryption
            - Format: "arn:aws:kms:region:account:key/key-id"
        
        enable_managed_spot_training: Optional[bool] - Use managed spot training (default: False)
            - Can reduce costs by up to 90%
            - Training may be interrupted
            - Requires checkpoint_config for resumption
        
        keep_alive_period_in_seconds: Optional[int] - Keep alive period for warm pools (default: 0)
            - Range: 0-3600 seconds
            - Reduces startup time for repeated training
            - Instances billed while warm

    Notes:
        - Spot training requires max_wait_time_in_seconds > max_runtime_in_seconds
        - Warm pools incur costs during keep-alive period
        - Choose instance type based on:
            - CPU-bound: ml.m5, ml.c5
            - GPU training: ml.p3, ml.p4d
            - Memory-intensive: ml.r5
        - Multi-instance requires distributed training configuration
    """

StoppingCondition

class StoppingCondition:
    """
    Training stopping criteria.

    Fields:
        max_runtime_in_seconds: Optional[int] - Maximum training runtime (default: 86400)
            - Range: 1-2419200 (28 days)
            - Training stopped if exceeded
            - Billable time capped at this value
        
        max_pending_time_in_seconds: Optional[int] - Maximum pending time (default: unlimited)
            - Time waiting for resources
            - Training fails if exceeded
        
        max_wait_time_in_seconds: Optional[int] - Maximum wait time for spot (default: unlimited)
            - Only for managed spot training
            - Must be >= max_runtime_in_seconds
            - Includes time waiting for spot capacity + training time

    Notes:
        - Always set reasonable max_runtime to control costs
        - For spot: max_wait_time >= max_runtime
        - Exceeded limits result in training job failure
        - Use checkpoints to resume interrupted spot training
    """

Networking

class Networking:
    """
    Network configuration for training.

    Fields:
        subnets: Optional[List[str]] - VPC subnet IDs
            - Format: ["subnet-xxx", "subnet-yyy"]
            - Use multiple subnets across AZs for availability
        
        security_group_ids: Optional[List[str]] - Security group IDs
            - Format: ["sg-xxx"]
            - Must allow inter-instance communication for distributed training
        
        enable_network_isolation: Optional[bool] - Enable network isolation (default: False)
            - Blocks all network access except S3/ECR
            - No internet access
            - Use for compliance requirements
        
        enable_inter_container_traffic_encryption: Optional[bool] - Enable encryption (default: False)
            - Encrypts traffic between training instances
            - Required for distributed training with network isolation
            - Performance impact for high-bandwidth workloads

    Notes:
        - VPC configuration required for private data sources
        - Security groups must allow:
            - Ingress: all traffic from same security group (distributed training)
            - Egress: HTTPS to S3/ECR
        - Network isolation incompatible with custom VPC endpoints
        - Encryption adds latency to inter-instance communication
    """

InputData

class InputData:
    """
    Simplified input data configuration.

    Fields:
        channel_name: str - Channel name (required)
            - Examples: "training", "validation", "test"
            - Accessible in container at /opt/ml/input/data/{channel_name}
        
        data_source: Union[str, DataSource] - S3 URI or DataSource object (required)
            - S3 URI: "s3://bucket/prefix"
            - DataSource object for advanced configuration
        
        content_type: Optional[str] - Content type
            - Examples: "application/json", "text/csv", "image/jpeg"
            - Passed to training container

    Notes:
        - Multiple channels can be specified
        - S3 data must be in same region as training job
        - Automatic download to local storage by default
        - Use Channel for advanced options (compression, input mode)
    """

Channel

class Channel:
    """
    Full input data channel configuration.

    Fields:
        channel_name: str - Channel name (required)
        
        data_source: DataSource - Data source configuration (required)
            - S3 data source with URI and distribution settings
        
        content_type: Optional[str] - Content type
            - MIME type of data
        
        compression_type: Optional[str] - Compression type
            - "None" or "Gzip"
            - Default: "None"
        
        input_mode: Optional[str] - Input mode
            - "File": Download all data before training (default)
            - "Pipe": Stream data during training (for large datasets)

    Notes:
        - Pipe mode reduces training start time for large datasets
        - Pipe mode requires training code to read from pipe
        - File mode simpler but slower for large data
        - Gzip decompressed automatically
    """

OutputDataConfig

class OutputDataConfig:
    """
    Output data location configuration for training jobs.

    Fields:
        s3_output_path: Optional[str] - S3 URI where output data will be stored
            - Format: "s3://bucket/prefix"
            - Default: s3://sagemaker-{region}-{account}/output
            - Model artifacts saved to {s3_output_path}/{job-name}/output/model.tar.gz
        
        kms_key_id: Optional[str] - KMS key for encrypting model artifacts
            - Format: "arn:aws:kms:region:account:key/key-id"
            - Applied to model artifacts and training output
        
        compression_type: Optional[str] - Model output compression type
            - "None" or "Gzip"
            - Default: "Gzip"

    Notes:
        - Output bucket must be in same region as training job
        - Execution role needs s3:PutObject permission
        - Encryption key must allow encrypt/decrypt from execution role
        - Compressed output reduces storage costs and transfer time
    """

CheckpointConfig

class CheckpointConfig:
    """
    Checkpoint configuration for training jobs.

    Fields:
        s3_uri: Optional[str] - S3 path for checkpoint data (required)
            - Format: "s3://bucket/prefix/checkpoints"
            - Checkpoints automatically saved here during training
        
        local_path: Optional[str] - Local directory for checkpoints
            - Default: "/opt/ml/checkpoints"
            - Training code should save checkpoints here
            - Automatically synced to S3

    Notes:
        - Essential for spot training to resume after interruption
        - Training code must implement checkpoint save/load
        - Checkpoints synced to S3 asynchronously
        - Use for long-running training jobs (>1 hour)
        - Execution role needs s3:PutObject permission on s3_uri
    """

TensorBoardOutputConfig

class TensorBoardOutputConfig:
    """
    Storage locations for TensorBoard output.

    Fields:
        s3_output_path: Optional[str] - S3 path for TensorBoard output (required)
            - Format: "s3://bucket/prefix/tensorboard"
            - TensorBoard logs automatically uploaded
        
        local_path: Optional[str] - Local path for TensorBoard output
            - Default: "/opt/ml/output/tensorboard"
            - Training code should write TensorBoard logs here

    Notes:
        - Use with: trainer.with_tensorboard_output_config()
        - View logs: tensorboard --logdir=s3://bucket/prefix/tensorboard
        - Logs synced to S3 during and after training
        - Requires TensorBoard in training container
    """

TrainingImageConfig

class TrainingImageConfig:
    """
    Configuration for training container image.

    Fields:
        training_image_config_training_repository_access_mode: str - Access mode (required)
            - "Platform": Use SageMaker-managed images
            - "Vpc": Use private registry via VPC endpoint
        
        training_repository_auth_config: Optional[TrainingRepositoryAuthConfig]
            - Authentication for private registries
            - Required if using private registry

    Notes:
        - Platform mode: no additional configuration needed
        - VPC mode: requires VPC endpoint to private registry
        - Private registries must support Docker Registry HTTP API V2
    """

RetryStrategy

class RetryStrategy:
    """
    Retry strategy for training jobs.

    Fields:
        maximum_retry_attempts: int - Maximum number of retry attempts (required)
            - Range: 1-10
            - Retries for infrastructure failures only
            - Not for algorithmic errors

    Notes:
        - Use with: trainer.with_retry_strategy()
        - Retries automatically for:
            - InternalServerError
            - CapacityError (insufficient instance capacity)
            - ThrottlingException
        - Does not retry for:
            - AlgorithmError (code errors)
            - ValidationException (configuration errors)
        - Each retry counted as separate training job for billing
    """

InfraCheckConfig

class InfraCheckConfig:
    """
    Infrastructure health check configuration.

    Fields:
        enable_infra_check: bool - Enable infrastructure checks before training (required)
            - Validates network connectivity, instance health
            - Adds ~2-3 minutes to startup time
            - Recommended for production workloads

    Notes:
        - Catches infrastructure issues before training starts
        - Reduces wasted time on failing training jobs
        - Checks: network connectivity, volume mounts, GPU availability
    """

RemoteDebugConfig

class RemoteDebugConfig:
    """
    Remote debugging configuration for training jobs.

    Fields:
        enable_remote_debug: bool - Enable remote debugging (required)
            - Allows SSH access to training instance
            - Requires additional setup in training script

    Notes:
        - Use for debugging complex training issues
        - Requires SSH key configuration
        - Security consideration: opens SSH port
        - Only for development, not production
    """

SessionChainingConfig

class SessionChainingConfig:
    """
    Session chaining configuration for training jobs.

    Fields:
        enable_session_tag_chaining: bool - Enable session tag chaining (required)
            - Propagates session tags to created resources
            - For compliance and cost tracking

    Notes:
        - Tags from session automatically applied to:
            - Training jobs
            - Model artifacts
            - Endpoints (if created from this training)
        - Useful for automated tagging policies
    """

MetricDefinition

class MetricDefinition:
    """
    Definition for extracting metrics from training job logs.

    Fields:
        name: str - Metric name (required)
            - Example: "train:loss"
            - Appears in CloudWatch and training job details
        
        regex: str - Regex pattern to extract metric from logs (required)
            - Must contain exactly one capture group
            - Capture group must match a number
            - Example: "Train Loss: ([0-9\\.]+)"

    Notes:
        - Metrics extracted from CloudWatch logs
        - Use with: trainer.with_metric_definitions()
        - Captured metrics available for:
            - Hyperparameter tuning objectives
            - CloudWatch metrics and alarms
            - Training job summaries
        - Regex must match log output format exactly
        - Common regex patterns:
            - Float: ([0-9\\.]+)
            - Scientific notation: ([0-9\\.e\\-\\+]+)
            - Percentage: ([0-9\\.]+)%
    """

Enums

Mode

class Mode(Enum):
    """
    Training mode for ModelTrainer.

    Values:
        LOCAL_CONTAINER = 1
            - Run training in a local Docker container
            - For testing before submitting to SageMaker
            - Requires Docker installed locally
        
        SAGEMAKER_TRAINING_JOB = 2
            - Run training as a SageMaker training job (default)
            - Uses SageMaker managed infrastructure
            - Billed per second of usage

    Notes:
        - LOCAL_CONTAINER useful for:
            - Debugging training code
            - Testing without SageMaker costs
            - Validating container configuration
        - SAGEMAKER_TRAINING_JOB provides:
            - Scalable infrastructure
            - Distributed training
            - Spot instances
            - Automatic checkpointing to S3
    """

TrainingType

class TrainingType(Enum):
    """
    Fine-tuning training type.

    Values:
        LORA = "lora"
            - Low-Rank Adaptation (parameter-efficient fine-tuning)
            - Updates only small adapter layers (~1% of parameters)
            - 5-10x faster and cheaper than full fine-tuning
            - Good quality for most use cases
            - Recommended for initial experiments
        
        FULL = "full"
            - Full model fine-tuning (all parameters updated)
            - Maximum quality for complex adaptations
            - Requires more compute and memory
            - Use when LORA quality insufficient
            - Better for significant domain shifts

    Notes:
        - LORA advantages:
            - Lower compute costs
            - Faster training
            - Less memory required
            - Multiple adapters from same base model
        - FULL advantages:
            - Maximum quality
            - Better for complex tasks
            - No architectural constraints
        - Start with LORA, upgrade to FULL if needed
    """

CustomizationTechnique

class CustomizationTechnique(Enum):
    """
    Model customization technique.

    Values:
        SFT = "sft"
            - Supervised Fine-Tuning
            - Uses labeled instruction-response pairs
            - Best for adapting to new tasks
        
        DPO = "dpo"
            - Direct Preference Optimization
            - Uses preference pairs (chosen vs rejected)
            - Aligns model with human preferences
            - More stable than RLHF
        
        RLAIF = "rlaif"
            - Reinforcement Learning from AI Feedback
            - Uses AI model for reward signals
            - Scales better than human feedback
        
        RLVR = "rlvr"
            - Reinforcement Learning from Verifiable Rewards
            - Uses programmatic reward functions
            - Best for objective metrics (code, math)

    Notes:
        - SFT: Start here for task adaptation
        - DPO: Use after SFT for alignment
        - RLAIF: Alternative to RLHF with AI rewards
        - RLVR: For tasks with objective evaluation
    """

Parameter Ranges

Parameter range classes for defining hyperparameter search spaces in tuning jobs.

ParameterRange

class ParameterRange:
    """
    Base class for parameter ranges.

    Abstract base class for defining hyperparameter ranges in tuning jobs.
    Subclasses: ContinuousParameter, IntegerParameter, CategoricalParameter
    """

ContinuousParameter

class ContinuousParameter(ParameterRange):
    """
    Continuous parameter range for floating-point hyperparameters.

    Parameters:
        min_value: float - Minimum value (inclusive) (required)
        max_value: float - Maximum value (inclusive) (required)
        scaling_type: str - Scaling type (default: "Auto")
            - "Auto": Automatically choose scaling
            - "Linear": Linear scaling
            - "Logarithmic": Log scaling (for learning rates, etc.)
            - "ReverseLogarithmic": Reverse log scaling

    Notes:
        - Use Logarithmic for parameters spanning orders of magnitude
        - Linear for parameters in similar scale
        - Scaling affects sampling strategy in Bayesian optimization
    """

IntegerParameter

class IntegerParameter(ParameterRange):
    """
    Integer parameter range for integer hyperparameters.

    Parameters:
        min_value: int - Minimum value (inclusive) (required)
        max_value: int - Maximum value (inclusive) (required)
        scaling_type: str - Scaling type (default: "Auto")
            - "Auto", "Linear", "Logarithmic", "ReverseLogarithmic"

    Notes:
        - Values always integers even with log scaling
        - Use for batch sizes, layers, etc.
    """

CategoricalParameter

class CategoricalParameter(ParameterRange):
    """
    Categorical parameter range for discrete hyperparameters.

    Parameters:
        values: List[Union[str, int, float]] - List of possible values (required)
            - Examples: ["adam", "sgd"], [32, 64, 128]

    Notes:
        - No scaling applied (categorical)
        - Values must be JSON-serializable
        - Order doesn't matter for Bayesian optimization
    """

Usage:

from sagemaker.core.parameter import (
    ContinuousParameter, IntegerParameter, CategoricalParameter
)

# Define parameter ranges
hyperparameter_ranges = {
    # Learning rate with log scaling (spans orders of magnitude)
    "learning_rate": ContinuousParameter(
        min_value=0.001,
        max_value=0.1,
        scaling_type="Logarithmic"
    ),
    
    # Batch size with linear scaling
    "batch_size": IntegerParameter(
        min_value=32,
        max_value=256,
        scaling_type="Linear"
    ),
    
    # Optimizer choices
    "optimizer": CategoricalParameter(
        values=["adam", "sgd", "rmsprop", "adamw"]
    ),
    
    # Number of layers
    "num_layers": IntegerParameter(
        min_value=2,
        max_value=10
    ),
    
    # Dropout rate
    "dropout": ContinuousParameter(
        min_value=0.0,
        max_value=0.5
    )
}

Helper Functions

Session

class Session:
    """
    SageMaker session for managing API interactions.

    Main class for managing interactions with SageMaker APIs and AWS services.
    Handles authentication, region configuration, and service clients.
    
    Parameters:
        boto_session: Optional[boto3.Session] - Boto3 session for AWS credentials
        region_name: Optional[str] - AWS region (default: from boto session)
        default_bucket: Optional[str] - Default S3 bucket for SageMaker resources
        
    Methods:
        get_execution_role() -> str: Get IAM role ARN
        default_bucket() -> str: Get default S3 bucket
        upload_data(path, bucket, key_prefix) -> str: Upload data to S3
        download_data(path, bucket, key_prefix): Download data from S3
        
    Notes:
        - Automatically created if not provided to ModelTrainer
        - Reuse session across multiple trainers for consistency
        - Default bucket: sagemaker-{region}-{account_id}
    """

get_execution_role

def get_execution_role() -> str:
    """
    Get the execution role ARN for SageMaker.

    Returns:
        str: IAM role ARN for SageMaker execution
            Format: "arn:aws:iam::{account}:role/{role-name}"

    Raises:
        ValueError: If role cannot be determined from environment
            - Not running in SageMaker notebook instance
            - IAM role not configured in environment
            - No default execution role available

    Notes:
        - Automatically detects role when running in:
            - SageMaker notebook instances
            - SageMaker Studio
            - SageMaker Processing jobs
            - SageMaker Training jobs
        - Outside SageMaker, explicitly provide role ARN to ModelTrainer
        
    Usage:
        # In SageMaker environment
        role = get_execution_role()
        
        # Outside SageMaker
        role = "arn:aws:iam::123456789012:role/SageMakerRole"
    """

Validation and Constraints

Training Job Constraints

  • Maximum runtime: 28 days (2,419,200 seconds)
  • Minimum instance count: 1
  • Maximum instance count: 20 (can request service quota increase)
  • Instance volume size: 1 GB - 16,384 GB
  • Maximum input channels: 20
  • Maximum hyperparameters: 100
  • Hyperparameter name length: 1-256 characters
  • Hyperparameter value length: 1-2500 characters
  • Job name length: 1-63 characters
  • Container environment variables: Maximum 512 entries

S3 Constraints

  • S3 URIs must be in same region as training job
  • Maximum object size: 5 TB
  • Multipart upload recommended for files >100 MB
  • S3 bucket names: 3-63 characters, lowercase, no underscores

Network Constraints

  • VPC subnets: Must be in same VPC
  • Security groups: Maximum 5 per training job
  • Network isolation: Incompatible with VPC configuration
  • Inter-container encryption: Only for distributed training (instance_count >1)

Hyperparameter Tuning Constraints

  • Maximum concurrent tuning jobs: 100 (can request increase)
  • Maximum training jobs per tuning: 500
  • Maximum parallel training jobs: 100
  • Parameter ranges: Maximum 30
  • Warm start parent jobs: Maximum 5
  • Metric definitions: Maximum 40

Fine-Tuning Constraints

  • Supported models: Check SageMaker JumpStart for full list
  • LORA rank (lora_r): 1-256 (typically 4-16)
  • Training dataset: Minimum 10 examples, recommended 100+
  • Validation dataset: Recommended 10% of training size
  • Maximum sequence length: Model-dependent (typically 512-4096 tokens)

Common Error Scenarios

  1. ResourceLimitExceeded:

    • Cause: Exceeded service quota for instance type
    • Solution: Request quota increase or use different instance type
  2. ValidationException: Invalid hyperparameter:

    • Cause: Hyperparameter value outside valid range
    • Solution: Check hyperparameters.get_info() for valid ranges
  3. S3 Access Denied:

    • Cause: Execution role lacks S3 permissions
    • Solution: Add s3:GetObject, s3:PutObject permissions to role
  4. Spot Instance Interruption:

    • Cause: Spot capacity reclaimed by AWS
    • Solution: Enable checkpointing, set max_wait_time_in_seconds appropriately
  5. AlgorithmError:

    • Cause: Error in training code
    • Solution: Check CloudWatch logs, test locally with LOCAL_CONTAINER mode
  6. Invalid Gated Model Access:

    • Cause: Model requires EULA acceptance or subscription
    • Solution: Set accept_eula=True, ensure model accessible in region