jaxabm.analysis.SensitivityAnalysis
- class jaxabm.analysis.SensitivityAnalysis(model_factory, param_ranges, metrics_of_interest, num_samples=100, seed=0)[source]
Bases:
objectPerform sensitivity analysis on model parameters.
This class provides tools for analyzing how changes in model parameters affect model outputs, using efficient sampling techniques and sensitivity indices calculation.
- model_factory
Function to create model instances
- param_ranges
Dictionary mapping parameter names to (min, max) ranges
- metrics_of_interest
List of metric names to analyze
- num_samples
Number of parameter samples to generate
- key
JAX random key
- samples
Generated parameter samples
- results
Analysis results (populated after run())
- __init__(model_factory, param_ranges, metrics_of_interest, num_samples=100, seed=0)[source]
Initialize sensitivity analysis.
- Parameters:
model_factory (
TypeVar(ModelFactory, bound=Callable[...,Model])) – Function to create model instancesparam_ranges (
Dict[str,Tuple[float,float]]) – Dictionary mapping parameter names to (min, max) rangesmetrics_of_interest (
List[str]) – List of metric names to analyzenum_samples (
int) – Number of parameter samples to generateseed (
int) – Random seed
Methods
__init__(model_factory, param_ranges, ...[, ...])Initialize sensitivity analysis.
plot([metric, ax])Plot sensitivity analysis results.
plot_indices([figsize])Plot the sensitivity indices.
run([verbose])Run sensitivity analysis.
Calculate sensitivity indices for each parameter and metric.
- sobol_indices()[source]
Calculate sensitivity indices for each parameter and metric.
This is a simplified implementation that calculates correlation-based indices as a proxy for Sobol indices. For a full Sobol analysis, specialized sampling would be required.
- plot(metric=None, ax=None, **kwargs)[source]
Plot sensitivity analysis results.
- Parameters:
metric – Metric to plot. If None, plot sobol indices for all metrics.
ax – Matplotlib axis to use for plotting.
**kwargs – Additional keyword arguments to pass to plotting function.
- Returns:
Matplotlib axis.