0
# Training
1
2
Comprehensive training framework with built-in optimization, distributed training support, logging, evaluation, and extensive customization options. The Trainer provides a high-level interface for fine-tuning transformer models while supporting advanced features like gradient accumulation, mixed precision, and custom training loops.
3
4
## Capabilities
5
6
### Trainer Class
7
8
Main training class that handles the complete training loop with automatic optimization, logging, and evaluation.
9
10
```python { .api }
11
class Trainer:
12
def __init__(
13
self,
14
model: PreTrainedModel = None,
15
args: TrainingArguments = None,
16
data_collator: DataCollator = None,
17
train_dataset: Dataset = None,
18
eval_dataset: Union[Dataset, Dict[str, Dataset]] = None,
19
processing_class: Union[PreTrainedTokenizer, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] = None,
20
model_init: Callable[[], PreTrainedModel] = None,
21
compute_metrics: Callable[[EvalPrediction], Dict] = None,
22
callbacks: List[TrainerCallback] = None,
23
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
24
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None
25
):
26
"""
27
Initialize trainer with model, training arguments, and datasets.
28
29
Args:
30
model: Model to train
31
args: Training configuration
32
data_collator: Collates batch data
33
train_dataset: Training dataset
34
eval_dataset: Evaluation dataset(s)
35
processing_class: Processing class (tokenizer, image processor, etc.) for the model
36
compute_metrics: Function to compute evaluation metrics
37
callbacks: List of training callbacks
38
optimizers: Custom optimizer and scheduler tuple
39
preprocess_logits_for_metrics: Preprocess logits before metrics
40
"""
41
42
def train(
43
self,
44
resume_from_checkpoint: Union[str, bool] = None,
45
trial: Union[optuna.Trial, Dict[str, Any]] = None,
46
ignore_keys_for_eval: List[str] = None,
47
**kwargs
48
) -> TrainOutput:
49
"""
50
Start training process.
51
52
Args:
53
resume_from_checkpoint: Path to checkpoint or True for latest
54
trial: Optuna trial for hyperparameter optimization
55
ignore_keys_for_eval: Keys to ignore during evaluation
56
57
Returns:
58
Training output with metrics and statistics
59
"""
60
61
def evaluate(
62
self,
63
eval_dataset: Dataset = None,
64
ignore_keys: List[str] = None,
65
metric_key_prefix: str = "eval"
66
) -> Dict[str, float]:
67
"""
68
Evaluate model on evaluation dataset.
69
70
Args:
71
eval_dataset: Dataset to evaluate on (uses default if None)
72
ignore_keys: Keys to ignore in output
73
metric_key_prefix: Prefix for metric names
74
75
Returns:
76
Dictionary of evaluation metrics
77
"""
78
79
def predict(
80
self,
81
test_dataset: Dataset,
82
ignore_keys: List[str] = None,
83
metric_key_prefix: str = "test"
84
) -> PredictionOutput:
85
"""
86
Make predictions on test dataset.
87
88
Args:
89
test_dataset: Dataset to predict on
90
ignore_keys: Keys to ignore in output
91
metric_key_prefix: Prefix for metric names
92
93
Returns:
94
Predictions with metrics and labels
95
"""
96
97
def save_model(
98
self,
99
output_dir: str = None,
100
_internal_call: bool = False
101
) -> None:
102
"""Save model and tokenizer to directory."""
103
104
def save_state(self) -> None:
105
"""Save trainer state for resuming training."""
106
107
def log(self, logs: Dict[str, float]) -> None:
108
"""Log metrics and values."""
109
110
def create_optimizer_and_scheduler(
111
self,
112
num_training_steps: int
113
) -> None:
114
"""Create optimizer and learning rate scheduler."""
115
```
116
117
### Training Arguments
118
119
Comprehensive configuration class for all training hyperparameters and settings.
120
121
```python { .api }
122
class TrainingArguments:
123
def __init__(
124
self,
125
output_dir: str,
126
overwrite_output_dir: bool = False,
127
do_train: bool = False,
128
do_eval: bool = False,
129
do_predict: bool = False,
130
evaluation_strategy: Union[IntervalStrategy, str] = "no",
131
prediction_loss_only: bool = False,
132
per_device_train_batch_size: int = 8,
133
per_device_eval_batch_size: int = 8,
134
per_gpu_train_batch_size: Optional[int] = None,
135
per_gpu_eval_batch_size: Optional[int] = None,
136
gradient_accumulation_steps: int = 1,
137
eval_accumulation_steps: Optional[int] = None,
138
eval_delay: Optional[float] = 0,
139
learning_rate: float = 5e-5,
140
weight_decay: float = 0.0,
141
adam_beta1: float = 0.9,
142
adam_beta2: float = 0.999,
143
adam_epsilon: float = 1e-8,
144
max_grad_norm: float = 1.0,
145
num_train_epochs: float = 3.0,
146
max_steps: int = -1,
147
lr_scheduler_type: Union[SchedulerType, str] = "linear",
148
warmup_ratio: float = 0.0,
149
warmup_steps: int = 0,
150
log_level: Optional[str] = "passive",
151
log_level_replica: Optional[str] = "warning",
152
log_on_each_node: bool = True,
153
logging_dir: Optional[str] = None,
154
logging_strategy: Union[IntervalStrategy, str] = "steps",
155
logging_first_step: bool = False,
156
logging_steps: int = 500,
157
logging_nan_inf_filter: bool = True,
158
save_strategy: Union[IntervalStrategy, str] = "steps",
159
save_steps: int = 500,
160
save_total_limit: Optional[int] = None,
161
save_safetensors: Optional[bool] = True,
162
save_on_each_node: bool = False,
163
no_cuda: bool = False,
164
use_cpu: bool = False,
165
use_mps_device: bool = False,
166
seed: int = 42,
167
data_seed: Optional[int] = None,
168
jit_mode_eval: bool = False,
169
use_ipex: bool = False,
170
bf16: bool = False,
171
fp16: bool = False,
172
fp16_opt_level: str = "O1",
173
half_precision_backend: str = "auto",
174
bf16_full_eval: bool = False,
175
fp16_full_eval: bool = False,
176
tf32: Optional[bool] = None,
177
local_rank: int = -1,
178
ddp_backend: Optional[str] = None,
179
ddp_timeout: Optional[int] = 1800,
180
ddp_find_unused_parameters: Optional[bool] = None,
181
ddp_bucket_cap_mb: Optional[int] = None,
182
ddp_broadcast_buffers: Optional[bool] = None,
183
dataloader_pin_memory: bool = True,
184
dataloader_num_workers: int = 0,
185
past_index: int = -1,
186
run_name: Optional[str] = None,
187
disable_tqdm: Optional[bool] = None,
188
remove_unused_columns: bool = True,
189
label_names: Optional[List[str]] = None,
190
load_best_model_at_end: Optional[bool] = False,
191
metric_for_best_model: Optional[str] = None,
192
greater_is_better: Optional[bool] = None,
193
ignore_data_skip: bool = False,
194
sharded_ddp: str = "",
195
fsdp: str = "",
196
fsdp_min_num_params: int = 0,
197
fsdp_config: Optional[str] = None,
198
fsdp_transformer_layer_cls_to_wrap: Optional[str] = None,
199
deepspeed: Optional[str] = None,
200
label_smoothing_factor: float = 0.0,
201
optim: Union[OptimizerNames, str] = "adamw_torch",
202
optim_args: Optional[str] = None,
203
adafactor: bool = False,
204
group_by_length: bool = False,
205
length_column_name: Optional[str] = "length",
206
report_to: Optional[List[str]] = None,
207
ddp_find_unused_parameters: Optional[bool] = None,
208
ddp_bucket_cap_mb: Optional[int] = None,
209
ddp_broadcast_buffers: Optional[bool] = None,
210
dataloader_pin_memory: bool = True,
211
skip_memory_metrics: bool = True,
212
use_legacy_prediction_loop: bool = False,
213
push_to_hub: bool = False,
214
resume_from_checkpoint: Optional[str] = None,
215
hub_model_id: Optional[str] = None,
216
hub_strategy: Union[HubStrategy, str] = "every_save",
217
hub_token: Optional[str] = None,
218
hub_private_repo: bool = False,
219
hub_always_push: bool = False,
220
gradient_checkpointing: bool = False,
221
include_inputs_for_metrics: bool = False,
222
fp16_backend: str = "auto",
223
push_to_hub_model_id: Optional[str] = None,
224
push_to_hub_organization: Optional[str] = None,
225
push_to_hub_token: Optional[str] = None,
226
mp_parameters: str = "",
227
auto_find_batch_size: bool = False,
228
full_determinism: bool = False,
229
torchdynamo: Optional[str] = None,
230
ray_scope: Optional[str] = "last",
231
ddp_timeout: Optional[int] = 1800,
232
torch_compile: bool = False,
233
torch_compile_backend: Optional[str] = None,
234
torch_compile_mode: Optional[str] = None,
235
dispatch_batches: Optional[bool] = None,
236
split_batches: Optional[bool] = None,
237
include_tokens_per_second: Optional[bool] = False,
238
**kwargs
239
):
240
"""
241
Configure training parameters.
242
243
Key parameters:
244
output_dir: Directory to save model and checkpoints
245
num_train_epochs: Number of training epochs
246
per_device_train_batch_size: Batch size per device
247
learning_rate: Learning rate for optimization
248
weight_decay: Weight decay for regularization
249
warmup_steps: Linear warmup steps
250
logging_steps: Log every N steps
251
save_steps: Save checkpoint every N steps
252
evaluation_strategy: When to evaluate ("steps", "epoch", "no")
253
fp16: Enable mixed precision training
254
gradient_accumulation_steps: Accumulate gradients over N steps
255
dataloader_num_workers: Number of data loading workers
256
remove_unused_columns: Remove unused dataset columns
257
load_best_model_at_end: Load best model after training
258
metric_for_best_model: Metric to determine best model
259
push_to_hub: Upload model to Hugging Face Hub
260
"""
261
```
262
263
### Data Collators
264
265
Utilities for batching and preprocessing data during training.
266
267
```python { .api }
268
class DataCollatorWithPadding:
269
def __init__(
270
self,
271
tokenizer: PreTrainedTokenizer,
272
padding: Union[bool, str] = True,
273
max_length: Optional[int] = None,
274
pad_to_multiple_of: Optional[int] = None,
275
return_tensors: str = "pt"
276
):
277
"""
278
Collator that pads sequences to the same length.
279
280
Args:
281
tokenizer: Tokenizer to use for padding
282
padding: Padding strategy
283
max_length: Maximum sequence length
284
pad_to_multiple_of: Pad to multiple of this value
285
return_tensors: Format of returned tensors
286
"""
287
288
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
289
"""Collate and pad batch of features."""
290
291
class DataCollatorForLanguageModeling:
292
def __init__(
293
self,
294
tokenizer: PreTrainedTokenizer,
295
mlm: bool = True,
296
mlm_probability: float = 0.15,
297
pad_to_multiple_of: Optional[int] = None,
298
tf_experimental_compile: bool = False,
299
return_tensors: str = "pt"
300
):
301
"""
302
Collator for language modeling tasks.
303
304
Args:
305
tokenizer: Tokenizer to use
306
mlm: Whether to use masked language modeling
307
mlm_probability: Probability of masking tokens
308
pad_to_multiple_of: Pad to multiple of this value
309
return_tensors: Format of returned tensors
310
"""
311
312
class DataCollatorForSeq2Seq:
313
def __init__(
314
self,
315
tokenizer: PreTrainedTokenizer,
316
model: Optional[PreTrainedModel] = None,
317
padding: Union[bool, str] = True,
318
max_length: Optional[int] = None,
319
pad_to_multiple_of: Optional[int] = None,
320
label_pad_token_id: int = -100,
321
return_tensors: str = "pt"
322
):
323
"""
324
Collator for sequence-to-sequence tasks.
325
326
Args:
327
tokenizer: Tokenizer to use
328
model: Model to get decoder start token
329
padding: Padding strategy
330
max_length: Maximum sequence length
331
label_pad_token_id: Token ID for padding labels
332
return_tensors: Format of returned tensors
333
"""
334
```
335
336
### Training Callbacks
337
338
Extensible callback system for customizing training behavior.
339
340
```python { .api }
341
class TrainerCallback:
342
"""Base class for trainer callbacks."""
343
344
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
345
"""Called at the end of trainer initialization."""
346
347
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
348
"""Called at the beginning of training."""
349
350
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
351
"""Called at the end of training."""
352
353
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
354
"""Called at the beginning of each epoch."""
355
356
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
357
"""Called at the end of each epoch."""
358
359
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
360
"""Called at the beginning of each training step."""
361
362
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
363
"""Called at the end of each training step."""
364
365
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
366
"""Called after evaluation."""
367
368
class EarlyStoppingCallback(TrainerCallback):
369
def __init__(
370
self,
371
early_stopping_patience: int = 1,
372
early_stopping_threshold: Optional[float] = 0.0
373
):
374
"""
375
Callback for early stopping based on evaluation metrics.
376
377
Args:
378
early_stopping_patience: Number of evaluations to wait
379
early_stopping_threshold: Minimum improvement threshold
380
"""
381
382
class TensorBoardCallback(TrainerCallback):
383
"""Log training metrics to TensorBoard."""
384
385
class WandbCallback(TrainerCallback):
386
"""Log training metrics to Weights & Biases."""
387
```
388
389
### Optimization and Scheduling
390
391
Learning rate schedulers and optimizers for effective training.
392
393
```python { .api }
394
def get_scheduler(
395
name: Union[str, SchedulerType],
396
optimizer: torch.optim.Optimizer,
397
num_warmup_steps: Optional[int] = None,
398
num_training_steps: Optional[int] = None,
399
**kwargs
400
) -> torch.optim.lr_scheduler.LambdaLR:
401
"""
402
Create learning rate scheduler.
403
404
Args:
405
name: Scheduler type ("linear", "cosine", "polynomial", etc.)
406
optimizer: Optimizer to schedule
407
num_warmup_steps: Number of warmup steps
408
num_training_steps: Total training steps
409
410
Returns:
411
Configured scheduler
412
"""
413
414
def get_linear_schedule_with_warmup(
415
optimizer: torch.optim.Optimizer,
416
num_warmup_steps: int,
417
num_training_steps: int,
418
last_epoch: int = -1
419
) -> torch.optim.lr_scheduler.LambdaLR:
420
"""Linear schedule with linear warmup."""
421
422
def get_cosine_schedule_with_warmup(
423
optimizer: torch.optim.Optimizer,
424
num_warmup_steps: int,
425
num_training_steps: int,
426
num_cycles: float = 0.5,
427
last_epoch: int = -1
428
) -> torch.optim.lr_scheduler.LambdaLR:
429
"""Cosine schedule with linear warmup."""
430
431
class AdamW(torch.optim.Optimizer):
432
def __init__(
433
self,
434
params,
435
lr: float = 1e-3,
436
betas: Tuple[float, float] = (0.9, 0.999),
437
eps: float = 1e-6,
438
weight_decay: float = 0.01,
439
correct_bias: bool = True
440
):
441
"""
442
AdamW optimizer with weight decay.
443
444
Args:
445
params: Model parameters
446
lr: Learning rate
447
betas: Adam beta parameters
448
eps: Epsilon for numerical stability
449
weight_decay: Weight decay coefficient
450
correct_bias: Apply bias correction
451
"""
452
```
453
454
### Training Output Types
455
456
Structured outputs from training and evaluation methods.
457
458
```python { .api }
459
class TrainOutput:
460
"""Output from training."""
461
global_step: int
462
training_loss: float
463
metrics: Dict[str, float]
464
465
class EvalPrediction:
466
"""Predictions and labels for evaluation."""
467
predictions: Union[np.ndarray, Tuple[np.ndarray]]
468
label_ids: Optional[np.ndarray]
469
inputs: Optional[np.ndarray]
470
471
class PredictionOutput:
472
"""Output from prediction."""
473
predictions: Union[np.ndarray, Tuple[np.ndarray]]
474
label_ids: Optional[np.ndarray]
475
metrics: Optional[Dict[str, float]]
476
477
class TrainerState:
478
"""Internal trainer state."""
479
epoch: Optional[float] = None
480
global_step: int = 0
481
max_steps: int = 0
482
logging_steps: int = 500
483
eval_steps: int = 500
484
save_steps: int = 500
485
train_batch_size: int = None
486
num_train_epochs: int = 0
487
total_flos: int = 0
488
log_history: List[Dict[str, float]] = None
489
best_metric: Optional[float] = None
490
best_model_checkpoint: Optional[str] = None
491
is_local_process_zero: bool = True
492
is_world_process_zero: bool = True
493
494
class TrainerControl:
495
"""Control flags for trainer behavior."""
496
should_training_stop: bool = False
497
should_epoch_stop: bool = False
498
should_save: bool = False
499
should_evaluate: bool = False
500
should_log: bool = False
501
```
502
503
## Training Examples
504
505
Common training patterns and configurations:
506
507
```python
508
# Basic fine-tuning setup
509
from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification
510
511
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
512
513
training_args = TrainingArguments(
514
output_dir="./results",
515
num_train_epochs=3,
516
per_device_train_batch_size=16,
517
per_device_eval_batch_size=64,
518
warmup_steps=500,
519
weight_decay=0.01,
520
logging_dir="./logs",
521
logging_steps=100,
522
evaluation_strategy="steps",
523
eval_steps=500,
524
save_steps=500,
525
load_best_model_at_end=True,
526
metric_for_best_model="eval_accuracy"
527
)
528
529
trainer = Trainer(
530
model=model,
531
args=training_args,
532
train_dataset=train_dataset,
533
eval_dataset=eval_dataset,
534
compute_metrics=compute_metrics
535
)
536
537
# Start training
538
trainer.train()
539
540
# Evaluate final model
541
eval_results = trainer.evaluate()
542
543
# Make predictions
544
predictions = trainer.predict(test_dataset)
545
```