NequIPFrameworkModel

class torch_sim.models.nequip_framework.NequIPFrameworkModel(model=None, *, r_max, type_names, device=None, neighbor_list_fn=vesin_nl_ts, atomic_numbers=None, system_idx=None)[source]

Bases: ModelInterface

NequIP model for energy, force and stress calculations.

This class wraps a NequIP model to compute energies, forces and stresses for atomic systems.

Parameters:
  • model (torch.nn.Module) – The NequIP model to use. Must be a torch.nn.Module.

  • r_max (float) – Cutoff radius for neighbor list construction.

  • type_names (list[str]) – List of chemical symbols supported by the model.

  • device (torch.device | None) – Device to run calculations on. Defaults to CUDA if available, otherwise CPU.

  • neighbor_list_fn (Callable) – Function to compute neighbor lists. Defaults to vesin_nl_ts.

  • atomic_numbers (torch.Tensor | None) – Atomic numbers with shape [n_atoms]. If provided at initialization, cannot be provided again during forward pass.

  • system_idx (torch.Tensor | None) – Batch indices with shape [n_atoms] indicating which system each atom belongs to. If not provided with atomic_numbers, all atoms are assumed to be in the same system.

setup_from_system_idx(atomic_numbers, system_idx)[source]

Set up internal state from atomic numbers and system indices.

Processes the atomic numbers and system indices to prepare the model for forward pass calculations. Creates the necessary data structures for batched processing of multiple systems.

Parameters:
  • atomic_numbers (torch.Tensor) – Atomic numbers tensor with shape [n_atoms].

  • system_idx (torch.Tensor) – System indices tensor with shape [n_atoms] indicating which system each atom belongs to.

Return type:

None

forward(state)[source]

Compute energies, forces, and stresses for the given atomic systems.

Processes the provided state information and computes energies, forces, and stresses using the underlying MACE model. Handles batched calculations for multiple systems and constructs the necessary neighbor lists.

Parameters:

state (SimState | StateDict) – State object containing positions, cell, and other system information. Can be either a SimState object or a dictionary with the relevant fields.

Returns:

Computed properties:
  • ’energy’: System energies with shape [n_systems]

  • ’forces’: Atomic forces with shape [n_atoms, 3] if compute_forces=True

  • ’stress’: System stresses with shape [n_systems, 3, 3] if

    compute_stress=True

Return type:

dict[str, Tensor]

Raises:
  • ValueError – If atomic numbers are not provided either in the constructor or in the forward pass, or if provided in both places.

  • ValueError – If system indices are not provided when needed.