[docs]defgradient_descent_init(state:SimState,model:"ModelInterface",*,cell_filter:"CellFilter | CellFilterFuncs | None"=None,**filter_kwargs:Any,)->"OptimState | CellOptimState":"""Initialize a gradient descent optimization state. Args: model: Model that computes energies, forces, and optionally stress state: SimState containing positions, masses, cell, etc. cell_filter: Filter for cell optimization (None for position-only optimization) **filter_kwargs: Additional arguments passed to cell filter initialization Returns: Initialized OptimState with forces, energy, and optional cell state Notes: Use cell_filter=None for position-only optimization. Use cell_filter=UNIT_CELL_FILTER or FRECHET_CELL_FILTER for cell optimization. """# Import here to avoid circular importsfromtorch_sim.optimizersimportCellOptimState,OptimState# Get initial forces and energy from modelmodel_output=model(state)energy=model_output["energy"]forces=model_output["forces"]stress=model_output.get("stress")# Optimizer-specific additional attributesoptim_attrs={"forces":forces,"energy":energy,"stress":stress,}ifcell_filterisnotNone:# Create cell optimization statecell_filter_funcs=init_fn,_step_fn=ts.get_cell_filter(cell_filter)optim_attrs["reference_cell"]=state.cell.clone()optim_attrs["cell_filter"]=cell_filter_funcscell_state=CellOptimState.from_state(state,**optim_attrs)# Initialize cell-specific attributesinit_fn(cell_state,model,**filter_kwargs)returncell_state# Create regular OptimState without cell optimizationreturnOptimState.from_state(state,**optim_attrs)
[docs]defgradient_descent_step(state:"OptimState | CellOptimState",model:"ModelInterface",*,pos_lr:float|torch.Tensor=0.01,cell_lr:float|torch.Tensor=0.1,)->"OptimState | CellOptimState":"""Perform one gradient descent optimization step. Updates atomic positions and optionally cell parameters based on the filter. Args: model: Model that computes energies, forces, and optionally stress state: Current optimization state pos_lr: Learning rate(s) for atomic positions cell_lr: Learning rate(s) for cell optimization (ignored if no cell filter) Returns: Updated OptimState after one optimization step """fromtorch_sim.optimizersimportCellOptimStatedevice,dtype=model.device,model.dtype# Get per-atom learning ratespos_lr=torch.as_tensor(pos_lr,device=device,dtype=dtype)ifpos_lr.ndim==0:pos_lr=pos_lr.expand(state.n_systems)atom_lr=pos_lr[state.system_idx].unsqueeze(-1)# Update atomic positionsstate.set_constrained_positions(state.positions+atom_lr*state.forces)# Update cell if using cell optimizationifisinstance(state,CellOptimState):# Compute cell step and update cell_init_fn,step_fn=state.cell_filterstep_fn(state,cell_lr)# Get updated forces, energy, and stressmodel_output=model(state)state.set_constrained_forces(model_output["forces"])state.energy=model_output["energy"]if"stress"inmodel_output:state.stress=model_output["stress"]# Update cell forcesifisinstance(state,CellOptimState):cell_filters.compute_cell_forces(model_output,state)returnstate