torch_nl_linked_cell¶
- torch_sim.neighbors.torch_nl.torch_nl_linked_cell(positions, cell, pbc, cutoff, system_idx, self_interaction=False)[source]¶
Compute the neighbor list for a set of atomic structures using the linked cell algorithm before applying a strict cutoff.
The atomic positions pos should be wrapped inside their respective unit cells.
This is the recommended default for batched neighbor list calculations as it provides good performance for systems of various sizes using the linked cell algorithm which has O(N) complexity.
- Parameters:
positions (
torch.Tensor [n_atom, 3]) – A tensor containing the positions of atoms wrapped inside their respective unit cells.cell (
torch.Tensor [n_systems, 3, 3]) – Unit cell vectors.pbc (
torch.Tensor [n_systems, 3] bool) – A tensor indicating the periodic boundary conditions to apply.cutoff (
torch.Tensor) – The cutoff radius used for the neighbor search.system_idx (
torch.Tensor [n_atom,] torch.long) – A tensor containing the index of the structure to which each atom belongs.self_interaction (
bool, optional) – A flag to indicate whether to keep the center atoms as their own neighbors. Default is False.
- Returns:
- A tuple containing:
- mapping (torch.Tensor [2, n_neighbors]):
A tensor containing the indices of the neighbor list for the given positions array. mapping[0] corresponds to the central atom indices, and mapping[1] corresponds to the neighbor atom indices.
- system_mapping (torch.Tensor [n_neighbors]):
A tensor mapping the neighbor atoms to their respective structures.
- shifts_idx (torch.Tensor [n_neighbors, 3]):
A tensor containing the cell shift indices used to reconstruct the neighbor atom positions.
- Return type:
Example
>>> # Create a batched system with 2 structures >>> positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [5.0, 5.0, 5.0]]) >>> cell = torch.eye(3).repeat(2, 1) * 10.0 # Two cells >>> pbc = torch.tensor([[True, True, True], [True, True, True]]) >>> cutoff = torch.tensor(2.0) >>> # First 2 atoms in system 0, last in system 1 >>> system_idx = torch.tensor([0, 0, 1]) >>> mapping, sys_map, shifts = torch_nl_linked_cell( ... positions, cell, pbc, cutoff, system_idx ... )
References