"""Implementations of NVT integrators."""
from dataclasses import dataclass
from typing import Any
import torch
import torch_sim as ts
from torch_sim.integrators.md import (
MDState,
NoseHooverChain,
NoseHooverChainFns,
calculate_momenta,
construct_nose_hoover_chain,
momentum_step,
position_step,
velocity_verlet,
)
from torch_sim.models.interface import ModelInterface
from torch_sim.state import SimState
from torch_sim.typing import StateDict
def _ou_step(
state: MDState,
dt: float | torch.Tensor,
kT: float | torch.Tensor,
gamma: float | torch.Tensor,
) -> MDState:
"""Apply stochastic noise and friction for Langevin dynamics.
This function implements the Ornstein-Uhlenbeck process for Langevin dynamics,
applying random noise and friction forces to particle momenta. The noise amplitude
is chosen to satisfy the fluctuation-dissipation theorem, ensuring proper
sampling of the canonical ensemble at temperature kT.
Args:
state (MDState): Current system state containing positions, momenta, etc.
dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems]
kT (torch.Tensor): Target temperature in energy units, either scalar or
with shape [n_systems]
gamma (torch.Tensor): Friction coefficient controlling noise strength,
either scalar or with shape [n_systems]
Returns:
MDState: Updated state with new momenta after stochastic step
Notes:
- Implements the "O" step in the BAOAB Langevin integration scheme
- Uses Ornstein-Uhlenbeck process for correct thermal sampling
- Noise amplitude scales with sqrt(mass) for equipartition
- Preserves detailed balance through fluctuation-dissipation relation
- The equation implemented is:
p(t+dt) = c1*p(t) + c2*sqrt(m)*N(0,1)
where c1 = exp(-gamma*dt) and c2 = sqrt(kT*(1-c1²))
"""
c1 = torch.exp(-gamma * dt)
if isinstance(kT, torch.Tensor) and len(kT.shape) > 0:
# kT is a tensor with shape (n_systems,)
kT = kT[state.system_idx]
# Index c1 and c2 with state.system_idx to align shapes with state.momenta
if isinstance(c1, torch.Tensor) and len(c1.shape) > 0:
c1 = c1[state.system_idx]
c2 = torch.sqrt(kT * (1 - torch.square(c1))).unsqueeze(-1)
# Generate random noise from normal distribution
noise = torch.randn_like(state.momenta, device=state.device, dtype=state.dtype)
new_momenta = (
c1.unsqueeze(-1) * state.momenta
+ c2 * torch.sqrt(state.masses).unsqueeze(-1) * noise
)
state.momenta = new_momenta
return state
[docs]
def nvt_langevin_init(
state: SimState | StateDict,
model: ModelInterface,
*,
kT: float | torch.Tensor,
seed: int | None = None,
**_kwargs: Any,
) -> MDState:
"""Initialize an NVT state from input data for Langevin dynamics.
Creates an initial state for NVT molecular dynamics by computing initial
energies and forces, and sampling momenta from a Maxwell-Boltzmann distribution
at the specified temperature.
Args:
model: Neural network model that computes energies and forces.
Must return a dict with 'energy' and 'forces' keys.
state: Either a SimState object or a dictionary containing positions,
masses, cell, pbc, and other required state vars
kT: Temperature in energy units for initializing momenta,
either scalar or with shape [n_systems]
seed: Random seed for reproducibility
Returns:
MDState: Initialized state for NVT integration containing positions,
momenta, forces, energy, and other required attributes
Notes:
The initial momenta are sampled from a Maxwell-Boltzmann distribution
at the specified temperature. This provides a proper thermal initial
state for the subsequent Langevin dynamics.
"""
if not isinstance(state, SimState):
state = SimState(**state)
model_output = model(state)
momenta = getattr(
state,
"momenta",
calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed),
)
return MDState(
positions=state.positions,
momenta=momenta,
energy=model_output["energy"],
forces=model_output["forces"],
masses=state.masses,
cell=state.cell,
pbc=state.pbc,
system_idx=state.system_idx,
atomic_numbers=state.atomic_numbers,
)
[docs]
def nvt_langevin_step(
state: MDState,
model: ModelInterface,
*,
dt: float | torch.Tensor,
kT: float | torch.Tensor,
gamma: float | torch.Tensor | None = None,
) -> MDState:
"""Perform one complete Langevin dynamics integration step.
This function implements the BAOAB splitting scheme for Langevin dynamics,
which provides accurate sampling of the canonical ensemble. The integration
sequence is:
1. Half momentum update using forces (B step)
2. Half position update using updated momenta (A step)
3. Full stochastic update with noise and friction (O step)
4. Half position update using updated momenta (A step)
5. Half momentum update using new forces (B step)
Args:
state: Current system state containing positions, momenta, forces
model: Neural network model that computes energies and forces.
Must return a dict with 'energy' and 'forces' keys.
dt: Integration timestep, either scalar or shape [n_systems]
kT: Target temperature in energy units, either scalar or
with shape [n_systems]
gamma: Friction coefficient for Langevin thermostat,
either scalar or with shape [n_systems]. Defaults to 1/(100*dt).
Returns:
MDState: Updated state after one complete Langevin step with new positions,
momenta, forces, and energy
Notes:
- Uses BAOAB splitting scheme for Langevin dynamics
- Preserves detailed balance for correct NVT sampling
- Handles periodic boundary conditions if enabled in state
- Friction coefficient gamma controls the thermostat coupling strength
- Weak coupling (small gamma) preserves dynamics but with slower thermalization
- Strong coupling (large gamma) faster thermalization but may distort dynamics
"""
device, dtype = model.device, model.dtype
if gamma is None:
gamma = 1 / (100 * dt)
if isinstance(gamma, float):
gamma = torch.tensor(gamma, device=device, dtype=dtype)
if isinstance(dt, float):
dt = torch.tensor(dt, device=device, dtype=dtype)
state = momentum_step(state, dt / 2)
state = position_step(state, dt / 2)
state = _ou_step(state, dt, kT, gamma)
state = position_step(state, dt / 2)
model_output = model(state)
state.energy = model_output["energy"]
state.forces = model_output["forces"]
return momentum_step(state, dt / 2)
[docs]
@dataclass(kw_only=True)
class NVTNoseHooverState(MDState):
"""State information for an NVT system with a Nose-Hoover chain thermostat.
This class represents the complete state of a molecular system being integrated
in the NVT (constant particle number, volume, temperature) ensemble using a
Nose-Hoover chain thermostat. The thermostat maintains constant temperature
through a deterministic extended system approach.
Attributes:
positions: Particle positions with shape [n_particles, n_dimensions]
masses: Particle masses with shape [n_particles]
cell: Simulation cell matrix with shape [n_dimensions, n_dimensions]
pbc: Whether to use periodic boundary conditions
momenta: Particle momenta with shape [n_particles, n_dimensions]
energy: Energy of the system
forces: Forces on particles with shape [n_particles, n_dimensions]
chain: State variables for the Nose-Hoover chain thermostat
Properties:
velocities: Particle velocities computed as momenta/masses
Has shape [n_particles, n_dimensions]
Notes:
- The Nose-Hoover chain provides deterministic temperature control
- Extended system approach conserves an extended energy quantity
- Chain variables evolve to maintain target temperature
- Time-reversible when integrated with appropriate algorithms
"""
chain: NoseHooverChain
_chain_fns: NoseHooverChainFns
_global_attributes = (
MDState._global_attributes | {"chain", "_chain_fns"} # noqa: SLF001
)
@property
def velocities(self) -> torch.Tensor:
"""Velocities calculated from momenta and masses with shape
[n_particles, n_dimensions].
"""
return self.momenta / self.masses.unsqueeze(-1)
[docs]
def get_number_of_degrees_of_freedom(self) -> torch.Tensor:
"""Calculate degrees of freedom per system."""
dof = super().get_number_of_degrees_of_freedom()
return dof - 3 # Subtract 3 degrees of freedom for center of mass motion
[docs]
def nvt_nose_hoover_init(
state: SimState | StateDict,
model: ModelInterface,
*,
kT: torch.Tensor,
dt: torch.Tensor,
tau: torch.Tensor | None = None,
chain_length: int = 3,
chain_steps: int = 3,
sy_steps: int = 3,
seed: int | None = None,
**kwargs: Any,
) -> NVTNoseHooverState:
"""Initialize the NVT Nose-Hoover state.
This function sets up integration of an NVT system using a Nose-Hoover chain
thermostat. The Nose-Hoover chain provides deterministic temperature control by
coupling the system to a chain of thermostats. The integration scheme is
time-reversible and conserves an extended energy quantity.
Args:
state: Initial system state as SimState or dict
model: Neural network model that computes energies and forces
kT: Target temperature in energy units
dt: Integration timestep
tau: Thermostat relaxation time (defaults to 100*dt)
chain_length: Number of thermostats in Nose-Hoover chain (default: 3)
chain_steps: Number of chain integration substeps (default: 3)
sy_steps: Number of Suzuki-Yoshida steps - must be 1, 3, 5, or 7 (default: 3)
seed: Random seed for momenta initialization
**kwargs: Additional state variables
Returns:
Initialized NVTNoseHooverState with positions, momenta, forces,
and thermostat chain variables
Notes:
- The Nose-Hoover chain provides deterministic temperature control
- Extended system approach conserves an extended energy quantity
- Chain variables evolve to maintain target temperature
- Time-reversible when integrated with appropriate algorithms
"""
if tau is None: # Set default tau if not provided
tau = dt * 100.0
# Create thermostat functions
chain_fns = construct_nose_hoover_chain(dt, chain_length, chain_steps, sy_steps, tau)
if not isinstance(state, SimState):
state = SimState(**state)
atomic_numbers = kwargs.get("atomic_numbers", state.atomic_numbers)
model_output = model(state)
momenta = kwargs.get(
"momenta",
calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed),
)
# Calculate initial kinetic energy per system
KE = ts.calc_kinetic_energy(
masses=state.masses, momenta=momenta, system_idx=state.system_idx
)
# Calculate degrees of freedom per system
n_atoms_per_system = torch.bincount(state.system_idx)
dof_per_system = (
n_atoms_per_system * state.positions.shape[-1]
) # n_atoms * n_dimensions
# Initialize state
return NVTNoseHooverState(
positions=state.positions,
momenta=momenta,
energy=model_output["energy"],
forces=model_output["forces"],
masses=state.masses,
cell=state.cell,
pbc=state.pbc,
atomic_numbers=atomic_numbers,
system_idx=state.system_idx,
chain=chain_fns.initialize(dof_per_system, KE, kT),
_chain_fns=chain_fns, # Store the chain functions
)
[docs]
def nvt_nose_hoover_step(
state: NVTNoseHooverState,
model: ModelInterface,
*,
dt: torch.Tensor,
kT: torch.Tensor,
) -> NVTNoseHooverState:
"""Perform one complete Nose-Hoover chain integration step.
This function performs one integration step for an NVT system using a Nose-Hoover
chain thermostat. The integration scheme is time-reversible and conserves an
extended energy quantity. If the center of mass motion is removed initially,
it remains removed throughout the simulation, so the degrees of freedom decreases
by 3.
Args:
state: Current system state containing positions, momenta, forces, and chain
model: Neural network model that computes energies and forces
dt: Integration timestep
kT: Target temperature in energy units
Returns:
Updated state after one complete Nose-Hoover step
Notes:
Integration sequence:
1. Update chain masses based on target temperature
2. First half-step of chain evolution
3. Full velocity Verlet step
4. Update chain kinetic energy
5. Second half-step of chain evolution
"""
# Get chain functions from state
chain_fns = state._chain_fns # noqa: SLF001
chain = state.chain
# Update chain masses based on target temperature
chain = chain_fns.update_mass(chain, kT)
# First half-step of chain evolution
momenta, chain = chain_fns.half_step(state.momenta, chain, kT, state.system_idx)
state.momenta = momenta
# Full velocity Verlet step
state = velocity_verlet(state=state, dt=dt, model=model)
# Update chain kinetic energy per system
KE = ts.calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)
chain.kinetic_energy = KE
# Second half-step of chain evolution
momenta, chain = chain_fns.half_step(state.momenta, chain, kT, state.system_idx)
state.momenta = momenta
state.chain = chain
return state
[docs]
def nvt_nose_hoover_invariant(
state: NVTNoseHooverState,
kT: torch.Tensor,
) -> torch.Tensor:
"""Calculate the conserved quantity for NVT ensemble with Nose-Hoover thermostat.
This function computes the conserved Hamiltonian of the extended system for
NVT dynamics with a Nose-Hoover chain thermostat. The invariant includes:
1. System potential energy
2. System kinetic energy
3. Chain thermostat energy terms
This quantity should remain approximately constant during simulation and is
useful for validating the thermostat implementation.
Args:
energy_fn: Function that computes system potential energy given positions
state: Current state of the system including chain variables
kT: Target temperature in energy units
Returns:
torch.Tensor: The conserved Hamiltonian of the extended NVT dynamics
Notes:
- Conservation indicates correct thermostat implementation
- Drift in this quantity suggests numerical instability
- Includes both physical and thermostat degrees of freedom
- Useful for debugging thermostat behavior
"""
# Calculate system energy terms per system
e_pot = state.energy
e_kin = ts.calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)
# Get system degrees of freedom per system
n_atoms_per_system = torch.bincount(state.system_idx)
dof = n_atoms_per_system * state.positions.shape[-1] # n_atoms * n_dimensions
# Start with system energy
e_tot = e_pot + e_kin
# Add first thermostat term
c = state.chain
# Ensure chain momenta and masses broadcast correctly with batch dimensions
chain_ke_0 = torch.square(c.momenta[:, 0]) / (2 * c.masses[:, 0])
chain_pe_0 = dof * kT * c.positions[:, 0]
# If chain variables are scalars but we have batches, broadcast them
if chain_ke_0.numel() == 1 and e_tot.numel() > 1:
chain_ke_0 = chain_ke_0.expand_as(e_tot)
if chain_pe_0.numel() == 1 and e_tot.numel() > 1:
chain_pe_0 = chain_pe_0.expand_as(e_tot)
e_tot = e_tot + chain_ke_0 + chain_pe_0
# Add remaining chain terms
for i in range(1, c.positions.shape[1]):
pos = c.positions[:, i]
momentum = c.momenta[:, i]
mass = c.masses[:, i]
chain_ke = momentum**2 / (2 * mass)
chain_pe = kT * pos
# Ensure proper broadcasting for batch dimensions
if chain_ke.numel() == 1 and e_tot.numel() > 1:
chain_ke = chain_ke.expand_as(e_tot)
if chain_pe.numel() == 1 and e_tot.numel() > 1:
chain_pe = chain_pe.expand_as(e_tot)
e_tot = e_tot + chain_ke + chain_pe
return e_tot
[docs]
class NVTVRescaleState(MDState):
"""State information for an NVT system with a V-Rescale thermostat.
This class represents the complete state of a molecular system being integrated
in the NVT (constant particle number, volume, temperature) ensemble using a
Velocity Rescaling thermostat. The thermostat maintains constant temperature
through stochastic velocity rescaling.
Attributes:
positions: Particle positions with shape [n_particles, n_dimensions]
masses: Particle masses with shape [n_particles]
cell: Simulation cell matrix with shape [n_dimensions, n_dimensions]
pbc: Whether to use periodic boundary conditions
momenta: Particle momenta with shape [n_particles, n_dimensions]
energy: Energy of the system
forces: Forces on particles with shape [n_particles, n_dimensions]
Notes:
- The V-Rescale thermostat provides proper canonical sampling
- Stochastic velocity rescaling ensures correct temperature distribution
- Time-reversible when integrated with appropriate algorithms
"""
[docs]
def get_number_of_degrees_of_freedom(self) -> torch.Tensor:
"""Calculate the degrees of freedom per system."""
# Subtract 3 for center of mass motion
return super().get_number_of_degrees_of_freedom() - 3
def _vrescale_update(
state: MDState,
tau: float | torch.Tensor,
kT: float | torch.Tensor,
dt: float | torch.Tensor,
) -> MDState:
"""Update the momentum by a scaling factor as described by Eq.A7 Bussi et al.
Note that we don't implement the optimize code from Bussi, which won't be useful
on a high level framework like PyTorch.
Args:
state: Current MD state
tau: Thermostat relaxation time
kT: Target temperature
dt: Integration timestep
Returns:
Updated state with rescaled momenta
"""
device, dtype = state.device, state.dtype
# Convert all inputs to tensors
tau_tensor = torch.as_tensor(tau, device=device, dtype=dtype)
kT_tensor = torch.as_tensor(kT, device=device, dtype=dtype)
dt_tensor = torch.as_tensor(dt, device=device, dtype=dtype)
# Calculate current temperature per system
current_kT = state.calc_kT()
# Calculate degrees of freedom per system
dof = state.get_number_of_degrees_of_freedom()
# Ensure kT and tau have proper batch dimensions
n_systems = current_kT.shape[0]
if kT_tensor.dim() == 0:
kT_tensor = kT_tensor.expand(n_systems)
if tau_tensor.dim() == 0:
tau_tensor = tau_tensor.expand(n_systems)
# Calculate kinetic energies
KE_old = dof * current_kT / 2
KE_new = dof * kT_tensor / 2
# Generate random numbers
r1 = torch.randn(n_systems, device=device, dtype=dtype)
# Sample Gamma((dof - 1)/2, 1/2) = \sum_2^{dof} X_i^2 where X_i ~ N(0,1)
r2 = torch.distributions.Gamma((dof - 1) / 2, torch.ones_like(dof) / 2).sample()
# Calculate scaling coefficients
c1 = torch.exp(-dt_tensor / tau_tensor)
c2 = (1 - c1) * KE_new / KE_old / dof
# Calculate scaling factor
scale = c1 + (c2 * (torch.square(r1) + r2)) + (2 * r1 * torch.sqrt(c1 * c2))
lam = torch.sqrt(scale)
# Apply scaling to momenta - map from system to atom indices
state.momenta = state.momenta * lam[state.system_idx].unsqueeze(-1)
return state
[docs]
def nvt_vrescale_init(
state: SimState | StateDict,
model: ModelInterface,
*,
kT: float | torch.Tensor,
seed: int | None = None,
**_kwargs: Any,
) -> NVTVRescaleState:
"""Initialize an NVT state from input data for velocity rescaling dynamics.
Creates an initial state for NVT molecular dynamics using the canonical
sampling through velocity rescaling (CSVR) thermostat. This thermostat
samples from the canonical ensemble by rescaling velocities with an
appropriately chosen random factor.
Args:
model: Neural network model that computes energies and forces.
Must return a dict with 'energy' and 'forces' keys.
state: Either a SimState object or a dictionary containing positions,
masses, cell, pbc, and other required state vars
kT: Temperature in energy units for initializing momenta,
either scalar or with shape [n_systems]
seed: Random seed for reproducibility
Returns:
MDState: Initialized state for NVT integration containing positions,
momenta, forces, energy, and other required attributes
Notes:
The initial momenta are sampled from a Maxwell-Boltzmann distribution
at the specified temperature. The V-Rescale thermostat provides proper
canonical sampling through stochastic velocity rescaling.
"""
if not isinstance(state, SimState):
state = SimState(**state)
model_output = model(state)
momenta = getattr(
state,
"momenta",
calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed),
)
return NVTVRescaleState(
positions=state.positions,
momenta=momenta,
energy=model_output["energy"],
forces=model_output["forces"],
masses=state.masses,
cell=state.cell,
pbc=state.pbc,
system_idx=state.system_idx,
atomic_numbers=state.atomic_numbers,
)
[docs]
def nvt_vrescale_step(
model: ModelInterface,
state: NVTVRescaleState,
*,
dt: float | torch.Tensor,
kT: float | torch.Tensor,
tau: float | torch.Tensor | None = None,
) -> NVTVRescaleState:
"""Perform one complete V-Rescale dynamics integration step.
This function implements the canonical sampling through velocity rescaling (V-Rescale)
thermostat combined with velocity Verlet integration. The V-Rescale thermostat samples
the canonical distribution by rescaling velocities with a properly chosen random
factor that ensures correct canonical sampling.
Args:
model: Neural network model that computes energies and forces.
Must return a dict with 'energy' and 'forces' keys.
state: Current system state containing positions, momenta, forces
dt: Integration timestep, either scalar or shape [n_systems]
kT: Target temperature in energy units, either scalar or
with shape [n_systems]
tau: Thermostat relaxation time controlling the coupling strength,
either scalar or with shape [n_systems]. Defaults to 100*dt.
seed: Random seed for reproducibility
Returns:
MDState: Updated state after one complete V-Rescale step with new positions,
momenta, forces, and energy
Notes:
- Uses V-Rescale thermostat for proper canonical ensemble sampling
- Unlike Berendsen thermostat, V-Rescale samples the true canonical distribution
- Integration sequence: V-Rescale rescaling + Velocity Verlet step
- The rescaling factor follows the distribution derived in Bussi et al.
References:
Bussi G, Donadio D, Parrinello M. "Canonical sampling through velocity rescaling."
The Journal of chemical physics, 126(1), 014101 (2007).
"""
device, dtype = model.device, model.dtype
if tau is None:
tau = 100 * dt
if isinstance(tau, float):
tau = torch.tensor(tau, device=device, dtype=dtype)
if isinstance(dt, float):
dt = torch.tensor(dt, device=device, dtype=dtype)
if isinstance(kT, float):
kT = torch.tensor(kT, device=device, dtype=dtype)
# Apply V-Rescale rescaling
state = _vrescale_update(state, tau, kT, dt)
# Perform velocity Verlet step
return velocity_verlet(state=state, dt=dt, model=model)