Efficient few-shot learning with Sentence Transformers
—
Main model classes and training functionality for few-shot text classification with sentence transformers. These components form the foundation of SetFit's approach to efficient few-shot learning.
The main model class that combines a sentence transformer for embedding generation with a classification head for predictions.
class SetFitModel:
def __init__(
self,
model_body: Optional[SentenceTransformer] = None,
model_head: Optional[Union[SetFitHead, LogisticRegression]] = None,
multi_target_strategy: Optional[str] = None,
normalize_embeddings: bool = False,
labels: Optional[List[str]] = None,
model_card_data: Optional[SetFitModelCardData] = None,
sentence_transformers_kwargs: Optional[Dict] = None
):
"""
Initialize a SetFit model with sentence transformer and classification head.
Parameters:
- model_body: Pre-trained sentence transformer model for embeddings
- model_head: Classification head (sklearn LogisticRegression or SetFitHead)
- multi_target_strategy: Strategy for multi-label classification ("one-vs-rest", "multi-output", "classifier-chain")
- normalize_embeddings: Whether to normalize embeddings before classification
- labels: List of label names for interpretation
- model_card_data: Metadata for model card generation
- sentence_transformers_kwargs: Additional arguments for sentence transformer
"""
def fit(
self,
x_train: List[str],
y_train: Union[List[int], List[List[int]]],
num_epochs: int,
batch_size: Optional[int] = None,
body_learning_rate: Optional[float] = None,
head_learning_rate: Optional[float] = None,
end_to_end: bool = False,
l2_weight: Optional[float] = None,
max_length: Optional[int] = None,
show_progress_bar: bool = True
):
"""
Fit the SetFit model on training data.
Parameters:
- x_train: Training texts (list of strings)
- y_train: Training labels (list of integers or lists for multi-label)
- num_epochs: Number of training epochs
- batch_size: Training batch size (optional)
- body_learning_rate: Learning rate for sentence transformer body (optional)
- head_learning_rate: Learning rate for classification head (optional)
- end_to_end: Whether to perform end-to-end training
- l2_weight: L2 regularization weight (optional)
- max_length: Maximum sequence length for tokenization (optional)
- show_progress_bar: Whether to show training progress bar
"""
def predict(
self,
inputs: Union[str, List[str]],
batch_size: int = 32,
as_numpy: bool = False,
use_labels: bool = True,
show_progress_bar: Optional[bool] = None
) -> Union[torch.Tensor, np.ndarray, List[str], int, str]:
"""
Make predictions on test data.
Parameters:
- inputs: Input text(s) to predict (single string or list of strings)
- batch_size: Batch size for prediction (default: 32)
- as_numpy: Return predictions as numpy array instead of torch tensor
- use_labels: Return label names instead of integers (if labels available)
- show_progress_bar: Whether to show progress bar during prediction
Returns:
Predicted class labels (format depends on parameters)
"""
def predict_proba(
self,
inputs: Union[str, List[str]],
batch_size: int = 32,
as_numpy: bool = False,
show_progress_bar: Optional[bool] = None
) -> Union[torch.Tensor, np.ndarray]:
"""
Get prediction probabilities for test data.
Parameters:
- inputs: Input text(s) to predict (single string or list of strings)
- batch_size: Batch size for prediction (default: 32)
- as_numpy: Return probabilities as numpy array instead of torch tensor
- show_progress_bar: Whether to show progress bar during prediction
Returns:
Prediction probabilities for each class
"""
def encode(
self,
inputs: List[str],
batch_size: int = 32,
show_progress_bar: Optional[bool] = None
) -> Union[torch.Tensor, np.ndarray]:
"""
Generate embeddings for input texts.
Parameters:
- inputs: Input texts (list of strings)
- batch_size: Batch size for encoding (default: 32)
- show_progress_bar: Whether to show progress bar during encoding
Returns:
Text embeddings as tensor or numpy array
"""
@classmethod
def from_pretrained(
cls,
model_id: str,
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[str] = None,
**kwargs
):
"""
Load a pre-trained SetFit model from Hugging Face Hub or local path.
Parameters:
- model_id: Model identifier or local path
- revision: Model revision/branch to use
- cache_dir: Directory to cache downloaded models
- force_download: Force re-download even if cached
- local_files_only: Only use local files, no downloads
- token: Hugging Face access token for private models
"""
def save_pretrained(self, save_directory: str, **kwargs):
"""
Save the model to a directory.
Parameters:
- save_directory: Directory path to save model files
"""
@property
def device(self):
"""Get the device (CPU/GPU) the model is on."""
@property
def has_differentiable_head(self) -> bool:
"""Check if model uses a differentiable (PyTorch) head."""
@property
def id2label(self) -> Dict[int, str]:
"""Mapping from label IDs to label names."""
@property
def label2id(self) -> Dict[str, int]:
"""Mapping from label names to label IDs."""Differentiable classification head for end-to-end training with sentence transformers.
class SetFitHead:
def __init__(
self,
in_features: Optional[int] = None,
out_features: int = 2,
temperature: float = 1.0,
eps: float = 1e-5,
bias: bool = True,
device: Optional[Union[torch.device, str]] = None,
multitarget: bool = False
):
"""
Initialize a differentiable classification head.
Parameters:
- in_features: Number of input features (embedding dimension)
- out_features: Number of output classes
- temperature: Temperature for softmax normalization
- eps: Small epsilon for numerical stability
- bias: Whether to use bias in linear layer
- device: Device to place the model on
- multitarget: Whether this is for multi-label classification
"""
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the classification head.
Parameters:
- features: Input embeddings tensor
Returns:
Logits tensor
"""
def predict(self, features: torch.Tensor) -> torch.Tensor:
"""
Get class predictions from features.
Parameters:
- features: Input embeddings tensor
Returns:
Predicted class indices
"""
def predict_proba(self, features: torch.Tensor) -> torch.Tensor:
"""
Get prediction probabilities from features.
Parameters:
- features: Input embeddings tensor
Returns:
Class probabilities tensor
"""Main trainer class for training SetFit models with comprehensive training configuration and monitoring.
class SetFitTrainer:
def __init__(
self,
model: Optional[SetFitModel] = None,
args: Optional[TrainingArguments] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
model_init: Optional[Callable[[], SetFitModel]] = None,
compute_metrics: Optional[Callable] = None,
callbacks: Optional[List] = None,
optimizers: Optional[Tuple] = None,
preprocess_logits_for_metrics: Optional[Callable] = None,
column_mapping: Optional[Dict[str, str]] = None
):
"""
Initialize a SetFit trainer.
Parameters:
- model: SetFit model to train
- args: Training arguments and hyperparameters
- train_dataset: Training dataset (HuggingFace Dataset)
- eval_dataset: Evaluation dataset (HuggingFace Dataset)
- model_init: Function to initialize model (for hyperparameter search)
- compute_metrics: Function to compute evaluation metrics
- callbacks: List of training callbacks
- optimizers: Custom optimizers (body_optimizer, head_optimizer)
- preprocess_logits_for_metrics: Function to preprocess logits before metrics
- column_mapping: Mapping of dataset columns to expected names
"""
def train(self) -> None:
"""
Train the SetFit model using the configured training arguments.
Performs two-phase training:
1. Fine-tune sentence transformer on contrastive pairs
2. Train classification head on embeddings
"""
def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
"""
Evaluate the model on evaluation dataset.
Parameters:
- eval_dataset: Evaluation dataset (uses trainer's eval_dataset if None)
Returns:
Dictionary of evaluation metrics
"""
def predict(self, test_dataset: Dataset) -> "PredictionOutput":
"""
Generate predictions on test dataset.
Parameters:
- test_dataset: Test dataset
Returns:
Predictions and optionally metrics
"""
def hyperparameter_search(
self,
hp_space: Optional[Callable] = None,
compute_objective: Optional[Callable] = None,
n_trials: int = 20,
direction: str = "maximize",
backend: Optional[str] = None,
hp_name: Optional[Callable] = None,
**kwargs
):
"""
Perform hyperparameter search using Optuna.
Parameters:
- hp_space: Function defining hyperparameter search space
- compute_objective: Function to compute optimization objective
- n_trials: Number of trials to run
- direction: Optimization direction ("maximize" or "minimize")
- backend: Backend for hyperparameter search
- hp_name: Function to generate trial names
"""Comprehensive configuration class for training hyperparameters and settings.
class TrainingArguments:
def __init__(
self,
output_dir: str = "./results",
batch_size: int = 16,
num_epochs: Union[int, Tuple[int, int]] = 1,
max_steps: Union[int, Tuple[int, int]] = -1,
sampling_strategy: str = "oversampling",
learning_rate: Union[float, Tuple[float, float]] = 2e-5,
loss: Callable = None,
distance_metric: Callable = None,
margin: float = 0.25,
use_amp: bool = False,
warmup_proportion: float = 0.1,
l2_weight: float = 0.01,
max_length: int = 512,
show_progress_bar: bool = True,
seed: int = 42,
use_differentiable_head: bool = False,
normalize_embeddings: bool = False,
eval_strategy: str = "no",
eval_steps: int = 500,
eval_max_steps: int = -1,
eval_delay: float = 0,
load_best_model_at_end: bool = False,
metric_for_best_model: str = "eval_loss",
greater_is_better: bool = False,
run_name: Optional[str] = None,
logging_dir: Optional[str] = None,
logging_strategy: str = "steps",
logging_steps: int = 500,
save_strategy: str = "steps",
save_steps: int = 500,
save_total_limit: Optional[int] = None,
no_cuda: bool = False,
dataloader_drop_last: bool = False,
dataloader_num_workers: int = 0,
dataloader_pin_memory: bool = True,
**kwargs
):
"""
Training arguments for SetFit model training.
Parameters:
- output_dir: Directory to save model outputs and logs
- batch_size: Training batch size
- num_epochs: Number of training epochs (can be tuple for body/head)
- max_steps: Maximum training steps (overrides num_epochs if > 0)
- sampling_strategy: Strategy for sampling training pairs ("oversampling", "undersampling", "unique")
- learning_rate: Learning rate (can be tuple for body/head)
- loss: Custom loss function for contrastive learning
- distance_metric: Distance metric for similarity computation
- margin: Margin for triplet loss
- use_amp: Use automatic mixed precision training
- warmup_proportion: Proportion of steps for learning rate warmup
- l2_weight: L2 regularization weight
- max_length: Maximum sequence length for tokenization
- show_progress_bar: Show progress bar during training
- seed: Random seed for reproducibility
- use_differentiable_head: Use PyTorch head instead of sklearn
- normalize_embeddings: Normalize embeddings before classification
- eval_strategy: Evaluation strategy ("no", "steps", "epoch")
- eval_steps: Number of steps between evaluations
- eval_max_steps: Maximum steps for evaluation
- eval_delay: Delay before starting evaluation
- load_best_model_at_end: Load best model based on metric at end
- metric_for_best_model: Metric to use for best model selection
- greater_is_better: Whether greater metric value is better
- run_name: Name for the training run (for logging)
- logging_dir: Directory for training logs
- logging_strategy: When to log ("no", "steps", "epoch")
- logging_steps: Number of steps between logging
- save_strategy: When to save checkpoints ("no", "steps", "epoch")
- save_steps: Number of steps between saves
- save_total_limit: Maximum number of checkpoints to keep
- no_cuda: Disable CUDA even if available
- dataloader_drop_last: Drop last incomplete batch
- dataloader_num_workers: Number of dataloader workers
- dataloader_pin_memory: Pin memory in dataloader for faster GPU transfer
"""from setfit import SetFitModel, SetFitTrainer, TrainingArguments
from datasets import Dataset
from sklearn.metrics import accuracy_score
# Prepare dataset
train_dataset = Dataset.from_dict({
"text": ["Great movie!", "Terrible film.", "Love it!", "Hate it."],
"label": [1, 0, 1, 0]
})
eval_dataset = Dataset.from_dict({
"text": ["Good film.", "Not good."],
"label": [1, 0]
})
# Initialize model
model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
# Configure training
args = TrainingArguments(
batch_size=16,
num_epochs=(2, 16), # 2 epochs for body, 16 for head
learning_rate=(2e-5, 1e-3), # Different rates for body/head
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_accuracy",
greater_is_better=True
)
def compute_metrics(eval_pred):
predictions, labels = eval_pred
return {"accuracy": accuracy_score(labels, predictions)}
# Create trainer
trainer = SetFitTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
column_mapping={"text": "text", "label": "label"}
)
# Train and evaluate
trainer.train()
results = trainer.evaluate()
print(f"Final accuracy: {results['eval_accuracy']:.3f}")from setfit import SetFitModel, SetFitHead, TrainingArguments, SetFitTrainer
# Create model with differentiable head
model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model.model_head = SetFitHead(
in_features=384, # Embedding dimension
out_features=3, # Number of classes
temperature=0.1 # Lower temperature for sharper predictions
)
# Configure for end-to-end training
args = TrainingArguments(
use_differentiable_head=True,
batch_size=32,
num_epochs=5,
learning_rate=2e-5,
warmup_proportion=0.1,
use_amp=True # Use mixed precision for speed
)
trainer = SetFitTrainer(model=model, args=args, train_dataset=train_dataset)
trainer.train()from setfit.integrations import default_hp_space_optuna
def model_init():
return SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
trainer = SetFitTrainer(
model_init=model_init,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics
)
# Run hyperparameter search
best_trial = trainer.hyperparameter_search(
hp_space=default_hp_space_optuna,
n_trials=10,
direction="maximize"
)
print(f"Best hyperparameters: {best_trial.hyperparameters}")
print(f"Best score: {best_trial.objective}")Install with Tessl CLI
npx tessl i tessl/pypi-setfit