import logging
from copy import deepcopy
from typing import Any, Dict, List
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from ..core.agents import FactorAgent, VariableAgent
from .computators import BPComputator
logger = logging.getLogger(__name__)
[docs]
class FactorGraph:
"""Represents a bipartite factor graph for belief propagation.
This class encapsulates the structure of a factor graph, which consists of
variable nodes and factor nodes. It enforces a bipartite structure where
variables are only connected to factors and vice versa. It uses a
`networkx.Graph` to manage the underlying connections.
Attributes:
variables (List[VariableAgent]): A list of all variable agents in the graph.
factors (List[FactorAgent]): A list of all factor agents in the graph.
G (nx.Graph): The underlying networkx graph structure.
"""
def __init__(
self,
variable_li: List[VariableAgent],
factor_li: List[FactorAgent],
edges: Dict[FactorAgent, List[VariableAgent]],
):
"""Initializes the factor graph.
Args:
variable_li: A list of `VariableAgent` objects.
factor_li: A list of `FactorAgent` objects.
edges: A dictionary mapping each factor agent to a list of the
variable agents it connects to.
"""
self.variables = variable_li
self.factors = factor_li
self.G = nx.Graph()
# Add nodes with bipartite attribute
self.G.add_nodes_from(self.variables, bipartite=0)
self.G.add_nodes_from(self.factors, bipartite=1)
self._add_edges(edges)
self._initialize_cost_tables()
self._original_factors = deepcopy(factor_li)
self._lb = None # Lower bound
self._ub = None # Upper bound
@property
def lb(self) -> int | float:
"""The lower bound of the problem, can be set externally."""
return self._lb # type: ignore
@lb.setter
def lb(self, value: int | float) -> None:
"""Sets the lower bound of the factor graph.
Args:
value: The lower bound value to set.
Raises:
ValueError: If the value is not an integer or float.
"""
if not isinstance(value, (int, float)):
raise ValueError("Lower bound must be an integer or float.")
self._lb = value
[docs]
def compute_cost(self, factors: List[FactorAgent] | None = None) -> float:
"""Compute total cost over the given factors using current variable assignments.
Args:
factors: which factors to sum over. Defaults to ``self.factors``
(current, possibly modified factors). Pass ``self.original_factors``
to get the cost against the original, unmodified cost tables.
"""
if factors is None:
factors = self.factors
assignments = {var.name: var.curr_assignment for var in self.variables}
total = 0.0
for factor in factors:
if factor.cost_table is None or not factor.connection_number:
continue
indices = []
valid = True
for var_name, dim in factor.connection_number.items():
if var_name in assignments:
while len(indices) <= dim:
indices.append(None)
indices[dim] = assignments[var_name]
else:
valid = False
break
if valid and None not in indices and len(indices) == factor.cost_table.ndim:
ct = (
factor.original_cost_table
if factor.original_cost_table is not None
else factor.cost_table
)
total += ct[tuple(indices)]
return total
@property
def global_cost(self) -> int | float:
"""Total cost over current (possibly modified) factors."""
return self.compute_cost()
@property
def curr_assignment(self) -> Dict[VariableAgent, int]:
"""dict: The current assignment for all variables in the graph."""
return {node: int(node.curr_assignment) for node in self.variables}
@property
def edges(self) -> Dict[FactorAgent, List[VariableAgent]]:
"""dict: Reconstructs the edge dictionary mapping factors to variables."""
edge_dict = {}
var_by_name = {v.name: v for v in self.variables}
for factor in self.factors:
if hasattr(factor, "connection_number"):
# Sort variables by their dimension index
vars_with_dims = []
vars_with_dims.extend(
(var_by_name[var_name], dim)
for var_name, dim in factor.connection_number.items()
if var_name in var_by_name
)
vars_with_dims.sort(key=lambda x: x[1])
edge_dict[factor] = [var for var, _ in vars_with_dims]
return edge_dict
[docs]
def set_computator(self, computator: BPComputator, **kwargs) -> None:
"""Assigns a computator to all nodes in the graph.
Args:
computator: The computator instance to assign.
**kwargs: Additional arguments (not currently used).
"""
for node in self.G.nodes():
node.computator = computator
[docs]
def normalize_messages(self) -> None:
"""Normalizes all incoming messages for all variable nodes.
This is a common technique to prevent numerical instability in belief
propagation algorithms by shifting message values.
"""
for var in self.variables:
for message in var.mailer.inbox:
message.data -= np.min(message.data)
[docs]
def visualize(
self,
layout: str = "bipartite",
layout_kwargs: Dict[str, Any] | None = None,
plot: bool = True,
*,
pretty: bool = False,
) -> Figure | None:
"""Visualizes the factor graph using matplotlib.
Variable nodes are drawn as circles, and factor nodes are drawn as squares.
If `plot` is True the figure is shown and the function returns None.
If `plot` is False the Figure object is returned (and not shown),
allowing further programmatic use.
Args:
layout: Layout algorithm to use. Supported values are
``"bipartite"`` (default), ``"spring"``, ``"circular"``, and
``"kamada_kawai"``. Ignored when ``pretty=True``.
layout_kwargs: Optional keyword arguments forwarded to the selected
NetworkX layout function.
plot: If True, call plt.show() and return None. If False, return
the matplotlib.figure.Figure instance without showing it.
pretty: When True, render a fixed spring-layout visualization with
styled nodes/legend useful for presentations.
"""
if pretty:
fig, ax = plt.subplots(figsize=(8.5, 6.5))
_plot_factor_graph(self, ax, "Factor Graph")
return self._finalize_visualization(plot, fig)
layout_kwargs = dict(layout_kwargs or {})
layout = layout.lower()
if layout == "bipartite":
layout_kwargs.setdefault("nodes", self.variables)
pos = nx.bipartite_layout(self.G, **layout_kwargs)
elif layout == "spring":
pos = nx.spring_layout(self.G, **layout_kwargs)
elif layout == "circular":
pos = nx.circular_layout(self.G, **layout_kwargs)
elif layout == "kamada_kawai":
pos = nx.kamada_kawai_layout(self.G, **layout_kwargs)
else:
raise ValueError(
f"Unsupported layout '{layout}'. "
"Choose from 'bipartite', 'spring', 'circular', or 'kamada_kawai'."
)
fig, ax = plt.subplots()
nx.draw_networkx_nodes(
self.G,
pos,
nodelist=self.variables,
node_shape="o",
node_color="lightblue",
node_size=300,
ax=ax,
)
nx.draw_networkx_nodes(
self.G,
pos,
nodelist=self.factors,
node_shape="s",
node_color="lightgreen",
node_size=300,
ax=ax,
)
nx.draw_networkx_edges(self.G, pos, ax=ax)
nx.draw_networkx_labels(self.G, pos, ax=ax)
ax.set_axis_off()
return self._finalize_visualization(plot, fig)
def _finalize_visualization(self, plot, fig):
plt.tight_layout()
if plot:
plt.show()
return None
return fig
def _add_edges(self, edges: Dict[FactorAgent, List[VariableAgent]]) -> None:
"""Adds edges and configures factor-variable connections.
Args:
edges: A dictionary mapping factors to the variables they connect.
"""
for factor, variables in edges.items():
if not hasattr(factor, "connection_number"):
factor.connection_number = {}
for i, var in enumerate(variables):
if not (
(factor in self.factors and var in self.variables)
or (factor in self.variables and var in self.factors)
):
raise ValueError("Edges must connect a factor to a variable.")
self.G.add_edge(factor, var, dim=i)
factor.connection_number[var.name] = i
logger.info("FactorGraph is bipartite: variables <-> factors only.")
def _initialize_cost_tables(self) -> None:
"""Initializes the cost tables for all factor nodes in the graph."""
for factor in self.factors:
if getattr(factor, "cost_table", None) is None:
factor.initiate_cost_table()
logger.debug("Cost table initialized for factor node: %s", factor.name)
elif getattr(factor, "original_cost_table", None) is None:
factor.save_original()
[docs]
def get_variable_agents(self) -> List[VariableAgent]:
"""Returns a list of all variable agents in the graph."""
return self.variables
[docs]
def get_factor_agents(self) -> List[FactorAgent]:
"""Returns a list of all factor agents in the graph."""
return self.factors
@property
def diameter(self) -> int:
"""int: The diameter of the factor graph.
If the graph is not connected, it returns the diameter of the
largest connected component.
"""
if not self.G:
return 0
if not nx.is_connected(self.G):
if not list(nx.connected_components(self.G)):
return 0
largest_cc = max(nx.connected_components(self.G), key=len)
subgraph = self.G.subgraph(largest_cc)
return nx.diameter(subgraph) if subgraph.nodes() else 0
return nx.diameter(self.G)
def __getstate__(self) -> dict:
"""Custom method to control what gets pickled."""
return self.__dict__.copy()
def __setstate__(self, state: dict) -> None:
"""Custom method to control unpickling behavior."""
self.__dict__.update(state)
if not hasattr(self, "G") or self.G is None:
self.G = nx.Graph()
if hasattr(self, "variables") and hasattr(self, "factors"):
self.G.add_nodes_from(self.variables, bipartite=0)
self.G.add_nodes_from(self.factors, bipartite=1)
var_name_to_obj = {var.name: var for var in self.variables}
for factor in self.factors:
if hasattr(factor, "connection_number"):
for var_name, dim in factor.connection_number.items():
if var_name in var_name_to_obj:
var = var_name_to_obj[var_name]
self.G.add_edge(factor, var, dim=dim)
@property
def original_factors(self) -> List[FactorAgent]:
"""list[FactorAgent]: A deep copy of the original factor agents."""
return self._original_factors
def _plot_factor_graph(graph: "FactorGraph", ax: Axes, title: str) -> None:
"""Pretty spring-layout view used when visualize(pretty=True) is requested."""
labels = {node: getattr(node, "name", str(node)) for node in graph.G.nodes}
pos = nx.spring_layout(graph.G, seed=42)
var_nodes = list(graph.variables)
fac_nodes = list(graph.factors)
ax.set_title(title)
ax.axis("off")
nx.draw_networkx_nodes(
graph.G,
pos,
nodelist=var_nodes,
node_color="#4c72b0",
node_size=600,
ax=ax,
label="Variables",
)
nx.draw_networkx_nodes(
graph.G,
pos,
nodelist=fac_nodes,
node_color="#dd8452",
node_shape="s",
node_size=500,
ax=ax,
label="Factors",
)
nx.draw_networkx_edges(graph.G, pos, ax=ax, alpha=0.3)
nx.draw_networkx_labels(graph.G, pos, labels=labels, font_size=10, ax=ax)
ax.legend(loc="upper right")