FairChemModel

class torch_sim.models.fairchem.FairChemModel(model, neighbor_list_fn=None, *, model_name=None, model_cache_dir=None, cpu=False, dtype=None, compute_stress=False, task_name=None)[source]

Bases: ModelInterface

FairChem model wrapper for computing atomistic properties.

Wraps FairChem models to compute energies, forces, and stresses. Can be initialized with a model checkpoint path or pretrained model name.

Uses the fairchem-core-2.2.0+ predictor API for batch inference.

Variables:
  • predictor – The FairChem predictor for batch inference

  • task_name (UMATask) – Task type for the model

  • _device (torch.device) – Device where computation is performed

  • _dtype (torch.dtype) – Data type used for computation

  • _compute_stress (bool) – Whether to compute stress tensor

  • implemented_properties (list) – Model outputs the model can compute

Parameters:
  • model (str | Path | None)

  • neighbor_list_fn (Callable | None)

  • model_name (str | None)

  • model_cache_dir (str | Path | None)

  • cpu (bool)

  • dtype (dtype | None)

  • compute_stress (bool)

  • task_name (UMATask | str | None)

Examples

>>> model = FairChemModel(model="path/to/checkpoint.pt", compute_stress=True)
>>> results = model(state)
property dtype: dtype

Return the data type used by the model.

property device: device

Return the device where the model is located.

forward(state)[source]

Compute energies, forces, and other properties.

Parameters:

state (SimState | StateDict) – State object containing positions, cells, atomic numbers, and other system information. If a dictionary is provided, it will be converted to a SimState.

Returns:

Dictionary of model predictions, which may include:
  • energy (torch.Tensor): Energy with shape [batch_size]

  • forces (torch.Tensor): Forces with shape [n_atoms, 3]

  • stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3]

Return type:

dict