Source code for propflow.bp.engines

import random
from typing import Dict, Literal, Optional, Sequence

import networkx as nx
import numpy as np

from ..core.agents import FactorAgent, VariableAgent
from ..core.components import Message
from ..policies import damp, damp_factor
from ..policies.cost_reduction import cost_reduction_all_factors_once
from ..policies.splitting import split_all_factors, split_factors
from ..utils.inbox_utils import multiply_messages
from .engine_base import BPEngine


[docs] class Engine(BPEngine): """A basic belief propagation engine. This is a direct alias for `BPEngine` and provides the standard, unmodified belief propagation behavior. """ ...
[docs] class SplitEngine(BPEngine): """A BP engine that applies the factor splitting policy. This engine modifies the factor graph by splitting each factor into two, distributing the original cost between them. This can sometimes help with convergence. """ def __init__(self, *args, split_factor: float = 0.5, **kwargs): """Initializes the SplitEngine. Args: *args: Positional arguments for the base `BPEngine`. split_factor: The proportion of the cost to allocate to the first of the two new factors. Defaults to 0.5. **kwargs: Keyword arguments for the base `BPEngine`. """ self.split_factor = split_factor super().__init__(*args, **kwargs) self._name = "SPFGEngine" self._set_name({"split-": f"{self.split_factor}-{self.split_factor}"})
[docs] def post_init(self) -> None: """Applies the factor splitting policy after initialization.""" split_all_factors(self.graph, self.split_factor)
[docs] class MidRunSplitEngine(BPEngine): """A BP engine that applies factor splitting at a chosen iteration. The engine runs standard BP until ``split_at_iter``. Immediately before that step is computed, it splits the requested factors and either resets messages to zero or transfers current factor-to-variable messages into the split graph. The ``transfer`` mode is an empirical heuristic, not a canonical mid-run split. It preserves the aggregate factor-to-variable contribution at the variable inbox boundary (and therefore the variable belief, and the Q messages to unaffected neighbors). It does not preserve the Q messages going into the split clones themselves: each clone's Q picks up the sibling clone's transferred share, so each clone's first outgoing R is not a canonical continuation of the un-split dynamics. Treat the split iteration as a heuristic injection. """ def __init__( self, *args, split_at_iter: int, split_factor: float = 0.5, split_targets: Sequence[str] | None = None, split_fraction: float | None = None, split_seed: int | None = None, transfer_mode: Literal["reset", "transfer"] = "reset", **kwargs, ) -> None: if split_at_iter < 0: raise ValueError("split_at_iter must be non-negative.") if transfer_mode not in {"reset", "transfer"}: raise ValueError("transfer_mode must be either 'reset' or 'transfer'.") self.split_at_iter = int(split_at_iter) self.split_factor = float(split_factor) self.split_targets = list(split_targets) if split_targets is not None else None self.split_fraction = split_fraction self.split_seed = split_seed self.transfer_mode = transfer_mode self.split_mapping: Dict[str, list[FactorAgent]] = {} self.split_events: list[dict] = [] self._split_applied = False self._pending_split_event: dict | None = None super().__init__(*args, **kwargs) self._name = "MidRunSplitEngine" self._set_name( { "split_at": str(self.split_at_iter), "mode": self.transfer_mode, "split": str(self.split_factor), } )
[docs] def step(self, i: int = 0): if not self._split_applied and i >= self.split_at_iter: self._apply_midrun_split(i) step = super().step(i) if self._pending_split_event is not None: snapshot = self._snapshots.get(i) if snapshot is not None: snapshot.metadata["split_event"] = dict(self._pending_split_event) self._pending_split_event = None return step
def _apply_midrun_split(self, iteration: int) -> None: prior_var_inboxes = self._capture_variable_inboxes() self.split_mapping = split_factors( self.graph, self.split_factor, factor_names=self.split_targets, split_fraction=self.split_fraction, seed=self.split_seed, ) self._split_applied = True self._refresh_graph_views() self.graph.set_computator(self.computator) if self.transfer_mode == "reset": self._reset_messages() elif self.transfer_mode == "transfer": self._transfer_messages(prior_var_inboxes) else: # pragma: no cover - guarded in __init__ raise ValueError(f"Unsupported transfer_mode: {self.transfer_mode}") event = { "iteration": int(iteration), "transfer_mode": self.transfer_mode, "split_factor": self.split_factor, "split_targets": self.split_targets, "split_fraction": self.split_fraction, "split_seed": self.split_seed, "split_mapping": { original: [clone.name for clone in clones] for original, clones in self.split_mapping.items() }, } self.split_events.append(event) self._pending_split_event = event def _refresh_graph_views(self) -> None: var_set, factor_set = nx.bipartite.sets(self.graph.G) self.var_nodes = sorted(var_set, key=lambda node: node.name) self.factor_nodes = sorted(factor_set, key=lambda node: node.name) self.graph_diameter = nx.diameter(self.graph.G) def _capture_variable_inboxes(self) -> Dict[str, list[Message]]: captured: Dict[str, list[Message]] = {} for variable in getattr(self, "var_nodes", []): captured[variable.name] = [msg.copy() for msg in variable.mailer.inbox] return captured def _clear_agent_state(self) -> None: for node in self.graph.G.nodes(): node.empty_mailbox() node.empty_outgoing() if hasattr(node, "_history"): node._history.clear() def _reset_messages(self) -> None: self._clear_agent_state() self._initialize_messages() def _transfer_messages(self, prior_var_inboxes: Dict[str, list[Message]]) -> None: """Redistribute prior R messages across the split clones. For each split factor F with prior message ``R[F->X]``, the variable inbox is rewritten as ``p * R`` from ``F'`` and ``(1 - p) * R`` from ``F''``. This preserves the aggregate contribution ``R[F'->X] + R[F''->X] == R[F->X]_old``, so X's belief and its Q to unaffected neighbors are unchanged on the split iteration. The Q messages going into the clones themselves are not preserved: ``Q[X->F']`` excludes only ``F'`` from the inbox, so the sibling's transferred share ``(1 - p) * R_old`` leaks in and is reflected in each clone's first outgoing R. Treat ``transfer`` as a heuristic for preserving variable-side state, not as a canonical mid-run split. """ self._clear_agent_state() factor_by_name = {factor.name: factor for factor in self.factor_nodes} variable_by_name = {variable.name: variable for variable in self.var_nodes} for variable_name, old_messages in prior_var_inboxes.items(): variable = variable_by_name.get(variable_name) if variable is None: continue for message in old_messages: sender_name = getattr(message.sender, "name", "") if sender_name in self.split_mapping: clones = self.split_mapping[sender_name] weights = [self.split_factor, 1.0 - self.split_factor] for clone, weight in zip(clones, weights): variable.mailer.receive_messages( Message( data=np.asarray(message.data, dtype=float) * weight, sender=clone, recipient=variable, ) ) elif sender_name in factor_by_name: variable.mailer.receive_messages( Message( data=np.copy(message.data), sender=factor_by_name[sender_name], recipient=variable, ) ) self._fill_missing_variable_messages() def _fill_missing_variable_messages(self) -> None: for variable in self.var_nodes: existing = {message.sender.name for message in variable.mailer.inbox} for neighbor in self.graph.G.neighbors(variable): if neighbor.name in existing: continue variable.mailer.receive_messages( Message( data=np.zeros(variable.domain), sender=neighbor, recipient=variable, ) )
[docs] class CostReductionOnceEngine(BPEngine): """A BP engine that applies a one-time cost reduction policy. This engine reduces the costs in the factor tables at the beginning of the simulation and then applies a discount to outgoing messages from factors. """ def __init__(self, *args, reduction_factor: float = 0.5, **kwargs): """Initializes the CostReductionOnceEngine. Args: *args: Positional arguments for the base `BPEngine`. reduction_factor: The factor by which to reduce costs. Defaults to 0.5. **kwargs: Keyword arguments for the base `BPEngine`. """ self.reduction_factor = reduction_factor super().__init__(*args, **kwargs)
[docs] def post_init(self): """Applies the one-time cost reduction after initialization.""" cost_reduction_all_factors_once(self.graph, self.reduction_factor)
[docs] def post_factor_compute(self, factor: FactorAgent, iteration: int): """Applies a discount to outgoing messages from factors.""" multiply_messages(factor.outbox, 0.5)
[docs] class DampingEngine(BPEngine): """A BP engine that applies message damping. Damping averages the message from the previous iteration with the newly computed message. This can help prevent oscillations and improve convergence. """ def __init__(self, *args, damping_factor: float = 0.9, **kwargs): """Initializes the DampingEngine. Args: *args: Positional arguments for the base `BPEngine`. damping_factor: The weight given to the previous message. Defaults to 0.9. **kwargs: Keyword arguments for the base `BPEngine`. """ self.damping_factor = damping_factor super().__init__(*args, **kwargs) self._name = "DampingEngine" self._set_name({"damping": str(self.damping_factor)})
[docs] def post_var_compute(self, var: VariableAgent): """Applies damping after a variable node computes its messages.""" damp(var, self.damping_factor) var.append_last_iteration()
[docs] class QRDampingEngine(BPEngine): """A BP engine that applies message damping to both Q and R messages. Q damping applies to variable → factor messages. R damping applies to factor → variable messages. """ def __init__( self, *args, q_damping_factor: float = 0.0, r_damping_factor: float = 0.0, **kwargs, ): self.q_damping_factor = float(q_damping_factor) self.r_damping_factor = float(r_damping_factor) # Keep a single `damping_factor` attribute for compatibility with # snapshot tooling and existing expectations. self.damping_factor = ( self.q_damping_factor if self.q_damping_factor > 0 else self.r_damping_factor ) super().__init__(*args, **kwargs) self._name = "QRDampingEngine" self._set_name( { "q_damping": str(self.q_damping_factor), "r_damping": str(self.r_damping_factor), } )
[docs] def post_init(self) -> None: """Prime factor history with zero messages so R-damping applies from iter 0.""" if self.r_damping_factor > 0: for factor in self.graph.factors: if not factor._history: zero_msgs = [] for neighbor in self.graph.G.neighbors(factor): zero_msgs.append( Message( sender=factor, recipient=neighbor, data=np.zeros(factor.domain), ) ) factor._history.append(zero_msgs)
[docs] def post_var_compute(self, var: VariableAgent) -> None: if self.q_damping_factor > 0: damp(var, self.q_damping_factor) var.append_last_iteration()
[docs] def post_factor_compute(self, factor: FactorAgent, iteration: int) -> None: if self.r_damping_factor > 0: damp_factor(factor, self.r_damping_factor) factor.append_last_iteration()
[docs] class RDampingEngine(BPEngine): """A BP engine that applies message damping to R messages. Damping averages the R message from the previous iteration with the newly computed message. This can help prevent oscillations and improve convergence. Unlike DampingEngine which damps Q messages (Variable -> Factor), this engine damps R messages (Factor -> Variable). """ def __init__(self, *args, damping_factor: float = 0.9, **kwargs): """Initializes the RDampingEngine. Args: *args: Positional arguments for the base `BPEngine`. damping_factor: The weight given to the previous message. Defaults to 0.9. **kwargs: Keyword arguments for the base `BPEngine`. """ self.damping_factor = damping_factor super().__init__(*args, **kwargs) self._name = "RDampingEngine" self._set_name({"damping": str(self.damping_factor)})
[docs] def post_init(self) -> None: """Initialize history with zero messages to ensure damping works from iter 0.""" # For RDampingEngine (Factor Agents) for factor in self.graph.factors: self._init_agent_history(factor)
def _init_agent_history(self, agent: FactorAgent) -> None: """Helper to prime history with zeros.""" if agent._history: return zero_msgs = [] neighbors = self.graph.G.neighbors(agent) for neighbor in neighbors: msg = Message(sender=agent, recipient=neighbor, data=np.zeros(agent.domain)) zero_msgs.append(msg) agent._history.append(zero_msgs)
[docs] def post_factor_compute(self, factor: FactorAgent, iteration: int): """Applies damping after a factor node computes its messages.""" damp_factor(factor, self.damping_factor) factor.append_last_iteration()
[docs] class DiffusionEngine(BPEngine): """A BP engine that applies spatial message diffusion. Unlike damping which blends messages across time (current vs previous), diffusion blends messages across space (local vs neighbors) at each iteration. This can help smooth the optimization landscape and improve convergence on densely connected graphs. """ def __init__(self, *args, alpha: float = 0.3, **kwargs): """Initializes the DiffusionEngine. Args: *args: Positional arguments for the base `BPEngine`. alpha: Diffusion coefficient in [0, 1]. Higher values = more smoothing. - alpha=0: no diffusion (pure BP) - alpha=0.1-0.3: recommended range for most problems - alpha=1: complete averaging (may lose local information) Defaults to 0.3. **kwargs: Keyword arguments for the base `BPEngine`. """ if not 0 <= alpha <= 1: raise ValueError(f"alpha must be in [0, 1], got {alpha}") self.alpha = alpha super().__init__(*args, **kwargs) self._name = "DiffusionEngine" self._set_name({"alpha": str(self.alpha)})
[docs] def post_var_compute(self, var: VariableAgent) -> None: """Apply spatial diffusion to Q-messages (variable → factor). For each message this variable sends to a factor, blend it with messages from other variables connected to the same factor. """ if self.alpha == 0: return # No diffusion needed # For each message in this variable's outbox for msg in var.outbox: target_factor = msg.recipient # Collect Q-messages from OTHER variables connected to same factor neighbor_msgs = [] for neighbor in self.graph.G.neighbors(target_factor): # Neighbors of a factor are variables if neighbor != var: # Skip self # Find if this neighbor variable has a message to the same factor for neighbor_msg in neighbor.outbox: if neighbor_msg.recipient == target_factor: neighbor_msgs.append(neighbor_msg.data) break # Apply diffusion if neighbors exist if neighbor_msgs: neighbor_avg = np.mean(neighbor_msgs, axis=0) # Blend: (1-α) × local + α × neighbor_average msg.data = (1 - self.alpha) * msg.data + self.alpha * neighbor_avg
[docs] def post_factor_compute(self, factor: FactorAgent, iteration: int) -> None: """Apply spatial diffusion to R-messages (factor → variable). For each message this factor sends to a variable, blend it with messages from other factors connected to the same variable. """ if self.alpha == 0: return # No diffusion needed # For each message in this factor's outbox for msg in factor.outbox: target_var = msg.recipient # Collect R-messages from OTHER factors connected to same variable neighbor_msgs = [] for neighbor in self.graph.G.neighbors(target_var): # Neighbors of a variable are factors if neighbor != factor: # Skip self # Find if this neighbor factor has a message to the same variable for neighbor_msg in neighbor.outbox: if neighbor_msg.recipient == target_var: neighbor_msgs.append(neighbor_msg.data) break # Apply diffusion if neighbors exist if neighbor_msgs: neighbor_avg = np.mean(neighbor_msgs, axis=0) # Blend: (1-α) × local + α × neighbor_average msg.data = (1 - self.alpha) * msg.data + self.alpha * neighbor_avg
[docs] class DampingSCFGEngine(DampingEngine, SplitEngine): """A BP engine that combines message damping and factor splitting.""" def __init__(self, *args, **kwargs): """Initializes the DampingSCFGEngine. This engine inherits parameters from both `DampingEngine` and `SplitEngine`. Args: *args: Positional arguments for the base engines. **kwargs: Keyword arguments for the base engines (e.g., `damping_factor`, `split_factor`). """ kwargs.setdefault("split_factor", 0.6) kwargs.setdefault("damping_factor", 0.9) super().__init__(*args, **kwargs) self.split_factor = kwargs.get("split_factor", 0.6) self._name = "DampingSCFG" self._set_name( { "split": f"{self.split_factor}-{1 - self.split_factor}", "damping": str(self.damping_factor), } )
[docs] class DampingCROnceEngine(DampingEngine, CostReductionOnceEngine): """A BP engine that combines message damping and one-time cost reduction.""" def __init__(self, *args, **kwargs): """Initializes the DampingCROnceEngine. This engine inherits parameters from `DampingEngine` and `CostReductionOnceEngine`. Args: *args: Positional arguments for the base engines. **kwargs: Keyword arguments for the base engines (e.g., `damping_factor`, `reduction_factor`). """ kwargs.setdefault("reduction_factor", 0.5) kwargs.setdefault("damping_factor", 0.9) super().__init__(*args, **kwargs) self.reduction_factor = kwargs.get("reduction_factor", 0.5) self._name = "DampingCROnceEngine" self._set_name( { "reduction": f"{self.reduction_factor}-{1 - self.reduction_factor}", "damping": str(self.damping_factor), } )
[docs] class TRWEngine(BPEngine): """ Tree-Reweighted Belief Propagation engine (Min-Sum variant). The engine keeps the standard Min-Sum computator but automatically: 1. Samples spanning trees over the variable-only (primal) graph to estimate per-factor appearance probabilities ``rho_f``. 2. Scales each factor's energy table by ``1 / rho_f`` before message computation so local costs match the TRW objective. 3. Re-weights outgoing R-messages from factors by ``rho_f`` so that variable updates/beliefs operate on appropriately weighted costs. Rho sampling and scaling can be overridden by providing explicit ``factor_rhos`` (all > 0). Otherwise the engine performs end-to-end TRW reweighting using the current factor graph structure. """ DEFAULT_MIN_RHO = 1e-6 def __init__( self, *args, factor_rhos: Optional[Dict[str, float]] = None, tree_sample_count: int = 250, tree_sampler_seed: Optional[int] = None, min_rho: float = DEFAULT_MIN_RHO, **kwargs, ) -> None: """ Args: factor_rhos: Optional explicit mapping from factor name to rho_f > 0. When omitted, the engine estimates rhos via spanning-tree sampling. tree_sample_count: Number of spanning trees to sample when estimating rhos. tree_sampler_seed: Seed forwarded to the tree sampler for reproducibility. min_rho: Lower bound applied to sampled rhos to keep them strictly positive (important for stable scaling). """ self.tree_sample_count = max(1, tree_sample_count) self.tree_sampler_seed = tree_sampler_seed self.min_rho = max(min_rho, self.DEFAULT_MIN_RHO) self._user_defined_rhos = bool(factor_rhos) self.factor_rhos: Dict[str, float] = dict(factor_rhos or {}) super().__init__(*args, **kwargs) self._name = "TRWEngine" suffix = ( "custom" if self._user_defined_rhos else f"trees-{self.tree_sample_count}" ) self._set_name({"trw": suffix})
[docs] def post_init(self) -> None: """ Validate rho configuration, sample if needed, and scale costs. Called from BPEngine.__init__ after `self.graph` is set but before messages are initialized. """ factors = getattr(self.graph, "factors", []) if not factors: return if not self.factor_rhos: self.factor_rhos = self._estimate_rhos_via_spanning_trees(factors) else: for factor in factors: self.factor_rhos.setdefault(factor.name, 1.0) for factor in factors: rho = self.factor_rhos.get(factor.name, 1.0) if rho <= 0: raise ValueError( f"TRWEngine: rho for factor '{factor.name}' must be > 0, got {rho}" ) self._scale_factor_cost_table(factor, rho)
[docs] def post_factor_compute(self, factor: FactorAgent, iteration: int) -> None: """Scale outgoing R-messages by rho_f before they are sent.""" rho = self.factor_rhos.get(factor.name, 1.0) if rho == 1.0 or not factor.mailer.outbox: return for msg in factor.mailer.outbox: msg.data = rho * msg.data
# --- Internal helpers ------------------------------------------------- def _scale_factor_cost_table(self, factor: FactorAgent, rho: float) -> None: """Reset to the original cost table (if saved) and divide by rho.""" factor.save_original() base = ( factor.original_cost_table if factor.original_cost_table is not None else factor.cost_table ) if base is None: return factor.cost_table = base / rho def _estimate_rhos_via_spanning_trees( self, factors: list[FactorAgent] ) -> Dict[str, float]: """Compute rho_f by sampling spanning trees on the primal graph.""" primal_graph, edge_to_factors = self._build_primal_graph() if ( primal_graph.number_of_edges() == 0 or primal_graph.number_of_nodes() == 0 or not nx.is_connected(primal_graph) ): # Fallback: no usable topology information, keep uniform rhos. return {factor.name: 1.0 for factor in factors} counts = {factor.name: 0 for factor in factors} rng = random.Random(self.tree_sampler_seed) samples = max(1, self.tree_sample_count) for _ in range(samples): tree_edges = self._sample_spanning_tree(primal_graph, rng) for node_u, node_v in tree_edges: key = (node_u, node_v) if node_u <= node_v else (node_v, node_u) for factor in edge_to_factors.get(key, []): counts[factor.name] += 1 rhos: Dict[str, float] = {} for factor in factors: count = counts.get(factor.name, 0) rho = count / samples if count > 0 else 0.0 if rho <= 0: rho = self.min_rho rhos[factor.name] = rho return rhos def _build_primal_graph( self, ) -> tuple[nx.Graph, Dict[tuple[str, str], list[FactorAgent]]]: """Construct the variable-only graph used for tree sampling.""" graph = nx.Graph() variables = getattr(self.graph, "variables", []) graph.add_nodes_from(var.name for var in variables) edge_to_factors: Dict[tuple[str, str], list[FactorAgent]] = {} for factor in getattr(self.graph, "factors", []): var_names = sorted(factor.connection_number.keys()) if len(var_names) != 2: # Hyper-edges are left unweighted (rho defaults to 1). continue edge_key = (var_names[0], var_names[1]) graph.add_edge(*edge_key) edge_to_factors.setdefault(edge_key, []).append(factor) # Guard against multiple factors per variable pair by treating them # as parallel edges (each receives credit whenever the pair is picked). return graph, edge_to_factors def _sample_spanning_tree( self, graph: nx.Graph, rng: random.Random ) -> list[tuple[str, str]]: """ Sample a spanning tree using Wilson's algorithm (loop-erased random walk). Wilson's method draws a uniformly random spanning tree without needing any determinant computations, avoiding numerical issues on large graphs. """ nodes = list(graph.nodes()) if not nodes: return [] root = rng.choice(nodes) tree_nodes = {root} unvisited = set(nodes) unvisited.discard(root) tree_edges: set[tuple[str, str]] = set() while unvisited: start = rng.choice(tuple(unvisited)) walk = [start] visited_index = {start: 0} current = start while current not in tree_nodes: neighbors = list(graph.neighbors(current)) if not neighbors: break # disconnected; handled by caller beforehand nxt = rng.choice(neighbors) if nxt in visited_index: # Erase the loop by truncating the walk. loop_start = visited_index[nxt] walk = walk[: loop_start + 1] else: walk.append(nxt) visited_index[nxt] = len(walk) - 1 current = nxt # Add the walk to the tree (connecting to the existing tree at `current`). for u, v in zip(walk, walk[1:]): edge = (u, v) if u <= v else (v, u) tree_edges.add(edge) tree_nodes.add(u) tree_nodes.add(v) unvisited.discard(u) unvisited.discard(v) tree_nodes.add(current) unvisited.discard(current) return list(tree_edges)
[docs] class DampingTRWEngine(DampingEngine, TRWEngine): """A BP engine that combines TRW reweighting with message damping.""" def __init__(self, *args, **kwargs): kwargs.setdefault("damping_factor", 0.9) super().__init__(*args, **kwargs) self.damping_factor = kwargs.get("damping_factor", 0.9) trw_suffix = ( "custom" if getattr(self, "_user_defined_rhos", False) else f"trees-{self.tree_sample_count}" ) self._name = "DampingTRWEngine" self._set_name({"damping": str(self.damping_factor), "trw": trw_suffix})
[docs] class MessagePruningEngine(BPEngine): """A BP engine that applies a message pruning policy to reduce memory usage.""" def __init__( self, *args, prune_threshold: float = 1e-4, min_iterations: int = 5, adaptive_threshold: bool = True, **kwargs, ): """Initializes the MessagePruningEngine. Args: *args: Positional arguments for the base `BPEngine`. prune_threshold: The threshold below which messages are pruned. min_iterations: The number of iterations to wait before pruning. adaptive_threshold: Whether to adapt the threshold dynamically. **kwargs: Keyword arguments for the base `BPEngine`. """ self.prune_threshold = prune_threshold self.min_iterations = min_iterations self.adaptive_threshold = adaptive_threshold super().__init__(*args, **kwargs)
[docs] def post_init(self) -> None: """Initializes and sets the message pruning policy on agent mailers.""" from ..policies.message_pruning import MessagePruningPolicy self.pruning_policy = MessagePruningPolicy( prune_threshold=self.prune_threshold, min_iterations=self.min_iterations, adaptive_threshold=self.adaptive_threshold, )