"""Wrapper for ORB models in TorchSim.This module re-exports the ORB package's torch-sim integration for convenientimporting. The actual implementation is maintained in the orb-models package.References: - ORB Models Package: https://github.com/orbital-materials/orb-models"""importtracebackimportwarningsfromtypingimportAnyimporttorchtry:fromorb_models.forcefield.inference.d3_modelimportD3SumModelfromorb_models.forcefield.inference.orb_torchsimimportOrbTorchSimModelimporttorch_simastsfromtorch_sim.elasticimportvoigt_6_to_full_3x3_stress# Re-export with backward-compatible nameclassOrbModel(OrbTorchSimModel):"""ORB model wrapper for torch-sim."""@staticmethoddef_normalize_charge_spin(state:"ts.SimState")->"ts.SimState":"""Provide ORB's optional charge/spin inputs when they are missing."""charge=getattr(state,"charge",None)spin=getattr(state,"spin",None)ifchargeisnotNoneandspinisnotNone:returnstatezeros=torch.zeros(state.n_systems,device=state.device,dtype=state.dtype)returnts.SimState.from_state(state,charge=chargeifchargeisnotNoneelsezeros,spin=spinifspinisnotNoneelsezeros,)def_get_results(self,out:dict[str,torch.Tensor])->dict[str,torch.Tensor]:"""Parses the results into a final output dictionary."""results={}model=(self.model.xc_modelifisinstance(self.model,D3SumModel)elseself.model)heads=getattr(model,"heads",{})no_direct_energy_head="energy"notinheadsno_direct_force_head="forces"notinheadsno_direct_stress_head="stress"notinheadsforpropinself.implemented_properties:ifprop=="free_energy"andno_direct_energy_head:continueifprop=="forces"andno_direct_force_head:continueifprop=="stress"andno_direct_stress_head:continue_prop="energy"ifprop=="free_energy"elseprop# Do not squeeze the output tensors in the case of single atom cells# TODO: remove after https://github.com/orbital-materials/orb-models/pull/158results[prop]=torch.atleast_1d(out[_prop])# Rename certain keys for the conservative modelifself.conservative:ifmodel.forces_nameinresults:results["direct_forces"]=results[model.forces_name]results["forces"]=results[model.grad_forces_name]ifmodel.has_stress:ifmodel.stress_nameinresults:results["direct_stress"]=results[model.stress_name]results["stress"]=results[model.grad_stress_name]# Ensure stress has shape [-1, 3, 3]if"stress"inresultsandresults["stress"].shape[-1]==6:results["stress"]=voigt_6_to_full_3x3_stress(torch.atleast_2d(results["stress"]))returnresultsdefforward(self,*args:Any,**kwargs:Any)->dict[str,Any]:"""Run forward pass, detaching outputs unless retain_graph is True."""ifargsandisinstance(args[0],ts.SimState):args=(self._normalize_charge_spin(args[0]),*args[1:])elifisinstance(kwargs.get("state"),ts.SimState):kwargs["state"]=self._normalize_charge_spin(kwargs["state"])output=super().forward(*args,**kwargs)return{# detach tensors as energy is not detached by defaultk:v.detach()ifhasattr(v,"detach")elsevfork,vinoutput.items()}exceptImportErrorasexc:warnings.warn(f"Orb import failed: {traceback.format_exc()}",stacklevel=2)fromtorch_sim.models.interfaceimportModelInterface
[docs]classOrbModel(ModelInterface):"""ORB model wrapper for torch-sim. NOTE: This class is a placeholder when orb-models is not installed. It raises an ImportError if accessed. """# Capture the original ImportError in a closure-safe default so the# fallback always re-raises the real import failure, even when callers# pass positional/keyword args (e.g. ``OrbModel(orb_ff, adapter, ...)``)# that would otherwise shadow an ``err`` parameter.def__init__(self,*_args:Any,_err:ImportError=exc,**_kwargs:Any)->None:"""Dummy init that re-raises the original import failure."""raise_err