CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-yellowbrick

A suite of visual analysis and diagnostic tools for machine learning.

Overview
Eval results
Files

classification.mddocs/

Classification Analysis

Comprehensive visualizers for evaluating classification model performance, providing insights into prediction accuracy, class distributions, decision boundaries, and threshold optimization. These tools support both binary and multi-class classification problems.

Capabilities

ROC/AUC Analysis

ROC (Receiver Operating Characteristic) curves and AUC (Area Under Curve) analysis for binary and multi-class classification models. Visualizes the trade-off between true positive rate and false positive rate across different classification thresholds.

class ROCAUC(ClassificationScoreVisualizer):
    """
    ROC/AUC visualizer for classification models.
    
    Parameters:
    - estimator: scikit-learn classifier
    - ax: matplotlib axes object, axes to plot on
    - micro: bool, whether to plot micro-averaged ROC for multi-class (default: True)
    - macro: bool, whether to plot macro-averaged ROC for multi-class (default: True)
    - per_class: bool, whether to plot per-class ROC curves (default: True)
    - binary: bool, whether to force binary classification mode (default: False)
    - classes: list of class labels for display
    - encoder: label encoder for transforming class labels
    - is_fitted: str or bool, whether estimator is already fitted ("auto", True, False)
    - force_model: bool, whether to force model usage even if not required
    """
    def __init__(self, estimator, ax=None, micro=True, macro=True, per_class=True, binary=False, classes=None, encoder=None, is_fitted="auto", force_model=False, **kwargs): ...
    def fit(self, X, y, **kwargs): ...
    def score(self, X, y, **kwargs): ...
    def show(self, **kwargs): ...

def roc_auc(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):
    """
    Functional API for ROC/AUC visualization.
    
    Parameters:
    - estimator: scikit-learn classifier
    - X_train: training features
    - y_train: training labels
    - X_test: test features (optional)
    - y_test: test labels (optional)
    - classes: list of class labels
    
    Returns:
    ROCAUC visualizer instance
    """

Usage Example:

from yellowbrick.classifier import ROCAUC, roc_auc
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

# Class-based API
model = RandomForestClassifier()
visualizer = ROCAUC(model, classes=['Benign', 'Malignant'])
visualizer.fit(X_train, y_train)
visualizer.score(X_test, y_test)
visualizer.show()

# Functional API
roc_auc(model, X_train, y_train, X_test, y_test, classes=['Benign', 'Malignant'])

Confusion Matrix

Confusion matrix visualization showing prediction accuracy and error patterns across different classes. Displays counts or percentages with customizable color schemes and normalization options.

class ConfusionMatrix(ClassificationScoreVisualizer):
    """
    Confusion matrix visualizer for classification models.
    
    Parameters:
    - estimator: scikit-learn classifier
    - ax: matplotlib axes object, axes to plot on
    - sample_weight: array-like of sample weights
    - percent: bool, whether to display percentages instead of counts (default: False)
    - classes: list of class labels for display
    - encoder: label encoder for transforming class labels
    - cmap: str, matplotlib colormap name (default: "YlOrRd")
    - fontsize: int, font size for matrix text
    - is_fitted: str or bool, whether estimator is already fitted ("auto", True, False)
    - force_model: bool, whether to force model usage even if not required
    """
    def __init__(self, estimator, ax=None, sample_weight=None, percent=False, classes=None, encoder=None, cmap="YlOrRd", fontsize=None, is_fitted="auto", force_model=False, **kwargs): ...
    def fit(self, X, y, **kwargs): ...
    def score(self, X, y, **kwargs): ...
    def show(self, **kwargs): ...

def confusion_matrix(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):
    """
    Functional API for confusion matrix visualization.
    
    Parameters:
    - estimator: scikit-learn classifier
    - X_train: training features
    - y_train: training labels
    - X_test: test features (optional)
    - y_test: test labels (optional)
    - classes: list of class labels
    
    Returns:
    ConfusionMatrix visualizer instance
    """

Classification Report

Heatmap visualization of classification metrics including precision, recall, F1-score, and support for each class. Provides a comprehensive overview of model performance across all classes.

class ClassificationReport(ClassificationScoreVisualizer):
    """
    Classification report heatmap visualizer.
    
    Parameters:
    - estimator: scikit-learn classifier
    - classes: list of class labels for display
    - sample_weight: array-like of sample weights
    - support: bool, whether to draw support column
    - cmap: matplotlib colormap for heatmap
    """
    def __init__(self, estimator, classes=None, sample_weight=None, support=True, cmap='RdYlBu_r', **kwargs): ...
    def fit(self, X, y, **kwargs): ...
    def score(self, X, y, **kwargs): ...
    def show(self, **kwargs): ...

def classification_report(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):
    """
    Functional API for classification report visualization.
    
    Parameters:
    - estimator: scikit-learn classifier
    - X_train: training features
    - y_train: training labels
    - X_test: test features (optional)
    - y_test: test labels (optional)
    - classes: list of class labels
    
    Returns:
    ClassificationReport visualizer instance
    """

Class Prediction Error

Bar chart showing the difference between actual and predicted class distributions, helping identify systematic prediction biases and class imbalance issues.

class ClassPredictionError(ClassificationScoreVisualizer):
    """
    Class prediction error visualizer.
    
    Parameters:
    - estimator: scikit-learn classifier
    - classes: list of class labels for display
    """
    def __init__(self, estimator, classes=None, **kwargs): ...
    def fit(self, X, y, **kwargs): ...
    def score(self, X, y, **kwargs): ...
    def show(self, **kwargs): ...

def class_prediction_error(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):
    """
    Functional API for class prediction error visualization.
    
    Parameters:
    - estimator: scikit-learn classifier
    - X_train: training features
    - y_train: training labels
    - X_test: test features (optional)
    - y_test: test labels (optional)
    - classes: list of class labels
    
    Returns:
    ClassPredictionError visualizer instance
    """

Precision-Recall Curves

Precision-Recall curves for evaluating binary and multi-class classifiers, particularly useful for imbalanced datasets where ROC curves may be overly optimistic.

class PrecisionRecallCurve(ClassificationScoreVisualizer):
    """
    Precision-Recall curve visualizer.
    
    Parameters:
    - estimator: scikit-learn classifier
    - classes: list of class labels for display
    - binary: bool, whether to force binary classification mode
    - micro: bool, whether to plot micro-averaged PR curve
    - per_class: bool, whether to plot per-class PR curves
    - iso_f1_curves: bool, whether to draw iso-F1 curves
    - fill_area: bool, whether to fill area under curve
    - ap_score: bool, whether to annotate average precision score
    """
    def __init__(self, estimator, classes=None, binary=False, micro=True, per_class=True, iso_f1_curves=False, fill_area=True, ap_score=True, **kwargs): ...
    def fit(self, X, y, **kwargs): ...
    def score(self, X, y, **kwargs): ...
    def show(self, **kwargs): ...

# Alias for compatibility
PRCurve = PrecisionRecallCurve

def precision_recall_curve(estimator, X_train, y_train, X_test=None, y_test=None, classes=None, **kwargs):
    """
    Functional API for precision-recall curve visualization.
    
    Parameters:
    - estimator: scikit-learn classifier
    - X_train: training features
    - y_train: training labels
    - X_test: test features (optional)
    - y_test: test labels (optional)
    - classes: list of class labels
    
    Returns:
    PrecisionRecallCurve visualizer instance
    """

Discrimination Threshold

Visualization of precision, recall, F1-score, and queue rate across different classification thresholds, helping optimize threshold selection for specific business requirements.

class DiscriminationThreshold(ClassificationScoreVisualizer):
    """
    Discrimination threshold visualizer for binary classification.
    
    Parameters:
    - estimator: scikit-learn binary classifier
    - n_trials: int, number of threshold points to evaluate
    - random_state: int, random state for reproducibility
    """
    def __init__(self, estimator, n_trials=50, random_state=None, **kwargs): ...
    def fit(self, X, y, **kwargs): ...
    def score(self, X, y, **kwargs): ...
    def show(self, **kwargs): ...

def discrimination_threshold(estimator, X_train, y_train, X_test=None, y_test=None, **kwargs):
    """
    Functional API for discrimination threshold visualization.
    
    Parameters:
    - estimator: scikit-learn binary classifier
    - X_train: training features
    - y_train: training labels
    - X_test: test features (optional)
    - y_test: test labels (optional)
    
    Returns:
    DiscriminationThreshold visualizer instance
    """

Class Balance

Visualization of class distribution in the dataset, helping identify class imbalance issues that may affect model performance.

class ClassBalance(Visualizer):
    """
    Class balance visualizer for examining target class distributions.
    
    Parameters:
    - labels: list of class labels for display
    """
    def __init__(self, labels=None, **kwargs): ...
    def fit(self, y, **kwargs): ...
    def show(self, **kwargs): ...

def class_balance(y, labels=None, **kwargs):
    """
    Functional API for class balance visualization.
    
    Parameters:
    - y: target labels
    - labels: list of class labels for display
    
    Returns:
    ClassBalance visualizer instance
    """

Base Classes

class ClassificationScoreVisualizer(ScoreVisualizer):
    """
    Base class for classification scoring visualizers.
    Provides common functionality for classification model evaluation.
    """
    def __init__(self, estimator, **kwargs): ...
    def fit(self, X, y, **kwargs): ...
    def score(self, X, y, **kwargs): ...

Usage Patterns

Basic Classification Evaluation

from yellowbrick.classifier import ROCAUC, ConfusionMatrix, ClassificationReport
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

# Prepare data and model
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
model = RandomForestClassifier()

# ROC/AUC Analysis
roc_viz = ROCAUC(model)
roc_viz.fit(X_train, y_train)
roc_viz.score(X_test, y_test)
roc_viz.show()

# Confusion Matrix
cm_viz = ConfusionMatrix(model, percent=True)
cm_viz.fit(X_train, y_train)
cm_viz.score(X_test, y_test)
cm_viz.show()

# Classification Report
cr_viz = ClassificationReport(model)
cr_viz.fit(X_train, y_train)
cr_viz.score(X_test, y_test)
cr_viz.show()

Multi-class Classification Analysis

from yellowbrick.classifier import ROCAUC, PrecisionRecallCurve
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier

# Load multi-class dataset
iris = load_iris()
X, y = iris.data, iris.target
class_names = iris.target_names

# Multi-class ROC analysis
model = RandomForestClassifier()
roc_viz = ROCAUC(model, classes=class_names)
roc_viz.fit(X_train, y_train)
roc_viz.score(X_test, y_test)
roc_viz.show()

# Multi-class Precision-Recall
pr_viz = PrecisionRecallCurve(model, classes=class_names, per_class=True, micro=True)
pr_viz.fit(X_train, y_train)
pr_viz.score(X_test, y_test)
pr_viz.show()

Threshold Optimization

from yellowbrick.classifier import DiscriminationThreshold
from sklearn.linear_model import LogisticRegression

# Binary classification threshold analysis
model = LogisticRegression()
threshold_viz = DiscriminationThreshold(model)
threshold_viz.fit(X_train, y_train)
threshold_viz.score(X_test, y_test)
threshold_viz.show()

Install with Tessl CLI

npx tessl i tessl/pypi-yellowbrick

docs

classification.md

clustering.md

data-utilities.md

features.md

index.md

model-selection.md

regression.md

text.md

tile.json