Source code for torch_sim.models.polarization
"""Electric-field corrections for polarization-aware models."""
import torch
from torch_sim.models.interface import ModelInterface
from torch_sim.state import SimState
from torch_sim.typing import AtomExtras, SystemExtras
[docs]
class UniformPolarizationModel(ModelInterface):
"""Calculates the energy and force contributions from the application
of a constant electric field to a polarizable system.
This model is intended to run after an upstream model inside
:class:`~torch_sim.models.interface.SerialSumModel`.
Required state extras:
* ``external_E_field``
* ``total_polarization``
* ``polarizability``
* ``born_effective_charges`` when ``compute_forces`` is enabled
"""
def __init__(
self,
device: torch.device | None = None,
dtype: torch.dtype = torch.float64,
*,
compute_forces: bool = True,
compute_stress: bool = True,
retain_graph: bool = False,
) -> None:
"""Initialize a uniform-field polarization correction model."""
super().__init__()
self._device = device or torch.device("cpu")
self._dtype = dtype
self._compute_forces = compute_forces
self._compute_stress = compute_stress
self._retain_graph = retain_graph
self._memory_scales_with = "n_atoms"
@ModelInterface.compute_stress.setter
def compute_stress(self, value: bool) -> None: # noqa: FBT001
"""Set whether the model returns an additive stress tensor."""
self._compute_stress = value
@ModelInterface.compute_forces.setter
def compute_forces(self, value: bool) -> None: # noqa: FBT001
"""Set whether the model returns additive force corrections."""
self._compute_forces = value
@property
def retain_graph(self) -> bool:
"""Whether outputs should remain attached to the autograd graph."""
return self._retain_graph
@retain_graph.setter
def retain_graph(self, value: bool) -> None:
"""Set whether outputs should remain attached to the autograd graph."""
self._retain_graph = value
def _finalize_output(
self, output: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""Detach outputs unless graph retention is enabled."""
if self.retain_graph:
return output
return {
key: val.detach() if isinstance(val, torch.Tensor) else val
for key, val in output.items()
}
def _apply_nonzero_field(
self,
state: SimState,
output: dict[str, torch.Tensor],
field: torch.Tensor,
) -> None:
"""Apply constant-field linear-response corrections.
Computes the additive updates
- delta_energy = -E·P0 - 1/2 E·alpha·E
- total_polarization = P0 + alpha·E
- delta_forces = Z*·E
"""
required_keys = [
SystemExtras.TOTAL_POLARIZATION.value,
SystemExtras.POLARIZABILITY.value,
]
if self.compute_forces:
required_keys.append(AtomExtras.BORN_EFFECTIVE_CHARGES.value)
missing_keys = [key for key in required_keys if not state.has_extras(key)]
if missing_keys:
missing = ", ".join(f"'{key}'" for key in missing_keys)
raise ValueError(
f"UniformPolarizationModel requires {missing} on the state "
"when external_E_field is non-zero"
)
dipole_coupling = torch.einsum("si,si->s", field, state.total_polarization)
polarization_response = torch.einsum(
"si,sij,sj->s", field, state.polarizability, field
)
output["energy"] = -dipole_coupling - 0.5 * polarization_response
output[SystemExtras.TOTAL_POLARIZATION.value] = (
torch.einsum(
"sij,sj->si",
state.polarizability,
field,
)
+ state.total_polarization
)
if self.compute_forces:
output["forces"] = torch.einsum(
"imn,im->in",
state.born_effective_charges,
field[state.system_idx],
)
[docs]
def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]:
"""Return additive uniform-field corrections for a polarization model."""
del kwargs
output: dict[str, torch.Tensor] = {
"energy": torch.zeros(state.n_systems, device=state.device, dtype=state.dtype)
}
if self.compute_forces:
output["forces"] = torch.zeros_like(state.positions)
if self.compute_stress:
# V1 intentionally applies no field-induced stress correction.
output["stress"] = torch.zeros(
state.n_systems, 3, 3, device=state.device, dtype=state.dtype
)
if not state.has_extras(SystemExtras.EXTERNAL_E_FIELD.value):
raise ValueError(
"UniformPolarizationModel requires 'external_E_field' on the state"
)
field = getattr(state, SystemExtras.EXTERNAL_E_FIELD.value)
if field.shape != (state.n_systems, 3):
raise ValueError(
"UniformPolarizationModel requires external_E_field to have shape "
"(n_systems, 3)"
)
if not torch.any(field != 0):
return self._finalize_output(output)
self._apply_nonzero_field(state, output, field)
return self._finalize_output(output)
__all__ = ["UniformPolarizationModel"]