0
# Core Model and Training
1
2
Main model classes and training functionality for few-shot text classification with sentence transformers. These components form the foundation of SetFit's approach to efficient few-shot learning.
3
4
## Capabilities
5
6
### SetFit Model
7
8
The main model class that combines a sentence transformer for embedding generation with a classification head for predictions.
9
10
```python { .api }
11
class SetFitModel:
12
def __init__(
13
self,
14
model_body: Optional[SentenceTransformer] = None,
15
model_head: Optional[Union[SetFitHead, LogisticRegression]] = None,
16
multi_target_strategy: Optional[str] = None,
17
normalize_embeddings: bool = False,
18
labels: Optional[List[str]] = None,
19
model_card_data: Optional[SetFitModelCardData] = None,
20
sentence_transformers_kwargs: Optional[Dict] = None
21
):
22
"""
23
Initialize a SetFit model with sentence transformer and classification head.
24
25
Parameters:
26
- model_body: Pre-trained sentence transformer model for embeddings
27
- model_head: Classification head (sklearn LogisticRegression or SetFitHead)
28
- multi_target_strategy: Strategy for multi-label classification ("one-vs-rest", "multi-output", "classifier-chain")
29
- normalize_embeddings: Whether to normalize embeddings before classification
30
- labels: List of label names for interpretation
31
- model_card_data: Metadata for model card generation
32
- sentence_transformers_kwargs: Additional arguments for sentence transformer
33
"""
34
35
def fit(
36
self,
37
x_train: List[str],
38
y_train: Union[List[int], List[List[int]]],
39
num_epochs: int,
40
batch_size: Optional[int] = None,
41
body_learning_rate: Optional[float] = None,
42
head_learning_rate: Optional[float] = None,
43
end_to_end: bool = False,
44
l2_weight: Optional[float] = None,
45
max_length: Optional[int] = None,
46
show_progress_bar: bool = True
47
):
48
"""
49
Fit the SetFit model on training data.
50
51
Parameters:
52
- x_train: Training texts (list of strings)
53
- y_train: Training labels (list of integers or lists for multi-label)
54
- num_epochs: Number of training epochs
55
- batch_size: Training batch size (optional)
56
- body_learning_rate: Learning rate for sentence transformer body (optional)
57
- head_learning_rate: Learning rate for classification head (optional)
58
- end_to_end: Whether to perform end-to-end training
59
- l2_weight: L2 regularization weight (optional)
60
- max_length: Maximum sequence length for tokenization (optional)
61
- show_progress_bar: Whether to show training progress bar
62
"""
63
64
def predict(
65
self,
66
inputs: Union[str, List[str]],
67
batch_size: int = 32,
68
as_numpy: bool = False,
69
use_labels: bool = True,
70
show_progress_bar: Optional[bool] = None
71
) -> Union[torch.Tensor, np.ndarray, List[str], int, str]:
72
"""
73
Make predictions on test data.
74
75
Parameters:
76
- inputs: Input text(s) to predict (single string or list of strings)
77
- batch_size: Batch size for prediction (default: 32)
78
- as_numpy: Return predictions as numpy array instead of torch tensor
79
- use_labels: Return label names instead of integers (if labels available)
80
- show_progress_bar: Whether to show progress bar during prediction
81
82
Returns:
83
Predicted class labels (format depends on parameters)
84
"""
85
86
def predict_proba(
87
self,
88
inputs: Union[str, List[str]],
89
batch_size: int = 32,
90
as_numpy: bool = False,
91
show_progress_bar: Optional[bool] = None
92
) -> Union[torch.Tensor, np.ndarray]:
93
"""
94
Get prediction probabilities for test data.
95
96
Parameters:
97
- inputs: Input text(s) to predict (single string or list of strings)
98
- batch_size: Batch size for prediction (default: 32)
99
- as_numpy: Return probabilities as numpy array instead of torch tensor
100
- show_progress_bar: Whether to show progress bar during prediction
101
102
Returns:
103
Prediction probabilities for each class
104
"""
105
106
def encode(
107
self,
108
inputs: List[str],
109
batch_size: int = 32,
110
show_progress_bar: Optional[bool] = None
111
) -> Union[torch.Tensor, np.ndarray]:
112
"""
113
Generate embeddings for input texts.
114
115
Parameters:
116
- inputs: Input texts (list of strings)
117
- batch_size: Batch size for encoding (default: 32)
118
- show_progress_bar: Whether to show progress bar during encoding
119
120
Returns:
121
Text embeddings as tensor or numpy array
122
"""
123
124
@classmethod
125
def from_pretrained(
126
cls,
127
model_id: str,
128
revision: Optional[str] = None,
129
cache_dir: Optional[str] = None,
130
force_download: bool = False,
131
local_files_only: bool = False,
132
token: Optional[str] = None,
133
**kwargs
134
):
135
"""
136
Load a pre-trained SetFit model from Hugging Face Hub or local path.
137
138
Parameters:
139
- model_id: Model identifier or local path
140
- revision: Model revision/branch to use
141
- cache_dir: Directory to cache downloaded models
142
- force_download: Force re-download even if cached
143
- local_files_only: Only use local files, no downloads
144
- token: Hugging Face access token for private models
145
"""
146
147
def save_pretrained(self, save_directory: str, **kwargs):
148
"""
149
Save the model to a directory.
150
151
Parameters:
152
- save_directory: Directory path to save model files
153
"""
154
155
@property
156
def device(self):
157
"""Get the device (CPU/GPU) the model is on."""
158
159
@property
160
def has_differentiable_head(self) -> bool:
161
"""Check if model uses a differentiable (PyTorch) head."""
162
163
@property
164
def id2label(self) -> Dict[int, str]:
165
"""Mapping from label IDs to label names."""
166
167
@property
168
def label2id(self) -> Dict[str, int]:
169
"""Mapping from label names to label IDs."""
170
```
171
172
### SetFit Head
173
174
Differentiable classification head for end-to-end training with sentence transformers.
175
176
```python { .api }
177
class SetFitHead:
178
def __init__(
179
self,
180
in_features: Optional[int] = None,
181
out_features: int = 2,
182
temperature: float = 1.0,
183
eps: float = 1e-5,
184
bias: bool = True,
185
device: Optional[Union[torch.device, str]] = None,
186
multitarget: bool = False
187
):
188
"""
189
Initialize a differentiable classification head.
190
191
Parameters:
192
- in_features: Number of input features (embedding dimension)
193
- out_features: Number of output classes
194
- temperature: Temperature for softmax normalization
195
- eps: Small epsilon for numerical stability
196
- bias: Whether to use bias in linear layer
197
- device: Device to place the model on
198
- multitarget: Whether this is for multi-label classification
199
"""
200
201
def forward(self, features: torch.Tensor) -> torch.Tensor:
202
"""
203
Forward pass through the classification head.
204
205
Parameters:
206
- features: Input embeddings tensor
207
208
Returns:
209
Logits tensor
210
"""
211
212
def predict(self, features: torch.Tensor) -> torch.Tensor:
213
"""
214
Get class predictions from features.
215
216
Parameters:
217
- features: Input embeddings tensor
218
219
Returns:
220
Predicted class indices
221
"""
222
223
def predict_proba(self, features: torch.Tensor) -> torch.Tensor:
224
"""
225
Get prediction probabilities from features.
226
227
Parameters:
228
- features: Input embeddings tensor
229
230
Returns:
231
Class probabilities tensor
232
"""
233
```
234
235
### SetFit Trainer
236
237
Main trainer class for training SetFit models with comprehensive training configuration and monitoring.
238
239
```python { .api }
240
class SetFitTrainer:
241
def __init__(
242
self,
243
model: Optional[SetFitModel] = None,
244
args: Optional[TrainingArguments] = None,
245
train_dataset: Optional[Dataset] = None,
246
eval_dataset: Optional[Dataset] = None,
247
model_init: Optional[Callable[[], SetFitModel]] = None,
248
compute_metrics: Optional[Callable] = None,
249
callbacks: Optional[List] = None,
250
optimizers: Optional[Tuple] = None,
251
preprocess_logits_for_metrics: Optional[Callable] = None,
252
column_mapping: Optional[Dict[str, str]] = None
253
):
254
"""
255
Initialize a SetFit trainer.
256
257
Parameters:
258
- model: SetFit model to train
259
- args: Training arguments and hyperparameters
260
- train_dataset: Training dataset (HuggingFace Dataset)
261
- eval_dataset: Evaluation dataset (HuggingFace Dataset)
262
- model_init: Function to initialize model (for hyperparameter search)
263
- compute_metrics: Function to compute evaluation metrics
264
- callbacks: List of training callbacks
265
- optimizers: Custom optimizers (body_optimizer, head_optimizer)
266
- preprocess_logits_for_metrics: Function to preprocess logits before metrics
267
- column_mapping: Mapping of dataset columns to expected names
268
"""
269
270
def train(self) -> None:
271
"""
272
Train the SetFit model using the configured training arguments.
273
274
Performs two-phase training:
275
1. Fine-tune sentence transformer on contrastive pairs
276
2. Train classification head on embeddings
277
"""
278
279
def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
280
"""
281
Evaluate the model on evaluation dataset.
282
283
Parameters:
284
- eval_dataset: Evaluation dataset (uses trainer's eval_dataset if None)
285
286
Returns:
287
Dictionary of evaluation metrics
288
"""
289
290
def predict(self, test_dataset: Dataset) -> "PredictionOutput":
291
"""
292
Generate predictions on test dataset.
293
294
Parameters:
295
- test_dataset: Test dataset
296
297
Returns:
298
Predictions and optionally metrics
299
"""
300
301
def hyperparameter_search(
302
self,
303
hp_space: Optional[Callable] = None,
304
compute_objective: Optional[Callable] = None,
305
n_trials: int = 20,
306
direction: str = "maximize",
307
backend: Optional[str] = None,
308
hp_name: Optional[Callable] = None,
309
**kwargs
310
):
311
"""
312
Perform hyperparameter search using Optuna.
313
314
Parameters:
315
- hp_space: Function defining hyperparameter search space
316
- compute_objective: Function to compute optimization objective
317
- n_trials: Number of trials to run
318
- direction: Optimization direction ("maximize" or "minimize")
319
- backend: Backend for hyperparameter search
320
- hp_name: Function to generate trial names
321
"""
322
```
323
324
### Training Arguments
325
326
Comprehensive configuration class for training hyperparameters and settings.
327
328
```python { .api }
329
class TrainingArguments:
330
def __init__(
331
self,
332
output_dir: str = "./results",
333
batch_size: int = 16,
334
num_epochs: Union[int, Tuple[int, int]] = 1,
335
max_steps: Union[int, Tuple[int, int]] = -1,
336
sampling_strategy: str = "oversampling",
337
learning_rate: Union[float, Tuple[float, float]] = 2e-5,
338
loss: Callable = None,
339
distance_metric: Callable = None,
340
margin: float = 0.25,
341
use_amp: bool = False,
342
warmup_proportion: float = 0.1,
343
l2_weight: float = 0.01,
344
max_length: int = 512,
345
show_progress_bar: bool = True,
346
seed: int = 42,
347
use_differentiable_head: bool = False,
348
normalize_embeddings: bool = False,
349
eval_strategy: str = "no",
350
eval_steps: int = 500,
351
eval_max_steps: int = -1,
352
eval_delay: float = 0,
353
load_best_model_at_end: bool = False,
354
metric_for_best_model: str = "eval_loss",
355
greater_is_better: bool = False,
356
run_name: Optional[str] = None,
357
logging_dir: Optional[str] = None,
358
logging_strategy: str = "steps",
359
logging_steps: int = 500,
360
save_strategy: str = "steps",
361
save_steps: int = 500,
362
save_total_limit: Optional[int] = None,
363
no_cuda: bool = False,
364
dataloader_drop_last: bool = False,
365
dataloader_num_workers: int = 0,
366
dataloader_pin_memory: bool = True,
367
**kwargs
368
):
369
"""
370
Training arguments for SetFit model training.
371
372
Parameters:
373
- output_dir: Directory to save model outputs and logs
374
- batch_size: Training batch size
375
- num_epochs: Number of training epochs (can be tuple for body/head)
376
- max_steps: Maximum training steps (overrides num_epochs if > 0)
377
- sampling_strategy: Strategy for sampling training pairs ("oversampling", "undersampling", "unique")
378
- learning_rate: Learning rate (can be tuple for body/head)
379
- loss: Custom loss function for contrastive learning
380
- distance_metric: Distance metric for similarity computation
381
- margin: Margin for triplet loss
382
- use_amp: Use automatic mixed precision training
383
- warmup_proportion: Proportion of steps for learning rate warmup
384
- l2_weight: L2 regularization weight
385
- max_length: Maximum sequence length for tokenization
386
- show_progress_bar: Show progress bar during training
387
- seed: Random seed for reproducibility
388
- use_differentiable_head: Use PyTorch head instead of sklearn
389
- normalize_embeddings: Normalize embeddings before classification
390
- eval_strategy: Evaluation strategy ("no", "steps", "epoch")
391
- eval_steps: Number of steps between evaluations
392
- eval_max_steps: Maximum steps for evaluation
393
- eval_delay: Delay before starting evaluation
394
- load_best_model_at_end: Load best model based on metric at end
395
- metric_for_best_model: Metric to use for best model selection
396
- greater_is_better: Whether greater metric value is better
397
- run_name: Name for the training run (for logging)
398
- logging_dir: Directory for training logs
399
- logging_strategy: When to log ("no", "steps", "epoch")
400
- logging_steps: Number of steps between logging
401
- save_strategy: When to save checkpoints ("no", "steps", "epoch")
402
- save_steps: Number of steps between saves
403
- save_total_limit: Maximum number of checkpoints to keep
404
- no_cuda: Disable CUDA even if available
405
- dataloader_drop_last: Drop last incomplete batch
406
- dataloader_num_workers: Number of dataloader workers
407
- dataloader_pin_memory: Pin memory in dataloader for faster GPU transfer
408
"""
409
```
410
411
## Usage Examples
412
413
### Basic Training Pipeline
414
415
```python
416
from setfit import SetFitModel, SetFitTrainer, TrainingArguments
417
from datasets import Dataset
418
from sklearn.metrics import accuracy_score
419
420
# Prepare dataset
421
train_dataset = Dataset.from_dict({
422
"text": ["Great movie!", "Terrible film.", "Love it!", "Hate it."],
423
"label": [1, 0, 1, 0]
424
})
425
426
eval_dataset = Dataset.from_dict({
427
"text": ["Good film.", "Not good."],
428
"label": [1, 0]
429
})
430
431
# Initialize model
432
model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
433
434
# Configure training
435
args = TrainingArguments(
436
batch_size=16,
437
num_epochs=(2, 16), # 2 epochs for body, 16 for head
438
learning_rate=(2e-5, 1e-3), # Different rates for body/head
439
eval_strategy="epoch",
440
save_strategy="epoch",
441
load_best_model_at_end=True,
442
metric_for_best_model="eval_accuracy",
443
greater_is_better=True
444
)
445
446
def compute_metrics(eval_pred):
447
predictions, labels = eval_pred
448
return {"accuracy": accuracy_score(labels, predictions)}
449
450
# Create trainer
451
trainer = SetFitTrainer(
452
model=model,
453
args=args,
454
train_dataset=train_dataset,
455
eval_dataset=eval_dataset,
456
compute_metrics=compute_metrics,
457
column_mapping={"text": "text", "label": "label"}
458
)
459
460
# Train and evaluate
461
trainer.train()
462
results = trainer.evaluate()
463
print(f"Final accuracy: {results['eval_accuracy']:.3f}")
464
```
465
466
### Using Differentiable Head
467
468
```python
469
from setfit import SetFitModel, SetFitHead, TrainingArguments, SetFitTrainer
470
471
# Create model with differentiable head
472
model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
473
model.model_head = SetFitHead(
474
in_features=384, # Embedding dimension
475
out_features=3, # Number of classes
476
temperature=0.1 # Lower temperature for sharper predictions
477
)
478
479
# Configure for end-to-end training
480
args = TrainingArguments(
481
use_differentiable_head=True,
482
batch_size=32,
483
num_epochs=5,
484
learning_rate=2e-5,
485
warmup_proportion=0.1,
486
use_amp=True # Use mixed precision for speed
487
)
488
489
trainer = SetFitTrainer(model=model, args=args, train_dataset=train_dataset)
490
trainer.train()
491
```
492
493
### Hyperparameter Search
494
495
```python
496
from setfit.integrations import default_hp_space_optuna
497
498
def model_init():
499
return SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
500
501
trainer = SetFitTrainer(
502
model_init=model_init,
503
args=args,
504
train_dataset=train_dataset,
505
eval_dataset=eval_dataset,
506
compute_metrics=compute_metrics
507
)
508
509
# Run hyperparameter search
510
best_trial = trainer.hyperparameter_search(
511
hp_space=default_hp_space_optuna,
512
n_trials=10,
513
direction="maximize"
514
)
515
516
print(f"Best hyperparameters: {best_trial.hyperparameters}")
517
print(f"Best score: {best_trial.objective}")
518
```