torch_sim.neighbors.nbody

Pure-PyTorch triplet and quadruplet interaction index builders.

Uses only standard PyTorch ops (argsort, bincount, repeat_interleave, boolean masking) and is compatible with torch.jit.script. No torch_scatter or torch_sparse dependencies.

build_triplets finds every ordered pair of edges (b→a, c→a) sharing a target atom a — the angle environment used by three-body potentials (Tersoff, SW) and message-passing networks (DimeNet).

build_mixed_triplets does the same across two different edge sets (different cutoffs or connectivity rules). Used internally by build_quadruplets and directly for architectures with separate embedding and interaction graphs.

build_quadruplets builds four-body interactions d→b→a←c from two neighbour lists at different cutoffs. The central bond b→a comes from the “internal” graph (shorter cutoff), while the outer bonds d→b and c→a come from the main graph (longer cutoff):

d ——(main, long)——> b ===(internal, short)===> a <——(main, long)—— c

For each short central bond, all long-range neighbours of its endpoints are paired (excluding c == d in the same image). This biases the model toward interactions where the central bond is strongest, which is the opposite of a uniform-cutoff torsion. Pure-PyTorch equivalent of GemNet-OC get_quadruplets:

mapping, _, shifts = torch_nl_linked_cell(pos, cell, pbc, tensor(5.0), sys_idx)
qmapping, _, qshifts = torch_nl_linked_cell(pos, cell, pbc, tensor(3.0), sys_idx)
trip = build_triplets(mapping, n_atoms)
quad = build_quadruplets(mapping, qmapping, n_atoms, shifts.float(), qshifts.float())
# quad["quad_c_to_a_edge"]      — c→a main-edge index per quadruplet
# quad["quad_d_to_b_trip_idx"]  — index into d_to_b_edge/b_to_a_edge per quadruplet
# quad["quad_c_to_a_trip_idx"]  — index into c_to_a_edge per quadruplet

Functions

build_mixed_triplets

Build triplet indices across two different edge sets sharing the same atoms.

build_quadruplets

Build quadruplet interaction indices d→b→a←c from two edge sets.

build_triplets

Build triplet interaction indices from an edge list.