"""
Gate result visualization displays.
This module provides sklearn-style Display classes for visualizing
validation gate results (HALT/WARN/PASS).
Examples
--------
>>> from temporalcv.gates import gate_signal_verification, run_gates
>>> from temporalcv.viz import GateResultDisplay
>>>
>>> result = gate_signal_verification(model, X, y)
>>> display = GateResultDisplay.from_gate(result)
>>> display.plot()
"""
from __future__ import annotations
from typing import Any
import numpy as np
from matplotlib.axes import Axes
from ._base import BaseDisplay
from ._style import (
COLORS,
TUFTE_PALETTE,
apply_tufte_style,
set_tufte_title,
)
__all__ = ["GateResultDisplay", "GateComparisonDisplay"]
[docs]
class GateResultDisplay(BaseDisplay):
"""
Visualization of a single gate result.
Displays the gate status (HALT/WARN/PASS) with metric details.
Parameters
----------
name : str
Gate name.
status : str
Gate status ("HALT", "WARN", or "PASS").
message : str
Gate message.
metrics : dict, optional
Additional metrics to display.
Attributes
----------
ax_ : matplotlib.axes.Axes
The axes used for plotting.
figure_ : matplotlib.figure.Figure
The figure containing the plot.
See Also
--------
temporalcv.gates.gate_signal_verification : Signal verification gate.
temporalcv.gates.gate_suspicious_improvement : Improvement gate.
Examples
--------
>>> from temporalcv.gates import gate_signal_verification
>>> from temporalcv.viz import GateResultDisplay
>>>
>>> result = gate_signal_verification(model, X, y, n_shuffles=100)
>>> display = GateResultDisplay.from_gate(result)
>>> display.plot()
"""
[docs]
def __init__(
self,
name: str,
status: str,
message: str,
*,
metrics: dict[str, Any] | None = None,
):
self.name = name
self.status = status.upper()
self.message = message
self.metrics = metrics or {}
[docs]
@classmethod
def from_gate(cls, gate_result: Any) -> GateResultDisplay:
"""
Create display from a GateResult object.
Parameters
----------
gate_result : GateResult
Result from a gate function (e.g., gate_signal_verification).
Returns
-------
GateResultDisplay
The display object.
Examples
--------
>>> result = gate_signal_verification(model, X, y)
>>> display = GateResultDisplay.from_gate(result)
"""
# Extract status string from enum
status_str = str(gate_result.status)
if "." in status_str:
status_str = status_str.split(".")[-1]
# Extract metrics if available
metrics = {}
if hasattr(gate_result, "details") and gate_result.details:
metrics = gate_result.details
return cls(
name=gate_result.gate_name,
status=status_str,
message=gate_result.message,
metrics=metrics,
)
[docs]
def plot(
self,
*,
ax: Axes | None = None,
tufte: bool = True,
show_message: bool = True,
) -> GateResultDisplay:
"""
Plot the gate result.
Parameters
----------
ax : matplotlib.axes.Axes, optional
Axes to plot on. If None, creates new figure.
tufte : bool
If True, apply Tufte styling (default).
show_message : bool
If True, show the gate message.
Returns
-------
self
The display object for method chaining.
"""
ax = self._get_ax_or_create(ax, figsize=(6, 2))
if tufte:
apply_tufte_style(ax)
# Status color
status_colors = {
"HALT": COLORS["halt"],
"WARN": COLORS["warn"],
"PASS": COLORS["pass"],
}
color = status_colors.get(self.status, TUFTE_PALETTE["secondary"])
# Draw status indicator
ax.barh(0, 1, height=0.6, color=color, alpha=0.85, edgecolor="none")
# Status text (centered, white)
ax.text(
0.5,
0,
self.status,
ha="center",
va="center",
fontsize=14,
fontweight="bold",
color="white",
)
# Gate name (above)
ax.text(
0.5,
0.5,
self.name,
ha="center",
va="bottom",
fontsize=10,
color=TUFTE_PALETTE["text"],
)
# Message (below)
if show_message and self.message:
# Truncate long messages
msg = self.message if len(self.message) < 60 else self.message[:57] + "..."
ax.text(
0.5,
-0.5,
msg,
ha="center",
va="top",
fontsize=8,
color=TUFTE_PALETTE["text_secondary"],
style="italic",
)
# Remove all axes elements
ax.set_xlim(0, 1)
ax.set_ylim(-1, 1)
ax.axis("off")
self._finalize_plot(ax)
return self
[docs]
class GateComparisonDisplay(BaseDisplay):
"""
Visualization comparing multiple gate results.
Displays multiple gates side by side for a comprehensive view.
Parameters
----------
gate_results : list
List of GateResult objects or (name, status) tuples.
Attributes
----------
ax_ : matplotlib.axes.Axes
The axes used for plotting.
figure_ : matplotlib.figure.Figure
The figure containing the plot.
See Also
--------
temporalcv.gates.run_gates : Run multiple gates.
GateResultDisplay : Single gate visualization.
Examples
--------
>>> from temporalcv.gates import run_gates, gate_signal_verification
>>> from temporalcv.viz import GateComparisonDisplay
>>>
>>> gates = [
... gate_signal_verification(model, X, y),
... gate_suspicious_improvement(model_mae, baseline_mae),
... ]
>>> report = run_gates(gates)
>>> display = GateComparisonDisplay.from_report(report)
>>> display.plot()
"""
[docs]
def __init__(
self,
names: list[str],
statuses: list[str],
messages: list[str] | None = None,
):
self.names = names
self.statuses = [s.upper() for s in statuses]
self.messages = messages or [""] * len(names)
self.n_gates = len(names)
[docs]
@classmethod
def from_gates(cls, gate_results: list[Any]) -> GateComparisonDisplay:
"""
Create display from a list of GateResult objects.
Parameters
----------
gate_results : list of GateResult
Results from gate functions.
Returns
-------
GateComparisonDisplay
The display object.
"""
names = []
statuses = []
messages = []
for result in gate_results:
names.append(result.gate_name)
status_str = str(result.status)
if "." in status_str:
status_str = status_str.split(".")[-1]
statuses.append(status_str)
messages.append(result.message)
return cls(names, statuses, messages)
[docs]
@classmethod
def from_report(cls, report: Any) -> GateComparisonDisplay:
"""
Create display from a GateReport object.
Parameters
----------
report : GateReport
Report from run_gates().
Returns
-------
GateComparisonDisplay
The display object.
"""
return cls.from_gates(report.results)
[docs]
def plot(
self,
*,
ax: Axes | None = None,
tufte: bool = True,
orientation: str = "horizontal",
show_messages: bool = False,
title: str | None = None,
) -> GateComparisonDisplay:
"""
Plot the gate comparison.
Parameters
----------
ax : matplotlib.axes.Axes, optional
Axes to plot on. If None, creates new figure.
tufte : bool
If True, apply Tufte styling (default).
orientation : str
"horizontal" (bars left-right) or "vertical" (bars top-bottom).
show_messages : bool
If True, show gate messages.
title : str, optional
Plot title.
Returns
-------
self
The display object for method chaining.
"""
# Determine figure size based on orientation and number of gates
if orientation == "horizontal":
figsize = (max(6, self.n_gates * 1.5), 2.5)
else:
figsize = (4, max(3, self.n_gates * 0.8))
ax = self._get_ax_or_create(ax, figsize=figsize)
if tufte:
apply_tufte_style(ax)
# Status colors
status_colors = {
"HALT": COLORS["halt"],
"WARN": COLORS["warn"],
"PASS": COLORS["pass"],
}
if orientation == "horizontal":
# Horizontal bars (side by side)
0.8 / self.n_gates
np.arange(self.n_gates)
for i, (name, status) in enumerate(zip(self.names, self.statuses)):
color = status_colors.get(status, TUFTE_PALETTE["secondary"])
ax.bar(
i,
1,
width=0.7,
color=color,
alpha=0.85,
edgecolor="none",
)
# Status label (on bar)
ax.text(
i,
0.5,
status,
ha="center",
va="center",
fontsize=11,
fontweight="bold",
color="white",
)
# Gate name (below)
ax.text(
i,
-0.1,
name,
ha="center",
va="top",
fontsize=9,
color=TUFTE_PALETTE["text"],
rotation=0,
)
ax.set_xlim(-0.5, self.n_gates - 0.5)
ax.set_ylim(-0.5 if not show_messages else -0.8, 1.1)
ax.set_xticks([])
ax.set_yticks([])
else:
# Vertical layout (stacked)
for i, (name, status) in enumerate(zip(self.names, self.statuses)):
y_pos = self.n_gates - 1 - i
color = status_colors.get(status, TUFTE_PALETTE["secondary"])
ax.barh(
y_pos,
1,
height=0.6,
color=color,
alpha=0.85,
edgecolor="none",
)
# Status label (on bar)
ax.text(
0.5,
y_pos,
status,
ha="center",
va="center",
fontsize=11,
fontweight="bold",
color="white",
)
# Gate name (left of bar)
ax.text(
-0.05,
y_pos,
name,
ha="right",
va="center",
fontsize=9,
color=TUFTE_PALETTE["text"],
)
ax.set_xlim(-0.5, 1.1)
ax.set_ylim(-0.5, self.n_gates - 0.5)
ax.set_xticks([])
ax.set_yticks([])
# Title
if title:
set_tufte_title(ax, title)
# Remove spines for this visualization
for spine in ax.spines.values():
spine.set_visible(False)
self._finalize_plot(ax)
return self