State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Comprehensive training framework with built-in optimization, distributed training support, logging, evaluation, and extensive customization options. The Trainer provides a high-level interface for fine-tuning transformer models while supporting advanced features like gradient accumulation, mixed precision, and custom training loops.
Main training class that handles the complete training loop with automatic optimization, logging, and evaluation.
class Trainer:
def __init__(
self,
model: PreTrainedModel = None,
args: TrainingArguments = None,
data_collator: DataCollator = None,
train_dataset: Dataset = None,
eval_dataset: Union[Dataset, Dict[str, Dataset]] = None,
processing_class: Union[PreTrainedTokenizer, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] = None,
model_init: Callable[[], PreTrainedModel] = None,
compute_metrics: Callable[[EvalPrediction], Dict] = None,
callbacks: List[TrainerCallback] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None
):
"""
Initialize trainer with model, training arguments, and datasets.
Args:
model: Model to train
args: Training configuration
data_collator: Collates batch data
train_dataset: Training dataset
eval_dataset: Evaluation dataset(s)
processing_class: Processing class (tokenizer, image processor, etc.) for the model
compute_metrics: Function to compute evaluation metrics
callbacks: List of training callbacks
optimizers: Custom optimizer and scheduler tuple
preprocess_logits_for_metrics: Preprocess logits before metrics
"""
def train(
self,
resume_from_checkpoint: Union[str, bool] = None,
trial: Union[optuna.Trial, Dict[str, Any]] = None,
ignore_keys_for_eval: List[str] = None,
**kwargs
) -> TrainOutput:
"""
Start training process.
Args:
resume_from_checkpoint: Path to checkpoint or True for latest
trial: Optuna trial for hyperparameter optimization
ignore_keys_for_eval: Keys to ignore during evaluation
Returns:
Training output with metrics and statistics
"""
def evaluate(
self,
eval_dataset: Dataset = None,
ignore_keys: List[str] = None,
metric_key_prefix: str = "eval"
) -> Dict[str, float]:
"""
Evaluate model on evaluation dataset.
Args:
eval_dataset: Dataset to evaluate on (uses default if None)
ignore_keys: Keys to ignore in output
metric_key_prefix: Prefix for metric names
Returns:
Dictionary of evaluation metrics
"""
def predict(
self,
test_dataset: Dataset,
ignore_keys: List[str] = None,
metric_key_prefix: str = "test"
) -> PredictionOutput:
"""
Make predictions on test dataset.
Args:
test_dataset: Dataset to predict on
ignore_keys: Keys to ignore in output
metric_key_prefix: Prefix for metric names
Returns:
Predictions with metrics and labels
"""
def save_model(
self,
output_dir: str = None,
_internal_call: bool = False
) -> None:
"""Save model and tokenizer to directory."""
def save_state(self) -> None:
"""Save trainer state for resuming training."""
def log(self, logs: Dict[str, float]) -> None:
"""Log metrics and values."""
def create_optimizer_and_scheduler(
self,
num_training_steps: int
) -> None:
"""Create optimizer and learning rate scheduler."""Comprehensive configuration class for all training hyperparameters and settings.
class TrainingArguments:
def __init__(
self,
output_dir: str,
overwrite_output_dir: bool = False,
do_train: bool = False,
do_eval: bool = False,
do_predict: bool = False,
evaluation_strategy: Union[IntervalStrategy, str] = "no",
prediction_loss_only: bool = False,
per_device_train_batch_size: int = 8,
per_device_eval_batch_size: int = 8,
per_gpu_train_batch_size: Optional[int] = None,
per_gpu_eval_batch_size: Optional[int] = None,
gradient_accumulation_steps: int = 1,
eval_accumulation_steps: Optional[int] = None,
eval_delay: Optional[float] = 0,
learning_rate: float = 5e-5,
weight_decay: float = 0.0,
adam_beta1: float = 0.9,
adam_beta2: float = 0.999,
adam_epsilon: float = 1e-8,
max_grad_norm: float = 1.0,
num_train_epochs: float = 3.0,
max_steps: int = -1,
lr_scheduler_type: Union[SchedulerType, str] = "linear",
warmup_ratio: float = 0.0,
warmup_steps: int = 0,
log_level: Optional[str] = "passive",
log_level_replica: Optional[str] = "warning",
log_on_each_node: bool = True,
logging_dir: Optional[str] = None,
logging_strategy: Union[IntervalStrategy, str] = "steps",
logging_first_step: bool = False,
logging_steps: int = 500,
logging_nan_inf_filter: bool = True,
save_strategy: Union[IntervalStrategy, str] = "steps",
save_steps: int = 500,
save_total_limit: Optional[int] = None,
save_safetensors: Optional[bool] = True,
save_on_each_node: bool = False,
no_cuda: bool = False,
use_cpu: bool = False,
use_mps_device: bool = False,
seed: int = 42,
data_seed: Optional[int] = None,
jit_mode_eval: bool = False,
use_ipex: bool = False,
bf16: bool = False,
fp16: bool = False,
fp16_opt_level: str = "O1",
half_precision_backend: str = "auto",
bf16_full_eval: bool = False,
fp16_full_eval: bool = False,
tf32: Optional[bool] = None,
local_rank: int = -1,
ddp_backend: Optional[str] = None,
ddp_timeout: Optional[int] = 1800,
ddp_find_unused_parameters: Optional[bool] = None,
ddp_bucket_cap_mb: Optional[int] = None,
ddp_broadcast_buffers: Optional[bool] = None,
dataloader_pin_memory: bool = True,
dataloader_num_workers: int = 0,
past_index: int = -1,
run_name: Optional[str] = None,
disable_tqdm: Optional[bool] = None,
remove_unused_columns: bool = True,
label_names: Optional[List[str]] = None,
load_best_model_at_end: Optional[bool] = False,
metric_for_best_model: Optional[str] = None,
greater_is_better: Optional[bool] = None,
ignore_data_skip: bool = False,
sharded_ddp: str = "",
fsdp: str = "",
fsdp_min_num_params: int = 0,
fsdp_config: Optional[str] = None,
fsdp_transformer_layer_cls_to_wrap: Optional[str] = None,
deepspeed: Optional[str] = None,
label_smoothing_factor: float = 0.0,
optim: Union[OptimizerNames, str] = "adamw_torch",
optim_args: Optional[str] = None,
adafactor: bool = False,
group_by_length: bool = False,
length_column_name: Optional[str] = "length",
report_to: Optional[List[str]] = None,
ddp_find_unused_parameters: Optional[bool] = None,
ddp_bucket_cap_mb: Optional[int] = None,
ddp_broadcast_buffers: Optional[bool] = None,
dataloader_pin_memory: bool = True,
skip_memory_metrics: bool = True,
use_legacy_prediction_loop: bool = False,
push_to_hub: bool = False,
resume_from_checkpoint: Optional[str] = None,
hub_model_id: Optional[str] = None,
hub_strategy: Union[HubStrategy, str] = "every_save",
hub_token: Optional[str] = None,
hub_private_repo: bool = False,
hub_always_push: bool = False,
gradient_checkpointing: bool = False,
include_inputs_for_metrics: bool = False,
fp16_backend: str = "auto",
push_to_hub_model_id: Optional[str] = None,
push_to_hub_organization: Optional[str] = None,
push_to_hub_token: Optional[str] = None,
mp_parameters: str = "",
auto_find_batch_size: bool = False,
full_determinism: bool = False,
torchdynamo: Optional[str] = None,
ray_scope: Optional[str] = "last",
ddp_timeout: Optional[int] = 1800,
torch_compile: bool = False,
torch_compile_backend: Optional[str] = None,
torch_compile_mode: Optional[str] = None,
dispatch_batches: Optional[bool] = None,
split_batches: Optional[bool] = None,
include_tokens_per_second: Optional[bool] = False,
**kwargs
):
"""
Configure training parameters.
Key parameters:
output_dir: Directory to save model and checkpoints
num_train_epochs: Number of training epochs
per_device_train_batch_size: Batch size per device
learning_rate: Learning rate for optimization
weight_decay: Weight decay for regularization
warmup_steps: Linear warmup steps
logging_steps: Log every N steps
save_steps: Save checkpoint every N steps
evaluation_strategy: When to evaluate ("steps", "epoch", "no")
fp16: Enable mixed precision training
gradient_accumulation_steps: Accumulate gradients over N steps
dataloader_num_workers: Number of data loading workers
remove_unused_columns: Remove unused dataset columns
load_best_model_at_end: Load best model after training
metric_for_best_model: Metric to determine best model
push_to_hub: Upload model to Hugging Face Hub
"""Utilities for batching and preprocessing data during training.
class DataCollatorWithPadding:
def __init__(
self,
tokenizer: PreTrainedTokenizer,
padding: Union[bool, str] = True,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
return_tensors: str = "pt"
):
"""
Collator that pads sequences to the same length.
Args:
tokenizer: Tokenizer to use for padding
padding: Padding strategy
max_length: Maximum sequence length
pad_to_multiple_of: Pad to multiple of this value
return_tensors: Format of returned tensors
"""
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
"""Collate and pad batch of features."""
class DataCollatorForLanguageModeling:
def __init__(
self,
tokenizer: PreTrainedTokenizer,
mlm: bool = True,
mlm_probability: float = 0.15,
pad_to_multiple_of: Optional[int] = None,
tf_experimental_compile: bool = False,
return_tensors: str = "pt"
):
"""
Collator for language modeling tasks.
Args:
tokenizer: Tokenizer to use
mlm: Whether to use masked language modeling
mlm_probability: Probability of masking tokens
pad_to_multiple_of: Pad to multiple of this value
return_tensors: Format of returned tensors
"""
class DataCollatorForSeq2Seq:
def __init__(
self,
tokenizer: PreTrainedTokenizer,
model: Optional[PreTrainedModel] = None,
padding: Union[bool, str] = True,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
label_pad_token_id: int = -100,
return_tensors: str = "pt"
):
"""
Collator for sequence-to-sequence tasks.
Args:
tokenizer: Tokenizer to use
model: Model to get decoder start token
padding: Padding strategy
max_length: Maximum sequence length
label_pad_token_id: Token ID for padding labels
return_tensors: Format of returned tensors
"""Extensible callback system for customizing training behavior.
class TrainerCallback:
"""Base class for trainer callbacks."""
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""Called at the end of trainer initialization."""
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""Called at the beginning of training."""
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""Called at the end of training."""
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""Called at the beginning of each epoch."""
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""Called at the end of each epoch."""
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""Called at the beginning of each training step."""
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""Called at the end of each training step."""
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""Called after evaluation."""
class EarlyStoppingCallback(TrainerCallback):
def __init__(
self,
early_stopping_patience: int = 1,
early_stopping_threshold: Optional[float] = 0.0
):
"""
Callback for early stopping based on evaluation metrics.
Args:
early_stopping_patience: Number of evaluations to wait
early_stopping_threshold: Minimum improvement threshold
"""
class TensorBoardCallback(TrainerCallback):
"""Log training metrics to TensorBoard."""
class WandbCallback(TrainerCallback):
"""Log training metrics to Weights & Biases."""Learning rate schedulers and optimizers for effective training.
def get_scheduler(
name: Union[str, SchedulerType],
optimizer: torch.optim.Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
**kwargs
) -> torch.optim.lr_scheduler.LambdaLR:
"""
Create learning rate scheduler.
Args:
name: Scheduler type ("linear", "cosine", "polynomial", etc.)
optimizer: Optimizer to schedule
num_warmup_steps: Number of warmup steps
num_training_steps: Total training steps
Returns:
Configured scheduler
"""
def get_linear_schedule_with_warmup(
optimizer: torch.optim.Optimizer,
num_warmup_steps: int,
num_training_steps: int,
last_epoch: int = -1
) -> torch.optim.lr_scheduler.LambdaLR:
"""Linear schedule with linear warmup."""
def get_cosine_schedule_with_warmup(
optimizer: torch.optim.Optimizer,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float = 0.5,
last_epoch: int = -1
) -> torch.optim.lr_scheduler.LambdaLR:
"""Cosine schedule with linear warmup."""
class AdamW(torch.optim.Optimizer):
def __init__(
self,
params,
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.01,
correct_bias: bool = True
):
"""
AdamW optimizer with weight decay.
Args:
params: Model parameters
lr: Learning rate
betas: Adam beta parameters
eps: Epsilon for numerical stability
weight_decay: Weight decay coefficient
correct_bias: Apply bias correction
"""Structured outputs from training and evaluation methods.
class TrainOutput:
"""Output from training."""
global_step: int
training_loss: float
metrics: Dict[str, float]
class EvalPrediction:
"""Predictions and labels for evaluation."""
predictions: Union[np.ndarray, Tuple[np.ndarray]]
label_ids: Optional[np.ndarray]
inputs: Optional[np.ndarray]
class PredictionOutput:
"""Output from prediction."""
predictions: Union[np.ndarray, Tuple[np.ndarray]]
label_ids: Optional[np.ndarray]
metrics: Optional[Dict[str, float]]
class TrainerState:
"""Internal trainer state."""
epoch: Optional[float] = None
global_step: int = 0
max_steps: int = 0
logging_steps: int = 500
eval_steps: int = 500
save_steps: int = 500
train_batch_size: int = None
num_train_epochs: int = 0
total_flos: int = 0
log_history: List[Dict[str, float]] = None
best_metric: Optional[float] = None
best_model_checkpoint: Optional[str] = None
is_local_process_zero: bool = True
is_world_process_zero: bool = True
class TrainerControl:
"""Control flags for trainer behavior."""
should_training_stop: bool = False
should_epoch_stop: bool = False
should_save: bool = False
should_evaluate: bool = False
should_log: bool = FalseCommon training patterns and configurations:
# Basic fine-tuning setup
from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=64,
warmup_steps=500,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=100,
evaluation_strategy="steps",
eval_steps=500,
save_steps=500,
load_best_model_at_end=True,
metric_for_best_model="eval_accuracy"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics
)
# Start training
trainer.train()
# Evaluate final model
eval_results = trainer.evaluate()
# Make predictions
predictions = trainer.predict(test_dataset)Install with Tessl CLI
npx tessl i tessl/pypi-transformers