SumModel

class torch_sim.models.interface.SumModel(*models)[source]

Bases: ModelInterface

Additive composition of multiple ModelInterface models.

Calls each child model’s forward(). Canonical mechanical outputs (energy, forces, stress) are combined additively, while non-canonical outputs are treated as full updated values and later models replace earlier ones. This is the standard way to layer a dispersion correction (e.g. DFT-D3), an Ewald electrostatic term, or a local pair potential on top of a primary machine-learning potential.

Parameters:

models (ModelInterface) – Two or more ModelInterface instances that share the same device and dtype.

Raises:

ValueError – If fewer than two models are given or if device/dtype do not match across all models.

Examples

`py sum_model = SumModel(mace_model, d3_model) output = sum_model(sim_state) `

property compute_stress: bool

Whether the model computes stresses.

property compute_forces: bool

Whether the model computes forces.

property retain_graph: bool

Whether any child model retains the computation graph.

property memory_scales_with: MemoryScaling

Most conservative memory-scaling among all child models.

forward(state, **kwargs)[source]

Sum the outputs of all child models.

Each child model is called with the same state and **kwargs. Canonical mechanical outputs that appear in multiple children are summed element-wise. Non-canonical outputs are replaced by later models so they behave like full state updates rather than deltas.

Parameters:
Returns:

Combined output dictionary with summed tensors.

Return type:

dict[str, Tensor]