"""L-BFGS (Limited-memory BFGS) optimizer implementation.
This module provides a batched L-BFGS optimizer for atomic structure relaxation.
L-BFGS is a quasi-Newton method that approximates the inverse Hessian using
a limited history of position and gradient differences, making it memory-efficient
for large systems while achieving superlinear convergence near the minimum.
When cell_filter is active, forces are transformed using the deformation gradient
to work in the same scaled coordinate space as ASE's UnitCellFilter/FrechetCellFilter.
The prev_forces and prev_positions are stored in the scaled/fractional space to match
ASE's behavior exactly.
"""
from typing import TYPE_CHECKING, Any
import torch
import torch_sim as ts
from torch_sim.optimizers import cell_filters
from torch_sim.optimizers.cell_filters import frechet_cell_filter_init
from torch_sim.state import SimState
from torch_sim.typing import StateDict
if TYPE_CHECKING:
from torch_sim.models.interface import ModelInterface
from torch_sim.optimizers import CellLBFGSState, LBFGSState
from torch_sim.optimizers.cell_filters import CellFilter, CellFilterFuncs
def _compute_atom_idx(system_idx: torch.Tensor, n_systems: int) -> torch.Tensor:
"""Compute per-system atom indices, vectorized.
Args:
system_idx: System index for each atom [N]
n_systems: Number of systems S
Returns:
Tensor [N] with per-system atom indices
"""
device = system_idx.device
counts = torch.bincount(system_idx, minlength=n_systems)
offsets = torch.zeros(n_systems, device=device, dtype=torch.long)
if n_systems > 1:
offsets[1:] = counts[:-1].cumsum(0)
return torch.arange(len(system_idx), device=device) - offsets[system_idx]
def _atoms_to_padded(
x: torch.Tensor,
system_idx: torch.Tensor,
n_systems: int,
max_atoms: int,
) -> torch.Tensor:
"""Convert atom-indexed [N, 3] to padded per-system [S, M, 3].
Args:
x: Tensor of shape [N, 3] where N = total atoms
system_idx: System index for each atom [N]
n_systems: Number of systems S
max_atoms: Maximum atoms per system M
Returns:
Tensor of shape [S, M, 3] with zeros for padding
"""
device, dtype = x.device, x.dtype
out = torch.zeros((n_systems, max_atoms, 3), device=device, dtype=dtype)
atom_idx = _compute_atom_idx(system_idx, n_systems)
out[system_idx, atom_idx] = x
return out
def _padded_to_atoms(
x: torch.Tensor,
system_idx: torch.Tensor,
) -> torch.Tensor:
"""Convert padded per-system [S, M, 3] to atom-indexed [N, 3].
Args:
x: Tensor of shape [S, M, 3]
system_idx: System index for each atom [N]
Returns:
Tensor of shape [N, 3]
"""
n_systems = x.shape[0]
atom_idx = _compute_atom_idx(system_idx, n_systems)
return x[system_idx, atom_idx] # [N, 3]
def _per_system_vdot(
a: torch.Tensor, b: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""Compute per-system dot product with padding mask.
Args:
a: Tensor of shape [S, M, 3]
b: Tensor of shape [S, M, 3]
mask: Boolean mask [S, M] where True = valid atom
Returns:
Tensor of shape [S] with per-system dot products
"""
# Element-wise product then sum over atoms and coordinates
prod = (a * b).sum(dim=-1) # [S, M]
prod = prod * mask.float() # Zero out padded atoms
return prod.sum(dim=-1) # [S]
[docs]
def lbfgs_init(
state: SimState | StateDict,
model: "ModelInterface",
*,
step_size: float = 0.1,
alpha: float | None = None,
cell_filter: "CellFilter | CellFilterFuncs | None" = None,
**filter_kwargs: Any,
) -> "LBFGSState | CellLBFGSState":
r"""Create an initial LBFGSState from a SimState or state dict.
Initializes forces/energy, clears the (s, y) memory, and broadcasts the
fixed step size to all systems.
Shape notation:
N = total atoms across all systems (n_atoms)
S = number of systems (n_systems)
M = max atoms per system (global_max_atoms)
H = history length (starts at 0)
M_ext = M + 3 (extended with cell DOFs per system)
Args:
state: Input state as SimState object or state parameter dict
model: Model that computes energies, forces, and optionally stress
step_size: Fixed per-system step length (damping factor).
If using ASE mode (fixed alpha), set this to 1.0 (or your damping).
If using dynamic mode (default), 0.1 is a safe starting point.
alpha: Initial inverse Hessian stiffness guess (ASE parameter).
If provided (e.g. 70.0), fixes H0 = 1/alpha for all steps (ASE-style).
If None (default), H0 is updated dynamically (Standard L-BFGS).
cell_filter: Filter for cell optimization (None for position-only optimization)
**filter_kwargs: Additional arguments passed to cell filter initialization
Returns:
LBFGSState with initialized optimization tensors, or CellLBFGSState if
cell_filter is provided
Notes:
The optimizer supports two modes of operation:
1. **Standard L-BFGS (default)**: Set `alpha=None`. The inverse Hessian
diagonal $H_0$ is updated dynamically at each step using the scaling
$\gamma_k = (s^T y) / (y^T y)$. This is the standard behavior described
by Nocedal & Wright.
2. **ASE Compatibility Mode**: Set `alpha` (e.g. 70.0) and `step_size=1.0`.
The inverse Hessian diagonal is fixed at $H_0 = 1/\alpha$ throughout the
optimization, and the step is scaled by `step_size` (damping).
This matches `ase.optimize.LBFGS(alpha=70.0, damping=1.0)`.
"""
from torch_sim.optimizers import CellLBFGSState, LBFGSState
tensor_args = {"device": model.device, "dtype": model.dtype}
if not isinstance(state, SimState):
state = SimState(**state)
n_systems = state.n_systems # S
# Compute max atoms per system for per-system history storage
counts = state.n_atoms_per_system # [S]
global_max_atoms = int(counts.max().item()) if len(counts) > 0 else 0 # M
max_atoms = counts.clone() # [S] - each system's atom count
# Get initial forces and energy from model
model_output = model(state)
energy = model_output["energy"] # [S]
forces = model_output["forces"] # [N, 3]
stress = model_output.get("stress") # [S, 3, 3] or None
# Initialize empty per-system history tensors
# History shape: [S, H, M, 3] where H=0 at start, M = global_max_atoms
s_history = torch.zeros(
(n_systems, 0, global_max_atoms, 3), **tensor_args
) # [S, 0, M, 3]
y_history = torch.zeros(
(n_systems, 0, global_max_atoms, 3), **tensor_args
) # [S, 0, M, 3]
# Alpha tensor: 0.0 means dynamic, >0 means fixed
alpha_val = 0.0 if alpha is None else alpha
alpha_tensor = torch.full((n_systems,), alpha_val, **tensor_args) # [S]
common_args = {
# Copy SimState attributes
"positions": state.positions.clone(), # [N, 3]
"masses": state.masses.clone(), # [N]
"cell": state.cell.clone(), # [S, 3, 3]
"atomic_numbers": state.atomic_numbers.clone(), # [N]
"system_idx": state.system_idx.clone(), # [N]
"pbc": state.pbc, # [S, 3]
"charge": state.charge, # preserve charge
"spin": state.spin, # preserve spin
"_constraints": state.constraints, # preserve constraints
# Optimization state
"forces": forces, # [N, 3]
"energy": energy, # [S]
"stress": stress, # [S, 3, 3] or None
# L-BFGS specific state
"prev_forces": forces.clone(), # [N, 3]
"prev_positions": state.positions.clone(), # [N, 3]
"s_history": s_history, # [S, 0, M, 3]
"y_history": y_history, # [S, 0, M, 3]
"step_size": torch.full((n_systems,), step_size, **tensor_args), # [S]
"alpha": alpha_tensor, # [S]
"n_iter": torch.zeros((n_systems,), device=model.device, dtype=torch.int32),
"max_atoms": max_atoms, # [S] atoms per system for padding
}
if cell_filter is not None:
cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter)
# At initialization, deform_grad is identity since reference_cell = current_cell
# Store prev_positions as fractional (same as Cartesian for identity deform_grad)
# Store prev_forces as scaled (same as Cartesian for identity deform_grad)
reference_cell = state.cell.clone() # [S, 3, 3]
cur_deform_grad = cell_filters.deform_grad(
reference_cell.mT, state.cell.mT
) # [S, 3, 3]
# Initial fractional positions = positions
# cur_deform_grad[system_idx]: [N, 3, 3], positions: [N, 3] -> [N, 3]
frac_positions = torch.linalg.solve(
cur_deform_grad[state.system_idx], # [N, 3, 3]
state.positions.unsqueeze(-1), # [N, 3, 1]
).squeeze(-1) # [N, 3]
# Initial scaled forces = forces @ deform_grad = forces
# forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] -> [N, 3]
scaled_forces = torch.bmm(
forces.unsqueeze(1), # [N, 1, 3]
cur_deform_grad[state.system_idx], # [N, 3, 3]
).squeeze(1) # [N, 3]
common_args["reference_cell"] = reference_cell # [S, 3, 3]
common_args["cell_filter"] = cell_filter_funcs
# Store fractional positions and scaled forces for ASE compatibility
common_args["prev_positions"] = frac_positions # [N, 3]
common_args["prev_forces"] = scaled_forces # [N, 3]
# Extended per-system history includes cell DOFs (3 "virtual atoms" per system)
# History shape: [S, H, M+3, 3] where M = global_max_atoms
extended_size_per_system = global_max_atoms + 3 # M_ext = M + 3
common_args["s_history"] = torch.zeros(
(n_systems, 0, extended_size_per_system, 3), **tensor_args
) # [S, 0, M_ext, 3]
common_args["y_history"] = torch.zeros(
(n_systems, 0, extended_size_per_system, 3), **tensor_args
) # [S, 0, M_ext, 3]
cell_state = CellLBFGSState(**common_args)
# Initialize cell-specific attributes
# After init: cell_positions [S, 3, 3], cell_forces [S, 3, 3], cell_factor [S]
init_fn(cell_state, model, **filter_kwargs)
# Store prev_cell_positions and prev_cell_forces for history update
cell_state.prev_cell_positions = cell_state.cell_positions.clone() # [S, 3, 3]
cell_state.prev_cell_forces = cell_state.cell_forces.clone() # [S, 3, 3]
return cell_state
return LBFGSState(**common_args)
[docs]
def lbfgs_step( # noqa: PLR0915, C901
state: "LBFGSState | CellLBFGSState",
model: "ModelInterface",
*,
max_history: int = 20,
max_step: float = 0.2,
curvature_eps: float = 1e-12,
) -> "LBFGSState | CellLBFGSState":
r"""Advance one L-BFGS iteration using the two-loop recursion.
Computes the search direction via the two-loop recursion, applies a
fixed step with optional per-system capping, evaluates new forces and
energy, and updates the limited-memory history with a curvature check.
When cell_filter is active, forces are transformed using the deformation
gradient to work in the same scaled coordinate space as ASE's cell filters.
The prev_positions are stored as fractional coordinates and prev_forces as
scaled forces, exactly matching ASE's pos0/forces0.
Shape notation:
N = total atoms across all systems (n_atoms)
S = number of systems (n_systems)
M = max atoms per system (history dimension)
H = current history length
M_ext = M + 3 (extended with cell DOFs per system)
Args:
state: Current L-BFGS optimization state
model: Model that computes energies, forces, and optionally stress
max_history: Number of (s, y) pairs retained for the two-loop recursion.
max_step: If set, caps the maximum per-atom displacement per iteration.
curvature_eps: Threshold for the curvature ⟨y, s⟩ used to accept new
history pairs.
Returns:
Updated LBFGSState after one optimization step
Notes:
- If `state.alpha > 0` (ASE mode), the initial inverse Hessian estimate is
fixed at $H_0 = 1/\alpha$.
- Otherwise (Standard mode), $H_0$ varies at each step based on the
curvature of the most recent history pair.
References:
- Nocedal & Wright, Numerical Optimization (L-BFGS two-loop recursion).
"""
from torch_sim.optimizers import CellLBFGSState
is_cell_state = isinstance(state, CellLBFGSState)
device, dtype = model.device, model.dtype
eps = 1e-8 if dtype == torch.float32 else 1e-16
n_systems = state.n_systems # S
# Derive max_atoms from history shape: [S, H, M, 3] or [S, H, M_ext, 3]
history_dim = state.s_history.shape[2] # M or M_ext
if is_cell_state:
max_atoms_ext = history_dim # M_ext = M + 3
max_atoms = max_atoms_ext - 3 # M
else:
max_atoms = history_dim # M
max_atoms_ext = max_atoms
# Create valid atom mask for per-system operations: [S, M]
atom_mask = torch.arange(max_atoms, device=device)[None] < state.max_atoms[:, None]
# Extended mask including cell DOFs: [S, M_ext]
if is_cell_state:
ext_mask = torch.cat(
[
atom_mask,
torch.ones((n_systems, 3), device=device, dtype=torch.bool),
],
dim=1,
) # [S, M_ext]
else:
ext_mask = atom_mask # [S, M]
if is_cell_state:
# Get current deformation gradient
# reference_cell.mT: [S, 3, 3], row_vector_cell: [S, 3, 3]
cur_deform_grad = cell_filters.deform_grad(
state.reference_cell.mT, state.row_vector_cell
) # [S, 3, 3]
# Transform forces to scaled coordinates
# forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] -> [N, 3]
forces_scaled = torch.bmm(
state.forces.unsqueeze(1), # [N, 1, 3]
cur_deform_grad[state.system_idx], # [N, 3, 3]
).squeeze(1) # [N, 3]
# Current fractional positions
# positions: [N, 3] -> frac_positions: [N, 3]
frac_positions = torch.linalg.solve(
cur_deform_grad[state.system_idx], # [N, 3, 3]
state.positions.unsqueeze(-1), # [N, 3, 1]
).squeeze(-1) # [N, 3]
# Convert to padded per-system format: [S, M, 3]
g_atoms = _atoms_to_padded(-forces_scaled, state.system_idx, n_systems, max_atoms)
# Cell forces: [S, 3, 3] -> [S, 3, 3]
g_cell = -state.cell_forces # [S, 3, 3]
# Extended gradient: [S, M_ext, 3] = [S, M+3, 3]
g = torch.cat([g_atoms, g_cell], dim=1) # [S, M_ext, 3]
else:
# Convert to padded per-system format: [S, M, 3]
g = _atoms_to_padded(-state.forces, state.system_idx, n_systems, max_atoms)
# Two-loop recursion to compute search direction d = -H_k g_k
# History shape: [S, H, M_ext, 3] or [S, H, M, 3]
cur_history_len = state.s_history.shape[1] # H
q = g.clone() # [S, M_ext, 3] or [S, M, 3]
alphas: list[torch.Tensor] = [] # list of [S] tensors
# First loop (from newest to oldest)
for i in range(cur_history_len - 1, -1, -1):
s_i = state.s_history[:, i] # [S, M_ext, 3] or [S, M, 3]
y_i = state.y_history[:, i] # [S, M_ext, 3] or [S, M, 3]
# ys = y^T s per system: [S]
ys = _per_system_vdot(y_i, s_i, ext_mask) # [S]
rho = torch.where(
ys.abs() > curvature_eps,
1.0 / (ys + eps),
torch.zeros_like(ys),
) # [S]
sq = _per_system_vdot(s_i, q, ext_mask) # [S]
alpha = rho * sq # [S]
alphas.append(alpha)
# q <- q - alpha * y_i (broadcast alpha to [S, 1, 1])
q = q - alpha.view(-1, 1, 1) * y_i # [S, M_ext, 3]
# Initial H0 scaling: gamma = (s^T y)/(y^T y) using the last pair
if cur_history_len > 0:
s_last = state.s_history[:, -1] # [S, M_ext, 3]
y_last = state.y_history[:, -1] # [S, M_ext, 3]
sy = _per_system_vdot(s_last, y_last, ext_mask) # [S]
yy = _per_system_vdot(y_last, y_last, ext_mask) # [S]
gamma_dynamic = torch.where(
yy.abs() > curvature_eps,
sy / (yy + eps),
torch.ones_like(yy),
) # [S]
else:
gamma_dynamic = torch.ones((n_systems,), device=device, dtype=dtype) # [S]
# Fixed gamma (ASE style: 1/alpha)
# If state.alpha > 0, use that. Else use dynamic.
is_fixed = state.alpha > 1e-6 # [S] bool
gamma_fixed = 1.0 / (state.alpha + eps) # [S]
gamma = torch.where(is_fixed, gamma_fixed, gamma_dynamic) # [S]
# z = gamma * q (broadcast gamma to [S, 1, 1])
z = gamma.view(-1, 1, 1) * q # [S, M_ext, 3]
# Second loop (from oldest to newest)
for i in range(cur_history_len):
s_i = state.s_history[:, i] # [S, M_ext, 3]
y_i = state.y_history[:, i] # [S, M_ext, 3]
ys = _per_system_vdot(y_i, s_i, ext_mask) # [S]
rho = torch.where(
ys.abs() > curvature_eps,
1.0 / (ys + eps),
torch.zeros_like(ys),
) # [S]
yz = _per_system_vdot(y_i, z, ext_mask) # [S]
beta = rho * yz # [S]
alpha_i = alphas[cur_history_len - 1 - i] # [S]
# z <- z + s_i * (alpha - beta)
coeff = (alpha_i - beta).view(-1, 1, 1) # [S, 1, 1]
z = z + s_i * coeff # [S, M_ext, 3]
d = -z # search direction: [S, M_ext, 3]
# Apply step_size scaling per system: [S, 1, 1]
step = state.step_size.view(-1, 1, 1) * d # [S, M_ext, 3]
# Per-system max norm (only over valid atoms/DOFs)
step_norms = torch.linalg.norm(step, dim=-1) # [S, M_ext]
step_norms = step_norms * ext_mask.float() # Zero out padded
sys_max = step_norms.max(dim=1).values # [S]
# Scaling factors per system: <= 1.0
scale = torch.where(
sys_max > max_step,
max_step / (sys_max + eps),
torch.ones_like(sys_max),
) # [S]
step = scale.view(-1, 1, 1) * step # [S, M_ext, 3]
# Split step into position and cell components
if is_cell_state:
step_padded = step[:, :max_atoms] # [S, M, 3]
step_cell = step[:, max_atoms:] # [S, 3, 3]
# Convert padded step to atom-level
step_positions = _padded_to_atoms(step_padded, state.system_idx)
else:
step_padded = step # [S, M, 3]
step_positions = _padded_to_atoms(step_padded, state.system_idx)
# Save previous state for history update
# For cell state: store fractional positions and scaled forces (ASE convention)
if is_cell_state:
state.prev_positions = frac_positions.clone() # [N, 3] (fractional)
state.prev_forces = forces_scaled.clone() # [N, 3] (scaled)
state.prev_cell_positions = state.cell_positions.clone() # [S, 3, 3]
state.prev_cell_forces = state.cell_forces.clone() # [S, 3, 3]
# Apply cell step
dr_cell = step_cell # [S, 3, 3]
cell_positions_new = state.cell_positions + dr_cell # [S, 3, 3]
state.cell_positions = cell_positions_new # [S, 3, 3]
# Determine if Frechet filter
init_fn, _step_fn = state.cell_filter
is_frechet = init_fn is frechet_cell_filter_init
if is_frechet:
# Frechet: deform_grad = exp(cell_positions / cell_factor)
cell_factor_reshaped = state.cell_factor.view(n_systems, 1, 1)
deform_grad_log_new = cell_positions_new / cell_factor_reshaped # [S, 3, 3]
deform_grad_new = torch.matrix_exp(deform_grad_log_new) # [S, 3, 3]
else:
# UnitCell: deform_grad = cell_positions / cell_factor
cell_factor_expanded = state.cell_factor.expand(n_systems, 3, 1)
deform_grad_new = cell_positions_new / cell_factor_expanded # [S, 3, 3]
# Update cell: new_cell = reference_cell @ deform_grad^T
# Use set_constrained_cell to apply cell constraints (e.g. FixSymmetry)
new_col_vector_cell = torch.bmm(
deform_grad_new, state.reference_cell
) # [S, 3, 3]
state.set_constrained_cell(new_col_vector_cell, scale_atoms=True)
# Apply position step in fractional space, then convert to Cartesian
new_frac = frac_positions + step_positions # [N, 3]
new_deform_grad = cell_filters.deform_grad(
state.reference_cell.mT, state.row_vector_cell
) # [S, 3, 3]
# new_positions = new_frac @ deform_grad^T
new_positions = torch.bmm(
new_frac.unsqueeze(1), # [N, 1, 3]
new_deform_grad[state.system_idx].transpose(-2, -1), # [N, 3, 3]
).squeeze(1) # [N, 3]
state.set_constrained_positions(new_positions) # [N, 3]
else:
state.prev_positions = state.positions.clone() # [N, 3]
state.prev_forces = state.forces.clone() # [N, 3]
state.set_constrained_positions(state.positions + step_positions) # [N, 3]
# Evaluate new forces/energy
model_output = model(state)
new_forces = model_output["forces"] # [N, 3]
new_energy = model_output["energy"] # [S]
new_stress = model_output.get("stress") # [S, 3, 3] or None
# Update cell forces for next step: [S, 3, 3]
if is_cell_state:
cell_filters.compute_cell_forces(model_output, state)
# Update state
state.set_constrained_forces(new_forces) # [N, 3]
state.energy = new_energy # [S]
state.stress = new_stress # [S, 3, 3] or None
# Build new (s, y) for history in per-system format [S, M_ext, 3] or [S, M, 3]
# s = position difference, y = gradient difference
if is_cell_state:
# Get new scaled forces and fractional positions for history
new_deform_grad = cell_filters.deform_grad(
state.reference_cell.mT, state.row_vector_cell
) # [S, 3, 3]
# new_forces: [N, 3] -> new_forces_scaled: [N, 3]
new_forces_scaled = torch.bmm(
new_forces.unsqueeze(1), # [N, 1, 3]
new_deform_grad[state.system_idx], # [N, 3, 3]
).squeeze(1) # [N, 3]
# positions: [N, 3] -> new_frac_positions: [N, 3]
new_frac_positions = torch.linalg.solve(
new_deform_grad[state.system_idx], # [N, 3, 3]
state.positions.unsqueeze(-1), # [N, 3, 1]
).squeeze(-1) # [N, 3]
# s_new_pos = frac_pos_new - frac_pos_old: [N, 3] -> [S, M, 3]
s_new_pos_atoms = new_frac_positions - state.prev_positions # [N, 3]
s_new_pos = _atoms_to_padded(
s_new_pos_atoms, state.system_idx, n_systems, max_atoms
) # [S, M, 3]
# s_new_cell = cell_pos_new - cell_pos_old: [S, 3, 3]
s_new_cell = state.cell_positions - state.prev_cell_positions # [S, 3, 3]
# Concatenate to extended format: [S, M_ext, 3]
s_new = torch.cat([s_new_pos, s_new_cell], dim=1) # [S, M_ext, 3]
# y_new = grad_diff for positions and cell (gradient = -forces)
# y = grad_new - grad_old = -forces_new - (-forces_old) = forces_old - forces_new
y_new_pos_atoms = -new_forces_scaled - (-state.prev_forces) # [N, 3]
y_new_pos = _atoms_to_padded(
y_new_pos_atoms, state.system_idx, n_systems, max_atoms
) # [S, M, 3]
y_new_cell = -state.cell_forces - (-state.prev_cell_forces) # [S, 3, 3]
y_new = torch.cat([y_new_pos, y_new_cell], dim=1) # [S, M_ext, 3]
else:
# s_new = pos_new - pos_old: [N, 3] -> [S, M, 3]
s_new_atoms = state.positions - state.prev_positions # [N, 3]
s_new = _atoms_to_padded(
s_new_atoms, state.system_idx, n_systems, max_atoms
) # [S, M, 3]
# y_new = grad_diff: [N, 3] -> [S, M, 3]
y_new_atoms = -new_forces - (-state.prev_forces) # [N, 3]
y_new = _atoms_to_padded(
y_new_atoms, state.system_idx, n_systems, max_atoms
) # [S, M, 3]
# Append history and trim if needed
# Note: ASE's L-BFGS doesn't have a curvature check for adding to history.
# Invalid curvatures are handled in the two-loop by checking rho.
# History tensors: [S, H, M_ext, 3] or [S, H, M, 3]
cur_history_len = state.s_history.shape[1] # H
if cur_history_len == 0:
# First entry: [S, 1, M_ext, 3] or [S, 1, M, 3]
s_hist = s_new.unsqueeze(1) # [S, 1, M_ext, 3]
y_hist = y_new.unsqueeze(1) # [S, 1, M_ext, 3]
else:
# Append new entry: [S, H, ...] cat [S, 1, ...] -> [S, H+1, ...]
s_hist = torch.cat([state.s_history, s_new.unsqueeze(1)], dim=1)
y_hist = torch.cat([state.y_history, y_new.unsqueeze(1)], dim=1)
# Trim to max_history
if s_hist.shape[1] > max_history:
s_hist = s_hist[:, -max_history:] # [S, max_history, ...]
y_hist = y_hist[:, -max_history:]
if is_cell_state:
# Store fractional/scaled for next iteration
state.prev_positions = new_frac_positions.clone() # [N, 3] (fractional)
state.prev_forces = new_forces_scaled.clone() # [N, 3] (scaled)
else:
state.prev_forces = new_forces.clone() # [N, 3]
state.prev_positions = state.positions.clone() # [N, 3]
state.s_history = s_hist # [S, H, M_ext, 3] or [S, H, M, 3]
state.y_history = y_hist # [S, H, M_ext, 3] or [S, H, M, 3]
state.n_iter = state.n_iter + 1 # [S]
return state