A lightweight library to help with training neural networks in PyTorch.
—
Comprehensive metric collection system covering classification, regression, NLP, computer vision, clustering, and GAN evaluation. PyTorch Ignite provides 80+ built-in metrics with a consistent API and support for custom metrics.
Foundation classes for all metrics with consistent interface and behavior.
class Metric:
"""
Base class for all metrics.
All metrics inherit from this class and implement reset(), update(), and compute() methods.
"""
def reset(self):
"""Reset metric state to initial values."""
def update(self, output):
"""
Update metric state with new data.
Parameters:
- output: engine output (typically (y_pred, y) tuple)
"""
def compute(self):
"""
Compute and return the metric value.
Returns:
Computed metric value
"""
def attach(self, engine, name, usage=EpochWise()):
"""
Attach metric to an engine.
Parameters:
- engine: Engine to attach to
- name: metric name
- usage: how metric should be used (EpochWise, etc.)
"""
class EpochMetric(Metric):
"""
Base class for metrics computed at the end of each epoch.
Parameters:
- compute_fn: function to compute metric from accumulated values
- output_transform: function to transform engine output
- check_compute_fn: whether to validate compute_fn signature
- device: device for tensor operations
"""
def __init__(self, compute_fn, output_transform=None, check_compute_fn=True, device=None): ...
class RunningAverage(Metric):
"""
Running average wrapper for any metric.
Parameters:
- src: source metric to average
- alpha: smoothing factor (default: 0.98)
- output_transform: function to transform engine output
"""
def __init__(self, src, alpha=0.98, output_transform=None): ...Metrics for classification tasks including binary, multi-class, and multi-label scenarios.
class Accuracy(Metric):
"""
Accuracy metric for classification tasks.
Parameters:
- output_transform: function to transform engine output
- is_multilabel: whether to treat as multi-label classification
- device: device for tensor operations
"""
def __init__(self, output_transform=None, is_multilabel=False, device=torch.device("cpu")): ...
class TopKCategoricalAccuracy(Metric):
"""
Top-K categorical accuracy.
Parameters:
- k: number of top predictions to consider
- output_transform: function to transform engine output
- device: device for tensor operations
"""
def __init__(self, k=5, output_transform=None, device=None): ...
class Precision(Metric):
"""
Precision score for classification.
Parameters:
- average: averaging strategy ('micro', 'macro', 'weighted', None)
- is_multilabel: whether to treat as multi-label
- output_transform: function to transform engine output
- device: device for tensor operations
"""
def __init__(self, average=False, is_multilabel=False, output_transform=None, device=None): ...
class Recall(Metric):
"""
Recall score for classification.
Parameters:
- average: averaging strategy ('micro', 'macro', 'weighted', None)
- is_multilabel: whether to treat as multi-label
- output_transform: function to transform engine output
- device: device for tensor operations
"""
def __init__(self, average=False, is_multilabel=False, output_transform=None, device=None): ...
class Fbeta(Metric):
"""
F-beta score for classification.
Parameters:
- beta: weight of recall in harmonic mean
- average: averaging strategy ('micro', 'macro', 'weighted', None)
- precision: precision metric (optional)
- recall: recall metric (optional)
- output_transform: function to transform engine output
- device: device for tensor operations
"""
def __init__(self, beta, average=True, precision=None, recall=None, output_transform=None, device=None): ...
class ConfusionMatrix(Metric):
"""
Confusion matrix for classification.
Parameters:
- num_classes: number of classes
- average: averaging strategy (None, 'samples')
- output_transform: function to transform engine output
- device: device for tensor operations
"""
def __init__(self, num_classes, average=None, output_transform=None, device=None): ...
class MultilabelConfusionMatrix(Metric):
"""Multi-label confusion matrix."""
def __init__(self, num_classes, output_transform=None, device=None): ...
class CohenKappa(Metric):
"""Cohen's kappa coefficient."""
def __init__(self, num_classes, output_transform=None, device=None): ...
class IoU(Metric):
"""Intersection over Union for segmentation."""
def __init__(self, num_classes, ignore_index=None, output_transform=None, device=None): ...
class mIoU(Metric):
"""Mean Intersection over Union."""
def __init__(self, num_classes, ignore_index=None, output_transform=None, device=None): ...
class DiceCoefficient(Metric):
"""Dice coefficient for segmentation."""
def __init__(self, output_transform=None, device=None): ...Metrics for regression tasks and continuous value prediction.
class MeanAbsoluteError(Metric):
"""Mean absolute error."""
def __init__(self, output_transform=None, device=None): ...
class MeanSquaredError(Metric):
"""Mean squared error."""
def __init__(self, output_transform=None, device=None): ...
class RootMeanSquaredError(Metric):
"""Root mean squared error."""
def __init__(self, output_transform=None, device=None): ...
class MeanPairwiseDistance(Metric):
"""Mean pairwise distance."""
def __init__(self, p=2, eps=1e-6, output_transform=None, device=None): ...
class CosineSimilarity(Metric):
"""Cosine similarity."""
def __init__(self, output_transform=None, device=None): ...
class PearsonCorrelation(Metric):
"""Pearson correlation coefficient."""
def __init__(self, output_transform=None, device=None): ...
class SpearmanRankCorrelation(Metric):
"""Spearman rank correlation."""
def __init__(self, output_transform=None, device=None): ...
class R2Score(Metric):
"""R-squared coefficient of determination."""
def __init__(self, output_transform=None, device=None): ...
class MeanAbsolutePercentageError(Metric):
"""Mean absolute percentage error."""
def __init__(self, output_transform=None, device=None): ...
class MedianAbsoluteError(Metric):
"""Median absolute error."""
def __init__(self, output_transform=None, device=None): ...
class MedianAbsolutePercentageError(Metric):
"""Median absolute percentage error."""
def __init__(self, output_transform=None, device=None): ...
class FractionalAbsoluteError(Metric):
"""Fractional absolute error."""
def __init__(self, output_transform=None, device=None): ...
class FractionalBias(Metric):
"""Fractional bias."""
def __init__(self, output_transform=None, device=None): ...Metrics for ranking tasks and probability-based evaluation.
class RocAuc(Metric):
"""
ROC AUC score.
Parameters:
- output_transform: function to transform engine output
- device: device for tensor operations
"""
def __init__(self, output_transform=None, device=None): ...
class PrecisionRecallCurve(Metric):
"""Precision-recall curve."""
def __init__(self, output_transform=None, device=None): ...
class AveragePrecision(Metric):
"""Average precision score."""
def __init__(self, output_transform=None, device=None): ...
class MeanAveragePrecision(Metric):
"""Mean average precision."""
def __init__(self, output_transform=None, device=None): ...Wrapper for computing loss values as metrics.
class Loss(Metric):
"""
Loss metric wrapper.
Parameters:
- loss_fn: PyTorch loss function
- output_transform: function to transform engine output
- batch_size: batch size for averaging
- device: device for tensor operations
"""
def __init__(self, loss_fn, output_transform=None, batch_size=None, device=None): ...
class GeometricAverage(Metric):
"""Geometric average of metrics."""
def __init__(self, output_transform=None, device=None): ...
class Average(Metric):
"""Average of values."""
def __init__(self, output_transform=None, device=None): ...
class VariableAccumulation(Metric):
"""Variable accumulation over batches."""
def __init__(self, op, output_transform=None, device=None): ...Metrics for evaluating image quality and similarity.
class PSNR(Metric):
"""
Peak signal-to-noise ratio.
Parameters:
- data_range: dynamic range of input images
- output_transform: function to transform engine output
- device: device for tensor operations
"""
def __init__(self, data_range=1.0, output_transform=None, device=None): ...
class SSIM(Metric):
"""
Structural similarity index.
Parameters:
- data_range: dynamic range of input images
- kernel_size: size of sliding window
- sigma: standard deviation for Gaussian kernel
- k1, k2: algorithm parameters
- output_transform: function to transform engine output
- device: device for tensor operations
"""
def __init__(self, data_range=1.0, kernel_size=11, sigma=1.5, k1=0.01, k2=0.03, output_transform=None, device=None): ...Metrics for natural language processing tasks.
class Bleu(Metric):
"""
BLEU score for text generation.
Parameters:
- ngram: n-gram order (default: 4)
- smooth: smoothing method
- output_transform: function to transform engine output
- device: device for tensor operations
"""
def __init__(self, ngram=4, smooth="no_smooth", output_transform=None, device=None): ...
class Rouge(Metric):
"""
ROUGE score for text summarization.
Parameters:
- multiref: multiref ROUGE (ROUGE-L)
- variants: ROUGE variants to compute
- output_transform: function to transform engine output
- device: device for tensor operations
"""
def __init__(self, multiref="best", variants=None, output_transform=None, device=None): ...Metrics for evaluating clustering algorithms.
class MutualInformation(Metric):
"""Mutual information score."""
def __init__(self, output_transform=None, device=None): ...
class NormalizedMutualInformation(Metric):
"""Normalized mutual information."""
def __init__(self, output_transform=None, device=None): ...
class AdjustedMutualInformation(Metric):
"""Adjusted mutual information."""
def __init__(self, output_transform=None, device=None): ...
class AdjustedRandIndex(Metric):
"""Adjusted Rand index."""
def __init__(self, output_transform=None, device=None): ...
class RandIndex(Metric):
"""Rand index."""
def __init__(self, output_transform=None, device=None): ...
class FowlkesMallowsIndex(Metric):
"""Fowlkes-Mallows index."""
def __init__(self, output_transform=None, device=None): ...
class CalinskeHarabaszIndex(Metric):
"""Calinski-Harabasz index."""
def __init__(self, output_transform=None, device=None): ...
class DaviesBouldinIndex(Metric):
"""Davies-Bouldin index."""
def __init__(self, output_transform=None, device=None): ...
class SilhouetteScore(Metric):
"""Silhouette score."""
def __init__(self, output_transform=None, device=None): ...Metrics for evaluating generative adversarial networks.
class FID(Metric):
"""
Fréchet Inception Distance.
Parameters:
- num_features: number of features in feature extractor
- feature_extractor: feature extraction model
- output_transform: function to transform engine output
- device: device for tensor operations
"""
def __init__(self, num_features=2048, feature_extractor=None, output_transform=None, device=None): ...
class InceptionScore(Metric):
"""
Inception Score.
Parameters:
- output_transform: function to transform engine output
- device: device for tensor operations
"""
def __init__(self, output_transform=None, device=None): ...Metrics for measuring frequency and timing information.
class Frequency(Metric):
"""
Frequency measurement.
Parameters:
- output_transform: function to transform engine output
- device: device for tensor operations
"""
def __init__(self, output_transform=None, device=None): ...from ignite.metrics import Accuracy, Loss
from ignite.engine import create_supervised_evaluator
# Create metrics
metrics = {
'accuracy': Accuracy(),
'loss': Loss(criterion)
}
# Create evaluator with metrics
evaluator = create_supervised_evaluator(model, metrics=metrics)
# Run evaluation
evaluator.run(val_loader)
print(f"Accuracy: {evaluator.state.metrics['accuracy']}")
print(f"Loss: {evaluator.state.metrics['loss']}")from ignite.metrics import Accuracy
def custom_output_transform(output):
# Extract predictions and targets from engine output
y_pred, y = output
return y_pred, y
accuracy = Accuracy(output_transform=custom_output_transform)from ignite.metrics import RunningAverage, Loss
# Create running average of loss
running_avg_loss = RunningAverage(Loss(criterion), alpha=0.98)
# Attach to trainer
running_avg_loss.attach(trainer, 'loss')
@trainer.on(Events.ITERATION_COMPLETED(every=100))
def log_loss(engine):
print(f"Running avg loss: {engine.state.metrics['loss']}")Install with Tessl CLI
npx tessl i tessl/pypi-pytorch-ignite