Statistical data visualization library for Python built on matplotlib
—
Create heatmaps and clustered heatmaps for visualizing matrix data, correlation matrices, pivot tables, and hierarchical clustering results. These functions excel at revealing patterns in high-dimensional data and relationships between variables.
Plot rectangular data as a color-encoded matrix with customizable annotations and styling.
def heatmap(
data,
*,
vmin=None,
vmax=None,
cmap=None,
center=None,
robust=False,
annot=None,
fmt=".2g",
annot_kws=None,
linewidths=0,
linecolor="white",
cbar=True,
cbar_kws=None,
cbar_ax=None,
square=False,
xticklabels="auto",
yticklabels="auto",
mask=None,
ax=None,
**kwargs
):
"""
Plot rectangular data as a color-encoded matrix.
Parameters:
- data: 2D array-like, rectangular dataset for heatmap
- vmin, vmax: float, colormap range anchors
- cmap: str or colormap, colormap for mapping data values to colors
- center: float, value at which to center colormap
- robust: bool, use robust quantiles for colormap range
- annot: bool or array-like, annotate cells with data values
- fmt: str, format string for annotations
- annot_kws: dict, keyword arguments for annotation text
- linewidths: float, width of lines dividing cells
- linecolor: str, color of lines dividing cells
- cbar: bool, draw colorbar
- cbar_kws: dict, keyword arguments for colorbar
- square: bool, make cells square-shaped
- xticklabels, yticklabels: bool, list, or "auto", axis labels
- mask: bool array, cells to mask (not plot)
Returns:
matplotlib Axes object
"""Plot hierarchically-clustered heatmaps with dendrograms.
def clustermap(
data,
*,
pivot_kws=None,
method="average",
metric="euclidean",
z_score=None,
standard_scale=None,
figsize=(10, 10),
cbar_kws=None,
row_cluster=True,
col_cluster=True,
row_linkage=None,
col_linkage=None,
row_colors=None,
col_colors=None,
mask=None,
dendrogram_ratio=0.2,
colors_ratio=0.03,
cbar_pos=(0.02, 0.8, 0.05, 0.18),
tree_kws=None,
**kwargs
):
"""
Plot a matrix dataset as a hierarchically-clustered heatmap.
Parameters:
- data: 2D array-like, rectangular dataset for clustering
- method: str, linkage method for hierarchical clustering
- metric: str or function, distance metric for clustering
- z_score: int or None, standardize data along axis (0=rows, 1=columns)
- standard_scale: int or None, normalize data along axis
- figsize: tuple, figure size in inches
- row_cluster, col_cluster: bool, cluster rows/columns
- row_linkage, col_linkage: array-like, precomputed linkage matrices
- row_colors, col_colors: list or array, colors for row/column labels
- dendrogram_ratio: float, proportion of figure for dendrograms
- colors_ratio: float, proportion of figure for color annotations
- cbar_pos: tuple, colorbar position (left, bottom, width, height)
- tree_kws: dict, keyword arguments for dendrogram
- **kwargs: additional arguments passed to heatmap()
Returns:
ClusterGrid object with components:
- .ax_heatmap: heatmap axes
- .ax_row_dendrogram: row dendrogram axes
- .ax_col_dendrogram: column dendrogram axes
- .ax_cbar: colorbar axes
"""import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
# Create sample correlation matrix
flights = sns.load_dataset("flights")
flights_pivot = flights.pivot(index="month", columns="year", values="passengers")
# Basic heatmap
sns.heatmap(flights_pivot)
plt.show()# Correlation matrix with annotations
tips = sns.load_dataset("tips")
correlation_matrix = tips.select_dtypes(include=[np.number]).corr()
sns.heatmap(
correlation_matrix,
annot=True,
cmap="coolwarm",
center=0,
square=True,
fmt=".2f"
)
plt.show()# Mask upper triangle of correlation matrix
mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))
sns.heatmap(
correlation_matrix,
mask=mask,
annot=True,
cmap="coolwarm",
center=0,
square=True,
linewidths=0.5
)
plt.show()# Hierarchically clustered heatmap
sns.clustermap(
flights_pivot,
cmap="viridis",
standard_scale=1, # Normalize columns
figsize=(10, 8)
)
plt.show()# Add row and column color annotations
row_colors = ["red" if month in ["Dec", "Jan", "Feb"] else "blue"
for month in flights_pivot.index]
sns.clustermap(
flights_pivot,
row_colors=row_colors,
cmap="Blues",
z_score=1, # Z-score normalize columns
figsize=(12, 8),
cbar_kws={"label": "Normalized Passengers"}
)
plt.show()# Custom colormap with centered scaling
diverging_data = np.random.randn(10, 12)
sns.heatmap(
diverging_data,
annot=True,
fmt=".1f",
cmap="RdBu_r",
center=0,
robust=True,
linewidths=0.5,
cbar_kws={"shrink": 0.8}
)
plt.show()# Cluster large dataset with custom parameters
large_data = np.random.randn(50, 50)
sns.clustermap(
large_data,
method="ward", # Ward linkage
metric="euclidean", # Euclidean distance
cmap="coolwarm",
center=0,
figsize=(15, 15),
dendrogram_ratio=0.15,
cbar_pos=(0.02, 0.85, 0.03, 0.12)
)
plt.show()# Clustering methods
ClusteringMethod = Literal[
"single", "complete", "average", "weighted",
"centroid", "median", "ward"
]
# Distance metrics
DistanceMetric = str | callable # scipy.spatial.distance metrics
# ClusterGrid return object
class ClusterGrid:
ax_heatmap: matplotlib.axes.Axes # Main heatmap
ax_row_dendrogram: matplotlib.axes.Axes # Row dendrogram
ax_col_dendrogram: matplotlib.axes.Axes # Column dendrogram
ax_cbar: matplotlib.axes.Axes # Colorbar
dendrogram_row: dict # Row clustering info
dendrogram_col: dict # Column clustering info
data: pandas.DataFrame # Clustered dataInstall with Tessl CLI
npx tessl i tessl/pypi-seaborn