jaxabm.AgentCollection

class jaxabm.AgentCollection(agent_type, num_agents)[source]

Bases: object

Collection of agents of the same type.

This class manages a collection of agents of the same type, providing methods for initialization, updating, and accessing agent states.

agent_type

The type of agent in the collection

num_agents

Number of agents in the collection

model_config

Model configuration associated with this collection (set during Model.initialize)

_states

Dictionary of agent state variables, with each variable having shape (num_agents, …)

_key

The JAX PRNGKey used to initialize this collection

__init__(agent_type, num_agents)[source]

Initialize agent collection placeholder.

The actual state initialization happens during the init method, which is typically called by Model.initialize().

Parameters:
  • agent_type (AgentType) – The type of agent in this collection (must adhere to AgentType protocol)

  • num_agents (int) – The number of agents to create in this collection.

Methods

__init__(agent_type, num_agents)

Initialize agent collection placeholder.

aggregate(variable[, fn])

Aggregate a state variable across agents.

filter(condition)

Filter agents based on a condition.

get_states()

Get agent states (alias for states property for backward compatibility).

init(key, model_config)

Initialize agent states.

update(model_state, key, model_config)

Update all agents in the collection using JAX vmap.

Attributes

states

Get agent states.

init(key, model_config)[source]

Initialize agent states.

This method initializes the states of all agents in the collection using the agent type’s init_state method. It is typically called by Model.initialize().

Parameters:
  • key (Any) – Random key for stochastic initialization.

  • model_config (ModelConfig) – Model configuration settings passed from the Model.

Return type:

None

update(model_state, key, model_config)[source]

Update all agents in the collection using JAX vmap.

This method updates the internal states (self._states) of all agents using their agent type’s update method, vectorized with jax.vmap. It assumes the agent type adheres to the AgentType protocol and its update method returns only the updated AgentState.

Parameters:
  • model_state (Dict[str, Any]) – Current model state (environment + other agent states)

  • key (Any) – Random key for stochastic updates

  • model_config (ModelConfig) – Model configuration settings

Return type:

None

get_states()[source]

Get agent states (alias for states property for backward compatibility).

Return type:

Dict[str, Any]

Returns:

Dictionary of agent state variables

property states: Dict[str, Array]

Get agent states.

Returns:

Dictionary of agent state variables

aggregate(variable, fn=<function mean>)[source]

Aggregate a state variable across agents.

Parameters:
  • variable (str) – Name of the state variable to aggregate

  • fn (Callable) – Aggregation function (default: mean)

Return type:

Any

Returns:

Aggregated value

filter(condition)[source]

Filter agents based on a condition.

This method creates a new agent collection with agents that meet the specified condition.

Parameters:

condition (Callable[[Dict[str, Any]], Array]) – Function that takes agent state and returns boolean array

Return type:

AgentCollection

Returns:

New agent collection with filtered agents