Source code for torch_sim.models.mace

"""Wrapper for MACE model in TorchSim.

This module provides a TorchSim wrapper of the MACE model for computing
energies, forces, and stresses for atomistic systems. It integrates the MACE model
with TorchSim's simulation framework, handling batched computations for multiple
systems simultaneously.

The implementation supports various features including:

* Computing energies, forces, and stresses
* Handling periodic boundary conditions (PBC)
* Optional CuEq acceleration for improved performance
* Batched calculations for multiple systems

Notes:
    This module depends on the MACE package and implements the ModelInterface
    for compatibility with the broader TorchSim framework.
"""

import traceback
import warnings
from collections.abc import Callable
from enum import StrEnum
from pathlib import Path
from typing import Any

import torch

import torch_sim as ts
from torch_sim.models.interface import ModelInterface
from torch_sim.neighbors import torchsim_nl


try:
    from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
    from mace.tools import atomic_numbers_to_indices, utils
except (ImportError, ModuleNotFoundError) as exc:
    warnings.warn(f"MACE import failed: {traceback.format_exc()}", stacklevel=2)

    class MaceModel(ModelInterface):
        """MACE model wrapper for torch-sim.

        This class is a placeholder for the MaceModel class.
        It raises an ImportError if MACE is not installed.
        """

        def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None:
            """Dummy init for type checking."""
            raise err


[docs] def to_one_hot( indices: torch.Tensor, num_classes: int, dtype: torch.dtype ) -> torch.Tensor: """Generates one-hot encoding from indices. NOTE: this is a modified version of the to_one_hot function in mace.tools, consider using upstream version if possible after https://github.com/ACEsuit/mace/pull/903/ is merged. Args: indices: A tensor of shape (N x 1) containing class indices. num_classes: An integer specifying the total number of classes. dtype: The desired data type of the output tensor. Returns: torch.Tensor: A tensor of shape (N x num_classes) containing the one-hot encodings. """ shape = (*indices.shape[:-1], num_classes) oh = torch.zeros(shape, device=indices.device, dtype=dtype).view(shape) # scatter_ is the in-place version of scatter oh.scatter_(dim=-1, index=indices, value=1) return oh.view(*shape)
[docs] class MaceModel(ModelInterface): """Computes energies for multiple systems using a MACE model. This class wraps a MACE model to compute energies, forces, and stresses for atomic systems within the TorchSim framework. It supports batched calculations for multiple systems and handles the necessary transformations between TorchSim's data structures and MACE's expected inputs. Attributes: r_max (float): Cutoff radius for neighbor interactions. z_table (utils.AtomicNumberTable): Table mapping atomic numbers to indices. model (torch.nn.Module): The underlying MACE neural network model. neighbor_list_fn (Callable): Function used to compute neighbor lists. atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms]. system_idx (torch.Tensor): System indices with shape [n_atoms]. n_systems (int): Number of systems in the batch. n_atoms_per_system (list[int]): Number of atoms in each system. ptr (torch.Tensor): Pointers to the start of each system in the batch with shape [n_systems + 1]. total_atoms (int): Total number of atoms across all systems. node_attrs (torch.Tensor): One-hot encoded atomic types with shape [n_atoms, n_elements]. """ def __init__( self, model: str | Path | torch.nn.Module | None = None, *, device: torch.device | None = None, dtype: torch.dtype = torch.float64, neighbor_list_fn: Callable = torchsim_nl, compute_forces: bool = True, compute_stress: bool = True, enable_cueq: bool = False, atomic_numbers: torch.Tensor | None = None, system_idx: torch.Tensor | None = None, ) -> None: """Initialize the MACE model for energy and force calculations. Sets up the MACE model for energy, force, and stress calculations within the TorchSim framework. The model can be initialized with atomic numbers and system indices, or these can be provided during the forward pass. Args: model (str | Path | torch.nn.Module | None): The MACE neural network model, either as a path to a saved model or as a loaded torch.nn.Module instance. device (torch.device | None): The device to run computations on. Defaults to CUDA if available, otherwise CPU. dtype (torch.dtype): The data type for tensor operations. Defaults to torch.float64. atomic_numbers (torch.Tensor | None): Atomic numbers with shape [n_atoms]. If provided at initialization, cannot be provided again during forward. system_idx (torch.Tensor | None): System indices with shape [n_atoms] indicating which system each atom belongs to. If not provided with atomic_numbers, all atoms are assumed to be in the same system. neighbor_list_fn (Callable): Function to compute neighbor lists. Defaults to torch_nl_linked_cell. compute_forces (bool): Whether to compute forces. Defaults to True. compute_stress (bool): Whether to compute stress. Defaults to True. enable_cueq (bool): Whether to enable CuEq acceleration. Defaults to False. Raises: NotImplementedError: If model is provided as a file path (not implemented yet). TypeError: If model is neither a path nor a torch.nn.Module. """ super().__init__() self._device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) self._dtype = dtype self._compute_forces = compute_forces self._compute_stress = compute_stress self.neighbor_list_fn = neighbor_list_fn self._memory_scales_with = "n_atoms_x_density" # Load model if provided as path if isinstance(model, str | Path): self.model = torch.load(model, map_location=self.device, weights_only=False) elif isinstance(model, torch.nn.Module): self.model = model.to(self.device) else: raise TypeError("Model must be a path or torch.nn.Module") self.model = self.model.eval() # Move all model components to device self.model = self.model.to(device=self._device) if self.dtype is not None: self.model = self.model.to(dtype=self.dtype) if enable_cueq: print("Converting models to CuEq for acceleration") # noqa: T201 self.model = run_e3nn_to_cueq(self.model, device=self.device.type) # Set model properties self.r_max = self.model.r_max atomic_nums = self.model.atomic_numbers if not isinstance(atomic_nums, torch.Tensor): raise TypeError("MACE model atomic_numbers must be a tensor") self.z_table = utils.AtomicNumberTable([int(z) for z in atomic_nums]) self.model.atomic_numbers = atomic_nums.detach().clone().to(device=self.device) self.atomic_numbers_in_init = atomic_numbers is not None self.system_idx_in_init = system_idx is not None if atomic_numbers is not None: self.atomic_numbers = atomic_numbers self._setup_node_attrs(atomic_numbers) if system_idx is not None: self.system_idx = system_idx self._setup_ptr(system_idx) if ( atomic_numbers is not None and system_idx is not None and system_idx.shape[0] != atomic_numbers.shape[0] ): raise ValueError( f"system_idx length {system_idx.shape[0]} must match " f"atomic_numbers length {atomic_numbers.shape[0]}." ) def _setup_ptr(self, system_idx: torch.Tensor) -> None: """Compute system boundary pointers from system indices. Args: system_idx (torch.Tensor): System indices tensor with shape [n_atoms]. """ counts = torch.bincount(system_idx) self.n_systems = len(counts) self.n_atoms_per_system = counts.tolist() self.ptr = torch.cat([counts.new_zeros(1), counts.cumsum(0)]) def _setup_node_attrs(self, atomic_numbers: torch.Tensor) -> None: """Compute one-hot encoded node attributes from atomic numbers. Args: atomic_numbers (torch.Tensor): Atomic numbers tensor with shape [n_atoms]. """ self.node_attrs = to_one_hot( torch.tensor( atomic_numbers_to_indices( atomic_numbers.detach().cpu().numpy(), z_table=self.z_table ), dtype=torch.long, device=self.device, ).unsqueeze(-1), num_classes=len(self.z_table), dtype=self.dtype, )
[docs] def forward( # noqa: C901 self, state: ts.SimState, **_kwargs: object ) -> dict[str, torch.Tensor]: """Compute energies, forces, and stresses for the given atomic systems. Processes the provided state information and computes energies, forces, and stresses using the underlying MACE model. Handles batched calculations for multiple systems and constructs the necessary neighbor lists. Args: state (SimState): State object containing positions, cell, and other system information. **_kwargs: Unused; accepted for interface compatibility. Returns: dict[str, torch.Tensor]: Computed properties: - 'energy': System energies with shape [n_systems] - 'forces': Atomic forces with shape [n_atoms, 3] if compute_forces=True - 'stress': System stresses with shape [n_systems, 3, 3] if compute_stress=True Raises: ValueError: If atomic numbers are not provided either in the constructor or in the forward pass, or if provided in both places. ValueError: If system indices are not provided when needed. """ if self.atomic_numbers_in_init: if state.positions.shape[0] != self.atomic_numbers.shape[0]: raise ValueError( f"Expected {self.atomic_numbers.shape[0]} atoms, " f"got {state.positions.shape[0]}." ) elif not hasattr(self, "atomic_numbers") or not torch.equal( state.atomic_numbers, self.atomic_numbers ): self._setup_node_attrs(state.atomic_numbers) self.atomic_numbers = state.atomic_numbers if self.system_idx_in_init: if state.system_idx.shape[0] != self.system_idx.shape[0]: raise ValueError( f"Expected system_idx of length {self.system_idx.shape[0]}, " f"got {state.system_idx.shape[0]}." ) elif not hasattr(self, "system_idx") or not torch.equal( state.system_idx, self.system_idx ): self._setup_ptr(state.system_idx) self.system_idx = state.system_idx # Wrap positions into the unit cell wrapped_positions = ( ts.transforms.pbc_wrap_batched( state.positions, state.cell, state.system_idx, state.pbc, ) if state.pbc.any() else state.positions ) # Batched neighbor list using linked-cell algorithm edge_index, mapping_system, unit_shifts = self.neighbor_list_fn( wrapped_positions, state.row_vector_cell, state.pbc, self.r_max, state.system_idx, ) # Convert unit cell shift indices to Cartesian shifts shifts = ts.transforms.compute_cell_shifts( state.row_vector_cell, unit_shifts, mapping_system ) # Build data dict for MACE model data_dict = dict( ptr=self.ptr, node_attrs=self.node_attrs, batch=state.system_idx, pbc=state.pbc, cell=state.row_vector_cell, positions=wrapped_positions, edge_index=edge_index, unit_shifts=unit_shifts, shifts=shifts, total_charge=state.charge, total_spin=state.spin, ) # Get model output out = self.model( data_dict, compute_force=self.compute_forces, compute_stress=self.compute_stress, ) results: dict[str, torch.Tensor] = {} # Process energy energy = out["energy"] if energy is not None: results["energy"] = energy.detach() else: results["energy"] = torch.zeros(self.n_systems, device=self.device) # Process forces if self.compute_forces: forces = out["forces"] if forces is not None: results["forces"] = forces.detach() # Process stress if self.compute_stress: stress = out["stress"] if stress is not None: results["stress"] = stress.detach() return results
[docs] class MaceUrls(StrEnum): """Checkpoint download URLs for MACE models.""" mace_mp_small = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model" mace_mpa_medium = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" mace_off_small = "https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true"