initial commit
This commit is contained in:
140
mne/preprocessing/otp.py
Normal file
140
mne/preprocessing/otp.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.pick import _picks_to_idx
|
||||
from .._ola import _COLA, _Storer
|
||||
from ..surface import _normalize_vectors
|
||||
from ..utils import logger, verbose
|
||||
|
||||
|
||||
def _svd_cov(cov, data):
|
||||
"""Use a covariance matrix to compute the SVD faster."""
|
||||
# This makes use of mathematical equivalences between PCA and SVD
|
||||
# on zero-mean data
|
||||
s, u = np.linalg.eigh(cov)
|
||||
norm = np.ones((s.size,))
|
||||
mask = s > np.finfo(float).eps * s[-1] # largest is last
|
||||
s = np.sqrt(s, out=s)
|
||||
norm[mask] = 1.0 / s[mask]
|
||||
u *= norm
|
||||
v = np.dot(u.T[mask], data)
|
||||
return u, s, v
|
||||
|
||||
|
||||
@verbose
|
||||
def oversampled_temporal_projection(raw, duration=10.0, picks=None, verbose=None):
|
||||
"""Denoise MEG channels using leave-one-out temporal projection.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
Raw data to denoise.
|
||||
duration : float | str
|
||||
The window duration (in seconds; default 10.) to use. Can also
|
||||
be "min" to use as short a window as possible.
|
||||
%(picks_all_data)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
raw_clean : instance of Raw
|
||||
The cleaned data.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This algorithm is computationally expensive, and can be several times
|
||||
slower than realtime for conventional M/EEG datasets. It uses a
|
||||
leave-one-out procedure with parallel temporal projection to remove
|
||||
individual sensor noise under the assumption that sampled fields
|
||||
(e.g., MEG and EEG) are oversampled by the sensor array
|
||||
:footcite:`LarsonTaulu2018`.
|
||||
|
||||
OTP can improve sensor noise levels (especially under visual
|
||||
inspection) and repair some bad channels. This noise reduction is known
|
||||
to interact with :func:`tSSS <mne.preprocessing.maxwell_filter>` such
|
||||
that increasing the ``st_correlation`` value will likely be necessary.
|
||||
|
||||
Channels marked as bad will not be used to reconstruct good channels,
|
||||
but good channels will be used to process the bad channels. Depending
|
||||
on the type of noise present in the bad channels, this might make
|
||||
them usable again.
|
||||
|
||||
Use of this algorithm is covered by a provisional patent.
|
||||
|
||||
.. versionadded:: 0.16
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
logger.info("Processing MEG data using oversampled temporal projection")
|
||||
picks = _picks_to_idx(raw.info, picks, exclude=())
|
||||
picks_good, picks_bad = list(), list() # these are indices into picks
|
||||
for ii, pi in enumerate(picks):
|
||||
if raw.ch_names[pi] in raw.info["bads"]:
|
||||
picks_bad.append(ii)
|
||||
else:
|
||||
picks_good.append(ii)
|
||||
picks_good = np.array(picks_good, int)
|
||||
picks_bad = np.array(picks_bad, int)
|
||||
|
||||
n_samples = int(round(float(duration) * raw.info["sfreq"]))
|
||||
if n_samples < len(picks_good) - 1:
|
||||
raise ValueError(
|
||||
f"duration ({n_samples / raw.info['sfreq']}) yielded {n_samples} samples, "
|
||||
f"which is fewer than the number of channels -1 ({len(picks_good) - 1})"
|
||||
)
|
||||
n_overlap = n_samples // 2
|
||||
raw_otp = raw.copy().load_data(verbose=False)
|
||||
otp = _COLA(
|
||||
partial(_otp, picks_good=picks_good, picks_bad=picks_bad),
|
||||
_Storer(raw_otp._data, picks=picks),
|
||||
len(raw.times),
|
||||
n_samples,
|
||||
n_overlap,
|
||||
raw.info["sfreq"],
|
||||
)
|
||||
read_lims = list(range(0, len(raw.times), n_samples)) + [len(raw.times)]
|
||||
for start, stop in zip(read_lims[:-1], read_lims[1:]):
|
||||
logger.info(
|
||||
f" Denoising {raw.times[[start, stop - 1]][0]: 8.2f} – "
|
||||
f"{raw.times[[start, stop - 1]][1]: 8.2f} s"
|
||||
)
|
||||
otp.feed(raw[picks, start:stop][0])
|
||||
return raw_otp
|
||||
|
||||
|
||||
def _otp(data, picks_good, picks_bad):
|
||||
"""Perform OTP on one segment of data."""
|
||||
if not np.isfinite(data).all():
|
||||
raise RuntimeError("non-finite data (inf or nan) found in raw instance")
|
||||
# demean our data
|
||||
data_means = np.mean(data, axis=-1, keepdims=True)
|
||||
data -= data_means
|
||||
# make a copy
|
||||
data_good = data[picks_good]
|
||||
# scale the copy that will be used to form the temporal basis vectors
|
||||
# so that _orth_svdvals thresholding should work properly with
|
||||
# different channel types (e.g., M-EEG)
|
||||
norms = _normalize_vectors(data_good)
|
||||
cov = np.dot(data_good, data_good.T)
|
||||
if len(picks_bad) > 0:
|
||||
full_basis = _svd_cov(cov, data_good)[2]
|
||||
for mi, pick in enumerate(picks_good):
|
||||
# operate on original data
|
||||
idx = list(range(mi)) + list(range(mi + 1, len(data_good)))
|
||||
# Equivalent: svd(data[idx], full_matrices=False)[2]
|
||||
t_basis = _svd_cov(cov[np.ix_(idx, idx)], data_good[idx])[2]
|
||||
x = np.dot(np.dot(data_good[mi], t_basis.T), t_basis)
|
||||
x *= norms[mi]
|
||||
x += data_means[pick]
|
||||
data[pick] = x
|
||||
for pick in picks_bad:
|
||||
data[pick] = np.dot(np.dot(data[pick], full_basis.T), full_basis)
|
||||
data[pick] += data_means[pick]
|
||||
return [data]
|
||||
Reference in New Issue
Block a user