Source code for torch_sim.monte_carlo

"""Propagators for Monte Carlo simulations.

This module provides functionality for performing Monte Carlo simulations,
particularly focused on swap Monte Carlo for atomic systems. It includes
implementations of the Metropolis criterion, swap generation, and utility
functions for handling permutations in batched systems.

The `swap_mc_init` and `swap_mc_step` functions can be used
with `integrate` but if a trajectory is being reported, the
`TorchSimTrajectory.write_state` method must be called with `variable_masses=True`.

Examples:
    >>> import torch_sim as ts
    >>> mc_state = ts.swap_mc_init(model, initial_state, seed=42)
    >>> for _ in range(1000):
    ...     mc_state = ts.swap_mc_step(model, mc_state, kT=0.1 * units.energy)
"""

from dataclasses import dataclass

import torch

from torch_sim.models.interface import ModelInterface
from torch_sim.state import SimState


[docs] @dataclass(kw_only=True) class SwapMCState(SimState): """State for Monte Carlo simulations with swap moves. This class extends the SimState to include properties specific to Monte Carlo simulations, such as the system energy and records of permutations applied during the simulation. Attributes: energy (torch.Tensor): Energy of the system with shape [batch_size] last_permutation (torch.Tensor): Last permutation applied to the system, with shape [n_atoms], tracking the moves made for analysis or reversal """ energy: torch.Tensor last_permutation: torch.Tensor _atom_attributes = SimState._atom_attributes | {"last_permutation"} # noqa: SLF001 _system_attributes = SimState._system_attributes | {"energy"} # noqa: SLF001
[docs] def generate_swaps(state: SimState, rng: torch.Generator | None = None) -> torch.Tensor: """Generate atom swaps for a given batched system. Generates proposed swaps between atoms of different types within the same system. The function ensures that swaps only occur between atoms with different atomic numbers. Args: state (SimState): The simulation state rng (torch.Generator | None, optional): Random number generator for reproducibility. Defaults to None. Returns: torch.Tensor: A tensor of proposed swaps with shape [n_systems, 2], where each row contains indices of atoms to be swapped """ system = state.system_idx atomic_numbers = state.atomic_numbers system_lengths = system.bincount() # change system_lengths to system system = torch.repeat_interleave( torch.arange(len(system_lengths), device=system.device), system_lengths ) # Create ragged weights tensor without loops max_length = torch.max(system_lengths).item() n_systems = len(system_lengths) # Create a range tensor for each system range_tensor = torch.arange(max_length, device=system.device).expand( n_systems, max_length ) # Create a mask where values are less than the max system length system_lengths_expanded = system_lengths.unsqueeze(1).expand(n_systems, max_length) weights = (range_tensor < system_lengths_expanded).float() first_index = torch.multinomial(weights, 1, replacement=False, generator=rng) # Process each system - we need this loop because of ragged systems system_starts = system_lengths.cumsum(dim=0) - system_lengths[0] for sys_idx in range(n_systems): # Get global index of selected atom first_idx = first_index[sys_idx, 0].item() + system_starts[sys_idx].item() first_type = atomic_numbers[first_idx] # Get indices of atoms in this system system_start = system_starts[sys_idx].item() system_end = system_start + system_lengths[sys_idx].item() # Create mask for same-type atoms same_type = atomic_numbers[system_start:system_end] == first_type # Zero out weights for same-type atoms (accounting for padding) weights[sys_idx, : len(same_type)][same_type] = 0.0 second_index = torch.multinomial(weights, 1, replacement=False, generator=rng) zeroed_swaps = torch.concatenate([first_index, second_index], dim=1) return zeroed_swaps + (system_lengths.cumsum(dim=0) - system_lengths[0]).unsqueeze(1)
[docs] def swaps_to_permutation(swaps: torch.Tensor, n_atoms: int) -> torch.Tensor: """Convert atom swap pairs to a full permutation tensor. Creates a permutation tensor that represents the result of applying the specified swaps to the system. Args: swaps (torch.Tensor): Tensor of shape [n_swaps, 2] containing pairs of indices to swap n_atoms (int): Total number of atoms in the system Returns: torch.Tensor: Permutation tensor of shape [n_atoms] where permutation[i] contains the index of the atom that should be moved to position i """ permutation = torch.arange(n_atoms, device=swaps.device) for swap in swaps: idx1, idx2 = swap temp = permutation[idx1].clone() permutation[idx1] = permutation[idx2] permutation[idx2] = temp return permutation
[docs] def metropolis_criterion( energy_new: torch.Tensor, energy_old: torch.Tensor, kT: float, rng: torch.Generator | None = None, ) -> torch.Tensor: """Apply the Metropolis acceptance criterion for Monte Carlo moves. Determines whether proposed moves should be accepted or rejected based on the energy difference and system temperature, following the Boltzmann distribution. Args: energy_new (torch.Tensor): New energy after proposed move of shape [batch_size] energy_old (torch.Tensor): Old energy before proposed move of shape [batch_size] kT (float): Temperature of the system in energy units rng (torch.Generator | None, optional): Random number generator for reproducibility. Defaults to None. Returns: torch.Tensor: Boolean tensor of shape [batch_size] indicating acceptance (True) or rejection (False) for each move Notes: The acceptance probability follows min(1, exp(-ΔE/kT)) according to the standard Metropolis algorithm. """ delta_e = energy_new - energy_old # Calculate acceptance probability: min(1, exp(-ΔE/kT)) p_acceptance = torch.clamp(torch.exp(-delta_e / kT), max=1.0) # Generate random numbers between 0 and 1 using the generator random_values = torch.rand( p_acceptance.shape, generator=rng, device=p_acceptance.device ) # Accept if random value < acceptance probability return random_values < p_acceptance
[docs] def swap_mc_init( state: SimState, model: ModelInterface, ) -> SwapMCState: """Initialize a swap Monte Carlo state from input data. Creates an initial state for swap Monte Carlo simulations by computing initial energy and setting up the permutation tracking. The simulation uses the Metropolis criterion to accept or reject proposed swaps based on energy differences. Make sure that if the trajectory is being reported, the `TorchSimTrajectory.write_state` method is called with `variable_masses=True`. Args: model: Energy model that takes a SimState and returns a dict containing 'energy' as a key state: The simulation state to initialize from Returns: SwapMCState: Initialized state for swap Monte Carlo simulation containing positions, energy, and permutation tracking Examples: >>> mc_state = swap_monte_carlo_init(model=energy_model, state=initial_state) >>> for _ in range(100): >>> mc_state = swap_monte_carlo_step(model, mc_state, kT=0.1) """ model_output = model(state) return SwapMCState( positions=state.positions, masses=state.masses, cell=state.cell, pbc=state.pbc, atomic_numbers=state.atomic_numbers, system_idx=state.system_idx, energy=model_output["energy"], last_permutation=torch.arange(state.n_atoms, device=state.device), )
[docs] def swap_mc_step( state: SwapMCState, model: ModelInterface, *, kT: float, seed: int | None = None, rng: torch.Generator | None = None, ) -> SwapMCState: """Perform a single swap Monte Carlo step. Proposes atom swaps, evaluates the energy change, and uses the Metropolis criterion to determine whether to accept the move. Rejected moves are reversed. Args: model: Energy model that takes a SimState and returns a dict containing 'energy' as a key state: The current Monte Carlo state kT: Temperature parameter in energy units seed: (Deprecated) Seed for the random number generator. If provided and `generator` is None, a temporary generator seeded with this value will be used. rng: Optional torch.Generator to drive all randomness for this step. Prefer passing a persistent generator across steps for reproducibility. Returns: SwapMCState: Updated Monte Carlo state after applying the step Notes: The function handles batched systems and ensures that swaps only occur within the same system. """ # Prefer explicit generator if provided; otherwise build one from seed _rng = rng if _rng is None and seed is not None: _rng = torch.Generator(device=model.device) _rng.manual_seed(seed) swaps = generate_swaps(state, rng=_rng) permutation = swaps_to_permutation(swaps, state.n_atoms) if not torch.all(state.system_idx == state.system_idx[permutation]): raise ValueError("Swaps must be between atoms in the same system") energies_old = state.energy.clone() state.positions = state.positions[permutation].clone() model_output = model(state) energies_new = model_output["energy"] accepted = metropolis_criterion(energies_new, energies_old, kT, rng=_rng) rejected_swaps = swaps[~accepted] reverse_rejected_swaps = swaps_to_permutation(rejected_swaps, state.n_atoms) state.positions = state.positions[reverse_rejected_swaps] state.energy = torch.where(accepted, energies_new, energies_old) state.last_permutation = permutation[reverse_rejected_swaps].clone() return state