Source code for torch_sim.integrators.npt

"""Implementations of NPT integrators."""

import logging
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, cast

import torch

import torch_sim as ts
from torch_sim._duecredit import dcite
from torch_sim.integrators.md import (
    MDState,
    NoseHooverChain,
    NoseHooverChainFns,
    construct_nose_hoover_chain,
    initialize_momenta,
    momentum_step,
)
from torch_sim.integrators.nvt import _vrescale_update
from torch_sim.models.interface import ModelInterface
from torch_sim.state import SimState
from torch_sim.units import MetalUnits


logger = logging.getLogger(__name__)


def _randn_for_state(state: MDState, shape: torch.Size | tuple[int, ...]) -> torch.Tensor:
    """Sample standard normal noise on the state's device/dtype using state RNG."""
    return torch.randn(shape, device=state.device, dtype=state.dtype, generator=state.rng)


[docs] @dataclass(kw_only=True) class NPTState(MDState): """State information for an NPT system. This class extends MDState with the stress tensor needed for constant-pressure simulations. Integrator-specific NPT states (e.g., NPTLangevinAnisotropicState, NPTNoseHooverIsotropicState) inherit from this class and add their own auxiliary variables. Attributes: stress (torch.Tensor): Stress tensor [n_systems, n_dim, n_dim] """ # System state variables stress: torch.Tensor _system_attributes = MDState._system_attributes | { # noqa: SLF001 "stress", }
[docs] @dataclass(kw_only=True) class NPTLangevinAnisotropicState(NPTState): """State for NPT Langevin dynamics with independent per-dimension cell lengths. Each spatial dimension has its own logarithmic strain coordinate εi = ln(Li/Li0), driven by the corresponding diagonal pressure component P_ii. This is analogous to LAMMPS ``fix press/langevin`` with ``couple none``. With three identical target pressures the sum of forces equals the isotropic strain force, so the isotropic limit is recovered. Attributes: positions (torch.Tensor): Particle positions [n_particles, n_dim] velocities (torch.Tensor): Particle velocities [n_particles, n_dim] energy (torch.Tensor): Energy of the system [n_systems] forces (torch.Tensor): Forces on particles [n_particles, n_dim] masses (torch.Tensor): Particle masses [n_particles] cell (torch.Tensor): Simulation cell matrix [n_systems, n_dim, n_dim] pbc (bool): Whether to use periodic boundary conditions system_idx (torch.Tensor): System indices [n_particles] atomic_numbers (torch.Tensor): Atomic numbers [n_particles] stress (torch.Tensor): Stress tensor [n_systems, n_dim, n_dim] reference_cell (torch.Tensor): Original cell [n_systems, d, d] cell_positions (torch.Tensor): Per-dimension strain εi [n_systems, 3] cell_velocities (torch.Tensor): dεi/dt [n_systems, 3] cell_masses (torch.Tensor): Mass for strain DOFs [n_systems] alpha (torch.Tensor): Particle friction [n_systems] cell_alpha (torch.Tensor): Cell friction [n_systems] b_tau (torch.Tensor): Barostat time constant [n_systems] Properties: momenta (torch.Tensor): Particle momenta calculated as velocities*masses with shape [n_particles, n_dimensions] current_cell (torch.Tensor): Cell reconstructed from strain and reference_cell volume (torch.Tensor): Current volume from cell determinant n_systems (int): Number of independent systems in the batch device (torch.device): Device on which tensors are stored dtype (torch.dtype): Data type of tensors """ alpha: torch.Tensor cell_alpha: torch.Tensor b_tau: torch.Tensor # Cell variables reference_cell: torch.Tensor cell_positions: torch.Tensor # (n_systems, 3) per-dimension strain cell_velocities: torch.Tensor # (n_systems, 3) cell_masses: torch.Tensor _system_attributes = NPTState._system_attributes | { # noqa: SLF001 "cell_positions", "cell_velocities", "cell_masses", "reference_cell", "alpha", "cell_alpha", "b_tau", } @property def current_cell(self) -> torch.Tensor: """Compute cell from per-dimension strain: cell[i,:] = exp(εi) · ref[i,:].""" scale = torch.exp(self.cell_positions) # (n_systems, 3) return scale.unsqueeze(-1) * self.reference_cell @property def volume(self) -> torch.Tensor: """Current volume from cell determinant.""" return torch.linalg.det(self.cell)
def _npt_langevin_particle_beta( state: "NPTLangevinAnisotropicState | NPTLangevinIsotropicState", kT: torch.Tensor, dt: torch.Tensor, ) -> torch.Tensor: """Calculate random noise term for particle Langevin dynamics. This function generates the stochastic force term for the Langevin thermostat according to the fluctuation-dissipation theorem, ensuring proper thermal sampling at the target temperature. Only particle degrees of freedom are involved (not cell DOFs), so it works for both isotropic and anisotropic cell dynamics. Args: state (NPTLangevinAnisotropicState | NPTLangevinIsotropicState): Current NPT state kT (torch.Tensor): Temperature in energy units, either scalar or shape [n_systems] dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] Returns: torch.Tensor: Random noise term for force calculation [n_particles, n_dim] """ # Generate system-specific noise with correct shape noise = _randn_for_state(state, state.momenta.shape) # Calculate the thermal noise amplitude by system batch_kT = kT if kT.ndim == 0: batch_kT = kT.expand(state.n_systems) # Map system kT to atoms atom_kT = batch_kT[state.system_idx] atom_alpha = state.alpha[state.system_idx] atom_dt = dt if dt.ndim == 0: atom_dt = dt.expand(state.n_systems)[state.system_idx] # Calculate the prefactor for each atom # The standard deviation should be sqrt(2*alpha*kB*T*dt) prefactor = torch.sqrt(2 * atom_alpha * atom_kT * atom_dt) return prefactor.unsqueeze(-1) * noise def _npt_langevin_anisotropic_cell_beta( state: NPTLangevinAnisotropicState, kT: torch.Tensor, dt: torch.Tensor, ) -> torch.Tensor: """Generate per-dimension noise for cell length fluctuations. Args: state: Current NPT state kT: Temperature in energy units (scalar or [n_systems]) dt: Timestep (scalar or [n_systems]) Returns: torch.Tensor: Noise [n_systems, 3] """ noise = _randn_for_state(state, (state.n_systems, 3)) batch_kT = kT if kT.ndim > 0 else kT.expand(state.n_systems) dt_expanded = dt if dt.ndim > 0 else dt.expand(state.n_systems) scaling = torch.sqrt(2.0 * state.cell_alpha * batch_kT * dt_expanded) return scaling.unsqueeze(-1) * noise def _npt_langevin_anisotropic_cell_position_step( state: NPTLangevinAnisotropicState, dt: torch.Tensor, strain_force: torch.Tensor, cell_beta: torch.Tensor, ) -> NPTLangevinAnisotropicState: """GJF position step for per-dimension strain εi. Args: state: Current NPT state dt: Timestep strain_force: F_εi [n_systems, 3] cell_beta: Noise [n_systems, 3] Returns: Updated state with new cell_positions (strain) """ Q_2 = (2 * state.cell_masses).unsqueeze(-1) # (n_systems, 1) dt_expanded = dt if dt.ndim > 0 else dt.expand(state.n_systems) dt_3 = dt_expanded.unsqueeze(-1) if dt_expanded.ndim > 0 else dt_expanded cell_b = 1 / (1 + (state.cell_alpha.unsqueeze(-1) * dt_3) / Q_2) c_1 = cell_b * dt_3 * state.cell_velocities c_2 = cell_b * dt_3 * dt_3 * strain_force / Q_2 c_3 = cell_b * dt_3 * cell_beta / Q_2 state.cell_positions = state.cell_positions + c_1 + c_2 + c_3 return state def _npt_langevin_anisotropic_cell_velocity_step( state: NPTLangevinAnisotropicState, F_eps_n: torch.Tensor, dt: torch.Tensor, strain_force: torch.Tensor, cell_beta: torch.Tensor, ) -> NPTLangevinAnisotropicState: """GJF velocity step for per-dimension strain εi. Args: state: Current NPT state F_eps_n: Initial strain force [n_systems, 3] dt: Timestep strain_force: Final strain force [n_systems, 3] cell_beta: Noise (SAME as in position step) [n_systems, 3] Returns: Updated state with new cell_velocities """ dt_expanded = dt if dt.ndim > 0 else dt.expand(state.n_systems) dt_3 = dt_expanded.unsqueeze(-1) if dt_expanded.ndim > 0 else dt_expanded Q = state.cell_masses.unsqueeze(-1) # (n_systems, 1) alpha_c = state.cell_alpha.unsqueeze(-1) # (n_systems, 1) a = (1 - (alpha_c * dt_3) / (2 * Q)) / (1 + (alpha_c * dt_3) / (2 * Q)) b = 1 / (1 + (alpha_c * dt_3) / (2 * Q)) c_1 = a * state.cell_velocities c_2 = dt_3 * ((a * F_eps_n) + strain_force) / (2 * Q) c_3 = b * cell_beta / Q state.cell_velocities = c_1 + c_2 + c_3 return state def _npt_langevin_anisotropic_position_step( state: NPTLangevinAnisotropicState, eps_old: torch.Tensor, dt: torch.Tensor, particle_beta: torch.Tensor, ) -> NPTLangevinAnisotropicState: """Update particle positions with per-dimension strain scaling. Each component of position is scaled by exp(εi_new - εi_old). Args: state: Current state (cell_positions already updated) eps_old: Previous strain [n_systems, 3] dt: Timestep particle_beta: Noise [n_particles, n_dim] Returns: Updated state with new positions """ M_2 = 2 * state.masses.unsqueeze(-1) # (n_atoms, 1) # Per-dimension scale factor scale = torch.exp(state.cell_positions - eps_old) # (n_systems, 3) scale_atoms = scale[state.system_idx] # (n_atoms, 3) # Damping factor alpha_atoms = state.alpha[state.system_idx] dt_atoms = dt if dt.ndim > 0: dt_atoms = dt[state.system_idx] b = 1 / (1 + ((alpha_atoms * dt_atoms) / (2 * state.masses))) # Scale each position component independently c_1 = scale_atoms * state.positions # (n_atoms, 3) # Time step factor: 2·s/(s+1) per dimension c_2 = (2 * scale_atoms / (scale_atoms + 1)) * b.unsqueeze(-1) * dt_atoms.unsqueeze(-1) c_3 = ( state.velocities + dt_atoms.unsqueeze(-1) * state.forces / M_2 + particle_beta / M_2 ) state.set_constrained_positions(c_1 + c_2 * c_3) return state def _npt_langevin_particle_velocity_step( state: "NPTLangevinAnisotropicState | NPTLangevinIsotropicState", forces: torch.Tensor, dt: torch.Tensor, particle_beta: torch.Tensor, ) -> "NPTLangevinAnisotropicState | NPTLangevinIsotropicState": """Update the particle velocities in NPT dynamics. This function updates particle velocities using a Langevin-type integrator, accounting for both deterministic forces and pre-generated thermal noise. Only particle degrees of freedom are involved (not cell DOFs), so it works for both isotropic and anisotropic cell dynamics. Args: state (NPTLangevinAnisotropicState | NPTLangevinIsotropicState): Current NPT state forces: Forces on particles (from before position update) dt: Integration timestep, either scalar or with shape [n_systems] particle_beta (torch.Tensor): Pre-generated GJF noise term β for particle dynamics. Must be the SAME realization used in the position step. Shape [n_particles, n_dim] Returns: Updated state with new velocities (same type as input). """ # Calculate denominator for update equations M_2 = 2 * state.masses # shape: (n_atoms, 1) # Map batch parameters to atom level alpha_atoms = state.alpha[state.system_idx] dt_atoms = dt if dt.ndim > 0: dt_atoms = dt[state.system_idx] # Calculate damping factors for Langevin integration a = (1 - (alpha_atoms * dt_atoms) / M_2) / (1 + (alpha_atoms * dt_atoms) / M_2) a = a.unsqueeze(-1) b = 1 / (1 + (alpha_atoms * dt_atoms) / M_2).unsqueeze(-1) # Velocity contribution with damping c_1 = a * state.velocities # Force contribution (average of initial and final forces) c_2 = dt_atoms.unsqueeze(-1) * ((a * forces) + state.forces) / M_2.unsqueeze(-1) # GJF noise term: b * β / m c_3 = b * particle_beta / state.masses.unsqueeze(-1) # Update momenta (velocities * masses) with all contributions new_velocities = c_1 + c_2 + c_3 # Apply constraints. state.set_constrained_momenta(new_velocities * state.masses.unsqueeze(-1)) return state def _npt_langevin_anisotropic_compute_cell_force( state: NPTLangevinAnisotropicState, external_pressure: torch.Tensor, kT: torch.Tensor, ) -> torch.Tensor: """Compute per-dimension force on the strain coordinates. F_εi = V · (P_ii - P_ext_i) where P_ii = -σ_ii + N·kT/V is the ii diagonal pressure component. The force is in energy units (eV). Args: state: Current NPT state external_pressure: Target pressure per dimension [3] or [n_systems, 3] kT: Temperature in energy units (scalar or [n_systems]) Returns: torch.Tensor: Force per dimension [n_systems, 3] """ volumes = state.volume # (n_systems,) # Diagonal stress components \sigma_ii stress_diag = torch.diagonal(state.stress, dim1=-2, dim2=-1) # (n_systems, 3) # P_ii = -\sigma_ii (virial part) P_virial_diag = -stress_diag # (n_systems, 3) # Kinetic contribution per dimension: N·kT/V (target temperature) batch_kT = kT if kT.ndim > 0 else kT.expand(state.n_systems) n_atoms = state.n_atoms_per_system.to(dtype=state.dtype) kinetic_pressure = (n_atoms * batch_kT / volumes).unsqueeze(-1) # (n_systems, 1) P_diag = P_virial_diag + kinetic_pressure # (n_systems, 3) # F_εi = V · (P_ii - P_ext_i) return volumes.unsqueeze(-1) * (P_diag - external_pressure)
[docs] def npt_langevin_anisotropic_init( state: SimState, model: ModelInterface, *, kT: float | torch.Tensor, dt: float | torch.Tensor, alpha: float | torch.Tensor | None = None, cell_alpha: float | torch.Tensor | None = None, b_tau: float | torch.Tensor | None = None, **_kwargs: Any, ) -> NPTLangevinAnisotropicState: """Initialize NPT Langevin state with independent per-dimension cell lengths. Each spatial dimension gets its own strain DOF εi = ln(Li/Li0), driven by the corresponding diagonal pressure component. To seed the RNG set ``state.rng = seed`` before calling. Args: state: SimState containing positions, masses, cell, pbc model: Model computing energy, forces, stress kT: Target temperature in energy units dt: Integration timestep alpha: Particle friction. Defaults to 1/(5·dt). cell_alpha: Cell friction. Defaults to 1/(30·dt). b_tau: Barostat time constant. Defaults to 300·dt. Returns: NPTLangevinAnisotropicState with εi = 0 for all dimensions """ device, dtype = model.device, model.dtype if alpha is None: alpha = 1.0 / (5 * dt) if cell_alpha is None: cell_alpha = 1.0 / (30 * dt) if b_tau is None: b_tau = 300 * dt alpha = torch.as_tensor(alpha, device=device, dtype=dtype) cell_alpha = torch.as_tensor(cell_alpha, device=device, dtype=dtype) b_tau = torch.as_tensor(b_tau, device=device, dtype=dtype) kT = torch.as_tensor(kT, device=device, dtype=dtype) dt = torch.as_tensor(dt, device=device, dtype=dtype) if alpha.ndim == 0: alpha = alpha.expand(state.n_systems) if cell_alpha.ndim == 0: cell_alpha = cell_alpha.expand(state.n_systems) if b_tau.ndim == 0: b_tau = b_tau.expand(state.n_systems) model_output = model(state) momenta = getattr(state, "momenta", None) if momenta is None: momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT, state.rng, ) reference_cell = state.cell.clone() dim = state.positions.shape[1] # εi = 0 at initialization (V = V₀) cell_positions = torch.zeros(state.n_systems, dim, device=device, dtype=dtype) cell_velocities = torch.zeros(state.n_systems, dim, device=device, dtype=dtype) batch_kT = kT.expand(state.n_systems) if kT.ndim == 0 else kT n_atoms_per_system = torch.bincount(state.system_idx) cell_masses = (n_atoms_per_system + 1) * batch_kT * b_tau * b_tau if state.constraints: msg = ( "Constraints are present in the system. " "Make sure they are compatible with NPT Langevin dynamics. " "We recommend not using constraints with NPT dynamics for now." ) warnings.warn(msg, UserWarning, stacklevel=3) logger.warning(msg) # Create the initial state npt_state = NPTLangevinAnisotropicState.from_state( state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], stress=model_output["stress"], alpha=alpha, b_tau=b_tau, reference_cell=reference_cell, cell_positions=cell_positions, cell_velocities=cell_velocities, cell_masses=cell_masses, cell_alpha=cell_alpha, ) npt_state.store_model_extras(model_output) return npt_state
[docs] @dcite("10.1063/1.4901303") def npt_langevin_anisotropic_step( state: NPTLangevinAnisotropicState, model: ModelInterface, *, dt: float | torch.Tensor, kT: float | torch.Tensor, external_pressure: float | torch.Tensor, ) -> NPTLangevinAnisotropicState: r"""Perform one NPT Langevin step with independent per-dimension cell lengths. Implements constant-pressure Langevin dynamics based on Gronbech-Jensen & Farago (2014) [4]_ and the LAMMPS ``fix press/langevin`` scheme [5]_. Each spatial dimension *i* has its own logarithmic strain :math:`\varepsilon_i = \ln(L_i/L_{i,0})` driven by the diagonal pressure component :math:`P_{ii}`. **Per-dimension strain force:** .. math:: F_{\varepsilon_i} = V \cdot (P_{ii} - P_{\text{ext},i}) where :math:`P_{ii} = -\sigma_{ii} + N k_B T / V`. With three identical target pressures the sum :math:`\sum_i F_{\varepsilon_i}` equals the isotropic strain force. **Cell reconstruction:** .. math:: \mathbf{h}_i = e^{\varepsilon_i}\,\mathbf{h}_{i,0} **Particle scaling (per component):** .. math:: r_{k,i} \to e^{\varepsilon_i^{n+1} - \varepsilon_i^n}\, r_{k,i} Args: state: Current NPT state model: Model computing energy, forces, stress dt: Integration timestep kT: Target temperature in energy units external_pressure: Target pressure — scalar (same for all dims), shape [3] (per-dimension), or [n_systems, 3] Returns: NPTLangevinAnisotropicState: Updated state References: .. [4] Gronbech-Jensen, N. & Farago, O. "Constant pressure and temperature discrete-time Langevin molecular dynamics." J. Chem. Phys. 141(19) (2014). .. [5] LAMMPS fix press/langevin: https://docs.lammps.org/fix_press_langevin.html """ device, dtype = model.device, model.dtype state.alpha = torch.as_tensor(state.alpha, device=device, dtype=dtype) kT_tensor = torch.as_tensor(kT, device=device, dtype=dtype) state.cell_alpha = torch.as_tensor(state.cell_alpha, device=device, dtype=dtype) dt_tensor = torch.as_tensor(dt, device=device, dtype=dtype) external_pressure_tensor = torch.as_tensor( external_pressure, device=device, dtype=dtype ) # Broadcast external_pressure to (n_systems, 3) if external_pressure_tensor.ndim == 0: external_pressure_tensor = external_pressure_tensor.expand(state.n_systems, 3) elif external_pressure_tensor.ndim == 1 and external_pressure_tensor.shape[0] == 3: external_pressure_tensor = external_pressure_tensor.unsqueeze(0).expand( state.n_systems, 3 ) batch_kT = kT_tensor.expand(state.n_systems) if kT_tensor.ndim == 0 else kT_tensor # Update barostat mass n_atoms_per_system = torch.bincount(state.system_idx) state.cell_masses = (n_atoms_per_system + 1) * batch_kT * torch.square(state.b_tau) # Store initial values forces = state.forces eps_old = state.cell_positions.clone() F_eps_n = _npt_langevin_anisotropic_compute_cell_force( state=state, external_pressure=external_pressure_tensor, kT=kT_tensor, ) # Generate GJF noise ONCE cell_beta = _npt_langevin_anisotropic_cell_beta(state, kT_tensor, dt_tensor) particle_beta = _npt_langevin_particle_beta(state, kT_tensor, dt_tensor) # Step 1: Update per-dimension strain state = _npt_langevin_anisotropic_cell_position_step( state, dt_tensor, F_eps_n, cell_beta ) # Reconstruct cell from updated strain state.cell = state.current_cell # Step 2: Update particle positions state = _npt_langevin_anisotropic_position_step( state, eps_old, dt_tensor, particle_beta ) # Recompute model output model_output = model(state) state.energy = model_output["energy"] state.forces = model_output["forces"] state.stress = model_output["stress"] state.store_model_extras(model_output) # Updated strain force F_eps_new = _npt_langevin_anisotropic_compute_cell_force( state=state, external_pressure=external_pressure_tensor, kT=kT_tensor, ) # Step 3: Update strain velocities (uses SAME cell_beta) state = _npt_langevin_anisotropic_cell_velocity_step( state, F_eps_n, dt_tensor, F_eps_new, cell_beta ) # Step 4: Update particle velocities (uses SAME particle_beta) return cast( "NPTLangevinAnisotropicState", _npt_langevin_particle_velocity_step(state, forces, dt_tensor, particle_beta), )
# ============================================================================= # NPT Langevin Strain integrator — isotropic logarithmic strain coordinate # =============================================================================
[docs] @dataclass(kw_only=True) class NPTLangevinIsotropicState(NPTState): """State for NPT Langevin dynamics using logarithmic strain coordinate. The cell degree of freedom is the isotropic logarithmic strain ε = (1/d)·ln(V/V₀), which is dimensionless. This guarantees V > 0 and gives the conjugate force F_ε = d·V·(P_avg - P_ext) in energy units, providing numerically well-scaled dynamics. Attributes: reference_cell (torch.Tensor): Original cell [n_systems, d, d] cell_positions (torch.Tensor): Strain ε = (1/d)·ln(V/V₀) [n_systems] cell_velocities (torch.Tensor): dε/dt [n_systems] cell_masses (torch.Tensor): Mass for strain DOF [n_systems] alpha (torch.Tensor): Particle friction [n_systems] cell_alpha (torch.Tensor): Cell friction [n_systems] b_tau (torch.Tensor): Barostat time constant [n_systems] """ alpha: torch.Tensor cell_alpha: torch.Tensor b_tau: torch.Tensor reference_cell: torch.Tensor cell_positions: torch.Tensor # strain ε (dimensionless) cell_velocities: torch.Tensor # dε/dt cell_masses: torch.Tensor _system_attributes = NPTState._system_attributes | { # noqa: SLF001 "cell_positions", "cell_velocities", "cell_masses", "reference_cell", "alpha", "cell_alpha", "b_tau", } @property def current_cell(self) -> torch.Tensor: """Compute cell from strain: cell = exp(ε) · reference_cell.""" scale = torch.exp(self.cell_positions) # exp(ε), shape (n_systems,) return scale.unsqueeze(-1).unsqueeze(-1) * self.reference_cell @property def volume(self) -> torch.Tensor: """Current volume V = V₀ · exp(d·ε).""" dim = self.positions.shape[1] V_0 = torch.linalg.det(self.reference_cell) return V_0 * torch.exp(dim * self.cell_positions)
def _compute_isotropic_cell_force( state: NPTLangevinIsotropicState, external_pressure: float | torch.Tensor, kT: float | torch.Tensor, ) -> torch.Tensor: """Compute force on the strain coordinate ε. F_ε = d · V · (P_avg - P_ext) where P_avg = -(1/3)Tr(σ) + NkT/V and d·V is the Jacobian dV/dε. This force is in energy units (eV), making it numerically well-scaled. Args: state: Current strain-based NPT state external_pressure: Target pressure (scalar or [n_systems]) kT: Temperature in energy units (scalar or [n_systems]) Returns: torch.Tensor: Force on strain per system [n_systems] """ external_pressure = torch.as_tensor( external_pressure, device=state.device, dtype=state.dtype ) kT = torch.as_tensor(kT, device=state.device, dtype=state.dtype) dim = state.positions.shape[1] volumes = state.volume # (n_systems,) # Isotropic virial pressure: P_virial = -(1/3)Tr(stress) stress_trace = torch.einsum("nii->n", state.stress) avg_virial_pressure = -stress_trace / 3 # (n_systems,) # Kinetic contribution: NkT/V batch_kT = kT if kT.ndim > 0 else kT.expand(state.n_systems) n_atoms = state.n_atoms_per_system.to(dtype=state.dtype) kinetic_pressure = n_atoms * batch_kT / volumes # (n_systems,) if external_pressure.ndim >= 2: raise ValueError( f"External pressure tensor provided with shape {external_pressure.shape}. " "Only scalar or per-system external pressure is supported." ) P_avg = avg_virial_pressure + kinetic_pressure # F_ε = d · V · (P_avg - P_ext) return dim * volumes * (P_avg - external_pressure) def _npt_langevin_isotropic_cell_beta( state: NPTLangevinIsotropicState, kT: torch.Tensor, dt: torch.Tensor, ) -> torch.Tensor: """Generate scalar random noise for isotropic strain fluctuations. Returns: torch.Tensor: Noise [n_systems] """ noise = _randn_for_state(state, (state.n_systems,)) batch_kT = kT if kT.ndim > 0 else kT.expand(state.n_systems) dt_expanded = dt if dt.ndim > 0 else dt.expand(state.n_systems) scaling = torch.sqrt(2.0 * state.cell_alpha * batch_kT * dt_expanded) return scaling * noise def _npt_langevin_isotropic_cell_position_step( state: NPTLangevinIsotropicState, dt: torch.Tensor, strain_force: torch.Tensor, cell_beta: torch.Tensor, ) -> NPTLangevinIsotropicState: """GJF position step for the strain coordinate ε. ε_{n+1} = ε_n + b·dt·dε/dt + b·dt²·F_ε/(2Q) + b·dt·β/(2Q) Args: state: Current state dt: Timestep strain_force: F_ε [n_systems] cell_beta: Noise term β_c [n_systems] Returns: Updated state with new cell_positions (strain) """ Q_2 = 2 * state.cell_masses dt_expanded = dt if dt.ndim > 0 else dt.expand(state.n_systems) cell_b = 1 / (1 + (state.cell_alpha * dt_expanded) / Q_2) c_1 = cell_b * dt_expanded * state.cell_velocities c_2 = cell_b * dt_expanded * dt_expanded * strain_force / Q_2 c_3 = cell_b * dt_expanded * cell_beta / Q_2 state.cell_positions = state.cell_positions + c_1 + c_2 + c_3 return state def _npt_langevin_isotropic_cell_velocity_step( state: NPTLangevinIsotropicState, F_eps_n: torch.Tensor, dt: torch.Tensor, strain_force: torch.Tensor, cell_beta: torch.Tensor, ) -> NPTLangevinIsotropicState: """GJF velocity step for the strain coordinate ε. dε/dt_{n+1} = a·dε/dt_n + dt/(2Q)·(a·F_ε^n + F_ε^{n+1}) + b·β/Q Args: state: Current state F_eps_n: Initial strain force [n_systems] dt: Timestep strain_force: Final strain force [n_systems] cell_beta: Noise term β_c (SAME as in position step) [n_systems] Returns: Updated state with new cell_velocities (dε/dt) """ dt_expanded = dt if dt.ndim > 0 else dt.expand(state.n_systems) Q = state.cell_masses a = (1 - (state.cell_alpha * dt_expanded) / (2 * Q)) / ( 1 + (state.cell_alpha * dt_expanded) / (2 * Q) ) b = 1 / (1 + (state.cell_alpha * dt_expanded) / (2 * Q)) c_1 = a * state.cell_velocities c_2 = dt_expanded * ((a * F_eps_n) + strain_force) / (2 * Q) c_3 = b * cell_beta / Q state.cell_velocities = c_1 + c_2 + c_3 return state def _npt_langevin_isotropic_position_step( state: NPTLangevinIsotropicState, eps_old: torch.Tensor, dt: torch.Tensor, particle_beta: torch.Tensor, ) -> NPTLangevinIsotropicState: """Update particle positions accounting for strain change. Positions are scaled by exp(ε_new - ε_old) for the volume change, then the standard GJF position update is applied. Args: state: Current state (cell_positions already updated to ε_new) eps_old: Strain before the cell position step [n_systems] dt: Timestep particle_beta: Noise [n_particles, n_dim] Returns: Updated state with new positions """ M_2 = 2 * state.masses.unsqueeze(-1) # (n_atoms, 1) # Scale factor from strain change: L_new/L_old = exp(ε_new - ε_old) scale = torch.exp(state.cell_positions - eps_old) # (n_systems,) scale_atoms = scale[state.system_idx] # (n_atoms,) # Damping factor alpha_atoms = state.alpha[state.system_idx] dt_atoms = dt if dt.ndim > 0: dt_atoms = dt[state.system_idx] b = 1 / (1 + ((alpha_atoms * dt_atoms) / (2 * state.masses))) # Scale positions due to volume change c_1 = scale_atoms.unsqueeze(-1) * state.positions # Time step factor: 2·s/(s+1) where s = scale c_2 = (2 * scale_atoms / (scale_atoms + 1)) * b * dt_atoms c_3 = ( state.velocities + dt_atoms.unsqueeze(-1) * state.forces / M_2 + particle_beta / M_2 ) state.set_constrained_positions(c_1 + c_2.unsqueeze(-1) * c_3) return state
[docs] def npt_langevin_isotropic_init( state: SimState, model: ModelInterface, *, kT: float | torch.Tensor, dt: float | torch.Tensor, alpha: float | torch.Tensor | None = None, cell_alpha: float | torch.Tensor | None = None, b_tau: float | torch.Tensor | None = None, **_kwargs: Any, ) -> NPTLangevinIsotropicState: """Initialize an NPT Langevin state using logarithmic strain coordinate. The strain coordinate ε = (1/d)·ln(V/V₀) provides well-scaled dynamics where the conjugate force F_ε = d·V·(P_avg - P_ext) is in energy units. Args: state: Initial SimState model: Model that computes energy, forces, stress kT: Target temperature in energy units dt: Integration timestep alpha: Particle friction coefficient. Defaults to 1/(5·dt). cell_alpha: Cell friction coefficient. Defaults to 1/(30·dt). b_tau: Barostat time constant. Defaults to 300·dt. Returns: NPTLangevinIsotropicState: Initialized state with ε = 0 """ device, dtype = model.device, model.dtype if alpha is None: alpha = 1.0 / (5 * dt) if cell_alpha is None: cell_alpha = 1.0 / (30 * dt) if b_tau is None: b_tau = 300 * dt alpha = torch.as_tensor(alpha, device=device, dtype=dtype) cell_alpha = torch.as_tensor(cell_alpha, device=device, dtype=dtype) b_tau = torch.as_tensor(b_tau, device=device, dtype=dtype) kT = torch.as_tensor(kT, device=device, dtype=dtype) dt = torch.as_tensor(dt, device=device, dtype=dtype) if alpha.ndim == 0: alpha = alpha.expand(state.n_systems) if cell_alpha.ndim == 0: cell_alpha = cell_alpha.expand(state.n_systems) if b_tau.ndim == 0: b_tau = b_tau.expand(state.n_systems) model_output = model(state) momenta = getattr(state, "momenta", None) if momenta is None: momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT, state.rng, ) reference_cell = state.cell.clone() # ε = 0 at initialization (V = V₀) cell_positions = torch.zeros(state.n_systems, device=device, dtype=dtype) cell_velocities = torch.zeros(state.n_systems, device=device, dtype=dtype) batch_kT = kT.expand(state.n_systems) if kT.ndim == 0 else kT n_atoms_per_system = torch.bincount(state.system_idx) cell_masses = (n_atoms_per_system + 1) * batch_kT * b_tau * b_tau if state.constraints: msg = ( "Constraints are present in the system. " "Make sure they are compatible with NPT Langevin dynamics. " "We recommend not using constraints with NPT dynamics for now." ) warnings.warn(msg, UserWarning, stacklevel=3) logger.warning(msg) npt_state = NPTLangevinIsotropicState.from_state( state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], stress=model_output["stress"], alpha=alpha, b_tau=b_tau, reference_cell=reference_cell, cell_positions=cell_positions, cell_velocities=cell_velocities, cell_masses=cell_masses, cell_alpha=cell_alpha, ) npt_state.store_model_extras(model_output) return npt_state
[docs] @dcite("10.1063/1.4901303") def npt_langevin_isotropic_step( state: NPTLangevinIsotropicState, model: ModelInterface, *, dt: float | torch.Tensor, kT: float | torch.Tensor, external_pressure: float | torch.Tensor, ) -> NPTLangevinIsotropicState: r"""Perform one NPT Langevin step using logarithmic strain coordinate. Uses the same GJF integrator as :func:`npt_langevin_anisotropic_step` but with the cell degree of freedom being the isotropic logarithmic strain :math:`\varepsilon = \frac{1}{d}\ln(V/V_0)` instead of the raw volume. **Strain force:** .. math:: F_\varepsilon = d \cdot V \cdot (P_{\text{avg}} - P_{\text{ext}}) where the Jacobian :math:`dV/d\varepsilon = d \cdot V` naturally provides a volume factor that makes :math:`F_\varepsilon` an energy (eV), giving numerically well-scaled dynamics. **Cell reconstruction:** .. math:: V = V_0 \exp(d\,\varepsilon), \quad \mathbf{h} = e^\varepsilon \, \mathbf{h}_0 **Particle scaling:** .. math:: \mathbf{r}_i \to e^{\varepsilon_{n+1} - \varepsilon_n} \, \mathbf{r}_i Args: state: Current strain-based NPT state model: Model computing energy, forces, stress dt: Integration timestep kT: Target temperature in energy units external_pressure: Target pressure Returns: NPTLangevinIsotropicState: Updated state """ device, dtype = model.device, model.dtype state.alpha = torch.as_tensor(state.alpha, device=device, dtype=dtype) kT_tensor = torch.as_tensor(kT, device=device, dtype=dtype) state.cell_alpha = torch.as_tensor(state.cell_alpha, device=device, dtype=dtype) dt_tensor = torch.as_tensor(dt, device=device, dtype=dtype) external_pressure_tensor = torch.as_tensor( external_pressure, device=device, dtype=dtype ) batch_kT = kT_tensor.expand(state.n_systems) if kT_tensor.ndim == 0 else kT_tensor # Update barostat mass n_atoms_per_system = torch.bincount(state.system_idx) state.cell_masses = (n_atoms_per_system + 1) * batch_kT * torch.square(state.b_tau) # Store initial values forces = state.forces eps_old = state.cell_positions.clone() F_eps_n = _compute_isotropic_cell_force( state=state, external_pressure=external_pressure_tensor, kT=kT_tensor, ) # Generate GJF noise ONCE cell_beta = _npt_langevin_isotropic_cell_beta(state, kT_tensor, dt_tensor) particle_beta = _npt_langevin_particle_beta(state, kT_tensor, dt_tensor) # Step 1: Update strain (cell position step) state = _npt_langevin_isotropic_cell_position_step( state, dt_tensor, F_eps_n, cell_beta ) # Reconstruct cell from updated strain state.cell = state.current_cell # Step 2: Update particle positions (with strain-based scaling) state = _npt_langevin_isotropic_position_step( state, eps_old, dt_tensor, particle_beta ) # Recompute model output model_output = model(state) state.energy = model_output["energy"] state.forces = model_output["forces"] state.stress = model_output["stress"] state.store_model_extras(model_output) # Compute updated strain force F_eps_new = _compute_isotropic_cell_force( state=state, external_pressure=external_pressure_tensor, kT=kT_tensor, ) # Step 3: Update strain velocity (uses SAME cell_beta) state = _npt_langevin_isotropic_cell_velocity_step( state, F_eps_n, dt_tensor, F_eps_new, cell_beta ) # Step 4: Update particle velocities (uses SAME particle_beta) return cast( "NPTLangevinIsotropicState", _npt_langevin_particle_velocity_step( state, forces, dt_tensor, particle_beta, ), )
[docs] @dataclass(kw_only=True) class NPTNoseHooverIsotropicState(NPTState): """State information for an NPT system with Nose-Hoover chain thermostats. This class represents the complete state of a molecular system being integrated in the NPT (constant particle number, pressure, temperature) ensemble using Nose-Hoover chain thermostats for both temperature and pressure control. The cell dynamics are parameterized using a logarithmic coordinate system where cell_position = (1/d)ln(V/V_0), with V being the current volume, V_0 the reference volume, and d the spatial dimension. This ensures volume positivity and simplifies the equations of motion. Attributes: positions (torch.Tensor): Particle positions with shape [n_particles, n_dims] momenta (torch.Tensor): Particle momenta with shape [n_particles, n_dims] forces (torch.Tensor): Forces on particles with shape [n_particles, n_dims] stress (torch.Tensor): Stress tensor with shape [n_systems, n_dims, n_dims] masses (torch.Tensor): Particle masses with shape [n_particles] reference_cell (torch.Tensor): Reference simulation cell matrix with shape [n_systems, n_dimensions, n_dimensions]. Used to measure relative volume changes. cell_position (torch.Tensor): Logarithmic cell coordinate with shape [n_systems]. Represents (1/d)ln(V/V_0) where V is current volume and V_0 is reference volume. cell_momentum (torch.Tensor): Cell momentum (velocity) conjugate to cell_position with shape [n_systems]. Controls volume changes. cell_mass (torch.Tensor): Mass parameter for cell dynamics with shape [n_systems]. Controls coupling between volume fluctuations and pressure. barostat (NoseHooverChain): Chain thermostat coupled to cell dynamics for pressure control thermostat (NoseHooverChain): Chain thermostat coupled to particle dynamics for temperature control barostat_fns (NoseHooverChainFns): Functions for barostat chain updates thermostat_fns (NoseHooverChainFns): Functions for thermostat chain updates Properties: velocities (torch.Tensor): Particle velocities computed as momenta divided by masses. Shape: [n_particles, n_dimensions] current_cell (torch.Tensor): Current simulation cell matrix derived from cell_position. Shape: [n_systems, n_dimensions, n_dimensions] Notes: - The cell parameterization ensures volume positivity - Nose-Hoover chains provide deterministic control of T and P - Extended system approach conserves an extended Hamiltonian - Time-reversible when integrated with appropriate algorithms - All cell-related properties now support batch dimensions """ # Cell variables - now with batch dimensions reference_cell: torch.Tensor # [n_systems, 3, 3] cell_position: torch.Tensor # [n_systems] cell_momentum: torch.Tensor # [n_systems] cell_mass: torch.Tensor # [n_systems] # Thermostat variables thermostat: NoseHooverChain thermostat_fns: NoseHooverChainFns # Barostat variables barostat: NoseHooverChain barostat_fns: NoseHooverChainFns _system_attributes = NPTState._system_attributes | { # noqa: SLF001 "reference_cell", "cell_position", "cell_momentum", "cell_mass", } _global_attributes = NPTState._global_attributes | { # noqa: SLF001 "thermostat", "barostat", "thermostat_fns", "barostat_fns", } @property def velocities(self) -> torch.Tensor: """Calculate particle velocities from momenta and masses. Returns: torch.Tensor: Particle velocities with shape [n_particles, n_dimensions] """ return self.momenta / self.masses.unsqueeze(-1) @property def current_cell(self) -> torch.Tensor: """Calculate current simulation cell from cell position. The cell is computed from the reference cell and cell_position using: cell = (V/V_0)^(1/d) * reference_cell where V = V_0 * exp(d * cell_position) Returns: torch.Tensor: Current simulation cell matrix with shape [n_systems, n_dimensions, n_dimensions] """ dim = self.positions.shape[1] V_0 = torch.det(self.reference_cell) # [n_systems] V = V_0 * torch.exp(dim * self.cell_position) # [n_systems] scale = (V / V_0) ** (1.0 / dim) # [n_systems] # Expand scale to [n_systems, 1, 1] for broadcasting scale = scale.unsqueeze(-1).unsqueeze(-1) return scale * self.reference_cell
[docs] def get_number_of_degrees_of_freedom(self) -> torch.Tensor: """Calculate degrees of freedom per system.""" dof = super().get_number_of_degrees_of_freedom() return dof - 3 # Subtract 3 degrees of freedom for center of mass motion
def _npt_nose_hoover_isotropic_cell_info( state: NPTNoseHooverIsotropicState, ) -> tuple[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]: """Gets the current volume and a function to compute the cell from volume. This helper function computes the current system volume and returns a function that can compute the simulation cell for any given volume. This is useful for integration algorithms that need to update the cell based on volume changes. Args: state (NPTNoseHooverIsotropicState): Current state of the NPT system Returns: tuple: - torch.Tensor: Current system volume with shape [n_systems] - callable: Function that takes a volume tensor [n_systems] and returns the corresponding cell matrix [n_systems, n_dimensions, n_dimensions] Notes: - Uses logarithmic cell coordinate parameterization - Volume changes are measured relative to reference cell - Cell scaling preserves shape while changing volume - Supports batched operations """ dim = state.positions.shape[1] ref = state.reference_cell # [n_systems, dim, dim] V_0 = torch.det(ref) # [n_systems] - Reference volume V = V_0 * torch.exp(dim * state.cell_position) # [n_systems] - Current volume def volume_to_cell(V: torch.Tensor) -> torch.Tensor: """Compute cell matrix for given volumes. Args: V (torch.Tensor): Volumes with shape [n_systems] Returns: torch.Tensor: Cell matrices with shape [n_systems, dim, dim] """ scale = torch.pow(V / V_0, 1.0 / dim) # [n_systems] # Expand scale to [n_systems, 1, 1] for broadcasting scale = scale.unsqueeze(-1).unsqueeze(-1) return scale * ref return V, volume_to_cell def _npt_nose_hoover_isotropic_update_cell_mass( state: NPTNoseHooverIsotropicState, kT: torch.Tensor, device: torch.device, dtype: torch.dtype, ) -> NPTNoseHooverIsotropicState: """Update the cell mass parameter in an NPT simulation. This function updates the mass parameter associated with cell volume fluctuations based on the current system size and target temperature. The cell mass controls how quickly the volume can change and is chosen to maintain stable pressure control. Args: state (NPTNoseHooverIsotropicState): Current state of the NPT system kT (torch.Tensor): Target temperature in energy units, either scalar or shape [n_systems] device (torch.device): Device for tensor operations dtype (torch.dtype): Data type for tensor operations Returns: NPTNoseHooverIsotropicState: Updated state with new cell mass Notes: - Cell mass scales with system size (N+1) and dimensionality - Larger cell mass gives slower but more stable volume fluctuations - Mass depends on barostat relaxation time (tau) - Supports batched operations """ _n_particles, dim = state.positions.shape # Handle both scalar and batched kT kT_system = kT.expand(state.n_systems) if kT.ndim == 0 else kT # Calculate cell masses for each system n_atoms_per_system = torch.bincount(state.system_idx, minlength=state.n_systems) cell_mass = ( dim * (n_atoms_per_system + 1) * kT_system * torch.square(state.barostat.tau) ) # Update state with new cell masses state.cell_mass = cell_mass.to(device=device, dtype=dtype) return state def _npt_nose_hoover_isotropic_sinhx_x(x: torch.Tensor) -> torch.Tensor: """Compute sinh(x)/x using Taylor series expansion near x=0. This function implements a Taylor series approximation of sinh(x)/x that is accurate near x=0. The series expansion is: sinh(x)/x = 1 + x²/6 + x⁴/120 + x⁶/5040 + x⁸/362880 + x¹⁰/39916800 Args: x (torch.Tensor): Input tensor Returns: torch.Tensor: Approximation of sinh(x)/x Notes: - Uses 6 terms of Taylor series for good accuracy near x=0 - Relative error < 1e-12 for |x| < 0.5 - More efficient than direct sinh(x)/x computation for small x - Avoids division by zero at x=0 Example: >>> x = torch.tensor([0.0, 0.1, 0.2]) >>> y = sinhx_x(x) >>> print(y) # tensor([1, 1.0017, 1.0067]) """ return ( 1 + torch.pow(x, 2) / 6 + torch.pow(x, 4) / 120 + torch.pow(x, 6) / 5040 + torch.pow(x, 8) / 362_880 + torch.pow(x, 10) / 39_916_800 ) def _npt_nose_hoover_isotropic_exp_iL1( # noqa: N802 state: NPTNoseHooverIsotropicState, velocities: torch.Tensor, cell_velocity: torch.Tensor, dt: torch.Tensor, ) -> torch.Tensor: """Apply the exp(iL1) operator for NPT dynamics position updates. This function implements the position update operator for NPT dynamics using a symplectic integration scheme. It accounts for both particle motion and cell scaling effects through the cell velocity, with optional periodic boundary conditions. The update follows the form: R_new = R + (exp(x) - 1)R + dt*V*exp(x/2)*sinh(x/2)/(x/2) where x = V_b * dt is the cell velocity term Args: state (NPTNoseHooverIsotropicState): Current simulation state velocities (torch.Tensor): Particle velocities [n_particles, n_dimensions] cell_velocity (torch.Tensor): Cell velocity with shape [n_systems] dt (torch.Tensor): Integration timestep Returns: torch.Tensor: Updated particle positions with optional periodic wrapping Notes: - Uses Taylor series for sinh(x)/x near x=0 for numerical stability - Properly handles cell scaling through cell_velocity - Maintains time-reversibility of the integration scheme - Applies periodic boundary conditions if state.pbc is True - Supports batched operations with proper atom-to-system mapping """ # Map system-level cell velocities to atom level using system indices cell_velocity_atoms = cell_velocity[state.system_idx] # [n_atoms] # Compute cell velocity terms per atom x = cell_velocity_atoms * dt # [n_atoms] x_2 = x / 2 # [n_atoms] # Compute sinh(x/2)/(x/2) using stable Taylor series sinh_term = _npt_nose_hoover_isotropic_sinhx_x(x_2) # [n_atoms] # Expand dimensions for broadcasting with positions [n_atoms, 3] x_expanded = x.unsqueeze(-1) # [n_atoms, 1] x_2_expanded = x_2.unsqueeze(-1) # [n_atoms, 1] sinh_expanded = sinh_term.unsqueeze(-1) # [n_atoms, 1] # Compute position updates new_positions = ( state.positions * (torch.exp(x_expanded) - 1) + dt * velocities * torch.exp(x_2_expanded) * sinh_expanded ) return state.positions + new_positions def _npt_nose_hoover_isotropic_exp_iL2( # noqa: N802 state: NPTNoseHooverIsotropicState, alpha: torch.Tensor, momenta: torch.Tensor, forces: torch.Tensor, cell_velocity: torch.Tensor, dt_2: torch.Tensor, ) -> torch.Tensor: """Apply the exp(iL2) operator for NPT dynamics momentum updates. This function implements the momentum update operator for NPT dynamics using a symplectic integration scheme. It accounts for both force terms and cell velocity scaling effects. The update follows the form: P_new = P*exp(-x) + dt/2 * F * exp(-x/2) * sinh(x/2)/(x/2) where x = alpha * V_b * dt/2 Args: state (NPTNoseHooverIsotropicState): Current simulation state for batch mapping alpha (torch.Tensor): Cell scaling parameter with shape [n_systems] momenta (torch.Tensor): Current particle momenta [n_particles, n_dimensions] forces (torch.Tensor): Forces on particles [n_particles, n_dimensions] cell_velocity (torch.Tensor): Cell velocity with shape [n_systems] dt_2 (torch.Tensor): Half timestep (dt/2) Returns: torch.Tensor: Updated particle momenta Notes: - Uses Taylor series for sinh(x)/x near x=0 for numerical stability - Properly handles cell velocity scaling effects - Maintains time-reversibility of the integration scheme - Part of the NPT integration algorithm - Supports batched operations with proper atom-to-system mapping """ # Map system-level cell velocities to atom level using system indices cell_velocity_atoms = cell_velocity[state.system_idx] # [n_atoms] # Compute scaling terms per atom alpha_atoms = alpha[state.system_idx] # [n_atoms] x = alpha_atoms * cell_velocity_atoms * dt_2 # [n_atoms] x_2 = x / 2 # [n_atoms] # Compute sinh(x/2)/(x/2) using stable Taylor series sinh_term = _npt_nose_hoover_isotropic_sinhx_x(x_2) # [n_atoms] # Expand dimensions for broadcasting with momenta [n_atoms, 3] x_expanded = x.unsqueeze(-1) # [n_atoms, 1] x_2_expanded = x_2.unsqueeze(-1) # [n_atoms, 1] sinh_expanded = sinh_term.unsqueeze(-1) # [n_atoms, 1] # Update momenta with both scaling and force terms return momenta * torch.exp(-x_expanded) + dt_2 * forces * sinh_expanded * torch.exp( -x_2_expanded ) def _npt_nose_hoover_isotropic_compute_cell_force( alpha: torch.Tensor, volume: torch.Tensor, positions: torch.Tensor, momenta: torch.Tensor, masses: torch.Tensor, stress: torch.Tensor, external_pressure: torch.Tensor, system_idx: torch.Tensor, ) -> torch.Tensor: """Compute the force on the cell degree of freedom in NPT dynamics. This function calculates the force driving cell volume changes in NPT simulations. The force includes contributions from: 1. Kinetic energy scaling (alpha * KE) 2. Internal stress (from stress_fn) 3. External pressure (P*V) Args: alpha (torch.Tensor): Cell scaling parameter volume (torch.Tensor): Current system volume with shape [n_systems] positions (torch.Tensor): Particle positions [n_particles, n_dimensions] momenta (torch.Tensor): Particle momenta [n_particles, n_dimensions] masses (torch.Tensor): Particle masses [n_particles] stress (torch.Tensor): Stress tensor [n_systems, n_dimensions, n_dimensions] external_pressure (torch.Tensor): Target external pressure system_idx (torch.Tensor): System indices for atoms [n_particles] Returns: torch.Tensor: Force on the cell degree of freedom with shape [n_systems] Notes: - Force drives volume changes to maintain target pressure - Includes both kinetic and potential contributions - Uses stress tensor for potential energy contribution - Properly handles periodic boundary conditions - Supports batched operations """ _N, dim = positions.shape n_systems = len(volume) # Compute kinetic energy contribution per system # Split momenta and masses by system KE_per_system = torch.zeros(n_systems, device=positions.device, dtype=positions.dtype) for sys_idx in range(n_systems): system_mask = system_idx == sys_idx if system_mask.any(): system_momenta = momenta[system_mask] system_masses = masses[system_mask] KE_per_system[sys_idx] = ts.calc_kinetic_energy( masses=system_masses, momenta=system_momenta ) # Get stress tensor and compute trace per system # Handle stress tensor with batch dimension if stress.ndim == 3: internal_pressure = torch.diagonal(stress, dim1=-2, dim2=-1).sum( dim=-1 ) # [n_systems] else: # Single system case - expand to batch dimension internal_pressure = torch.trace(stress).unsqueeze(0).expand(n_systems) # Compute force on cell coordinate per system # F = alpha * (2 * KE) - dU/dV - P*V*d return ( (alpha * 2 * KE_per_system) - (internal_pressure * volume) - (external_pressure * volume * dim) ) def _npt_nose_hoover_isotropic_inner_step( state: NPTNoseHooverIsotropicState, model: ModelInterface, dt: torch.Tensor, external_pressure: torch.Tensor, ) -> NPTNoseHooverIsotropicState: """Perform one inner step of NPT integration using velocity Verlet algorithm. This function implements a single integration step for NPT dynamics, including: 1. Cell momentum and particle momentum updates (half step) 2. Position and cell position updates (full step) 3. Force updates with new positions and cell 4. Final momentum updates (half step) Args: model (ModelInterface): Model to compute forces and energies state (NPTNoseHooverIsotropicState): Current system state dt (torch.Tensor): Integration timestep external_pressure (torch.Tensor): Target external pressure Returns: NPTNoseHooverIsotropicState: Updated state after one integration step """ # Get target pressure from kwargs or use default dt_2 = dt / 2 # Unpack state variables for clarity positions = state.positions momenta = state.momenta masses = state.masses forces = state.forces cell_position = state.cell_position # [n_systems] cell_momentum = state.cell_momentum # [n_systems] cell_mass = state.cell_mass # [n_systems] # Get current volume and cell function volume, volume_to_cell = _npt_nose_hoover_isotropic_cell_info(state) cell = volume_to_cell(volume) # First half step: Update momenta # alpha = 1 + dim / degrees_of_freedom (3 * natoms - 3) alpha = 1 + 3 / state.get_number_of_degrees_of_freedom() # [n_systems] # Reuse stress from previous step since positions and cell unchanged cell_force_val = _npt_nose_hoover_isotropic_compute_cell_force( alpha=alpha, volume=volume, positions=positions, momenta=momenta, masses=masses, stress=state.stress, external_pressure=external_pressure, system_idx=state.system_idx, ) # Update cell momentum and particle momenta cell_momentum = cell_momentum + dt_2 * cell_force_val.unsqueeze(-1) cell_velocities = cell_momentum.squeeze(-1) / cell_mass momenta = _npt_nose_hoover_isotropic_exp_iL2( state, alpha, momenta, forces, cell_velocities, dt_2 ) # Full step: Update positions cell_position = cell_position + cell_velocities * dt # Update state with new cell_position before calling functions that depend on it state.cell_position = cell_position # Get updated cell volume, volume_to_cell = _npt_nose_hoover_isotropic_cell_info(state) cell = volume_to_cell(volume) # Update particle positions and forces state.set_constrained_momenta(momenta) positions = _npt_nose_hoover_isotropic_exp_iL1( state, state.velocities, cell_velocities, dt ) state.set_constrained_positions(positions) state.cell = cell model_output = model(state) # Second half step: Update momenta momenta = _npt_nose_hoover_isotropic_exp_iL2( state, alpha, momenta, model_output["forces"], cell_velocities, dt_2 ) cell_force_val = _npt_nose_hoover_isotropic_compute_cell_force( alpha=alpha, volume=volume, positions=positions, momenta=momenta, masses=masses, stress=model_output["stress"], external_pressure=external_pressure, system_idx=state.system_idx, ) cell_momentum = cell_momentum + dt_2 * cell_force_val.unsqueeze(-1) # Return updated state state.set_constrained_positions(positions) state.set_constrained_momenta(momenta) state.forces = model_output["forces"] state.stress = model_output["stress"] state.energy = model_output["energy"] state.store_model_extras(model_output) state.cell_position = cell_position state.cell_momentum = cell_momentum state.cell_mass = cell_mass return state
[docs] def npt_nose_hoover_isotropic_init( state: SimState, model: ModelInterface, *, kT: float | torch.Tensor, dt: float | torch.Tensor, chain_length: int = 3, chain_steps: int = 2, sy_steps: int = 3, t_tau: float | torch.Tensor | None = None, b_tau: float | torch.Tensor | None = None, **kwargs: Any, ) -> NPTNoseHooverIsotropicState: """Initialize the NPT Nose-Hoover state. This function initializes a state for NPT molecular dynamics with Nose-Hoover chain thermostats for both temperature and pressure control. It sets up the system with appropriate initial conditions including particle positions, momenta, cell variables, and thermostat chains. To seed the RNG set ``state.rng = seed`` before calling. Args: model (ModelInterface): Model to compute forces and energies state: Initial system state as SimState containing positions, masses, cell, and PBC information kT: Target temperature in energy units external_pressure: Target external pressure dt: Integration timestep chain_length: Length of Nose-Hoover chains. Defaults to 3. chain_steps: Chain integration substeps. Defaults to 2. sy_steps: Suzuki-Yoshida integration order. Defaults to 3. t_tau: Thermostat relaxation time. Controls how quickly temperature equilibrates. Defaults to 100*dt b_tau: Barostat relaxation time. Controls how quickly pressure equilibrates. Defaults to 1000*dt **kwargs: Additional state variables like atomic_numbers or pre-initialized momenta Returns: NPTNoseHooverIsotropicState: Initialized state containing: - Particle positions, momenta, forces - Cell position, momentum and mass (all with batch dimensions) - Reference cell matrix (with batch dimensions) - Thermostat and barostat chain variables - System energy - Other state variables (masses, PBC, etc.) Notes: - Uses separate Nose-Hoover chains for temperature and pressure control - Cell mass is set based on system size and barostat relaxation time - Initial momenta are drawn from Maxwell-Boltzmann distribution if not provided - Cell dynamics use logarithmic coordinates for volume updates - All cell properties are properly initialized with batch dimensions """ device, dtype = state.device, state.dtype dt_tensor = torch.as_tensor(dt, device=device, dtype=dtype) kT_tensor = torch.as_tensor(kT, device=device, dtype=dtype) t_tau_tensor = torch.as_tensor( 10 * dt_tensor if t_tau is None else t_tau, device=device, dtype=dtype ) b_tau_tensor = torch.as_tensor( 100 * dt_tensor if b_tau is None else b_tau, device=device, dtype=dtype ) # Setup thermostats with appropriate timescales barostat_fns = construct_nose_hoover_chain( dt_tensor, chain_length, chain_steps, sy_steps, b_tau_tensor ) thermostat_fns = construct_nose_hoover_chain( dt_tensor, chain_length, chain_steps, sy_steps, t_tau_tensor ) _n_particles, dim = state.positions.shape n_systems = state.n_systems atomic_numbers = kwargs.get("atomic_numbers", state.atomic_numbers) # Initialize cell variables with proper system dimensions # cell_momentum: [n_systems, 1] for compatibility with half_step cell_position = torch.zeros(n_systems, device=device, dtype=dtype) cell_momentum = torch.zeros(n_systems, 1, device=device, dtype=dtype) # Handle both scalar and batched kT kT_system = kT_tensor.expand(n_systems) if kT_tensor.ndim == 0 else kT_tensor # Calculate cell masses for each system n_atoms_per_system = torch.bincount(state.system_idx, minlength=n_systems) cell_mass = dim * (n_atoms_per_system + 1) * kT_system * torch.square(b_tau_tensor) cell_mass = cell_mass.to(device=device, dtype=dtype) # Calculate cell kinetic energy (using first system for initialization) dof_barostat = torch.ones(n_systems, device=device, dtype=dtype) KE_cell = (cell_momentum.squeeze(-1) ** 2) / (2 * cell_mass) # Initialize momenta momenta = kwargs.get("momenta") if momenta is None: momenta = getattr(state, "momenta", None) if momenta is None: momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT_tensor, state.rng, ) # Compute total DOF for thermostat initialization and a zero KE placeholder dof_per_system = state.get_number_of_degrees_of_freedom() - 3 KE_thermostat = ts.calc_kinetic_energy( masses=state.masses, momenta=momenta, system_idx=state.system_idx ) # Ensure reference_cell has proper system dimensions if state.cell.ndim == 2: # Single cell matrix - expand to batch dimension reference_cell = state.cell.unsqueeze(0).expand(n_systems, -1, -1).clone() else: # Already has batch dimension reference_cell = state.cell.clone() # Handle scalar cell input if (torch.is_tensor(state.cell) and state.cell.ndim == 0) or isinstance( state.cell, int | float ): cell_matrix = torch.eye(dim, device=device, dtype=dtype) * state.cell reference_cell = cell_matrix.unsqueeze(0).expand(n_systems, -1, -1).clone() state.cell = reference_cell # Get model output model_output = model(state) forces = model_output["forces"] energy = model_output["energy"] stress = model_output["stress"] if state.constraints: # warn if constraints are present msg = ( "Constraints are present in the system. " "Make sure they are compatible with NPT Nosé Hoover dynamics." "We recommend not using constraints with NPT dynamics for now." ) warnings.warn(msg, UserWarning, stacklevel=3) logger.warning(msg) # Create initial state npt_state = NPTNoseHooverIsotropicState.from_state( state, momenta=momenta, energy=energy, forces=forces, stress=stress, atomic_numbers=atomic_numbers, reference_cell=reference_cell, cell_position=cell_position, cell_momentum=cell_momentum, cell_mass=cell_mass, barostat=barostat_fns.initialize(dof_barostat, KE_cell, kT_tensor), thermostat=thermostat_fns.initialize(dof_per_system, KE_thermostat, kT_tensor), barostat_fns=barostat_fns, thermostat_fns=thermostat_fns, ) npt_state.store_model_extras(model_output) return npt_state
[docs] @dcite("10.1080/00268979600100761") @dcite("10.1088/0305-4470/39/19/S18") def npt_nose_hoover_isotropic_step( state: NPTNoseHooverIsotropicState, model: ModelInterface, *, dt: float | torch.Tensor, kT: float | torch.Tensor, external_pressure: float | torch.Tensor, ) -> NPTNoseHooverIsotropicState: r"""Perform a complete NPT integration step with Nose-Hoover chain thermostats. Implements the MTK (Martyna-Tobias-Klein) NPT scheme from Tuckerman et al. (2006) [10]_ with Nose-Hoover chains from Martyna et al. (1996) [3]_. **Equations of motion** (Tuckerman et al. 2006, Eqs. 1-6): .. math:: \dot{\mathbf{r}}_i &= \frac{\mathbf{p}_i}{m_i} + \frac{p_\epsilon}{W}\,\mathbf{r}_i \\ \dot{\mathbf{p}}_i &= \mathbf{F}_i - \alpha\,\frac{p_\epsilon}{W}\,\mathbf{p}_i \\ \dot{\epsilon} &= \frac{p_\epsilon}{W} \\ \dot{p}_\epsilon &= G_\epsilon = \alpha\,(2K) + \text{Tr}(\boldsymbol{\sigma}_{\text{int}})\,V - P_{\text{ext}}\,V\,d where :math:`\epsilon = (1/d)\ln(V/V_0)` is the logarithmic cell coordinate, :math:`\alpha = 1 + d/N_f`, :math:`d=3` is spatial dimension, and :math:`N_f = 3N - 3` the degrees of freedom. **Symmetric propagator** (Trotter factorization): .. math:: e^{i\mathcal{L}\Delta t} = e^{i\mathcal{L}_{\text{NHC-baro}}\frac{\Delta t}{2}} \;e^{i\mathcal{L}_{\text{NHC-part}}\frac{\Delta t}{2}} \;e^{i\mathcal{L}_2\frac{\Delta t}{2}} \;e^{i\mathcal{L}_1\Delta t} \;e^{i\mathcal{L}_2\frac{\Delta t}{2}} \;e^{i\mathcal{L}_{\text{NHC-part}}\frac{\Delta t}{2}} \;e^{i\mathcal{L}_{\text{NHC-baro}}\frac{\Delta t}{2}} **Position update** :math:`e^{i\mathcal{L}_1\Delta t}`: .. math:: \mathbf{r}_i \leftarrow \mathbf{r}_i + \bigl(e^{v_\epsilon\Delta t} - 1\bigr)\,\mathbf{r}_i + \Delta t\,\mathbf{v}_i\,e^{v_\epsilon\Delta t/2} \,\frac{\sinh(v_\epsilon\Delta t/2)}{v_\epsilon\Delta t/2} **Momentum update** :math:`e^{i\mathcal{L}_2\Delta t/2}`: .. math:: \mathbf{p}_i \leftarrow \mathbf{p}_i\,e^{-\alpha v_\epsilon\Delta t/2} + \frac{\Delta t}{2}\,\mathbf{F}_i\, e^{-\alpha v_\epsilon\Delta t/4} \,\frac{\sinh(\alpha v_\epsilon\Delta t/4)} {\alpha v_\epsilon\Delta t/4} where :math:`v_\epsilon = p_\epsilon / W` is the cell velocity. **Variable mapping (equation -> code):** ============================================ ============================ Equation symbol Code variable ============================================ ============================ :math:`\mathbf{r}_i` (positions) ``state.positions`` :math:`\mathbf{p}_i` (momenta) ``state.momenta`` :math:`m_i` (masses) ``state.masses`` :math:`\mathbf{F}_i` (forces) ``state.forces`` :math:`\epsilon` (log-cell coordinate) ``state.cell_position`` :math:`p_\epsilon` (cell momentum) ``state.cell_momentum`` :math:`W` (cell mass) ``state.cell_mass`` :math:`\alpha` (1 + d/Nf) ``alpha`` (local) :math:`v_\epsilon` (cell velocity) ``cell_velocities`` (local) :math:`V_0` (reference volume) ``det(state.reference_cell)`` :math:`G_\epsilon` (cell force) ``cell_force_val`` :math:`P_{\text{ext}}` (target pressure) ``external_pressure`` :math:`k_BT` (thermal energy) ``kT`` :math:`\Delta t` (timestep) ``dt`` ============================================ ============================ If the center of mass motion is removed initially, it remains removed throughout the simulation, so the degrees of freedom decreases by 3. Args: model: Model to compute forces and energies state: Current system state dt: Integration timestep kT: Target temperature external_pressure: Target external pressure Returns: NPTNoseHooverIsotropicState: Updated state after complete integration step References: .. [10] Tuckerman, M. E., et al. "A Liouville-operator derived measure-preserving integrator for molecular dynamics simulations in the isothermal-isobaric ensemble." J. Phys. A 39(19), 5629-5651 (2006). .. [3] Martyna, G. J., et al. "Explicit reversible integrators for extended systems dynamics." Mol. Phys. 87(5), 1117-1157 (1996). """ device, dtype = model.device, model.dtype dt_tensor = torch.as_tensor(dt, device=device, dtype=dtype) kT_tensor = torch.as_tensor(kT, device=device, dtype=dtype) external_pressure_tensor = torch.as_tensor( external_pressure, device=device, dtype=dtype ) # Unpack state variables for clarity barostat = state.barostat thermostat = state.thermostat # Update mass parameters state.barostat = state.barostat_fns.update_mass(barostat, kT_tensor) state.thermostat = state.thermostat_fns.update_mass(thermostat, kT_tensor) state = _npt_nose_hoover_isotropic_update_cell_mass(state, kT_tensor, device, dtype) # First half step of thermostat chains cell_system_idx = torch.arange(state.n_systems, device=device) state.cell_momentum, state.barostat = state.barostat_fns.half_step( state.cell_momentum, state.barostat, kT_tensor, cell_system_idx ) state.momenta, state.thermostat = state.thermostat_fns.half_step( state.momenta, state.thermostat, kT_tensor, state.system_idx ) # Perform inner NPT step state = _npt_nose_hoover_isotropic_inner_step( state, model, dt_tensor, external_pressure_tensor ) # Update kinetic energies for thermostats KE = ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) state.thermostat.kinetic_energy = KE KE_cell = (torch.square(state.cell_momentum.squeeze(-1))) / (2 * state.cell_mass) state.barostat.kinetic_energy = KE_cell # Second half step of thermostat chains state.momenta, state.thermostat = state.thermostat_fns.half_step( state.momenta, state.thermostat, kT_tensor, state.system_idx ) state.cell_momentum, state.barostat = state.barostat_fns.half_step( state.cell_momentum, state.barostat, kT_tensor, cell_system_idx ) return state
def _compute_chain_energy( chain: NoseHooverChain, kT: torch.Tensor, e_tot: torch.Tensor, dof: torch.Tensor ) -> torch.Tensor: """Compute energy contribution from a Nose-Hoover chain. Args: chain: The Nose-Hoover chain state kT: Target temperature e_tot: Current total energy for broadcasting dof: Degrees of freedom (only used for first chain element) Returns: Total chain energy contribution """ chain_energy = torch.zeros_like(e_tot) # First chain element with DOF weighting ke_0 = torch.square(chain.momenta[:, 0]) / (2 * chain.masses[:, 0]) pe_0 = dof * kT * chain.positions[:, 0] chain_energy += ke_0 + pe_0 # Remaining chain elements for i in range(1, chain.positions.shape[1]): ke = torch.square(chain.momenta[:, i]) / (2 * chain.masses[:, i]) pe = kT * chain.positions[:, i] chain_energy += ke + pe return chain_energy
[docs] def npt_nose_hoover_isotropic_invariant( state: NPTNoseHooverIsotropicState, kT: torch.Tensor, external_pressure: torch.Tensor, ) -> torch.Tensor: """Computes the conserved quantity for NPT ensemble with Nose-Hoover thermostat. This function calculates the Hamiltonian of the extended NPT dynamics, which should be conserved during the simulation. It's useful for validating the correctness of NPT simulations. The conserved quantity includes: - Potential energy of the systems - Kinetic energy of the particles - Energy contributions from thermostat chains (per system) - Energy contributions from barostat chains (per system) - PV work term - Cell kinetic energy Args: state: Current state of the NPT simulation system. Must contain position, momentum, cell, cell_momentum, cell_mass, thermostat, and barostat with proper batching for multiple systems. external_pressure: Target external pressure of the system. kT: Target thermal energy (Boltzmann constant x temperature). Returns: torch.Tensor: The conserved quantity (extended Hamiltonian) of the NPT system. Returns a scalar for a single system or tensor with shape [n_systems] for multiple systems. """ # Calculate volume and potential energy volume = torch.det(state.current_cell) # [n_systems] e_pot = state.energy # Should be scalar or [n_systems] # Calculate kinetic energy of particles per system e_kin_per_system = ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) # Calculate degrees of freedom per system dof_per_system = state.get_number_of_degrees_of_freedom() # Initialize total energy with PE + KE e_tot = e_pot + e_kin_per_system # Add thermostat chain contributions (batched per system, DOF = 3 * n_atoms - 3) e_tot += _compute_chain_energy(state.thermostat, kT, e_tot, dof_per_system) # Add barostat chain contributions (batched per system, DOF = 1) barostat_dof = torch.ones_like(dof_per_system) # 1 DOF per system for barostat e_tot += _compute_chain_energy(state.barostat, kT, e_tot, barostat_dof) # Add PV term and cell kinetic energy (both are per system) e_tot += external_pressure * volume # Ensure cell_momentum has the right shape [n_systems] cell_momentum = state.cell_momentum.squeeze() e_tot += torch.square(cell_momentum) / (2 * state.cell_mass) return e_tot
[docs] @dataclass(kw_only=True) class NPTCRescaleState(NPTState): """State for NPT ensemble with cell rescaling barostat. This class extends the NPTState to include variables and properties specific to the NPT ensemble with a cell rescaling barostat. """ isothermal_compressibility: torch.Tensor # shape: [n_systems] tau_p: torch.Tensor # shape: [n_systems] initial_cell: torch.Tensor # shape: [n_systems, 3, 3] initial_cell_inv: torch.Tensor # shape: [n_systems, 3, 3] initial_volume: torch.Tensor # shape: [n_systems] _system_attributes = NPTState._system_attributes | { # noqa: SLF001 "isothermal_compressibility", "tau_p", "initial_cell", "initial_cell_inv", "initial_volume", }
[docs] def get_number_of_degrees_of_freedom(self) -> torch.Tensor: """Calculate degrees of freedom for each system in the batch. Returns: torch.Tensor: Degrees of freedom for each system, shape [n_systems] """ # Subtract 3 for center of mass motion return super().get_number_of_degrees_of_freedom() - 3
[docs] def rotate_gram_schmidt(box: torch.Tensor) -> torch.Tensor: """Convert a batch of 3x3 box matrices into lower-triangular form. Args: box (torch.Tensor): shape [n_systems, 3, 3] Returns: torch.Tensor: shape [n_systems, 3, 3] lower-triangular boxes """ out = torch.zeros_like(box) # Columns (a, b, c) correspond to box vectors in column form a = box[:, :, 0] b = box[:, :, 1] c = box[:, :, 2] # --- Compute the lower-triangular entries --- # a-axis out[:, 0, 0] = torch.norm(a, dim=1) # b projections out[:, 1, 0] = torch.sum(a * b, dim=1) / out[:, 0, 0] out[:, 1, 1] = torch.sqrt(torch.sum(b * b, dim=1) - out[:, 1, 0] ** 2) # c projections out[:, 2, 0] = torch.sum(a * c, dim=1) / out[:, 0, 0] out[:, 2, 1] = (torch.sum(b * c, dim=1) - out[:, 2, 0] * out[:, 1, 0]) / out[:, 1, 1] out[:, 2, 2] = torch.sqrt( torch.sum(c * c, dim=1) - out[:, 2, 0] ** 2 - out[:, 2, 1] ** 2 ) # Upper-triangular entries are 0 by initialization return out
[docs] def batch_matrix_vector( matrices: torch.Tensor, vectors: torch.Tensor, ) -> torch.Tensor: """Perform batch matrix-vector multiplication. Args: matrices (torch.Tensor): shape [n_systems, n, n] vectors (torch.Tensor): shape [n_systems, n, m] Returns: torch.Tensor: shape [n_systems, n, m] result of multiplication """ return torch.matmul(matrices, vectors.unsqueeze(-1)).squeeze(-1)
def _compute_deviatoric_correction( cell: torch.Tensor, volume: torch.Tensor, initial_cell_inv: torch.Tensor, initial_volume: torch.Tensor, external_pressure_tensor: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute deviatoric pressure correction for non-hydrostatic external stress. Follows the algorithm from Bussi's crescale reference implementation: projects the deviatoric part of the external stress through the reference cell frame. Args: cell: Current cell matrix, shape [n_systems, 3, 3] volume: Current volume, shape [n_systems] initial_cell_inv: Inverse of initial cell, shape [n_systems, 3, 3] initial_volume: Initial volume, shape [n_systems] external_pressure_tensor: Full external pressure tensor [n_systems, 3, 3] Returns: pressure_hydro: Hydrostatic pressure scalar [n_systems] pressure_dev: Deviatoric pressure correction [n_systems, 3, 3] trace_pressure_dev: Trace of pressure_dev [n_systems] """ pressure_hydro = torch.einsum("bii->b", external_pressure_tensor) / 3 I = torch.eye(3, device=cell.device, dtype=cell.dtype) # noqa: E741 stress_dev = external_pressure_tensor - pressure_hydro[:, None, None] * I.expand_as( external_pressure_tensor ) # Project to reference coordinates: sigma = V0 * h0_inv^T @ stress_dev @ h0_inv sigma = initial_volume[:, None, None] * ( initial_cell_inv.transpose(-2, -1) @ stress_dev @ initial_cell_inv ) # Symmetrize and project back: pressure_dev = h^T @ 0.5*(sigma+sigma^T) @ h / V sigma_sym = 0.5 * (sigma + sigma.transpose(-2, -1)) pressure_dev = cell.transpose(-2, -1) @ sigma_sym @ cell / volume[:, None, None] trace_pressure_dev = torch.einsum("bii->b", pressure_dev) return pressure_hydro, pressure_dev, trace_pressure_dev def _crescale_triclinic_barostat_step( state: NPTCRescaleState, kT: torch.Tensor, dt: torch.Tensor, external_pressure: torch.Tensor, ) -> NPTCRescaleState: volume = torch.det(state.cell) # shape: (n_systems,) P_int = ts.quantities.compute_instantaneous_pressure_tensor( momenta=state.momenta, masses=state.masses, system_idx=state.system_idx, stress=state.stress, volumes=volume, ) sqrt_vol = torch.sqrt(volume) trace_P_int = torch.einsum("bii->b", P_int) prefactor_random = torch.sqrt( kT * state.isothermal_compressibility * dt / (4 * state.tau_p) ) prefactor = state.isothermal_compressibility * sqrt_vol / (2 * state.tau_p) # Deviatoric correction for non-hydrostatic external stress deviatoric = external_pressure.ndim >= 2 if deviatoric: # Expand [3,3] -> [n_systems, 3, 3] if needed ext_p_tensor = external_pressure if ext_p_tensor.ndim == 2: ext_p_tensor = ext_p_tensor.unsqueeze(0).expand(state.n_systems, -1, -1) pressure_hydro, pressure_dev, trace_pressure_dev = _compute_deviatoric_correction( cell=state.cell, volume=volume, initial_cell_inv=state.initial_cell_inv, initial_volume=state.initial_volume, external_pressure_tensor=ext_p_tensor, ) effective_p_ext = pressure_hydro + trace_pressure_dev / 3 else: effective_p_ext = external_pressure ## Step 1: propagate sqrt(volume) for dt/2 change_sqrt_vol = -prefactor * ( effective_p_ext - trace_P_int / 3 - kT / (2 * volume) ) * dt / 2 + prefactor_random * _randn_for_state(state, sqrt_vol.shape) new_sqrt_volume = sqrt_vol + change_sqrt_vol ## Step 2: compute deformation matrix random_coeff = 2 * state.isothermal_compressibility * kT * dt / (3 * state.tau_p) prefactor_random_matrix = torch.sqrt(random_coeff) / new_sqrt_volume I = torch.eye( # noqa: E741 3, device=state.positions.device, dtype=state.positions.dtype ).expand_as(P_int) # Driving force: traceless part of (P_int - pressure_dev) P_drive = P_int if deviatoric: P_drive = P_int - pressure_dev trace_P_drive = torch.einsum("bii->b", P_drive) a_tilde = (state.isothermal_compressibility / (3 * state.tau_p))[:, None, None] * ( P_drive - trace_P_drive[:, None, None] / 3 * I ) random_matrix = torch.randn( state.n_systems, 3, 3, device=state.positions.device, dtype=state.positions.dtype, generator=state.rng, ) random_matrix_tilde = ( random_matrix - torch.einsum("bii->b", random_matrix)[:, None, None] / 3 * I ) deformation_matrix = torch.matrix_exp( a_tilde * dt + prefactor_random_matrix[:, None, None] * random_matrix_tilde ) deformation_matrix = rotate_gram_schmidt(deformation_matrix) ## Step 3: propagate sqrt(volume) for dt/2 new_sqrt_volume += -prefactor * ( effective_p_ext - trace_P_int / 3 - kT / (2 * volume) ) * dt / 2 + prefactor_random * _randn_for_state(state, sqrt_vol.shape) rscaling = deformation_matrix * torch.pow((new_sqrt_volume / sqrt_vol), 2 / 3).view( -1, 1, 1 ) vscaling = torch.inverse(rscaling).transpose(-2, -1) # Update positions and momenta (barostat + half momentum step) state.positions = batch_matrix_vector( rscaling[state.system_idx], state.positions ) + batch_matrix_vector( (vscaling + rscaling)[state.system_idx], state.momenta ) * dt / (2 * state.masses.unsqueeze(-1)) state.momenta = batch_matrix_vector(vscaling[state.system_idx], state.momenta) # Right multiply: cell @ rscaling^T preserves fractional coordinates state.cell = state.cell @ rscaling.mT return state def _crescale_anisotropic_barostat_step( state: NPTCRescaleState, kT: torch.Tensor, dt: torch.Tensor, external_pressure: torch.Tensor, ) -> NPTCRescaleState: volume = torch.det(state.cell) # shape: (n_systems,) P_int = ts.quantities.compute_instantaneous_pressure_tensor( momenta=state.momenta, masses=state.masses, system_idx=state.system_idx, stress=state.stress, volumes=volume, ) sqrt_vol = torch.sqrt(volume) trace_P_int = torch.einsum("bii->b", P_int) prefactor_random = torch.sqrt( kT * state.isothermal_compressibility * dt / (4 * state.tau_p) ) prefactor = state.isothermal_compressibility * sqrt_vol / (2 * state.tau_p) change_sqrt_vol = -prefactor * ( external_pressure - trace_P_int / 3 - kT / (2 * volume) ) * dt / 2 + prefactor_random * _randn_for_state(state, sqrt_vol.shape) new_sqrt_volume = sqrt_vol + change_sqrt_vol ## Step 2: compute deformation matrix prefactor_random_matrix = ( torch.sqrt(2 * state.isothermal_compressibility * kT * dt / (3 * state.tau_p)) / new_sqrt_volume ) # Note: it corresponds to using a diagonal isothermal compressibility tensor P_int_diagonal = torch.diagonal(P_int, dim1=-2, dim2=-1) a_tilde = (state.isothermal_compressibility / (3 * state.tau_p))[:, None] * ( P_int_diagonal - trace_P_int[:, None] / 3 ) random_matrix = torch.randn( state.n_systems, 3, device=state.positions.device, dtype=state.positions.dtype, generator=state.rng, ) random_matrix_tilde = random_matrix - torch.mean(random_matrix, dim=1, keepdim=True) deformation_matrix = torch.exp( a_tilde * dt + prefactor_random_matrix[:, None] * random_matrix_tilde ) ## Step 3: propagate sqrt(volume) for dt/2 new_sqrt_volume += -prefactor * ( external_pressure - trace_P_int / 3 - kT / (2 * volume) ) * dt / 2 + prefactor_random * _randn_for_state(state, sqrt_vol.shape) rscaling = deformation_matrix * torch.pow( (new_sqrt_volume / sqrt_vol), 2 / 3 ).unsqueeze(-1) # Update positions and momenta (barostat + half momentum step) state.positions = rscaling[state.system_idx] * state.positions + ( rscaling + 1 / rscaling )[state.system_idx] * state.momenta * dt / (2 * state.masses.unsqueeze(-1)) state.momenta = (1 / rscaling)[state.system_idx] * state.momenta state.cell = torch.diag_embed(rscaling) @ state.cell return state
[docs] def compute_average_pressure_tensor( *, degrees_of_freedom: torch.Tensor, kT: torch.Tensor, stress: torch.Tensor, volumes: torch.Tensor, ) -> torch.Tensor: """Compute forces on the cell for NPT dynamics. This function calculates the instantaneous internal pressure tensor. Args: degrees_of_freedom (torch.Tensor): Degrees of freedom of the system, shape (n_systems,) kT (torch.Tensor): Thermal energy (k_B * T), shape (n_systems,) stress (torch.Tensor): Stress tensor of the system, shape (n_systems, 3, 3) volumes (torch.Tensor): Volumes of the systems, shape (n_systems,) Returns: torch.Tensor: Instanteneous internal pressure tesnor [n_systems, 3, 3] """ # Calculate virials: 2/V * (N_{atoms}k_B T / 2 - Virial_{tensor}) n_systems = stress.shape[0] prefactor = degrees_of_freedom * kT / volumes # shape: (n_systems,) average_kinetic_energy_tensor = prefactor[:, None, None] * torch.eye( 3, device=stress.device, dtype=stress.dtype ).expand(n_systems, 3, 3) return average_kinetic_energy_tensor - stress
def _crescale_triclinic_average_barostat_step( state: NPTCRescaleState, kT: torch.Tensor, dt: torch.Tensor, external_pressure: torch.Tensor, ) -> NPTCRescaleState: volume = torch.det(state.cell) # shape: (n_systems,) P_int = compute_average_pressure_tensor( degrees_of_freedom=state.get_number_of_degrees_of_freedom() / 3, kT=kT, stress=state.stress, volumes=volume, ) sqrt_vol = torch.sqrt(volume) trace_P_int = torch.einsum("bii->b", P_int) prefactor_random = torch.sqrt( kT * state.isothermal_compressibility * dt / (4 * state.tau_p) ) prefactor = state.isothermal_compressibility * sqrt_vol / (2 * state.tau_p) change_sqrt_vol = -prefactor * ( external_pressure - trace_P_int / 3 - kT / (2 * volume) ) * dt / 2 + prefactor_random * _randn_for_state(state, sqrt_vol.shape) new_sqrt_volume = sqrt_vol + change_sqrt_vol ## Step 2: compute deformation matrix prefactor_random_matrix = ( torch.sqrt(2 * state.isothermal_compressibility * kT * dt / (3 * state.tau_p)) / new_sqrt_volume ) a_tilde = (state.isothermal_compressibility / (3 * state.tau_p))[:, None, None] * ( P_int - trace_P_int[:, None, None] / 3 * torch.eye( 3, device=state.positions.device, dtype=state.positions.dtype ).expand_as(P_int) ) random_matrix = torch.randn( state.n_systems, 3, 3, device=state.positions.device, dtype=state.positions.dtype, generator=state.rng, ) random_matrix_tilde = random_matrix - torch.einsum("bii->b", random_matrix)[ :, None, None ] / 3 * torch.eye( 3, device=state.positions.device, dtype=state.positions.dtype ).expand_as(random_matrix) deformation_matrix = torch.matrix_exp( a_tilde * dt + prefactor_random_matrix[:, None, None] * random_matrix_tilde ) deformation_matrix = rotate_gram_schmidt(deformation_matrix) ## Step 3: propagate sqrt(volume) for dt/2 new_sqrt_volume += -prefactor * ( external_pressure - trace_P_int / 3 - kT / (2 * volume) ) * dt / 2 + prefactor_random * _randn_for_state(state, sqrt_vol.shape) rscaling = deformation_matrix * torch.pow((new_sqrt_volume / sqrt_vol), 2 / 3).view( -1, 1, 1 ) # Update positions and momenta (barostat + half momentum step) state.positions = batch_matrix_vector( rscaling[state.system_idx], state.positions ) + batch_matrix_vector( ( torch.eye( 3, device=state.positions.device, dtype=state.positions.dtype ).expand_as(rscaling) + rscaling )[state.system_idx], state.momenta, ) * dt / (2 * state.masses.unsqueeze(-1)) # Right multiply: cell @ rscaling^T preserves fractional coordinates state.cell = state.cell @ rscaling.mT return state def _crescale_isotropic_barostat_step( state: NPTCRescaleState, kT: torch.Tensor, dt: torch.Tensor, external_pressure: torch.Tensor, ) -> NPTCRescaleState: volume = torch.det(state.cell) # shape: (n_systems,) P_int = ts.quantities.compute_instantaneous_pressure_tensor( momenta=state.momenta, masses=state.masses, system_idx=state.system_idx, stress=state.stress, volumes=volume, ) sqrt_vol = torch.sqrt(volume) trace_P_int = torch.einsum("bii->b", P_int) prefactor_random = torch.sqrt( kT * state.isothermal_compressibility * dt / (4 * state.tau_p) ) prefactor = state.isothermal_compressibility * sqrt_vol / (2 * state.tau_p) change_sqrt_vol = -prefactor * ( external_pressure - trace_P_int / 3 - kT / (2 * volume) ) * dt + torch.sqrt( 2 * torch.ones_like(sqrt_vol) ) * prefactor_random * _randn_for_state(state, sqrt_vol.shape) new_sqrt_volume = sqrt_vol + change_sqrt_vol # Update positions and momenta (barostat + half momentum step) # SI (S13ab): notice there is a typo in the SI where q_i(t) # should be scaled as well by rscaling rscaling = torch.pow((new_sqrt_volume / sqrt_vol), 2 / 3).unsqueeze(-1) state.positions = rscaling[state.system_idx] * state.positions + ( rscaling + 1 / rscaling )[state.system_idx] * state.momenta * (0.5 * dt) / state.masses.unsqueeze(-1) state.momenta = (1 / rscaling)[state.system_idx] * state.momenta rscaling = rscaling.unsqueeze(-1) # make [n_systems, 1, 1] state.cell = rscaling * state.cell return state def _coerce_crescale_step_inputs( state: NPTCRescaleState, dt: float | torch.Tensor, kT: float | torch.Tensor, external_pressure: float | torch.Tensor, tau: float | torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Normalize scalar-or-tensor C-rescale step parameters to state tensors.""" device, dtype = state.device, state.dtype dt_tensor = torch.as_tensor(dt, device=device, dtype=dtype) kT_tensor = torch.as_tensor(kT, device=device, dtype=dtype) external_pressure_tensor = torch.as_tensor( external_pressure, device=device, dtype=dtype ) tau_tensor = torch.as_tensor( 1 * dt_tensor if tau is None else tau, device=device, dtype=dtype ) return dt_tensor, kT_tensor, external_pressure_tensor, tau_tensor
[docs] @dcite("10.1063/5.0020514") @dcite("10.3390/app12031139") def npt_crescale_triclinic_step( state: NPTCRescaleState, model: ModelInterface, *, dt: float | torch.Tensor, kT: float | torch.Tensor, external_pressure: float | torch.Tensor, tau: float | torch.Tensor | None = None, ) -> NPTCRescaleState: r"""Perform one NPT integration step with anisotropic stochastic cell rescaling. Implements the anisotropic C-Rescale barostat from Del Tatto et al. (2022) [7]_ extending the isotropic scheme of Bernetti & Bussi (2020) [6]_. Cell lengths and angles can change independently. Uses instantaneous kinetic energy. Both positions and momenta are scaled. **Trotter splitting:** V-Rescale(dt/2) -> B(dt/2) -> Barostat(dt) -> Force eval -> B(dt/2) -> V-Rescale(dt/2) **Barostat sub-steps** (3-step volume + deformation update): Step 1 -- Propagate :math:`\sqrt{V}` for :math:`\Delta t/2` (same SDE as isotropic, Eq. 7 of [6]_): .. math:: \Delta\lambda = -\frac{\beta_T\lambda}{2\tau_p} \left(P_0 - \frac{\text{Tr}(\mathbf{P}_{\text{int}})}{3} - \frac{k_BT}{2V}\right)\frac{\Delta t}{2} + \sqrt{\frac{k_BT\beta_T\Delta t}{4\tau_p}}\;R Step 2 -- Compute deviatoric deformation matrix: .. math:: \tilde{\mathbf{A}} &= \frac{\beta_T}{3\tau_p} \left(\mathbf{P}_{\text{int}} - \frac{\text{Tr}(\mathbf{P}_{\text{int}})}{3}\,\mathbf{I}\right) \\ \boldsymbol{\mu}_{\text{dev}} &= \exp\bigl(\tilde{\mathbf{A}}\,\Delta t + \sigma\,\tilde{\mathbf{R}}\bigr) where :math:`\sigma = \sqrt{2\beta_T k_BT\Delta t/(3\tau_p)}\;/\;\sqrt{V'}` and :math:`\tilde{\mathbf{R}}` is a traceless random matrix. Step 3 -- Propagate :math:`\sqrt{V}` for :math:`\Delta t/2` (same as step 1). **Total scaling and update:** .. math:: \boldsymbol{\mu} &= \boldsymbol{\mu}_{\text{dev}} \cdot (V'/V)^{1/3} \\ \mathbf{r}_i &\leftarrow \boldsymbol{\mu}\,\mathbf{r}_i + (\boldsymbol{\mu}^{-T} + \boldsymbol{\mu})\, \frac{\mathbf{p}_i}{2m_i}\,\Delta t \\ \mathbf{p}_i &\leftarrow \boldsymbol{\mu}^{-T}\,\mathbf{p}_i \\ \mathbf{h} &\leftarrow \mathbf{h}\,\boldsymbol{\mu}^T **Variable mapping (equation -> code):** ============================================ ================================ Equation symbol Code variable ============================================ ================================ :math:`V` (volume) ``volume`` :math:`\lambda` (:math:`\sqrt{V}`) ``sqrt_vol`` :math:`\beta_T` (compressibility) ``state.isothermal_compressibility`` :math:`\tau_p` (barostat relax. time) ``state.tau_p`` :math:`P_0` (target pressure) ``external_pressure`` :math:`\mathbf{P}_{\text{int}}` (press. tensor) ``P_int`` :math:`\tilde{\mathbf{A}}` (deviator drive) ``a_tilde`` :math:`\boldsymbol{\mu}_{\text{dev}}` ``deformation_matrix`` :math:`\boldsymbol{\mu}` (total scaling) ``rscaling`` :math:`\boldsymbol{\mu}^{-T}` (mom. scaling) ``vscaling`` :math:`\tilde{\mathbf{R}}` (traceless noise) ``random_matrix_tilde`` :math:`\sigma` (noise prefactor) ``prefactor_random_matrix`` :math:`k_BT` (thermal energy) ``kT`` :math:`\Delta t` (timestep) ``dt`` :math:`\tau` (thermostat relax.) ``tau`` (V-Rescale) ============================================ ================================ Args: model: Model to compute forces and energies state: Current system state dt: Integration timestep kT: Target temperature external_pressure: Target external pressure tau: V-Rescale thermostat relaxation time. If None, defaults to 100*dt Returns: NPTCRescaleState: Updated state after one integration step References: .. [7] Del Tatto, V., et al. "Molecular dynamics of solids at constant pressure and stress using anisotropic stochastic cell rescaling." Applied Sciences 12(3), 1139 (2022). .. [6] Bernetti, M. & Bussi, G. "Pressure control using stochastic cell rescaling." J. Chem. Phys. 153, 114107 (2020). """ dt_tensor, kT_tensor, external_pressure_tensor, tau_tensor = ( _coerce_crescale_step_inputs(state, dt, kT, external_pressure, tau) ) state = _vrescale_update(state, tau_tensor, kT_tensor, dt_tensor / 2) state = momentum_step(state, dt_tensor / 2) # Barostat step state = _crescale_triclinic_barostat_step( state, kT_tensor, dt_tensor, external_pressure_tensor ) # Forces model_output = model(state) state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt_tensor / 2) # Final thermostat step return _vrescale_update(state, tau_tensor, kT_tensor, dt_tensor / 2)
[docs] @dcite("10.1063/5.0020514") @dcite("10.3390/app12031139") def npt_crescale_anisotropic_step( state: NPTCRescaleState, model: ModelInterface, *, dt: float | torch.Tensor, kT: float | torch.Tensor, external_pressure: float | torch.Tensor, tau: float | torch.Tensor | None = None, ) -> NPTCRescaleState: """Perform one NPT integration step with cell rescaling barostat. This function performs a single integration step for NPT dynamics using a cell rescaling barostat. It updates particle positions, momenta, and the simulation cell based on the target temperature and pressure. Trotter based splitting: 1. Half Thermostat (velocity scaling) 2. Half Update momenta with forces 3. Barostat (cell rescaling) 4. Update positions (from barostat + half momenta) 5. Update forces with new positions and cell 6. Compute forces 7. Half Update momenta with forces 8. Half Thermostat (velocity scaling) Only allow isotropic external stress. This method has 3 degrees of freedom for each cell length, allowing independent scaling of each cell vector. Inspired from: https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp - Time reversible integrator - Instantaneous kinetic energy (not not the average from equipartition) Args: model (ModelInterface): Model to compute forces and energies state (NPTCRescaleState): Current system state dt (torch.Tensor): Integration timestep kT (torch.Tensor): Target temperature external_pressure (torch.Tensor): Target external pressure tau (torch.Tensor | None): V-Rescale thermostat relaxation time. If None, defaults to 100*dt Returns: NPTCRescaleState: Updated state after one integration step """ device, dtype = model.device, model.dtype dt = torch.as_tensor(dt, device=device, dtype=dtype) kT = torch.as_tensor(kT, device=device, dtype=dtype) external_pressure = torch.as_tensor(external_pressure, device=device, dtype=dtype) # Note: would probably be better to have tau in NVTCRescaleState tau = torch.as_tensor(tau or 100 * dt, device=device, dtype=dtype) state = _vrescale_update(state, tau, kT, dt / 2) state = momentum_step(state, dt / 2) # Barostat step state = _crescale_anisotropic_barostat_step(state, kT, dt, external_pressure) # Forces model_output = model(state) state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) # Final thermostat step return _vrescale_update(state, tau, kT, dt / 2)
[docs] @dcite("10.1063/5.0020514") @dcite("10.3390/app12031139") def npt_crescale_triclinic_average_step( state: NPTCRescaleState, model: ModelInterface, *, dt: float | torch.Tensor, kT: float | torch.Tensor, external_pressure: float | torch.Tensor, tau: float | torch.Tensor | None = None, ) -> NPTCRescaleState: """Perform one NPT integration step with cell rescaling barostat. This function performs a single integration step for NPT dynamics using a cell rescaling barostat. It updates particle positions, momenta, and the simulation cell based on the target temperature and pressure. Trotter based splitting: 1. Half Thermostat (velocity scaling) 2. Half Update momenta with forces 3. Barostat (cell rescaling) 4. Update positions (from barostat + half momenta) 5. Update forces with new positions and cell 6. Compute forces 7. Half Update momenta with forces 8. Half Thermostat (velocity scaling) Only allow isotropic external stress. This method performs anisotropic cell rescaling. Lengths and angles can change independently. Based on pressure using average kinetic energy from equipartition theorem. Only positions are scaled when scaling the cell. Inspired from: https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp - Time reversible integrator - Average kinetic energy, scaling only positions Args: model (ModelInterface): Model to compute forces and energies state (NPTCRescaleState): Current system state dt (torch.Tensor): Integration timestep kT (torch.Tensor): Target temperature external_pressure (torch.Tensor): Target external pressure tau (torch.Tensor | None): V-Rescale thermostat relaxation time. If None, defaults to 100*dt Returns: NPTCRescaleState: Updated state after one integration step """ device, dtype = model.device, model.dtype dt = torch.as_tensor(dt, device=device, dtype=dtype) kT = torch.as_tensor(kT, device=device, dtype=dtype) external_pressure = torch.as_tensor(external_pressure, device=device, dtype=dtype) # Note: would probably be better to have tau in NVTCRescaleState tau = torch.as_tensor(tau or 1 * dt, device=device, dtype=dtype) state = _vrescale_update(state, tau, kT, dt / 2) state = momentum_step(state, dt / 2) # Barostat step state = _crescale_triclinic_average_barostat_step(state, kT, dt, external_pressure) # Forces model_output = model(state) state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) # Final thermostat step return _vrescale_update(state, tau, kT, dt / 2)
[docs] @dcite("10.1063/5.0020514") def npt_crescale_isotropic_step( state: NPTCRescaleState, model: ModelInterface, *, dt: float | torch.Tensor, kT: float | torch.Tensor, external_pressure: float | torch.Tensor, tau: float | torch.Tensor | None = None, ) -> NPTCRescaleState: r"""Perform one NPT integration step with isotropic stochastic cell rescaling. Implements isotropic C-Rescale from Bernetti & Bussi (2020) [6]_. Cell shape is preserved; cell lengths are scaled equally. **Trotter splitting:** V-Rescale(dt/2) -> B(dt/2) -> Barostat(dt) -> Force eval -> B(dt/2) -> V-Rescale(dt/2) **Isotropic volume SDE** (Eq. 7 of [6]_, using :math:`\lambda = \sqrt{V}`): .. math:: d\lambda = -\frac{\beta_T\lambda}{2\tau_p} \left(P_0 - \frac{\text{Tr}(\mathbf{P}_{\text{int}})}{3} - \frac{k_BT}{2V}\right) dt + \sqrt{\frac{k_BT\,\beta_T}{2\tau_p}}\;dW where :math:`\beta_T` is the isothermal compressibility and :math:`\mathbf{P}_{\text{int}}` is the instantaneous pressure tensor (including the kinetic contribution). **Position and momentum scaling** (SI Eqs. S13a-b of [6]_, corrected): .. math:: \mathbf{r}_i &\leftarrow \mu\,\mathbf{r}_i + (\mu + \mu^{-1})\,\frac{\mathbf{p}_i}{2m_i}\,\Delta t \\ \mathbf{p}_i &\leftarrow \mu^{-1}\,\mathbf{p}_i \\ \mathbf{h} &\leftarrow \mu\,\mathbf{h} where :math:`\mu = (V'/V)^{1/3}` is the isotropic scaling factor and :math:`\mathbf{h}` is the cell matrix. **Variable mapping (equation -> code):** ============================================ ================================ Equation symbol Code variable ============================================ ================================ :math:`V` (volume) ``volume`` :math:`\lambda` (:math:`\sqrt{V}`) ``sqrt_vol`` :math:`\beta_T` (compressibility) ``state.isothermal_compressibility`` :math:`\tau_p` (barostat relax. time) ``state.tau_p`` :math:`P_0` (target pressure) ``external_pressure`` :math:`\mathbf{P}_{\text{int}}` (press. tensor) ``P_int`` :math:`\text{Tr}(\mathbf{P}_{\text{int}})` ``trace_P_int`` :math:`\mu` (scaling factor) ``rscaling`` :math:`k_BT` (thermal energy) ``kT`` :math:`\Delta t` (timestep) ``dt`` :math:`\tau` (thermostat relax.) ``tau`` (V-Rescale) ============================================ ================================ Args: model: Model to compute forces and energies state: Current system state dt: Integration timestep kT: Target temperature external_pressure: Target external pressure tau: V-Rescale thermostat relaxation time. If None, defaults to 100*dt Returns: NPTCRescaleState: Updated state after one integration step References: .. [6] Bernetti, M. & Bussi, G. "Pressure control using stochastic cell rescaling." J. Chem. Phys. 153, 114107 (2020). Note: SI Eq. S13a has a typo (positions must also be scaled by mu). """ device, dtype = model.device, model.dtype dt = torch.as_tensor(dt, device=device, dtype=dtype) kT = torch.as_tensor(kT, device=device, dtype=dtype) external_pressure = torch.as_tensor(external_pressure, device=device, dtype=dtype) # Note: would probably be better to have tau in NVTCRescaleState tau = torch.as_tensor(tau or 1 * dt, device=device, dtype=dtype) state = _vrescale_update(state, tau, kT, dt / 2) state = momentum_step(state, dt / 2) # Barostat step state = _crescale_isotropic_barostat_step(state, kT, dt, external_pressure) # Forces model_output = model(state) state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) # Final thermostat step return _vrescale_update(state, tau, kT, dt / 2)
[docs] def npt_crescale_init( state: SimState, model: ModelInterface, *, kT: float | torch.Tensor, dt: float | torch.Tensor, tau_p: float | torch.Tensor | None = None, isothermal_compressibility: float | torch.Tensor | None = None, ) -> NPTCRescaleState: """Initialize the NPT cell rescaling state. This function initializes a state for NPT molecular dynamics with a cell rescaling barostat. It sets up the system with appropriate initial conditions including particle positions, momenta, and cell variables. Only allow isotropic external stress, but can run both isotropic and anisotropic cell rescaling. To seed the RNG set ``state.rng = seed`` before calling. Args: state: Initial system state as SimState containing positions, masses, cell, and PBC information model (ModelInterface): Model to compute forces and energies kT: Target temperature in energy units dt: Integration timestep tau_p: Barostat relaxation time. Controls how quickly pressure equilibrates. isothermal_compressibility: Isothermal compressibility of the system. """ device, dtype = model.device, model.dtype # Convert all parameters to tensors with correct device and dtype dt = torch.as_tensor(dt, device=device, dtype=dtype) kT = torch.as_tensor(kT, device=device, dtype=dtype) # Set default values if not provided tau_p = torch.as_tensor(tau_p or 3 * dt, device=device, dtype=dtype) # 5ps for dt=1fs isothermal_compressibility = torch.as_tensor( isothermal_compressibility or 1e-6 / MetalUnits.pressure, # 1e-6 bar^-1 for metals device=device, dtype=dtype, # (eV/A^3)^-1 ) if tau_p.ndim == 0: tau_p = tau_p.expand(state.n_systems) if isothermal_compressibility.ndim == 0: isothermal_compressibility = isothermal_compressibility.expand(state.n_systems) # Get model output to initialize forces and stress model_output = model(state) # Initialize momenta if not provided momenta = getattr(state, "momenta", None) if momenta is None: momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT, state.rng, ) # Store initial cell for deviatoric correction initial_cell = state.cell.clone() initial_cell_inv = torch.inverse(initial_cell) initial_volume = torch.det(initial_cell) # Create the initial state npt_state = NPTCRescaleState.from_state( state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], stress=model_output["stress"], tau_p=tau_p, isothermal_compressibility=isothermal_compressibility, initial_cell=initial_cell, initial_cell_inv=initial_cell_inv, initial_volume=initial_volume, ) npt_state.store_model_extras(model_output) return npt_state