torch_sim.models.pair_potential¶
General batched pair potential and pair forces models.
This module provides PairPotentialModel, a flexible wrapper that turns any
pairwise energy function into a full TorchSim model with forces (via autograd) and
optional stress / per-atom output. It generalises Lennard-Jones, Morse, soft-sphere,
and similar potentials that depend only on pairwise distances and atomic numbers.
It also provides PairForcesModel for potentials defined directly as forces
(e.g. the asymmetric particle-life interaction) that cannot be expressed as the
gradient of a scalar energy.
The pair function signature required by PairPotentialModel is:
pair_fn(distances, atomic_numbers_i, atomic_numbers_j) -> pair_energies,
where all arguments are 1-D tensors of length n_pairs and the return value is a
1-D tensor of pair energies. Additional parameters (e.g., sigma, epsilon)
can be bound using functools.partial().
Notes
The
cutoffparameter determines the neighbor list construction range. Pairs beyond the cutoff are excluded from energy/force calculations. If your potential has its own natural cutoff (e.g., WCA potential), ensure the model’scutoffis at least as large.The
atomic_numbers_iandatomic_numbers_jarguments are provided for type-dependent potentials, but can be ignored (e.g., with# noqa: ARG001) for type-independent potentials like Lennard-Jones.The
dtypeof the SimState must match the model’sdtype. The model will raise aTypeErrorif they don’t match.Use
reduce_to_half_list=Truefor symmetric potentials to halve computation time. Only useFalsefor asymmetric interactions or when you need the full neighbor list for other purposes.
Example:
from torch_sim.models.pair_potential import PairPotentialModel
from torch_sim import io
from ase.build import bulk
import functools
import torch
def bmhtf_pair(dr, zi, zj, A, B, C, D, sigma):
# Born-Meyer-Huggins-Tosi-Fumi (BMHTF) potential for ionic crystals
# V(r) = A * exp(B * (sigma - r)) - C/r^6 - D/r^8
exp_term = A * torch.exp(B * (sigma - dr))
r6_term = C / dr.pow(6)
r8_term = D / dr.pow(8)
energy = exp_term - r6_term - r8_term
return torch.where(dr > 0, energy, torch.zeros_like(energy))
# Na-Cl interaction parameters
fn = functools.partial(
bmhtf_pair,
A=20.3548,
B=3.1546,
C=674.4793,
D=837.0770,
sigma=2.755,
)
model = PairPotentialModel(pair_fn=fn, cutoff=10.0)
# Create NaCl structure using ASE
nacl_atoms = bulk("NaCl", "rocksalt", a=5.64)
sim_state = io.atoms_to_state(nacl_atoms, device=torch.device("cpu"))
results = model(sim_state)
Functions
Reduce a full neighbor list to a half list. |
Classes
Batched pair model for potentials defined directly as forces. |
|
General batched pair potential model. |
|
bool(x) -> bool |