Source code for temporalcv.viz.cv

"""
Cross-validation fold visualization displays.

This module provides sklearn-style Display classes for visualizing
cross-validation fold structures, including gap enforcement for time series.

Examples
--------
>>> from temporalcv import WalkForwardCV
>>> from temporalcv.viz import CVFoldsDisplay
>>>
>>> cv = WalkForwardCV(n_splits=5, test_size=20, extra_gap=1)
>>> display = CVFoldsDisplay.from_cv(cv, X, y)
>>> display.plot()
>>> plt.show()
"""

from __future__ import annotations

from typing import Any

import numpy as np
from matplotlib.axes import Axes

from ._base import BaseDisplay
from ._style import (
    COLORS,
    TUFTE_PALETTE,
    apply_tufte_style,
    direct_label,
    set_tufte_labels,
    set_tufte_title,
)

__all__ = ["CVFoldsDisplay"]


[docs] class CVFoldsDisplay(BaseDisplay): """ Visualization of cross-validation fold structure. Displays train/test splits as horizontal bars, with optional gap visualization for time series cross-validation. Parameters ---------- train_indices : list of array-like Training indices for each fold. test_indices : list of array-like Test indices for each fold. gap_indices : list of array-like, optional Gap indices for each fold (for walk-forward CV with gap). n_samples : int, optional Total number of samples. Inferred from indices if not provided. Attributes ---------- ax_ : matplotlib.axes.Axes The axes used for plotting. figure_ : matplotlib.figure.Figure The figure containing the plot. See Also -------- temporalcv.WalkForwardCV : Walk-forward cross-validator with gap. temporalcv.cv_financial.PurgedKFold : Purged K-Fold for finance. Examples -------- >>> from temporalcv import WalkForwardCV >>> from temporalcv.viz import CVFoldsDisplay >>> import numpy as np >>> >>> X = np.random.randn(200, 5) >>> y = np.random.randn(200) >>> cv = WalkForwardCV(n_splits=5, test_size=20, extra_gap=1) >>> >>> # From cross-validator >>> display = CVFoldsDisplay.from_cv(cv, X, y) >>> display.plot() >>> >>> # Or from pre-computed splits >>> splits = list(cv.split(X, y)) >>> display = CVFoldsDisplay.from_splits(splits) >>> display.plot() """
[docs] def __init__( self, train_indices: list[np.ndarray], test_indices: list[np.ndarray], *, gap_indices: list[np.ndarray] | None = None, n_samples: int | None = None, ): self.train_indices = [np.asarray(t) for t in train_indices] self.test_indices = [np.asarray(t) for t in test_indices] self.gap_indices = [np.asarray(g) for g in gap_indices] if gap_indices else None # Infer n_samples if n_samples is not None: self.n_samples = n_samples else: all_indices = np.concatenate(self.train_indices + self.test_indices) self.n_samples = int(np.max(all_indices)) + 1 self.n_splits = len(self.train_indices)
[docs] @classmethod def from_cv( cls, cv: Any, X: np.ndarray, y: np.ndarray | None = None, *, groups: np.ndarray | None = None, ) -> CVFoldsDisplay: """ Create display from a cross-validator object. Parameters ---------- cv : cross-validator A scikit-learn compatible cross-validator with split() method. X : array-like of shape (n_samples, n_features) Training data. y : array-like of shape (n_samples,), optional Target values. groups : array-like of shape (n_samples,), optional Group labels for GroupKFold-like splitters. Returns ------- CVFoldsDisplay The display object. Examples -------- >>> from temporalcv import WalkForwardCV >>> cv = WalkForwardCV(n_splits=5, test_size=20) >>> display = CVFoldsDisplay.from_cv(cv, X, y) """ trains = [] tests = [] gaps = [] for train, test in cv.split(X, y, groups): trains.append(train) tests.append(test) # Detect gap (indices between train end and test start) if len(train) > 0 and len(test) > 0: gap_start = train[-1] + 1 gap_end = test[0] if gap_end > gap_start: gaps.append(np.arange(gap_start, gap_end)) else: gaps.append(np.array([])) else: gaps.append(np.array([])) # Check if any gaps exist has_gaps = any(len(g) > 0 for g in gaps) return cls( trains, tests, gap_indices=gaps if has_gaps else None, n_samples=len(X), )
[docs] @classmethod def from_splits( cls, splits: list[tuple[np.ndarray, np.ndarray]], *, n_samples: int | None = None, ) -> CVFoldsDisplay: """ Create display from pre-computed splits. Parameters ---------- splits : list of (train_indices, test_indices) tuples Pre-computed splits from cv.split(). n_samples : int, optional Total number of samples. Returns ------- CVFoldsDisplay The display object. Examples -------- >>> splits = list(cv.split(X, y)) >>> display = CVFoldsDisplay.from_splits(splits, n_samples=len(X)) """ trains = [s[0] for s in splits] tests = [s[1] for s in splits] gaps = [] for train, test in splits: if len(train) > 0 and len(test) > 0: gap_start = int(train[-1] + 1) gap_end = test[0] if gap_end > gap_start: gaps.append(np.arange(gap_start, gap_end)) else: gaps.append(np.array([])) else: gaps.append(np.array([])) has_gaps = any(len(g) > 0 for g in gaps) return cls( trains, tests, gap_indices=gaps if has_gaps else None, n_samples=n_samples, )
[docs] def plot( self, *, ax: Axes | None = None, tufte: bool = True, bar_height: float = 0.6, show_labels: bool = True, title: str | None = None, ) -> CVFoldsDisplay: """ Plot the cross-validation fold structure. Parameters ---------- ax : matplotlib.axes.Axes, optional Axes to plot on. If None, creates new figure. tufte : bool If True, apply Tufte styling (default). bar_height : float Height of each fold bar (0-1). show_labels : bool If True, show fold labels and sample counts. title : str, optional Plot title. If None, uses default. Returns ------- self The display object for method chaining. Examples -------- >>> display.plot(title="Walk-Forward CV Folds") >>> plt.tight_layout() >>> plt.show() """ ax = self._get_ax_or_create(ax, figsize=(10, max(3, self.n_splits * 0.8))) if tufte: apply_tufte_style(ax) # Plot each fold for fold_idx in range(self.n_splits): y_pos = self.n_splits - 1 - fold_idx # Reverse so fold 1 is on top train = self.train_indices[fold_idx] test = self.test_indices[fold_idx] # Training set if len(train) > 0: ax.barh( y_pos, len(train), left=train[0], height=bar_height, color=COLORS["train"], alpha=0.85, edgecolor="none", label="Train" if fold_idx == 0 else None, ) # Gap (if exists) if self.gap_indices is not None and len(self.gap_indices[fold_idx]) > 0: gap = self.gap_indices[fold_idx] ax.barh( y_pos, len(gap), left=gap[0], height=bar_height, color=COLORS["gap"], alpha=0.5, edgecolor="none", label="Gap" if fold_idx == 0 else None, ) # Test set if len(test) > 0: ax.barh( y_pos, len(test), left=test[0], height=bar_height, color=COLORS["test"], alpha=0.85, edgecolor="none", label="Test" if fold_idx == 0 else None, ) # Direct labels (Tufte principle) if show_labels: # Fold label on left direct_label( ax, -5, y_pos, f"Fold {fold_idx + 1}", offset=(0, 0), ha="right", va="center", fontsize=9, ) # Sample counts (right of test bar) if len(test) > 0: direct_label( ax, test[-1] + 2, y_pos, f"n={len(train)}/{len(test)}", offset=(0, 0), ha="left", va="center", fontsize=8, color=TUFTE_PALETTE["text_secondary"], ) # Styling ax.set_xlim(-self.n_samples * 0.15, self.n_samples * 1.1) ax.set_ylim(-0.5, self.n_splits - 0.5) ax.set_yticks([]) set_tufte_labels(ax, xlabel="Sample Index") if title is None: gap_text = " (with gap)" if self.gap_indices is not None else "" title = f"Cross-Validation Folds{gap_text}" set_tufte_title(ax, title) # Minimal legend (bottom right, unobtrusive) handles, labels = ax.get_legend_handles_labels() if handles: ax.legend( handles, labels, loc="lower right", frameon=False, fontsize=8, ncol=len(handles), ) self._finalize_plot(ax) return self