"""A Policy for Pruning Redundant Messages in Belief Propagation.
This module provides a policy that can be used to reduce the number of
messages that are processed and stored in a belief propagation simulation.
It works by comparing a new incoming message to the previous message from the
same sender and discarding the new one if the change is below a certain
threshold. This can significantly reduce memory usage and computation time in
simulations where messages quickly stabilize.
"""
import logging
from typing import Dict
import numpy as np
from ..configs.global_config_mapping import PolicyDefaults
from ..core.agents import FGAgent
from ..core.components import Message
from ..core.protocols import PolicyType
logger = logging.getLogger(__name__)
[docs]
class MessagePruningPolicy:
"""A policy that prunes messages that have not changed significantly.
This policy helps to optimize belief propagation by avoiding the processing
of messages that are redundant. It compares the norm of the difference
between a new message and the previous message from the same sender.
Attributes:
policy_type: The type of the policy (PolicyType.MESSAGE).
prune_threshold: The threshold for pruning.
min_iterations: The number of initial iterations during which
no pruning will occur.
adaptive_threshold: If True, the threshold is scaled by the
magnitude of the message.
iteration_count: The current iteration number.
pruned_count: The total number of pruned messages.
total_count: The total number of messages considered for pruning.
"""
def __init__(
self,
prune_threshold: float = None,
min_iterations: int = 5,
adaptive_threshold: bool = True,
):
"""Initializes the MessagePruningPolicy.
Args:
prune_threshold: The base threshold for pruning. If the change
is less than this, the message is pruned. Defaults to the
value in `POLICY_DEFAULTS`.
min_iterations: The number of iterations to run before pruning
begins. Defaults to 5.
adaptive_threshold: Whether to use an adaptive threshold that
scales with the message magnitude. Defaults to True.
"""
self.policy_type = PolicyType.MESSAGE
self.prune_threshold = (
prune_threshold
if prune_threshold is not None
else PolicyDefaults.PRUNING_THRESHOLD.value
)
self.min_iterations = min_iterations
self.adaptive_threshold = adaptive_threshold
self.iteration_count = 0
self.pruned_count = 0
self.total_count = 0
[docs]
def should_accept_message(self, agent: FGAgent, new_message: Message) -> bool:
"""Determines whether to accept or prune an incoming message.
Args:
agent: The agent receiving the message.
new_message: The new `Message` object.
Returns:
True if the message should be accepted, False if it should be pruned.
"""
self.total_count += 1
if self.iteration_count < self.min_iterations:
return True
prev_message = agent.mailer[new_message.sender.name]
if prev_message is None:
return True
diff_norm = np.linalg.norm(new_message.data - prev_message.data)
threshold = self.prune_threshold
if self.adaptive_threshold:
msg_magnitude = np.linalg.norm(new_message.data)
# Calculate dynamic threshold based on message magnitude
# For very small messages, we want a smaller threshold to avoid pruning everything
# For large messages, we can afford a larger threshold
# The factor 0.1 is arbitrary but seems to work well
dynamic_factor = max(
1.0, msg_magnitude * PolicyDefaults().pruning_magnitude_factor
)
threshold *= dynamic_factor
if diff_norm < threshold:
self.pruned_count += 1
logger.debug(
f"Pruned message {new_message.sender.name} -> "
f"{new_message.recipient.name}, diff: {diff_norm:.6f}"
)
return False
return True
[docs]
def step_completed(self) -> None:
"""Signals that a simulation step has completed, incrementing the iteration count."""
self.iteration_count += 1
[docs]
def get_stats(self) -> Dict[str, float]:
"""Returns statistics about the pruning process.
Returns:
A dictionary containing the pruning rate, total messages considered,
number of pruned messages, and total iterations.
"""
return {
"pruning_rate": self.pruned_count / max(self.total_count, 1),
"total_messages": self.total_count,
"pruned_messages": self.pruned_count,
"iterations": self.iteration_count,
}
[docs]
def reset(self) -> None:
"""Resets the policy's internal state for a new simulation run."""
self.iteration_count = 0
self.pruned_count = 0
self.total_count = 0