initial commit
This commit is contained in:
243
mne/utils/linalg.py
Normal file
243
mne/utils/linalg.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""Utility functions to speed up linear algebraic operations.
|
||||
|
||||
In general, things like np.dot and linalg.svd should be used directly
|
||||
because they are smart about checking for bad values. However, in cases where
|
||||
things are done repeatedly (e.g., thousands of times on tiny matrices), the
|
||||
overhead can become problematic from a performance standpoint. Examples:
|
||||
|
||||
- Optimization routines:
|
||||
- Dipole fitting
|
||||
- Sparse solving
|
||||
- cHPI fitting
|
||||
- Inverse computation
|
||||
- Beamformers (LCMV/DICS)
|
||||
- eLORETA minimum norm
|
||||
|
||||
Significant performance gains can be achieved by ensuring that inputs
|
||||
are Fortran contiguous because that's what LAPACK requires. Without this,
|
||||
inputs will be memcopied.
|
||||
"""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
from scipy import linalg
|
||||
from scipy._lib._util import _asarray_validated
|
||||
|
||||
from ..fixes import _safe_svd
|
||||
|
||||
# For efficiency, names should be str or tuple of str, dtype a builtin
|
||||
# NumPy dtype
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _get_blas_funcs(dtype, names):
|
||||
return linalg.get_blas_funcs(names, (np.empty(0, dtype),))
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _get_lapack_funcs(dtype, names):
|
||||
assert dtype in (np.float64, np.complex128)
|
||||
x = np.empty(0, dtype)
|
||||
return linalg.get_lapack_funcs(names, (x,))
|
||||
|
||||
|
||||
###############################################################################
|
||||
# linalg.svd and linalg.pinv2
|
||||
|
||||
|
||||
def _svd_lwork(shape, dtype=np.float64):
|
||||
"""Set up SVD calculations on identical-shape float64/complex128 arrays."""
|
||||
try:
|
||||
ds = linalg._decomp_svd
|
||||
except AttributeError: # < 1.8.0
|
||||
ds = linalg.decomp_svd
|
||||
gesdd_lwork, gesvd_lwork = _get_lapack_funcs(dtype, ("gesdd_lwork", "gesvd_lwork"))
|
||||
sdd_lwork = ds._compute_lwork(
|
||||
gesdd_lwork, *shape, compute_uv=True, full_matrices=False
|
||||
)
|
||||
svd_lwork = ds._compute_lwork(
|
||||
gesvd_lwork, *shape, compute_uv=True, full_matrices=False
|
||||
)
|
||||
return sdd_lwork, svd_lwork
|
||||
|
||||
|
||||
def _repeated_svd(x, lwork, overwrite_a=False):
|
||||
"""Mimic scipy.linalg.svd, avoid lwork and get_lapack_funcs overhead."""
|
||||
gesdd, gesvd = _get_lapack_funcs(x.dtype, ("gesdd", "gesvd"))
|
||||
# this has to use overwrite_a=False in case we need to fall back to gesvd
|
||||
u, s, v, info = gesdd(
|
||||
x, compute_uv=True, lwork=lwork[0], full_matrices=False, overwrite_a=False
|
||||
)
|
||||
if info > 0:
|
||||
# Fall back to slower gesvd, sometimes gesdd fails
|
||||
u, s, v, info = gesvd(
|
||||
x,
|
||||
compute_uv=True,
|
||||
lwork=lwork[1],
|
||||
full_matrices=False,
|
||||
overwrite_a=overwrite_a,
|
||||
)
|
||||
if info > 0:
|
||||
raise np.linalg.LinAlgError("SVD did not converge")
|
||||
if info < 0:
|
||||
raise ValueError(f"illegal value in {-info}-th argument of internal gesdd")
|
||||
return u, s, v
|
||||
|
||||
|
||||
###############################################################################
|
||||
# linalg.eigh
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _get_evd(dtype):
|
||||
x = np.empty(0, dtype)
|
||||
if dtype == np.float64:
|
||||
driver = "syevd"
|
||||
else:
|
||||
assert dtype == np.complex128
|
||||
driver = "heevd"
|
||||
(evr,) = linalg.get_lapack_funcs((driver,), (x,))
|
||||
return evr, driver
|
||||
|
||||
|
||||
def eigh(a, overwrite_a=False, check_finite=True):
|
||||
"""Efficient wrapper for eigh.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : ndarray, shape (n_components, n_components)
|
||||
The symmetric array operate on.
|
||||
overwrite_a : bool
|
||||
If True, the contents of a can be overwritten for efficiency.
|
||||
check_finite : bool
|
||||
If True, check that all elements are finite.
|
||||
|
||||
Returns
|
||||
-------
|
||||
w : ndarray, shape (n_components,)
|
||||
The N eigenvalues, in ascending order, each repeated according to
|
||||
its multiplicity.
|
||||
v : ndarray, shape (n_components, n_components)
|
||||
The normalized eigenvector corresponding to the eigenvalue ``w[i]``
|
||||
is the column ``v[:, i]``.
|
||||
"""
|
||||
# We use SYEVD, see https://github.com/scipy/scipy/issues/9212
|
||||
if check_finite:
|
||||
a = _asarray_validated(a, check_finite=check_finite)
|
||||
evd, driver = _get_evd(a.dtype)
|
||||
w, v, info = evd(a, lower=1, overwrite_a=overwrite_a)
|
||||
if info == 0:
|
||||
return w, v
|
||||
if info < 0:
|
||||
raise ValueError(f"illegal value in argument {-info} of internal {driver}")
|
||||
else:
|
||||
raise linalg.LinAlgError(
|
||||
"internal fortran routine failed to converge: "
|
||||
f"{info} off-diagonal elements of an "
|
||||
"intermediate tridiagonal form did not converge"
|
||||
" to zero."
|
||||
)
|
||||
|
||||
|
||||
def sqrtm_sym(A, rcond=1e-7, inv=False):
|
||||
"""Compute the sqrt of a positive, semi-definite matrix (or its inverse).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A : ndarray, shape (..., n, n)
|
||||
The array to take the square root of.
|
||||
rcond : float
|
||||
The relative condition number used during reconstruction.
|
||||
inv : bool
|
||||
If True, compute the inverse of the square root rather than the
|
||||
square root itself.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A_sqrt : ndarray, shape (..., n, n)
|
||||
The (possibly inverted) square root of A.
|
||||
s : ndarray, shape (..., n)
|
||||
The original square root singular values (not inverted).
|
||||
"""
|
||||
# Same as linalg.sqrtm(C) but faster, also yields the eigenvalues
|
||||
return _sym_mat_pow(A, -0.5 if inv else 0.5, rcond, return_s=True)
|
||||
|
||||
|
||||
def _sym_mat_pow(A, power, rcond=1e-7, reduce_rank=False, return_s=False):
|
||||
"""Exponentiate Hermitian matrices with optional rank reduction."""
|
||||
assert power in (-1, 0.5, -0.5) # only used internally
|
||||
s, u = np.linalg.eigh(A) # eigenvalues in ascending order
|
||||
# Is it positive semi-defidite? If so, keep real
|
||||
limit = s[..., -1:] * rcond
|
||||
if not (s >= -limit).all(): # allow some tiny small negative ones
|
||||
raise ValueError("Matrix is not positive semi-definite")
|
||||
s[s <= limit] = np.inf if power < 0 else 0
|
||||
if reduce_rank:
|
||||
# These are ordered smallest to largest, so we set the first one
|
||||
# to inf -- then the 1. / s below will turn this to zero, as needed.
|
||||
s[..., 0] = np.inf
|
||||
if power in (-0.5, 0.5):
|
||||
np.sqrt(s, out=s)
|
||||
use_s = 1.0 / s if power < 0 else s
|
||||
out = np.matmul(u * use_s[..., np.newaxis, :], u.swapaxes(-2, -1).conj())
|
||||
if return_s:
|
||||
out = (out, s)
|
||||
return out
|
||||
|
||||
|
||||
# SciPy deprecation of pinv + pinvh rcond (never worked properly anyway)
|
||||
def pinvh(a, rtol=None):
|
||||
"""Compute a pseudo-inverse of a Hermitian matrix.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : ndarray, shape (n, n)
|
||||
The Hermitian array to invert.
|
||||
rtol : float | None
|
||||
The relative tolerance.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a_pinv : ndarray, shape (n, n)
|
||||
The pseudo-inverse of a.
|
||||
"""
|
||||
s, u = np.linalg.eigh(a)
|
||||
del a
|
||||
if rtol is None:
|
||||
rtol = s.size * np.finfo(s.dtype).eps
|
||||
maxS = np.max(np.abs(s))
|
||||
above_cutoff = abs(s) > maxS * rtol
|
||||
psigma_diag = 1.0 / s[above_cutoff]
|
||||
u = u[:, above_cutoff]
|
||||
return (u * psigma_diag) @ u.conj().T
|
||||
|
||||
|
||||
def pinv(a, rtol=None):
|
||||
"""Compute a pseudo-inverse of a matrix.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : ndarray, shape (n, m)
|
||||
The array to invert.
|
||||
rtol : float | None
|
||||
The relative tolerance.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a_pinv : ndarray, shape (m, n)
|
||||
The pseudo-inverse of a.
|
||||
"""
|
||||
u, s, vh = _safe_svd(a, full_matrices=False)
|
||||
del a
|
||||
maxS = np.max(s)
|
||||
if rtol is None:
|
||||
rtol = max(vh.shape + u.shape) * np.finfo(u.dtype).eps
|
||||
rank = np.sum(s > maxS * rtol)
|
||||
u = u[:, :rank]
|
||||
u /= s[:rank]
|
||||
return (u @ vh[:rank]).conj().T
|
||||
Reference in New Issue
Block a user