PyTorch native metrics library providing 400+ rigorously tested metrics across classification, regression, audio, image, text, and other ML domains
Natural language processing metrics for translation, summarization, and text generation evaluation including BLEU, ROUGE, and semantic similarity measures for comprehensive text quality assessment.
Metrics for evaluating machine translation quality and n-gram overlap.
class BLEUScore(Metric):
def __init__(
self,
n_gram: int = 4,
smooth: bool = False,
weights: Optional[Sequence[float]] = None,
**kwargs
): ...
class SacreBLEUScore(Metric):
def __init__(
self,
n_gram: int = 4,
smooth: bool = False,
tokenize: Optional[str] = None,
lowercase: bool = False,
**kwargs
): ...
class CHRFScore(Metric):
def __init__(
self,
n_char_order: int = 6,
n_word_order: int = 2,
beta: float = 2.0,
lowercase: bool = False,
whitespace: bool = False,
**kwargs
): ...Metrics for evaluating automatic summarization quality.
class ROUGEScore(Metric):
def __init__(
self,
rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL"),
use_stemmer: bool = False,
normalizer: Optional[Callable[[str], str]] = None,
tokenizer: Optional[Callable[[str], Sequence[str]]] = None,
accumulate: str = "best",
**kwargs
): ...Character and word-level error rate measurements for ASR and text processing.
class CharErrorRate(Metric):
def __init__(
self,
**kwargs
): ...
class WordErrorRate(Metric):
def __init__(
self,
**kwargs
): ...
class MatchErrorRate(Metric):
def __init__(
self,
**kwargs
): ...
class TranslationEditRate(Metric):
def __init__(
self,
normalize: bool = False,
no_punctuation: bool = False,
lowercase: bool = True,
asian_support: bool = False,
**kwargs
): ...String similarity and distance measures for sequence comparison.
class EditDistance(Metric):
def __init__(
self,
substitution_cost: int = 1,
reduction: Optional[str] = "mean",
**kwargs
): ...
class ExtendedEditDistance(Metric):
def __init__(
self,
language: str = "en",
return_sentence_level_score: bool = False,
alpha: float = 2.0,
rho: float = 0.3,
deletion: float = 0.2,
insertion: float = 1.0,
substitution: float = 1.0,
**kwargs
): ...Information-theoretic measures for text quality assessment.
class WordInfoLost(Metric):
def __init__(
self,
**kwargs
): ...
class WordInfoPreserved(Metric):
def __init__(
self,
**kwargs
): ...
class Perplexity(Metric):
def __init__(
self,
ignore_index: int = -100,
**kwargs
): ...Specialized metrics for question answering task evaluation.
class SQuAD(Metric):
def __init__(
self,
**kwargs
): ...Deep learning-based semantic similarity measures (require optional dependencies).
class BERTScore(Metric):
def __init__(
self,
model_name_or_path: str = "distilbert-base-uncased",
num_layers: Optional[int] = None,
all_layers: bool = False,
model_type: Optional[str] = None,
user_forward_fn: Optional[Callable[[Any, Tensor], Tensor]] = None,
user_tokenizer: Optional[Any] = None,
verbose: bool = False,
idf: bool = False,
device: Optional[Union[str, torch.device]] = None,
max_length: int = 512,
batch_size: int = 64,
num_threads: int = 4,
return_hash: bool = False,
lang: str = "en",
rescale_with_baseline: bool = False,
baseline_path: Optional[str] = None,
use_fast_tokenizer: bool = False,
**kwargs
): ...
class InfoLM(Metric):
def __init__(
self,
model_name_or_path: str = "google/bert_uncased_L-2_H-128_A-2",
temperature: float = 0.25,
measure_to_use: str = "fisher_rao",
max_length: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
batch_size: int = 64,
num_threads: int = 4,
verbose: bool = False,
return_sentence_level_score: bool = False,
**kwargs
): ...import torch
from torchmetrics.text import BLEUScore
# Initialize BLEU metric
bleu = BLEUScore()
# Sample predictions and references
preds = ["the cat is on the mat"]
target = [["there is a cat on the mat", "a cat is on the mat"]]
# Compute BLEU score
bleu_score = bleu(preds, target)
print(f"BLEU Score: {bleu_score:.4f}")
# 4-gram BLEU with smoothing
bleu_smooth = BLEUScore(n_gram=4, smooth=True)
bleu_smooth_score = bleu_smooth(preds, target)
print(f"Smoothed BLEU: {bleu_smooth_score:.4f}")from torchmetrics.text import ROUGEScore
# Initialize ROUGE metric
rouge = ROUGEScore()
# Sample summaries and references
preds = ["the quick brown fox jumps over the lazy dog"]
target = ["a quick brown fox jumps over a lazy dog"]
# Compute ROUGE scores
rouge_scores = rouge(preds, target)
print(f"ROUGE-1: {rouge_scores['rouge1_fmeasure']:.4f}")
print(f"ROUGE-2: {rouge_scores['rouge2_fmeasure']:.4f}")
print(f"ROUGE-L: {rouge_scores['rougeL_fmeasure']:.4f}")
# Custom ROUGE configuration
rouge_custom = ROUGEScore(rouge_keys=("rouge1", "rouge2", "rougeL", "rougeLsum"))
rouge_custom_scores = rouge_custom(preds, target)from torchmetrics.text import WordErrorRate, CharErrorRate
# Initialize error rate metrics
wer = WordErrorRate()
cer = CharErrorRate()
# ASR outputs vs ground truth
preds = ["this is a test"]
target = ["this is the test"]
# Compute error rates
wer_score = wer(preds, target)
cer_score = cer(preds, target)
print(f"Word Error Rate: {wer_score:.4f}")
print(f"Character Error Rate: {cer_score:.4f}")from torchmetrics.text import BERTScore
# Initialize BERTScore (requires transformers)
try:
bertscore = BERTScore(model_name_or_path="distilbert-base-uncased")
# Sample texts
preds = ["the cat sat on the mat"]
target = ["a cat was sitting on the mat"]
# Compute BERTScore
bert_scores = bertscore(preds, target)
print(f"BERTScore F1: {bert_scores['f1']:.4f}")
print(f"BERTScore Precision: {bert_scores['precision']:.4f}")
print(f"BERTScore Recall: {bert_scores['recall']:.4f}")
except ImportError:
print("BERTScore requires the 'transformers' package")from torchmetrics.text import EditDistance
# Initialize edit distance
edit_dist = EditDistance()
# Sample strings
preds = ["kitten"]
target = ["sitting"]
# Compute edit distance
distance = edit_dist(preds, target)
print(f"Edit Distance: {distance:.0f}")
# Normalized edit distance
edit_dist_norm = EditDistance(reduction="mean")
norm_distance = edit_dist_norm(preds, target)
print(f"Normalized Edit Distance: {norm_distance:.4f}")from torchmetrics.text import Perplexity
import torch
# Initialize perplexity metric
perplexity = Perplexity()
# Language model predictions (log probabilities)
# Shape: (batch_size, sequence_length, vocab_size)
preds = torch.randn(2, 8, 1000) # 2 sequences, 8 tokens, vocab size 1000
target = torch.randint(0, 1000, (2, 8)) # target token ids
# Compute perplexity
ppl_score = perplexity(preds, target)
print(f"Perplexity: {ppl_score:.4f}")from torchmetrics.text import TranslationEditRate
# Initialize TER metric
ter = TranslationEditRate(normalize=True, lowercase=True)
# Translation examples
preds = ["The cat is on the mat"]
target = ["There is a cat on the mat"]
# Compute TER
ter_score = ter(preds, target)
print(f"Translation Edit Rate: {ter_score:.4f}")from torchmetrics.text import SQuAD
# Initialize SQuAD metric
squad = SQuAD()
# QA predictions and references
preds = [{"prediction_text": "Denver Broncos", "id": "56be4db0acb8001400a502ec"}]
target = [{"answers": {"answer_start": [177], "text": ["Denver Broncos"]},
"id": "56be4db0acb8001400a502ec"}]
# Compute SQuAD scores
squad_scores = squad(preds, target)
print(f"Exact Match: {squad_scores['exact_match']:.4f}")
print(f"F1 Score: {squad_scores['f1']:.4f}")from torchmetrics.text import BLEUScore, ROUGEScore
# Multiple reference translations/summaries
preds = ["the cat is on the mat"]
target = [["there is a cat on the mat",
"a cat is on the mat",
"the cat sits on the mat"]]
# BLEU with multiple references
bleu_multi = BLEUScore()
bleu_score = bleu_multi(preds, target)
print(f"Multi-reference BLEU: {bleu_score:.4f}")
# ROUGE with multiple references
rouge_multi = ROUGEScore()
rouge_scores = rouge_multi(preds, target)
print(f"Multi-reference ROUGE-L: {rouge_scores['rougeL_fmeasure']:.4f}")from typing import Union, Optional, Sequence, List, Dict, Any, Callable, Tuple
import torch
from torch import Tensor
TextInput = Union[str, List[str]]
TextTarget = Union[str, List[str], List[List[str]]] # Multiple references supported
ROUGEKeys = Union[str, Tuple[str, ...]]
AccumulateType = Union["avg", "best"]
MeasureType = Union["fisher_rao", "kl_divergence", "js_divergence"]Install with Tessl CLI
npx tessl i tessl/pypi-torchmetrics