determine_max_batch_size¶
- torch_sim.autobatching.determine_max_batch_size(state, model, max_atoms=500_000, start_size=1, scale_factor=1.6, oom_error_message='CUDA out of memory')[source]¶
Determine maximum batch size that fits in GPU memory.
Uses a geometric sequence to efficiently search for the largest number of batches that can be processed without running out of GPU memory. This function incrementally tests larger batch sizes until it encounters an out-of-memory error or reaches the specified maximum atom count.
- Parameters:
state (
SimState) – State to replicate for testing.model (
ModelInterface) – Model to test with.max_atoms (
int) – Upper limit on number of atoms to try (for safety). Defaults to 500,000.start_size (
int) – Initial batch size to test. Defaults to 1.scale_factor (
float) – Factor to multiply batch size by in each iteration. Defaults to 1.6.oom_error_message (
str | list[str]) – String or list of strings to match in RuntimeError messages to identify out-of-memory errors. Defaults to “CUDA out of memory”.
- Returns:
Maximum number of batches that fit in GPU memory.
- Return type:
- Raises:
RuntimeError – If a RuntimeError occurs that doesn’t match any of the specified OOM error messages.
Example:
# Find the maximum batch size for a Lennard-Jones model max_batches = determine_max_batch_size( state=sample_state, model=lj_model, max_atoms=100_000 )
Notes
The function returns a batch size slightly smaller than the actual maximum (with a safety margin) to avoid operating too close to memory limits.