Source code for torch_sim.quantities

"""Functions for computing physical quantities."""

from typing import TYPE_CHECKING

import torch

from torch_sim.units import MetalUnits


if TYPE_CHECKING:
    from torch_sim.integrators.md import MDState
    from torch_sim.optimizers import OptimState


# @torch.jit.script
[docs] def calc_kT( # noqa: N802 *, masses: torch.Tensor, momenta: torch.Tensor | None = None, velocities: torch.Tensor | None = None, system_idx: torch.Tensor | None = None, dof_per_system: torch.Tensor | None = None, ) -> torch.Tensor: """Calculate temperature in energy units from momenta/velocities and masses. Args: momenta (torch.Tensor): Particle momenta, shape (n_particles, n_dim) masses (torch.Tensor): Particle masses, shape (n_particles,) velocities (torch.Tensor | None): Particle velocities, shape (n_particles, n_dim) system_idx (torch.Tensor | None): Optional tensor indicating system membership of each particle dof_per_system (torch.Tensor | None): Optional tensor indicating degrees of freedom per system Returns: torch.Tensor: Scalar temperature value """ if not ((momenta is not None) ^ (velocities is not None)): raise ValueError("Must pass either one of momenta or velocities") if momenta is None and velocities is not None: # If velocity provided, calculate mv^2 squared_term = torch.square(velocities) * masses.unsqueeze(-1) elif momenta is not None and velocities is None: # If momentum provided, calculate v^2 = p^2/m^2 squared_term = torch.square(momenta) / masses.unsqueeze(-1) else: raise ValueError("Must pass either one of momenta or velocities") if system_idx is None: # Count total degrees of freedom dof = squared_term.numel() return torch.sum(squared_term) / dof # Sum squared terms for each system flattened_squared = torch.sum(squared_term, dim=-1) # Count degrees of freedom per system system_sizes = torch.bincount(system_idx) if dof_per_system is None: dof_per_system = system_sizes * squared_term.shape[-1] # Calculate temperature per system system_sums = torch.segment_reduce( flattened_squared, reduce="sum", lengths=system_sizes ) return system_sums / dof_per_system
[docs] def calc_temperature( *, masses: torch.Tensor, momenta: torch.Tensor | None = None, velocities: torch.Tensor | None = None, system_idx: torch.Tensor | None = None, dof_per_system: torch.Tensor | None = None, units: MetalUnits = MetalUnits.temperature, ) -> torch.Tensor: """Calculate temperature from momenta/velocities and masses. Args: momenta (torch.Tensor): Particle momenta, shape (n_particles, n_dim) masses (torch.Tensor): Particle masses, shape (n_particles,) velocities (torch.Tensor | None): Particle velocities, shape (n_particles, n_dim) system_idx (torch.Tensor | None): Optional tensor indicating system membership of each particle dof_per_system (torch.Tensor | None): Optional tensor indicating degrees of freedom per system units (object): Units to return the temperature in Returns: torch.Tensor: Temperature value in specified units (default, K) """ kT = calc_kT( masses=masses, momenta=momenta, velocities=velocities, system_idx=system_idx, dof_per_system=dof_per_system, ) return kT / units
# @torch.jit.script
[docs] def calc_kinetic_energy( *, masses: torch.Tensor, momenta: torch.Tensor | None = None, velocities: torch.Tensor | None = None, system_idx: torch.Tensor | None = None, ) -> torch.Tensor: """Computes the kinetic energy of a system. Args: momenta (torch.Tensor): Particle momenta, shape (n_particles, n_dim) masses (torch.Tensor): Particle masses, shape (n_particles,) velocities (torch.Tensor | None): Particle velocities, shape (n_particles, n_dim) system_idx (torch.Tensor | None): Optional tensor indicating system membership of each particle Returns: If system_idx is None: Scalar tensor containing the total kinetic energy If system_idx is provided: Tensor of kinetic energies per system """ if not ((momenta is not None) ^ (velocities is not None)): raise ValueError("Must pass either one of momenta or velocities") if momenta is None and velocities is not None: # Using velocities squared_term = torch.square(velocities) * masses.unsqueeze(-1) elif momenta is not None and velocities is None: # Using momenta squared_term = torch.square(momenta) / masses.unsqueeze(-1) else: raise ValueError("Must pass either one of momenta or velocities") if system_idx is None: return 0.5 * torch.sum(squared_term) flattened_squared = torch.sum(squared_term, dim=-1) return 0.5 * torch.segment_reduce( flattened_squared, reduce="sum", lengths=torch.bincount(system_idx) )
[docs] def get_pressure( stress: torch.Tensor, kinetic_energy: float | torch.Tensor, volume: torch.Tensor, dim: int = 3, ) -> torch.Tensor: """Compute the pressure from the stress tensor. The stress tensor is defined as 1/volume * dU/de_ij So the pressure is -1/volume * trace(dU/de_ij) """ return 1 / dim * ((2 * kinetic_energy / volume) - torch.einsum("...ii", stress))
[docs] def calc_heat_flux( momenta: torch.Tensor | None, masses: torch.Tensor, velocities: torch.Tensor | None, energies: torch.Tensor, stresses: torch.Tensor, batch: torch.Tensor | None = None, *, # Force keyword arguments for booleans is_centroid_stress: bool = False, is_virial_only: bool = False, ) -> torch.Tensor: r"""Calculate the heat flux vector. Computes the microscopic heat flux, :math:`\mathbf{J}` defined as: .. math:: \mathbf{J} = \mathbf{J}^c + \mathbf{J}^v where the convective part :math:`\mathbf{J}^c` and virial part :math:`\mathbf{J}^v` are: .. math:: \mathbf{J}^c &= \sum_i \epsilon_i \mathbf{v}_i \\ \mathbf{J}^v &= \sum_i \sum_j \mathbf{S}_{ij} \cdot \mathbf{v}_j where :math:`\epsilon_i` is the per-atom energy (p.e. + k.e.), :math:`\mathbf{v}_i` is velocity, and :math:`\mathbf{S}_{ij}` is the per-atom stress tensor. Args: momenta: Particle momenta, shape (n_particles, n_dim) masses: Particle masses, shape (n_particles,) velocities: Particle velocities, shape (n_particles, n_dim) energies: Per-atom energies (p.e. + k.e.), shape (n_particles,) stresses: Per-atom stress tensor components: - If is_centroid_stress=False: shape (n_particles, 6) for :math:`[\sigma_{xx}, \sigma_{yy}, \sigma_{zz}, \sigma_{xy}, \sigma_{xz}, \sigma_{yz}]` - If is_centroid_stress=True: shape (n_particles, 9) for :math:`[\mathbf{r}_{ix}f_{ix}, \mathbf{r}_{iy}f_{iy}, \mathbf{r}_{iz}f_{iz}, \mathbf{r}_{ix}f_{iy}, \mathbf{r}_{ix}f_{iz}, \mathbf{r}_{iy}f_{iz}, \mathbf{r}_{iy}f_{ix}, \mathbf{r}_{iz}f_{ix}, \mathbf{r}_{iz}f_{iy}]` batch: Optional tensor indicating system membership is_centroid_stress: Whether stress uses centroid formulation is_virial_only: If True, returns only virial part :math:`\mathbf{J}^v` Returns: Heat flux vector of shape (3,) or (n_systems, 3) """ if momenta is not None and velocities is not None: raise ValueError("Must pass either momenta or velocities, not both") if momenta is None and velocities is None: raise ValueError("Must pass either momenta or velocities") # Deduce velocities if velocities is None: velocities = momenta / masses.unsqueeze(-1) convective_flux = energies.unsqueeze(-1) * velocities # Calculate virial flux if is_centroid_stress: # Centroid formulation: r_i[x,y,z] . f_i[x,y,z] virial_x = -( stresses[:, 0] * velocities[:, 0] # r_ix.f_ix.v_x + stresses[:, 3] * velocities[:, 1] # r_ix.f_iy.v_y + stresses[:, 4] * velocities[:, 2] # r_ix.f_iz.v_z ) virial_y = -( stresses[:, 6] * velocities[:, 0] # r_iy.f_ix.v_x + stresses[:, 1] * velocities[:, 1] # r_iy.f_iy.v_y + stresses[:, 5] * velocities[:, 2] # r_iy.f_iz.v_z ) virial_z = -( stresses[:, 7] * velocities[:, 0] # r_iz.f_ix.v_x + stresses[:, 8] * velocities[:, 1] # r_iz.f_iy.v_y + stresses[:, 2] * velocities[:, 2] # r_iz.f_iz.v_z ) else: # Standard stress tensor components virial_x = -( stresses[:, 0] * velocities[:, 0] # s_xx.v_x + stresses[:, 3] * velocities[:, 1] # s_xy.v_y + stresses[:, 4] * velocities[:, 2] # s_xz.v_z ) virial_y = -( stresses[:, 3] * velocities[:, 0] # s_xy.v_x + stresses[:, 1] * velocities[:, 1] # s_yy.v_y + stresses[:, 5] * velocities[:, 2] # s_yz.v_z ) virial_z = -( stresses[:, 4] * velocities[:, 0] # s_xz.v_x + stresses[:, 5] * velocities[:, 1] # s_yz.v_y + stresses[:, 2] * velocities[:, 2] # s_zz.v_z ) virial_flux = torch.stack([virial_x, virial_y, virial_z], dim=-1) if batch is None: # All atoms virial_sum = torch.sum(virial_flux, dim=0) if is_virial_only: return virial_sum conv_sum = torch.sum(convective_flux, dim=0) return conv_sum + virial_sum # All atoms in each system n_systems = int(torch.max(batch) + 1) virial_sum = torch.zeros( (n_systems, 3), device=velocities.device, dtype=velocities.dtype ) virial_sum.scatter_add_(0, batch.unsqueeze(-1).expand(-1, 3), virial_flux) if is_virial_only: return virial_sum conv_sum = torch.zeros( (n_systems, 3), device=velocities.device, dtype=velocities.dtype ) conv_sum.scatter_add_(0, batch.unsqueeze(-1).expand(-1, 3), convective_flux) return conv_sum + virial_sum
[docs] def system_wise_max_force[T: MDState | OptimState](state: T) -> torch.Tensor: """Compute the maximum force per system. Args: state (SimState): State to compute the maximum force per system for. Returns: torch.Tensor: Maximum forces per system """ system_wise_max_force = torch.zeros( state.n_systems, device=state.device, dtype=state.dtype ) max_forces = state.forces.norm(dim=1) return system_wise_max_force.scatter_reduce( dim=0, index=state.system_idx, src=max_forces, reduce="amax" )