OrbModel¶
- class torch_sim.models.orb.OrbModel(model, *, conservative=None, compute_stress=True, compute_forces=True, system_config=None, max_num_neighbors=None, edge_method=None, half_supercell=None, device=None, dtype=torch.float32)[source]¶
Bases:
ModelInterfaceComputes atomistic energies, forces and stresses using an ORB model.
This class wraps an ORB model to compute energies, forces, and stresses for atomistic systems. It handles model initialization, configuration, and provides a forward pass that accepts a SimState object and returns model predictions.
- Variables:
model (
Union[DirectForcefieldRegressor, ConservativeForcefieldRegressor]) – The ORB modelsystem_config (
SystemConfig) – Configuration for the atomic systemconservative (
bool) – Whether to use conservative forces/stresses calculationimplemented_properties (
list) – Properties the model can compute_dtype (
torch.dtype) – Data type used for computation_device (
torch.device) – Device where computation is performed_edge_method (
EdgeCreationMethod) – Method for creating edges in the graph_max_num_neighbors (
int) – Maximum number of neighbors for each atom_half_supercell (
bool) – Whether to use half supercell optimization_memory_scales_with (
str) – What the memory usage scales with
- Parameters:
model (DirectForcefieldRegressor | ConservativeForcefieldRegressor | str | Path)
conservative (bool | None)
compute_stress (bool)
compute_forces (bool)
system_config (SystemConfig | None)
max_num_neighbors (int | None)
edge_method (EdgeCreationMethod | None)
half_supercell (bool | None)
dtype (dtype)
Examples
>>> model = OrbModel(model=loaded_orb_model, compute_stress=True) >>> results = model(state)
- forward(state)[source]¶
Perform forward pass to compute energies, forces, and other properties.
Takes a simulation state and computes the properties implemented by the model, such as energy, forces, and stresses.
- Parameters:
state (
SimState | StateDict) – State object containing positions, cells, atomic numbers, and other system information. If a dictionary is provided, it will be converted to a SimState.- Returns:
- Model predictions, which may include:
energy (torch.Tensor): Energy with shape [batch_size]
forces (torch.Tensor): Forces with shape [n_atoms, 3]
- stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3],
if compute_stress is True
- Return type:
Notes
The state is automatically transferred to the model’s device if needed. All output tensors are detached from the computation graph.