Reinforcement Learning Calibration
JaxABM provides state-of-the-art reinforcement learning methods for parameter calibration. These methods are particularly effective for complex, non-linear parameter spaces where traditional optimization methods struggle.
Why Use RL for Calibration?
Reinforcement Learning approaches calibration as a sequential decision-making problem:
Adaptive Exploration: Learns where to search in parameter space
Non-linear Dynamics: Handles complex parameter interactions
Robust to Noise: Naturally handles stochastic model outputs
Memory: Learns from previous parameter evaluations
Continuous Learning: Improves strategy over time
Available RL Methods
Actor-Critic (Recommended)
The most robust and performant RL method for calibration.
calibrator = jx.analysis.ModelCalibrator(
model_factory=model_factory,
initial_params=initial_params,
target_metrics=target_metrics,
method='actor_critic',
max_iterations=50
)
Strengths: - Excellent convergence properties - Handles continuous parameter spaces naturally - Low variance gradient estimates - Fast learning
Best For: - Economic models with complex interactions - High-dimensional parameter spaces - Models with smooth parameter landscapes
Policy Gradient
Direct policy optimization with enhanced numerical stability.
calibrator = jx.analysis.ModelCalibrator(
model_factory=model_factory,
initial_params=initial_params,
target_metrics=target_metrics,
method='policy_gradient',
learning_rate=0.01
)
Strengths: - Direct optimization of parameter policy - Good for continuous actions - Principled exploration via entropy regularization
Recent Improvements (v0.1.0): - Fixed numerical instability issues - Added gradient clipping and value bounds - Enhanced safety checks for NaN handling - Improved convergence criteria
Q-Learning
Neural network-based Q-learning for discrete parameter optimization.
calibrator = jx.analysis.ModelCalibrator(
model_factory=model_factory,
initial_params=initial_params,
target_metrics=target_metrics,
method='q_learning',
max_iterations=100
)
Strengths: - Well-established theoretical foundations - Good for discrete parameter choices - Epsilon-greedy exploration
Best For: - Models with discrete parameter choices - When you need interpretable action-value functions
Deep Q-Networks (DQN)
Advanced Q-learning with experience replay and target networks.
calibrator = jx.analysis.ModelCalibrator(
model_factory=model_factory,
initial_params=initial_params,
target_metrics=target_metrics,
method='dqn',
max_iterations=100
)
Strengths: - Stable learning with experience replay - Handles large state spaces - Double Q-learning for reduced overestimation
Best For: - Complex models with large parameter spaces - When sample efficiency is important
Configuration Options
Learning Rate Tuning
# Conservative learning (stable but slow)
calibrator = jx.analysis.ModelCalibrator(
method='actor_critic',
learning_rate=0.001
)
# Aggressive learning (fast but may be unstable)
calibrator = jx.analysis.ModelCalibrator(
method='actor_critic',
learning_rate=0.01
)
Exploration Control
# High exploration (for complex landscapes)
calibrator = jx.analysis.ModelCalibrator(
method='policy_gradient',
# More exploration via entropy regularization
)
# Conservative exploration (for smooth landscapes)
calibrator = jx.analysis.ModelCalibrator(
method='actor_critic',
# Less exploration, more exploitation
)
Advanced Techniques
Custom Reward Shaping
The RL methods use sophisticated reward shaping:
# Automatic reward shaping based on:
# 1. Improvement over previous iteration
# 2. Distance to target metrics
# 3. Parameter bound violations
# 4. Convergence bonuses
Multi-Objective RL
Handle multiple competing objectives:
calibrator = jx.analysis.ModelCalibrator(
target_metrics={
'accuracy': 0.95,
'speed': 100.0,
'robustness': 0.9
},
metrics_weights={
'accuracy': 2.0, # Prioritize accuracy
'speed': 1.0,
'robustness': 1.5
},
method='actor_critic'
)
Convergence Monitoring
calibrator = jx.analysis.ModelCalibrator(
method='actor_critic',
tolerance=1e-4, # Convergence threshold
patience=10, # Early stopping patience
max_iterations=100
)
# Monitor progress
best_params = calibrator.calibrate(verbose=True)
# Access training history
history = calibrator.get_calibration_history()
print(f"Loss progression: {history['loss']}")
Performance Optimization
Batch Evaluation
# Evaluate multiple parameter sets simultaneously
calibrator = jx.analysis.ModelCalibrator(
method='actor_critic',
num_evaluation_runs=3, # Average over multiple runs
evaluation_steps=50 # Longer simulations for stability
)
GPU Acceleration
# RL methods automatically use GPU when available
import jax
print(f"Using devices: {jax.devices()}")
# Ensure your model factory also uses GPU
def gpu_model_factory(params):
# Your model implementation
return model
Memory Management
# For large models, control memory usage
calibrator = jx.analysis.ModelCalibrator(
method='dqn',
max_iterations=50, # Shorter runs
evaluation_steps=25 # Shorter evaluations
)
Troubleshooting
Common Issues and Solutions
- Slow Convergence
Increase learning rate
Use Actor-Critic instead of Policy Gradient
Check parameter bounds are reasonable
- Unstable Training
Decrease learning rate
Increase evaluation_steps for more stable estimates
Use multiple evaluation runs
- Poor Final Performance
Increase max_iterations
Check target metrics are achievable
Verify model factory is working correctly
- NaN Values (Fixed in v0.1.0)
All RL methods now include robust NaN handling
Automatic gradient clipping prevents explosions
Safe numerical operations throughout
Debugging Tools
# Enable verbose output
best_params = calibrator.calibrate(verbose=True)
# Plot training progress
calibrator.plot_calibration()
# Examine parameter evolution
history = calibrator.get_calibration_history()
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['loss'])
plt.title('Loss Over Time')
plt.subplot(1, 2, 2)
for param_name in history['params'][0].keys():
values = [p[param_name] for p in history['params']]
plt.plot(values, label=param_name)
plt.legend()
plt.title('Parameter Evolution')
plt.show()
Best Practices
Start with Actor-Critic: Most robust method for new problems
Tune Learning Rates: Start with 0.001-0.01 range
Monitor Convergence: Use verbose mode and check loss plots
Parameter Scaling: Normalize parameters to [0, 1] range when possible
Evaluation Budget: Balance between accuracy and computational cost
Reproducibility: Always set random seeds
Validation: Test final parameters on independent datasets
Example: Economic Model Calibration
Complete example using Actor-Critic for economic model:
import jaxabm as jx
import jax.numpy as jnp
def economic_model_factory(params):
class EconomicModel:
def __init__(self, params):
self.growth_rate = params['growth_rate']
self.inflation_rate = params['inflation_rate']
self.employment_rate = params['employment_rate']
def run(self, steps=100):
# Complex economic dynamics
gdp = 1.0
for _ in range(steps):
gdp *= (1 + self.growth_rate)
# Add noise and interactions
gdp += jnp.random.normal() * 0.01
# Compute metrics
unemployment = max(0, 0.1 - self.employment_rate)
inflation = self.inflation_rate * gdp
return {
'gdp_growth': [gdp - 1.0],
'unemployment': [unemployment],
'inflation': [inflation]
}
return EconomicModel(params)
# Set up RL calibration
rl_calibrator = jx.analysis.ModelCalibrator(
model_factory=economic_model_factory,
initial_params={
'growth_rate': 0.02,
'inflation_rate': 0.03,
'employment_rate': 0.95
},
target_metrics={
'gdp_growth': 0.025, # 2.5% growth
'unemployment': 0.05, # 5% unemployment
'inflation': 0.02 # 2% inflation
},
param_bounds={
'growth_rate': (0.0, 0.1),
'inflation_rate': (0.0, 0.1),
'employment_rate': (0.8, 1.0)
},
method='actor_critic',
max_iterations=50,
tolerance=1e-3
)
# Run calibration
optimal_params = rl_calibrator.calibrate(verbose=True)
print(f"Optimal parameters: {optimal_params}")
# Validate results
test_model = economic_model_factory(optimal_params)
results = test_model.run()
print(f"Final metrics: {results}")
This example demonstrates the power of RL methods for complex economic model calibration with multiple interacting parameters.