Source code for propflow.utils.fg_utils

"""Utilities for creating, loading, and manipulating factor graphs.

This module provides a collection of helper functions and classes for common
tasks related to factor graphs, such as building graphs with specific
topologies (random, cycle), calculating bounds, and safely handling pickled
graph objects.
"""

import pickle
import random
import sys
from functools import lru_cache
from typing import Any, Callable, Dict, List, Tuple

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from matplotlib.patches import Patch, Rectangle

from ..bp.factor_graph import FactorGraph
from ..configs.global_config_mapping import get_ct_factory
from ..core.agents import FactorAgent, VariableAgent
from .path_utils import find_project_root

project_root = find_project_root()
sys.path.append(str(project_root))

_MAX_SEED = 2**63 - 1


def _make_variable(idx: int, domain: int) -> VariableAgent:
    """Creates a single `VariableAgent` with a standardized name."""
    return VariableAgent(name=f"x{idx}", domain=domain)


def _make_factor(
    name: str, domain: int, ct_factory: Callable | str, ct_params: dict
) -> FactorAgent:
    """Creates a single `FactorAgent`, deferring cost table creation."""
    ct_fn = get_ct_factory(ct_factory)
    return FactorAgent(
        name=name, domain=domain, ct_creation_func=ct_fn, param=ct_params
    )


def _build_factor_edge_list(
    edges: List[Tuple[VariableAgent, VariableAgent]],
    domain_size: int,
    ct_factory: Any,
    ct_params: dict,
) -> Dict[FactorAgent, List[VariableAgent]]:
    """Creates factor nodes for binary constraints and maps them to variables."""
    edge_dict = {}
    for a, b in edges:
        fname = f"f{a.name[1:]}{b.name[1:]}"
        fnode = _make_factor(fname, domain_size, ct_factory, ct_params)
        edge_dict[fnode] = [a, b]
    return edge_dict


def _resolve_graph_seed(seed: int | None) -> int:
    """Resolve a deterministic seed, respecting the global numpy RNG."""

    if seed is not None:
        return int(seed) % _MAX_SEED

    # numpy's legacy global RNG honors np.random.seed calls from user scripts.
    return int(np.random.randint(0, _MAX_SEED, dtype=np.int64))


def _make_connections_density(
    variable_list: List[VariableAgent], density: float, *, seed: int | None = None
) -> List[Tuple[VariableAgent, VariableAgent]]:
    """Creates a random graph of variable connections based on a given density."""
    graph_seed = _resolve_graph_seed(seed)
    rng = random.Random(graph_seed)
    num_vars = len(variable_list)
    r_graph = nx.erdos_renyi_graph(num_vars, density, seed=graph_seed)
    if num_vars > 1 and not nx.is_connected(r_graph):
        components = list(nx.connected_components(r_graph))
        # Connect components sequentially to ensure a single connected component.
        for comp_a, comp_b in zip(components, components[1:]):
            u = rng.choice(tuple(comp_a))
            v = rng.choice(tuple(comp_b))
            r_graph.add_edge(u, v)
    variable_map = dict(enumerate(variable_list))
    full_graph = nx.relabel_nodes(r_graph, variable_map)
    return list(full_graph.edges())


[docs] class FGBuilder: """A builder class providing static methods to construct factor graphs."""
[docs] @staticmethod def build_from_edges( variables: List[VariableAgent], factors: List[FactorAgent], edges: Dict[FactorAgent, List[VariableAgent]], ) -> FactorGraph: """Builds a factor graph from the provided variables, factors, and edges. Args: variables (List[VariableAgent]): The variable nodes in the graph. factors (List[FactorAgent]): The factor nodes in the graph. edges (Dict[FactorAgent, List[VariableAgent]]): The edges connecting factors to variables. Returns: FactorGraph: The constructed factor graph. """ return FactorGraph(variables, factors, edges)
[docs] @staticmethod def build_random_graph( num_vars: int, domain_size: int, ct_factory: Callable | str, ct_params: Dict[str, Any], density: float, *, seed: int | None = None, ) -> FactorGraph: """Builds a factor graph with random binary constraints. Args: num_vars: The number of variables in the graph. domain_size: The size of the domain for each variable. ct_factory: The factory for creating cost tables. ct_params: Parameters for the cost table factory. density: The density of the graph (probability of an edge). seed: Optional seed controlling the random topology. When omitted, randomness is derived from the globally-configured numpy and ``random`` RNGs so user-level seeding still produces deterministic graphs. Returns: A `FactorGraph` instance with a random topology. """ variables = [_make_variable(i + 1, domain_size) for i in range(num_vars)] connections = _make_connections_density(variables, density, seed=seed) edges = _build_factor_edge_list(connections, domain_size, ct_factory, ct_params) factors = list(edges.keys()) return FactorGraph(variables, factors, edges)
[docs] @staticmethod def build_cycle_graph( num_vars: int, domain_size: int, ct_factory: Callable | str, ct_params: Dict[str, Any], **kwargs, ) -> FactorGraph: """Builds a factor graph with a simple cycle topology. The graph structure is `x1 – f12 – x2 – ... – xn – fn1 – x1`. Args: num_vars: The number of variables in the cycle. domain_size: The size of the domain for each variable. ct_factory: The factory for creating cost tables. ct_params: Parameters for the cost table factory. **kwargs: Catches unused arguments like `density` for API consistency. Returns: A `FactorGraph` instance with a cycle topology. """ variables = [_make_variable(i + 1, domain_size) for i in range(num_vars)] edges = {} for j in range(num_vars): a, b = variables[j], variables[(j + 1) % num_vars] f_name = f"f{a.name[1:]}{b.name[1:]}" f_node = _make_factor(f_name, domain_size, ct_factory, ct_params) edges[f_node] = [a, b] factors = list(edges.keys()) return FactorGraph(variables, factors, edges)
[docs] @staticmethod def build_lemniscate_graph( num_vars: int, domain_size: int, ct_factory: Callable | str, ct_params: Dict[str, Any], **kwargs, ) -> FactorGraph: """Builds a factor graph with a lemniscate (∞) topology. The structure consists of two cycles that share a single central variable, producing a figure-eight shape. Each loop is guaranteed to contain at least two distinct variables in addition to the central node. Args: num_vars: Total number of variables in the graph. Must be >= 5. domain_size: The size of the domain for each variable. ct_factory: Factory used to create cost tables for the factors. ct_params: Parameters forwarded to the cost table factory. **kwargs: Captures unused parameters (e.g., density) for API parity. Returns: A `FactorGraph` instance shaped like a lemniscate. Raises: ValueError: If fewer than five variables are provided. """ if num_vars < 5: raise ValueError( "Lemniscate graph requires at least 5 variables to form two loops." ) variables = [_make_variable(i + 1, domain_size) for i in range(num_vars)] center = variables[0] remaining = variables[1:] left_size = max(2, len(remaining) // 2) right_size = len(remaining) - left_size if right_size < 2: shortage = 2 - right_size left_size -= shortage right_size += shortage if left_size < 2 or right_size < 2: raise ValueError( "Lemniscate graph requires at least 5 variables to form two loops." ) left_loop_nodes = [center, *remaining[:left_size]] right_loop_nodes = [center, *remaining[left_size:]] edge_pairs: List[Tuple[VariableAgent, VariableAgent]] = [] for loop in (left_loop_nodes, right_loop_nodes): edge_pairs.extend( (loop[idx], loop[idx + 1]) for idx in range(len(loop) - 1) ) edge_pairs.append((loop[-1], loop[0])) params = ct_params or {} edges = _build_factor_edge_list(edge_pairs, domain_size, ct_factory, params) factors = list(edges.keys()) return FactorGraph(variables, factors, edges)
# Provide aliases for API compatibility/user preference. create_lemniscate_graph = build_lemniscate_graph create_leminscate_graph = build_lemniscate_graph
[docs] @staticmethod def build_with_unary_costs( base_graph: FactorGraph, unary_costs: Dict[str, np.ndarray], ) -> FactorGraph: """Extends a factor graph with unary constraints (per-variable cost biases). Unary factors are single-variable factors that act as priors or biases for individual variables. They're useful for adding soft constraints or preferences on variable assignments. Args: base_graph: An existing factor graph to extend. unary_costs: A dictionary mapping variable names to 1D numpy arrays of costs. The array length must match the variable's domain size. Returns: A new FactorGraph with the unary factors added. Example: >>> fg = FGBuilder.build_cycle_graph(3, 2, create_random_int_table, {"low": 0, "high": 10}) >>> unary = {"x1": np.array([0, 5]), "x2": np.array([3, 0])} >>> fg_with_unary = FGBuilder.build_with_unary_costs(fg, unary) """ # copy existing components variables = list(base_graph.variables) factors = list(base_graph.factors) edges = dict(base_graph.edges) # create unary factors var_map = {v.name: v for v in variables} for var_name, costs in unary_costs.items(): if var_name not in var_map: raise ValueError(f"Variable '{var_name}' not found in graph") var = var_map[var_name] cost_arr = np.asarray(costs, dtype=float) if cost_arr.ndim != 1: raise ValueError(f"Unary costs for '{var_name}' must be 1D array") if len(cost_arr) != var.domain: raise ValueError( f"Unary costs for '{var_name}' must have length {var.domain}, got {len(cost_arr)}" ) # create unary factor with lambda returning the fixed cost array def make_unary_ct(costs_array): def ct_func(num_vars, domain_size, **kwargs): return costs_array.copy() return ct_func unary_factor = FactorAgent( name=f"u{var_name[1:]}", # u1 for x1, etc domain=var.domain, ct_creation_func=make_unary_ct(cost_arr), param={}, ) factors.append(unary_factor) edges[unary_factor] = [var] return FactorGraph(variables, factors, edges)
def get_message_shape(domain_size: int, connections: int = 2) -> tuple[int, ...]: """Calculates the shape of a cost table for a factor. Args: domain_size: The size of the domain for each connected variable. connections: The number of variables connected to the factor. Returns: A tuple representing the shape of the cost table. """ return (domain_size,) * connections @lru_cache(maxsize=128) def get_broadcast_shape(ct_dims: int, domain_size: int, ax: int) -> tuple[int, ...]: """Calculates the shape for broadcasting a message into a cost table.""" shape = [1] * ct_dims shape[ax] = domain_size return tuple(shape) def generate_random_cost(fg: FactorGraph) -> float: """Calculates a total cost based on a random assignment for each factor. Args: fg: The factor graph to evaluate. Returns: The sum of costs from a random assignment in each factor's cost table. """ cost = 0.0 for fact in fg.factors: random_index = tuple( np.random.randint(0, fact.domain, size=fact.cost_table.ndim) ) cost += fact.cost_table[random_index] return cost class SafeUnpickler(pickle.Unpickler): """A custom unpickler to handle module path changes during deserialization. This class overrides `find_class` to intercept and correct module paths that may have changed between the time of pickling and unpickling, preventing `ImportError` or `AttributeError`. """ def find_class(self, module: str, name: str) -> Any: """Finds a class, handling potential module path changes.""" module_mapping = { "bp.factor_graph": "propflow.bp.factor_graph", "bp.agents": "propflow.core.agents", "bp.components": "propflow.core.components", } module = module_mapping.get(module, module) # handle renamed experiments directory (was "expiriments") if module.startswith("expiriments."): module = "experiments." + module[len("expiriments.") :] try: return super().find_class(module, name) except (ImportError, AttributeError) as e: print(f"Warning: Could not import {module}.{name}: {e}") return type(name, (), {}) def load_pickle_safely(file_path: str) -> Any: """Loads a pickle file using the `SafeUnpickler` to prevent import errors. Args: file_path: The path to the pickle file. Returns: The deserialized object, or `None` if an error occurs. """ try: with open(file_path, "rb") as f: return SafeUnpickler(f).load() except Exception as e: print(f"Error loading pickle: {e}") return None def repair_factor_graph(fg: FactorGraph) -> FactorGraph: """Attempts to repair a loaded factor graph by ensuring essential attributes exist. This is useful when unpickling older versions of `FactorGraph` objects that may be missing attributes added in newer versions. Args: fg: The `FactorGraph` object to repair. Returns: The repaired `FactorGraph` object. """ if not hasattr(fg, "G") or fg.G is None: print("Initializing missing NetworkX graph") fg.G = nx.Graph() if hasattr(fg, "variables") and hasattr(fg, "factors"): fg.G.add_nodes_from(fg.variables) fg.G.add_nodes_from(fg.factors) for factor in fg.factors: if hasattr(factor, "connection_number"): for var, dim in factor.connection_number.items(): fg.G.add_edge(factor, var, dim=dim) for node in fg.G.nodes(): if not hasattr(node, "mailbox"): node.mailbox = [] if ( hasattr(node, "type") and node.type == "factor" and (not hasattr(node, "cost_table") or node.cost_table is None) ): try: if hasattr(node, "initiate_cost_table"): node.initiate_cost_table() except Exception as e: print(f"Could not initialize cost table for {node}: {e}") return fg def get_bound(factor_graph: FactorGraph, reduce_func: Callable = np.min) -> float: """Calculates a simple bound on the total cost of the factor graph. This is typically used to get a lower bound by summing the minimum values from each factor's cost table. Args: factor_graph: The factor graph to analyze. reduce_func: The function to apply to each cost table to get a single value (e.g., `np.min` for a lower bound, `np.max` for an upper bound). Defaults to `np.min`. Returns: The calculated bound. """ bound = 0.0 for factor in factor_graph.factors: if hasattr(factor, "cost_table") and factor.cost_table is not None: bound += reduce_func(factor.cost_table) return bound def pretty_print_array( A, *, cmap="Blues", fmt="{:.2g}", annotate=True, auto_min=True, auto_max=True, cell_highlights=None, # list[(r,c)] or dict[(r,c)]="label" row_highlights=None, # list[row] or dict[row]="label" label_colors=None, # dict[label]->color title=None, figsize=(6, 4), cbar=True, ): """ Render a NumPy array as a heatmap with optional highlighting. Parameters ---------- A : array-like (2D) Numeric data. cmap : str Matplotlib colormap name. fmt : str Number format for annotations, e.g. "{:.2g}". annotate : bool If True, draw numbers on each cell. auto_min : bool If True, highlight all global minima (ties included) with label 'min'. auto_max : bool If True, highlight all global maxima (ties included) with label 'max'. cell_highlights : list[(r,c)] | dict[(r,c)] -> str Cells to highlight. If list, uses label 'selected'. If dict, value is the legend label for that cell group. row_highlights : list[int] | dict[int] -> str Rows to highlight. If list, labels become 'row {i}'. If dict, value is the legend label for that row. label_colors : dict[str] -> str Custom colors per label. Unknown labels get auto-assigned. title : str Figure title. figsize : (w,h) Figure size in inches. cbar : bool If True, show colorbar. Returns ------- fig, ax : Matplotlib figure and axes. """ A = np.asarray(A) if A.ndim != 2: raise ValueError("A must be 2D") nrows, ncols = A.shape fig, ax = plt.subplots(figsize=figsize) im = ax.imshow(A, cmap=cmap, aspect="equal") if cbar: plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) # Grid lines ax.set_xticks(np.arange(ncols)) ax.set_yticks(np.arange(nrows)) ax.set_xticklabels([str(j) for j in range(ncols)]) ax.set_yticklabels([str(i) for i in range(nrows)]) ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True) ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True) ax.grid(which="minor", color="white", linewidth=1) ax.tick_params(which="minor", bottom=False, left=False) # Normalize inputs used_labels = [] default_label_colors = { "selected": "#FFB000", "row": "#FFD166", "min": "#00C2A0", "max": "#E63946", } if label_colors: default_label_colors.update(label_colors) # Build highlight maps cell_map = {} if cell_highlights is not None: if isinstance(cell_highlights, dict): for (r, c), lab in cell_highlights.items(): cell_map.setdefault(lab, []).append((int(r), int(c))) else: # assume list of (r,c) cell_map.setdefault("selected", []).extend( (int(r), int(c)) for r, c in cell_highlights ) row_map = {} if row_highlights is not None: if isinstance(row_highlights, dict): for r, lab in row_highlights.items(): row_map.setdefault(lab, []).append(int(r)) else: # assume list of rows for r in row_highlights: row_map.setdefault(f"row {int(r)}", []).append(int(r)) # Auto min/max if auto_min: mn = np.min(A) mins = np.argwhere(A == mn) if len(mins): cell_map.setdefault("min", []).extend((int(r), int(c)) for r, c in mins) if auto_max: mx = np.max(A) maxs = np.argwhere(A == mx) if len(maxs): cell_map.setdefault("max", []).extend((int(r), int(c)) for r, c in maxs) # Draw highlights (filled translucent rectangles), then annotate def color_for( label, fallback_cycle=( "#FFB000", "#6A4C93", "#2A9D8F", "#E76F51", "#118AB2", "#EF476F", ), ): if label in default_label_colors: return default_label_colors[label] # Assign a stable color based on label hash return fallback_cycle[abs(hash(label)) % len(fallback_cycle)] # Row highlights for lab, rows in row_map.items(): col = color_for(lab) for r in rows: rect = Rectangle( (-0.5, r - 0.5), ncols, 1, linewidth=2, edgecolor=col, facecolor=col, alpha=0.25, ) ax.add_patch(rect) used_labels.append((lab, col)) # Cell highlights for lab, cells in cell_map.items(): col = color_for(lab) for r, c in cells: rect = Rectangle( (c - 0.5, r - 0.5), 1, 1, linewidth=2, edgecolor=col, facecolor=col, alpha=0.35, ) ax.add_patch(rect) used_labels.append((lab, col)) # Annotations if annotate: # Pick contrasting text color vs background norm = im.norm for i in range(nrows): for j in range(ncols): val = A[i, j] # heuristic contrast: dark text on light cells, white text on dark cells txt_color = "white" if norm(val) > 0.6 else "black" ax.text( j, i, fmt.format(val), ha="center", va="center", color=txt_color, fontsize=10, ) # Legend describing highlight colors if used_labels: # Deduplicate while preserving order seen = set() handles = [] for lab, col in used_labels: if (lab, col) in seen: continue seen.add((lab, col)) handles.append(Patch(facecolor=col, edgecolor=col, alpha=0.6, label=lab)) ax.legend( handles=handles, title="Highlights", loc="upper left", bbox_to_anchor=(1.02, 1.0), borderaxespad=0.0, frameon=False, ) if title: ax.set_title(title) plt.tight_layout() return fig, ax