calculate_memory_scalers¶
- torch_sim.autobatching.calculate_memory_scalers(state, memory_scales_with='n_atoms_x_density', cutoff=6.0)[source]¶
Calculate a metric that estimates memory requirements for each system in a state.
Provides different scaling metrics that correlate with memory usage. Models with radial neighbor cutoffs generally scale with “n_atoms_x_density”, while models with a fixed number of neighbors scale with “n_atoms”. For molecular systems, “n_edges” gives the most accurate estimate by computing the actual neighbor list edge count using the provided cutoff. The choice of metric can significantly impact the accuracy of memory requirement estimations for different types of simulation systems.
Uses vectorized operations for batched periodic states and
state[i]indexing for non-periodic systems so no eager split is needed.- Parameters:
state (
SimState) – State to calculate metric for, with shape information specific to the SimState instance.memory_scales_with (
"n_atoms_x_density" | "n_atoms" | "n_edges") – Type of metric to use. “n_atoms” uses only atom count and is suitable for models that have a fixed number of neighbors. “n_atoms_x_density” uses atom count multiplied by number density and is better for models with radial cutoffs. “n_edges” computes the actual neighbor list edge count, which is the most accurate metric overall but more expensive to compute than the alternatives; strongly recommended for molecular systems. Defaults to “n_atoms_x_density”.cutoff (
float) – Neighbor list cutoff distance in Angstroms. Only used when memory_scales_with=”n_edges”. Should match the model’s cutoff for best accuracy. Defaults to 7.0.
- Returns:
Calculated metric value for each system.
- Return type:
- Raises:
ValueError – If an invalid metric type is provided.
Example:
# Calculate memory scaling factor based on atom count metrics = calculate_memory_scalers(state, memory_scales_with="n_atoms") # Calculate memory scaling factor based on atom count and density metrics = calculate_memory_scalers(state, memory_scales_with="n_atoms_x_density") # Calculate memory scaling factor based on actual neighbor list edge count metrics = calculate_memory_scalers( state, memory_scales_with="n_edges", cutoff=5.0 )