A suite of visual analysis and diagnostic tools for machine learning.
Comprehensive visualizers for evaluating classification model performance, providing insights into prediction accuracy, class distributions, decision boundaries, and threshold optimization. These tools support both binary and multi-class classification problems.
ROC (Receiver Operating Characteristic) curves and AUC (Area Under Curve) analysis for binary and multi-class classification models. Visualizes the trade-off between true positive rate and false positive rate across different classification thresholds.
class ROCAUC(ClassificationScoreVisualizer):
"""
ROC/AUC visualizer for classification models.
Parameters:
- estimator: scikit-learn classifier
- ax: matplotlib axes object, axes to plot on
- micro: bool, whether to plot micro-averaged ROC for multi-class (default: True)
- macro: bool, whether to plot macro-averaged ROC for multi-class (default: True)
- per_class: bool, whether to plot per-class ROC curves (default: True)
- binary: bool, whether to force binary classification mode (default: False)
- classes: list of class labels for display
- encoder: label encoder for transforming class labels
- is_fitted: str or bool, whether estimator is already fitted ("auto", True, False)
- force_model: bool, whether to force model usage even if not required
"""
def __init__(self, estimator, ax=None, micro=True, macro=True, per_class=True, binary=False, classes=None, encoder=None, is_fitted="auto", force_model=False, **kwargs): ...
def fit(self, X, y, **kwargs): ...
def score(self, X, y, **kwargs): ...
def show(self, **kwargs): ...
def roc_auc(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):
"""
Functional API for ROC/AUC visualization.
Parameters:
- estimator: scikit-learn classifier
- X_train: training features
- y_train: training labels
- X_test: test features (optional)
- y_test: test labels (optional)
- classes: list of class labels
Returns:
ROCAUC visualizer instance
"""Usage Example:
from yellowbrick.classifier import ROCAUC, roc_auc
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
# Class-based API
model = RandomForestClassifier()
visualizer = ROCAUC(model, classes=['Benign', 'Malignant'])
visualizer.fit(X_train, y_train)
visualizer.score(X_test, y_test)
visualizer.show()
# Functional API
roc_auc(model, X_train, y_train, X_test, y_test, classes=['Benign', 'Malignant'])Confusion matrix visualization showing prediction accuracy and error patterns across different classes. Displays counts or percentages with customizable color schemes and normalization options.
class ConfusionMatrix(ClassificationScoreVisualizer):
"""
Confusion matrix visualizer for classification models.
Parameters:
- estimator: scikit-learn classifier
- ax: matplotlib axes object, axes to plot on
- sample_weight: array-like of sample weights
- percent: bool, whether to display percentages instead of counts (default: False)
- classes: list of class labels for display
- encoder: label encoder for transforming class labels
- cmap: str, matplotlib colormap name (default: "YlOrRd")
- fontsize: int, font size for matrix text
- is_fitted: str or bool, whether estimator is already fitted ("auto", True, False)
- force_model: bool, whether to force model usage even if not required
"""
def __init__(self, estimator, ax=None, sample_weight=None, percent=False, classes=None, encoder=None, cmap="YlOrRd", fontsize=None, is_fitted="auto", force_model=False, **kwargs): ...
def fit(self, X, y, **kwargs): ...
def score(self, X, y, **kwargs): ...
def show(self, **kwargs): ...
def confusion_matrix(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):
"""
Functional API for confusion matrix visualization.
Parameters:
- estimator: scikit-learn classifier
- X_train: training features
- y_train: training labels
- X_test: test features (optional)
- y_test: test labels (optional)
- classes: list of class labels
Returns:
ConfusionMatrix visualizer instance
"""Heatmap visualization of classification metrics including precision, recall, F1-score, and support for each class. Provides a comprehensive overview of model performance across all classes.
class ClassificationReport(ClassificationScoreVisualizer):
"""
Classification report heatmap visualizer.
Parameters:
- estimator: scikit-learn classifier
- classes: list of class labels for display
- sample_weight: array-like of sample weights
- support: bool, whether to draw support column
- cmap: matplotlib colormap for heatmap
"""
def __init__(self, estimator, classes=None, sample_weight=None, support=True, cmap='RdYlBu_r', **kwargs): ...
def fit(self, X, y, **kwargs): ...
def score(self, X, y, **kwargs): ...
def show(self, **kwargs): ...
def classification_report(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):
"""
Functional API for classification report visualization.
Parameters:
- estimator: scikit-learn classifier
- X_train: training features
- y_train: training labels
- X_test: test features (optional)
- y_test: test labels (optional)
- classes: list of class labels
Returns:
ClassificationReport visualizer instance
"""Bar chart showing the difference between actual and predicted class distributions, helping identify systematic prediction biases and class imbalance issues.
class ClassPredictionError(ClassificationScoreVisualizer):
"""
Class prediction error visualizer.
Parameters:
- estimator: scikit-learn classifier
- classes: list of class labels for display
"""
def __init__(self, estimator, classes=None, **kwargs): ...
def fit(self, X, y, **kwargs): ...
def score(self, X, y, **kwargs): ...
def show(self, **kwargs): ...
def class_prediction_error(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):
"""
Functional API for class prediction error visualization.
Parameters:
- estimator: scikit-learn classifier
- X_train: training features
- y_train: training labels
- X_test: test features (optional)
- y_test: test labels (optional)
- classes: list of class labels
Returns:
ClassPredictionError visualizer instance
"""Precision-Recall curves for evaluating binary and multi-class classifiers, particularly useful for imbalanced datasets where ROC curves may be overly optimistic.
class PrecisionRecallCurve(ClassificationScoreVisualizer):
"""
Precision-Recall curve visualizer.
Parameters:
- estimator: scikit-learn classifier
- classes: list of class labels for display
- binary: bool, whether to force binary classification mode
- micro: bool, whether to plot micro-averaged PR curve
- per_class: bool, whether to plot per-class PR curves
- iso_f1_curves: bool, whether to draw iso-F1 curves
- fill_area: bool, whether to fill area under curve
- ap_score: bool, whether to annotate average precision score
"""
def __init__(self, estimator, classes=None, binary=False, micro=True, per_class=True, iso_f1_curves=False, fill_area=True, ap_score=True, **kwargs): ...
def fit(self, X, y, **kwargs): ...
def score(self, X, y, **kwargs): ...
def show(self, **kwargs): ...
# Alias for compatibility
PRCurve = PrecisionRecallCurve
def precision_recall_curve(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):
"""
Functional API for precision-recall curve visualization.
Parameters:
- estimator: scikit-learn classifier
- X_train: training features
- y_train: training labels
- X_test: test features (optional)
- y_test: test labels (optional)
- classes: list of class labels
Returns:
PrecisionRecallCurve visualizer instance
"""Visualization of precision, recall, F1-score, and queue rate across different classification thresholds, helping optimize threshold selection for specific business requirements.
class DiscriminationThreshold(ClassificationScoreVisualizer):
"""
Discrimination threshold visualizer for binary classification.
Parameters:
- estimator: scikit-learn binary classifier
- n_trials: int, number of threshold points to evaluate
- random_state: int, random state for reproducibility
"""
def __init__(self, estimator, n_trials=50, random_state=None, **kwargs): ...
def fit(self, X, y, **kwargs): ...
def score(self, X, y, **kwargs): ...
def show(self, **kwargs): ...
def discrimination_threshold(estimator, X_train, y_train, X_test=None, y_test=None, **kwargs):
"""
Functional API for discrimination threshold visualization.
Parameters:
- estimator: scikit-learn binary classifier
- X_train: training features
- y_train: training labels
- X_test: test features (optional)
- y_test: test labels (optional)
Returns:
DiscriminationThreshold visualizer instance
"""Visualization of class distribution in the dataset, helping identify class imbalance issues that may affect model performance.
class ClassBalance(Visualizer):
"""
Class balance visualizer for examining target class distributions.
Parameters:
- labels: list of class labels for display
"""
def __init__(self, labels=None, **kwargs): ...
def fit(self, y, **kwargs): ...
def show(self, **kwargs): ...
def class_balance(y, labels=None, **kwargs):
"""
Functional API for class balance visualization.
Parameters:
- y: target labels
- labels: list of class labels for display
Returns:
ClassBalance visualizer instance
"""class ClassificationScoreVisualizer(ScoreVisualizer):
"""
Base class for classification scoring visualizers.
Provides common functionality for classification model evaluation.
"""
def __init__(self, estimator, **kwargs): ...
def fit(self, X, y, **kwargs): ...
def score(self, X, y, **kwargs): ...from yellowbrick.classifier import ROCAUC, ConfusionMatrix, ClassificationReport
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
# Prepare data and model
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
model = RandomForestClassifier()
# ROC/AUC Analysis
roc_viz = ROCAUC(model)
roc_viz.fit(X_train, y_train)
roc_viz.score(X_test, y_test)
roc_viz.show()
# Confusion Matrix
cm_viz = ConfusionMatrix(model, percent=True)
cm_viz.fit(X_train, y_train)
cm_viz.score(X_test, y_test)
cm_viz.show()
# Classification Report
cr_viz = ClassificationReport(model)
cr_viz.fit(X_train, y_train)
cr_viz.score(X_test, y_test)
cr_viz.show()from yellowbrick.classifier import ROCAUC, PrecisionRecallCurve
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
# Load multi-class dataset
iris = load_iris()
X, y = iris.data, iris.target
class_names = iris.target_names
# Multi-class ROC analysis
model = RandomForestClassifier()
roc_viz = ROCAUC(model, classes=class_names)
roc_viz.fit(X_train, y_train)
roc_viz.score(X_test, y_test)
roc_viz.show()
# Multi-class Precision-Recall
pr_viz = PrecisionRecallCurve(model, classes=class_names, per_class=True, micro=True)
pr_viz.fit(X_train, y_train)
pr_viz.score(X_test, y_test)
pr_viz.show()from yellowbrick.classifier import DiscriminationThreshold
from sklearn.linear_model import LogisticRegression
# Binary classification threshold analysis
model = LogisticRegression()
threshold_viz = DiscriminationThreshold(model)
threshold_viz.fit(X_train, y_train)
threshold_viz.score(X_test, y_test)
threshold_viz.show()Install with Tessl CLI
npx tessl i tessl/pypi-yellowbrick