torch_nl_n2

torch_sim.neighbors.torch_nl.torch_nl_n2(positions, cell, pbc, cutoff, system_idx, self_interaction=False)[source]

Compute the neighbor list for a set of atomic structures using a naive neighbor search before applying a strict cutoff.

The atomic positions pos should be wrapped inside their respective unit cells.

This implementation uses a naive O(N²) neighbor search which can be slow for large systems but is simple and works reliably for small to medium systems.

Parameters:
  • positions (torch.Tensor [n_atom, 3]) – A tensor containing the positions of atoms wrapped inside their respective unit cells.

  • cell (torch.Tensor [n_systems, 3, 3]) – Unit cell vectors.

  • pbc (torch.Tensor [n_systems, 3] bool) – A tensor indicating the periodic boundary conditions to apply.

  • cutoff (torch.Tensor) – The cutoff radius used for the neighbor search.

  • system_idx (torch.Tensor [n_atom,] torch.long) – A tensor containing the index of the structure to which each atom belongs.

  • self_interaction (bool, optional) – A flag to indicate whether to keep the center atoms as their own neighbors. Default is False.

Returns:

mapping (torch.Tensor [2, n_neighbors]):

A tensor containing the indices of the neighbor list for the given positions array. mapping[0] corresponds to the central atom indices, and mapping[1] corresponds to the neighbor atom indices.

system_mapping (torch.Tensor [n_neighbors]):

A tensor mapping the neighbor atoms to their respective structures.

shifts_idx (torch.Tensor [n_neighbors, 3]):

A tensor containing the cell shift indices used to reconstruct the neighbor atom positions.

Return type:

tuple[Tensor, Tensor, Tensor]

Example

>>> # Create a batched system with 2 structures
>>> positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [5.0, 5.0, 5.0]])
>>> cell = torch.eye(3).repeat(2, 1) * 10.0  # Two cells
>>> pbc = torch.tensor([[True, True, True], [True, True, True]])
>>> cutoff = torch.tensor(2.0)
>>> # First 2 atoms in system 0, last in system 1
>>> system_idx = torch.tensor([0, 0, 1])
>>> mapping, sys_map, shifts = torch_nl_n2(
...     positions, cell, pbc, cutoff, system_idx
... )

References