CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-torchmetrics

PyTorch native metrics library providing 400+ rigorously tested metrics across classification, regression, audio, image, text, and other ML domains

Overview
Eval results
Files

text.mddocs/

Text Metrics

Natural language processing metrics for translation, summarization, and text generation evaluation including BLEU, ROUGE, and semantic similarity measures for comprehensive text quality assessment.

Capabilities

Machine Translation Metrics

Metrics for evaluating machine translation quality and n-gram overlap.

class BLEUScore(Metric):
    def __init__(
        self,
        n_gram: int = 4,
        smooth: bool = False,
        weights: Optional[Sequence[float]] = None,
        **kwargs
    ): ...

class SacreBLEUScore(Metric):
    def __init__(
        self,
        n_gram: int = 4,
        smooth: bool = False,
        tokenize: Optional[str] = None,
        lowercase: bool = False,
        **kwargs
    ): ...

class CHRFScore(Metric):
    def __init__(
        self,
        n_char_order: int = 6,
        n_word_order: int = 2,
        beta: float = 2.0,
        lowercase: bool = False,
        whitespace: bool = False,
        **kwargs
    ): ...

Summarization Metrics

Metrics for evaluating automatic summarization quality.

class ROUGEScore(Metric):
    def __init__(
        self,
        rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL"),
        use_stemmer: bool = False,
        normalizer: Optional[Callable[[str], str]] = None,
        tokenizer: Optional[Callable[[str], Sequence[str]]] = None,
        accumulate: str = "best",
        **kwargs
    ): ...

Error Rate Metrics

Character and word-level error rate measurements for ASR and text processing.

class CharErrorRate(Metric):
    def __init__(
        self,
        **kwargs
    ): ...

class WordErrorRate(Metric):
    def __init__(
        self,
        **kwargs
    ): ...

class MatchErrorRate(Metric):
    def __init__(
        self,
        **kwargs
    ): ...

class TranslationEditRate(Metric):
    def __init__(
        self,
        normalize: bool = False,
        no_punctuation: bool = False,
        lowercase: bool = True,
        asian_support: bool = False,
        **kwargs
    ): ...

Edit Distance Metrics

String similarity and distance measures for sequence comparison.

class EditDistance(Metric):
    def __init__(
        self,
        substitution_cost: int = 1,
        reduction: Optional[str] = "mean",
        **kwargs
    ): ...

class ExtendedEditDistance(Metric):
    def __init__(
        self,
        language: str = "en",
        return_sentence_level_score: bool = False,
        alpha: float = 2.0,
        rho: float = 0.3,
        deletion: float = 0.2,
        insertion: float = 1.0,
        substitution: float = 1.0,
        **kwargs
    ): ...

Information Metrics

Information-theoretic measures for text quality assessment.

class WordInfoLost(Metric):
    def __init__(
        self,
        **kwargs
    ): ...

class WordInfoPreserved(Metric):
    def __init__(
        self,
        **kwargs
    ): ...

class Perplexity(Metric):
    def __init__(
        self,
        ignore_index: int = -100,
        **kwargs
    ): ...

Question Answering Metrics

Specialized metrics for question answering task evaluation.

class SQuAD(Metric):
    def __init__(
        self,
        **kwargs
    ): ...

Semantic Similarity Metrics

Deep learning-based semantic similarity measures (require optional dependencies).

class BERTScore(Metric):
    def __init__(
        self,
        model_name_or_path: str = "distilbert-base-uncased",
        num_layers: Optional[int] = None,
        all_layers: bool = False,
        model_type: Optional[str] = None,
        user_forward_fn: Optional[Callable[[Any, Tensor], Tensor]] = None,
        user_tokenizer: Optional[Any] = None,
        verbose: bool = False,
        idf: bool = False,
        device: Optional[Union[str, torch.device]] = None,
        max_length: int = 512,
        batch_size: int = 64,
        num_threads: int = 4,
        return_hash: bool = False,
        lang: str = "en",
        rescale_with_baseline: bool = False,
        baseline_path: Optional[str] = None,
        use_fast_tokenizer: bool = False,
        **kwargs
    ): ...

class InfoLM(Metric):
    def __init__(
        self,
        model_name_or_path: str = "google/bert_uncased_L-2_H-128_A-2",
        temperature: float = 0.25,
        measure_to_use: str = "fisher_rao",
        max_length: Optional[int] = None,
        device: Optional[Union[str, torch.device]] = None,
        batch_size: int = 64,
        num_threads: int = 4,
        verbose: bool = False,
        return_sentence_level_score: bool = False,
        **kwargs
    ): ...

Usage Examples

BLEU Score for Translation

import torch
from torchmetrics.text import BLEUScore

# Initialize BLEU metric
bleu = BLEUScore()

# Sample predictions and references
preds = ["the cat is on the mat"]
target = [["there is a cat on the mat", "a cat is on the mat"]]

# Compute BLEU score
bleu_score = bleu(preds, target)
print(f"BLEU Score: {bleu_score:.4f}")

# 4-gram BLEU with smoothing
bleu_smooth = BLEUScore(n_gram=4, smooth=True)
bleu_smooth_score = bleu_smooth(preds, target)
print(f"Smoothed BLEU: {bleu_smooth_score:.4f}")

ROUGE for Summarization

from torchmetrics.text import ROUGEScore

# Initialize ROUGE metric
rouge = ROUGEScore()

# Sample summaries and references
preds = ["the quick brown fox jumps over the lazy dog"]
target = ["a quick brown fox jumps over a lazy dog"]

# Compute ROUGE scores
rouge_scores = rouge(preds, target)
print(f"ROUGE-1: {rouge_scores['rouge1_fmeasure']:.4f}")
print(f"ROUGE-2: {rouge_scores['rouge2_fmeasure']:.4f}")
print(f"ROUGE-L: {rouge_scores['rougeL_fmeasure']:.4f}")

# Custom ROUGE configuration
rouge_custom = ROUGEScore(rouge_keys=("rouge1", "rouge2", "rougeL", "rougeLsum"))
rouge_custom_scores = rouge_custom(preds, target)

Word Error Rate for ASR

from torchmetrics.text import WordErrorRate, CharErrorRate

# Initialize error rate metrics
wer = WordErrorRate()
cer = CharErrorRate()

# ASR outputs vs ground truth
preds = ["this is a test"]
target = ["this is the test"]

# Compute error rates
wer_score = wer(preds, target)
cer_score = cer(preds, target)

print(f"Word Error Rate: {wer_score:.4f}")
print(f"Character Error Rate: {cer_score:.4f}")

BERTScore for Semantic Similarity

from torchmetrics.text import BERTScore

# Initialize BERTScore (requires transformers)
try:
    bertscore = BERTScore(model_name_or_path="distilbert-base-uncased")
    
    # Sample texts
    preds = ["the cat sat on the mat"]
    target = ["a cat was sitting on the mat"]
    
    # Compute BERTScore
    bert_scores = bertscore(preds, target)
    print(f"BERTScore F1: {bert_scores['f1']:.4f}")
    print(f"BERTScore Precision: {bert_scores['precision']:.4f}")
    print(f"BERTScore Recall: {bert_scores['recall']:.4f}")
    
except ImportError:
    print("BERTScore requires the 'transformers' package")

Edit Distance

from torchmetrics.text import EditDistance

# Initialize edit distance
edit_dist = EditDistance()

# Sample strings
preds = ["kitten"]
target = ["sitting"]

# Compute edit distance
distance = edit_dist(preds, target)
print(f"Edit Distance: {distance:.0f}")

# Normalized edit distance
edit_dist_norm = EditDistance(reduction="mean")
norm_distance = edit_dist_norm(preds, target)
print(f"Normalized Edit Distance: {norm_distance:.4f}")

Perplexity for Language Models

from torchmetrics.text import Perplexity
import torch

# Initialize perplexity metric
perplexity = Perplexity()

# Language model predictions (log probabilities)
# Shape: (batch_size, sequence_length, vocab_size)
preds = torch.randn(2, 8, 1000)  # 2 sequences, 8 tokens, vocab size 1000
target = torch.randint(0, 1000, (2, 8))  # target token ids

# Compute perplexity
ppl_score = perplexity(preds, target)
print(f"Perplexity: {ppl_score:.4f}")

Translation Edit Rate (TER)

from torchmetrics.text import TranslationEditRate

# Initialize TER metric
ter = TranslationEditRate(normalize=True, lowercase=True)

# Translation examples
preds = ["The cat is on the mat"]
target = ["There is a cat on the mat"]

# Compute TER
ter_score = ter(preds, target)
print(f"Translation Edit Rate: {ter_score:.4f}")

SQuAD Metric for QA

from torchmetrics.text import SQuAD

# Initialize SQuAD metric
squad = SQuAD()

# QA predictions and references
preds = [{"prediction_text": "Denver Broncos", "id": "56be4db0acb8001400a502ec"}]
target = [{"answers": {"answer_start": [177], "text": ["Denver Broncos"]}, 
          "id": "56be4db0acb8001400a502ec"}]

# Compute SQuAD scores
squad_scores = squad(preds, target)
print(f"Exact Match: {squad_scores['exact_match']:.4f}")
print(f"F1 Score: {squad_scores['f1']:.4f}")

Multi-Reference Evaluation

from torchmetrics.text import BLEUScore, ROUGEScore

# Multiple reference translations/summaries
preds = ["the cat is on the mat"]
target = [["there is a cat on the mat", 
          "a cat is on the mat", 
          "the cat sits on the mat"]]

# BLEU with multiple references
bleu_multi = BLEUScore()
bleu_score = bleu_multi(preds, target)
print(f"Multi-reference BLEU: {bleu_score:.4f}")

# ROUGE with multiple references
rouge_multi = ROUGEScore()
rouge_scores = rouge_multi(preds, target)
print(f"Multi-reference ROUGE-L: {rouge_scores['rougeL_fmeasure']:.4f}")

Types

from typing import Union, Optional, Sequence, List, Dict, Any, Callable, Tuple
import torch
from torch import Tensor

TextInput = Union[str, List[str]]
TextTarget = Union[str, List[str], List[List[str]]]  # Multiple references supported

ROUGEKeys = Union[str, Tuple[str, ...]]
AccumulateType = Union["avg", "best"]
MeasureType = Union["fisher_rao", "kl_divergence", "js_divergence"]

Install with Tessl CLI

npx tessl i tessl/pypi-torchmetrics

docs

audio.md

classification.md

clustering.md

detection.md

functional.md

image.md

index.md

multimodal.md

nominal.md

regression.md

retrieval.md

segmentation.md

shape.md

text.md

utilities.md

video.md

tile.json