0
# Training
1
2
The sentence-transformers package provides a modern training framework built on HuggingFace Trainer, supporting various learning objectives and multi-dataset training for sentence transformer models.
3
4
## SentenceTransformerTrainer
5
6
### Constructor
7
8
```python
9
SentenceTransformerTrainer(
10
model: SentenceTransformer | None = None,
11
args: SentenceTransformerTrainingArguments | None = None,
12
train_dataset: Dataset | None = None,
13
eval_dataset: Dataset | None = None,
14
tokenizer: PreTrainedTokenizer | None = None,
15
data_collator: DataCollator | None = None,
16
compute_metrics: callable | None = None,
17
callbacks: list[TrainerCallback] | None = None,
18
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
19
preprocess_logits_for_metrics: callable | None = None,
20
loss: torch.nn.Module | dict[str, torch.nn.Module] | None = None
21
)
22
```
23
`{ .api }`
24
25
Modern trainer for sentence transformer models.
26
27
**Parameters**:
28
- `model`: SentenceTransformer model to train
29
- `args`: Training configuration arguments
30
- `train_dataset`: Training dataset(s) - single Dataset or dict of datasets
31
- `eval_dataset`: Evaluation dataset(s) - single Dataset or dict of datasets
32
- `tokenizer`: Tokenizer (usually auto-detected from model)
33
- `data_collator`: Custom data collator for batching
34
- `compute_metrics`: Function to compute evaluation metrics
35
- `callbacks`: List of training callbacks
36
- `optimizers`: Custom optimizer and learning rate scheduler
37
- `preprocess_logits_for_metrics`: Function to preprocess logits for metrics
38
- `loss`: Loss function(s) - single loss or dict mapping dataset names to losses
39
40
### Training Methods
41
42
```python
43
def train(
44
resume_from_checkpoint: str | bool | None = None,
45
trial: dict[str, Any] | None = None,
46
ignore_keys_for_eval: list[str] | None = None,
47
**kwargs
48
) -> TrainOutput
49
```
50
`{ .api }`
51
52
Train the sentence transformer model.
53
54
**Parameters**:
55
- `resume_from_checkpoint`: Path to checkpoint or True to resume from latest
56
- `trial`: Hyperparameter optimization trial object
57
- `ignore_keys_for_eval`: Keys to ignore during evaluation
58
- `**kwargs`: Additional arguments passed to base trainer
59
60
**Returns**: Training output with logs and metrics
61
62
```python
63
def evaluate(
64
eval_dataset: Dataset | None = None,
65
ignore_keys: list[str] | None = None,
66
metric_key_prefix: str = "eval"
67
) -> dict[str, float]
68
```
69
`{ .api }`
70
71
Evaluate the model on evaluation dataset(s).
72
73
```python
74
def predict(
75
test_dataset: Dataset,
76
ignore_keys: list[str] | None = None,
77
metric_key_prefix: str = "test"
78
) -> PredictionOutput
79
```
80
`{ .api }`
81
82
Make predictions on test dataset.
83
84
### Multi-Dataset Training
85
86
```python
87
def add_dataset(
88
train_dataset: Dataset,
89
eval_dataset: Dataset | None = None,
90
dataset_name: str | None = None,
91
loss: torch.nn.Module | None = None
92
) -> None
93
```
94
`{ .api }`
95
96
Add additional dataset to multi-dataset training setup.
97
98
**Parameters**:
99
- `train_dataset`: Training dataset to add
100
- `eval_dataset`: Optional evaluation dataset
101
- `dataset_name`: Name for the dataset (auto-generated if None)
102
- `loss`: Specific loss function for this dataset
103
104
## SentenceTransformerTrainingArguments
105
106
```python
107
class SentenceTransformerTrainingArguments(TrainingArguments):
108
def __init__(
109
self,
110
output_dir: str,
111
evaluation_strategy: str | IntervalStrategy = "no",
112
eval_steps: int | None = None,
113
eval_delay: float = 0,
114
logging_dir: str | None = None,
115
logging_strategy: str | IntervalStrategy = "steps",
116
logging_steps: int = 500,
117
save_strategy: str | IntervalStrategy = "steps",
118
save_steps: int = 500,
119
save_total_limit: int | None = None,
120
seed: int = 42,
121
data_seed: int | None = None,
122
jit_mode_eval: bool = False,
123
use_ipex: bool = False,
124
bf16: bool = False,
125
fp16: bool = False,
126
fp16_opt_level: str = "O1",
127
half_precision_backend: str = "auto",
128
bf16_full_eval: bool = False,
129
fp16_full_eval: bool = False,
130
tf32: bool | None = None,
131
local_rank: int = -1,
132
ddp_backend: str | None = None,
133
tpu_num_cores: int | None = None,
134
tpu_metrics_debug: bool = False,
135
debug: str | list[DebugOption] = "",
136
dataloader_drop_last: bool = False,
137
dataloader_num_workers: int = 0,
138
past_index: int = -1,
139
run_name: str | None = None,
140
disable_tqdm: bool | None = None,
141
remove_unused_columns: bool = True,
142
label_names: list[str] | None = None,
143
load_best_model_at_end: bool = False,
144
ignore_data_skip: bool = False,
145
fsdp: str | list[str] = "",
146
fsdp_min_num_params: int = 0,
147
fsdp_config: dict[str, Any] | None = None,
148
fsdp_transformer_layer_cls_to_wrap: str | None = None,
149
deepspeed: str | None = None,
150
label_smoothing_factor: float = 0.0,
151
optim: str | OptimizerNames = "adamw_torch",
152
optim_args: str | None = None,
153
adafactor: bool = False,
154
group_by_length: bool = False,
155
length_column_name: str | None = "length",
156
report_to: str | list[str] | None = None,
157
ddp_find_unused_parameters: bool | None = None,
158
ddp_bucket_cap_mb: int | None = None,
159
ddp_broadcast_buffers: bool | None = None,
160
dataloader_pin_memory: bool = True,
161
skip_memory_metrics: bool = True,
162
use_legacy_prediction_loop: bool = False,
163
push_to_hub: bool = False,
164
resume_from_checkpoint: str | None = None,
165
hub_model_id: str | None = None,
166
hub_strategy: str | HubStrategy = "every_save",
167
hub_token: str | None = None,
168
hub_private_repo: bool = False,
169
hub_always_push: bool = False,
170
gradient_checkpointing: bool = False,
171
include_inputs_for_metrics: bool = False,
172
auto_find_batch_size: bool = False,
173
full_determinism: bool = False,
174
torchdynamo: str | None = None,
175
ray_scope: str | None = "last",
176
ddp_timeout: int = 1800,
177
torch_compile: bool = False,
178
torch_compile_backend: str | None = None,
179
torch_compile_mode: str | None = None,
180
dispatch_batches: bool | None = None,
181
split_batches: bool | None = None,
182
include_tokens_per_second: bool = False,
183
# Sentence Transformers specific arguments
184
batch_sampler: str = "batch_sampler",
185
multi_dataset_batch_sampler: str = "proportional",
186
**kwargs
187
)
188
```
189
`{ .api }`
190
191
Training arguments extending HuggingFace TrainingArguments with sentence transformer specific options.
192
193
**Key Sentence Transformer Parameters**:
194
- `batch_sampler`: Strategy for sampling batches from datasets
195
- `multi_dataset_batch_sampler`: Strategy for multi-dataset batch sampling ("proportional", "round_robin")
196
197
## SentenceTransformerModelCardData
198
199
```python
200
class SentenceTransformerModelCardData:
201
def __init__(
202
self,
203
language: str | list[str] | None = None,
204
license: str | None = None,
205
tags: str | list[str] | None = None,
206
model_name: str | None = None,
207
model_id: str | None = None,
208
eval_results: list[EvalResult] | None = None,
209
train_datasets: str | list[str] | None = None,
210
eval_datasets: str | list[str] | None = None,
211
prior_models: str | list[str] | None = None,
212
base_model: str | None = None,
213
similarity_fn_name: str | None = None,
214
model_max_length: int | None = None
215
)
216
```
217
`{ .api }`
218
219
Data class for generating comprehensive model cards for sentence transformer models.
220
221
**Parameters**:
222
- `language`: Supported language(s)
223
- `license`: Model license
224
- `tags`: Categorization tags
225
- `model_name`: Human-readable model name
226
- `model_id`: Unique model identifier
227
- `eval_results`: Evaluation results and benchmarks
228
- `train_datasets`: Datasets used for training
229
- `eval_datasets`: Datasets used for evaluation
230
- `prior_models`: Models used as starting points
231
- `base_model`: Base transformer model
232
- `similarity_fn_name`: Default similarity function
233
- `model_max_length`: Maximum input length
234
235
## Batch Samplers
236
237
### DefaultBatchSampler
238
239
```python
240
class DefaultBatchSampler:
241
def __init__(
242
self,
243
dataset: Dataset,
244
batch_size: int,
245
drop_last: bool = False,
246
generator: torch.Generator | None = None
247
)
248
```
249
`{ .api }`
250
251
Standard batch sampler for single dataset training.
252
253
### MultiDatasetDefaultBatchSampler
254
255
```python
256
class MultiDatasetDefaultBatchSampler:
257
def __init__(
258
self,
259
datasets: dict[str, Dataset],
260
batch_sizes: dict[str, int] | int,
261
sampling_strategy: str = "proportional",
262
generator: torch.Generator | None = None
263
)
264
```
265
`{ .api }`
266
267
Abstract base class for multi-dataset batch sampling.
268
269
**Parameters**:
270
- `datasets`: Dictionary mapping dataset names to Dataset objects
271
- `batch_sizes`: Batch sizes per dataset or single batch size for all
272
- `sampling_strategy`: How to sample from multiple datasets ("proportional", "round_robin")
273
- `generator`: Random number generator for reproducibility
274
275
## Usage Examples
276
277
### Basic Training Setup
278
279
```python
280
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
281
from sentence_transformers import SentenceTransformerTrainingArguments
282
from sentence_transformers.losses import MultipleNegativesRankingLoss
283
from datasets import Dataset
284
285
# Prepare training data
286
train_data = [
287
{"anchor": "The cat sits on the mat", "positive": "A feline rests on a rug"},
288
{"anchor": "Python is a programming language", "positive": "Python is used for coding"},
289
{"anchor": "Machine learning uses data", "positive": "ML algorithms process datasets"}
290
]
291
292
train_dataset = Dataset.from_list(train_data)
293
294
# Initialize model
295
model = SentenceTransformer('distilbert-base-uncased')
296
297
# Define loss function
298
loss = MultipleNegativesRankingLoss(model)
299
300
# Training arguments
301
args = SentenceTransformerTrainingArguments(
302
output_dir='./sentence-transformer-training',
303
num_train_epochs=3,
304
per_device_train_batch_size=16,
305
per_device_eval_batch_size=16,
306
learning_rate=2e-5,
307
warmup_ratio=0.1,
308
logging_steps=10,
309
logging_dir='./logs',
310
evaluation_strategy="steps",
311
eval_steps=100,
312
save_steps=100,
313
save_total_limit=2,
314
load_best_model_at_end=True,
315
metric_for_best_model="eval_loss",
316
greater_is_better=False,
317
run_name="sentence-transformer-training"
318
)
319
320
# Create trainer
321
trainer = SentenceTransformerTrainer(
322
model=model,
323
args=args,
324
train_dataset=train_dataset,
325
loss=loss
326
)
327
328
# Train model
329
trainer.train()
330
331
# Save trained model
332
model.save('./trained-sentence-transformer')
333
```
334
335
### Multi-Dataset Training
336
337
```python
338
from sentence_transformers.losses import CosineSimilarityLoss, TripletLoss
339
340
# Prepare multiple datasets
341
dataset1 = Dataset.from_list([
342
{"sentence1": "The cat sits", "sentence2": "A cat is sitting", "label": 1.0},
343
{"sentence1": "Dogs are pets", "sentence2": "Cats are animals", "label": 0.3}
344
])
345
346
dataset2 = Dataset.from_list([
347
{"anchor": "Python programming", "positive": "Coding in Python", "negative": "Java development"},
348
{"anchor": "Machine learning", "positive": "AI algorithms", "negative": "Web design"}
349
])
350
351
# Define different losses for different datasets
352
loss1 = CosineSimilarityLoss(model)
353
loss2 = TripletLoss(model)
354
355
# Multi-dataset training
356
trainer = SentenceTransformerTrainer(
357
model=model,
358
args=args,
359
train_dataset={"similarity": dataset1, "triplet": dataset2},
360
loss={"similarity": loss1, "triplet": loss2}
361
)
362
363
trainer.train()
364
```
365
366
### Advanced Training with Evaluation
367
368
```python
369
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
370
import numpy as np
371
372
# Prepare evaluation data
373
eval_sentences1 = ["The cat sits on the mat", "I love programming"]
374
eval_sentences2 = ["A feline rests on a rug", "I enjoy coding"]
375
eval_scores = [0.9, 0.8] # Similarity scores
376
377
evaluator = EmbeddingSimilarityEvaluator(
378
eval_sentences1,
379
eval_sentences2,
380
eval_scores,
381
name="dev"
382
)
383
384
def compute_metrics(eval_pred):
385
"""Custom metrics computation."""
386
predictions = eval_pred.predictions
387
labels = eval_pred.label_ids
388
389
# Compute custom metrics
390
mse = np.mean((predictions - labels) ** 2)
391
return {"mse": mse}
392
393
# Enhanced training arguments
394
args = SentenceTransformerTrainingArguments(
395
output_dir='./advanced-training',
396
num_train_epochs=5,
397
per_device_train_batch_size=32,
398
gradient_accumulation_steps=2,
399
learning_rate=2e-5,
400
weight_decay=0.01,
401
warmup_ratio=0.1,
402
lr_scheduler_type="cosine",
403
logging_steps=10,
404
evaluation_strategy="epoch",
405
save_strategy="epoch",
406
save_total_limit=3,
407
load_best_model_at_end=True,
408
metric_for_best_model="eval_cosine_pearson",
409
greater_is_better=True,
410
push_to_hub=False,
411
report_to=["tensorboard"],
412
run_name="advanced-sentence-transformer"
413
)
414
415
trainer = SentenceTransformerTrainer(
416
model=model,
417
args=args,
418
train_dataset=train_dataset,
419
loss=loss,
420
compute_metrics=compute_metrics
421
)
422
423
# Add evaluation
424
def evaluation_callback(trainer):
425
"""Custom evaluation during training."""
426
results = evaluator(model, output_path=args.output_dir + "/evaluation")
427
trainer.log({"eval_cosine_pearson": results})
428
429
# Train with evaluation
430
trainer.train()
431
```
432
433
### Custom Loss Function Training
434
435
```python
436
from sentence_transformers.losses import MatryoshkaLoss, TripletLoss
437
438
# Matryoshka representation learning
439
base_loss = MultipleNegativesRankingLoss(model)
440
matryoshka_loss = MatryoshkaLoss(
441
model=model,
442
loss=base_loss,
443
matryoshka_dims=[768, 512, 256, 128, 64] # Progressive dimensions
444
)
445
446
# Training with progressive dimensionality
447
trainer = SentenceTransformerTrainer(
448
model=model,
449
args=args,
450
train_dataset=train_dataset,
451
loss=matryoshka_loss
452
)
453
454
trainer.train()
455
456
# The trained model can now produce embeddings at multiple dimensions
457
embeddings_768 = model.encode(["Test sentence"], truncate_dim=768)
458
embeddings_256 = model.encode(["Test sentence"], truncate_dim=256)
459
embeddings_64 = model.encode(["Test sentence"], truncate_dim=64)
460
```
461
462
### Distributed Training
463
464
```python
465
# For multi-GPU training
466
args = SentenceTransformerTrainingArguments(
467
output_dir='./distributed-training',
468
num_train_epochs=3,
469
per_device_train_batch_size=16, # Per GPU batch size
470
gradient_accumulation_steps=4, # Effective batch size: 16 * 4 * num_gpus
471
dataloader_num_workers=4,
472
ddp_find_unused_parameters=False,
473
fp16=True, # Mixed precision training
474
logging_steps=10,
475
save_steps=500,
476
evaluation_strategy="steps",
477
eval_steps=500,
478
warmup_ratio=0.1,
479
learning_rate=2e-5,
480
run_name="distributed-training"
481
)
482
483
# The trainer automatically handles multi-GPU setup
484
trainer = SentenceTransformerTrainer(
485
model=model,
486
args=args,
487
train_dataset=train_dataset,
488
loss=loss
489
)
490
491
# Launch with: torchrun --nproc_per_node=2 train_script.py
492
trainer.train()
493
```
494
495
### Hyperparameter Optimization
496
497
```python
498
from transformers import TrainerCallback
499
from ray import tune
500
501
def model_init():
502
"""Initialize model for hyperparameter search."""
503
return SentenceTransformer('distilbert-base-uncased')
504
505
def hp_space(trial):
506
"""Define hyperparameter search space."""
507
return {
508
"learning_rate": tune.loguniform(1e-6, 1e-4),
509
"per_device_train_batch_size": tune.choice([16, 32, 64]),
510
"warmup_ratio": tune.uniform(0.0, 0.3),
511
"weight_decay": tune.uniform(0.0, 0.3),
512
}
513
514
# Hyperparameter search
515
trainer = SentenceTransformerTrainer(
516
model_init=model_init,
517
args=args,
518
train_dataset=train_dataset,
519
loss=MultipleNegativesRankingLoss,
520
compute_metrics=compute_metrics
521
)
522
523
best_trial = trainer.hyperparameter_search(
524
hp_space=hp_space,
525
compute_objective=lambda metrics: metrics["eval_loss"],
526
n_trials=10,
527
direction="minimize"
528
)
529
530
print(f"Best hyperparameters: {best_trial.hyperparameters}")
531
```
532
533
### Custom Data Collator
534
535
```python
536
from transformers import DataCollatorWithPadding
537
from typing import Dict, List, Any
538
import torch
539
540
class CustomDataCollator:
541
"""Custom data collator for sentence transformer training."""
542
543
def __init__(self, tokenizer, padding=True, max_length=None):
544
self.tokenizer = tokenizer
545
self.padding = padding
546
self.max_length = max_length
547
548
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
549
# Custom processing of training examples
550
batch = {}
551
552
# Extract texts
553
texts = []
554
for feature in features:
555
if 'anchor' in feature and 'positive' in feature:
556
texts.extend([feature['anchor'], feature['positive']])
557
if 'negative' in feature:
558
texts.append(feature['negative'])
559
elif 'sentence1' in feature and 'sentence2' in feature:
560
texts.extend([feature['sentence1'], feature['sentence2']])
561
562
# Tokenize
563
tokenized = self.tokenizer(
564
texts,
565
padding=self.padding,
566
truncation=True,
567
max_length=self.max_length,
568
return_tensors='pt'
569
)
570
571
return tokenized
572
573
# Use custom data collator
574
custom_collator = CustomDataCollator(model.tokenizer, max_length=512)
575
576
trainer = SentenceTransformerTrainer(
577
model=model,
578
args=args,
579
train_dataset=train_dataset,
580
data_collator=custom_collator,
581
loss=loss
582
)
583
```
584
585
### Training Callbacks
586
587
```python
588
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
589
import wandb
590
591
class CustomCallback(TrainerCallback):
592
"""Custom training callback for monitoring."""
593
594
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
595
print("Training started!")
596
597
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
598
# Custom logic at epoch end
599
model = kwargs.get('model')
600
if model:
601
# Evaluate on custom data or log additional metrics
602
pass
603
604
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
605
print(f"Model saved at step {state.global_step}")
606
607
# Add callbacks to trainer
608
trainer = SentenceTransformerTrainer(
609
model=model,
610
args=args,
611
train_dataset=train_dataset,
612
loss=loss,
613
callbacks=[CustomCallback()]
614
)
615
```
616
617
### Model Card Generation
618
619
```python
620
from sentence_transformers import SentenceTransformerModelCardData
621
622
# Create comprehensive model card
623
model_card_data = SentenceTransformerModelCardData(
624
language=['en'],
625
license='apache-2.0',
626
tags=['sentence-transformers', 'sentence-similarity', 'embeddings'],
627
model_name='Custom Sentence Transformer',
628
base_model='distilbert-base-uncased',
629
train_datasets=['custom-similarity-dataset'],
630
eval_datasets=['sts-benchmark'],
631
similarity_fn_name='cosine',
632
model_max_length=512
633
)
634
635
# Save model with card
636
trainer.save_model('./final-model')
637
model.save('./final-model', model_card_data=model_card_data)
638
639
# Push to hub with model card
640
model.push_to_hub('my-username/my-sentence-transformer', model_card_data=model_card_data)
641
```
642
643
## Best Practices
644
645
1. **Data Preparation**: Ensure high-quality, diverse training data
646
2. **Loss Selection**: Choose appropriate loss functions for your task
647
3. **Batch Size**: Use larger batch sizes when possible for contrastive learning
648
4. **Learning Rate**: Start with 2e-5 and adjust based on model size
649
5. **Evaluation**: Use domain-relevant evaluation metrics during training
650
6. **Checkpointing**: Save regular checkpoints and use early stopping
651
7. **Multi-Dataset**: Combine multiple datasets for robust representations
652
8. **Progressive Training**: Consider curriculum learning or progressive approaches