Source code for propflow.core.components

from __future__ import annotations
from typing import Optional, Dict
import numpy as np
from typing import List, TypeAlias, TYPE_CHECKING

from .dcop_base import Agent

if TYPE_CHECKING:
    from .agents import FGAgent

CostTable: TypeAlias = np.ndarray


[docs] class Message: """Represents a message passed between agents in the belief propagation algorithm. Attributes: data (np.ndarray): The content of the message, typically a numpy array representing costs or beliefs. sender (Agent): The agent sending the message. recipient (Agent): The agent receiving the message. """ def __init__(self, data: np.ndarray, sender: Agent, recipient: Agent): """Initializes a Message instance. Args: data (np.ndarray): The message content. sender (Agent): The sender of the message. recipient (Agent): The recipient of the message. """ self.data = data self.sender = sender self.recipient = recipient
[docs] def copy(self) -> Message: """Creates a deep copy of the message. Returns: Message: A new `Message` instance with copied data. """ return Message( data=np.copy(self.data), sender=self.sender, recipient=self.recipient )
def __hash__(self) -> int: """Computes the hash of the message based on sender and recipient names.""" return hash((self.sender.name, self.recipient.name)) def __eq__(self, other: object) -> bool: """Checks for equality based on sender and recipient names.""" if not isinstance(other, Message): return NotImplemented return ( self.sender.name == other.sender.name and self.recipient.name == other.recipient.name ) def __ne__(self, other: object) -> bool: """Checks for inequality.""" return not self == other def __str__(self) -> str: """Returns a human-readable string representation of the message.""" return f"Message from {self.sender.name} to {self.recipient.name}: {self.data}" def __repr__(self) -> str: """Returns a detailed string representation of the message.""" return self.__str__()
[docs] class MailHandler: """Handles message passing with deduplication and synchronization. This class manages an agent's incoming and outgoing messages, ensuring that only the latest message from each sender is stored. Attributes: pruning_policy: An optional policy for selectively discarding messages. """ def __init__(self, _domain_size: int): """Initializes the MailHandler. Args: _domain_size (int): The domain size for messages, used to initialize empty messages. """ self._message_domain_size = _domain_size self._incoming: Dict[str, Message] = {} # Key: sender_key, Value: message self._outgoing: List[Message] = [] self._clear_after_staging = True
[docs] def set_pruning_policy(self, policy) -> None: """Sets a message pruning policy. Args: policy: An object with a `should_accept_message` method. """ self.pruning_policy = getattr(self, "pruning_policy", None) self.pruning_policy = policy
def _make_key(self, agent: Agent) -> str: """Creates a unique key for an agent to prevent collisions. Args: agent (Agent): The agent for which to create a key. Returns: str: A unique string identifier for the agent. """ return f"{agent.name}_{agent.type}"
[docs] def set_first_message(self, owner: FGAgent, neighbor: FGAgent) -> None: """Initializes the inbox with a zero-message from a neighbor. This is used to ensure that an agent has a message from each neighbor before computation begins. Args: owner (FGAgent): The agent who owns this mail handler. neighbor (FGAgent): The neighboring agent to initialize a message from. """ key = self._make_key(neighbor) # Default initialization with zeros self._incoming[key] = Message( np.zeros(self._message_domain_size), neighbor, owner, )
[docs] def receive_messages(self, messages: Message | list[Message]) -> None: """Receives and handles one or more messages. Applies a pruning policy if one is set and stores the message, overwriting any previous message from the same sender. Args: messages: A single `Message` or a list of `Message` objects. """ if isinstance(messages, list): for message in messages: self.receive_messages(message) return message = messages # Check for pruning policy if hasattr(self, "pruning_policy") and self.pruning_policy is not None: owner = message.recipient if not self.pruning_policy.should_accept_message(owner, message): return # Message pruned # Accept message key = self._make_key(message.sender) self._incoming[key] = message
[docs] def send(self) -> None: """Sends all staged outgoing messages to their recipients.""" for message in self._outgoing: message.recipient.mailer.receive_messages(message)
[docs] def stage_sending(self, messages: List[Message]) -> None: """Stages a list of messages to be sent. Args: messages (List[Message]): The messages to be sent. """ self._outgoing = messages.copy()
[docs] def prepare(self) -> None: """Clears the outbox, typically after messages have been sent.""" self._outgoing.clear()
[docs] def clear_inbox(self) -> None: """Clears all messages from the inbox.""" self._incoming.clear()
[docs] def clear_outgoing(self) -> None: """Clears all messages from the outbox.""" self._outgoing.clear()
@property def inbox(self) -> List[Message]: """list[Message]: A list of incoming messages.""" return list(self._incoming.values()) @inbox.setter def inbox(self, li: List[Message]) -> None: """Sets the inbox from a list of messages. Args: li (List[Message]): A list of messages to populate the inbox with. """ self._incoming.clear() for msg in li: key = self._make_key(msg.sender) self._incoming[key] = msg @property def outbox(self) -> List[Message]: """list[Message]: A list of outgoing messages.""" return self._outgoing @outbox.setter def outbox(self, li: List[Message]) -> None: """Sets the outbox from a list of messages. Args: li (List[Message]): A list of messages to populate the outbox with. """ self._outgoing = li def __getitem__(self, sender_name: str) -> Optional[Message]: """Retrieves a message from the inbox by the sender's name. Args: sender_name (str): The name of the sender. Returns: The `Message` object if found, otherwise `None`. """ for key, msg in self._incoming.items(): if msg.sender.name == sender_name: return msg return None def __len__(self) -> int: """Returns the number of messages in the inbox.""" return len(self._incoming) def __iter__(self): """Returns an iterator over the messages in the inbox.""" return iter(self.inbox)