PairPotentialModel

class torch_sim.models.pair_potential.PairPotentialModel(pair_fn, *, cutoff, device=None, dtype=torch.float64, compute_forces=True, compute_stress=False, per_atom_energies=False, per_atom_stresses=False, neighbor_list_fn=torchsim_nl, reduce_to_half_list=False, retain_graph=False)[source]

Bases: ModelInterface

General batched pair potential model.

Computes energies, forces, and stresses for any pairwise potential defined by a callable of the form pair_fn(distances, atomic_numbers_i, atomic_numbers_j) -> pair_energies, where all arguments are 1-D tensors of length n_pairs and the return value is a 1-D tensor of pair energies. Forces are obtained analytically via autograd by differentiating the energy with respect to positions.

When stress is computed, it uses the virial formula: σ = -1/V Σ_{ij} r_ij ⊗ f_ij, where r_ij is the pair displacement vector and f_ij is the force vector.

Example:

def lj_fn(dr, zi, zj):
    idr6 = (1.0 / dr) ** 6
    return 4.0 * (idr6**2 - idr6)


model = PairPotentialModel(pair_fn=lj_fn, cutoff=2.5)
results = model(sim_state)
Parameters:
  • pair_fn (Callable)

  • cutoff (float)

  • device (device | None)

  • dtype (dtype)

  • compute_forces (bool)

  • compute_stress (bool)

  • per_atom_energies (bool)

  • per_atom_stresses (bool)

  • neighbor_list_fn (Callable)

  • reduce_to_half_list (bool)

  • retain_graph (bool)

forward(state, **_kwargs)[source]

Compute pair-potential properties with batched tensor operations.

Parameters:
  • state (SimState) – Simulation state.

  • **_kwargs (object) – Unused; accepted for interface compatibility.

Returns:

dict with keys "energy" (shape [n_systems]), optionally "forces" ([n_atoms, 3]), "stress" ([n_systems, 3, 3]), "energies" ([n_atoms]), "stresses" ([n_atoms, 3, 3]).

Raises:

TypeError – If the SimState’s dtype does not match the model’s dtype.

Return type:

dict[str, Tensor]