jaxabm.analysis.ModelCalibrator

class jaxabm.analysis.ModelCalibrator(model_factory, initial_params, target_metrics, param_bounds=None, metrics_weights=None, learning_rate=0.01, max_iterations=100, method='adam', loss_type='mse', evaluation_steps=50, num_evaluation_runs=3, tolerance=1e-06, patience=10, seed=0)[source]

Bases: object

Calibrate model parameters using advanced optimization techniques.

This class provides methods for automatically tuning model parameters to achieve desired outputs, using gradient-based optimization with Adam, or various reinforcement learning and evolutionary approaches.

model_factory

Function to create model instances

params

Current parameter values

target_metrics

Target values for each metric

metrics_weights

Importance weights for each metric in the loss function

learning_rate

Learning rate for optimization

max_iterations

Maximum number of optimization iterations

method

Calibration method

loss_type

Type of loss function to use

param_bounds

Parameter bounds for each parameter

evaluation_steps

Number of steps to run model for evaluation

num_evaluation_runs

Number of runs to average for robust evaluation

loss_history

History of loss values during calibration

param_history

History of parameter values during calibration

confidence_intervals

Confidence intervals for metrics

__init__(model_factory, initial_params, target_metrics, param_bounds=None, metrics_weights=None, learning_rate=0.01, max_iterations=100, method='adam', loss_type='mse', evaluation_steps=50, num_evaluation_runs=3, tolerance=1e-06, patience=10, seed=0)[source]

Initialize model calibrator.

Parameters:
  • model_factory (TypeVar(ModelFactory, bound= Callable[..., Model])) – Function to create model instances

  • initial_params (Dict[str, float]) – Initial parameter values

  • target_metrics (Dict[str, float]) – Target metric values

  • param_bounds (Optional[Dict[str, Tuple[float, float]]]) – Bounds for each parameter as (min, max) tuples

  • metrics_weights (Optional[Dict[str, float]]) – Weights for each metric in the loss function

  • learning_rate (float) – Learning rate for optimization

  • max_iterations (int) – Maximum number of optimization iterations

  • method (str) – Calibration method (“adam”, “sgd”, “es”, “pso”, “cem”, “bayesian”)

  • loss_type (str) – Loss function type (“mse”, “mae”, “huber”, “relative”)

  • evaluation_steps (int) – Number of simulation steps for evaluation

  • num_evaluation_runs (int) – Number of runs to average for robust evaluation

  • tolerance (float) – Convergence tolerance

  • patience (int) – Early stopping patience

  • seed (int) – Random seed

Methods

__init__(model_factory, initial_params, ...)

Initialize model calibrator.

calibrate([verbose])

Run calibration process and return optimized parameters.

get_calibration_history()

Get calibration history.

plot_calibration([figsize])

Plot comprehensive calibration results.

calibrate(verbose=True)[source]

Run calibration process and return optimized parameters.

Return type:

Dict[str, float]

get_calibration_history()[source]

Get calibration history.

Return type:

Dict[str, List[Any]]

Returns:

Dictionary with ‘loss’, ‘params’, and ‘confidence_intervals’ histories

plot_calibration(figsize=(15, 10))[source]

Plot comprehensive calibration results.

Parameters:

figsize (Tuple[int, int]) – Figure size as (width, height)

Return type:

Any

Returns:

Matplotlib figure and axes