or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-transformers.mdcross-encoder.mdevaluation.mdindex.mdloss-functions.mdsparse-encoder.mdtraining.mdutilities.md

training.mddocs/

0

# Training

1

2

The sentence-transformers package provides a modern training framework built on HuggingFace Trainer, supporting various learning objectives and multi-dataset training for sentence transformer models.

3

4

## SentenceTransformerTrainer

5

6

### Constructor

7

8

```python

9

SentenceTransformerTrainer(

10

model: SentenceTransformer | None = None,

11

args: SentenceTransformerTrainingArguments | None = None,

12

train_dataset: Dataset | None = None,

13

eval_dataset: Dataset | None = None,

14

tokenizer: PreTrainedTokenizer | None = None,

15

data_collator: DataCollator | None = None,

16

compute_metrics: callable | None = None,

17

callbacks: list[TrainerCallback] | None = None,

18

optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),

19

preprocess_logits_for_metrics: callable | None = None,

20

loss: torch.nn.Module | dict[str, torch.nn.Module] | None = None

21

)

22

```

23

`{ .api }`

24

25

Modern trainer for sentence transformer models.

26

27

**Parameters**:

28

- `model`: SentenceTransformer model to train

29

- `args`: Training configuration arguments

30

- `train_dataset`: Training dataset(s) - single Dataset or dict of datasets

31

- `eval_dataset`: Evaluation dataset(s) - single Dataset or dict of datasets

32

- `tokenizer`: Tokenizer (usually auto-detected from model)

33

- `data_collator`: Custom data collator for batching

34

- `compute_metrics`: Function to compute evaluation metrics

35

- `callbacks`: List of training callbacks

36

- `optimizers`: Custom optimizer and learning rate scheduler

37

- `preprocess_logits_for_metrics`: Function to preprocess logits for metrics

38

- `loss`: Loss function(s) - single loss or dict mapping dataset names to losses

39

40

### Training Methods

41

42

```python

43

def train(

44

resume_from_checkpoint: str | bool | None = None,

45

trial: dict[str, Any] | None = None,

46

ignore_keys_for_eval: list[str] | None = None,

47

**kwargs

48

) -> TrainOutput

49

```

50

`{ .api }`

51

52

Train the sentence transformer model.

53

54

**Parameters**:

55

- `resume_from_checkpoint`: Path to checkpoint or True to resume from latest

56

- `trial`: Hyperparameter optimization trial object

57

- `ignore_keys_for_eval`: Keys to ignore during evaluation

58

- `**kwargs`: Additional arguments passed to base trainer

59

60

**Returns**: Training output with logs and metrics

61

62

```python

63

def evaluate(

64

eval_dataset: Dataset | None = None,

65

ignore_keys: list[str] | None = None,

66

metric_key_prefix: str = "eval"

67

) -> dict[str, float]

68

```

69

`{ .api }`

70

71

Evaluate the model on evaluation dataset(s).

72

73

```python

74

def predict(

75

test_dataset: Dataset,

76

ignore_keys: list[str] | None = None,

77

metric_key_prefix: str = "test"

78

) -> PredictionOutput

79

```

80

`{ .api }`

81

82

Make predictions on test dataset.

83

84

### Multi-Dataset Training

85

86

```python

87

def add_dataset(

88

train_dataset: Dataset,

89

eval_dataset: Dataset | None = None,

90

dataset_name: str | None = None,

91

loss: torch.nn.Module | None = None

92

) -> None

93

```

94

`{ .api }`

95

96

Add additional dataset to multi-dataset training setup.

97

98

**Parameters**:

99

- `train_dataset`: Training dataset to add

100

- `eval_dataset`: Optional evaluation dataset

101

- `dataset_name`: Name for the dataset (auto-generated if None)

102

- `loss`: Specific loss function for this dataset

103

104

## SentenceTransformerTrainingArguments

105

106

```python

107

class SentenceTransformerTrainingArguments(TrainingArguments):

108

def __init__(

109

self,

110

output_dir: str,

111

evaluation_strategy: str | IntervalStrategy = "no",

112

eval_steps: int | None = None,

113

eval_delay: float = 0,

114

logging_dir: str | None = None,

115

logging_strategy: str | IntervalStrategy = "steps",

116

logging_steps: int = 500,

117

save_strategy: str | IntervalStrategy = "steps",

118

save_steps: int = 500,

119

save_total_limit: int | None = None,

120

seed: int = 42,

121

data_seed: int | None = None,

122

jit_mode_eval: bool = False,

123

use_ipex: bool = False,

124

bf16: bool = False,

125

fp16: bool = False,

126

fp16_opt_level: str = "O1",

127

half_precision_backend: str = "auto",

128

bf16_full_eval: bool = False,

129

fp16_full_eval: bool = False,

130

tf32: bool | None = None,

131

local_rank: int = -1,

132

ddp_backend: str | None = None,

133

tpu_num_cores: int | None = None,

134

tpu_metrics_debug: bool = False,

135

debug: str | list[DebugOption] = "",

136

dataloader_drop_last: bool = False,

137

dataloader_num_workers: int = 0,

138

past_index: int = -1,

139

run_name: str | None = None,

140

disable_tqdm: bool | None = None,

141

remove_unused_columns: bool = True,

142

label_names: list[str] | None = None,

143

load_best_model_at_end: bool = False,

144

ignore_data_skip: bool = False,

145

fsdp: str | list[str] = "",

146

fsdp_min_num_params: int = 0,

147

fsdp_config: dict[str, Any] | None = None,

148

fsdp_transformer_layer_cls_to_wrap: str | None = None,

149

deepspeed: str | None = None,

150

label_smoothing_factor: float = 0.0,

151

optim: str | OptimizerNames = "adamw_torch",

152

optim_args: str | None = None,

153

adafactor: bool = False,

154

group_by_length: bool = False,

155

length_column_name: str | None = "length",

156

report_to: str | list[str] | None = None,

157

ddp_find_unused_parameters: bool | None = None,

158

ddp_bucket_cap_mb: int | None = None,

159

ddp_broadcast_buffers: bool | None = None,

160

dataloader_pin_memory: bool = True,

161

skip_memory_metrics: bool = True,

162

use_legacy_prediction_loop: bool = False,

163

push_to_hub: bool = False,

164

resume_from_checkpoint: str | None = None,

165

hub_model_id: str | None = None,

166

hub_strategy: str | HubStrategy = "every_save",

167

hub_token: str | None = None,

168

hub_private_repo: bool = False,

169

hub_always_push: bool = False,

170

gradient_checkpointing: bool = False,

171

include_inputs_for_metrics: bool = False,

172

auto_find_batch_size: bool = False,

173

full_determinism: bool = False,

174

torchdynamo: str | None = None,

175

ray_scope: str | None = "last",

176

ddp_timeout: int = 1800,

177

torch_compile: bool = False,

178

torch_compile_backend: str | None = None,

179

torch_compile_mode: str | None = None,

180

dispatch_batches: bool | None = None,

181

split_batches: bool | None = None,

182

include_tokens_per_second: bool = False,

183

# Sentence Transformers specific arguments

184

batch_sampler: str = "batch_sampler",

185

multi_dataset_batch_sampler: str = "proportional",

186

**kwargs

187

)

188

```

189

`{ .api }`

190

191

Training arguments extending HuggingFace TrainingArguments with sentence transformer specific options.

192

193

**Key Sentence Transformer Parameters**:

194

- `batch_sampler`: Strategy for sampling batches from datasets

195

- `multi_dataset_batch_sampler`: Strategy for multi-dataset batch sampling ("proportional", "round_robin")

196

197

## SentenceTransformerModelCardData

198

199

```python

200

class SentenceTransformerModelCardData:

201

def __init__(

202

self,

203

language: str | list[str] | None = None,

204

license: str | None = None,

205

tags: str | list[str] | None = None,

206

model_name: str | None = None,

207

model_id: str | None = None,

208

eval_results: list[EvalResult] | None = None,

209

train_datasets: str | list[str] | None = None,

210

eval_datasets: str | list[str] | None = None,

211

prior_models: str | list[str] | None = None,

212

base_model: str | None = None,

213

similarity_fn_name: str | None = None,

214

model_max_length: int | None = None

215

)

216

```

217

`{ .api }`

218

219

Data class for generating comprehensive model cards for sentence transformer models.

220

221

**Parameters**:

222

- `language`: Supported language(s)

223

- `license`: Model license

224

- `tags`: Categorization tags

225

- `model_name`: Human-readable model name

226

- `model_id`: Unique model identifier

227

- `eval_results`: Evaluation results and benchmarks

228

- `train_datasets`: Datasets used for training

229

- `eval_datasets`: Datasets used for evaluation

230

- `prior_models`: Models used as starting points

231

- `base_model`: Base transformer model

232

- `similarity_fn_name`: Default similarity function

233

- `model_max_length`: Maximum input length

234

235

## Batch Samplers

236

237

### DefaultBatchSampler

238

239

```python

240

class DefaultBatchSampler:

241

def __init__(

242

self,

243

dataset: Dataset,

244

batch_size: int,

245

drop_last: bool = False,

246

generator: torch.Generator | None = None

247

)

248

```

249

`{ .api }`

250

251

Standard batch sampler for single dataset training.

252

253

### MultiDatasetDefaultBatchSampler

254

255

```python

256

class MultiDatasetDefaultBatchSampler:

257

def __init__(

258

self,

259

datasets: dict[str, Dataset],

260

batch_sizes: dict[str, int] | int,

261

sampling_strategy: str = "proportional",

262

generator: torch.Generator | None = None

263

)

264

```

265

`{ .api }`

266

267

Abstract base class for multi-dataset batch sampling.

268

269

**Parameters**:

270

- `datasets`: Dictionary mapping dataset names to Dataset objects

271

- `batch_sizes`: Batch sizes per dataset or single batch size for all

272

- `sampling_strategy`: How to sample from multiple datasets ("proportional", "round_robin")

273

- `generator`: Random number generator for reproducibility

274

275

## Usage Examples

276

277

### Basic Training Setup

278

279

```python

280

from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer

281

from sentence_transformers import SentenceTransformerTrainingArguments

282

from sentence_transformers.losses import MultipleNegativesRankingLoss

283

from datasets import Dataset

284

285

# Prepare training data

286

train_data = [

287

{"anchor": "The cat sits on the mat", "positive": "A feline rests on a rug"},

288

{"anchor": "Python is a programming language", "positive": "Python is used for coding"},

289

{"anchor": "Machine learning uses data", "positive": "ML algorithms process datasets"}

290

]

291

292

train_dataset = Dataset.from_list(train_data)

293

294

# Initialize model

295

model = SentenceTransformer('distilbert-base-uncased')

296

297

# Define loss function

298

loss = MultipleNegativesRankingLoss(model)

299

300

# Training arguments

301

args = SentenceTransformerTrainingArguments(

302

output_dir='./sentence-transformer-training',

303

num_train_epochs=3,

304

per_device_train_batch_size=16,

305

per_device_eval_batch_size=16,

306

learning_rate=2e-5,

307

warmup_ratio=0.1,

308

logging_steps=10,

309

logging_dir='./logs',

310

evaluation_strategy="steps",

311

eval_steps=100,

312

save_steps=100,

313

save_total_limit=2,

314

load_best_model_at_end=True,

315

metric_for_best_model="eval_loss",

316

greater_is_better=False,

317

run_name="sentence-transformer-training"

318

)

319

320

# Create trainer

321

trainer = SentenceTransformerTrainer(

322

model=model,

323

args=args,

324

train_dataset=train_dataset,

325

loss=loss

326

)

327

328

# Train model

329

trainer.train()

330

331

# Save trained model

332

model.save('./trained-sentence-transformer')

333

```

334

335

### Multi-Dataset Training

336

337

```python

338

from sentence_transformers.losses import CosineSimilarityLoss, TripletLoss

339

340

# Prepare multiple datasets

341

dataset1 = Dataset.from_list([

342

{"sentence1": "The cat sits", "sentence2": "A cat is sitting", "label": 1.0},

343

{"sentence1": "Dogs are pets", "sentence2": "Cats are animals", "label": 0.3}

344

])

345

346

dataset2 = Dataset.from_list([

347

{"anchor": "Python programming", "positive": "Coding in Python", "negative": "Java development"},

348

{"anchor": "Machine learning", "positive": "AI algorithms", "negative": "Web design"}

349

])

350

351

# Define different losses for different datasets

352

loss1 = CosineSimilarityLoss(model)

353

loss2 = TripletLoss(model)

354

355

# Multi-dataset training

356

trainer = SentenceTransformerTrainer(

357

model=model,

358

args=args,

359

train_dataset={"similarity": dataset1, "triplet": dataset2},

360

loss={"similarity": loss1, "triplet": loss2}

361

)

362

363

trainer.train()

364

```

365

366

### Advanced Training with Evaluation

367

368

```python

369

from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator

370

import numpy as np

371

372

# Prepare evaluation data

373

eval_sentences1 = ["The cat sits on the mat", "I love programming"]

374

eval_sentences2 = ["A feline rests on a rug", "I enjoy coding"]

375

eval_scores = [0.9, 0.8] # Similarity scores

376

377

evaluator = EmbeddingSimilarityEvaluator(

378

eval_sentences1,

379

eval_sentences2,

380

eval_scores,

381

name="dev"

382

)

383

384

def compute_metrics(eval_pred):

385

"""Custom metrics computation."""

386

predictions = eval_pred.predictions

387

labels = eval_pred.label_ids

388

389

# Compute custom metrics

390

mse = np.mean((predictions - labels) ** 2)

391

return {"mse": mse}

392

393

# Enhanced training arguments

394

args = SentenceTransformerTrainingArguments(

395

output_dir='./advanced-training',

396

num_train_epochs=5,

397

per_device_train_batch_size=32,

398

gradient_accumulation_steps=2,

399

learning_rate=2e-5,

400

weight_decay=0.01,

401

warmup_ratio=0.1,

402

lr_scheduler_type="cosine",

403

logging_steps=10,

404

evaluation_strategy="epoch",

405

save_strategy="epoch",

406

save_total_limit=3,

407

load_best_model_at_end=True,

408

metric_for_best_model="eval_cosine_pearson",

409

greater_is_better=True,

410

push_to_hub=False,

411

report_to=["tensorboard"],

412

run_name="advanced-sentence-transformer"

413

)

414

415

trainer = SentenceTransformerTrainer(

416

model=model,

417

args=args,

418

train_dataset=train_dataset,

419

loss=loss,

420

compute_metrics=compute_metrics

421

)

422

423

# Add evaluation

424

def evaluation_callback(trainer):

425

"""Custom evaluation during training."""

426

results = evaluator(model, output_path=args.output_dir + "/evaluation")

427

trainer.log({"eval_cosine_pearson": results})

428

429

# Train with evaluation

430

trainer.train()

431

```

432

433

### Custom Loss Function Training

434

435

```python

436

from sentence_transformers.losses import MatryoshkaLoss, TripletLoss

437

438

# Matryoshka representation learning

439

base_loss = MultipleNegativesRankingLoss(model)

440

matryoshka_loss = MatryoshkaLoss(

441

model=model,

442

loss=base_loss,

443

matryoshka_dims=[768, 512, 256, 128, 64] # Progressive dimensions

444

)

445

446

# Training with progressive dimensionality

447

trainer = SentenceTransformerTrainer(

448

model=model,

449

args=args,

450

train_dataset=train_dataset,

451

loss=matryoshka_loss

452

)

453

454

trainer.train()

455

456

# The trained model can now produce embeddings at multiple dimensions

457

embeddings_768 = model.encode(["Test sentence"], truncate_dim=768)

458

embeddings_256 = model.encode(["Test sentence"], truncate_dim=256)

459

embeddings_64 = model.encode(["Test sentence"], truncate_dim=64)

460

```

461

462

### Distributed Training

463

464

```python

465

# For multi-GPU training

466

args = SentenceTransformerTrainingArguments(

467

output_dir='./distributed-training',

468

num_train_epochs=3,

469

per_device_train_batch_size=16, # Per GPU batch size

470

gradient_accumulation_steps=4, # Effective batch size: 16 * 4 * num_gpus

471

dataloader_num_workers=4,

472

ddp_find_unused_parameters=False,

473

fp16=True, # Mixed precision training

474

logging_steps=10,

475

save_steps=500,

476

evaluation_strategy="steps",

477

eval_steps=500,

478

warmup_ratio=0.1,

479

learning_rate=2e-5,

480

run_name="distributed-training"

481

)

482

483

# The trainer automatically handles multi-GPU setup

484

trainer = SentenceTransformerTrainer(

485

model=model,

486

args=args,

487

train_dataset=train_dataset,

488

loss=loss

489

)

490

491

# Launch with: torchrun --nproc_per_node=2 train_script.py

492

trainer.train()

493

```

494

495

### Hyperparameter Optimization

496

497

```python

498

from transformers import TrainerCallback

499

from ray import tune

500

501

def model_init():

502

"""Initialize model for hyperparameter search."""

503

return SentenceTransformer('distilbert-base-uncased')

504

505

def hp_space(trial):

506

"""Define hyperparameter search space."""

507

return {

508

"learning_rate": tune.loguniform(1e-6, 1e-4),

509

"per_device_train_batch_size": tune.choice([16, 32, 64]),

510

"warmup_ratio": tune.uniform(0.0, 0.3),

511

"weight_decay": tune.uniform(0.0, 0.3),

512

}

513

514

# Hyperparameter search

515

trainer = SentenceTransformerTrainer(

516

model_init=model_init,

517

args=args,

518

train_dataset=train_dataset,

519

loss=MultipleNegativesRankingLoss,

520

compute_metrics=compute_metrics

521

)

522

523

best_trial = trainer.hyperparameter_search(

524

hp_space=hp_space,

525

compute_objective=lambda metrics: metrics["eval_loss"],

526

n_trials=10,

527

direction="minimize"

528

)

529

530

print(f"Best hyperparameters: {best_trial.hyperparameters}")

531

```

532

533

### Custom Data Collator

534

535

```python

536

from transformers import DataCollatorWithPadding

537

from typing import Dict, List, Any

538

import torch

539

540

class CustomDataCollator:

541

"""Custom data collator for sentence transformer training."""

542

543

def __init__(self, tokenizer, padding=True, max_length=None):

544

self.tokenizer = tokenizer

545

self.padding = padding

546

self.max_length = max_length

547

548

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:

549

# Custom processing of training examples

550

batch = {}

551

552

# Extract texts

553

texts = []

554

for feature in features:

555

if 'anchor' in feature and 'positive' in feature:

556

texts.extend([feature['anchor'], feature['positive']])

557

if 'negative' in feature:

558

texts.append(feature['negative'])

559

elif 'sentence1' in feature and 'sentence2' in feature:

560

texts.extend([feature['sentence1'], feature['sentence2']])

561

562

# Tokenize

563

tokenized = self.tokenizer(

564

texts,

565

padding=self.padding,

566

truncation=True,

567

max_length=self.max_length,

568

return_tensors='pt'

569

)

570

571

return tokenized

572

573

# Use custom data collator

574

custom_collator = CustomDataCollator(model.tokenizer, max_length=512)

575

576

trainer = SentenceTransformerTrainer(

577

model=model,

578

args=args,

579

train_dataset=train_dataset,

580

data_collator=custom_collator,

581

loss=loss

582

)

583

```

584

585

### Training Callbacks

586

587

```python

588

from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl

589

import wandb

590

591

class CustomCallback(TrainerCallback):

592

"""Custom training callback for monitoring."""

593

594

def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):

595

print("Training started!")

596

597

def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):

598

# Custom logic at epoch end

599

model = kwargs.get('model')

600

if model:

601

# Evaluate on custom data or log additional metrics

602

pass

603

604

def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):

605

print(f"Model saved at step {state.global_step}")

606

607

# Add callbacks to trainer

608

trainer = SentenceTransformerTrainer(

609

model=model,

610

args=args,

611

train_dataset=train_dataset,

612

loss=loss,

613

callbacks=[CustomCallback()]

614

)

615

```

616

617

### Model Card Generation

618

619

```python

620

from sentence_transformers import SentenceTransformerModelCardData

621

622

# Create comprehensive model card

623

model_card_data = SentenceTransformerModelCardData(

624

language=['en'],

625

license='apache-2.0',

626

tags=['sentence-transformers', 'sentence-similarity', 'embeddings'],

627

model_name='Custom Sentence Transformer',

628

base_model='distilbert-base-uncased',

629

train_datasets=['custom-similarity-dataset'],

630

eval_datasets=['sts-benchmark'],

631

similarity_fn_name='cosine',

632

model_max_length=512

633

)

634

635

# Save model with card

636

trainer.save_model('./final-model')

637

model.save('./final-model', model_card_data=model_card_data)

638

639

# Push to hub with model card

640

model.push_to_hub('my-username/my-sentence-transformer', model_card_data=model_card_data)

641

```

642

643

## Best Practices

644

645

1. **Data Preparation**: Ensure high-quality, diverse training data

646

2. **Loss Selection**: Choose appropriate loss functions for your task

647

3. **Batch Size**: Use larger batch sizes when possible for contrastive learning

648

4. **Learning Rate**: Start with 2e-5 and adjust based on model size

649

5. **Evaluation**: Use domain-relevant evaluation metrics during training

650

6. **Checkpointing**: Save regular checkpoints and use early stopping

651

7. **Multi-Dataset**: Combine multiple datasets for robust representations

652

8. **Progressive Training**: Consider curriculum learning or progressive approaches