Source code for torch_sim.models.particle_life

"""Particle life model.

Thin wrapper around :class:`~torch_sim.models.pair_potential.PairForcesModel` with
the :func:`particle_life_pair_force` force function
baked in.

Example::

    model = ParticleLifeModel(sigma=1.0, epsilon=1.0, beta=0.3, cutoff=1.0)
    results = model(sim_state)
"""

from __future__ import annotations

import functools
from collections.abc import Callable  # noqa: TC003

import torch

from torch_sim.models.pair_potential import PairForcesModel
from torch_sim.neighbors import torchsim_nl


[docs] def particle_life_pair_force( dr: torch.Tensor, zi: torch.Tensor, # noqa: ARG001 zj: torch.Tensor, # noqa: ARG001 A: torch.Tensor | float = 1.0, beta: torch.Tensor | float = 0.3, sigma: torch.Tensor | float = 1.0, ) -> torch.Tensor: """Asymmetric particle-life scalar force magnitude. This is a *force* function (not an energy), intended for use with :class:`PairForcesModel`. Args: dr: Pairwise distances, shape [n_pairs]. zi: Atomic numbers of first atoms (unused). zj: Atomic numbers of second atoms (unused). A: Interaction amplitude. Defaults to 1.0. beta: Inner radius. Defaults to 0.3. sigma: Outer radius / cutoff. Defaults to 1.0. Returns: Scalar force magnitudes, shape [n_pairs]. """ inner_mask = dr < beta outer_mask = (dr >= beta) & (dr < sigma) inner_force = dr / beta - 1.0 outer_force = A * (1.0 - torch.abs(2.0 * dr - 1.0 - beta) / (1.0 - beta)) return torch.where(inner_mask, inner_force, torch.zeros_like(dr)) + torch.where( outer_mask, outer_force, torch.zeros_like(dr) )
[docs] class ParticleLifeModel(PairForcesModel): """Asymmetric particle-life force model. Convenience subclass that fixes the force function to :func:`particle_life_pair_force` so the caller only needs to supply ``sigma``, ``epsilon`` (amplitude), and ``beta``. Example:: model = ParticleLifeModel( sigma=1.0, epsilon=1.0, beta=0.3, cutoff=1.0, ) results = model(sim_state) """ def __init__( self, sigma: float = 1.0, epsilon: float = 1.0, beta: float = 0.3, device: torch.device | None = None, dtype: torch.dtype = torch.float64, *, compute_stress: bool = False, per_atom_stresses: bool = False, neighbor_list_fn: Callable = torchsim_nl, cutoff: float | None = None, **kwargs: object, # noqa: ARG002 ) -> None: """Initialize the particle life model. Args: sigma: Outer radius of the interaction. Defaults to 1.0. epsilon: Interaction amplitude (``A`` parameter). Defaults to 1.0. beta: Inner radius of the interaction. Defaults to 0.3. device: Device for computations. Defaults to CPU. dtype: Floating-point dtype. Defaults to torch.float32. compute_forces: Accepted for backward compatibility (always True). compute_stress: Whether to compute stress tensor. Defaults to False. per_atom_energies: Accepted for backward compatibility (ignored — this is a force-only model, energy is always zero). per_atom_stresses: Whether to return per-atom stresses. Defaults to False. use_neighbor_list: Accepted for backward compatibility (ignored). neighbor_list_fn: Neighbor-list constructor. Defaults to torchsim_nl. cutoff: Interaction cutoff. Defaults to 2.5 * sigma. **kwargs: Additional keyword arguments. """ self.sigma_param = sigma self.epsilon = epsilon self.beta = beta force_fn = functools.partial( particle_life_pair_force, A=epsilon, beta=beta, sigma=sigma ) super().__init__( force_fn=force_fn, cutoff=cutoff if cutoff is not None else 2.5 * sigma, device=device, dtype=dtype, compute_stress=compute_stress, per_atom_stresses=per_atom_stresses, neighbor_list_fn=neighbor_list_fn, reduce_to_half_list=False, )