SumModel¶
- class torch_sim.models.interface.SumModel(*models)[source]¶
Bases:
ModelInterfaceAdditive composition of multiple
ModelInterfacemodels.Calls each child model’s
forward(). Canonical mechanical outputs (energy, forces, stress) are combined additively, while non-canonical outputs are treated as full updated values and later models replace earlier ones. This is the standard way to layer a dispersion correction (e.g. DFT-D3), an Ewald electrostatic term, or a local pair potential on top of a primary machine-learning potential.- Parameters:
models (ModelInterface) – Two or more
ModelInterfaceinstances that share the samedeviceanddtype.- Raises:
ValueError – If fewer than two models are given or if
device/dtypedo not match across all models.
Examples
`py sum_model = SumModel(mace_model, d3_model) output = sum_model(sim_state) `- property memory_scales_with: MemoryScaling¶
Most conservative memory-scaling among all child models.
- forward(state, **kwargs)[source]¶
Sum the outputs of all child models.
Each child model is called with the same
stateand**kwargs. Canonical mechanical outputs that appear in multiple children are summed element-wise. Non-canonical outputs are replaced by later models so they behave like full state updates rather than deltas.- Parameters:
state (SimState) – Simulation state (see
ModelInterface).**kwargs – Forwarded to every child model.
- Returns:
Combined output dictionary with summed tensors.
- Return type: