"""A Policy for Pruning Redundant Messages in Belief Propagation.
This module provides a policy that can be used to reduce the number of
messages that are processed and stored in a belief propagation simulation.
It works by comparing a new incoming message to the previous message from the
same sender and discarding the new one if the change is below a certain
threshold. This can significantly reduce memory usage and computation time in
simulations where messages quickly stabilize.
"""
from typing import Dict
import numpy as np
from ..core.agents import FGAgent
from ..core.components import Message
from ..core.protocols import PolicyType
from ..configs.global_config_mapping import POLICY_DEFAULTS
from .bp_policies import Policy
import logging
logger = logging.getLogger(__name__)
[docs]
class MessagePruningPolicy(Policy):
"""A policy that prunes messages that have not changed significantly.
This policy helps to optimize belief propagation by avoiding the processing
of messages that are redundant. It compares the norm of the difference
between a new message and the previous message from the same sender.
Attributes:
prune_threshold (float): The threshold for pruning.
min_iterations (int): The number of initial iterations during which
no pruning will occur.
adaptive_threshold (bool): If True, the threshold is scaled by the
magnitude of the message.
iteration_count (int): The current iteration number.
pruned_count (int): The total number of pruned messages.
total_count (int): The total number of messages considered for pruning.
"""
def __init__(
self,
prune_threshold: float = None,
min_iterations: int = 5,
adaptive_threshold: bool = True,
):
"""Initializes the MessagePruningPolicy.
Args:
prune_threshold: The base threshold for pruning. If the change
is less than this, the message is pruned. Defaults to the
value in `POLICY_DEFAULTS`.
min_iterations: The number of iterations to run before pruning
begins. Defaults to 5.
adaptive_threshold: Whether to use an adaptive threshold that
scales with the message magnitude. Defaults to True.
"""
super().__init__(PolicyType.MESSAGE)
self.prune_threshold = (
prune_threshold
if prune_threshold is not None
else POLICY_DEFAULTS["pruning_threshold"]
)
self.min_iterations = min_iterations
self.adaptive_threshold = adaptive_threshold
self.iteration_count = 0
self.pruned_count = 0
self.total_count = 0
[docs]
def should_accept_message(self, agent: FGAgent, new_message: Message) -> bool:
"""Determines whether to accept or prune an incoming message.
Args:
agent: The agent receiving the message.
new_message: The new `Message` object.
Returns:
True if the message should be accepted, False if it should be pruned.
"""
self.total_count += 1
if self.iteration_count < self.min_iterations:
return True
prev_message = agent.mailer[new_message.sender.name]
if prev_message is None:
return True
diff_norm = np.linalg.norm(new_message.data - prev_message.data)
threshold = self.prune_threshold
if self.adaptive_threshold:
msg_magnitude = np.linalg.norm(new_message.data)
threshold *= max(1.0, msg_magnitude * POLICY_DEFAULTS["pruning_magnitude_factor"])
if diff_norm < threshold:
self.pruned_count += 1
logger.debug(
f"Pruned message {new_message.sender.name} -> "
f"{new_message.recipient.name}, diff: {diff_norm:.6f}"
)
return False
return True
[docs]
def step_completed(self) -> None:
"""Signals that a simulation step has completed, incrementing the iteration count."""
self.iteration_count += 1
[docs]
def get_stats(self) -> Dict[str, float]:
"""Returns statistics about the pruning process.
Returns:
A dictionary containing the pruning rate, total messages considered,
number of pruned messages, and total iterations.
"""
return {
"pruning_rate": self.pruned_count / max(self.total_count, 1),
"total_messages": self.total_count,
"pruned_messages": self.pruned_count,
"iterations": self.iteration_count,
}
[docs]
def reset(self) -> None:
"""Resets the policy's internal state for a new simulation run."""
self.iteration_count = 0
self.pruned_count = 0
self.total_count = 0