Source code for torch_sim.integrators.md

"""Core molecular dynamics state and operations."""

from collections.abc import Callable
from dataclasses import dataclass

import torch

from torch_sim import transforms
from torch_sim.models.interface import ModelInterface
from torch_sim.state import SimState


[docs] @dataclass class MDState(SimState): """State information for molecular dynamics simulations. This class represents the complete state of a molecular system being integrated with molecular dynamics. It extends the base SimState class to include additional attributes required for MD simulations, such as momenta, energy, and forces. The class also provides computed properties like velocities. Attributes: positions (torch.Tensor): Particle positions [n_particles, n_dim] masses (torch.Tensor): Particle masses [n_particles] cell (torch.Tensor): Simulation cell matrix [n_systems, n_dim, n_dim] pbc (bool): Whether to use periodic boundary conditions system_idx (torch.Tensor): System indices [n_particles] atomic_numbers (torch.Tensor): Atomic numbers [n_particles] momenta (torch.Tensor): Particle momenta [n_particles, n_dim] energy (torch.Tensor): Total energy of the system [n_systems] forces (torch.Tensor): Forces on particles [n_particles, n_dim] Properties: velocities (torch.Tensor): Particle velocities [n_particles, n_dim] n_systems (int): Number of independent systems in the batch device (torch.device): Device on which tensors are stored dtype (torch.dtype): Data type of tensors """ momenta: torch.Tensor energy: torch.Tensor forces: torch.Tensor _atom_attributes = ( SimState._atom_attributes | {"momenta", "forces"} # noqa: SLF001 ) _system_attributes = ( SimState._system_attributes | {"energy"} # noqa: SLF001 ) @property def velocities(self) -> torch.Tensor: """Velocities calculated from momenta and masses with shape [n_particles, n_dimensions]. """ return self.momenta / self.masses.unsqueeze(-1)
[docs] def calculate_momenta( positions: torch.Tensor, masses: torch.Tensor, system_idx: torch.Tensor, kT: float | torch.Tensor, seed: int | None = None, ) -> torch.Tensor: """Initialize particle momenta based on temperature. Generates random momenta for particles following the Maxwell-Boltzmann distribution at the specified temperature. The center of mass motion is removed to prevent system drift. Args: positions (torch.Tensor): Particle positions [n_particles, n_dim] masses (torch.Tensor): Particle masses [n_particles] system_idx (torch.Tensor): System indices [n_particles] kT (torch.Tensor): Temperature in energy units [n_systems] seed (int, optional): Random seed for reproducibility. Defaults to None. Returns: torch.Tensor: Initialized momenta [n_particles, n_dim] """ device = positions.device dtype = positions.dtype generator = torch.Generator(device=device) if seed is not None: generator.manual_seed(seed) if isinstance(kT, torch.Tensor) and len(kT.shape) > 0: # kT is a tensor with shape (n_systems,) kT = kT[system_idx] # Generate random momenta from normal distribution momenta = torch.randn( positions.shape, device=device, dtype=dtype, generator=generator ) * torch.sqrt(masses * kT).unsqueeze(-1) systemwise_momenta = torch.zeros( size=(int(system_idx[-1]) + 1, momenta.shape[1]), device=device, dtype=dtype ) # create 3 copies of system_idx system_idx_3 = system_idx.view(-1, 1).repeat(1, 3) bincount = torch.bincount(system_idx) mean_momenta = torch.scatter_reduce( systemwise_momenta, dim=0, index=system_idx_3, src=momenta, reduce="sum", ) / bincount.view(-1, 1) return torch.where( torch.repeat_interleave(bincount > 1, bincount).view(-1, 1), momenta - mean_momenta[system_idx], momenta, )
[docs] def momentum_step[T: MDState](state: T, dt: float | torch.Tensor) -> T: """Update particle momenta using current forces. This function performs the momentum update step of velocity Verlet integration by applying forces over the timestep dt. It implements the equation: p(t+dt) = p(t) + F(t) * dt Args: state (MDState): Current system state containing forces and momenta dt (torch.Tensor): Integration timestep, either scalar or with shape [n_systems] Returns: MDState: Updated state with new momenta after force application """ new_momenta = state.momenta + state.forces * dt state.momenta = new_momenta return state
[docs] def position_step[T: MDState](state: T, dt: float | torch.Tensor) -> T: """Update particle positions using current velocities. This function performs the position update step of velocity Verlet integration by propagating particles according to their velocities over timestep dt. It implements the equation: r(t+dt) = r(t) + v(t) * dt Args: state (MDState): Current system state containing positions and velocities dt (torch.Tensor): Integration timestep, either scalar or with shape [n_systems] Returns: MDState: Updated state with new positions after propagation """ new_positions = state.positions + state.velocities * dt if state.pbc: # Split positions and cells by system new_positions = transforms.pbc_wrap_batched( new_positions, state.cell, state.system_idx ) state.positions = new_positions return state
[docs] def velocity_verlet[T: MDState](state: T, dt: torch.Tensor, model: ModelInterface) -> T: """Perform one complete velocity Verlet integration step. This function implements the velocity Verlet algorithm, which provides time-reversible integration of the equations of motion. The integration sequence is: 1. Half momentum update 2. Full position update 3. Force update 4. Half momentum update Args: state: Current system state containing positions, momenta, forces dt: Integration timestep model: Neural network model that computes energies and forces Returns: Updated state after one complete velocity Verlet step Notes: - Time-reversible and symplectic integrator of second order accuracy - Conserves energy in the absence of numerical errors - Handles periodic boundary conditions if enabled in state """ dt_2 = dt / 2 state = momentum_step(state, dt_2) state = position_step(state, dt) model_output = model(state) state.energy = model_output["energy"] state.forces = model_output["forces"] return momentum_step(state, dt_2)
[docs] @dataclass class NoseHooverChain: """State information for a Nose-Hoover chain thermostat. The Nose-Hoover chain is a deterministic thermostat that maintains constant temperature by coupling the system to a chain of thermostats. Each thermostat in the chain has its own positions, momenta, and masses. Attributes: positions: Positions of the chain thermostats. Shape: [chain_length] momenta: Momenta of the chain thermostats. Shape: [chain_length] masses: Masses of the chain thermostats. Shape: [chain_length] tau: Thermostat relaxation time. Longer values give better stability but worse temperature control. Shape: scalar kinetic_energy: Current kinetic energy of the coupled system. Shape: scalar degrees_of_freedom: Number of degrees of freedom in the coupled system """ positions: torch.Tensor momenta: torch.Tensor masses: torch.Tensor tau: torch.Tensor kinetic_energy: torch.Tensor degrees_of_freedom: int
[docs] @dataclass class NoseHooverChainFns: """Collection of functions for operating on a Nose-Hoover chain. Attributes: initialize (Callable): Function to initialize the chain state half_step (Callable): Function to perform half-step integration of chain update_mass (Callable): Function to update the chain masses """ initialize: Callable half_step: Callable update_mass: Callable
#: Suzuki-Yoshida composition weights for higher-order symplectic integrators. #: #: These coefficients are used to construct high-order operator-splitting #: schemes (Suzuki-Yoshida compositions) in molecular dynamics and Hamiltonian #: simulations. #: #: The coefficients define how lower-order symplectic integrators (e.g., leapfrog) #: can be recursively composed to achieve higher-order accuracy while preserving #: symplectic structure. #: #: References: #: - M. Suzuki, *General Decomposition Theory of Ordered Exponentials*, #: Proc. Japan Acad. 69, 161 (1993). #: - H. Yoshida, *Construction of higher order symplectic integrators*, #: Phys. Lett. A 150, 262-268 (1990). #: - M. Tuckerman, *Statistical Mechanics: Theory and Molecular Simulation*, #: Oxford University Press (2010). Section 4.11 #: #: :type: dict[int, torch.Tensor] SUZUKI_YOSHIDA_WEIGHTS = { 1: torch.tensor([1.0]), 3: torch.tensor([0.828981543588751, -0.657963087177502, 0.828981543588751]), 5: torch.tensor( [ 0.2967324292201065, 0.2967324292201065, -0.186929716880426, 0.2967324292201065, 0.2967324292201065, ] ), 7: torch.tensor( [ 0.784513610477560, 0.235573213359357, -1.17767998417887, 1.31518632068391, -1.17767998417887, 0.235573213359357, 0.784513610477560, ] ), }
[docs] def construct_nose_hoover_chain( dt: torch.Tensor, chain_length: int, chain_steps: int, sy_steps: int, tau: torch.Tensor, ) -> NoseHooverChainFns: """Creates functions to simulate a Nose-Hoover Chain thermostat. Implements the direct translation method from Martyna et al. for thermal ensemble sampling using Nose-Hoover chains. The chains are updated using a symmetric splitting scheme with two half-steps per simulation step. The integration uses a multi-timestep approach with Suzuki-Yoshida (SY) splitting: - The chain evolution is split into nc substeps (chain_steps) - Each substep is further split into sy_steps - Each SY step has length δi = Δt*wi/nc where wi are the SY weights Args: dt: Simulation timestep chain_length: Number of thermostats in the chain chain_steps: Number of outer substeps for chain integration sy_steps: Number of Suzuki-Yoshida steps (must be 1, 3, 5, or 7) tau: Temperature equilibration timescale (in units of dt) Larger values give better stability but slower equilibration Returns: NoseHooverChainFns containing: - initialize: Function to create initial chain state - half_step: Function to evolve chain for half timestep - update_mass: Function to update chain masses References: Martyna et al. "Nose-Hoover chains: the canonical ensemble via continuous dynamics" J. Chem. Phys. 97, 2635 (1992) """ def init_fn( degrees_of_freedom: int, KE: torch.Tensor, kT: torch.Tensor ) -> NoseHooverChain: """Initialize a Nose-Hoover chain state. Args: degrees_of_freedom: Number of degrees of freedom in coupled system KE: Initial kinetic energy of the system kT: Target temperature in energy units Returns: Initial NoseHooverChain state """ device = KE.device dtype = KE.dtype xi = torch.zeros(chain_length, dtype=dtype, device=device) p_xi = torch.zeros(chain_length, dtype=dtype, device=device) Q = kT * torch.square(tau) * torch.ones(chain_length, dtype=dtype, device=device) Q[0] *= degrees_of_freedom return NoseHooverChain(xi, p_xi, Q, tau, KE, degrees_of_freedom) def substep_fn( delta: torch.Tensor, P: torch.Tensor, state: NoseHooverChain, kT: torch.Tensor ) -> tuple[torch.Tensor, NoseHooverChain, torch.Tensor]: """Perform single update of chain parameters and rescale velocities. Args: delta: Integration timestep for this substep P: System momenta to be rescaled state: Current chain state kT: Target temperature Returns: Tuple of (rescaled momenta, updated chain state, temperature) """ xi, p_xi, Q, _tau, KE, DOF = ( state.positions, state.momenta, state.masses, state.tau, state.kinetic_energy, state.degrees_of_freedom, ) delta_2 = delta / 2.0 delta_4 = delta_2 / 2.0 delta_8 = delta_4 / 2.0 M = chain_length - 1 # Update chain momenta backwards G = torch.square(p_xi[M - 1]) / Q[M - 1] - kT p_xi[M] += delta_4 * G for m in range(M - 1, 0, -1): G = torch.square(p_xi[m - 1]) / Q[m - 1] - kT scale = torch.exp(-delta_8 * p_xi[m + 1] / Q[m + 1]) p_xi[m] = scale * (scale * p_xi[m] + delta_4 * G) # Update system coupling G = 2.0 * KE - DOF * kT scale = torch.exp(-delta_8 * p_xi[1] / Q[1]) p_xi[0] = scale * (scale * p_xi[0] + delta_4 * G) # Rescale system momenta scale = torch.exp(-delta_2 * p_xi[0] / Q[0]) KE = KE * torch.square(scale) P = P * scale # Update positions xi = xi + delta_2 * p_xi / Q # Update chain momenta forwards G = 2.0 * KE - DOF * kT for m in range(M): scale = torch.exp(-delta_8 * p_xi[m + 1] / Q[m + 1]) p_xi[m] = scale * (scale * p_xi[m] + delta_4 * G) G = torch.square(p_xi[m]) / Q[m] - kT p_xi[M] += delta_4 * G return P, NoseHooverChain(xi, p_xi, Q, _tau, KE, DOF), kT def half_step_chain_fn( P: torch.Tensor, state: NoseHooverChain, kT: torch.Tensor ) -> tuple[torch.Tensor, NoseHooverChain]: """Evolve chain for half timestep using multi-timestep integration. Args: P: System momenta to be rescaled state: Current chain state kT: Target temperature Returns: Tuple of (rescaled momenta, updated chain state) """ if chain_steps == 1 and sy_steps == 1: P, state, _ = substep_fn(dt, P, state, kT) return P, state delta = dt / chain_steps weights = SUZUKI_YOSHIDA_WEIGHTS[sy_steps] for step in range(chain_steps * sy_steps): d = delta * weights[step % sy_steps] P, state, _ = substep_fn(d, P, state, kT) return P, state def update_chain_mass_fn( chain_state: NoseHooverChain, kT: torch.Tensor ) -> NoseHooverChain: """Update chain masses to maintain target oscillation period. Args: chain_state: Current chain state kT: Target temperature Returns: Updated chain state with new masses """ device = chain_state.positions.device dtype = chain_state.positions.dtype Q = ( kT * torch.square(chain_state.tau) * torch.ones(chain_length, dtype=dtype, device=device) ) Q[0] *= chain_state.degrees_of_freedom return NoseHooverChain( chain_state.positions, chain_state.momenta, Q, chain_state.tau, chain_state.kinetic_energy, chain_state.degrees_of_freedom, ) return NoseHooverChainFns(init_fn, half_step_chain_fn, update_chain_mass_fn)