FairChemModel¶
- class torch_sim.models.fairchem.FairChemModel(model, neighbor_list_fn=None, *, model_name=None, model_cache_dir=None, cpu=False, dtype=None, compute_stress=False, task_name=None)[source]¶
Bases:
ModelInterfaceFairChem model wrapper for computing atomistic properties.
Wraps FairChem models to compute energies, forces, and stresses. Can be initialized with a model checkpoint path or pretrained model name.
Uses the fairchem-core-2.2.0+ predictor API for batch inference.
- Variables:
predictor – The FairChem predictor for batch inference
task_name (
UMATask) – Task type for the model_device (
torch.device) – Device where computation is performed_dtype (
torch.dtype) – Data type used for computation_compute_stress (
bool) – Whether to compute stress tensorimplemented_properties (
list) – Model outputs the model can compute
- Parameters:
Examples
>>> model = FairChemModel(model="path/to/checkpoint.pt", compute_stress=True) >>> results = model(state)
- forward(state)[source]¶
Compute energies, forces, and other properties.
- Parameters:
state (
SimState | StateDict) – State object containing positions, cells, atomic numbers, and other system information. If a dictionary is provided, it will be converted to a SimState.- Returns:
- Dictionary of model predictions, which may include:
energy (torch.Tensor): Energy with shape [batch_size]
forces (torch.Tensor): Forces with shape [n_atoms, 3]
stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3]
- Return type: