Source code for torch_sim.models.dispersion

"""DFT-D3(BJ) dispersion correction model.

Wraps the ``nvalchemiops`` Warp-accelerated DFT-D3 implementation as a
:class:`~torch_sim.models.interface.ModelInterface`, with full PBC, stress
(virial), and batched system support.

References:
    - Grimme et al., J. Chem. Phys. 132, 154104 (2010).
      https://doi.org/10.1063/1.3382344
    - Grimme et al., J. Comput. Chem. 32, 1456-1465 (2011).
      https://doi.org/10.1002/jcc.21759
    - nvalchemi-toolkit-ops: https://github.com/NVIDIA/nvalchemi-toolkit-ops
"""

from __future__ import annotations

import traceback
import warnings
from typing import TYPE_CHECKING, Any

import torch

from torch_sim._duecredit import dcite
from torch_sim.models.interface import ModelInterface
from torch_sim.neighbors import torchsim_nl
from torch_sim.units import UnitConversion


try:
    from nvalchemiops.torch.interactions.dispersion import D3Parameters
    from nvalchemiops.torch.interactions.dispersion import dftd3 as nvalchemiops_dftd3
except (ImportError, ModuleNotFoundError) as exc:
    warnings.warn(f"nvalchemiops import failed: {traceback.format_exc()}", stacklevel=2)

    class D3Parameters:
        """Placeholder when nvalchemiops is not installed."""

        def __init__(self, *_a: Any, _err: Exception = exc, **_kw: Any) -> None:
            """Raise the original import error."""
            raise _err

    def nvalchemiops_dftd3(*_a: Any, _err: Exception = exc, **_kw: Any) -> Any:
        """Raise the original import error."""
        raise _err


if TYPE_CHECKING:
    from collections.abc import Callable

    from torch_sim.state import SimState

_FORCE_CONV = UnitConversion.Hartree_to_eV / UnitConversion.Bohr_to_Ang


[docs] class D3DispersionModel(ModelInterface): """DFT-D3(BJ) dispersion correction as a :class:`ModelInterface`. Computes DFT-D3 energies, forces, and (optionally) stresses via the ``nvalchemiops`` Warp GPU/CPU kernels. All user-facing quantities are in metal units (Angstrom / eV); unit conversion to and from atomic units (Bohr / Hartree) is handled internally. Functional-dependent BJ damping parameters (``a1``, ``a2``, ``s8``, ``s6``) can be looked up from the canonical parameter table: https://github.com/dftd3/simple-dftd3/blob/main/assets/parameters.toml Args: a1: BJ damping parameter (dimensionless, functional-dependent). a2: BJ damping parameter (in Bohr, functional-dependent). s8: C8 scaling factor (dimensionless, functional-dependent). s6: C6 scaling factor. Defaults to 1.0. d3_params: Reference D3 parameters (rcov, r4r2, c6ab, cn_ref). cutoff: Neighbor-list cutoff in **Angstrom**. device: Compute device. Defaults to CUDA if available, else CPU. dtype: Floating-point dtype. Defaults to ``torch.float64``. compute_forces: Whether to return forces. Defaults to True. compute_stress: Whether to return stress. Defaults to True. neighbor_list_fn: Neighbor-list constructor. Defaults to ``torchsim_nl``. Example:: model = D3DispersionModel( a1=0.4289, a2=4.4407, s8=0.7875, d3_params=params, cutoff=50.0, ) results = model(sim_state) """ @dcite("10.1063/1.3382344") @dcite("10.1002/jcc.21759") def __init__( self, a1: float, a2: float, s8: float, *, s6: float = 1.0, d3_params: D3Parameters | None = None, cutoff: float = 95.0 * UnitConversion.Bohr_to_Ang, device: torch.device | None = None, dtype: torch.dtype = torch.float64, compute_forces: bool = True, compute_stress: bool = True, neighbor_list_fn: Callable = torchsim_nl, ) -> None: """Initialize the D3 dispersion model.""" super().__init__() self._device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) self._dtype = dtype self._compute_forces = compute_forces self._compute_stress = compute_stress self._memory_scales_with = "n_atoms_x_density" self.neighbor_list_fn = neighbor_list_fn self.cutoff = cutoff self.a1 = a1 self.a2 = a2 self.s8 = s8 self.s6 = s6 self.d3_params = d3_params
[docs] def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: """Compute D3 dispersion energy, forces, and stress. Args: state: Simulation state (positions in Angstrom, cell in Angstrom). **_kwargs: Unused; accepted for interface compatibility. Returns: dict with ``"energy"`` [n_systems], ``"forces"`` [n_atoms, 3], and (if ``compute_stress``) ``"stress"`` [n_systems, 3, 3]. """ edge_index, _mapping_system, unit_shifts = self.neighbor_list_fn( state.positions, state.row_vector_cell, state.pbc, self.cutoff, state.system_idx, ) n_atoms = state.positions.shape[0] neighbor_ptr = torch.zeros( n_atoms + 1, dtype=torch.int32, device=state.positions.device ) neighbor_ptr[1:] = ( torch.bincount(edge_index[0], minlength=n_atoms).cumsum(0).to(torch.int32) ) positions_bohr = state.positions * UnitConversion.Ang_to_Bohr cell_bohr = state.row_vector_cell.contiguous() * UnitConversion.Ang_to_Bohr numbers = state.atomic_numbers.to(torch.int32) unit_shifts_int = unit_shifts.to(torch.int32) edge_index_int = edge_index.to(torch.int32) d3_out = nvalchemiops_dftd3( positions=positions_bohr, numbers=numbers, a1=self.a1, a2=self.a2, s8=self.s8, s6=self.s6, d3_params=self.d3_params, neighbor_list=edge_index_int, neighbor_ptr=neighbor_ptr, cell=cell_bohr, unit_shifts=unit_shifts_int, batch_idx=state.system_idx.to(torch.int32), compute_virial=self._compute_stress, num_systems=state.n_systems, ) results: dict[str, torch.Tensor] = { "energy": (d3_out[0] * UnitConversion.Hartree_to_eV).to(self._dtype).detach(), "forces": (d3_out[1] * _FORCE_CONV).to(self._dtype).detach(), } if self._compute_stress: # d3_out[3] is only defined if compute_virial is True # we use [-1] to index it to avoid typing errors. volumes = state.volume.unsqueeze(-1).unsqueeze(-1) stress = (d3_out[-1] * UnitConversion.Hartree_to_eV) / volumes results["stress"] = stress.to(self._dtype).detach() return results