Source code for propflow.policies.splitting

"""A Policy for Splitting Factors in a Factor Graph.

This module provides a function that implements the factor splitting policy.
This technique modifies the structure of the factor graph by replacing each
factor with two new "cloned" factors. The original cost table is distributed
between these two clones. This can be useful for altering the message-passing
dynamics and can sometimes help with convergence or finding better solutions.
"""

from __future__ import annotations

from copy import deepcopy
import math
import random
from typing import Dict, List, Sequence

import networkx as nx

from ..bp.factor_graph import FactorGraph
from ..configs.global_config_mapping import PolicyDefaults
from ..core.agents import FactorAgent


def _split_factors(
    fg: FactorGraph, factors: List[FactorAgent], p: float
) -> Dict[str, List[FactorAgent]]:
    """Replaces each factor with two clones having cost tables p*C and (1-p)*C."""
    G: nx.Graph = fg.G
    split_mapping: Dict[str, List[FactorAgent]] = {}

    for f in list(factors):
        cost1 = p * f.cost_table  # type: ignore
        cost2 = (1.0 - p) * f.cost_table  # type: ignore

        f1 = f.create_from_cost_table(cost_table=cost1, name=f"{f.name}'")
        f2 = f.create_from_cost_table(cost_table=cost2, name=f"{f.name}''")

        f1.connection_number = deepcopy(f.connection_number)
        f2.connection_number = deepcopy(f.connection_number)

        for v, edge_data in G[f].items():
            G.add_edge(f1, v, **edge_data)
            G.add_edge(f2, v, **edge_data)

        fg.factors.append(f1)
        fg.factors.append(f2)

        G.remove_node(f)
        fg.factors.remove(f)
        split_mapping[f.name] = [f1, f2]

    return split_mapping


def _select_factors_for_split(
    fg: FactorGraph,
    *,
    factor_names: Sequence[str] | None = None,
    split_fraction: float | None = None,
    seed: int | None = None,
) -> List[FactorAgent]:
    candidates = list(fg.factors)
    if factor_names is not None:
        by_name = {factor.name: factor for factor in candidates}
        missing = [name for name in factor_names if name not in by_name]
        if missing:
            raise ValueError(f"Unknown factor(s) requested for splitting: {missing}")
        candidates = [by_name[name] for name in factor_names]

    candidates = sorted(candidates, key=lambda factor: factor.name)
    if split_fraction is None:
        return candidates

    if not 0.0 < split_fraction <= 1.0:
        raise ValueError("split_fraction must be in the range (0, 1].")

    if not candidates:
        return []

    count = max(1, math.ceil(len(candidates) * split_fraction))
    rng = random.Random(seed)
    return sorted(rng.sample(candidates, count), key=lambda factor: factor.name)


[docs] def split_factors( fg: FactorGraph, p: float | None = None, *, factor_names: Sequence[str] | None = None, split_fraction: float | None = None, seed: int | None = None, ) -> Dict[str, List[FactorAgent]]: """Split selected factors in-place and return original-to-clone mapping. Args: fg: The factor graph to modify. p: Split proportion allocated to the first clone. factor_names: Optional factor names to restrict the split target set. When omitted, all factors are candidates. split_fraction: Optional fraction of the candidate factors to split. Selection is deterministic for a given seed. seed: Seed used when ``split_fraction`` is provided. Returns: A mapping from original factor name to the two replacement clone factors. """ if p is None: p = PolicyDefaults().split_factor assert 0.0 < p < 1.0, "p must be in (0,1)" # type: ignore selected = _select_factors_for_split( fg, factor_names=factor_names, split_fraction=split_fraction, seed=seed, ) return _split_factors(fg, selected, p)
[docs] def split_all_factors( fg: FactorGraph, p: float = None, # type: ignore ) -> Dict[str, List[FactorAgent]]: """Performs an in-place replacement of every factor with two cloned factors. Each factor `f` with cost table `C` is replaced by two new factors, `f'` and `f''`, with cost tables `p*C` and `(1-p)*C`, respectively. The new factors inherit all the connections of the original factor. This function directly modifies the provided `FactorGraph` object, including its underlying `networkx.Graph` and its list of factors. Args: fg: The `FactorGraph` to modify. p: The splitting proportion, which must be between 0 and 1. This determines how the original cost is distributed between the two new factors. If None, the default from `POLICY_DEFAULTS` is used. Raises: AssertionError: If `p` is not in the range (0, 1). """ if p is None: p = PolicyDefaults().split_factor assert 0.0 < p < 1.0, "p must be in (0,1)" return _split_factors(fg, fg.factors, p)
[docs] def split_specific_factors( fg: FactorGraph, factors: List[FactorAgent], p: float | None = None ) -> Dict[str, List[FactorAgent]]: """Split specific factors in the factor graph. Args: fg (FactorGraph): The factor graph to modify. factors (List[FactorAgent]): The factors to split. p (float | None, optional): The splitting proportion. Defaults to None. """ if p is None: p = PolicyDefaults().split_factor assert 0.0 < p < 1.0, "p must be in (0,1)" # type: ignore return _split_factors(fg, factors, p)