Source code for propflow.bp.computators

import functools
import logging
from functools import lru_cache
from typing import List, TypeAlias

import numpy as np

from ..core.components import Message
from ..core.protocols import Computator

# Minimal logging for computators
logger = logging.getLogger(__name__)
logger.setLevel(logging.CRITICAL)


[docs] class BPComputator(Computator): """A vectorized, cache-friendly Belief Propagation computator. This class provides a highly optimized implementation for the core computations in belief propagation algorithms. It uses function dispatch tables for common operations ( `min`, `max`, `sum`, `add`, `multiply`) to minimize overhead and leverages vectorized numpy operations for performance. The behavior of the computator (Min-Sum, Max-Product) is determined by the `reduce_func` and `combine_func` arguments passed during initialization. Attributes: reduce_func (Callable): The function used for message reduction (`np.min`). combine_func (Callable): The function used for message combination (`np.add`). """ # Function dispatch tables for zero-overhead lookups _REDUCE_DISPATCH = { np.min: (np.ndarray.min, np.ndarray.argmin), np.max: (np.ndarray.max, np.ndarray.argmax), np.sum: (np.ndarray.sum, np.ndarray.argmax), } _COMBINE_DISPATCH = { np.add: (np.sum, np.subtract, np.zeros), np.multiply: (np.prod, np.divide, np.ones), } def __init__(self, reduce_func=np.min, combine_func=np.add): """Initializes the BPComputator. Args: reduce_func (Callable): The function to use for reducing messages (e.g., `np.min` for Min-Sum). Defaults to `np.min`. combine_func (Callable): The function to use for combining messages (e.g., `np.add` for Min-Sum). Defaults to `np.add`. """ self.reduce_func = reduce_func self.combine_func = combine_func self._connection_cache = {} # Pre-select optimized functions using dispatch tables self._reduce_msg, self._argreduce_func = self._setup_reduce_functions( reduce_func ) ( self._combine_axis, self._combine_inverse, self._belief_identity, ) = self._setup_combine_functions(combine_func) def _setup_reduce_functions(self, reduce_func): """Sets up reduce functions from the dispatch table for performance.""" if reduce_func in self._REDUCE_DISPATCH: return self._REDUCE_DISPATCH[reduce_func] else: # Generic fallback for custom reduce functions return ( lambda x, axis: reduce_func(x, axis=axis), np.ndarray.argmax, # Default to argmax ) def _setup_combine_functions(self, combine_func): """Sets up combine functions from the dispatch table for performance.""" if combine_func in self._COMBINE_DISPATCH: return self._COMBINE_DISPATCH[combine_func] else: # Generic fallback return ( lambda x, axis: np.apply_along_axis( lambda arr: functools.reduce(combine_func, arr), axis, x ), None, # No inverse function available np.ones, # Safe default for identity ) def _remove_message_from_aggregate( self, agg, message_to_remove, all_messages, axis, cost_table=None ): """Efficiently removes a message from an aggregate. Uses a fast inverse operation if available (e.g., subtraction for addition), otherwise falls back to re-computing the aggregate. Args: agg (np.ndarray): The current aggregate array. message_to_remove (np.ndarray): The message data to remove. all_messages (list): A list of all message data for fallback. axis (int): The index of the message to remove. cost_table (np.ndarray, optional): The cost table for fallback. Returns: np.ndarray: The aggregate with the message removed. """ if self._combine_inverse is not None: return self._combine_inverse(agg, message_to_remove) # Fallback: recompute aggregate without this message if cost_table is not None: temp_agg = cost_table.astype(agg.dtype, copy=True) for i, msg in enumerate(all_messages): if i != axis: self.combine_func(temp_agg, msg, out=temp_agg) else: temp_agg = self._belief_identity(agg.shape).astype(agg.dtype) for i, msg in enumerate(all_messages): if i != axis: temp_agg = self.combine_func(temp_agg, msg) return temp_agg
[docs] def compute_Q(self, messages: List[Message]) -> List[Message]: """Computes outgoing messages from a variable node to factor nodes (Q messages). This is an optimized, vectorized implementation. Args: messages: A list of incoming messages from factor nodes. Returns: A list of computed messages to be sent to factor nodes. """ early = self._validate(messages=messages) if early is not None: return early variable = messages[0].recipient n_messages = len(messages) msg_data = np.stack([msg.data for msg in messages]) total_combined = self._combine_axis(msg_data, axis=0) outgoing_messages = [] for i in range(n_messages): combined_data = self._remove_message_from_aggregate( total_combined, msg_data[i], msg_data, i ) outgoing_messages.append( Message( data=combined_data, sender=variable, recipient=messages[i].sender, ) ) return outgoing_messages
[docs] def compute_R( self, cost_table: np.ndarray, incoming_messages: List[Message] ) -> List[Message]: """Computes outgoing messages from a factor node to variable nodes (R messages). This is an optimized, vectorized implementation that involves three main steps: 1. Broadcast each incoming Q message to the dimensionality of the cost table. 2. Combine the cost table with all broadcasted Q messages once. 3. For each recipient, efficiently "remove" its Q message from the aggregate and reduce to produce the outgoing R message. Args: cost_table: The factor's cost table. incoming_messages: A list of incoming messages from variable nodes. Returns: A list of computed messages to be sent to variable nodes. """ k = cost_table.ndim shape = cost_table.shape dtype = cost_table.dtype combine_func = self.combine_func reduce_msg = self._reduce_msg b_msgs = [] axes_cache = [] for axis, msg in enumerate(incoming_messages): q = np.asarray(msg.data, dtype=dtype) br = q.reshape([shape[axis] if i == axis else 1 for i in range(k)]) b_msgs.append(br) axes_cache.append(tuple(j for j in range(k) if j != axis)) agg = cost_table.astype(dtype, copy=True) for q in b_msgs: combine_func(agg, q, out=agg) out = [] for axis, broadcasted_q in enumerate(b_msgs): temp = self._remove_message_from_aggregate( agg, broadcasted_q, b_msgs, axis, cost_table ) r_vec = reduce_msg(temp, axis=axes_cache[axis]) out.append( Message( data=r_vec, sender=incoming_messages[axis].recipient, recipient=incoming_messages[axis].sender, ) ) return out
def _validate(self, messages=None, cost_table=None, incoming_messages=None): """Validates inputs and handles edge cases for compute methods.""" if messages is not None: if not messages: return [] if len(messages) == 1: variable = messages[0].recipient return [ Message( data=np.zeros_like(messages[0].data), sender=variable, recipient=messages[0].sender, ) ] if incoming_messages is not None: if not incoming_messages: return [] factor = incoming_messages[0].recipient if not hasattr(factor, "connection_number") or not factor.connection_number: factor.connection_number = { msg.sender.name: i for i, msg in enumerate(incoming_messages) } return None def _get_node_dimension(self, factor, node) -> int: """Optimized dimension lookup with caching.""" cache_key = (id(factor), node.name) if cache_key in self._connection_cache: return self._connection_cache[cache_key] if ( hasattr(factor, "connection_number") and factor.connection_number and node.name in factor.connection_number ): dim = factor.connection_number[node.name] self._connection_cache[cache_key] = dim return dim available_keys = list(getattr(factor, "connection_number", {}).keys()) raise KeyError( f"Node '{node.name}' not found in factor '{factor.name}' connections. " f"Available connections: {available_keys}" ) @lru_cache(maxsize=1024) def _get_broadcast_shape(self, ct_dim: int, sender_dim: int, msg_len: int) -> tuple: """Cached broadcast shape computation.""" shape = [1] * ct_dim shape[sender_dim] = msg_len return tuple(shape)
[docs] def get_assignment(self, belief: np.ndarray) -> int: """Gets the optimal assignment from a belief vector. Uses the pre-selected `_argreduce_func` (e.g., `argmin`, `argmax`) for zero-overhead execution. Args: belief: The belief vector. Returns: The index of the optimal assignment. """ return int(self._argreduce_func(belief))
[docs] def compute_belief(self, messages: List[Message], domain: int) -> np.ndarray: """Computes the belief of a variable node from incoming messages. Args: messages: A list of incoming messages. domain: The domain size of the variable. Returns: A numpy array representing the belief distribution. """ if not messages: return np.ones(domain) / domain belief = self._belief_identity(domain) for message in messages: belief = self.combine_func(belief, message.data) return belief
Computator: TypeAlias = BPComputator
[docs] class MinSumComputator(BPComputator): """A computator for the Min-Sum belief propagation algorithm. This is equivalent to finding the Most Probable Explanation (MPE) in a graphical model represented in the log-domain. It combines messages using addition and reduces them using the min operator. """ def __init__(self): """Initializes the MinSumComputator.""" super().__init__(reduce_func=np.min, combine_func=np.add)
[docs] class MaxSumComputator(BPComputator): """A computator for the Max-Sum belief propagation algorithm. This is used for problems where the goal is to maximize a sum of utilities. It combines messages using addition and reduces them using the max operator. """ def __init__(self): """Initializes the MaxSumComputator.""" super().__init__(reduce_func=np.max, combine_func=np.add)
[docs] class MaxProductComputator(BPComputator): """A computator for the Max-Product belief propagation algorithm. This is equivalent to finding the Most Probable Explanation (MPE) in a graphical model. It combines messages using multiplication and reduces them using the max operator. """ def __init__(self): """Initializes the MaxProductComputator.""" super().__init__(reduce_func=np.max, combine_func=np.multiply)
[docs] class SumProductComputator(BPComputator): """A computator for the Sum-Product belief propagation algorithm. This is used to compute marginal probabilities in a graphical model. It combines messages using multiplication and reduces them (marginalizes) using summation. """ def __init__(self): """Initializes the SumProductComputator.""" super().__init__(reduce_func=np.sum, combine_func=np.multiply)