Source code for propflow.core.agents

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import List, Callable, Any, Dict

import numpy as np

from .components import Message, CostTable, MailHandler
from .dcop_base import Agent


[docs] class FGAgent(Agent, ABC): """Abstract base class for belief propagation (BP) nodes. Extends the `Agent` class with methods relevant to message passing, updating local belief, and retrieving that belief. It serves as a foundation for both `VariableAgent` and `FactorAgent` classes. Attributes: domain (int): The size of the variable domain. mailer (MailHandler): Handles incoming and outgoing messages. """ def __init__(self, name: str, node_type: str, domain: int): """Initializes an FGAgent. Args: name (str): The name of the agent. node_type (str): The type of the node (e.g., 'variable', 'factor'). domain (int): The size of the variable domain. """ super().__init__(name, node_type) self.domain = domain self._history = [] self._max_history = 10 # Limit history size to prevent memory issues self.mailer = MailHandler(domain)
[docs] def receive_message(self, message: Message) -> None: """Receives a message and adds it to the mailer's inbox. Args: message (Message): The message to be received. """ self.mailer.receive_messages(message)
[docs] def send_message(self, message: Message) -> None: """Sends a message to its recipient via the mailer. Args: message (Message): The message to be sent. """ self.mailer.send()
[docs] def empty_mailbox(self) -> None: """Clears all messages from the mailer's inbox.""" self.mailer.clear_inbox()
[docs] def empty_outgoing(self): """Clears all messages from the mailer's outbox.""" self.mailer.clear_outgoing()
@property def inbox(self) -> List[Message]: """list[Message]: A list of incoming messages.""" return self.mailer.inbox @property def outbox(self) -> List[Message]: """list[Message]: A list of outgoing messages.""" return self.mailer.outbox
[docs] @abstractmethod def compute_messages(self) -> List[Message]: """Abstract method to compute outgoing messages. This method must be implemented by subclasses to define how messages are calculated based on the agent's current state and incoming messages. Returns: A list of messages to be sent. """ pass
@property def last_iteration(self) -> List[Message]: """list[Message]: The last list of messages sent.""" if not self._history: return [] return self._history[-1]
[docs] def last_cycle(self, diameter: int = 1) -> List[Message]: """Retrieves messages from a previous cycle. Args: diameter (int): The number of iterations in a cycle. Defaults to 1. Returns: A list of messages from the specified previous cycle. """ if not self._history: return [] return self._history[-diameter]
[docs] def append_last_iteration(self): """Appends the current outbox to the history. Maintains a history of sent messages, limited by `_max_history`. """ self._history.append([msg.copy() for msg in self.mailer.outbox]) if len(self._history) > self._max_history: self._history.pop(0) # Remove oldest to maintain size limit
[docs] class VariableAgent(FGAgent): """Represents a variable node in a factor graph. This agent is responsible for aggregating messages from neighboring factor nodes to compute its belief over its domain. Attributes: computator: An object that handles the computation of messages and beliefs. """ def __init__(self, name: str, domain: int): """Initializes a VariableAgent. Args: name (str): The name of the variable (e.g., 'x1'). domain (int): The size of the variable's domain. """ node_type = "variable" super().__init__(name, node_type, domain)
[docs] def compute_messages(self) -> None: """Computes outgoing messages to factor nodes. Uses the assigned `computator` to calculate messages based on the contents of the inbox. """ if self.computator and self.mailer.inbox: messages = self.computator.compute_Q(self.mailer.inbox) self.mailer.stage_sending(messages)
@property def belief(self) -> np.ndarray: """np.ndarray: The current belief distribution over the variable's domain.""" if self.computator and hasattr(self.computator, "compute_belief"): return self.computator.compute_belief(self.inbox, self.domain) # Fallback to sum-product behavior if no computator method is available if not self.inbox: return np.ones(self.domain) / self.domain # Uniform belief # Sum all incoming messages belief = np.zeros(self.domain) for message in self.inbox: belief += message.data return belief @property def curr_assignment(self) -> int | float: """int | float: The current assignment for the variable.""" if self.computator and hasattr(self.computator, "get_assignment"): return self.computator.get_assignment(self.belief) # Fallback to default MinSum behavior if no computator support return int(np.argmin(self.belief)) def __str__(self) -> str: """Returns the uppercase name of the agent.""" return self.name.upper() def __repr__(self) -> str: """Returns a string representation of the VariableAgent.""" return f"VariableAgent({self.name}, domain={self.domain})"
[docs] class FactorAgent(FGAgent): """Represents a factor node in a factor graph. This agent stores a cost function (or utility function) that defines the relationship between a set of connected variable nodes. Attributes: cost_table (CostTable): The table of costs for each combination of assignments. connection_number (dict): A mapping from variable names to their dimension index. ct_creation_func (Callable): A function to create the cost table. ct_creation_params (dict): Parameters for the cost table creation function. """ def __init__( self, name: str, domain: int, ct_creation_func: Callable, param: Dict[str, Any] | None = None, cost_table: CostTable | None = None, ): """Initializes a FactorAgent. Args: name (str): The name of the factor (e.g., 'f12'). domain (int): The size of the variable domain. ct_creation_func (Callable): A function to generate the cost table. param (dict, optional): Parameters for `ct_creation_func`. Defaults to None. cost_table (CostTable, optional): An existing cost table. Defaults to None. """ node_type = "factor" super().__init__(name, node_type, domain) self.cost_table = None if cost_table is None else cost_table.copy() self.connection_number: Dict[str, int] = {} # var_name -> dimension self.ct_creation_func = ct_creation_func self.ct_creation_params = param if param is not None else {} self._original: np.ndarray | None = None
[docs] @classmethod def create_from_cost_table(cls, name: str, cost_table: CostTable) -> FactorAgent: """Creates a FactorAgent from an existing cost table. Args: name (str): The name of the factor. cost_table (CostTable): The cost table to use. Returns: A new `FactorAgent` instance. """ return cls( name=name, domain=cost_table.shape[0], ct_creation_func=lambda *args, **kwargs: cost_table, param=None, cost_table=cost_table, )
[docs] def compute_messages(self) -> None: """Computes messages to be sent to variable nodes. Uses the assigned `computator` to calculate messages based on the cost table and incoming messages from variable nodes. """ if self.computator and self.cost_table is not None and self.inbox: messages = self.computator.compute_R( cost_table=self.cost_table, incoming_messages=self.inbox ) self.mailer.stage_sending(messages)
[docs] def initiate_cost_table(self) -> None: """Creates the cost table using the provided creation function. Raises: ValueError: If the cost table already exists or if no connections are set. """ if self.cost_table is not None: raise ValueError("Cost table already exists. Cannot create a new one.") if not self.connection_number: raise ValueError("No connections set. Cannot create cost table.") # Create cost table with correct dimensions num_vars = len(self.connection_number) self.cost_table = self.ct_creation_func( num_vars, self.domain, **self.ct_creation_params )
[docs] def set_dim_for_variable(self, variable: VariableAgent, dim: int) -> None: """Maps a variable to a dimension in the cost table. Args: variable (VariableAgent): The variable agent to map. dim (int): The dimension index in the cost table. """ self.connection_number[variable.name] = dim
[docs] def set_name_for_factor(self) -> None: """Sets the factor's name based on its connected variables. Raises: ValueError: If no connections are set. """ if not self.connection_number: raise ValueError("No connections set. Cannot set name.") var_indices = [] for var_name in sorted(self.connection_number.keys()): if var_name.startswith("x"): var_indices.append(var_name[1:]) self.name = f"f{''.join(var_indices)}_"
[docs] def save_original(self, ct: CostTable | None = None) -> None: """Saves a copy of the original cost table. Args: ct (CostTable, optional): An external cost table to save. Defaults to None. """ if self._original is None and self.cost_table is not None and ct is None: self._original = np.copy(self.cost_table) elif ct is not None and self._original is None and self.cost_table is not None: self._original = np.copy(ct)
@property def mean_cost(self) -> float: """float: The mean value of the costs in the cost table.""" if self.cost_table is None: return 0.0 return float(np.mean(self.cost_table)) @property def total_cost(self) -> float: """float: The sum of all costs in the cost table.""" if self.cost_table is None: return 0.0 return float(np.sum(self.cost_table)) @property def original_cost_table(self) -> np.ndarray | None: """np.ndarray | None: The original, unmodified cost table, if saved.""" return self._original def __repr__(self) -> str: """Returns a string representation of the FactorAgent.""" return f"FactorAgent({self.name}, connections={list(self.connection_number.keys())})" def __str__(self) -> str: """Returns the uppercase name of the agent.""" return self.name.upper()