torch_sim.neighbors.nbody¶
Pure-PyTorch triplet and quadruplet interaction index builders.
Uses only standard PyTorch ops (argsort, bincount, repeat_interleave, boolean
masking) and is compatible with torch.jit.script. No torch_scatter or
torch_sparse dependencies.
build_triplets finds every ordered pair of edges (b→a, c→a) sharing a
target atom a — the angle environment used by three-body potentials (Tersoff,
SW) and message-passing networks (DimeNet).
build_mixed_triplets does the same across two different edge sets (different
cutoffs or connectivity rules). Used internally by build_quadruplets and
directly for architectures with separate embedding and interaction graphs.
build_quadruplets builds four-body interactions d→b→a←c from two neighbour
lists at different cutoffs. The central bond b→a comes from the “internal”
graph (shorter cutoff), while the outer bonds d→b and c→a come from the
main graph (longer cutoff):
d ——(main, long)——> b ===(internal, short)===> a <——(main, long)—— c
For each short central bond, all long-range neighbours of its endpoints are paired
(excluding c == d in the same image). This biases the model toward interactions
where the central bond is strongest, which is the opposite of a uniform-cutoff
torsion. Pure-PyTorch equivalent of GemNet-OC get_quadruplets:
mapping, _, shifts = torch_nl_linked_cell(pos, cell, pbc, tensor(5.0), sys_idx)
qmapping, _, qshifts = torch_nl_linked_cell(pos, cell, pbc, tensor(3.0), sys_idx)
trip = build_triplets(mapping, n_atoms)
quad = build_quadruplets(mapping, qmapping, n_atoms, shifts.float(), qshifts.float())
# quad["quad_c_to_a_edge"] — c→a main-edge index per quadruplet
# quad["quad_d_to_b_trip_idx"] — index into d_to_b_edge/b_to_a_edge per quadruplet
# quad["quad_c_to_a_trip_idx"] — index into c_to_a_edge per quadruplet
Functions
Build triplet indices across two different edge sets sharing the same atoms. |
|
Build quadruplet interaction indices |
|
Build triplet interaction indices from an edge list. |