torch_sim.models.pair_potential

General batched pair potential and pair forces models.

This module provides PairPotentialModel, a flexible wrapper that turns any pairwise energy function into a full TorchSim model with forces (via autograd) and optional stress / per-atom output. It generalises Lennard-Jones, Morse, soft-sphere, and similar potentials that depend only on pairwise distances and atomic numbers.

It also provides PairForcesModel for potentials defined directly as forces (e.g. the asymmetric particle-life interaction) that cannot be expressed as the gradient of a scalar energy.

The pair function signature required by PairPotentialModel is: pair_fn(distances, atomic_numbers_i, atomic_numbers_j) -> pair_energies, where all arguments are 1-D tensors of length n_pairs and the return value is a 1-D tensor of pair energies. Additional parameters (e.g., sigma, epsilon) can be bound using functools.partial().

Notes

  • The cutoff parameter determines the neighbor list construction range. Pairs beyond the cutoff are excluded from energy/force calculations. If your potential has its own natural cutoff (e.g., WCA potential), ensure the model’s cutoff is at least as large.

  • The atomic_numbers_i and atomic_numbers_j arguments are provided for type-dependent potentials, but can be ignored (e.g., with # noqa: ARG001) for type-independent potentials like Lennard-Jones.

  • The dtype of the SimState must match the model’s dtype. The model will raise a TypeError if they don’t match.

  • Use reduce_to_half_list=True for symmetric potentials to halve computation time. Only use False for asymmetric interactions or when you need the full neighbor list for other purposes.

Example:

from torch_sim.models.pair_potential import PairPotentialModel
from torch_sim import io
from ase.build import bulk
import functools
import torch


def bmhtf_pair(dr, zi, zj, A, B, C, D, sigma):
    # Born-Meyer-Huggins-Tosi-Fumi (BMHTF) potential for ionic crystals
    # V(r) = A * exp(B * (sigma - r)) - C/r^6 - D/r^8
    exp_term = A * torch.exp(B * (sigma - dr))
    r6_term = C / dr.pow(6)
    r8_term = D / dr.pow(8)
    energy = exp_term - r6_term - r8_term
    return torch.where(dr > 0, energy, torch.zeros_like(energy))


# Na-Cl interaction parameters
fn = functools.partial(
    bmhtf_pair,
    A=20.3548,
    B=3.1546,
    C=674.4793,
    D=837.0770,
    sigma=2.755,
)
model = PairPotentialModel(pair_fn=fn, cutoff=10.0)

# Create NaCl structure using ASE
nacl_atoms = bulk("NaCl", "rocksalt", a=5.64)
sim_state = io.atoms_to_state(nacl_atoms, device=torch.device("cpu"))
results = model(sim_state)

Functions

full_to_half_list

Reduce a full neighbor list to a half list.

Classes

PairForcesModel

Batched pair model for potentials defined directly as forces.

PairPotentialModel

General batched pair potential model.

TYPE_CHECKING

bool(x) -> bool