initial commit
This commit is contained in:
556
mne/rank.py
Normal file
556
mne/rank.py
Normal file
@@ -0,0 +1,556 @@
|
||||
"""Some utility functions for rank estimation."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
from scipy import linalg
|
||||
|
||||
from ._fiff.meas_info import Info, _simplify_info
|
||||
from ._fiff.pick import _picks_by_type, _picks_to_idx, pick_channels_cov, pick_info
|
||||
from ._fiff.proj import make_projector
|
||||
from .defaults import _handle_default
|
||||
from .utils import (
|
||||
_apply_scaling_cov,
|
||||
_check_on_missing,
|
||||
_check_rank,
|
||||
_compute_row_norms,
|
||||
_on_missing,
|
||||
_pl,
|
||||
_scaled_array,
|
||||
_undo_scaling_cov,
|
||||
_validate_type,
|
||||
fill_doc,
|
||||
logger,
|
||||
verbose,
|
||||
warn,
|
||||
)
|
||||
|
||||
|
||||
@verbose
|
||||
def estimate_rank(
|
||||
data,
|
||||
tol="auto",
|
||||
return_singular=False,
|
||||
norm=True,
|
||||
tol_kind="absolute",
|
||||
verbose=None,
|
||||
):
|
||||
"""Estimate the rank of data.
|
||||
|
||||
This function will normalize the rows of the data (typically
|
||||
channels or vertices) such that non-zero singular values
|
||||
should be close to one.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array
|
||||
Data to estimate the rank of (should be 2-dimensional).
|
||||
%(tol_rank)s
|
||||
return_singular : bool
|
||||
If True, also return the singular values that were used
|
||||
to determine the rank.
|
||||
norm : bool
|
||||
If True, data will be scaled by their estimated row-wise norm.
|
||||
Else data are assumed to be scaled. Defaults to True.
|
||||
%(tol_kind_rank)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
rank : int
|
||||
Estimated rank of the data.
|
||||
s : array
|
||||
If return_singular is True, the singular values that were
|
||||
thresholded to determine the rank are also returned.
|
||||
"""
|
||||
if norm:
|
||||
data = data.copy() # operate on a copy
|
||||
norms = _compute_row_norms(data)
|
||||
data /= norms[:, np.newaxis]
|
||||
s = linalg.svdvals(data)
|
||||
rank = _estimate_rank_from_s(s, tol, tol_kind)
|
||||
if return_singular is True:
|
||||
return rank, s
|
||||
else:
|
||||
return rank
|
||||
|
||||
|
||||
def _estimate_rank_from_s(s, tol="auto", tol_kind="absolute"):
|
||||
"""Estimate the rank of a matrix from its singular values.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
s : ndarray, shape (..., ndim)
|
||||
The singular values of the matrix.
|
||||
tol : float | ``'auto'``
|
||||
Tolerance for singular values to consider non-zero in calculating the
|
||||
rank. Can be 'auto' to use the same thresholding as
|
||||
``scipy.linalg.orth`` (assuming np.float64 datatype) adjusted
|
||||
by a factor of 2.
|
||||
tol_kind : str
|
||||
Can be ``"absolute"`` or ``"relative"``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rank : ndarray, shape (...)
|
||||
The estimated rank.
|
||||
"""
|
||||
s = np.array(s, float)
|
||||
max_s = np.amax(s, axis=-1)
|
||||
if isinstance(tol, str):
|
||||
if tol not in ("auto", "float32"):
|
||||
raise ValueError(f'tol must be "auto" or float, got {repr(tol)}')
|
||||
# XXX this should be float32 probably due to how we save and
|
||||
# load data, but it breaks test_make_inverse_operator (!)
|
||||
# The factor of 2 gets test_compute_covariance_auto_reg[None]
|
||||
# to pass without breaking minimum norm tests. :(
|
||||
# Passing 'float32' is a hack workaround for test_maxfilter_get_rank :(
|
||||
if tol == "float32":
|
||||
eps = np.finfo(np.float32).eps
|
||||
else:
|
||||
eps = np.finfo(np.float64).eps
|
||||
tol = s.shape[-1] * max_s * eps
|
||||
if s.ndim == 1: # typical
|
||||
logger.info(
|
||||
" Using tolerance %0.2g (%0.2g eps * %d dim * %0.2g"
|
||||
" max singular value)",
|
||||
tol,
|
||||
eps,
|
||||
len(s),
|
||||
max_s,
|
||||
)
|
||||
elif not (isinstance(tol, np.ndarray) and tol.dtype.kind == "f"):
|
||||
tol = float(tol)
|
||||
if tol_kind == "relative":
|
||||
tol = tol * max_s
|
||||
|
||||
rank = np.sum(s > tol, axis=-1)
|
||||
return rank
|
||||
|
||||
|
||||
def _estimate_rank_raw(
|
||||
raw, picks=None, tol=1e-4, scalings="norm", with_ref_meg=False, tol_kind="absolute"
|
||||
):
|
||||
"""Aid the transition away from raw.estimate_rank."""
|
||||
if picks is None:
|
||||
picks = _picks_to_idx(raw.info, picks, with_ref_meg=with_ref_meg)
|
||||
# conveniency wrapper to expose the expert "tol" option + scalings options
|
||||
return _estimate_rank_meeg_signals(
|
||||
raw[picks][0], pick_info(raw.info, picks), scalings, tol, False, tol_kind
|
||||
)
|
||||
|
||||
|
||||
@fill_doc
|
||||
def _estimate_rank_meeg_signals(
|
||||
data,
|
||||
info,
|
||||
scalings,
|
||||
tol="auto",
|
||||
return_singular=False,
|
||||
tol_kind="absolute",
|
||||
log_ch_type=None,
|
||||
):
|
||||
"""Estimate rank for M/EEG data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : np.ndarray of float, shape(n_channels, n_samples)
|
||||
The M/EEG signals.
|
||||
%(info_not_none)s
|
||||
scalings : dict | ``'norm'`` | np.ndarray | None
|
||||
The rescaling method to be applied. If dict, it will override the
|
||||
following default dict:
|
||||
|
||||
dict(mag=1e15, grad=1e13, eeg=1e6)
|
||||
|
||||
If ``'norm'`` data will be scaled by channel-wise norms. If array,
|
||||
pre-specified norms will be used. If None, no scaling will be applied.
|
||||
tol : float | str
|
||||
Tolerance. See ``estimate_rank``.
|
||||
return_singular : bool
|
||||
If True, also return the singular values that were used
|
||||
to determine the rank.
|
||||
tol_kind : str
|
||||
Tolerance kind. See ``estimate_rank``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rank : int
|
||||
Estimated rank of the data.
|
||||
s : array
|
||||
If return_singular is True, the singular values that were
|
||||
thresholded to determine the rank are also returned.
|
||||
"""
|
||||
picks_list = _picks_by_type(info)
|
||||
if data.shape[1] < data.shape[0]:
|
||||
ValueError(
|
||||
"You've got fewer samples than channels, your "
|
||||
"rank estimate might be inaccurate."
|
||||
)
|
||||
with _scaled_array(data, picks_list, scalings):
|
||||
out = estimate_rank(
|
||||
data,
|
||||
tol=tol,
|
||||
norm=False,
|
||||
return_singular=return_singular,
|
||||
tol_kind=tol_kind,
|
||||
)
|
||||
rank = out[0] if isinstance(out, tuple) else out
|
||||
if log_ch_type is None:
|
||||
ch_type = " + ".join(list(zip(*picks_list))[0])
|
||||
else:
|
||||
ch_type = log_ch_type
|
||||
logger.info(" Estimated rank (%s): %d", ch_type, rank)
|
||||
return out
|
||||
|
||||
|
||||
@verbose
|
||||
def _estimate_rank_meeg_cov(
|
||||
data,
|
||||
info,
|
||||
scalings,
|
||||
tol="auto",
|
||||
return_singular=False,
|
||||
*,
|
||||
log_ch_type=None,
|
||||
verbose=None,
|
||||
):
|
||||
"""Estimate rank of M/EEG covariance data, given the covariance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : np.ndarray of float, shape (n_channels, n_channels)
|
||||
The M/EEG covariance.
|
||||
%(info_not_none)s
|
||||
scalings : dict | 'norm' | np.ndarray | None
|
||||
The rescaling method to be applied. If dict, it will override the
|
||||
following default dict:
|
||||
|
||||
dict(mag=1e12, grad=1e11, eeg=1e5)
|
||||
|
||||
If 'norm' data will be scaled by channel-wise norms. If array,
|
||||
pre-specified norms will be used. If None, no scaling will be applied.
|
||||
tol : float | str
|
||||
Tolerance. See ``estimate_rank``.
|
||||
return_singular : bool
|
||||
If True, also return the singular values that were used
|
||||
to determine the rank.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rank : int
|
||||
Estimated rank of the data.
|
||||
s : array
|
||||
If return_singular is True, the singular values that were
|
||||
thresholded to determine the rank are also returned.
|
||||
"""
|
||||
picks_list = _picks_by_type(info, exclude=[])
|
||||
scalings = _handle_default("scalings_cov_rank", scalings)
|
||||
_apply_scaling_cov(data, picks_list, scalings)
|
||||
if data.shape[1] < data.shape[0]:
|
||||
ValueError(
|
||||
"You've got fewer samples than channels, your "
|
||||
"rank estimate might be inaccurate."
|
||||
)
|
||||
out = estimate_rank(data, tol=tol, norm=False, return_singular=return_singular)
|
||||
rank = out[0] if isinstance(out, tuple) else out
|
||||
if log_ch_type is None:
|
||||
ch_type_ = " + ".join(list(zip(*picks_list))[0])
|
||||
else:
|
||||
ch_type_ = log_ch_type
|
||||
logger.info(f" Estimated rank ({ch_type_}): {rank}")
|
||||
_undo_scaling_cov(data, picks_list, scalings)
|
||||
return out
|
||||
|
||||
|
||||
@verbose
|
||||
def _get_rank_sss(
|
||||
inst, msg="You should use data-based rank estimate instead", verbose=None
|
||||
):
|
||||
"""Look up rank from SSS data.
|
||||
|
||||
.. note::
|
||||
Throws an error if SSS has not been applied.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst : instance of Raw, Epochs or Evoked, or Info
|
||||
Any MNE object with an .info attribute
|
||||
|
||||
Returns
|
||||
-------
|
||||
rank : int
|
||||
The numerical rank as predicted by the number of SSS
|
||||
components.
|
||||
"""
|
||||
# XXX this is too basic for movement compensated data
|
||||
# https://github.com/mne-tools/mne-python/issues/4676
|
||||
info = inst if isinstance(inst, Info) else inst.info
|
||||
del inst
|
||||
|
||||
proc_info = info.get("proc_history", [])
|
||||
if len(proc_info) > 1:
|
||||
logger.info("Found multiple SSS records. Using the first.")
|
||||
if (
|
||||
len(proc_info) == 0
|
||||
or "max_info" not in proc_info[0]
|
||||
or "in_order" not in proc_info[0]["max_info"]["sss_info"]
|
||||
):
|
||||
raise ValueError(
|
||||
f'Could not find Maxfilter information in info["proc_history"]. {msg}'
|
||||
)
|
||||
proc_info = proc_info[0]
|
||||
max_info = proc_info["max_info"]
|
||||
inside = max_info["sss_info"]["in_order"]
|
||||
nfree = (inside + 1) ** 2 - 1
|
||||
nfree -= (
|
||||
len(max_info["sss_info"]["components"][:nfree])
|
||||
- max_info["sss_info"]["components"][:nfree].sum()
|
||||
)
|
||||
return nfree
|
||||
|
||||
|
||||
def _info_rank(info, ch_type, picks, rank):
|
||||
if ch_type in ["meg", "mag", "grad"] and rank != "full":
|
||||
try:
|
||||
return _get_rank_sss(info)
|
||||
except ValueError:
|
||||
pass
|
||||
return len(picks)
|
||||
|
||||
|
||||
def _compute_rank_int(inst, *args, **kwargs):
|
||||
"""Wrap compute_rank but yield an int."""
|
||||
# XXX eventually we should unify how channel types are handled
|
||||
# so that we don't need to do this, or we do it everywhere.
|
||||
# Using pca=True in compute_whitener might help.
|
||||
return sum(compute_rank(inst, *args, **kwargs).values())
|
||||
|
||||
|
||||
@verbose
|
||||
def compute_rank(
|
||||
inst,
|
||||
rank=None,
|
||||
scalings=None,
|
||||
info=None,
|
||||
tol="auto",
|
||||
proj=True,
|
||||
tol_kind="absolute",
|
||||
on_rank_mismatch="ignore",
|
||||
verbose=None,
|
||||
):
|
||||
"""Compute the rank of data or noise covariance.
|
||||
|
||||
This function will normalize the rows of the data (typically
|
||||
channels or vertices) such that non-zero singular values
|
||||
should be close to one. It operates on :term:`data channels` only.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst : instance of Raw, Epochs, or Covariance
|
||||
Raw measurements to compute the rank from or the covariance.
|
||||
%(rank_none)s
|
||||
scalings : dict | None (default None)
|
||||
Defaults to ``dict(mag=1e15, grad=1e13, eeg=1e6)``.
|
||||
These defaults will scale different channel types
|
||||
to comparable values.
|
||||
%(info)s Only necessary if ``inst`` is a :class:`mne.Covariance`
|
||||
object (since this does not provide ``inst.info``).
|
||||
%(tol_rank)s
|
||||
proj : bool
|
||||
If True, all projs in ``inst`` and ``info`` will be applied or
|
||||
considered when ``rank=None`` or ``rank='info'``.
|
||||
%(tol_kind_rank)s
|
||||
%(on_rank_mismatch)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
rank : dict
|
||||
Estimated rank of the data for each channel type.
|
||||
To get the total rank, you can use ``sum(rank.values())``.
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 0.18
|
||||
"""
|
||||
return _compute_rank(
|
||||
inst=inst,
|
||||
rank=rank,
|
||||
scalings=scalings,
|
||||
info=info,
|
||||
tol=tol,
|
||||
proj=proj,
|
||||
tol_kind=tol_kind,
|
||||
on_rank_mismatch=on_rank_mismatch,
|
||||
)
|
||||
|
||||
|
||||
@verbose
|
||||
def _compute_rank(
|
||||
inst,
|
||||
rank=None,
|
||||
scalings=None,
|
||||
info=None,
|
||||
*,
|
||||
tol="auto",
|
||||
proj=True,
|
||||
tol_kind="absolute",
|
||||
on_rank_mismatch="ignore",
|
||||
log_ch_type=None,
|
||||
verbose=None,
|
||||
):
|
||||
from .cov import Covariance
|
||||
from .epochs import BaseEpochs
|
||||
from .io import BaseRaw
|
||||
|
||||
rank = _check_rank(rank)
|
||||
scalings = _handle_default("scalings_cov_rank", scalings)
|
||||
_check_on_missing(on_rank_mismatch, "on_rank_mismatch")
|
||||
|
||||
if isinstance(inst, Covariance):
|
||||
inst_type = "covariance"
|
||||
if info is None:
|
||||
raise ValueError("info cannot be None if inst is a Covariance.")
|
||||
# Reset bads as it's already taken into account in inst['names']
|
||||
info = info.copy()
|
||||
info["bads"] = []
|
||||
inst = pick_channels_cov(
|
||||
inst,
|
||||
set(inst["names"]) & set(info["ch_names"]),
|
||||
exclude=info["bads"] + inst["bads"],
|
||||
ordered=False,
|
||||
)
|
||||
if info["ch_names"] != inst["names"]:
|
||||
info = pick_info(
|
||||
info, [info["ch_names"].index(name) for name in inst["names"]]
|
||||
)
|
||||
else:
|
||||
info = inst.info
|
||||
inst_type = "data"
|
||||
logger.info(f"Computing rank from {inst_type} with rank={repr(rank)}")
|
||||
|
||||
_validate_type(rank, (str, dict, None), "rank")
|
||||
if isinstance(rank, str): # string, either 'info' or 'full'
|
||||
rank_type = "info"
|
||||
info_type = rank
|
||||
rank = dict()
|
||||
else: # None or dict
|
||||
rank_type = "estimated"
|
||||
if rank is None:
|
||||
rank = dict()
|
||||
|
||||
simple_info = _simplify_info(info)
|
||||
picks_list = _picks_by_type(info, meg_combined=True, ref_meg=False, exclude="bads")
|
||||
for ch_type, picks in picks_list:
|
||||
est_verbose = None
|
||||
if ch_type in rank:
|
||||
# raise an error of user-supplied rank exceeds number of channels
|
||||
if rank[ch_type] > len(picks):
|
||||
raise ValueError(
|
||||
f"rank[{repr(ch_type)}]={rank[ch_type]} exceeds the number"
|
||||
f" of channels ({len(picks)})"
|
||||
)
|
||||
# special case: if whitening a covariance, check the passed rank
|
||||
# against the estimated one
|
||||
est_verbose = False
|
||||
if not (
|
||||
on_rank_mismatch != "ignore"
|
||||
and rank_type == "estimated"
|
||||
and ch_type == "meg"
|
||||
and isinstance(inst, Covariance)
|
||||
and not inst["diag"]
|
||||
):
|
||||
continue
|
||||
ch_names = [info["ch_names"][pick] for pick in picks]
|
||||
n_chan = len(ch_names)
|
||||
if proj:
|
||||
proj_op, n_proj, _ = make_projector(info["projs"], ch_names)
|
||||
else:
|
||||
proj_op, n_proj = None, 0
|
||||
if log_ch_type is None:
|
||||
ch_type_ = ch_type.upper()
|
||||
else:
|
||||
ch_type_ = log_ch_type
|
||||
if rank_type == "info":
|
||||
# use info
|
||||
this_rank = _info_rank(info, ch_type, picks, info_type)
|
||||
if info_type != "full":
|
||||
this_rank -= n_proj
|
||||
logger.info(
|
||||
f" {ch_type_}: rank {this_rank} after "
|
||||
f"{n_proj} projector{_pl(n_proj)} applied to "
|
||||
f"{n_chan} channel{_pl(n_chan)}"
|
||||
)
|
||||
else:
|
||||
logger.info(f" {ch_type_}: rank {this_rank} from info")
|
||||
else:
|
||||
# Use empirical estimation
|
||||
assert rank_type == "estimated"
|
||||
if isinstance(inst, BaseRaw | BaseEpochs):
|
||||
if isinstance(inst, BaseRaw):
|
||||
data = inst.get_data(picks, reject_by_annotation="omit")
|
||||
else: # isinstance(inst, BaseEpochs):
|
||||
data = np.concatenate(inst.get_data(picks), axis=1)
|
||||
if proj:
|
||||
data = np.dot(proj_op, data)
|
||||
this_rank = _estimate_rank_meeg_signals(
|
||||
data,
|
||||
pick_info(simple_info, picks),
|
||||
scalings,
|
||||
tol,
|
||||
False,
|
||||
tol_kind,
|
||||
log_ch_type=log_ch_type,
|
||||
)
|
||||
else:
|
||||
assert isinstance(inst, Covariance)
|
||||
if inst["diag"]:
|
||||
this_rank = (inst["data"][picks] > 0).sum() - n_proj
|
||||
else:
|
||||
data = inst["data"][picks][:, picks]
|
||||
if proj:
|
||||
data = np.dot(np.dot(proj_op, data), proj_op.T)
|
||||
|
||||
this_rank, sing = _estimate_rank_meeg_cov(
|
||||
data,
|
||||
pick_info(simple_info, picks),
|
||||
scalings,
|
||||
tol,
|
||||
return_singular=True,
|
||||
log_ch_type=log_ch_type,
|
||||
verbose=est_verbose,
|
||||
)
|
||||
if ch_type in rank:
|
||||
ratio = sing[this_rank - 1] / sing[rank[ch_type] - 1]
|
||||
if ratio > 100:
|
||||
msg = (
|
||||
f"The passed rank[{repr(ch_type)}]="
|
||||
f"{rank[ch_type]} exceeds the estimated rank "
|
||||
f"of the noise covariance ({this_rank}) "
|
||||
f"leading to a potential increase in "
|
||||
f"noise during whitening by a factor "
|
||||
f"of {np.sqrt(ratio):0.1g}. Ensure that the "
|
||||
f"rank correctly corresponds to that of the "
|
||||
f"given noise covariance matrix."
|
||||
)
|
||||
_on_missing(on_rank_mismatch, msg, "on_rank_mismatch")
|
||||
continue
|
||||
this_info_rank = _info_rank(info, ch_type, picks, "info")
|
||||
logger.info(
|
||||
f" {ch_type_}: rank {this_rank} computed from "
|
||||
f"{n_chan} data channel{_pl(n_chan)} with "
|
||||
f"{n_proj} projector{_pl(n_proj)}"
|
||||
)
|
||||
if this_rank > this_info_rank:
|
||||
warn(
|
||||
"Something went wrong in the data-driven estimation of the data "
|
||||
"rank as it exceeds the theoretical rank from the info "
|
||||
f"({this_rank} > {this_info_rank}). Consider setting rank "
|
||||
'to "auto" or setting it explicitly as an integer.'
|
||||
)
|
||||
if ch_type not in rank:
|
||||
rank[ch_type] = int(this_rank)
|
||||
|
||||
return rank
|
||||
Reference in New Issue
Block a user