torch_nl_n2¶
- torch_sim.neighbors.torch_nl.torch_nl_n2(positions, cell, pbc, cutoff, system_idx, self_interaction=False)[source]¶
Compute the neighbor list for a set of atomic structures using a naive neighbor search before applying a strict cutoff.
The atomic positions pos should be wrapped inside their respective unit cells.
This implementation uses a naive O(N²) neighbor search which can be slow for large systems but is simple and works reliably for small to medium systems.
- Parameters:
positions (
torch.Tensor [n_atom, 3]) – A tensor containing the positions of atoms wrapped inside their respective unit cells.cell (
torch.Tensor [n_systems, 3, 3]) – Unit cell vectors.pbc (
torch.Tensor [n_systems, 3] bool) – A tensor indicating the periodic boundary conditions to apply.cutoff (
torch.Tensor) – The cutoff radius used for the neighbor search.system_idx (
torch.Tensor [n_atom,] torch.long) – A tensor containing the index of the structure to which each atom belongs.self_interaction (
bool, optional) – A flag to indicate whether to keep the center atoms as their own neighbors. Default is False.
- Returns:
- mapping (torch.Tensor [2, n_neighbors]):
A tensor containing the indices of the neighbor list for the given positions array. mapping[0] corresponds to the central atom indices, and mapping[1] corresponds to the neighbor atom indices.
- system_mapping (torch.Tensor [n_neighbors]):
A tensor mapping the neighbor atoms to their respective structures.
- shifts_idx (torch.Tensor [n_neighbors, 3]):
A tensor containing the cell shift indices used to reconstruct the neighbor atom positions.
- Return type:
Example
>>> # Create a batched system with 2 structures >>> positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [5.0, 5.0, 5.0]]) >>> cell = torch.eye(3).repeat(2, 1) * 10.0 # Two cells >>> pbc = torch.tensor([[True, True, True], [True, True, True]]) >>> cutoff = torch.tensor(2.0) >>> # First 2 atoms in system 0, last in system 1 >>> system_idx = torch.tensor([0, 0, 1]) >>> mapping, sys_map, shifts = torch_nl_n2( ... positions, cell, pbc, cutoff, system_idx ... )
References