Source code for torch_sim.models.interface

"""Core interfaces for all models in TorchSim.

This module defines the abstract base class that all TorchSim models must implement.
It establishes a common API for interacting with different force and energy models,
ensuring consistent behavior regardless of the underlying implementation. The module
also provides validation utilities to verify model conformance to the interface.

Example::

    # Creating a custom model that implements the interface
    class MyModel(ModelInterface):
        def __init__(self, device=None, dtype=torch.float64):
            self._device = device or torch.device(
                "cuda" if torch.cuda.is_available() else "cpu"
            )
            self._dtype = dtype
            self._compute_stress = True
            self._compute_forces = True

        def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs):
            # Implementation that returns energy, forces, and stress
            return {"energy": energy, "forces": forces, "stress": stress}

Notes:
    Models must explicitly declare support for stress computation through the
    compute_stress property, as some integrators require stress calculations.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import torch

import torch_sim as ts
from torch_sim.state import _CANONICAL_MODEL_KEYS


if TYPE_CHECKING:
    from collections.abc import Callable

    from torch_sim.state import SimState
    from torch_sim.typing import MemoryScaling


VALIDATE_ATOL = 1e-4

_MEMORY_SCALING_PRIORITY: dict[MemoryScaling, int] = {
    "n_atoms": 0,
    "n_atoms_x_density": 1,
    "n_edges": 2,
}


def _accumulate_model_output(
    combined: dict[str, torch.Tensor], output: dict[str, torch.Tensor]
) -> None:
    """Accumulate one model output into a combined output dict.

    Canonical mechanical outputs are additive. Other outputs are treated as
    full updated values, so later models replace earlier ones.
    """
    for key, tensor in output.items():
        if key in combined and key in _CANONICAL_MODEL_KEYS:
            combined[key] = combined[key] + tensor
        else:
            combined[key] = tensor


[docs] class ModelInterface(torch.nn.Module, ABC): """Abstract base class for all simulation models in TorchSim. This interface provides a common structure for all energy and force models, ensuring they implement the required methods and properties. It defines how models should process atomic positions and system information to compute energies, forces, and stresses. Attributes: device (torch.device): Device where the model runs computations. dtype (torch.dtype): Data type used for tensor calculations. compute_stress (bool): Whether the model calculates stress tensors. compute_forces (bool): Whether the model calculates atomic forces. memory_scales_with (MemoryScaling): The metric that the model scales with. "n_atoms" uses only atom count and is suitable for models that have a fixed number of neighbors. "n_atoms_x_density" uses atom count multiplied by number density and is better for models with radial cutoffs. Defaults to "n_atoms_x_density". Examples: ```py # Using a model that implements ModelInterface model = LennardJonesModel(device=torch.device("cuda")) # Forward pass with a simulation state output = model(sim_state) # Access computed properties energy = output["energy"] # Shape: [n_systems] forces = output["forces"] # Shape: [n_atoms, 3] stress = output["stress"] # Shape: [n_systems, 3, 3] ``` """ _device: torch.device _dtype: torch.dtype _compute_stress: bool _compute_forces: bool @property def device(self) -> torch.device: """The device of the model.""" return self._device @device.setter def device(self, device: torch.device) -> None: raise NotImplementedError( "No device setter has been defined for this model" " so the device cannot be changed after initialization." ) @property def dtype(self) -> torch.dtype: """The data type of the model.""" return self._dtype @dtype.setter def dtype(self, dtype: torch.dtype) -> None: raise NotImplementedError( "No dtype setter has been defined for this model" " so the dtype cannot be changed after initialization." ) @property def compute_stress(self) -> bool: """Whether the model computes stresses.""" return self._compute_stress @compute_stress.setter def compute_stress(self, compute_stress: bool) -> None: raise NotImplementedError( "No compute_stress setter has been defined for this model" " so compute_stress cannot be set after initialization." ) @property def compute_forces(self) -> bool: """Whether the model computes forces.""" return self._compute_forces @compute_forces.setter def compute_forces(self, compute_forces: bool) -> None: raise NotImplementedError( "No compute_forces setter has been defined for this model" " so compute_forces cannot be set after initialization." ) @property def memory_scales_with(self) -> MemoryScaling: """The metric that the model scales with. Models with radial neighbor cutoffs scale with "n_atoms_x_density", while models with a fixed number of neighbors scale with "n_atoms". Default is "n_atoms_x_density" because most models are radial cutoff based. """ return getattr(self, "_memory_scales_with", "n_atoms_x_density")
[docs] @abstractmethod def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: """Calculate energies, forces, and stresses for a atomistic system. This is the main computational method that all model implementations must provide. It takes atomic positions and system information as input and returns a dictionary containing computed physical properties. Args: state (SimState): Simulation state containing: - positions: Atomic positions with shape [n_atoms, 3] - cell: Unit cell vectors with shape [n_systems, 3, 3] - system_idx: System indices for each atom with shape [n_atoms] - atomic_numbers: Atomic numbers with shape [n_atoms] (optional) **kwargs: Additional model-specific parameters. Returns: dict[str, torch.Tensor]: Computed properties: - "energy": Potential energy with shape [n_systems] - "forces": Atomic forces with shape [n_atoms, 3] - "stress": Stress tensor with shape [n_systems, 3, 3] (if compute_stress=True) - May include additional model-specific outputs Examples: ```py # Compute energies and forces with a model output = model.forward(state) energy = output["energy"] forces = output["forces"] stress = output.get("stress", None) ``` """
[docs] class SumModel(ModelInterface): """Additive composition of multiple :class:`ModelInterface` models. Calls each child model's :meth:`forward`. Canonical mechanical outputs (energy, forces, stress) are combined additively, while non-canonical outputs are treated as full updated values and later models replace earlier ones. This is the standard way to layer a dispersion correction (e.g. DFT-D3), an Ewald electrostatic term, or a local pair potential on top of a primary machine-learning potential. Args: models: Two or more :class:`ModelInterface` instances that share the same ``device`` and ``dtype``. Raises: ValueError: If fewer than two models are given or if ``device``/``dtype`` do not match across all models. Examples: ```py sum_model = SumModel(mace_model, d3_model) output = sum_model(sim_state) ``` """ def __init__(self, *models: ModelInterface) -> None: """Initialize the sum model. Args: models: Two or more :class:`ModelInterface` instances. All must share the same ``device`` and ``dtype``. """ super().__init__() if len(models) < 2: raise ValueError("SumModel requires at least two child models") first = models[0] for i, m in enumerate(models[1:], start=1): if m.device != first.device: raise ValueError( f"Device mismatch: model 0 has {first.device}, " f"model {i} has {m.device}" ) if m.dtype != first.dtype: raise ValueError( f"Dtype mismatch: model 0 has {first.dtype}, model {i} has {m.dtype}" ) self.models = torch.nn.ModuleList(models) self._device = first.device self._dtype = first.dtype self._compute_stress = all(m.compute_stress for m in models) self._compute_forces = all(m.compute_forces for m in models) def _children(self) -> list[ModelInterface]: """Return child models with proper typing for static analysis.""" return list(self.models.children()) # type: ignore[return-value] @ModelInterface.compute_stress.setter def compute_stress(self, value: bool) -> None: # noqa: FBT001 """Propagate ``compute_stress`` to all child models that support it.""" for m in self._children(): try: m.compute_stress = value except NotImplementedError: if value: raise self._compute_stress = value @ModelInterface.compute_forces.setter def compute_forces(self, value: bool) -> None: # noqa: FBT001 """Propagate ``compute_forces`` to all child models that support it.""" for m in self._children(): try: m.compute_forces = value except NotImplementedError: if value: raise self._compute_forces = value @property def retain_graph(self) -> bool: """Whether any child model retains the computation graph.""" return all(getattr(m, "retain_graph", False) for m in self._children()) @retain_graph.setter def retain_graph(self, value: bool) -> None: for m in self._children(): if hasattr(m, "retain_graph"): m.retain_graph = value # type: ignore[union-attr] @property def memory_scales_with(self) -> MemoryScaling: """Most conservative memory-scaling among all child models.""" best: MemoryScaling = "n_atoms" for m in self._children(): scaling = m.memory_scales_with if _MEMORY_SCALING_PRIORITY[scaling] > _MEMORY_SCALING_PRIORITY[best]: best = scaling return best
[docs] def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: """Sum the outputs of all child models. Each child model is called with the same ``state`` and ``**kwargs``. Canonical mechanical outputs that appear in multiple children are summed element-wise. Non-canonical outputs are replaced by later models so they behave like full state updates rather than deltas. Args: state: Simulation state (see :class:`ModelInterface`). **kwargs: Forwarded to every child model. Returns: Combined output dictionary with summed tensors. """ combined: dict[str, torch.Tensor] = {} for model in self._children(): output = model(state, **kwargs) _accumulate_model_output(combined, output) return combined
[docs] class SerialSumModel(SumModel): """Serial additive composition of multiple :class:`ModelInterface` models. Unlike :class:`SumModel`, child models do not all see the same input state. Instead, each child runs after the previous child's non-canonical outputs have been stored into a cloned :class:`~torch_sim.state.SimState` via :meth:`torch_sim.state.SimState.store_model_extras`. This lets earlier models expose per-atom or per-system features that later models can consume. Energies, forces, and stresses remain additive, while repeated auxiliary outputs are treated as full updated values from the latest stage. Examples: ```py serial_model = SerialSumModel(polarization_model, dispersion_model) output = serial_model(sim_state) ``` """
[docs] def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: """Run child models serially, exposing extras from earlier models.""" combined: dict[str, torch.Tensor] = {} serial_state = state.clone() for model in self._children(): output = model(serial_state, **kwargs) _accumulate_model_output(combined, output) serial_state.store_model_extras(output) return combined
def _check_output_detached( output: dict[str, torch.Tensor], model: ModelInterface ) -> None: """Check that output tensors match the model's graph retention setting. When ``retain_graph`` is absent or ``False``, all tensors must be detached. When ``retain_graph`` is ``True``, all tensors must have ``requires_grad``. Args: output: Model output dictionary mapping keys to tensors. model: The model that produced the output. Raises: ValueError: If tensors are not detached when ``retain_graph`` is ``False``, or lack gradients when ``retain_graph`` is ``True``. """ retain_graph = getattr(model, "retain_graph", False) for key, tensor in output.items(): if not isinstance(tensor, torch.Tensor): continue if retain_graph and not tensor.requires_grad: raise ValueError( f"Output tensor '{key}' does not have gradients but model.retain_graph " "is True. Ensure the tensor is part of the computation graph." ) if not retain_graph and tensor.requires_grad: raise ValueError( f"Output tensor '{key}' is not detached from the computation graph. " "Call .detach() on the tensor before returning it, or set " "model.retain_graph = True if graph retention is intentional." )
[docs] def validate_model_outputs( # noqa: C901, PLR0915 model: ModelInterface, device: torch.device, dtype: torch.dtype, *, check_detached: bool = False, state_modifier: Callable[[SimState], SimState] | None = None, ) -> None: """Validate the outputs of a model implementation against the interface requirements. Runs a series of tests to ensure a model implementation correctly follows the ModelInterface contract. The tests include creating sample systems, running forward passes, and verifying output shapes and consistency. Args: model (ModelInterface): Model implementation to validate. device (torch.device): Device to run the validation tests on. dtype (torch.dtype): Data type to use for validation tensors. check_detached (bool): If ``True``, assert that all output tensors are detached from the autograd graph, unless the model has a ``retain_graph`` attribute set to ``True``. Defaults to ``False`` so that external callers are not immediately broken. state_modifier: If provided, applied to every ``SimState`` created during validation before the model sees it. Must return the (possibly new) state. Raises: AssertionError: If the model doesn't conform to the required interface, including issues with output shapes, types, or behavior consistency. Example:: # Create a new model implementation model = MyCustomModel(device=torch.device("cuda")) # Validate that it correctly implements the interface validate_model_outputs(model, device=torch.device("cuda"), dtype=torch.float64) Notes: This validator creates small test systems (diamond silicon, HCP magnesium, and primitive BCC iron) for validation. It tests both single and multi-batch processing capabilities. """ from ase.build import bulk, molecule def _modify(state: SimState) -> SimState: return state_modifier(state) if state_modifier is not None else state for attr in ("dtype", "device", "compute_stress", "compute_forces"): if not hasattr(model, attr): raise ValueError(f"model.{attr} is not set") try: if not model.compute_stress: model.compute_stress = True stress_computed = True except NotImplementedError: stress_computed = False try: if not model.compute_forces: model.compute_forces = True force_computed = True except NotImplementedError: force_computed = False si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) mg_atoms = bulk("Mg", "hcp", a=3.21, c=5.21).repeat([3, 2, 1]) fe_atoms = bulk("Fe", "bcc", a=2.87) sim_state = _modify( ts.io.atoms_to_state([si_atoms, mg_atoms, fe_atoms], device, dtype) ) og_positions = sim_state.positions.clone() og_cell = sim_state.cell.clone() system_idx = sim_state.system_idx og_system_idx = system_idx.clone() og_atomic_nums = sim_state.atomic_numbers.clone() if check_detached and hasattr(model, "retain_graph"): model.__dict__["retain_graph"] = True _check_output_detached(model.forward(sim_state), model) model.__dict__["retain_graph"] = False model_output = model.forward(sim_state) if check_detached: _check_output_detached(model_output, model) # assert model did not mutate the input if not torch.allclose(og_positions, sim_state.positions): raise ValueError(f"{og_positions=} != {sim_state.positions=}") if not torch.allclose(og_cell, sim_state.cell): raise ValueError(f"{og_cell=} != {sim_state.cell=}") if not torch.allclose(og_system_idx, system_idx): raise ValueError(f"{og_system_idx=} != {sim_state.system_idx=}") if not torch.allclose(og_atomic_nums, sim_state.atomic_numbers): raise ValueError(f"{og_atomic_nums=} != {sim_state.atomic_numbers=}") # assert model output has the correct keys if "energy" not in model_output: raise ValueError("energy not in model output") if force_computed and "forces" not in model_output: raise ValueError("forces not in model output") if stress_computed and "stress" not in model_output: raise ValueError("stress not in model output") # assert model output shapes are correct if model_output["energy"].shape != (3,): raise ValueError(f"{model_output['energy'].shape=} != (3,)") if force_computed and model_output["forces"].shape != (21, 3): raise ValueError(f"{model_output['forces'].shape=} != (21, 3)") if stress_computed and model_output["stress"].shape != (3, 3, 3): raise ValueError(f"{model_output['stress'].shape=} != (3, 3, 3)") # Test single Si system output shapes (8 atoms) si_state = _modify(ts.io.atoms_to_state([si_atoms], device, dtype)) si_model_output = model.forward(si_state) if not torch.allclose( si_model_output["energy"], model_output["energy"][0], atol=VALIDATE_ATOL ): raise ValueError(f"{si_model_output['energy']=} != {model_output['energy'][0]=}") if not torch.allclose( forces := si_model_output["forces"], expected_forces := model_output["forces"][: si_state.n_atoms], atol=VALIDATE_ATOL, ): raise ValueError(f"{forces=} != {expected_forces=}") if si_model_output["energy"].shape != (1,): raise ValueError(f"{si_model_output['energy'].shape=} != (1,)") if force_computed and si_model_output["forces"].shape != (8, 3): raise ValueError(f"{si_model_output['forces'].shape=} != (8, 3)") if stress_computed and si_model_output["stress"].shape != (1, 3, 3): raise ValueError(f"{si_model_output['stress'].shape=} != (1, 3, 3)") # Test single Mg system output shapes (12 atoms) mg_state = _modify(ts.io.atoms_to_state([mg_atoms], device, dtype)) mg_model_output = model.forward(mg_state) if not torch.allclose( mg_model_output["energy"], model_output["energy"][1], atol=VALIDATE_ATOL ): raise ValueError(f"{mg_model_output['energy']=} != {model_output['energy'][1]=}") mg_n = mg_state.n_atoms mg_slice = slice(si_state.n_atoms, si_state.n_atoms + mg_n) if not torch.allclose( forces := mg_model_output["forces"], expected_forces := model_output["forces"][mg_slice], atol=VALIDATE_ATOL, ): raise ValueError(f"{forces=} != {expected_forces=}") if mg_model_output["energy"].shape != (1,): raise ValueError(f"{mg_model_output['energy'].shape=} != (1,)") if force_computed and mg_model_output["forces"].shape != (12, 3): raise ValueError(f"{mg_model_output['forces'].shape=} != (12, 3)") if stress_computed and mg_model_output["stress"].shape != (1, 3, 3): raise ValueError(f"{mg_model_output['stress'].shape=} != (1, 3, 3)") # Test single Fe system output shapes (1 atom) # This catches that models do not squeeze away singleton dimensions. fe_state = _modify(ts.io.atoms_to_state([fe_atoms], device, dtype)) fe_model_output = model.forward(fe_state) if not torch.allclose( fe_model_output["energy"], model_output["energy"][2], atol=VALIDATE_ATOL ): raise ValueError(f"{fe_model_output['energy']=} != {model_output['energy'][2]=}") if not torch.allclose( forces := fe_model_output["forces"], expected_forces := model_output["forces"][si_state.n_atoms + mg_n :], atol=VALIDATE_ATOL, ): raise ValueError(f"{forces=} != {expected_forces=}") if fe_model_output["energy"].shape != (1,): raise ValueError(f"{fe_model_output['energy'].shape=} != (1,)") if force_computed and fe_model_output["forces"].shape != (1, 3): raise ValueError(f"{fe_model_output['forces'].shape=} != (1, 3)") if stress_computed and fe_model_output["stress"].shape != (1, 3, 3): raise ValueError(f"{fe_model_output['stress'].shape=} != (1, 3, 3)") # Translating one atom by a full lattice vector should not change outputs. # This catches models that fail to apply periodic boundary conditions. shifted_state = si_state.clone() lattice_vec = shifted_state.cell[0, :, 0] # column convention shifted_state.positions[0] = shifted_state.positions[0] + 3 * lattice_vec shifted_output = model.forward(shifted_state) if not torch.allclose( shifted_output["energy"], si_model_output["energy"], atol=VALIDATE_ATOL ): raise ValueError( "Energy changed after translating an atom by a lattice " f"vector: {shifted_output['energy']=} != " f"{si_model_output['energy']=}" ) if force_computed and not torch.allclose( shifted_output["forces"], si_model_output["forces"], atol=VALIDATE_ATOL ): raise ValueError( "Forces changed after translating an atom by a lattice " "vector: max diff = " f"{(shifted_output['forces'] - si_model_output['forces']).abs().max()}" ) if stress_computed and not torch.allclose( shifted_output["stress"], si_model_output["stress"], atol=VALIDATE_ATOL ): raise ValueError( "Stress changed after translating an atom by a lattice " "vector: max diff = " f"{(shifted_output['stress'] - si_model_output['stress']).abs().max()}" ) # Test a non-periodic molecule (benzene) benzene_atoms = molecule("C6H6") benzene_state = _modify(ts.io.atoms_to_state([benzene_atoms], device, dtype)) benzene_output = model.forward(benzene_state) if benzene_output["energy"].shape != (1,): raise ValueError( f"energy shape incorrect for benzene: " f"{benzene_output['energy'].shape=} != (1,)" ) if force_computed and benzene_output["forces"].shape != (12, 3): raise ValueError( f"forces shape incorrect for benzene: " f"{benzene_output['forces'].shape=} != (12, 3)" )