estimate_max_memory_scaler

torch_sim.autobatching.estimate_max_memory_scaler(states, model, metric_values, **kwargs)[source]

Estimate maximum memory scaling metric that fits in GPU memory.

Tests both minimum and maximum metric states to determine a safe upper bound for the memory scaling metric. This approach ensures the estimated value works for both small, dense systems and large, sparse systems.

Parameters:
  • states (SimState | Sequence[SimState]) – Batched state or list of states. Individual systems are accessed via states[idx] (integer indexing), so only the two extreme states are materialized.

  • model (ModelInterface) – Model to test with, implementing the ModelInterface protocol.

  • metric_values (list[float]) – Corresponding metric values for each state, as calculated by calculate_memory_scalers().

  • **kwargs – Additional keyword arguments passed to determine_max_batch_size.

Returns:

Maximum safe metric value that fits in GPU memory.

Return type:

float

Example:

# Calculate metrics for a set of states
metrics = calculate_memory_scalers(state, memory_scales_with="n_atoms")

# Estimate maximum safe metric value
max_metric = estimate_max_memory_scaler(states, model, metrics)

Notes

This function tests batch sizes with both the smallest and largest systems to find a conservative estimate that works across varying system sizes. The returned value will be the minimum of the two estimates.