API Reference

This section provides detailed documentation for all JaxABM classes and functions.

Core Components

Package Overview

The main JaxABM package provides the following modules:

Core Model Classes

Model([parameters, seed])

Base class for agent-based models in JaxABM.

AgentCollection(agent_type, num_agents)

Collection of agents of the same type.

ModelConfig([seed, steps, track_history, ...])

Configuration for model execution.

Analysis and Calibration

analysis.ModelCalibrator(model_factory, ...)

Calibrate model parameters using advanced optimization techniques.

analysis.SensitivityAnalysis(model_factory, ...)

Perform sensitivity analysis on model parameters.

analysis.EnsembleCalibrator(model_factory, ...)

Ensemble calibrator that combines multiple optimization methods.

Utilities

Legacy Support

agentpy.Model([parameters, seed])

Base class for agent-based models in JaxABM.

agentpy.AgentList(model, n, agent_class, ...)

Container for managing collections of agents.

agentpy.Parameter(name, bounds[, distribution])

Parameter for sensitivity analysis and model calibration.

Quick Reference

Common Classes

class jaxabm.Model(parameters=None, seed=None)[source]

Bases: object

Base class for agent-based models in JaxABM.

This class provides an AgentPy-like interface for creating and running agent-based models.

Example

```python class MyModel(Model):

def setup(self):

self.agents = self.add_agents(10, MyAgent) self.env.temperature = 25.0

def step(self):

# Agents are stepped automatically by default # Add additional model logic here self.env.temperature += 0.1

# Record data self.record(‘temperature’, self.env.temperature)

model = MyModel(parameters) results = model.run() ```

setup()[source]

Set up model.

Override this method to set up agents and environment.

Return type:

None

step()[source]

Execute a single time step.

Override this method to define model behavior. By default, it steps all agent lists.

Return type:

None

end()[source]

Execute code at the end of a simulation.

Override this method to define behavior at the end of a simulation.

Return type:

None

update_state(env_state, agent_states, model_params, key)[source]

Update model environment state.

This method is called by the JAX model to update the environment state based on agent states. By default, it returns the environment state unchanged. Override this method to define custom state update logic.

Parameters:
  • env_state (Dict[str, Any]) – Current environment state.

  • agent_states (Dict[str, Dict[str, Any]]) – Current agent states by collection.

  • model_params (Dict[str, Any]) – Model parameters.

  • key (Array) – JAX random key.

Return type:

Dict[str, Any]

Returns:

Updated environment state.

compute_metrics(env_state, agent_states, model_params)[source]

Compute model metrics.

This method is called by the JAX model to compute metrics from model state. By default, it returns an empty dictionary. Override this method to define custom metrics.

Parameters:
  • env_state (Dict[str, Any]) – Current environment state.

  • agent_states (Dict[str, Dict[str, Any]]) – Current agent states by collection.

  • model_params (Dict[str, Any]) – Model parameters.

Return type:

Dict[str, Any]

Returns:

Dictionary of metrics.

add_agents(n, agent_class, name=None, **kwargs)[source]

Add agents to the model.

Parameters:
  • n (int) – Number of agents to add.

  • agent_class (Type[Agent]) – Agent class to use.

  • name (Optional[str]) – Name for this agent collection.

  • **kwargs – Parameters to pass to the agents.

Return type:

AgentList

Returns:

AgentList of created agents.

get_agent(collection_name, agent_id)[source]

Get an agent instance by collection name and ID.

This allows accessing agent instances for calling custom methods.

Parameters:
  • collection_name (str) – Name of the agent collection.

  • agent_id (int) – ID of the agent.

Return type:

Optional[Agent]

Returns:

Agent instance if found, None otherwise.

record(name, value)[source]

Record data for later analysis.

Parameters:
  • name (str) – Name for the recorded data.

  • value (Any) – Value to record.

Return type:

None

run(steps=None)[source]

Run the model for the specified number of steps.

Parameters:

steps (Optional[int]) – Number of steps to run.

Return type:

Results

Returns:

Results object containing simulation results.

batch_run(parameter_ranges, repetitions=1)[source]

Run model with multiple parameter combinations.

Parameters:
  • parameter_ranges (Dict[str, List[Any]]) – Dictionary mapping parameter names to lists of values.

  • repetitions (int) – Number of repetitions for each parameter combination.

Return type:

Dict[str, Results]

Returns:

Dictionary mapping parameter combinations to Results objects.

class jaxabm.AgentCollection(agent_type, num_agents)[source]

Bases: object

Collection of agents of the same type.

This class manages a collection of agents of the same type, providing methods for initialization, updating, and accessing agent states.

agent_type

The type of agent in the collection

num_agents

Number of agents in the collection

model_config

Model configuration associated with this collection (set during Model.initialize)

_states

Dictionary of agent state variables, with each variable having shape (num_agents, …)

_key

The JAX PRNGKey used to initialize this collection

init(key, model_config)[source]

Initialize agent states.

This method initializes the states of all agents in the collection using the agent type’s init_state method. It is typically called by Model.initialize().

Parameters:
  • key (Any) – Random key for stochastic initialization.

  • model_config (ModelConfig) – Model configuration settings passed from the Model.

Return type:

None

update(model_state, key, model_config)[source]

Update all agents in the collection using JAX vmap.

This method updates the internal states (self._states) of all agents using their agent type’s update method, vectorized with jax.vmap. It assumes the agent type adheres to the AgentType protocol and its update method returns only the updated AgentState.

Parameters:
  • model_state (Dict[str, Any]) – Current model state (environment + other agent states)

  • key (Any) – Random key for stochastic updates

  • model_config (ModelConfig) – Model configuration settings

Return type:

None

get_states()[source]

Get agent states (alias for states property for backward compatibility).

Return type:

Dict[str, Any]

Returns:

Dictionary of agent state variables

property states: Dict[str, Array]

Get agent states.

Returns:

Dictionary of agent state variables

aggregate(variable, fn=<function mean>)[source]

Aggregate a state variable across agents.

Parameters:
  • variable (str) – Name of the state variable to aggregate

  • fn (Callable) – Aggregation function (default: mean)

Return type:

Any

Returns:

Aggregated value

filter(condition)[source]

Filter agents based on a condition.

This method creates a new agent collection with agents that meet the specified condition.

Parameters:

condition (Callable[[Dict[str, Any]], Array]) – Function that takes agent state and returns boolean array

Return type:

AgentCollection

Returns:

New agent collection with filtered agents

class jaxabm.analysis.ModelCalibrator(model_factory, initial_params, target_metrics, param_bounds=None, metrics_weights=None, learning_rate=0.01, max_iterations=100, method='adam', loss_type='mse', evaluation_steps=50, num_evaluation_runs=3, tolerance=1e-06, patience=10, seed=0)[source]

Bases: object

Calibrate model parameters using advanced optimization techniques.

This class provides methods for automatically tuning model parameters to achieve desired outputs, using gradient-based optimization with Adam, or various reinforcement learning and evolutionary approaches.

model_factory

Function to create model instances

params

Current parameter values

target_metrics

Target values for each metric

metrics_weights

Importance weights for each metric in the loss function

learning_rate

Learning rate for optimization

max_iterations

Maximum number of optimization iterations

method

Calibration method

loss_type

Type of loss function to use

param_bounds

Parameter bounds for each parameter

evaluation_steps

Number of steps to run model for evaluation

num_evaluation_runs

Number of runs to average for robust evaluation

loss_history

History of loss values during calibration

param_history

History of parameter values during calibration

confidence_intervals

Confidence intervals for metrics

calibrate(verbose=True)[source]

Run calibration process and return optimized parameters.

Return type:

Dict[str, float]

get_calibration_history()[source]

Get calibration history.

Return type:

Dict[str, List[Any]]

Returns:

Dictionary with ‘loss’, ‘params’, and ‘confidence_intervals’ histories

plot_calibration(figsize=(15, 10))[source]

Plot comprehensive calibration results.

Parameters:

figsize (Tuple[int, int]) – Figure size as (width, height)

Return type:

Any

Returns:

Matplotlib figure and axes

class jaxabm.analysis.SensitivityAnalysis(model_factory, param_ranges, metrics_of_interest, num_samples=100, seed=0)[source]

Bases: object

Perform sensitivity analysis on model parameters.

This class provides tools for analyzing how changes in model parameters affect model outputs, using efficient sampling techniques and sensitivity indices calculation.

model_factory

Function to create model instances

param_ranges

Dictionary mapping parameter names to (min, max) ranges

metrics_of_interest

List of metric names to analyze

num_samples

Number of parameter samples to generate

key

JAX random key

samples

Generated parameter samples

results

Analysis results (populated after run())

run(verbose=True)[source]

Run sensitivity analysis.

Parameters:

verbose (bool) – Whether to print progress information

Return type:

Dict[str, Array]

Returns:

Dictionary mapping metric names to arrays of results

sobol_indices()[source]

Calculate sensitivity indices for each parameter and metric.

This is a simplified implementation that calculates correlation-based indices as a proxy for Sobol indices. For a full Sobol analysis, specialized sampling would be required.

Return type:

Dict[str, Dict[str, float]]

Returns:

Dictionary mapping metric names to dictionaries of parameter name -> sensitivity index

plot(metric=None, ax=None, **kwargs)[source]

Plot sensitivity analysis results.

Parameters:
  • metric – Metric to plot. If None, plot sobol indices for all metrics.

  • ax – Matplotlib axis to use for plotting.

  • **kwargs – Additional keyword arguments to pass to plotting function.

Returns:

Matplotlib axis.

plot_indices(figsize=(10, 6))[source]

Plot the sensitivity indices.

Parameters:

figsize (Tuple[int, int]) – Figure size as (width, height)

Return type:

Any

Returns:

Matplotlib figure and axes