Source code for torch_sim.state

"""The core state representation.

The main SimState class represents atomistic systems with support for batched
operations and conversion to/from various atomistic formats.
"""

import copy
import importlib
import typing
from collections import defaultdict
from collections.abc import Generator, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self

import torch
from torch._prims_common import DeviceLikeType

import torch_sim as ts
from torch_sim.typing import PRNGLike, StateLike


if TYPE_CHECKING:
    from ase import Atoms
    from phonopy.structure.atoms import PhonopyAtoms
    from pymatgen.core import Structure

from torch_sim.constraints import (
    Constraint,
    _dof_per_system,
    merge_constraints,
    validate_constraints,
)


[docs] def coerce_prng(rng: PRNGLike, device: DeviceLikeType | None) -> torch.Generator: """Coerce an int seed or existing Generator into a ``torch.Generator``. Args: rng: An int seed, an existing ``torch.Generator``, or None for unseeded. device: Target device for the returned generator. Returns: A ``torch.Generator`` on *device*. """ if isinstance(rng, torch.Generator): if rng.device == device: return rng new_generator = torch.Generator(device=device) new_generator.set_state(rng.get_state()) return new_generator if isinstance(rng, int): generator = torch.Generator(device=device) generator.manual_seed(rng) return generator if rng is None: return torch.Generator(device=device) raise ValueError(f"Invalid rng type: {type(rng)}")
[docs] def require_system_idx(system_idx: torch.Tensor | None) -> torch.Tensor: """Return non-null system indices or raise with a clear invariant message.""" if system_idx is None: raise RuntimeError("system_idx is set by SimState.__post_init__") return system_idx
[docs] @dataclass(kw_only=True) class SimState: """State representation for atomistic systems with batched operations support. Contains the fundamental properties needed to describe an atomistic system: positions, masses, unit cell, periodic boundary conditions, and atomic numbers. Supports batched operations where multiple atomistic systems can be processed simultaneously, managed through system indices. States support slicing, cloning, splitting, popping, and movement to other data structures or devices. Slicing is supported through fancy indexing, e.g. `state[[0, 1, 2]]` will return a new state containing only the first three systems. The other operations are available through the `pop`, `split`, `clone`, and `to` methods. Attributes: positions (torch.Tensor): Atomic positions with shape (n_atoms, 3) masses (torch.Tensor): Atomic masses with shape (n_atoms,) cell (torch.Tensor): Unit cell vectors with shape (n_systems, 3, 3). Note that we use a column vector convention, i.e. the cell vectors are stored as `[[a1, b1, c1], [a2, b2, c2], [a3, b3, c3]]` as opposed to the row vector convention `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]` used by ASE. pbc (bool | list[bool] | torch.Tensor): indicates periodic boundary conditions in each axis. If a boolean is provided, all axes are assumed to have the same periodic boundary conditions. atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) system_idx (torch.Tensor): Maps each atom index to its system index. Has shape (n_atoms,), must be unique consecutive integers starting from 0. constraints (list["Constraint"] | None): List of constraints applied to the system. Constraints affect degrees of freedom and modify positions. Properties: wrap_positions (torch.Tensor): Positions wrapped according to periodic boundary conditions device (torch.device): Device of the positions tensor dtype (torch.dtype): Data type of the positions tensor n_atoms (int): Total number of atoms across all systems n_systems (int): Number of unique systems in the system Notes: - positions, masses, and atomic_numbers must have shape (n_atoms, 3). - cell must be in the conventional matrix form. - system indices must be unique consecutive integers starting from 0. Examples: >>> state = initialize_state( ... [ase_atoms_1, ase_atoms_2, ase_atoms_3], device, dtype ... ) >>> state.n_systems 3 >>> new_state = state[[0, 1]] >>> new_state.n_systems 2 >>> cloned_state = state.clone() """ positions: torch.Tensor masses: torch.Tensor cell: torch.Tensor pbc: torch.Tensor # coerced from bool/list[bool] by __setattr__ atomic_numbers: torch.Tensor charge: torch.Tensor | None = field(default=None) spin: torch.Tensor | None = field(default=None) system_idx: torch.Tensor = field(default=None) # type: ignore[assignment] # coerced from None by __setattr__ _constraints: list["Constraint"] = field(default_factory=lambda: []) # noqa: PIE807 _rng: PRNGLike = field(default=None, repr=False) if TYPE_CHECKING: def __init__( # noqa: D107 self, *, positions: torch.Tensor, masses: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor | list[bool] | bool, atomic_numbers: torch.Tensor, charge: torch.Tensor | None = None, spin: torch.Tensor | None = None, system_idx: torch.Tensor | None = None, _constraints: list[Constraint] | None = None, _rng: PRNGLike = None, ) -> None: ... _atom_attributes: ClassVar[set[str]] = { "positions", "masses", "atomic_numbers", "system_idx", } _system_attributes: ClassVar[set[str]] = {"cell", "charge", "spin"} _global_attributes: ClassVar[set[str]] = {"pbc", "_rng"} @property def rng(self) -> torch.Generator: """Lazily initialized per-state RNG on the state's device.""" self._rng = coerce_prng(self._rng, self.device) return self._rng @rng.setter def rng(self, value: PRNGLike) -> None: self._rng = value def __setattr__(self, name: str, value: object) -> None: """Coerce pbc and system_idx on every assignment.""" if name == "pbc" and not isinstance(value, torch.Tensor): if isinstance(value, bool): value = [value] * 3 value = torch.tensor(value, dtype=torch.bool, device=self.device) elif name == "system_idx": if value is None: if hasattr(self, "positions"): value = torch.zeros( self.n_atoms, device=self.device, dtype=torch.int64 ) elif isinstance(value, torch.Tensor): _, counts = torch.unique_consecutive(value, return_counts=True) if not torch.all(counts == torch.bincount(value)): raise ValueError("System indices must be unique consecutive integers") super().__setattr__(name, value) def __post_init__(self) -> None: # noqa: C901 """Initialize the SimState and validate the arguments.""" # Check that positions, masses and atomic numbers have compatible shapes shapes = [ getattr(self, attr).shape[0] for attr in ("positions", "masses", "atomic_numbers") ] if len(set(shapes)) > 1: raise ValueError( f"Incompatible shapes: positions {shapes[0]}, " f"masses {shapes[1]}, atomic_numbers {shapes[2]}" ) # Get n_systems from system_idx (now guaranteed to be non-None) _, counts = torch.unique_consecutive(self.system_idx, return_counts=True) n_systems = len(counts) if self.constraints: validate_constraints(self.constraints, state=self) if self.charge is None: self.charge = torch.zeros(n_systems, device=self.device, dtype=self.dtype) elif self.charge.shape[0] != n_systems: raise ValueError(f"Charge must have shape (n_systems={n_systems},)") if self.spin is None: self.spin = torch.zeros(n_systems, device=self.device, dtype=self.dtype) elif self.spin.shape[0] != n_systems: raise ValueError(f"Spin must have shape (n_systems={n_systems},)") if self.cell.ndim != 3: self.cell = self.cell.unsqueeze(0) if self.cell.shape[-2:] != (3, 3): raise ValueError("Cell must have shape (n_systems, 3, 3)") if self.cell.shape[0] != n_systems: raise ValueError( f"Cell must have shape (n_systems={n_systems}, 3, 3), " f"got {self.cell.shape}" ) # if devices aren't all the same, raise an error, in a clean way devices = { attr: getattr(self, attr).device for attr in ( "positions", "masses", "cell", "atomic_numbers", "pbc", "system_idx", ) } if len(set(devices.values())) > 1: raise ValueError("All tensors must be on the same device") @classmethod def _get_all_attributes(cls) -> set[str]: """Get all attributes of the SimState.""" return ( cls._atom_attributes | cls._system_attributes | cls._global_attributes | {"_constraints"} ) @property def wrap_positions(self) -> torch.Tensor: """Atomic positions wrapped according to periodic boundary conditions if pbc=True, otherwise returns unwrapped positions with shape (n_atoms, 3). """ if not self.pbc.any(): return self.positions return ts.transforms.pbc_wrap_batched( self.positions, self.cell, self.system_idx, self.pbc ) @property def device(self) -> torch.device: """The device where the tensor data is located.""" return self.positions.device @property def dtype(self) -> torch.dtype: """The data type of the positions tensor.""" return self.positions.dtype @property def n_atoms(self) -> int: """Total number of atoms in the system across all systems.""" return self.positions.shape[0] @property def n_atoms_per_system(self) -> torch.Tensor: """Number of atoms per system.""" return self.system_idx.bincount() @property def n_systems(self) -> int: """Number of systems in the system.""" return torch.unique(self.system_idx).shape[0] @property def volume(self) -> torch.Tensor: """Volume of the system.""" return torch.det(self.cell) @property def attributes(self) -> dict[str, Any]: """Get all public attributes of the state.""" return {attr: getattr(self, attr) for attr in self._get_all_attributes()} @property def column_vector_cell(self) -> torch.Tensor: """Unit cell following the column vector convention.""" return self.cell @column_vector_cell.setter def column_vector_cell(self, value: torch.Tensor) -> None: """Set the unit cell from value following the column vector convention. Args: value: The unit cell as a column vector """ self.cell = value @property def row_vector_cell(self) -> torch.Tensor: """Unit cell following the row vector convention.""" return self.cell.mT @row_vector_cell.setter def row_vector_cell(self, value: torch.Tensor) -> None: """Set the unit cell from value following the row vector convention. Args: value: The unit cell as a row vector """ self.cell = value.mT
[docs] def set_constrained_positions(self, new_positions: torch.Tensor) -> None: """Set the positions and apply constraints if they exist. Args: new_positions: New positions tensor with shape (n_atoms, 3) """ # Apply constraints if they exist for constraint in self.constraints: constraint.adjust_positions(self, new_positions) self.positions = new_positions
[docs] def set_constrained_cell( self, new_cell: torch.Tensor, scale_atoms: bool = False, # noqa: FBT001, FBT002 ) -> None: """Set the cell, apply constraints, and optionally scale atomic positions. Args: new_cell: New cell tensor with shape (n_systems, 3, 3) in column vector convention scale_atoms: Whether to scale atomic positions to preserve fractional coordinates. Defaults to False. """ for constraint in self.constraints: constraint.adjust_cell(self, new_cell) self.set_cell(new_cell, scale_atoms=scale_atoms)
@property def constraints(self) -> list[Constraint]: """Get the constraints for the SimState. Returns: list["Constraint"]: List of constraints applied to the system. """ return self._constraints @constraints.setter def constraints(self, constraints: list[Constraint] | Constraint) -> None: """Set the constraints for the SimState. Args: constraints (list["Constraint"] | None): List of constraints to apply. If None, no constraints are applied. Raises: ValueError: If constraints are invalid or span multiple systems """ # check it is a list if isinstance(constraints, Constraint): constraints = [constraints] # Validate new constraints before adding validate_constraints(constraints, state=self) self._constraints = constraints
[docs] def set_cell( self, cell: torch.Tensor, scale_atoms: bool = False, # noqa: FBT001, FBT002 ) -> None: """Set the unit cell of the system, optionally scaling atomic positions. Torch version of ASE Atoms.set_cell. Args: cell (torch.Tensor): New unit cell with shape (n_systems, 3, 3) scale_atoms (bool, optional): Whether to scale atomic positions according to the change in cell. Defaults to False. """ if cell.shape != self.cell.shape: raise ValueError( f"New cell must have shape {self.cell.shape}, got {cell.shape}" ) if scale_atoms: M = torch.linalg.solve(self.cell.mT, cell.mT) self.positions = torch.bmm( self.positions.unsqueeze(1), M[self.system_idx] ).squeeze(1) self.cell = cell
[docs] def get_number_of_degrees_of_freedom(self) -> torch.Tensor: """Calculate degrees of freedom accounting for constraints. Returns: torch.Tensor: Number of degrees of freedom per system, with shape (n_systems,). Each system starts with 3 * n_atoms_per_system degrees of freedom, minus any degrees removed by constraints. Raises: ValueError: If any system has zero or negative degrees of freedom. This strict behavior is used by simulation routines that require physically valid DOF. """ dof_per_system = _dof_per_system(self, self.constraints) if (dof_per_system <= 0).any(): raise ValueError("Degrees of freedom cannot be zero or negative") return dof_per_system
@staticmethod def _clone_attr(value: object) -> object: """Clone one attribute value, preserving RNG generator state.""" if isinstance(value, torch.Tensor): return value.clone() if isinstance(value, torch.Generator): cloned_generator = torch.Generator(device=value.device) cloned_generator.set_state(value.get_state()) return cloned_generator return copy.deepcopy(value)
[docs] def clone(self) -> Self: """Create a deep copy of the SimState. Creates a new SimState object with identical but independent tensors, allowing modification without affecting the original. Returns: SimState: A new SimState object with the same properties as the original """ attrs: dict[str, Any] = { attr_name: self._clone_attr(attr_value) for attr_name, attr_value in self.attributes.items() } return type(self)(**attrs)
[docs] @classmethod def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: """Create a new state from an existing state with additional attributes. This method copies attributes from the source state that are valid for the target state class, and adds any additional attributes needed. It supports upcasting (SimState -> MDState), downcasting (MDState -> SimState), and cross-casting (MDState -> OptimState) between state types. Args: state: Source state to copy base attributes from **additional_attrs: Additional attributes required by the target state class Returns: New state of the target class with copied and additional attributes Example: >>> from torch_sim.integrators.md import MDState >>> md_state = MDState.from_state( ... sim_state, ... energy=model_output["energy"], ... forces=model_output["forces"], ... momenta=torch.zeros_like(sim_state.positions), ... ) """ attrs = {} for attr_name, attr_value in state.attributes.items(): if attr_name in cls._get_all_attributes(): attrs[attr_name] = cls._clone_attr(attr_value) # Add/override with additional attributes attrs.update(additional_attrs) return cls(**attrs)
[docs] def to_atoms(self) -> list["Atoms"]: """Convert the SimState to a list of ASE Atoms objects. Returns: list[Atoms]: A list of ASE Atoms objects, one per system """ return ts.io.state_to_atoms(self)
[docs] def to_structures(self) -> list["Structure"]: """Convert the SimState to a list of pymatgen Structure objects. Returns: list[Structure]: A list of pymatgen Structure objects, one per system """ return ts.io.state_to_structures(self)
[docs] def to_phonopy(self) -> list["PhonopyAtoms"]: """Convert the SimState to a list of PhonopyAtoms objects. Returns: list[PhonopyAtoms]: A list of PhonopyAtoms objects, one per system """ return ts.io.state_to_phonopy(self)
[docs] def split(self) -> list[Self]: """Split the SimState into a list of single-system SimStates. Divides the current state into separate states, each containing a single system, preserving all properties appropriately for each system. Returns: list[SimState]: A list of SimState objects, one per system """ return _split_state(self)
[docs] def pop(self, system_indices: int | list[int] | slice | torch.Tensor) -> list[Self]: """Pop off states with the specified system indices. This method modifies the original state object by removing the specified systems and returns the removed systems as separate SimState objects. Args: system_indices (int | list[int] | slice | torch.Tensor): The system indices to pop Returns: list[SimState]: Popped SimState objects, one per system index Notes: This method modifies the original SimState in-place. """ system_indices = _normalize_system_indices( system_indices, self.n_systems, self.device ) # Get the modified state and popped states modified_state, popped_states = _pop_states(self, system_indices) # Update all attributes of self with the modified state's attributes for attr_name, attr_value in modified_state.attributes.items(): setattr(self, attr_name, attr_value) return popped_states
[docs] def to( self, device: torch.device | None = None, dtype: torch.dtype | None = None ) -> Self: """Convert the SimState to a new device and/or data type. Args: device (torch.device, optional): The target device. Defaults to current device. dtype (torch.dtype, optional): The target data type. Defaults to current dtype. Returns: SimState: A new SimState with tensors on the specified device and dtype """ return _state_to_device(self, device, dtype)
def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> Self: """Enable standard Python indexing syntax for slicing batches. Args: system_indices (int | list[int] | slice | torch.Tensor): The system indices to include Returns: SimState: A new SimState containing only the specified systems """ # TODO: need to document that slicing is supported # Reuse the existing slice method system_indices = _normalize_system_indices( system_indices, self.n_systems, self.device ) return _slice_state(self, system_indices) def __init_subclass__(cls, **kwargs) -> None: """Enforce that all derived states cannot have tensor attributes that can also be None. This is because torch.concatenate cannot concat between a tensor and a None. Also enforce all of child classes's attributes are specified in _atom_attributes, _system_attributes, or _global_attributes. """ cls._assert_no_tensor_attributes_can_be_none() cls._assert_all_attributes_have_defined_scope() super().__init_subclass__(**kwargs) @classmethod def _assert_no_tensor_attributes_can_be_none(cls) -> None: # We need to use get_type_hints to correctly inspect the types # exceptions exist because the type hint doesn't actually reflect the real type # (since we change their type in the post_init) exceptions = {"system_idx", "charge", "spin"} type_hints = typing.get_type_hints(cls) for attr_name, attr_type_hint in type_hints.items(): origin = typing.get_origin(attr_type_hint) if attr_name in exceptions: continue is_union = origin is typing.Union if not is_union and origin is not None: # For Python 3.10+ `|` syntax, origin is types.UnionType # We check by name to be robust against module reloading/patching issues is_union = origin.__module__ == "types" and origin.__name__ == "UnionType" if is_union: args = typing.get_args(attr_type_hint) if torch.Tensor in args and type(None) in args: raise TypeError( f"Attribute '{attr_name}' in class '{cls.__name__}' is not " "allowed to be of type 'torch.Tensor | None' because torch.cat " "cannot concatenate between a tensor and a None. Please default " "the tensor with dummy values and track the 'None' case." ) @classmethod def _assert_all_attributes_have_defined_scope(cls) -> None: all_defined_attributes = ( cls._atom_attributes | cls._system_attributes | cls._global_attributes ) # 1) assert that no attribute is defined twice in all_defined_attributes duplicates = ( (cls._atom_attributes & cls._system_attributes) | (cls._atom_attributes & cls._global_attributes) | (cls._system_attributes & cls._global_attributes) ) if duplicates: raise TypeError( f"Attributes {duplicates} are declared multiple times in {cls.__name__} " "in _atom_attributes, _system_attributes, or _global_attributes" ) # 2) assert that all attributes are defined in all_defined_attributes all_annotations = {} for parent_cls in cls.mro(): if hasattr(parent_cls, "__annotations__"): all_annotations.update(parent_cls.__annotations__) # Get class namespace attributes (methods, properties, class vars with values) class_namespace = vars(cls) attributes_to_check = set(class_namespace.keys()) | set(all_annotations.keys()) for attr_name in attributes_to_check: is_special_attribute = attr_name.startswith("__") is_private_attribute = attr_name.startswith("_") and not is_special_attribute is_property = attr_name in class_namespace and isinstance( class_namespace.get(attr_name), property ) is_method = hasattr(cls, attr_name) and callable(getattr(cls, attr_name)) is_class_variable = ( # Note: _atom_attributes, _system_attributes, and _global_attributes # are all class variables typing.get_origin(all_annotations.get(attr_name)) is typing.ClassVar ) if ( is_special_attribute or is_private_attribute or is_property or is_method or is_class_variable ): continue if attr_name not in all_defined_attributes: raise TypeError( f"Attribute '{attr_name}' is not defined in {cls.__name__} in any " "of _atom_attributes, _system_attributes, or _global_attributes" )
[docs] @dataclass(kw_only=True) class DeformGradMixin: """Mixin for states that support deformation gradients.""" reference_cell: torch.Tensor _system_attributes: ClassVar[set[str]] = {"reference_cell"} if TYPE_CHECKING: # define this under a TYPE_CHECKING block to avoid it being included in the # dataclass __init__ during runtime row_vector_cell: torch.Tensor @property def reference_row_vector_cell(self) -> torch.Tensor: """Get the original unit cell in terms of row vectors.""" return self.reference_cell.mT @reference_row_vector_cell.setter def reference_row_vector_cell(self, value: torch.Tensor) -> None: """Set the original unit cell in terms of row vectors.""" self.reference_cell = value.mT @staticmethod def _deform_grad( reference_row_vector_cell: torch.Tensor, row_vector_cell: torch.Tensor ) -> torch.Tensor: """Calculate the deformation gradient from original cell to current cell. Returns: The deformation gradient """ return torch.linalg.solve(reference_row_vector_cell, row_vector_cell).transpose( -2, -1 )
[docs] def deform_grad(self) -> torch.Tensor: """Calculate the deformation gradient from original cell to current cell. Returns: The deformation gradient """ return self._deform_grad(self.reference_row_vector_cell, self.row_vector_cell)
def _normalize_system_indices( system_indices: int | Sequence[int] | slice | torch.Tensor, n_systems: int, device: torch.device, ) -> torch.Tensor: """Normalize system indices to handle negative indices and different input types. Converts various system index representations to a consistent tensor format, handling negative indices in the Python style (counting from the end). Args: system_indices (int | list[int] | slice | torch.Tensor): The system indices to normalize n_systems (int): Total number of systems in the system device (torch.device): Device to place the output tensor on Returns: torch.Tensor: Normalized system indices as a tensor Raises: TypeError: If system_indices is of an unsupported type """ if isinstance(system_indices, int): # Handle negative integer indexing if system_indices < 0: system_indices = n_systems + system_indices return torch.tensor([system_indices], device=device) if isinstance(system_indices, list): # Handle negative indices in lists normalized = [idx if idx >= 0 else n_systems + idx for idx in system_indices] return torch.tensor(normalized, device=device) if isinstance(system_indices, slice): # Let PyTorch handle the slice conversion with negative indices return torch.arange(n_systems, device=device)[system_indices] if isinstance(system_indices, torch.Tensor): # Handle negative indices in tensors return torch.where(system_indices < 0, n_systems + system_indices, system_indices) raise TypeError(f"Unsupported index type: {type(system_indices)}") def _state_to_device[T: SimState]( state: T, device: torch.device | None = None, dtype: torch.dtype | None = None ) -> T: """Convert the SimState to a new device and dtype. Creates a new SimState with all tensors moved to the specified device and with the specified data type. Args: state (SimState): The state to convert device (torch.device, optional): The target device. Defaults to current device. dtype (torch.dtype, optional): The target data type. Defaults to current dtype. Returns: SimState: A new SimState with tensors on the specified device and dtype """ if device is None: device = state.device if dtype is None: dtype = state.dtype attrs = state.attributes for attr_name, attr_value in attrs.items(): if isinstance(attr_value, torch.Tensor): attrs[attr_name] = attr_value.to(device=device) elif isinstance(attr_value, torch.Generator): attrs[attr_name] = coerce_prng(attr_value, device) if dtype is not None: attrs["positions"] = attrs["positions"].to(dtype=dtype) attrs["masses"] = attrs["masses"].to(dtype=dtype) attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) return type(state)(**attrs)
[docs] def get_attrs_for_scope( state: SimState, scope: Literal["per-atom", "per-system", "global"] ) -> Generator[tuple[str, Any], None, None]: """Get attributes for a given scope. Args: state (SimState): The state to get attributes for scope (Literal["per-atom", "per-system", "global"]): The scope to get attributes for Returns: Generator[tuple[str, Any], None, None]: A generator of attribute names and values """ match scope: case "per-atom": attr_names = state._atom_attributes # noqa: SLF001 case "per-system": attr_names = state._system_attributes # noqa: SLF001 case "global": attr_names = state._global_attributes # noqa: SLF001 case _: raise ValueError(f"Unknown scope: {scope!r}") for attr_name in attr_names: yield attr_name, getattr(state, attr_name)
def _filter_attrs_by_index( state: SimState, atom_indices: torch.Tensor, system_indices: torch.Tensor, ) -> dict: """Filter attributes by integer atom and system indices. Selects subsets of attributes using integer index tensors, preserving the order specified by the indices. Args: state (SimState): The state to filter atom_indices (torch.Tensor): Integer indices of atoms to include system_indices (torch.Tensor): Integer indices of systems to include Returns: dict: Filtered attributes with appropriate handling for each scope """ filtered_attrs = dict(get_attrs_for_scope(state, "global")) # Build inverse maps: old index → new position in the output. # These are used for both constraint remapping and per-atom system_idx. atom_remap = torch.empty(state.n_atoms, dtype=torch.long, device=state.device) atom_remap[atom_indices] = torch.arange(len(atom_indices), device=state.device) if len(system_indices) == 0: system_remap = torch.empty(0, device=state.device, dtype=torch.long) else: max_idx = int(system_indices.max().item()) + 1 system_remap = torch.empty(max_idx, device=state.device, dtype=torch.long) system_remap[system_indices] = torch.arange( len(system_indices), device=state.device ) # select_constraint uses boolean masks (which lose ordering), so we must # remap constraint atom_idx / system_idx afterward to match the actual # output order given by atom_indices / system_indices. atom_mask = torch.zeros(state.n_atoms, dtype=torch.bool, device=state.device) atom_mask[atom_indices] = True system_mask = torch.zeros(state.n_systems, dtype=torch.bool, device=state.device) system_mask[system_indices] = True filtered_attrs["_constraints"] = [ c for con in copy.deepcopy(state.constraints) if (c := con.select_constraint(atom_mask, system_mask)) ] new_atom_idx = atom_remap[torch.where(atom_mask)[0]] new_system_idx = system_remap[torch.where(system_mask)[0]] for c in filtered_attrs["_constraints"]: if hasattr(c, "atom_idx") and isinstance(c.atom_idx, torch.Tensor): c.atom_idx = new_atom_idx[c.atom_idx] # ty: ignore[invalid-assignment] if hasattr(c, "system_idx") and isinstance(c.system_idx, torch.Tensor): c.system_idx = new_system_idx[c.system_idx] # ty: ignore[invalid-assignment] for name, val in get_attrs_for_scope(state, "per-atom"): filtered_attrs[name] = ( system_remap[val[atom_indices]] if name == "system_idx" else val[atom_indices] ) for name, val in get_attrs_for_scope(state, "per-system"): filtered_attrs[name] = ( val[system_indices] if isinstance(val, torch.Tensor) else val ) return filtered_attrs def _split_state[T: SimState](state: T) -> list[T]: """Split a SimState into a list of states, each containing a single system. Divides a multi-system state into individual single-system states, preserving appropriate properties for each system. Args: state (SimState): The SimState to split Returns: list[SimState]: A list of SimState objects, each containing a single system """ system_sizes = state.n_atoms_per_system.tolist() split_per_atom = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): if attr_name != "system_idx": split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) split_per_system = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): if isinstance(attr_value, torch.Tensor): split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) else: # Non-tensor attributes are replicated for each split split_per_system[attr_name] = [attr_value] * state.n_systems global_attrs = dict(get_attrs_for_scope(state, "global")) # Create a state for each system states: list[T] = [] n_systems = len(system_sizes) zero_tensor = torch.tensor([0], device=state.device, dtype=torch.int64) cumsum_atoms = torch.cat((zero_tensor, torch.cumsum(state.n_atoms_per_system, dim=0))) for sys_idx in range(n_systems): # Build per-system attributes (padded attributes stay padded for consistency) per_system_dict = { attr_name: split_per_system[attr_name][sys_idx] for attr_name in split_per_system } system_attrs = { # Create a system tensor with all zeros for this system "system_idx": torch.zeros( system_sizes[sys_idx], device=state.device, dtype=torch.int64 ), # Add the split per-atom attributes **{ attr_name: split_per_atom[attr_name][sys_idx] for attr_name in split_per_atom }, # Add the split per-system attributes (with unpadding applied) **per_system_dict, # Add the global attributes **global_attrs, } start_idx = int(cumsum_atoms[sys_idx].item()) end_idx = int(cumsum_atoms[sys_idx + 1].item()) atom_idx = torch.arange(start_idx, end_idx, device=state.device) new_constraints: list[Constraint] = [] for constraint in state.constraints: sub = constraint.select_sub_constraint(atom_idx, sys_idx) if sub is not None: new_constraints.append(sub) system_attrs["_constraints"] = new_constraints states.append(type(state)(**system_attrs)) # ty: ignore[invalid-argument-type] return states def _pop_states[T: SimState]( state: T, pop_indices: list[int] | torch.Tensor ) -> tuple[T, list[T]]: """Pop off the states with the specified indices. Extracts and removes the specified system indices from the state. Args: state (SimState): The SimState to modify pop_indices (list[int] | torch.Tensor): The system indices to extract and remove Returns: tuple[SimState, list[SimState]]: A tuple containing: - The modified original state with specified systems removed - A list of the extracted SimStates, one per popped system Notes: Unlike the pop method, this function does not modify the input state. """ if len(pop_indices) == 0: return state, [] if isinstance(pop_indices, list): pop_indices = torch.tensor(pop_indices, device=state.device, dtype=torch.int64) # Derive keep/pop atom and system indices all_systems = torch.arange(state.n_systems, device=state.device) keep_system_indices = all_systems[~torch.isin(all_systems, pop_indices)] pop_atom_mask = torch.isin(state.system_idx, pop_indices) keep_atom_indices = torch.where(~pop_atom_mask)[0] pop_atom_indices = torch.where(pop_atom_mask)[0] # Filter attributes for keep and pop states keep_attrs = _filter_attrs_by_index(state, keep_atom_indices, keep_system_indices) pop_attrs = _filter_attrs_by_index(state, pop_atom_indices, pop_indices) # Create the keep state keep_state: T = type(state)(**keep_attrs) # Create and split the pop state pop_state: T = type(state)(**pop_attrs) pop_states = _split_state(pop_state) return keep_state, pop_states def _slice_state[T: SimState](state: T, system_indices: list[int] | torch.Tensor) -> T: """Slice a substate from the SimState containing only the specified system indices. Creates a new SimState containing only the specified systems, preserving the requested order. E.g., system_indices=[3, 1, 4] results in original systems 3, 1, 4 becoming new systems 0, 1, 2. Args: state (SimState): The state to slice system_indices (list[int] | torch.Tensor): System indices to include in the sliced state (order preserved in the result) Returns: SimState: A new SimState object containing only the specified systems Raises: ValueError: If system_indices is empty """ if isinstance(system_indices, list): system_indices = torch.tensor( system_indices, device=state.device, dtype=torch.int64 ) if len(system_indices) == 0: raise ValueError("system_indices cannot be empty") # Build atom indices in requested order (preserves system_indices order) system_indices = system_indices.reshape(-1) cumsum = torch.cat( (state.n_atoms_per_system.new_zeros(1), state.n_atoms_per_system.cumsum(0)) ) atom_indices = torch.cat( [ torch.arange(cumsum[i].item(), cumsum[i + 1].item(), device=state.device) for i in system_indices ] ) filtered_attrs = _filter_attrs_by_index(state, atom_indices, system_indices) return type(state)(**filtered_attrs)
[docs] def concatenate_states[T: SimState]( # noqa: C901, PLR0915 states: Sequence[T], device: torch.device | None = None ) -> T: """Concatenate a list of SimStates into a single SimState. Combines multiple states into a single state with multiple systems. Global properties are taken from the first state, and per-atom and per-system properties are concatenated. Args: states (Sequence[SimState]): A list of SimState objects to concatenate device (torch.device, optional): The device to place the concatenated state on. Defaults to the device of the first state. Returns: SimState: A new SimState containing all input states as separate systems Raises: ValueError: If states is empty TypeError: If not all states are of the same type """ if not states: raise ValueError("Cannot concatenate an empty list of states") # Get the first state to determine properties first_state = states[0] # Ensure all states are of the same class state_class = type(first_state) if not all(isinstance(state, state_class) for state in states): raise TypeError("All states must be of the same type") # Use the target device or default to the first state's device target_device = device or first_state.device # Initialize result with global properties from first state concatenated = dict(get_attrs_for_scope(first_state, "global")) # Pre-allocate lists for tensors to concatenate per_atom_tensors = defaultdict(list) per_system_tensors = defaultdict(list) new_system_indices = [] system_offset = 0 num_atoms_per_state = [] # Process all states in a single pass for state in states: # Move state to target device if needed if state.device != target_device: state = state.to(target_device) # Collect per-atom properties for prop, val in get_attrs_for_scope(state, "per-atom"): if prop == "system_idx": # skip system_idx, it will be handled below continue per_atom_tensors[prop].append(val) # Collect per-system properties for prop, val in get_attrs_for_scope(state, "per-system"): per_system_tensors[prop].append(val) # Update system indices num_systems = state.n_systems new_indices = state.system_idx + system_offset new_system_indices.append(new_indices) num_atoms_per_state.append(state.n_atoms) system_offset += num_systems # Concatenate collected tensors for prop, tensors in per_atom_tensors.items(): # if tensors: concatenated[prop] = torch.cat(tensors, dim=0) # Get padded attributes if defined on the state class padded_attrs = getattr(first_state, "_padded_system_attributes", set()) for prop, tensors in per_system_tensors.items(): # if tensors: if isinstance(tensors[0], torch.Tensor): # TODO(AG): Is there a clean way to handle this? if prop in padded_attrs: # Pad tensors to max size before concatenating # Detect tensor shape to determine padding strategy first_tensor = tensors[0] ndim = first_tensor.ndim if ndim == 3: # Shape [S, D, D] required for BFGS hessian # Pad last two dimensions max_size = max(t.shape[-1] for t in tensors) padded_tensors = [] for t in tensors: if t.shape[-1] < max_size: pad_size = max_size - t.shape[-1] t = torch.nn.functional.pad(t, (0, pad_size, 0, pad_size)) padded_tensors.append(t) concatenated[prop] = torch.cat(padded_tensors, dim=0) elif ndim == 4: # Shape [S, H, M, 3] required for L-BFGS history # Pad dimension 2 (M) to max, and dimension 1 (H) to max max_m = max(t.shape[2] for t in tensors) # max atoms dim max_h = max(t.shape[1] for t in tensors) # max history dim padded_tensors = [] for t in tensors: s_dim, h_dim, m_dim, last_dim = t.shape if h_dim == 0: # Special case: empty history, just create new shape t = torch.zeros( (s_dim, max_h, max_m, last_dim), device=t.device, dtype=t.dtype, ) elif m_dim < max_m or h_dim < max_h: pad_m = max_m - m_dim pad_h = max_h - h_dim # For [S, H, M, 3]: pad M (dim 2) and H (dim 1) t = torch.nn.functional.pad(t, (0, 0, 0, pad_m, 0, pad_h)) padded_tensors.append(t) concatenated[prop] = torch.cat(padded_tensors, dim=0) else: # Unknown shape, just concatenate without padding concatenated[prop] = torch.cat(tensors, dim=0) else: concatenated[prop] = torch.cat(tensors, dim=0) else: # Non-tensor attributes, take first one (they should all be identical) concatenated[prop] = tensors[0] # Concatenate system indices concatenated["system_idx"] = torch.cat(new_system_indices) # Merge constraints constraint_lists = [state.constraints for state in states] num_systems_per_state = [state.n_systems for state in states] constraints = merge_constraints( constraint_lists, torch.tensor(num_atoms_per_state, device=target_device), torch.tensor(num_systems_per_state, device=target_device), ) # Create a new instance of the same class return state_class(**concatenated, _constraints=constraints)
[docs] def initialize_state( system: StateLike, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> SimState: """Initialize state tensors from a atomistic system representation. Converts various atomistic system representations (ASE Atoms, pymatgen Structure, PhonopyAtoms, or existing SimState) to a SimState object. Args: system (StateLike): Input system to convert to state tensors device (torch.device): Device to create tensors on dtype (torch.dtype): Data type for tensor values Returns: SimState: State representation initialized from input system Raises: ValueError: If system type is not supported or if list items have inconsistent types """ # TODO: create a way to pass velocities from pmg and ase if isinstance(system, SimState): return system.clone().to(device, dtype) if isinstance(system, list | tuple) and all(isinstance(s, SimState) for s in system): system: list[SimState] = typing.cast("list[SimState]", system) return ts.concatenate_states(system) converters = [ ("pymatgen.core", "Structure", ts.io.structures_to_state), ("ase", "Atoms", ts.io.atoms_to_state), ("phonopy.structure.atoms", "PhonopyAtoms", ts.io.phonopy_to_state), ] # Try each converter for module_path, class_name, converter_func in converters: try: module = importlib.import_module(module_path) cls = getattr(module, class_name) if isinstance(system, cls) or ( isinstance(system, list | tuple) and all(isinstance(s, cls) for s in system) ): return converter_func(system, device, dtype) except ImportError: continue # remaining code just for informative error all_same_type = ( isinstance(system, list | tuple) and all(isinstance(s, type(system[0])) for s in system) and system ) if isinstance(system, list | tuple) and not all_same_type: raise ValueError( f"All items in list must be of the same type, " f"found {type(system[0])} and {type(system[1])}" ) system_type = ( f"list[{type(system[0])}]" if isinstance(system, list | tuple) else type(system) ) raise ValueError(f"Unsupported {system_type=}")