FairChemV1Model¶
- class torch_sim.models.fairchem_legacy.FairChemV1Model(model, neighbor_list_fn=None, *, config_yml=None, model_name=None, local_cache=None, trainer=None, cpu=False, seed=None, dtype=None, compute_stress=False, pbc=True, disable_amp=True)[source]¶
Bases:
ModelInterfaceComputes atomistic energies, forces and stresses using a FairChem model.
This class wraps a FairChem model to compute energies, forces, and stresses for atomistic systems. It handles model initialization, checkpoint loading, and provides a forward pass that accepts a SimState object and returns model predictions.
The model can be initialized either with a configuration file or a pretrained checkpoint. It supports various model architectures and configurations supported by FairChem.
- Variables:
neighbor_list_fn (
Callable | None) – Function to compute neighbor listsconfig (
dict) – Complete model configuration dictionarytrainer – FairChem trainer object that contains the model
data_object (
Batch) – Data object containing system informationimplemented_properties (
list) – Model outputs the model can computepbc (
bool) – Whether periodic boundary conditions are used_dtype (
torch.dtype) – Data type used for computation_compute_stress (
bool) – Whether to compute stress tensor_compute_forces (
bool) – Whether to compute forces_device (
torch.device) – Device where computation is performed_reshaped_props (
dict) – Properties that need reshaping after computation
- Parameters:
Examples
>>> model = FairChemV1Model(model="path/to/checkpoint.pt", compute_stress=True) >>> results = model(state)
- load_checkpoint(checkpoint_path, checkpoint=None)[source]¶
Load an existing trained model checkpoint.
Loads model parameters from a checkpoint file or dictionary, setting the model to inference mode.
- Parameters:
checkpoint_path (
str) – Path to the trained model checkpoint filecheckpoint (
dict | None) – A pretrained checkpoint dictionary. If provided, this dictionary is used instead of loading from checkpoint_path.
- Return type:
None
Notes
If loading fails, a message is printed but no exception is raised.
- forward(state)[source]¶
Perform forward pass to compute energies, forces, and other properties.
Takes a simulation state and computes the properties implemented by the model, such as energy, forces, and stresses.
- Parameters:
state (
SimState | StateDict) – State object containing positions, cells, atomic numbers, and other system information. If a dictionary is provided, it will be converted to a SimState.- Returns:
- Dictionary of model predictions, which may include:
energy (torch.Tensor): Energy with shape [batch_size]
forces (torch.Tensor): Forces with shape [n_atoms, 3]
- stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3],
if compute_stress is True
- Return type:
Notes
The state is automatically transferred to the model’s device if needed. All output tensors are detached from the computation graph.