API Reference
This section provides detailed documentation for all JaxABM classes and functions.
Core Components
Package Overview
The main JaxABM package provides the following modules:
Core Model Classes
|
Base class for agent-based models in JaxABM. |
|
Collection of agents of the same type. |
|
Configuration for model execution. |
Analysis and Calibration
|
Calibrate model parameters using advanced optimization techniques. |
|
Perform sensitivity analysis on model parameters. |
|
Ensemble calibrator that combines multiple optimization methods. |
Utilities
Legacy Support
|
Base class for agent-based models in JaxABM. |
|
Container for managing collections of agents. |
|
Parameter for sensitivity analysis and model calibration. |
Quick Reference
Common Classes
- class jaxabm.Model(parameters=None, seed=None)[source]
Bases:
objectBase class for agent-based models in JaxABM.
This class provides an AgentPy-like interface for creating and running agent-based models.
Example
```python class MyModel(Model):
- def setup(self):
self.agents = self.add_agents(10, MyAgent) self.env.temperature = 25.0
- def step(self):
# Agents are stepped automatically by default # Add additional model logic here self.env.temperature += 0.1
# Record data self.record(‘temperature’, self.env.temperature)
model = MyModel(parameters) results = model.run() ```
- step()[source]
Execute a single time step.
Override this method to define model behavior. By default, it steps all agent lists.
- Return type:
- end()[source]
Execute code at the end of a simulation.
Override this method to define behavior at the end of a simulation.
- Return type:
- update_state(env_state, agent_states, model_params, key)[source]
Update model environment state.
This method is called by the JAX model to update the environment state based on agent states. By default, it returns the environment state unchanged. Override this method to define custom state update logic.
- compute_metrics(env_state, agent_states, model_params)[source]
Compute model metrics.
This method is called by the JAX model to compute metrics from model state. By default, it returns an empty dictionary. Override this method to define custom metrics.
- get_agent(collection_name, agent_id)[source]
Get an agent instance by collection name and ID.
This allows accessing agent instances for calling custom methods.
- class jaxabm.AgentCollection(agent_type, num_agents)[source]
Bases:
objectCollection of agents of the same type.
This class manages a collection of agents of the same type, providing methods for initialization, updating, and accessing agent states.
- agent_type
The type of agent in the collection
- num_agents
Number of agents in the collection
- model_config
Model configuration associated with this collection (set during Model.initialize)
- _states
Dictionary of agent state variables, with each variable having shape (num_agents, …)
- _key
The JAX PRNGKey used to initialize this collection
- init(key, model_config)[source]
Initialize agent states.
This method initializes the states of all agents in the collection using the agent type’s init_state method. It is typically called by Model.initialize().
- Parameters:
key (
Any) – Random key for stochastic initialization.model_config (
ModelConfig) – Model configuration settings passed from the Model.
- Return type:
- update(model_state, key, model_config)[source]
Update all agents in the collection using JAX vmap.
This method updates the internal states (self._states) of all agents using their agent type’s update method, vectorized with jax.vmap. It assumes the agent type adheres to the AgentType protocol and its update method returns only the updated AgentState.
- class jaxabm.analysis.ModelCalibrator(model_factory, initial_params, target_metrics, param_bounds=None, metrics_weights=None, learning_rate=0.01, max_iterations=100, method='adam', loss_type='mse', evaluation_steps=50, num_evaluation_runs=3, tolerance=1e-06, patience=10, seed=0)[source]
Bases:
objectCalibrate model parameters using advanced optimization techniques.
This class provides methods for automatically tuning model parameters to achieve desired outputs, using gradient-based optimization with Adam, or various reinforcement learning and evolutionary approaches.
- model_factory
Function to create model instances
- params
Current parameter values
- target_metrics
Target values for each metric
- metrics_weights
Importance weights for each metric in the loss function
- learning_rate
Learning rate for optimization
- max_iterations
Maximum number of optimization iterations
- method
Calibration method
- loss_type
Type of loss function to use
- param_bounds
Parameter bounds for each parameter
- evaluation_steps
Number of steps to run model for evaluation
- num_evaluation_runs
Number of runs to average for robust evaluation
- loss_history
History of loss values during calibration
- param_history
History of parameter values during calibration
- confidence_intervals
Confidence intervals for metrics
- class jaxabm.analysis.SensitivityAnalysis(model_factory, param_ranges, metrics_of_interest, num_samples=100, seed=0)[source]
Bases:
objectPerform sensitivity analysis on model parameters.
This class provides tools for analyzing how changes in model parameters affect model outputs, using efficient sampling techniques and sensitivity indices calculation.
- model_factory
Function to create model instances
- param_ranges
Dictionary mapping parameter names to (min, max) ranges
- metrics_of_interest
List of metric names to analyze
- num_samples
Number of parameter samples to generate
- key
JAX random key
- samples
Generated parameter samples
- results
Analysis results (populated after run())
- sobol_indices()[source]
Calculate sensitivity indices for each parameter and metric.
This is a simplified implementation that calculates correlation-based indices as a proxy for Sobol indices. For a full Sobol analysis, specialized sampling would be required.
- plot(metric=None, ax=None, **kwargs)[source]
Plot sensitivity analysis results.
- Parameters:
metric – Metric to plot. If None, plot sobol indices for all metrics.
ax – Matplotlib axis to use for plotting.
**kwargs – Additional keyword arguments to pass to plotting function.
- Returns:
Matplotlib axis.