jaxabm.analysis.ModelCalibrator
- class jaxabm.analysis.ModelCalibrator(model_factory, initial_params, target_metrics, param_bounds=None, metrics_weights=None, learning_rate=0.01, max_iterations=100, method='adam', loss_type='mse', evaluation_steps=50, num_evaluation_runs=3, tolerance=1e-06, patience=10, seed=0)[source]
Bases:
objectCalibrate model parameters using advanced optimization techniques.
This class provides methods for automatically tuning model parameters to achieve desired outputs, using gradient-based optimization with Adam, or various reinforcement learning and evolutionary approaches.
- model_factory
Function to create model instances
- params
Current parameter values
- target_metrics
Target values for each metric
- metrics_weights
Importance weights for each metric in the loss function
- learning_rate
Learning rate for optimization
- max_iterations
Maximum number of optimization iterations
- method
Calibration method
- loss_type
Type of loss function to use
- param_bounds
Parameter bounds for each parameter
- evaluation_steps
Number of steps to run model for evaluation
- num_evaluation_runs
Number of runs to average for robust evaluation
- loss_history
History of loss values during calibration
- param_history
History of parameter values during calibration
- confidence_intervals
Confidence intervals for metrics
- __init__(model_factory, initial_params, target_metrics, param_bounds=None, metrics_weights=None, learning_rate=0.01, max_iterations=100, method='adam', loss_type='mse', evaluation_steps=50, num_evaluation_runs=3, tolerance=1e-06, patience=10, seed=0)[source]
Initialize model calibrator.
- Parameters:
model_factory (
TypeVar(ModelFactory, bound=Callable[...,Model])) – Function to create model instancesinitial_params (
Dict[str,float]) – Initial parameter valuesparam_bounds (
Optional[Dict[str,Tuple[float,float]]]) – Bounds for each parameter as (min, max) tuplesmetrics_weights (
Optional[Dict[str,float]]) – Weights for each metric in the loss functionlearning_rate (
float) – Learning rate for optimizationmax_iterations (
int) – Maximum number of optimization iterationsmethod (
str) – Calibration method (“adam”, “sgd”, “es”, “pso”, “cem”, “bayesian”)loss_type (
str) – Loss function type (“mse”, “mae”, “huber”, “relative”)evaluation_steps (
int) – Number of simulation steps for evaluationnum_evaluation_runs (
int) – Number of runs to average for robust evaluationtolerance (
float) – Convergence tolerancepatience (
int) – Early stopping patienceseed (
int) – Random seed
Methods
__init__(model_factory, initial_params, ...)Initialize model calibrator.
calibrate([verbose])Run calibration process and return optimized parameters.
Get calibration history.
plot_calibration([figsize])Plot comprehensive calibration results.