"""A parallelized simulator for running and comparing multiple engine configurations.
This module provides a `Simulator` class that can run multiple belief propagation
engine configurations across a set of factor graphs in parallel. It uses Python's
`multiprocessing` module to distribute the simulation runs, collects the results,
and provides a simple plotting utility to visualize and compare the performance
of different engines.
"""
import pickle
import random
import sys
import time
import traceback
from multiprocessing import Pool, cpu_count
from typing import Any, Dict, List, Optional, Tuple
import colorlog
import matplotlib.pyplot as plt
import numpy as np
from .configs import Logger
from .configs.global_config_mapping import (
LOG_LEVELS,
SimulatorDefaults,
get_validated_config,
)
from .policies import ConvergenceConfig
def _setup_logger(level: Optional[str] = None) -> Logger:
"""Configures and returns a logger for the simulator."""
logging_config = get_validated_config("logging")
safe_level = (
level if isinstance(level, str) else SimulatorDefaults().default_log_level
)
log_level = LOG_LEVELS.get(safe_level.upper(), logging_config["default_level"])
logger = Logger("Simulator")
logger.setLevel(log_level)
if not logger.handlers:
console = colorlog.StreamHandler(sys.stdout)
console.setFormatter(
colorlog.ColoredFormatter(
logging_config["console_format"],
log_colors=logging_config["console_colors"],
)
)
logger.addHandler(console)
return logger
[docs]
class Simulator:
"""Orchestrates parallel execution of multiple simulation configurations.
This class takes a set of engine configurations and a list of factor graphs,
runs each engine on each graph in parallel, collects the cost history from
each run, and provides methods to visualize the aggregated results.
Attributes:
engine_configs (dict): A dictionary mapping engine names to their configurations.
logger (Logger): A configured logger instance.
results (dict): A dictionary to store the results of the simulations.
timeout (int): The timeout in seconds for multiprocessing tasks.
"""
def __init__(
self,
engine_configs: Dict[str, Any],
log_level: Optional[str] = None,
*,
seed: int | None = None,
):
"""Initializes the Simulator.
Args:
engine_configs: A dictionary where keys are descriptive engine names
and values are configuration dictionaries for each engine.
log_level: The logging level for the simulator (e.g., 'INFO', 'DEBUG').
seed: Optional base seed applied to every simulation job. When
provided, each worker process deterministically seeds numpy and
Python's ``random`` module before running an engine.
"""
self.engine_configs = engine_configs
self.logger = _setup_logger(log_level)
self.results: Dict[str, List[List[float]]] = {
name: [] for name in engine_configs
}
self.timeout = SimulatorDefaults().timeout
self._seed = seed
[docs]
def run_simulations(
self, graphs: List[Any], max_iter: Optional[int] = None
) -> Dict[str, List[List[float]]]:
"""Runs all engine configurations on all provided graphs in parallel.
Args:
graphs: A list of factor graph objects to run simulations on.
max_iter: The maximum number of iterations for each simulation run.
Returns:
A dictionary containing the collected results, where keys are engine
names and values are lists of cost histories for each run.
"""
max_iter = max_iter or SimulatorDefaults().default_max_iter
self.logger.warning(
f"Preparing {len(graphs) * len(self.engine_configs)} total simulations."
)
simulation_args: List[Tuple[Any, ...]] = []
job_counter = 0
for i, graph in enumerate(graphs):
graph_blob = pickle.dumps(graph)
for name, config in self.engine_configs.items():
job_seed = None
if self._seed is not None:
job_seed = self._seed + job_counter
simulation_args.append(
(
i,
name,
config,
graph_blob,
max_iter,
self.logger.level,
job_seed,
)
)
job_counter += 1
start_time = time.time()
try:
all_results = self._run_batch_safe(simulation_args, max_workers=cpu_count())
except Exception as e:
self.logger.error(
f"CRITICAL ERROR - All multiprocessing strategies failed: {e}"
)
self.logger.error(traceback.format_exc())
self.logger.warning("Falling back to sequential processing...")
all_results = self._sequential_fallback(simulation_args)
total_time = time.time() - start_time
self.logger.warning(f"All simulations completed in {total_time:.2f} seconds.")
if len(all_results) != len(simulation_args):
self.logger.error(
f"Expected {len(simulation_args)} results, but got {len(all_results)}"
)
for _, engine_name, costs in all_results:
self.results[engine_name].append(costs)
for engine_name, costs_list in self.results.items():
self.logger.warning(f"{engine_name}: {len(costs_list)} runs completed.")
return self.results
[docs]
def plot_results(
self, max_iter: Optional[int] = None, verbose: bool = False
) -> None:
"""Plots the average cost convergence for each engine configuration.
Args:
max_iter: The maximum number of iterations to display on the plot.
verbose: If True, plots individual simulation runs with transparency
and standard deviation bands around the average.
"""
max_iter = max_iter or SimulatorDefaults().default_max_iter
self.logger.warning(f"Starting plotting... (Verbose: {verbose})")
plt.figure(figsize=(12, 8))
colors = plt.cm.viridis(np.linspace(0, 1, len(self.results)))
for idx, (engine_name, costs_list) in enumerate(self.results.items()):
valid_costs_list = [c for c in costs_list if c]
if not valid_costs_list:
self.logger.error(f"No valid cost data for {engine_name}")
continue
max_len = max(max_iter, max(len(c) for c in valid_costs_list))
padded_costs = np.array(
[c + [c[-1]] * (max_len - len(c)) for c in valid_costs_list]
)
avg_costs = np.mean(padded_costs, axis=0)
color = colors[idx]
if verbose:
for i in range(padded_costs.shape[0]):
plt.plot(padded_costs[i, :], color=color, alpha=0.2, linewidth=0.5)
std_costs = np.std(padded_costs, axis=0)
plt.fill_between(
range(max_len),
avg_costs - std_costs,
avg_costs + std_costs,
color=color,
alpha=0.1,
)
plt.plot(avg_costs, label=f"{engine_name} (Avg)", color=color, linewidth=2)
self.logger.warning(
f"Plotted {engine_name}: avg final cost = {avg_costs[-1]:.2f}"
)
plt.title(
f"Average Costs over {len(self.results.get(list(self.results.keys())[0], []))} Runs",
fontsize=14,
)
plt.xlabel("Iteration", fontsize=12)
plt.ylabel("Average Cost", fontsize=12)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
self.logger.warning("Displaying plot.")
[docs]
def set_log_level(self, level: str) -> None:
"""Sets the logging level for the simulator's logger.
Args:
level: The desired logging level (e.g., 'INFO', 'DEBUG').
"""
log_level = LOG_LEVELS.get(level.upper())
if log_level:
self.logger.setLevel(log_level)
self.logger.warning(f"Log level set to {level.upper()}")
else:
self.logger.error(f"Invalid log level: {level}")
@staticmethod
def _run_single_simulation(args: Tuple) -> Tuple[int, str, List[float]]:
"""A static method to run a single simulation instance, designed for multiprocessing."""
(
graph_index,
engine_name,
config,
graph_data,
max_iter,
log_level,
job_seed,
) = args
logger = _setup_logger(str(log_level))
try:
if job_seed is not None:
np.random.seed(job_seed)
random.seed(job_seed)
fg_copy = pickle.loads(graph_data)
engine_cls = config["class"]
engine_params = {k: v for k, v in config.items() if k != "class"}
engine = engine_cls(
factor_graph=fg_copy,
convergence_config=ConvergenceConfig(),
**engine_params,
)
engine.run(max_iter=max_iter)
costs = [
float(s.global_cost)
for s in engine.snapshots
if s.global_cost is not None
]
logger.info(
f"Finished: graph {graph_index}, engine {engine_name}. Final cost: {costs[-1] if costs else 'N/A'}"
)
return (graph_index, engine_name, costs)
except Exception as e:
logger.error(
f"Exception in child process for graph {graph_index}, engine {engine_name}: {e}\n{traceback.format_exc()}"
)
return (graph_index, engine_name, [])
def _run_batch_safe(
self, simulation_args: List[Tuple], max_workers: int
) -> List[Tuple]:
"""Runs simulations in parallel with a timeout, falling back to batching."""
self.logger.warning(
f"Attempting full multiprocessing with {max_workers} processes..."
)
try:
with Pool(processes=max_workers) as pool:
result = pool.map_async(self._run_single_simulation, simulation_args)
return result.get(timeout=self.timeout)
except Exception as e:
self.logger.error(f"Full multiprocessing failed: {e}")
self.logger.warning("Trying batch processing...")
return self._run_in_batches(
simulation_args, max_workers=max(1, max_workers // 2)
)
def _run_in_batches(
self, simulation_args: List[Tuple], batch_size: int = 8, max_workers: int = 4
) -> List[Tuple]:
"""Runs simulations in smaller parallel batches as a fallback."""
self.logger.warning(
f"Starting batch processing with batch_size={batch_size} and max_workers={max_workers}"
)
all_results = []
for i in range(0, len(simulation_args), batch_size):
batch = simulation_args[i : i + batch_size]
self.logger.warning(
f"Running batch {i // batch_size + 1}/{len(simulation_args) // batch_size + 1}..."
)
try:
with Pool(processes=min(max_workers, len(batch))) as pool:
all_results.extend(pool.map(self._run_single_simulation, batch))
except Exception as e:
self.logger.error(
f"Batch failed: {e}. Running sequentially as fallback."
)
all_results.extend(self._sequential_fallback(batch))
return all_results
def _sequential_fallback(self, simulation_args: List[Tuple]) -> List[Tuple]:
"""Runs all simulations sequentially as a final fallback."""
self.logger.warning("Running all simulations sequentially as a last resort.")
return [self._run_single_simulation(args) for args in simulation_args]