jaxabm.analysis.EnsembleCalibrator

class jaxabm.analysis.EnsembleCalibrator(model_factory, initial_params, target_metrics, methods=['adam', 'es', 'pso'], **kwargs)[source]

Bases: object

Ensemble calibrator that combines multiple optimization methods.

This class runs multiple calibration methods in parallel and combines their results to find the best parameter set.

__init__(model_factory, initial_params, target_metrics, methods=['adam', 'es', 'pso'], **kwargs)[source]

Initialize ensemble calibrator.

Parameters:
  • model_factory (TypeVar(ModelFactory, bound= Callable[..., Model])) – Function to create model instances

  • initial_params (Dict[str, float]) – Initial parameter values

  • target_metrics (Dict[str, float]) – Target metric values

  • methods (List[str]) – List of optimization methods to use

  • **kwargs – Additional arguments passed to individual calibrators

Methods

__init__(model_factory, initial_params, ...)

Initialize ensemble calibrator.

calibrate([verbose])

Run ensemble calibration.

plot_comparison([figsize])

Plot comparison of all methods.

calibrate(verbose=True)[source]

Run ensemble calibration.

Parameters:

verbose (bool) – Whether to print progress information

Return type:

Dict[str, Any]

Returns:

Dictionary with results from all methods and the best overall result

plot_comparison(figsize=(15, 10))[source]

Plot comparison of all methods.

Parameters:

figsize (Tuple[int, int]) – Figure size as (width, height)

Return type:

Any

Returns:

Matplotlib figure and axes