Efficient few-shot learning with Sentence Transformers
npx @tessl/cli install tessl/pypi-setfit@1.1.0SetFit is an efficient and prompt-free framework for few-shot fine-tuning of Sentence Transformers that achieves high accuracy with minimal labeled data. It eliminates the need for handcrafted prompts by generating rich embeddings directly from text examples, trains significantly faster than large-scale models like T0 or GPT-3, and provides multilingual classification support through any Sentence Transformer model.
pip install setfitimport setfitCommon imports for working with SetFit models:
from setfit import SetFitModel, SetFitTrainer, TrainingArgumentsfrom setfit import SetFitModel, SetFitTrainer, TrainingArguments
from datasets import Dataset
# Prepare your few-shot dataset
train_texts = [
"I love this movie!",
"This film is terrible.",
"Amazing cinematography!",
"Waste of time."
]
train_labels = [1, 0, 1, 0] # 1 = positive, 0 = negative
train_dataset = Dataset.from_dict({
"text": train_texts,
"label": train_labels
})
# Initialize a SetFit model from a pre-trained Sentence Transformer
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
# Create trainer with training arguments
args = TrainingArguments(
batch_size=16,
num_epochs=4,
evaluation_strategy="epoch"
)
trainer = SetFitTrainer(
model=model,
train_dataset=train_dataset,
args=args,
column_mapping={"text": "text", "label": "label"}
)
# Train the model
trainer.train()
# Make predictions
predictions = model.predict([
"This movie is fantastic!",
"I didn't enjoy this film."
])
print(predictions) # [1, 0]
# Get prediction probabilities
probs = model.predict_proba([
"This movie is fantastic!",
"I didn't enjoy this film."
])
print(probs) # [[0.1, 0.9], [0.8, 0.2]]SetFit combines two key components for efficient few-shot learning:
This design enables SetFit to achieve strong performance with minimal training data (as few as 8 examples per class) while training much faster than large generative models.
Main model classes and training functionality for few-shot text classification with sentence transformers.
class SetFitModel:
def __init__(
self,
model_body: Optional[SentenceTransformer] = None,
model_head: Optional[Union[SetFitHead, LogisticRegression]] = None,
multi_target_strategy: Optional[str] = None,
normalize_embeddings: bool = False,
labels: Optional[List[str]] = None,
model_card_data: Optional[SetFitModelCardData] = None,
sentence_transformers_kwargs: Optional[Dict] = None
): ...
def fit(
self,
x_train: List[str],
y_train: Union[List[int], List[List[int]]],
num_epochs: int,
batch_size: Optional[int] = None,
body_learning_rate: Optional[float] = None,
head_learning_rate: Optional[float] = None,
end_to_end: bool = False,
l2_weight: Optional[float] = None,
max_length: Optional[int] = None,
show_progress_bar: bool = True
): ...
def predict(
self,
inputs: Union[str, List[str]],
batch_size: int = 32,
as_numpy: bool = False,
use_labels: bool = True,
show_progress_bar: Optional[bool] = None
): ...
def predict_proba(
self,
inputs: Union[str, List[str]],
batch_size: int = 32,
as_numpy: bool = False,
show_progress_bar: Optional[bool] = None
): ...
def encode(
self,
inputs: List[str],
batch_size: int = 32,
show_progress_bar: Optional[bool] = None
): ...
class SetFitTrainer:
def __init__(
self,
model: Optional[SetFitModel] = None,
args: Optional[TrainingArguments] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
model_init: Optional[Callable[[], SetFitModel]] = None,
compute_metrics: Optional[Callable] = None,
callbacks: Optional[List] = None,
column_mapping: Optional[Dict[str, str]] = None
): ...
def train(self): ...
def evaluate(self, eval_dataset: Optional[Dataset] = None): ...
def predict(self, test_dataset: Dataset): ...
class TrainingArguments:
def __init__(
self,
output_dir: str = "./results",
batch_size: int = 16,
num_epochs: Union[int, Tuple[int, int]] = 1,
max_steps: Union[int, Tuple[int, int]] = -1,
sampling_strategy: str = "oversampling",
learning_rate: Union[float, Tuple[float, float]] = 2e-5,
eval_strategy: str = "no",
save_strategy: str = "steps"
# ... more parameters available
): ...Dataset preparation, sampling, and templating utilities for few-shot learning scenarios.
def sample_dataset(
dataset: Dataset,
label_column: str = "label",
num_samples: int = 8,
seed: int = 42
) -> Dataset: ...
def get_templated_dataset(
dataset: Optional[Dataset] = None,
candidate_labels: Optional[List[str]] = None,
reference_dataset: Optional[str] = None,
template: str = "This example is {}",
sample_size: int = 2,
text_column: str = "text",
label_column: str = "label",
multi_label: bool = False,
label_names_column: str = "label_text"
) -> Dataset: ...
def create_fewshot_splits(
dataset: Dataset,
sample_sizes: List[int] = [2, 4, 8, 16, 32, 64],
add_data_augmentation: bool = False,
dataset_name: Optional[str] = None
) -> DatasetDict: ...Teacher-student training framework for model compression and efficiency improvements.
class DistillationTrainer:
def __init__(
self,
teacher_model: SetFitModel,
student_model: Optional[SetFitModel] = None,
args: Optional[TrainingArguments] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
model_init: Optional[Callable[[], SetFitModel]] = None,
metric: Union[str, Callable] = "accuracy",
column_mapping: Optional[Dict[str, str]] = None
): ...
def train(self): ...
def evaluate(self, eval_dataset: Optional[Dataset] = None): ...Specialized models and trainers for aspect-based sentiment analysis tasks with span-level predictions.
class AbsaModel:
def __init__(self, aspect_model=None, polarity_model=None): ...
def predict(self, inputs): ...
class AspectModel:
def __init__(self, spacy_model="en_core_web_sm", span_context=0): ...
class PolarityModel:
def __init__(self, spacy_model="en_core_web_sm", span_context=0): ...Aspect-Based Sentiment Analysis
Export functionality for ONNX and OpenVINO formats to enable efficient deployment and inference.
# Note: These functions require explicit imports from submodules:
# from setfit.exporters.onnx import export_onnx
# from setfit.exporters.openvino import export_to_openvino
def export_onnx(
model_body: SentenceTransformer,
model_head: Union[torch.nn.Module, LogisticRegression],
opset: int,
output_path: str = "model.onnx",
ignore_ir_version: bool = True,
use_hummingbird: bool = False
) -> None: ...
def export_to_openvino(
model: SetFitModel,
output_path: str = "model.xml"
) -> None: ...Automatic model card generation and metadata management for reproducibility and documentation.
class SetFitModelCardData:
def __init__(self, language=None, license=None, tags=None, model_name=None): ...
def set_train_set_metrics(self, metrics): ...
def generate_model_card(self): ...__version__: str # Library version ("1.1.3")# Core types used across the API
from typing import List, Dict, Optional, Union, Tuple, Any, Callable
import numpy as np
import torch
from datasets import Dataset, DatasetDict
from sentence_transformers import SentenceTransformer
from sklearn.linear_model import LogisticRegression
# SetFit-specific type aliases
ModelBody = SentenceTransformer
ModelHead = Union[LogisticRegression, "SetFitHead"]
Labels = Union[List[int], List[List[int]]] # Single or multi-label
PredictionOutput = Union[np.ndarray, List[int]]
ProbabilityOutput = Union[np.ndarray, List[List[float]]]