PairPotentialModel¶
- class torch_sim.models.pair_potential.PairPotentialModel(pair_fn, *, cutoff, device=None, dtype=torch.float64, compute_forces=True, compute_stress=False, per_atom_energies=False, per_atom_stresses=False, neighbor_list_fn=torchsim_nl, reduce_to_half_list=False, retain_graph=False)[source]¶
Bases:
ModelInterfaceGeneral batched pair potential model.
Computes energies, forces, and stresses for any pairwise potential defined by a callable of the form
pair_fn(distances, atomic_numbers_i, atomic_numbers_j) -> pair_energies, where all arguments are 1-D tensors of length n_pairs and the return value is a 1-D tensor of pair energies. Forces are obtained analytically via autograd by differentiating the energy with respect to positions.When stress is computed, it uses the virial formula: σ = -1/V Σ_{ij} r_ij ⊗ f_ij, where r_ij is the pair displacement vector and f_ij is the force vector.
Example:
def lj_fn(dr, zi, zj): idr6 = (1.0 / dr) ** 6 return 4.0 * (idr6**2 - idr6) model = PairPotentialModel(pair_fn=lj_fn, cutoff=2.5) results = model(sim_state)
- Parameters:
- forward(state, **_kwargs)[source]¶
Compute pair-potential properties with batched tensor operations.
- Parameters:
- Returns:
dict with keys
"energy"(shape[n_systems]), optionally"forces"([n_atoms, 3]),"stress"([n_systems, 3, 3]),"energies"([n_atoms]),"stresses"([n_atoms, 3, 3]).- Raises:
TypeError – If the SimState’s dtype does not match the model’s dtype.
- Return type: