CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-setfit

Efficient few-shot learning with Sentence Transformers

Pending
Overview
Eval results
Files

knowledge-distillation.mddocs/

Knowledge Distillation

Teacher-student training framework for model compression and efficiency improvements. Knowledge distillation allows training smaller, faster student models that retain much of the performance of larger teacher models.

Capabilities

Distillation Trainer

Main trainer class for knowledge distillation between SetFit models.

class DistillationTrainer:
    def __init__(
        self,
        teacher_model: SetFitModel,
        student_model: SetFitModel,
        args: Optional[TrainingArguments] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        model_init: Optional[Callable[[], SetFitModel]] = None,
        compute_metrics: Optional[Callable] = None,
        callbacks: Optional[List] = None,
        optimizers: Optional[Tuple] = None,
        preprocess_logits_for_metrics: Optional[Callable] = None,
        column_mapping: Optional[Dict[str, str]] = None
    ):
        """
        Initialize a distillation trainer for knowledge transfer.

        Parameters:
        - teacher_model: Pre-trained SetFit model to distill knowledge from
        - student_model: Smaller SetFit model to train as student
        - args: Training arguments for distillation process
        - train_dataset: Training dataset for distillation
        - eval_dataset: Evaluation dataset for monitoring performance
        - model_init: Function to initialize student model (for HP search)
        - compute_metrics: Function to compute evaluation metrics
        - callbacks: List of training callbacks
        - optimizers: Custom optimizers for student model
        - preprocess_logits_for_metrics: Function to preprocess logits
        - column_mapping: Mapping of dataset columns to expected names
        """

    def train(self) -> None:
        """
        Train the student model using knowledge distillation.
        
        The training process involves:
        1. Generate embeddings from teacher model
        2. Train student model to match teacher embeddings
        3. Fine-tune student classification head
        """

    def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
        """
        Evaluate the student model on evaluation dataset.

        Parameters:
        - eval_dataset: Evaluation dataset (uses trainer's eval_dataset if None)

        Returns:
        Dictionary of evaluation metrics for student model
        """

    def predict(self, test_dataset: Dataset) -> "PredictionOutput":
        """
        Generate predictions using the trained student model.

        Parameters:
        - test_dataset: Test dataset for predictions

        Returns:
        Predictions from student model
        """

Distillation Dataset Classes

Specialized dataset classes for contrastive distillation training.

class ContrastiveDataset:
    def __init__(
        self,
        sentences: List[str],
        labels: List[int],
        sampling_strategy: str = "oversampling"
    ):
        """
        Dataset for contrastive learning with positive and negative pairs.

        Parameters:
        - sentences: List of input sentences
        - labels: List of corresponding labels
        - sampling_strategy: Strategy for sampling pairs ("oversampling", "undersampling", "unique")
        """

class ContrastiveDistillationDataset:
    def __init__(
        self,
        sentences: List[str],
        labels: List[int],
        teacher_embeddings: np.ndarray,
        sampling_strategy: str = "oversampling"
    ):
        """
        Dataset for contrastive distillation with teacher embeddings.

        Parameters:
        - sentences: List of input sentences
        - labels: List of corresponding labels  
        - teacher_embeddings: Pre-computed embeddings from teacher model
        - sampling_strategy: Strategy for sampling pairs
        """

Usage Examples

Basic Knowledge Distillation

from setfit import SetFitModel, DistillationTrainer, TrainingArguments
from datasets import Dataset

# Prepare training data
train_texts = [
    "I love this movie!", "This film is terrible.", 
    "Amazing cinematography!", "Waste of time.",
    "Brilliant acting!", "Poor storyline."
]
train_labels = [1, 0, 1, 0, 1, 0]

train_dataset = Dataset.from_dict({
    "text": train_texts,
    "label": train_labels
})

# Load pre-trained teacher model (larger, more accurate)
teacher_model = SetFitModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")

# Initialize student model (smaller, faster)
student_model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

# Configure distillation training
args = TrainingArguments(
    output_dir="./distillation_results",
    batch_size=16,
    num_epochs=4,
    learning_rate=2e-5,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=50
)

# Create distillation trainer
distillation_trainer = DistillationTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    args=args,
    train_dataset=train_dataset,
    column_mapping={"text": "text", "label": "label"}
)

# Train student model through distillation
print("Starting knowledge distillation...")
distillation_trainer.train()

# The student model is now trained to mimic the teacher
student_predictions = student_model.predict([
    "This movie is fantastic!",
    "I didn't enjoy this film."
])
print(f"Student predictions: {student_predictions}")

Comparing Teacher vs Student Performance

from setfit import SetFitModel, DistillationTrainer, TrainingArguments
from datasets import load_dataset
from sklearn.metrics import accuracy_score, classification_report
import time

# Load dataset
train_dataset = load_dataset("SetFit/sst2", split="train[:100]")  # Small subset for demo
test_dataset = load_dataset("SetFit/sst2", split="test[:50]")

# Teacher model (large, accurate)
teacher_model = SetFitModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")

# Student model (small, fast)
student_model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

# Train teacher model first
print("Training teacher model...")
teacher_trainer = SetFitTrainer(
    model=teacher_model,
    train_dataset=train_dataset,
    args=TrainingArguments(num_epochs=4, batch_size=16)
)
teacher_trainer.train()

# Train student via distillation
print("Training student model via distillation...")
distillation_trainer = DistillationTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    train_dataset=train_dataset,
    args=TrainingArguments(num_epochs=4, batch_size=16)
)
distillation_trainer.train()

# Compare performance and speed
test_texts = test_dataset["text"]
test_labels = test_dataset["label"]

# Teacher predictions
start_time = time.time()
teacher_preds = teacher_model.predict(test_texts)
teacher_time = time.time() - start_time

# Student predictions  
start_time = time.time()
student_preds = student_model.predict(test_texts)
student_time = time.time() - start_time

# Calculate metrics
teacher_acc = accuracy_score(test_labels, teacher_preds)
student_acc = accuracy_score(test_labels, student_preds)

print(f"\nPerformance Comparison:")
print(f"Teacher accuracy: {teacher_acc:.3f} (Time: {teacher_time:.3f}s)")
print(f"Student accuracy: {student_acc:.3f} (Time: {student_time:.3f}s)")
print(f"Speed improvement: {teacher_time/student_time:.1f}x")
print(f"Accuracy retention: {student_acc/teacher_acc:.1%}")

print(f"\nDetailed Student Results:")
print(classification_report(test_labels, student_preds))

Multi-Teacher Distillation

from setfit import SetFitModel, DistillationTrainer, TrainingArguments
import numpy as np

# Load multiple teacher models with different strengths
teacher1 = SetFitModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
teacher2 = SetFitModel.from_pretrained("sentence-transformers/all-roberta-large-v1") 
teacher3 = SetFitModel.from_pretrained("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")

# Train teachers on the same dataset
teachers = [teacher1, teacher2, teacher3]
for i, teacher in enumerate(teachers):
    print(f"Training teacher {i+1}...")
    trainer = SetFitTrainer(
        model=teacher,
        train_dataset=train_dataset,
        args=TrainingArguments(num_epochs=3, batch_size=16)
    )
    trainer.train()

# Create ensemble predictions for student training
def create_ensemble_dataset(teachers, dataset):
    """Create training dataset with ensemble teacher guidance."""
    texts = dataset["text"]
    labels = dataset["label"]
    
    # Get predictions from all teachers
    teacher_probs = []
    for teacher in teachers:
        probs = teacher.predict_proba(texts)
        teacher_probs.append(probs)
    
    # Average teacher predictions
    ensemble_probs = np.mean(teacher_probs, axis=0)
    
    # Use soft labels from ensemble
    return Dataset.from_dict({
        "text": texts,
        "label": labels,
        "soft_labels": ensemble_probs.tolist()
    })

# Create enhanced training dataset
enhanced_dataset = create_ensemble_dataset(teachers, train_dataset)

# Student model
student_model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

# Custom distillation trainer that uses ensemble guidance
# (This would require custom implementation in practice)
distillation_trainer = DistillationTrainer(
    teacher_model=teacher1,  # Use first teacher as primary
    student_model=student_model,
    train_dataset=enhanced_dataset,
    args=TrainingArguments(num_epochs=5, batch_size=16)
)

distillation_trainer.train()

Progressive Distillation

from setfit import SetFitModel, DistillationTrainer, TrainingArguments

# Create a chain of models: Large -> Medium -> Small
large_model = SetFitModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
medium_model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L12-v2")  
small_model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

# Stage 1: Train large model (teacher)
print("Stage 1: Training large model...")
large_trainer = SetFitTrainer(
    model=large_model,
    train_dataset=train_dataset,
    args=TrainingArguments(num_epochs=4, batch_size=16)
)
large_trainer.train()

# Stage 2: Distill large -> medium
print("Stage 2: Distilling large -> medium...")
medium_distillation = DistillationTrainer(
    teacher_model=large_model,
    student_model=medium_model,
    train_dataset=train_dataset,
    args=TrainingArguments(num_epochs=4, batch_size=16)
)
medium_distillation.train()

# Stage 3: Distill medium -> small  
print("Stage 3: Distilling medium -> small...")
small_distillation = DistillationTrainer(
    teacher_model=medium_model,
    student_model=small_model,
    train_dataset=train_dataset,
    args=TrainingArguments(num_epochs=4, batch_size=16)
)
small_distillation.train()

# Compare all models
models = {
    "Large": large_model,
    "Medium": medium_model, 
    "Small": small_model
}

test_texts = ["This is amazing!", "This is terrible."]

print("\nProgressive Distillation Results:")
for name, model in models.items():
    start_time = time.time()
    predictions = model.predict(test_texts)
    inference_time = time.time() - start_time
    
    print(f"{name} model: {predictions} (Time: {inference_time:.4f}s)")

Distillation with Custom Loss

from setfit import DistillationTrainer, TrainingArguments
import torch
import torch.nn.functional as F

class CustomDistillationTrainer(DistillationTrainer):
    def __init__(self, *args, temperature=4.0, alpha=0.7, **kwargs):
        super().__init__(*args, **kwargs)
        self.temperature = temperature
        self.alpha = alpha  # Weight for distillation loss vs task loss
    
    def compute_distillation_loss(self, teacher_logits, student_logits, labels):
        """Custom distillation loss combining soft and hard targets."""
        # Soft target loss (KL divergence)
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)
        student_log_probs = F.log_softmax(student_logits / self.temperature, dim=1)
        distillation_loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
        distillation_loss *= (self.temperature ** 2)
        
        # Hard target loss (standard cross-entropy)
        task_loss = F.cross_entropy(student_logits, labels)
        
        # Combined loss
        total_loss = self.alpha * distillation_loss + (1 - self.alpha) * task_loss
        return total_loss

# Use custom trainer
custom_trainer = CustomDistillationTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    train_dataset=train_dataset,
    args=TrainingArguments(num_epochs=4, batch_size=16),
    temperature=5.0,  # Higher temperature for softer distributions
    alpha=0.8        # More weight on distillation loss
)

custom_trainer.train()

Evaluating Distillation Quality

from setfit import SetFitModel, DistillationTrainer
from sklearn.metrics import accuracy_score
import numpy as np
from scipy.stats import spearmanr

def evaluate_distillation_quality(teacher_model, student_model, test_dataset):
    """Comprehensive evaluation of distillation quality."""
    test_texts = test_dataset["text"]
    test_labels = test_dataset["label"]
    
    # Get predictions and probabilities
    teacher_preds = teacher_model.predict(test_texts)
    student_preds = student_model.predict(test_texts)
    
    teacher_probs = teacher_model.predict_proba(test_texts)
    student_probs = student_model.predict_proba(test_texts)
    
    # Calculate metrics
    teacher_acc = accuracy_score(test_labels, teacher_preds)
    student_acc = accuracy_score(test_labels, student_preds)
    
    # Prediction agreement between teacher and student
    agreement = accuracy_score(teacher_preds, student_preds)
    
    # Probability correlation (how similar are the confidence scores)
    teacher_max_probs = np.max(teacher_probs, axis=1)
    student_max_probs = np.max(student_probs, axis=1)
    prob_correlation, _ = spearmanr(teacher_max_probs, student_max_probs)
    
    # KL divergence between probability distributions
    kl_divergences = []
    for t_prob, s_prob in zip(teacher_probs, student_probs):
        # Add small epsilon to avoid log(0)
        kl_div = np.sum(t_prob * np.log((t_prob + 1e-8) / (s_prob + 1e-8)))
        kl_divergences.append(kl_div)
    avg_kl_div = np.mean(kl_divergences)
    
    results = {
        "teacher_accuracy": teacher_acc,
        "student_accuracy": student_acc,
        "accuracy_retention": student_acc / teacher_acc,
        "prediction_agreement": agreement,
        "probability_correlation": prob_correlation,
        "avg_kl_divergence": avg_kl_div
    }
    
    return results

# Evaluate distillation
evaluation_results = evaluate_distillation_quality(
    teacher_model=teacher_model,
    student_model=student_model,
    test_dataset=test_dataset
)

print("Distillation Quality Assessment:")
for metric, value in evaluation_results.items():
    if isinstance(value, float):
        print(f"{metric}: {value:.4f}")
    else:
        print(f"{metric}: {value}")

Install with Tessl CLI

npx tessl i tessl/pypi-setfit

docs

absa.md

core-model-training.md

data-utilities.md

index.md

knowledge-distillation.md

model-cards.md

model-export.md

tile.json