"""
Analysis module for JAX-based agent-based modeling.
This module provides tools for analyzing and calibrating agent-based models
built with the jaxabm framework, including sensitivity analysis and
parameter optimization techniques that leverage JAX's capabilities.
"""
from typing import Any, Dict, List, Optional, Tuple, Callable, Union, TypeVar
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random
# Import Model from .model and ModelConfig from .core
from .model import Model
from .core import ModelConfig
# Type variables for better type annotations
ModelFactory = TypeVar('ModelFactory', bound=Callable[..., Model])
PRNGKey = jax.Array
[docs]
class SensitivityAnalysis:
"""Perform sensitivity analysis on model parameters.
This class provides tools for analyzing how changes in model parameters
affect model outputs, using efficient sampling techniques and
sensitivity indices calculation.
Attributes:
model_factory: Function to create model instances
param_ranges: Dictionary mapping parameter names to (min, max) ranges
metrics_of_interest: List of metric names to analyze
num_samples: Number of parameter samples to generate
key: JAX random key
samples: Generated parameter samples
results: Analysis results (populated after run())
"""
[docs]
def __init__(
self,
model_factory: ModelFactory,
param_ranges: Dict[str, Tuple[float, float]],
metrics_of_interest: List[str],
num_samples: int = 100,
seed: int = 0
):
"""Initialize sensitivity analysis.
Args:
model_factory: Function to create model instances
param_ranges: Dictionary mapping parameter names to (min, max) ranges
metrics_of_interest: List of metric names to analyze
num_samples: Number of parameter samples to generate
seed: Random seed
"""
self.model_factory = model_factory
self.param_ranges = param_ranges
self.metrics_of_interest = metrics_of_interest
self.num_samples = num_samples
self.key = random.PRNGKey(seed)
# Generate samples using Latin Hypercube Sampling
self.samples = self._generate_lhs_samples()
self.results = None
def _generate_lhs_samples(self) -> jax.Array:
"""Generate Latin Hypercube Samples for parameters.
Latin Hypercube Sampling ensures better coverage of the parameter space
than simple random sampling.
Returns:
Array of shape (num_samples, num_parameters) with sampled parameter values
"""
self.key, subkey = random.split(self.key)
# Create normalized LHS samples (0-1)
n_params = len(self.param_ranges)
points = jnp.linspace(0, 1, self.num_samples + 1)[:-1] # n points in [0, 1)
points = points + random.uniform(subkey, (self.num_samples,)) / self.num_samples # Add jitter
# Create a permutation of these points for each parameter
samples = jnp.zeros((self.num_samples, n_params))
for i, param in enumerate(self.param_ranges):
self.key, subkey = random.split(self.key)
perm = random.permutation(subkey, points)
samples = samples.at[:, i].set(perm)
# Scale samples to parameter ranges
for i, (param, (min_val, max_val)) in enumerate(self.param_ranges.items()):
samples = samples.at[:, i].multiply(max_val - min_val)
samples = samples.at[:, i].add(min_val)
return samples
[docs]
def run(self, verbose: bool = True) -> Dict[str, jax.Array]:
"""Run sensitivity analysis.
Args:
verbose: Whether to print progress information
Returns:
Dictionary mapping metric names to arrays of results
"""
if verbose:
print(f"Running sensitivity analysis with {self.num_samples} samples...")
param_names = list(self.param_ranges.keys())
metrics_results = {metric: jnp.zeros(self.num_samples) for metric in self.metrics_of_interest}
# Run model for each parameter sample
for i in range(self.num_samples):
if verbose:
print(f"\nSample {i+1}/{self.num_samples}")
# Construct parameter dictionary for this sample
params = {param: float(self.samples[i, j]) for j, param in enumerate(param_names)}
if verbose:
print(f"Parameters: {', '.join([f'{k}={v:.4f}' for k, v in params.items()])}")
# Create and run model using the factory
# Pass parameters and create a config with the specific seed
seed_value = i + 1000 # Use the sample index for reproducibility
config = ModelConfig(seed=seed_value)
# Assuming model_factory signature is factory(params=..., config=...)
# The factory itself needs to handle adding agents/state.
model = self.model_factory(params=params, config=config)
if verbose:
print("Running model...")
# model.run() now handles initialization internally
results = model.run()
# Extract metrics of interest
if verbose:
print("Results:")
# Handle both dictionary and Results objects
if hasattr(results, '_data'):
results_dict = results._data
else:
results_dict = results
for metric in self.metrics_of_interest:
if metric in results_dict and results_dict[metric] is not None:
# Handle both scalar values and arrays/lists
metric_value = results_dict[metric]
if hasattr(metric_value, '__len__') and not isinstance(metric_value, str):
# It's an array or list, take the last value
value = metric_value[-1]
else:
# It's a scalar value
value = metric_value
metrics_results[metric] = metrics_results[metric].at[i].set(value)
if verbose:
print(f" {metric}: {float(value):.4f}")
self.results = metrics_results
if verbose:
print("\nSensitivity analysis complete!")
return metrics_results
[docs]
def sobol_indices(self) -> Dict[str, Dict[str, float]]:
"""Calculate sensitivity indices for each parameter and metric.
This is a simplified implementation that calculates correlation-based
indices as a proxy for Sobol indices. For a full Sobol analysis,
specialized sampling would be required.
Returns:
Dictionary mapping metric names to dictionaries of parameter name -> sensitivity index
"""
# NOTE: This method calculates squared correlation coefficients as a simplified
# proxy for true Sobol indices. A full Sobol analysis would require
# different sampling techniques (e.g., Saltelli sampling).
if self.results is None:
raise ValueError("Must run sensitivity analysis before calculating indices")
param_names = list(self.param_ranges.keys())
indices = {}
for metric, values in self.results.items():
# Normalize the metric values
values_norm = (values - jnp.mean(values)) / (jnp.std(values) + 1e-8)
# Calculate correlation coefficients as a simple sensitivity measure
metric_indices = {}
for i, param in enumerate(param_names):
# Use correlation coefficient as a simple proxy for sensitivity
param_values = self.samples[:, i]
param_values_norm = (param_values - jnp.mean(param_values)) / (jnp.std(param_values) + 1e-8)
# Calculate correlation coefficient
corr = jnp.mean(param_values_norm * values_norm)
metric_indices[param] = float(corr ** 2) # Square to get something like an R² value
indices[metric] = metric_indices
return indices
[docs]
def plot(self, metric=None, ax=None, **kwargs):
"""Plot sensitivity analysis results.
Args:
metric: Metric to plot. If None, plot sobol indices for all metrics.
ax: Matplotlib axis to use for plotting.
**kwargs: Additional keyword arguments to pass to plotting function.
Returns:
Matplotlib axis.
"""
if ax is None:
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
# Get sobol indices
indices = self.sobol_indices()
# Choose metric to plot
if metric is None and self.metrics_of_interest:
metric = self.metrics_of_interest[0]
if metric in indices:
# Get indices for this metric
metric_indices = indices[metric]
# Sort indices by value
sorted_indices = sorted(metric_indices.items(), key=lambda x: x[1], reverse=True)
# Plot bar chart
params = [p for p, _ in sorted_indices]
values = [v for _, v in sorted_indices]
ax.bar(params, values, **kwargs)
ax.set_xlabel('Parameter')
ax.set_ylabel('Sensitivity Index')
ax.set_title(f'Sensitivity Indices for {metric}')
# Rotate x-labels if there are many parameters
if len(params) > 3:
import matplotlib.pyplot as plt
plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
return ax
[docs]
def plot_indices(self, figsize: Tuple[int, int] = (10, 6)) -> Any:
"""Plot the sensitivity indices.
Args:
figsize: Figure size as (width, height)
Returns:
Matplotlib figure and axes
"""
try:
import matplotlib.pyplot as plt
import numpy as np
except ImportError:
raise ImportError("Matplotlib is required for plotting. Install it with 'pip install matplotlib'")
indices = self.sobol_indices()
metrics = list(indices.keys())
params = list(indices[metrics[0]].keys())
fig, ax = plt.subplots(figsize=figsize)
x = np.arange(len(metrics))
width = 0.8 / len(params)
for i, param in enumerate(params):
param_values = [indices[metric][param] for metric in metrics]
offset = width * i - width * len(params) / 2 + width / 2
ax.bar(x + offset, param_values, width, label=param)
ax.set_xlabel('Metrics')
ax.set_ylabel('Sensitivity Index')
ax.set_title('Parameter Sensitivity Analysis')
ax.set_xticks(x)
ax.set_xticklabels(metrics)
ax.legend(loc='best')
plt.tight_layout()
return fig, ax
[docs]
class ModelCalibrator:
"""Calibrate model parameters using advanced optimization techniques.
This class provides methods for automatically tuning model parameters
to achieve desired outputs, using gradient-based optimization with Adam,
or various reinforcement learning and evolutionary approaches.
Attributes:
model_factory: Function to create model instances
params: Current parameter values
target_metrics: Target values for each metric
metrics_weights: Importance weights for each metric in the loss function
learning_rate: Learning rate for optimization
max_iterations: Maximum number of optimization iterations
method: Calibration method
loss_type: Type of loss function to use
param_bounds: Parameter bounds for each parameter
evaluation_steps: Number of steps to run model for evaluation
num_evaluation_runs: Number of runs to average for robust evaluation
loss_history: History of loss values during calibration
param_history: History of parameter values during calibration
confidence_intervals: Confidence intervals for metrics
"""
[docs]
def __init__(
self,
model_factory: ModelFactory,
initial_params: Dict[str, float],
target_metrics: Dict[str, float],
param_bounds: Optional[Dict[str, Tuple[float, float]]] = None,
metrics_weights: Optional[Dict[str, float]] = None,
learning_rate: float = 0.01,
max_iterations: int = 100,
method: str = "adam",
loss_type: str = "mse",
evaluation_steps: int = 50,
num_evaluation_runs: int = 3,
tolerance: float = 1e-6,
patience: int = 10,
seed: int = 0
):
"""Initialize model calibrator.
Args:
model_factory: Function to create model instances
initial_params: Initial parameter values
target_metrics: Target metric values
param_bounds: Bounds for each parameter as (min, max) tuples
metrics_weights: Weights for each metric in the loss function
learning_rate: Learning rate for optimization
max_iterations: Maximum number of optimization iterations
method: Calibration method ("adam", "sgd", "es", "pso", "cem", "bayesian")
loss_type: Loss function type ("mse", "mae", "huber", "relative")
evaluation_steps: Number of simulation steps for evaluation
num_evaluation_runs: Number of runs to average for robust evaluation
tolerance: Convergence tolerance
patience: Early stopping patience
seed: Random seed
"""
self.model_factory = model_factory
self.params = initial_params.copy()
self.target_metrics = target_metrics
self.param_bounds = param_bounds or {k: (0.01, 10.0) for k in initial_params}
self.metrics_weights = metrics_weights or {k: 1.0 for k in target_metrics}
self.learning_rate = learning_rate
self.max_iterations = max_iterations
self.method = method
self.loss_type = loss_type
self.evaluation_steps = evaluation_steps
self.num_evaluation_runs = num_evaluation_runs
self.tolerance = tolerance
self.patience = patience
# Initialize random key
self.key = random.PRNGKey(seed)
# History tracking
self.loss_history = []
self.param_history = []
self.confidence_intervals = []
self.best_params = initial_params.copy()
self.best_loss = float('inf')
# Method-specific initialization
self._setup_optimization_method()
def _setup_optimization_method(self):
"""Set up the optimization method."""
if self.method in ["adam", "sgd"]:
self._setup_gradient_optimization()
elif self.method == "es":
self._setup_evolution_strategies()
elif self.method == "pso":
self._setup_particle_swarm()
elif self.method == "cem":
self._setup_cross_entropy()
elif self.method == "bayesian":
self._setup_bayesian_optimization()
elif self.method == "q_learning":
self._setup_q_learning()
elif self.method == "policy_gradient":
self._setup_policy_gradient()
elif self.method == "actor_critic":
self._setup_actor_critic()
elif self.method == "multi_agent_rl":
self._setup_multi_agent_rl()
elif self.method == "dqn":
self._setup_deep_q_network()
else:
raise ValueError(f"Unknown calibration method: {self.method}")
def _compute_loss(self, metrics: Dict[str, float]) -> float:
"""Compute loss based on the specified loss type."""
loss = 0.0
for metric, target in self.target_metrics.items():
if metric not in metrics:
continue
value = metrics[metric]
weight = self.metrics_weights[metric]
if self.loss_type == "mse":
metric_loss = (value - target) ** 2
elif self.loss_type == "mae":
metric_loss = abs(value - target)
elif self.loss_type == "huber":
delta = 1.0
residual = abs(value - target)
# Use JAX's where function instead of if/else for JIT compatibility
metric_loss = jnp.where(
residual <= delta,
0.5 * residual ** 2,
delta * (residual - 0.5 * delta)
)
elif self.loss_type == "relative":
metric_loss = abs(value - target) / (abs(target) + 1e-8)
else:
raise ValueError(f"Unknown loss type: {self.loss_type}")
loss += weight * metric_loss
return loss
def _evaluate_params_robust(self, params: Dict[str, float]) -> Tuple[float, Dict[str, Tuple[float, float]]]:
"""Evaluate parameters with multiple runs for robustness."""
all_metrics = {metric: [] for metric in self.target_metrics}
for run in range(self.num_evaluation_runs):
# Use different seeds for each run
self.key, subkey = random.split(self.key)
seed_value = random.randint(subkey, (), 0, 1_000_000)
config = ModelConfig(seed=seed_value.item())
model = self.model_factory(params=params, config=config)
results = model.run(steps=self.evaluation_steps)
for metric in self.target_metrics:
if metric in results:
# Handle both JAX arrays and lists
if hasattr(results[metric], '__len__') and len(results[metric]) > 0:
all_metrics[metric].append(float(results[metric][-1]))
else:
all_metrics[metric].append(0.0)
else:
all_metrics[metric].append(0.0)
# Compute mean metrics and confidence intervals
mean_metrics = {}
confidence_intervals = {}
for metric, values in all_metrics.items():
values_array = jnp.array(values)
mean_val = float(jnp.mean(values_array))
std_val = float(jnp.std(values_array))
mean_metrics[metric] = mean_val
# 95% confidence interval
ci_half_width = 1.96 * std_val / jnp.sqrt(len(values))
confidence_intervals[metric] = (
mean_val - ci_half_width,
mean_val + ci_half_width
)
# Use normalized loss for better RL performance
loss = self._compute_normalized_loss(mean_metrics)
return loss, confidence_intervals
def _compute_normalized_loss(self, metrics: Dict[str, float]) -> float:
"""Compute normalized loss for better RL optimization (key improvement!)."""
total_loss = 0.0
for metric, target in self.target_metrics.items():
if metric in metrics:
value = metrics[metric]
# Normalize by target to make losses comparable and stable
normalized_error = abs(value - target) / (abs(target) + 1e-8)
total_loss += normalized_error ** 2
return float(total_loss)
def _setup_gradient_optimization(self):
"""Set up gradient-based optimization with Adam or SGD."""
param_names = list(self.params.keys())
# Use a fixed seed for gradient computation to avoid tracer issues
def loss_fn(params_flat):
# Convert flat parameters to dictionary
params = {name: params_flat[i] for i, name in enumerate(param_names)}
# Use a fixed seed for gradient computation (deterministic)
config = ModelConfig(seed=42)
model = self.model_factory(params=params, config=config)
results = model.run(steps=self.evaluation_steps)
# Handle both JAX arrays and lists
metrics = {}
for metric in self.target_metrics:
if metric in results:
if hasattr(results[metric], '__len__') and len(results[metric]) > 0:
# Take the last value, handling both JAX arrays and lists
if hasattr(results[metric], 'at'): # JAX array
metrics[metric] = results[metric][-1]
else: # Python list
metrics[metric] = results[metric][-1]
else:
metrics[metric] = 0.0
else:
metrics[metric] = 0.0
return self._compute_loss(metrics)
self.loss_fn = loss_fn
self.grad_fn = jit(grad(loss_fn))
if self.method == "adam":
# Adam optimizer state
self.adam_m = jnp.zeros(len(param_names)) # First moment
self.adam_v = jnp.zeros(len(param_names)) # Second moment
self.adam_beta1 = 0.9
self.adam_beta2 = 0.999
self.adam_eps = 1e-8
self.adam_t = 0 # Time step
def _setup_evolution_strategies(self):
"""Set up Evolution Strategies (ES) optimization."""
self.es_population_size = 20
self.es_sigma = 0.1 # Mutation strength
self.es_elite_ratio = 0.2 # Fraction of population to keep as elite
# Initialize population
param_names = list(self.params.keys())
n_params = len(param_names)
self.key, subkey = random.split(self.key)
self.es_population = random.normal(subkey, (self.es_population_size, n_params)) * self.es_sigma
# Center population around initial parameters
initial_flat = jnp.array([self.params[name] for name in param_names])
self.es_population = self.es_population + initial_flat[None, :]
# Clip to bounds
for i, (param, (min_val, max_val)) in enumerate(self.param_bounds.items()):
self.es_population = self.es_population.at[:, i].set(
jnp.clip(self.es_population[:, i], min_val, max_val)
)
def _setup_particle_swarm(self):
"""Set up Particle Swarm Optimization (PSO)."""
self.pso_population_size = 20
self.pso_w = 0.7 # Inertia weight
self.pso_c1 = 1.5 # Cognitive parameter
self.pso_c2 = 1.5 # Social parameter
param_names = list(self.params.keys())
n_params = len(param_names)
# Initialize particles
self.key, subkey = random.split(self.key)
self.pso_positions = random.uniform(subkey, (self.pso_population_size, n_params))
# Scale to parameter bounds
for i, (param, (min_val, max_val)) in enumerate(self.param_bounds.items()):
self.pso_positions = self.pso_positions.at[:, i].multiply(max_val - min_val)
self.pso_positions = self.pso_positions.at[:, i].add(min_val)
# Initialize velocities
self.key, subkey = random.split(self.key)
velocity_scale = 0.1
self.pso_velocities = random.normal(subkey, (self.pso_population_size, n_params)) * velocity_scale
# Personal and global best
self.pso_personal_best = self.pso_positions.copy()
self.pso_personal_best_scores = jnp.full(self.pso_population_size, float('inf'))
self.pso_global_best = self.pso_positions[0].copy()
self.pso_global_best_score = float('inf')
def _setup_cross_entropy(self):
"""Set up Cross-Entropy Method (CEM)."""
self.cem_population_size = 50
self.cem_elite_ratio = 0.2
self.cem_noise_decay = 0.99
param_names = list(self.params.keys())
n_params = len(param_names)
# Initialize distribution parameters
self.cem_mean = jnp.array([self.params[name] for name in param_names])
self.cem_std = jnp.ones(n_params) * 0.5
def _setup_bayesian_optimization(self):
"""Set up Bayesian Optimization with Gaussian Process."""
# Simple implementation - in practice, you'd use a library like GPyOpt
self.bo_n_initial = 10
self.bo_acquisition = "ei" # Expected Improvement
param_names = list(self.params.keys())
n_params = len(param_names)
# Generate initial samples
self.key, subkey = random.split(self.key)
self.bo_X = random.uniform(subkey, (self.bo_n_initial, n_params))
# Scale to parameter bounds
for i, (param, (min_val, max_val)) in enumerate(self.param_bounds.items()):
self.bo_X = self.bo_X.at[:, i].multiply(max_val - min_val)
self.bo_X = self.bo_X.at[:, i].add(min_val)
self.bo_y = jnp.full(self.bo_n_initial, float('inf'))
self.bo_evaluated = 0
def _setup_q_learning(self):
"""Set up Improved Q-Learning with continuous action space and better state representation."""
self.ql_learning_rate = 0.001
self.ql_epsilon = 0.2 # Lower initial exploration
self.ql_epsilon_decay = 0.995
self.ql_epsilon_min = 0.02 # Lower minimum exploration
self.ql_gamma = 0.95
self.ql_batch_size = 64 # Larger batch size
self.ql_memory_size = 2000
self.ql_target_update_freq = 50 # Target network updates
param_names = list(self.params.keys())
n_params = len(param_names)
# Enhanced state size: current params + targets + history + gradients
state_size = n_params * 4 # current, normalized, targets, gradients
n_actions = n_params * 5 # 5 different step sizes per parameter
# Improved neural network architecture
self.key, subkey1, subkey2, subkey3, subkey4 = random.split(self.key, 5)
hidden1_size = 256
hidden2_size = 128
hidden3_size = 64
# Main Q-network
self.ql_params = {
'layer1': {
'weights': random.normal(subkey1, (state_size, hidden1_size)) * 0.1,
'bias': jnp.zeros(hidden1_size)
},
'layer2': {
'weights': random.normal(subkey2, (hidden1_size, hidden2_size)) * 0.1,
'bias': jnp.zeros(hidden2_size)
},
'layer3': {
'weights': random.normal(subkey3, (hidden2_size, hidden3_size)) * 0.1,
'bias': jnp.zeros(hidden3_size)
},
'output': {
'weights': random.normal(subkey4, (hidden3_size, n_actions)) * 0.1,
'bias': jnp.zeros(n_actions)
}
}
# Target network (copy of main network)
self.ql_target_params = {k: {kk: vv.copy() for kk, vv in v.items()}
for k, v in self.ql_params.items()}
# Enhanced experience replay with prioritization
self.ql_memory = []
self.ql_param_names = param_names
self.ql_step_sizes = [0.01, 0.03, 0.05, 0.1, 0.2] # Multiple step sizes
# State history for enhanced representation
self.ql_param_history = []
self.ql_loss_history = []
self.ql_gradient_estimates = jnp.zeros(n_params)
# Parameter space normalization
self.ql_param_mins = jnp.array([self.param_bounds[p][0] for p in param_names])
self.ql_param_maxs = jnp.array([self.param_bounds[p][1] for p in param_names])
# Adam optimizer for neural network training
self.ql_adam_m = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.ql_params.items()}
self.ql_adam_v = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.ql_params.items()}
self.ql_adam_t = 0
def _setup_policy_gradient(self):
"""Set up Enhanced Policy Gradient with continuous actions and better architecture."""
self.pg_learning_rate = 0.005
self.pg_baseline_decay = 0.9
self.pg_min_std = 0.05
self.pg_entropy_coeff = 0.02
self.pg_value_coeff = 0.1 # For baseline learning
param_names = list(self.params.keys())
n_params = len(param_names)
# Enhanced state representation
state_size = n_params * 4 # current, normalized, targets, history
# Policy network (actor) - outputs continuous actions
self.key, subkey1, subkey2, subkey3 = random.split(self.key, 4)
hidden_size = 128
self.pg_policy_params = {
'shared_layer1': {
'weights': random.normal(subkey1, (state_size, hidden_size)) * 0.1,
'bias': jnp.zeros(hidden_size)
},
'shared_layer2': {
'weights': random.normal(subkey2, (hidden_size, 64)) * 0.1,
'bias': jnp.zeros(64)
},
'mean_output': {
'weights': random.normal(subkey3, (64, n_params)) * 0.1,
'bias': jnp.zeros(n_params)
},
'std_output': {
'weights': random.normal(subkey3, (64, n_params)) * 0.1,
'bias': jnp.ones(n_params) * jnp.log(0.2) # Initialize to reasonable std
}
}
# Value network (critic) for baseline
self.key, subkey4, subkey5 = random.split(self.key, 3)
self.pg_value_params = {
'layer1': {
'weights': random.normal(subkey4, (state_size, hidden_size)) * 0.1,
'bias': jnp.zeros(hidden_size)
},
'layer2': {
'weights': random.normal(subkey5, (hidden_size, 1)) * 0.1,
'bias': jnp.array([0.0])
}
}
self.pg_param_names = param_names
# Enhanced state tracking
self.pg_state_history = []
self.pg_loss_history = []
# Adam optimizers for both networks
self.pg_policy_adam_m = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.pg_policy_params.items()}
self.pg_policy_adam_v = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.pg_policy_params.items()}
self.pg_value_adam_m = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.pg_value_params.items()}
self.pg_value_adam_v = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.pg_value_params.items()}
self.pg_adam_t = 0
def _setup_actor_critic(self):
"""Set up Advanced Actor-Critic with proper network architecture and training."""
self.ac_actor_lr = 0.003
self.ac_critic_lr = 0.005
self.ac_gamma = 0.95
self.ac_lambda = 0.95 # GAE lambda
self.ac_gradient_clip_norm = 1.0
self.ac_entropy_coeff = 0.02
param_names = list(self.params.keys())
n_params = len(param_names)
state_size = n_params * 4
# Advanced actor network architecture
self.key, *subkeys = random.split(self.key, 9)
hidden_size = 128
self.ac_actor_params = {
'layer1': {
'weights': random.normal(subkeys[0], (state_size, hidden_size)) * 0.1,
'bias': jnp.zeros(hidden_size)
},
'layer2': {
'weights': random.normal(subkeys[1], (hidden_size, 64)) * 0.1,
'bias': jnp.zeros(64)
},
'mean_output': {
'weights': random.normal(subkeys[2], (64, n_params)) * 0.1,
'bias': jnp.zeros(n_params)
},
'std_output': {
'weights': random.normal(subkeys[3], (64, n_params)) * 0.1,
'bias': jnp.ones(n_params) * jnp.log(0.2)
}
}
# Advanced critic network
self.ac_critic_params = {
'layer1': {
'weights': random.normal(subkeys[4], (state_size, hidden_size)) * 0.1,
'bias': jnp.zeros(hidden_size)
},
'layer2': {
'weights': random.normal(subkeys[5], (hidden_size, 64)) * 0.1,
'bias': jnp.zeros(64)
},
'layer3': {
'weights': random.normal(subkeys[6], (64, 32)) * 0.1,
'bias': jnp.zeros(32)
},
'output': {
'weights': random.normal(subkeys[7], (32, 1)) * 0.1,
'bias': jnp.array([0.0])
}
}
self.ac_param_names = param_names
# GAE (Generalized Advantage Estimation) buffers
self.ac_states = []
self.ac_actions = []
self.ac_rewards = []
self.ac_values = []
self.ac_log_probs = []
self.ac_dones = []
# Adam optimizers
self.ac_actor_adam_m = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.ac_actor_params.items()}
self.ac_actor_adam_v = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.ac_actor_params.items()}
self.ac_critic_adam_m = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.ac_critic_params.items()}
self.ac_critic_adam_v = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.ac_critic_params.items()}
self.ac_adam_t = 0
def _setup_deep_q_network(self):
"""Set up Advanced DQN with Double DQN, Dueling Architecture, and Prioritized Experience Replay."""
self.dqn_learning_rate = 0.0005
self.dqn_epsilon = 0.3
self.dqn_epsilon_decay = 0.995
self.dqn_epsilon_min = 0.05
self.dqn_gamma = 0.95
self.dqn_batch_size = 64
self.dqn_memory_size = 3000
self.dqn_target_update_freq = 100
self.dqn_double_dqn = True # Use Double DQN
param_names = list(self.params.keys())
n_params = len(param_names)
state_size = n_params * 4
n_actions = n_params * 7 # More granular actions
# Dueling DQN architecture
self.key, *subkeys = random.split(self.key, 10)
hidden_size = 256
# Shared layers
self.dqn_params = {
'shared1': {
'weights': random.normal(subkeys[0], (state_size, hidden_size)) * 0.1,
'bias': jnp.zeros(hidden_size)
},
'shared2': {
'weights': random.normal(subkeys[1], (hidden_size, 128)) * 0.1,
'bias': jnp.zeros(128)
},
# Value stream
'value1': {
'weights': random.normal(subkeys[2], (128, 64)) * 0.1,
'bias': jnp.zeros(64)
},
'value_output': {
'weights': random.normal(subkeys[3], (64, 1)) * 0.1,
'bias': jnp.array([0.0])
},
# Advantage stream
'advantage1': {
'weights': random.normal(subkeys[4], (128, 64)) * 0.1,
'bias': jnp.zeros(64)
},
'advantage_output': {
'weights': random.normal(subkeys[5], (64, n_actions)) * 0.1,
'bias': jnp.zeros(n_actions)
}
}
# Target network
self.dqn_target_params = {k: {kk: vv.copy() for kk, vv in v.items()}
for k, v in self.dqn_params.items()}
# Prioritized experience replay
self.dqn_memory = []
self.dqn_priorities = []
self.dqn_alpha = 0.6 # Prioritization exponent
self.dqn_beta = 0.4 # Importance sampling exponent
self.dqn_beta_increment = 0.001
self.dqn_param_names = param_names
self.dqn_step_sizes = [0.005, 0.01, 0.02, 0.05, 0.1, 0.15, 0.2]
# Action history for diversity regularization
self.dqn_action_history = []
self.dqn_action_history_size = 10
self.dqn_action_regularization = 0.1
# Adam optimizer
self.dqn_adam_m = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.dqn_params.items()}
self.dqn_adam_v = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.dqn_params.items()}
self.dqn_adam_t = 0
def _setup_multi_agent_rl(self):
"""Set up Advanced Multi-Agent RL with coordinated exploration and communication."""
self.marl_learning_rate = 0.002
self.marl_epsilon = 0.25
self.marl_epsilon_decay = 0.995
self.marl_epsilon_min = 0.05
self.marl_communication_dim = 16 # Communication vector size
param_names = list(self.params.keys())
n_agents = len(param_names)
# Each agent has enhanced architecture with communication
self.marl_agents = {}
for i, param in enumerate(param_names):
self.key, *subkeys = random.split(self.key, 8)
# Enhanced state: own param + global state + communication from others
local_state_size = 1 + n_agents * 3 + (n_agents - 1) * self.marl_communication_dim
self.marl_agents[param] = {
# Q-network with communication
'q_params': {
'layer1': {
'weights': random.normal(subkeys[0], (local_state_size, 128)) * 0.1,
'bias': jnp.zeros(128)
},
'layer2': {
'weights': random.normal(subkeys[1], (128, 64)) * 0.1,
'bias': jnp.zeros(64)
},
'q_output': {
'weights': random.normal(subkeys[2], (64, 5)) * 0.1, # 5 actions per agent
'bias': jnp.zeros(5)
}
},
# Communication network
'comm_params': {
'layer1': {
'weights': random.normal(subkeys[3], (local_state_size, 32)) * 0.1,
'bias': jnp.zeros(32)
},
'comm_output': {
'weights': random.normal(subkeys[4], (32, self.marl_communication_dim)) * 0.1,
'bias': jnp.zeros(self.marl_communication_dim)
}
},
'epsilon': self.marl_epsilon,
'memory': [],
'memory_size': 500,
# Adam optimizers
'q_adam_m': {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.marl_agents.get(param, {}).get('q_params', {}).items()},
'q_adam_v': {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.marl_agents.get(param, {}).get('q_params', {}).items()},
'comm_adam_m': {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.marl_agents.get(param, {}).get('comm_params', {}).items()},
'comm_adam_v': {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.marl_agents.get(param, {}).get('comm_params', {}).items()},
'adam_t': 0
}
# Initialize Adam optimizers properly after agent creation
for param in param_names:
agent = self.marl_agents[param]
agent['q_adam_m'] = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in agent['q_params'].items()}
agent['q_adam_v'] = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in agent['q_params'].items()}
agent['comm_adam_m'] = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in agent['comm_params'].items()}
agent['comm_adam_v'] = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in agent['comm_params'].items()}
self.marl_param_names = param_names
self.marl_step_sizes = [0.01, 0.03, 0.05, 0.1, 0.15] # 5 actions per agent
# Global communication buffer
self.marl_global_comm = jnp.zeros((n_agents, self.marl_communication_dim))
def _params_to_state(self, params: Dict[str, float]) -> jax.Array:
"""Convert parameters to normalized state for neural networks."""
param_values = jnp.array([params[name] for name in self.ql_param_names])
normalized = (param_values - self.ql_param_mins) / (self.ql_param_maxs - self.ql_param_mins)
return jnp.clip(normalized, 0.0, 1.0)
def _find_nearest_bin(self, value: float, bins: jax.Array) -> int:
"""Find the nearest bin index for a value."""
return int(jnp.argmin(jnp.abs(bins - value)))
def _dqn_forward(self, state: jax.Array) -> jax.Array:
"""Forward pass through DQN with deeper architecture."""
# Layer 1
h1 = jnp.tanh(jnp.dot(state, self.dqn_params['layer1']['weights']) +
self.dqn_params['layer1']['bias'])
# Layer 2
h2 = jnp.tanh(jnp.dot(h1, self.dqn_params['layer2']['weights']) +
self.dqn_params['layer2']['bias'])
# Layer 3 (output)
q_values = jnp.dot(h2, self.dqn_params['layer3']['weights']) + \
self.dqn_params['layer3']['bias']
return q_values
def _policy_gradient_sample_action(self, state: jax.Array) -> Tuple[jax.Array, float, float]:
"""Sample action from policy with exploration preservation."""
mean = self.pg_policy_params['mean']
# Enforce minimum standard deviation to prevent policy collapse
log_std = jnp.maximum(self.pg_policy_params['log_std'], jnp.log(self.pg_min_std))
std = jnp.exp(log_std)
# Sample from normal distribution
self.key, subkey = random.split(self.key)
action = mean + std * random.normal(subkey, mean.shape)
# Compute log probability
log_prob = -0.5 * jnp.sum(((action - mean) / std) ** 2) - \
0.5 * jnp.sum(jnp.log(2 * jnp.pi * std ** 2))
# Compute entropy for regularization
entropy = 0.5 * jnp.sum(log_std + jnp.log(2 * jnp.pi * jnp.e))
return action, float(log_prob), float(entropy)
def _actor_critic_forward(self, state: jax.Array) -> Tuple[jax.Array, float, float]:
"""Forward pass for robust actor-critic."""
# Actor (policy) with minimum std constraint
mean = self.ac_actor_params['mean']
log_std = jnp.maximum(self.ac_actor_params['log_std'], jnp.log(0.05)) # Min std = 0.05
std = jnp.exp(log_std)
# Sample action
self.key, subkey = random.split(self.key)
action = mean + std * random.normal(subkey, mean.shape)
# Log probability
log_prob = -0.5 * jnp.sum(((action - mean) / std) ** 2) - \
0.5 * jnp.sum(jnp.log(2 * jnp.pi * std ** 2))
# Robust critic (value function) with hidden layer
h1 = jnp.tanh(jnp.dot(state, self.ac_critic_params['layer1_weights']) +
self.ac_critic_params['layer1_bias'])
value = jnp.dot(h1, self.ac_critic_params['layer2_weights']).squeeze() + \
self.ac_critic_params['layer2_bias'][0]
# Clip value to prevent explosion
value = jnp.clip(value, *self.ac_value_clip_range)
return action, float(log_prob), float(value)
def _ql_forward(self, state: jax.Array) -> jax.Array:
"""Forward pass through Q-learning neural network."""
# Layer 1
h1 = jnp.tanh(jnp.dot(state, self.ql_params['layer1']['weights']) +
self.ql_params['layer1']['bias'])
# Layer 2
h2 = jnp.tanh(jnp.dot(h1, self.ql_params['layer2']['weights']) +
self.ql_params['layer2']['bias'])
# Layer 3 (output)
q_values = jnp.dot(h2, self.ql_params['layer3']['weights']) + \
self.ql_params['layer3']['bias']
return q_values
def _ql_forward_improved(self, state: jax.Array, params: Optional[Dict] = None) -> jax.Array:
"""Improved Q-learning forward pass with optional custom parameters."""
if params is None:
params = self.ql_params
# Enhanced forward pass with deeper network
h1 = jnp.tanh(jnp.dot(state, params['layer1']['weights']) +
params['layer1']['bias'])
h2 = jnp.tanh(jnp.dot(h1, params['layer2']['weights']) +
params['layer2']['bias'])
h3 = jnp.tanh(jnp.dot(h2, params['layer3']['weights']) +
params['layer3']['bias'])
q_values = jnp.dot(h3, params['output']['weights']) + params['output']['bias']
return q_values
def _pg_forward_enhanced(self, state: jax.Array) -> Tuple[jax.Array, float, float, float]:
"""Enhanced policy gradient forward pass returning action, mean, std, value."""
# Policy network (actor) - using correct parameter names
h1_policy = jnp.tanh(jnp.dot(state, self.pg_policy_params['shared_layer1']['weights']) +
self.pg_policy_params['shared_layer1']['bias'])
h2_policy = jnp.tanh(jnp.dot(h1_policy, self.pg_policy_params['shared_layer2']['weights']) +
self.pg_policy_params['shared_layer2']['bias'])
# Mean and log_std for continuous actions with gradient clipping
mean = jnp.dot(h2_policy, self.pg_policy_params['mean_output']['weights']) + self.pg_policy_params['mean_output']['bias']
mean = jnp.clip(mean, -10.0, 10.0) # Clip mean to prevent explosion
log_std = jnp.dot(h2_policy, self.pg_policy_params['std_output']['weights']) + self.pg_policy_params['std_output']['bias']
log_std = jnp.clip(log_std, -5.0, 2.0) # More conservative clipping
std = jnp.exp(log_std)
std = jnp.maximum(std, 0.01) # Minimum std to prevent collapse
# Sample action with safe noise
self.key, subkey = random.split(self.key)
noise = random.normal(subkey, mean.shape)
noise = jnp.clip(noise, -3.0, 3.0) # Clip noise to prevent outliers
action = mean + std * noise
# Value network (critic) with safer computation
h1_value = jnp.tanh(jnp.dot(state, self.pg_value_params['layer1']['weights']) +
self.pg_value_params['layer1']['bias'])
value = jnp.dot(h1_value, self.pg_value_params['layer2']['weights']) + self.pg_value_params['layer2']['bias']
value = jnp.clip(value, -100.0, 100.0) # Clip value to prevent explosion
# Ensure outputs are finite
mean = jnp.where(jnp.isfinite(mean), mean, 0.0)
std = jnp.where(jnp.isfinite(std), std, 0.1)
value = jnp.where(jnp.isfinite(value), value, 0.0)
return action, float(mean[0]), float(std[0]), float(value[0])
def _ac_forward_enhanced(self, state: jax.Array) -> Tuple[jax.Array, float, float, float]:
"""Enhanced actor-critic forward pass returning action, log_prob, value, entropy."""
# Actor network - using correct parameter names
h1_actor = jnp.tanh(jnp.dot(state, self.ac_actor_params['layer1']['weights']) +
self.ac_actor_params['layer1']['bias'])
h2_actor = jnp.tanh(jnp.dot(h1_actor, self.ac_actor_params['layer2']['weights']) +
self.ac_actor_params['layer2']['bias'])
# Action probabilities for discrete actions or mean/std for continuous
mean = jnp.dot(h2_actor, self.ac_actor_params['mean_output']['weights']) + self.ac_actor_params['mean_output']['bias']
log_std = jnp.dot(h2_actor, self.ac_actor_params['std_output']['weights']) + self.ac_actor_params['std_output']['bias']
std = jnp.exp(jnp.clip(log_std, -20, 2))
# Sample action
self.key, subkey = random.split(self.key)
action = mean + std * random.normal(subkey, mean.shape)
# Compute log probability
log_prob = -0.5 * ((action - mean) / std) ** 2 - 0.5 * jnp.log(2 * jnp.pi) - jnp.log(std)
log_prob = jnp.sum(log_prob) # Sum over action dimensions
# Compute entropy
entropy = 0.5 * jnp.log(2 * jnp.pi * jnp.e) + jnp.log(std)
entropy = jnp.sum(entropy)
# Critic network
h1_critic = jnp.tanh(jnp.dot(state, self.ac_critic_params['layer1']['weights']) +
self.ac_critic_params['layer1']['bias'])
h2_critic = jnp.tanh(jnp.dot(h1_critic, self.ac_critic_params['layer2']['weights']) +
self.ac_critic_params['layer2']['bias'])
h3_critic = jnp.tanh(jnp.dot(h2_critic, self.ac_critic_params['layer3']['weights']) +
self.ac_critic_params['layer3']['bias'])
value = jnp.dot(h3_critic, self.ac_critic_params['output']['weights']) + self.ac_critic_params['output']['bias']
return action, float(log_prob), float(value[0]), float(entropy)
def _dqn_forward_dueling(self, state: jax.Array, params: Optional[Dict] = None) -> jax.Array:
"""Dueling DQN forward pass with separate value and advantage streams."""
if params is None:
params = self.dqn_params
# Shared layers
h1 = jnp.tanh(jnp.dot(state, params['shared1']['weights']) +
params['shared1']['bias'])
h2 = jnp.tanh(jnp.dot(h1, params['shared2']['weights']) +
params['shared2']['bias'])
# Value stream - using correct parameter names
v_stream = jnp.tanh(jnp.dot(h2, params['value1']['weights']) +
params['value1']['bias'])
state_value = jnp.dot(v_stream, params['value_output']['weights']) + params['value_output']['bias']
# Advantage stream - using correct parameter names
a_stream = jnp.tanh(jnp.dot(h2, params['advantage1']['weights']) +
params['advantage1']['bias'])
advantages = jnp.dot(a_stream, params['advantage_output']['weights']) + params['advantage_output']['bias']
# Combine value and advantages (dueling architecture)
q_values = state_value + (advantages - jnp.mean(advantages, keepdims=True))
return q_values
def _marl_agent_forward(self, agent_params: Dict, state: float) -> jax.Array:
"""Forward pass for individual multi-agent RL agent."""
state_array = jnp.array([state])
h1 = jnp.tanh(jnp.dot(state_array, agent_params['layer1_weights']) +
agent_params['layer1_bias'])
q_values = jnp.dot(h1, agent_params['layer2_weights']) + agent_params['layer2_bias']
return q_values
def _compute_action_diversity_penalty(self, action: int) -> float:
"""Compute penalty for repetitive actions in DQN."""
if len(self.dqn_action_history) == 0:
return 0.0
# Count recent occurrences of this action
recent_actions = self.dqn_action_history[-self.dqn_action_history_size:]
action_count = sum(1 for a in recent_actions if a == action)
# Penalty proportional to frequency
penalty = self.dqn_action_regularization * (action_count / len(recent_actions))
return penalty
def _create_enhanced_state(self, params: Dict[str, float]) -> jax.Array:
"""Create enhanced state representation with normalized params, targets, history, and gradients."""
param_names = list(self.params.keys())
n_params = len(param_names)
# Current normalized parameters
current_params = jnp.array([params[name] for name in param_names])
param_mins = jnp.array([self.param_bounds[p][0] for p in param_names])
param_maxs = jnp.array([self.param_bounds[p][1] for p in param_names])
normalized_params = (current_params - param_mins) / (param_maxs - param_mins)
normalized_params = jnp.clip(normalized_params, 0.0, 1.0)
# Target metrics (normalized)
target_values = jnp.array([self.target_metrics.get(metric, 0.0) for metric in self.target_metrics])
target_norm = target_values / (jnp.abs(target_values) + 1e-8)
# Parameter history features (recent trends)
if hasattr(self, 'ql_param_history') and len(self.ql_param_history) > 0:
recent_params = jnp.array([self.ql_param_history[-1][name] for name in param_names])
param_change = (current_params - recent_params) / (param_maxs - param_mins + 1e-8)
else:
param_change = jnp.zeros(n_params)
# Gradient estimates
if hasattr(self, 'ql_gradient_estimates'):
gradient_norm = self.ql_gradient_estimates / (jnp.abs(self.ql_gradient_estimates) + 1e-8)
else:
gradient_norm = jnp.zeros(n_params)
# Combine all features
enhanced_state = jnp.concatenate([
normalized_params, # Current normalized parameters
target_norm[:n_params] if len(target_norm) >= n_params else jnp.zeros(n_params), # Targets
param_change, # Recent parameter changes
gradient_norm # Gradient information
])
return enhanced_state
def _adam_update(self, params: Dict, gradients: Dict, adam_m: Dict, adam_v: Dict,
adam_t: int, learning_rate: float = 0.001,
beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8) -> Tuple[Dict, Dict, Dict]:
"""Proper Adam optimizer update for neural network parameters."""
updated_params = {}
updated_m = {}
updated_v = {}
for layer_name, layer_params in params.items():
updated_params[layer_name] = {}
updated_m[layer_name] = {}
updated_v[layer_name] = {}
for param_name, param_value in layer_params.items():
grad = gradients[layer_name][param_name]
m = adam_m[layer_name][param_name]
v = adam_v[layer_name][param_name]
# Update biased first moment estimate
m_new = beta1 * m + (1 - beta1) * grad
# Update biased second raw moment estimate
v_new = beta2 * v + (1 - beta2) * (grad ** 2)
# Compute bias-corrected first moment estimate
m_hat = m_new / (1 - beta1 ** adam_t)
# Compute bias-corrected second raw moment estimate
v_hat = v_new / (1 - beta2 ** adam_t)
# Update parameters
updated_params[layer_name][param_name] = param_value - learning_rate * m_hat / (jnp.sqrt(v_hat) + eps)
updated_m[layer_name][param_name] = m_new
updated_v[layer_name][param_name] = v_new
return updated_params, updated_m, updated_v
[docs]
def calibrate(self, verbose: bool = True) -> Dict[str, float]:
"""Run calibration process and return optimized parameters."""
if verbose:
print(f"Starting calibration with {self.method} method...")
print(f"Target metrics: {self.target_metrics}")
print(f"Parameter bounds: {self.param_bounds}")
if self.method in ["adam", "sgd"]:
return self._calibrate_gradient(verbose)
elif self.method == "es":
return self._calibrate_evolution_strategies(verbose)
elif self.method == "pso":
return self._calibrate_particle_swarm(verbose)
elif self.method == "cem":
return self._calibrate_cross_entropy(verbose)
elif self.method == "bayesian":
return self._calibrate_bayesian(verbose)
elif self.method == "q_learning":
return self._calibrate_q_learning(verbose)
elif self.method == "policy_gradient":
return self._calibrate_policy_gradient(verbose)
elif self.method == "actor_critic":
return self._calibrate_actor_critic(verbose)
elif self.method == "multi_agent_rl":
return self._calibrate_multi_agent_rl(verbose)
elif self.method == "dqn":
return self._calibrate_dqn(verbose)
def _calibrate_gradient(self, verbose: bool) -> Dict[str, float]:
"""Gradient-based calibration with Adam or SGD."""
param_names = list(self.params.keys())
params_flat = jnp.array([self.params[name] for name in param_names])
no_improvement_count = 0
for iteration in range(self.max_iterations):
if verbose:
print(f"\nIteration {iteration+1}/{self.max_iterations}")
# Compute gradients
grads = self.grad_fn(params_flat)
# Check for NaN gradients
if jnp.any(jnp.isnan(grads)):
if verbose:
print("NaN gradients detected, stopping optimization")
break
# Apply optimizer update
if self.method == "adam":
self.adam_t += 1
# Update biased first moment estimate
self.adam_m = self.adam_beta1 * self.adam_m + (1 - self.adam_beta1) * grads
# Update biased second raw moment estimate
self.adam_v = self.adam_beta2 * self.adam_v + (1 - self.adam_beta2) * (grads ** 2)
# Compute bias-corrected first moment estimate
m_hat = self.adam_m / (1 - self.adam_beta1 ** self.adam_t)
# Compute bias-corrected second raw moment estimate
v_hat = self.adam_v / (1 - self.adam_beta2 ** self.adam_t)
# Update parameters
params_flat = params_flat - self.learning_rate * m_hat / (jnp.sqrt(v_hat) + self.adam_eps)
else: # SGD
params_flat = params_flat - self.learning_rate * grads
# Clip to bounds
for i, (param, (min_val, max_val)) in enumerate(self.param_bounds.items()):
params_flat = params_flat.at[i].set(jnp.clip(params_flat[i], min_val, max_val))
# Update parameter dictionary
for i, name in enumerate(param_names):
self.params[name] = float(params_flat[i])
# Robust evaluation
loss, ci = self._evaluate_params_robust(self.params)
# Track history
self.param_history.append(self.params.copy())
self.loss_history.append(loss)
self.confidence_intervals.append(ci)
# Update best parameters
if loss < self.best_loss:
self.best_loss = loss
self.best_params = self.params.copy()
no_improvement_count = 0
else:
no_improvement_count += 1
if verbose:
print(f"Loss: {loss:.6f} (best: {self.best_loss:.6f})")
for metric, target in self.target_metrics.items():
if metric in ci:
mean_val = (ci[metric][0] + ci[metric][1]) / 2
ci_width = ci[metric][1] - ci[metric][0]
print(f" {metric}: {mean_val:.4f} ± {ci_width/2:.4f} (target: {target:.4f})")
# Early stopping
if loss < self.tolerance:
if verbose:
print("Converged: loss below tolerance")
break
if no_improvement_count >= self.patience:
if verbose:
print("Early stopping: no improvement")
break
return self.best_params
def _calibrate_evolution_strategies(self, verbose: bool) -> Dict[str, float]:
"""Evolution Strategies calibration."""
param_names = list(self.params.keys())
n_elite = int(self.es_population_size * self.es_elite_ratio)
for iteration in range(self.max_iterations):
if verbose:
print(f"\nIteration {iteration+1}/{self.max_iterations}")
# Evaluate population
fitness_scores = []
for i in range(self.es_population_size):
params = {name: float(self.es_population[i, j]) for j, name in enumerate(param_names)}
loss, _ = self._evaluate_params_robust(params)
fitness_scores.append(loss)
fitness_scores = jnp.array(fitness_scores)
# Select elite
elite_indices = jnp.argsort(fitness_scores)[:n_elite]
elite_population = self.es_population[elite_indices]
# Update best
best_idx = elite_indices[0]
best_loss = fitness_scores[best_idx]
if best_loss < self.best_loss:
self.best_loss = best_loss
self.best_params = {name: float(self.es_population[best_idx, j])
for j, name in enumerate(param_names)}
# Generate new population
self.key, subkey = random.split(self.key)
# Compute elite mean and covariance
elite_mean = jnp.mean(elite_population, axis=0)
# Generate new population around elite mean
noise = random.normal(subkey, self.es_population.shape) * self.es_sigma
self.es_population = elite_mean[None, :] + noise
# Clip to bounds
for i, (param, (min_val, max_val)) in enumerate(self.param_bounds.items()):
self.es_population = self.es_population.at[:, i].set(
jnp.clip(self.es_population[:, i], min_val, max_val)
)
# Track history
self.loss_history.append(float(best_loss))
self.param_history.append(self.best_params.copy())
if verbose:
print(f"Best loss: {best_loss:.6f}")
print(f"Population mean: {elite_mean}")
# Decay mutation strength
self.es_sigma *= 0.995
if best_loss < self.tolerance:
if verbose:
print("Converged: loss below tolerance")
break
return self.best_params
def _calibrate_particle_swarm(self, verbose: bool) -> Dict[str, float]:
"""Particle Swarm Optimization calibration."""
param_names = list(self.params.keys())
for iteration in range(self.max_iterations):
if verbose:
print(f"\nIteration {iteration+1}/{self.max_iterations}")
# Evaluate particles
for i in range(self.pso_population_size):
params = {name: float(self.pso_positions[i, j]) for j, name in enumerate(param_names)}
loss, _ = self._evaluate_params_robust(params)
# Update personal best
if loss < self.pso_personal_best_scores[i]:
self.pso_personal_best_scores = self.pso_personal_best_scores.at[i].set(loss)
self.pso_personal_best = self.pso_personal_best.at[i].set(self.pso_positions[i])
# Update global best
if loss < self.pso_global_best_score:
self.pso_global_best_score = loss
self.pso_global_best = self.pso_positions[i].copy()
self.best_loss = loss
self.best_params = params.copy()
# Update velocities and positions
self.key, subkey1, subkey2 = random.split(self.key, 3)
r1 = random.uniform(subkey1, self.pso_velocities.shape)
r2 = random.uniform(subkey2, self.pso_velocities.shape)
cognitive = self.pso_c1 * r1 * (self.pso_personal_best - self.pso_positions)
social = self.pso_c2 * r2 * (self.pso_global_best[None, :] - self.pso_positions)
self.pso_velocities = (self.pso_w * self.pso_velocities + cognitive + social)
self.pso_positions = self.pso_positions + self.pso_velocities
# Clip to bounds
for i, (param, (min_val, max_val)) in enumerate(self.param_bounds.items()):
self.pso_positions = self.pso_positions.at[:, i].set(
jnp.clip(self.pso_positions[:, i], min_val, max_val)
)
# Track history
self.loss_history.append(float(self.pso_global_best_score))
self.param_history.append(self.best_params.copy())
if verbose:
print(f"Best loss: {self.pso_global_best_score:.6f}")
print(f"Best params: {self.best_params}")
if self.pso_global_best_score < self.tolerance:
if verbose:
print("Converged: loss below tolerance")
break
return self.best_params
def _calibrate_cross_entropy(self, verbose: bool) -> Dict[str, float]:
"""Cross-Entropy Method calibration."""
param_names = list(self.params.keys())
n_elite = int(self.cem_population_size * self.cem_elite_ratio)
for iteration in range(self.max_iterations):
if verbose:
print(f"\nIteration {iteration+1}/{self.max_iterations}")
# Sample population
self.key, subkey = random.split(self.key)
population = random.normal(subkey, (self.cem_population_size, len(param_names)))
population = population * self.cem_std[None, :] + self.cem_mean[None, :]
# Clip to bounds
for i, (param, (min_val, max_val)) in enumerate(self.param_bounds.items()):
population = population.at[:, i].set(jnp.clip(population[:, i], min_val, max_val))
# Evaluate population
fitness_scores = []
for i in range(self.cem_population_size):
params = {name: float(population[i, j]) for j, name in enumerate(param_names)}
loss, _ = self._evaluate_params_robust(params)
fitness_scores.append(loss)
fitness_scores = jnp.array(fitness_scores)
# Select elite
elite_indices = jnp.argsort(fitness_scores)[:n_elite]
elite_population = population[elite_indices]
# Update distribution parameters
self.cem_mean = jnp.mean(elite_population, axis=0)
self.cem_std = jnp.std(elite_population, axis=0) + 1e-6 # Add small epsilon
# Update best
best_idx = elite_indices[0]
best_loss = fitness_scores[best_idx]
if best_loss < self.best_loss:
self.best_loss = best_loss
self.best_params = {name: float(population[best_idx, j])
for j, name in enumerate(param_names)}
# Track history
self.loss_history.append(float(best_loss))
self.param_history.append(self.best_params.copy())
if verbose:
print(f"Best loss: {best_loss:.6f}")
print(f"Distribution mean: {self.cem_mean}")
print(f"Distribution std: {self.cem_std}")
# Decay noise
self.cem_std *= self.cem_noise_decay
if best_loss < self.tolerance:
if verbose:
print("Converged: loss below tolerance")
break
return self.best_params
def _calibrate_bayesian(self, verbose: bool) -> Dict[str, float]:
"""Bayesian Optimization calibration (simplified implementation)."""
param_names = list(self.params.keys())
# Evaluate initial points
for i in range(self.bo_n_initial):
if self.bo_evaluated >= self.bo_n_initial:
break
params = {name: float(self.bo_X[i, j]) for j, name in enumerate(param_names)}
loss, _ = self._evaluate_params_robust(params)
self.bo_y = self.bo_y.at[i].set(loss)
if loss < self.best_loss:
self.best_loss = loss
self.best_params = params.copy()
self.bo_evaluated += 1
# Main optimization loop
for iteration in range(self.max_iterations - self.bo_n_initial):
if verbose:
print(f"\nIteration {iteration+1}/{self.max_iterations - self.bo_n_initial}")
# Simple acquisition function: random search with bias toward good regions
# In practice, you'd use a proper GP and acquisition function
self.key, subkey = random.split(self.key)
# Find best current point
best_idx = jnp.argmin(self.bo_y[:self.bo_evaluated])
best_point = self.bo_X[best_idx]
# Sample around best point with some exploration
noise_scale = 0.1 * (1.0 - iteration / self.max_iterations) # Decay exploration
candidate = best_point + random.normal(subkey, best_point.shape) * noise_scale
# Clip to bounds
for i, (param, (min_val, max_val)) in enumerate(self.param_bounds.items()):
candidate = candidate.at[i].set(jnp.clip(candidate[i], min_val, max_val))
# Evaluate candidate
params = {name: float(candidate[j]) for j, name in enumerate(param_names)}
loss, _ = self._evaluate_params_robust(params)
# Add to dataset
self.bo_X = jnp.concatenate([self.bo_X, candidate[None, :]], axis=0)
self.bo_y = jnp.concatenate([self.bo_y, jnp.array([loss])], axis=0)
self.bo_evaluated += 1
# Update best
if loss < self.best_loss:
self.best_loss = loss
self.best_params = params.copy()
# Track history
self.loss_history.append(float(self.best_loss))
self.param_history.append(self.best_params.copy())
if verbose:
print(f"Best loss: {self.best_loss:.6f}")
print(f"Best params: {self.best_params}")
if self.best_loss < self.tolerance:
if verbose:
print("Converged: loss below tolerance")
break
return self.best_params
def _calibrate_q_learning(self, verbose: bool) -> Dict[str, float]:
"""Fixed Q-Learning calibration with working optimization."""
no_improvement_count = 0
# Initialize Q-table for discrete states/actions
if not hasattr(self, 'ql_q_table'):
self.ql_q_table = {}
# Define simple action space: (param_index, step_size, direction)
param_names = list(self.params.keys())
actions = []
for i, param in enumerate(param_names):
min_val, max_val = self.param_bounds[param]
small_step = (max_val - min_val) * 0.02 # 2% of range
large_step = (max_val - min_val) * 0.1 # 10% of range
actions.extend([
(i, small_step, 1), # small increase
(i, small_step, -1), # small decrease
(i, large_step, 1), # large increase
(i, large_step, -1), # large decrease
])
def get_state_key(params):
"""Convert parameters to discrete state."""
state_parts = []
for param, value in params.items():
min_val, max_val = self.param_bounds[param]
normalized = (value - min_val) / (max_val - min_val)
bin_idx = int(jnp.clip(normalized * 10, 0, 9)) # 10 bins
state_parts.append(str(bin_idx))
return "_".join(state_parts)
for iteration in range(self.max_iterations):
if verbose:
print(f"\nIteration {iteration+1}/{self.max_iterations}")
# Get current state
state_key = get_state_key(self.params)
# Initialize Q-values for new states
if state_key not in self.ql_q_table:
self.ql_q_table[state_key] = jnp.zeros(len(actions))
# Epsilon-greedy action selection
self.key, subkey = random.split(self.key)
if random.uniform(subkey) < self.ql_epsilon:
action_idx = int(random.randint(subkey, (), 0, len(actions)))
else:
action_idx = int(jnp.argmax(self.ql_q_table[state_key]))
# Apply action
param_idx, step_size, direction = actions[action_idx]
param_name = param_names[param_idx]
new_params = self.params.copy()
min_val, max_val = self.param_bounds[param_name]
new_value = self.params[param_name] + direction * step_size
new_params[param_name] = float(jnp.clip(new_value, min_val, max_val))
# Evaluate new parameters
old_loss = self.loss_history[-1] if self.loss_history else self._evaluate_params_robust(self.params)[0]
new_loss, ci = self._evaluate_params_robust(new_params)
# Compute reward (improvement with proper scaling)
improvement = old_loss - new_loss
reward = improvement * 100.0 # Scale for better learning
# Update Q-value
next_state_key = get_state_key(new_params)
if next_state_key not in self.ql_q_table:
self.ql_q_table[next_state_key] = jnp.zeros(len(actions))
# Q-learning update
current_q = self.ql_q_table[state_key][action_idx]
max_next_q = jnp.max(self.ql_q_table[next_state_key])
target_q = reward + self.ql_gamma * max_next_q
# Update Q-table
updated_q_values = self.ql_q_table[state_key].at[action_idx].set(
current_q + self.ql_learning_rate * (target_q - current_q)
)
self.ql_q_table[state_key] = updated_q_values
# Update parameters if improvement
if new_loss < old_loss:
self.params = new_params
# Track history
self.param_history.append(self.params.copy())
self.loss_history.append(new_loss)
self.confidence_intervals.append(ci)
# Update best parameters
if new_loss < self.best_loss:
self.best_loss = new_loss
self.best_params = self.params.copy()
no_improvement_count = 0
else:
no_improvement_count += 1
if verbose:
print(f"Loss: {new_loss:.6f} (best: {self.best_loss:.6f})")
print(f"Action: {action_idx} (param: {param_name}, step: {step_size:.4f}, dir: {direction})")
print(f"Reward: {reward:.6f}, Epsilon: {self.ql_epsilon:.3f}")
# Decay exploration
self.ql_epsilon = max(self.ql_epsilon * self.ql_epsilon_decay, self.ql_epsilon_min)
# Early stopping
if new_loss < self.tolerance:
if verbose:
print("Converged: loss below tolerance")
break
if no_improvement_count >= self.patience:
if verbose:
print("Early stopping: no improvement")
break
return self.best_params
def _train_ql_network_improved(self):
"""Improved Q-learning neural network training with proper Adam optimization."""
# Sample random batch
self.key, subkey = random.split(self.key)
batch_indices = random.choice(subkey, len(self.ql_memory),
shape=(self.ql_batch_size,), replace=False)
batch = [self.ql_memory[i] for i in batch_indices]
# Prepare batch data
states = jnp.array([exp[0] for exp in batch])
actions = jnp.array([exp[1] for exp in batch])
rewards = jnp.array([exp[2] for exp in batch])
next_states = jnp.array([exp[3] for exp in batch])
# Compute target Q-values using target network (Double DQN)
next_q_values_main = jnp.array([self._ql_forward_improved(next_state) for next_state in next_states])
next_q_values_target = jnp.array([self._ql_forward_improved(next_state, self.ql_target_params)
for next_state in next_states])
# Double DQN: use main network to select actions, target network to evaluate
next_actions = jnp.argmax(next_q_values_main, axis=1)
target_q_values = rewards + self.ql_gamma * next_q_values_target[jnp.arange(len(batch)), next_actions]
# Compute current Q-values
current_q_values = jnp.array([self._ql_forward_improved(state) for state in states])
current_q_selected = current_q_values[jnp.arange(len(batch)), actions]
# Compute TD errors
td_errors = target_q_values - current_q_selected
# Compute gradients (simplified backpropagation)
gradients = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.ql_params.items()}
for i in range(len(batch)):
state = states[i]
action = actions[i]
td_error = td_errors[i]
# Forward pass to get activations
h1 = jnp.tanh(jnp.dot(state, self.ql_params['layer1']['weights']) +
self.ql_params['layer1']['bias'])
h2 = jnp.tanh(jnp.dot(h1, self.ql_params['layer2']['weights']) +
self.ql_params['layer2']['bias'])
h3 = jnp.tanh(jnp.dot(h2, self.ql_params['layer3']['weights']) +
self.ql_params['layer3']['bias'])
# Backpropagation (simplified)
# Output layer gradients
grad_output_w = jnp.zeros_like(self.ql_params['output']['weights'])
grad_output_b = jnp.zeros_like(self.ql_params['output']['bias'])
grad_output_w = grad_output_w.at[:, action].add(td_error * h3)
grad_output_b = grad_output_b.at[action].add(td_error)
# Layer 3 gradients
delta3 = jnp.zeros(h3.shape[0])
delta3 = delta3.at[:].add(td_error * self.ql_params['output']['weights'][:, action])
delta3 = delta3 * (1 - h3**2) # tanh derivative
grad_layer3_w = jnp.outer(h2, delta3)
grad_layer3_b = delta3
# Accumulate gradients
gradients['output']['weights'] += grad_output_w / len(batch)
gradients['output']['bias'] += grad_output_b / len(batch)
gradients['layer3']['weights'] += grad_layer3_w / len(batch)
gradients['layer3']['bias'] += grad_layer3_b / len(batch)
# Apply Adam optimization
self.ql_adam_t += 1
self.ql_params, self.ql_adam_m, self.ql_adam_v = self._adam_update(
self.ql_params, gradients, self.ql_adam_m, self.ql_adam_v,
self.ql_adam_t, self.ql_learning_rate
)
def _calibrate_policy_gradient(self, verbose: bool) -> Dict[str, float]:
"""Enhanced Policy Gradient calibration with improved actor-critic architecture."""
no_improvement_count = 0
for iteration in range(self.max_iterations):
if verbose:
print(f"\nIteration {iteration+1}/{self.max_iterations}")
# Create enhanced state representation
state = self._create_enhanced_state(self.params)
# Store state history
self.pg_state_history.append(state)
if len(self.pg_state_history) > 20: # Keep recent history
self.pg_state_history.pop(0)
# Sample action from enhanced policy
action, mean, std, value = self._pg_forward_enhanced(state)
# Convert action to parameter dictionary and clip to bounds
new_params = {}
for i, param_name in enumerate(self.pg_param_names):
min_val, max_val = self.param_bounds[param_name]
new_params[param_name] = float(jnp.clip(action[i], min_val, max_val))
# Evaluate parameters
loss, ci = self._evaluate_params_robust(new_params)
# Shaped reward function with safety checks
old_loss = self.loss_history[-1] if self.loss_history else float('inf')
improvement = old_loss - loss
reward = jnp.clip(improvement * 10.0, -50.0, 50.0) # Clip reward to prevent explosion
# Add target-based reward shaping with safety
target_reward = 0.0
for metric, target in self.target_metrics.items():
if metric in ci:
current_value = (ci[metric][0] + ci[metric][1]) / 2
if jnp.isfinite(current_value) and jnp.isfinite(target):
# Reward based on inverse distance to target
distance = abs(current_value - target) / (abs(target) + 1e-8)
target_reward += 1.0 / (1.0 + distance) # Higher reward for closer values
reward = jnp.clip(reward + target_reward, -100.0, 100.0)
# Compute log probability for the action taken with safety checks
action_safe = jnp.where(jnp.isfinite(action), action, 0.0)
mean_safe = jnp.where(jnp.isfinite(mean), mean, 0.0)
std_safe = jnp.maximum(jnp.where(jnp.isfinite(std), std, 0.1), 0.01)
log_prob = -0.5 * jnp.sum(((action_safe - mean_safe) / std_safe) ** 2) - \
0.5 * jnp.sum(jnp.log(2 * jnp.pi * std_safe ** 2))
log_prob = jnp.clip(log_prob, -50.0, 50.0)
# Compute entropy for exploration bonus with safety
entropy = 0.5 * jnp.sum(jnp.log(2 * jnp.pi * jnp.e * std_safe ** 2))
entropy = jnp.clip(entropy, -50.0, 50.0)
# Store episode data
self.pg_loss_history.append(loss)
# Advantage estimation (TD error) with safety checks
if len(self.pg_loss_history) > 1:
# Simple advantage estimate with clipping
advantage = jnp.clip(reward - value, -50.0, 50.0)
else:
advantage = jnp.clip(reward, -50.0, 50.0)
# Check for NaN/inf in key values
if not jnp.isfinite(advantage):
advantage = 0.0
if not jnp.isfinite(log_prob):
log_prob = 0.0
if not jnp.isfinite(entropy):
entropy = 0.0
# Policy gradient update with entropy regularization and safety
policy_loss = -float(log_prob) * advantage - self.pg_entropy_coeff * entropy
value_loss = (reward - value) ** 2
# Skip update if any critical values are NaN
if not (jnp.isfinite(policy_loss) and jnp.isfinite(value_loss)):
if verbose:
print(f"Skipping update due to NaN values: policy_loss={policy_loss}, value_loss={value_loss}")
continue
# Compute gradients for policy network (simplified)
policy_gradients = self._compute_policy_gradients_enhanced(
state, action, mean, std, advantage, entropy
)
# Compute gradients for value network
value_gradients = self._compute_value_gradients_enhanced(state, reward, value)
# Apply Adam optimization for both networks
self.pg_adam_t += 1
# Update policy network
self.pg_policy_params, self.pg_policy_adam_m, self.pg_policy_adam_v = self._adam_update(
self.pg_policy_params, policy_gradients, self.pg_policy_adam_m,
self.pg_policy_adam_v, self.pg_adam_t, self.pg_learning_rate
)
# Update value network
self.pg_value_params, self.pg_value_adam_m, self.pg_value_adam_v = self._adam_update(
self.pg_value_params, value_gradients, self.pg_value_adam_m,
self.pg_value_adam_v, self.pg_adam_t, self.pg_learning_rate * self.pg_value_coeff
)
# Update current parameters to the new sampled parameters
self.params = new_params.copy()
# Track history
self.param_history.append(self.params.copy())
self.loss_history.append(loss)
self.confidence_intervals.append(ci)
# Update best parameters
if loss < self.best_loss:
self.best_loss = loss
self.best_params = self.params.copy()
no_improvement_count = 0
else:
no_improvement_count += 1
if verbose:
print(f"Loss: {loss:.6f} (best: {self.best_loss:.6f})")
print(f"Reward: {reward:.6f}, Advantage: {advantage:.6f}")
print(f"Value: {value:.6f}, Entropy: {entropy:.6f}")
print(f"Policy std: {jnp.mean(std):.4f}")
# Early stopping
if loss < self.tolerance:
if verbose:
print("Converged: loss below tolerance")
break
if no_improvement_count >= self.patience:
if verbose:
print("Early stopping: no improvement")
break
return self.best_params
def _compute_policy_gradients_enhanced(self, state: jax.Array, action: jax.Array,
mean: jax.Array, std: jax.Array,
advantage: float, entropy: float) -> Dict:
"""Compute policy gradients with entropy regularization and safety checks."""
gradients = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.pg_policy_params.items()}
# Safety checks for inputs
action = jnp.where(jnp.isfinite(action), action, 0.0)
mean = jnp.where(jnp.isfinite(mean), mean, 0.0)
std = jnp.maximum(jnp.where(jnp.isfinite(std), std, 0.1), 0.01)
advantage = jnp.where(jnp.isfinite(advantage), advantage, 0.0)
# Forward pass to get activations
h1 = jnp.tanh(jnp.dot(state, self.pg_policy_params['shared_layer1']['weights']) +
self.pg_policy_params['shared_layer1']['bias'])
h2 = jnp.tanh(jnp.dot(h1, self.pg_policy_params['shared_layer2']['weights']) +
self.pg_policy_params['shared_layer2']['bias'])
# Gradients for mean output with safety
grad_log_prob_mean = (action - mean) / (std ** 2)
grad_log_prob_mean = jnp.clip(grad_log_prob_mean, -10.0, 10.0)
policy_grad_mean = advantage * grad_log_prob_mean
policy_grad_mean = jnp.clip(policy_grad_mean, -5.0, 5.0)
gradients['mean_output']['weights'] = jnp.outer(h2, policy_grad_mean)
gradients['mean_output']['bias'] = policy_grad_mean
# Gradients for std output (log_std) with safety
grad_log_prob_std = ((action - mean) ** 2 / (std ** 2) - 1) / std
grad_log_prob_std = jnp.clip(grad_log_prob_std, -10.0, 10.0)
entropy_grad_std = self.pg_entropy_coeff / std
entropy_grad_std = jnp.clip(entropy_grad_std, -1.0, 1.0)
policy_grad_std = advantage * grad_log_prob_std + entropy_grad_std
policy_grad_std = jnp.clip(policy_grad_std, -5.0, 5.0)
gradients['std_output']['weights'] = jnp.outer(h2, policy_grad_std)
gradients['std_output']['bias'] = policy_grad_std
# Ensure all gradients are finite
for layer_name in gradients:
for param_name in gradients[layer_name]:
gradients[layer_name][param_name] = jnp.where(
jnp.isfinite(gradients[layer_name][param_name]),
gradients[layer_name][param_name],
0.0
)
return gradients
def _compute_value_gradients_enhanced(self, state: jax.Array, target: float, value: float) -> Dict:
"""Compute value network gradients."""
gradients = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.pg_value_params.items()}
# Value prediction error
value_error = target - value
# Forward pass to get activations
h1 = jnp.tanh(jnp.dot(state, self.pg_value_params['layer1']['weights']) +
self.pg_value_params['layer1']['bias'])
# Gradients for value network
gradients['layer2']['weights'] = jnp.outer(h1, jnp.array([value_error]))
gradients['layer2']['bias'] = jnp.array([value_error])
return gradients
def _calibrate_actor_critic(self, verbose: bool) -> Dict[str, float]:
"""Enhanced Actor-Critic calibration with proper network architecture and GAE."""
no_improvement_count = 0
for iteration in range(self.max_iterations):
if verbose:
print(f"\nIteration {iteration+1}/{self.max_iterations}")
# Current enhanced state
state = self._create_enhanced_state(self.params)
# Forward pass through enhanced actor-critic
action, log_prob, value, entropy = self._ac_forward_enhanced(state)
# Convert action to parameter dictionary and clip to bounds
new_params = {}
for i, param_name in enumerate(self.ac_param_names):
min_val, max_val = self.param_bounds[param_name]
new_params[param_name] = float(jnp.clip(action[i], min_val, max_val))
# Evaluate parameters
loss, ci = self._evaluate_params_robust(new_params)
# Shaped reward
old_loss = self.loss_history[-1] if self.loss_history else float('inf')
improvement = old_loss - loss
reward = improvement * 10.0
# Add target-based bonus
target_bonus = 0.0
for metric, target in self.target_metrics.items():
if metric in ci:
current_value = (ci[metric][0] + ci[metric][1]) / 2
distance = abs(current_value - target) / (abs(target) + 1e-8)
target_bonus += 1.0 / (1.0 + distance)
reward += target_bonus
# Store experience for GAE
self.ac_states.append(state)
self.ac_actions.append(action)
self.ac_rewards.append(reward)
self.ac_values.append(value)
self.ac_log_probs.append(log_prob)
self.ac_dones.append(False) # Not episodic
# Compute advantage and train networks
if len(self.ac_states) >= 5: # Train every few steps
advantages = self._compute_gae(self.ac_rewards, self.ac_values, self.ac_dones)
# Actor loss (policy gradient with entropy bonus)
actor_loss = -float(log_prob) * advantages[-1] - self.ac_entropy_coeff * entropy
# Critic loss (value prediction error)
critic_loss = (reward - value) ** 2
# Compute gradients
actor_gradients = self._compute_actor_gradients_enhanced(state, action, advantages[-1], entropy)
critic_gradients = self._compute_critic_gradients_enhanced(state, reward, value)
# Apply Adam optimization
self.ac_adam_t += 1
# Update actor
self.ac_actor_params, self.ac_actor_adam_m, self.ac_actor_adam_v = self._adam_update(
self.ac_actor_params, actor_gradients, self.ac_actor_adam_m,
self.ac_actor_adam_v, self.ac_adam_t, self.ac_actor_lr
)
# Update critic
self.ac_critic_params, self.ac_critic_adam_m, self.ac_critic_adam_v = self._adam_update(
self.ac_critic_params, critic_gradients, self.ac_critic_adam_m,
self.ac_critic_adam_v, self.ac_adam_t, self.ac_critic_lr
)
# Clear buffers
self.ac_states = self.ac_states[-1:] # Keep last state
self.ac_actions = self.ac_actions[-1:]
self.ac_rewards = self.ac_rewards[-1:]
self.ac_values = self.ac_values[-1:]
self.ac_log_probs = self.ac_log_probs[-1:]
self.ac_dones = self.ac_dones[-1:]
# Update parameters
self.params = new_params
# Track history
self.param_history.append(self.params.copy())
self.loss_history.append(loss)
self.confidence_intervals.append(ci)
# Update best parameters
if loss < self.best_loss:
self.best_loss = loss
self.best_params = self.params.copy()
no_improvement_count = 0
else:
no_improvement_count += 1
if verbose:
print(f"Loss: {loss:.6f} (best: {self.best_loss:.6f})")
print(f"Reward: {reward:.6f}")
print(f"Value: {value:.6f}, Entropy: {entropy:.6f}")
# Early stopping
if loss < self.tolerance:
if verbose:
print("Converged: loss below tolerance")
break
if no_improvement_count >= self.patience:
if verbose:
print("Early stopping: no improvement")
break
return self.best_params
def _compute_gae(self, rewards: List[float], values: List[float], dones: List[bool]) -> List[float]:
"""Compute Generalized Advantage Estimation."""
advantages = []
gae = 0
for i in reversed(range(len(rewards))):
if i == len(rewards) - 1:
next_value = 0.0
else:
next_value = values[i + 1]
delta = rewards[i] + self.ac_gamma * next_value * (1 - dones[i]) - values[i]
gae = delta + self.ac_gamma * self.ac_lambda * (1 - dones[i]) * gae
advantages.insert(0, gae)
return advantages
def _compute_actor_gradients_enhanced(self, state: jax.Array, action: jax.Array,
advantage: float, entropy: float) -> Dict:
"""Compute actor gradients with enhanced architecture."""
gradients = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.ac_actor_params.items()}
# Forward pass to get activations
a1 = jnp.tanh(jnp.dot(state, self.ac_actor_params['layer1']['weights']) +
self.ac_actor_params['layer1']['bias'])
a2 = jnp.tanh(jnp.dot(a1, self.ac_actor_params['layer2']['weights']) +
self.ac_actor_params['layer2']['bias'])
mean = jnp.dot(a2, self.ac_actor_params['mean_output']['weights']) + \
self.ac_actor_params['mean_output']['bias']
log_std = jnp.dot(a2, self.ac_actor_params['std_output']['weights']) + \
self.ac_actor_params['std_output']['bias']
std = jnp.exp(jnp.maximum(log_std, jnp.log(0.05)))
# Policy gradients
grad_log_prob_mean = (action - mean) / (std ** 2)
policy_grad_mean = advantage * grad_log_prob_mean
grad_log_prob_std = ((action - mean) ** 2 / (std ** 2) - 1) / std
entropy_grad_std = self.ac_entropy_coeff / std
policy_grad_std = advantage * grad_log_prob_std + entropy_grad_std
# Output layer gradients
gradients['mean_output']['weights'] = jnp.outer(a2, policy_grad_mean)
gradients['mean_output']['bias'] = policy_grad_mean
gradients['std_output']['weights'] = jnp.outer(a2, policy_grad_std)
gradients['std_output']['bias'] = policy_grad_std
return gradients
def _compute_critic_gradients_enhanced(self, state: jax.Array, target: float, value: float) -> Dict:
"""Compute critic gradients with enhanced architecture."""
gradients = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.ac_critic_params.items()}
# Value prediction error
value_error = target - value
# Forward pass to get activations
c1 = jnp.tanh(jnp.dot(state, self.ac_critic_params['layer1']['weights']) +
self.ac_critic_params['layer1']['bias'])
c2 = jnp.tanh(jnp.dot(c1, self.ac_critic_params['layer2']['weights']) +
self.ac_critic_params['layer2']['bias'])
c3 = jnp.tanh(jnp.dot(c2, self.ac_critic_params['layer3']['weights']) +
self.ac_critic_params['layer3']['bias'])
# Output layer gradients
gradients['output']['weights'] = jnp.outer(c3, jnp.array([value_error]))
gradients['output']['bias'] = jnp.array([value_error])
return gradients
def _calibrate_multi_agent_rl(self, verbose: bool) -> Dict[str, float]:
"""Enhanced Multi-Agent RL calibration with communication and coordination."""
no_improvement_count = 0
for iteration in range(self.max_iterations):
if verbose:
print(f"\nIteration {iteration+1}/{self.max_iterations}")
# Create enhanced global state
global_state = self._create_enhanced_state(self.params)
# Each agent selects an action for its parameter
new_params = self.params.copy()
agent_actions = {}
agent_communications = {}
# First pass: generate communications
for param_name in self.marl_param_names:
agent = self.marl_agents[param_name]
# Create local state (parameter + global info)
min_val, max_val = self.param_bounds[param_name]
normalized_param = (self.params[param_name] - min_val) / (max_val - min_val)
local_state = jnp.array([normalized_param])
# Get communication from other agents (zeros for first iteration)
other_comms = self.marl_global_comm[self.marl_param_names.index(param_name)]
# Forward pass to get Q-values and communication
q_values, comm_output = self._marl_agent_forward_enhanced(
agent, local_state, other_comms.reshape(1, -1)
)
agent_communications[param_name] = comm_output
# Update global communication buffer
for i, param_name in enumerate(self.marl_param_names):
self.marl_global_comm = self.marl_global_comm.at[i].set(
agent_communications[param_name]
)
# Second pass: select actions with updated communications
for param_name in self.marl_param_names:
agent = self.marl_agents[param_name]
# Create enhanced local state
min_val, max_val = self.param_bounds[param_name]
normalized_param = (self.params[param_name] - min_val) / (max_val - min_val)
# Include global state and communications from other agents
other_comms = []
for other_param in self.marl_param_names:
if other_param != param_name:
other_comms.append(agent_communications[other_param])
if other_comms:
comm_input = jnp.concatenate(other_comms)
else:
comm_input = jnp.zeros(self.marl_communication_dim)
# Enhanced local state: own param + global features + communications
local_state = jnp.concatenate([
jnp.array([normalized_param]),
global_state[:len(self.marl_param_names)], # Global param state
global_state[len(self.marl_param_names):2*len(self.marl_param_names)], # Targets
comm_input
])
# Action selection with epsilon-greedy
self.key, subkey = random.split(self.key)
if random.uniform(subkey) < agent['epsilon']:
action = random.randint(subkey, (), 0, len(self.marl_step_sizes))
else:
q_values, _ = self._marl_agent_forward_enhanced(
agent, local_state.reshape(1, -1), jnp.zeros((1, self.marl_communication_dim))
)
action = jnp.argmax(q_values)
agent_actions[param_name] = int(action)
# Apply action
step_size = self.marl_step_sizes[action]
step = step_size * (max_val - min_val)
# Random direction
self.key, subkey = random.split(self.key)
direction = 1 if random.uniform(subkey) < 0.5 else -1
new_value = self.params[param_name] + direction * step
new_params[param_name] = float(jnp.clip(new_value, min_val, max_val))
# Evaluate joint action
loss, ci = self._evaluate_params_robust(new_params)
# Shaped reward for all agents
old_loss = self.loss_history[-1] if self.loss_history else float('inf')
improvement = old_loss - loss
reward = improvement * 10.0
# Coordination bonus (all agents get same reward)
target_bonus = 0.0
for metric, target in self.target_metrics.items():
if metric in ci:
current_value = (ci[metric][0] + ci[metric][1]) / 2
distance = abs(current_value - target) / (abs(target) + 1e-8)
target_bonus += 1.0 / (1.0 + distance)
reward += target_bonus
# Update each agent's neural network
for param_name in self.marl_param_names:
agent = self.marl_agents[param_name]
action = agent_actions[param_name]
# States for this agent
min_val, max_val = self.param_bounds[param_name]
current_norm = (self.params[param_name] - min_val) / (max_val - min_val)
next_norm = (new_params[param_name] - min_val) / (max_val - min_val)
current_state = jnp.array([current_norm])
next_state = jnp.array([next_norm])
# Store experience
experience = (current_state, action, reward, next_state)
agent['memory'].append(experience)
if len(agent['memory']) > agent['memory_size']:
agent['memory'].pop(0)
# Train agent's network
if len(agent['memory']) >= 16: # Mini-batch size
self._train_marl_agent_enhanced(agent, param_name)
# Decay exploration
agent['epsilon'] = max(agent['epsilon'] * self.marl_epsilon_decay, self.marl_epsilon_min)
# Update parameters
self.params = new_params
# Track history
self.param_history.append(self.params.copy())
self.loss_history.append(loss)
self.confidence_intervals.append(ci)
# Update best parameters
if loss < self.best_loss:
self.best_loss = loss
self.best_params = self.params.copy()
no_improvement_count = 0
else:
no_improvement_count += 1
if verbose:
print(f"Loss: {loss:.6f} (best: {self.best_loss:.6f})")
print(f"Reward: {reward:.6f}")
print(f"Actions: {agent_actions}")
avg_epsilon = sum(agent['epsilon'] for agent in self.marl_agents.values()) / len(self.marl_agents)
print(f"Avg Epsilon: {avg_epsilon:.3f}")
# Early stopping
if loss < self.tolerance:
if verbose:
print("Converged: loss below tolerance")
break
if no_improvement_count >= self.patience:
if verbose:
print("Early stopping: no improvement")
break
return self.best_params
def _train_marl_agent_enhanced(self, agent: Dict, param_name: str):
"""Enhanced training for multi-agent RL with proper Adam optimization."""
# Sample random batch
self.key, subkey = random.split(self.key)
batch_size = min(16, len(agent['memory']))
batch_indices = random.choice(subkey, len(agent['memory']),
shape=(batch_size,), replace=False)
batch = [agent['memory'][i] for i in batch_indices]
# Prepare batch data
states = jnp.array([exp[0] for exp in batch])
actions = jnp.array([exp[1] for exp in batch])
rewards = jnp.array([exp[2] for exp in batch])
next_states = jnp.array([exp[3] for exp in batch])
# Compute target Q-values
next_q_values = []
for next_state in next_states:
comm_input = jnp.zeros((1, self.marl_communication_dim))
q_vals, _ = self._marl_agent_forward_enhanced(agent, next_state.reshape(1, -1), comm_input)
next_q_values.append(q_vals)
next_q_values = jnp.array(next_q_values)
target_q_values = rewards + 0.95 * jnp.max(next_q_values, axis=1)
# Compute current Q-values
current_q_values = []
for state in states:
comm_input = jnp.zeros((1, self.marl_communication_dim))
q_vals, _ = self._marl_agent_forward_enhanced(agent, state.reshape(1, -1), comm_input)
current_q_values.append(q_vals)
current_q_values = jnp.array(current_q_values)
# Compute TD errors
td_errors = target_q_values - current_q_values[jnp.arange(len(batch)), actions]
# Compute gradients (simplified)
gradients = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in agent['q_params'].items()}
for i in range(len(batch)):
state = states[i]
action = actions[i]
td_error = td_errors[i]
# Simple gradient for Q-output layer
grad_w = jnp.zeros_like(agent['q_params']['q_output']['weights'])
grad_b = jnp.zeros_like(agent['q_params']['q_output']['bias'])
# Assume state features for gradient computation
grad_w = grad_w.at[:, action].add(td_error * state[0]) # Simplified
grad_b = grad_b.at[action].add(td_error)
gradients['q_output']['weights'] += grad_w / len(batch)
gradients['q_output']['bias'] += grad_b / len(batch)
# Apply Adam optimization
agent['adam_t'] += 1
agent['q_params'], agent['q_adam_m'], agent['q_adam_v'] = self._adam_update(
agent['q_params'], gradients, agent['q_adam_m'],
agent['q_adam_v'], agent['adam_t'], self.marl_learning_rate
)
def _calibrate_dqn(self, verbose: bool) -> Dict[str, float]:
"""Advanced DQN calibration with Dueling architecture and prioritized experience replay."""
no_improvement_count = 0
target_update_counter = 0
for iteration in range(self.max_iterations):
if verbose:
print(f"\nIteration {iteration+1}/{self.max_iterations}")
# Current enhanced state
state = self._create_enhanced_state(self.params)
# Epsilon-greedy action selection with diversity penalty
self.key, subkey = random.split(self.key)
if random.uniform(subkey) < self.dqn_epsilon:
# Explore: random action with diversity bias
n_actions = len(self.dqn_param_names) * len(self.dqn_step_sizes)
action = random.randint(subkey, (), 0, n_actions)
else:
# Exploit: best action from DQN with diversity penalty
q_values = self._dqn_forward_dueling(state)
# Apply action diversity penalty
for i in range(len(q_values)):
penalty = self._compute_action_diversity_penalty(i)
q_values = q_values.at[i].add(-penalty)
action = jnp.argmax(q_values)
# Track action for diversity
self.dqn_action_history.append(int(action))
if len(self.dqn_action_history) > self.dqn_action_history_size:
self.dqn_action_history.pop(0)
# Apply action to parameters
new_params = self.params.copy()
param_idx = int(action) // len(self.dqn_step_sizes)
step_idx = int(action) % len(self.dqn_step_sizes)
param_name = self.dqn_param_names[param_idx]
# Update parameter
min_val, max_val = self.param_bounds[param_name]
step_size = self.dqn_step_sizes[step_idx]
step = step_size * (max_val - min_val)
# Random direction
self.key, subkey = random.split(self.key)
direction = 1 if random.uniform(subkey) < 0.5 else -1
new_value = self.params[param_name] + direction * step
new_params[param_name] = float(jnp.clip(new_value, min_val, max_val))
# Evaluate new parameters
loss, ci = self._evaluate_params_robust(new_params)
# Shaped reward
old_loss = self.loss_history[-1] if self.loss_history else float('inf')
improvement = old_loss - loss
reward = improvement * 10.0
# Store experience
next_state = self._create_enhanced_state(new_params)
experience = (state, int(action), reward, next_state, False) # not done
self.dqn_memory.append(experience)
if len(self.dqn_memory) > self.dqn_memory_size:
self.dqn_memory.pop(0)
# Train DQN if we have enough experiences
if len(self.dqn_memory) >= self.dqn_batch_size:
self._train_dqn_improved()
# Update target network periodically
target_update_counter += 1
if target_update_counter >= self.dqn_target_update_freq:
self.dqn_target_params = {k: {kk: vv.copy() for kk, vv in v.items()}
for k, v in self.dqn_params.items()}
target_update_counter = 0
# Update parameters
self.params = new_params
# Track history
self.param_history.append(self.params.copy())
self.loss_history.append(loss)
self.confidence_intervals.append(ci)
# Update best parameters
if loss < self.best_loss:
self.best_loss = loss
self.best_params = self.params.copy()
no_improvement_count = 0
else:
no_improvement_count += 1
if verbose:
print(f"Loss: {loss:.6f} (best: {self.best_loss:.6f})")
print(f"Action: {action}, Reward: {reward:.6f}")
print(f"Epsilon: {self.dqn_epsilon:.3f}")
# Decay exploration with minimum
self.dqn_epsilon = max(self.dqn_epsilon * self.dqn_epsilon_decay, self.dqn_epsilon_min)
# Early stopping
if loss < self.tolerance:
if verbose:
print("Converged: loss below tolerance")
break
if no_improvement_count >= self.patience:
if verbose:
print("Early stopping: no improvement")
break
return self.best_params
def _train_dqn_improved(self):
"""Train the DQN using improved experience replay."""
# Sample random batch
self.key, subkey = random.split(self.key)
batch_indices = random.choice(subkey, len(self.dqn_memory),
shape=(self.dqn_batch_size,), replace=False)
batch = [self.dqn_memory[i] for i in batch_indices]
# Prepare batch data
states = jnp.array([exp[0] for exp in batch])
actions = jnp.array([exp[1] for exp in batch])
rewards = jnp.array([exp[2] for exp in batch])
next_states = jnp.array([exp[3] for exp in batch])
# Compute target Q-values using Double DQN
next_q_values_main = jnp.array([self._dqn_forward_dueling(next_state) for next_state in next_states])
next_q_values_target = jnp.array([self._dqn_forward_dueling(next_state, self.dqn_target_params)
for next_state in next_states])
# Double DQN: use main network to select actions, target network to evaluate
next_actions = jnp.argmax(next_q_values_main, axis=1)
target_q_values = rewards + self.dqn_gamma * next_q_values_target[jnp.arange(len(batch)), next_actions]
# Compute current Q-values
current_q_values = jnp.array([self._dqn_forward_dueling(state) for state in states])
current_q_selected = current_q_values[jnp.arange(len(batch)), actions]
# Compute TD errors
td_errors = target_q_values - current_q_selected
# Simplified gradient update (focusing on advantage stream)
gradients = {k: {kk: jnp.zeros_like(vv) for kk, vv in v.items()}
for k, v in self.dqn_params.items()}
for i in range(len(batch)):
state = states[i]
action = actions[i]
td_error = td_errors[i]
# Forward pass to get activations
h1 = jnp.relu(jnp.dot(state, self.dqn_params['shared1']['weights']) +
self.dqn_params['shared1']['bias'])
h2 = jnp.relu(jnp.dot(h1, self.dqn_params['shared2']['weights']) +
self.dqn_params['shared2']['bias'])
a1 = jnp.relu(jnp.dot(h2, self.dqn_params['advantage1']['weights']) +
self.dqn_params['advantage1']['bias'])
# Gradients for advantage output
grad_adv_w = jnp.zeros_like(self.dqn_params['advantage_output']['weights'])
grad_adv_b = jnp.zeros_like(self.dqn_params['advantage_output']['bias'])
grad_adv_w = grad_adv_w.at[:, action].add(td_error * a1)
grad_adv_b = grad_adv_b.at[action].add(td_error)
# Accumulate gradients
gradients['advantage_output']['weights'] += grad_adv_w / len(batch)
gradients['advantage_output']['bias'] += grad_adv_b / len(batch)
# Apply Adam optimization
self.dqn_adam_t += 1
self.dqn_params, self.dqn_adam_m, self.dqn_adam_v = self._adam_update(
self.dqn_params, gradients, self.dqn_adam_m, self.dqn_adam_v,
self.dqn_adam_t, self.dqn_learning_rate
)
[docs]
def get_calibration_history(self) -> Dict[str, List[Any]]:
"""Get calibration history.
Returns:
Dictionary with 'loss', 'params', and 'confidence_intervals' histories
"""
return {
"loss": self.loss_history,
"params": self.param_history,
"confidence_intervals": self.confidence_intervals
}
[docs]
def plot_calibration(self, figsize: Tuple[int, int] = (15, 10)) -> Any:
"""Plot comprehensive calibration results.
Args:
figsize: Figure size as (width, height)
Returns:
Matplotlib figure and axes
"""
try:
import matplotlib.pyplot as plt
import numpy as np
except ImportError:
raise ImportError("Matplotlib is required for plotting. Install it with 'pip install matplotlib'")
fig, axes = plt.subplots(2, 2, figsize=figsize)
# Plot loss history
axes[0, 0].plot(self.loss_history, 'b-', linewidth=2)
axes[0, 0].set_title(f"Loss History ({self.method.upper()})")
axes[0, 0].set_xlabel("Iteration")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].set_yscale('log')
# Plot parameter evolution
param_names = list(self.params.keys())
for param in param_names:
values = [params[param] for params in self.param_history]
axes[0, 1].plot(values, label=param, linewidth=2)
axes[0, 1].set_title("Parameter Evolution")
axes[0, 1].set_xlabel("Iteration")
axes[0, 1].set_ylabel("Parameter Value")
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
# Plot final metrics vs targets
if self.confidence_intervals:
final_ci = self.confidence_intervals[-1]
metrics = list(self.target_metrics.keys())
targets = [self.target_metrics[m] for m in metrics]
current_means = [(final_ci[m][0] + final_ci[m][1]) / 2 for m in metrics]
current_errors = [(final_ci[m][1] - final_ci[m][0]) / 2 for m in metrics]
x_pos = np.arange(len(metrics))
axes[1, 0].bar(x_pos - 0.2, targets, 0.4, label='Target', alpha=0.7)
axes[1, 0].errorbar(x_pos + 0.2, current_means, yerr=current_errors,
fmt='o', capsize=5, label='Current ± CI')
axes[1, 0].set_title("Final Metrics vs Targets")
axes[1, 0].set_xlabel("Metrics")
axes[1, 0].set_ylabel("Value")
axes[1, 0].set_xticks(x_pos)
axes[1, 0].set_xticklabels(metrics)
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
# Plot convergence analysis
if len(self.loss_history) > 10:
# Moving average of loss
window = min(10, len(self.loss_history) // 4)
moving_avg = np.convolve(self.loss_history, np.ones(window)/window, mode='valid')
axes[1, 1].plot(range(window-1, len(self.loss_history)), moving_avg,
'r-', linewidth=2, label=f'Moving Avg (window={window})')
axes[1, 1].plot(self.loss_history, 'b-', alpha=0.3, label='Raw Loss')
axes[1, 1].set_title("Convergence Analysis")
axes[1, 1].set_xlabel("Iteration")
axes[1, 1].set_ylabel("Loss")
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_yscale('log')
plt.tight_layout()
return fig, axes
[docs]
class EnsembleCalibrator:
"""Ensemble calibrator that combines multiple optimization methods.
This class runs multiple calibration methods in parallel and combines
their results to find the best parameter set.
"""
[docs]
def __init__(
self,
model_factory: ModelFactory,
initial_params: Dict[str, float],
target_metrics: Dict[str, float],
methods: List[str] = ["adam", "es", "pso"],
**kwargs
):
"""Initialize ensemble calibrator.
Args:
model_factory: Function to create model instances
initial_params: Initial parameter values
target_metrics: Target metric values
methods: List of optimization methods to use
**kwargs: Additional arguments passed to individual calibrators
"""
self.model_factory = model_factory
self.initial_params = initial_params
self.target_metrics = target_metrics
self.methods = methods
self.kwargs = kwargs
self.calibrators = {}
self.results = {}
[docs]
def calibrate(self, verbose: bool = True) -> Dict[str, Any]:
"""Run ensemble calibration.
Args:
verbose: Whether to print progress information
Returns:
Dictionary with results from all methods and the best overall result
"""
if verbose:
print(f"Running ensemble calibration with methods: {self.methods}")
best_loss = float('inf')
best_method = None
best_params = None
for method in self.methods:
if verbose:
print(f"\n{'='*50}")
print(f"Running {method.upper()} calibration")
print(f"{'='*50}")
# Create calibrator for this method
calibrator = ModelCalibrator(
model_factory=self.model_factory,
initial_params=self.initial_params.copy(),
target_metrics=self.target_metrics,
method=method,
**self.kwargs
)
# Run calibration
try:
final_params = calibrator.calibrate(verbose=verbose)
final_loss = calibrator.best_loss
# Store results
self.calibrators[method] = calibrator
self.results[method] = {
'params': final_params,
'loss': final_loss,
'history': calibrator.get_calibration_history()
}
# Update best
if final_loss < best_loss:
best_loss = final_loss
best_method = method
best_params = final_params
if verbose:
print(f"{method.upper()} final loss: {final_loss:.6f}")
except Exception as e:
if verbose:
print(f"Error in {method}: {e}")
continue
# Store best overall result
self.results['best'] = {
'method': best_method,
'params': best_params,
'loss': best_loss
}
if verbose:
print(f"\n{'='*50}")
print("ENSEMBLE RESULTS")
print(f"{'='*50}")
print(f"Best method: {best_method}")
print(f"Best loss: {best_loss:.6f}")
print(f"Best parameters: {best_params}")
return self.results
[docs]
def plot_comparison(self, figsize: Tuple[int, int] = (15, 10)) -> Any:
"""Plot comparison of all methods.
Args:
figsize: Figure size as (width, height)
Returns:
Matplotlib figure and axes
"""
try:
import matplotlib.pyplot as plt
import numpy as np
except ImportError:
raise ImportError("Matplotlib is required for plotting. Install it with 'pip install matplotlib'")
if not self.results:
raise ValueError("Must run calibration before plotting")
fig, axes = plt.subplots(2, 2, figsize=figsize)
# Plot loss histories
for method in self.methods:
if method in self.results:
loss_history = self.results[method]['history']['loss']
axes[0, 0].plot(loss_history, label=method.upper(), linewidth=2)
axes[0, 0].set_title("Loss History Comparison")
axes[0, 0].set_xlabel("Iteration")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].set_yscale('log')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
# Plot final losses
methods_with_results = [m for m in self.methods if m in self.results]
final_losses = [self.results[m]['loss'] for m in methods_with_results]
axes[0, 1].bar(methods_with_results, final_losses, alpha=0.7)
axes[0, 1].set_title("Final Loss Comparison")
axes[0, 1].set_xlabel("Method")
axes[0, 1].set_ylabel("Final Loss")
axes[0, 1].set_yscale('log')
axes[0, 1].grid(True, alpha=0.3)
# Plot parameter convergence for best method
if 'best' in self.results:
best_method = self.results['best']['method']
if best_method in self.results:
param_history = self.results[best_method]['history']['params']
param_names = list(self.initial_params.keys())
for param in param_names:
values = [params[param] for params in param_history]
axes[1, 0].plot(values, label=param, linewidth=2)
axes[1, 0].set_title(f"Parameter Evolution (Best: {best_method.upper()})")
axes[1, 0].set_xlabel("Iteration")
axes[1, 0].set_ylabel("Parameter Value")
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
# Plot convergence speed comparison
for method in self.methods:
if method in self.results:
loss_history = self.results[method]['history']['loss']
# Normalize to show convergence speed
if len(loss_history) > 1:
normalized = np.array(loss_history) / loss_history[0]
axes[1, 1].plot(normalized, label=method.upper(), linewidth=2)
axes[1, 1].set_title("Convergence Speed Comparison")
axes[1, 1].set_xlabel("Iteration")
axes[1, 1].set_ylabel("Normalized Loss")
axes[1, 1].set_yscale('log')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)
plt.tight_layout()
return fig, axes
def compare_calibration_methods(
model_factory: ModelFactory,
initial_params: Dict[str, float],
target_metrics: Dict[str, float],
methods: List[str] = ["adam", "sgd", "es", "pso", "cem"],
param_bounds: Optional[Dict[str, Tuple[float, float]]] = None,
max_iterations: int = 50,
verbose: bool = True,
**kwargs
) -> Dict[str, Any]:
"""Compare different calibration methods on the same problem.
Args:
model_factory: Function to create model instances
initial_params: Initial parameter values
target_metrics: Target metric values
methods: List of methods to compare
param_bounds: Parameter bounds for each parameter
max_iterations: Maximum iterations for each method
verbose: Whether to print progress
**kwargs: Additional arguments passed to ModelCalibrator
Returns:
Dictionary with comparison results
"""
ensemble = EnsembleCalibrator(
model_factory=model_factory,
initial_params=initial_params,
target_metrics=target_metrics,
methods=methods,
param_bounds=param_bounds,
max_iterations=max_iterations,
**kwargs
)
results = ensemble.calibrate(verbose=verbose)
if verbose:
print("\n" + "="*60)
print("CALIBRATION METHOD COMPARISON SUMMARY")
print("="*60)
# Sort methods by performance
method_performance = []
for method in methods:
if method in results:
method_performance.append((method, results[method]['loss']))
method_performance.sort(key=lambda x: x[1])
print(f"{'Rank':<6} {'Method':<10} {'Final Loss':<15} {'Improvement':<15}")
print("-" * 60)
for i, (method, loss) in enumerate(method_performance):
if i == 0:
improvement = "Best"
else:
best_loss = method_performance[0][1]
improvement = f"{loss/best_loss:.2f}x worse"
print(f"{i+1:<6} {method.upper():<10} {loss:<15.6f} {improvement:<15}")
return results
# Example usage function
def create_calibration_example():
"""Create an example demonstrating the improved calibration capabilities."""
# This is a placeholder example - in practice, you'd use your actual model factory
def example_model_factory(params, config):
"""Example model factory for demonstration."""
# This would be replaced with your actual model creation logic
class ExampleModel:
def __init__(self, params, config):
self.params = params
self.config = config
def run(self, steps=50):
# Simulate some model behavior
import jax.numpy as jnp
import jax.random as random
key = random.PRNGKey(self.config.seed)
# Simulate metrics that depend on parameters
noise = random.normal(key, (steps,)) * 0.1
# Example: metric depends on parameter values
metric1_base = self.params.get('param1', 1.0) * 2.0
metric2_base = self.params.get('param2', 1.0) ** 2
metric1_values = [float(metric1_base + noise[i]) for i in range(steps)]
metric2_values = [float(metric2_base + noise[i]) for i in range(steps)]
return {
'metric1': metric1_values,
'metric2': metric2_values
}
return ExampleModel(params, config)
# Example usage
initial_params = {'param1': 1.0, 'param2': 2.0}
target_metrics = {'metric1': 3.0, 'metric2': 4.0}
param_bounds = {'param1': (0.1, 5.0), 'param2': (0.1, 5.0)}
print("Example: Advanced Model Calibration")
print("="*50)
# Single method calibration
print("\n1. Single Method Calibration (Adam)")
calibrator = ModelCalibrator(
model_factory=example_model_factory,
initial_params=initial_params,
target_metrics=target_metrics,
param_bounds=param_bounds,
method="adam",
max_iterations=30
)
result = calibrator.calibrate()
print(f"Final parameters: {result}")
# Ensemble calibration
print("\n2. Ensemble Calibration")
ensemble_results = compare_calibration_methods(
model_factory=example_model_factory,
initial_params=initial_params,
target_metrics=target_metrics,
methods=["adam", "es", "pso"],
max_iterations=20,
verbose=True
)
return {
'single_method': result,
'ensemble': ensemble_results
}
if __name__ == "__main__":
# Run example if script is executed directly
example_results = create_calibration_example()