or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

feature-extraction.mdgeneration.mdindex.mdmodels.mdoptimization.mdpipelines.mdtokenization.mdtraining.md

training.mddocs/

0

# Training

1

2

Comprehensive training framework with built-in optimization, distributed training support, logging, evaluation, and extensive customization options. The Trainer provides a high-level interface for fine-tuning transformer models while supporting advanced features like gradient accumulation, mixed precision, and custom training loops.

3

4

## Capabilities

5

6

### Trainer Class

7

8

Main training class that handles the complete training loop with automatic optimization, logging, and evaluation.

9

10

```python { .api }

11

class Trainer:

12

def __init__(

13

self,

14

model: PreTrainedModel = None,

15

args: TrainingArguments = None,

16

data_collator: DataCollator = None,

17

train_dataset: Dataset = None,

18

eval_dataset: Union[Dataset, Dict[str, Dataset]] = None,

19

processing_class: Union[PreTrainedTokenizer, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] = None,

20

model_init: Callable[[], PreTrainedModel] = None,

21

compute_metrics: Callable[[EvalPrediction], Dict] = None,

22

callbacks: List[TrainerCallback] = None,

23

optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),

24

preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None

25

):

26

"""

27

Initialize trainer with model, training arguments, and datasets.

28

29

Args:

30

model: Model to train

31

args: Training configuration

32

data_collator: Collates batch data

33

train_dataset: Training dataset

34

eval_dataset: Evaluation dataset(s)

35

processing_class: Processing class (tokenizer, image processor, etc.) for the model

36

compute_metrics: Function to compute evaluation metrics

37

callbacks: List of training callbacks

38

optimizers: Custom optimizer and scheduler tuple

39

preprocess_logits_for_metrics: Preprocess logits before metrics

40

"""

41

42

def train(

43

self,

44

resume_from_checkpoint: Union[str, bool] = None,

45

trial: Union[optuna.Trial, Dict[str, Any]] = None,

46

ignore_keys_for_eval: List[str] = None,

47

**kwargs

48

) -> TrainOutput:

49

"""

50

Start training process.

51

52

Args:

53

resume_from_checkpoint: Path to checkpoint or True for latest

54

trial: Optuna trial for hyperparameter optimization

55

ignore_keys_for_eval: Keys to ignore during evaluation

56

57

Returns:

58

Training output with metrics and statistics

59

"""

60

61

def evaluate(

62

self,

63

eval_dataset: Dataset = None,

64

ignore_keys: List[str] = None,

65

metric_key_prefix: str = "eval"

66

) -> Dict[str, float]:

67

"""

68

Evaluate model on evaluation dataset.

69

70

Args:

71

eval_dataset: Dataset to evaluate on (uses default if None)

72

ignore_keys: Keys to ignore in output

73

metric_key_prefix: Prefix for metric names

74

75

Returns:

76

Dictionary of evaluation metrics

77

"""

78

79

def predict(

80

self,

81

test_dataset: Dataset,

82

ignore_keys: List[str] = None,

83

metric_key_prefix: str = "test"

84

) -> PredictionOutput:

85

"""

86

Make predictions on test dataset.

87

88

Args:

89

test_dataset: Dataset to predict on

90

ignore_keys: Keys to ignore in output

91

metric_key_prefix: Prefix for metric names

92

93

Returns:

94

Predictions with metrics and labels

95

"""

96

97

def save_model(

98

self,

99

output_dir: str = None,

100

_internal_call: bool = False

101

) -> None:

102

"""Save model and tokenizer to directory."""

103

104

def save_state(self) -> None:

105

"""Save trainer state for resuming training."""

106

107

def log(self, logs: Dict[str, float]) -> None:

108

"""Log metrics and values."""

109

110

def create_optimizer_and_scheduler(

111

self,

112

num_training_steps: int

113

) -> None:

114

"""Create optimizer and learning rate scheduler."""

115

```

116

117

### Training Arguments

118

119

Comprehensive configuration class for all training hyperparameters and settings.

120

121

```python { .api }

122

class TrainingArguments:

123

def __init__(

124

self,

125

output_dir: str,

126

overwrite_output_dir: bool = False,

127

do_train: bool = False,

128

do_eval: bool = False,

129

do_predict: bool = False,

130

evaluation_strategy: Union[IntervalStrategy, str] = "no",

131

prediction_loss_only: bool = False,

132

per_device_train_batch_size: int = 8,

133

per_device_eval_batch_size: int = 8,

134

per_gpu_train_batch_size: Optional[int] = None,

135

per_gpu_eval_batch_size: Optional[int] = None,

136

gradient_accumulation_steps: int = 1,

137

eval_accumulation_steps: Optional[int] = None,

138

eval_delay: Optional[float] = 0,

139

learning_rate: float = 5e-5,

140

weight_decay: float = 0.0,

141

adam_beta1: float = 0.9,

142

adam_beta2: float = 0.999,

143

adam_epsilon: float = 1e-8,

144

max_grad_norm: float = 1.0,

145

num_train_epochs: float = 3.0,

146

max_steps: int = -1,

147

lr_scheduler_type: Union[SchedulerType, str] = "linear",

148

warmup_ratio: float = 0.0,

149

warmup_steps: int = 0,

150

log_level: Optional[str] = "passive",

151

log_level_replica: Optional[str] = "warning",

152

log_on_each_node: bool = True,

153

logging_dir: Optional[str] = None,

154

logging_strategy: Union[IntervalStrategy, str] = "steps",

155

logging_first_step: bool = False,

156

logging_steps: int = 500,

157

logging_nan_inf_filter: bool = True,

158

save_strategy: Union[IntervalStrategy, str] = "steps",

159

save_steps: int = 500,

160

save_total_limit: Optional[int] = None,

161

save_safetensors: Optional[bool] = True,

162

save_on_each_node: bool = False,

163

no_cuda: bool = False,

164

use_cpu: bool = False,

165

use_mps_device: bool = False,

166

seed: int = 42,

167

data_seed: Optional[int] = None,

168

jit_mode_eval: bool = False,

169

use_ipex: bool = False,

170

bf16: bool = False,

171

fp16: bool = False,

172

fp16_opt_level: str = "O1",

173

half_precision_backend: str = "auto",

174

bf16_full_eval: bool = False,

175

fp16_full_eval: bool = False,

176

tf32: Optional[bool] = None,

177

local_rank: int = -1,

178

ddp_backend: Optional[str] = None,

179

ddp_timeout: Optional[int] = 1800,

180

ddp_find_unused_parameters: Optional[bool] = None,

181

ddp_bucket_cap_mb: Optional[int] = None,

182

ddp_broadcast_buffers: Optional[bool] = None,

183

dataloader_pin_memory: bool = True,

184

dataloader_num_workers: int = 0,

185

past_index: int = -1,

186

run_name: Optional[str] = None,

187

disable_tqdm: Optional[bool] = None,

188

remove_unused_columns: bool = True,

189

label_names: Optional[List[str]] = None,

190

load_best_model_at_end: Optional[bool] = False,

191

metric_for_best_model: Optional[str] = None,

192

greater_is_better: Optional[bool] = None,

193

ignore_data_skip: bool = False,

194

sharded_ddp: str = "",

195

fsdp: str = "",

196

fsdp_min_num_params: int = 0,

197

fsdp_config: Optional[str] = None,

198

fsdp_transformer_layer_cls_to_wrap: Optional[str] = None,

199

deepspeed: Optional[str] = None,

200

label_smoothing_factor: float = 0.0,

201

optim: Union[OptimizerNames, str] = "adamw_torch",

202

optim_args: Optional[str] = None,

203

adafactor: bool = False,

204

group_by_length: bool = False,

205

length_column_name: Optional[str] = "length",

206

report_to: Optional[List[str]] = None,

207

ddp_find_unused_parameters: Optional[bool] = None,

208

ddp_bucket_cap_mb: Optional[int] = None,

209

ddp_broadcast_buffers: Optional[bool] = None,

210

dataloader_pin_memory: bool = True,

211

skip_memory_metrics: bool = True,

212

use_legacy_prediction_loop: bool = False,

213

push_to_hub: bool = False,

214

resume_from_checkpoint: Optional[str] = None,

215

hub_model_id: Optional[str] = None,

216

hub_strategy: Union[HubStrategy, str] = "every_save",

217

hub_token: Optional[str] = None,

218

hub_private_repo: bool = False,

219

hub_always_push: bool = False,

220

gradient_checkpointing: bool = False,

221

include_inputs_for_metrics: bool = False,

222

fp16_backend: str = "auto",

223

push_to_hub_model_id: Optional[str] = None,

224

push_to_hub_organization: Optional[str] = None,

225

push_to_hub_token: Optional[str] = None,

226

mp_parameters: str = "",

227

auto_find_batch_size: bool = False,

228

full_determinism: bool = False,

229

torchdynamo: Optional[str] = None,

230

ray_scope: Optional[str] = "last",

231

ddp_timeout: Optional[int] = 1800,

232

torch_compile: bool = False,

233

torch_compile_backend: Optional[str] = None,

234

torch_compile_mode: Optional[str] = None,

235

dispatch_batches: Optional[bool] = None,

236

split_batches: Optional[bool] = None,

237

include_tokens_per_second: Optional[bool] = False,

238

**kwargs

239

):

240

"""

241

Configure training parameters.

242

243

Key parameters:

244

output_dir: Directory to save model and checkpoints

245

num_train_epochs: Number of training epochs

246

per_device_train_batch_size: Batch size per device

247

learning_rate: Learning rate for optimization

248

weight_decay: Weight decay for regularization

249

warmup_steps: Linear warmup steps

250

logging_steps: Log every N steps

251

save_steps: Save checkpoint every N steps

252

evaluation_strategy: When to evaluate ("steps", "epoch", "no")

253

fp16: Enable mixed precision training

254

gradient_accumulation_steps: Accumulate gradients over N steps

255

dataloader_num_workers: Number of data loading workers

256

remove_unused_columns: Remove unused dataset columns

257

load_best_model_at_end: Load best model after training

258

metric_for_best_model: Metric to determine best model

259

push_to_hub: Upload model to Hugging Face Hub

260

"""

261

```

262

263

### Data Collators

264

265

Utilities for batching and preprocessing data during training.

266

267

```python { .api }

268

class DataCollatorWithPadding:

269

def __init__(

270

self,

271

tokenizer: PreTrainedTokenizer,

272

padding: Union[bool, str] = True,

273

max_length: Optional[int] = None,

274

pad_to_multiple_of: Optional[int] = None,

275

return_tensors: str = "pt"

276

):

277

"""

278

Collator that pads sequences to the same length.

279

280

Args:

281

tokenizer: Tokenizer to use for padding

282

padding: Padding strategy

283

max_length: Maximum sequence length

284

pad_to_multiple_of: Pad to multiple of this value

285

return_tensors: Format of returned tensors

286

"""

287

288

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:

289

"""Collate and pad batch of features."""

290

291

class DataCollatorForLanguageModeling:

292

def __init__(

293

self,

294

tokenizer: PreTrainedTokenizer,

295

mlm: bool = True,

296

mlm_probability: float = 0.15,

297

pad_to_multiple_of: Optional[int] = None,

298

tf_experimental_compile: bool = False,

299

return_tensors: str = "pt"

300

):

301

"""

302

Collator for language modeling tasks.

303

304

Args:

305

tokenizer: Tokenizer to use

306

mlm: Whether to use masked language modeling

307

mlm_probability: Probability of masking tokens

308

pad_to_multiple_of: Pad to multiple of this value

309

return_tensors: Format of returned tensors

310

"""

311

312

class DataCollatorForSeq2Seq:

313

def __init__(

314

self,

315

tokenizer: PreTrainedTokenizer,

316

model: Optional[PreTrainedModel] = None,

317

padding: Union[bool, str] = True,

318

max_length: Optional[int] = None,

319

pad_to_multiple_of: Optional[int] = None,

320

label_pad_token_id: int = -100,

321

return_tensors: str = "pt"

322

):

323

"""

324

Collator for sequence-to-sequence tasks.

325

326

Args:

327

tokenizer: Tokenizer to use

328

model: Model to get decoder start token

329

padding: Padding strategy

330

max_length: Maximum sequence length

331

label_pad_token_id: Token ID for padding labels

332

return_tensors: Format of returned tensors

333

"""

334

```

335

336

### Training Callbacks

337

338

Extensible callback system for customizing training behavior.

339

340

```python { .api }

341

class TrainerCallback:

342

"""Base class for trainer callbacks."""

343

344

def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):

345

"""Called at the end of trainer initialization."""

346

347

def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):

348

"""Called at the beginning of training."""

349

350

def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):

351

"""Called at the end of training."""

352

353

def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):

354

"""Called at the beginning of each epoch."""

355

356

def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):

357

"""Called at the end of each epoch."""

358

359

def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):

360

"""Called at the beginning of each training step."""

361

362

def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):

363

"""Called at the end of each training step."""

364

365

def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):

366

"""Called after evaluation."""

367

368

class EarlyStoppingCallback(TrainerCallback):

369

def __init__(

370

self,

371

early_stopping_patience: int = 1,

372

early_stopping_threshold: Optional[float] = 0.0

373

):

374

"""

375

Callback for early stopping based on evaluation metrics.

376

377

Args:

378

early_stopping_patience: Number of evaluations to wait

379

early_stopping_threshold: Minimum improvement threshold

380

"""

381

382

class TensorBoardCallback(TrainerCallback):

383

"""Log training metrics to TensorBoard."""

384

385

class WandbCallback(TrainerCallback):

386

"""Log training metrics to Weights & Biases."""

387

```

388

389

### Optimization and Scheduling

390

391

Learning rate schedulers and optimizers for effective training.

392

393

```python { .api }

394

def get_scheduler(

395

name: Union[str, SchedulerType],

396

optimizer: torch.optim.Optimizer,

397

num_warmup_steps: Optional[int] = None,

398

num_training_steps: Optional[int] = None,

399

**kwargs

400

) -> torch.optim.lr_scheduler.LambdaLR:

401

"""

402

Create learning rate scheduler.

403

404

Args:

405

name: Scheduler type ("linear", "cosine", "polynomial", etc.)

406

optimizer: Optimizer to schedule

407

num_warmup_steps: Number of warmup steps

408

num_training_steps: Total training steps

409

410

Returns:

411

Configured scheduler

412

"""

413

414

def get_linear_schedule_with_warmup(

415

optimizer: torch.optim.Optimizer,

416

num_warmup_steps: int,

417

num_training_steps: int,

418

last_epoch: int = -1

419

) -> torch.optim.lr_scheduler.LambdaLR:

420

"""Linear schedule with linear warmup."""

421

422

def get_cosine_schedule_with_warmup(

423

optimizer: torch.optim.Optimizer,

424

num_warmup_steps: int,

425

num_training_steps: int,

426

num_cycles: float = 0.5,

427

last_epoch: int = -1

428

) -> torch.optim.lr_scheduler.LambdaLR:

429

"""Cosine schedule with linear warmup."""

430

431

class AdamW(torch.optim.Optimizer):

432

def __init__(

433

self,

434

params,

435

lr: float = 1e-3,

436

betas: Tuple[float, float] = (0.9, 0.999),

437

eps: float = 1e-6,

438

weight_decay: float = 0.01,

439

correct_bias: bool = True

440

):

441

"""

442

AdamW optimizer with weight decay.

443

444

Args:

445

params: Model parameters

446

lr: Learning rate

447

betas: Adam beta parameters

448

eps: Epsilon for numerical stability

449

weight_decay: Weight decay coefficient

450

correct_bias: Apply bias correction

451

"""

452

```

453

454

### Training Output Types

455

456

Structured outputs from training and evaluation methods.

457

458

```python { .api }

459

class TrainOutput:

460

"""Output from training."""

461

global_step: int

462

training_loss: float

463

metrics: Dict[str, float]

464

465

class EvalPrediction:

466

"""Predictions and labels for evaluation."""

467

predictions: Union[np.ndarray, Tuple[np.ndarray]]

468

label_ids: Optional[np.ndarray]

469

inputs: Optional[np.ndarray]

470

471

class PredictionOutput:

472

"""Output from prediction."""

473

predictions: Union[np.ndarray, Tuple[np.ndarray]]

474

label_ids: Optional[np.ndarray]

475

metrics: Optional[Dict[str, float]]

476

477

class TrainerState:

478

"""Internal trainer state."""

479

epoch: Optional[float] = None

480

global_step: int = 0

481

max_steps: int = 0

482

logging_steps: int = 500

483

eval_steps: int = 500

484

save_steps: int = 500

485

train_batch_size: int = None

486

num_train_epochs: int = 0

487

total_flos: int = 0

488

log_history: List[Dict[str, float]] = None

489

best_metric: Optional[float] = None

490

best_model_checkpoint: Optional[str] = None

491

is_local_process_zero: bool = True

492

is_world_process_zero: bool = True

493

494

class TrainerControl:

495

"""Control flags for trainer behavior."""

496

should_training_stop: bool = False

497

should_epoch_stop: bool = False

498

should_save: bool = False

499

should_evaluate: bool = False

500

should_log: bool = False

501

```

502

503

## Training Examples

504

505

Common training patterns and configurations:

506

507

```python

508

# Basic fine-tuning setup

509

from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification

510

511

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

512

513

training_args = TrainingArguments(

514

output_dir="./results",

515

num_train_epochs=3,

516

per_device_train_batch_size=16,

517

per_device_eval_batch_size=64,

518

warmup_steps=500,

519

weight_decay=0.01,

520

logging_dir="./logs",

521

logging_steps=100,

522

evaluation_strategy="steps",

523

eval_steps=500,

524

save_steps=500,

525

load_best_model_at_end=True,

526

metric_for_best_model="eval_accuracy"

527

)

528

529

trainer = Trainer(

530

model=model,

531

args=training_args,

532

train_dataset=train_dataset,

533

eval_dataset=eval_dataset,

534

compute_metrics=compute_metrics

535

)

536

537

# Start training

538

trainer.train()

539

540

# Evaluate final model

541

eval_results = trainer.evaluate()

542

543

# Make predictions

544

predictions = trainer.predict(test_dataset)

545

```