Efficient few-shot learning with Sentence Transformers
—
Dataset preparation, sampling, and templating utilities for few-shot learning scenarios. These functions help create balanced training sets, generate synthetic examples, and prepare data in the format expected by SetFit models.
Create balanced few-shot datasets by sampling equal numbers of examples per class.
def sample_dataset(
dataset: Dataset,
label_column: str = "label",
num_samples: int = 8,
seed: int = 42
) -> Dataset:
"""
Sample a Dataset to create equal number of samples per class.
Parameters:
- dataset: HuggingFace Dataset to sample from
- label_column: Name of the column containing labels
- num_samples: Number of samples per class to select
- seed: Random seed for reproducible sampling
Returns:
New Dataset with balanced samples per class
"""
def create_samples(
df: "pd.DataFrame",
sample_size: int,
seed: int = 42
) -> "pd.DataFrame":
"""
Sample DataFrame with equal samples per class.
Parameters:
- df: Input pandas DataFrame
- sample_size: Number of samples per class
- seed: Random seed for reproducibility
Returns:
Sampled DataFrame with balanced classes
"""
def create_samples_multilabel(
df: "pd.DataFrame",
sample_size: int,
seed: int = 42
) -> "pd.DataFrame":
"""
Sample DataFrame for multilabel classification scenarios.
Parameters:
- df: Input pandas DataFrame with multilabel targets
- sample_size: Number of samples to select
- seed: Random seed for reproducibility
Returns:
Sampled DataFrame for multilabel training
"""Generate training splits with different sample sizes for few-shot learning experiments.
def create_fewshot_splits(
dataset: Dataset,
sample_sizes: List[int] = [2, 4, 8, 16, 32, 64],
add_data_augmentation: bool = False,
dataset_name: Optional[str] = None
) -> DatasetDict:
"""
Create training splits with equal samples per class for different shot sizes.
Parameters:
- dataset: Source dataset to create splits from
- sample_sizes: List of sample sizes to create splits for
- add_data_augmentation: Whether to add data augmentation
- dataset_name: Name of the dataset for tracking
Returns:
DatasetDict with splits for each sample size
"""
def create_fewshot_splits_multilabel(
dataset: Dataset,
sample_sizes: List[int] = [2, 4, 8, 16]
) -> DatasetDict:
"""
Create multilabel training splits with different sample sizes.
Parameters:
- dataset: Source multilabel dataset
- sample_sizes: List of sample sizes to create splits for
Returns:
DatasetDict with multilabel splits for each sample size
"""Generate synthetic training examples using templates and candidate labels.
def get_templated_dataset(
dataset: Optional[Dataset] = None,
candidate_labels: Optional[List[str]] = None,
reference_dataset: Optional[str] = None,
template: str = "This example is {}",
sample_size: int = 2,
text_column: str = "text",
label_column: str = "label",
multi_label: bool = False,
label_names_column: str = "label_text"
) -> Dataset:
"""
Create templated examples for a reference dataset or reference labels.
Parameters:
- dataset: Source dataset to template (optional)
- candidate_labels: List of label names to create templates for
- reference_dataset: Name of reference dataset to use
- template: Template string with {} placeholder for label
- sample_size: Number of examples per label to generate
- text_column: Name of text column in dataset
- label_column: Name of label column in dataset
- multi_label: Whether this is multi-label classification
- label_names_column: Column containing label names
Returns:
Dataset with templated examples
"""
def get_candidate_labels(
dataset_name: str,
label_names_column: str = "label_text"
) -> List[str]:
"""
Extract candidate labels from a dataset.
Parameters:
- dataset_name: Name of the dataset to extract labels from
- label_names_column: Column containing label names
Returns:
List of unique label names
"""Custom dataset class for training differentiable heads with PyTorch.
class SetFitDataset:
def __init__(
self,
x: List[str],
y: Union[List[int], List[List[int]]],
tokenizer: "PreTrainedTokenizerBase",
max_length: int = 512
):
"""
Dataset for training differentiable head on text classification.
Parameters:
- x: List of input texts
- y: List of labels (integers for single-label, lists for multi-label)
- tokenizer: HuggingFace tokenizer for text processing
- max_length: Maximum sequence length for tokenization
"""
def __len__(self) -> int:
"""Return the number of examples in the dataset."""
def __getitem__(self, index: int) -> Dict[str, Any]:
"""
Get a single example from the dataset.
Parameters:
- index: Index of the example to retrieve
Returns:
Dictionary with input_ids, attention_mask, and labels
"""
@staticmethod
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
"""
Collate function for batching examples.
Parameters:
- batch: List of examples from __getitem__
Returns:
Batched tensors ready for model input
"""# Default seeds for reproducible sampling
SEEDS: List[int] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
# Default sample sizes for few-shot experiments
SAMPLE_SIZES: List[int] = [2, 4, 8, 16, 32, 64]
# Type alias for tokenizer output
TokenizerOutput = Dict[str, List[int]]from setfit import sample_dataset
from datasets import load_dataset
# Load a dataset
dataset = load_dataset("emotion", split="train")
print(f"Original dataset size: {len(dataset)}")
# Create balanced 8-shot dataset
few_shot_dataset = sample_dataset(
dataset=dataset,
label_column="label",
num_samples=8,
seed=42
)
print(f"Few-shot dataset size: {len(few_shot_dataset)}")
print(f"Samples per class: 8")
# Check distribution
from collections import Counter
label_dist = Counter(few_shot_dataset["label"])
print(f"Label distribution: {dict(label_dist)}")from setfit import create_fewshot_splits
from datasets import load_dataset
# Load dataset
dataset = load_dataset("imdb", split="train")
# Create multiple few-shot splits
splits = create_fewshot_splits(
dataset=dataset,
sample_sizes=[2, 4, 8, 16, 32],
add_data_augmentation=False
)
print(f"Created splits: {list(splits.keys())}")
for split_name, split_data in splits.items():
print(f"{split_name}: {len(split_data)} examples")
# Use a specific split for training
train_2_shot = splits["train-2"]
print(f"2-shot training set: {len(train_2_shot)} examples")from setfit import get_templated_dataset
# Create templated dataset from labels
candidate_labels = ["positive", "negative", "neutral"]
templated_dataset = get_templated_dataset(
candidate_labels=candidate_labels,
template="This text expresses {} sentiment.",
sample_size=4
)
print("Generated templated examples:")
for i, example in enumerate(templated_dataset):
print(f"{i}: {example['text']} -> {example['label']}")from setfit import SetFitDataset
from transformers import AutoTokenizer
import torch
# Prepare data
texts = ["I love this!", "This is terrible.", "Amazing work!", "Not good."]
labels = [1, 0, 1, 0]
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
# Create dataset
dataset = SetFitDataset(
x=texts,
y=labels,
tokenizer=tokenizer,
max_length=256
)
# Create dataloader
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=2,
collate_fn=dataset.collate_fn,
shuffle=True
)
# Iterate through batches
for batch in dataloader:
print(f"Batch shape: input_ids={batch['input_ids'].shape}")
print(f"Labels: {batch['labels']}")
breakfrom setfit import create_fewshot_splits_multilabel, create_samples_multilabel
from datasets import Dataset
import pandas as pd
# Create multi-label dataset
data = {
"text": ["Great action movie", "Romantic comedy", "Scary thriller", "Funny drama"],
"labels": [[1, 0, 0], [0, 1, 1], [0, 0, 1], [1, 1, 0]] # [action, comedy, drama]
}
dataset = Dataset.from_dict(data)
# Create few-shot splits for multi-label
ml_splits = create_fewshot_splits_multilabel(
dataset=dataset,
sample_sizes=[2, 4]
)
print(f"Multi-label splits: {list(ml_splits.keys())}")
# Or use with pandas DataFrame
df = pd.DataFrame(data)
sampled_df = create_samples_multilabel(df, sample_size=2, seed=42)
print(f"Sampled DataFrame shape: {sampled_df.shape}")from setfit import SetFitModel, SetFitTrainer, TrainingArguments, create_fewshot_splits
from datasets import load_dataset
import numpy as np
# Load dataset
dataset = load_dataset("SetFit/sst2", split="train")
test_dataset = load_dataset("SetFit/sst2", split="test")
# Create multiple training splits
splits = create_fewshot_splits(
dataset=dataset,
sample_sizes=[2, 4, 8, 16],
add_data_augmentation=False
)
# Benchmark performance across sample sizes
results = {}
model_name = "sentence-transformers/all-MiniLM-L6-v2"
for split_name, train_split in splits.items():
print(f"\nTraining on {split_name}...")
# Initialize fresh model for each experiment
model = SetFitModel.from_pretrained(model_name)
args = TrainingArguments(
batch_size=16,
num_epochs=4,
eval_strategy="epoch"
)
trainer = SetFitTrainer(
model=model,
args=args,
train_dataset=train_split,
eval_dataset=test_dataset
)
trainer.train()
eval_results = trainer.evaluate()
results[split_name] = eval_results["eval_accuracy"]
print(f"{split_name} accuracy: {eval_results['eval_accuracy']:.3f}")
print("\nFinal Results:")
for split_name, accuracy in results.items():
print(f"{split_name}: {accuracy:.3f}")Install with Tessl CLI
npx tessl i tessl/pypi-setfit