Source code for jaxabm.agentpy

"""
JaxABM Core Agent-based Modeling Classes with AgentPy-like Interface

This module provides the main classes for the JaxABM framework with an
AgentPy-like interface. This makes it easier to create, run, and analyze
agent-based models while maintaining the performance benefits of JAX.

The main classes are:
- Agent: Base class for creating agents
- AgentList: Container for managing collections of agents
- Environment: Container for environment state and spatial structures
- Model: Base class for creating models
- Results: Container for simulation results and analysis tools
- Parameter: Class for parameter definition (for sensitivity analysis)
- Sample: Class for parameter samples (for batch runs)
- SensitivityAnalyzer: Wrapper for sensitivity analysis
- ModelCalibrator: Wrapper for model calibration
"""

import jax
import jax.numpy as jnp
import numpy as np
import time
import warnings
import matplotlib.pyplot as plt
import os
from typing import Any, Dict, List, Optional, Tuple, Union, Callable, TypeVar, Type, Set

# Import the core JAX-based components that we'll build upon
from .agent import AgentType, AgentCollection
from .core import ModelConfig
from .model import Model as JaxModel
from .utils import convert_to_numpy, format_time, run_parallel_simulations

# Type variables for better type hinting
T = TypeVar('T', bound='Agent')
ModelType = TypeVar('ModelType', bound='Model')


class Agent:
    """Base class for agents in JaxABM.
    
    This class provides an AgentPy-like interface for creating agents. To create
    a custom agent, inherit from this class and override the setup and step methods.
    You can also add custom methods to define additional agent behaviors.
    
    Example:
        ```python
        class MyAgent(Agent):
            def setup(self):
                return {
                    'x': 0,
                    'y': 0
                }
                
            def step(self, model_state):
                # Update agent state
                return {
                    'x': self._state['x'] + 0.1,
                    'y': self._state['y'] + 0.1
                }
                
            def custom_action(self, param):
                # Custom behavior outside of the step function
                self._state['x'] = param
                return self._state['x']
        ```
    """
    
    def __init__(self):
        """Initialize agent."""
        self.id = None
        self.model = None
        self.p = {}  # Parameters
        self._state = {}
    
    def setup(self) -> Dict[str, Any]:
        """Set up agent state.
        
        Override this method to initialize agent state.
        
        Returns:
            A dictionary containing the initial agent state.
        """
        return {}
    
    def step(self, model_state: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
        """Update agent state.
        
        Override this method to define agent behavior.
        
        Args:
            model_state: The current model state.
            
        Returns:
            A dictionary containing the updated agent state.
        """
        return self._state
    
    def update_state(self, new_state: Dict[str, Any]) -> None:
        """Update agent state.
        
        This method allows updating the agent's state from custom methods.
        It's used to ensure state changes from custom methods are properly
        reflected in the underlying model.
        
        Args:
            new_state: New state dictionary to merge with current state.
        """
        if self.model and hasattr(self.model, '_update_agent_state'):
            # If connected to a model, use the model's update mechanism
            self.model._update_agent_state(self, new_state)
        else:
            # Otherwise, just update the local state
            self._state.update(new_state)
    
    def __getattr__(self, name: str) -> Any:
        """Get agent attribute from state.
        
        This allows accessing state variables as attributes, e.g., agent.x
        instead of agent._state['x'].
        
        Args:
            name: Attribute name.
            
        Returns:
            Attribute value if it exists in _state.
        
        Raises:
            AttributeError: If attribute not found in _state.
        """
        if self._state and name in self._state:
            return self._state[name]
        raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
    
    def __setattr__(self, name: str, value: Any) -> None:
        """Set agent attribute.
        
        Special attributes (id, model, p, _state) are set normally.
        Other attributes are set in the _state dictionary.
        
        Args:
            name: Attribute name.
            value: Attribute value.
        """
        if name in ('id', 'model', 'p', '_state'):
            super().__setattr__(name, value)
        else:
            if not hasattr(self, '_state'):
                super().__setattr__('_state', {})
            
            # Update state
            new_state = {name: value}
            self.update_state(new_state)


class AgentWrapper(AgentType):
    """Adapter that wraps Agent class to implement AgentType protocol."""
    
    def __init__(self, agent_class: Type[Agent], params: Optional[Dict[str, Any]] = None):
        """Initialize agent wrapper.
        
        Args:
            agent_class: The Agent class to wrap.
            params: Parameters to pass to the agent.
        """
        self.agent_class = agent_class
        self.params = params or {}
        
        # Create a template instance to access methods
        self.agent_instance = agent_class()
        if params:
            self.agent_instance.p = params
    
    def init_state(self, model_config: ModelConfig, key: jax.Array) -> Dict[str, Any]:
        """Initialize agent state.
        
        Implements AgentType protocol by calling the Agent.setup method.
        
        Args:
            model_config: Model configuration.
            key: JAX random key.
            
        Returns:
            Initial agent state dictionary.
        """
        # Call the agent's setup method
        state = self.agent_instance.setup()
        
        # Ensure state is a dictionary
        if not isinstance(state, dict):
            if state is None:
                return {}
            else:
                raise ValueError(f"Agent.setup() must return a dictionary, got {type(state)}")
        
        return state
    
    def update(self, state: Dict[str, Any], model_state: Dict[str, Any],
               model_config: ModelConfig, key: jax.Array) -> Dict[str, Any]:
        """Update agent state.
        
        Implements AgentType protocol by calling the Agent.step method.
        
        Args:
            state: Current agent state.
            model_state: Current model state.
            model_config: Model configuration.
            key: JAX random key.
            
        Returns:
            Updated agent state.
        """
        # Update the agent instance's state
        self.agent_instance._state = state
        
        # Call the agent's step method
        new_state = self.agent_instance.step(model_state)
        
        # Ensure new_state is a dictionary
        if not isinstance(new_state, dict):
            if new_state is None:
                return state  # No change
            else:
                raise ValueError(f"Agent.step() must return a dictionary, got {type(new_state)}")
        
        return new_state


[docs] class AgentList: """Container for managing collections of agents. This class provides an AgentPy-like interface for managing groups of agents. Example: ```python # In Model.setup(): self.agents = AgentList(self, 10, MyAgent) # Access agent attributes: x_positions = self.agents.x # Returns array of x values # Filter agents: active_agents = self.agents.select(lambda agents: agents.active) ``` """
[docs] def __init__(self, model: 'Model', n: int, agent_class: Type[Agent], **kwargs): """Initialize agent list. Args: model: The model the agents belong to. n: Number of agents to create. agent_class: The Agent class to use. **kwargs: Parameters to pass to the agents. """ self.model = model self.n = n self.agent_class = agent_class self.params = kwargs # Create AgentWrapper and AgentCollection self.agent_type = AgentWrapper(agent_class, kwargs) self.collection = AgentCollection( agent_type=self.agent_type, num_agents=n ) # Store name for use with model self.name = None
@property def states(self) -> Dict[str, Any]: """Get all agent states. Returns: Dictionary of agent state variables. """ # Try to get states from model if hasattr(self.model, '_jax_model') and self.model._jax_model.state: agent_states = self.model._jax_model.state.get('agents', {}) if self.name and self.name in agent_states: return agent_states[self.name] # Fallback to collection if hasattr(self.collection, 'states'): return self.collection.states return {}
[docs] def __getattr__(self, name: str) -> Any: """Get agent attribute for all agents. This allows getting arrays of attribute values, e.g., agents.x. Args: name: Attribute name. Returns: Array of attribute values. Raises: AttributeError: If attribute not found. """ states = self.states if name in states: return states[name] raise AttributeError(f"'AgentList' object has no attribute '{name}'")
[docs] def __len__(self) -> int: """Get number of agents. Returns: Number of agents. """ return self.n
[docs] def select(self, condition: Callable[[Any], Any]) -> 'AgentList': """Select agents that satisfy a condition. Args: condition: Function that takes agent attributes and returns a boolean array or mask. Returns: New AgentList with selected agents. """ # Create a mask from the condition states = self.states mask = condition(states) # Count selected agents filtered_count = int(jnp.sum(mask)) # Create new AgentList with same agent type filtered = AgentList( model=self.model, n=filtered_count, agent_class=self.agent_class, **self.params ) # Filter the collection (if possible) if hasattr(self.collection, 'filter'): filtered.collection = self.collection.filter(lambda s: mask) return filtered
[docs] def __iter__(self): """Allow iteration over agents. This returns the actual agent instances stored in the model, allowing access to custom methods. Yields: Agent instances with appropriate state. """ # If we have access to the model's stored agent instances, use those if self.model and hasattr(self.model, '_agent_instances'): if self.name in self.model._agent_instances: yield from self.model._agent_instances[self.name] return # Fallback: Create agent instances on-the-fly (less efficient) states = self.states if not states: return # Get an arbitrary state variable to determine agent count any_state = next(iter(states.values())) num_agents = len(any_state) # Create agent instances with appropriate state for i in range(num_agents): agent = self.agent_class() agent.model = self.model agent.id = i agent.p = self.params agent._state = {k: states[k][i] for k in states} yield agent
class Environment: """Environment for agent interactions. This class provides a container for environment state and methods for creating and managing spatial structures like grids and networks. Example: ```python # In Model.setup(): self.env.add_state('temperature', 25.0) # Access environment state: temp = self.env.temperature ``` """ def __init__(self, model: 'Model'): """Initialize environment. Args: model: The model the environment belongs to. """ self.model = model self.state = {} def add_state(self, name: str, value: Any) -> None: """Add state variable to environment. Args: name: State variable name. value: State variable value. """ self.state[name] = value # Add to model if it exists if hasattr(self.model, '_jax_model') and self.model._jax_model: self.model._jax_model.add_env_state(name, value) def __getattr__(self, name: str) -> Any: """Get environment state variable. Args: name: State variable name. Returns: State variable value. Raises: AttributeError: If state variable not found. """ if name in self.state: return self.state[name] # Check model state if environment state isn't available if hasattr(self.model, '_jax_model') and self.model._jax_model and self.model._jax_model.state: env_state = self.model._jax_model.state.get('env', {}) if name in env_state: return env_state[name] raise AttributeError(f"'Environment' object has no attribute '{name}'") def __setattr__(self, name: str, value: Any) -> None: """Set environment attribute. Special attributes (model, state) are set normally. Other attributes are set in the state dictionary and the underlying JAX model if it exists. Args: name: Attribute name. value: Attribute value. """ if name in ('model', 'state'): super().__setattr__(name, value) else: if not hasattr(self, 'state'): super().__setattr__('state', {}) self.state[name] = value # Update JAX model if it exists if hasattr(self, 'model') and hasattr(self.model, '_jax_model') and self.model._jax_model: self.model._jax_model.add_env_state(name, value) class Grid: """Grid environment for spatial agent-based models. This class provides a 2D grid for agent interactions. Example: ```python # In Model.setup(): self.grid = Grid(self, (10, 10)) # Position agents on grid: self.grid.position_agents(self.agents) ``` """ def __init__(self, model: 'Model', shape: Tuple[int, int], periodic: bool = False): """Initialize grid. Args: model: The model the grid belongs to. shape: Grid shape (width, height). periodic: Whether the grid has periodic boundaries. """ self.model = model self.shape = shape self.periodic = periodic # Add grid to model environment self.model.env.add_state('grid_shape', shape) self.model.env.add_state('grid_periodic', periodic) def position_agents(self, agents: AgentList, positions: Optional[jnp.ndarray] = None) -> None: """Position agents on the grid. Args: agents: The agents to position. positions: Optional array of positions (x, y) for each agent. If not provided, agents are positioned randomly. """ n = len(agents) width, height = self.shape # Generate random positions if not provided if positions is None: key = jax.random.PRNGKey(self.model.p.get('seed', 0)) x = jax.random.randint(key, (n,), 0, width) key, subkey = jax.random.split(key) y = jax.random.randint(subkey, (n,), 0, height) positions = jnp.column_stack((x, y)) # Update agent states if hasattr(agents.collection, 'states') and agents.collection.states is not None: # Check if position already exists in states states = agents.collection.states if 'position' in states: # We can't directly modify JAX arrays, so we create a new states dict new_states = {k: v for k, v in states.items()} new_states['position'] = positions agents.collection._states = new_states else: # Add position to states if hasattr(agents.collection, '_states'): agents.collection._states['position'] = positions class Network: """Network environment for agent interactions. This class provides a network structure for agent interactions. Example: ```python # In Model.setup(): self.network = Network(self) # Add edges: self.network.add_edge(agent1, agent2) ``` """ def __init__(self, model: 'Model', directed: bool = False): """Initialize network. Args: model: The model the network belongs to. directed: Whether the network is directed. """ self.model = model self.directed = directed # Add network to model environment self.model.env.add_state('network_directed', directed) self.model.env.add_state('network_edges', jnp.zeros((0, 2), dtype=jnp.int32)) def add_edge(self, from_agent: Union[Agent, int], to_agent: Union[Agent, int]) -> None: """Add an edge to the network. Args: from_agent: The source agent or agent ID. to_agent: The target agent or agent ID. """ # Convert agents to IDs if necessary from_id = from_agent.id if isinstance(from_agent, Agent) else from_agent to_id = to_agent.id if isinstance(to_agent, Agent) else to_agent # Get current edges current_edges = self.model.env.network_edges # Add new edge new_edge = jnp.array([[from_id, to_id]], dtype=jnp.int32) new_edges = jnp.concatenate([current_edges, new_edge], axis=0) # Update model environment self.model.env.add_state('network_edges', new_edges) # Add reverse edge for undirected networks if not self.directed and from_id != to_id: self.add_edge(to_id, from_id) def get_neighbors(self, agent: Union[Agent, int]) -> jnp.ndarray: """Get neighbors of an agent. Args: agent: The agent or agent ID. Returns: Array of neighbor agent IDs. """ # Convert agent to ID if necessary agent_id = agent.id if isinstance(agent, Agent) else agent # Get edges edges = self.model.env.network_edges # Find neighbors if self.directed: # Only outgoing edges mask = edges[:, 0] == agent_id neighbors = edges[mask, 1] else: # Both incoming and outgoing edges mask1 = edges[:, 0] == agent_id mask2 = edges[:, 1] == agent_id neighbors1 = edges[mask1, 1] neighbors2 = edges[mask2, 0] neighbors = jnp.concatenate([neighbors1, neighbors2], axis=0) # Remove duplicates neighbors = jnp.unique(neighbors) return neighbors class Results: """Container for simulation results. This class provides an AgentPy-like interface for accessing and visualizing simulation results. Example: ```python results = model.run() # Plot results: results.plot() # Access specific variables: results.variables.MyAgent.x.plot() ``` """ class VariableContainer: """Container for simulation variables.""" def __init__(self, data: Dict[str, Any]): """Initialize variable container. Args: data: Simulation data. """ self._data = data # Create dynamic attributes for each agent type agent_types = set() for key in data: if key.startswith('agents.'): agent_type = key.split('.')[1] agent_types.add(agent_type) # Create nested containers for agent types for agent_type in agent_types: setattr(self, agent_type, self.AgentContainer(data, agent_type)) class AgentContainer: """Container for agent variables.""" def __init__(self, data: Dict[str, Any], agent_type: str): """Initialize agent container. Args: data: Simulation data. agent_type: Agent type name. """ self._data = data self._agent_type = agent_type # Create dynamic attributes for each agent variable self._variables = set() for key in data: if key.startswith(f'agents.{agent_type}.'): var_name = key.split('.')[-1] self._variables.add(var_name) setattr(self, var_name, self.VariableSeries(data[key], var_name)) class VariableSeries: """Series of variable values over time.""" def __init__(self, values: List[Any], name: str): """Initialize variable series. Args: values: List of values over time. name: Variable name. """ self._values = values self._name = name def plot(self, ax=None, **kwargs): """Plot variable values over time. Args: ax: Matplotlib axis. **kwargs: Additional keyword arguments for plotting. Returns: Matplotlib axis. """ if ax is None: fig, ax = plt.subplots() # Handle different data types if self._values and hasattr(self._values[0], 'shape') and len(self._values[0].shape) > 0: # Matrix data - plot mean data = np.mean(np.array(self._values), axis=1) ax.plot(data, **kwargs) ax.set_ylabel(f'Mean {self._name}') else: # Scalar data ax.plot(self._values, **kwargs) ax.set_ylabel(self._name) ax.set_xlabel('Time') return ax def __getitem__(self, key: int) -> Any: """Get value at time step. Args: key: Time step. Returns: Value at time step. """ return self._values[key] def __len__(self) -> int: """Get number of time steps. Returns: Number of time steps. """ return len(self._values) def __init__(self, data: Dict[str, Any]): """Initialize results. Args: data: Simulation data. """ self._data = data # Convert JAX arrays to numpy self._data = convert_to_numpy(self._data) # Create variable container self.variables = self.VariableContainer(self._data) def plot(self, variables: Optional[List[str]] = None, ax=None, **kwargs): """Plot results. Args: variables: List of variables to plot. ax: Matplotlib axis. **kwargs: Additional keyword arguments for plotting. Returns: Matplotlib axis. """ if ax is None: fig, ax = plt.subplots() if variables is None: # Plot all scalar variables for key, values in self._data.items(): if isinstance(values, list) and all(isinstance(v, (int, float, np.number)) for v in values): ax.plot(values, label=key, **kwargs) else: # Plot specified variables for var in variables: if var in self._data: ax.plot(self._data[var], label=var, **kwargs) ax.legend() ax.set_xlabel('Time') return ax def save(self, filename: str) -> None: """Save results to file. Args: filename: Filename to save results to. """ import pickle with open(filename, 'wb') as f: pickle.dump(self._data, f) @classmethod def load(cls, filename: str) -> 'Results': """Load results from file. Args: filename: Filename to load results from. Returns: Results object. """ import pickle with open(filename, 'rb') as f: data = pickle.load(f) return cls(data)
[docs] class Model: """Base class for agent-based models in JaxABM. This class provides an AgentPy-like interface for creating and running agent-based models. Example: ```python class MyModel(Model): def setup(self): self.agents = self.add_agents(10, MyAgent) self.env.temperature = 25.0 def step(self): # Agents are stepped automatically by default # Add additional model logic here self.env.temperature += 0.1 # Record data self.record('temperature', self.env.temperature) model = MyModel(parameters) results = model.run() ``` """
[docs] def __init__(self, parameters: Optional[Dict[str, Any]] = None, seed: Optional[int] = None): """Initialize model. Args: parameters: Model parameters. seed: Random seed. """ self.p = parameters or {} # Set seed from parameters or argument self.seed = seed if seed is not None else self.p.get('seed', 0) # Set steps from parameters self.steps = self.p.get('steps', 100) # Initialize environment self.env = Environment(self) # Initialize data recording self._recorded_data = {} # Agent lists self._agent_lists = {} # JAX model instance (initialized during run) self._jax_model = None # Initialize state self._current_env_state = {} self._current_agent_states = {} # Track if we're currently in a run self._running = False # Store actual agent instances for custom method access self._agent_instances = {}
[docs] def setup(self) -> None: """Set up model. Override this method to set up agents and environment. """ pass
[docs] def step(self) -> None: """Execute a single time step. Override this method to define model behavior. By default, it steps all agent lists. """ # Default behavior: Let JAX model handle agent stepping pass
[docs] def end(self) -> None: """Execute code at the end of a simulation. Override this method to define behavior at the end of a simulation. """ pass
[docs] def update_state(self, env_state: Dict[str, Any], agent_states: Dict[str, Dict[str, Any]], model_params: Dict[str, Any], key: jax.Array) -> Dict[str, Any]: """Update model environment state. This method is called by the JAX model to update the environment state based on agent states. By default, it returns the environment state unchanged. Override this method to define custom state update logic. Args: env_state: Current environment state. agent_states: Current agent states by collection. model_params: Model parameters. key: JAX random key. Returns: Updated environment state. """ # Store the current state for access in step self._current_env_state = env_state.copy() self._current_agent_states = agent_states # Call user-defined step function self.step() # Create new environment state from current + local env state new_env_state = {**env_state} for name, value in self.env.state.items(): new_env_state[name] = value return new_env_state
[docs] def compute_metrics(self, env_state: Dict[str, Any], agent_states: Dict[str, Dict[str, Any]], model_params: Dict[str, Any]) -> Dict[str, Any]: """Compute model metrics. This method is called by the JAX model to compute metrics from model state. By default, it returns an empty dictionary. Override this method to define custom metrics. Args: env_state: Current environment state. agent_states: Current agent states by collection. model_params: Model parameters. Returns: Dictionary of metrics. """ return {}
[docs] def add_agents(self, n: int, agent_class: Type[Agent], name: Optional[str] = None, **kwargs) -> AgentList: """Add agents to the model. Args: n: Number of agents to add. agent_class: Agent class to use. name: Name for this agent collection. **kwargs: Parameters to pass to the agents. Returns: AgentList of created agents. """ # Create agent list agent_list = AgentList(self, n, agent_class, **kwargs) # Generate name if not provided if name is None: name = agent_class.__name__.lower() + 's' # Store agent list self._agent_lists[name] = agent_list agent_list.name = name # Create and store actual agent instances for custom method access self._agent_instances[name] = [] for i in range(n): agent = agent_class() agent.id = i agent.model = self agent.p = kwargs self._agent_instances[name].append(agent) return agent_list
def _update_agent_state(self, agent: Agent, new_state: Dict[str, Any]) -> None: """Update agent state from custom methods. This method is called when an agent's custom method wants to update the agent's state outside of the step function. Args: agent: The agent to update. new_state: The new state to apply. """ # Find the agent's collection and index for name, agents in self._agent_instances.items(): if agent in agents: index = agents.index(agent) # If model is running, update the JAX model state if self._running and self._jax_model and self._jax_model.state: # Update the agent's state in the JAX model agent_states = self._jax_model.state.get('agents', {}) if name in agent_states: collection_states = agent_states[name] for state_name, value in new_state.items(): if state_name in collection_states: # JAX arrays are immutable, so we need a workaround # This is not ideal for performance but necessary for flexibility current_values = collection_states[state_name] new_values = np.array(current_values) # Convert to numpy for mutability new_values[index] = value # Update the JAX model state collection_states[state_name] = jnp.array(new_values) # Also update the agent's local state agent._state.update(new_state) break
[docs] def get_agent(self, collection_name: str, agent_id: int) -> Optional[Agent]: """Get an agent instance by collection name and ID. This allows accessing agent instances for calling custom methods. Args: collection_name: Name of the agent collection. agent_id: ID of the agent. Returns: Agent instance if found, None otherwise. """ if collection_name in self._agent_instances: agents = self._agent_instances[collection_name] if 0 <= agent_id < len(agents): return agents[agent_id] return None
[docs] def record(self, name: str, value: Any) -> None: """Record data for later analysis. Args: name: Name for the recorded data. value: Value to record. """ self._recorded_data.setdefault(name, []).append(value)
[docs] def run(self, steps: Optional[int] = None) -> Results: """Run the model for the specified number of steps. Args: steps: Number of steps to run. Returns: Results object containing simulation results. """ # Set steps from argument or parameters if steps is not None: self.steps = steps # Mark model as running self._running = True # Create model config config = ModelConfig( steps=self.steps, collect_interval=1, seed=self.seed ) # Initialize the model self.setup() # Check if step method has been overridden self._dynamic_state_update = self.__class__.step != Model.step # Create JAX model self._jax_model = JaxModel( params=self.p, config=config, update_state_fn=self.update_state, metrics_fn=self.compute_metrics ) # Add agent collections to the model for name, agent_list in self._agent_lists.items(): self._jax_model.add_agent_collection(name, agent_list.collection) # Add environment state to the model for name, value in self.env.state.items(): self._jax_model.add_env_state(name, value) # Run the JAX model start_time = time.time() results_dict = self._jax_model.run() end_time = time.time() # Call end method self.end() # Mark model as not running self._running = False # Prepare results results_dict.update(self._recorded_data) # Format simulation time simulation_time = end_time - start_time # Add agent states to results if self._jax_model.state and 'agents' in self._jax_model.state: for name, states in self._jax_model.state['agents'].items(): for state_name, values in states.items(): results_dict[f'agents.{name}.{state_name}'] = values # Create and return Results object results = Results(results_dict) # Print execution time print(f"Simulation executed in {format_time(simulation_time)}") return results
[docs] def batch_run(self, parameter_ranges: Dict[str, List[Any]], repetitions: int = 1) -> Dict[str, Results]: """Run model with multiple parameter combinations. Args: parameter_ranges: Dictionary mapping parameter names to lists of values. repetitions: Number of repetitions for each parameter combination. Returns: Dictionary mapping parameter combinations to Results objects. """ # Generate all parameter combinations import itertools param_names = list(parameter_ranges.keys()) param_values = list(itertools.product(*parameter_ranges.values())) results = {} # Run each parameter combination for values in param_values: # Create parameter dictionary params = {**self.p} # Start with current parameters for name, value in zip(param_names, values): params[name] = value # Run repetitions rep_results = [] for rep in range(repetitions): # Create new model with these parameters model = self.__class__(params, seed=self.seed + rep) # Run model rep_results.append(model.run()) # Store results params_tuple = tuple(values) results[params_tuple] = rep_results return results
# Make classes available at the module level __all__ = ['Agent', 'AgentList', 'Environment', 'Grid', 'Network', 'Model', 'Results', 'Parameter', 'Sample', 'SensitivityAnalyzer', 'ModelCalibrator']
[docs] class Parameter: """Parameter for sensitivity analysis and model calibration. This class defines a parameter with a range of possible values, which can be used for sensitivity analysis or parameter calibration. Example: ```python # Create a parameter for sensitivity analysis p1 = Parameter('growth_rate', bounds=(0.01, 0.1)) # Create a parameter with a distribution p2 = Parameter('initial_population', bounds=(10, 1000), distribution='uniform') ``` """
[docs] def __init__(self, name: str, bounds: Tuple[float, float], distribution: str = 'uniform'): """Initialize parameter. Args: name: Parameter name. bounds: Parameter bounds (min, max). distribution: Distribution for sampling ('uniform', 'normal', etc.). """ self.name = name self.bounds = bounds self.distribution = distribution
[docs] def sample(self, n: int = 1) -> np.ndarray: """Sample parameter values. Args: n: Number of samples. Returns: Array of sampled values. """ # Use numpy for simple random sampling if self.distribution == 'uniform': return np.random.uniform(self.bounds[0], self.bounds[1], size=n) elif self.distribution == 'normal': mean = (self.bounds[0] + self.bounds[1]) / 2 std = (self.bounds[1] - self.bounds[0]) / 4 # Approximate return np.random.normal(mean, std, size=n) else: raise ValueError(f"Unknown distribution: {self.distribution}")
class Sample: """Container for parameter samples. This class stores parameter samples for batch runs or sensitivity analysis. Example: ```python # Create parameters p1 = Parameter('growth_rate', (0.01, 0.1)) p2 = Parameter('initial_population', (10, 1000)) # Create sample sample = Sample([p1, p2], n_samples=10) # Run model with sample results = analyzer.run(sample) ``` """ def __init__(self, parameters: List[Parameter], n_samples: int = 10): """Initialize sample. Args: parameters: List of parameters. n_samples: Number of samples per parameter. """ self.parameters = parameters self.n_samples = n_samples # Sample parameter values self._samples = {} # Use numpy for random sampling for param in parameters: self._samples[param.name] = param.sample(n_samples) def __getitem__(self, index: int) -> Dict[str, float]: """Get parameter set at index. Args: index: Sample index. Returns: Dictionary of parameter values. """ params = {} for param in self.parameters: params[param.name] = self._samples[param.name][index] return params def __len__(self) -> int: """Get number of samples. Returns: Number of samples. """ return self.n_samples class SensitivityAnalyzer: """Wrapper for sensitivity analysis with AgentPy-like interface. This class provides a more user-friendly interface for sensitivity analysis with JaxABM. Example: ```python # Create parameters p1 = Parameter('growth_rate', (0.01, 0.1)) p2 = Parameter('initial_population', (10, 1000)) # Create analyzer analyzer = SensitivityAnalyzer( MyModel, parameters=[p1, p2], n_samples=10, metrics=['population', 'resources'] ) # Run analysis results = analyzer.run() # Calculate sensitivity sensitivity = analyzer.calculate_sensitivity() ``` """ def __init__(self, model_class: Type[Model], parameters: List[Parameter], n_samples: int = 10, metrics: List[str] = None): """Initialize sensitivity analyzer. Args: model_class: Model class to analyze. parameters: List of parameters to vary. n_samples: Number of samples per parameter. metrics: List of metrics to analyze. """ self.model_class = model_class self.parameters = parameters self.n_samples = n_samples self.metrics = metrics or [] # Create sample self.sample = Sample(parameters, n_samples) # Import actual sensitivity analysis from .analysis import SensitivityAnalysis # Create model factory function def model_factory(params=None, config=None): """Create model instance with parameters.""" model = self.model_class(params) return model # Create parameter ranges for sensitivity analysis param_ranges = {} for param in parameters: param_ranges[param.name] = param.bounds # Create sensitivity analysis self.analysis = SensitivityAnalysis( model_factory=model_factory, param_ranges=param_ranges, metrics_of_interest=metrics, num_samples=n_samples ) def run(self) -> Dict[str, Any]: """Run sensitivity analysis. Returns: Dictionary of results. """ results_obj = self.analysis.run() # Convert Results object to dictionary if needed if hasattr(results_obj, '_data'): return results_obj._data return results_obj def calculate_sensitivity(self, method: str = 'sobol') -> Dict[str, Any]: """Calculate sensitivity indices. Args: method: Method for calculating sensitivity indices. 'sobol' or 'morris'. Returns: Dictionary of sensitivity indices. """ if method == 'sobol': return self.analysis.sobol_indices() elif method == 'morris': return self.analysis.morris_indices() else: raise ValueError(f"Unknown method: {method}") def plot(self, metric: Optional[str] = None, ax=None, **kwargs): """Plot sensitivity analysis results. Args: metric: Metric to plot. If None, plot all metrics. ax: Matplotlib axis. **kwargs: Additional keyword arguments for plotting. Returns: Matplotlib axis. """ # For now, we'll just call the underlying analysis plot method return self.analysis.plot(metric, ax, **kwargs) class ModelCalibrator: """Wrapper for model calibration with AgentPy-like interface. This class provides a more user-friendly interface for model calibration with JaxABM. Example: ```python # Create parameters p1 = Parameter('growth_rate', (0.01, 0.1)) p2 = Parameter('initial_population', (10, 1000)) # Create calibrator calibrator = ModelCalibrator( MyModel, parameters=[p1, p2], target_metrics={'population': 500, 'resources': 1000}, metrics_weights={'population': 1.0, 'resources': 0.5} ) # Run calibration optimal_params = calibrator.run() ``` """ def __init__(self, model_class: Type[Model], parameters: List[Parameter], target_metrics: Dict[str, float], metrics_weights: Dict[str, float] = None, learning_rate: float = 0.01, max_iterations: int = 20, method: str = 'gradient'): """Initialize model calibrator. Args: model_class: Model class to calibrate. parameters: List of parameters to optimize. target_metrics: Dictionary of target metrics. metrics_weights: Dictionary of metric weights for loss function. learning_rate: Learning rate for optimization. max_iterations: Maximum number of iterations. method: Optimization method ('gradient' or 'rl'). """ self.model_class = model_class self.parameters = parameters self.target_metrics = target_metrics self.metrics_weights = metrics_weights or {m: 1.0 for m in target_metrics} self.learning_rate = learning_rate self.max_iterations = max_iterations self.method = method # Import actual model calibrator from .analysis import ModelCalibrator as JaxModelCalibrator # Create model factory function def model_factory(params=None, config=None): """Create model instance with parameters.""" model = self.model_class(params) return model # Create initial parameters initial_params = {} for param in parameters: # Start with middle of range initial_params[param.name] = (param.bounds[0] + param.bounds[1]) / 2 # Create calibrator self.calibrator = JaxModelCalibrator( model_factory=model_factory, initial_params=initial_params, target_metrics=target_metrics, metrics_weights=metrics_weights, learning_rate=learning_rate, max_iterations=max_iterations, method=method ) def run(self) -> Dict[str, float]: """Run calibration. Returns: Dictionary of optimized parameters. """ return self.calibrator.calibrate() def plot_progress(self, ax=None, **kwargs): """Plot calibration progress. Args: ax: Matplotlib axis. **kwargs: Additional keyword arguments for plotting. Returns: Matplotlib axis. """ if ax is None: fig, ax = plt.subplots() # Plot loss over iterations if hasattr(self.calibrator, 'loss_history') and self.calibrator.loss_history: ax.plot(self.calibrator.loss_history, **kwargs) ax.set_xlabel('Iteration') ax.set_ylabel('Loss') ax.set_title('Calibration Progress') ax.grid(True) return ax