SimState

class torch_sim.state.SimState(positions, masses, cell, pbc, atomic_numbers, charge=None, spin=None, system_idx=None, _constraints=<factory>)[source]

Bases: object

State representation for atomistic systems with batched operations support.

Contains the fundamental properties needed to describe an atomistic system: positions, masses, unit cell, periodic boundary conditions, and atomic numbers. Supports batched operations where multiple atomistic systems can be processed simultaneously, managed through system indices.

States support slicing, cloning, splitting, popping, and movement to other data structures or devices. Slicing is supported through fancy indexing, e.g. state[[0, 1, 2]] will return a new state containing only the first three systems. The other operations are available through the pop, split, clone, and to methods.

Variables:
  • positions (torch.Tensor) – Atomic positions with shape (n_atoms, 3)

  • masses (torch.Tensor) – Atomic masses with shape (n_atoms,)

  • cell (torch.Tensor) – Unit cell vectors with shape (n_systems, 3, 3). Note that we use a column vector convention, i.e. the cell vectors are stored as [[a1, b1, c1], [a2, b2, c2], [a3, b3, c3]] as opposed to the row vector convention [[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]] used by ASE.

  • pbc (bool | list[bool] | torch.Tensor) – indicates periodic boundary conditions in each axis. If a boolean is provided, all axes are assumed to have the same periodic boundary conditions.

  • atomic_numbers (torch.Tensor) – Atomic numbers with shape (n_atoms,)

  • system_idx (torch.Tensor) – Maps each atom index to its system index. Has shape (n_atoms,), must be unique consecutive integers starting from 0.

  • constraints (list["Constraint"] | None) – List of constraints applied to the system. Constraints affect degrees of freedom and modify positions.

Parameters:
Properties:
wrap_positions (torch.Tensor): Positions wrapped according to periodic boundary

conditions

device (torch.device): Device of the positions tensor dtype (torch.dtype): Data type of the positions tensor n_atoms (int): Total number of atoms across all systems n_systems (int): Number of unique systems in the system

Notes

  • positions, masses, and atomic_numbers must have shape (n_atoms, 3).

  • cell must be in the conventional matrix form.

  • system indices must be unique consecutive integers starting from 0.

Examples

>>> state = initialize_state(
...     [ase_atoms_1, ase_atoms_2, ase_atoms_3], device, dtype
... )
>>> state.n_systems
3
>>> new_state = state[[0, 1]]
>>> new_state.n_systems
2
>>> cloned_state = state.clone()
property wrap_positions: Tensor

Atomic positions wrapped according to periodic boundary conditions if pbc=True, otherwise returns unwrapped positions with shape (n_atoms, 3).

property device: device

The device where the tensor data is located.

property dtype: dtype

The data type of the positions tensor.

property n_atoms: int

Total number of atoms in the system across all systems.

property n_atoms_per_system: Tensor

Number of atoms per system.

property n_systems: int

Number of systems in the system.

property volume: Tensor

Volume of the system.

property attributes: dict[str, Tensor]

Get all public attributes of the state.

property column_vector_cell: Tensor

Unit cell following the column vector convention.

property row_vector_cell: Tensor

Unit cell following the row vector convention.

set_constrained_positions(new_positions)[source]

Set the positions and apply constraints if they exist.

Parameters:

new_positions (Tensor) – New positions tensor with shape (n_atoms, 3)

Return type:

None

set_constrained_cell(new_cell, scale_atoms=False)[source]

Set the cell, apply constraints, and optionally scale atomic positions.

Parameters:
  • new_cell (Tensor) – New cell tensor with shape (n_systems, 3, 3) in column vector convention

  • scale_atoms (bool) – Whether to scale atomic positions to preserve fractional coordinates. Defaults to False.

Return type:

None

property constraints: list[Constraint]

Get the constraints for the SimState.

Returns:

List of constraints applied to the system.

Return type:

list[”Constraint”]

set_cell(cell, scale_atoms=False)[source]

Set the unit cell of the system, optionally scaling atomic positions. Torch version of ASE Atoms.set_cell.

Parameters:
  • cell (torch.Tensor) – New unit cell with shape (n_systems, 3, 3)

  • scale_atoms (bool, optional) – Whether to scale atomic positions according to the change in cell. Defaults to False.

Return type:

None

get_number_of_degrees_of_freedom()[source]

Calculate degrees of freedom accounting for constraints.

Returns:

Number of degrees of freedom per system, with shape

(n_systems,). Each system starts with 3 * n_atoms_per_system degrees of freedom, minus any degrees removed by constraints.

Return type:

Tensor

clone()[source]

Create a deep copy of the SimState.

Creates a new SimState object with identical but independent tensors, allowing modification without affecting the original.

Returns:

A new SimState object with the same properties as the original

Return type:

SimState

classmethod from_state(state, **additional_attrs)[source]

Create a new state from an existing state with additional attributes.

This method copies attributes from the source state that are valid for the target state class, and adds any additional attributes needed. It supports upcasting (SimState -> MDState), downcasting (MDState -> SimState), and cross-casting (MDState -> OptimState) between state types.

Parameters:
  • state (SimState) – Source state to copy base attributes from

  • **additional_attrs (Any) – Additional attributes required by the target state class

Returns:

New state of the target class with copied and additional attributes

Return type:

Self

Example

>>> from torch_sim.integrators.md import MDState
>>> md_state = MDState.from_state(
...     sim_state,
...     energy=model_output["energy"],
...     forces=model_output["forces"],
...     momenta=torch.zeros_like(sim_state.positions),
... )
to_atoms()[source]

Convert the SimState to a list of ASE Atoms objects.

Returns:

A list of ASE Atoms objects, one per system

Return type:

list[Atoms]

to_structures()[source]

Convert the SimState to a list of pymatgen Structure objects.

Returns:

A list of pymatgen Structure objects, one per system

Return type:

list[Structure]

to_phonopy()[source]

Convert the SimState to a list of PhonopyAtoms objects.

Returns:

A list of PhonopyAtoms objects, one per system

Return type:

list[PhonopyAtoms]

split()[source]

Split the SimState into a list of single-system SimStates.

Divides the current state into separate states, each containing a single system, preserving all properties appropriately for each system.

Returns:

A list of SimState objects, one per system

Return type:

list[SimState]

pop(system_indices)[source]

Pop off states with the specified system indices.

This method modifies the original state object by removing the specified systems and returns the removed systems as separate SimState objects.

Parameters:

system_indices (int | list[int] | slice | torch.Tensor) – The system indices to pop

Returns:

Popped SimState objects, one per system index

Return type:

list[SimState]

Notes

This method modifies the original SimState in-place.

to(device=None, dtype=None)[source]

Convert the SimState to a new device and/or data type.

Parameters:
  • device (torch.device, optional) – The target device. Defaults to current device.

  • dtype (torch.dtype, optional) – The target data type. Defaults to current dtype.

Returns:

A new SimState with tensors on the specified device and dtype

Return type:

SimState