0
# Knowledge Distillation
1
2
Teacher-student training framework for model compression and efficiency improvements. Knowledge distillation allows training smaller, faster student models that retain much of the performance of larger teacher models.
3
4
## Capabilities
5
6
### Distillation Trainer
7
8
Main trainer class for knowledge distillation between SetFit models.
9
10
```python { .api }
11
class DistillationTrainer:
12
def __init__(
13
self,
14
teacher_model: SetFitModel,
15
student_model: SetFitModel,
16
args: Optional[TrainingArguments] = None,
17
train_dataset: Optional[Dataset] = None,
18
eval_dataset: Optional[Dataset] = None,
19
model_init: Optional[Callable[[], SetFitModel]] = None,
20
compute_metrics: Optional[Callable] = None,
21
callbacks: Optional[List] = None,
22
optimizers: Optional[Tuple] = None,
23
preprocess_logits_for_metrics: Optional[Callable] = None,
24
column_mapping: Optional[Dict[str, str]] = None
25
):
26
"""
27
Initialize a distillation trainer for knowledge transfer.
28
29
Parameters:
30
- teacher_model: Pre-trained SetFit model to distill knowledge from
31
- student_model: Smaller SetFit model to train as student
32
- args: Training arguments for distillation process
33
- train_dataset: Training dataset for distillation
34
- eval_dataset: Evaluation dataset for monitoring performance
35
- model_init: Function to initialize student model (for HP search)
36
- compute_metrics: Function to compute evaluation metrics
37
- callbacks: List of training callbacks
38
- optimizers: Custom optimizers for student model
39
- preprocess_logits_for_metrics: Function to preprocess logits
40
- column_mapping: Mapping of dataset columns to expected names
41
"""
42
43
def train(self) -> None:
44
"""
45
Train the student model using knowledge distillation.
46
47
The training process involves:
48
1. Generate embeddings from teacher model
49
2. Train student model to match teacher embeddings
50
3. Fine-tune student classification head
51
"""
52
53
def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
54
"""
55
Evaluate the student model on evaluation dataset.
56
57
Parameters:
58
- eval_dataset: Evaluation dataset (uses trainer's eval_dataset if None)
59
60
Returns:
61
Dictionary of evaluation metrics for student model
62
"""
63
64
def predict(self, test_dataset: Dataset) -> "PredictionOutput":
65
"""
66
Generate predictions using the trained student model.
67
68
Parameters:
69
- test_dataset: Test dataset for predictions
70
71
Returns:
72
Predictions from student model
73
"""
74
```
75
76
### Distillation Dataset Classes
77
78
Specialized dataset classes for contrastive distillation training.
79
80
```python { .api }
81
class ContrastiveDataset:
82
def __init__(
83
self,
84
sentences: List[str],
85
labels: List[int],
86
sampling_strategy: str = "oversampling"
87
):
88
"""
89
Dataset for contrastive learning with positive and negative pairs.
90
91
Parameters:
92
- sentences: List of input sentences
93
- labels: List of corresponding labels
94
- sampling_strategy: Strategy for sampling pairs ("oversampling", "undersampling", "unique")
95
"""
96
97
class ContrastiveDistillationDataset:
98
def __init__(
99
self,
100
sentences: List[str],
101
labels: List[int],
102
teacher_embeddings: np.ndarray,
103
sampling_strategy: str = "oversampling"
104
):
105
"""
106
Dataset for contrastive distillation with teacher embeddings.
107
108
Parameters:
109
- sentences: List of input sentences
110
- labels: List of corresponding labels
111
- teacher_embeddings: Pre-computed embeddings from teacher model
112
- sampling_strategy: Strategy for sampling pairs
113
"""
114
```
115
116
## Usage Examples
117
118
### Basic Knowledge Distillation
119
120
```python
121
from setfit import SetFitModel, DistillationTrainer, TrainingArguments
122
from datasets import Dataset
123
124
# Prepare training data
125
train_texts = [
126
"I love this movie!", "This film is terrible.",
127
"Amazing cinematography!", "Waste of time.",
128
"Brilliant acting!", "Poor storyline."
129
]
130
train_labels = [1, 0, 1, 0, 1, 0]
131
132
train_dataset = Dataset.from_dict({
133
"text": train_texts,
134
"label": train_labels
135
})
136
137
# Load pre-trained teacher model (larger, more accurate)
138
teacher_model = SetFitModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
139
140
# Initialize student model (smaller, faster)
141
student_model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
142
143
# Configure distillation training
144
args = TrainingArguments(
145
output_dir="./distillation_results",
146
batch_size=16,
147
num_epochs=4,
148
learning_rate=2e-5,
149
eval_strategy="epoch",
150
save_strategy="epoch",
151
logging_steps=50
152
)
153
154
# Create distillation trainer
155
distillation_trainer = DistillationTrainer(
156
teacher_model=teacher_model,
157
student_model=student_model,
158
args=args,
159
train_dataset=train_dataset,
160
column_mapping={"text": "text", "label": "label"}
161
)
162
163
# Train student model through distillation
164
print("Starting knowledge distillation...")
165
distillation_trainer.train()
166
167
# The student model is now trained to mimic the teacher
168
student_predictions = student_model.predict([
169
"This movie is fantastic!",
170
"I didn't enjoy this film."
171
])
172
print(f"Student predictions: {student_predictions}")
173
```
174
175
### Comparing Teacher vs Student Performance
176
177
```python
178
from setfit import SetFitModel, DistillationTrainer, TrainingArguments
179
from datasets import load_dataset
180
from sklearn.metrics import accuracy_score, classification_report
181
import time
182
183
# Load dataset
184
train_dataset = load_dataset("SetFit/sst2", split="train[:100]") # Small subset for demo
185
test_dataset = load_dataset("SetFit/sst2", split="test[:50]")
186
187
# Teacher model (large, accurate)
188
teacher_model = SetFitModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
189
190
# Student model (small, fast)
191
student_model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
192
193
# Train teacher model first
194
print("Training teacher model...")
195
teacher_trainer = SetFitTrainer(
196
model=teacher_model,
197
train_dataset=train_dataset,
198
args=TrainingArguments(num_epochs=4, batch_size=16)
199
)
200
teacher_trainer.train()
201
202
# Train student via distillation
203
print("Training student model via distillation...")
204
distillation_trainer = DistillationTrainer(
205
teacher_model=teacher_model,
206
student_model=student_model,
207
train_dataset=train_dataset,
208
args=TrainingArguments(num_epochs=4, batch_size=16)
209
)
210
distillation_trainer.train()
211
212
# Compare performance and speed
213
test_texts = test_dataset["text"]
214
test_labels = test_dataset["label"]
215
216
# Teacher predictions
217
start_time = time.time()
218
teacher_preds = teacher_model.predict(test_texts)
219
teacher_time = time.time() - start_time
220
221
# Student predictions
222
start_time = time.time()
223
student_preds = student_model.predict(test_texts)
224
student_time = time.time() - start_time
225
226
# Calculate metrics
227
teacher_acc = accuracy_score(test_labels, teacher_preds)
228
student_acc = accuracy_score(test_labels, student_preds)
229
230
print(f"\nPerformance Comparison:")
231
print(f"Teacher accuracy: {teacher_acc:.3f} (Time: {teacher_time:.3f}s)")
232
print(f"Student accuracy: {student_acc:.3f} (Time: {student_time:.3f}s)")
233
print(f"Speed improvement: {teacher_time/student_time:.1f}x")
234
print(f"Accuracy retention: {student_acc/teacher_acc:.1%}")
235
236
print(f"\nDetailed Student Results:")
237
print(classification_report(test_labels, student_preds))
238
```
239
240
### Multi-Teacher Distillation
241
242
```python
243
from setfit import SetFitModel, DistillationTrainer, TrainingArguments
244
import numpy as np
245
246
# Load multiple teacher models with different strengths
247
teacher1 = SetFitModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
248
teacher2 = SetFitModel.from_pretrained("sentence-transformers/all-roberta-large-v1")
249
teacher3 = SetFitModel.from_pretrained("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
250
251
# Train teachers on the same dataset
252
teachers = [teacher1, teacher2, teacher3]
253
for i, teacher in enumerate(teachers):
254
print(f"Training teacher {i+1}...")
255
trainer = SetFitTrainer(
256
model=teacher,
257
train_dataset=train_dataset,
258
args=TrainingArguments(num_epochs=3, batch_size=16)
259
)
260
trainer.train()
261
262
# Create ensemble predictions for student training
263
def create_ensemble_dataset(teachers, dataset):
264
"""Create training dataset with ensemble teacher guidance."""
265
texts = dataset["text"]
266
labels = dataset["label"]
267
268
# Get predictions from all teachers
269
teacher_probs = []
270
for teacher in teachers:
271
probs = teacher.predict_proba(texts)
272
teacher_probs.append(probs)
273
274
# Average teacher predictions
275
ensemble_probs = np.mean(teacher_probs, axis=0)
276
277
# Use soft labels from ensemble
278
return Dataset.from_dict({
279
"text": texts,
280
"label": labels,
281
"soft_labels": ensemble_probs.tolist()
282
})
283
284
# Create enhanced training dataset
285
enhanced_dataset = create_ensemble_dataset(teachers, train_dataset)
286
287
# Student model
288
student_model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
289
290
# Custom distillation trainer that uses ensemble guidance
291
# (This would require custom implementation in practice)
292
distillation_trainer = DistillationTrainer(
293
teacher_model=teacher1, # Use first teacher as primary
294
student_model=student_model,
295
train_dataset=enhanced_dataset,
296
args=TrainingArguments(num_epochs=5, batch_size=16)
297
)
298
299
distillation_trainer.train()
300
```
301
302
### Progressive Distillation
303
304
```python
305
from setfit import SetFitModel, DistillationTrainer, TrainingArguments
306
307
# Create a chain of models: Large -> Medium -> Small
308
large_model = SetFitModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
309
medium_model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L12-v2")
310
small_model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
311
312
# Stage 1: Train large model (teacher)
313
print("Stage 1: Training large model...")
314
large_trainer = SetFitTrainer(
315
model=large_model,
316
train_dataset=train_dataset,
317
args=TrainingArguments(num_epochs=4, batch_size=16)
318
)
319
large_trainer.train()
320
321
# Stage 2: Distill large -> medium
322
print("Stage 2: Distilling large -> medium...")
323
medium_distillation = DistillationTrainer(
324
teacher_model=large_model,
325
student_model=medium_model,
326
train_dataset=train_dataset,
327
args=TrainingArguments(num_epochs=4, batch_size=16)
328
)
329
medium_distillation.train()
330
331
# Stage 3: Distill medium -> small
332
print("Stage 3: Distilling medium -> small...")
333
small_distillation = DistillationTrainer(
334
teacher_model=medium_model,
335
student_model=small_model,
336
train_dataset=train_dataset,
337
args=TrainingArguments(num_epochs=4, batch_size=16)
338
)
339
small_distillation.train()
340
341
# Compare all models
342
models = {
343
"Large": large_model,
344
"Medium": medium_model,
345
"Small": small_model
346
}
347
348
test_texts = ["This is amazing!", "This is terrible."]
349
350
print("\nProgressive Distillation Results:")
351
for name, model in models.items():
352
start_time = time.time()
353
predictions = model.predict(test_texts)
354
inference_time = time.time() - start_time
355
356
print(f"{name} model: {predictions} (Time: {inference_time:.4f}s)")
357
```
358
359
### Distillation with Custom Loss
360
361
```python
362
from setfit import DistillationTrainer, TrainingArguments
363
import torch
364
import torch.nn.functional as F
365
366
class CustomDistillationTrainer(DistillationTrainer):
367
def __init__(self, *args, temperature=4.0, alpha=0.7, **kwargs):
368
super().__init__(*args, **kwargs)
369
self.temperature = temperature
370
self.alpha = alpha # Weight for distillation loss vs task loss
371
372
def compute_distillation_loss(self, teacher_logits, student_logits, labels):
373
"""Custom distillation loss combining soft and hard targets."""
374
# Soft target loss (KL divergence)
375
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)
376
student_log_probs = F.log_softmax(student_logits / self.temperature, dim=1)
377
distillation_loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
378
distillation_loss *= (self.temperature ** 2)
379
380
# Hard target loss (standard cross-entropy)
381
task_loss = F.cross_entropy(student_logits, labels)
382
383
# Combined loss
384
total_loss = self.alpha * distillation_loss + (1 - self.alpha) * task_loss
385
return total_loss
386
387
# Use custom trainer
388
custom_trainer = CustomDistillationTrainer(
389
teacher_model=teacher_model,
390
student_model=student_model,
391
train_dataset=train_dataset,
392
args=TrainingArguments(num_epochs=4, batch_size=16),
393
temperature=5.0, # Higher temperature for softer distributions
394
alpha=0.8 # More weight on distillation loss
395
)
396
397
custom_trainer.train()
398
```
399
400
### Evaluating Distillation Quality
401
402
```python
403
from setfit import SetFitModel, DistillationTrainer
404
from sklearn.metrics import accuracy_score
405
import numpy as np
406
from scipy.stats import spearmanr
407
408
def evaluate_distillation_quality(teacher_model, student_model, test_dataset):
409
"""Comprehensive evaluation of distillation quality."""
410
test_texts = test_dataset["text"]
411
test_labels = test_dataset["label"]
412
413
# Get predictions and probabilities
414
teacher_preds = teacher_model.predict(test_texts)
415
student_preds = student_model.predict(test_texts)
416
417
teacher_probs = teacher_model.predict_proba(test_texts)
418
student_probs = student_model.predict_proba(test_texts)
419
420
# Calculate metrics
421
teacher_acc = accuracy_score(test_labels, teacher_preds)
422
student_acc = accuracy_score(test_labels, student_preds)
423
424
# Prediction agreement between teacher and student
425
agreement = accuracy_score(teacher_preds, student_preds)
426
427
# Probability correlation (how similar are the confidence scores)
428
teacher_max_probs = np.max(teacher_probs, axis=1)
429
student_max_probs = np.max(student_probs, axis=1)
430
prob_correlation, _ = spearmanr(teacher_max_probs, student_max_probs)
431
432
# KL divergence between probability distributions
433
kl_divergences = []
434
for t_prob, s_prob in zip(teacher_probs, student_probs):
435
# Add small epsilon to avoid log(0)
436
kl_div = np.sum(t_prob * np.log((t_prob + 1e-8) / (s_prob + 1e-8)))
437
kl_divergences.append(kl_div)
438
avg_kl_div = np.mean(kl_divergences)
439
440
results = {
441
"teacher_accuracy": teacher_acc,
442
"student_accuracy": student_acc,
443
"accuracy_retention": student_acc / teacher_acc,
444
"prediction_agreement": agreement,
445
"probability_correlation": prob_correlation,
446
"avg_kl_divergence": avg_kl_div
447
}
448
449
return results
450
451
# Evaluate distillation
452
evaluation_results = evaluate_distillation_quality(
453
teacher_model=teacher_model,
454
student_model=student_model,
455
test_dataset=test_dataset
456
)
457
458
print("Distillation Quality Assessment:")
459
for metric, value in evaluation_results.items():
460
if isinstance(value, float):
461
print(f"{metric}: {value:.4f}")
462
else:
463
print(f"{metric}: {value}")
464
```