Source code for propflow.snapshots.visualizer

"""Visualization utilities for belief propagation snapshot trajectories."""

from __future__ import annotations

import math
from pathlib import Path
from typing import Any, Dict, List, Literal, Mapping, Sequence, Tuple
import getpass

from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
import numpy as np

from .types import EngineSnapshot
from propflow.utils.tools.bct import BCTCreator, SnapshotBCTBuilder
from propflow.core.agents import FactorAgent
from propflow.configs.global_config_mapping import domain_value_to_label

FactorLike = str | FactorAgent


class CostTablePlotter:
    """Helper class for plotting and formatting cost tables."""

    @staticmethod
    def get_agent_connection_info(
        agent: FactorAgent,
        connections: Dict[FactorAgent, Tuple[str, str]] | None = None,
    ) -> Tuple[str, str]:
        """Get row and column variable names for a factor agent, prompting if needed."""
        # Check explicit connections dict first
        if connections and agent in connections:
            return connections[agent]

        # Try to get from connection_number
        if len(agent.connection_number) == 2:
            # Sort by dimension index to get row (0) then col (1)
            sorted_vars = sorted(
                agent.connection_number.items(), key=lambda item: item[1]
            )
            return sorted_vars[0][0], sorted_vars[1][0]

        # Fallback: prompt user
        print(f"Factor '{agent.name}' does not have explicit row/col connection info.")
        row_var = getpass.getpass(
            f"Enter name for ROW variable (axis 0) of {agent.name}: "
        )
        col_var = getpass.getpass(
            f"Enter name for COL variable (axis 1) of {agent.name}: "
        )
        return row_var, col_var

    @staticmethod
    def prepare_from_agent(
        agent: FactorAgent,
        connections: Dict[FactorAgent, Tuple[str, str]] | None = None,
    ) -> Tuple[np.ndarray, List[str], List[str], str, str]:
        """Prepare cost display data directly from a FactorAgent."""
        if agent.cost_table is None:
            raise ValueError(
                f"FactorAgent '{agent.name}' has no cost table initialized."
            )

        matrix = np.asarray(agent.cost_table, dtype=float)
        if matrix.ndim != 2:
            raise ValueError(
                f"Factor '{agent.name}' cost table has shape {matrix.shape}; only 2D tables are supported for plotting."
            )

        row_var, col_var = CostTablePlotter.get_agent_connection_info(
            agent, connections
        )

        # Generate generic labels based on shape
        row_labels = list(range(matrix.shape[0]))
        col_labels = list(range(matrix.shape[1]))

        row_labels = CostTablePlotter.human_domain_labels(row_labels)
        col_labels = CostTablePlotter.human_domain_labels(col_labels)

        return matrix, row_labels, col_labels, row_var, col_var

    @staticmethod
    def draw_heatmap(
        ax: plt.Axes,
        matrix: np.ndarray,
        factor: str,
        step: int,
        row_labels: List[str],
        col_labels: List[str],
        row_name: str,
        col_name: str,
        cmap: str,
    ) -> plt.Axes:  # type: ignore
        im = ax.imshow(matrix, aspect="equal", cmap=cmap)
        ax.set_xticks(np.arange(len(col_labels)))
        ax.set_yticks(np.arange(len(row_labels)))
        ax.set_xticklabels(col_labels)
        ax.set_yticklabels(row_labels)
        ax.set_xlabel(col_name)
        ax.set_ylabel(row_name)
        ax.set_title(f"{factor} cost table (step {step})")
        ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
        ax.set_xticks(np.arange(-0.5, len(col_labels), 1), minor=True)
        ax.set_yticks(np.arange(-0.5, len(row_labels), 1), minor=True)
        ax.grid(which="minor", color="w", linestyle="-", linewidth=1.0, alpha=0.6)
        return im  # type: ignore

    @staticmethod
    def format_table(
        matrix: np.ndarray,
        row_labels: Sequence[str],
        col_labels: Sequence[str],
        row_var: str,
        col_var: str,
        *,
        fmt: str = "{:.3g}",
    ) -> str:
        """Return a labeled, aligned text table for a 2D cost matrix."""
        col_names = [str(col) for col in col_labels]
        row_names = [str(row) for row in row_labels]

        # Compute cell widths for alignment
        value_strings = [[fmt.format(val) for val in row] for row in matrix]
        widths: List[int] = []
        widths.append(
            max(
                len(f"{row_var} \\ {col_var}"),
                len(row_var),
                max(len(r) for r in row_names),
            )
        )
        for col_idx in range(len(col_names)):
            col_width = len(col_names[col_idx])
            for row_vals in value_strings:
                col_width = max(col_width, len(row_vals[col_idx]))
            widths.append(col_width)

        def _pad(text: str, width: int) -> str:
            return text.ljust(width)

        header_cells = [_pad(f"{row_var} \\ {col_var}", widths[0])] + [
            _pad(name, widths[i + 1]) for i, name in enumerate(col_names)
        ]
        lines = ["  ".join(header_cells)]

        for row_name, row_vals in zip(row_names, value_strings):
            cells = [_pad(row_name, widths[0])]
            cells += [_pad(val, widths[i + 1]) for i, val in enumerate(row_vals)]
            lines.append("  ".join(cells))

        return "\n".join(lines)

    @staticmethod
    def human_domain_labels(labels: Sequence[Any]) -> List[str]:
        """Map numeric domain values to letter labels (display-friendly)."""
        pretty: List[str] = []
        for lbl in labels:
            try:
                idx = int(lbl)
                pretty.append(CostTablePlotter.value_label(idx))
                continue
            except (ValueError, TypeError):
                pass
            if isinstance(lbl, str) and ":" in lbl:
                prefix, suffix = lbl.rsplit(":", 1)
                try:
                    idx = int(suffix)
                    pretty.append(f"{prefix}:{CostTablePlotter.value_label(idx)}")
                    continue
                except ValueError:
                    pass
            pretty.append(str(lbl))
        return pretty

    @staticmethod
    def value_label(value: Any) -> str:
        """Convert numeric assignment/index to a letter label; leave others as-is."""
        try:
            idx = int(value)
        except (ValueError, TypeError):
            return str(value)
        if idx >= 0:
            return domain_value_to_label(idx + 1)  # display: 0→a, 1→b, ...
        return str(idx)


[docs] class SnapshotVisualizer: """Visualize belief propagation snapshot trajectories.""" _MAX_AUTO_VARS = 20 _SMALL_PLOT_THRESHOLD = 8 _MAX_AUTO_MESSAGE_PAIRS = 6 def __init__(self, snapshots: Sequence[EngineSnapshot]): """Initialize the visualizer with snapshot records. Args: snapshots: A sequence of EngineSnapshot objects. Raises: ValueError: If snapshots is empty or contains no variables. """ if not snapshots: raise ValueError("Snapshots are empty") self._records = sorted(list(snapshots), key=lambda rec: rec.step) self._steps = [rec.step for rec in self._records] self._variables = self._collect_variables(self._records) self._bct_builder: SnapshotBCTBuilder | None = None if not self._variables: raise ValueError("No variable assignments found in snapshots")
[docs] def variables(self) -> List[str]: """Return sorted list of all variables in the snapshots.""" return sorted(self._variables)
[docs] def steps(self) -> List[int]: """Return the ordered simulation steps captured in the snapshots.""" return list(self._steps)
[docs] def argmin_series( self, vars_filter: List[str] | None = None ) -> Dict[str, List[int | None]]: """Get argmin trajectories for selected variables. Args: vars_filter: Optional list of variable names to include. If None, all variables are included. Returns: Dictionary mapping variable names to their argmin trajectories. """ target_vars = self._select_variables(vars_filter) result: Dict[str, List[int | None]] = {var: [] for var in target_vars} for rec in self._records: r_grouped: Dict[str, List[np.ndarray]] = {} for (f, v), r_msg in rec.R.items(): r_grouped.setdefault(v, []).append(np.asarray(r_msg, dtype=float)) for var in target_vars: if vectors := r_grouped.get(var, []): combined = np.sum(vectors, axis=0) result[var].append(int(np.argmin(combined))) else: result[var].append(rec.assignments.get(var)) return result
[docs] def global_cost_series( self, *, include_missing: bool = False, fill_value: float = float("nan"), ) -> Tuple[List[int], List[float]]: """Return the global cost trajectory extracted from the snapshots. Args: include_missing: If True, include steps with missing costs using ``fill_value``. fill_value: Value to substitute whenever a snapshot lacks a global cost. Returns: A tuple ``(steps, costs)`` with matching lengths. Raises: ValueError: If no snapshots contain global cost information. """ steps: List[int] = [] costs: List[float] = [] for rec in self._records: step = rec.step cost = rec.global_cost if cost is None: if include_missing: steps.append(step) costs.append(fill_value) continue steps.append(step) costs.append(float(cost)) if not costs: raise ValueError("No global cost data available in snapshots") return steps, costs
[docs] def factor_cost_matrix(self, factor: FactorLike, step: int) -> np.ndarray: """Return a copy of a factor's cost table at a given step.""" factor_name = self._factor_name(factor) record = self._snapshot_by_step(step) tables = getattr(record, "cost_tables", {}) if factor_name not in tables: available = ", ".join(sorted(tables.keys())) or "none" raise ValueError( f"Snapshot {step} does not contain a cost table for factor '{factor_name}'." f" Available factors: {available}" ) return np.asarray(tables[factor_name], dtype=float).copy()
[docs] def factor_cost_labels( self, factor: FactorLike, step: int ) -> Tuple[List[str], List[str]]: factor_name = self._factor_name(factor) record = self._snapshot_by_step(step) labels = getattr(record, "cost_labels", {}).get(factor_name) if not labels: raise ValueError( f"No variable ordering stored for factor '{factor_name}' at step {step}." ) if len(labels) != 2: raise ValueError( f"Cost visualisation currently supports binary factors only." f" Factor '{factor_name}' has {len(labels)} variables." ) row_var, col_var = labels dom = record.dom row_labels = dom.get(row_var) or [ f"{row_var}:{i}" for i in range(self._infer_domain_size(row_var, record)) ] col_labels = dom.get(col_var) or [ f"{col_var}:{i}" for i in range(self._infer_domain_size(col_var, record)) ] return row_labels, col_labels
[docs] @staticmethod def plot_agent_cost_table( agent: FactorAgent, *, connections: Dict[FactorAgent, Tuple[str, str]] | None = None, fmt: str = "{:.3g}", plot: bool = True, show: bool = True, savepath: str | None = None, cmap: str = "viridis", annotate: bool = True, ) -> plt.Figure | str: """Static method to show a cost table for a FactorAgent without snapshots. Args: agent: The FactorAgent to visualize. connections: Optional dictionary mapping FactorAgent objects to (row_var, col_var) tuples. Used to provide variable names for agents without explicit connections. fmt: Numeric format string for printing tables and annotations. plot: If True, generate a heatmap using matplotlib. If False, print a text table. show: Whether to display the plot (if plot=True). savepath: Optional path to save the figure (if plot=True). cmap: Colormap for the heatmap (if plot=True). annotate: Whether to annotate cells with values (if plot=True). Returns: The formatted table string (if plot=False) or a matplotlib Figure (if plot=True). """ matrix, row_labels, col_labels, row_var, col_var = ( CostTablePlotter.prepare_from_agent(agent, connections) ) # Always format and print the table string, regardless of plotting table_str = CostTablePlotter.format_table( matrix, row_labels, col_labels, row_var, col_var, fmt=fmt ) print(table_str) if not plot: return table_str # Plotting logic fig, ax = plt.subplots(figsize=(6, 5)) im = CostTablePlotter.draw_heatmap( ax, matrix, agent.name, 0, # Dummy step 0 for static visualization row_labels, col_labels, row_var, col_var, cmap, ) if annotate: norm = im.norm cmap_fn = im.cmap for i, _ in enumerate(row_labels): for j, _ in enumerate(col_labels): val = matrix[i, j] rgba = cmap_fn(norm(val)) r, g, b = rgba[:3] luminance = 0.299 * r + 0.587 * g + 0.114 * b txt_color = "white" if luminance < 0.6 else "black" ax.text( j, i, fmt.format(val), ha="center", va="center", color=txt_color, fontsize=9, ) fig.colorbar(im, ax=ax, shrink=0.85) fig.tight_layout() if savepath: save_path = Path(savepath) save_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(save_path, dpi=150) if show: fig.show() else: plt.close(fig) return fig
[docs] def plot_cost_tables( self, factor: FactorLike | None = None, step: int | None = None, *, show: bool = True, savepath: str | None = None, cmap: str = "viridis", annotate: bool = True, fmt: str = "{:.3g}", connections: Dict[FactorAgent, Tuple[str, str]] | None = None, ) -> plt.Figure | str: """Pretty-print or plot factor cost tables from a snapshot. - When ``factor`` is provided, prints a labeled table for that factor and returns the formatted string. - When ``factor`` is ``None``, plots all cost tables in a grid with axis labels derived from the factor's variable ordering. Titles include the factor name. Args: factor: Optional factor name/agent. If provided, only that factor is shown. step: Snapshot step to use. Defaults to the last recorded step. show: Whether to display the plot (when plotting all factors). savepath: Optional path to save the plotted figure (when plotting all factors). cmap: Matplotlib colormap name for heatmaps. annotate: Whether to overlay numeric values on heatmaps. fmt: Numeric format string for printing tables and annotations. connections: Optional dictionary mapping FactorAgent objects to (row_var, col_var) tuples. Used to provide variable names for agents without explicit connections. Returns: A matplotlib Figure when plotting all factors, or the formatted table string when printing a single factor. """ if not self._records: raise ValueError("No snapshots available") if factor is not None and isinstance(factor, FactorAgent): # Handle direct FactorAgent visualization using the static method logic # We use the static method implementation to avoid code duplication, # but we return the string as expected by this method signature. return SnapshotVisualizer.plot_agent_cost_table( factor, connections=connections, fmt=fmt, plot=True, # Default to plotting when called from plot_cost_tables show=show, savepath=savepath, cmap=cmap, annotate=annotate, ) target_step = step if step is not None else self._records[-1].step record = self._snapshot_by_step(target_step) tables = getattr(record, "cost_tables", {}) if not tables: raise ValueError(f"No cost tables available at step {target_step}") if factor is not None: return self._handle_single_factor_plot( factor, tables, target_step, fmt, connections, cmap=cmap, annotate=annotate, show=show, savepath=savepath, ) return self._plot_cost_table_grid( tables, target_step, cmap, annotate, fmt, show, savepath )
[docs] def plot_factor_costs( self, from_variable: str | Sequence[tuple[str, FactorLike]], to_factor: FactorLike | None = None, step: int | None = None, *, mode: Literal["auto", "min", "max"] = "auto", cmap: str = "viridis", annotate: bool = True, show: bool = True, savepath: str | None = None, return_data: bool = False, highlight_color: str = "tab:red", text_color: str = "black", fmt: str = "{:.3g}", ) -> plt.Figure | Tuple[plt.Figure, np.ndarray, np.ndarray]: """Visualise factor cost tables induced by factor→variable messages.""" if isinstance(from_variable, (list, tuple)) and not isinstance( from_variable, str ): if to_factor is not None: raise ValueError( "When providing multiple factor pairs, omit the 'to_factor' argument." ) if step is None: raise ValueError( "Step must be provided when plotting multiple factor panels." ) if return_data: raise ValueError( "return_data is only supported for single factor visualisations." ) pairs: List[tuple[str, FactorLike]] = [] for item in from_variable: if not isinstance(item, (list, tuple)) or len(item) != 2: raise ValueError("Each entry must be a (variable, factor) pair.") var, fac = item pairs.append((str(var), fac)) real_mode = self._infer_message_mode() if mode == "auto" else mode return self._plot_factor_grid( pairs, step=step, mode=real_mode, cmap=cmap, annotate=annotate, show=show, savepath=savepath, highlight_color=highlight_color, text_color=text_color, fmt=fmt, ) if to_factor is None: raise ValueError( "to_factor must be provided when plotting a single factor panel." ) if step is None: raise ValueError("step must be provided.") real_mode = self._infer_message_mode() if mode == "auto" else mode fig, ax = plt.subplots(figsize=(6, 5)) message_arr, winners_mask, im = self._render_factor_panel( ax, str(from_variable), to_factor, step, real_mode, cmap, annotate, highlight_color, text_color, fmt, ) fig.colorbar(im, ax=ax, shrink=0.85) subtitle = "argmin" if real_mode == "min" else "argmax" factor_name = self._factor_name(to_factor) ax.set_title(f"{factor_name}{from_variable} ({subtitle} of message)") fig.tight_layout() if savepath: save_path = Path(savepath) save_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(save_path, dpi=150) if show: fig.show() else: # pragma: no cover - interactive branch plt.close(fig) return (fig, message_arr, winners_mask) if return_data else fig
[docs] def plot_global_cost( self, *, show: bool = True, savepath: str | None = None, include_missing: bool = False, fill_value: float = float("nan"), rolling_window: int | None = None, return_data: bool = False, ) -> plt.Figure | Tuple[plt.Figure, Dict[str, Any]]: """Plot the evolution of the global cost captured in the snapshots. Args: show: Whether to display the plot window. savepath: Optional file path to save the figure. include_missing: If True, include steps without a recorded cost using ``fill_value``. fill_value: Value used when ``include_missing`` is True and a step lacks a cost. rolling_window: Size of a trailing window to compute and overlay a rolling mean. return_data: If True, return the underlying data alongside the figure. Returns: The created matplotlib figure, optionally accompanied by the plotted data. """ steps, costs = self.global_cost_series( include_missing=include_missing, fill_value=fill_value, ) fig, ax = plt.subplots(figsize=(8, 4.5)) ax.plot(steps, costs, marker="o", label="Global cost") rolling_info: Dict[str, Any] | None = None if rolling_window is not None and rolling_window > 1: if len(costs) >= rolling_window: smooth_steps, smooth_values = self._rolling_window_average( costs, steps, rolling_window, ) ax.plot( smooth_steps, smooth_values, linestyle="--", color="tab:orange", label=f"{rolling_window}-step rolling mean", ) rolling_info = { "window": rolling_window, "steps": smooth_steps, "values": smooth_values, } else: rolling_info = { "window": rolling_window, "steps": [], "values": [], } ax.set_xlabel("Iteration") ax.set_ylabel("Global cost") ax.set_title("Global cost trajectory") ax.grid(True, alpha=0.3) if rolling_info is not None or len(costs) > 1: ax.legend() fig.tight_layout() if savepath: save_path = Path(savepath) save_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(save_path, dpi=150) if show: fig.show() else: plt.close(fig) data = { "steps": steps, "costs": costs, "rolling": rolling_info, } return (fig, data) if return_data else fig
[docs] def plot_message_norms( self, *, message_type: Literal["Q", "R"] = "Q", pairs: Sequence[tuple[str, str]] | None = None, norm: Literal["l2", "l1", "linf"] = "l2", show: bool = True, savepath: str | None = None, return_data: bool = False, ) -> plt.Figure | Tuple[plt.Figure, Dict[str, Any]]: """Plot the per-step norms of Q or R messages. Args: message_type: Select ``\"Q\"`` (variable→factor) or ``\"R\"`` (factor→variable``) messages. pairs: Optional explicit list of message pairs to include. Each tuple is ``(sender, recipient)``. norm: Vector norm used to summarise each message. Supported values are ``\"l2\"``, ``\"l1\"``, ``\"linf\"``. show: Whether to display the plot. savepath: Optional path to save the figure. return_data: If True, include the computed series in the return value. Returns: The created figure, optionally with the underlying message norm series. """ msg_type = message_type.upper() if msg_type not in {"Q", "R"}: raise ValueError("message_type must be 'Q' or 'R'") available_pairs = self._collect_message_pairs(msg_type) if not available_pairs: raise ValueError(f"No {msg_type} messages recorded in the snapshots.") if pairs is None: target_pairs = available_pairs[: self._MAX_AUTO_MESSAGE_PAIRS] else: target_pairs = [(str(src), str(dst)) for src, dst in pairs] if missing := [ pair for pair in target_pairs if pair not in available_pairs ]: missing_str = ", ".join(f"{a}->{b}" for a, b in missing) raise ValueError(f"Requested message pairs not present: {missing_str}") series: Dict[tuple[str, str], List[float]] = {pair: [] for pair in target_pairs} for rec in self._records: messages = rec.Q if msg_type == "Q" else rec.R for pair in target_pairs: payload = messages.get(pair) if payload is None: series[pair].append(float("nan")) else: series[pair].append(self._message_norm(payload, norm)) fig, ax = plt.subplots(figsize=(9, 4.5)) for pair, values in series.items(): label = f"{pair[0]}{pair[1]}" ax.plot(self._steps, values, marker="o", label=label) direction = "Variable→Factor" if msg_type == "Q" else "Factor→Variable" ax.set_xlabel("Iteration") ax.set_ylabel(f"{norm.upper()}-norm") ax.set_title(f"{msg_type} message norms ({direction})") ax.grid(True, alpha=0.3) ax.legend() fig.tight_layout() if savepath: save_path = Path(savepath) save_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(save_path, dpi=150) if show: fig.show() else: plt.close(fig) data = { "steps": list(self._steps), "series": series, "message_type": msg_type, "norm": norm, } return (fig, data) if return_data else fig
[docs] def plot_assignment_heatmap( self, vars_filter: List[str] | None = None, *, show: bool = True, savepath: str | None = None, cmap: str = "viridis", missing_value: float = float("nan"), annotate: bool = True, value_labels: Mapping[int, str] | Sequence[str] | None = None, return_data: bool = False, ) -> plt.Figure | Tuple[plt.Figure, Dict[str, Any]] | None: """Plot variable assignments over time as a heatmap. Args: vars_filter: Optional subset of variables to include. show: Whether to display the figure window. savepath: Optional path to save the generated figure. cmap: Matplotlib colormap name to use for the heatmap. missing_value: Value inserted when an assignment is missing for a step. Defaults to ``NaN`` so gaps appear as empty cells. annotate: Whether to write assignment values inside each cell. value_labels: Optional mapping or ordered list that converts assignment indices to display labels (e.g., ``{0: \"A\", 1: \"B\"}`` or ``[\"A\", \"B\", \"C\"]``). return_data: If True, return the data used for plotting alongside the figure. Returns: The heatmap figure, optionally with the underlying matrix. """ target_vars = self._select_variables(vars_filter) if not target_vars: raise ValueError("No variables available to plot assignments.") steps = self._steps matrix = self._prepare_assignment_matrix(target_vars, steps, missing_value) label_lookup = self._determine_assignment_labels(value_labels, matrix) fig, ax = plt.subplots( figsize=self._calculate_heatmap_figsize(steps, target_vars) ) im = self._draw_assignment_heatmap(ax, matrix, steps, target_vars, cmap) if annotate: self._annotate_assignment_heatmap( ax, matrix, target_vars, steps, im, label_lookup ) self._add_colorbar(fig, ax, im, label_lookup, matrix) fig.tight_layout() if savepath: save_path = Path(savepath) save_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(save_path, dpi=150) if show: fig.show() return None else: plt.close(fig) payload = { "variables": target_vars, "steps": steps, "matrix": matrix, "value_labels": label_lookup or None, } return (fig, payload) if return_data else fig
[docs] def plot_argmin_per_variable( self, vars_filter: List[str] | None = None, *, figsize: tuple[float, float] | None = None, show: bool = True, savepath: str | None = None, combined_savepath: str | None = None, layout: Literal["separate", "combined"] = "separate", ) -> None: """Plot argmin trajectories for selected variables. Args: vars_filter: Optional list of variable names to plot. figsize: Figure size tuple (width, height). show: Whether to display the plot. savepath: Optional path to save individual variable plots (separate layout). combined_savepath: Optional path to save a combined plot. layout: Choose "separate" for per-variable panels or "combined" for a single figure. """ layout_choice = layout.lower() if layout_choice not in {"separate", "combined"}: raise ValueError("layout must be 'separate' or 'combined'") target_vars = self._select_variables(vars_filter, enforce_limit=True) series = self.argmin_series(target_vars) steps = self._steps if not steps: raise ValueError("No steps to plot") if layout_choice == "combined" or len(target_vars) > self._SMALL_PLOT_THRESHOLD: self._plot_combined_argmin( target_vars, series, steps, figsize, savepath, combined_savepath, show, layout_choice, ) else: self._plot_separate_argmin( target_vars, series, steps, figsize, savepath, combined_savepath, show )
def _save_figure(self, path, fig): save_path = Path(path) save_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(save_path, dpi=150) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------
[docs] def plot_bct( self, variable_name: str, iteration: int | None = None, *, value_index: int | None = None, steps_back: int | None = None, show: bool = True, savepath: str | None = None, verbose: bool = False, ) -> BCTCreator: """Plot a Backtrack Cost Tree (BCT) for a variable from snapshots. Reconstructs BCT data from snapshot Q and R messages, then visualizes how costs and beliefs from earlier iterations contribute to the final belief of the specified variable. Args: variable_name: The name of the variable to visualize the BCT for. iteration: The iteration index to trace back from. Defaults to None (last step). If None, uses -1 (the last captured iteration). steps_back: Optional number of steps from the end to anchor the tree. For example, ``steps_back=10`` traces the state 10 steps before the last recorded snapshot. When provided, overrides ``iteration``. show: Whether to display the plot. savepath: Optional path to save the generated figure. verbose: If True, annotate edges with message costs that generated each contribution in addition to assignments and table costs. Returns: The BCTCreator object for further analysis (e.g., convergence analysis, variable comparisons). Can be used to call methods like analyze_convergence(), compare_variables(), export_analysis(), etc. Raises: ValueError: If the variable_name is not found in the snapshots. """ if variable_name not in self._variables: raise ValueError(f"Variable {variable_name} not found in snapshots") builder = self._ensure_bct_builder() total_steps = len(self._records) resolved_step = self._resolve_bct_iteration(iteration, steps_back, total_steps) target_value = value_index if target_value is None: target_value = builder.assignment_for(variable_name, resolved_step) if target_value is None: target_value = 0 root = builder.belief_root(variable_name, resolved_step, int(target_value)) creator = BCTCreator(builder.graph, root) creator.visualize_bct(show=show, save_path=savepath, verbose=verbose) return creator
def _handle_single_factor_plot( self, factor, tables, target_step, fmt, connections, *, cmap="viridis", annotate=True, show=True, savepath=None, ): factor_name = self._factor_name(factor) if factor_name not in tables: available = ", ".join(sorted(tables.keys())) or "none" raise ValueError( f"Factor '{factor_name}' not found in cost tables. Available: {available}" ) matrix, row_labels, col_labels, row_var, col_var = self._prepare_cost_display( np.asarray(tables[factor_name], dtype=float), factor_name, target_step, ) table_str = CostTablePlotter.format_table( matrix, row_labels, col_labels, row_var, col_var, fmt=fmt ) # Plotting logic fig, ax = plt.subplots(figsize=(6, 5)) im = CostTablePlotter.draw_heatmap( ax, matrix, factor_name, target_step, row_labels, col_labels, row_var, col_var, cmap, ) if annotate: self._annotate_heatmap(ax, matrix, row_labels, col_labels, im, fmt) fig.colorbar(im, ax=ax, shrink=0.85) fig.tight_layout() if savepath: save_path = Path(savepath) save_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(save_path, dpi=150) if show: fig.show() else: plt.close(fig) return fig def _plot_cost_table_grid( self, tables, target_step, cmap, annotate, fmt, show, savepath ): factors = sorted(tables.keys()) ncols = min(3, len(factors)) nrows = math.ceil(len(factors) / ncols) fig, axes = plt.subplots( nrows, ncols, figsize=(5.5 * ncols, 4.5 * nrows), constrained_layout=True, squeeze=False, ) flat_axes = axes.flatten() ims: List[plt.AxesImage] = [] for ax, factor_name in zip(flat_axes, factors): matrix, row_labels, col_labels, row_var, col_var = ( self._prepare_cost_display( np.asarray(tables[factor_name], dtype=float), factor_name, target_step, ) ) im = CostTablePlotter.draw_heatmap( ax, matrix, factor_name, target_step, row_labels, col_labels, row_var, col_var, cmap, ) if annotate: self._annotate_heatmap(ax, matrix, row_labels, col_labels, im, fmt) ims.append(im) ax.set_title(factor_name) for ax in flat_axes[len(factors) :]: ax.axis("off") if ims: fig.colorbar( ims[0], ax=flat_axes[: len(factors)], shrink=0.85, location="right", pad=0.06, ) if savepath: save_path = Path(savepath) save_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(save_path, dpi=150) if show: fig.show() else: plt.close(fig) return fig def _annotate_heatmap(self, ax, matrix, row_labels, col_labels, im, fmt): norm = im.norm cmap_fn = im.cmap for i, row_label in enumerate(row_labels): for j, _ in enumerate(col_labels): val = matrix[i, j] rgba = cmap_fn(norm(val)) r, g, b = rgba[:3] luminance = 0.299 * r + 0.587 * g + 0.114 * b txt_color = "white" if luminance < 0.6 else "black" ax.text( j, i, fmt.format(val), ha="center", va="center", color=txt_color, fontsize=9, ) def _prepare_assignment_matrix(self, target_vars, steps, missing_value): matrix = np.full((len(target_vars), len(steps)), missing_value, dtype=float) for col, rec in enumerate(self._records): assignments = rec.assignments for row, var in enumerate(target_vars): value = assignments.get(var) if value is None: continue matrix[row, col] = float(value) if np.all(np.isnan(matrix)): raise ValueError("No assignments recorded for the selected variables.") return matrix def _determine_assignment_labels(self, value_labels, matrix): label_lookup: Dict[int, str] = {} if value_labels is not None: if isinstance(value_labels, Mapping): label_lookup = {int(k): str(v) for k, v in value_labels.items()} elif isinstance(value_labels, Sequence) and not isinstance( value_labels, (str, bytes) ): label_lookup = {idx: str(name) for idx, name in enumerate(value_labels)} else: raise TypeError( "value_labels must be a mapping or a sequence of labels" ) else: present_values = { int(round(v)) for v in np.unique(matrix[~np.isnan(matrix)]) } label_lookup = { val: CostTablePlotter.value_label(val) for val in present_values } return label_lookup def _calculate_heatmap_figsize(self, steps, target_vars): fig_width = max(6.0, 0.6 * len(steps)) fig_height = max(4.0, 0.5 * len(target_vars)) return (fig_width, fig_height) def _draw_assignment_heatmap(self, ax, matrix, steps, target_vars, cmap): im = ax.imshow(matrix, aspect="auto", cmap=cmap, interpolation="nearest") ax.set_xticks(range(len(steps))) ax.set_xticklabels(steps, rotation=45, ha="right") ax.set_yticks(range(len(target_vars))) ax.set_yticklabels(target_vars) ax.set_xlabel("Iteration") ax.set_ylabel("Variable") ax.set_title("Assignment heatmap") return im def _annotate_assignment_heatmap( self, ax, matrix, target_vars, steps, im, label_lookup ): norm = im.norm cmap_fn = im.cmap for row in range(len(target_vars)): for col in range(len(steps)): value = matrix[row, col] if np.isnan(value): continue int_value = int(round(value)) label = label_lookup.get( int_value, CostTablePlotter.value_label(int_value) ) rgba = cmap_fn(norm(value)) r, g, b = rgba[:3] luminance = 0.299 * r + 0.587 * g + 0.114 * b txt_color = "white" if luminance < 0.6 else "black" ax.text( col, row, label, ha="center", va="center", color=txt_color, fontsize=10, fontweight="bold", ) def _add_colorbar(self, fig, ax, im, label_lookup, matrix): cbar = fig.colorbar(im, ax=ax, shrink=0.85) if label_lookup: unique_values = sorted( {int(round(v)) for v in np.unique(matrix[~np.isnan(matrix)])} ) cbar.set_ticks(unique_values) cbar.set_ticklabels( [ label_lookup.get(val, CostTablePlotter.value_label(val)) for val in unique_values ] ) cbar.set_label("Assignment label") else: cbar.set_label("Assignment index") def _plot_combined_argmin( self, target_vars, series, steps, figsize, savepath, combined_savepath, show, layout_choice, ): fig, ax = plt.subplots(figsize=figsize or (12, 6)) color_cycle = plt.rcParams.get("axes.prop_cycle") palette = ( color_cycle.by_key().get("color", []) if color_cycle is not None else [] ) for idx, var in enumerate(target_vars): color = palette[idx % len(palette)] if palette else None self._plot_single_variable_trace(ax, var, series[var], steps, color) ax.set_xlabel("Iteration") ax.set_ylabel("Argmin index") ax.set_title("Belief argmin trajectories") ax.grid(True, alpha=0.3) ax.legend() self._set_integer_ticks(ax, series) plt.tight_layout() if savepath: self._save_figure(savepath, fig) if combined_savepath and layout_choice != "combined": self._save_figure(combined_savepath, fig) if show: fig.show() else: plt.close(fig) def _plot_single_variable_trace(self, ax, var, values, steps, color): xs_all: List[int] = [] ys_all: List[int] = [] segments: List[tuple[list[int], list[int]]] = [] seg_x: List[int] = [] seg_y: List[int] = [] for step, value in zip(steps, values): if value is None: if seg_x: segments.append((seg_x, seg_y)) seg_x, seg_y = [], [] continue xs_all.append(step) ys_all.append(value) seg_x.append(step) seg_y.append(value) if seg_x: segments.append((seg_x, seg_y)) if not xs_all: return label = var for seg_x, seg_y in segments: ax.plot( seg_x, seg_y, color=color, linewidth=1.6, label=label, ) label = None ax.scatter( xs_all, ys_all, color=color, s=60, label=label, ) def _plot_separate_argmin( self, target_vars, series, steps, figsize, savepath, combined_savepath, show ): per_var_fig, axes = plt.subplots( len(target_vars), 1, figsize=figsize or (10, 3 * len(target_vars)) ) if len(target_vars) == 1: axes = [axes] for ax, var in zip(axes, target_vars): self._plot_variable_series(ax, var, steps, series[var]) plt.tight_layout() if savepath: self._save_figure(savepath, per_var_fig) if combined_savepath and len(target_vars) > 1: self._save_figure(combined_savepath, per_var_fig) if show: per_var_fig.show() else: plt.close(per_var_fig) def _prepare_factor_panel_data( self, from_variable: str, factor: FactorLike, step: int ): factor_name = self._factor_name(factor) record = self._snapshot_by_step(step) neighbours = record.N_fac.get(factor_name, []) if from_variable not in neighbours: raise ValueError( f"Variable '{from_variable}' is not connected to factor '{factor_name}' at step {step}." ) labels = getattr(record, "cost_labels", {}).get(factor_name) if not labels or len(labels) != 2: raise ValueError( "plot_factor_costs currently supports binary factors only. " f"Factor '{factor_name}' has variable ordering {labels}." ) try: target_index = labels.index(from_variable) except ValueError as exc: raise ValueError( f"Variable '{from_variable}' not present in factor '{factor_name}' ordering {labels}." ) from exc other_index = 1 - target_index other_variable = labels[other_index] matrix = self.factor_cost_matrix(factor_name, step) if matrix.ndim != 2: raise ValueError( f"Factor '{factor_name}' cost table has shape {matrix.shape}; only 2D tables are supported." ) row_labels = record.dom.get(from_variable) col_labels = record.dom.get(other_variable) if not row_labels or not col_labels: raise ValueError( "Domain labels missing for factor visualisation: " f"rows={row_labels}, cols={col_labels}" ) row_labels = CostTablePlotter.human_domain_labels(row_labels) col_labels = CostTablePlotter.human_domain_labels(col_labels) return ( factor_name, record, target_index, other_variable, matrix, row_labels, col_labels, ) def _calculate_effective_cost( self, matrix, record, from_variable, factor_name, target_index, row_labels, mode ): aligned = matrix if target_index == 0 else np.swapaxes(matrix, 0, 1) # Get Q message from from_variable only (not both variables) from_q_message = record.Q.get((from_variable, factor_name)) from_msg = ( np.zeros(len(row_labels)) if from_q_message is None else np.asarray(from_q_message, dtype=float) ) # Compute effective cost by adding Q message to from_variable's dimension only if target_index == 0: # from_variable is rows effective = aligned + from_msg[:, None] else: # from_variable is columns effective = aligned + from_msg[None, :] # Compute R message by reducing over the OTHER variable's dimension reduce_axis = 1 - target_index r_message = ( np.min(effective, axis=reduce_axis) if mode == "min" else np.max(effective, axis=reduce_axis) ) return effective, r_message, aligned def _compute_winners( self, effective, r_message, target_index, row_labels, col_labels, mode ): # Determine tolerance for comparisons tol = 1e-12 + 1e-9 * max(1.0, np.ptp(effective)) # Find cells that produce each R message value (primary highlighting) winners = np.zeros_like(effective, dtype=bool) if target_index == 0: for i in range(len(row_labels)): winners[i, :] = np.abs(effective[i, :] - r_message[i]) <= tol else: for j in range(len(col_labels)): winners[:, j] = np.abs(effective[:, j] - r_message[j]) <= tol # Find the absolute minimum/maximum of R message (secondary highlighting) r_best = np.min(r_message) if mode == "min" else np.max(r_message) r_best_indices = np.where(np.abs(r_message - r_best) <= tol)[0] # Mark cells that produce the best R value best_winners = np.zeros_like(effective, dtype=bool) if target_index == 0: # R is per row for i in r_best_indices: best_winners[i, :] = winners[i, :] else: # R is per column for j in r_best_indices: best_winners[:, j] = winners[:, j] return winners, best_winners def _draw_factor_panel_heatmap( self, ax, aligned, factor_name, step, row_labels, col_labels, from_variable, other_variable, cmap, winners, best_winners, annotate, text_color, fmt, ): # Draw heatmap showing original cost table (not effective cost) im = CostTablePlotter.draw_heatmap( ax, aligned, factor_name, step, row_labels, col_labels, from_variable, other_variable, cmap, ) # Draw two-level highlighting with improved visibility for i in range(len(row_labels)): for j in range(len(col_labels)): # Primary highlighting (red) for all cells that produce R message values if winners[i, j] and not best_winners[i, j]: rect = Rectangle( (j - 0.45, i - 0.45), 0.9, 0.9, fill=True, facecolor="red", alpha=0.15, linewidth=3.5, edgecolor="red", ) ax.add_patch(rect) # Secondary highlighting (gold/orange) for cells that produce best R value if best_winners[i, j]: rect = Rectangle( (j - 0.45, i - 0.45), 0.9, 0.9, fill=True, facecolor="gold", alpha=0.4, linewidth=4, edgecolor="darkorange", ) ax.add_patch(rect) if annotate: ax.text( j, i, fmt.format(aligned[i, j]), ha="center", va="center", color=text_color, fontsize=10, fontweight="bold" if winners[i, j] else "normal", ) ax.set_xlim(-0.5, len(col_labels) - 0.5) return im def _snapshot_by_step(self, step: int) -> EngineSnapshot: try: idx = self._steps.index(step) except ValueError as exc: # pragma: no cover - defensive available = ", ".join(str(s) for s in self._steps) raise ValueError( f"No snapshot recorded for step {step}. Available steps: {available}" ) from exc return self._records[idx] @staticmethod def _factor_name(factor: FactorLike) -> str: if isinstance(factor, FactorAgent): return factor.name if isinstance(factor, str): return factor if name := getattr(factor, "name", None): return str(name) else: raise ValueError( "Factor reference must be a name or an object with a 'name' attribute" ) @staticmethod def _infer_domain_size(var: str, record: EngineSnapshot) -> int: if size := len(record.dom.get(var, [])): return size assignments = record.assignments.get(var) return int(assignments) + 1 if assignments is not None else 0 def _prepare_cost_display( self, matrix: np.ndarray, factor: str, step: int, ) -> Tuple[np.ndarray, List[str], List[str], str, str]: if matrix.ndim != 2: raise ValueError( f"Factor '{factor}' cost table has shape {matrix.shape}; only 2D tables are supported for plotting." ) row_labels, col_labels = self.factor_cost_labels(factor, step) row_labels = CostTablePlotter.human_domain_labels(row_labels) col_labels = CostTablePlotter.human_domain_labels(col_labels) if matrix.shape != (len(row_labels), len(col_labels)): raise ValueError( "Cost table shape does not match domain sizes: " f"matrix={matrix.shape}, rows={len(row_labels)}, cols={len(col_labels)}" ) record = self._snapshot_by_step(step) row_var, col_var = getattr(record, "cost_labels", {}).get( factor, ["rows", "cols"] ) return matrix, row_labels, col_labels, row_var, col_var def _infer_message_mode(self) -> Literal["min", "max"]: """Infer whether to highlight minima or maxima from snapshot metadata.""" if not self._records: return "min" metadata = getattr(self._records[0], "metadata", {}) or {} name = str(metadata.get("computator", "")).lower() return "max" if "max" in name else "min" def _render_factor_panel( self, ax: plt.Axes, # pyright: ignore[reportPrivateImportUsage] from_variable: str, factor: FactorLike, step: int, mode: Literal["min", "max"], cmap: str, annotate: bool, highlight_color: str, text_color: str, fmt: str, ) -> Tuple[np.ndarray, np.ndarray, plt.AxesImage]: # pyright: ignore[reportPrivateImportUsage] ( factor_name, record, target_index, other_variable, matrix, row_labels, col_labels, ) = self._prepare_factor_panel_data(from_variable, factor, step) effective, r_message, aligned = self._calculate_effective_cost( matrix, record, from_variable, factor_name, target_index, row_labels, mode ) winners, best_winners = self._compute_winners( effective, r_message, target_index, row_labels, col_labels, mode ) im = self._draw_factor_panel_heatmap( ax, aligned, factor_name, step, row_labels, col_labels, from_variable, other_variable, cmap, winners, best_winners, annotate, text_color, fmt, ) # Compute message for return compatibility (sum of incoming R messages) message = record.R.get((factor_name, from_variable)) message_arr = ( np.zeros(len(row_labels)) if message is None else np.asarray(message, dtype=float) ) winners_mask = np.any(winners, axis=1) # At least one winner per row return message_arr, winners_mask, im def _plot_factor_grid( self, pairs: Sequence[tuple[str, FactorLike]], *, step: int, mode: Literal["min", "max"], cmap: str, annotate: bool, show: bool, savepath: str | None, highlight_color: str, text_color: str, fmt: str, ) -> plt.Figure: if not pairs: raise ValueError("No factor pairs supplied for plotting.") ncols = min(3, len(pairs)) nrows = math.ceil(len(pairs) / ncols) fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 5 * nrows)) flat_axes = np.atleast_1d(axes).flatten() ims: List[plt.AxesImage] = [] for ax, (var, fac) in zip(flat_axes, pairs): message_arr, winners_mask, im = self._render_factor_panel( ax, var, fac, step, mode, cmap, annotate, highlight_color, text_color, fmt, ) ims.append(im) for ax in flat_axes[len(pairs) :]: ax.axis("off") if ims: fig.colorbar(ims[0], ax=flat_axes[: len(pairs)], shrink=0.85) fig.tight_layout() if savepath: save_path = Path(savepath) save_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(save_path, dpi=150) if show: fig.show() else: # pragma: no cover - interactive branch plt.close(fig) return fig def _set_integer_ticks(self, ax, series: Dict[str, List[int | None]]) -> None: """Set y-axis ticks to integer values found in series.""" if values := [v for seq in series.values() for v in seq if v is not None]: uniques = sorted(set(values)) ax.set_yticks(uniques) ax.set_yticklabels([CostTablePlotter.value_label(v) for v in uniques]) @staticmethod def _rolling_window_average( values: Sequence[float], steps: Sequence[int], window: int, ) -> Tuple[List[int], List[float]]: """Compute a trailing rolling mean and align it with the corresponding steps.""" if window <= 1: return list(steps), [float(v) for v in values] arr = np.asarray(values, dtype=float) if len(arr) < window: raise ValueError("window length exceeds number of available points") cumulative = np.cumsum(arr, dtype=float) cumulative[window:] = cumulative[window:] - cumulative[:-window] averages = (cumulative[window - 1 :] / window).tolist() step_list = list(steps) return step_list[window - 1 :], averages def _plot_variable_series( self, ax, var: str, steps: Sequence[int], series: Sequence[int | None] ) -> None: """Plot a single variable's argmin trajectory.""" ax.plot(steps, series, marker="o") ax.set_xlabel("Iteration") ax.set_ylabel("Argmin index") ax.set_title(var) ax.grid(True, alpha=0.3) if valid_values := [value for value in series if value is not None]: uniques = sorted(set(valid_values)) ax.set_yticks(uniques) ax.set_yticklabels([CostTablePlotter.value_label(v) for v in uniques]) def _select_variables( self, vars_filter: List[str] | None, *, enforce_limit: bool = False ) -> List[str]: """Select variables to plot. Args: vars_filter: Optional filter list. enforce_limit: If True, raise error if too many variables. Returns: List of selected variable names. """ if vars_filter: if unknown := [var for var in vars_filter if var not in self._variables]: raise ValueError(f"Unknown variables requested: {', '.join(unknown)}") return list(dict.fromkeys(vars_filter)) if enforce_limit and len(self._variables) > self._MAX_AUTO_VARS: raise ValueError( f"{len(self._variables)} variables available; provide vars_filter to select a subset" ) return sorted(self._variables) @staticmethod def _collect_variables(records: Sequence[EngineSnapshot]) -> set[str]: """Collect all variable names from snapshot assignments.""" vars_set = set() for rec in records: vars_set.update(str(key) for key in rec.assignments.keys()) return vars_set def _collect_message_pairs( self, message_type: Literal["Q", "R"] ) -> List[tuple[str, str]]: """Collect all unique sender/recipient pairs for a message type.""" pairs: set[tuple[str, str]] = set() for rec in self._records: store = rec.Q if message_type == "Q" else rec.R for src, dst in store.keys(): pairs.add((str(src), str(dst))) return sorted(pairs) @staticmethod def _message_norm(values: Any, norm: Literal["l2", "l1", "linf"]) -> float: """Compute a vector norm for a message payload.""" arr = np.asarray(values, dtype=float) if norm == "l2": return float(np.linalg.norm(arr)) if norm == "l1": return float(np.linalg.norm(arr, ord=1)) if norm == "linf": return float(np.linalg.norm(arr, ord=np.inf)) raise ValueError(f"Unsupported norm: {norm}") @staticmethod def _resolve_bct_iteration( iteration: int | None, steps_back: int | None, total_steps: int ) -> int: """Resolve the effective iteration index for BCT visualization.""" if total_steps <= 0: return 0 resolved = iteration if iteration is not None else -1 if steps_back is not None: if steps_back <= 0: raise ValueError("steps_back must be positive") resolved = max(0, total_steps - steps_back) if resolved < 0: resolved = max(0, total_steps + resolved) if resolved >= total_steps: resolved = total_steps - 1 return resolved def _ensure_bct_builder(self) -> SnapshotBCTBuilder: if self._bct_builder is None: self._bct_builder = SnapshotBCTBuilder(self._records) return self._bct_builder __all__ = ["SnapshotVisualizer"]