"""
Metric comparison visualization displays.
This module provides sklearn-style Display classes for visualizing
metric comparisons between models or methods.
Examples
--------
>>> from temporalcv.viz import MetricComparisonDisplay
>>>
>>> results = {
... "Model A": {"MAE": 0.15, "RMSE": 0.22},
... "Model B": {"MAE": 0.12, "RMSE": 0.19},
... }
>>> display = MetricComparisonDisplay.from_dict(results)
>>> display.plot()
"""
from __future__ import annotations
import numpy as np
from matplotlib.axes import Axes
from ._base import BaseDisplay
from ._style import (
COLORS,
TUFTE_PALETTE,
apply_tufte_style,
set_tufte_labels,
set_tufte_title,
)
__all__ = ["MetricComparisonDisplay"]
[docs]
class MetricComparisonDisplay(BaseDisplay):
"""
Visualization comparing metrics across models.
Displays metric comparisons as grouped bar charts with Tufte styling.
Parameters
----------
model_names : list of str
Names of models being compared.
metric_names : list of str
Names of metrics being compared.
values : array-like of shape (n_models, n_metrics)
Metric values for each model.
lower_is_better : dict, optional
Map of metric name to bool indicating if lower is better.
Default: True for all metrics.
Attributes
----------
ax_ : matplotlib.axes.Axes
The axes used for plotting.
figure_ : matplotlib.figure.Figure
The figure containing the plot.
See Also
--------
temporalcv.compare.compare_horizons : Compare across horizons.
temporalcv.compare.compare_models : Compare model performance.
Examples
--------
>>> from temporalcv.viz import MetricComparisonDisplay
>>>
>>> # From dictionary
>>> results = {"Model A": {"MAE": 0.15}, "Model B": {"MAE": 0.12}}
>>> display = MetricComparisonDisplay.from_dict(results)
>>> display.plot()
>>>
>>> # From arrays
>>> display = MetricComparisonDisplay.from_arrays(
... model_names=["A", "B"],
... metric_names=["MAE", "RMSE"],
... values=[[0.15, 0.22], [0.12, 0.19]],
... )
>>> display.plot()
"""
[docs]
def __init__(
self,
model_names: list[str],
metric_names: list[str],
values: np.ndarray,
*,
lower_is_better: dict[str, bool] | None = None,
baseline_idx: int | None = None,
):
self.model_names = list(model_names)
self.metric_names = list(metric_names)
self.values = np.asarray(values)
self.lower_is_better = lower_is_better or dict.fromkeys(metric_names, True)
self.baseline_idx = baseline_idx
self.n_models = len(model_names)
self.n_metrics = len(metric_names)
[docs]
@classmethod
def from_dict(
cls,
results: dict[str, dict[str, float]],
*,
lower_is_better: dict[str, bool] | None = None,
baseline: str | None = None,
) -> MetricComparisonDisplay:
"""
Create display from a nested dictionary.
Parameters
----------
results : dict
Nested dict of {model_name: {metric_name: value}}.
lower_is_better : dict, optional
Map of metric name to bool.
baseline : str, optional
Name of baseline model for relative comparison.
Returns
-------
MetricComparisonDisplay
The display object.
Examples
--------
>>> results = {
... "Baseline": {"MAE": 0.20, "RMSE": 0.28},
... "Model A": {"MAE": 0.15, "RMSE": 0.22},
... }
>>> display = MetricComparisonDisplay.from_dict(results, baseline="Baseline")
"""
model_names = list(results.keys())
metric_names = list(next(iter(results.values())).keys())
values = np.array(
[[results[m].get(metric, np.nan) for metric in metric_names] for m in model_names]
)
baseline_idx = None
if baseline is not None and baseline in model_names:
baseline_idx = model_names.index(baseline)
return cls(
model_names,
metric_names,
values,
lower_is_better=lower_is_better,
baseline_idx=baseline_idx,
)
[docs]
@classmethod
def from_arrays(
cls,
model_names: list[str],
metric_names: list[str],
values: np.ndarray,
*,
lower_is_better: dict[str, bool] | None = None,
baseline_idx: int | None = None,
) -> MetricComparisonDisplay:
"""
Create display from arrays.
Parameters
----------
model_names : list of str
Names of models.
metric_names : list of str
Names of metrics.
values : array-like of shape (n_models, n_metrics)
Metric values.
lower_is_better : dict, optional
Map of metric name to bool.
baseline_idx : int, optional
Index of baseline model.
Returns
-------
MetricComparisonDisplay
The display object.
"""
return cls(
model_names,
metric_names,
values,
lower_is_better=lower_is_better,
baseline_idx=baseline_idx,
)
[docs]
def plot(
self,
*,
ax: Axes | None = None,
tufte: bool = True,
orientation: str = "vertical",
show_values: bool = True,
show_best: bool = True,
title: str | None = None,
metric_idx: int | None = None,
) -> MetricComparisonDisplay:
"""
Plot the metric comparison.
Parameters
----------
ax : matplotlib.axes.Axes, optional
Axes to plot on. If None, creates new figure.
tufte : bool
If True, apply Tufte styling (default).
orientation : str
"vertical" (bars go up) or "horizontal" (bars go right).
show_values : bool
If True, show metric values on bars.
show_best : bool
If True, highlight best model for each metric.
title : str, optional
Plot title.
metric_idx : int, optional
If provided, plot only this metric (useful for single-metric comparison).
Returns
-------
self
The display object for method chaining.
"""
# Determine if plotting single metric or multiple
if metric_idx is not None:
return self._plot_single_metric(
ax=ax,
tufte=tufte,
orientation=orientation,
show_values=show_values,
show_best=show_best,
title=title,
metric_idx=metric_idx,
)
elif self.n_metrics == 1:
return self._plot_single_metric(
ax=ax,
tufte=tufte,
orientation=orientation,
show_values=show_values,
show_best=show_best,
title=title,
metric_idx=0,
)
else:
return self._plot_grouped(
ax=ax,
tufte=tufte,
show_values=show_values,
show_best=show_best,
title=title,
)
def _plot_single_metric(
self,
*,
ax: Axes | None,
tufte: bool,
orientation: str,
show_values: bool,
show_best: bool,
title: str | None,
metric_idx: int,
) -> MetricComparisonDisplay:
"""Plot comparison for a single metric."""
ax = self._get_ax_or_create(ax, figsize=(8, max(3, self.n_models * 0.6)))
if tufte:
apply_tufte_style(ax)
values = self.values[:, metric_idx]
metric_name = self.metric_names[metric_idx]
lower_better = self.lower_is_better.get(metric_name, True)
# Determine best model
best_idx = np.nanargmin(values) if lower_better else np.nanargmax(values)
# Colors: best in green, others in muted blue
colors = [
COLORS["pass"] if i == best_idx and show_best else TUFTE_PALETTE["info"]
for i in range(self.n_models)
]
# Baseline highlighting
if self.baseline_idx is not None and self.baseline_idx != best_idx:
colors[self.baseline_idx] = TUFTE_PALETTE["secondary"]
positions = np.arange(self.n_models)
if orientation == "horizontal":
ax.barh(
positions,
values,
color=colors,
alpha=0.85,
edgecolor="none",
height=0.6,
)
# Value labels
if show_values:
for _i, (pos, val) in enumerate(zip(positions, values)):
ha = "left" if val >= 0 else "right"
offset = val * 0.02 if val >= 0 else val * 0.02
ax.text(
val + offset,
float(pos),
f"{val:.3f}",
va="center",
ha=ha,
fontsize=9,
color=TUFTE_PALETTE["text"],
)
ax.set_yticks(positions)
ax.set_yticklabels(self.model_names)
set_tufte_labels(ax, xlabel=metric_name)
else: # vertical
ax.bar(
positions,
values,
color=colors,
alpha=0.85,
edgecolor="none",
width=0.6,
)
# Value labels
if show_values:
for _i, (pos, val) in enumerate(zip(positions, values)):
ax.text(
float(pos),
val + max(values) * 0.02,
f"{val:.3f}",
ha="center",
va="bottom",
fontsize=9,
color=TUFTE_PALETTE["text"],
)
ax.set_xticks(positions)
ax.set_xticklabels(self.model_names, rotation=0)
set_tufte_labels(ax, ylabel=metric_name)
if title is None:
title = f"{metric_name} Comparison"
set_tufte_title(ax, title)
self._finalize_plot(ax)
return self
def _plot_grouped(
self,
*,
ax: Axes | None,
tufte: bool,
show_values: bool, # noqa: ARG002
show_best: bool,
title: str | None,
) -> MetricComparisonDisplay:
"""Plot grouped bar chart for multiple metrics."""
figsize = (max(8, self.n_models * 2), 5)
ax = self._get_ax_or_create(ax, figsize=figsize)
if tufte:
apply_tufte_style(ax)
# Bar positioning
bar_width = 0.8 / self.n_metrics
positions = np.arange(self.n_models)
# Color palette for metrics (cycle through Tufte colors)
metric_colors = [
TUFTE_PALETTE["info"],
TUFTE_PALETTE["warning"],
TUFTE_PALETTE["success"],
TUFTE_PALETTE["accent"],
]
for m_idx, metric in enumerate(self.metric_names):
offset = (m_idx - self.n_metrics / 2 + 0.5) * bar_width
values = self.values[:, m_idx]
color = metric_colors[m_idx % len(metric_colors)]
ax.bar(
positions + offset,
values,
bar_width * 0.9,
label=metric,
color=color,
alpha=0.85,
edgecolor="none",
)
# Highlight best for each metric
if show_best:
lower_better = self.lower_is_better.get(metric, True)
best_idx = np.nanargmin(values) if lower_better else np.nanargmax(values)
# Add subtle marker for best
ax.scatter(
positions[best_idx] + offset,
values[best_idx] + max(self.values.flat) * 0.03,
marker="v",
color=COLORS["pass"],
s=30,
zorder=5,
)
ax.set_xticks(positions)
ax.set_xticklabels(self.model_names)
set_tufte_labels(ax, ylabel="Metric Value")
if title is None:
title = "Model Comparison"
set_tufte_title(ax, title)
# Minimal legend
ax.legend(
loc="upper right",
frameon=False,
fontsize=9,
)
self._finalize_plot(ax)
return self
[docs]
def plot_relative(
self,
*,
ax: Axes | None = None,
tufte: bool = True,
title: str | None = None,
) -> MetricComparisonDisplay:
"""
Plot metrics relative to baseline (percent improvement).
Requires baseline_idx to be set.
Parameters
----------
ax : matplotlib.axes.Axes, optional
Axes to plot on.
tufte : bool
If True, apply Tufte styling.
title : str, optional
Plot title.
Returns
-------
self
The display object.
"""
if self.baseline_idx is None:
raise ValueError("baseline_idx must be set for relative comparison")
ax = self._get_ax_or_create(ax, figsize=(8, max(3, self.n_models * 0.6)))
if tufte:
apply_tufte_style(ax)
# Compute relative improvement
baseline_values = self.values[self.baseline_idx]
relative = np.zeros_like(self.values)
for m_idx, metric in enumerate(self.metric_names):
lower_better = self.lower_is_better.get(metric, True)
if lower_better:
# Improvement = (baseline - model) / baseline * 100
relative[:, m_idx] = (
(baseline_values[m_idx] - self.values[:, m_idx]) / baseline_values[m_idx] * 100
)
else:
# Improvement = (model - baseline) / baseline * 100
relative[:, m_idx] = (
(self.values[:, m_idx] - baseline_values[m_idx]) / baseline_values[m_idx] * 100
)
# For single metric, use simple bar chart
if self.n_metrics == 1:
positions = np.arange(self.n_models)
values = relative[:, 0]
colors = [
COLORS["pass"] if v > 0 else COLORS["halt"] if v < 0 else TUFTE_PALETTE["secondary"]
for v in values
]
ax.barh(
positions,
values,
color=colors,
alpha=0.85,
edgecolor="none",
height=0.6,
)
# Value labels
for pos, val in zip(positions, values):
ha = "left" if val >= 0 else "right"
ax.text(
val + (1 if val >= 0 else -1),
float(pos),
f"{val:+.1f}%",
va="center",
ha=ha,
fontsize=9,
color=TUFTE_PALETTE["text"],
)
ax.axvline(0, color=TUFTE_PALETTE["spine"], linewidth=0.8, zorder=0)
ax.set_yticks(positions)
ax.set_yticklabels(self.model_names)
set_tufte_labels(ax, xlabel=f"Improvement in {self.metric_names[0]} (%)")
if title is None:
baseline_name = self.model_names[self.baseline_idx]
title = f"Improvement vs {baseline_name}"
set_tufte_title(ax, title)
self._finalize_plot(ax)
return self