Source code for torch_sim.optimizers.gradient_descent

"""Gradient descent optimizer implementation."""

from typing import TYPE_CHECKING, Any

import torch

import torch_sim as ts
from torch_sim.optimizers import cell_filters
from torch_sim.state import SimState
from torch_sim.typing import StateDict


if TYPE_CHECKING:
    from torch_sim.models.interface import ModelInterface
    from torch_sim.optimizers import CellOptimState, OptimState
    from torch_sim.optimizers.cell_filters import CellFilter, CellFilterFuncs


[docs] def gradient_descent_init( state: SimState | StateDict, model: "ModelInterface", *, cell_filter: "CellFilter | CellFilterFuncs | None" = None, **filter_kwargs: Any, ) -> "OptimState | CellOptimState": """Initialize a gradient descent optimization state. Args: model: Model that computes energies, forces, and optionally stress state: SimState containing positions, masses, cell, etc. cell_filter: Filter for cell optimization (None for position-only optimization) **filter_kwargs: Additional arguments passed to cell filter initialization Returns: Initialized OptimState with forces, energy, and optional cell state Notes: Use cell_filter=None for position-only optimization. Use cell_filter=UNIT_CELL_FILTER or FRECHET_CELL_FILTER for cell optimization. """ # Import here to avoid circular imports from torch_sim.optimizers import CellOptimState, OptimState if not isinstance(state, SimState): state = SimState(**state) # Get initial forces and energy from model model_output = model(state) energy = model_output["energy"] forces = model_output["forces"] stress = model_output.get("stress") # Common state arguments common_args = { "positions": state.positions, "forces": forces, "energy": energy, "stress": stress, "masses": state.masses, "cell": state.cell, "pbc": state.pbc, "atomic_numbers": state.atomic_numbers, "system_idx": state.system_idx, } if cell_filter is not None: # Create cell optimization state cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) common_args["reference_cell"] = state.cell.clone() common_args["cell_filter"] = cell_filter_funcs cell_state = CellOptimState(**common_args) # Initialize cell-specific attributes init_fn(cell_state, model, **filter_kwargs) return cell_state # Create regular OptimState without cell optimization return OptimState(**common_args)
[docs] def gradient_descent_step( state: "OptimState | CellOptimState", model: "ModelInterface", *, pos_lr: float | torch.Tensor = 0.01, cell_lr: float | torch.Tensor = 0.1, ) -> "OptimState | CellOptimState": """Perform one gradient descent optimization step. Updates atomic positions and optionally cell parameters based on the filter. Args: model: Model that computes energies, forces, and optionally stress state: Current optimization state pos_lr: Learning rate(s) for atomic positions cell_lr: Learning rate(s) for cell optimization (ignored if no cell filter) Returns: Updated OptimState after one optimization step """ from torch_sim.optimizers import CellOptimState device, dtype = model.device, model.dtype # Get per-atom learning rates if isinstance(pos_lr, (int, float)): pos_lr = torch.full((state.n_systems,), pos_lr, device=device, dtype=dtype) atom_lr = pos_lr[state.system_idx].unsqueeze(-1) # Update atomic positions state.positions = state.positions + atom_lr * state.forces # Update cell if using cell optimization if isinstance(state, CellOptimState): # Compute cell step and update cell _init_fn, step_fn = state.cell_filter step_fn(state, cell_lr) # Get updated forces, energy, and stress model_output = model(state) state.forces = model_output["forces"] state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] # Update cell forces if isinstance(state, CellOptimState): cell_filters.compute_cell_forces(model_output, state) return state