Efficient few-shot learning with Sentence Transformers
—
Automatic model card generation and metadata management for reproducibility and documentation. Model cards provide comprehensive documentation of model performance, training details, and usage guidelines.
Dataclass for storing and managing model card metadata throughout the training and deployment lifecycle.
class SetFitModelCardData:
def __init__(
self,
language: Optional[Union[str, List[str]]] = None,
license: Optional[str] = None,
tags: Optional[List[str]] = None,
model_name: Optional[str] = None,
model_id: Optional[str] = None,
dataset_name: Optional[str] = None,
dataset_id: Optional[str] = None,
dataset_revision: Optional[str] = None,
task_name: Optional[str] = None,
st_id: Optional[str] = None
):
"""
Initialize model card data for SetFit models.
Parameters:
- language: Language(s) supported by the model
- license: Model license (e.g., "MIT", "Apache-2.0")
- tags: List of tags for model categorization
- model_name: Human-readable model name
- model_id: Unique model identifier (HuggingFace Hub ID)
- dataset_name: Name of training dataset
- dataset_id: Dataset identifier (HuggingFace Hub ID)
- dataset_revision: Specific dataset revision/commit
- task_name: ML task type (e.g., "text-classification")
- st_id: Sentence Transformer model ID used as base
"""
def set_best_model_step(self, step: int) -> None:
"""
Record the training step where best performance was achieved.
Parameters:
- step: Training step number with best validation metrics
"""
def set_widget_examples(self, examples: List[Dict[str, str]]) -> None:
"""
Set example inputs for HuggingFace Hub model widget.
Parameters:
- examples: List of example inputs with expected outputs
"""
def set_train_set_metrics(self, metrics: Dict[str, float]) -> None:
"""
Record training set performance metrics.
Parameters:
- metrics: Dictionary of metric names and values
"""
def set_label_examples(self, examples: Dict[str, List[str]]) -> None:
"""
Set example texts for each label class.
Parameters:
- examples: Mapping from label names to example texts
"""
def infer_dataset_id(self, dataset_name: str) -> Optional[str]:
"""
Infer HuggingFace dataset ID from dataset name.
Parameters:
- dataset_name: Name of the dataset
Returns:
Inferred dataset ID or None if not found
"""
def register_model(self, model_id: str, exists_ok: bool = False) -> None:
"""
Register model with HuggingFace Hub.
Parameters:
- model_id: Model ID for registration
- exists_ok: Whether to allow overwriting existing model
"""
def infer_st_id(self, st_model: SentenceTransformer) -> str:
"""
Infer sentence transformer model ID from model object.
Parameters:
- st_model: Sentence transformer model instance
Returns:
Model ID string
"""
def set_st_id(self, st_id: str) -> None:
"""
Set the sentence transformer model ID.
Parameters:
- st_id: Sentence transformer model identifier
"""
def post_training_eval_results(self, results: Dict[str, Any]) -> None:
"""
Record post-training evaluation results.
Parameters:
- results: Evaluation results including metrics and metadata
"""
def to_dict(self) -> Dict[str, Any]:
"""
Convert model card data to dictionary format.
Returns:
Dictionary representation of model card data
"""
def to_yaml(self) -> str:
"""
Convert model card data to YAML format for HuggingFace Hub.
Returns:
YAML string representation
"""Training callback that automatically tracks metrics and generates model cards during training.
class ModelCardCallback:
def __init__(self, trainer: "Trainer"):
"""
Initialize model card callback for automatic tracking.
Parameters:
- trainer: SetFit trainer instance to monitor
"""
def on_init_end(
self,
args: TrainingArguments,
state: "TrainerState",
control: "TrainerControl"
) -> None:
"""
Called when trainer initialization ends.
Parameters:
- args: Training arguments
- state: Current trainer state
- control: Training control flags
"""
def on_train_begin(
self,
args: TrainingArguments,
state: "TrainerState",
control: "TrainerControl"
) -> None:
"""
Called when training begins.
Parameters:
- args: Training arguments
- state: Current trainer state
- control: Training control flags
"""
def on_evaluate(
self,
args: TrainingArguments,
state: "TrainerState",
control: "TrainerControl",
model: SetFitModel,
logs: Optional[Dict[str, float]] = None
) -> None:
"""
Called after each evaluation.
Parameters:
- args: Training arguments
- state: Current trainer state
- control: Training control flags
- model: Model being trained
- logs: Evaluation metrics
"""
def on_log(
self,
args: TrainingArguments,
state: "TrainerState",
control: "TrainerControl",
model: SetFitModel,
logs: Optional[Dict[str, float]] = None
) -> None:
"""
Called when metrics are logged.
Parameters:
- args: Training arguments
- state: Current trainer state
- control: Training control flags
- model: Model being trained
- logs: Logged metrics
"""Utility functions for generating model cards from trained models.
def generate_model_card(model: SetFitModel) -> str:
"""
Generate comprehensive model card for a SetFit model.
Parameters:
- model: Trained SetFit model with model card data
Returns:
Complete model card as markdown string
"""
def is_on_huggingface(repo_id: str, is_model: bool = True) -> bool:
"""
Check if a repository exists on HuggingFace Hub.
Parameters:
- repo_id: Repository identifier to check
- is_model: Whether to check model hub (True) or dataset hub (False)
Returns:
True if repository exists, False otherwise
"""from setfit import SetFitModel, SetFitTrainer, TrainingArguments, SetFitModelCardData
from datasets import load_dataset
# Create model card data
model_card_data = SetFitModelCardData(
language="en",
license="apache-2.0",
tags=["setfit", "sentence-transformers", "text-classification", "few-shot"],
model_name="SetFit Sentiment Classifier",
dataset_name="emotion",
task_name="text-classification"
)
# Initialize model with model card data
model = SetFitModel.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2",
model_card_data=model_card_data
)
# Load dataset
train_dataset = load_dataset("emotion", split="train[:100]")
eval_dataset = load_dataset("emotion", split="validation[:50]")
# Set up training with model card tracking
args = TrainingArguments(
output_dir="./results",
num_epochs=4,
batch_size=16,
eval_strategy="epoch",
save_strategy="epoch",
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="eval_accuracy"
)
trainer = SetFitTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
column_mapping={"text": "text", "label": "label"}
)
# Train model (model card will be updated automatically)
trainer.train()
# Generate and display model card
model_card = model.generate_model_card()
print(model_card)
# Save model with model card
model.save_pretrained("./my-setfit-model")from setfit import SetFitModel, SetFitModelCardData, generate_model_card
import json
# Create detailed model card data
model_card_data = SetFitModelCardData(
language=["en", "es", "fr"], # Multi-language support
license="mit",
tags=[
"setfit",
"sentence-transformers",
"text-classification",
"few-shot-learning",
"multilingual",
"sentiment-analysis"
],
model_name="Multilingual SetFit Sentiment Classifier",
model_id="my-org/multilingual-setfit-sentiment",
dataset_name="multilingual_sentiment",
dataset_id="my-org/multilingual-sentiment-dataset",
task_name="text-classification"
)
# Add widget examples for HuggingFace Hub
widget_examples = [
{"text": "I love this product!"},
{"text": "This is terrible quality."},
{"text": "¡Me encanta este producto!"}, # Spanish
{"text": "J'adore ce produit!"} # French
]
model_card_data.set_widget_examples(widget_examples)
# Add label examples
label_examples = {
"positive": [
"This is amazing!",
"I absolutely love it!",
"Best purchase ever!"
],
"negative": [
"This is awful.",
"Worst product I've bought.",
"Complete waste of money."
],
"neutral": [
"It's okay.",
"Average quality product.",
"Nothing special about it."
]
}
model_card_data.set_label_examples(label_examples)
# Initialize model
model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
model_card_data=model_card_data
)
# Simulate training metrics
training_metrics = {
"final_train_accuracy": 0.94,
"final_train_f1": 0.93,
"final_eval_accuracy": 0.89,
"final_eval_f1": 0.88,
"training_time": "00:15:32",
"total_steps": 200,
"best_step": 180
}
# Update model card with training results
model_card_data.set_train_set_metrics(training_metrics)
model_card_data.set_best_model_step(180)
# Add post-training evaluation results
eval_results = {
"test_accuracy": 0.87,
"test_f1": 0.86,
"test_precision": 0.88,
"test_recall": 0.85,
"per_class_metrics": {
"positive": {"precision": 0.91, "recall": 0.89, "f1": 0.90},
"negative": {"precision": 0.88, "recall": 0.87, "f1": 0.88},
"neutral": {"precision": 0.85, "recall": 0.81, "f1": 0.83}
},
"confusion_matrix": [[45, 3, 2], [4, 41, 5], [7, 6, 37]]
}
model_card_data.post_training_eval_results(eval_results)
# Generate comprehensive model card
model_card = generate_model_card(model)
print("Generated Model Card:")
print("=" * 50)
print(model_card)
# Save model card to file
with open("MODEL_CARD.md", "w") as f:
f.write(model_card)
print("\nModel card saved to MODEL_CARD.md")from setfit import SetFitTrainer, ModelCardCallback, SetFitModelCardData
from datasets import load_dataset
# Set up model card data
model_card_data = SetFitModelCardData(
language="en",
license="apache-2.0",
tags=["setfit", "text-classification", "emotion-detection"],
model_name="SetFit Emotion Classifier",
dataset_name="emotion",
task_name="multi-class-classification"
)
# Initialize model
model = SetFitModel.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2",
model_card_data=model_card_data
)
# Load dataset
train_dataset = load_dataset("emotion", split="train[:200]")
eval_dataset = load_dataset("emotion", split="validation[:100]")
# Create trainer with model card callback
trainer = SetFitTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=TrainingArguments(
output_dir="./emotion-classifier",
num_epochs=3,
batch_size=16,
eval_strategy="epoch",
logging_steps=20,
save_strategy="epoch"
),
callbacks=[ModelCardCallback] # Add model card callback
)
# Train model - callback will automatically update model card
print("Training model with automatic model card generation...")
trainer.train()
# Model card is automatically generated and updated during training
final_model_card = model.generate_model_card()
print("Final Model Card:")
print(final_model_card)
# Save model with updated model card
model.save_pretrained("./emotion-classifier-final")from setfit import SetFitModel, SetFitModelCardData
from huggingface_hub import HfApi
import os
# Create production-ready model card
model_card_data = SetFitModelCardData(
language="en",
license="apache-2.0",
tags=[
"setfit",
"sentence-transformers",
"text-classification",
"few-shot-learning",
"pytorch",
"scikit-learn"
],
model_name="SetFit Binary Sentiment Classifier",
model_id="my-username/setfit-sentiment-binary",
dataset_name="imdb",
dataset_id="imdb",
task_name="binary-classification"
)
# Load trained model
model = SetFitModel.from_pretrained("./my-trained-model")
model.model_card_data = model_card_data
# Add comprehensive evaluation metrics
evaluation_metrics = {
"accuracy": 0.922,
"precision": 0.918,
"recall": 0.925,
"f1": 0.921,
"roc_auc": 0.965,
"samples_seen": 25000,
"training_samples": 16, # Few-shot learning
"eval_samples": 25000
}
model_card_data.set_train_set_metrics(evaluation_metrics)
# Set widget examples for interactive testing
widget_examples = [
{"text": "This movie is absolutely fantastic! I loved every minute of it."},
{"text": "Boring and predictable. Waste of time."},
{"text": "Not bad, but could be better."},
{"text": "Outstanding performance by the lead actor!"}
]
model_card_data.set_widget_examples(widget_examples)
# Generate model card
model_card_content = generate_model_card(model)
# Save model locally with model card
save_path = "./setfit-sentiment-for-hub"
model.save_pretrained(save_path)
# Write model card to README.md
with open(f"{save_path}/README.md", "w") as f:
f.write(model_card_content)
print(f"Model and model card saved to {save_path}")
# Upload to HuggingFace Hub (requires authentication)
# api = HfApi()
# api.upload_folder(
# folder_path=save_path,
# repo_id="my-username/setfit-sentiment-binary",
# repo_type="model"
# )
# print("Model uploaded to HuggingFace Hub!")from setfit import SetFitModelCardData
from jinja2 import Template
# Custom model card template
CUSTOM_TEMPLATE = """
# {{ model_name }}
{{ description }}
## Model Details
- **Model Type**: SetFit (Sentence Transformers + Classification Head)
- **Language**: {{ language }}
- **License**: {{ license }}
- **Base Model**: {{ st_id }}
## Training Data
- **Dataset**: {{ dataset_name }}
- **Training Samples**: {{ training_samples }}
- **Validation Samples**: {{ validation_samples }}
## Performance Metrics
| Metric | Score |
|--------|-------|
{% for metric, score in metrics.items() %}
| {{ metric.title() }} | {{ "%.3f"|format(score) }} |
{% endfor %}
## Usage
```python
from setfit import SetFitModel
# Load model
model = SetFitModel.from_pretrained("{{ model_id }}")
# Make predictions
predictions = model.predict([
"Your text here"
]){{ limitations }}
If you use this model, please cite:
@article{setfit2022,
title={Efficient Few-Shot Learning Without Prompts},
author={Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren},
journal={arXiv preprint arXiv:2209.11055},
year={2022}
}"""
class CustomModelCardGenerator: def init(self, template_string=CUSTOM_TEMPLATE): self.template = Template(template_string)
def generate(self, model_card_data, **kwargs):
"""Generate model card from template and data."""
# Prepare template variables
template_vars = {
"model_name": model_card_data.model_name,
"description": "Few-shot text classification model trained with SetFit framework",
"language": model_card_data.language,
"license": model_card_data.license,
"st_id": model_card_data.st_id,
"dataset_name": model_card_data.dataset_name,
"model_id": model_card_data.model_id,
"training_samples": kwargs.get("training_samples", "Unknown"),
"validation_samples": kwargs.get("validation_samples", "Unknown"),
"metrics": kwargs.get("metrics", {}),
"batch_size": kwargs.get("batch_size", "16"),
"learning_rate": kwargs.get("learning_rate", "2e-5"),
"num_epochs": kwargs.get("num_epochs", "1"),
"optimizer": kwargs.get("optimizer", "AdamW"),
"limitations": kwargs.get("limitations", "This model may have biases present in the training data.")
}
return self.template.render(**template_vars)model_card_data = SetFitModelCardData( model_name="Custom SetFit Classifier", language="en", license="mit", st_id="sentence-transformers/all-MiniLM-L6-v2", dataset_name="custom_dataset", model_id="username/custom-setfit-model" )
generator = CustomModelCardGenerator() custom_card = generator.generate( model_card_data, training_samples=500, validation_samples=100, metrics={ "accuracy": 0.89, "f1": 0.87, "precision": 0.91, "recall": 0.84 }, batch_size=32, learning_rate="5e-5", num_epochs=3, limitations="Model trained on domain-specific data. May not generalize well to other domains." )
print("Custom Model Card:") print(custom_card)
### Model Card Validation and Quality Checks
```python
from setfit import SetFitModelCardData, is_on_huggingface
import re
def validate_model_card_data(model_card_data: SetFitModelCardData) -> Dict[str, List[str]]:
"""
Validate model card data for completeness and quality.
Returns:
Dictionary with validation results and recommendations
"""
errors = []
warnings = []
suggestions = []
# Required fields check
required_fields = ['model_name', 'language', 'license', 'task_name']
for field in required_fields:
if not getattr(model_card_data, field):
errors.append(f"Missing required field: {field}")
# License validation
valid_licenses = ['apache-2.0', 'mit', 'cc-by-4.0', 'cc-by-sa-4.0', 'gpl-3.0']
if model_card_data.license and model_card_data.license.lower() not in valid_licenses:
warnings.append(f"Unusual license: {model_card_data.license}")
# Model ID validation
if model_card_data.model_id:
if not re.match(r'^[a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+$', model_card_data.model_id):
errors.append("Model ID should follow format: username/model-name")
# Check if model exists on HuggingFace Hub
if is_on_huggingface(model_card_data.model_id):
warnings.append(f"Model ID {model_card_data.model_id} already exists on HuggingFace Hub")
# Dataset validation
if model_card_data.dataset_id:
if not is_on_huggingface(model_card_data.dataset_id, is_model=False):
warnings.append(f"Dataset {model_card_data.dataset_id} not found on HuggingFace Hub")
# Tag validation
if model_card_data.tags:
recommended_tags = ['setfit', 'sentence-transformers', 'text-classification']
missing_recommended = [tag for tag in recommended_tags if tag not in model_card_data.tags]
if missing_recommended:
suggestions.append(f"Consider adding recommended tags: {missing_recommended}")
else:
suggestions.append("Add tags to improve model discoverability")
# Language validation
if model_card_data.language:
if isinstance(model_card_data.language, str):
if len(model_card_data.language) > 3:
warnings.append("Language should use ISO 639-1 codes (e.g., 'en', 'fr')")
return {
"errors": errors,
"warnings": warnings,
"suggestions": suggestions,
"is_valid": len(errors) == 0
}
# Example validation
model_card_data = SetFitModelCardData(
model_name="Test Model",
language="english", # Should be "en"
license="custom-license", # Not standard
model_id="invalid-id-format", # Invalid format
tags=["classification"], # Missing recommended tags
dataset_id="nonexistent/dataset"
)
validation_results = validate_model_card_data(model_card_data)
print("Model Card Validation Results:")
print(f"Valid: {validation_results['is_valid']}")
if validation_results['errors']:
print("\nErrors (must fix):")
for error in validation_results['errors']:
print(f" ❌ {error}")
if validation_results['warnings']:
print("\nWarnings (should review):")
for warning in validation_results['warnings']:
print(f" ⚠️ {warning}")
if validation_results['suggestions']:
print("\nSuggestions (recommended):")
for suggestion in validation_results['suggestions']:
print(f" 💡 {suggestion}")Install with Tessl CLI
npx tessl i tessl/pypi-setfit