Quick Start Guide
This guide will get you up and running with JaxABM in just a few minutes.
Basic Concepts
JaxABM is built around these core concepts:
Model: The main simulation container
Agents: Individual entities with behaviors
AgentCollection: Groups of similar agents
Environment: Shared state and resources
Calibration: Parameter optimization
Sensitivity Analysis: Parameter importance testing
Your First Model
Let’s create a simple random walk model:
import jaxabm as jx
import jax.numpy as jnp
import jax.random as random
# Create model
model = jx.Model()
# Define agent behavior
def step_function(agents, env_state):
"""Agents take random steps"""
key = random.PRNGKey(42)
# Random movement
dx = random.normal(key, (agents.n_agents,)) * 0.1
dy = random.normal(key, (agents.n_agents,)) * 0.1
# Update positions
new_x = agents.state['x'] + dx
new_y = agents.state['y'] + dy
return agents.update_state({'x': new_x, 'y': new_y})
# Create agents
agents = jx.AgentCollection(
"walkers",
n_agents=100,
initial_state={
'x': jnp.zeros(100),
'y': jnp.zeros(100)
},
step_function=step_function
)
# Add agents to model
model.add_agent_collection(agents)
# Run simulation
results = model.run(steps=50)
print(f"Final positions: {results.agents['walkers'].state}")
Parameter Calibration
JaxABM provides powerful calibration tools to find optimal parameters:
import jaxabm as jx
# Define model factory
def create_model(params):
class EconomyModel:
def __init__(self, params):
self.growth_rate = params.get('growth_rate', 0.02)
self.volatility = params.get('volatility', 0.1)
def run(self, steps=100):
# Simulate economic growth
growth = self.growth_rate + random.normal(key, ()) * self.volatility
return {'gdp': [growth], 'unemployment': [0.05]}
return EconomyModel(params)
# Set up calibration
calibrator = jx.analysis.ModelCalibrator(
model_factory=create_model,
initial_params={'growth_rate': 0.01, 'volatility': 0.05},
target_metrics={'gdp': 0.025, 'unemployment': 0.04},
param_bounds={'growth_rate': (0.0, 0.1), 'volatility': (0.01, 0.5)},
method='pso' # Particle Swarm Optimization
)
# Run calibration
best_params = calibrator.calibrate()
print(f"Optimal parameters: {best_params}")
Reinforcement Learning Calibration
Use RL methods for complex parameter optimization:
# Use Actor-Critic for advanced calibration
rl_calibrator = jx.analysis.ModelCalibrator(
model_factory=create_model,
initial_params={'growth_rate': 0.01, 'volatility': 0.05},
target_metrics={'gdp': 0.025, 'unemployment': 0.04},
param_bounds={'growth_rate': (0.0, 0.1), 'volatility': (0.01, 0.5)},
method='actor_critic', # RL method
max_iterations=50
)
rl_params = rl_calibrator.calibrate()
Available RL methods:
- q_learning: Q-learning with neural networks
- policy_gradient: Policy gradient methods
- actor_critic: Actor-critic algorithms
- dqn: Deep Q-Networks
Sensitivity Analysis
Analyze parameter importance:
# Set up sensitivity analysis
sensitivity = jx.analysis.SensitivityAnalysis(
model_factory=create_model,
param_ranges={
'growth_rate': (0.0, 0.1),
'volatility': (0.01, 0.5)
},
metrics_of_interest=['gdp', 'unemployment']
)
# Run analysis
results = sensitivity.run()
# Get Sobol indices
indices = sensitivity.sobol_indices()
print(f"Parameter importance: {indices}")
# Plot results
sensitivity.plot()
Advanced Features
Multi-Agent Interactions
def interaction_function(agents, env_state):
"""Agents interact with each other"""
# Find neighbors
distances = compute_distances(agents.state['x'], agents.state['y'])
neighbors = find_neighbors(distances, radius=0.5)
# Influence behavior
influence = compute_social_influence(agents.state, neighbors)
return agents.update_state({'influence': influence})
Custom Environment
# Add environment state
model.add_env_state('temperature', initial_value=20.0)
model.add_env_state('resources', initial_value=jnp.ones(100))
# Environment update function
def update_environment(env_state, agents):
# Climate change
new_temp = env_state['temperature'] + 0.01
# Resource depletion
consumption = jnp.sum(agents['consumers'].state['consumption'])
new_resources = env_state['resources'] - consumption
return {'temperature': new_temp, 'resources': new_resources}
model.set_env_step_function(update_environment)
Performance Optimization
from jax import jit, vmap
# JIT compile for speed
@jit
def fast_step_function(agents, env_state):
# Your agent logic here
return agents
# Vectorize operations
@vmap
def vectorized_computation(agent_data):
# Process each agent
return result
Next Steps
Now that you have the basics, explore:
Detailed Tutorials: Tutorials
Complete Examples: Examples
Calibration Guide: Model Calibration
API Reference: API Reference
Common Patterns
- Economic Models
See examples/economic_model for a complete economic simulation
- Epidemiological Models
Check examples/sir_model for disease spread modeling
- Social Network Models
Explore examples/social_network for network-based simulations
- Spatial Models
Look at examples/spatial_model for geographic simulations