Source code for torch_sim.models.electrostatics

"""Electrostatics models: DSF, Ewald, and PME.

Wraps the ``nvalchemiops`` Warp-accelerated electrostatics implementations as
:class:`~torch_sim.models.interface.ModelInterface` subclasses, with full PBC,
stress (virial), and batched system support.  Per-atom partial charges are read
from ``state.partial_charges`` (a SimState atom extra).
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from nvalchemiops.torch.interactions.electrostatics import (
    dsf_coulomb,
    ewald_summation,
    particle_mesh_ewald,
)

from torch_sim._duecredit import dcite
from torch_sim.models.interface import ModelInterface
from torch_sim.neighbors import torchsim_nl
from torch_sim.units import UnitConversion


if TYPE_CHECKING:
    from collections.abc import Callable

    from torch_sim.state import SimState


def _zero_result(
    state: SimState,
    dtype: torch.dtype,
    compute_forces: bool,  # noqa: FBT001
    compute_stress: bool,  # noqa: FBT001
) -> dict[str, torch.Tensor]:
    """Return zero energy / forces / stress for non-periodic states."""
    dev = state.positions.device
    results: dict[str, torch.Tensor] = {
        "energy": torch.zeros(state.n_systems, dtype=dtype, device=dev),
    }
    if compute_forces:
        results["forces"] = torch.zeros(state.n_atoms, 3, dtype=dtype, device=dev)
    if compute_stress:
        results["stress"] = torch.zeros(state.n_systems, 3, 3, dtype=dtype, device=dev)
    return results


def _build_csr(
    state: SimState,
    cutoff: float,
    neighbor_list_fn: Callable,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Build a CSR neighbor list and integer unit-shift tensor."""
    edge_index, _mapping, unit_shifts = neighbor_list_fn(
        state.positions,
        state.row_vector_cell,
        state.pbc,
        cutoff,
        state.system_idx,
    )
    n_atoms = state.positions.shape[0]
    dev = state.positions.device
    neighbor_ptr = torch.zeros(n_atoms + 1, dtype=torch.int32, device=dev)
    neighbor_ptr[1:] = (
        torch.bincount(edge_index[0], minlength=n_atoms).cumsum(0).to(torch.int32)
    )
    return (
        edge_index.to(torch.int32),
        neighbor_ptr,
        unit_shifts.to(torch.int32),
    )


[docs] class DSFCoulombModel(ModelInterface): """Damped Shifted Force electrostatics as a :class:`ModelInterface`. Uses the ``nvalchemiops`` DSF kernel for O(N) electrostatic energy, forces, and (optionally) stress. All user-facing quantities are in metal units (Angstrom / eV); the Coulomb constant ``ke`` is baked in. Per-atom partial charges are read from ``state.partial_charges``. Args: cutoff: Real-space cutoff in Angstrom. alpha: DSF damping parameter. 0.0 gives shifted-force bare Coulomb. device: Compute device. Defaults to CUDA if available, else CPU. dtype: Floating-point dtype. Defaults to ``torch.float64``. compute_forces: Whether to return forces. Defaults to True. compute_stress: Whether to return stress. Defaults to True. neighbor_list_fn: Neighbor-list constructor. Defaults to ``torchsim_nl``. """ @dcite("10.1063/1.2206581", description="Fennell & Gezelter DSF method") def __init__( self, cutoff: float = 10.0, *, alpha: float = 0.2, device: torch.device | None = None, dtype: torch.dtype = torch.float64, compute_forces: bool = True, compute_stress: bool = True, neighbor_list_fn: Callable = torchsim_nl, ) -> None: """Initialize the DSF Coulomb model.""" super().__init__() self._device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) self._dtype = dtype self._compute_forces = compute_forces self._compute_stress = compute_stress self._memory_scales_with = "n_atoms_x_density" self.neighbor_list_fn = neighbor_list_fn self.cutoff = cutoff self.alpha = alpha
[docs] def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: """Compute DSF electrostatic energy, forces, and stress. Args: state: Simulation state with ``partial_charges`` set as an atom extra (shape ``[n_atoms]``). **_kwargs: Unused; accepted for interface compatibility. Returns: dict with ``"energy"`` [n_systems], ``"forces"`` [n_atoms, 3], and (if ``compute_stress``) ``"stress"`` [n_systems, 3, 3]. """ if not state.has_extras("partial_charges"): raise ValueError("Partial charges are required for DSF Coulomb summation.") charges = state.partial_charges edge_index, neighbor_ptr, unit_shifts = _build_csr( state, self.cutoff, self.neighbor_list_fn ) cell = state.row_vector_cell.contiguous() dsf_args: dict = dict( positions=state.positions, charges=charges, cutoff=self.cutoff, alpha=self.alpha, neighbor_list=edge_index, neighbor_ptr=neighbor_ptr, unit_shifts=unit_shifts, cell=cell, batch_idx=state.system_idx.to(torch.int32), compute_forces=self._compute_forces, compute_virial=self._compute_stress, num_systems=state.n_systems, ) out = dsf_coulomb(**dsf_args) if not isinstance(out, tuple): out = (out,) energy = (out[0] * UnitConversion.e2_per_Ang_to_eV).to(self._dtype).detach() results: dict[str, torch.Tensor] = {"energy": energy} if self._compute_forces: forces = out[1] * UnitConversion.e2_per_Ang_to_eV # type: ignore[index] results["forces"] = forces.to(self._dtype).detach() if self._compute_stress: volumes = state.volume.unsqueeze(-1).unsqueeze(-1) stress = (out[-1] * UnitConversion.e2_per_Ang_to_eV) / volumes results["stress"] = stress.to(self._dtype).detach() return results
[docs] class EwaldModel(ModelInterface): """Classical Ewald summation as a :class:`ModelInterface`. Uses the ``nvalchemiops`` Ewald kernel for exact periodic electrostatics. Returns per-atom energies that are aggregated to per-system. All user-facing quantities are in metal units (Angstrom / eV). Per-atom partial charges are read from ``state.partial_charges``. Requires periodic boundary conditions. Args: cutoff: Real-space cutoff in Angstrom. accuracy: Target accuracy for auto-estimated Ewald parameters. device: Compute device. Defaults to CUDA if available, else CPU. dtype: Floating-point dtype. Defaults to ``torch.float64``. compute_forces: Whether to return forces. Defaults to True. compute_stress: Whether to return stress. Defaults to True. neighbor_list_fn: Neighbor-list constructor. Defaults to ``torchsim_nl``. """ @dcite("10.1002/andp.19213690304", description="Ewald summation method") def __init__( self, cutoff: float = 10.0, *, accuracy: float = 1e-6, device: torch.device | None = None, dtype: torch.dtype = torch.float64, compute_forces: bool = True, compute_stress: bool = True, neighbor_list_fn: Callable = torchsim_nl, ) -> None: """Initialize the Ewald model.""" super().__init__() self._device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) self._dtype = dtype self._compute_forces = compute_forces self._compute_stress = compute_stress self._memory_scales_with = "n_atoms_x_density" self.neighbor_list_fn = neighbor_list_fn self.cutoff = cutoff self.accuracy = accuracy
[docs] def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: """Compute Ewald electrostatic energy, forces, and stress. Args: state: Simulation state with ``partial_charges`` set as an atom extra (shape ``[n_atoms]``). Returns zeros for non-periodic states. **_kwargs: Unused; accepted for interface compatibility. Returns: dict with ``"energy"`` [n_systems], ``"forces"`` [n_atoms, 3], and (if ``compute_stress``) ``"stress"`` [n_systems, 3, 3]. """ if not state.has_extras("partial_charges"): raise ValueError("Partial charges are required for Ewald summation.") if not state.pbc.any(): return _zero_result( state, self._dtype, self._compute_forces, self._compute_stress ) charges = state.partial_charges edge_index, neighbor_ptr, unit_shifts = _build_csr( state, self.cutoff, self.neighbor_list_fn ) cell = state.row_vector_cell.contiguous() out = ewald_summation( positions=state.positions, charges=charges, cell=cell, neighbor_list=edge_index, neighbor_ptr=neighbor_ptr, neighbor_shifts=unit_shifts, batch_idx=state.system_idx.to(torch.int32), compute_forces=self._compute_forces, compute_virial=self._compute_stress, accuracy=self.accuracy, ) if not isinstance(out, tuple): out = (out,) per_atom_energy = out[0] * UnitConversion.e2_per_Ang_to_eV dev = state.positions.device energy = torch.zeros(state.n_systems, dtype=torch.float64, device=dev) energy.scatter_add_(0, state.system_idx.long(), per_atom_energy) results: dict[str, torch.Tensor] = { "energy": energy.to(self._dtype).detach(), } if self._compute_forces: forces = out[1] * UnitConversion.e2_per_Ang_to_eV # type: ignore[index] results["forces"] = forces.to(self._dtype).detach() if self._compute_stress: volumes = state.volume.unsqueeze(-1).unsqueeze(-1) stress = (out[-1] * UnitConversion.e2_per_Ang_to_eV) / volumes results["stress"] = stress.to(self._dtype).detach() return results
[docs] class PMEModel(ModelInterface): """Particle Mesh Ewald electrostatics as a :class:`ModelInterface`. Uses the ``nvalchemiops`` PME kernel for O(N log N) periodic electrostatics. Returns per-atom energies that are aggregated to per-system. All user-facing quantities are in metal units (Angstrom / eV). Per-atom partial charges are read from ``state.partial_charges``. Requires periodic boundary conditions. Args: cutoff: Real-space cutoff in Angstrom. accuracy: Target accuracy for auto-estimated parameters. mesh_spacing: Optional mesh spacing (Angstrom) for automatic mesh sizing. mesh_dimensions: Explicit FFT mesh dimensions ``(nx, ny, nz)``. spline_order: B-spline interpolation order. Defaults to 4. device: Compute device. Defaults to CUDA if available, else CPU. dtype: Floating-point dtype. Defaults to ``torch.float64``. compute_forces: Whether to return forces. Defaults to True. compute_stress: Whether to return stress. Defaults to True. neighbor_list_fn: Neighbor-list constructor. Defaults to ``torchsim_nl``. """ @dcite("10.1063/1.464397", description="Darden et al. PME method") def __init__( self, cutoff: float = 10.0, *, accuracy: float = 1e-6, mesh_spacing: float | None = None, mesh_dimensions: tuple[int, int, int] | None = None, spline_order: int = 4, device: torch.device | None = None, dtype: torch.dtype = torch.float64, compute_forces: bool = True, compute_stress: bool = True, neighbor_list_fn: Callable = torchsim_nl, ) -> None: """Initialize the PME model.""" super().__init__() self._device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) self._dtype = dtype self._compute_forces = compute_forces self._compute_stress = compute_stress self._memory_scales_with = "n_atoms_x_density" self.neighbor_list_fn = neighbor_list_fn self.cutoff = cutoff self.accuracy = accuracy self.mesh_spacing = mesh_spacing self.mesh_dimensions = mesh_dimensions self.spline_order = spline_order
[docs] def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: """Compute PME electrostatic energy, forces, and stress. Args: state: Simulation state with ``partial_charges`` set as an atom extra (shape ``[n_atoms]``). Returns zeros for non-periodic states. **_kwargs: Unused; accepted for interface compatibility. Returns: dict with ``"energy"`` [n_systems], ``"forces"`` [n_atoms, 3], and (if ``compute_stress``) ``"stress"`` [n_systems, 3, 3]. """ if not state.has_extras("partial_charges"): raise ValueError("Partial charges are required for PME summation.") if not state.pbc.any(): return _zero_result( state, self._dtype, self._compute_forces, self._compute_stress ) charges = state.partial_charges edge_index, neighbor_ptr, unit_shifts = _build_csr( state, self.cutoff, self.neighbor_list_fn ) cell = state.row_vector_cell.contiguous() batch_idx = state.system_idx.to(torch.int32) if state.n_systems > 1 else None pme_kwargs: dict = dict( positions=state.positions, charges=charges, cell=cell, neighbor_list=edge_index, neighbor_ptr=neighbor_ptr, neighbor_shifts=unit_shifts, batch_idx=batch_idx, compute_forces=self._compute_forces, compute_virial=self._compute_stress, accuracy=self.accuracy, spline_order=self.spline_order, ) if self.mesh_spacing is not None: pme_kwargs["mesh_spacing"] = self.mesh_spacing if self.mesh_dimensions is not None: pme_kwargs["mesh_dimensions"] = self.mesh_dimensions out = particle_mesh_ewald(**pme_kwargs) if not isinstance(out, tuple): out = (out,) per_atom_energy = out[0] * UnitConversion.e2_per_Ang_to_eV dev = state.positions.device energy = torch.zeros(state.n_systems, dtype=torch.float64, device=dev) energy.scatter_add_(0, state.system_idx.long(), per_atom_energy) results: dict[str, torch.Tensor] = { "energy": energy.to(self._dtype).detach(), } if self._compute_forces: forces = out[1] * UnitConversion.e2_per_Ang_to_eV # type: ignore[index] results["forces"] = forces.to(self._dtype).detach() if self._compute_stress: volumes = state.volume.unsqueeze(-1).unsqueeze(-1) stress = (out[-1] * UnitConversion.e2_per_Ang_to_eV) / volumes results["stress"] = stress.to(self._dtype).detach() return results