Source code for propflow.snapshots.step_formatter
"""Step-by-step formatter for belief propagation simulation output.
This module provides tools to format BP simulation steps in an Excel-like
tabular format, showing Q/R messages, cost tables, assignments, and solution
costs per iteration.
"""
from __future__ import annotations
from typing import Dict, Iterable, List, Optional, Sequence, Literal, Tuple
import numpy as np
import re
from .types import EngineSnapshot
def _letter_label(idx: int) -> str:
"""Convert numeric index to letter label (0->a, 1->b, etc.)."""
return chr(ord("a") + idx)
def _format_array(arr: np.ndarray, precision: int = 3) -> str:
"""Format array as comma-separated values."""
values = [
f"{v:.{precision}g}" if isinstance(v, float) else str(int(v))
for v in arr.flatten()
]
return ", ".join(values)
def _normalize_min_zero(arr: np.ndarray) -> np.ndarray:
if arr.size == 0:
return arr
return arr - float(np.min(arr))
def _render_table(headers: Sequence[str], rows: Iterable[Sequence[str]]) -> List[str]:
rows_list = [list(row) for row in rows]
if not rows_list:
return []
col_widths = [len(header) for header in headers]
for row in rows_list:
for idx, value in enumerate(row):
col_widths[idx] = max(col_widths[idx], len(value))
def _format_row(values: Sequence[str]) -> str:
return (
"| "
+ " | ".join(
value.ljust(col_widths[idx]) for idx, value in enumerate(values)
)
+ " |"
)
header_row = _format_row(headers)
sep_row = "|-" + "-|-".join("-" * width for width in col_widths) + "-|"
body_rows = [_format_row(row) for row in rows_list]
return [header_row, sep_row, *body_rows]
def _infer_route_order(variable_names: Sequence[str]) -> List[str]:
def _parse(name: str) -> tuple[int | None, str]:
match = re.search(r"(\\d+)$", name)
if not match:
return (None, name)
return (int(match.group(1)), name[: match.start()])
parsed = []
for name in variable_names:
number, prefix = _parse(name)
parsed.append((prefix.lower(), number, name))
if any(number is not None for _, number, _ in parsed):
parsed.sort(key=lambda item: (item[0], item[1] is None, item[1] or 0, item[2]))
else:
parsed.sort(key=lambda item: item[2])
return [name for _, _, name in parsed]
[docs]
class StepByStepFormatter:
"""Formats BP simulation steps in Excel-like tabular format.
This class takes a sequence of EngineSnapshots and provides methods to
format them as readable step-by-step output showing:
- Cost tables for all factors
- Q messages (variable -> factor) per iteration
- R messages (factor -> variable) per iteration
- Variable assignments and beliefs per iteration
- Solution cost per iteration
Example:
>>> from propflow.snapshots import StepByStepFormatter
>>> formatter = StepByStepFormatter(engine.snapshot_manager.snapshots)
>>> print(formatter.format_all_steps())
"""
def __init__(
self,
snapshots: Sequence[EngineSnapshot],
normalize_messages: bool = True,
route_filter: Literal["both", "cw", "ccw"] = "both",
route_order: Optional[Sequence[str]] = None,
ignore_pairs: Optional[Sequence[Tuple[str, str]]] = None,
) -> None:
"""Initialize formatter with snapshot records.
Args:
snapshots: A sequence of EngineSnapshot objects from a BP simulation.
normalize_messages: If True, normalize Q/R messages by subtracting
the minimum value per message before formatting.
route_filter: Filter messages by route direction (``"cw"``,
``"ccw"``, or ``"both"``). Defaults to ``"both"``.
route_order: Optional explicit variable ordering used to infer
clockwise/counter-clockwise direction. If omitted, the order
is inferred from variable names.
ignore_pairs: Optional list of ``(sender, recipient)`` message pairs
to omit from both Q and R output.
Raises:
ValueError: If snapshots is empty.
"""
if not snapshots:
raise ValueError("Cannot format empty snapshot sequence")
self._snapshots = list(snapshots)
self._sorted_steps = sorted(s.step for s in self._snapshots)
self._step_to_snapshot = {s.step: s for s in self._snapshots}
self._normalize_messages = normalize_messages
self._route_filter = route_filter.lower()
self._ignore_pairs = set(ignore_pairs or [])
if self._route_filter not in {"both", "cw", "ccw"}:
raise ValueError("route_filter must be 'both', 'cw', or 'ccw'")
# extract variable and factor names from first snapshot
first = self._snapshots[0]
self._variables = sorted(first.dom.keys())
self._factors = sorted(first.cost_tables.keys()) if first.cost_tables else []
if route_order is None:
self._route_order = _infer_route_order(self._variables)
else:
missing = set(self._variables) - set(route_order)
extra = set(route_order) - set(self._variables)
if missing:
raise ValueError(
f"route_order is missing variables: {', '.join(sorted(missing))}"
)
if extra:
raise ValueError(
f"route_order contains unknown variables: {', '.join(sorted(extra))}"
)
self._route_order = list(route_order)
self._route_index = {name: idx for idx, name in enumerate(self._route_order)}
if first.cost_labels:
self._factor_neighbors = {
name: list(labels) for name, labels in first.cost_labels.items()
}
else:
self._factor_neighbors = {
name: list(neighbors) for name, neighbors in first.N_fac.items()
}
@property
def variables(self) -> List[str]:
"""List of variable names in the problem."""
return list(self._variables)
@property
def factors(self) -> List[str]:
"""List of factor names in the problem."""
return list(self._factors)
@property
def steps(self) -> List[int]:
"""List of step numbers available."""
return list(self._sorted_steps)
def _route_direction(
self, var_name: str, factor_name: str
) -> Literal["cw", "ccw", "both"]:
neighbors = self._factor_neighbors.get(factor_name, [])
if var_name not in neighbors or len(neighbors) != 2:
return "both"
other = neighbors[0] if neighbors[1] == var_name else neighbors[1]
if var_name not in self._route_index or other not in self._route_index:
return "both"
total = len(self._route_order)
if total < 2:
return "both"
var_idx = self._route_index[var_name]
other_idx = self._route_index[other]
if (var_idx + 1) % total == other_idx:
return "cw"
if (var_idx - 1) % total == other_idx:
return "ccw"
return "both"
def _route_allows(self, var_name: str, factor_name: str) -> bool:
if self._route_filter == "both":
return True
return self._route_direction(var_name, factor_name) == self._route_filter
def _other_variable(self, var_name: str, factor_name: str) -> Optional[str]:
neighbors = self._factor_neighbors.get(factor_name, [])
if var_name not in neighbors or len(neighbors) != 2:
return None
return neighbors[0] if neighbors[1] == var_name else neighbors[1]
def _route_allows_message(
self,
*,
kind: Literal["Q", "R"],
sender: str,
recipient: str,
) -> bool:
if self._route_filter == "both":
return True
if kind == "Q":
return self._route_direction(sender, recipient) == self._route_filter
other = self._other_variable(recipient, sender)
if other is None:
return False
return self._route_direction(other, sender) == self._route_filter
def _iter_messages(
self,
messages: Dict[Tuple[str, str], np.ndarray],
*,
kind: Literal["Q", "R"],
) -> Iterable[Tuple[str, str, str]]:
for (sender, recipient), data in sorted(messages.items()):
if (sender, recipient) in self._ignore_pairs:
continue
if not self._route_allows_message(
kind=kind, sender=sender, recipient=recipient
):
continue
arr = np.asarray(data)
if self._normalize_messages:
arr = _normalize_min_zero(arr)
formatted = _format_array(arr)
yield sender, recipient, formatted
@property
def domain_size(self) -> int:
"""Domain size of variables (from first snapshot)."""
first = self._snapshots[0]
if first.dom:
return len(next(iter(first.dom.values())))
return 0
[docs]
def format_cost_tables(self, use_letters: bool = True) -> str:
"""Format cost tables for all factors.
Args:
use_letters: If True, use letter labels (a, b, ...) instead of numbers.
Returns:
Formatted string showing all cost tables.
"""
first = self._snapshots[0]
if not first.cost_tables:
return "No cost tables available\n"
lines = ["=" * 60, "COST TABLES", "=" * 60, ""]
for factor_name in self._factors:
table = first.cost_tables.get(factor_name)
labels = first.cost_labels.get(factor_name, [])
if table is None:
continue
lines.append(f"Factor: {factor_name.upper()}")
if labels:
lines.append(f" Connected variables: {', '.join(labels)}")
# format as 2D table if binary factor
if table.ndim == 2:
domain = table.shape[0]
row_labels = [
_letter_label(i) if use_letters else str(i) for i in range(domain)
]
col_labels = row_labels.copy()
# header row
header = " " + " ".join(f"{lbl:>5}" for lbl in col_labels)
lines.append(header)
# data rows
for i, row_label in enumerate(row_labels):
row_vals = " ".join(f"{table[i, j]:>5.3g}" for j in range(domain))
lines.append(f" {row_label} {row_vals}")
else:
# 1D unary factor
domain = table.shape[0]
labels_str = [
_letter_label(i) if use_letters else str(i) for i in range(domain)
]
lines.append(
" "
+ " ".join(
f"{lbl}: {table[i]:.3g}" for i, lbl in enumerate(labels_str)
)
)
lines.append("")
return "\n".join(lines)
[docs]
def format_iteration(
self,
step: int,
use_letters: bool = True,
show: Literal["text", "table"] = "text",
) -> str:
"""Format Q/R messages, assignments, and cost for one iteration.
Args:
step: The step number to format.
use_letters: If True, use letter labels for domain values.
show: ``"text"`` (default) prints the existing format; ``"table"``
renders Q/R messages in a tabular layout.
Returns:
Formatted string for the iteration.
"""
if step not in self._step_to_snapshot:
return f"Step {step} not found\n"
snapshot = self._step_to_snapshot[step]
lines = ["-" * 60, f"ITERATION {step}", "-" * 60, ""]
# q messages (variable -> factor)
lines.append("Q Messages (Variable -> Factor):")
if show == "table":
q_rows = list(self._iter_messages(snapshot.Q, kind="Q"))
if q_rows:
lines.extend(_render_table(["Sender", "Recipient", "Message"], q_rows))
else:
lines.append(" (no Q messages)")
else:
wrote_q = False
for sender, recipient, formatted in self._iter_messages(
snapshot.Q, kind="Q"
):
lines.append(f" {sender} -> {recipient}: [{formatted}]")
wrote_q = True
if show != "table" and not wrote_q:
lines.append(" (no Q messages)")
lines.append("")
# r messages (factor -> variable)
lines.append("R Messages (Factor -> Variable):")
if show == "table":
r_rows = list(self._iter_messages(snapshot.R, kind="R"))
if r_rows:
lines.extend(_render_table(["Sender", "Recipient", "Message"], r_rows))
else:
lines.append(" (no R messages)")
else:
wrote_r = False
for sender, recipient, formatted in self._iter_messages(
snapshot.R, kind="R"
):
lines.append(f" {sender} -> {recipient}: [{formatted}]")
wrote_r = True
if show != "table" and not wrote_r:
lines.append(" (no R messages)")
lines.append("")
# assignments
lines.append("Assignments:")
for var_name in self._variables:
assignment = snapshot.assignments.get(var_name, "?")
if use_letters and isinstance(assignment, int):
assignment_label = _letter_label(assignment)
else:
assignment_label = str(assignment)
lines.append(f" {var_name} = {assignment_label}")
lines.append("")
# beliefs (optional)
if snapshot.beliefs:
lines.append("Beliefs:")
for var_name in self._variables:
belief = snapshot.beliefs.get(var_name)
if belief is not None:
formatted = _format_array(np.asarray(belief))
lines.append(f" {var_name}: [{formatted}]")
lines.append("")
# global cost
if snapshot.global_cost is not None:
lines.append(f"Solution Cost: {snapshot.global_cost:.3g}")
else:
lines.append("Solution Cost: (not computed)")
# damping factor
if snapshot.lambda_ != 0:
lines.append(f"Damping Factor: {snapshot.lambda_:.3g}")
lines.append("")
return "\n".join(lines)
[docs]
def format_all_steps(
self,
include_cost_tables: bool = True,
show: Literal["text", "table"] = "text",
) -> str:
"""Return complete step-by-step output.
Args:
include_cost_tables: If True, include cost tables at the beginning.
show: ``"text"`` (default) prints the existing format; ``"table"``
renders Q/R messages in a tabular layout.
Returns:
Complete formatted output for all iterations.
"""
parts = []
# header
parts.append("=" * 60)
parts.append("BELIEF PROPAGATION STEP-BY-STEP OUTPUT")
parts.append(f"Variables: {', '.join(self._variables)}")
parts.append(f"Factors: {', '.join(self._factors)}")
parts.append(f"Domain size: {self.domain_size}")
parts.append(f"Total iterations: {len(self._sorted_steps)}")
parts.append("=" * 60)
parts.append("")
# cost tables
if include_cost_tables:
parts.append(self.format_cost_tables())
# iterations
for step in self._sorted_steps:
parts.append(self.format_iteration(step, show=show))
return "\n".join(parts)
[docs]
def format_summary(self) -> str:
"""Return a compact summary of the simulation.
Returns:
Summary string with initial/final costs and assignments.
"""
lines = ["SIMULATION SUMMARY", "-" * 40]
if self._sorted_steps:
first_snapshot = self._step_to_snapshot[self._sorted_steps[0]]
last_snapshot = self._step_to_snapshot[self._sorted_steps[-1]]
initial_cost = first_snapshot.global_cost
final_cost = last_snapshot.global_cost
lines.append(f"Total iterations: {len(self._sorted_steps)}")
if initial_cost is not None:
lines.append(f"Initial cost: {initial_cost:.3g}")
if final_cost is not None:
lines.append(f"Final cost: {final_cost:.3g}")
if initial_cost is not None and final_cost is not None:
improvement = initial_cost - final_cost
lines.append(f"Cost improvement: {improvement:.3g}")
lines.append("")
lines.append("Final assignments:")
for var_name in self._variables:
assignment = last_snapshot.assignments.get(var_name, "?")
lines.append(
f" {var_name} = {_letter_label(assignment) if isinstance(assignment, int) else assignment}"
)
return "\n".join(lines)
__all__ = ["StepByStepFormatter"]