Model training capabilities including unified ModelTrainer, hyperparameter tuning, fine-tuning for foundation models, and distributed training support.
Important: The sagemaker.train module uses lazy loading for some classes. Import paths:
# Available via lazy loading from sagemaker.train
from sagemaker.train import ModelTrainer, Session, get_execution_role
from sagemaker.train import BenchMarkEvaluator, CustomScorerEvaluator, LLMAsJudgeEvaluator
from sagemaker.train import get_benchmarks, get_builtin_metrics
# Requires full module path (not lazy-loaded)
from sagemaker.train.sft_trainer import SFTTrainer
from sagemaker.train.dpo_trainer import DPOTrainer
from sagemaker.train.rlaif_trainer import RLAIFTrainer
from sagemaker.train.rlvr_trainer import RLVRTrainer
from sagemaker.train.tuner import HyperparameterTuner, WarmStartTypesMain class for training models using AWS SageMaker, replacing V2 Estimator and framework-specific classes.
class ModelTrainer:
"""
Train models on SageMaker or locally.
Parameters:
training_mode: Mode - LOCAL_CONTAINER or SAGEMAKER_TRAINING_JOB (default: SAGEMAKER_TRAINING_JOB)
sagemaker_session: Optional[Session] - SageMaker session (default: creates new session)
role: Optional[str] - IAM role ARN (required for SageMaker training)
base_job_name: Optional[str] - Base name for training jobs (default: auto-generated)
source_code: Optional[SourceCode] - Source code configuration for custom training scripts
distributed: Optional[Union[Torchrun, MPI]] - Distributed training configuration
compute: Optional[Compute] - Compute resource configuration (required)
networking: Optional[Networking] - Network configuration for VPC deployments
stopping_condition: Optional[StoppingCondition] - Stopping criteria (default: 1 day max runtime)
algorithm_name: Optional[str] - Built-in algorithm name (e.g., "xgboost", "kmeans")
training_image: Optional[str] - Training container image URI (required unless algorithm_name provided)
training_image_config: Optional[TrainingImageConfig] - Image configuration for private registries
output_data_config: Optional[OutputDataConfig] - Output configuration (default: default S3 bucket)
input_data_config: Optional[List[Union[Channel, InputData]]] - Input data channels
checkpoint_config: Optional[CheckpointConfig] - Checkpoint configuration for resumable training
training_input_mode: Optional[str] - Input mode: "File" or "Pipe" (default: "File")
environment: Optional[Dict[str, str]] - Environment variables passed to container
hyperparameters: Optional[Union[Dict[str, Any], str]] - Training hyperparameters (values converted to strings)
tags: Optional[List[Tag]] - Resource tags for cost tracking
local_container_root: Optional[str] - Local container root directory (for LOCAL_CONTAINER mode)
Methods:
train(input_data_config=None, wait=True, logs=True) -> None
Train the model.
Parameters:
input_data_config: Optional[List[InputData]] - Input data (overrides constructor value)
wait: bool - Block until training completes (default: True)
logs: bool - Show CloudWatch logs (default: True)
Raises:
ValueError: If required parameters missing or invalid
ClientError: If AWS API call fails
RuntimeError: If training job fails
create_input_data_channel(channel_name, data_source, key_prefix=None, ignore_patterns=None) -> InputData
Create input data channel from local files or S3.
Parameters:
channel_name: str - Channel name (e.g., "training", "validation")
data_source: str - Local path or S3 URI
key_prefix: Optional[str] - S3 key prefix for local uploads
ignore_patterns: Optional[List[str]] - Patterns to ignore for local uploads
Returns:
InputData: Configured input data channel
with_tensorboard_output_config(tensorboard_output_config) -> ModelTrainer
Configure TensorBoard logging.
Parameters:
tensorboard_output_config: TensorBoardOutputConfig - TensorBoard configuration
Returns:
ModelTrainer: Self for method chaining
with_retry_strategy(retry_strategy) -> ModelTrainer
Set retry strategy for training job.
Parameters:
retry_strategy: RetryStrategy - Retry configuration
Returns:
ModelTrainer: Self for method chaining
with_checkpoint_config(checkpoint_config) -> ModelTrainer
Set checkpoint configuration.
Parameters:
checkpoint_config: CheckpointConfig - Checkpoint configuration
Returns:
ModelTrainer: Self for method chaining
with_metric_definitions(metric_definitions) -> ModelTrainer
Set metric definitions for extracting metrics from logs.
Parameters:
metric_definitions: List[MetricDefinition] - Metric extraction patterns
Returns:
ModelTrainer: Self for method chaining
Class Methods:
from_recipe(training_recipe, compute, ...) -> ModelTrainer
Create from training recipe configuration.
from_jumpstart_config(jumpstart_config, ...) -> ModelTrainer
Create from JumpStart configuration.
Attributes:
_latest_training_job: TrainingJob - Most recent training job resource
Raises:
ValueError: Invalid configuration parameters
ClientError: AWS API errors (permission, quota, validation)
RuntimeError: Training execution errors
"""Usage:
from sagemaker.train import ModelTrainer, Session
from sagemaker.train.configs import InputData, Compute, SourceCode
# Create trainer with basic configuration
trainer = ModelTrainer(
training_image="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0-gpu-py310",
role="arn:aws:iam::123456789012:role/SageMakerRole",
compute=Compute(
instance_type="ml.p3.2xlarge",
instance_count=1,
volume_size_in_gb=50
),
source_code=SourceCode(
source_dir="./code",
entry_script="train.py",
requirements="requirements.txt"
),
hyperparameters={
"epochs": 10,
"learning_rate": 0.001,
"batch_size": 32
}
)
# Configure training data
train_data = InputData(
channel_name="training",
data_source="s3://my-bucket/train",
content_type="application/json"
)
val_data = InputData(
channel_name="validation",
data_source="s3://my-bucket/val",
content_type="application/json"
)
# Train model with error handling
try:
trainer.train(
input_data_config=[train_data, val_data],
wait=True,
logs=True
)
# Access training job details
job = trainer._latest_training_job
print(f"Training job: {job.training_job_name}")
print(f"Model artifacts: {job.model_artifacts}")
except ClientError as e:
error_code = e.response['Error']['Code']
if error_code == 'ResourceLimitExceeded':
print("Exceeded instance limit. Request quota increase.")
elif error_code == 'ValidationException':
print(f"Invalid configuration: {e}")
else:
raise
except RuntimeError as e:
print(f"Training failed: {e}")
# Check CloudWatch logs for detailsAdvanced Configuration:
from sagemaker.train.configs import (
StoppingCondition, CheckpointConfig,
Networking, RetryStrategy, MetricDefinition
)
# Configure with all options
trainer = ModelTrainer(
training_image="my-training-image",
role=role_arn,
compute=Compute(
instance_type="ml.p3.8xlarge",
instance_count=4,
volume_size_in_gb=100,
enable_managed_spot_training=True, # Cost savings
keep_alive_period_in_seconds=300 # Warm pool for repeated training
),
networking=Networking(
subnets=["subnet-xxx", "subnet-yyy"],
security_group_ids=["sg-xxx"],
enable_inter_container_traffic_encryption=True
),
stopping_condition=StoppingCondition(
max_runtime_in_seconds=86400, # 24 hours
max_wait_time_in_seconds=172800 # 48 hours for spot
),
checkpoint_config=CheckpointConfig(
s3_uri="s3://my-bucket/checkpoints",
local_path="/opt/ml/checkpoints"
),
environment={
"MY_ENV_VAR": "value",
"LOG_LEVEL": "INFO"
},
tags=[
{"Key": "Project", "Value": "MyProject"},
{"Key": "Environment", "Value": "Production"}
]
)
# Add metric definitions for custom metrics
trainer.with_metric_definitions([
MetricDefinition(
name="train:loss",
regex="Train Loss: ([0-9\\.]+)"
),
MetricDefinition(
name="val:accuracy",
regex="Validation Accuracy: ([0-9\\.]+)"
)
])
# Add retry strategy for resilience
trainer.with_retry_strategy(
RetryStrategy(maximum_retry_attempts=2)
)
# Train with full configuration
trainer.train(input_data_config=[train_data, val_data])Manages hyperparameter tuning jobs for model optimization with support for Bayesian, Random, Grid, and Hyperband strategies.
class HyperparameterTuner:
"""
Hyperparameter tuning for model optimization.
Parameters:
model_trainer: ModelTrainer - The trainer instance to tune (required)
objective_metric_name: Union[str, PipelineVariable] - Metric to optimize (required)
hyperparameter_ranges: Dict[str, ParameterRange] - Parameter ranges to search (required unless autotune=True)
metric_definitions: Optional[List[Dict]] - Metric definitions with 'Name' and 'Regex' keys
strategy: Union[str, PipelineVariable] - Search strategy: "Bayesian", "Random", "Hyperband", "Grid" (default: "Bayesian")
objective_type: Union[str, PipelineVariable] - "Maximize" or "Minimize" (default: "Maximize")
max_jobs: Union[int, PipelineVariable] - Maximum number of training jobs (required)
max_parallel_jobs: Union[int, PipelineVariable] - Maximum parallel jobs (default: 1)
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] - Maximum runtime for entire tuning job (default: unlimited)
tags: Optional[Tags] - Resource tags
base_tuning_job_name: Optional[str] - Base tuning job name (default: auto-generated from timestamp)
warm_start_config: Optional[HyperParameterTuningJobWarmStartConfig] - Warm start config for transfer learning
strategy_config: Optional[HyperParameterTuningJobStrategyConfig] - Strategy-specific configuration (required for Hyperband)
completion_criteria_config: Optional[TuningJobCompletionCriteria] - Auto-stop completion criteria
early_stopping_type: Union[str, PipelineVariable] - Early stopping: "Off", "Auto" (default: "Off")
model_trainer_name: Optional[str] - Name for multi-trainer configurations
random_seed: Optional[int] - Random seed for reproducible hyperparameter search (range: 0-2147483647)
autotune: bool - Enable automatic parameter range selection (default: False)
hyperparameters_to_keep_static: Optional[List[str]] - Hyperparameters to exclude from auto-tuning when autotune=True
Methods:
tune(inputs=None, job_name=None, model_trainer_kwargs=None, wait=True, **kwargs) -> None
Start hyperparameter tuning job.
Parameters:
inputs: Optional - Input data for training (overrides model_trainer configuration)
job_name: Optional[str] - Tuning job name (default: auto-generated)
model_trainer_kwargs: Optional[Dict] - Additional ModelTrainer kwargs
wait: bool - Block until tuning completes (default: True)
**kwargs: Additional tuning job parameters
Raises:
ValueError: Invalid configuration (e.g., Hyperband with early_stopping_type != "Off")
ClientError: AWS API errors
stop_tuning_job() -> None
Stop the latest tuning job immediately.
Raises:
ClientError: If job cannot be stopped
describe() -> dict
Get full tuning job details from DescribeHyperParameterTuningJob API.
Returns:
dict: Complete job description including all training jobs
wait() -> None
Wait for tuning job to complete.
Raises:
WaiterError: If job fails or times out
best_training_job() -> str
Get name of best training job from latest tuning job.
Returns:
str: Training job name with best objective metric value
Raises:
RuntimeError: If tuning job not yet started or failed
analytics() -> HyperparameterTuningJobAnalytics
Get analytics instance with pandas DataFrame of results.
Returns:
HyperparameterTuningJobAnalytics: Analytics object with dataframe() method
transfer_learning_tuner(additional_parents=None, model_trainer=None) -> HyperparameterTuner
Create transfer learning tuner from this tuner.
Parameters:
additional_parents: Optional[Set[str]] - Additional parent tuning job names
model_trainer: Optional[ModelTrainer] - New trainer (default: reuse current)
Returns:
HyperparameterTuner: New tuner configured for transfer learning
override_resource_config(instance_configs) -> None
Override instance configuration for model trainers.
Parameters:
instance_configs: List[HyperParameterTuningInstanceConfig] - Instance configs
Class Methods:
create(model_trainer_dict, objective_metric_name_dict, hyperparameter_ranges_dict, ...) -> HyperparameterTuner
Create multi-model-trainer tuner.
Parameters:
model_trainer_dict: Dict[str, ModelTrainer] - Named trainers
objective_metric_name_dict: Dict[str, str] - Objective metrics per trainer
hyperparameter_ranges_dict: Dict[str, Dict[str, ParameterRange]] - Ranges per trainer
Returns:
HyperparameterTuner: Configured multi-trainer tuner
Attributes:
latest_tuning_job: HyperParameterTuningJob - Most recent tuning job resource
Raises:
ValueError: Invalid strategy, parameters, or configuration
ClientError: AWS API errors
RuntimeError: Tuning execution errors
Notes:
- Bayesian strategy: Most efficient for continuous parameters, learns from previous jobs
- Random strategy: Good baseline, works well for all parameter types
- Grid strategy: Exhaustive search, best for small search spaces with discrete parameters
- Hyperband strategy: Multi-fidelity based, dynamically allocates resources, requires strategy_config
- Early stopping "Auto" may stop underperforming jobs but is not guaranteed
- For Hyperband, early_stopping_type must be "Off" (has internal early stopping)
- Warm start types: IdenticalDataAndAlgorithm (same data/image) or TransferLearning (different data/config)
"""Tuning Strategies:
strategy_config)Early Stopping:
Usage:
from sagemaker.train.tuner import HyperparameterTuner, WarmStartTypes
from sagemaker.core.parameter import ContinuousParameter, IntegerParameter, CategoricalParameter
# Define parameter ranges with scaling
hyperparameter_ranges = {
"learning_rate": ContinuousParameter(
min_value=0.001,
max_value=0.1,
scaling_type="Logarithmic" # Better for learning rates
),
"batch_size": IntegerParameter(
min_value=32,
max_value=256,
scaling_type="Linear"
),
"optimizer": CategoricalParameter(
values=["adam", "sgd", "rmsprop"]
),
"dropout": ContinuousParameter(
min_value=0.0,
max_value=0.5,
scaling_type="Linear"
)
}
# Create tuner with Bayesian strategy
tuner = HyperparameterTuner(
model_trainer=trainer,
objective_metric_name="validation:accuracy",
hyperparameter_ranges=hyperparameter_ranges,
metric_definitions=[
{"Name": "validation:accuracy", "Regex": "Val Acc: ([0-9\\.]+)"}
],
strategy="Bayesian",
objective_type="Maximize",
max_jobs=20,
max_parallel_jobs=3,
early_stopping_type="Auto"
)
# Start tuning with error handling
try:
tuner.tune(wait=True)
# Get best training job
best_job = tuner.best_training_job()
print(f"Best training job: {best_job}")
# Access analytics
analytics = tuner.analytics()
df = analytics.dataframe() # Pandas DataFrame with all results
print(df.sort_values('FinalObjectiveValue', ascending=False).head())
except ValueError as e:
print(f"Invalid configuration: {e}")
except ClientError as e:
print(f"AWS API error: {e}")Hyperband Strategy:
from sagemaker.core.shapes import HyperbandStrategyConfig, HyperParameterTuningJobStrategyConfig
# Configure Hyperband strategy
hyperband_config = HyperbandStrategyConfig(
min_resource=1, # Minimum epochs/resources for training jobs
max_resource=27, # Maximum epochs/resources
reduction_factor=3, # Factor for reducing configurations (typically 2-4)
number_of_brackets=4 # Number of bracket levels (more brackets = more exploration)
)
strategy_config = HyperParameterTuningJobStrategyConfig(
hyperband_strategy_config=hyperband_config
)
# Create tuner with Hyperband
tuner = HyperparameterTuner(
model_trainer=trainer,
objective_metric_name="validation:accuracy",
hyperparameter_ranges=hyperparameter_ranges,
strategy="Hyperband",
strategy_config=strategy_config,
early_stopping_type="Off", # Required for Hyperband
max_jobs=100,
max_parallel_jobs=5
)
tuner.tune()Warm Start (Transfer Learning):
from sagemaker.core.shapes import (
HyperParameterTuningJobWarmStartConfig,
ParentHyperParameterTuningJob
)
from sagemaker.train.tuner import WarmStartTypes
# Configure warm start from parent tuning job
warm_start_config = HyperParameterTuningJobWarmStartConfig(
parent_hyper_parameter_tuning_jobs=[
ParentHyperParameterTuningJob(
hyper_parameter_tuning_job_name="my-previous-tuning-job"
)
],
warm_start_type=WarmStartTypes.TRANSFER_LEARNING.value
)
# Create tuner with warm start
tuner = HyperparameterTuner(
model_trainer=trainer,
objective_metric_name="validation:accuracy",
hyperparameter_ranges=hyperparameter_ranges,
warm_start_config=warm_start_config,
max_jobs=20,
max_parallel_jobs=3
)
# Or use transfer_learning_tuner helper
new_tuner = tuner.transfer_learning_tuner(
additional_parents={"parent-job-1", "parent-job-2"}
)
new_tuner.tune()Warm Start Types Enum:
class WarmStartTypes(Enum):
"""
Enum for hyperparameter tuning warm start types.
Values:
IDENTICAL_DATA_AND_ALGORITHM: "IdenticalDataAndAlgorithm"
- Same input data and training image required
- Can change: hyperparameter ranges, max training jobs
- Cannot change: algorithm version, input data
TRANSFER_LEARNING: "TransferLearning"
- Different data, ranges, or training configuration allowed
- Can change: input data, hyperparameter ranges, training image version
- Use when continuing tuning with different dataset or updated model
"""
IDENTICAL_DATA_AND_ALGORITHM = "IdenticalDataAndAlgorithm"
TRANSFER_LEARNING = "TransferLearning"Completion Criteria:
from sagemaker.core.shapes import (
TuningJobCompletionCriteria,
BestObjectiveNotImproving,
ConvergenceDetected
)
# Auto-stop when target metric reached or convergence detected
completion_criteria = TuningJobCompletionCriteria(
target_objective_metric_value=0.95, # Stop when metric reaches this value
best_objective_not_improving=BestObjectiveNotImproving(
max_number_of_training_jobs_not_improving=10 # Stop if no 1%+ improvement
),
convergence_detected=ConvergenceDetected(
complete_on_convergence="Enabled" # Stop if AMT detects convergence
)
)
tuner = HyperparameterTuner(
model_trainer=trainer,
objective_metric_name="validation:accuracy",
hyperparameter_ranges=hyperparameter_ranges,
completion_criteria_config=completion_criteria,
max_jobs=100 # May complete with fewer jobs due to criteria
)Multi-Model-Trainer Tuning:
# Create tuner for multiple trainers
tuner = HyperparameterTuner.create(
model_trainer_dict={
"trainer1": trainer1,
"trainer2": trainer2
},
objective_metric_name_dict={
"trainer1": "val:accuracy",
"trainer2": "val:f1"
},
hyperparameter_ranges_dict={
"trainer1": {
"learning_rate": ContinuousParameter(0.001, 0.1)
},
"trainer2": {
"learning_rate": ContinuousParameter(0.0001, 0.01)
}
},
strategy="Bayesian",
max_jobs=20,
max_parallel_jobs=4
)AutoTune Feature:
# Enable automatic parameter range selection
tuner = HyperparameterTuner(
model_trainer=trainer,
objective_metric_name="validation:accuracy",
hyperparameter_ranges={}, # Empty when autotune=True
autotune=True,
hyperparameters_to_keep_static=["epochs", "batch_size"], # Don't tune these
max_jobs=50,
max_parallel_jobs=5
)
# SageMaker will automatically determine optimal ranges for tunable hyperparameters
tuner.tune()Analytics class for accessing tuning job results as pandas DataFrame.
class HyperparameterTuningJobAnalytics:
"""
Analytics for hyperparameter tuning jobs.
Parameters:
tuning_job_name: str - Name of completed tuning job
sagemaker_session: Optional[Session] - SageMaker session
Attributes:
name: str - Name of tuning job
tuning_ranges: Dict[str, ParameterRange] - Dictionary of parameter ranges used
Methods:
description() -> dict
Get full job description from DescribeHyperParameterTuningJob API.
Returns:
dict: Complete tuning job metadata and configuration
training_job_summaries() -> List[Dict]
Get all training job summaries from tuning job.
Returns:
List[Dict]: List of training job summaries with parameters and metrics
dataframe() -> pandas.DataFrame
Pandas DataFrame with all tuning results for analysis.
Returns:
pandas.DataFrame: Results with columns for parameters, metrics, and metadata
DataFrame Columns:
- Tuned hyperparameter values (one column per hyperparameter)
- TrainingJobName: Name of training job
- TrainingJobStatus: Status (Completed, Failed, Stopped)
- FinalObjectiveValue: Final metric value
- TrainingStartTime: Training start timestamp
- TrainingEndTime: Training end timestamp
- TrainingElapsedTimeSeconds: Training duration in seconds
- TrainingJobDefinitionName: Training job definition name
Usage:
Analyze tuning results, identify best parameters, visualize search progress
"""Usage:
from sagemaker.train.tuner import HyperparameterTuningJobAnalytics
import pandas as pd
import matplotlib.pyplot as plt
# Create analytics instance
analytics = HyperparameterTuningJobAnalytics(
tuning_job_name="my-tuning-job-2024-01-01-12-00-00"
)
# Get results as DataFrame
df = analytics.dataframe()
# Analyze results
print("Best 5 configurations:")
print(df.sort_values('FinalObjectiveValue', ascending=False).head())
# Analyze parameter impact
print("\nLearning rate vs Accuracy:")
print(df.groupby('learning_rate')['FinalObjectiveValue'].mean())
# Visualize search progress
plt.figure(figsize=(10, 6))
plt.plot(range(len(df)), df['FinalObjectiveValue'].cummax())
plt.xlabel('Training Job Number')
plt.ylabel('Best Objective Value')
plt.title('Hyperparameter Tuning Progress')
plt.show()
# Export best configuration
best_job = df.loc[df['FinalObjectiveValue'].idxmax()]
best_params = {
col: best_job[col]
for col in df.columns
if col in analytics.tuning_ranges.keys()
}
print(f"\nBest hyperparameters: {best_params}")Fine-tuning trainers for foundation models with support for Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and Reinforcement Learning techniques.
Supervised Fine-Tuning trainer for foundation models using instruction-following or task-specific datasets.
class SFTTrainer:
"""
Supervised Fine-Tuning trainer for foundation models.
Parameters:
model: Union[str, ModelPackage] - Foundation model to fine-tune
- Model name/ID (e.g., "meta-llama/Llama-2-7b-hf")
- Model ARN (e.g., "arn:aws:sagemaker:...")
- ModelPackage object
(required)
training_type: Union[TrainingType, str] - Fine-tuning approach (default: TrainingType.LORA)
- TrainingType.LORA: Parameter-efficient, faster, lower cost
- TrainingType.FULL: All parameters updated, higher quality for complex tasks
model_package_group: Optional[Union[str, ModelPackageGroup]] - Model package group for output
- Group name (e.g., "my-fine-tuned-models")
- Group ARN
- ModelPackageGroup object
mlflow_resource_arn: Optional[str] - MLflow tracking server ARN for experiment tracking
mlflow_experiment_name: Optional[str] - MLflow experiment name
mlflow_run_name: Optional[str] - MLflow run name
training_dataset: Optional[Union[str, DataSet]] - Training dataset (required at train() time if not provided)
- S3 URI (e.g., "s3://bucket/train.jsonl")
- Dataset ARN
- DataSet object from AI Registry
validation_dataset: Optional[Union[str, DataSet]] - Validation dataset
- Same formats as training_dataset
s3_output_path: Optional[str] - S3 path for outputs
- Defaults to: s3://sagemaker-<region>-<account>/output
kms_key_id: Optional[str] - KMS key ID for output encryption
networking: Optional[VpcConfig] - VPC configuration for secure training
accept_eula: Optional[bool] - Accept EULA for gated models (default: False)
- Required as True for models with usage agreements
Methods:
train(training_dataset=None, validation_dataset=None, wait=True) -> TrainingJob
Execute supervised fine-tuning training job.
Parameters:
training_dataset: Optional[Union[str, DataSet]] - Override constructor dataset
validation_dataset: Optional[Union[str, DataSet]] - Override constructor dataset
wait: bool - Block until training completes (default: True)
Returns:
TrainingJob: Completed training job resource
Raises:
ValueError: If training_dataset not provided either in constructor or train()
ValueError: If accept_eula=False for gated model
ClientError: AWS API errors
Attributes:
hyperparameters: FineTuningOptions - Dynamic hyperparameters with validation
- Access via: trainer.hyperparameters.parameter_name
- Common parameters:
- epochs: int (default: 3)
- learning_rate: float (default: 2e-4)
- per_device_train_batch_size: int (default: 4)
- per_device_eval_batch_size: int (default: 8)
- gradient_accumulation_steps: int (default: 1)
- lora_r: int (default: 8, only for LORA)
- lora_alpha: int (default: 16, only for LORA)
- lora_dropout: float (default: 0.05, only for LORA)
_latest_training_job: TrainingJob - Most recent training job
Raises:
ValueError: Invalid configuration or missing required parameters
ClientError: AWS API errors (permissions, quotas)
RuntimeError: Training execution errors
Notes:
- Training dataset format: JSONL with "prompt" and "completion" fields
- Validation dataset optional but recommended for monitoring
- LORA typically 5-10x faster and cheaper than FULL fine-tuning
- Use FULL training for significant domain shift or when quality critical
- Gated models (Llama, etc.) require accept_eula=True
- Output model registered to Model Registry if model_package_group specified
"""Dataset Format:
Training datasets must be in JSONL format with specific fields:
{"prompt": "Question: What is machine learning?", "completion": "Machine learning is..."}
{"prompt": "Question: Explain neural networks.", "completion": "Neural networks are..."}Usage:
from sagemaker.train.sft_trainer import SFTTrainer
from sagemaker.train.common import TrainingType
# Create SFT trainer for Llama model
trainer = SFTTrainer(
model="meta-llama/Llama-2-7b-hf",
training_type=TrainingType.LORA,
model_package_group="my-llm-models",
training_dataset="s3://my-bucket/train.jsonl",
validation_dataset="s3://my-bucket/val.jsonl",
s3_output_path="s3://my-bucket/output",
accept_eula=True # Required for Llama models
)
# Configure hyperparameters
trainer.hyperparameters.epochs = 3
trainer.hyperparameters.learning_rate = 2e-4
trainer.hyperparameters.per_device_train_batch_size = 4
trainer.hyperparameters.gradient_accumulation_steps = 4 # Effective batch size = 4*4 = 16
# LoRA-specific parameters
trainer.hyperparameters.lora_r = 8 # LoRA rank
trainer.hyperparameters.lora_alpha = 16 # LoRA alpha
trainer.hyperparameters.lora_dropout = 0.05
trainer.hyperparameters.target_modules = ["q_proj", "v_proj"] # Which layers to adapt
# Train model with error handling
try:
job = trainer.train(wait=True)
print(f"Training completed: {job.training_job_name}")
print(f"Model artifacts: {job.model_artifacts}")
except ValueError as e:
print(f"Configuration error: {e}")
except ClientError as e:
print(f"AWS API error: {e}")Full Fine-Tuning:
# Use FULL training for maximum quality
trainer = SFTTrainer(
model="meta-llama/Llama-2-7b-hf",
training_type=TrainingType.FULL, # Updates all model parameters
training_dataset="s3://my-bucket/large-train-dataset.jsonl",
validation_dataset="s3://my-bucket/val.jsonl",
accept_eula=True
)
# Adjust hyperparameters for full training
trainer.hyperparameters.epochs = 1 # Typically fewer epochs needed
trainer.hyperparameters.learning_rate = 1e-5 # Lower learning rate
trainer.hyperparameters.per_device_train_batch_size = 1 # Smaller batch due to memory
trainer.hyperparameters.gradient_accumulation_steps = 16 # Larger effective batch
job = trainer.train()Direct Preference Optimization trainer for aligning models with human preferences using preference pair datasets.
class DPOTrainer:
"""
Direct Preference Optimization trainer for foundation models.
Parameters:
model: Union[str, ModelPackage] - Foundation model to fine-tune (required)
- Model name/ID, ARN, or ModelPackage object
training_type: Union[TrainingType, str] - Fine-tuning approach (default: TrainingType.LORA)
- TrainingType.LORA or TrainingType.FULL
model_package_group: Optional[Union[str, ModelPackageGroup]] - Model package group for output
mlflow_resource_arn: Optional[str] - MLflow tracking server ARN
mlflow_experiment_name: Optional[str] - MLflow experiment name
mlflow_run_name: Optional[str] - MLflow run name
training_dataset: Optional[Union[str, DataSet]] - Training dataset with preference pairs
- Must contain "prompt", "chosen", "rejected" fields
validation_dataset: Optional[Union[str, DataSet]] - Validation dataset
s3_output_path: Optional[str] - S3 path for outputs
kms_key_id: Optional[str] - KMS key ID for encryption
networking: Optional[VpcConfig] - VPC configuration
accept_eula: bool - Accept EULA for gated models (default: False)
Methods:
train(training_dataset=None, validation_dataset=None, wait=True) -> TrainingJob
Execute DPO training.
Returns:
TrainingJob: Completed training job
Raises:
ValueError: Invalid configuration or dataset format
ClientError: AWS API errors
Attributes:
hyperparameters: FineTuningOptions - Dynamic hyperparameters
- epochs, learning_rate, batch_size, etc.
- beta: float - DPO temperature parameter (default: 0.1)
_latest_training_job: TrainingJob - Most recent training job
Notes:
- Dataset format: JSONL with "prompt", "chosen", "rejected" fields
- DPO directly optimizes for preference without reward model
- More stable and efficient than RLHF
- Requires high-quality preference data
"""Dataset Format for DPO:
{"prompt": "Explain quantum computing", "chosen": "Quantum computing uses...", "rejected": "It's magic..."}
{"prompt": "Write a poem", "chosen": "Roses are red...", "rejected": "asdfghjkl"}Usage:
from sagemaker.train.dpo_trainer import DPOTrainer
# Create DPO trainer with preference pairs dataset
trainer = DPOTrainer(
model="meta-llama/Llama-2-7b-hf",
model_package_group="aligned-models",
training_dataset="s3://my-bucket/preferences.jsonl",
accept_eula=True
)
# Configure DPO-specific parameters
trainer.hyperparameters.beta = 0.1 # Temperature parameter
trainer.hyperparameters.learning_rate = 1e-5
trainer.hyperparameters.epochs = 1
# Train with preference optimization
try:
job = trainer.train(wait=True)
print(f"DPO training completed: {job.training_job_name}")
except ValueError as e:
print(f"Dataset format error - check 'chosen' and 'rejected' fields: {e}")Reinforcement Learning from AI Feedback trainer using LLM-based reward signals.
class RLAIFTrainer:
"""
Reinforcement Learning from AI Feedback trainer.
Parameters:
model: Union[str, ModelPackage] - Foundation model to fine-tune (required)
training_type: Union[TrainingType, str] - Fine-tuning approach (default: TrainingType.LORA)
model_package_group: Optional[Union[str, ModelPackageGroup]] - Model package group
reward_model_id: str - Bedrock model ID for generating AI feedback (required)
Allowed values:
- "openai.gpt-oss-120b-1:0" (us-west-2, us-east-1, ap-northeast-1, eu-west-1)
- "openai.gpt-oss-20b-1:0" (us-west-2, us-east-1, ap-northeast-1, eu-west-1)
- "qwen.qwen3-32b-v1:0" (us-west-2, us-east-1, ap-northeast-1, eu-west-1)
- "qwen.qwen3-coder-30b-a3b-v1:0" (us-west-2, us-east-1, ap-northeast-1, eu-west-1)
- "qwen.qwen3-coder-480b-a35b-v1:0" (us-west-2, ap-northeast-1)
- "qwen.qwen3-235b-a22b-2507-v1:0" (us-west-2, ap-northeast-1)
reward_prompt: Union[str, Evaluator] - Reward prompt or evaluator (required)
Built-in metrics:
- "Builtin.Helpfulness" - Response helpfulness
- "Builtin.Harmlessness" - Safety and harmlessness
- "Builtin.Honesty" - Truthfulness and accuracy
- "Builtin.Conciseness" - Response conciseness
Or custom Evaluator object
mlflow_resource_arn: Optional[Union[str, MlflowTrackingServer]] - MLflow tracking
mlflow_experiment_name: Optional[str] - MLflow experiment name
mlflow_run_name: Optional[str] - MLflow run name
training_dataset: Optional[Union[str, DataSet]] - Training dataset
validation_dataset: Optional[Union[str, DataSet]] - Validation dataset
s3_output_path: Optional[str] - S3 path for outputs
kms_key_id: Optional[str] - KMS key ID for encryption
networking: Optional[VpcConfig] - VPC configuration
accept_eula: bool - Accept EULA for gated models (default: False)
Methods:
train(training_dataset=None, validation_dataset=None, wait=True) -> TrainingJob
Execute RLAIF training with AI feedback.
Returns:
TrainingJob: Completed training job
Raises:
ValueError: Invalid reward_model_id or reward_prompt
ClientError: AWS API errors
Attributes:
reward_model_id: str - Bedrock reward model identifier
reward_prompt: Union[str, Evaluator] - Reward prompt configuration
hyperparameters: FineTuningOptions - Dynamic hyperparameters
_latest_training_job: TrainingJob - Most recent training job
_evaluator_arn: str - Evaluator ARN (if using custom evaluator)
Notes:
- Reward model must be available in your AWS region
- Built-in metrics provide standardized evaluation
- Custom evaluators allow domain-specific reward signals
- Training dataset: prompts for model to generate responses
- Reward model evaluates generated responses during training
"""Usage:
from sagemaker.train.rlaif_trainer import RLAIFTrainer
# Create RLAIF trainer with built-in metric
trainer = RLAIFTrainer(
model="meta-llama/Llama-2-7b-hf",
model_package_group="rl-models",
reward_model_id="openai.gpt-oss-120b-1:0",
reward_prompt="Builtin.Harmlessness", # Built-in safety metric
training_dataset="s3://my-bucket/prompts.jsonl",
accept_eula=True
)
# Configure training
trainer.hyperparameters.epochs = 3
trainer.hyperparameters.learning_rate = 1e-5
# Train with AI feedback
try:
job = trainer.train(wait=True)
print(f"RLAIF training completed: {job.training_job_name}")
except ValueError as e:
if "reward_model_id" in str(e):
print("Invalid reward model ID or not available in your region")
raiseCustom Reward Evaluator:
from sagemaker.ai_registry.evaluator import Evaluator
# Create custom evaluator
evaluator = Evaluator.create(
name="code-quality-evaluator",
type="RewardFunction",
source="arn:aws:lambda:us-west-2:123:function:evaluate-code-quality",
wait=True
)
# Use custom evaluator for domain-specific rewards
trainer = RLAIFTrainer(
model="codellama/CodeLlama-7b-hf",
reward_model_id="openai.gpt-oss-120b-1:0",
reward_prompt=evaluator, # Custom evaluator
training_dataset="s3://my-bucket/code-prompts.jsonl",
accept_eula=True
)
job = trainer.train()Reinforcement Learning from Verifiable Rewards trainer using custom reward functions.
class RLVRTrainer:
"""
Reinforcement Learning from Verifiable Rewards trainer.
Parameters:
model: Union[str, ModelPackage] - Foundation model to fine-tune (required)
training_type: Union[TrainingType, str] - Fine-tuning approach (default: TrainingType.LORA)
model_package_group: Optional[Union[str, ModelPackageGroup]] - Model package group
custom_reward_function: Optional[Union[str, Evaluator]] - Custom reward function evaluator (required)
- Evaluator ARN
- Evaluator object from AI Registry
- Lambda function ARN
- Must return numeric reward value
mlflow_resource_arn: Optional[Union[str, MlflowTrackingServer]] - MLflow tracking
mlflow_experiment_name: Optional[str] - MLflow experiment name
mlflow_run_name: Optional[str] - MLflow run name
training_dataset: Optional[Union[str, DataSet]] - Training dataset
validation_dataset: Optional[Union[str, DataSet]] - Validation dataset
s3_output_path: Optional[str] - S3 path for outputs
kms_key_id: Optional[str] - KMS key ID for encryption
networking: Optional[VpcConfig] - VPC configuration
accept_eula: bool - Accept EULA for gated models (default: False)
Methods:
train(training_dataset=None, validation_dataset=None, wait=True) -> TrainingJob
Execute RLVR training with custom rewards.
Returns:
TrainingJob: Completed training job
Raises:
ValueError: If custom_reward_function not provided
ClientError: AWS API errors
Attributes:
custom_reward_function: Union[str, Evaluator] - Custom reward evaluator
hyperparameters: FineTuningOptions - Dynamic hyperparameters
_latest_training_job: TrainingJob - Most recent training job
Notes:
- Reward function must return numeric score for model outputs
- Use for objective, verifiable metrics (code correctness, math accuracy, etc.)
- Reward function invoked during training to evaluate generated responses
- Training dataset contains prompts for model to generate responses
"""Usage:
from sagemaker.train.rlvr_trainer import RLVRTrainer
from sagemaker.ai_registry.evaluator import Evaluator
# Create custom evaluator for code correctness
evaluator = Evaluator.create(
name="code-correctness-evaluator",
type="RewardFunction",
source="./reward_function.py", # Local file uploaded to Lambda
wait=True
)
# Create RLVR trainer with custom rewards
trainer = RLVRTrainer(
model="codellama/CodeLlama-7b-hf",
model_package_group="rl-code-models",
custom_reward_function=evaluator,
training_dataset="s3://my-bucket/coding-tasks.jsonl",
accept_eula=True
)
# Configure training
trainer.hyperparameters.epochs = 5
trainer.hyperparameters.learning_rate = 1e-5
# Train with verifiable rewards
try:
job = trainer.train(wait=True)
print(f"RLVR training completed: {job.training_job_name}")
except ValueError as e:
print(f"Reward function error: {e}")Example Reward Function (reward_function.py):
import json
def lambda_handler(event, context):
"""
Reward function for code correctness.
Args:
event: Dict with 'prompt' and 'response' fields
Returns:
Dict with 'reward' field (float between 0 and 1)
"""
prompt = event['prompt']
response = event['response']
# Extract test cases from prompt
test_cases = extract_test_cases(prompt)
# Execute generated code with test cases
try:
passed = 0
for test_input, expected_output in test_cases:
actual_output = execute_code(response, test_input)
if actual_output == expected_output:
passed += 1
reward = passed / len(test_cases)
except Exception as e:
# Syntax errors or runtime failures get 0 reward
reward = 0.0
return {
'reward': reward
}Dynamic hyperparameters class with validation for fine-tuning trainers.
class FineTuningOptions:
"""
Dynamic fine-tuning hyperparameters with validation.
Parameters:
options_dict: Dict[str, Any] - Dictionary of options with specifications
Structure: {param_name: {type, default, min, max, enum}}
Methods:
to_dict() -> Dict[str, Any]
Convert to dictionary with string values for API calls.
Returns:
Dict[str, Any]: Hyperparameters as string-valued dictionary
get_info(param_name=None) -> None
Display parameter information and valid ranges.
Parameters:
param_name: Optional[str] - Specific parameter (shows all if None)
Attributes:
_specs: Dict - Parameter specifications (type, default, min, max, enum)
_initialized: bool - Initialization flag for attribute access control
Notes:
- Hyperparameters become dynamic attributes based on model recipe
- Validates types (float, integer, string) when setting values
- Validates ranges (min/max) when setting numeric values
- Validates enums when setting categorical values
- Access via: trainer.hyperparameters.parameter_name
- Setting invalid value raises ValueError with helpful message
Example:
trainer.hyperparameters.epochs = 5 # Valid
trainer.hyperparameters.learning_rate = 0.001 # Valid
trainer.hyperparameters.epochs = -1 # Raises ValueError
"""Common Fine-Tuning Hyperparameters:
# View available hyperparameters
trainer.hyperparameters.get_info()
# Training configuration
trainer.hyperparameters.epochs = 3 # Number of training epochs
trainer.hyperparameters.learning_rate = 2e-4 # Learning rate
trainer.hyperparameters.warmup_ratio = 0.1 # Warmup ratio
trainer.hyperparameters.weight_decay = 0.01 # Weight decay for regularization
# Batch size configuration
trainer.hyperparameters.per_device_train_batch_size = 4 # Batch size per device
trainer.hyperparameters.per_device_eval_batch_size = 8 # Eval batch size
trainer.hyperparameters.gradient_accumulation_steps = 4 # Effective batch = 4*4
# LoRA-specific (only for TrainingType.LORA)
trainer.hyperparameters.lora_r = 8 # LoRA rank (higher = more capacity)
trainer.hyperparameters.lora_alpha = 16 # LoRA scaling (typically 2*lora_r)
trainer.hyperparameters.lora_dropout = 0.05 # LoRA dropout
trainer.hyperparameters.target_modules = ["q_proj", "v_proj"] # Layers to adapt
# Optimization
trainer.hyperparameters.optimizer = "adamw" # Optimizer type
trainer.hyperparameters.max_grad_norm = 1.0 # Gradient clipping
trainer.hyperparameters.lr_scheduler_type = "cosine" # LR schedule
# Logging and saving
trainer.hyperparameters.logging_steps = 10 # Log every N steps
trainer.hyperparameters.save_steps = 500 # Save checkpoint every N steps
trainer.hyperparameters.eval_steps = 100 # Evaluate every N steps
# Check parameter constraints
try:
trainer.hyperparameters.epochs = 100 # May exceed max
except ValueError as e:
print(f"Invalid value: {e}")
trainer.hyperparameters.get_info("epochs") # Show valid rangeConfiguration classes for distributed training with torchrun, MPI, and SageMaker Model Parallelism.
class Torchrun:
"""
Configuration for torchrun-based distributed training.
Parameters:
process_count_per_node: Optional[int] - Number of processes per node
- Typically set to number of GPUs per instance
- Default: auto-detect from instance type
smp: Optional[SMP] - SageMaker Model Parallelism v2 configuration
- Required for model parallelism or tensor parallelism
Attributes:
driver_dir: str - Driver directory path (default: "/opt/ml/code")
driver_script: str - Driver script name
Notes:
- Use for PyTorch distributed data parallel training
- Automatically configures NCCL for GPU communication
- Works with PyTorch DistributedDataParallel (DDP)
"""class MPI:
"""
Configuration for MPI-based distributed training.
Parameters:
process_count_per_node: Optional[int] - Number of processes per node
- Typically equals number of GPUs per instance
mpi_additional_options: Optional[List[str]] - Additional MPI options
- Example: ["-x", "NCCL_DEBUG=INFO"]
Attributes:
driver_dir: str - Driver directory path
driver_script: str - Driver script name
Notes:
- Use for MPI-based frameworks (Horovod, etc.)
- Provides more control over process placement
- Useful for hybrid CPU/GPU workloads
"""class SMP:
"""
SageMaker Model Parallelism v2 configuration.
Parameters:
hybrid_shard_degree: Optional[int] - Hybrid sharding degree
- Number of model shards
- Must evenly divide total GPU count
- Default: 1 (no sharding)
sm_activation_offloading: Optional[bool] - Enable activation offloading
- Offload activations to CPU to save GPU memory
- Trades compute for memory
- Default: False
activation_loading_horizon: Optional[int] - Activation loading horizon
- Prefetch window for activation loading
- Default: 4
fsdp_cache_flush_warnings: Optional[bool] - FSDP cache flush warnings
- Enable warnings for FSDP cache issues
- Default: False
allow_empty_shards: Optional[bool] - Allow empty shards
- Permit shards with no parameters
- Default: False
tensor_parallel_degree: Optional[int] - Tensor parallelism degree
- Number of ways to split tensors
- Must evenly divide GPU count
- Default: 1 (no tensor parallelism)
context_parallel_degree: Optional[int] - Context parallelism degree
- Split sequence dimension for long context
- Default: 1
expert_parallel_degree: Optional[int] - Expert parallelism degree
- For Mixture of Experts (MoE) models
- Default: 1
random_seed: Optional[int] - Random seed for reproducibility
- Range: 0-2147483647
- Default: 12345
Notes:
- Use for training large models that don't fit in single GPU memory
- hybrid_shard_degree: splits model weights across GPUs
- tensor_parallel_degree: splits individual tensors across GPUs
- Combine parallelism types for very large models
- Total parallelism: hybrid_shard × tensor_parallel × context_parallel
- Must evenly divide total GPU count
"""Usage:
from sagemaker.train import ModelTrainer
from sagemaker.train.distributed import Torchrun, SMP
from sagemaker.train.configs import Compute
# Configure distributed training with model parallelism
distributed_config = Torchrun(
process_count_per_node=8, # 8 GPUs per node
smp=SMP(
hybrid_shard_degree=4, # Shard model across 4 GPUs
tensor_parallel_degree=2, # Additional tensor parallelism
sm_activation_offloading=True, # Save memory
random_seed=42
)
)
# Create trainer with distributed config
trainer = ModelTrainer(
training_image="pytorch-training-image",
distributed=distributed_config,
compute=Compute(
instance_type="ml.p4d.24xlarge", # 8 A100 GPUs
instance_count=2, # 16 total GPUs
volume_size_in_gb=500
),
role=role
)
# Distributed configuration automatically applied
trainer.train(input_data_config=[train_data])Data Parallel Only:
# Simple data parallel (no model parallelism)
distributed_config = Torchrun(
process_count_per_node=8 # 8-way data parallelism
)
trainer = ModelTrainer(
training_image="pytorch-image",
distributed=distributed_config,
compute=Compute(
instance_type="ml.p3.16xlarge",
instance_count=4 # 32 total GPUs for data parallelism
)
)MPI-Based Training:
from sagemaker.train.distributed import MPI
# Configure MPI for Horovod
mpi_config = MPI(
process_count_per_node=8,
mpi_additional_options=[
"-x", "NCCL_DEBUG=INFO",
"-x", "HOROVOD_TIMELINE=timeline.json"
]
)
trainer = ModelTrainer(
training_image="horovod-training-image",
distributed=mpi_config,
compute=Compute(
instance_type="ml.p3.16xlarge",
instance_count=2
)
)class SourceCode:
"""
Source code configuration for training.
Fields:
source_dir: Optional[str] - Directory containing source code
- Local directory path
- Contents uploaded to S3 automatically
- Extracted to /opt/ml/code in container
entry_script: Optional[str] - Entry point script
- Relative to source_dir
- Example: "train.py"
- Script receives hyperparameters as arguments
command: Optional[str] - Command to execute
- Overrides entry_script if provided
- Example: "python train.py --custom-args"
requirements: Optional[str] - Requirements file path
- Relative to source_dir
- Example: "requirements.txt"
- Installed via pip before training
ignore_patterns: Optional[List[str]] - Patterns to ignore during upload
- Glob patterns
- Example: ["*.pyc", "__pycache__", ".git"]
Notes:
- Either entry_script or command must be provided
- requirements.txt automatically installed if present in source_dir
- Large directories may slow uploads - use ignore_patterns
- Source code uploaded to S3 and cached for subsequent runs
"""class Compute:
"""
Compute resource configuration.
Fields:
instance_type: str - EC2 instance type (required)
- Training: ml.m5.xlarge, ml.p3.2xlarge, ml.p4d.24xlarge, etc.
- Check SageMaker documentation for full list
instance_count: int - Number of instances (default: 1)
- Range: 1-100
- Use >1 for distributed training
volume_size_in_gb: int - EBS volume size (default: 30)
- Range: 1-16384 GB
- Must be large enough for training data + model
volume_kms_key_id: Optional[str] - KMS key for volume encryption
- Format: "arn:aws:kms:region:account:key/key-id"
enable_managed_spot_training: Optional[bool] - Use managed spot training (default: False)
- Can reduce costs by up to 90%
- Training may be interrupted
- Requires checkpoint_config for resumption
keep_alive_period_in_seconds: Optional[int] - Keep alive period for warm pools (default: 0)
- Range: 0-3600 seconds
- Reduces startup time for repeated training
- Instances billed while warm
Notes:
- Spot training requires max_wait_time_in_seconds > max_runtime_in_seconds
- Warm pools incur costs during keep-alive period
- Choose instance type based on:
- CPU-bound: ml.m5, ml.c5
- GPU training: ml.p3, ml.p4d
- Memory-intensive: ml.r5
- Multi-instance requires distributed training configuration
"""class StoppingCondition:
"""
Training stopping criteria.
Fields:
max_runtime_in_seconds: Optional[int] - Maximum training runtime (default: 86400)
- Range: 1-2419200 (28 days)
- Training stopped if exceeded
- Billable time capped at this value
max_pending_time_in_seconds: Optional[int] - Maximum pending time (default: unlimited)
- Time waiting for resources
- Training fails if exceeded
max_wait_time_in_seconds: Optional[int] - Maximum wait time for spot (default: unlimited)
- Only for managed spot training
- Must be >= max_runtime_in_seconds
- Includes time waiting for spot capacity + training time
Notes:
- Always set reasonable max_runtime to control costs
- For spot: max_wait_time >= max_runtime
- Exceeded limits result in training job failure
- Use checkpoints to resume interrupted spot training
"""class Networking:
"""
Network configuration for training.
Fields:
subnets: Optional[List[str]] - VPC subnet IDs
- Format: ["subnet-xxx", "subnet-yyy"]
- Use multiple subnets across AZs for availability
security_group_ids: Optional[List[str]] - Security group IDs
- Format: ["sg-xxx"]
- Must allow inter-instance communication for distributed training
enable_network_isolation: Optional[bool] - Enable network isolation (default: False)
- Blocks all network access except S3/ECR
- No internet access
- Use for compliance requirements
enable_inter_container_traffic_encryption: Optional[bool] - Enable encryption (default: False)
- Encrypts traffic between training instances
- Required for distributed training with network isolation
- Performance impact for high-bandwidth workloads
Notes:
- VPC configuration required for private data sources
- Security groups must allow:
- Ingress: all traffic from same security group (distributed training)
- Egress: HTTPS to S3/ECR
- Network isolation incompatible with custom VPC endpoints
- Encryption adds latency to inter-instance communication
"""class InputData:
"""
Simplified input data configuration.
Fields:
channel_name: str - Channel name (required)
- Examples: "training", "validation", "test"
- Accessible in container at /opt/ml/input/data/{channel_name}
data_source: Union[str, DataSource] - S3 URI or DataSource object (required)
- S3 URI: "s3://bucket/prefix"
- DataSource object for advanced configuration
content_type: Optional[str] - Content type
- Examples: "application/json", "text/csv", "image/jpeg"
- Passed to training container
Notes:
- Multiple channels can be specified
- S3 data must be in same region as training job
- Automatic download to local storage by default
- Use Channel for advanced options (compression, input mode)
"""class Channel:
"""
Full input data channel configuration.
Fields:
channel_name: str - Channel name (required)
data_source: DataSource - Data source configuration (required)
- S3 data source with URI and distribution settings
content_type: Optional[str] - Content type
- MIME type of data
compression_type: Optional[str] - Compression type
- "None" or "Gzip"
- Default: "None"
input_mode: Optional[str] - Input mode
- "File": Download all data before training (default)
- "Pipe": Stream data during training (for large datasets)
Notes:
- Pipe mode reduces training start time for large datasets
- Pipe mode requires training code to read from pipe
- File mode simpler but slower for large data
- Gzip decompressed automatically
"""class OutputDataConfig:
"""
Output data location configuration for training jobs.
Fields:
s3_output_path: Optional[str] - S3 URI where output data will be stored
- Format: "s3://bucket/prefix"
- Default: s3://sagemaker-{region}-{account}/output
- Model artifacts saved to {s3_output_path}/{job-name}/output/model.tar.gz
kms_key_id: Optional[str] - KMS key for encrypting model artifacts
- Format: "arn:aws:kms:region:account:key/key-id"
- Applied to model artifacts and training output
compression_type: Optional[str] - Model output compression type
- "None" or "Gzip"
- Default: "Gzip"
Notes:
- Output bucket must be in same region as training job
- Execution role needs s3:PutObject permission
- Encryption key must allow encrypt/decrypt from execution role
- Compressed output reduces storage costs and transfer time
"""class CheckpointConfig:
"""
Checkpoint configuration for training jobs.
Fields:
s3_uri: Optional[str] - S3 path for checkpoint data (required)
- Format: "s3://bucket/prefix/checkpoints"
- Checkpoints automatically saved here during training
local_path: Optional[str] - Local directory for checkpoints
- Default: "/opt/ml/checkpoints"
- Training code should save checkpoints here
- Automatically synced to S3
Notes:
- Essential for spot training to resume after interruption
- Training code must implement checkpoint save/load
- Checkpoints synced to S3 asynchronously
- Use for long-running training jobs (>1 hour)
- Execution role needs s3:PutObject permission on s3_uri
"""class TensorBoardOutputConfig:
"""
Storage locations for TensorBoard output.
Fields:
s3_output_path: Optional[str] - S3 path for TensorBoard output (required)
- Format: "s3://bucket/prefix/tensorboard"
- TensorBoard logs automatically uploaded
local_path: Optional[str] - Local path for TensorBoard output
- Default: "/opt/ml/output/tensorboard"
- Training code should write TensorBoard logs here
Notes:
- Use with: trainer.with_tensorboard_output_config()
- View logs: tensorboard --logdir=s3://bucket/prefix/tensorboard
- Logs synced to S3 during and after training
- Requires TensorBoard in training container
"""class TrainingImageConfig:
"""
Configuration for training container image.
Fields:
training_image_config_training_repository_access_mode: str - Access mode (required)
- "Platform": Use SageMaker-managed images
- "Vpc": Use private registry via VPC endpoint
training_repository_auth_config: Optional[TrainingRepositoryAuthConfig]
- Authentication for private registries
- Required if using private registry
Notes:
- Platform mode: no additional configuration needed
- VPC mode: requires VPC endpoint to private registry
- Private registries must support Docker Registry HTTP API V2
"""class RetryStrategy:
"""
Retry strategy for training jobs.
Fields:
maximum_retry_attempts: int - Maximum number of retry attempts (required)
- Range: 1-10
- Retries for infrastructure failures only
- Not for algorithmic errors
Notes:
- Use with: trainer.with_retry_strategy()
- Retries automatically for:
- InternalServerError
- CapacityError (insufficient instance capacity)
- ThrottlingException
- Does not retry for:
- AlgorithmError (code errors)
- ValidationException (configuration errors)
- Each retry counted as separate training job for billing
"""class InfraCheckConfig:
"""
Infrastructure health check configuration.
Fields:
enable_infra_check: bool - Enable infrastructure checks before training (required)
- Validates network connectivity, instance health
- Adds ~2-3 minutes to startup time
- Recommended for production workloads
Notes:
- Catches infrastructure issues before training starts
- Reduces wasted time on failing training jobs
- Checks: network connectivity, volume mounts, GPU availability
"""class RemoteDebugConfig:
"""
Remote debugging configuration for training jobs.
Fields:
enable_remote_debug: bool - Enable remote debugging (required)
- Allows SSH access to training instance
- Requires additional setup in training script
Notes:
- Use for debugging complex training issues
- Requires SSH key configuration
- Security consideration: opens SSH port
- Only for development, not production
"""class SessionChainingConfig:
"""
Session chaining configuration for training jobs.
Fields:
enable_session_tag_chaining: bool - Enable session tag chaining (required)
- Propagates session tags to created resources
- For compliance and cost tracking
Notes:
- Tags from session automatically applied to:
- Training jobs
- Model artifacts
- Endpoints (if created from this training)
- Useful for automated tagging policies
"""class MetricDefinition:
"""
Definition for extracting metrics from training job logs.
Fields:
name: str - Metric name (required)
- Example: "train:loss"
- Appears in CloudWatch and training job details
regex: str - Regex pattern to extract metric from logs (required)
- Must contain exactly one capture group
- Capture group must match a number
- Example: "Train Loss: ([0-9\\.]+)"
Notes:
- Metrics extracted from CloudWatch logs
- Use with: trainer.with_metric_definitions()
- Captured metrics available for:
- Hyperparameter tuning objectives
- CloudWatch metrics and alarms
- Training job summaries
- Regex must match log output format exactly
- Common regex patterns:
- Float: ([0-9\\.]+)
- Scientific notation: ([0-9\\.e\\-\\+]+)
- Percentage: ([0-9\\.]+)%
"""class Mode(Enum):
"""
Training mode for ModelTrainer.
Values:
LOCAL_CONTAINER = 1
- Run training in a local Docker container
- For testing before submitting to SageMaker
- Requires Docker installed locally
SAGEMAKER_TRAINING_JOB = 2
- Run training as a SageMaker training job (default)
- Uses SageMaker managed infrastructure
- Billed per second of usage
Notes:
- LOCAL_CONTAINER useful for:
- Debugging training code
- Testing without SageMaker costs
- Validating container configuration
- SAGEMAKER_TRAINING_JOB provides:
- Scalable infrastructure
- Distributed training
- Spot instances
- Automatic checkpointing to S3
"""class TrainingType(Enum):
"""
Fine-tuning training type.
Values:
LORA = "lora"
- Low-Rank Adaptation (parameter-efficient fine-tuning)
- Updates only small adapter layers (~1% of parameters)
- 5-10x faster and cheaper than full fine-tuning
- Good quality for most use cases
- Recommended for initial experiments
FULL = "full"
- Full model fine-tuning (all parameters updated)
- Maximum quality for complex adaptations
- Requires more compute and memory
- Use when LORA quality insufficient
- Better for significant domain shifts
Notes:
- LORA advantages:
- Lower compute costs
- Faster training
- Less memory required
- Multiple adapters from same base model
- FULL advantages:
- Maximum quality
- Better for complex tasks
- No architectural constraints
- Start with LORA, upgrade to FULL if needed
"""class CustomizationTechnique(Enum):
"""
Model customization technique.
Values:
SFT = "sft"
- Supervised Fine-Tuning
- Uses labeled instruction-response pairs
- Best for adapting to new tasks
DPO = "dpo"
- Direct Preference Optimization
- Uses preference pairs (chosen vs rejected)
- Aligns model with human preferences
- More stable than RLHF
RLAIF = "rlaif"
- Reinforcement Learning from AI Feedback
- Uses AI model for reward signals
- Scales better than human feedback
RLVR = "rlvr"
- Reinforcement Learning from Verifiable Rewards
- Uses programmatic reward functions
- Best for objective metrics (code, math)
Notes:
- SFT: Start here for task adaptation
- DPO: Use after SFT for alignment
- RLAIF: Alternative to RLHF with AI rewards
- RLVR: For tasks with objective evaluation
"""Parameter range classes for defining hyperparameter search spaces in tuning jobs.
class ParameterRange:
"""
Base class for parameter ranges.
Abstract base class for defining hyperparameter ranges in tuning jobs.
Subclasses: ContinuousParameter, IntegerParameter, CategoricalParameter
"""class ContinuousParameter(ParameterRange):
"""
Continuous parameter range for floating-point hyperparameters.
Parameters:
min_value: float - Minimum value (inclusive) (required)
max_value: float - Maximum value (inclusive) (required)
scaling_type: str - Scaling type (default: "Auto")
- "Auto": Automatically choose scaling
- "Linear": Linear scaling
- "Logarithmic": Log scaling (for learning rates, etc.)
- "ReverseLogarithmic": Reverse log scaling
Notes:
- Use Logarithmic for parameters spanning orders of magnitude
- Linear for parameters in similar scale
- Scaling affects sampling strategy in Bayesian optimization
"""class IntegerParameter(ParameterRange):
"""
Integer parameter range for integer hyperparameters.
Parameters:
min_value: int - Minimum value (inclusive) (required)
max_value: int - Maximum value (inclusive) (required)
scaling_type: str - Scaling type (default: "Auto")
- "Auto", "Linear", "Logarithmic", "ReverseLogarithmic"
Notes:
- Values always integers even with log scaling
- Use for batch sizes, layers, etc.
"""class CategoricalParameter(ParameterRange):
"""
Categorical parameter range for discrete hyperparameters.
Parameters:
values: List[Union[str, int, float]] - List of possible values (required)
- Examples: ["adam", "sgd"], [32, 64, 128]
Notes:
- No scaling applied (categorical)
- Values must be JSON-serializable
- Order doesn't matter for Bayesian optimization
"""Usage:
from sagemaker.core.parameter import (
ContinuousParameter, IntegerParameter, CategoricalParameter
)
# Define parameter ranges
hyperparameter_ranges = {
# Learning rate with log scaling (spans orders of magnitude)
"learning_rate": ContinuousParameter(
min_value=0.001,
max_value=0.1,
scaling_type="Logarithmic"
),
# Batch size with linear scaling
"batch_size": IntegerParameter(
min_value=32,
max_value=256,
scaling_type="Linear"
),
# Optimizer choices
"optimizer": CategoricalParameter(
values=["adam", "sgd", "rmsprop", "adamw"]
),
# Number of layers
"num_layers": IntegerParameter(
min_value=2,
max_value=10
),
# Dropout rate
"dropout": ContinuousParameter(
min_value=0.0,
max_value=0.5
)
}class Session:
"""
SageMaker session for managing API interactions.
Main class for managing interactions with SageMaker APIs and AWS services.
Handles authentication, region configuration, and service clients.
Parameters:
boto_session: Optional[boto3.Session] - Boto3 session for AWS credentials
region_name: Optional[str] - AWS region (default: from boto session)
default_bucket: Optional[str] - Default S3 bucket for SageMaker resources
Methods:
get_execution_role() -> str: Get IAM role ARN
default_bucket() -> str: Get default S3 bucket
upload_data(path, bucket, key_prefix) -> str: Upload data to S3
download_data(path, bucket, key_prefix): Download data from S3
Notes:
- Automatically created if not provided to ModelTrainer
- Reuse session across multiple trainers for consistency
- Default bucket: sagemaker-{region}-{account_id}
"""def get_execution_role() -> str:
"""
Get the execution role ARN for SageMaker.
Returns:
str: IAM role ARN for SageMaker execution
Format: "arn:aws:iam::{account}:role/{role-name}"
Raises:
ValueError: If role cannot be determined from environment
- Not running in SageMaker notebook instance
- IAM role not configured in environment
- No default execution role available
Notes:
- Automatically detects role when running in:
- SageMaker notebook instances
- SageMaker Studio
- SageMaker Processing jobs
- SageMaker Training jobs
- Outside SageMaker, explicitly provide role ARN to ModelTrainer
Usage:
# In SageMaker environment
role = get_execution_role()
# Outside SageMaker
role = "arn:aws:iam::123456789012:role/SageMakerRole"
"""ResourceLimitExceeded:
ValidationException: Invalid hyperparameter:
S3 Access Denied:
Spot Instance Interruption:
AlgorithmError:
Invalid Gated Model Access: