or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

absa.mdcore-model-training.mddata-utilities.mdindex.mdknowledge-distillation.mdmodel-cards.mdmodel-export.md

knowledge-distillation.mddocs/

0

# Knowledge Distillation

1

2

Teacher-student training framework for model compression and efficiency improvements. Knowledge distillation allows training smaller, faster student models that retain much of the performance of larger teacher models.

3

4

## Capabilities

5

6

### Distillation Trainer

7

8

Main trainer class for knowledge distillation between SetFit models.

9

10

```python { .api }

11

class DistillationTrainer:

12

def __init__(

13

self,

14

teacher_model: SetFitModel,

15

student_model: SetFitModel,

16

args: Optional[TrainingArguments] = None,

17

train_dataset: Optional[Dataset] = None,

18

eval_dataset: Optional[Dataset] = None,

19

model_init: Optional[Callable[[], SetFitModel]] = None,

20

compute_metrics: Optional[Callable] = None,

21

callbacks: Optional[List] = None,

22

optimizers: Optional[Tuple] = None,

23

preprocess_logits_for_metrics: Optional[Callable] = None,

24

column_mapping: Optional[Dict[str, str]] = None

25

):

26

"""

27

Initialize a distillation trainer for knowledge transfer.

28

29

Parameters:

30

- teacher_model: Pre-trained SetFit model to distill knowledge from

31

- student_model: Smaller SetFit model to train as student

32

- args: Training arguments for distillation process

33

- train_dataset: Training dataset for distillation

34

- eval_dataset: Evaluation dataset for monitoring performance

35

- model_init: Function to initialize student model (for HP search)

36

- compute_metrics: Function to compute evaluation metrics

37

- callbacks: List of training callbacks

38

- optimizers: Custom optimizers for student model

39

- preprocess_logits_for_metrics: Function to preprocess logits

40

- column_mapping: Mapping of dataset columns to expected names

41

"""

42

43

def train(self) -> None:

44

"""

45

Train the student model using knowledge distillation.

46

47

The training process involves:

48

1. Generate embeddings from teacher model

49

2. Train student model to match teacher embeddings

50

3. Fine-tune student classification head

51

"""

52

53

def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:

54

"""

55

Evaluate the student model on evaluation dataset.

56

57

Parameters:

58

- eval_dataset: Evaluation dataset (uses trainer's eval_dataset if None)

59

60

Returns:

61

Dictionary of evaluation metrics for student model

62

"""

63

64

def predict(self, test_dataset: Dataset) -> "PredictionOutput":

65

"""

66

Generate predictions using the trained student model.

67

68

Parameters:

69

- test_dataset: Test dataset for predictions

70

71

Returns:

72

Predictions from student model

73

"""

74

```

75

76

### Distillation Dataset Classes

77

78

Specialized dataset classes for contrastive distillation training.

79

80

```python { .api }

81

class ContrastiveDataset:

82

def __init__(

83

self,

84

sentences: List[str],

85

labels: List[int],

86

sampling_strategy: str = "oversampling"

87

):

88

"""

89

Dataset for contrastive learning with positive and negative pairs.

90

91

Parameters:

92

- sentences: List of input sentences

93

- labels: List of corresponding labels

94

- sampling_strategy: Strategy for sampling pairs ("oversampling", "undersampling", "unique")

95

"""

96

97

class ContrastiveDistillationDataset:

98

def __init__(

99

self,

100

sentences: List[str],

101

labels: List[int],

102

teacher_embeddings: np.ndarray,

103

sampling_strategy: str = "oversampling"

104

):

105

"""

106

Dataset for contrastive distillation with teacher embeddings.

107

108

Parameters:

109

- sentences: List of input sentences

110

- labels: List of corresponding labels

111

- teacher_embeddings: Pre-computed embeddings from teacher model

112

- sampling_strategy: Strategy for sampling pairs

113

"""

114

```

115

116

## Usage Examples

117

118

### Basic Knowledge Distillation

119

120

```python

121

from setfit import SetFitModel, DistillationTrainer, TrainingArguments

122

from datasets import Dataset

123

124

# Prepare training data

125

train_texts = [

126

"I love this movie!", "This film is terrible.",

127

"Amazing cinematography!", "Waste of time.",

128

"Brilliant acting!", "Poor storyline."

129

]

130

train_labels = [1, 0, 1, 0, 1, 0]

131

132

train_dataset = Dataset.from_dict({

133

"text": train_texts,

134

"label": train_labels

135

})

136

137

# Load pre-trained teacher model (larger, more accurate)

138

teacher_model = SetFitModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")

139

140

# Initialize student model (smaller, faster)

141

student_model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

142

143

# Configure distillation training

144

args = TrainingArguments(

145

output_dir="./distillation_results",

146

batch_size=16,

147

num_epochs=4,

148

learning_rate=2e-5,

149

eval_strategy="epoch",

150

save_strategy="epoch",

151

logging_steps=50

152

)

153

154

# Create distillation trainer

155

distillation_trainer = DistillationTrainer(

156

teacher_model=teacher_model,

157

student_model=student_model,

158

args=args,

159

train_dataset=train_dataset,

160

column_mapping={"text": "text", "label": "label"}

161

)

162

163

# Train student model through distillation

164

print("Starting knowledge distillation...")

165

distillation_trainer.train()

166

167

# The student model is now trained to mimic the teacher

168

student_predictions = student_model.predict([

169

"This movie is fantastic!",

170

"I didn't enjoy this film."

171

])

172

print(f"Student predictions: {student_predictions}")

173

```

174

175

### Comparing Teacher vs Student Performance

176

177

```python

178

from setfit import SetFitModel, DistillationTrainer, TrainingArguments

179

from datasets import load_dataset

180

from sklearn.metrics import accuracy_score, classification_report

181

import time

182

183

# Load dataset

184

train_dataset = load_dataset("SetFit/sst2", split="train[:100]") # Small subset for demo

185

test_dataset = load_dataset("SetFit/sst2", split="test[:50]")

186

187

# Teacher model (large, accurate)

188

teacher_model = SetFitModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")

189

190

# Student model (small, fast)

191

student_model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

192

193

# Train teacher model first

194

print("Training teacher model...")

195

teacher_trainer = SetFitTrainer(

196

model=teacher_model,

197

train_dataset=train_dataset,

198

args=TrainingArguments(num_epochs=4, batch_size=16)

199

)

200

teacher_trainer.train()

201

202

# Train student via distillation

203

print("Training student model via distillation...")

204

distillation_trainer = DistillationTrainer(

205

teacher_model=teacher_model,

206

student_model=student_model,

207

train_dataset=train_dataset,

208

args=TrainingArguments(num_epochs=4, batch_size=16)

209

)

210

distillation_trainer.train()

211

212

# Compare performance and speed

213

test_texts = test_dataset["text"]

214

test_labels = test_dataset["label"]

215

216

# Teacher predictions

217

start_time = time.time()

218

teacher_preds = teacher_model.predict(test_texts)

219

teacher_time = time.time() - start_time

220

221

# Student predictions

222

start_time = time.time()

223

student_preds = student_model.predict(test_texts)

224

student_time = time.time() - start_time

225

226

# Calculate metrics

227

teacher_acc = accuracy_score(test_labels, teacher_preds)

228

student_acc = accuracy_score(test_labels, student_preds)

229

230

print(f"\nPerformance Comparison:")

231

print(f"Teacher accuracy: {teacher_acc:.3f} (Time: {teacher_time:.3f}s)")

232

print(f"Student accuracy: {student_acc:.3f} (Time: {student_time:.3f}s)")

233

print(f"Speed improvement: {teacher_time/student_time:.1f}x")

234

print(f"Accuracy retention: {student_acc/teacher_acc:.1%}")

235

236

print(f"\nDetailed Student Results:")

237

print(classification_report(test_labels, student_preds))

238

```

239

240

### Multi-Teacher Distillation

241

242

```python

243

from setfit import SetFitModel, DistillationTrainer, TrainingArguments

244

import numpy as np

245

246

# Load multiple teacher models with different strengths

247

teacher1 = SetFitModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")

248

teacher2 = SetFitModel.from_pretrained("sentence-transformers/all-roberta-large-v1")

249

teacher3 = SetFitModel.from_pretrained("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")

250

251

# Train teachers on the same dataset

252

teachers = [teacher1, teacher2, teacher3]

253

for i, teacher in enumerate(teachers):

254

print(f"Training teacher {i+1}...")

255

trainer = SetFitTrainer(

256

model=teacher,

257

train_dataset=train_dataset,

258

args=TrainingArguments(num_epochs=3, batch_size=16)

259

)

260

trainer.train()

261

262

# Create ensemble predictions for student training

263

def create_ensemble_dataset(teachers, dataset):

264

"""Create training dataset with ensemble teacher guidance."""

265

texts = dataset["text"]

266

labels = dataset["label"]

267

268

# Get predictions from all teachers

269

teacher_probs = []

270

for teacher in teachers:

271

probs = teacher.predict_proba(texts)

272

teacher_probs.append(probs)

273

274

# Average teacher predictions

275

ensemble_probs = np.mean(teacher_probs, axis=0)

276

277

# Use soft labels from ensemble

278

return Dataset.from_dict({

279

"text": texts,

280

"label": labels,

281

"soft_labels": ensemble_probs.tolist()

282

})

283

284

# Create enhanced training dataset

285

enhanced_dataset = create_ensemble_dataset(teachers, train_dataset)

286

287

# Student model

288

student_model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

289

290

# Custom distillation trainer that uses ensemble guidance

291

# (This would require custom implementation in practice)

292

distillation_trainer = DistillationTrainer(

293

teacher_model=teacher1, # Use first teacher as primary

294

student_model=student_model,

295

train_dataset=enhanced_dataset,

296

args=TrainingArguments(num_epochs=5, batch_size=16)

297

)

298

299

distillation_trainer.train()

300

```

301

302

### Progressive Distillation

303

304

```python

305

from setfit import SetFitModel, DistillationTrainer, TrainingArguments

306

307

# Create a chain of models: Large -> Medium -> Small

308

large_model = SetFitModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")

309

medium_model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L12-v2")

310

small_model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

311

312

# Stage 1: Train large model (teacher)

313

print("Stage 1: Training large model...")

314

large_trainer = SetFitTrainer(

315

model=large_model,

316

train_dataset=train_dataset,

317

args=TrainingArguments(num_epochs=4, batch_size=16)

318

)

319

large_trainer.train()

320

321

# Stage 2: Distill large -> medium

322

print("Stage 2: Distilling large -> medium...")

323

medium_distillation = DistillationTrainer(

324

teacher_model=large_model,

325

student_model=medium_model,

326

train_dataset=train_dataset,

327

args=TrainingArguments(num_epochs=4, batch_size=16)

328

)

329

medium_distillation.train()

330

331

# Stage 3: Distill medium -> small

332

print("Stage 3: Distilling medium -> small...")

333

small_distillation = DistillationTrainer(

334

teacher_model=medium_model,

335

student_model=small_model,

336

train_dataset=train_dataset,

337

args=TrainingArguments(num_epochs=4, batch_size=16)

338

)

339

small_distillation.train()

340

341

# Compare all models

342

models = {

343

"Large": large_model,

344

"Medium": medium_model,

345

"Small": small_model

346

}

347

348

test_texts = ["This is amazing!", "This is terrible."]

349

350

print("\nProgressive Distillation Results:")

351

for name, model in models.items():

352

start_time = time.time()

353

predictions = model.predict(test_texts)

354

inference_time = time.time() - start_time

355

356

print(f"{name} model: {predictions} (Time: {inference_time:.4f}s)")

357

```

358

359

### Distillation with Custom Loss

360

361

```python

362

from setfit import DistillationTrainer, TrainingArguments

363

import torch

364

import torch.nn.functional as F

365

366

class CustomDistillationTrainer(DistillationTrainer):

367

def __init__(self, *args, temperature=4.0, alpha=0.7, **kwargs):

368

super().__init__(*args, **kwargs)

369

self.temperature = temperature

370

self.alpha = alpha # Weight for distillation loss vs task loss

371

372

def compute_distillation_loss(self, teacher_logits, student_logits, labels):

373

"""Custom distillation loss combining soft and hard targets."""

374

# Soft target loss (KL divergence)

375

teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)

376

student_log_probs = F.log_softmax(student_logits / self.temperature, dim=1)

377

distillation_loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')

378

distillation_loss *= (self.temperature ** 2)

379

380

# Hard target loss (standard cross-entropy)

381

task_loss = F.cross_entropy(student_logits, labels)

382

383

# Combined loss

384

total_loss = self.alpha * distillation_loss + (1 - self.alpha) * task_loss

385

return total_loss

386

387

# Use custom trainer

388

custom_trainer = CustomDistillationTrainer(

389

teacher_model=teacher_model,

390

student_model=student_model,

391

train_dataset=train_dataset,

392

args=TrainingArguments(num_epochs=4, batch_size=16),

393

temperature=5.0, # Higher temperature for softer distributions

394

alpha=0.8 # More weight on distillation loss

395

)

396

397

custom_trainer.train()

398

```

399

400

### Evaluating Distillation Quality

401

402

```python

403

from setfit import SetFitModel, DistillationTrainer

404

from sklearn.metrics import accuracy_score

405

import numpy as np

406

from scipy.stats import spearmanr

407

408

def evaluate_distillation_quality(teacher_model, student_model, test_dataset):

409

"""Comprehensive evaluation of distillation quality."""

410

test_texts = test_dataset["text"]

411

test_labels = test_dataset["label"]

412

413

# Get predictions and probabilities

414

teacher_preds = teacher_model.predict(test_texts)

415

student_preds = student_model.predict(test_texts)

416

417

teacher_probs = teacher_model.predict_proba(test_texts)

418

student_probs = student_model.predict_proba(test_texts)

419

420

# Calculate metrics

421

teacher_acc = accuracy_score(test_labels, teacher_preds)

422

student_acc = accuracy_score(test_labels, student_preds)

423

424

# Prediction agreement between teacher and student

425

agreement = accuracy_score(teacher_preds, student_preds)

426

427

# Probability correlation (how similar are the confidence scores)

428

teacher_max_probs = np.max(teacher_probs, axis=1)

429

student_max_probs = np.max(student_probs, axis=1)

430

prob_correlation, _ = spearmanr(teacher_max_probs, student_max_probs)

431

432

# KL divergence between probability distributions

433

kl_divergences = []

434

for t_prob, s_prob in zip(teacher_probs, student_probs):

435

# Add small epsilon to avoid log(0)

436

kl_div = np.sum(t_prob * np.log((t_prob + 1e-8) / (s_prob + 1e-8)))

437

kl_divergences.append(kl_div)

438

avg_kl_div = np.mean(kl_divergences)

439

440

results = {

441

"teacher_accuracy": teacher_acc,

442

"student_accuracy": student_acc,

443

"accuracy_retention": student_acc / teacher_acc,

444

"prediction_agreement": agreement,

445

"probability_correlation": prob_correlation,

446

"avg_kl_divergence": avg_kl_div

447

}

448

449

return results

450

451

# Evaluate distillation

452

evaluation_results = evaluate_distillation_quality(

453

teacher_model=teacher_model,

454

student_model=student_model,

455

test_dataset=test_dataset

456

)

457

458

print("Distillation Quality Assessment:")

459

for metric, value in evaluation_results.items():

460

if isinstance(value, float):

461

print(f"{metric}: {value:.4f}")

462

else:

463

print(f"{metric}: {value}")

464

```