Source code for torch_sim.optimizers.fire

"""FIRE (Fast Inertial Relaxation Engine) optimizer implementation."""

from typing import TYPE_CHECKING, Any, get_args

import torch

import torch_sim as ts
import torch_sim.math as tsm
from torch_sim._duecredit import dcite
from torch_sim.optimizers import CellFireState, cell_filters
from torch_sim.state import SimState


if TYPE_CHECKING:
    from torch_sim.models.interface import ModelInterface
    from torch_sim.optimizers import FireFlavor, FireState
    from torch_sim.optimizers.cell_filters import CellFilter, CellFilterFuncs


[docs] @dcite("10.1103/PhysRevLett.97.170201") def fire_init( state: SimState, model: "ModelInterface", *, dt_start: float | torch.Tensor = 0.1, alpha_start: float | torch.Tensor = 0.1, fire_flavor: "FireFlavor" = "ase_fire", cell_filter: "CellFilter | CellFilterFuncs | None" = None, **filter_kwargs: Any, ) -> "FireState | CellFireState": """Initialize a FIRE optimization state. Creates an optimizer that performs FIRE (Fast Inertial Relaxation Engine) optimization on atomic positions and optionally cell parameters. Args: model: Model that computes energies, forces, and optionally stress state: Input SimState dt_start: Initial timestep per system alpha_start: Initial mixing parameter per system fire_flavor: Optimization flavor ("vv_fire" or "ase_fire") cell_filter: Filter for cell optimization (None for position-only optimization) **filter_kwargs: Additional arguments passed to cell filter initialization Returns: FireState with initialized optimization tensors Notes: - fire_flavor="vv_fire" follows the original paper closely - fire_flavor="ase_fire" mimics the ASE implementation - Use cell_filter=UNIT_CELL_FILTER or FRECHET_CELL_FILTER for cell optimization """ # Import here to avoid circular imports from torch_sim.optimizers import CellFireState, FireFlavor, FireState if fire_flavor not in get_args(FireFlavor): raise ValueError(f"Unknown {fire_flavor=}, must be one of {get_args(FireFlavor)}") device: torch.device = model.device dtype: torch.dtype = model.dtype n_systems = state.n_systems # Get initial forces and energy from model model_output = model(state) energy = model_output["energy"] forces = model_output["forces"] stress = model_output.get("stress") # Setup initial parameters dt_start_t = torch.as_tensor(dt_start, device=device, dtype=dtype) if dt_start_t.ndim == 0: # NOTE: clone needed as this is overwritten/assigned later by a masked_fill dt_start_t = dt_start_t.expand(n_systems).clone() alpha_start_t = torch.as_tensor(alpha_start, device=device, dtype=dtype) if alpha_start_t.ndim == 0: # NOTE: clone needed as this is overwritten/assigned later by a masked_fill alpha_start_t = alpha_start_t.expand(n_systems).clone() # FIRE-specific additional attributes fire_attrs = { "forces": forces, "energy": energy, "stress": stress, "velocities": torch.full( state.positions.shape, torch.nan, device=device, dtype=dtype ), "dt": dt_start_t, "alpha": alpha_start_t, "n_pos": torch.zeros((n_systems,), device=model.device, dtype=torch.int32), } if cell_filter is not None: # Create cell optimization state cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) fire_attrs["reference_cell"] = state.cell.clone() fire_attrs["cell_filter"] = cell_filter_funcs cell_state = CellFireState.from_state(state, **fire_attrs) # Initialize cell-specific attributes init_fn(cell_state, model, **filter_kwargs) # Initialize cell velocities after cell_forces is set cell_state.cell_velocities = torch.full( cell_state.cell_forces.shape, torch.nan, device=device, dtype=dtype ) return cell_state # Create regular FireState without cell optimization return FireState.from_state(state, **fire_attrs)
[docs] def fire_step( state: "FireState | CellFireState", model: "ModelInterface", *, dt_max: float | torch.Tensor = 1.0, n_min: int | torch.Tensor = 5, f_inc: float | torch.Tensor = 1.1, f_dec: float | torch.Tensor = 0.5, alpha_start: float | torch.Tensor = 0.1, f_alpha: float | torch.Tensor = 0.99, max_step: float | torch.Tensor = 0.2, fire_flavor: "FireFlavor" = "ase_fire", ) -> "FireState | CellFireState": """Perform one FIRE optimization step. Args: model: Model that computes energies, forces, and optionally stress state: Current FIRE optimization state dt_max: Maximum allowed timestep n_min: Minimum steps before timestep increase f_inc: Factor for timestep increase when power is positive f_dec: Factor for timestep decrease when power is negative alpha_start: Initial velocity mixing parameter f_alpha: Factor for mixing parameter decrease max_step: Maximum distance an atom can move per iteration fire_flavor: Optimization flavor ("vv_fire" or "ase_fire") Returns: Updated FireState after one optimization step """ # Import here to avoid circular imports from torch_sim.optimizers import FireFlavor, ase_fire_key, vv_fire_key if fire_flavor not in get_args(FireFlavor): raise ValueError(f"Unknown {fire_flavor=}, must be one of {get_args(FireFlavor)}") device, dtype = model.device, model.dtype eps = 1e-8 if dtype == torch.float32 else 1e-16 # Setup parameters dt_max, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step = ( torch.as_tensor(p, device=device, dtype=dtype) for p in (dt_max, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step) ) step_func_kwargs = dict( model=model, dt_max=dt_max, n_min=n_min, f_inc=f_inc, f_dec=f_dec, alpha_start=alpha_start, f_alpha=f_alpha, eps=eps, ) if fire_flavor == ase_fire_key: step_func_kwargs["max_step"] = max_step step_func = {vv_fire_key: _vv_fire_step, ase_fire_key: _ase_fire_step}[fire_flavor] return step_func(state, **step_func_kwargs) # ty: ignore[invalid-argument-type]
def _vv_fire_step[T: "FireState | CellFireState"]( state: T, model: "ModelInterface", *, dt_max: torch.Tensor, n_min: torch.Tensor, f_inc: torch.Tensor, f_dec: torch.Tensor, alpha_start: torch.Tensor, f_alpha: torch.Tensor, eps: float, ) -> T: """Perform one Velocity-Verlet based FIRE optimization step.""" n_systems, device, dtype = state.n_systems, state.device, state.dtype # Initialize velocities if NaN nan_velocities = state.velocities.isnan().any(dim=1) if nan_velocities.any(): state.velocities[nan_velocities] = 0 if isinstance(state, CellFireState): nan_cell_vel = state.cell_velocities.isnan().any(dim=(1, 2)) state.cell_velocities[nan_cell_vel] = 0 alpha_start_system = torch.full( (n_systems,), alpha_start.item(), device=device, dtype=dtype ) # First half of velocity update atom_wise_dt = state.dt[state.system_idx].unsqueeze(-1) state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) # Position update state.set_constrained_positions(state.positions + atom_wise_dt * state.velocities) # Cell position updates are handled in the velocity update step above # Get new forces and energy model_output = model(state) state.set_constrained_forces(model_output["forces"]) state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] # Update cell forces if isinstance(state, CellFireState): cell_filters.compute_cell_forces(model_output, state) # Second half of velocity update state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) if isinstance(state, CellFireState): cell_wise_dt = state.dt.view(n_systems, 1, 1) state.cell_velocities += ( 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) ) # Calculate power system_power = tsm.batched_vdot(state.forces, state.velocities, state.system_idx) if isinstance(state, CellFireState): system_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) # Update dt, alpha, n_pos pos_mask_system = system_power > 0.0 neg_mask_system = ~pos_mask_system state.n_pos[pos_mask_system] += 1 inc_mask = (state.n_pos > n_min) & pos_mask_system state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) state.alpha[inc_mask] *= f_alpha state.dt[neg_mask_system] *= f_dec state.alpha[neg_mask_system] = alpha_start_system[neg_mask_system] state.n_pos[neg_mask_system] = 0 # Velocity mixing v_scaling_system = tsm.batched_vdot( state.velocities, state.velocities, state.system_idx ) f_scaling_system = tsm.batched_vdot(state.forces, state.forces, state.system_idx) if isinstance(state, CellFireState): v_scaling_system += state.cell_velocities.pow(2).sum(dim=(1, 2)) f_scaling_system += state.cell_forces.pow(2).sum(dim=(1, 2)) v_scaling_cell = torch.sqrt(v_scaling_system.view(n_systems, 1, 1)) f_scaling_cell = torch.sqrt(f_scaling_system.view(n_systems, 1, 1)) v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell alpha_cell_bc = state.alpha.view(n_systems, 1, 1) state.cell_velocities = torch.where( pos_mask_system.view(n_systems, 1, 1), (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, torch.zeros_like(state.cell_velocities), ) v_scaling_atom = torch.sqrt(v_scaling_system[state.system_idx].unsqueeze(-1)) f_scaling_atom = torch.sqrt(f_scaling_system[state.system_idx].unsqueeze(-1)) v_mixing_atom = state.forces * (v_scaling_atom / (f_scaling_atom + eps)) alpha_atom = state.alpha[state.system_idx].unsqueeze(-1) state.velocities = torch.where( pos_mask_system[state.system_idx].unsqueeze(-1), (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, torch.zeros_like(state.velocities), ) return state def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 state: T, model: "ModelInterface", *, dt_max: torch.Tensor, n_min: torch.Tensor, f_inc: torch.Tensor, f_dec: torch.Tensor, alpha_start: torch.Tensor, f_alpha: torch.Tensor, max_step: torch.Tensor, eps: float, ) -> T: """Perform one ASE-style FIRE optimization step.""" from torch_sim.optimizers import CellFireState n_systems, device, dtype = state.n_systems, state.device, state.dtype # Per-atom NaN detection before zeroing: needed to decide whether to skip # FIRE mixing (all NaN = first step) vs run it (partial NaN = autobatcher swap). nan_velocities = state.velocities.isnan().any(dim=1) state.velocities.nan_to_num_(nan=0.0) if isinstance(state, CellFireState): state.cell_velocities.nan_to_num_(nan=0.0) if nan_velocities.all(): # First step: all NaN → zero. Use raw forces, skip FIRE mixing (matches ASE). forces = state.forces else: alpha_start_system = torch.full( (n_systems,), alpha_start.item(), device=device, dtype=dtype ) # Transform forces for cell optimization if isinstance(state, CellFireState): cur_deform_grad = cell_filters.deform_grad( state.row_vector_cell, getattr(state, "reference_row_vector_cell", state.row_vector_cell), ) forces = torch.bmm( state.forces.unsqueeze(1), cur_deform_grad[state.system_idx] ).squeeze(1) else: forces = state.forces # Calculate power (newly zeroed systems will have power=0 → neg_mask) system_power = tsm.batched_vdot(forces, state.velocities, state.system_idx) if isinstance(state, CellFireState): system_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) # Update dt, alpha, n_pos pos_mask_system = system_power > 0.0 neg_mask_system = ~pos_mask_system inc_mask = (state.n_pos > n_min) & pos_mask_system state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) state.alpha[inc_mask] *= f_alpha state.n_pos[pos_mask_system] += 1 state.dt[neg_mask_system] *= f_dec state.alpha[neg_mask_system] = alpha_start_system[neg_mask_system] state.n_pos[neg_mask_system] = 0 # Velocity mixing BEFORE acceleration (ASE ordering) v_scaling_system = tsm.batched_vdot( state.velocities, state.velocities, state.system_idx ) f_scaling_system = tsm.batched_vdot(forces, forces, state.system_idx) if isinstance(state, CellFireState): v_scaling_system += state.cell_velocities.pow(2).sum(dim=(1, 2)) f_scaling_system += state.cell_forces.pow(2).sum(dim=(1, 2)) v_scaling_cell = torch.sqrt(v_scaling_system.view(n_systems, 1, 1)) f_scaling_cell = torch.sqrt(f_scaling_system.view(n_systems, 1, 1)) v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell alpha_cell_bc = state.alpha.view(n_systems, 1, 1) state.cell_velocities = torch.where( pos_mask_system.view(n_systems, 1, 1), (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, torch.zeros_like(state.cell_velocities), ) v_scaling_atom = torch.sqrt(v_scaling_system[state.system_idx].unsqueeze(-1)) f_scaling_atom = torch.sqrt(f_scaling_system[state.system_idx].unsqueeze(-1)) v_mixing_atom = forces * (v_scaling_atom / (f_scaling_atom + eps)) alpha_atom = state.alpha[state.system_idx].unsqueeze(-1) state.velocities = torch.where( pos_mask_system[state.system_idx].unsqueeze(-1), (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, torch.zeros_like(state.velocities), ) # Acceleration (single forward-Euler, no mass for ASE FIRE) state.velocities += forces * state.dt[state.system_idx].unsqueeze(-1) dr_atom = state.velocities * state.dt[state.system_idx].unsqueeze(-1) dr_scaling_system = tsm.batched_vdot(dr_atom, dr_atom, state.system_idx) if isinstance(state, CellFireState): state.cell_velocities += state.cell_forces * state.dt.view(n_systems, 1, 1) dr_cell = state.cell_velocities * state.dt.view(n_systems, 1, 1) dr_scaling_system += dr_cell.pow(2).sum(dim=(1, 2)) dr_scaling_cell = torch.sqrt(dr_scaling_system).view(n_systems, 1, 1) dr_cell = torch.where( dr_scaling_cell > max_step, max_step * dr_cell / (dr_scaling_cell + eps), dr_cell, ) dr_scaling_atom = torch.sqrt(dr_scaling_system)[state.system_idx].unsqueeze(-1) dr_atom = torch.where( dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), dr_atom, ) # Position updates if isinstance(state, CellFireState): # For cell optimization, handle both atomic and cell position updates # This follows the ASE FIRE implementation pattern # Transform atomic positions to fractional coordinates cur_deform_grad = cell_filters.deform_grad( state.reference_cell.mT, state.row_vector_cell ) frac_positions = torch.linalg.solve( cur_deform_grad[state.system_idx], state.positions.unsqueeze(-1) ).squeeze(-1) # Store fractional positions (will transform to Cartesian after cell update) new_frac_positions = frac_positions + dr_atom # Update cell positions directly based on stored cell filter type if hasattr(state, "cell_filter") and state.cell_filter is not None: from torch_sim.optimizers.cell_filters import frechet_cell_filter_init init_fn, _step_fn = state.cell_filter is_frechet = init_fn is frechet_cell_filter_init # Update cell positions cell_positions_new = state.cell_positions + dr_cell state.cell_positions = cell_positions_new if is_frechet: # Frechet: convert from log space to deformation gradient cell_factor_reshaped = state.cell_factor.view(state.n_systems, 1, 1) deform_grad_log_new = cell_positions_new / cell_factor_reshaped deform_grad_new = torch.matrix_exp(deform_grad_log_new) else: # Unit cell: positions are scaled deformation gradient cell_factor_expanded = state.cell_factor.expand(state.n_systems, 3, 1) deform_grad_new = cell_positions_new / cell_factor_expanded # Compute new cell from deformation gradient new_col_vector_cell = torch.bmm(deform_grad_new, state.reference_cell) # Apply cell constraints and scale positions to new cell coordinates # (needed for correct displacement calculation in position constraints) state.set_constrained_cell(new_col_vector_cell, scale_atoms=True) # Transform fractional positions to Cartesian using NEW deformation gradient new_deform_grad = cell_filters.deform_grad( state.reference_cell.mT, state.row_vector_cell ) state.set_constrained_positions( torch.bmm( new_frac_positions.unsqueeze(1), new_deform_grad[state.system_idx].transpose(-2, -1), ).squeeze(1) ) else: state.set_constrained_positions(state.positions + dr_atom) # Get new forces, energy, and stress model_output = model(state) state.set_constrained_forces(model_output["forces"]) state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] # Update cell forces if isinstance(state, CellFireState): cell_filters.compute_cell_forces(model_output, state) return state