Source code for torch_sim.math

"""Mathematical operations and utilities. Adapted from https://github.com/abhijeetgangan/torch_matfunc."""

# ruff: noqa: FBT001, FBT002

from typing import Final

import torch

from torch_sim._duecredit import dcite


@torch.jit.script
def torch_divmod(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute division and modulo operations for tensors.

    Args:
        a: Dividend tensor
        b: Divisor tensor

    Returns:
        tuple containing:
            - Quotient tensor
            - Remainder tensor
    """
    d = torch.div(a, b, rounding_mode="floor")
    m = a % b
    return d, m


[docs] def expm_frechet( # noqa: C901 A: torch.Tensor, E: torch.Tensor, method: str | None = None, check_finite: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Frechet derivative of the matrix exponential of A in the direction E. Optimized for batched 3x3 matrices. Also handles single 3x3 matrices by auto-adding a batch dimension. Method notes: - ``SPS`` uses scaling-Pade-squaring for the matrix exponential and its Frechet derivative. See :func:`expm_frechet_sps`. - ``BE`` uses the block matrix identity exp([[A, E], [0, A]]) = [[exp(A), L_exp(A, E)], [0, exp(A)]]. See :func:`expm_frechet_block_enlarge`. Args: A: (B, 3, 3) or (3, 3) tensor. Matrix of which to take the matrix exponential. E: (B, 3, 3) or (3, 3) tensor. Matrix direction in which to take the Frechet derivative. Must have same shape as A. method: str, optional. Choice of algorithm. Should be one of - `SPS` - Scaling-Pade-squaring (default) - `BE` - Block-enlarge check_finite: bool, optional. Whether to check that the input matrix contains only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs. Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: expm_A: Matrix exponential of A. expm_frechet_AE: Frechet derivative of the matrix exponential of A in the direction E. """ if check_finite: if not torch.isfinite(A).all(): raise ValueError("Matrix A contains non-finite values") if not torch.isfinite(E).all(): raise ValueError("Matrix E contains non-finite values") # Convert inputs to torch tensors if they aren't already if not isinstance(A, torch.Tensor): A = torch.tensor(A, dtype=torch.float64) if not isinstance(E, torch.Tensor): E = torch.tensor(E, dtype=torch.float64) if A.shape != E.shape: raise ValueError("expected A and E to be the same shape") if method is None: method = "SPS" if method in ["BE", "blockEnlarge"]: # "blockEnlarge" is deprecated if A.dim() != 3 or A.shape[1] != A.shape[2]: raise ValueError("expected A to be (B, N, N)") return expm_frechet_block_enlarge(A, E) if method == "SPS": return expm_frechet_sps(A, E) raise ValueError(f"Unknown {method=}")
[docs] def matrix_exp(A: torch.Tensor) -> torch.Tensor: """Compute the matrix exponential of A using PyTorch's matrix_exp. Args: A: Input matrix Returns: torch.Tensor: Matrix exponential of A """ return torch.matrix_exp(A)
[docs] @dcite("10.1137/080716426") def expm_frechet_sps( A: torch.Tensor, E: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """Scaling-Pade-squaring helper for Frechet derivative of exp(A) on 3x3 matrices. References: - Awad H. Al-Mohy and Nicholas J. Higham (2009), "Computing the Fréchet Derivative of the Matrix Exponential, with an Application to Condition Number Estimation", SIAM J. Matrix Anal. Appl. 30(4):1639-1657. https://doi.org/10.1137/080716426 """ # Handle unbatched 3x3 input by adding batch dimension unbatched = A.dim() == 2 if unbatched: if A.shape != (3, 3): raise ValueError("expected A to be (3, 3) or (B, 3, 3)") A = A.unsqueeze(0) E = E.unsqueeze(0) if A.dim() != 3 or A.shape[1:] != (3, 3): raise ValueError("expected A, E to be (B, 3, 3) with same shape") batch_size = A.shape[0] device, dtype = A.device, A.dtype ident = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand(batch_size, 3, 3) A_norm_1 = torch.norm(A, p=1, dim=(-2, -1)) scale_val = torch.log2( torch.clamp(A_norm_1.max() / ell_table_61[13], min=1.0, max=2.0**64) ) s = max(0, min(int(torch.ceil(scale_val).item()), 64)) A = A * 2.0**-s E = E * 2.0**-s A2 = torch.matmul(A, A) M2 = torch.matmul(A, E) + torch.matmul(E, A) A4 = torch.matmul(A2, A2) M4 = torch.matmul(A2, M2) + torch.matmul(M2, A2) A6 = torch.matmul(A2, A4) M6 = torch.matmul(A4, M2) + torch.matmul(M4, A2) b = ( 64764752532480000.0, 32382376266240000.0, 7771770303897600.0, 1187353796428800.0, 129060195264000.0, 10559470521600.0, 670442572800.0, 33522128640.0, 1323241920.0, 40840800.0, 960960.0, 16380.0, 182.0, 1.0, ) W1 = b[13] * A6 + b[11] * A4 + b[9] * A2 W2 = b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident Z1 = b[12] * A6 + b[10] * A4 + b[8] * A2 Z2 = b[6] * A6 + b[4] * A4 + b[2] * A2 + b[0] * ident W = torch.matmul(A6, W1) + W2 U = torch.matmul(A, W) V = torch.matmul(A6, Z1) + Z2 Lw1 = b[13] * M6 + b[11] * M4 + b[9] * M2 Lw2 = b[7] * M6 + b[5] * M4 + b[3] * M2 Lz1 = b[12] * M6 + b[10] * M4 + b[8] * M2 Lz2 = b[6] * M6 + b[4] * M4 + b[2] * M2 Lw = torch.matmul(A6, Lw1) + torch.matmul(M6, W1) + Lw2 Lu = torch.matmul(A, Lw) + torch.matmul(E, W) Lv = torch.matmul(A6, Lz1) + torch.matmul(M6, Z1) + Lz2 R = torch.linalg.solve(-U + V, U + V) L = torch.linalg.solve(-U + V, Lu + Lv + torch.matmul(Lu - Lv, R)) for _ in range(s): L = torch.matmul(R, L) + torch.matmul(L, R) R = torch.matmul(R, R) if unbatched: return R.squeeze(0), L.squeeze(0) return R, L
[docs] @dcite("10.1137/1.9780898717778") def expm_frechet_block_enlarge( A: torch.Tensor, E: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """Block-enlarge helper for Frechet derivative via matrix exponential. Builds M = [[A, E], [0, A]], computes exp(M), and extracts: - exp(A) from the top-left block - L_exp(A, E) from the top-right block Reference: Nicholas J. Higham (2008), "Functions of Matrices: Theory and Computation", SIAM. (Frechet derivative block-matrix identity.) Args: A: (B, N, N) Batch of input matrices. E: (B, N, N) Batch of direction matrices. Must have same shape as A. Returns: expm_A: Matrix exponential of A expm_frechet_AE: torch.Tensor Frechet derivative of the matrix exponential of A in the direction E """ batch, n, _ = A.shape # Create block matrix M = [[A, E], [0, A]] of shape (B, 2N, 2N) M = torch.zeros((batch, 2 * n, 2 * n), dtype=A.dtype, device=A.device) M[:, :n, :n] = A M[:, :n, n:] = E M[:, n:, n:] = A # Use matrix exponential (supports batched input) expm_M = matrix_exp(M) return expm_M[:, :n, :n], expm_M[:, :n, n:]
# Maximal values ell_m of ||2**-s A|| such that the backward error bound # does not exceed 2**-53. ell_table_61: Final = ( None, # 1 2.11e-8, 3.56e-4, 1.08e-2, 6.49e-2, 2.00e-1, 4.37e-1, 7.83e-1, 1.23e0, 1.78e0, 2.42e0, # 11 3.13e0, 3.90e0, 4.74e0, 5.63e0, 6.56e0, 7.52e0, 8.53e0, 9.56e0, 1.06e1, 1.17e1, ) def _identity_for_t( T: torch.Tensor, dtype: torch.dtype, device: torch.device ) -> torch.Tensor: """Return identity (3, 3) or (n, 3, 3) matching T's batch shape.""" if T.dim() == 3: n = T.shape[0] return torch.eye(3, dtype=dtype, device=device).unsqueeze(0).expand(n, -1, -1) return torch.eye(3, dtype=dtype, device=device) def _matrix_log_case1a(T: torch.Tensor, lambda_val: torch.Tensor) -> torch.Tensor: """Compute log(T) when q(T) = (T - λI). This is the case where T is a scalar multiple of the identity matrix. T may be (3, 3) or (n, 3, 3); lambda_val scalar or (n, 1, 1). Args: T: The matrix whose logarithm is to be computed lambda_val: The eigenvalue of T as a tensor Returns: The logarithm of T, which is log(λ)·I """ dtype, device = lambda_val.dtype, lambda_val.device identity = _identity_for_t(T, dtype, device) return torch.log(lambda_val) * identity def _matrix_log_case1b( T: torch.Tensor, lambda_val: torch.Tensor, num_tol: float = 1e-16 ) -> torch.Tensor: """Compute log(T) when q(T) = (T - λI)². This is the case where T has a Jordan block of size 2. T may be (3, 3) or (n, 3, 3); lambda_val scalar or (n, 1, 1). Args: T: The matrix whose logarithm is to be computed lambda_val: The eigenvalue of T num_tol: Numerical tolerance for stability checks, default=1e-16 Returns: The logarithm of T """ dtype, device = lambda_val.dtype, lambda_val.device identity = _identity_for_t(T, dtype, device) T_minus_lambdaI = T - lambda_val * identity denom = torch.clamp(lambda_val.abs(), min=num_tol) scale = torch.where(lambda_val.abs() > 1, lambda_val, denom) return torch.log(lambda_val) * identity + T_minus_lambdaI / scale def _ensure_batched( T: torch.Tensor, *eigenvalues: torch.Tensor ) -> tuple[bool, torch.Tensor, tuple[torch.Tensor, ...]]: """Ensure T and eigenvalues are in batched form for matrix log computation. Args: T: Matrix of shape (3, 3) or (n, 3, 3) *eigenvalues: Scalar or (n, 1, 1) shaped eigenvalue tensors Returns: Tuple of (unbatched, T, eigenvalues) where unbatched is True if input was unbatched, T has shape (n, 3, 3), and eigenvalues have shape (n, 1, 1) """ unbatched = T.dim() == 2 if unbatched: T = T.unsqueeze(0) eigenvalues = tuple(ev.view(1, 1, 1) for ev in eigenvalues) return unbatched, T, eigenvalues def _matrix_log_case1c( T: torch.Tensor, lambda_val: torch.Tensor, num_tol: float = 1e-16 ) -> torch.Tensor: """Compute log(T) when q(T) = (T - λI)³. This is the case where T has a Jordan block of size 3. T may be (3, 3) or (n, 3, 3); lambda_val scalar or (n, 1, 1). Args: T: The matrix whose logarithm is to be computed lambda_val: The eigenvalue of T num_tol: Numerical tolerance for stability checks, default=1e-16 Returns: The logarithm of T """ unbatched, T, (lambda_val,) = _ensure_batched(T, lambda_val) dtype, device = lambda_val.dtype, lambda_val.device identity = _identity_for_t(T, dtype, device) T_minus_lambdaI = T - lambda_val * identity T_minus_lambdaI_squared = torch.bmm(T_minus_lambdaI, T_minus_lambdaI) lambda_squared = lambda_val * lambda_val term1 = torch.log(lambda_val) * identity term2 = T_minus_lambdaI / torch.clamp(lambda_val.abs(), min=num_tol) term3 = T_minus_lambdaI_squared / torch.clamp(2 * lambda_squared, min=num_tol) result = term1 + term2 - term3 return result.squeeze(0) if unbatched else result def _matrix_log_case2a( T: torch.Tensor, lambda_val: torch.Tensor, mu: torch.Tensor, num_tol: float = 1e-16 ) -> torch.Tensor: """Compute log(T) when q(T) = (T - λI)(T - μI) with λ≠μ. This is the case with two distinct eigenvalues. T may be (3, 3) or (n, 3, 3); lambda_val, mu scalar or (n, 1, 1). Formula: log T = log μ((T - λI)/(μ - λ)) + log λ((T - μI)/(λ - μ)) Args: T: The matrix whose logarithm is to be computed lambda_val: The repeated eigenvalue of T mu: The non-repeated eigenvalue of T num_tol: Numerical tolerance for stability checks, default=1e-16 Returns: The logarithm of T Raises: ValueError: If λ and μ are too close """ unbatched, T, (lambda_val, mu) = _ensure_batched(T, lambda_val, mu) dtype, device = lambda_val.dtype, lambda_val.device identity = _identity_for_t(T, dtype, device) if (torch.abs(lambda_val - mu) < num_tol).any(): raise ValueError("λ and μ are too close, computation may be unstable") T_minus_lambdaI = T - lambda_val * identity T_minus_muI = T - mu * identity term1 = torch.log(mu) * (T_minus_lambdaI / (mu - lambda_val)) term2 = torch.log(lambda_val) * (T_minus_muI / (lambda_val - mu)) result = term1 + term2 return result.squeeze(0) if unbatched else result def _matrix_log_case2b( T: torch.Tensor, lambda_val: torch.Tensor, mu: torch.Tensor, num_tol: float = 1e-16 ) -> torch.Tensor: """Compute log(T) when q(T) = (T - μI)(T - λI)² with λ≠μ. This is the case with one eigenvalue of multiplicity 2 and one distinct. T may be (3, 3) or (n, 3, 3); lambda_val, mu scalar or (n, 1, 1). Formula: log T = log μ((T - λI)²/(λ - μ)²) - log λ((T - μI)(T - (2λ - μ)I)/(λ - μ)²) + ((T - λI)(T - μI)/(λ(λ - μ))) Args: T: The matrix whose logarithm is to be computed lambda_val: The repeated eigenvalue of T mu: The non-repeated eigenvalue of T num_tol: Numerical tolerance for stability checks, default=1e-16 Returns: The logarithm of T Raises: ValueError: If λ and μ are too close or λ≈0 """ unbatched, T, (lambda_val, mu) = _ensure_batched(T, lambda_val, mu) dtype, device = lambda_val.dtype, lambda_val.device identity = _identity_for_t(T, dtype, device) lambda_minus_mu = lambda_val - mu if (torch.abs(lambda_minus_mu) < num_tol).any(): raise ValueError("λ and μ are too close, computation may be unstable") if (torch.abs(lambda_val) < num_tol).any(): raise ValueError("λ is too close to zero, computation may be unstable") lambda_minus_mu_squared = lambda_minus_mu * lambda_minus_mu T_minus_lambdaI = T - lambda_val * identity T_minus_muI = T - mu * identity T_minus_lambdaI_squared = torch.bmm(T_minus_lambdaI, T_minus_lambdaI) T_minus_2lambda_plus_muI = T - (2 * lambda_val - mu) * identity term2_mat = torch.bmm(T_minus_muI, T_minus_2lambda_plus_muI) term1 = torch.log(mu) * (T_minus_lambdaI_squared / lambda_minus_mu_squared) term2 = -torch.log(lambda_val) * (term2_mat / lambda_minus_mu_squared) term3_mat = torch.bmm(T_minus_lambdaI, T_minus_muI) term3 = term3_mat / (lambda_val * lambda_minus_mu) result = term1 + term2 + term3 return result.squeeze(0) if unbatched else result def _matrix_log_case3( T: torch.Tensor, lambda_val: torch.Tensor, mu: torch.Tensor, nu: torch.Tensor, num_tol: float = 1e-16, ) -> torch.Tensor: """Compute log(T) when q(T) = (T - λI)(T - μI)(T - νI) with λ≠μ≠ν≠λ. This is the case with three distinct eigenvalues. T may be (3, 3) or (n, 3, 3); lambda_val, mu, nu scalar or (n, 1, 1). Formula: log T = log λ((T - μI)(T - νI)/((λ - μ)(λ - ν))) + log μ((T - λI)(T - νI)/((μ - λ)(μ - ν))) + log ν((T - λI)(T - μI)/((ν - λ)(ν - μ))) Args: T: The matrix whose logarithm is to be computed lambda_val: First eigenvalue of T mu: Second eigenvalue of T nu: Third eigenvalue of T num_tol: Numerical tolerance for stability checks, default=1e-6 Returns: The logarithm of T Raises: ValueError: If eigenvalues are too close """ unbatched, T, (lambda_val, mu, nu) = _ensure_batched(T, lambda_val, mu, nu) dtype, device = lambda_val.dtype, lambda_val.device identity = _identity_for_t(T, dtype, device) min_diff = torch.minimum( torch.minimum( torch.abs(lambda_val - mu), torch.abs(lambda_val - nu), ), torch.abs(mu - nu), ) if (min_diff < num_tol).any(): raise ValueError("Eigenvalues are too close, computation may be unstable") T_minus_lambdaI = T - lambda_val * identity T_minus_muI = T - mu * identity T_minus_nuI = T - nu * identity lambda_term_num = torch.bmm(T_minus_muI, T_minus_nuI) lambda_term = torch.log(lambda_val) * ( lambda_term_num / ((lambda_val - mu) * (lambda_val - nu)) ) mu_term_num = torch.bmm(T_minus_lambdaI, T_minus_nuI) mu_term = torch.log(mu) * (mu_term_num / ((mu - lambda_val) * (mu - nu))) nu_term_num = torch.bmm(T_minus_lambdaI, T_minus_muI) nu_term = torch.log(nu) * (nu_term_num / ((nu - lambda_val) * (nu - mu))) result = lambda_term + mu_term + nu_term return result.squeeze(0) if unbatched else result def _determine_matrix_log_cases( T: torch.Tensor, sorted_eig: torch.Tensor, diff: torch.Tensor, n_unique: torch.Tensor, valid: torch.Tensor, num_tol: float, ) -> torch.Tensor: """Determine which matrix log case applies to each system. Args: T: Input matrices of shape (n_systems, 3, 3) sorted_eig: Sorted eigenvalues of shape (n_systems, 3) diff: Differences between consecutive eigenvalues (n_systems, 2) n_unique: Number of unique eigenvalues per system (n_systems,) valid: Boolean mask of valid systems (n_systems,) num_tol: Numerical tolerance Returns: Case indices: 0=case1a, 1=case1b, 2=case1c, 3=case2a, 4=case2b, 5=case3, -1=fallback """ n_systems = T.shape[0] device, dtype_out = T.device, T.dtype case_indices = torch.full((n_systems,), -1, dtype=torch.long, device=device) if not valid.any(): return case_indices eye3 = torch.eye(3, dtype=dtype_out, device=device).unsqueeze(0) # Case 1: all eigenvalues equal m1 = valid & (n_unique == 1) if m1.any(): lam = sorted_eig[:, 0:1].unsqueeze(-1) T_lam = T - lam * eye3 rank1 = torch.linalg.matrix_rank(T_lam) rank2 = torch.linalg.matrix_rank(torch.bmm(T_lam, T_lam)) case_indices.masked_fill_(m1 & (rank1 == 0), 0) case_indices.masked_fill_(m1 & (rank1 > 0) & (rank2 == 0), 1) case_indices.masked_fill_(m1 & (rank1 > 0) & (rank2 > 0), 2) # Case 2: two distinct eigenvalues m2 = valid & (n_unique == 2) if m2.any(): lam_rep = torch.where( diff[:, 0:1] <= num_tol, sorted_eig[:, 0:1], sorted_eig[:, 2:3] ).unsqueeze(-1) mu_val = torch.where( diff[:, 0:1] <= num_tol, sorted_eig[:, 2:3], sorted_eig[:, 0:1] ).unsqueeze(-1) M = torch.bmm(T - mu_val * eye3, torch.bmm(T - lam_rep * eye3, T)) case2a = m2 & (torch.linalg.norm(M, dim=(-2, -1)) < num_tol) case_indices.masked_fill_(case2a, 3) case_indices.masked_fill_(m2 & ~case2a, 4) # Case 3: three distinct eigenvalues case_indices.masked_fill_(valid & (n_unique == 3), 5) return case_indices def _process_matrix_log_case( case_int: int, idx_t: torch.Tensor, T_sub: torch.Tensor, sorted_sub: torch.Tensor, dtype_out: torch.dtype, device: torch.device, num_tol: float, ) -> torch.Tensor: """Process a single matrix log case for the given indices. Args: case_int: Case identifier (-1 to 5) idx_t: Indices of systems belonging to this case T_sub: Subset of matrices for this case sorted_sub: Sorted eigenvalues for this case dtype_out: Output dtype device: Device for computation num_tol: Numerical tolerance Returns: Computed matrix logarithms for the subset """ if case_int == -1: # Fallback to scipy for complex eigenvalues n_sub = idx_t.numel() result = torch.zeros_like(T_sub) for i in range(n_sub): result[i] = matrix_log_scipy(T_sub[i].cpu()).to(device) elif case_int <= 2: # Cases 1a, 1b, 1c lam = sorted_sub[:, 0:1].unsqueeze(-1).to(dtype_out) case1_funcs = { 0: lambda: _matrix_log_case1a(T_sub, lam), 1: lambda: _matrix_log_case1b(T_sub, lam, num_tol), 2: lambda: _matrix_log_case1c(T_sub, lam, num_tol), } result = case1_funcs[case_int]() elif case_int <= 4: # Cases 2a, 2b d = sorted_sub[:, 1:2] - sorted_sub[:, 0:1] lam_rep = ( torch.where(d <= num_tol, sorted_sub[:, 0:1], sorted_sub[:, 2:3]) .unsqueeze(-1) .to(dtype_out) ) mu_val = ( torch.where(d <= num_tol, sorted_sub[:, 2:3], sorted_sub[:, 0:1]) .unsqueeze(-1) .to(dtype_out) ) case2_func = _matrix_log_case2a if case_int == 3 else _matrix_log_case2b result = case2_func(T_sub, lam_rep, mu_val, num_tol) else: # Case 3: three distinct eigenvalues lam = sorted_sub[:, 0:1].unsqueeze(-1).to(dtype_out) mu_val = sorted_sub[:, 1:2].unsqueeze(-1).to(dtype_out) nu_val = sorted_sub[:, 2:3].unsqueeze(-1).to(dtype_out) result = _matrix_log_case3(T_sub, lam, mu_val, nu_val, num_tol) return result @dcite("10.1007/s10659-008-9169-x") def _matrix_log_33(T: torch.Tensor, dtype: torch.dtype = torch.float64) -> torch.Tensor: """Compute the logarithm of 3x3 matrix T based on its eigenvalue structure. The logarithm of this matrix is known exactly as given the in the references. Supports both single matrix (3, 3) and batched input (n_systems, 3, 3). Args: T: The matrix whose logarithm is to be computed, shape (3, 3) or (n_systems, 3, 3) dtype: The data type to use for numerical tolerance, default=torch.float64 Returns: The logarithm of T, same shape as input References: - https://link.springer.com/article/10.1007/s10659-008-9169-x """ num_tol = 1e-16 if dtype == torch.float64 else 1e-8 # Handle unbatched input by adding batch dimension unbatched = T.dim() == 2 if unbatched: if T.shape != (3, 3): raise ValueError("Input must be a 3x3 matrix") T = T.unsqueeze(0) elif T.shape[1:] != (3, 3): raise ValueError("Batched input must have shape (n_systems, 3, 3)") device, dtype_out = T.device, T.dtype eigenvalues = torch.linalg.eigvals(T) # Check for complex eigenvalues - require scipy fallback imag_magnitude = torch.abs(torch.imag(eigenvalues)) has_complex_eig = (imag_magnitude > num_tol).any(dim=1) eigenvalues_real = torch.real(eigenvalues) # Sort eigenvalues once for all systems sorted_eig, _ = torch.sort(eigenvalues_real, dim=1) diff = sorted_eig[:, 1:] - sorted_eig[:, :-1] n_unique = 1 + (diff > num_tol).sum(dim=1) valid = ~has_complex_eig & torch.isfinite(eigenvalues_real).all(dim=1) # Determine case for each system case_indices = _determine_matrix_log_cases( T, sorted_eig, diff, n_unique, valid, num_tol ) # Process each case out = torch.zeros_like(T) for case_int in range(-1, 6): mask = case_indices == case_int if not mask.any(): continue idx_t = mask.nonzero(as_tuple=True)[0] out[idx_t] = _process_matrix_log_case( case_int, idx_t, T[idx_t], sorted_eig[idx_t], dtype_out, device, num_tol ) return out.squeeze(0) if unbatched else out
[docs] def matrix_log_scipy(matrix: torch.Tensor) -> torch.Tensor: """Compute the matrix logarithm of a square matrix using scipy.linalg.logm. This function handles tensors on CPU or GPU and preserves gradients. Args: matrix: A square matrix tensor Returns: torch.Tensor: The matrix logarithm of the input matrix """ import scipy.linalg # Save original device and dtype device, dtype, requires_grad = matrix.device, matrix.dtype, matrix.requires_grad # Detach and move to CPU for scipy matrix_cpu = matrix.detach().cpu().numpy() # Compute the logarithm using scipy result_np = scipy.linalg.logm(matrix_cpu) # Convert back to tensor and move to original device result = torch.tensor(result_np, dtype=dtype, device=device) # If input requires gradient, make the output require gradient too if requires_grad: result = result.requires_grad_() return result
[docs] def matrix_log_33( matrix: torch.Tensor, sim_dtype: torch.dtype = torch.float64, fallback_warning: bool = False, ) -> torch.Tensor: """Compute the matrix logarithm of a square 3x3 matrix. Also supports batched input of shape (n_systems, 3, 3). Args: matrix: A square 3x3 matrix tensor, or batch of shape (n_systems, 3, 3) sim_dtype: Simulation dtype, default=torch.float64 fallback_warning: Whether to print a warning when falling back to scipy, default=False Returns: The matrix logarithm of the input matrix This function attempts to use the exact formula for 3x3 matrices first, and falls back to scipy implementation if that fails. """ # Convert to double precision for stability matrix = matrix.to(torch.float64) try: return _matrix_log_33(matrix).to(sim_dtype) except (ValueError, RuntimeError) as exc: msg = ( f"Error computing matrix logarithm with _matrix_log_33 {exc} \n" "Falling back to scipy" ) if fallback_warning: print(msg) # noqa: T201 # Fall back to scipy implementation if matrix.dim() == 3: out = torch.zeros_like(matrix, dtype=sim_dtype) for i in range(matrix.shape[0]): out[i] = matrix_log_scipy(matrix[i].cpu()).to(matrix.device).to(sim_dtype) return out return matrix_log_scipy(matrix).to(sim_dtype)
[docs] def batched_vdot( x: torch.Tensor, y: torch.Tensor, batch_indices: torch.Tensor ) -> torch.Tensor: """Computes batched vdot (sum of element-wise product) for groups of vectors. Args: x: Tensor of shape [N_total_entities, D] (e.g., forces, velocities). y: Tensor of shape [N_total_entities, D]. batch_indices: Tensor of shape [N_total_entities] indicating batch membership. Returns: Tensor: shape [n_systems] where each element is the sum(x_i * y_i) for entities belonging to that batch, summed over all components D and all entities in the batch. """ if ( x.ndim != 2 or y.ndim != 2 or batch_indices.ndim != 1 or x.shape != y.shape or x.shape[0] != batch_indices.shape[0] ): raise ValueError(f"Invalid input shapes: {x.shape=}, {batch_indices.shape=}") if batch_indices.min() < 0: raise ValueError("batch_indices must be non-negative") output = torch.zeros(int(batch_indices.max()) + 1, dtype=x.dtype, device=x.device) output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1)) return output