"""Implementations of NVE integrators."""
from typing import Any
import torch
from torch_sim.integrators.md import (
MDState,
initialize_momenta,
momentum_step,
position_step,
)
from torch_sim.models.interface import ModelInterface
from torch_sim.state import SimState
[docs]
def nve_init(
state: SimState,
model: ModelInterface,
*,
kT: float | torch.Tensor,
**_kwargs: Any,
) -> MDState:
"""Initialize an NVE state from input data.
Creates an initial state for NVE molecular dynamics by computing initial
energies and forces, and sampling momenta from a Maxwell-Boltzmann distribution
at the specified temperature.
To seed the RNG set ``state.rng = seed`` before calling.
Args:
model: Neural network model that computes energies and forces.
Must return a dict with 'energy' and 'forces' keys.
state: SimState containing positions, masses, cell, pbc, and other
required state variables
kT: Temperature in energy units for initializing momenta,
scalar or with shape [n_systems]
Returns:
MDState: Initialized state for NVE integration containing positions,
momenta, forces, energy, and other required attributes
Notes:
- Initial velocities sampled from Maxwell-Boltzmann distribution
- Time integration error scales as O(dt²)
"""
model_output = model(state)
momenta = getattr(
state,
"momenta",
initialize_momenta(
state.positions,
state.masses,
state.system_idx,
kT,
state.rng,
),
)
return MDState.from_state(
state,
momenta=momenta,
energy=model_output["energy"],
forces=model_output["forces"],
)
[docs]
def nve_step(
state: MDState, model: ModelInterface, *, dt: float | torch.Tensor, **_kwargs: Any
) -> MDState:
"""Perform one complete NVE (microcanonical) integration step.
This function implements the velocity Verlet algorithm for NVE dynamics,
which provides energy-conserving time evolution. The integration sequence is:
1. Half momentum update using current forces
2. Full position update using updated momenta
3. Force update at new positions
4. Half momentum update using new forces
Args:
model: Neural network model that computes energies and forces.
Must return a dict with 'energy' and 'forces' keys.
state: Current system state containing positions, momenta, forces
dt: Integration timestep, either scalar or shape [n_systems]
Returns:
MDState: Updated state after one complete NVE step with new positions,
momenta, forces, and energy
Notes:
- Uses velocity Verlet algorithm for time reversible integration
- Conserves energy in the absence of numerical errors
- Handles periodic boundary conditions if enabled in state
- Symplectic integrator preserving phase space volume
"""
dt = torch.as_tensor(dt, device=state.device, dtype=state.dtype)
state = momentum_step(state, dt / 2)
state = position_step(state, dt)
model_output = model(state)
state.energy = model_output["energy"]
state.forces = model_output["forces"]
return momentum_step(state, dt / 2)