import random
from typing import Dict, Literal, Optional, Sequence
import networkx as nx
import numpy as np
from ..core.agents import FactorAgent, VariableAgent
from ..core.components import Message
from ..policies import damp, damp_factor
from ..policies.cost_reduction import cost_reduction_all_factors_once
from ..policies.splitting import split_all_factors, split_factors
from ..utils.inbox_utils import multiply_messages
from .engine_base import BPEngine
[docs]
class Engine(BPEngine):
"""A basic belief propagation engine.
This is a direct alias for `BPEngine` and provides the standard,
unmodified belief propagation behavior.
"""
...
[docs]
class SplitEngine(BPEngine):
"""A BP engine that applies the factor splitting policy.
This engine modifies the factor graph by splitting each factor into two,
distributing the original cost between them. This can sometimes help with
convergence.
"""
def __init__(self, *args, split_factor: float = 0.5, **kwargs):
"""Initializes the SplitEngine.
Args:
*args: Positional arguments for the base `BPEngine`.
split_factor: The proportion of the cost to allocate to the first
of the two new factors. Defaults to 0.5.
**kwargs: Keyword arguments for the base `BPEngine`.
"""
self.split_factor = split_factor
super().__init__(*args, **kwargs)
self._name = "SPFGEngine"
self._set_name({"split-": f"{self.split_factor}-{self.split_factor}"})
[docs]
def post_init(self) -> None:
"""Applies the factor splitting policy after initialization."""
split_all_factors(self.graph, self.split_factor)
[docs]
class MidRunSplitEngine(BPEngine):
"""A BP engine that applies factor splitting at a chosen iteration.
The engine runs standard BP until ``split_at_iter``. Immediately before
that step is computed, it splits the requested factors and either resets
messages to zero or transfers current factor-to-variable messages into the
split graph.
The ``transfer`` mode is an empirical heuristic, not a canonical mid-run
split. It preserves the aggregate factor-to-variable contribution at the
variable inbox boundary (and therefore the variable belief, and the Q
messages to unaffected neighbors). It does not preserve the Q messages
going into the split clones themselves: each clone's Q picks up the
sibling clone's transferred share, so each clone's first outgoing R is
not a canonical continuation of the un-split dynamics. Treat the split
iteration as a heuristic injection.
"""
def __init__(
self,
*args,
split_at_iter: int,
split_factor: float = 0.5,
split_targets: Sequence[str] | None = None,
split_fraction: float | None = None,
split_seed: int | None = None,
transfer_mode: Literal["reset", "transfer"] = "reset",
**kwargs,
) -> None:
if split_at_iter < 0:
raise ValueError("split_at_iter must be non-negative.")
if transfer_mode not in {"reset", "transfer"}:
raise ValueError("transfer_mode must be either 'reset' or 'transfer'.")
self.split_at_iter = int(split_at_iter)
self.split_factor = float(split_factor)
self.split_targets = list(split_targets) if split_targets is not None else None
self.split_fraction = split_fraction
self.split_seed = split_seed
self.transfer_mode = transfer_mode
self.split_mapping: Dict[str, list[FactorAgent]] = {}
self.split_events: list[dict] = []
self._split_applied = False
self._pending_split_event: dict | None = None
super().__init__(*args, **kwargs)
self._name = "MidRunSplitEngine"
self._set_name(
{
"split_at": str(self.split_at_iter),
"mode": self.transfer_mode,
"split": str(self.split_factor),
}
)
[docs]
def step(self, i: int = 0):
if not self._split_applied and i >= self.split_at_iter:
self._apply_midrun_split(i)
step = super().step(i)
if self._pending_split_event is not None:
snapshot = self._snapshots.get(i)
if snapshot is not None:
snapshot.metadata["split_event"] = dict(self._pending_split_event)
self._pending_split_event = None
return step
def _apply_midrun_split(self, iteration: int) -> None:
prior_var_inboxes = self._capture_variable_inboxes()
self.split_mapping = split_factors(
self.graph,
self.split_factor,
factor_names=self.split_targets,
split_fraction=self.split_fraction,
seed=self.split_seed,
)
self._split_applied = True
self._refresh_graph_views()
self.graph.set_computator(self.computator)
if self.transfer_mode == "reset":
self._reset_messages()
elif self.transfer_mode == "transfer":
self._transfer_messages(prior_var_inboxes)
else: # pragma: no cover - guarded in __init__
raise ValueError(f"Unsupported transfer_mode: {self.transfer_mode}")
event = {
"iteration": int(iteration),
"transfer_mode": self.transfer_mode,
"split_factor": self.split_factor,
"split_targets": self.split_targets,
"split_fraction": self.split_fraction,
"split_seed": self.split_seed,
"split_mapping": {
original: [clone.name for clone in clones]
for original, clones in self.split_mapping.items()
},
}
self.split_events.append(event)
self._pending_split_event = event
def _refresh_graph_views(self) -> None:
var_set, factor_set = nx.bipartite.sets(self.graph.G)
self.var_nodes = sorted(var_set, key=lambda node: node.name)
self.factor_nodes = sorted(factor_set, key=lambda node: node.name)
self.graph_diameter = nx.diameter(self.graph.G)
def _capture_variable_inboxes(self) -> Dict[str, list[Message]]:
captured: Dict[str, list[Message]] = {}
for variable in getattr(self, "var_nodes", []):
captured[variable.name] = [msg.copy() for msg in variable.mailer.inbox]
return captured
def _clear_agent_state(self) -> None:
for node in self.graph.G.nodes():
node.empty_mailbox()
node.empty_outgoing()
if hasattr(node, "_history"):
node._history.clear()
def _reset_messages(self) -> None:
self._clear_agent_state()
self._initialize_messages()
def _transfer_messages(self, prior_var_inboxes: Dict[str, list[Message]]) -> None:
"""Redistribute prior R messages across the split clones.
For each split factor F with prior message ``R[F->X]``, the variable
inbox is rewritten as ``p * R`` from ``F'`` and ``(1 - p) * R`` from
``F''``. This preserves the aggregate contribution
``R[F'->X] + R[F''->X] == R[F->X]_old``, so X's belief and its Q to
unaffected neighbors are unchanged on the split iteration.
The Q messages going into the clones themselves are not preserved:
``Q[X->F']`` excludes only ``F'`` from the inbox, so the sibling's
transferred share ``(1 - p) * R_old`` leaks in and is reflected in
each clone's first outgoing R. Treat ``transfer`` as a heuristic for
preserving variable-side state, not as a canonical mid-run split.
"""
self._clear_agent_state()
factor_by_name = {factor.name: factor for factor in self.factor_nodes}
variable_by_name = {variable.name: variable for variable in self.var_nodes}
for variable_name, old_messages in prior_var_inboxes.items():
variable = variable_by_name.get(variable_name)
if variable is None:
continue
for message in old_messages:
sender_name = getattr(message.sender, "name", "")
if sender_name in self.split_mapping:
clones = self.split_mapping[sender_name]
weights = [self.split_factor, 1.0 - self.split_factor]
for clone, weight in zip(clones, weights):
variable.mailer.receive_messages(
Message(
data=np.asarray(message.data, dtype=float) * weight,
sender=clone,
recipient=variable,
)
)
elif sender_name in factor_by_name:
variable.mailer.receive_messages(
Message(
data=np.copy(message.data),
sender=factor_by_name[sender_name],
recipient=variable,
)
)
self._fill_missing_variable_messages()
def _fill_missing_variable_messages(self) -> None:
for variable in self.var_nodes:
existing = {message.sender.name for message in variable.mailer.inbox}
for neighbor in self.graph.G.neighbors(variable):
if neighbor.name in existing:
continue
variable.mailer.receive_messages(
Message(
data=np.zeros(variable.domain),
sender=neighbor,
recipient=variable,
)
)
[docs]
class CostReductionOnceEngine(BPEngine):
"""A BP engine that applies a one-time cost reduction policy.
This engine reduces the costs in the factor tables at the beginning of the
simulation and then applies a discount to outgoing messages from factors.
"""
def __init__(self, *args, reduction_factor: float = 0.5, **kwargs):
"""Initializes the CostReductionOnceEngine.
Args:
*args: Positional arguments for the base `BPEngine`.
reduction_factor: The factor by which to reduce costs.
Defaults to 0.5.
**kwargs: Keyword arguments for the base `BPEngine`.
"""
self.reduction_factor = reduction_factor
super().__init__(*args, **kwargs)
[docs]
def post_init(self):
"""Applies the one-time cost reduction after initialization."""
cost_reduction_all_factors_once(self.graph, self.reduction_factor)
[docs]
def post_factor_compute(self, factor: FactorAgent, iteration: int):
"""Applies a discount to outgoing messages from factors."""
multiply_messages(factor.outbox, 0.5)
[docs]
class DampingEngine(BPEngine):
"""A BP engine that applies message damping.
Damping averages the message from the previous iteration with the newly
computed message. This can help prevent oscillations and improve convergence.
"""
def __init__(self, *args, damping_factor: float = 0.9, **kwargs):
"""Initializes the DampingEngine.
Args:
*args: Positional arguments for the base `BPEngine`.
damping_factor: The weight given to the previous message.
Defaults to 0.9.
**kwargs: Keyword arguments for the base `BPEngine`.
"""
self.damping_factor = damping_factor
super().__init__(*args, **kwargs)
self._name = "DampingEngine"
self._set_name({"damping": str(self.damping_factor)})
[docs]
def post_var_compute(self, var: VariableAgent):
"""Applies damping after a variable node computes its messages."""
damp(var, self.damping_factor)
var.append_last_iteration()
[docs]
class QRDampingEngine(BPEngine):
"""A BP engine that applies message damping to both Q and R messages.
Q damping applies to variable → factor messages.
R damping applies to factor → variable messages.
"""
def __init__(
self,
*args,
q_damping_factor: float = 0.0,
r_damping_factor: float = 0.0,
**kwargs,
):
self.q_damping_factor = float(q_damping_factor)
self.r_damping_factor = float(r_damping_factor)
# Keep a single `damping_factor` attribute for compatibility with
# snapshot tooling and existing expectations.
self.damping_factor = (
self.q_damping_factor
if self.q_damping_factor > 0
else self.r_damping_factor
)
super().__init__(*args, **kwargs)
self._name = "QRDampingEngine"
self._set_name(
{
"q_damping": str(self.q_damping_factor),
"r_damping": str(self.r_damping_factor),
}
)
[docs]
def post_init(self) -> None:
"""Prime factor history with zero messages so R-damping applies from iter 0."""
if self.r_damping_factor > 0:
for factor in self.graph.factors:
if not factor._history:
zero_msgs = []
for neighbor in self.graph.G.neighbors(factor):
zero_msgs.append(
Message(
sender=factor,
recipient=neighbor,
data=np.zeros(factor.domain),
)
)
factor._history.append(zero_msgs)
[docs]
def post_var_compute(self, var: VariableAgent) -> None:
if self.q_damping_factor > 0:
damp(var, self.q_damping_factor)
var.append_last_iteration()
[docs]
def post_factor_compute(self, factor: FactorAgent, iteration: int) -> None:
if self.r_damping_factor > 0:
damp_factor(factor, self.r_damping_factor)
factor.append_last_iteration()
[docs]
class RDampingEngine(BPEngine):
"""A BP engine that applies message damping to R messages.
Damping averages the R message from the previous iteration with the newly
computed message. This can help prevent oscillations and improve convergence.
Unlike DampingEngine which damps Q messages (Variable -> Factor), this engine
damps R messages (Factor -> Variable).
"""
def __init__(self, *args, damping_factor: float = 0.9, **kwargs):
"""Initializes the RDampingEngine.
Args:
*args: Positional arguments for the base `BPEngine`.
damping_factor: The weight given to the previous message.
Defaults to 0.9.
**kwargs: Keyword arguments for the base `BPEngine`.
"""
self.damping_factor = damping_factor
super().__init__(*args, **kwargs)
self._name = "RDampingEngine"
self._set_name({"damping": str(self.damping_factor)})
[docs]
def post_init(self) -> None:
"""Initialize history with zero messages to ensure damping works from iter 0."""
# For RDampingEngine (Factor Agents)
for factor in self.graph.factors:
self._init_agent_history(factor)
def _init_agent_history(self, agent: FactorAgent) -> None:
"""Helper to prime history with zeros."""
if agent._history:
return
zero_msgs = []
neighbors = self.graph.G.neighbors(agent)
for neighbor in neighbors:
msg = Message(sender=agent, recipient=neighbor, data=np.zeros(agent.domain))
zero_msgs.append(msg)
agent._history.append(zero_msgs)
[docs]
def post_factor_compute(self, factor: FactorAgent, iteration: int):
"""Applies damping after a factor node computes its messages."""
damp_factor(factor, self.damping_factor)
factor.append_last_iteration()
[docs]
class DiffusionEngine(BPEngine):
"""A BP engine that applies spatial message diffusion.
Unlike damping which blends messages across time (current vs previous),
diffusion blends messages across space (local vs neighbors) at each iteration.
This can help smooth the optimization landscape and improve convergence on
densely connected graphs.
"""
def __init__(self, *args, alpha: float = 0.3, **kwargs):
"""Initializes the DiffusionEngine.
Args:
*args: Positional arguments for the base `BPEngine`.
alpha: Diffusion coefficient in [0, 1]. Higher values = more smoothing.
- alpha=0: no diffusion (pure BP)
- alpha=0.1-0.3: recommended range for most problems
- alpha=1: complete averaging (may lose local information)
Defaults to 0.3.
**kwargs: Keyword arguments for the base `BPEngine`.
"""
if not 0 <= alpha <= 1:
raise ValueError(f"alpha must be in [0, 1], got {alpha}")
self.alpha = alpha
super().__init__(*args, **kwargs)
self._name = "DiffusionEngine"
self._set_name({"alpha": str(self.alpha)})
[docs]
def post_var_compute(self, var: VariableAgent) -> None:
"""Apply spatial diffusion to Q-messages (variable → factor).
For each message this variable sends to a factor, blend it with
messages from other variables connected to the same factor.
"""
if self.alpha == 0:
return # No diffusion needed
# For each message in this variable's outbox
for msg in var.outbox:
target_factor = msg.recipient
# Collect Q-messages from OTHER variables connected to same factor
neighbor_msgs = []
for neighbor in self.graph.G.neighbors(target_factor):
# Neighbors of a factor are variables
if neighbor != var: # Skip self
# Find if this neighbor variable has a message to the same factor
for neighbor_msg in neighbor.outbox:
if neighbor_msg.recipient == target_factor:
neighbor_msgs.append(neighbor_msg.data)
break
# Apply diffusion if neighbors exist
if neighbor_msgs:
neighbor_avg = np.mean(neighbor_msgs, axis=0)
# Blend: (1-α) × local + α × neighbor_average
msg.data = (1 - self.alpha) * msg.data + self.alpha * neighbor_avg
[docs]
def post_factor_compute(self, factor: FactorAgent, iteration: int) -> None:
"""Apply spatial diffusion to R-messages (factor → variable).
For each message this factor sends to a variable, blend it with
messages from other factors connected to the same variable.
"""
if self.alpha == 0:
return # No diffusion needed
# For each message in this factor's outbox
for msg in factor.outbox:
target_var = msg.recipient
# Collect R-messages from OTHER factors connected to same variable
neighbor_msgs = []
for neighbor in self.graph.G.neighbors(target_var):
# Neighbors of a variable are factors
if neighbor != factor: # Skip self
# Find if this neighbor factor has a message to the same variable
for neighbor_msg in neighbor.outbox:
if neighbor_msg.recipient == target_var:
neighbor_msgs.append(neighbor_msg.data)
break
# Apply diffusion if neighbors exist
if neighbor_msgs:
neighbor_avg = np.mean(neighbor_msgs, axis=0)
# Blend: (1-α) × local + α × neighbor_average
msg.data = (1 - self.alpha) * msg.data + self.alpha * neighbor_avg
[docs]
class DampingSCFGEngine(DampingEngine, SplitEngine):
"""A BP engine that combines message damping and factor splitting."""
def __init__(self, *args, **kwargs):
"""Initializes the DampingSCFGEngine.
This engine inherits parameters from both `DampingEngine` and `SplitEngine`.
Args:
*args: Positional arguments for the base engines.
**kwargs: Keyword arguments for the base engines (e.g.,
`damping_factor`, `split_factor`).
"""
kwargs.setdefault("split_factor", 0.6)
kwargs.setdefault("damping_factor", 0.9)
super().__init__(*args, **kwargs)
self.split_factor = kwargs.get("split_factor", 0.6)
self._name = "DampingSCFG"
self._set_name(
{
"split": f"{self.split_factor}-{1 - self.split_factor}",
"damping": str(self.damping_factor),
}
)
[docs]
class DampingCROnceEngine(DampingEngine, CostReductionOnceEngine):
"""A BP engine that combines message damping and one-time cost reduction."""
def __init__(self, *args, **kwargs):
"""Initializes the DampingCROnceEngine.
This engine inherits parameters from `DampingEngine` and
`CostReductionOnceEngine`.
Args:
*args: Positional arguments for the base engines.
**kwargs: Keyword arguments for the base engines (e.g.,
`damping_factor`, `reduction_factor`).
"""
kwargs.setdefault("reduction_factor", 0.5)
kwargs.setdefault("damping_factor", 0.9)
super().__init__(*args, **kwargs)
self.reduction_factor = kwargs.get("reduction_factor", 0.5)
self._name = "DampingCROnceEngine"
self._set_name(
{
"reduction": f"{self.reduction_factor}-{1 - self.reduction_factor}",
"damping": str(self.damping_factor),
}
)
[docs]
class TRWEngine(BPEngine):
"""
Tree-Reweighted Belief Propagation engine (Min-Sum variant).
The engine keeps the standard Min-Sum computator but automatically:
1. Samples spanning trees over the variable-only (primal) graph to
estimate per-factor appearance probabilities ``rho_f``.
2. Scales each factor's energy table by ``1 / rho_f`` before message
computation so local costs match the TRW objective.
3. Re-weights outgoing R-messages from factors by ``rho_f`` so that
variable updates/beliefs operate on appropriately weighted costs.
Rho sampling and scaling can be overridden by providing explicit
``factor_rhos`` (all > 0). Otherwise the engine performs end-to-end
TRW reweighting using the current factor graph structure.
"""
DEFAULT_MIN_RHO = 1e-6
def __init__(
self,
*args,
factor_rhos: Optional[Dict[str, float]] = None,
tree_sample_count: int = 250,
tree_sampler_seed: Optional[int] = None,
min_rho: float = DEFAULT_MIN_RHO,
**kwargs,
) -> None:
"""
Args:
factor_rhos:
Optional explicit mapping from factor name to rho_f > 0. When
omitted, the engine estimates rhos via spanning-tree sampling.
tree_sample_count:
Number of spanning trees to sample when estimating rhos.
tree_sampler_seed:
Seed forwarded to the tree sampler for reproducibility.
min_rho:
Lower bound applied to sampled rhos to keep them strictly
positive (important for stable scaling).
"""
self.tree_sample_count = max(1, tree_sample_count)
self.tree_sampler_seed = tree_sampler_seed
self.min_rho = max(min_rho, self.DEFAULT_MIN_RHO)
self._user_defined_rhos = bool(factor_rhos)
self.factor_rhos: Dict[str, float] = dict(factor_rhos or {})
super().__init__(*args, **kwargs)
self._name = "TRWEngine"
suffix = (
"custom" if self._user_defined_rhos else f"trees-{self.tree_sample_count}"
)
self._set_name({"trw": suffix})
[docs]
def post_init(self) -> None:
"""
Validate rho configuration, sample if needed, and scale costs.
Called from BPEngine.__init__ after `self.graph` is set but before
messages are initialized.
"""
factors = getattr(self.graph, "factors", [])
if not factors:
return
if not self.factor_rhos:
self.factor_rhos = self._estimate_rhos_via_spanning_trees(factors)
else:
for factor in factors:
self.factor_rhos.setdefault(factor.name, 1.0)
for factor in factors:
rho = self.factor_rhos.get(factor.name, 1.0)
if rho <= 0:
raise ValueError(
f"TRWEngine: rho for factor '{factor.name}' must be > 0, got {rho}"
)
self._scale_factor_cost_table(factor, rho)
[docs]
def post_factor_compute(self, factor: FactorAgent, iteration: int) -> None:
"""Scale outgoing R-messages by rho_f before they are sent."""
rho = self.factor_rhos.get(factor.name, 1.0)
if rho == 1.0 or not factor.mailer.outbox:
return
for msg in factor.mailer.outbox:
msg.data = rho * msg.data
# --- Internal helpers -------------------------------------------------
def _scale_factor_cost_table(self, factor: FactorAgent, rho: float) -> None:
"""Reset to the original cost table (if saved) and divide by rho."""
factor.save_original()
base = (
factor.original_cost_table
if factor.original_cost_table is not None
else factor.cost_table
)
if base is None:
return
factor.cost_table = base / rho
def _estimate_rhos_via_spanning_trees(
self, factors: list[FactorAgent]
) -> Dict[str, float]:
"""Compute rho_f by sampling spanning trees on the primal graph."""
primal_graph, edge_to_factors = self._build_primal_graph()
if (
primal_graph.number_of_edges() == 0
or primal_graph.number_of_nodes() == 0
or not nx.is_connected(primal_graph)
):
# Fallback: no usable topology information, keep uniform rhos.
return {factor.name: 1.0 for factor in factors}
counts = {factor.name: 0 for factor in factors}
rng = random.Random(self.tree_sampler_seed)
samples = max(1, self.tree_sample_count)
for _ in range(samples):
tree_edges = self._sample_spanning_tree(primal_graph, rng)
for node_u, node_v in tree_edges:
key = (node_u, node_v) if node_u <= node_v else (node_v, node_u)
for factor in edge_to_factors.get(key, []):
counts[factor.name] += 1
rhos: Dict[str, float] = {}
for factor in factors:
count = counts.get(factor.name, 0)
rho = count / samples if count > 0 else 0.0
if rho <= 0:
rho = self.min_rho
rhos[factor.name] = rho
return rhos
def _build_primal_graph(
self,
) -> tuple[nx.Graph, Dict[tuple[str, str], list[FactorAgent]]]:
"""Construct the variable-only graph used for tree sampling."""
graph = nx.Graph()
variables = getattr(self.graph, "variables", [])
graph.add_nodes_from(var.name for var in variables)
edge_to_factors: Dict[tuple[str, str], list[FactorAgent]] = {}
for factor in getattr(self.graph, "factors", []):
var_names = sorted(factor.connection_number.keys())
if len(var_names) != 2:
# Hyper-edges are left unweighted (rho defaults to 1).
continue
edge_key = (var_names[0], var_names[1])
graph.add_edge(*edge_key)
edge_to_factors.setdefault(edge_key, []).append(factor)
# Guard against multiple factors per variable pair by treating them
# as parallel edges (each receives credit whenever the pair is picked).
return graph, edge_to_factors
def _sample_spanning_tree(
self, graph: nx.Graph, rng: random.Random
) -> list[tuple[str, str]]:
"""
Sample a spanning tree using Wilson's algorithm (loop-erased random walk).
Wilson's method draws a uniformly random spanning tree without needing any
determinant computations, avoiding numerical issues on large graphs.
"""
nodes = list(graph.nodes())
if not nodes:
return []
root = rng.choice(nodes)
tree_nodes = {root}
unvisited = set(nodes)
unvisited.discard(root)
tree_edges: set[tuple[str, str]] = set()
while unvisited:
start = rng.choice(tuple(unvisited))
walk = [start]
visited_index = {start: 0}
current = start
while current not in tree_nodes:
neighbors = list(graph.neighbors(current))
if not neighbors:
break # disconnected; handled by caller beforehand
nxt = rng.choice(neighbors)
if nxt in visited_index:
# Erase the loop by truncating the walk.
loop_start = visited_index[nxt]
walk = walk[: loop_start + 1]
else:
walk.append(nxt)
visited_index[nxt] = len(walk) - 1
current = nxt
# Add the walk to the tree (connecting to the existing tree at `current`).
for u, v in zip(walk, walk[1:]):
edge = (u, v) if u <= v else (v, u)
tree_edges.add(edge)
tree_nodes.add(u)
tree_nodes.add(v)
unvisited.discard(u)
unvisited.discard(v)
tree_nodes.add(current)
unvisited.discard(current)
return list(tree_edges)
[docs]
class DampingTRWEngine(DampingEngine, TRWEngine):
"""A BP engine that combines TRW reweighting with message damping."""
def __init__(self, *args, **kwargs):
kwargs.setdefault("damping_factor", 0.9)
super().__init__(*args, **kwargs)
self.damping_factor = kwargs.get("damping_factor", 0.9)
trw_suffix = (
"custom"
if getattr(self, "_user_defined_rhos", False)
else f"trees-{self.tree_sample_count}"
)
self._name = "DampingTRWEngine"
self._set_name({"damping": str(self.damping_factor), "trw": trw_suffix})
[docs]
class MessagePruningEngine(BPEngine):
"""A BP engine that applies a message pruning policy to reduce memory usage."""
def __init__(
self,
*args,
prune_threshold: float = 1e-4,
min_iterations: int = 5,
adaptive_threshold: bool = True,
**kwargs,
):
"""Initializes the MessagePruningEngine.
Args:
*args: Positional arguments for the base `BPEngine`.
prune_threshold: The threshold below which messages are pruned.
min_iterations: The number of iterations to wait before pruning.
adaptive_threshold: Whether to adapt the threshold dynamically.
**kwargs: Keyword arguments for the base `BPEngine`.
"""
self.prune_threshold = prune_threshold
self.min_iterations = min_iterations
self.adaptive_threshold = adaptive_threshold
super().__init__(*args, **kwargs)
[docs]
def post_init(self) -> None:
"""Initializes and sets the message pruning policy on agent mailers."""
from ..policies.message_pruning import MessagePruningPolicy
self.pruning_policy = MessagePruningPolicy(
prune_threshold=self.prune_threshold,
min_iterations=self.min_iterations,
adaptive_threshold=self.adaptive_threshold,
)