initial commit
This commit is contained in:
9
mne/preprocessing/__init__.py
Normal file
9
mne/preprocessing/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Preprocessing with artifact detection, SSP, and ICA."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import lazy_loader as lazy
|
||||
|
||||
(__getattr__, __dir__, __all__) = lazy.attach_stub(__name__, __file__)
|
||||
91
mne/preprocessing/__init__.pyi
Normal file
91
mne/preprocessing/__init__.pyi
Normal file
@@ -0,0 +1,91 @@
|
||||
__all__ = [
|
||||
"EOGRegression",
|
||||
"ICA",
|
||||
"Xdawn",
|
||||
"annotate_amplitude",
|
||||
"annotate_break",
|
||||
"annotate_movement",
|
||||
"annotate_muscle_zscore",
|
||||
"annotate_nan",
|
||||
"compute_average_dev_head_t",
|
||||
"compute_bridged_electrodes",
|
||||
"compute_current_source_density",
|
||||
"compute_fine_calibration",
|
||||
"compute_maxwell_basis",
|
||||
"compute_proj_ecg",
|
||||
"compute_proj_eog",
|
||||
"compute_proj_hfc",
|
||||
"corrmap",
|
||||
"cortical_signal_suppression",
|
||||
"create_ecg_epochs",
|
||||
"create_eog_epochs",
|
||||
"equalize_bads",
|
||||
"eyetracking",
|
||||
"find_bad_channels_lof",
|
||||
"find_bad_channels_maxwell",
|
||||
"find_ecg_events",
|
||||
"find_eog_events",
|
||||
"fix_stim_artifact",
|
||||
"get_score_funcs",
|
||||
"ica_find_ecg_events",
|
||||
"ica_find_eog_events",
|
||||
"ieeg",
|
||||
"infomax",
|
||||
"interpolate_bridged_electrodes",
|
||||
"maxwell_filter",
|
||||
"maxwell_filter_prepare_emptyroom",
|
||||
"nirs",
|
||||
"oversampled_temporal_projection",
|
||||
"peak_finder",
|
||||
"read_eog_regression",
|
||||
"read_fine_calibration",
|
||||
"read_ica",
|
||||
"read_ica_eeglab",
|
||||
"realign_raw",
|
||||
"regress_artifact",
|
||||
"write_fine_calibration",
|
||||
]
|
||||
from . import eyetracking, ieeg, nirs
|
||||
from ._annotate_amplitude import annotate_amplitude
|
||||
from ._annotate_nan import annotate_nan
|
||||
from ._csd import compute_bridged_electrodes, compute_current_source_density
|
||||
from ._css import cortical_signal_suppression
|
||||
from ._fine_cal import (
|
||||
compute_fine_calibration,
|
||||
read_fine_calibration,
|
||||
write_fine_calibration,
|
||||
)
|
||||
from ._lof import find_bad_channels_lof
|
||||
from ._peak_finder import peak_finder
|
||||
from ._regress import EOGRegression, read_eog_regression, regress_artifact
|
||||
from .artifact_detection import (
|
||||
annotate_break,
|
||||
annotate_movement,
|
||||
annotate_muscle_zscore,
|
||||
compute_average_dev_head_t,
|
||||
)
|
||||
from .ecg import create_ecg_epochs, find_ecg_events
|
||||
from .eog import create_eog_epochs, find_eog_events
|
||||
from .hfc import compute_proj_hfc
|
||||
from .ica import (
|
||||
ICA,
|
||||
corrmap,
|
||||
get_score_funcs,
|
||||
ica_find_ecg_events,
|
||||
ica_find_eog_events,
|
||||
read_ica,
|
||||
read_ica_eeglab,
|
||||
)
|
||||
from .infomax_ import infomax
|
||||
from .interpolate import equalize_bads, interpolate_bridged_electrodes
|
||||
from .maxwell import (
|
||||
compute_maxwell_basis,
|
||||
find_bad_channels_maxwell,
|
||||
maxwell_filter,
|
||||
maxwell_filter_prepare_emptyroom,
|
||||
)
|
||||
from .otp import oversampled_temporal_projection
|
||||
from .realign import realign_raw
|
||||
from .ssp import compute_proj_ecg, compute_proj_eog
|
||||
from .stim import fix_stim_artifact
|
||||
from .xdawn import Xdawn
|
||||
280
mne/preprocessing/_annotate_amplitude.py
Normal file
280
mne/preprocessing/_annotate_amplitude.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.pick import _picks_by_type, _picks_to_idx
|
||||
from ..annotations import (
|
||||
Annotations,
|
||||
_adjust_onset_meas_date,
|
||||
_annotations_starts_stops,
|
||||
)
|
||||
from ..fixes import jit
|
||||
from ..io import BaseRaw
|
||||
from ..utils import _mask_to_onsets_offsets, _validate_type, logger, verbose
|
||||
|
||||
|
||||
@verbose
|
||||
def annotate_amplitude(
|
||||
raw,
|
||||
peak=None,
|
||||
flat=None,
|
||||
bad_percent=5,
|
||||
min_duration=0.005,
|
||||
picks=None,
|
||||
*,
|
||||
verbose=None,
|
||||
):
|
||||
"""Annotate raw data based on peak-to-peak amplitude.
|
||||
|
||||
Creates annotations ``BAD_peak`` or ``BAD_flat`` for spans of data where
|
||||
consecutive samples exceed the threshold in ``peak`` or fall below the
|
||||
threshold in ``flat`` for more than ``min_duration``.
|
||||
Channels where more than ``bad_percent`` of the total recording length
|
||||
should be annotated with either ``BAD_peak`` or ``BAD_flat`` are returned
|
||||
in ``bads`` instead.
|
||||
Note that the annotations and the bads are not automatically added to the
|
||||
:class:`~mne.io.Raw` object; use :meth:`~mne.io.Raw.set_annotations` and
|
||||
:class:`info['bads'] <mne.Info>` to do so.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The raw data.
|
||||
peak : float | dict | None
|
||||
Annotate segments based on **maximum** peak-to-peak signal amplitude
|
||||
(PTP). Valid **keys** can be any channel type present in the object.
|
||||
The **values** are floats that set the maximum acceptable PTP. If the
|
||||
PTP is larger than this threshold, the segment will be annotated.
|
||||
If float, the minimum acceptable PTP is applied to all channels.
|
||||
flat : float | dict | None
|
||||
Annotate segments based on **minimum** peak-to-peak signal amplitude
|
||||
(PTP). Valid **keys** can be any channel type present in the object.
|
||||
The **values** are floats that set the minimum acceptable PTP. If the
|
||||
PTP is smaller than this threshold, the segment will be annotated.
|
||||
If float, the minimum acceptable PTP is applied to all channels.
|
||||
bad_percent : float
|
||||
The percentage of the time a channel can be above or below thresholds.
|
||||
Below this percentage, :class:`~mne.Annotations` are created.
|
||||
Above this percentage, the channel involved is return in ``bads``. Note
|
||||
the returned ``bads`` are not automatically added to
|
||||
:class:`info['bads'] <mne.Info>`.
|
||||
Defaults to ``5``, i.e. 5%%.
|
||||
min_duration : float
|
||||
The minimum duration (s) required by consecutives samples to be above
|
||||
``peak`` or below ``flat`` thresholds to be considered.
|
||||
to consider as above or below threshold.
|
||||
For some systems, adjacent time samples with exactly the same value are
|
||||
not totally uncommon. Defaults to ``0.005`` (5 ms).
|
||||
%(picks_good_data)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
annotations : instance of Annotations
|
||||
The annotated bad segments.
|
||||
bads : list
|
||||
The channels detected as bad.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function does not use a window to detect small peak-to-peak or large
|
||||
peak-to-peak amplitude changes as the ``reject`` and ``flat`` argument from
|
||||
:class:`~mne.Epochs` does. Instead, it looks at the difference between
|
||||
consecutive samples.
|
||||
|
||||
- When used to detect segments below ``flat``, at least ``min_duration``
|
||||
seconds of consecutive samples must respect
|
||||
``abs(a[i+1] - a[i]) ≤ flat``.
|
||||
- When used to detect segments above ``peak``, at least ``min_duration``
|
||||
seconds of consecutive samples must respect
|
||||
``abs(a[i+1] - a[i]) ≥ peak``.
|
||||
|
||||
Thus, this function does not detect every temporal event with large
|
||||
peak-to-peak amplitude, but only the ones where the peak-to-peak amplitude
|
||||
is supra-threshold between consecutive samples. For instance, segments
|
||||
experiencing a DC shift will not be picked up. Only the edges from the DC
|
||||
shift will be annotated (and those only if the edge transitions are longer
|
||||
than ``min_duration``).
|
||||
|
||||
This function may perform faster if data is loaded in memory, as it
|
||||
loads data one channel type at a time (across all time points), which is
|
||||
typically not an efficient way to read raw data from disk.
|
||||
|
||||
.. versionadded:: 1.0
|
||||
"""
|
||||
_validate_type(raw, BaseRaw, "raw")
|
||||
picks_ = _picks_to_idx(raw.info, picks, "data_or_ica", exclude="bads")
|
||||
peak = _check_ptp(peak, "peak", raw.info, picks_)
|
||||
flat = _check_ptp(flat, "flat", raw.info, picks_)
|
||||
if peak is None and flat is None:
|
||||
raise ValueError(
|
||||
"At least one of the arguments 'peak' or 'flat' must not be None."
|
||||
)
|
||||
bad_percent = _check_bad_percent(bad_percent)
|
||||
min_duration = _check_min_duration(
|
||||
min_duration, raw.times.size * 1 / raw.info["sfreq"]
|
||||
)
|
||||
min_duration_samples = int(np.round(min_duration * raw.info["sfreq"]))
|
||||
bads = list()
|
||||
|
||||
# grouping picks by channel types to avoid operating on each channel
|
||||
# individually
|
||||
picks = {
|
||||
ch_type: np.intersect1d(picks_of_type, picks_, assume_unique=True)
|
||||
for ch_type, picks_of_type in _picks_by_type(raw.info, exclude="bads")
|
||||
if np.intersect1d(picks_of_type, picks_, assume_unique=True).size != 0
|
||||
}
|
||||
del picks_ # reusing this variable name in for loop
|
||||
|
||||
# skip BAD_acq_skip sections
|
||||
onsets, ends = _annotations_starts_stops(raw, "bad_acq_skip", invert=True)
|
||||
index = np.concatenate(
|
||||
[np.arange(raw.times.size)[onset:end] for onset, end in zip(onsets, ends)]
|
||||
)
|
||||
|
||||
# size matching the diff a[i+1] - a[i]
|
||||
any_flat = np.zeros(len(raw.times) - 1, bool)
|
||||
any_peak = np.zeros(len(raw.times) - 1, bool)
|
||||
|
||||
# look for discrete difference above or below thresholds
|
||||
logger.info("Finding segments below or above PTP threshold.")
|
||||
for ch_type, picks_ in picks.items():
|
||||
data = np.concatenate(
|
||||
[raw[picks_, onset:end][0] for onset, end in zip(onsets, ends)], axis=1
|
||||
)
|
||||
diff = np.abs(np.diff(data, axis=1))
|
||||
|
||||
if flat is not None:
|
||||
flat_ = diff <= flat[ch_type]
|
||||
# reject too short segments
|
||||
flat_ = _reject_short_segments(flat_, min_duration_samples)
|
||||
# reject channels above maximum bad_percentage
|
||||
flat_count = flat_.sum(axis=1)
|
||||
flat_count[np.nonzero(flat_count)] += 1 # offset by 1 due to diff
|
||||
flat_mean = flat_count / raw.times.size * 100
|
||||
flat_ch_to_set_bad = picks_[np.where(flat_mean >= bad_percent)[0]]
|
||||
bads.extend(flat_ch_to_set_bad)
|
||||
# add onset/offset for annotations
|
||||
flat_ch_to_annotate = np.where((0 < flat_mean) & (flat_mean < bad_percent))[
|
||||
0
|
||||
]
|
||||
# convert from raw.times[onset:end] - 1 to raw.times[:] - 1
|
||||
idx = index[np.where(flat_[flat_ch_to_annotate, :])[1]]
|
||||
any_flat[idx] = True
|
||||
|
||||
if peak is not None:
|
||||
peak_ = diff >= peak[ch_type]
|
||||
# reject too short segments
|
||||
peak_ = _reject_short_segments(peak_, min_duration_samples)
|
||||
# reject channels above maximum bad_percentage
|
||||
peak_count = peak_.sum(axis=1)
|
||||
peak_count[np.nonzero(peak_count)] += 1 # offset by 1 due to diff
|
||||
peak_mean = peak_count / raw.times.size * 100
|
||||
peak_ch_to_set_bad = picks_[np.where(peak_mean >= bad_percent)[0]]
|
||||
bads.extend(peak_ch_to_set_bad)
|
||||
# add onset/offset for annotations
|
||||
peak_ch_to_annotate = np.where((0 < peak_mean) & (peak_mean < bad_percent))[
|
||||
0
|
||||
]
|
||||
# convert from raw.times[onset:end] - 1 to raw.times[:] - 1
|
||||
idx = index[np.where(peak_[peak_ch_to_annotate, :])[1]]
|
||||
any_peak[idx] = True
|
||||
|
||||
# annotation for flat
|
||||
annotation_flat = _create_annotations(any_flat, "flat", raw)
|
||||
# annotation for peak
|
||||
annotation_peak = _create_annotations(any_peak, "peak", raw)
|
||||
# group
|
||||
annotations = annotation_flat + annotation_peak
|
||||
# bads
|
||||
bads = [raw.ch_names[bad] for bad in bads if bad not in raw.info["bads"]]
|
||||
|
||||
return annotations, bads
|
||||
|
||||
|
||||
def _check_ptp(ptp, name, info, picks):
|
||||
"""Check the PTP threhsold argument, and converts it to dict if needed."""
|
||||
_validate_type(ptp, ("numeric", dict, None))
|
||||
|
||||
if ptp is not None and not isinstance(ptp, dict):
|
||||
if ptp < 0:
|
||||
raise ValueError(
|
||||
f"Argument '{name}' should define a positive threshold. "
|
||||
f"Provided: '{ptp}'."
|
||||
)
|
||||
ch_types = set(info.get_channel_types(picks))
|
||||
ptp = {ch_type: ptp for ch_type in ch_types}
|
||||
elif isinstance(ptp, dict):
|
||||
for key, value in ptp.items():
|
||||
if value < 0:
|
||||
raise ValueError(
|
||||
f"Argument '{name}' should define positive thresholds. "
|
||||
f"Provided for channel type '{key}': '{value}'."
|
||||
)
|
||||
return ptp
|
||||
|
||||
|
||||
def _check_bad_percent(bad_percent):
|
||||
"""Check that bad_percent is a valid percentage and converts to float."""
|
||||
_validate_type(bad_percent, "numeric", "bad_percent")
|
||||
bad_percent = float(bad_percent)
|
||||
if not 0 <= bad_percent <= 100:
|
||||
raise ValueError(
|
||||
"Argument 'bad_percent' should define a percentage between 0% "
|
||||
f"and 100%. Provided: {bad_percent}%."
|
||||
)
|
||||
return bad_percent
|
||||
|
||||
|
||||
def _check_min_duration(min_duration, raw_duration):
|
||||
"""Check that min_duration is a valid duration and converts to float."""
|
||||
_validate_type(min_duration, "numeric", "min_duration")
|
||||
min_duration = float(min_duration)
|
||||
if min_duration < 0:
|
||||
raise ValueError(
|
||||
"Argument 'min_duration' should define a positive duration in "
|
||||
f"seconds. Provided: '{min_duration}' seconds."
|
||||
)
|
||||
if min_duration >= raw_duration:
|
||||
raise ValueError(
|
||||
"Argument 'min_duration' should define a positive duration in "
|
||||
f"seconds shorter than the raw duration ({raw_duration} seconds). "
|
||||
f"Provided: '{min_duration}' seconds."
|
||||
)
|
||||
return min_duration
|
||||
|
||||
|
||||
def _reject_short_segments(arr, min_duration_samples):
|
||||
"""Check if flat or peak segments are longer than the minimum duration."""
|
||||
assert arr.dtype == np.dtype(bool) and arr.ndim == 2
|
||||
for k, ch in enumerate(arr):
|
||||
onsets, offsets = _mask_to_onsets_offsets(ch)
|
||||
_mark_inner(arr[k], onsets, offsets, min_duration_samples)
|
||||
return arr
|
||||
|
||||
|
||||
@jit()
|
||||
def _mark_inner(arr_k, onsets, offsets, min_duration_samples):
|
||||
"""Inner loop of _reject_short_segments()."""
|
||||
for start, stop in zip(onsets, offsets):
|
||||
if stop - start < min_duration_samples:
|
||||
arr_k[start:stop] = False
|
||||
|
||||
|
||||
def _create_annotations(any_arr, kind, raw):
|
||||
"""Create the peak of flat annotations from the any_arr."""
|
||||
assert kind in ("peak", "flat")
|
||||
starts, stops = _mask_to_onsets_offsets(any_arr)
|
||||
starts, stops = np.array(starts), np.array(stops)
|
||||
onsets = starts / raw.info["sfreq"]
|
||||
durations = (stops - starts) / raw.info["sfreq"]
|
||||
annot = Annotations(
|
||||
onsets,
|
||||
durations,
|
||||
[f"BAD_{kind}"] * len(onsets),
|
||||
orig_time=raw.info["meas_date"],
|
||||
)
|
||||
_adjust_onset_meas_date(annot, raw)
|
||||
return annot
|
||||
38
mne/preprocessing/_annotate_nan.py
Normal file
38
mne/preprocessing/_annotate_nan.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..annotations import Annotations, _adjust_onset_meas_date
|
||||
from ..utils import verbose
|
||||
from .artifact_detection import _annotations_from_mask
|
||||
|
||||
|
||||
@verbose
|
||||
def annotate_nan(raw, *, verbose=None):
|
||||
"""Detect segments with NaN and return a new Annotations instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
Data to find segments with NaN values.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
annot : instance of Annotations
|
||||
New channel-specific annotations for the data.
|
||||
"""
|
||||
data, times = raw.get_data(return_times=True)
|
||||
onsets, durations, ch_names = list(), list(), list()
|
||||
for row, ch_name in zip(data, raw.ch_names):
|
||||
annot = _annotations_from_mask(times, np.isnan(row), "BAD_NAN")
|
||||
onsets.extend(annot.onset)
|
||||
durations.extend(annot.duration)
|
||||
ch_names.extend([[ch_name]] * len(annot))
|
||||
annot = Annotations(
|
||||
onsets, durations, "BAD_NAN", ch_names=ch_names, orig_time=raw.info["meas_date"]
|
||||
)
|
||||
_adjust_onset_meas_date(annot, raw)
|
||||
return annot
|
||||
323
mne/preprocessing/_csd.py
Normal file
323
mne/preprocessing/_csd.py
Normal file
@@ -0,0 +1,323 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
# Copyright 2003-2010 Jürgen Kayser <rjk23@columbia.edu>
|
||||
#
|
||||
# The original CSD Toolbox can be found at
|
||||
# http://psychophysiology.cpmc.columbia.edu/Software/CSDtoolbox/
|
||||
#
|
||||
# Relicensed under BSD-3-Clause and adapted with permission from authors of original GPL
|
||||
# code.
|
||||
|
||||
import numpy as np
|
||||
from scipy.optimize import minimize_scalar
|
||||
from scipy.stats import gaussian_kde
|
||||
|
||||
from .._fiff.constants import FIFF
|
||||
from .._fiff.pick import pick_types
|
||||
from ..bem import fit_sphere_to_headshape
|
||||
from ..channels.interpolation import _calc_g, _calc_h
|
||||
from ..epochs import BaseEpochs, make_fixed_length_epochs
|
||||
from ..evoked import Evoked
|
||||
from ..io import BaseRaw
|
||||
from ..utils import _check_preload, _ensure_int, _validate_type, logger, verbose
|
||||
|
||||
|
||||
def _prepare_G(G, lambda2):
|
||||
G.flat[:: len(G) + 1] += lambda2
|
||||
# compute the CSD
|
||||
Gi = np.linalg.inv(G)
|
||||
|
||||
TC = Gi.sum(0)
|
||||
sgi = np.sum(TC) # compute sum total
|
||||
|
||||
return Gi, TC, sgi
|
||||
|
||||
|
||||
def _compute_csd(G_precomputed, H, radius):
|
||||
"""Compute the CSD."""
|
||||
n_channels = H.shape[0]
|
||||
data = np.eye(n_channels)
|
||||
mu = data.mean(0)
|
||||
Z = data - mu
|
||||
|
||||
Gi, TC, sgi = G_precomputed
|
||||
|
||||
Cp2 = np.dot(Gi, Z)
|
||||
c02 = np.sum(Cp2, axis=0) / sgi
|
||||
C2 = Cp2 - np.dot(TC[:, np.newaxis], c02[np.newaxis, :])
|
||||
X = np.dot(C2.T, H).T / radius**2
|
||||
return X
|
||||
|
||||
|
||||
@verbose
|
||||
def compute_current_source_density(
|
||||
inst,
|
||||
sphere="auto",
|
||||
lambda2=1e-5,
|
||||
stiffness=4,
|
||||
n_legendre_terms=50,
|
||||
copy=True,
|
||||
*,
|
||||
verbose=None,
|
||||
):
|
||||
"""Get the current source density (CSD) transformation.
|
||||
|
||||
Transformation based on spherical spline surface Laplacian
|
||||
:footcite:`PerrinEtAl1987,PerrinEtAl1989,Cohen2014,KayserTenke2015`.
|
||||
|
||||
This function can be used to re-reference the signal using a Laplacian
|
||||
(LAP) "reference-free" transformation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst : instance of Raw, Epochs or Evoked
|
||||
The data to be transformed.
|
||||
sphere : array-like, shape (4,) | str
|
||||
The sphere, head-model of the form (x, y, z, r) where x, y, z
|
||||
is the center of the sphere and r is the radius in meters.
|
||||
Can also be "auto" to use a digitization-based fit.
|
||||
lambda2 : float
|
||||
Regularization parameter, produces smoothness. Defaults to 1e-5.
|
||||
stiffness : float
|
||||
Stiffness of the spline.
|
||||
n_legendre_terms : int
|
||||
Number of Legendre terms to evaluate.
|
||||
copy : bool
|
||||
Whether to overwrite instance data or create a copy.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
inst_csd : instance of Raw, Epochs or Evoked
|
||||
The transformed data. Output type will match input type.
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 0.20
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
_validate_type(inst, (BaseEpochs, BaseRaw, Evoked), "inst")
|
||||
_check_preload(inst, "Computing CSD")
|
||||
|
||||
if inst.info["custom_ref_applied"] == FIFF.FIFFV_MNE_CUSTOM_REF_CSD:
|
||||
raise ValueError("CSD already applied, should not be reapplied")
|
||||
|
||||
_validate_type(copy, (bool), "copy")
|
||||
inst = inst.copy() if copy else inst
|
||||
|
||||
picks = pick_types(inst.info, meg=False, eeg=True, exclude=[])
|
||||
|
||||
if any([ch in np.array(inst.ch_names)[picks] for ch in inst.info["bads"]]):
|
||||
raise ValueError(
|
||||
"CSD cannot be computed with bad EEG channels. Either"
|
||||
" drop (inst.drop_channels(inst.info['bads']) "
|
||||
"or interpolate (`inst.interpolate_bads()`) "
|
||||
"bad EEG channels."
|
||||
)
|
||||
|
||||
if len(picks) == 0:
|
||||
raise ValueError("No EEG channels found.")
|
||||
|
||||
_validate_type(lambda2, "numeric", "lambda2")
|
||||
if not 0 <= lambda2 < 1:
|
||||
raise ValueError(f"lambda2 must be between 0 and 1, got {lambda2}")
|
||||
|
||||
_validate_type(stiffness, "numeric", "stiffness")
|
||||
if stiffness < 0:
|
||||
raise ValueError(f"stiffness must be non-negative got {stiffness}")
|
||||
|
||||
n_legendre_terms = _ensure_int(n_legendre_terms, "n_legendre_terms")
|
||||
if n_legendre_terms < 1:
|
||||
raise ValueError(
|
||||
f"n_legendre_terms must be greater than 0, got {n_legendre_terms}"
|
||||
)
|
||||
|
||||
if isinstance(sphere, str) and sphere == "auto":
|
||||
radius, origin_head, origin_device = fit_sphere_to_headshape(inst.info)
|
||||
x, y, z = origin_head - origin_device
|
||||
sphere = (x, y, z, radius)
|
||||
try:
|
||||
sphere = np.array(sphere, float)
|
||||
x, y, z, radius = sphere
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f'sphere must be "auto" or array-like with shape (4,), got {sphere}'
|
||||
)
|
||||
_validate_type(x, "numeric", "x")
|
||||
_validate_type(y, "numeric", "y")
|
||||
_validate_type(z, "numeric", "z")
|
||||
_validate_type(radius, "numeric", "radius")
|
||||
if radius <= 0:
|
||||
raise ValueError("sphere radius must be greater than 0, got {radius}")
|
||||
|
||||
pos = np.array([inst.info["chs"][pick]["loc"][:3] for pick in picks])
|
||||
if not np.isfinite(pos).all() or np.isclose(pos, 0.0).all(1).any():
|
||||
raise ValueError("Zero or infinite position found in chs")
|
||||
pos -= (x, y, z)
|
||||
|
||||
# Project onto a unit sphere to compute the cosine similarity:
|
||||
pos /= np.linalg.norm(pos, axis=1, keepdims=True)
|
||||
cos_dist = np.clip(np.dot(pos, pos.T), -1, 1)
|
||||
# This is equivalent to doing one minus half the squared Euclidean:
|
||||
# from scipy.spatial.distance import squareform, pdist
|
||||
# cos_dist = 1 - squareform(pdist(pos, 'sqeuclidean')) / 2.
|
||||
del pos
|
||||
|
||||
G = _calc_g(cos_dist, stiffness=stiffness, n_legendre_terms=n_legendre_terms)
|
||||
H = _calc_h(cos_dist, stiffness=stiffness, n_legendre_terms=n_legendre_terms)
|
||||
|
||||
G_precomputed = _prepare_G(G, lambda2)
|
||||
|
||||
trans_csd = _compute_csd(G_precomputed=G_precomputed, H=H, radius=radius)
|
||||
|
||||
epochs = [inst._data] if not isinstance(inst, BaseEpochs) else inst._data
|
||||
for epo in epochs:
|
||||
epo[picks] = np.dot(trans_csd, epo[picks])
|
||||
with inst.info._unlock():
|
||||
inst.info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_CSD
|
||||
for pick in picks:
|
||||
inst.info["chs"][pick].update(
|
||||
coil_type=FIFF.FIFFV_COIL_EEG_CSD, unit=FIFF.FIFF_UNIT_V_M2
|
||||
)
|
||||
|
||||
# Remove rejection thresholds for EEG
|
||||
if isinstance(inst, BaseEpochs):
|
||||
if inst.reject and "eeg" in inst.reject:
|
||||
del inst.reject["eeg"]
|
||||
if inst.flat and "eeg" in inst.flat:
|
||||
del inst.flat["eeg"]
|
||||
|
||||
return inst
|
||||
|
||||
|
||||
@verbose
|
||||
def compute_bridged_electrodes(
|
||||
inst,
|
||||
lm_cutoff=16,
|
||||
epoch_threshold=0.5,
|
||||
l_freq=0.5,
|
||||
h_freq=30,
|
||||
epoch_duration=2,
|
||||
bw_method=None,
|
||||
verbose=None,
|
||||
):
|
||||
r"""Compute bridged EEG electrodes using the intrinsic Hjorth algorithm.
|
||||
|
||||
First, an electrical distance matrix is computed by taking the pairwise
|
||||
variance between electrodes. Local minimums in this matrix below
|
||||
``lm_cutoff`` are indicative of bridging between a pair of electrodes.
|
||||
Pairs of electrodes are marked as bridged as long as their electrical
|
||||
distance is below ``lm_cutoff`` on more than the ``epoch_threshold``
|
||||
proportion of epochs.
|
||||
|
||||
Based on :footcite:`TenkeKayser2001,GreischarEtAl2004,DelormeMakeig2004`
|
||||
and the `EEGLAB implementation
|
||||
<https://psychophysiology.cpmc.columbia.edu/>`__.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst : instance of Raw, Epochs or Evoked
|
||||
The data to compute electrode bridging on.
|
||||
lm_cutoff : float
|
||||
The distance in :math:`{\mu}V^2` cutoff below which to
|
||||
search for a local minimum (lm) indicative of bridging.
|
||||
EEGLAB defaults to 5 :math:`{\mu}V^2`. MNE defaults to
|
||||
16 :math:`{\mu}V^2` to be conservative based on the distributions in
|
||||
:footcite:t:`GreischarEtAl2004`.
|
||||
epoch_threshold : float
|
||||
The proportion of epochs with electrical distance less than
|
||||
``lm_cutoff`` in order to consider the channel bridged.
|
||||
The default is 0.5.
|
||||
l_freq : float
|
||||
The low cutoff frequency to use. Default is 0.5 Hz.
|
||||
h_freq : float
|
||||
The high cutoff frequency to use. Default is 30 Hz.
|
||||
epoch_duration : float
|
||||
The time in seconds to divide the raw into fixed-length epochs
|
||||
to check for consistent bridging. Only used if ``inst`` is
|
||||
:class:`mne.io.BaseRaw`. The default is 2 seconds.
|
||||
bw_method : None
|
||||
``bw_method`` to pass to :class:`scipy.stats.gaussian_kde`.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
bridged_idx : list of tuple
|
||||
The indices of channels marked as bridged with each bridged
|
||||
pair stored as a tuple.
|
||||
ed_matrix : ndarray of float, shape (n_epochs, n_channels, n_channels)
|
||||
The electrical distance matrix for each pair of EEG electrodes.
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 1.1
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
_check_preload(inst, "Computing bridged electrodes")
|
||||
inst = inst.copy() # don't modify original
|
||||
picks = pick_types(inst.info, eeg=True)
|
||||
if len(picks) == 0:
|
||||
raise RuntimeError("No EEG channels found, cannot compute electrode bridging")
|
||||
# first, filter
|
||||
inst.filter(l_freq=l_freq, h_freq=h_freq, picks=picks, verbose=False)
|
||||
|
||||
if isinstance(inst, BaseRaw):
|
||||
inst = make_fixed_length_epochs(
|
||||
inst, duration=epoch_duration, preload=True, verbose=False
|
||||
)
|
||||
|
||||
# standardize shape
|
||||
data = inst.get_data(picks=picks)
|
||||
if isinstance(inst, Evoked):
|
||||
data = data[np.newaxis, ...] # expand evoked
|
||||
|
||||
# next, compute electrical distance matrix, upper triangular
|
||||
n_epochs = data.shape[0]
|
||||
ed_matrix = np.zeros((n_epochs, picks.size, picks.size)) * np.nan
|
||||
for i in range(picks.size):
|
||||
for j in range(i + 1, picks.size):
|
||||
ed_matrix[:, i, j] = np.var(data[:, i] - data[:, j], axis=1)
|
||||
|
||||
# scale, fill in other half, diagonal
|
||||
ed_matrix *= 1e12 # scale to muV**2
|
||||
|
||||
# initialize bridged indices
|
||||
bridged_idx = list()
|
||||
|
||||
# if not enough values below local minimum cutoff, return no bridges
|
||||
ed_flat = ed_matrix[~np.isnan(ed_matrix)]
|
||||
if ed_flat[ed_flat < lm_cutoff].size / n_epochs < epoch_threshold:
|
||||
return bridged_idx, ed_matrix
|
||||
|
||||
# kernel density estimation
|
||||
kde = gaussian_kde(ed_flat[ed_flat < lm_cutoff], bw_method=bw_method)
|
||||
with np.errstate(invalid="ignore"):
|
||||
local_minimum = float(
|
||||
minimize_scalar(
|
||||
lambda x: kde(x) if x < lm_cutoff and x > 0 else np.inf
|
||||
).x.item()
|
||||
)
|
||||
logger.info(f"Local minimum {local_minimum} found")
|
||||
|
||||
# find electrodes that are below the cutoff local minimum on
|
||||
# `epochs_threshold` proportion of epochs
|
||||
for i in range(picks.size):
|
||||
for j in range(i + 1, picks.size):
|
||||
bridged_count = np.sum(ed_matrix[:, i, j] < local_minimum)
|
||||
if bridged_count / n_epochs > epoch_threshold:
|
||||
logger.info(
|
||||
"Bridge detected between "
|
||||
f"{inst.ch_names[picks[i]]} and "
|
||||
f"{inst.ch_names[picks[j]]}"
|
||||
)
|
||||
bridged_idx.append((picks[i], picks[j]))
|
||||
|
||||
return bridged_idx, ed_matrix
|
||||
95
mne/preprocessing/_css.py
Normal file
95
mne/preprocessing/_css.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.pick import _picks_to_idx
|
||||
from ..evoked import Evoked
|
||||
from ..utils import _ensure_int, _validate_type, verbose
|
||||
|
||||
|
||||
def _temp_proj(ref_2, ref_1, raw_data, n_proj=6):
|
||||
# Orthonormalize gradiometer and magnetometer data by a QR decomposition
|
||||
ref_1_orth = np.linalg.qr(ref_1.T)[0]
|
||||
ref_2_orth = np.linalg.qr(ref_2.T)[0]
|
||||
|
||||
# Calculate cross-correlation
|
||||
cross_corr = np.dot(ref_1_orth.T, ref_2_orth)
|
||||
|
||||
# Channel weights for common temporal subspace by SVD of cross-correlation
|
||||
ref_1_ch_weights, _, _ = np.linalg.svd(cross_corr)
|
||||
|
||||
# Get temporal signals from channel weights
|
||||
proj_mat = ref_1_orth @ ref_1_ch_weights
|
||||
|
||||
# Project out common subspace
|
||||
filtered_data = raw_data
|
||||
proj_vec = proj_mat[:, :n_proj]
|
||||
weights = filtered_data @ proj_vec
|
||||
filtered_data -= weights @ proj_vec.T
|
||||
|
||||
|
||||
@verbose
|
||||
def cortical_signal_suppression(
|
||||
evoked, picks=None, mag_picks=None, grad_picks=None, n_proj=6, *, verbose=None
|
||||
):
|
||||
"""Apply cortical signal suppression (CSS) to evoked data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
evoked : instance of Evoked
|
||||
The evoked object to use for CSS. Must contain magnetometer,
|
||||
gradiometer, and EEG channels.
|
||||
%(picks_good_data)s
|
||||
mag_picks : array-like of int
|
||||
Array of the first set of channel indices that will be used to find
|
||||
the common temporal subspace. If None (default), all magnetometers will
|
||||
be used.
|
||||
grad_picks : array-like of int
|
||||
Array of the second set of channel indices that will be used to find
|
||||
the common temporal subspace. If None (default), all gradiometers will
|
||||
be used.
|
||||
n_proj : int
|
||||
The number of projection vectors.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
evoked_subcortical : instance of Evoked
|
||||
The evoked object with contributions from the ``mag_picks`` and ``grad_picks``
|
||||
channels removed from the ``picks`` channels.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This method removes the common signal subspace between two sets of
|
||||
channels (``mag_picks`` and ``grad_picks``) from a set of channels
|
||||
(``picks``) via a temporal projection using ``n_proj`` number of
|
||||
projection vectors. In the reference publication :footcite:`Samuelsson2019`,
|
||||
the joint subspace between magnetometers and gradiometers is used to
|
||||
suppress the cortical signal in the EEG data. In principle, other
|
||||
combinations of sensor types (or channels) could be used to suppress
|
||||
signals from other sources.
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
_validate_type(evoked, Evoked, "evoked")
|
||||
n_proj = _ensure_int(n_proj, "n_proj")
|
||||
picks = _picks_to_idx(evoked.info, picks, none="data", exclude="bads")
|
||||
mag_picks = _picks_to_idx(evoked.info, mag_picks, none="mag", exclude="bads")
|
||||
grad_picks = _picks_to_idx(evoked.info, grad_picks, none="grad", exclude="bads")
|
||||
evoked_subcortical = evoked.copy()
|
||||
|
||||
# Get data
|
||||
all_data = evoked.data
|
||||
mag_data = all_data[mag_picks]
|
||||
grad_data = all_data[grad_picks]
|
||||
|
||||
# Process data with temporal projection algorithm
|
||||
data = all_data[picks]
|
||||
_temp_proj(mag_data, grad_data, data, n_proj=n_proj)
|
||||
evoked_subcortical.data[picks, :] = data
|
||||
|
||||
return evoked_subcortical
|
||||
597
mne/preprocessing/_fine_cal.py
Normal file
597
mne/preprocessing/_fine_cal.py
Normal file
@@ -0,0 +1,597 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from scipy.optimize import minimize
|
||||
|
||||
from .._fiff.pick import pick_info, pick_types
|
||||
from .._fiff.tag import _coil_trans_to_loc, _loc_to_coil_trans
|
||||
from ..bem import _check_origin
|
||||
from ..io import BaseRaw
|
||||
from ..transforms import _find_vector_rotation
|
||||
from ..utils import (
|
||||
_check_fname,
|
||||
_check_option,
|
||||
_clean_names,
|
||||
_ensure_int,
|
||||
_pl,
|
||||
_reg_pinv,
|
||||
_validate_type,
|
||||
check_fname,
|
||||
logger,
|
||||
verbose,
|
||||
)
|
||||
from .maxwell import (
|
||||
_col_norm_pinv,
|
||||
_get_grad_point_coilsets,
|
||||
_prep_fine_cal,
|
||||
_prep_mf_coils,
|
||||
_read_cross_talk,
|
||||
_trans_sss_basis,
|
||||
)
|
||||
|
||||
|
||||
@verbose
|
||||
def compute_fine_calibration(
|
||||
raw,
|
||||
n_imbalance=3,
|
||||
t_window=10.0,
|
||||
ext_order=2,
|
||||
origin=(0.0, 0.0, 0.0),
|
||||
cross_talk=None,
|
||||
calibration=None,
|
||||
*,
|
||||
angle_limit=5.0,
|
||||
err_limit=5.0,
|
||||
verbose=None,
|
||||
):
|
||||
"""Compute fine calibration from empty-room data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The raw data to use. Should be from an empty-room recording,
|
||||
and all channels should be good.
|
||||
n_imbalance : int
|
||||
Can be 1 or 3 (default), indicating the number of gradiometer
|
||||
imbalance components. Only used if gradiometers are present.
|
||||
t_window : float
|
||||
Time window to use for surface normal rotation in seconds.
|
||||
Default is 10.
|
||||
%(ext_order_maxwell)s
|
||||
Default is 2, which is lower than the default (3) for
|
||||
:func:`mne.preprocessing.maxwell_filter` because it tends to yield
|
||||
more stable parameter estimates.
|
||||
%(origin_maxwell)s
|
||||
%(cross_talk_maxwell)s
|
||||
calibration : dict | None
|
||||
Dictionary with existing calibration. If provided, the magnetometer
|
||||
imbalances and adjusted normals will be used and only the gradiometer
|
||||
imbalances will be estimated (see step 2 in Notes below).
|
||||
angle_limit : float
|
||||
The maximum permitted angle in degrees between the original and adjusted
|
||||
magnetometer normals. If the angle is exceeded, the segment is treated as
|
||||
an outlier and discarded.
|
||||
|
||||
.. versionadded:: 1.9
|
||||
err_limit : float
|
||||
The maximum error (in percent) for each channel in order for a segment to
|
||||
be used.
|
||||
|
||||
.. versionadded:: 1.9
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
calibration : dict
|
||||
Fine calibration data.
|
||||
count : int
|
||||
The number of good segments used to compute the magnetometer
|
||||
parameters.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.preprocessing.maxwell_filter
|
||||
|
||||
Notes
|
||||
-----
|
||||
This algorithm proceeds in two steps, both optimizing the fit between the
|
||||
data and a reconstruction of the data based only on an external multipole
|
||||
expansion:
|
||||
|
||||
1. Estimate magnetometer normal directions and scale factors. All
|
||||
coils (mag and matching grad) are rotated by the adjusted normal
|
||||
direction.
|
||||
2. Estimate gradiometer imbalance factors. These add point magnetometers
|
||||
in just the gradiometer difference direction or in all three directions
|
||||
(depending on ``n_imbalance``).
|
||||
|
||||
Magnetometer normal and coefficient estimation (1) is typically the most
|
||||
time consuming step. Gradiometer imbalance parameters (2) can be
|
||||
iteratively reestimated (for example, first using ``n_imbalance=1`` then
|
||||
subsequently ``n_imbalance=3``) by passing the previous ``calibration``
|
||||
output to the ``calibration`` input in the second call.
|
||||
|
||||
MaxFilter processes at most 120 seconds of data, so consider cropping
|
||||
your raw instance prior to processing. It also checks to make sure that
|
||||
there were some minimal usable ``count`` number of segments (default 5)
|
||||
that were included in the estimate.
|
||||
|
||||
.. versionadded:: 0.21
|
||||
"""
|
||||
n_imbalance = _ensure_int(n_imbalance, "n_imbalance")
|
||||
_check_option("n_imbalance", n_imbalance, (1, 3))
|
||||
_validate_type(raw, BaseRaw, "raw")
|
||||
ext_order = _ensure_int(ext_order, "ext_order")
|
||||
origin = _check_origin(origin, raw.info, "meg", disp=True)
|
||||
_check_option("raw.info['bads']", raw.info["bads"], ([],))
|
||||
_validate_type(err_limit, "numeric", "err_limit")
|
||||
_validate_type(angle_limit, "numeric", "angle_limit")
|
||||
for key, val in dict(err_limit=err_limit, angle_limit=angle_limit).items():
|
||||
if val < 0:
|
||||
raise ValueError(f"{key} must be greater than or equal to 0, got {val}")
|
||||
# Fine cal should not include ref channels
|
||||
picks = pick_types(raw.info, meg=True, ref_meg=False)
|
||||
if raw.info["dev_head_t"] is not None:
|
||||
raise ValueError(
|
||||
'info["dev_head_t"] is not None, suggesting that the '
|
||||
"data are not from an empty-room recording"
|
||||
)
|
||||
|
||||
info = pick_info(raw.info, picks) # make a copy and pick MEG channels
|
||||
mag_picks = pick_types(info, meg="mag", exclude=())
|
||||
grad_picks = pick_types(info, meg="grad", exclude=())
|
||||
|
||||
# Get cross-talk
|
||||
ctc, _ = _read_cross_talk(cross_talk, info["ch_names"])
|
||||
|
||||
# Check fine cal
|
||||
_validate_type(calibration, (dict, None), "calibration")
|
||||
|
||||
#
|
||||
# 1. Rotate surface normals using magnetometer information (if present)
|
||||
#
|
||||
cals = np.ones(len(info["ch_names"]))
|
||||
time_idxs = raw.time_as_index(np.arange(0.0, raw.times[-1], t_window))
|
||||
if len(time_idxs) <= 1:
|
||||
time_idxs = np.array([0, len(raw.times)], int)
|
||||
else:
|
||||
time_idxs[-1] = len(raw.times)
|
||||
count = 0
|
||||
locs = np.array([ch["loc"] for ch in info["chs"]])
|
||||
zs = locs[mag_picks, -3:].copy()
|
||||
if calibration is not None:
|
||||
_, calibration, _ = _prep_fine_cal(info, calibration, ignore_ref=True)
|
||||
for pi, pick in enumerate(mag_picks):
|
||||
idx = calibration["ch_names"].index(info["ch_names"][pick])
|
||||
cals[pick] = calibration["imb_cals"][idx].item()
|
||||
zs[pi] = calibration["locs"][idx][-3:]
|
||||
elif len(mag_picks) > 0:
|
||||
cal_list = list()
|
||||
z_list = list()
|
||||
logger.info(
|
||||
f"Adjusting normals for {len(mag_picks)} magnetometers "
|
||||
f"(averaging over {len(time_idxs) - 1} time intervals)"
|
||||
)
|
||||
for start, stop in zip(time_idxs[:-1], time_idxs[1:]):
|
||||
logger.info(
|
||||
f" Processing interval {start / info['sfreq']:0.3f} - "
|
||||
f"{stop / info['sfreq']:0.3f} s"
|
||||
)
|
||||
data = raw[picks, start:stop][0]
|
||||
if ctc is not None:
|
||||
data = ctc.dot(data)
|
||||
z, cal, good = _adjust_mag_normals(
|
||||
info,
|
||||
data,
|
||||
origin,
|
||||
ext_order,
|
||||
angle_limit=angle_limit,
|
||||
err_limit=err_limit,
|
||||
)
|
||||
if good:
|
||||
z_list.append(z)
|
||||
cal_list.append(cal)
|
||||
count = len(cal_list)
|
||||
if count == 0:
|
||||
raise RuntimeError("No usable segments found")
|
||||
cals[:] = np.mean(cal_list, axis=0)
|
||||
zs[:] = np.mean(z_list, axis=0)
|
||||
if len(mag_picks) > 0:
|
||||
for ii, new_z in enumerate(zs):
|
||||
z_loc = locs[mag_picks[ii]]
|
||||
# Find sensors with same NZ and R0 (should be three for VV)
|
||||
idxs = _matched_loc_idx(z_loc, locs)
|
||||
# Rotate the direction vectors to the plane defined by new normal
|
||||
_rotate_locs(locs, idxs, new_z)
|
||||
for ci, loc in enumerate(locs):
|
||||
info["chs"][ci]["loc"][:] = loc
|
||||
del calibration, zs
|
||||
|
||||
#
|
||||
# 2. Estimate imbalance parameters (always done)
|
||||
#
|
||||
if len(grad_picks) > 0:
|
||||
extra = "X direction" if n_imbalance == 1 else ("XYZ directions")
|
||||
logger.info(f"Computing imbalance for {len(grad_picks)} gradimeters ({extra})")
|
||||
imb_list = list()
|
||||
for start, stop in zip(time_idxs[:-1], time_idxs[1:]):
|
||||
logger.info(
|
||||
f" Processing interval {start / info['sfreq']:0.3f} - "
|
||||
f"{stop / info['sfreq']:0.3f} s"
|
||||
)
|
||||
data = raw[picks, start:stop][0]
|
||||
if ctc is not None:
|
||||
data = ctc.dot(data)
|
||||
out = _estimate_imbalance(info, data, cals, n_imbalance, origin, ext_order)
|
||||
imb_list.append(out)
|
||||
imb = np.mean(imb_list, axis=0)
|
||||
else:
|
||||
imb = np.zeros((len(info["ch_names"]), n_imbalance))
|
||||
|
||||
#
|
||||
# Put in output structure
|
||||
#
|
||||
assert len(np.intersect1d(mag_picks, grad_picks)) == 0
|
||||
imb_cals = [
|
||||
cals[ii : ii + 1] if ii in mag_picks else imb[ii]
|
||||
for ii in range(len(info["ch_names"]))
|
||||
]
|
||||
ch_names = _clean_names(info["ch_names"], remove_whitespace=True)
|
||||
calibration = dict(ch_names=ch_names, locs=locs, imb_cals=imb_cals)
|
||||
return calibration, count
|
||||
|
||||
|
||||
def _matched_loc_idx(mag_loc, all_loc):
|
||||
return np.where(
|
||||
[
|
||||
np.allclose(mag_loc[-3:], loc[-3:]) and np.allclose(mag_loc[:3], loc[:3])
|
||||
for loc in all_loc
|
||||
]
|
||||
)[0]
|
||||
|
||||
|
||||
def _rotate_locs(locs, idxs, new_z):
|
||||
new_z = new_z / np.linalg.norm(new_z)
|
||||
old_z = locs[idxs[0]][-3:]
|
||||
old_z = old_z / np.linalg.norm(old_z)
|
||||
rot = _find_vector_rotation(old_z, new_z)
|
||||
for ci in idxs:
|
||||
this_trans = _loc_to_coil_trans(locs[ci])
|
||||
this_trans[:3, :3] = np.dot(rot, this_trans[:3, :3])
|
||||
locs[ci][:] = _coil_trans_to_loc(this_trans)
|
||||
np.testing.assert_allclose(locs[ci][-3:], new_z, atol=1e-4)
|
||||
|
||||
|
||||
def _vector_angle(x, y):
|
||||
"""Get the angle between two vectors in degrees."""
|
||||
return np.abs(
|
||||
np.arccos(
|
||||
np.clip(
|
||||
(x * y).sum(axis=-1)
|
||||
/ (np.linalg.norm(x, axis=-1) * np.linalg.norm(y, axis=-1)),
|
||||
-1,
|
||||
1.0,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _adjust_mag_normals(info, data, origin, ext_order, *, angle_limit, err_limit):
|
||||
"""Adjust coil normals using magnetometers and empty-room data."""
|
||||
# in principle we could allow using just mag or mag+grad, but MF uses
|
||||
# just mag so let's follow suit
|
||||
mag_scale = 100.0
|
||||
picks_use = pick_types(info, meg="mag", exclude="bads")
|
||||
picks_meg = pick_types(info, meg=True, exclude=())
|
||||
picks_mag_orig = pick_types(info, meg="mag", exclude="bads")
|
||||
info = pick_info(info, picks_use) # copy
|
||||
data = data[picks_use]
|
||||
cals = np.ones((len(data), 1))
|
||||
angles = np.zeros(len(cals))
|
||||
picks_mag = pick_types(info, meg="mag")
|
||||
data[picks_mag] *= mag_scale
|
||||
# Transform variables so we're only dealing with good mags
|
||||
exp = dict(int_order=0, ext_order=ext_order, origin=origin)
|
||||
all_coils = _prep_mf_coils(info, ignore_ref=True)
|
||||
S_tot = _trans_sss_basis(exp, all_coils, coil_scale=mag_scale)
|
||||
first_err = _data_err(data, S_tot, cals)
|
||||
count = 0
|
||||
# two passes: first do the worst, then do all in order
|
||||
zs = np.array([ch["loc"][-3:] for ch in info["chs"]])
|
||||
zs /= np.linalg.norm(zs, axis=-1, keepdims=True)
|
||||
orig_zs = zs.copy()
|
||||
match_idx = dict()
|
||||
locs = np.array([ch["loc"] for ch in info["chs"]])
|
||||
for pick in picks_mag:
|
||||
match_idx[pick] = _matched_loc_idx(locs[pick], locs)
|
||||
counts = defaultdict(lambda: 0)
|
||||
for ki, kind in enumerate(("worst first", "in order")):
|
||||
logger.info(f" Magnetometer normal adjustment ({kind}) ...")
|
||||
S_tot = _trans_sss_basis(exp, all_coils, coil_scale=mag_scale)
|
||||
for pick in picks_mag:
|
||||
err = _data_err(data, S_tot, cals, axis=1)
|
||||
|
||||
# First pass: do worst; second pass: do all in order (up to 3x/sen)
|
||||
if ki == 0:
|
||||
order = list(np.argsort(err[picks_mag]))
|
||||
cal_idx = 0
|
||||
while len(order) > 0:
|
||||
cal_idx = picks_mag[order.pop(-1)]
|
||||
if counts[cal_idx] < 3:
|
||||
break
|
||||
if err[cal_idx] < 2.5:
|
||||
break # move on to second loop
|
||||
else:
|
||||
cal_idx = pick
|
||||
counts[cal_idx] += 1
|
||||
assert cal_idx in picks_mag
|
||||
count += 1
|
||||
old_z = zs[cal_idx].copy()
|
||||
objective = partial(
|
||||
_cal_sss_target,
|
||||
old_z=old_z,
|
||||
all_coils=all_coils,
|
||||
cal_idx=cal_idx,
|
||||
data=data,
|
||||
cals=cals,
|
||||
match_idx=match_idx,
|
||||
S_tot=S_tot,
|
||||
origin=origin,
|
||||
ext_order=ext_order,
|
||||
)
|
||||
|
||||
# Figure out the additive term for z-component
|
||||
zs[cal_idx] = minimize(
|
||||
objective,
|
||||
old_z,
|
||||
bounds=[(-2, 2)] * 3,
|
||||
# BFGS is the default for minimize but COBYLA converges faster
|
||||
method="COBYLA",
|
||||
# Start with a small relative step because nominal geometry information
|
||||
# should be fairly accurate to begin with
|
||||
options=dict(rhobeg=1e-1),
|
||||
).x
|
||||
|
||||
# Do in-place adjustment to all_coils
|
||||
cals[cal_idx] = 1.0 / np.linalg.norm(zs[cal_idx])
|
||||
zs[cal_idx] *= cals[cal_idx]
|
||||
for idx in match_idx[cal_idx]:
|
||||
_rotate_coil(zs[cal_idx], old_z, all_coils, idx, inplace=True)
|
||||
|
||||
# Recalculate S_tot, taking into account rotations
|
||||
S_tot = _trans_sss_basis(exp, all_coils)
|
||||
|
||||
# Reprt results
|
||||
old_err = err[cal_idx]
|
||||
new_err = _data_err(data, S_tot, cals, idx=cal_idx)
|
||||
angles[cal_idx] = np.abs(
|
||||
np.rad2deg(_vector_angle(zs[cal_idx], orig_zs[cal_idx]))
|
||||
)
|
||||
ch_name = info["ch_names"][cal_idx]
|
||||
logger.debug(
|
||||
f" Optimization step {count:3d} | "
|
||||
f"{ch_name} ({counts[cal_idx]}) | "
|
||||
f"res {old_err:5.2f}→{new_err:5.2f}% | "
|
||||
f"×{cals[cal_idx, 0]:0.3f} | {angles[cal_idx]:0.2f}°"
|
||||
)
|
||||
last_err = _data_err(data, S_tot, cals)
|
||||
# Chunk is usable if all angles and errors are both small
|
||||
reason = list()
|
||||
max_angle = np.max(angles)
|
||||
if max_angle >= angle_limit:
|
||||
reason.append(f"max angle {max_angle:0.2f} >= {angle_limit:0.1f}°")
|
||||
each_err = _data_err(data, S_tot, cals, axis=-1)[picks_mag]
|
||||
n_bad = (each_err > err_limit).sum()
|
||||
if n_bad:
|
||||
reason.append(
|
||||
f"{n_bad} residual{_pl(n_bad)} > {err_limit:0.1f}% "
|
||||
f"(max: {each_err.max():0.2f}%)"
|
||||
)
|
||||
reason = ", ".join(reason)
|
||||
if reason:
|
||||
reason = f" ({reason})"
|
||||
good = not bool(reason)
|
||||
assert np.allclose(np.linalg.norm(zs, axis=1), 1.0)
|
||||
logger.info(f" Fit mismatch {first_err:0.2f}→{last_err:0.2f}%")
|
||||
logger.info(f' Data segment {"" if good else "un"}usable{reason}')
|
||||
# Reformat zs and cals to be the n_mags (including bads)
|
||||
assert zs.shape == (len(data), 3)
|
||||
assert cals.shape == (len(data), 1)
|
||||
imb_cals = np.ones(len(picks_meg))
|
||||
imb_cals[picks_mag_orig] = cals[:, 0]
|
||||
return zs, imb_cals, good
|
||||
|
||||
|
||||
def _data_err(data, S_tot, cals, idx=None, axis=None):
|
||||
if idx is None:
|
||||
idx = slice(None)
|
||||
S_tot = S_tot / cals
|
||||
data_model = np.dot(np.dot(S_tot[idx], _col_norm_pinv(S_tot.copy())[0]), data)
|
||||
err = 100 * (
|
||||
np.linalg.norm(data_model - data[idx], axis=axis)
|
||||
/ np.linalg.norm(data[idx], axis=axis)
|
||||
)
|
||||
return err
|
||||
|
||||
|
||||
def _rotate_coil(new_z, old_z, all_coils, idx, inplace=False):
|
||||
"""Adjust coils."""
|
||||
# Turn NX and NY to the plane determined by NZ
|
||||
old_z = old_z / np.linalg.norm(old_z)
|
||||
new_z = new_z / np.linalg.norm(new_z)
|
||||
rot = _find_vector_rotation(old_z, new_z) # additional coil rotation
|
||||
this_sl = all_coils[5][idx]
|
||||
this_rmag = np.dot(rot, all_coils[0][this_sl].T).T
|
||||
this_cosmag = np.dot(rot, all_coils[1][this_sl].T).T
|
||||
if inplace:
|
||||
all_coils[0][this_sl] = this_rmag
|
||||
all_coils[1][this_sl] = this_cosmag
|
||||
subset = (
|
||||
this_rmag,
|
||||
this_cosmag,
|
||||
np.zeros(this_rmag.shape[0], int),
|
||||
1,
|
||||
all_coils[4][[idx]],
|
||||
{0: this_sl},
|
||||
)
|
||||
return subset
|
||||
|
||||
|
||||
def _cal_sss_target(
|
||||
new_z, old_z, all_coils, cal_idx, data, cals, S_tot, origin, ext_order, match_idx
|
||||
):
|
||||
"""Evaluate objective function for SSS-based magnetometer calibration."""
|
||||
cals[cal_idx] = 1.0 / np.linalg.norm(new_z)
|
||||
exp = dict(int_order=0, ext_order=ext_order, origin=origin)
|
||||
S_tot = S_tot.copy()
|
||||
# Rotate necessary coils properly and adjust correct element in c
|
||||
for idx in match_idx[cal_idx]:
|
||||
this_coil = _rotate_coil(new_z, old_z, all_coils, idx)
|
||||
# Replace correct row of S_tot with new value
|
||||
S_tot[idx] = _trans_sss_basis(exp, this_coil)
|
||||
# Get the GOF
|
||||
return _data_err(data, S_tot, cals, idx=cal_idx)
|
||||
|
||||
|
||||
def _estimate_imbalance(info, data, cals, n_imbalance, origin, ext_order):
|
||||
"""Estimate gradiometer imbalance parameters."""
|
||||
mag_scale = 100.0
|
||||
n_iterations = 3
|
||||
mag_picks = pick_types(info, meg="mag", exclude=())
|
||||
grad_picks = pick_types(info, meg="grad", exclude=())
|
||||
data = data.copy()
|
||||
data[mag_picks, :] *= mag_scale
|
||||
del mag_picks
|
||||
|
||||
grad_imb = np.zeros((len(grad_picks), n_imbalance))
|
||||
exp = dict(origin=origin, int_order=0, ext_order=ext_order)
|
||||
all_coils = _prep_mf_coils(info, ignore_ref=True)
|
||||
grad_point_coils = _get_grad_point_coilsets(info, n_imbalance, ignore_ref=True)
|
||||
S_orig = _trans_sss_basis(exp, all_coils, coil_scale=mag_scale)
|
||||
S_orig /= cals[:, np.newaxis]
|
||||
# Compute point gradiometers for each grad channel
|
||||
this_cs = np.array([mag_scale], float)
|
||||
S_pt = np.array(
|
||||
[_trans_sss_basis(exp, coils, None, this_cs) for coils in grad_point_coils]
|
||||
)
|
||||
for k in range(n_iterations):
|
||||
S_tot = S_orig.copy()
|
||||
# In theory we could zero out the homogeneous components with:
|
||||
# S_tot[grad_picks, :3] = 0
|
||||
# But in practice it doesn't seem to matter
|
||||
S_recon = S_tot[grad_picks]
|
||||
|
||||
# Add influence of point magnetometers
|
||||
S_tot[grad_picks, :] += np.einsum("ij,ijk->jk", grad_imb.T, S_pt)
|
||||
|
||||
# Compute multipolar moments
|
||||
mm = np.dot(_col_norm_pinv(S_tot.copy())[0], data)
|
||||
|
||||
# Use good channels to recalculate
|
||||
prev_imb = grad_imb.copy()
|
||||
data_recon = np.dot(S_recon, mm)
|
||||
assert S_pt.shape == (n_imbalance, len(grad_picks), S_tot.shape[1])
|
||||
khi_pts = (S_pt @ mm).transpose(1, 2, 0)
|
||||
assert khi_pts.shape == (len(grad_picks), data.shape[1], n_imbalance)
|
||||
residual = data[grad_picks] - data_recon
|
||||
assert residual.shape == (len(grad_picks), data.shape[1])
|
||||
d = (residual[:, np.newaxis, :] @ khi_pts)[:, 0]
|
||||
assert d.shape == (len(grad_picks), n_imbalance)
|
||||
dinv, _, _ = _reg_pinv(khi_pts.swapaxes(-1, -2) @ khi_pts, rcond=1e-6)
|
||||
assert dinv.shape == (len(grad_picks), n_imbalance, n_imbalance)
|
||||
grad_imb[:] = (d[:, np.newaxis] @ dinv)[:, 0]
|
||||
# This code is equivalent but hits a np.linalg.pinv bug on old NumPy:
|
||||
# grad_imb[:] = np.sum( # dot product across the time dim
|
||||
# np.linalg.pinv(khi_pts) * residual[:, np.newaxis], axis=-1)
|
||||
deltas = np.linalg.norm(grad_imb - prev_imb) / max(
|
||||
np.linalg.norm(grad_imb), np.linalg.norm(prev_imb)
|
||||
)
|
||||
logger.debug(
|
||||
f" Iteration {k + 1}/{n_iterations}: "
|
||||
f"max ∆ = {100 * deltas.max():7.3f}%"
|
||||
)
|
||||
imb = np.zeros((len(data), n_imbalance))
|
||||
imb[grad_picks] = grad_imb
|
||||
return imb
|
||||
|
||||
|
||||
def read_fine_calibration(fname):
|
||||
"""Read fine calibration information from a ``.dat`` file.
|
||||
|
||||
The fine calibration typically includes improved sensor locations,
|
||||
calibration coefficients, and gradiometer imbalance information.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : path-like
|
||||
The filename.
|
||||
|
||||
Returns
|
||||
-------
|
||||
calibration : dict
|
||||
Fine calibration information. Key-value pairs are:
|
||||
|
||||
- ``ch_names``
|
||||
List of str of the channel names.
|
||||
- ``locs``
|
||||
Coil location and orientation parameters.
|
||||
- ``imb_cals``
|
||||
For magnetometers, the calibration coefficients.
|
||||
For gradiometers, one or three imbalance parameters.
|
||||
"""
|
||||
# Read new sensor locations
|
||||
fname = _check_fname(fname, overwrite="read", must_exist=True)
|
||||
check_fname(fname, "cal", (".dat",))
|
||||
ch_names, locs, imb_cals = list(), list(), list()
|
||||
with open(fname) as fid:
|
||||
for line in fid:
|
||||
if line[0] in "#\n":
|
||||
continue
|
||||
vals = line.strip().split()
|
||||
if len(vals) not in [14, 16]:
|
||||
raise RuntimeError(
|
||||
"Error parsing fine calibration file, "
|
||||
"should have 14 or 16 entries per line "
|
||||
f"but found {len(vals)} on line:\n{line}"
|
||||
)
|
||||
# `vals` contains channel number
|
||||
ch_name = vals[0]
|
||||
if len(ch_name) in (3, 4): # heuristic for Neuromag fix
|
||||
try:
|
||||
ch_name = int(ch_name)
|
||||
except ValueError: # something other than e.g. 113 or 2642
|
||||
pass
|
||||
else:
|
||||
ch_name = f"MEG{int(ch_name):04}"
|
||||
# (x, y, z), x-norm 3-vec, y-norm 3-vec, z-norm 3-vec
|
||||
# and 1 or 3 imbalance terms
|
||||
ch_names.append(ch_name)
|
||||
locs.append(np.array(vals[1:13], float))
|
||||
imb_cals.append(np.array(vals[13:], float))
|
||||
locs = np.array(locs)
|
||||
return dict(ch_names=ch_names, locs=locs, imb_cals=imb_cals)
|
||||
|
||||
|
||||
def write_fine_calibration(fname, calibration):
|
||||
"""Write fine calibration information to a ``.dat`` file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : path-like
|
||||
The filename to write out.
|
||||
calibration : dict
|
||||
Fine calibration information.
|
||||
"""
|
||||
fname = _check_fname(fname, overwrite=True)
|
||||
check_fname(fname, "cal", (".dat",))
|
||||
keys = ("ch_names", "locs", "imb_cals")
|
||||
with open(fname, "wb") as cal_file:
|
||||
for ch_name, loc, imb_cal in zip(*(calibration[key] for key in keys)):
|
||||
cal_line = np.concatenate([loc, imb_cal]).round(6)
|
||||
cal_line = " ".join(f"{c:0.6f}" for c in cal_line)
|
||||
cal_file.write(f"{ch_name} {cal_line}\n".encode("ASCII"))
|
||||
98
mne/preprocessing/_lof.py
Normal file
98
mne/preprocessing/_lof.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Bad channel detection using Local Outlier Factor (LOF)."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.pick import _picks_to_idx
|
||||
from ..io.base import BaseRaw
|
||||
from ..utils import _soft_import, _validate_type, logger, verbose
|
||||
|
||||
|
||||
@verbose
|
||||
def find_bad_channels_lof(
|
||||
raw,
|
||||
n_neighbors=20,
|
||||
*,
|
||||
picks=None,
|
||||
metric="euclidean",
|
||||
threshold=1.5,
|
||||
return_scores=False,
|
||||
verbose=None,
|
||||
):
|
||||
"""Find bad channels using Local Outlier Factor (LOF) algorithm.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
Raw data to process.
|
||||
n_neighbors : int
|
||||
Number of neighbors defining the local neighborhood (default is 20).
|
||||
Smaller values will lead to higher LOF scores.
|
||||
%(picks_good_data)s
|
||||
metric : str
|
||||
Metric to use for distance computation. Default is “euclidean”,
|
||||
see :func:`sklearn.metrics.pairwise.distance_metrics` for details.
|
||||
threshold : float
|
||||
Threshold to define outliers. Theoretical threshold ranges anywhere
|
||||
between 1.0 and any positive integer. Default: 1.5
|
||||
It is recommended to consider this as an hyperparameter to optimize.
|
||||
return_scores : bool
|
||||
If ``True``, return a dictionary with LOF scores for each
|
||||
evaluated channel. Default is ``False``.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
noisy_chs : list
|
||||
List of bad M/EEG channels that were automatically detected.
|
||||
scores : ndarray, shape (n_picks,)
|
||||
Only returned when ``return_scores`` is ``True``. It contains the
|
||||
LOF outlier score for each channel in ``picks``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
maxwell_filter
|
||||
annotate_amplitude
|
||||
|
||||
Notes
|
||||
-----
|
||||
See :footcite:`KumaravelEtAl2022` and :footcite:`BreunigEtAl2000` for background on
|
||||
choosing ``threshold``.
|
||||
|
||||
.. versionadded:: 1.7
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
""" # noqa: E501
|
||||
_soft_import("sklearn", "using LOF detection", strict=True)
|
||||
from sklearn.neighbors import LocalOutlierFactor
|
||||
|
||||
_validate_type(raw, BaseRaw, "raw")
|
||||
# Get the channel types
|
||||
channel_types = raw.get_channel_types()
|
||||
picks = _picks_to_idx(raw.info, picks=picks, none="data", exclude="bads")
|
||||
picked_ch_types = set(channel_types[p] for p in picks)
|
||||
|
||||
# Check if there are different channel types
|
||||
if len(picked_ch_types) != 1:
|
||||
raise ValueError(
|
||||
f"Need exactly one channel type in picks, got {sorted(picked_ch_types)}"
|
||||
)
|
||||
ch_names = [raw.ch_names[pick] for pick in picks]
|
||||
data = raw.get_data(picks=picks)
|
||||
clf = LocalOutlierFactor(n_neighbors=n_neighbors, metric=metric)
|
||||
clf.fit_predict(data)
|
||||
scores_lof = clf.negative_outlier_factor_
|
||||
bad_channel_indices = [
|
||||
i for i, v in enumerate(np.abs(scores_lof)) if v >= threshold
|
||||
]
|
||||
bads = [ch_names[idx] for idx in bad_channel_indices]
|
||||
logger.info(f"LOF: Detected bad channel(s): {bads}")
|
||||
if return_scores:
|
||||
return bads, scores_lof
|
||||
else:
|
||||
return bads
|
||||
184
mne/preprocessing/_peak_finder.py
Normal file
184
mne/preprocessing/_peak_finder.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils import _pl, logger, verbose
|
||||
|
||||
|
||||
@verbose
|
||||
def peak_finder(x0, thresh=None, extrema=1, verbose=None):
|
||||
"""Noise-tolerant fast peak-finding algorithm.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x0 : 1d array
|
||||
A real vector from the maxima will be found (required).
|
||||
thresh : float | None
|
||||
The amount above surrounding data for a peak to be
|
||||
identified. Larger values mean the algorithm is more selective in
|
||||
finding peaks. If ``None``, use the default of
|
||||
``(max(x0) - min(x0)) / 4``.
|
||||
extrema : {-1, 1}
|
||||
1 if maxima are desired, -1 if minima are desired
|
||||
(default = maxima, 1).
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
peak_loc : array
|
||||
The indices of the identified peaks in x0.
|
||||
peak_mag : array
|
||||
The magnitude of the identified peaks.
|
||||
|
||||
Notes
|
||||
-----
|
||||
If repeated values are found the first is identified as the peak.
|
||||
Conversion from initial Matlab code from:
|
||||
Nathanael C. Yoder (ncyoder@purdue.edu)
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> from mne.preprocessing import peak_finder
|
||||
>>> t = np.arange(0, 3, 0.01)
|
||||
>>> x = np.sin(np.pi*t) - np.sin(0.5*np.pi*t)
|
||||
>>> peak_locs, peak_mags = peak_finder(x) # doctest: +SKIP
|
||||
>>> peak_locs # doctest: +SKIP
|
||||
array([36, 260]) # doctest: +SKIP
|
||||
>>> peak_mags # doctest: +SKIP
|
||||
array([0.36900026, 1.76007351]) # doctest: +SKIP
|
||||
"""
|
||||
x0 = np.asanyarray(x0)
|
||||
s = x0.size
|
||||
|
||||
if x0.ndim >= 2 or s == 0:
|
||||
raise ValueError("The input data must be a non empty 1D vector")
|
||||
|
||||
if thresh is None:
|
||||
thresh = (np.max(x0) - np.min(x0)) / 4
|
||||
logger.debug(f"Peak finder automatic threshold: {thresh:0.2g}")
|
||||
|
||||
assert extrema in [-1, 1]
|
||||
|
||||
if extrema == -1:
|
||||
x0 = extrema * x0 # Make it so we are finding maxima regardless
|
||||
|
||||
dx0 = np.diff(x0) # Find derivative
|
||||
# This is so we find the first of repeated values
|
||||
dx0[dx0 == 0] = -np.finfo(float).eps
|
||||
# Find where the derivative changes sign
|
||||
ind = np.where(dx0[:-1:] * dx0[1::] < 0)[0] + 1
|
||||
|
||||
# Include endpoints in potential peaks and valleys
|
||||
x = np.concatenate((x0[:1], x0[ind], x0[-1:]))
|
||||
ind = np.concatenate(([0], ind, [s - 1]))
|
||||
del x0
|
||||
|
||||
# x only has the peaks, valleys, and endpoints
|
||||
length = x.size
|
||||
min_mag = np.min(x)
|
||||
|
||||
if length > 2: # Function with peaks and valleys
|
||||
# Set initial parameters for loop
|
||||
temp_mag = min_mag
|
||||
found_peak = False
|
||||
left_min = min_mag
|
||||
|
||||
# Deal with first point a little differently since tacked it on
|
||||
# Calculate the sign of the derivative since we took the first point
|
||||
# on it does not necessarily alternate like the rest.
|
||||
signDx = np.sign(np.diff(x[:3]))
|
||||
if signDx[0] <= 0: # The first point is larger or equal to the second
|
||||
ii = -1
|
||||
if signDx[0] == signDx[1]: # Want alternating signs
|
||||
x = np.concatenate((x[:1], x[2:]))
|
||||
ind = np.concatenate((ind[:1], ind[2:]))
|
||||
length -= 1
|
||||
|
||||
else: # First point is smaller than the second
|
||||
ii = 0
|
||||
if signDx[0] == signDx[1]: # Want alternating signs
|
||||
x = x[1:]
|
||||
ind = ind[1:]
|
||||
length -= 1
|
||||
|
||||
# Preallocate max number of maxima
|
||||
maxPeaks = int(np.ceil(length / 2.0))
|
||||
peak_loc = np.zeros(maxPeaks, dtype=np.int64)
|
||||
peak_mag = np.zeros(maxPeaks)
|
||||
c_ind = 0
|
||||
# Loop through extrema which should be peaks and then valleys
|
||||
while ii < (length - 1):
|
||||
ii += 1 # This is a peak
|
||||
# Reset peak finding if we had a peak and the next peak is bigger
|
||||
# than the last or the left min was small enough to reset.
|
||||
if found_peak and (
|
||||
(x[ii] > peak_mag[-1]) or (left_min < peak_mag[-1] - thresh)
|
||||
):
|
||||
temp_mag = min_mag
|
||||
found_peak = False
|
||||
|
||||
# Make sure we don't iterate past the length of our vector
|
||||
if ii == length - 1:
|
||||
break # We assign the last point differently out of the loop
|
||||
|
||||
# Found new peak that was lager than temp mag and threshold larger
|
||||
# than the minimum to its left.
|
||||
if (x[ii] > temp_mag) and (x[ii] > left_min + thresh):
|
||||
temp_loc = ii
|
||||
temp_mag = x[ii]
|
||||
|
||||
ii += 1 # Move onto the valley
|
||||
# Come down at least thresh from peak
|
||||
if not found_peak and (temp_mag > (thresh + x[ii])):
|
||||
found_peak = True # We have found a peak
|
||||
left_min = x[ii]
|
||||
peak_loc[c_ind] = temp_loc # Add peak to index
|
||||
peak_mag[c_ind] = temp_mag
|
||||
c_ind += 1
|
||||
elif x[ii] < left_min: # New left minima
|
||||
left_min = x[ii]
|
||||
|
||||
# Check end point
|
||||
if (x[-1] > temp_mag) and (x[-1] > (left_min + thresh)):
|
||||
peak_loc[c_ind] = length - 1
|
||||
peak_mag[c_ind] = x[-1]
|
||||
c_ind += 1
|
||||
elif not found_peak and temp_mag > min_mag:
|
||||
# Check if we still need to add the last point
|
||||
peak_loc[c_ind] = temp_loc
|
||||
peak_mag[c_ind] = temp_mag
|
||||
c_ind += 1
|
||||
|
||||
# Create output
|
||||
peak_inds = ind[peak_loc[:c_ind]]
|
||||
peak_mags = peak_mag[:c_ind]
|
||||
else: # This is a monotone function where an endpoint is the only peak
|
||||
x_ind = np.argmax(x)
|
||||
peak_mags = x[x_ind]
|
||||
if peak_mags > (min_mag + thresh):
|
||||
peak_inds = ind[x_ind]
|
||||
else:
|
||||
peak_mags = []
|
||||
peak_inds = []
|
||||
|
||||
# Change sign of data if was finding minima
|
||||
if extrema < 0:
|
||||
peak_mags *= -1.0
|
||||
|
||||
# ensure output type array
|
||||
if not isinstance(peak_inds, np.ndarray):
|
||||
peak_inds = np.atleast_1d(peak_inds).astype("int64")
|
||||
|
||||
if not isinstance(peak_mags, np.ndarray):
|
||||
peak_mags = np.atleast_1d(peak_mags).astype("float64")
|
||||
|
||||
# Plot if no output desired
|
||||
if len(peak_inds) == 0:
|
||||
logger.info("No significant peaks found")
|
||||
else:
|
||||
logger.info(f"Found {len(peak_inds)} significant peak{_pl(peak_inds)}")
|
||||
|
||||
return peak_inds, peak_mags
|
||||
390
mne/preprocessing/_regress.py
Normal file
390
mne/preprocessing/_regress.py
Normal file
@@ -0,0 +1,390 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.pick import _picks_to_idx, pick_info
|
||||
from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT
|
||||
from ..epochs import BaseEpochs
|
||||
from ..evoked import Evoked
|
||||
from ..io import BaseRaw
|
||||
from ..minimum_norm.inverse import _needs_eeg_average_ref_proj
|
||||
from ..utils import (
|
||||
_check_fname,
|
||||
_check_option,
|
||||
_check_preload,
|
||||
_import_h5io_funcs,
|
||||
_validate_type,
|
||||
copy_function_doc_to_method_doc,
|
||||
fill_doc,
|
||||
verbose,
|
||||
)
|
||||
from ..viz import plot_regression_weights
|
||||
|
||||
|
||||
@verbose
|
||||
def regress_artifact(
|
||||
inst,
|
||||
picks=None,
|
||||
*,
|
||||
exclude="bads",
|
||||
picks_artifact="eog",
|
||||
betas=None,
|
||||
proj=True,
|
||||
copy=True,
|
||||
verbose=None,
|
||||
):
|
||||
"""Remove artifacts using regression based on reference channels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst : instance of Epochs | Raw
|
||||
The instance to process.
|
||||
%(picks_good_data)s
|
||||
exclude : list | 'bads'
|
||||
List of channels to exclude from the regression, only used when picking
|
||||
based on types (e.g., exclude="bads" when picks="meg").
|
||||
Specify ``'bads'`` (the default) to exclude all channels marked as bad.
|
||||
|
||||
.. versionadded:: 1.2
|
||||
picks_artifact : array-like | str
|
||||
Channel picks to use as predictor/explanatory variables capturing
|
||||
the artifact of interest (default is "eog").
|
||||
betas : ndarray, shape (n_picks, n_picks_ref) | None
|
||||
The regression coefficients to use. If None (default), they will be
|
||||
estimated from the data.
|
||||
proj : bool
|
||||
Whether to automatically apply SSP projection vectors before performing
|
||||
the regression. Default is ``True``.
|
||||
copy : bool
|
||||
If True (default), copy the instance before modifying it.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
inst : instance of Epochs | Raw
|
||||
The processed data.
|
||||
betas : ndarray, shape (n_picks, n_picks_ref)
|
||||
The betas used during regression.
|
||||
|
||||
Notes
|
||||
-----
|
||||
To implement the method outlined in :footcite:`GrattonEtAl1983`,
|
||||
remove the evoked response from epochs before estimating the
|
||||
regression coefficients, then apply those regression coefficients to the
|
||||
original data in two calls like (here for a single-condition ``epochs``
|
||||
only):
|
||||
|
||||
>>> epochs_no_ave = epochs.copy().subtract_evoked() # doctest:+SKIP
|
||||
>>> _, betas = mne.preprocessing.regress(epochs_no_ave) # doctest:+SKIP
|
||||
>>> epochs_clean, _ = mne.preprocessing.regress(epochs, betas=betas) # doctest:+SKIP
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
""" # noqa: E501
|
||||
if betas is None:
|
||||
model = EOGRegression(
|
||||
picks=picks, exclude=exclude, picks_artifact=picks_artifact, proj=proj
|
||||
)
|
||||
model.fit(inst)
|
||||
else:
|
||||
# Create an EOGRegression object and load the given betas into it.
|
||||
picks = _picks_to_idx(inst.info, picks, exclude=exclude, none="data")
|
||||
picks_artifact = _picks_to_idx(inst.info, picks_artifact)
|
||||
want_betas_shape = (len(picks), len(picks_artifact))
|
||||
_check_option("betas.shape", betas.shape, (want_betas_shape,))
|
||||
model = EOGRegression(picks, picks_artifact, proj=proj)
|
||||
model.info_ = inst.info.copy()
|
||||
model.coef_ = betas
|
||||
return model.apply(inst, copy=copy), model.coef_
|
||||
|
||||
|
||||
@fill_doc
|
||||
class EOGRegression:
|
||||
"""Remove EOG artifact signals from other channels by regression.
|
||||
|
||||
Employs linear regression to remove signals captured by some channels,
|
||||
typically EOG, as described in :footcite:`GrattonEtAl1983`. You can also
|
||||
choose to fit the regression coefficients on evoked blink/saccade data and
|
||||
then apply them to continuous data, as described in
|
||||
:footcite:`CroftBarry2000`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(picks_good_data)s
|
||||
exclude : list | 'bads'
|
||||
List of channels to exclude from the regression, only used when picking
|
||||
based on types (e.g., exclude="bads" when picks="meg").
|
||||
Specify ``'bads'`` (the default) to exclude all channels marked as bad.
|
||||
picks_artifact : array-like | str
|
||||
Channel picks to use as predictor/explanatory variables capturing
|
||||
the artifact of interest (default is "eog").
|
||||
proj : bool
|
||||
Whether to automatically apply SSP projection vectors before fitting
|
||||
and applying the regression. Default is ``True``.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
coef_ : ndarray, shape (n, n)
|
||||
The regression coefficients. Only available after fitting.
|
||||
info_ : Info
|
||||
Channel information corresponding to the regression weights.
|
||||
Only available after fitting.
|
||||
picks : array-like | str
|
||||
Channels to perform the regression on.
|
||||
exclude : list | 'bads'
|
||||
Channels to exclude from the regression.
|
||||
picks_artifact : array-like | str
|
||||
The channels designated as containing the artifacts of interest.
|
||||
proj : bool
|
||||
Whether projections will be applied before performing the regression.
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 1.2
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
|
||||
def __init__(self, picks=None, exclude="bads", picks_artifact="eog", proj=True):
|
||||
self.picks = picks
|
||||
self.exclude = exclude
|
||||
self.picks_artifact = picks_artifact
|
||||
self.proj = proj
|
||||
|
||||
def fit(self, inst):
|
||||
"""Fit EOG regression coefficients.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst : Raw | Epochs | Evoked
|
||||
The data on which the EOG regression weights should be fitted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : EOGRegression
|
||||
The fitted ``EOGRegression`` object. The regression coefficients
|
||||
are available as the ``.coef_`` and ``.intercept_`` attributes.
|
||||
|
||||
Notes
|
||||
-----
|
||||
If your data contains EEG channels, make sure to apply the desired
|
||||
reference (see :func:`mne.set_eeg_reference`) before performing EOG
|
||||
regression.
|
||||
"""
|
||||
picks, picks_artifact = self._check_inst(inst)
|
||||
|
||||
# Calculate regression coefficients. Add a row of ones to also fit the
|
||||
# intercept.
|
||||
_check_preload(inst, "artifact regression")
|
||||
artifact_data = inst._data[..., picks_artifact, :]
|
||||
ref_data = artifact_data - np.mean(artifact_data, axis=-1, keepdims=True)
|
||||
if ref_data.ndim == 3:
|
||||
ref_data = ref_data.transpose(1, 0, 2)
|
||||
ref_data = ref_data.reshape(len(picks_artifact), -1)
|
||||
cov_ref = ref_data @ ref_data.T
|
||||
|
||||
# Process each channel separately to reduce memory load
|
||||
coef = np.zeros((len(picks), len(picks_artifact)))
|
||||
for pi, pick in enumerate(picks):
|
||||
this_data = inst._data[..., pick, :] # view
|
||||
# Subtract mean over time from every trial/channel
|
||||
cov_data = this_data - np.mean(this_data, -1, keepdims=True)
|
||||
cov_data = cov_data.reshape(1, -1)
|
||||
# Perform the linear regression
|
||||
coef[pi] = np.linalg.solve(cov_ref, ref_data @ cov_data.T).T[0]
|
||||
|
||||
# Store relevant parameters in the object.
|
||||
self.coef_ = coef
|
||||
self.info_ = inst.info.copy()
|
||||
return self
|
||||
|
||||
@fill_doc
|
||||
def apply(self, inst, copy=True):
|
||||
"""Apply the regression coefficients to data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst : Raw | Epochs | Evoked
|
||||
The data on which to apply the regression.
|
||||
%(copy_df)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
inst : Raw | Epochs | Evoked
|
||||
A version of the data with the artifact channels regressed out.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Only works after ``.fit()`` has been used.
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
if copy:
|
||||
inst = inst.copy()
|
||||
picks, picks_artifact = self._check_inst(inst)
|
||||
|
||||
# Check that the channels are compatible with the regression weights.
|
||||
ref_picks = _picks_to_idx(
|
||||
self.info_, self.picks, none="data", exclude=self.exclude
|
||||
)
|
||||
ref_picks_artifact = _picks_to_idx(self.info_, self.picks_artifact)
|
||||
if any(
|
||||
inst.ch_names[ch1] != self.info_["chs"][ch2]["ch_name"]
|
||||
for ch1, ch2 in zip(picks, ref_picks)
|
||||
):
|
||||
raise ValueError(
|
||||
"Selected data channels are not compatible with "
|
||||
"the regression weights. Make sure that all data "
|
||||
"channels are present and in the correct order."
|
||||
)
|
||||
if any(
|
||||
inst.ch_names[ch1] != self.info_["chs"][ch2]["ch_name"]
|
||||
for ch1, ch2 in zip(picks_artifact, ref_picks_artifact)
|
||||
):
|
||||
raise ValueError(
|
||||
"Selected artifact channels are not compatible "
|
||||
"with the regression weights. Make sure that all "
|
||||
"artifact channels are present and in the "
|
||||
"correct order."
|
||||
)
|
||||
|
||||
_check_preload(inst, "artifact regression")
|
||||
artifact_data = inst._data[..., picks_artifact, :]
|
||||
ref_data = artifact_data - np.mean(artifact_data, -1, keepdims=True)
|
||||
for pi, pick in enumerate(picks):
|
||||
this_data = inst._data[..., pick, :] # view
|
||||
this_data -= (self.coef_[pi] @ ref_data).reshape(this_data.shape)
|
||||
return inst
|
||||
|
||||
@copy_function_doc_to_method_doc(plot_regression_weights)
|
||||
def plot(
|
||||
self,
|
||||
ch_type=None,
|
||||
sensors=True,
|
||||
show_names=False,
|
||||
mask=None,
|
||||
mask_params=None,
|
||||
contours=6,
|
||||
outlines="head",
|
||||
sphere=None,
|
||||
image_interp=_INTERPOLATION_DEFAULT,
|
||||
extrapolate=_EXTRAPOLATE_DEFAULT,
|
||||
border=_BORDER_DEFAULT,
|
||||
res=64,
|
||||
size=1,
|
||||
cmap=None,
|
||||
vlim=(None, None),
|
||||
cnorm=None,
|
||||
axes=None,
|
||||
colorbar=True,
|
||||
cbar_fmt="%1.1e",
|
||||
title=None,
|
||||
show=True,
|
||||
):
|
||||
return plot_regression_weights(
|
||||
self,
|
||||
ch_type=ch_type,
|
||||
sensors=sensors,
|
||||
show_names=show_names,
|
||||
mask=mask,
|
||||
mask_params=mask_params,
|
||||
contours=contours,
|
||||
outlines=outlines,
|
||||
sphere=sphere,
|
||||
image_interp=image_interp,
|
||||
extrapolate=extrapolate,
|
||||
border=border,
|
||||
res=res,
|
||||
size=size,
|
||||
cmap=cmap,
|
||||
vlim=vlim,
|
||||
cnorm=cnorm,
|
||||
axes=axes,
|
||||
colorbar=colorbar,
|
||||
cbar_fmt=cbar_fmt,
|
||||
title=title,
|
||||
show=show,
|
||||
)
|
||||
|
||||
def _check_inst(self, inst):
|
||||
"""Perform some sanity checks on the input."""
|
||||
_validate_type(
|
||||
inst, (BaseRaw, BaseEpochs, Evoked), "inst", "Raw, Epochs, Evoked"
|
||||
)
|
||||
picks = _picks_to_idx(inst.info, self.picks, none="data", exclude=self.exclude)
|
||||
picks_artifact = _picks_to_idx(inst.info, self.picks_artifact)
|
||||
all_picks = np.unique(np.concatenate([picks, picks_artifact]))
|
||||
use_info = pick_info(inst.info, all_picks)
|
||||
del all_picks
|
||||
if _needs_eeg_average_ref_proj(use_info):
|
||||
raise RuntimeError(
|
||||
"No average reference for the EEG channels has been "
|
||||
"set. Use inst.set_eeg_reference(projection=True) to do so."
|
||||
)
|
||||
if self.proj and not inst.proj:
|
||||
inst.apply_proj()
|
||||
if not inst.proj and len(use_info.get("projs", [])) > 0:
|
||||
raise RuntimeError(
|
||||
"Projections need to be applied before "
|
||||
"regression can be performed. Use the "
|
||||
".apply_proj() method to do so."
|
||||
)
|
||||
return picks, picks_artifact
|
||||
|
||||
def __repr__(self):
|
||||
"""Produce a string representation of this object."""
|
||||
s = "<EOGRegression | "
|
||||
if hasattr(self, "coef_"):
|
||||
n_art = self.coef_.shape[1]
|
||||
plural = "s" if n_art > 1 else ""
|
||||
s += f"fitted to {n_art} artifact channel{plural}>"
|
||||
else:
|
||||
s += "not fitted>"
|
||||
return s
|
||||
|
||||
@fill_doc
|
||||
def save(self, fname, overwrite=False):
|
||||
"""Save the regression model to an HDF5 file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : path-like
|
||||
The file to write the regression weights to. Should end in ``.h5``.
|
||||
%(overwrite)s
|
||||
"""
|
||||
_, write_hdf5 = _import_h5io_funcs()
|
||||
_validate_type(fname, "path-like", "fname")
|
||||
fname = _check_fname(fname, overwrite=overwrite, name="fname")
|
||||
write_hdf5(fname, self.__dict__, overwrite=overwrite)
|
||||
|
||||
|
||||
def read_eog_regression(fname):
|
||||
"""Read an EOG regression model from an HDF5 file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : path-like
|
||||
The file to read the regression model from. Should end in ``.h5``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : EOGRegression
|
||||
The regression model read from the file.
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 1.2
|
||||
"""
|
||||
read_hdf5, _ = _import_h5io_funcs()
|
||||
_validate_type(fname, "path-like", "fname")
|
||||
fname = _check_fname(fname, overwrite="read", must_exist=True, name="fname")
|
||||
model = EOGRegression()
|
||||
model.__dict__.update(read_hdf5(fname))
|
||||
return model
|
||||
655
mne/preprocessing/artifact_detection.py
Normal file
655
mne/preprocessing/artifact_detection.py
Normal file
@@ -0,0 +1,655 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
|
||||
import numpy as np
|
||||
from scipy.ndimage import distance_transform_edt, label
|
||||
from scipy.signal import find_peaks
|
||||
from scipy.stats import zscore
|
||||
|
||||
from ..annotations import (
|
||||
Annotations,
|
||||
_adjust_onset_meas_date,
|
||||
_annotations_starts_stops,
|
||||
annotations_from_events,
|
||||
)
|
||||
from ..filter import filter_data
|
||||
from ..io.base import BaseRaw
|
||||
from ..transforms import (
|
||||
Transform,
|
||||
_angle_between_quats,
|
||||
_average_quats,
|
||||
_quat_to_affine,
|
||||
apply_trans,
|
||||
quat_to_rot,
|
||||
)
|
||||
from ..utils import (
|
||||
_check_option,
|
||||
_mask_to_onsets_offsets,
|
||||
_pl,
|
||||
_validate_type,
|
||||
logger,
|
||||
verbose,
|
||||
warn,
|
||||
)
|
||||
|
||||
|
||||
@verbose
|
||||
def annotate_muscle_zscore(
|
||||
raw,
|
||||
threshold=4,
|
||||
ch_type=None,
|
||||
min_length_good=0.1,
|
||||
filter_freq=(110, 140),
|
||||
n_jobs=None,
|
||||
verbose=None,
|
||||
):
|
||||
"""Create annotations for segments that likely contain muscle artifacts.
|
||||
|
||||
Detects data segments containing activity in the frequency range given by
|
||||
``filter_freq`` whose envelope magnitude exceeds the specified z-score
|
||||
threshold, when summed across channels and divided by ``sqrt(n_channels)``.
|
||||
False-positive transient peaks are prevented by low-pass filtering the
|
||||
resulting z-score time series at 4 Hz. Only operates on a single channel
|
||||
type, if ``ch_type`` is ``None`` it will select the first type in the list
|
||||
``mag``, ``grad``, ``eeg``.
|
||||
See :footcite:`Muthukumaraswamy2013` for background on choosing
|
||||
``filter_freq`` and ``threshold``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
Data to estimate segments with muscle artifacts.
|
||||
threshold : float
|
||||
The threshold in z-scores for marking segments as containing muscle
|
||||
activity artifacts.
|
||||
ch_type : 'mag' | 'grad' | 'eeg' | None
|
||||
The type of sensors to use. If ``None`` it will take the first type in
|
||||
``mag``, ``grad``, ``eeg``.
|
||||
min_length_good : float | None
|
||||
The shortest allowed duration of "good data" (in seconds) between
|
||||
adjacent annotations; shorter segments will be incorporated into the
|
||||
surrounding annotations.``None`` is equivalent to ``0``.
|
||||
Default is ``0.1``.
|
||||
filter_freq : array-like, shape (2,)
|
||||
The lower and upper frequencies of the band-pass filter.
|
||||
Default is ``(110, 140)``.
|
||||
%(n_jobs)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
annot : mne.Annotations
|
||||
Periods with muscle artifacts annotated as BAD_muscle.
|
||||
scores_muscle : array
|
||||
Z-score values averaged across channels for each sample.
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
raw_copy = raw.copy()
|
||||
|
||||
if ch_type is None:
|
||||
raw_ch_type = raw_copy.get_channel_types()
|
||||
if "mag" in raw_ch_type:
|
||||
ch_type = "mag"
|
||||
elif "grad" in raw_ch_type:
|
||||
ch_type = "grad"
|
||||
elif "eeg" in raw_ch_type:
|
||||
ch_type = "eeg"
|
||||
else:
|
||||
raise ValueError(
|
||||
"No M/EEG channel types found, please specify a 'ch_type' or provide "
|
||||
"M/EEG sensor data."
|
||||
)
|
||||
logger.info("Using %s sensors for muscle artifact detection", ch_type)
|
||||
else:
|
||||
_check_option("ch_type", ch_type, ["mag", "grad", "eeg"])
|
||||
raw_copy.pick(ch_type)
|
||||
|
||||
raw_copy.filter(
|
||||
filter_freq[0],
|
||||
filter_freq[1],
|
||||
fir_design="firwin",
|
||||
pad="reflect_limited",
|
||||
n_jobs=n_jobs,
|
||||
)
|
||||
raw_copy.apply_hilbert(envelope=True, n_jobs=n_jobs)
|
||||
|
||||
data = raw_copy.get_data(reject_by_annotation="NaN")
|
||||
nan_mask = ~np.isnan(data[0])
|
||||
sfreq = raw_copy.info["sfreq"]
|
||||
|
||||
art_scores = zscore(data[:, nan_mask], axis=1)
|
||||
art_scores = art_scores.sum(axis=0) / np.sqrt(art_scores.shape[0])
|
||||
art_scores = filter_data(art_scores, sfreq, None, 4)
|
||||
|
||||
scores_muscle = np.zeros(data.shape[1])
|
||||
scores_muscle[nan_mask] = art_scores
|
||||
|
||||
art_mask = scores_muscle > threshold
|
||||
# return muscle scores with NaNs
|
||||
scores_muscle[~nan_mask] = np.nan
|
||||
|
||||
# remove artifact free periods shorter than min_length_good
|
||||
min_length_good = 0 if min_length_good is None else min_length_good
|
||||
min_samps = min_length_good * sfreq
|
||||
comps, num_comps = label(art_mask == 0)
|
||||
for com in range(1, num_comps + 1):
|
||||
l_idx = np.nonzero(comps == com)[0]
|
||||
if len(l_idx) < min_samps:
|
||||
art_mask[l_idx] = True
|
||||
|
||||
annot = _annotations_from_mask(
|
||||
raw_copy.times, art_mask, "BAD_muscle", orig_time=raw.info["meas_date"]
|
||||
)
|
||||
_adjust_onset_meas_date(annot, raw)
|
||||
return annot, scores_muscle
|
||||
|
||||
|
||||
def annotate_movement(
|
||||
raw,
|
||||
pos,
|
||||
rotation_velocity_limit=None,
|
||||
translation_velocity_limit=None,
|
||||
mean_distance_limit=None,
|
||||
use_dev_head_trans="average",
|
||||
):
|
||||
"""Detect segments with movement.
|
||||
|
||||
Detects segments periods further from rotation_velocity_limit,
|
||||
translation_velocity_limit and mean_distance_limit. It returns an
|
||||
annotation with the bad segments.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
Data to compute head position.
|
||||
pos : array, shape (N, 10)
|
||||
The position and quaternion parameters from cHPI fitting. Obtained
|
||||
with `mne.chpi` functions.
|
||||
rotation_velocity_limit : float
|
||||
Head rotation velocity limit in degrees per second.
|
||||
translation_velocity_limit : float
|
||||
Head translation velocity limit in meters per second.
|
||||
mean_distance_limit : float
|
||||
Head position limit from mean recording in meters.
|
||||
use_dev_head_trans : 'average' (default) | 'info'
|
||||
Identify the device to head transform used to define the
|
||||
fixed HPI locations for computing moving distances.
|
||||
If ``average`` the average device to head transform is
|
||||
computed using ``compute_average_dev_head_t``.
|
||||
If ``info``, ``raw.info['dev_head_t']`` is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
annot : mne.Annotations
|
||||
Periods with head motion.
|
||||
hpi_disp : array
|
||||
Head position over time with respect to the mean head pos.
|
||||
|
||||
See Also
|
||||
--------
|
||||
compute_average_dev_head_t
|
||||
"""
|
||||
sfreq = raw.info["sfreq"]
|
||||
hp_ts = pos[:, 0].copy() - raw.first_time
|
||||
dt = np.diff(hp_ts)
|
||||
hp_ts = np.concatenate([hp_ts, [hp_ts[-1] + 1.0 / sfreq]])
|
||||
orig_time = raw.info["meas_date"]
|
||||
annot = Annotations([], [], [], orig_time=orig_time)
|
||||
|
||||
# Annotate based on rotational velocity
|
||||
t_tot = raw.times[-1]
|
||||
if rotation_velocity_limit is not None:
|
||||
assert rotation_velocity_limit > 0
|
||||
# Rotational velocity (radians / s)
|
||||
r = _angle_between_quats(pos[:-1, 1:4], pos[1:, 1:4])
|
||||
r /= dt
|
||||
bad_mask = r >= np.deg2rad(rotation_velocity_limit)
|
||||
onsets, offsets = _mask_to_onsets_offsets(bad_mask)
|
||||
onsets, offsets = hp_ts[onsets], hp_ts[offsets]
|
||||
bad_pct = 100 * (offsets - onsets).sum() / t_tot
|
||||
logger.info(
|
||||
"Omitting %5.1f%% (%3d segments): " "ω >= %5.1f°/s (max: %0.1f°/s)",
|
||||
bad_pct,
|
||||
len(onsets),
|
||||
rotation_velocity_limit,
|
||||
np.rad2deg(r.max()),
|
||||
)
|
||||
annot += _annotations_from_mask(
|
||||
hp_ts, bad_mask, "BAD_mov_rotat_vel", orig_time=orig_time
|
||||
)
|
||||
|
||||
# Annotate based on translational velocity limit
|
||||
if translation_velocity_limit is not None:
|
||||
assert translation_velocity_limit > 0
|
||||
v = np.linalg.norm(np.diff(pos[:, 4:7], axis=0), axis=-1)
|
||||
v /= dt
|
||||
bad_mask = v >= translation_velocity_limit
|
||||
onsets, offsets = _mask_to_onsets_offsets(bad_mask)
|
||||
onsets, offsets = hp_ts[onsets], hp_ts[offsets]
|
||||
bad_pct = 100 * (offsets - onsets).sum() / t_tot
|
||||
logger.info(
|
||||
"Omitting %5.1f%% (%3d segments): " "v >= %5.4fm/s (max: %5.4fm/s)",
|
||||
bad_pct,
|
||||
len(onsets),
|
||||
translation_velocity_limit,
|
||||
v.max(),
|
||||
)
|
||||
annot += _annotations_from_mask(
|
||||
hp_ts, bad_mask, "BAD_mov_trans_vel", orig_time=orig_time
|
||||
)
|
||||
|
||||
# Annotate based on displacement from mean head position
|
||||
disp = []
|
||||
if mean_distance_limit is not None:
|
||||
assert mean_distance_limit > 0
|
||||
|
||||
# compute dev to head transform for fixed points
|
||||
use_dev_head_trans = use_dev_head_trans.lower()
|
||||
if use_dev_head_trans not in ["average", "info"]:
|
||||
raise ValueError(
|
||||
"use_dev_head_trans must be either"
|
||||
f" 'average' or 'info': got '{use_dev_head_trans}'"
|
||||
)
|
||||
|
||||
if use_dev_head_trans == "average":
|
||||
fixed_dev_head_t = compute_average_dev_head_t(raw, pos)
|
||||
elif use_dev_head_trans == "info":
|
||||
fixed_dev_head_t = raw.info["dev_head_t"]
|
||||
|
||||
# Get static head pos from file, used to convert quat to cartesian
|
||||
chpi_pos = sorted(
|
||||
[d for d in raw.info["hpi_results"][-1]["dig_points"]],
|
||||
key=lambda x: x["ident"],
|
||||
)
|
||||
chpi_pos = np.array([d["r"] for d in chpi_pos])
|
||||
|
||||
# Get head pos changes during recording
|
||||
chpi_pos_mov = np.array(
|
||||
[apply_trans(_quat_to_affine(quat), chpi_pos) for quat in pos[:, 1:7]]
|
||||
)
|
||||
|
||||
# get fixed position
|
||||
chpi_pos_fix = apply_trans(fixed_dev_head_t, chpi_pos)
|
||||
|
||||
# get movement displacement from mean pos
|
||||
hpi_disp = chpi_pos_mov - np.tile(chpi_pos_fix, (pos.shape[0], 1, 1))
|
||||
|
||||
# get positions above threshold distance
|
||||
disp = np.sqrt((hpi_disp**2).sum(axis=2))
|
||||
bad_mask = np.any(disp > mean_distance_limit, axis=1)
|
||||
onsets, offsets = _mask_to_onsets_offsets(bad_mask)
|
||||
onsets, offsets = hp_ts[onsets], hp_ts[offsets]
|
||||
bad_pct = 100 * (offsets - onsets).sum() / t_tot
|
||||
logger.info(
|
||||
"Omitting %5.1f%% (%3d segments): " "disp >= %5.4fm (max: %5.4fm)",
|
||||
bad_pct,
|
||||
len(onsets),
|
||||
mean_distance_limit,
|
||||
disp.max(),
|
||||
)
|
||||
annot += _annotations_from_mask(
|
||||
hp_ts, bad_mask, "BAD_mov_dist", orig_time=orig_time
|
||||
)
|
||||
_adjust_onset_meas_date(annot, raw)
|
||||
return annot, disp
|
||||
|
||||
|
||||
@verbose
|
||||
def compute_average_dev_head_t(raw, pos, *, verbose=None):
|
||||
"""Get new device to head transform based on good segments.
|
||||
|
||||
Segments starting with "BAD" annotations are not included for calculating
|
||||
the mean head position.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw | list of Raw
|
||||
Data to compute head position. Can be a list containing multiple raw
|
||||
instances.
|
||||
pos : array, shape (N, 10) | list of ndarray
|
||||
The position and quaternion parameters from cHPI fitting. Can be
|
||||
a list containing multiple position arrays, one per raw instance passed.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
dev_head_t : instance of Transform
|
||||
New ``dev_head_t`` transformation using the averaged good head positions.
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionchanged:: 1.7
|
||||
Support for multiple raw instances and position arrays was added.
|
||||
"""
|
||||
# Get weighted head pos trans and rot
|
||||
if not isinstance(raw, list | tuple):
|
||||
raw = [raw]
|
||||
if not isinstance(pos, list | tuple):
|
||||
pos = [pos]
|
||||
if len(pos) != len(raw):
|
||||
raise ValueError(
|
||||
f"Number of head positions ({len(pos)}) must match the number of raw "
|
||||
f"instances ({len(raw)})"
|
||||
)
|
||||
hp = list()
|
||||
dt = list()
|
||||
for ri, (r, p) in enumerate(zip(raw, pos)):
|
||||
_validate_type(r, BaseRaw, f"raw[{ri}]")
|
||||
_validate_type(p, np.ndarray, f"pos[{ri}]")
|
||||
hp_, dt_ = _raw_hp_weights(r, p)
|
||||
hp.append(hp_)
|
||||
dt.append(dt_)
|
||||
hp = np.concatenate(hp, axis=0)
|
||||
dt = np.concatenate(dt, axis=0)
|
||||
dt /= dt.sum()
|
||||
best_q = _average_quats(hp[:, 1:4], weights=dt)
|
||||
trans = np.eye(4)
|
||||
trans[:3, :3] = quat_to_rot(best_q)
|
||||
trans[:3, 3] = dt @ hp[:, 4:7]
|
||||
dist = np.linalg.norm(trans[:3, 3])
|
||||
if dist > 1: # less than 1 meter is sane
|
||||
warn(f"Implausible head position detected: {dist} meters from device origin")
|
||||
dev_head_t = Transform("meg", "head", trans)
|
||||
return dev_head_t
|
||||
|
||||
|
||||
def _raw_hp_weights(raw, pos):
|
||||
sfreq = raw.info["sfreq"]
|
||||
seg_good = np.ones(len(raw.times))
|
||||
hp = pos.copy()
|
||||
hp_ts = hp[:, 0] - raw._first_time
|
||||
|
||||
# Check rounding issues at 0 time
|
||||
if hp_ts[0] < 0:
|
||||
hp_ts[0] = 0
|
||||
assert hp_ts[1] > 1.0 / sfreq
|
||||
|
||||
# Mask out segments if beyond scan time
|
||||
mask = hp_ts <= raw.times[-1]
|
||||
if not mask.all():
|
||||
logger.info(
|
||||
" Removing %d samples > raw.times[-1] (%s)",
|
||||
np.sum(~mask),
|
||||
raw.times[-1],
|
||||
)
|
||||
hp = hp[mask]
|
||||
del mask, hp_ts
|
||||
|
||||
# Get time indices
|
||||
ts = np.concatenate((hp[:, 0], [(raw.last_samp + 1) / sfreq]))
|
||||
assert (np.diff(ts) > 0).all()
|
||||
ts -= raw.first_samp / sfreq
|
||||
idx = raw.time_as_index(ts, use_rounding=True)
|
||||
del ts
|
||||
if idx[0] == -1: # annoying rounding errors
|
||||
idx[0] = 0
|
||||
assert idx[1] > 0
|
||||
assert (idx >= 0).all()
|
||||
assert idx[-1] == len(seg_good)
|
||||
assert (np.diff(idx) > 0).all()
|
||||
|
||||
# Mark times bad that are bad according to annotations
|
||||
onsets, ends = _annotations_starts_stops(raw, "bad")
|
||||
for onset, end in zip(onsets, ends):
|
||||
seg_good[onset:end] = 0
|
||||
dt = np.diff(np.cumsum(np.concatenate([[0], seg_good]))[idx])
|
||||
assert (dt >= 0).all()
|
||||
dt = dt / sfreq
|
||||
del seg_good, idx
|
||||
return hp, dt
|
||||
|
||||
|
||||
def _annotations_from_mask(times, mask, annot_name, orig_time=None):
|
||||
"""Construct annotations from boolean mask of the data."""
|
||||
mask_tf = distance_transform_edt(mask)
|
||||
# Overcome the shortcoming of find_peaks
|
||||
# in finding a marginal peak, by
|
||||
# inserting 0s at the front and the
|
||||
# rear, then subtracting in index
|
||||
ins_mask_tf = np.concatenate((np.zeros(1), mask_tf, np.zeros(1)))
|
||||
left_midpt_index = find_peaks(ins_mask_tf)[0] - 1
|
||||
right_midpt_index = (
|
||||
np.flip(len(ins_mask_tf) - 1 - find_peaks(ins_mask_tf[::-1])[0]) - 1
|
||||
)
|
||||
onsets_index = left_midpt_index - mask_tf[left_midpt_index].astype(int) + 1
|
||||
ends_index = right_midpt_index + mask_tf[right_midpt_index].astype(int)
|
||||
# Ensure onsets_index >= 0,
|
||||
# otherwise the duration starts from the beginning
|
||||
onsets_index[onsets_index < 0] = 0
|
||||
# Ensure ends_index < len(times),
|
||||
# otherwise the duration is to the end of times
|
||||
if len(times) == len(mask):
|
||||
ends_index[ends_index >= len(times)] = len(times) - 1
|
||||
# To be consistent with the original code,
|
||||
# possibly a bug in tests code
|
||||
else:
|
||||
ends_index[ends_index >= len(mask)] = len(mask)
|
||||
onsets = times[onsets_index]
|
||||
ends = times[ends_index]
|
||||
durations = ends - onsets
|
||||
desc = [annot_name] * len(durations)
|
||||
return Annotations(onsets, durations, desc, orig_time=orig_time)
|
||||
|
||||
|
||||
@verbose
|
||||
def annotate_break(
|
||||
raw,
|
||||
events=None,
|
||||
min_break_duration=15.0,
|
||||
t_start_after_previous=5.0,
|
||||
t_stop_before_next=5.0,
|
||||
ignore=("bad", "edge"),
|
||||
*,
|
||||
verbose=None,
|
||||
):
|
||||
"""Create `~mne.Annotations` for breaks in an ongoing recording.
|
||||
|
||||
This function first searches for segments in the data that are not
|
||||
annotated or do not contain any events and are at least
|
||||
``min_break_duration`` seconds long, and then proceeds to creating
|
||||
annotations for those break periods.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The continuous data to analyze.
|
||||
events : None | array, shape (n_events, 3)
|
||||
If ``None`` (default), operate based solely on the annotations present
|
||||
in ``raw``. If an events array, ignore any annotations in the raw data,
|
||||
and operate based on these events only.
|
||||
min_break_duration : float
|
||||
The minimum time span in seconds between the offset of one and the
|
||||
onset of the subsequent annotation (if ``events`` is ``None``) or
|
||||
between two consecutive events (if ``events`` is an array) to consider
|
||||
this period a "break". Defaults to 15 seconds.
|
||||
|
||||
.. note:: This value defines the minimum duration of a break period in
|
||||
the data, **not** the minimum duration of the generated
|
||||
annotations! See also ``t_start_after_previous`` and
|
||||
``t_stop_before_next`` for details.
|
||||
|
||||
t_start_after_previous, t_stop_before_next : float
|
||||
Specifies how far the to-be-created "break" annotation extends towards
|
||||
the two annotations or events spanning the break. This can be used to
|
||||
ensure e.g. that the break annotation doesn't start and end immediately
|
||||
with a stimulation event. If, for example, your data contains a break
|
||||
of 30 seconds between two stimuli, and ``t_start_after_previous`` is
|
||||
set to ``5`` and ``t_stop_before_next`` is set to ``3``, the break
|
||||
annotation will start 5 seconds after the first stimulus, and end 3
|
||||
seconds before the second stimulus, yielding an annotated break of
|
||||
``30 - 5 - 3 = 22`` seconds. Both default to 5 seconds.
|
||||
|
||||
.. note:: The beginning and the end of the recording will be annotated
|
||||
as breaks, too, if the period from recording start until the
|
||||
first annotation or event (or from last annotation or event
|
||||
until recording end) is at least ``min_break_duration``
|
||||
seconds long.
|
||||
|
||||
ignore : iterable of str
|
||||
Annotation descriptions starting with these strings will be ignored by
|
||||
the break-finding algorithm. The string comparison is case-insensitive,
|
||||
i.e., ``('bad',)`` and ``('BAD',)`` are equivalent. By default, all
|
||||
annotation descriptions starting with "bad" and annotations
|
||||
indicating "edges" (produced by data concatenation) will be
|
||||
ignored. Pass an empty list or tuple to take all existing annotations
|
||||
into account. If ``events`` is passed, this parameter has no effect.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
break_annotations : instance of Annotations
|
||||
The break annotations, each with the description ``'BAD_break'``. If
|
||||
no breaks could be found given the provided function parameters, an
|
||||
empty `~mne.Annotations` object will be returned.
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 0.24
|
||||
"""
|
||||
_validate_type(item=raw, item_name="raw", types=BaseRaw, type_name="Raw")
|
||||
_validate_type(item=events, item_name="events", types=(None, np.ndarray))
|
||||
|
||||
if min_break_duration - t_start_after_previous - t_stop_before_next <= 0:
|
||||
annot_dur = min_break_duration - t_start_after_previous - t_stop_before_next
|
||||
raise ValueError(
|
||||
f"The result of "
|
||||
f"min_break_duration - t_start_after_previous - "
|
||||
f"t_stop_before_next must be greater than 0, but it is: "
|
||||
f"{annot_dur}"
|
||||
)
|
||||
|
||||
if events is not None and events.size == 0:
|
||||
raise ValueError("The events array must not be empty.")
|
||||
|
||||
if events is not None or not ignore:
|
||||
ignore = tuple()
|
||||
else:
|
||||
ignore = tuple(ignore)
|
||||
|
||||
for item in ignore:
|
||||
_validate_type(item=item, types="str", item_name='All elements of "ignore"')
|
||||
|
||||
if events is None:
|
||||
annotations = raw.annotations.copy()
|
||||
if ignore:
|
||||
logger.info(
|
||||
f"Ignoring annotations with descriptions starting "
|
||||
f'with: {", ".join(ignore)}'
|
||||
)
|
||||
else:
|
||||
annotations = annotations_from_events(
|
||||
events=events, sfreq=raw.info["sfreq"], orig_time=raw.info["meas_date"]
|
||||
)
|
||||
|
||||
if not annotations:
|
||||
raise ValueError("Could not find (or generate) any annotations in your data.")
|
||||
|
||||
# Only keep annotations of interest and extract annotated time periods
|
||||
# Ignore case
|
||||
ignore = tuple(i.lower() for i in ignore)
|
||||
keep_mask = [True] * len(annotations)
|
||||
for idx, description in enumerate(annotations.description):
|
||||
description = description.lower()
|
||||
if any(description.startswith(i) for i in ignore):
|
||||
keep_mask[idx] = False
|
||||
|
||||
annotated_intervals = [
|
||||
[onset, onset + duration]
|
||||
for onset, duration in zip(
|
||||
annotations.onset[keep_mask], annotations.duration[keep_mask]
|
||||
)
|
||||
]
|
||||
|
||||
# Merge overlapping annotation intervals
|
||||
# Pre-load `merged_intervals` with the first interval to simplify
|
||||
# processing
|
||||
merged_intervals = [annotated_intervals[0]]
|
||||
for interval in annotated_intervals:
|
||||
merged_interval_stop = merged_intervals[-1][1]
|
||||
interval_start, interval_stop = interval
|
||||
|
||||
if interval_stop < merged_interval_stop:
|
||||
# Current interval ends sooner than the merged one; skip it
|
||||
continue
|
||||
elif (
|
||||
interval_start <= merged_interval_stop
|
||||
and interval_stop >= merged_interval_stop
|
||||
):
|
||||
# Expand duration of the merged interval
|
||||
merged_intervals[-1][1] = interval_stop
|
||||
else:
|
||||
# No overlap between the current interval and the existing merged
|
||||
# time period; proceed to the next interval
|
||||
merged_intervals.append(interval)
|
||||
|
||||
merged_intervals = np.array(merged_intervals)
|
||||
merged_intervals -= raw.first_time # work in zero-based time
|
||||
|
||||
# Now extract the actual break periods
|
||||
break_onsets = []
|
||||
break_durations = []
|
||||
|
||||
# Handle the time period up until the first annotation
|
||||
if 0 < merged_intervals[0][0] and merged_intervals[0][0] >= min_break_duration:
|
||||
onset = 0 # don't add t_start_after_previous here
|
||||
offset = merged_intervals[0][0] - t_stop_before_next
|
||||
duration = offset - onset
|
||||
break_onsets.append(onset)
|
||||
break_durations.append(duration)
|
||||
|
||||
# Handle the time period between first and last annotation
|
||||
for idx, _ in enumerate(merged_intervals[1:, :], start=1):
|
||||
this_start = merged_intervals[idx, 0]
|
||||
previous_stop = merged_intervals[idx - 1, 1]
|
||||
if this_start - previous_stop < min_break_duration:
|
||||
continue
|
||||
|
||||
onset = previous_stop + t_start_after_previous
|
||||
offset = this_start - t_stop_before_next
|
||||
duration = offset - onset
|
||||
break_onsets.append(onset)
|
||||
break_durations.append(duration)
|
||||
|
||||
# Handle the time period after the last annotation
|
||||
if (
|
||||
raw.times[-1] > merged_intervals[-1][1]
|
||||
and raw.times[-1] - merged_intervals[-1][1] >= min_break_duration
|
||||
):
|
||||
onset = merged_intervals[-1][1] + t_start_after_previous
|
||||
offset = raw.times[-1] # don't subtract t_stop_before_next here
|
||||
duration = offset - onset
|
||||
break_onsets.append(onset)
|
||||
break_durations.append(duration)
|
||||
|
||||
# Finally, create the break annotations
|
||||
break_annotations = Annotations(
|
||||
onset=break_onsets,
|
||||
duration=break_durations,
|
||||
description=["BAD_break"],
|
||||
orig_time=raw.info["meas_date"],
|
||||
)
|
||||
|
||||
# Log some info
|
||||
n_breaks = len(break_annotations)
|
||||
break_times = [
|
||||
f"{o:.1f} – {o + d:.1f} s [{d:.1f} s]"
|
||||
for o, d in zip(break_annotations.onset, break_annotations.duration)
|
||||
]
|
||||
break_times = "\n ".join(break_times)
|
||||
total_break_dur = sum(break_annotations.duration)
|
||||
fraction_breaks = total_break_dur / raw.times[-1]
|
||||
logger.info(
|
||||
f"\nDetected {n_breaks} break period{_pl(n_breaks)} of >= "
|
||||
f"{min_break_duration} s duration:\n {break_times}\n"
|
||||
f"In total, {round(100 * fraction_breaks, 1):.1f}% of the "
|
||||
f"data ({round(total_break_dur, 1):.1f} s) have been marked "
|
||||
f"as a break.\n"
|
||||
)
|
||||
_adjust_onset_meas_date(break_annotations, raw)
|
||||
|
||||
return break_annotations
|
||||
50
mne/preprocessing/bads.py
Normal file
50
mne/preprocessing/bads.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
from scipy.stats import zscore
|
||||
|
||||
|
||||
def _find_outliers(X, threshold=3.0, max_iter=2, tail=0):
|
||||
"""Find outliers based on iterated Z-scoring.
|
||||
|
||||
This procedure compares the absolute z-score against the threshold.
|
||||
After excluding local outliers, the comparison is repeated until no
|
||||
local outlier is present any more.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : np.ndarray of float, shape (n_elemenets,)
|
||||
The scores for which to find outliers.
|
||||
threshold : float
|
||||
The value above which a feature is classified as outlier.
|
||||
max_iter : int
|
||||
The maximum number of iterations.
|
||||
tail : {0, 1, -1}
|
||||
Whether to search for outliers on both extremes of the z-scores (0),
|
||||
or on just the positive (1) or negative (-1) side.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bad_idx : np.ndarray of int, shape (n_features)
|
||||
The outlier indices.
|
||||
"""
|
||||
my_mask = np.zeros(len(X), dtype=bool)
|
||||
for _ in range(max_iter):
|
||||
X = np.ma.masked_array(X, my_mask)
|
||||
if tail == 0:
|
||||
this_z = np.abs(zscore(X))
|
||||
elif tail == 1:
|
||||
this_z = zscore(X)
|
||||
elif tail == -1:
|
||||
this_z = -zscore(X)
|
||||
else:
|
||||
raise ValueError(f"Tail parameter {tail} not recognised.")
|
||||
local_bad = this_z > threshold
|
||||
my_mask = np.max([my_mask, local_bad], 0)
|
||||
if not np.any(local_bad):
|
||||
break
|
||||
|
||||
bad_idx = np.where(my_mask)[0]
|
||||
return bad_idx
|
||||
167
mne/preprocessing/ctps_.py
Normal file
167
mne/preprocessing/ctps_.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from scipy.signal import hilbert
|
||||
from scipy.special import logsumexp
|
||||
|
||||
|
||||
def _compute_normalized_phase(data):
|
||||
"""Compute normalized phase angles.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : ndarray, shape (n_epochs, n_sources, n_times)
|
||||
The data to compute the phase angles for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
phase_angles : ndarray, shape (n_epochs, n_sources, n_times)
|
||||
The normalized phase angles.
|
||||
"""
|
||||
return (np.angle(hilbert(data)) + np.pi) / (2 * np.pi)
|
||||
|
||||
|
||||
def ctps(data, is_raw=True):
|
||||
"""Compute cross-trial-phase-statistics [1].
|
||||
|
||||
Note. It is assumed that the sources are already
|
||||
appropriately filtered
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: ndarray, shape (n_epochs, n_channels, n_times)
|
||||
Any kind of data of dimensions trials, traces, features.
|
||||
is_raw : bool
|
||||
If True it is assumed that data haven't been transformed to Hilbert
|
||||
space and phase angles haven't been normalized. Defaults to True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ks_dynamics : ndarray, shape (n_sources, n_times)
|
||||
The kuiper statistics.
|
||||
pk_dynamics : ndarray, shape (n_sources, n_times)
|
||||
The normalized kuiper index for ICA sources and
|
||||
time slices.
|
||||
phase_angles : ndarray, shape (n_epochs, n_sources, n_times) | None
|
||||
The phase values for epochs, sources and time slices. If ``is_raw``
|
||||
is False, None is returned.
|
||||
|
||||
References
|
||||
----------
|
||||
[1] Dammers, J., Schiek, M., Boers, F., Silex, C., Zvyagintsev,
|
||||
M., Pietrzyk, U., Mathiak, K., 2008. Integration of amplitude
|
||||
and phase statistics for complete artifact removal in independent
|
||||
components of neuromagnetic recordings. Biomedical
|
||||
Engineering, IEEE Transactions on 55 (10), 2353-2362.
|
||||
"""
|
||||
if not data.ndim == 3:
|
||||
raise ValueError(f"Data must have 3 dimensions, not {data.ndim}.")
|
||||
|
||||
if is_raw:
|
||||
phase_angles = _compute_normalized_phase(data)
|
||||
else:
|
||||
phase_angles = data # phase angles can be computed externally
|
||||
|
||||
# initialize array for results
|
||||
ks_dynamics = np.zeros_like(phase_angles[0])
|
||||
pk_dynamics = np.zeros_like(phase_angles[0])
|
||||
|
||||
# calculate Kuiper's statistic for each source
|
||||
for ii, source in enumerate(np.transpose(phase_angles, [1, 0, 2])):
|
||||
ks, pk = kuiper(source)
|
||||
pk_dynamics[ii, :] = pk
|
||||
ks_dynamics[ii, :] = ks
|
||||
|
||||
return ks_dynamics, pk_dynamics, phase_angles if is_raw else None
|
||||
|
||||
|
||||
def kuiper(data, dtype=np.float64): # noqa: D401
|
||||
"""Kuiper's test of uniform distribution.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : ndarray, shape (n_sources,) | (n_sources, n_times)
|
||||
Empirical distribution.
|
||||
dtype : str | obj
|
||||
The data type to be used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ks : ndarray
|
||||
Kuiper's statistic.
|
||||
pk : ndarray
|
||||
Normalized probability of Kuiper's statistic [0, 1].
|
||||
"""
|
||||
# if data not numpy array, implicitly convert and make to use copied data
|
||||
# ! sort data array along first axis !
|
||||
data = np.sort(data, axis=0).astype(dtype)
|
||||
shape = data.shape
|
||||
n_dim = len(shape)
|
||||
n_trials = shape[0]
|
||||
|
||||
# create uniform cdf
|
||||
j1 = (np.arange(n_trials, dtype=dtype) + 1.0) / float(n_trials)
|
||||
j2 = np.arange(n_trials, dtype=dtype) / float(n_trials)
|
||||
if n_dim > 1: # single phase vector (n_trials)
|
||||
j1 = j1[:, np.newaxis]
|
||||
j2 = j2[:, np.newaxis]
|
||||
d1 = (j1 - data).max(axis=0)
|
||||
d2 = (data - j2).max(axis=0)
|
||||
n_eff = n_trials
|
||||
|
||||
d = d1 + d2 # Kuiper's statistic [n_time_slices]
|
||||
|
||||
return d, _prob_kuiper(d, n_eff, dtype=dtype)
|
||||
|
||||
|
||||
def _prob_kuiper(d, n_eff, dtype="f8"):
|
||||
"""Test for statistical significance against uniform distribution.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d : float
|
||||
The kuiper distance value.
|
||||
n_eff : int
|
||||
The effective number of elements.
|
||||
dtype : str | obj
|
||||
The data type to be used. Defaults to double precision floats.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pk_norm : float
|
||||
The normalized Kuiper value such that 0 < ``pk_norm`` < 1.
|
||||
|
||||
References
|
||||
----------
|
||||
[1] Stephens MA 1970. Journal of the Royal Statistical Society, ser. B,
|
||||
vol 32, pp 115-122.
|
||||
|
||||
[2] Kuiper NH 1962. Proceedings of the Koninklijke Nederlands Akademie
|
||||
van Wetenschappen, ser Vol 63 pp 38-47
|
||||
"""
|
||||
n_time_slices = np.size(d) # single value or vector
|
||||
n_points = 100
|
||||
|
||||
en = math.sqrt(n_eff)
|
||||
k_lambda = (en + 0.155 + 0.24 / en) * d # see [1]
|
||||
l2 = k_lambda**2.0
|
||||
j2 = (np.arange(n_points) + 1) ** 2
|
||||
j2 = j2.repeat(n_time_slices).reshape(n_points, n_time_slices)
|
||||
fact = 4.0 * j2 * l2 - 1.0
|
||||
|
||||
# compute normalized pK value in range [0,1]
|
||||
a = -2.0 * j2 * l2
|
||||
b = 2.0 * fact
|
||||
pk_norm = -logsumexp(a, b=b, axis=0) / (2.0 * n_eff)
|
||||
|
||||
# check for no difference to uniform cdf
|
||||
pk_norm = np.where(k_lambda < 0.4, 0.0, pk_norm)
|
||||
|
||||
# check for round off errors
|
||||
pk_norm = np.where(pk_norm > 1.0, 1.0, pk_norm)
|
||||
|
||||
return pk_norm
|
||||
539
mne/preprocessing/ecg.py
Normal file
539
mne/preprocessing/ecg.py
Normal file
@@ -0,0 +1,539 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.meas_info import create_info
|
||||
from .._fiff.pick import _picks_to_idx, pick_channels, pick_types
|
||||
from ..annotations import _annotations_starts_stops
|
||||
from ..epochs import BaseEpochs, Epochs
|
||||
from ..evoked import Evoked
|
||||
from ..filter import filter_data
|
||||
from ..io import BaseRaw, RawArray
|
||||
from ..utils import int_like, logger, sum_squared, verbose, warn
|
||||
|
||||
|
||||
@verbose
|
||||
def qrs_detector(
|
||||
sfreq,
|
||||
ecg,
|
||||
thresh_value=0.6,
|
||||
levels=2.5,
|
||||
n_thresh=3,
|
||||
l_freq=5,
|
||||
h_freq=35,
|
||||
tstart=0,
|
||||
filter_length="10s",
|
||||
verbose=None,
|
||||
):
|
||||
"""Detect QRS component in ECG channels.
|
||||
|
||||
QRS is the main wave on the heart beat.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sfreq : float
|
||||
Sampling rate
|
||||
ecg : array
|
||||
ECG signal
|
||||
thresh_value : float | str
|
||||
qrs detection threshold. Can also be "auto" for automatic
|
||||
selection of threshold.
|
||||
levels : float
|
||||
number of std from mean to include for detection
|
||||
n_thresh : int
|
||||
max number of crossings
|
||||
l_freq : float
|
||||
Low pass frequency
|
||||
h_freq : float
|
||||
High pass frequency
|
||||
%(tstart_ecg)s
|
||||
%(filter_length_ecg)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
events : array
|
||||
Indices of ECG peaks.
|
||||
"""
|
||||
win_size = int(round((60.0 * sfreq) / 120.0))
|
||||
|
||||
filtecg = filter_data(
|
||||
ecg,
|
||||
sfreq,
|
||||
l_freq,
|
||||
h_freq,
|
||||
None,
|
||||
filter_length,
|
||||
0.5,
|
||||
0.5,
|
||||
phase="zero-double",
|
||||
fir_window="hann",
|
||||
fir_design="firwin2",
|
||||
)
|
||||
|
||||
ecg_abs = np.abs(filtecg)
|
||||
init = int(sfreq)
|
||||
|
||||
n_samples_start = int(sfreq * tstart)
|
||||
ecg_abs = ecg_abs[n_samples_start:]
|
||||
|
||||
n_points = len(ecg_abs)
|
||||
|
||||
maxpt = np.empty(3)
|
||||
maxpt[0] = np.max(ecg_abs[:init])
|
||||
maxpt[1] = np.max(ecg_abs[init : init * 2])
|
||||
maxpt[2] = np.max(ecg_abs[init * 2 : init * 3])
|
||||
|
||||
init_max = np.mean(maxpt)
|
||||
|
||||
if thresh_value == "auto":
|
||||
thresh_runs = np.arange(0.3, 1.1, 0.05)
|
||||
elif isinstance(thresh_value, str):
|
||||
raise ValueError('threshold value must be "auto" or a float')
|
||||
else:
|
||||
thresh_runs = [thresh_value]
|
||||
|
||||
# Try a few thresholds (or just one)
|
||||
clean_events = list()
|
||||
for thresh_value in thresh_runs:
|
||||
thresh1 = init_max * thresh_value
|
||||
numcross = list()
|
||||
time = list()
|
||||
rms = list()
|
||||
ii = 0
|
||||
while ii < (n_points - win_size):
|
||||
window = ecg_abs[ii : ii + win_size]
|
||||
if window[0] > thresh1:
|
||||
max_time = np.argmax(window)
|
||||
time.append(ii + max_time)
|
||||
nx = np.sum(
|
||||
np.diff(((window > thresh1).astype(np.int64) == 1).astype(int))
|
||||
)
|
||||
numcross.append(nx)
|
||||
rms.append(np.sqrt(sum_squared(window) / window.size))
|
||||
ii += win_size
|
||||
else:
|
||||
ii += 1
|
||||
|
||||
if len(rms) == 0:
|
||||
rms.append(0.0)
|
||||
time.append(0.0)
|
||||
time = np.array(time)
|
||||
rms_mean = np.mean(rms)
|
||||
rms_std = np.std(rms)
|
||||
rms_thresh = rms_mean + (rms_std * levels)
|
||||
b = np.where(rms < rms_thresh)[0]
|
||||
a = np.array(numcross)[b]
|
||||
ce = time[b[a < n_thresh]]
|
||||
|
||||
ce += n_samples_start
|
||||
if ce.size > 0: # We actually found an event
|
||||
clean_events.append(ce)
|
||||
|
||||
if clean_events:
|
||||
# pick the best threshold; first get effective heart rates
|
||||
rates = np.array(
|
||||
[60.0 * len(cev) / (len(ecg) / float(sfreq)) for cev in clean_events]
|
||||
)
|
||||
|
||||
# now find heart rates that seem reasonable (infant through adult
|
||||
# athlete)
|
||||
idx = np.where(np.logical_and(rates <= 160.0, rates >= 40.0))[0]
|
||||
if idx.size > 0:
|
||||
ideal_rate = np.median(rates[idx]) # get close to the median
|
||||
else:
|
||||
ideal_rate = 80.0 # get close to a reasonable default
|
||||
|
||||
idx = np.argmin(np.abs(rates - ideal_rate))
|
||||
clean_events = clean_events[idx]
|
||||
else:
|
||||
clean_events = np.array([])
|
||||
|
||||
return clean_events
|
||||
|
||||
|
||||
@verbose
|
||||
def find_ecg_events(
|
||||
raw,
|
||||
event_id=999,
|
||||
ch_name=None,
|
||||
tstart=0.0,
|
||||
l_freq=5,
|
||||
h_freq=35,
|
||||
qrs_threshold="auto",
|
||||
filter_length="10s",
|
||||
return_ecg=False,
|
||||
reject_by_annotation=True,
|
||||
verbose=None,
|
||||
):
|
||||
"""Find ECG events by localizing the R wave peaks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The raw data.
|
||||
%(event_id_ecg)s
|
||||
%(ch_name_ecg)s
|
||||
%(tstart_ecg)s
|
||||
%(l_freq_ecg_filter)s
|
||||
qrs_threshold : float | str
|
||||
Between 0 and 1. qrs detection threshold. Can also be "auto" to
|
||||
automatically choose the threshold that generates a reasonable
|
||||
number of heartbeats (40-160 beats / min).
|
||||
%(filter_length_ecg)s
|
||||
return_ecg : bool
|
||||
Return the ECG data. This is especially useful if no ECG channel
|
||||
is present in the input data, so one will be synthesized (only works if MEG
|
||||
channels are present in the data). Defaults to ``False``.
|
||||
%(reject_by_annotation_all)s
|
||||
|
||||
.. versionadded:: 0.18
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
ecg_events : array
|
||||
The events corresponding to the peaks of the R waves.
|
||||
ch_ecg : int | None
|
||||
Index of channel used.
|
||||
average_pulse : float
|
||||
The estimated average pulse. If no ECG events could be found, this will
|
||||
be zero.
|
||||
ecg : array | None
|
||||
The ECG data of the synthesized ECG channel, if any. This will only
|
||||
be returned if ``return_ecg=True`` was passed.
|
||||
|
||||
See Also
|
||||
--------
|
||||
create_ecg_epochs
|
||||
compute_proj_ecg
|
||||
"""
|
||||
skip_by_annotation = ("edge", "bad") if reject_by_annotation else ()
|
||||
del reject_by_annotation
|
||||
idx_ecg = _get_ecg_channel_index(ch_name, raw)
|
||||
if idx_ecg is not None:
|
||||
logger.info(f"Using channel {raw.ch_names[idx_ecg]} to identify heart beats.")
|
||||
ecg = raw.get_data(picks=idx_ecg)
|
||||
else:
|
||||
ecg, _ = _make_ecg(raw, start=None, stop=None)
|
||||
assert ecg.ndim == 2 and ecg.shape[0] == 1
|
||||
ecg = ecg[0]
|
||||
# Deal with filtering the same way we do in raw, i.e. filter each good
|
||||
# segment
|
||||
onsets, ends = _annotations_starts_stops(
|
||||
raw, skip_by_annotation, "reject_by_annotation", invert=True
|
||||
)
|
||||
ecgs = list()
|
||||
max_idx = (ends - onsets).argmax()
|
||||
for si, (start, stop) in enumerate(zip(onsets, ends)):
|
||||
# Only output filter params once (for info level), and only warn
|
||||
# once about the length criterion (longest segment is too short)
|
||||
use_verbose = verbose if si == max_idx else "error"
|
||||
ecgs.append(
|
||||
filter_data(
|
||||
ecg[start:stop],
|
||||
raw.info["sfreq"],
|
||||
l_freq,
|
||||
h_freq,
|
||||
[0],
|
||||
filter_length,
|
||||
0.5,
|
||||
0.5,
|
||||
1,
|
||||
"fir",
|
||||
None,
|
||||
copy=False,
|
||||
phase="zero-double",
|
||||
fir_window="hann",
|
||||
fir_design="firwin2",
|
||||
verbose=use_verbose,
|
||||
)
|
||||
)
|
||||
ecg = np.concatenate(ecgs)
|
||||
|
||||
# detecting QRS and generating events. Since not user-controlled, don't
|
||||
# output filter params here (hardcode verbose=False)
|
||||
ecg_events = qrs_detector(
|
||||
raw.info["sfreq"],
|
||||
ecg,
|
||||
tstart=tstart,
|
||||
thresh_value=qrs_threshold,
|
||||
l_freq=None,
|
||||
h_freq=None,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# map ECG events back to original times
|
||||
remap = np.empty(len(ecg), int)
|
||||
offset = 0
|
||||
for start, stop in zip(onsets, ends):
|
||||
this_len = stop - start
|
||||
assert this_len >= 0
|
||||
remap[offset : offset + this_len] = np.arange(start, stop)
|
||||
offset += this_len
|
||||
assert offset == len(ecg)
|
||||
|
||||
if ecg_events.size > 0:
|
||||
ecg_events = remap[ecg_events]
|
||||
else:
|
||||
ecg_events = np.array([])
|
||||
|
||||
n_events = len(ecg_events)
|
||||
duration_sec = len(ecg) / raw.info["sfreq"] - tstart
|
||||
duration_min = duration_sec / 60.0
|
||||
average_pulse = n_events / duration_min
|
||||
logger.info(
|
||||
f"Number of ECG events detected : {n_events} "
|
||||
f"(average pulse {average_pulse} / min.)"
|
||||
)
|
||||
|
||||
ecg_events = np.array(
|
||||
[
|
||||
ecg_events + raw.first_samp,
|
||||
np.zeros(n_events, int),
|
||||
event_id * np.ones(n_events, int),
|
||||
]
|
||||
).T
|
||||
|
||||
out = (ecg_events, idx_ecg, average_pulse)
|
||||
ecg = ecg[np.newaxis] # backward compat output 2D
|
||||
if return_ecg:
|
||||
out += (ecg,)
|
||||
return out
|
||||
|
||||
|
||||
def _get_ecg_channel_index(ch_name, inst):
|
||||
"""Get ECG channel index, if no channel found returns None."""
|
||||
if ch_name is None:
|
||||
ecg_idx = pick_types(
|
||||
inst.info,
|
||||
meg=False,
|
||||
eeg=False,
|
||||
stim=False,
|
||||
eog=False,
|
||||
ecg=True,
|
||||
emg=False,
|
||||
ref_meg=False,
|
||||
exclude="bads",
|
||||
)
|
||||
else:
|
||||
if ch_name not in inst.ch_names:
|
||||
raise ValueError(f"{ch_name} not in channel list ({inst.ch_names})")
|
||||
ecg_idx = pick_channels(inst.ch_names, include=[ch_name])
|
||||
|
||||
if len(ecg_idx) == 0:
|
||||
return None
|
||||
|
||||
if len(ecg_idx) > 1:
|
||||
warn(
|
||||
f"More than one ECG channel found. Using only {inst.ch_names[ecg_idx[0]]}."
|
||||
)
|
||||
|
||||
return ecg_idx[0]
|
||||
|
||||
|
||||
@verbose
|
||||
def create_ecg_epochs(
|
||||
raw,
|
||||
ch_name=None,
|
||||
event_id=999,
|
||||
picks=None,
|
||||
tmin=-0.5,
|
||||
tmax=0.5,
|
||||
l_freq=8,
|
||||
h_freq=16,
|
||||
reject=None,
|
||||
flat=None,
|
||||
baseline=None,
|
||||
preload=True,
|
||||
keep_ecg=False,
|
||||
reject_by_annotation=True,
|
||||
decim=1,
|
||||
verbose=None,
|
||||
):
|
||||
"""Conveniently generate epochs around ECG artifact events.
|
||||
|
||||
%(create_ecg_epochs)s
|
||||
|
||||
.. note:: Filtering is only applied to the ECG channel while finding
|
||||
events. The resulting ``ecg_epochs`` will have no filtering
|
||||
applied (i.e., have the same filter properties as the input
|
||||
``raw`` instance).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The raw data.
|
||||
%(ch_name_ecg)s
|
||||
%(event_id_ecg)s
|
||||
%(picks_all)s
|
||||
tmin : float
|
||||
Start time before event.
|
||||
tmax : float
|
||||
End time after event.
|
||||
%(l_freq_ecg_filter)s
|
||||
%(reject_epochs)s
|
||||
%(flat)s
|
||||
%(baseline_epochs)s
|
||||
preload : bool
|
||||
Preload epochs or not (default True). Must be True if
|
||||
keep_ecg is True.
|
||||
keep_ecg : bool
|
||||
When ECG is synthetically created (after picking), should it be added
|
||||
to the epochs? Must be False when synthetic channel is not used.
|
||||
Defaults to False.
|
||||
%(reject_by_annotation_epochs)s
|
||||
|
||||
.. versionadded:: 0.14.0
|
||||
%(decim)s
|
||||
|
||||
.. versionadded:: 0.21.0
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
ecg_epochs : instance of Epochs
|
||||
Data epoched around ECG R wave peaks.
|
||||
|
||||
See Also
|
||||
--------
|
||||
find_ecg_events
|
||||
compute_proj_ecg
|
||||
|
||||
Notes
|
||||
-----
|
||||
If you already have a list of R-peak times, or want to compute R-peaks
|
||||
outside MNE-Python using a different algorithm, the recommended approach is
|
||||
to call the :class:`~mne.Epochs` constructor directly, with your R-peaks
|
||||
formatted as an :term:`events` array (here we also demonstrate the relevant
|
||||
default values)::
|
||||
|
||||
mne.Epochs(raw, r_peak_events_array, tmin=-0.5, tmax=0.5,
|
||||
baseline=None, preload=True, proj=False) # doctest: +SKIP
|
||||
"""
|
||||
has_ecg = "ecg" in raw or ch_name is not None
|
||||
if keep_ecg and (has_ecg or not preload):
|
||||
raise ValueError(
|
||||
"keep_ecg can be True only if the ECG channel is "
|
||||
"created synthetically and preload=True."
|
||||
)
|
||||
|
||||
events, _, _, ecg = find_ecg_events(
|
||||
raw,
|
||||
ch_name=ch_name,
|
||||
event_id=event_id,
|
||||
l_freq=l_freq,
|
||||
h_freq=h_freq,
|
||||
return_ecg=True,
|
||||
reject_by_annotation=reject_by_annotation,
|
||||
)
|
||||
|
||||
picks = _picks_to_idx(raw.info, picks, "all", exclude=())
|
||||
|
||||
# create epochs around ECG events and baseline (important)
|
||||
ecg_epochs = Epochs(
|
||||
raw,
|
||||
events=events,
|
||||
event_id=event_id,
|
||||
tmin=tmin,
|
||||
tmax=tmax,
|
||||
proj=False,
|
||||
flat=flat,
|
||||
picks=picks,
|
||||
reject=reject,
|
||||
baseline=baseline,
|
||||
reject_by_annotation=reject_by_annotation,
|
||||
preload=preload,
|
||||
decim=decim,
|
||||
)
|
||||
|
||||
if keep_ecg:
|
||||
# We know we have created a synthetic channel and epochs are preloaded
|
||||
ecg_raw = RawArray(
|
||||
ecg,
|
||||
create_info(
|
||||
ch_names=["ECG-SYN"], sfreq=raw.info["sfreq"], ch_types=["ecg"]
|
||||
),
|
||||
first_samp=raw.first_samp,
|
||||
)
|
||||
with ecg_raw.info._unlock():
|
||||
ignore = ["ch_names", "chs", "nchan", "bads"]
|
||||
for k, v in raw.info.items():
|
||||
if k not in ignore:
|
||||
ecg_raw.info[k] = v
|
||||
syn_epochs = Epochs(
|
||||
ecg_raw,
|
||||
events=ecg_epochs.events,
|
||||
event_id=event_id,
|
||||
tmin=tmin,
|
||||
tmax=tmax,
|
||||
proj=False,
|
||||
picks=[0],
|
||||
baseline=baseline,
|
||||
decim=decim,
|
||||
preload=True,
|
||||
)
|
||||
ecg_epochs = ecg_epochs.add_channels([syn_epochs])
|
||||
|
||||
return ecg_epochs
|
||||
|
||||
|
||||
@verbose
|
||||
def _make_ecg(inst, start, stop, reject_by_annotation=False, verbose=None):
|
||||
"""Create ECG signal from cross channel average."""
|
||||
if not any(c in inst for c in ["mag", "grad"]):
|
||||
raise ValueError(
|
||||
"Generating an artificial ECG channel can only be done for MEG data."
|
||||
)
|
||||
for ch in ["mag", "grad"]:
|
||||
if ch in inst:
|
||||
break
|
||||
logger.info(
|
||||
"Reconstructing ECG signal from {}".format(
|
||||
{"mag": "Magnetometers", "grad": "Gradiometers"}[ch]
|
||||
)
|
||||
)
|
||||
picks = pick_types(inst.info, meg=ch, eeg=False, ref_meg=False)
|
||||
|
||||
# Handle start/stop
|
||||
msg = (
|
||||
"integer arguments for the start and stop parameters are "
|
||||
"not supported for Epochs and Evoked objects. Please "
|
||||
"consider using float arguments specifying start and stop "
|
||||
"time in seconds."
|
||||
)
|
||||
begin_param_name = "tmin"
|
||||
if isinstance(start, int_like):
|
||||
if isinstance(inst, BaseRaw):
|
||||
# Raw has start param, can just use int
|
||||
begin_param_name = "start"
|
||||
else:
|
||||
raise ValueError(msg)
|
||||
|
||||
end_param_name = "tmax"
|
||||
if isinstance(start, int_like):
|
||||
if isinstance(inst, BaseRaw):
|
||||
# Raw has stop param, can just use int
|
||||
end_param_name = "stop"
|
||||
else:
|
||||
raise ValueError(msg)
|
||||
|
||||
kwargs = {begin_param_name: start, end_param_name: stop}
|
||||
|
||||
if isinstance(inst, BaseRaw):
|
||||
reject_by_annotation = "omit" if reject_by_annotation else None
|
||||
ecg, times = inst.get_data(
|
||||
picks,
|
||||
return_times=True,
|
||||
**kwargs,
|
||||
reject_by_annotation=reject_by_annotation,
|
||||
)
|
||||
elif isinstance(inst, BaseEpochs):
|
||||
ecg = np.hstack(inst.copy().get_data(picks, **kwargs))
|
||||
times = inst.times
|
||||
elif isinstance(inst, Evoked):
|
||||
ecg = inst.get_data(picks, **kwargs)
|
||||
times = inst.times
|
||||
return ecg.mean(0, keepdims=True), times
|
||||
342
mne/preprocessing/eog.py
Normal file
342
mne/preprocessing/eog.py
Normal file
@@ -0,0 +1,342 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.pick import pick_channels, pick_types
|
||||
from ..epochs import Epochs
|
||||
from ..filter import filter_data
|
||||
from ..utils import _pl, _validate_type, logger, verbose
|
||||
from ._peak_finder import peak_finder
|
||||
|
||||
|
||||
@verbose
|
||||
def find_eog_events(
|
||||
raw,
|
||||
event_id=998,
|
||||
l_freq=1,
|
||||
h_freq=10,
|
||||
filter_length="10s",
|
||||
ch_name=None,
|
||||
tstart=0,
|
||||
reject_by_annotation=False,
|
||||
thresh=None,
|
||||
verbose=None,
|
||||
):
|
||||
"""Locate EOG artifacts.
|
||||
|
||||
.. note:: To control true-positive and true-negative detection rates, you
|
||||
may adjust the ``thresh`` parameter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The raw data.
|
||||
event_id : int
|
||||
The index to assign to found events.
|
||||
l_freq : float
|
||||
Low cut-off frequency to apply to the EOG channel in Hz.
|
||||
h_freq : float
|
||||
High cut-off frequency to apply to the EOG channel in Hz.
|
||||
filter_length : str | int | None
|
||||
Number of taps to use for filtering.
|
||||
%(ch_name_eog)s
|
||||
tstart : float
|
||||
Start detection after tstart seconds.
|
||||
reject_by_annotation : bool
|
||||
Whether to omit data that is annotated as bad.
|
||||
thresh : float | None
|
||||
Threshold to trigger the detection of an EOG event. This controls the
|
||||
thresholding of the underlying peak-finding algorithm. Larger values
|
||||
mean that fewer peaks (i.e., fewer EOG events) will be detected.
|
||||
If ``None``, use the default of ``(max(eog) - min(eog)) / 4``,
|
||||
with ``eog`` being the filtered EOG signal.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
eog_events : array
|
||||
Events.
|
||||
|
||||
See Also
|
||||
--------
|
||||
create_eog_epochs
|
||||
compute_proj_eog
|
||||
"""
|
||||
# Getting EOG Channel
|
||||
eog_inds = _get_eog_channel_index(ch_name, raw)
|
||||
eog_names = np.array(raw.ch_names)[eog_inds] # for logging
|
||||
logger.info(f"EOG channel index for this subject is: {eog_inds}")
|
||||
|
||||
# Reject bad segments.
|
||||
reject_by_annotation = "omit" if reject_by_annotation else None
|
||||
eog, times = raw.get_data(
|
||||
picks=eog_inds, reject_by_annotation=reject_by_annotation, return_times=True
|
||||
)
|
||||
times = times * raw.info["sfreq"] + raw.first_samp
|
||||
|
||||
eog_events = _find_eog_events(
|
||||
eog,
|
||||
ch_names=eog_names,
|
||||
event_id=event_id,
|
||||
l_freq=l_freq,
|
||||
h_freq=h_freq,
|
||||
sampling_rate=raw.info["sfreq"],
|
||||
first_samp=raw.first_samp,
|
||||
filter_length=filter_length,
|
||||
tstart=tstart,
|
||||
thresh=thresh,
|
||||
verbose=verbose,
|
||||
)
|
||||
# Map times to corresponding samples.
|
||||
eog_events[:, 0] = np.round(times[eog_events[:, 0] - raw.first_samp]).astype(int)
|
||||
return eog_events
|
||||
|
||||
|
||||
@verbose
|
||||
def _find_eog_events(
|
||||
eog,
|
||||
*,
|
||||
ch_names,
|
||||
event_id,
|
||||
l_freq,
|
||||
h_freq,
|
||||
sampling_rate,
|
||||
first_samp,
|
||||
filter_length="10s",
|
||||
tstart=0.0,
|
||||
thresh=None,
|
||||
verbose=None,
|
||||
):
|
||||
"""Find EOG events."""
|
||||
logger.info(
|
||||
"Filtering the data to remove DC offset to help "
|
||||
"distinguish blinks from saccades"
|
||||
)
|
||||
|
||||
# filtering to remove dc offset so that we know which is blink and saccades
|
||||
# hardcode verbose=False to suppress filter param messages (since this
|
||||
# filter is not under user control)
|
||||
fmax = np.minimum(45, sampling_rate / 2.0 - 0.75) # protect Nyquist
|
||||
filteog = np.array(
|
||||
[
|
||||
filter_data(
|
||||
x,
|
||||
sampling_rate,
|
||||
2,
|
||||
fmax,
|
||||
None,
|
||||
filter_length,
|
||||
0.5,
|
||||
0.5,
|
||||
phase="zero-double",
|
||||
fir_window="hann",
|
||||
fir_design="firwin2",
|
||||
verbose=False,
|
||||
)
|
||||
for x in eog
|
||||
]
|
||||
)
|
||||
temp = np.sqrt(np.sum(filteog**2, axis=1))
|
||||
indexmax = np.argmax(temp)
|
||||
if ch_names is not None: # it can be None if called from ica_find_eog_events
|
||||
logger.info(f"Selecting channel {ch_names[indexmax]} for blink detection")
|
||||
|
||||
# easier to detect peaks with filtering.
|
||||
filteog = filter_data(
|
||||
eog[indexmax],
|
||||
sampling_rate,
|
||||
l_freq,
|
||||
h_freq,
|
||||
None,
|
||||
filter_length,
|
||||
0.5,
|
||||
0.5,
|
||||
phase="zero-double",
|
||||
fir_window="hann",
|
||||
fir_design="firwin2",
|
||||
)
|
||||
|
||||
# detecting eog blinks and generating event file
|
||||
|
||||
logger.info("Now detecting blinks and generating corresponding events")
|
||||
|
||||
temp = filteog - np.mean(filteog)
|
||||
n_samples_start = int(sampling_rate * tstart)
|
||||
if np.abs(np.max(temp)) > np.abs(np.min(temp)):
|
||||
eog_events, _ = peak_finder(filteog[n_samples_start:], thresh, extrema=1)
|
||||
else:
|
||||
eog_events, _ = peak_finder(filteog[n_samples_start:], thresh, extrema=-1)
|
||||
|
||||
eog_events += n_samples_start
|
||||
n_events = len(eog_events)
|
||||
logger.info(f"Number of EOG events detected: {n_events}")
|
||||
eog_events = np.array(
|
||||
[
|
||||
eog_events + first_samp,
|
||||
np.zeros(n_events, int),
|
||||
event_id * np.ones(n_events, int),
|
||||
]
|
||||
).T
|
||||
|
||||
return eog_events
|
||||
|
||||
|
||||
def _get_eog_channel_index(ch_name, inst):
|
||||
"""Get EOG channel indices."""
|
||||
_validate_type(ch_name, types=(None, str, list), item_name="ch_name")
|
||||
|
||||
if ch_name is None:
|
||||
eog_inds = pick_types(
|
||||
inst.info,
|
||||
meg=False,
|
||||
eeg=False,
|
||||
stim=False,
|
||||
eog=True,
|
||||
ecg=False,
|
||||
emg=False,
|
||||
ref_meg=False,
|
||||
exclude="bads",
|
||||
)
|
||||
if eog_inds.size == 0:
|
||||
raise RuntimeError("No EOG channel(s) found")
|
||||
ch_names = [inst.ch_names[i] for i in eog_inds]
|
||||
elif isinstance(ch_name, str):
|
||||
ch_names = [ch_name]
|
||||
else: # it's a list
|
||||
ch_names = ch_name.copy()
|
||||
|
||||
# ensure the specified channels are present in the data
|
||||
if ch_name is not None:
|
||||
not_found = [ch_name for ch_name in ch_names if ch_name not in inst.ch_names]
|
||||
if not_found:
|
||||
raise ValueError(
|
||||
f"The specified EOG channel{_pl(not_found)} "
|
||||
f'cannot be found: {", ".join(not_found)}'
|
||||
)
|
||||
|
||||
eog_inds = pick_channels(inst.ch_names, include=ch_names)
|
||||
|
||||
logger.info(f'Using EOG channel{_pl(ch_names)}: {", ".join(ch_names)}')
|
||||
return eog_inds
|
||||
|
||||
|
||||
@verbose
|
||||
def create_eog_epochs(
|
||||
raw,
|
||||
ch_name=None,
|
||||
event_id=998,
|
||||
picks=None,
|
||||
tmin=-0.5,
|
||||
tmax=0.5,
|
||||
l_freq=1,
|
||||
h_freq=10,
|
||||
reject=None,
|
||||
flat=None,
|
||||
baseline=None,
|
||||
preload=True,
|
||||
reject_by_annotation=True,
|
||||
thresh=None,
|
||||
decim=1,
|
||||
verbose=None,
|
||||
):
|
||||
"""Conveniently generate epochs around EOG artifact events.
|
||||
|
||||
%(create_eog_epochs)s
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The raw data.
|
||||
%(ch_name_eog)s
|
||||
event_id : int
|
||||
The index to assign to found events.
|
||||
%(picks_all)s
|
||||
tmin : float
|
||||
Start time before event.
|
||||
tmax : float
|
||||
End time after event.
|
||||
l_freq : float
|
||||
Low pass frequency to apply to the EOG channel while finding events.
|
||||
h_freq : float
|
||||
High pass frequency to apply to the EOG channel while finding events.
|
||||
reject : dict | None
|
||||
Rejection parameters based on peak-to-peak amplitude.
|
||||
Valid keys are 'grad' | 'mag' | 'eeg' | 'eog' | 'ecg'.
|
||||
If reject is None then no rejection is done. Example::
|
||||
|
||||
reject = dict(grad=4000e-13, # T / m (gradiometers)
|
||||
mag=4e-12, # T (magnetometers)
|
||||
eeg=40e-6, # V (EEG channels)
|
||||
eog=250e-6 # V (EOG channels)
|
||||
)
|
||||
|
||||
flat : dict | None
|
||||
Rejection parameters based on flatness of signal.
|
||||
Valid keys are 'grad' | 'mag' | 'eeg' | 'eog' | 'ecg', and values
|
||||
are floats that set the minimum acceptable peak-to-peak amplitude.
|
||||
If flat is None then no rejection is done.
|
||||
baseline : tuple or list of length 2, or None
|
||||
The time interval to apply rescaling / baseline correction.
|
||||
If None do not apply it. If baseline is (a, b)
|
||||
the interval is between "a (s)" and "b (s)".
|
||||
If a is None the beginning of the data is used
|
||||
and if b is None then b is set to the end of the interval.
|
||||
If baseline is equal to (None, None) all the time
|
||||
interval is used. If None, no correction is applied.
|
||||
preload : bool
|
||||
Preload epochs or not.
|
||||
%(reject_by_annotation_epochs)s
|
||||
|
||||
.. versionadded:: 0.14.0
|
||||
thresh : float
|
||||
Threshold to trigger EOG event.
|
||||
%(decim)s
|
||||
|
||||
.. versionadded:: 0.21.0
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
eog_epochs : instance of Epochs
|
||||
Data epoched around EOG events.
|
||||
|
||||
See Also
|
||||
--------
|
||||
find_eog_events
|
||||
compute_proj_eog
|
||||
|
||||
Notes
|
||||
-----
|
||||
Filtering is only applied to the EOG channel while finding events.
|
||||
The resulting ``eog_epochs`` will have no filtering applied (i.e., have
|
||||
the same filter properties as the input ``raw`` instance).
|
||||
"""
|
||||
events = find_eog_events(
|
||||
raw,
|
||||
ch_name=ch_name,
|
||||
event_id=event_id,
|
||||
l_freq=l_freq,
|
||||
h_freq=h_freq,
|
||||
reject_by_annotation=reject_by_annotation,
|
||||
thresh=thresh,
|
||||
)
|
||||
|
||||
# create epochs around EOG events
|
||||
eog_epochs = Epochs(
|
||||
raw,
|
||||
events=events,
|
||||
event_id=event_id,
|
||||
tmin=tmin,
|
||||
tmax=tmax,
|
||||
proj=False,
|
||||
reject=reject,
|
||||
flat=flat,
|
||||
picks=picks,
|
||||
baseline=baseline,
|
||||
preload=preload,
|
||||
reject_by_annotation=reject_by_annotation,
|
||||
decim=decim,
|
||||
)
|
||||
return eog_epochs
|
||||
10
mne/preprocessing/eyetracking/__init__.py
Normal file
10
mne/preprocessing/eyetracking/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Eye tracking specific preprocessing functions."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from .eyetracking import set_channel_types_eyetrack, convert_units
|
||||
from .calibration import Calibration, read_eyelink_calibration
|
||||
from ._pupillometry import interpolate_blinks
|
||||
from .utils import get_screen_visual_angle
|
||||
121
mne/preprocessing/eyetracking/_pupillometry.py
Normal file
121
mne/preprocessing/eyetracking/_pupillometry.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..._fiff.constants import FIFF
|
||||
from ...annotations import _annotations_starts_stops
|
||||
from ...io import BaseRaw
|
||||
from ...utils import _check_preload, _validate_type, logger, warn
|
||||
|
||||
|
||||
def interpolate_blinks(raw, buffer=0.05, match="BAD_blink", interpolate_gaze=False):
|
||||
"""Interpolate eyetracking signals during blinks.
|
||||
|
||||
This function uses the timing of blink annotations to estimate missing
|
||||
data. Missing values are then interpolated linearly. Operates in place.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The raw data with at least one ``'pupil'`` or ``'eyegaze'`` channel.
|
||||
buffer : float | array-like of float, shape ``(2,))``
|
||||
The time in seconds before and after a blink to consider invalid and
|
||||
include in the segment to be interpolated over. Default is ``0.05`` seconds
|
||||
(50 ms). If array-like, the first element is the time before the blink and the
|
||||
second element is the time after the blink to consider invalid, for example,
|
||||
``(0.025, .1)``.
|
||||
match : str | list of str
|
||||
The description of annotations to interpolate over. If a list, the data within
|
||||
all annotations that match any of the strings in the list will be interpolated
|
||||
over. If a ``match`` starts with ``'BAD_'``, that part will be removed from the
|
||||
annotation description after interpolation. Defaults to ``'BAD_blink'``.
|
||||
interpolate_gaze : bool
|
||||
If False, only apply interpolation to ``'pupil channels'``. If True, interpolate
|
||||
over ``'eyegaze'`` channels as well. Defaults to False, because eye position can
|
||||
change in unpredictable ways during blinks.
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : instance of Raw
|
||||
Returns the modified instance.
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 1.5
|
||||
"""
|
||||
_check_preload(raw, "interpolate_blinks")
|
||||
_validate_type(raw, BaseRaw, "raw")
|
||||
_validate_type(buffer, (float, tuple, list, np.ndarray), "buffer")
|
||||
_validate_type(match, (str, tuple, list, np.ndarray), "match")
|
||||
|
||||
# determine the buffer around blinks to include in the interpolation
|
||||
buffer = np.array(buffer, dtype=float)
|
||||
if buffer.size == 1:
|
||||
buffer = np.array([buffer, buffer])
|
||||
|
||||
if isinstance(match, str):
|
||||
match = [match]
|
||||
|
||||
# get the blink annotations
|
||||
blink_annots = [annot for annot in raw.annotations if annot["description"] in match]
|
||||
if not blink_annots:
|
||||
warn(f"No annotations matching {match} found. Aborting.")
|
||||
return raw
|
||||
_interpolate_blinks(raw, buffer, blink_annots, interpolate_gaze=interpolate_gaze)
|
||||
|
||||
# remove bad from the annotation description
|
||||
for desc in match:
|
||||
if desc.startswith("BAD_"):
|
||||
logger.info(f"Removing 'BAD_' from {desc}.")
|
||||
raw.annotations.rename({desc: desc.replace("BAD_", "")})
|
||||
return raw
|
||||
|
||||
|
||||
def _interpolate_blinks(raw, buffer, blink_annots, interpolate_gaze):
|
||||
"""Interpolate eyetracking signals during blinks in-place."""
|
||||
logger.info("Interpolating missing data during blinks...")
|
||||
pre_buffer, post_buffer = buffer
|
||||
# iterate over each eyetrack channel and interpolate the blinks
|
||||
interpolated_chs = []
|
||||
for ci, ch_info in enumerate(raw.info["chs"]):
|
||||
if interpolate_gaze: # interpolate over all eyetrack channels
|
||||
if ch_info["kind"] != FIFF.FIFFV_EYETRACK_CH:
|
||||
continue
|
||||
else: # interpolate over pupil channels only
|
||||
if ch_info["coil_type"] != FIFF.FIFFV_COIL_EYETRACK_PUPIL:
|
||||
continue
|
||||
# Create an empty boolean mask
|
||||
mask = np.zeros_like(raw.times, dtype=bool)
|
||||
starts, ends = _annotations_starts_stops(raw, "BAD_blink")
|
||||
starts = np.divide(starts, raw.info["sfreq"])
|
||||
ends = np.divide(ends, raw.info["sfreq"])
|
||||
for annot, start, end in zip(blink_annots, starts, ends):
|
||||
if "ch_names" not in annot or not annot["ch_names"]:
|
||||
msg = f"Blink annotation missing values for 'ch_names' key: {annot}"
|
||||
raise ValueError(msg)
|
||||
start -= pre_buffer
|
||||
end += post_buffer
|
||||
if ch_info["ch_name"] not in annot["ch_names"]:
|
||||
continue # skip if the channel is not in the blink annotation
|
||||
# Update the mask for times within the current blink period
|
||||
mask |= (raw.times >= start) & (raw.times <= end)
|
||||
blink_indices = np.where(mask)[0]
|
||||
non_blink_indices = np.where(~mask)[0]
|
||||
|
||||
# Linear interpolation
|
||||
interpolated_samples = np.interp(
|
||||
raw.times[blink_indices],
|
||||
raw.times[non_blink_indices],
|
||||
raw._data[ci, non_blink_indices],
|
||||
)
|
||||
# Replace the samples at the blink_indices with the interpolated values
|
||||
raw._data[ci, blink_indices] = interpolated_samples
|
||||
interpolated_chs.append(ch_info["ch_name"])
|
||||
if interpolated_chs:
|
||||
logger.info(
|
||||
f"Interpolated {len(interpolated_chs)} channels: {interpolated_chs}"
|
||||
)
|
||||
else:
|
||||
warn("No channels were interpolated.")
|
||||
222
mne/preprocessing/eyetracking/calibration.py
Normal file
222
mne/preprocessing/eyetracking/calibration.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Eyetracking Calibration(s) class constructor."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...io.eyelink._utils import _parse_calibration
|
||||
from ...utils import _check_fname, _validate_type, fill_doc, logger
|
||||
from ...viz.utils import plt_show
|
||||
|
||||
|
||||
@fill_doc
|
||||
class Calibration(dict):
|
||||
"""Eye-tracking calibration info.
|
||||
|
||||
This data structure behaves like a dictionary. It contains information regarding a
|
||||
calibration that was conducted during an eye-tracking recording.
|
||||
|
||||
.. note::
|
||||
When possible, a Calibration instance should be created with a helper function,
|
||||
such as :func:`~mne.preprocessing.eyetracking.read_eyelink_calibration`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
onset : float
|
||||
The onset of the calibration in seconds. If the calibration was
|
||||
performed before the recording started, the the onset can be
|
||||
negative.
|
||||
model : str
|
||||
A string, which is the model of the eye-tracking calibration that was applied.
|
||||
For example ``'H3'`` for a horizontal only 3-point calibration, or ``'HV3'``
|
||||
for a horizontal and vertical 3-point calibration.
|
||||
eye : str
|
||||
The eye that was calibrated. For example, ``'left'``, or ``'right'``.
|
||||
avg_error : float
|
||||
The average error in degrees between the calibration positions and the
|
||||
actual gaze position.
|
||||
max_error : float
|
||||
The maximum error in degrees that occurred between the calibration
|
||||
positions and the actual gaze position.
|
||||
positions : array-like of float, shape ``(n_calibration_points, 2)``
|
||||
The x and y coordinates of the calibration points.
|
||||
offsets : array-like of float, shape ``(n_calibration_points,)``
|
||||
The error in degrees between the calibration position and the actual
|
||||
gaze position for each calibration point.
|
||||
gaze : array-like of float, shape ``(n_calibration_points, 2)``
|
||||
The x and y coordinates of the actual gaze position for each calibration point.
|
||||
screen_size : array-like of shape ``(2,)``
|
||||
The width and height (in meters) of the screen that the eyetracking
|
||||
data was collected with. For example ``(.531, .298)`` for a monitor with
|
||||
a display area of 531 x 298 mm.
|
||||
screen_distance : float
|
||||
The distance (in meters) from the participant's eyes to the screen.
|
||||
screen_resolution : array-like of shape ``(2,)``
|
||||
The resolution (in pixels) of the screen that the eyetracking data
|
||||
was collected with. For example, ``(1920, 1080)`` for a 1920x1080
|
||||
resolution display.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
onset,
|
||||
model,
|
||||
eye,
|
||||
avg_error,
|
||||
max_error,
|
||||
positions,
|
||||
offsets,
|
||||
gaze,
|
||||
screen_size=None,
|
||||
screen_distance=None,
|
||||
screen_resolution=None,
|
||||
):
|
||||
super().__init__(
|
||||
onset=onset,
|
||||
model=model,
|
||||
eye=eye,
|
||||
avg_error=avg_error,
|
||||
max_error=max_error,
|
||||
screen_size=screen_size,
|
||||
screen_distance=screen_distance,
|
||||
screen_resolution=screen_resolution,
|
||||
positions=positions,
|
||||
offsets=offsets,
|
||||
gaze=gaze,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""Return a summary of the Calibration object."""
|
||||
return (
|
||||
f"Calibration |\n"
|
||||
f" onset: {self['onset']} seconds\n"
|
||||
f" model: {self['model']}\n"
|
||||
f" eye: {self['eye']}\n"
|
||||
f" average error: {self['avg_error']} degrees\n"
|
||||
f" max error: {self['max_error']} degrees\n"
|
||||
f" screen size: {self['screen_size']} meters\n"
|
||||
f" screen distance: {self['screen_distance']} meters\n"
|
||||
f" screen resolution: {self['screen_resolution']} pixels\n"
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
"""Copy the instance.
|
||||
|
||||
Returns
|
||||
-------
|
||||
cal : instance of Calibration
|
||||
The copied Calibration.
|
||||
"""
|
||||
return deepcopy(self)
|
||||
|
||||
def plot(self, show_offsets=True, axes=None, show=True):
|
||||
"""Visualize calibration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
show_offsets : bool
|
||||
Whether to display the offset (in visual degrees) of each calibration
|
||||
point or not. Defaults to ``True``.
|
||||
axes : instance of matplotlib.axes.Axes | None
|
||||
Axes to draw the calibration positions to. If ``None`` (default), a new axes
|
||||
will be created.
|
||||
show : bool
|
||||
Whether to show the figure or not. Defaults to ``True``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
fig : instance of matplotlib.figure.Figure
|
||||
The resulting figure object for the calibration plot.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
msg = "positions and gaze keys must both be 2D numpy arrays."
|
||||
assert isinstance(self["positions"], np.ndarray), msg
|
||||
assert isinstance(self["gaze"], np.ndarray), msg
|
||||
|
||||
if axes is not None:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
_validate_type(axes, Axes, "axes")
|
||||
ax = axes
|
||||
fig = ax.get_figure()
|
||||
else: # create new figure and axes
|
||||
fig, ax = plt.subplots(layout="constrained")
|
||||
px, py = self["positions"].T
|
||||
gaze_x, gaze_y = self["gaze"].T
|
||||
|
||||
ax.set_title(f"Calibration ({self['eye']} eye)")
|
||||
ax.set_xlabel("x (pixels)")
|
||||
ax.set_ylabel("y (pixels)")
|
||||
|
||||
# Display avg_error and max_error in the top left corner
|
||||
text = (
|
||||
f"avg_error: {self['avg_error']} deg.\nmax_error: {self['max_error']} deg."
|
||||
)
|
||||
ax.text(
|
||||
0,
|
||||
1.01,
|
||||
text,
|
||||
transform=ax.transAxes,
|
||||
verticalalignment="baseline",
|
||||
fontsize=8,
|
||||
)
|
||||
|
||||
# Invert y-axis because the origin is in the top left corner
|
||||
ax.invert_yaxis()
|
||||
ax.scatter(px, py, color="gray")
|
||||
ax.scatter(gaze_x, gaze_y, color="red", alpha=0.5)
|
||||
|
||||
if show_offsets:
|
||||
for i in range(len(px)):
|
||||
x_offset = 0.01 * gaze_x[i] # 1% to the right of the gazepoint
|
||||
text = ax.text(
|
||||
x=gaze_x[i] + x_offset,
|
||||
y=gaze_y[i],
|
||||
s=self["offsets"][i],
|
||||
fontsize=8,
|
||||
ha="left",
|
||||
va="center",
|
||||
)
|
||||
|
||||
plt_show(show)
|
||||
return fig
|
||||
|
||||
|
||||
@fill_doc
|
||||
def read_eyelink_calibration(
|
||||
fname, screen_size=None, screen_distance=None, screen_resolution=None
|
||||
):
|
||||
"""Return info on calibrations collected in an eyelink file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : path-like
|
||||
Path to the eyelink file (.asc).
|
||||
screen_size : array-like of shape ``(2,)``
|
||||
The width and height (in meters) of the screen that the eyetracking
|
||||
data was collected with. For example ``(.531, .298)`` for a monitor with
|
||||
a display area of 531 x 298 mm. Defaults to ``None``.
|
||||
screen_distance : float
|
||||
The distance (in meters) from the participant's eyes to the screen.
|
||||
Defaults to ``None``.
|
||||
screen_resolution : array-like of shape ``(2,)``
|
||||
The resolution (in pixels) of the screen that the eyetracking data
|
||||
was collected with. For example, ``(1920, 1080)`` for a 1920x1080
|
||||
resolution display. Defaults to ``None``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
calibrations : list
|
||||
A list of :class:`~mne.preprocessing.eyetracking.Calibration` instances, one for
|
||||
each eye of every calibration that was performed during the recording session.
|
||||
"""
|
||||
fname = _check_fname(fname, overwrite="read", must_exist=True, name="fname")
|
||||
logger.info(f"Reading calibration data from {fname}")
|
||||
lines = fname.read_text(encoding="ASCII").splitlines()
|
||||
return _parse_calibration(lines, screen_size, screen_distance, screen_resolution)
|
||||
327
mne/preprocessing/eyetracking/eyetracking.py
Normal file
327
mne/preprocessing/eyetracking/eyetracking.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..._fiff.constants import FIFF
|
||||
from ...epochs import BaseEpochs
|
||||
from ...evoked import Evoked
|
||||
from ...io import BaseRaw
|
||||
from ...utils import _check_option, _validate_type, logger, warn
|
||||
from .calibration import Calibration
|
||||
from .utils import _check_calibration
|
||||
|
||||
|
||||
# specific function to set eyetrack channels
|
||||
def set_channel_types_eyetrack(inst, mapping):
|
||||
"""Define sensor type for eyetrack channels.
|
||||
|
||||
This function can set all eye tracking specific information:
|
||||
channel type, unit, eye (and x/y component; only for gaze channels)
|
||||
|
||||
Supported channel types:
|
||||
``'eyegaze'`` and ``'pupil'``
|
||||
|
||||
Supported units:
|
||||
``'au'``, ``'px'``, ``'deg'``, ``'rad'`` (for eyegaze)
|
||||
``'au'``, ``'mm'``, ``'m'`` (for pupil)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst : instance of Raw, Epochs, or Evoked
|
||||
The data instance.
|
||||
mapping : dict
|
||||
A dictionary mapping a channel to a list/tuple including
|
||||
channel type, unit, eye, [and x/y component] (all as str), e.g.,
|
||||
``{'l_x': ('eyegaze', 'deg', 'left', 'x')}`` or
|
||||
``{'r_pupil': ('pupil', 'au', 'right')}``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
inst : instance of Raw | Epochs | Evoked
|
||||
The instance, modified in place.
|
||||
|
||||
Notes
|
||||
-----
|
||||
``inst.set_channel_types()`` to ``'eyegaze'`` or ``'pupil'``
|
||||
works as well, but cannot correctly set unit, eye and x/y component.
|
||||
|
||||
Data will be stored in SI units:
|
||||
if your data comes in ``deg`` (visual angle) it will be converted to
|
||||
``rad``, if it is in ``mm`` it will be converted to ``m``.
|
||||
"""
|
||||
ch_names = inst.info["ch_names"]
|
||||
|
||||
# allowed
|
||||
valid_types = ["eyegaze", "pupil"] # ch_type
|
||||
valid_units = {
|
||||
"px": ["px", "pixel"],
|
||||
"rad": ["rad", "radian", "radians"],
|
||||
"deg": ["deg", "degree", "degrees"],
|
||||
"m": ["m", "meter", "meters"],
|
||||
"mm": ["mm", "millimeter", "millimeters"],
|
||||
"au": [None, "none", "au", "arbitrary"],
|
||||
}
|
||||
valid_units["all"] = [item for sublist in valid_units.values() for item in sublist]
|
||||
valid_eye = {"l": ["left", "l"], "r": ["right", "r"]}
|
||||
valid_eye["all"] = [item for sublist in valid_eye.values() for item in sublist]
|
||||
valid_xy = {"x": ["x", "h", "horizontal"], "y": ["y", "v", "vertical"]}
|
||||
valid_xy["all"] = [item for sublist in valid_xy.values() for item in sublist]
|
||||
|
||||
# loop over channels
|
||||
for ch_name, ch_desc in mapping.items():
|
||||
if ch_name not in ch_names:
|
||||
raise ValueError(f"This channel name ({ch_name}) doesn't exist in info.")
|
||||
c_ind = ch_names.index(ch_name)
|
||||
|
||||
# set ch_type and unit
|
||||
ch_type = ch_desc[0].lower()
|
||||
if ch_type not in valid_types:
|
||||
raise ValueError(
|
||||
f"ch_type must be one of {valid_types}. Got '{ch_type}' instead."
|
||||
)
|
||||
if ch_type == "eyegaze":
|
||||
coil_type = FIFF.FIFFV_COIL_EYETRACK_POS
|
||||
elif ch_type == "pupil":
|
||||
coil_type = FIFF.FIFFV_COIL_EYETRACK_PUPIL
|
||||
inst.info["chs"][c_ind]["coil_type"] = coil_type
|
||||
inst.info["chs"][c_ind]["kind"] = FIFF.FIFFV_EYETRACK_CH
|
||||
|
||||
ch_unit = None if (ch_desc[1] is None) else ch_desc[1].lower()
|
||||
if ch_unit not in valid_units["all"]:
|
||||
raise ValueError(
|
||||
"unit must be one of {}. Got '{}' instead.".format(
|
||||
valid_units["all"], ch_unit
|
||||
)
|
||||
)
|
||||
if ch_unit in valid_units["px"]:
|
||||
unit_new = FIFF.FIFF_UNIT_PX
|
||||
elif ch_unit in valid_units["rad"]:
|
||||
unit_new = FIFF.FIFF_UNIT_RAD
|
||||
elif ch_unit in valid_units["deg"]: # convert deg to rad (SI)
|
||||
inst = inst.apply_function(_convert_deg_to_rad, picks=ch_name)
|
||||
unit_new = FIFF.FIFF_UNIT_RAD
|
||||
elif ch_unit in valid_units["m"]:
|
||||
unit_new = FIFF.FIFF_UNIT_M
|
||||
elif ch_unit in valid_units["mm"]: # convert mm to m (SI)
|
||||
inst = inst.apply_function(_convert_mm_to_m, picks=ch_name)
|
||||
unit_new = FIFF.FIFF_UNIT_M
|
||||
elif ch_unit in valid_units["au"]:
|
||||
unit_new = FIFF.FIFF_UNIT_NONE
|
||||
inst.info["chs"][c_ind]["unit"] = unit_new
|
||||
|
||||
# set eye (and x/y-component)
|
||||
loc = np.array(
|
||||
[
|
||||
np.nan,
|
||||
np.nan,
|
||||
np.nan,
|
||||
np.nan,
|
||||
np.nan,
|
||||
np.nan,
|
||||
np.nan,
|
||||
np.nan,
|
||||
np.nan,
|
||||
np.nan,
|
||||
np.nan,
|
||||
np.nan,
|
||||
]
|
||||
)
|
||||
|
||||
ch_eye = ch_desc[2].lower()
|
||||
if ch_eye not in valid_eye["all"]:
|
||||
raise ValueError(
|
||||
"eye must be one of {}. Got '{}' instead.".format(
|
||||
valid_eye["all"], ch_eye
|
||||
)
|
||||
)
|
||||
if ch_eye in valid_eye["l"]:
|
||||
loc[3] = -1
|
||||
elif ch_eye in valid_eye["r"]:
|
||||
loc[3] = 1
|
||||
|
||||
if ch_type == "eyegaze":
|
||||
ch_xy = ch_desc[3].lower()
|
||||
if ch_xy not in valid_xy["all"]:
|
||||
raise ValueError(
|
||||
"x/y must be one of {}. Got '{}' instead.".format(
|
||||
valid_xy["all"], ch_xy
|
||||
)
|
||||
)
|
||||
if ch_xy in valid_xy["x"]:
|
||||
loc[4] = -1
|
||||
elif ch_xy in valid_xy["y"]:
|
||||
loc[4] = 1
|
||||
|
||||
inst.info["chs"][c_ind]["loc"] = loc
|
||||
|
||||
return inst
|
||||
|
||||
|
||||
def _convert_mm_to_m(array):
|
||||
return array * 0.001
|
||||
|
||||
|
||||
def _convert_deg_to_rad(array):
|
||||
return array * np.pi / 180.0
|
||||
|
||||
|
||||
def convert_units(inst, calibration, to="radians"):
|
||||
"""Convert Eyegaze data from pixels to radians of visual angle or vice versa.
|
||||
|
||||
.. warning::
|
||||
Currently, depending on the units (pixels or radians), eyegaze channels may not
|
||||
be reported correctly in visualization functions like :meth:`mne.io.Raw.plot`.
|
||||
They will be shown correctly in :func:`mne.viz.eyetracking.plot_gaze`.
|
||||
See :gh:`11879` for more information.
|
||||
|
||||
.. Important::
|
||||
There are important considerations to keep in mind when using this function,
|
||||
see the Notes section below.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst : instance of Raw, Epochs, or Evoked
|
||||
The Raw, Epochs, or Evoked instance with eyegaze channels.
|
||||
calibration : Calibration
|
||||
Instance of Calibration, containing information about the screen size
|
||||
(in meters), viewing distance (in meters), and the screen resolution
|
||||
(in pixels).
|
||||
to : str
|
||||
Must be either ``"radians"`` or ``"pixels"``, indicating the desired unit.
|
||||
|
||||
Returns
|
||||
-------
|
||||
inst : instance of Raw | Epochs | Evoked
|
||||
The Raw, Epochs, or Evoked instance, modified in place.
|
||||
|
||||
Notes
|
||||
-----
|
||||
There are at least two important considerations to keep in mind when using this
|
||||
function:
|
||||
|
||||
1. Converting between on-screen pixels and visual angle is not a linear
|
||||
transformation. If the visual angle subtends less than approximately ``.44``
|
||||
radians (``25`` degrees), the conversion could be considered to be approximately
|
||||
linear. However, as the visual angle increases, the conversion becomes
|
||||
increasingly non-linear. This may lead to unexpected results after converting
|
||||
between pixels and visual angle.
|
||||
|
||||
* This function assumes that the head is fixed in place and aligned with the center
|
||||
of the screen, such that gaze to the center of the screen results in a visual
|
||||
angle of ``0`` radians.
|
||||
|
||||
.. versionadded:: 1.7
|
||||
"""
|
||||
_validate_type(inst, (BaseRaw, BaseEpochs, Evoked), "inst")
|
||||
_validate_type(calibration, Calibration, "calibration")
|
||||
_check_option("to", to, ("radians", "pixels"))
|
||||
_check_calibration(calibration)
|
||||
|
||||
# get screen parameters
|
||||
screen_size = calibration["screen_size"]
|
||||
screen_resolution = calibration["screen_resolution"]
|
||||
dist = calibration["screen_distance"]
|
||||
|
||||
# loop through channels and convert units
|
||||
converted_chs = []
|
||||
for ch_dict in inst.info["chs"]:
|
||||
if ch_dict["coil_type"] != FIFF.FIFFV_COIL_EYETRACK_POS:
|
||||
continue
|
||||
unit = ch_dict["unit"]
|
||||
name = ch_dict["ch_name"]
|
||||
|
||||
if ch_dict["loc"][4] == -1: # x-coordinate
|
||||
size = screen_size[0]
|
||||
res = screen_resolution[0]
|
||||
elif ch_dict["loc"][4] == 1: # y-coordinate
|
||||
size = screen_size[1]
|
||||
res = screen_resolution[1]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"loc array not set properly for channel '{name}'. Index 4 should"
|
||||
f" be -1 or 1, but got {ch_dict['loc'][4]}"
|
||||
)
|
||||
# check unit, convert, and set new unit
|
||||
if to == "radians":
|
||||
if unit != FIFF.FIFF_UNIT_PX:
|
||||
raise ValueError(
|
||||
f"Data must be in pixels in order to convert to radians."
|
||||
f" Got {unit} for {name}"
|
||||
)
|
||||
inst.apply_function(_pix_to_rad, picks=name, size=size, res=res, dist=dist)
|
||||
ch_dict["unit"] = FIFF.FIFF_UNIT_RAD
|
||||
elif to == "pixels":
|
||||
if unit != FIFF.FIFF_UNIT_RAD:
|
||||
raise ValueError(
|
||||
f"Data must be in radians in order to convert to pixels."
|
||||
f" Got {unit} for {name}"
|
||||
)
|
||||
inst.apply_function(_rad_to_pix, picks=name, size=size, res=res, dist=dist)
|
||||
ch_dict["unit"] = FIFF.FIFF_UNIT_PX
|
||||
converted_chs.append(name)
|
||||
if converted_chs:
|
||||
logger.info(f"Converted {converted_chs} to {to}.")
|
||||
if to == "radians":
|
||||
# check if any values are greaater than .44 radians
|
||||
# (25 degrees) and warn user
|
||||
data = inst.get_data(picks=converted_chs)
|
||||
if np.any(np.abs(data) > 0.52):
|
||||
warn(
|
||||
"Some visual angle values subtend greater than .52 radians "
|
||||
"(30 degrees), meaning that the conversion between pixels "
|
||||
"and visual angle may be very non-linear. Take caution when "
|
||||
"interpreting these values. Max visual angle value in data:"
|
||||
f" {np.nanmax(data):0.2f} radians.",
|
||||
UserWarning,
|
||||
)
|
||||
else:
|
||||
warn("Could not find any eyegaze channels. Doing nothing.", UserWarning)
|
||||
return inst
|
||||
|
||||
|
||||
def _pix_to_rad(data, size, res, dist):
|
||||
"""Convert pixel coordinates to radians of visual angle.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array-like, shape (n_samples,)
|
||||
A vector of pixel coordinates.
|
||||
size : float
|
||||
The width or height of the screen, in meters.
|
||||
res : int
|
||||
The screen resolution in pixels, along the x or y axis.
|
||||
dist : float
|
||||
The viewing distance from the screen, in meters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rad : ndarray, shape (n_samples)
|
||||
the data in radians.
|
||||
"""
|
||||
# Center the data so that 0 radians will be the center of the screen
|
||||
data -= res / 2
|
||||
# How many meters is the pixel width or height
|
||||
px_size = size / res
|
||||
# Convert to radians
|
||||
return np.arctan((data * px_size) / dist)
|
||||
|
||||
|
||||
def _rad_to_pix(data, size, res, dist):
|
||||
"""Convert radians of visual angle to pixel coordinates.
|
||||
|
||||
See the parameters section of _pix_to_rad for more information.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pix : ndarray, shape (n_samples)
|
||||
the data in pixels.
|
||||
"""
|
||||
# How many meters is the pixel width or height
|
||||
px_size = size / res
|
||||
# 1. calculate length of opposite side of triangle (in meters)
|
||||
# 2. convert meters to pixel coordinates
|
||||
# 3. add half of screen resolution to uncenter the pixel data (0,0 is top left)
|
||||
return np.tan(data) * dist / px_size + res / 2
|
||||
45
mne/preprocessing/eyetracking/utils.py
Normal file
45
mne/preprocessing/eyetracking/utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...utils import _validate_type
|
||||
from .calibration import Calibration
|
||||
|
||||
|
||||
def _check_calibration(
|
||||
calibration, want_keys=("screen_size", "screen_resolution", "screen_distance")
|
||||
):
|
||||
missing_keys = []
|
||||
for key in want_keys:
|
||||
if calibration.get(key, None) is None:
|
||||
missing_keys.append(key)
|
||||
|
||||
if missing_keys:
|
||||
raise KeyError(
|
||||
"Calibration object must have the following keys with valid values:"
|
||||
f" {', '.join(missing_keys)}"
|
||||
)
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def get_screen_visual_angle(calibration):
|
||||
"""Calculate the radians of visual angle that the participant screen subtends.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
calibration : Calibration
|
||||
An instance of Calibration. Must have valid values for ``"screen_size"`` and
|
||||
``"screen_distance"`` keys.
|
||||
|
||||
Returns
|
||||
-------
|
||||
visual angle in radians : ndarray, shape (2,)
|
||||
The visual angle of the monitor width and height, respectively.
|
||||
"""
|
||||
_validate_type(calibration, Calibration, "calibration")
|
||||
_check_calibration(calibration, want_keys=("screen_size", "screen_distance"))
|
||||
size = np.array(calibration["screen_size"])
|
||||
return 2 * np.arctan(size / (2 * calibration["screen_distance"]))
|
||||
106
mne/preprocessing/hfc.py
Normal file
106
mne/preprocessing/hfc.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.pick import _picks_to_idx, pick_info
|
||||
from .._fiff.proj import Projection
|
||||
from ..utils import verbose
|
||||
from .maxwell import _prep_mf_coils, _sss_basis
|
||||
|
||||
|
||||
@verbose
|
||||
def compute_proj_hfc(
|
||||
info, order=1, picks="meg", exclude="bads", *, accuracy="accurate", verbose=None
|
||||
):
|
||||
"""Generate projectors to perform homogeneous/harmonic correction to data.
|
||||
|
||||
Remove environmental fields from magnetometer data by assuming it is
|
||||
explained as a homogeneous :footcite:`TierneyEtAl2021` or harmonic field
|
||||
:footcite:`TierneyEtAl2022`. Useful for arrays of OPMs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(info)s
|
||||
order : int
|
||||
The order of the spherical harmonic basis set to use. Set to 1 to use
|
||||
only the homogeneous field component (default), 2 to add gradients, 3
|
||||
to add quadrature terms, etc.
|
||||
picks : str | array_like | slice | None
|
||||
Channels to include. Default of ``'meg'`` (same as None) will select
|
||||
all non-reference MEG channels. Use ``('meg', 'ref_meg')`` to include
|
||||
reference sensors as well.
|
||||
exclude : list | 'bads'
|
||||
List of channels to exclude from HFC, only used when picking
|
||||
based on types (e.g., exclude="bads" when picks="meg").
|
||||
Specify ``'bads'`` (the default) to exclude all channels marked as bad.
|
||||
accuracy : str
|
||||
Can be ``"point"``, ``"normal"`` or ``"accurate"`` (default), defines
|
||||
which level of coil definition accuracy is used to generate model.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
%(projs)s
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.io.Raw.add_proj
|
||||
mne.io.Raw.apply_proj
|
||||
|
||||
Notes
|
||||
-----
|
||||
To apply the projectors to a dataset, use
|
||||
``inst.add_proj(projs).apply_proj()``.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
picks = _picks_to_idx(info, picks, none="meg", exclude=exclude, with_ref_meg=False)
|
||||
info = pick_info(info, picks)
|
||||
del picks
|
||||
exp = dict(origin=(0.0, 0.0, 0.0), int_order=0, ext_order=order)
|
||||
coils = _prep_mf_coils(info, ignore_ref=False, accuracy=accuracy)
|
||||
n_chs = len(coils[5])
|
||||
if n_chs != info["nchan"]:
|
||||
raise ValueError(
|
||||
f'Only {n_chs}/{info["nchan"]} picks could be interpreted '
|
||||
"as MEG channels."
|
||||
)
|
||||
S = _sss_basis(exp, coils)
|
||||
del coils
|
||||
bad_chans = [
|
||||
info["ch_names"][pick] for pick in np.where((~np.isfinite(S)).any(axis=1))[0]
|
||||
]
|
||||
if bad_chans:
|
||||
raise ValueError(
|
||||
"The following channel(s) generate non-finite projectors:\n"
|
||||
f" {bad_chans}\nPlease exclude from picks!"
|
||||
)
|
||||
S /= np.linalg.norm(S, axis=0)
|
||||
labels = _label_basis(order)
|
||||
assert len(labels) == S.shape[1]
|
||||
projs = []
|
||||
for label, vec in zip(labels, S.T):
|
||||
proj_data = dict(
|
||||
col_names=info["ch_names"],
|
||||
row_names=None,
|
||||
data=vec[np.newaxis, :],
|
||||
ncol=info["nchan"],
|
||||
nrow=1,
|
||||
)
|
||||
projs.append(Projection(active=False, data=proj_data, desc=label))
|
||||
return projs
|
||||
|
||||
|
||||
def _label_basis(order):
|
||||
"""Give basis vectors names for Projection() class."""
|
||||
return [
|
||||
f"HFC: l={L} m={m}"
|
||||
for L in np.arange(1, order + 1)
|
||||
for m in np.arange(-1 * L, L + 1)
|
||||
]
|
||||
3552
mne/preprocessing/ica.py
Normal file
3552
mne/preprocessing/ica.py
Normal file
File diff suppressed because it is too large
Load Diff
8
mne/preprocessing/ieeg/__init__.py
Normal file
8
mne/preprocessing/ieeg/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Intracranial EEG specific preprocessing functions."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from ._projection import project_sensors_onto_brain
|
||||
from ._volume import make_montage_volume, warp_montage
|
||||
209
mne/preprocessing/ieeg/_projection.py
Normal file
209
mne/preprocessing/ieeg/_projection.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from itertools import combinations
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.distance import pdist, squareform
|
||||
|
||||
from ..._fiff.pick import _picks_to_idx
|
||||
from ...channels import make_dig_montage
|
||||
from ...surface import (
|
||||
_compute_nearest,
|
||||
_read_mri_surface,
|
||||
_read_patch,
|
||||
fast_cross_3d,
|
||||
read_surface,
|
||||
)
|
||||
from ...transforms import _cart_to_sph, _ensure_trans, apply_trans, invert_transform
|
||||
from ...utils import _ensure_int, _validate_type, get_subjects_dir, verbose
|
||||
|
||||
|
||||
@verbose
|
||||
def project_sensors_onto_brain(
|
||||
info,
|
||||
trans,
|
||||
subject,
|
||||
subjects_dir=None,
|
||||
picks=None,
|
||||
n_neighbors=10,
|
||||
copy=True,
|
||||
verbose=None,
|
||||
):
|
||||
"""Project sensors onto the brain surface.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(info_not_none)s
|
||||
%(trans_not_none)s
|
||||
%(subject)s
|
||||
%(subjects_dir)s
|
||||
%(picks_base)s only ``ecog`` channels.
|
||||
n_neighbors : int
|
||||
The number of neighbors to use to compute the normal vectors
|
||||
for the projection. Must be 2 or greater. More neighbors makes
|
||||
a normal vector with greater averaging which preserves the grid
|
||||
structure. Fewer neighbors has less averaging which better
|
||||
preserves contours in the grid.
|
||||
copy : bool
|
||||
If ``True``, return a new instance of ``info``, if ``False``
|
||||
``info`` is modified in place.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
%(info_not_none)s
|
||||
|
||||
Notes
|
||||
-----
|
||||
This is useful in ECoG analysis for compensating for "brain shift"
|
||||
or shrinking of the brain away from the skull due to changes
|
||||
in pressure during the craniotomy.
|
||||
|
||||
To use the brain surface, a BEM model must be created e.g. using
|
||||
:ref:`mne watershed_bem` using the T1 or :ref:`mne flash_bem`
|
||||
using a FLASH scan.
|
||||
"""
|
||||
n_neighbors = _ensure_int(n_neighbors, "n_neighbors")
|
||||
_validate_type(copy, bool, "copy")
|
||||
if copy:
|
||||
info = info.copy()
|
||||
if n_neighbors < 2:
|
||||
raise ValueError(f"n_neighbors must be 2 or greater, got {n_neighbors}")
|
||||
subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
|
||||
try:
|
||||
surf = _read_mri_surface(subjects_dir / subject / "bem" / "brain.surf")
|
||||
except FileNotFoundError as err:
|
||||
raise RuntimeError(
|
||||
f"{err}\n\nThe brain surface requires generating "
|
||||
"a BEM using `mne flash_bem` (if you have "
|
||||
"the FLASH scan) or `mne watershed_bem` (to "
|
||||
"use the T1)"
|
||||
) from None
|
||||
# get channel locations
|
||||
picks_idx = _picks_to_idx(info, "ecog" if picks is None else picks)
|
||||
locs = np.array([info["chs"][idx]["loc"][:3] for idx in picks_idx])
|
||||
trans = _ensure_trans(trans, "head", "mri")
|
||||
locs = apply_trans(trans, locs)
|
||||
# compute distances for nearest neighbors
|
||||
dists = squareform(pdist(locs))
|
||||
# find angles for brain surface and points
|
||||
angles = _cart_to_sph(locs)
|
||||
surf_angles = _cart_to_sph(surf["rr"])
|
||||
# initialize projected locs
|
||||
proj_locs = np.zeros(locs.shape) * np.nan
|
||||
for i, loc in enumerate(locs):
|
||||
neighbor_pts = locs[np.argsort(dists[i])[: n_neighbors + 1]]
|
||||
pt1, pt2, pt3 = map(np.array, zip(*combinations(neighbor_pts, 3)))
|
||||
normals = fast_cross_3d(pt1 - pt2, pt1 - pt3)
|
||||
normals[normals @ loc < 0] *= -1
|
||||
normal = np.mean(normals, axis=0)
|
||||
normal /= np.linalg.norm(normal)
|
||||
# find the correct orientation brain surface point nearest the line
|
||||
# https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html
|
||||
use_rr = surf["rr"][
|
||||
abs(surf_angles[:, 1:] - angles[i, 1:]).sum(axis=1) < np.pi / 4
|
||||
]
|
||||
surf_dists = np.linalg.norm(
|
||||
fast_cross_3d(use_rr - loc, use_rr - loc + normal), axis=1
|
||||
)
|
||||
proj_locs[i] = use_rr[np.argmin(surf_dists)]
|
||||
# back to the "head" coordinate frame for storing in ``raw``
|
||||
proj_locs = apply_trans(invert_transform(trans), proj_locs)
|
||||
montage = info.get_montage()
|
||||
montage_kwargs = (
|
||||
montage.get_positions() if montage else dict(ch_pos=dict(), coord_frame="head")
|
||||
)
|
||||
for idx, loc in zip(picks_idx, proj_locs):
|
||||
# surface RAS-> head and mm->m
|
||||
montage_kwargs["ch_pos"][info.ch_names[idx]] = loc
|
||||
info.set_montage(make_dig_montage(**montage_kwargs))
|
||||
return info
|
||||
|
||||
|
||||
@verbose
|
||||
def _project_sensors_onto_inflated(
|
||||
info,
|
||||
trans,
|
||||
subject,
|
||||
subjects_dir=None,
|
||||
picks=None,
|
||||
max_dist=0.004,
|
||||
flat=False,
|
||||
verbose=None,
|
||||
):
|
||||
"""Project sensors onto the brain surface.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(info_not_none)s
|
||||
%(trans_not_none)s
|
||||
%(subject)s
|
||||
%(subjects_dir)s
|
||||
%(picks_base)s only ``seeg`` channels.
|
||||
%(max_dist_ieeg)s
|
||||
flat : bool
|
||||
Whether to project the sensors onto the flat map of the
|
||||
inflated brain instead of the normal inflated brain.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
%(info_not_none)s
|
||||
|
||||
Notes
|
||||
-----
|
||||
This is useful in sEEG analysis for visualization
|
||||
"""
|
||||
subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
|
||||
surf_data = dict(lh=dict(), rh=dict())
|
||||
x_dir = np.array([1.0, 0.0, 0.0])
|
||||
surfs = ("pial", "inflated")
|
||||
if flat:
|
||||
surfs += ("cortex.patch.flat",)
|
||||
for hemi in ("lh", "rh"):
|
||||
for surf in surfs:
|
||||
for img in ("", ".T1", ".T2", ""):
|
||||
surf_fname = subjects_dir / subject / "surf" / f"{hemi}.{surf}"
|
||||
if surf_fname.is_file():
|
||||
break
|
||||
if surf.split(".")[-1] == "flat":
|
||||
surf = "flat"
|
||||
coords, faces, orig_faces = _read_patch(surf_fname)
|
||||
# rotate 90 degrees to get to a more standard orientation
|
||||
# where X determines the distance between the hemis
|
||||
coords = coords[:, [1, 0, 2]]
|
||||
coords[:, 1] *= -1
|
||||
else:
|
||||
coords, faces = read_surface(surf_fname)
|
||||
if surf in ("inflated", "flat"):
|
||||
x_ = coords @ x_dir
|
||||
coords -= np.max(x_) * x_dir if hemi == "lh" else np.min(x_) * x_dir
|
||||
surf_data[hemi][surf] = (coords / 1000, faces) # mm -> m
|
||||
# get channel locations
|
||||
picks_idx = _picks_to_idx(info, "seeg" if picks is None else picks)
|
||||
locs = np.array([info["chs"][idx]["loc"][:3] for idx in picks_idx])
|
||||
trans = _ensure_trans(trans, "head", "mri")
|
||||
locs = apply_trans(trans, locs)
|
||||
# initialize projected locs
|
||||
proj_locs = np.zeros(locs.shape) * np.nan
|
||||
surf = "flat" if flat else "inflated"
|
||||
for hemi in ("lh", "rh"):
|
||||
hemi_picks = np.where(locs[:, 0] <= 0 if hemi == "lh" else locs[:, 0] > 0)[0]
|
||||
# compute distances to pial vertices
|
||||
nearest, dists = _compute_nearest(
|
||||
surf_data[hemi]["pial"][0], locs[hemi_picks], return_dists=True
|
||||
)
|
||||
mask = dists / 1000 < max_dist
|
||||
proj_locs[hemi_picks[mask]] = surf_data[hemi][surf][0][nearest[mask]]
|
||||
# back to the "head" coordinate frame for storing in ``raw``
|
||||
proj_locs = apply_trans(invert_transform(trans), proj_locs)
|
||||
montage = info.get_montage()
|
||||
montage_kwargs = (
|
||||
montage.get_positions() if montage else dict(ch_pos=dict(), coord_frame="head")
|
||||
)
|
||||
for idx, loc in zip(picks_idx, proj_locs):
|
||||
montage_kwargs["ch_pos"][info.ch_names[idx]] = loc
|
||||
info.set_montage(make_dig_montage(**montage_kwargs))
|
||||
return info
|
||||
242
mne/preprocessing/ieeg/_volume.py
Normal file
242
mne/preprocessing/ieeg/_volume.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...channels import DigMontage, make_dig_montage
|
||||
from ...surface import _voxel_neighbors
|
||||
from ...transforms import Transform, _frame_to_str, apply_trans
|
||||
from ...utils import _check_option, _pl, _require_version, _validate_type, verbose, warn
|
||||
|
||||
|
||||
@verbose
|
||||
def warp_montage(montage, moving, static, reg_affine, sdr_morph, verbose=None):
|
||||
"""Warp a montage to a template with image volumes using SDR.
|
||||
|
||||
.. note:: This is likely only applicable for channels inside the brain
|
||||
(intracranial electrodes).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
montage : instance of mne.channels.DigMontage
|
||||
The montage object containing the channels.
|
||||
%(moving)s
|
||||
%(static)s
|
||||
%(reg_affine)s
|
||||
%(sdr_morph)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
montage_warped : mne.channels.DigMontage
|
||||
The modified montage object containing the channels.
|
||||
"""
|
||||
_require_version("nibabel", "warp montage", "2.1.0")
|
||||
_require_version("dipy", "warping points using SDR", "1.6.0")
|
||||
|
||||
from dipy.align.imwarp import DiffeomorphicMap
|
||||
from nibabel import MGHImage
|
||||
from nibabel.spatialimages import SpatialImage
|
||||
|
||||
_validate_type(moving, SpatialImage, "moving")
|
||||
_validate_type(static, SpatialImage, "static")
|
||||
_validate_type(reg_affine, np.ndarray, "reg_affine")
|
||||
_check_option("reg_affine.shape", reg_affine.shape, ((4, 4),))
|
||||
_validate_type(sdr_morph, (DiffeomorphicMap, None), "sdr_morph")
|
||||
_validate_type(montage, DigMontage, "montage")
|
||||
|
||||
moving_mgh = MGHImage(np.array(moving.dataobj).astype(np.float32), moving.affine)
|
||||
static_mgh = MGHImage(np.array(static.dataobj).astype(np.float32), static.affine)
|
||||
del moving, static
|
||||
|
||||
# get montage channel coordinates
|
||||
ch_dict = montage.get_positions()
|
||||
if ch_dict["coord_frame"] != "mri":
|
||||
bad_coord_frames = np.unique([d["coord_frame"] for d in montage.dig])
|
||||
bad_coord_frames = ", ".join(
|
||||
[
|
||||
_frame_to_str[cf] if cf in _frame_to_str else str(cf)
|
||||
for cf in bad_coord_frames
|
||||
]
|
||||
)
|
||||
raise RuntimeError(
|
||||
f'Coordinate frame not supported, expected "mri", got {bad_coord_frames}'
|
||||
)
|
||||
ch_names = list(ch_dict["ch_pos"].keys())
|
||||
ch_coords = np.array([ch_dict["ch_pos"][name] for name in ch_names])
|
||||
|
||||
ch_coords = apply_trans( # convert to moving voxel space
|
||||
np.linalg.inv(moving_mgh.header.get_vox2ras_tkr()), ch_coords * 1000
|
||||
)
|
||||
# next, to moving scanner RAS
|
||||
ch_coords = apply_trans(moving_mgh.header.get_vox2ras(), ch_coords)
|
||||
|
||||
# now, apply reg_affine
|
||||
ch_coords = apply_trans(
|
||||
Transform( # to static ras
|
||||
fro="ras", to="ras", trans=np.linalg.inv(reg_affine)
|
||||
),
|
||||
ch_coords,
|
||||
)
|
||||
|
||||
# now, apply SDR morph
|
||||
if sdr_morph is not None:
|
||||
ch_coords = sdr_morph.transform_points(
|
||||
ch_coords,
|
||||
coord2world=sdr_morph.domain_grid2world,
|
||||
world2coord=sdr_morph.domain_world2grid,
|
||||
)
|
||||
|
||||
# back to voxels but now for the static image
|
||||
ch_coords = apply_trans(np.linalg.inv(static_mgh.header.get_vox2ras()), ch_coords)
|
||||
|
||||
# finally, back to surface RAS
|
||||
ch_coords = apply_trans(static_mgh.header.get_vox2ras_tkr(), ch_coords) / 1000
|
||||
|
||||
# make warped montage
|
||||
montage_warped = make_dig_montage(dict(zip(ch_names, ch_coords)), coord_frame="mri")
|
||||
return montage_warped
|
||||
|
||||
|
||||
def _warn_missing_chs(info, dig_image, after_warp=False):
|
||||
"""Warn that channels are missing."""
|
||||
# ensure that each electrode contact was marked in at least one voxel
|
||||
missing = set(np.arange(1, len(info.ch_names) + 1)).difference(
|
||||
set(np.unique(np.array(dig_image.dataobj)))
|
||||
)
|
||||
missing_ch = [info.ch_names[idx - 1] for idx in missing]
|
||||
if missing_ch:
|
||||
warn(
|
||||
f"Channel{_pl(missing_ch)} "
|
||||
f'{", ".join(repr(ch) for ch in missing_ch)} not assigned '
|
||||
"voxels " + (f" after applying {after_warp}" if after_warp else "")
|
||||
)
|
||||
|
||||
|
||||
@verbose
|
||||
def make_montage_volume(
|
||||
montage,
|
||||
base_image,
|
||||
thresh=0.5,
|
||||
max_peak_dist=1,
|
||||
voxels_max=100,
|
||||
use_min=False,
|
||||
verbose=None,
|
||||
):
|
||||
"""Make a volume from intracranial electrode contact locations.
|
||||
|
||||
Find areas of the input volume with intensity greater than
|
||||
a threshold surrounding local extrema near the channel location.
|
||||
Monotonicity from the peak is enforced to prevent channels
|
||||
bleeding into each other.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
montage : instance of mne.channels.DigMontage
|
||||
The montage object containing the channels.
|
||||
base_image : path-like | nibabel.spatialimages.SpatialImage
|
||||
Path to a volumetric scan (e.g. CT) of the subject. Can be in any
|
||||
format readable by nibabel. Can also be a nibabel image object.
|
||||
Local extrema (max or min) should be nearby montage channel locations.
|
||||
thresh : float
|
||||
The threshold relative to the peak to determine the size
|
||||
of the sensors on the volume.
|
||||
max_peak_dist : int
|
||||
The number of voxels away from the channel location to
|
||||
look in the ``image``. This will depend on the accuracy of
|
||||
the channel locations, the default (one voxel in all directions)
|
||||
will work only with localizations that are that accurate.
|
||||
voxels_max : int
|
||||
The maximum number of voxels for each channel.
|
||||
use_min : bool
|
||||
Whether to hypointensities in the volume as channel locations.
|
||||
Default False uses hyperintensities.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
elec_image : nibabel.spatialimages.SpatialImage
|
||||
An image in Freesurfer surface RAS space with voxel values
|
||||
corresponding to the index of the channel. The background
|
||||
is 0s and this index starts at 1.
|
||||
"""
|
||||
_require_version("nibabel", "montage volume", "2.1.0")
|
||||
import nibabel as nib
|
||||
|
||||
_validate_type(montage, DigMontage, "montage")
|
||||
_validate_type(base_image, nib.spatialimages.SpatialImage, "base_image")
|
||||
_validate_type(thresh, float, "thresh")
|
||||
if thresh < 0 or thresh >= 1:
|
||||
raise ValueError(f"`thresh` must be between 0 and 1, got {thresh}")
|
||||
_validate_type(max_peak_dist, int, "max_peak_dist")
|
||||
_validate_type(voxels_max, int, "voxels_max")
|
||||
_validate_type(use_min, bool, "use_min")
|
||||
|
||||
# load image and make sure it's in surface RAS
|
||||
if not isinstance(base_image, nib.spatialimages.SpatialImage):
|
||||
base_image = nib.load(base_image)
|
||||
|
||||
base_image_mgh = nib.MGHImage(
|
||||
np.array(base_image.dataobj).astype(np.float32), base_image.affine
|
||||
)
|
||||
del base_image
|
||||
|
||||
# get montage channel coordinates
|
||||
ch_dict = montage.get_positions()
|
||||
if ch_dict["coord_frame"] != "mri":
|
||||
bad_coord_frames = np.unique([d["coord_frame"] for d in montage.dig])
|
||||
bad_coord_frames = ", ".join(
|
||||
[
|
||||
_frame_to_str[cf] if cf in _frame_to_str else str(cf)
|
||||
for cf in bad_coord_frames
|
||||
]
|
||||
)
|
||||
raise RuntimeError(
|
||||
f'Coordinate frame not supported, expected "mri", got {bad_coord_frames}'
|
||||
)
|
||||
|
||||
ch_names = list(ch_dict["ch_pos"].keys())
|
||||
ch_coords = np.array([ch_dict["ch_pos"][name] for name in ch_names])
|
||||
|
||||
# convert to voxel space
|
||||
ch_coords = apply_trans(
|
||||
np.linalg.inv(base_image_mgh.header.get_vox2ras_tkr()), ch_coords * 1000
|
||||
)
|
||||
|
||||
# take channel coordinates and use the image to transform them
|
||||
# into a volume where all the voxels over a threshold nearby
|
||||
# are labeled with an index
|
||||
image_data = np.array(base_image_mgh.dataobj)
|
||||
if use_min:
|
||||
image_data *= -1
|
||||
elec_image = np.zeros(base_image_mgh.shape, dtype=int)
|
||||
for i, ch_coord in enumerate(ch_coords):
|
||||
if np.isnan(ch_coord).any():
|
||||
continue
|
||||
# this looks up to a voxel away, it may be marked imperfectly
|
||||
volume = _voxel_neighbors(
|
||||
ch_coord,
|
||||
image_data,
|
||||
thresh=thresh,
|
||||
max_peak_dist=max_peak_dist,
|
||||
voxels_max=voxels_max,
|
||||
)
|
||||
for voxel in volume:
|
||||
if elec_image[voxel] != 0:
|
||||
# some voxels ambiguous because the contacts are bridged on
|
||||
# the image so assign the voxel to the nearest contact location
|
||||
dist_old = np.sqrt(
|
||||
(ch_coords[elec_image[voxel] - 1] - voxel) ** 2
|
||||
).sum()
|
||||
dist_new = np.sqrt((ch_coord - voxel) ** 2).sum()
|
||||
if dist_new < dist_old:
|
||||
elec_image[voxel] = i + 1
|
||||
else:
|
||||
elec_image[voxel] = i + 1
|
||||
|
||||
# assemble the volume
|
||||
elec_image = nib.spatialimages.SpatialImage(elec_image, base_image_mgh.affine)
|
||||
_warn_missing_chs(montage, elec_image, after_warp=False)
|
||||
|
||||
return elec_image
|
||||
336
mne/preprocessing/infomax_.py
Normal file
336
mne/preprocessing/infomax_.py
Normal file
@@ -0,0 +1,336 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from scipy.special import expit
|
||||
from scipy.stats import kurtosis
|
||||
|
||||
from ..utils import check_random_state, logger, random_permutation, verbose
|
||||
|
||||
|
||||
@verbose
|
||||
def infomax(
|
||||
data,
|
||||
weights=None,
|
||||
l_rate=None,
|
||||
block=None,
|
||||
w_change=1e-12,
|
||||
anneal_deg=60.0,
|
||||
anneal_step=0.9,
|
||||
extended=True,
|
||||
n_subgauss=1,
|
||||
kurt_size=6000,
|
||||
ext_blocks=1,
|
||||
max_iter=200,
|
||||
random_state=None,
|
||||
blowup=1e4,
|
||||
blowup_fac=0.5,
|
||||
n_small_angle=20,
|
||||
use_bias=True,
|
||||
verbose=None,
|
||||
return_n_iter=False,
|
||||
):
|
||||
"""Run (extended) Infomax ICA decomposition on raw data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : np.ndarray, shape (n_samples, n_features)
|
||||
The whitened data to unmix.
|
||||
weights : np.ndarray, shape (n_features, n_features)
|
||||
The initialized unmixing matrix.
|
||||
Defaults to None, which means the identity matrix is used.
|
||||
l_rate : float
|
||||
This quantity indicates the relative size of the change in weights.
|
||||
Defaults to ``0.01 / log(n_features ** 2)``.
|
||||
|
||||
.. note:: Smaller learning rates will slow down the ICA procedure.
|
||||
|
||||
block : int
|
||||
The block size of randomly chosen data segments.
|
||||
Defaults to floor(sqrt(n_times / 3.)).
|
||||
w_change : float
|
||||
The change at which to stop iteration. Defaults to 1e-12.
|
||||
anneal_deg : float
|
||||
The angle (in degrees) at which the learning rate will be reduced.
|
||||
Defaults to 60.0.
|
||||
anneal_step : float
|
||||
The factor by which the learning rate will be reduced once
|
||||
``anneal_deg`` is exceeded: ``l_rate *= anneal_step.``
|
||||
Defaults to 0.9.
|
||||
extended : bool
|
||||
Whether to use the extended Infomax algorithm or not.
|
||||
Defaults to True.
|
||||
n_subgauss : int
|
||||
The number of subgaussian components. Only considered for extended
|
||||
Infomax. Defaults to 1.
|
||||
kurt_size : int
|
||||
The window size for kurtosis estimation. Only considered for extended
|
||||
Infomax. Defaults to 6000.
|
||||
ext_blocks : int
|
||||
Only considered for extended Infomax. If positive, denotes the number
|
||||
of blocks after which to recompute the kurtosis, which is used to
|
||||
estimate the signs of the sources. In this case, the number of
|
||||
sub-gaussian sources is automatically determined.
|
||||
If negative, the number of sub-gaussian sources to be used is fixed
|
||||
and equal to n_subgauss. In this case, the kurtosis is not estimated.
|
||||
Defaults to 1.
|
||||
max_iter : int
|
||||
The maximum number of iterations. Defaults to 200.
|
||||
%(random_state)s
|
||||
blowup : float
|
||||
The maximum difference allowed between two successive estimations of
|
||||
the unmixing matrix. Defaults to 10000.
|
||||
blowup_fac : float
|
||||
The factor by which the learning rate will be reduced if the difference
|
||||
between two successive estimations of the unmixing matrix exceededs
|
||||
``blowup``: ``l_rate *= blowup_fac``. Defaults to 0.5.
|
||||
n_small_angle : int | None
|
||||
The maximum number of allowed steps in which the angle between two
|
||||
successive estimations of the unmixing matrix is less than
|
||||
``anneal_deg``. If None, this parameter is not taken into account to
|
||||
stop the iterations. Defaults to 20.
|
||||
use_bias : bool
|
||||
This quantity indicates if the bias should be computed.
|
||||
Defaults to True.
|
||||
%(verbose)s
|
||||
return_n_iter : bool
|
||||
Whether to return the number of iterations performed. Defaults to
|
||||
False.
|
||||
|
||||
Returns
|
||||
-------
|
||||
unmixing_matrix : np.ndarray, shape (n_features, n_features)
|
||||
The linear unmixing operator.
|
||||
n_iter : int
|
||||
The number of iterations. Only returned if ``return_max_iter=True``.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] A. J. Bell, T. J. Sejnowski. An information-maximization approach to
|
||||
blind separation and blind deconvolution. Neural Computation, 7(6),
|
||||
1129-1159, 1995.
|
||||
.. [2] T. W. Lee, M. Girolami, T. J. Sejnowski. Independent component
|
||||
analysis using an extended infomax algorithm for mixed subgaussian
|
||||
and supergaussian sources. Neural Computation, 11(2), 417-441, 1999.
|
||||
"""
|
||||
rng = check_random_state(random_state)
|
||||
|
||||
# define some default parameters
|
||||
max_weight = 1e8
|
||||
restart_fac = 0.9
|
||||
min_l_rate = 1e-10
|
||||
degconst = 180.0 / np.pi
|
||||
|
||||
# for extended Infomax
|
||||
extmomentum = 0.5
|
||||
signsbias = 0.02
|
||||
signcount_threshold = 25
|
||||
signcount_step = 2
|
||||
|
||||
# check data shape
|
||||
n_samples, n_features = data.shape
|
||||
n_features_square = n_features**2
|
||||
|
||||
# check input parameters
|
||||
# heuristic default - may need adjustment for large or tiny data sets
|
||||
if l_rate is None:
|
||||
l_rate = 0.01 / math.log(n_features**2.0)
|
||||
|
||||
if block is None:
|
||||
block = int(math.floor(math.sqrt(n_samples / 3.0)))
|
||||
|
||||
logger.info(f"Computing{' Extended ' if extended else ' '}Infomax ICA")
|
||||
|
||||
# collect parameters
|
||||
nblock = n_samples // block
|
||||
lastt = (nblock - 1) * block + 1
|
||||
|
||||
# initialize training
|
||||
if weights is None:
|
||||
weights = np.identity(n_features, dtype=np.float64)
|
||||
else:
|
||||
weights = weights.T
|
||||
|
||||
BI = block * np.identity(n_features, dtype=np.float64)
|
||||
bias = np.zeros((n_features, 1), dtype=np.float64)
|
||||
onesrow = np.ones((1, block), dtype=np.float64)
|
||||
startweights = weights.copy()
|
||||
oldweights = startweights.copy()
|
||||
step = 0
|
||||
count_small_angle = 0
|
||||
wts_blowup = False
|
||||
blockno = 0
|
||||
signcount = 0
|
||||
initial_ext_blocks = ext_blocks # save the initial value in case of reset
|
||||
|
||||
# for extended Infomax
|
||||
if extended:
|
||||
signs = np.ones(n_features)
|
||||
|
||||
for k in range(n_subgauss):
|
||||
signs[k] = -1
|
||||
|
||||
kurt_size = min(kurt_size, n_samples)
|
||||
old_kurt = np.zeros(n_features, dtype=np.float64)
|
||||
oldsigns = np.zeros(n_features)
|
||||
|
||||
# trainings loop
|
||||
olddelta, oldchange = 1.0, 0.0
|
||||
while step < max_iter:
|
||||
# shuffle data at each step
|
||||
permute = random_permutation(n_samples, rng)
|
||||
|
||||
# ICA training block
|
||||
# loop across block samples
|
||||
for t in range(0, lastt, block):
|
||||
u = np.dot(data[permute[t : t + block], :], weights)
|
||||
u += np.dot(bias, onesrow).T
|
||||
|
||||
if extended:
|
||||
# extended ICA update
|
||||
y = np.tanh(u)
|
||||
weights += l_rate * np.dot(
|
||||
weights, BI - signs[None, :] * np.dot(u.T, y) - np.dot(u.T, u)
|
||||
)
|
||||
if use_bias:
|
||||
bias += l_rate * np.reshape(
|
||||
np.sum(y, axis=0, dtype=np.float64) * -2.0, (n_features, 1)
|
||||
)
|
||||
|
||||
else:
|
||||
# logistic ICA weights update
|
||||
y = expit(u)
|
||||
weights += l_rate * np.dot(weights, BI + np.dot(u.T, (1.0 - 2.0 * y)))
|
||||
|
||||
if use_bias:
|
||||
bias += l_rate * np.reshape(
|
||||
np.sum((1.0 - 2.0 * y), axis=0, dtype=np.float64),
|
||||
(n_features, 1),
|
||||
)
|
||||
|
||||
# check change limit
|
||||
max_weight_val = np.max(np.abs(weights))
|
||||
if max_weight_val > max_weight:
|
||||
wts_blowup = True
|
||||
|
||||
blockno += 1
|
||||
if wts_blowup:
|
||||
break
|
||||
|
||||
# ICA kurtosis estimation
|
||||
if extended:
|
||||
if ext_blocks > 0 and blockno % ext_blocks == 0:
|
||||
if kurt_size < n_samples:
|
||||
rp = np.floor(rng.uniform(0, 1, kurt_size) * (n_samples - 1))
|
||||
tpartact = np.dot(data[rp.astype(int), :], weights).T
|
||||
else:
|
||||
tpartact = np.dot(data, weights).T
|
||||
|
||||
# estimate kurtosis
|
||||
kurt = kurtosis(tpartact, axis=1, fisher=True)
|
||||
|
||||
if extmomentum != 0:
|
||||
kurt = extmomentum * old_kurt + (1.0 - extmomentum) * kurt
|
||||
old_kurt = kurt
|
||||
|
||||
# estimate weighted signs
|
||||
signs = np.sign(kurt + signsbias)
|
||||
|
||||
ndiff = (signs - oldsigns != 0).sum()
|
||||
if ndiff == 0:
|
||||
signcount += 1
|
||||
else:
|
||||
signcount = 0
|
||||
oldsigns = signs
|
||||
|
||||
if signcount >= signcount_threshold:
|
||||
ext_blocks = np.fix(ext_blocks * signcount_step)
|
||||
signcount = 0
|
||||
|
||||
# here we continue after the for loop over the ICA training blocks
|
||||
# if weights in bounds:
|
||||
if not wts_blowup:
|
||||
oldwtchange = weights - oldweights
|
||||
step += 1
|
||||
angledelta = 0.0
|
||||
delta = oldwtchange.reshape(1, n_features_square)
|
||||
change = np.sum(delta * delta, dtype=np.float64)
|
||||
if step > 2:
|
||||
angledelta = math.acos(
|
||||
np.sum(delta * olddelta) / math.sqrt(change * oldchange)
|
||||
)
|
||||
angledelta *= degconst
|
||||
|
||||
if verbose:
|
||||
logger.info(
|
||||
"step %d - lrate %5f, wchange %8.8f, angledelta %4.1f deg",
|
||||
step,
|
||||
l_rate,
|
||||
change,
|
||||
angledelta,
|
||||
)
|
||||
|
||||
# anneal learning rate
|
||||
oldweights = weights.copy()
|
||||
if angledelta > anneal_deg:
|
||||
l_rate *= anneal_step # anneal learning rate
|
||||
# accumulate angledelta until anneal_deg reaches l_rate
|
||||
olddelta = delta
|
||||
oldchange = change
|
||||
count_small_angle = 0 # reset count when angledelta is large
|
||||
else:
|
||||
if step == 1: # on first step only
|
||||
olddelta = delta # initialize
|
||||
oldchange = change
|
||||
|
||||
if n_small_angle is not None:
|
||||
count_small_angle += 1
|
||||
if count_small_angle > n_small_angle:
|
||||
max_iter = step
|
||||
|
||||
# apply stopping rule
|
||||
if step > 2 and change < w_change:
|
||||
step = max_iter
|
||||
elif change > blowup:
|
||||
l_rate *= blowup_fac
|
||||
|
||||
# restart if weights blow up (for lowering l_rate)
|
||||
else:
|
||||
step = 0 # start again
|
||||
wts_blowup = 0 # re-initialize variables
|
||||
blockno = 1
|
||||
l_rate *= restart_fac # with lower learning rate
|
||||
weights = startweights.copy()
|
||||
oldweights = startweights.copy()
|
||||
olddelta = np.zeros((1, n_features_square), dtype=np.float64)
|
||||
bias = np.zeros((n_features, 1), dtype=np.float64)
|
||||
|
||||
ext_blocks = initial_ext_blocks
|
||||
|
||||
# for extended Infomax
|
||||
if extended:
|
||||
signs = np.ones(n_features)
|
||||
for k in range(n_subgauss):
|
||||
signs[k] = -1
|
||||
oldsigns = np.zeros(n_features)
|
||||
|
||||
if l_rate > min_l_rate:
|
||||
if verbose:
|
||||
logger.info(
|
||||
f"... lowering learning rate to {l_rate:g}"
|
||||
"\n... re-starting..."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Error in Infomax ICA: unmixing_matrix matrix"
|
||||
"might not be invertible!"
|
||||
)
|
||||
|
||||
# prepare return values
|
||||
if return_n_iter:
|
||||
return weights.T, step
|
||||
else:
|
||||
return weights.T
|
||||
231
mne/preprocessing/interpolate.py
Normal file
231
mne/preprocessing/interpolate.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""Tools for data interpolation."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from itertools import chain
|
||||
|
||||
import numpy as np
|
||||
from scipy.sparse.csgraph import connected_components
|
||||
|
||||
from .._fiff.meas_info import create_info
|
||||
from ..epochs import BaseEpochs, EpochsArray
|
||||
from ..evoked import Evoked, EvokedArray
|
||||
from ..io import BaseRaw, RawArray
|
||||
from ..transforms import _cart_to_sph, _sph_to_cart
|
||||
from ..utils import _ensure_int, _validate_type
|
||||
|
||||
|
||||
def equalize_bads(insts, interp_thresh=1.0, copy=True):
|
||||
"""Interpolate or mark bads consistently for a list of instances.
|
||||
|
||||
Once called on a list of instances, the instances can be concatenated
|
||||
as they will have the same list of bad channels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
insts : list
|
||||
The list of instances (Evoked, Epochs or Raw) to consider
|
||||
for interpolation. Each instance should have marked channels.
|
||||
interp_thresh : float
|
||||
A float between 0 and 1 (default) that specifies the fraction of time
|
||||
a channel should be good to be eventually interpolated for certain
|
||||
instances. For example if 0.5, a channel which is good at least half
|
||||
of the time will be interpolated in the instances where it is marked
|
||||
as bad. If 1 then channels will never be interpolated and if 0 all bad
|
||||
channels will be systematically interpolated.
|
||||
copy : bool
|
||||
If True then the returned instances will be copies.
|
||||
|
||||
Returns
|
||||
-------
|
||||
insts_bads : list
|
||||
The list of instances, with the same channel(s) marked as bad in all of
|
||||
them, possibly with some formerly bad channels interpolated.
|
||||
"""
|
||||
if not 0 <= interp_thresh <= 1:
|
||||
raise ValueError(f"interp_thresh must be between 0 and 1, got {interp_thresh}")
|
||||
|
||||
all_bads = list(set(chain.from_iterable([inst.info["bads"] for inst in insts])))
|
||||
if isinstance(insts[0], BaseEpochs):
|
||||
durations = [len(inst) * len(inst.times) for inst in insts]
|
||||
else:
|
||||
durations = [len(inst.times) for inst in insts]
|
||||
|
||||
good_times = []
|
||||
for ch_name in all_bads:
|
||||
good_times.append(
|
||||
sum(
|
||||
durations[k]
|
||||
for k, inst in enumerate(insts)
|
||||
if ch_name not in inst.info["bads"]
|
||||
)
|
||||
/ np.sum(durations)
|
||||
)
|
||||
|
||||
bads_keep = [ch for k, ch in enumerate(all_bads) if good_times[k] < interp_thresh]
|
||||
if copy:
|
||||
insts = [inst.copy() for inst in insts]
|
||||
|
||||
for inst in insts:
|
||||
if len(set(inst.info["bads"]) - set(bads_keep)):
|
||||
inst.interpolate_bads(exclude=bads_keep)
|
||||
inst.info["bads"] = bads_keep
|
||||
|
||||
return insts
|
||||
|
||||
|
||||
def interpolate_bridged_electrodes(inst, bridged_idx, bad_limit=4):
|
||||
"""Interpolate bridged electrode pairs.
|
||||
|
||||
Because bridged electrodes contain brain signal, it's just that the
|
||||
signal is spatially smeared between the two electrodes, we can
|
||||
make a virtual channel midway between the bridged pairs and use
|
||||
that to aid in interpolation rather than completely discarding the
|
||||
data from the two channels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst : instance of Epochs, Evoked, or Raw
|
||||
The data object with channels that are to be interpolated.
|
||||
bridged_idx : list of tuple
|
||||
The indices of channels marked as bridged with each bridged
|
||||
pair stored as a tuple.
|
||||
bad_limit : int
|
||||
The maximum number of electrodes that can be bridged together
|
||||
(included) and interpolated. Above this number, an error will be
|
||||
raised.
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
Returns
|
||||
-------
|
||||
inst : instance of Epochs, Evoked, or Raw
|
||||
The modified data object.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.preprocessing.compute_bridged_electrodes
|
||||
"""
|
||||
_validate_type(inst, (BaseRaw, BaseEpochs, Evoked))
|
||||
bad_limit = _ensure_int(bad_limit, "bad_limit")
|
||||
if bad_limit <= 0:
|
||||
raise ValueError(
|
||||
"Argument 'bad_limit' should be a strictly positive "
|
||||
f"integer. Provided {bad_limit} is invalid."
|
||||
)
|
||||
montage = inst.get_montage()
|
||||
if montage is None:
|
||||
raise RuntimeError("No channel positions found in ``inst``")
|
||||
pos = montage.get_positions()
|
||||
if pos["coord_frame"] != "head":
|
||||
raise RuntimeError(
|
||||
f"Montage channel positions must be in ``head`` got {pos['coord_frame']}"
|
||||
)
|
||||
# store bads orig to put back at the end
|
||||
bads_orig = inst.info["bads"]
|
||||
inst.info["bads"] = list()
|
||||
|
||||
# look for group of bad channels
|
||||
nodes = sorted(set(chain(*bridged_idx)))
|
||||
G_dense = np.zeros((len(nodes), len(nodes)))
|
||||
# fill the edges with a weight of 1
|
||||
for bridge in bridged_idx:
|
||||
idx0 = np.searchsorted(nodes, bridge[0])
|
||||
idx1 = np.searchsorted(nodes, bridge[1])
|
||||
G_dense[idx0, idx1] = 1
|
||||
G_dense[idx1, idx0] = 1
|
||||
# look for connected components
|
||||
_, labels = connected_components(G_dense, directed=False)
|
||||
groups_idx = [[nodes[j] for j in np.where(labels == k)[0]] for k in set(labels)]
|
||||
groups_names = [
|
||||
[inst.info.ch_names[k] for k in group_idx] for group_idx in groups_idx
|
||||
]
|
||||
|
||||
# warn for all bridged areas that include too many electrodes
|
||||
for group_names in groups_names:
|
||||
if len(group_names) > bad_limit:
|
||||
raise RuntimeError(
|
||||
f"The channels {', '.join(group_names)} are bridged together "
|
||||
"and form a large area of bridged electrodes. Interpolation "
|
||||
"might be inaccurate."
|
||||
)
|
||||
|
||||
# make virtual channels
|
||||
virtual_chs = dict()
|
||||
bads = set()
|
||||
for k, group_idx in enumerate(groups_idx):
|
||||
group_names = [inst.info.ch_names[k] for k in group_idx]
|
||||
bads = bads.union(group_names)
|
||||
# compute centroid position in spherical "head" coordinates
|
||||
pos_virtual = _find_centroid_sphere(pos["ch_pos"], group_names)
|
||||
# create the virtual channel info and set the position
|
||||
virtual_info = create_info([f"virtual {k + 1}"], inst.info["sfreq"], "eeg")
|
||||
virtual_info["chs"][0]["loc"][:3] = pos_virtual
|
||||
# create virtual channel
|
||||
data = inst.get_data(picks=group_names)
|
||||
if isinstance(inst, BaseRaw):
|
||||
data = np.average(data, axis=0).reshape(1, -1)
|
||||
virtual_ch = RawArray(data, virtual_info, first_samp=inst.first_samp)
|
||||
elif isinstance(inst, BaseEpochs):
|
||||
data = np.average(data, axis=1).reshape(len(data), 1, -1)
|
||||
virtual_ch = EpochsArray(data, virtual_info, tmin=inst.tmin)
|
||||
else: # evoked
|
||||
data = np.average(data, axis=0).reshape(1, -1)
|
||||
virtual_ch = EvokedArray(
|
||||
np.average(data, axis=0).reshape(1, -1),
|
||||
virtual_info,
|
||||
tmin=inst.tmin,
|
||||
nave=inst.nave,
|
||||
kind=inst.kind,
|
||||
)
|
||||
virtual_chs[f"virtual {k + 1}"] = virtual_ch
|
||||
|
||||
# add the virtual channels
|
||||
inst.add_channels(list(virtual_chs.values()), force_update_info=True)
|
||||
|
||||
# use the virtual channels to interpolate
|
||||
inst.info["bads"] = list(bads)
|
||||
inst.interpolate_bads()
|
||||
|
||||
# drop virtual channels
|
||||
inst.drop_channels(list(virtual_chs.keys()))
|
||||
|
||||
inst.info["bads"] = bads_orig
|
||||
return inst
|
||||
|
||||
|
||||
def _find_centroid_sphere(ch_pos, group_names):
|
||||
"""Compute the centroid position between N electrodes.
|
||||
|
||||
The centroid should be determined in spherical "head" coordinates which is
|
||||
more accurante than cutting through the scalp by averaging in cartesian
|
||||
coordinates.
|
||||
|
||||
A simple way is to average the location in cartesian coordinate, convert
|
||||
to spehrical coordinate and replace the radius with the average radius of
|
||||
the N points in spherical coordinates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ch_pos : OrderedDict
|
||||
The position of all channels in cartesian coordinates.
|
||||
group_names : list | tuple
|
||||
The name of the N electrodes used to determine the centroid.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pos_centroid : array of shape (3,)
|
||||
The position of the centroid in cartesian coordinates.
|
||||
"""
|
||||
cartesian_positions = np.array([ch_pos[ch_name] for ch_name in group_names])
|
||||
sphere_positions = _cart_to_sph(cartesian_positions)
|
||||
cartesian_pos_centroid = np.average(cartesian_positions, axis=0)
|
||||
sphere_pos_centroid = _cart_to_sph(cartesian_pos_centroid)
|
||||
# average the radius and overwrite it
|
||||
avg_radius = np.average(sphere_positions, axis=0)[0]
|
||||
sphere_pos_centroid[0, 0] = avg_radius
|
||||
# convert back to cartesian
|
||||
pos_centroid = _sph_to_cart(sphere_pos_centroid)[0, :]
|
||||
return pos_centroid
|
||||
2893
mne/preprocessing/maxwell.py
Normal file
2893
mne/preprocessing/maxwell.py
Normal file
File diff suppressed because it is too large
Load Diff
22
mne/preprocessing/nirs/__init__.py
Normal file
22
mne/preprocessing/nirs/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""NIRS specific preprocessing functions."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from .nirs import (
|
||||
short_channels,
|
||||
source_detector_distances,
|
||||
_check_channels_ordered,
|
||||
_channel_frequencies,
|
||||
_fnirs_spread_bads,
|
||||
_channel_chromophore,
|
||||
_validate_nirs_info,
|
||||
_fnirs_optode_names,
|
||||
_optode_position,
|
||||
_reorder_nirx,
|
||||
)
|
||||
from ._optical_density import optical_density
|
||||
from ._beer_lambert_law import beer_lambert_law
|
||||
from ._scalp_coupling_index import scalp_coupling_index
|
||||
from ._tddr import temporal_derivative_distribution_repair, tddr
|
||||
115
mne/preprocessing/nirs/_beer_lambert_law.py
Normal file
115
mne/preprocessing/nirs/_beer_lambert_law.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import os.path as op
|
||||
|
||||
import numpy as np
|
||||
from scipy.interpolate import interp1d
|
||||
from scipy.io import loadmat
|
||||
|
||||
from ..._fiff.constants import FIFF
|
||||
from ...io import BaseRaw
|
||||
from ...utils import _validate_type, pinv, warn
|
||||
from ..nirs import _validate_nirs_info, source_detector_distances
|
||||
|
||||
|
||||
def beer_lambert_law(raw, ppf=6.0):
|
||||
r"""Convert NIRS optical density data to haemoglobin concentration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The optical density data.
|
||||
ppf : tuple | float
|
||||
The partial pathlength factors for each wavelength.
|
||||
|
||||
.. versionchanged:: 1.7
|
||||
Support for different factors for the two wavelengths.
|
||||
|
||||
Returns
|
||||
-------
|
||||
raw : instance of Raw
|
||||
The modified raw instance.
|
||||
"""
|
||||
raw = raw.copy().load_data()
|
||||
_validate_type(raw, BaseRaw, "raw")
|
||||
_validate_type(ppf, ("numeric", "array-like"), "ppf")
|
||||
ppf = np.array(ppf, float)
|
||||
if ppf.ndim == 0: # upcast single float to shape (2,)
|
||||
ppf = np.array([ppf, ppf])
|
||||
if ppf.shape != (2,):
|
||||
raise ValueError(
|
||||
f"ppf must be float or array-like of shape (2,), got shape {ppf.shape}"
|
||||
)
|
||||
ppf = ppf[:, np.newaxis] # shape (2, 1)
|
||||
picks = _validate_nirs_info(raw.info, fnirs="od", which="Beer-lambert")
|
||||
# This is the one place we *really* need the actual/accurate frequencies
|
||||
freqs = np.array([raw.info["chs"][pick]["loc"][9] for pick in picks], float)
|
||||
abs_coef = _load_absorption(freqs)
|
||||
distances = source_detector_distances(raw.info, picks="all")
|
||||
bad = ~np.isfinite(distances[picks])
|
||||
bad |= distances[picks] <= 0
|
||||
if bad.any():
|
||||
warn(
|
||||
"Source-detector distances are zero on NaN, some resulting "
|
||||
"concentrations will be zero. Consider setting a montage "
|
||||
"with raw.set_montage."
|
||||
)
|
||||
distances[picks[bad]] = 0.0
|
||||
if (distances[picks] > 0.1).any():
|
||||
warn(
|
||||
"Source-detector distances are greater than 10 cm. "
|
||||
"Large distances will result in invalid data, and are "
|
||||
"likely due to optode locations being stored in a "
|
||||
" unit other than meters."
|
||||
)
|
||||
rename = dict()
|
||||
for ii, jj in zip(picks[::2], picks[1::2]):
|
||||
EL = abs_coef * distances[ii] * ppf
|
||||
iEL = pinv(EL)
|
||||
|
||||
raw._data[[ii, jj]] = iEL @ raw._data[[ii, jj]] * 1e-3
|
||||
|
||||
# Update channel information
|
||||
coil_dict = dict(hbo=FIFF.FIFFV_COIL_FNIRS_HBO, hbr=FIFF.FIFFV_COIL_FNIRS_HBR)
|
||||
for ki, kind in zip((ii, jj), ("hbo", "hbr")):
|
||||
ch = raw.info["chs"][ki]
|
||||
ch.update(coil_type=coil_dict[kind], unit=FIFF.FIFF_UNIT_MOL)
|
||||
new_name = f'{ch["ch_name"].split(" ")[0]} {kind}'
|
||||
rename[ch["ch_name"]] = new_name
|
||||
raw.rename_channels(rename)
|
||||
|
||||
# Validate the format of data after transformation is valid
|
||||
_validate_nirs_info(raw.info, fnirs="hb")
|
||||
return raw
|
||||
|
||||
|
||||
def _load_absorption(freqs):
|
||||
"""Load molar extinction coefficients."""
|
||||
# Data from https://omlc.org/spectra/hemoglobin/summary.html
|
||||
# The text was copied to a text file. The text before and
|
||||
# after the table was deleted. The the following was run in
|
||||
# matlab
|
||||
# extinct_coef=importdata('extinction_coef.txt')
|
||||
# save('extinction_coef.mat', 'extinct_coef')
|
||||
#
|
||||
# Returns data as [[HbO2(freq1), Hb(freq1)],
|
||||
# [HbO2(freq2), Hb(freq2)]]
|
||||
extinction_fname = op.join(
|
||||
op.dirname(__file__), "..", "..", "data", "extinction_coef.mat"
|
||||
)
|
||||
a = loadmat(extinction_fname)["extinct_coef"]
|
||||
|
||||
interp_hbo = interp1d(a[:, 0], a[:, 1], kind="linear")
|
||||
interp_hb = interp1d(a[:, 0], a[:, 2], kind="linear")
|
||||
|
||||
ext_coef = np.array(
|
||||
[
|
||||
[interp_hbo(freqs[0]), interp_hb(freqs[0])],
|
||||
[interp_hbo(freqs[1]), interp_hb(freqs[1])],
|
||||
]
|
||||
)
|
||||
abs_coef = ext_coef * 0.2303
|
||||
|
||||
return abs_coef
|
||||
53
mne/preprocessing/nirs/_optical_density.py
Normal file
53
mne/preprocessing/nirs/_optical_density.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..._fiff.constants import FIFF
|
||||
from ...io import BaseRaw
|
||||
from ...utils import _validate_type, verbose, warn
|
||||
from ..nirs import _validate_nirs_info
|
||||
|
||||
|
||||
@verbose
|
||||
def optical_density(raw, *, verbose=None):
|
||||
r"""Convert NIRS raw data to optical density.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The raw data.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
raw : instance of Raw
|
||||
The modified raw instance.
|
||||
"""
|
||||
raw = raw.copy().load_data()
|
||||
_validate_type(raw, BaseRaw, "raw")
|
||||
picks = _validate_nirs_info(raw.info, fnirs="cw_amplitude")
|
||||
|
||||
# The devices measure light intensity. Negative light intensities should
|
||||
# not occur. If they do it is likely due to hardware or movement issues.
|
||||
# Set all negative values to abs(x), this also has the benefit of ensuring
|
||||
# that the means are all greater than zero for the division below.
|
||||
if np.any(raw._data[picks] <= 0):
|
||||
warn("Negative intensities encountered. Setting to abs(x)")
|
||||
min_ = np.inf
|
||||
for pi in picks:
|
||||
np.abs(raw._data[pi], out=raw._data[pi])
|
||||
min_ = min(min_, raw._data[pi].min() or min_)
|
||||
# avoid == 0
|
||||
for pi in picks:
|
||||
np.maximum(raw._data[pi], min_, out=raw._data[pi])
|
||||
|
||||
for pi in picks:
|
||||
data_mean = np.mean(raw._data[pi])
|
||||
raw._data[pi] /= data_mean
|
||||
np.log(raw._data[pi], out=raw._data[pi])
|
||||
raw._data[pi] *= -1
|
||||
raw.info["chs"][pi]["coil_type"] = FIFF.FIFFV_COIL_FNIRS_OD
|
||||
|
||||
return raw
|
||||
69
mne/preprocessing/nirs/_scalp_coupling_index.py
Normal file
69
mne/preprocessing/nirs/_scalp_coupling_index.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...io import BaseRaw
|
||||
from ...utils import _validate_type, verbose
|
||||
from ..nirs import _validate_nirs_info
|
||||
|
||||
|
||||
@verbose
|
||||
def scalp_coupling_index(
|
||||
raw,
|
||||
l_freq=0.7,
|
||||
h_freq=1.5,
|
||||
l_trans_bandwidth=0.3,
|
||||
h_trans_bandwidth=0.3,
|
||||
verbose=False,
|
||||
):
|
||||
r"""Calculate scalp coupling index.
|
||||
|
||||
This function calculates the scalp coupling index
|
||||
:footcite:`pollonini2014auditory`. This is a measure of the quality of the
|
||||
connection between the optode and the scalp.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The raw data.
|
||||
%(l_freq)s
|
||||
%(h_freq)s
|
||||
%(l_trans_bandwidth)s
|
||||
%(h_trans_bandwidth)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
sci : array of float
|
||||
Array containing scalp coupling index for each channel.
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
_validate_type(raw, BaseRaw, "raw")
|
||||
picks = _validate_nirs_info(raw.info, fnirs="od", which="Scalp coupling index")
|
||||
|
||||
raw = raw.copy().pick(picks).load_data()
|
||||
zero_mask = np.std(raw._data, axis=-1) == 0
|
||||
filtered_data = raw.filter(
|
||||
l_freq,
|
||||
h_freq,
|
||||
l_trans_bandwidth=l_trans_bandwidth,
|
||||
h_trans_bandwidth=h_trans_bandwidth,
|
||||
verbose=verbose,
|
||||
).get_data()
|
||||
|
||||
sci = np.zeros(picks.shape)
|
||||
for ii in range(0, len(picks), 2):
|
||||
with np.errstate(invalid="ignore"):
|
||||
c = np.corrcoef(filtered_data[ii], filtered_data[ii + 1])[0][1]
|
||||
if not np.isfinite(c): # someone had std=0
|
||||
c = 0
|
||||
sci[ii] = c
|
||||
sci[ii + 1] = c
|
||||
sci[zero_mask] = 0
|
||||
sci = sci[np.argsort(picks)] # restore original order
|
||||
return sci
|
||||
155
mne/preprocessing/nirs/_tddr.py
Normal file
155
mne/preprocessing/nirs/_tddr.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
|
||||
import numpy as np
|
||||
from scipy.signal import butter, filtfilt
|
||||
|
||||
from ...io import BaseRaw
|
||||
from ...utils import _validate_type, verbose
|
||||
from ..nirs import _validate_nirs_info
|
||||
|
||||
|
||||
@verbose
|
||||
def temporal_derivative_distribution_repair(raw, *, verbose=None):
|
||||
"""Apply temporal derivative distribution repair to data.
|
||||
|
||||
Applies temporal derivative distribution repair (TDDR) to data
|
||||
:footcite:`FishburnEtAl2019`. This approach removes baseline shift
|
||||
and spike artifacts without the need for any user-supplied parameters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The raw data.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
raw : instance of Raw
|
||||
Data with TDDR applied.
|
||||
|
||||
Notes
|
||||
-----
|
||||
TDDR was initially designed to be used on optical density fNIRS data but
|
||||
has been enabled to be applied on hemoglobin concentration fNIRS data as
|
||||
well in MNE. We recommend applying the algorithm to optical density fNIRS
|
||||
data as intended by the original author wherever possible.
|
||||
|
||||
There is a shorter alias ``mne.preprocessing.nirs.tddr`` that can be used
|
||||
instead of this function (e.g. if line length is an issue).
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
raw = raw.copy().load_data()
|
||||
_validate_type(raw, BaseRaw, "raw")
|
||||
picks = _validate_nirs_info(raw.info)
|
||||
|
||||
if not len(picks):
|
||||
raise RuntimeError("TDDR should be run on optical density or hemoglobin data.")
|
||||
for pick in picks:
|
||||
raw._data[pick] = _TDDR(raw._data[pick], raw.info["sfreq"])
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
# provide a short alias
|
||||
tddr = temporal_derivative_distribution_repair
|
||||
|
||||
|
||||
# Taken from https://github.com/frankfishburn/TDDR/ (MIT license).
|
||||
# With permission https://github.com/frankfishburn/TDDR/issues/1.
|
||||
# The only modification is the name, scipy signal import and flake fixes.
|
||||
def _TDDR(signal, sample_rate):
|
||||
# This function is the reference implementation for the TDDR algorithm for
|
||||
# motion correction of fNIRS data, as described in:
|
||||
#
|
||||
# Fishburn F.A., Ludlum R.S., Vaidya C.J., & Medvedev A.V. (2019).
|
||||
# Temporal Derivative Distribution Repair (TDDR): A motion correction
|
||||
# method for fNIRS. NeuroImage, 184, 171-179.
|
||||
# https://doi.org/10.1016/j.neuroimage.2018.09.025
|
||||
#
|
||||
# Usage:
|
||||
# signals_corrected = TDDR( signals , sample_rate );
|
||||
#
|
||||
# Inputs:
|
||||
# signals: A [sample x channel] matrix of uncorrected optical density or
|
||||
# hemoglobin data
|
||||
# sample_rate: A scalar reflecting the rate of acquisition in Hz
|
||||
#
|
||||
# Outputs:
|
||||
# signals_corrected: A [sample x channel] matrix of corrected optical
|
||||
# density data
|
||||
signal = np.array(signal)
|
||||
if len(signal.shape) != 1:
|
||||
for ch in range(signal.shape[1]):
|
||||
signal[:, ch] = _TDDR(signal[:, ch], sample_rate)
|
||||
return signal
|
||||
|
||||
# Preprocess: Separate high and low frequencies
|
||||
filter_cutoff = 0.5
|
||||
filter_order = 3
|
||||
Fc = filter_cutoff * 2 / sample_rate
|
||||
signal_mean = np.mean(signal)
|
||||
signal -= signal_mean
|
||||
if Fc < 1:
|
||||
fb, fa = butter(filter_order, Fc)
|
||||
signal_low = filtfilt(fb, fa, signal, padlen=0)
|
||||
else:
|
||||
signal_low = signal
|
||||
|
||||
signal_high = signal - signal_low
|
||||
|
||||
# Initialize
|
||||
tune = 4.685
|
||||
D = np.sqrt(np.finfo(signal.dtype).eps)
|
||||
mu = np.inf
|
||||
|
||||
# Step 1. Compute temporal derivative of the signal
|
||||
deriv = np.diff(signal_low)
|
||||
|
||||
# Step 2. Initialize observation weights
|
||||
w = np.ones(deriv.shape)
|
||||
|
||||
# Step 3. Iterative estimation of robust weights
|
||||
for _ in range(50):
|
||||
mu0 = mu
|
||||
|
||||
# Step 3a. Estimate weighted mean
|
||||
mu = np.sum(w * deriv) / np.sum(w)
|
||||
|
||||
# Step 3b. Calculate absolute residuals of estimate
|
||||
dev = np.abs(deriv - mu)
|
||||
|
||||
# Step 3c. Robust estimate of standard deviation of the residuals
|
||||
sigma = 1.4826 * np.median(dev)
|
||||
|
||||
# Step 3d. Scale deviations by standard deviation and tuning parameter
|
||||
if sigma == 0:
|
||||
break
|
||||
r = dev / (sigma * tune)
|
||||
|
||||
# Step 3e. Calculate new weights according to Tukey's biweight function
|
||||
w = ((1 - r**2) * (r < 1)) ** 2
|
||||
|
||||
# Step 3f. Terminate if new estimate is within
|
||||
# machine-precision of old estimate
|
||||
if abs(mu - mu0) < D * max(abs(mu), abs(mu0)):
|
||||
break
|
||||
|
||||
# Step 4. Apply robust weights to centered derivative
|
||||
new_deriv = w * (deriv - mu)
|
||||
|
||||
# Step 5. Integrate corrected derivative
|
||||
signal_low_corrected = np.cumsum(np.insert(new_deriv, 0, 0.0))
|
||||
|
||||
# Postprocess: Center the corrected signal
|
||||
signal_low_corrected = signal_low_corrected - np.mean(signal_low_corrected)
|
||||
|
||||
# Postprocess: Merge back with uncorrected high frequency component
|
||||
signal_corrected = signal_low_corrected + signal_high + signal_mean
|
||||
|
||||
return signal_corrected
|
||||
336
mne/preprocessing/nirs/nirs.py
Normal file
336
mne/preprocessing/nirs/nirs.py
Normal file
@@ -0,0 +1,336 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..._fiff.pick import _picks_to_idx, pick_types
|
||||
from ...utils import _check_option, _validate_type, fill_doc
|
||||
|
||||
# Standardized fNIRS channel name regexs
|
||||
_S_D_F_RE = re.compile(r"S(\d+)_D(\d+) (\d+\.?\d*)")
|
||||
_S_D_H_RE = re.compile(r"S(\d+)_D(\d+) (\w+)")
|
||||
|
||||
|
||||
@fill_doc
|
||||
def source_detector_distances(info, picks=None):
|
||||
r"""Determine the distance between NIRS source and detectors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(info_not_none)s
|
||||
%(picks_all_data)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
dists : array of float
|
||||
Array containing distances in meters.
|
||||
Of shape equal to number of channels, or shape of picks if supplied.
|
||||
"""
|
||||
return np.array(
|
||||
[
|
||||
np.linalg.norm(
|
||||
np.diff(info["chs"][pick]["loc"][3:9].reshape(2, 3), axis=0)[0]
|
||||
)
|
||||
for pick in _picks_to_idx(info, picks, exclude=[])
|
||||
],
|
||||
float,
|
||||
)
|
||||
|
||||
|
||||
@fill_doc
|
||||
def short_channels(info, threshold=0.01):
|
||||
r"""Determine which NIRS channels are short.
|
||||
|
||||
Channels with a source to detector distance of less than
|
||||
``threshold`` are reported as short. The default threshold is 0.01 m.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(info_not_none)s
|
||||
threshold : float
|
||||
The threshold distance for what is considered short in meters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
short : array of bool
|
||||
Array indicating which channels are short.
|
||||
Of shape equal to number of channels.
|
||||
"""
|
||||
return source_detector_distances(info) < threshold
|
||||
|
||||
|
||||
def _channel_frequencies(info):
|
||||
"""Return the light frequency for each channel."""
|
||||
# Only valid for fNIRS data before conversion to haemoglobin
|
||||
picks = _picks_to_idx(
|
||||
info, ["fnirs_cw_amplitude", "fnirs_od"], exclude=[], allow_empty=True
|
||||
)
|
||||
freqs = list()
|
||||
for pick in picks:
|
||||
freqs.append(round(float(_S_D_F_RE.match(info["ch_names"][pick]).groups()[2])))
|
||||
return np.array(freqs, int)
|
||||
|
||||
|
||||
def _channel_chromophore(info):
|
||||
"""Return the chromophore of each channel."""
|
||||
# Only valid for fNIRS data after conversion to haemoglobin
|
||||
picks = _picks_to_idx(info, ["hbo", "hbr"], exclude=[], allow_empty=True)
|
||||
chroma = []
|
||||
for ii in picks:
|
||||
chroma.append(info["ch_names"][ii].split(" ")[1])
|
||||
return chroma
|
||||
|
||||
|
||||
def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=True):
|
||||
"""Check channels follow expected fNIRS format.
|
||||
|
||||
If the channels are correctly ordered then an array of valid picks
|
||||
will be returned.
|
||||
|
||||
If throw_errors is True then any errors in fNIRS formatting will be
|
||||
thrown to inform the user. If throw_errors is False then an empty array
|
||||
will be returned if the channels are not sufficiently formatted.
|
||||
"""
|
||||
# Every second channel should be same SD pair
|
||||
# and have the specified light frequencies.
|
||||
|
||||
# All wavelength based fNIRS data.
|
||||
picks_wave = _picks_to_idx(
|
||||
info, ["fnirs_cw_amplitude", "fnirs_od"], exclude=[], allow_empty=True
|
||||
)
|
||||
# All chromophore fNIRS data
|
||||
picks_chroma = _picks_to_idx(info, ["hbo", "hbr"], exclude=[], allow_empty=True)
|
||||
|
||||
if (len(picks_wave) > 0) & (len(picks_chroma) > 0):
|
||||
picks = _throw_or_return_empty(
|
||||
"MNE does not support a combination of amplitude, optical "
|
||||
"density, and haemoglobin data in the same raw structure.",
|
||||
throw_errors,
|
||||
)
|
||||
|
||||
# All continuous wave fNIRS data
|
||||
if len(picks_wave):
|
||||
error_word = "frequencies"
|
||||
use_RE = _S_D_F_RE
|
||||
picks = picks_wave
|
||||
else:
|
||||
error_word = "chromophore"
|
||||
use_RE = _S_D_H_RE
|
||||
picks = picks_chroma
|
||||
|
||||
pair_vals = np.array(pair_vals)
|
||||
if pair_vals.shape != (2,):
|
||||
raise ValueError(
|
||||
f"Exactly two {error_word} must exist in info, got {list(pair_vals)}"
|
||||
)
|
||||
# In principle we do not need to require that these be sorted --
|
||||
# all we need to do is change our sorted() below to make use of a
|
||||
# pair_vals.index(...) in a sort key -- but in practice we always want
|
||||
# (hbo, hbr) or (lower_freq, upper_freq) pairings, both of which will
|
||||
# work with a naive string sort, so let's just enforce sorted-ness here
|
||||
is_str = pair_vals.dtype.kind == "U"
|
||||
pair_vals = list(pair_vals)
|
||||
if is_str:
|
||||
if pair_vals != ["hbo", "hbr"]:
|
||||
raise ValueError(
|
||||
f'The {error_word} in info must be ["hbo", "hbr"], but got '
|
||||
f"{pair_vals} instead"
|
||||
)
|
||||
elif not np.array_equal(np.unique(pair_vals), pair_vals):
|
||||
raise ValueError(
|
||||
f"The {error_word} in info must be unique and sorted, but got "
|
||||
f"got {pair_vals} instead"
|
||||
)
|
||||
|
||||
if len(picks) % 2 != 0:
|
||||
picks = _throw_or_return_empty(
|
||||
"NIRS channels not ordered correctly. An even number of NIRS "
|
||||
f"channels is required. {len(info.ch_names)} channels were"
|
||||
f"provided",
|
||||
throw_errors,
|
||||
)
|
||||
|
||||
# Ensure wavelength info exists for waveform data
|
||||
all_freqs = [info["chs"][ii]["loc"][9] for ii in picks_wave]
|
||||
if np.any(np.isnan(all_freqs)):
|
||||
picks = _throw_or_return_empty(
|
||||
f"NIRS channels is missing wavelength information in the "
|
||||
f'info["chs"] structure. The encoded wavelengths are {all_freqs}.',
|
||||
throw_errors,
|
||||
)
|
||||
|
||||
# Validate the channel naming scheme
|
||||
for pick in picks:
|
||||
ch_name_info = use_RE.match(info["chs"][pick]["ch_name"])
|
||||
if not bool(ch_name_info):
|
||||
picks = _throw_or_return_empty(
|
||||
"NIRS channels have specified naming conventions. "
|
||||
"The provided channel name can not be parsed: "
|
||||
f"{repr(info.ch_names[pick])}",
|
||||
throw_errors,
|
||||
)
|
||||
break
|
||||
value = ch_name_info.groups()[2]
|
||||
if len(picks_wave):
|
||||
value = value
|
||||
else: # picks_chroma
|
||||
if value not in ["hbo", "hbr"]:
|
||||
picks = _throw_or_return_empty(
|
||||
"NIRS channels have specified naming conventions."
|
||||
"Chromophore data must be labeled either hbo or hbr. "
|
||||
f"The failing channel is {info['chs'][pick]['ch_name']}",
|
||||
throw_errors,
|
||||
)
|
||||
break
|
||||
|
||||
# Reorder to be paired (naive sort okay here given validation above)
|
||||
picks = picks[np.argsort([info["ch_names"][pick] for pick in picks])]
|
||||
|
||||
# Validate our paired ordering
|
||||
for ii, jj in zip(picks[::2], picks[1::2]):
|
||||
ch1_name = info["chs"][ii]["ch_name"]
|
||||
ch2_name = info["chs"][jj]["ch_name"]
|
||||
ch1_re = use_RE.match(ch1_name)
|
||||
ch2_re = use_RE.match(ch2_name)
|
||||
ch1_S, ch1_D, ch1_value = ch1_re.groups()[:3]
|
||||
ch2_S, ch2_D, ch2_value = ch2_re.groups()[:3]
|
||||
if len(picks_wave):
|
||||
ch1_value, ch2_value = float(ch1_value), float(ch2_value)
|
||||
if (
|
||||
(ch1_S != ch2_S)
|
||||
or (ch1_D != ch2_D)
|
||||
or (ch1_value != pair_vals[0])
|
||||
or (ch2_value != pair_vals[1])
|
||||
):
|
||||
picks = _throw_or_return_empty(
|
||||
"NIRS channels not ordered correctly. Channels must be "
|
||||
"ordered as source detector pairs with alternating"
|
||||
f" {error_word} {pair_vals[0]} & {pair_vals[1]}, but got "
|
||||
f"S{ch1_S}_D{ch1_D} pair "
|
||||
f"{repr(ch1_name)} and {repr(ch2_name)}",
|
||||
throw_errors,
|
||||
)
|
||||
break
|
||||
|
||||
if check_bads:
|
||||
for ii, jj in zip(picks[::2], picks[1::2]):
|
||||
want = [info.ch_names[ii], info.ch_names[jj]]
|
||||
got = list(set(info["bads"]).intersection(want))
|
||||
if len(got) == 1:
|
||||
raise RuntimeError(
|
||||
f"NIRS bad labelling is not consistent, found {got} but "
|
||||
f"needed {want}"
|
||||
)
|
||||
return picks
|
||||
|
||||
|
||||
def _throw_or_return_empty(msg, throw_errors):
|
||||
if throw_errors:
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def _validate_nirs_info(
|
||||
info,
|
||||
*,
|
||||
throw_errors=True,
|
||||
fnirs=None,
|
||||
which=None,
|
||||
check_bads=True,
|
||||
allow_empty=True,
|
||||
):
|
||||
"""Apply all checks to fNIRS info. Works on all continuous wave types."""
|
||||
_validate_type(fnirs, (None, str), "fnirs")
|
||||
kinds = dict(
|
||||
od="optical density",
|
||||
cw_amplitude="continuous wave",
|
||||
hb="chromophore",
|
||||
)
|
||||
_check_option("fnirs", fnirs, (None,) + tuple(kinds))
|
||||
if fnirs is not None:
|
||||
kind = kinds[fnirs]
|
||||
fnirs = ["hbo", "hbr"] if fnirs == "hb" else f"fnirs_{fnirs}"
|
||||
if not len(pick_types(info, fnirs=fnirs)):
|
||||
raise RuntimeError(
|
||||
f"{which} must operate on {kind} data, but none was found."
|
||||
)
|
||||
freqs = np.unique(_channel_frequencies(info))
|
||||
if freqs.size > 0:
|
||||
pair_vals = freqs
|
||||
else:
|
||||
pair_vals = np.unique(_channel_chromophore(info))
|
||||
out = _check_channels_ordered(
|
||||
info, pair_vals, throw_errors=throw_errors, check_bads=check_bads
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _fnirs_spread_bads(info):
|
||||
"""Spread bad labeling across fnirs channels."""
|
||||
# For an optode pair if any component (light frequency or chroma) is marked
|
||||
# as bad, then they all should be. This function will find any pairs marked
|
||||
# as bad and spread the bad marking to all components of the optode pair.
|
||||
picks = _validate_nirs_info(info, check_bads=False)
|
||||
new_bads = set(info["bads"])
|
||||
for ii, jj in zip(picks[::2], picks[1::2]):
|
||||
ch1_name, ch2_name = info.ch_names[ii], info.ch_names[jj]
|
||||
if ch1_name in new_bads:
|
||||
new_bads.add(ch2_name)
|
||||
elif ch2_name in new_bads:
|
||||
new_bads.add(ch1_name)
|
||||
info["bads"] = sorted(new_bads)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def _fnirs_optode_names(info):
|
||||
"""Return list of unique optode names."""
|
||||
picks_wave = _picks_to_idx(
|
||||
info, ["fnirs_cw_amplitude", "fnirs_od"], exclude=[], allow_empty=True
|
||||
)
|
||||
picks_chroma = _picks_to_idx(info, ["hbo", "hbr"], exclude=[], allow_empty=True)
|
||||
|
||||
if len(picks_wave) > 0:
|
||||
regex = _S_D_F_RE
|
||||
elif len(picks_chroma) > 0:
|
||||
regex = _S_D_H_RE
|
||||
else:
|
||||
return [], []
|
||||
|
||||
sources = np.unique([int(regex.match(ch).groups()[0]) for ch in info.ch_names])
|
||||
detectors = np.unique([int(regex.match(ch).groups()[1]) for ch in info.ch_names])
|
||||
|
||||
src_names = [f"S{s}" for s in sources]
|
||||
det_names = [f"D{d}" for d in detectors]
|
||||
|
||||
return src_names, det_names
|
||||
|
||||
|
||||
def _optode_position(info, optode):
|
||||
"""Find the position of an optode."""
|
||||
idx = [optode in a for a in info.ch_names].index(True)
|
||||
|
||||
if "S" in optode:
|
||||
loc_idx = range(3, 6)
|
||||
elif "D" in optode:
|
||||
loc_idx = range(6, 9)
|
||||
|
||||
return info["chs"][idx]["loc"][loc_idx]
|
||||
|
||||
|
||||
def _reorder_nirx(raw):
|
||||
# Maybe someday we should make this public like
|
||||
# mne.preprocessing.nirs.reorder_standard(raw, order='nirx')
|
||||
info = raw.info
|
||||
picks = pick_types(info, fnirs=True, exclude=[])
|
||||
prefixes = [info["ch_names"][pick].split()[0] for pick in picks]
|
||||
nirs_names = [info["ch_names"][pick] for pick in picks]
|
||||
nirs_sorted = sorted(
|
||||
nirs_names,
|
||||
key=lambda name: (prefixes.index(name.split()[0]), name.split(maxsplit=1)[1]),
|
||||
)
|
||||
raw.reorder_channels(nirs_sorted)
|
||||
140
mne/preprocessing/otp.py
Normal file
140
mne/preprocessing/otp.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.pick import _picks_to_idx
|
||||
from .._ola import _COLA, _Storer
|
||||
from ..surface import _normalize_vectors
|
||||
from ..utils import logger, verbose
|
||||
|
||||
|
||||
def _svd_cov(cov, data):
|
||||
"""Use a covariance matrix to compute the SVD faster."""
|
||||
# This makes use of mathematical equivalences between PCA and SVD
|
||||
# on zero-mean data
|
||||
s, u = np.linalg.eigh(cov)
|
||||
norm = np.ones((s.size,))
|
||||
mask = s > np.finfo(float).eps * s[-1] # largest is last
|
||||
s = np.sqrt(s, out=s)
|
||||
norm[mask] = 1.0 / s[mask]
|
||||
u *= norm
|
||||
v = np.dot(u.T[mask], data)
|
||||
return u, s, v
|
||||
|
||||
|
||||
@verbose
|
||||
def oversampled_temporal_projection(raw, duration=10.0, picks=None, verbose=None):
|
||||
"""Denoise MEG channels using leave-one-out temporal projection.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
Raw data to denoise.
|
||||
duration : float | str
|
||||
The window duration (in seconds; default 10.) to use. Can also
|
||||
be "min" to use as short a window as possible.
|
||||
%(picks_all_data)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
raw_clean : instance of Raw
|
||||
The cleaned data.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This algorithm is computationally expensive, and can be several times
|
||||
slower than realtime for conventional M/EEG datasets. It uses a
|
||||
leave-one-out procedure with parallel temporal projection to remove
|
||||
individual sensor noise under the assumption that sampled fields
|
||||
(e.g., MEG and EEG) are oversampled by the sensor array
|
||||
:footcite:`LarsonTaulu2018`.
|
||||
|
||||
OTP can improve sensor noise levels (especially under visual
|
||||
inspection) and repair some bad channels. This noise reduction is known
|
||||
to interact with :func:`tSSS <mne.preprocessing.maxwell_filter>` such
|
||||
that increasing the ``st_correlation`` value will likely be necessary.
|
||||
|
||||
Channels marked as bad will not be used to reconstruct good channels,
|
||||
but good channels will be used to process the bad channels. Depending
|
||||
on the type of noise present in the bad channels, this might make
|
||||
them usable again.
|
||||
|
||||
Use of this algorithm is covered by a provisional patent.
|
||||
|
||||
.. versionadded:: 0.16
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
logger.info("Processing MEG data using oversampled temporal projection")
|
||||
picks = _picks_to_idx(raw.info, picks, exclude=())
|
||||
picks_good, picks_bad = list(), list() # these are indices into picks
|
||||
for ii, pi in enumerate(picks):
|
||||
if raw.ch_names[pi] in raw.info["bads"]:
|
||||
picks_bad.append(ii)
|
||||
else:
|
||||
picks_good.append(ii)
|
||||
picks_good = np.array(picks_good, int)
|
||||
picks_bad = np.array(picks_bad, int)
|
||||
|
||||
n_samples = int(round(float(duration) * raw.info["sfreq"]))
|
||||
if n_samples < len(picks_good) - 1:
|
||||
raise ValueError(
|
||||
f"duration ({n_samples / raw.info['sfreq']}) yielded {n_samples} samples, "
|
||||
f"which is fewer than the number of channels -1 ({len(picks_good) - 1})"
|
||||
)
|
||||
n_overlap = n_samples // 2
|
||||
raw_otp = raw.copy().load_data(verbose=False)
|
||||
otp = _COLA(
|
||||
partial(_otp, picks_good=picks_good, picks_bad=picks_bad),
|
||||
_Storer(raw_otp._data, picks=picks),
|
||||
len(raw.times),
|
||||
n_samples,
|
||||
n_overlap,
|
||||
raw.info["sfreq"],
|
||||
)
|
||||
read_lims = list(range(0, len(raw.times), n_samples)) + [len(raw.times)]
|
||||
for start, stop in zip(read_lims[:-1], read_lims[1:]):
|
||||
logger.info(
|
||||
f" Denoising {raw.times[[start, stop - 1]][0]: 8.2f} – "
|
||||
f"{raw.times[[start, stop - 1]][1]: 8.2f} s"
|
||||
)
|
||||
otp.feed(raw[picks, start:stop][0])
|
||||
return raw_otp
|
||||
|
||||
|
||||
def _otp(data, picks_good, picks_bad):
|
||||
"""Perform OTP on one segment of data."""
|
||||
if not np.isfinite(data).all():
|
||||
raise RuntimeError("non-finite data (inf or nan) found in raw instance")
|
||||
# demean our data
|
||||
data_means = np.mean(data, axis=-1, keepdims=True)
|
||||
data -= data_means
|
||||
# make a copy
|
||||
data_good = data[picks_good]
|
||||
# scale the copy that will be used to form the temporal basis vectors
|
||||
# so that _orth_svdvals thresholding should work properly with
|
||||
# different channel types (e.g., M-EEG)
|
||||
norms = _normalize_vectors(data_good)
|
||||
cov = np.dot(data_good, data_good.T)
|
||||
if len(picks_bad) > 0:
|
||||
full_basis = _svd_cov(cov, data_good)[2]
|
||||
for mi, pick in enumerate(picks_good):
|
||||
# operate on original data
|
||||
idx = list(range(mi)) + list(range(mi + 1, len(data_good)))
|
||||
# Equivalent: svd(data[idx], full_matrices=False)[2]
|
||||
t_basis = _svd_cov(cov[np.ix_(idx, idx)], data_good[idx])[2]
|
||||
x = np.dot(np.dot(data_good[mi], t_basis.T), t_basis)
|
||||
x *= norms[mi]
|
||||
x += data_means[pick]
|
||||
data[pick] = x
|
||||
for pick in picks_bad:
|
||||
data[pick] = np.dot(np.dot(data[pick], full_basis.T), full_basis)
|
||||
data[pick] += data_means[pick]
|
||||
return [data]
|
||||
132
mne/preprocessing/realign.py
Normal file
132
mne/preprocessing/realign.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
from numpy.polynomial.polynomial import Polynomial
|
||||
from scipy.stats import pearsonr
|
||||
|
||||
from ..io import BaseRaw
|
||||
from ..utils import _validate_type, logger, verbose, warn
|
||||
|
||||
|
||||
@verbose
|
||||
def realign_raw(raw, other, t_raw, t_other, *, verbose=None):
|
||||
"""Realign two simultaneous recordings.
|
||||
|
||||
Due to clock drift, recordings at a given same sample rate made by two
|
||||
separate devices simultaneously can become out of sync over time. This
|
||||
function uses event times captured by both acquisition devices to resample
|
||||
``other`` to match ``raw``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The first raw instance.
|
||||
other : instance of Raw
|
||||
The second raw instance. It will be resampled to match ``raw``.
|
||||
t_raw : array-like, shape (n_events,)
|
||||
The times of shared events in ``raw`` relative to ``raw.times[0]`` (0).
|
||||
Typically these could be events on some TTL channel such as::
|
||||
|
||||
find_events(raw)[:, 0] / raw.info["sfreq"] - raw.first_time
|
||||
t_other : array-like, shape (n_events,)
|
||||
The times of shared events in ``other`` relative to ``other.times[0]``.
|
||||
%(verbose)s
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function operates inplace. It will:
|
||||
|
||||
1. Estimate the zero-order (start offset) and first-order (clock drift)
|
||||
correction.
|
||||
2. Crop the start of ``raw`` or ``other``, depending on which started
|
||||
recording first.
|
||||
3. Resample ``other`` to match ``raw`` based on the clock drift.
|
||||
4. Realign the onsets and durations in ``other.annotations``.
|
||||
5. Crop the end of ``raw`` or ``other``, depending on which stopped
|
||||
recording first (and the clock drift rate).
|
||||
|
||||
This function is primarily designed to work on recordings made at the same
|
||||
sample rate, but it can also operate on recordings made at different
|
||||
sample rates to resample and deal with clock drift simultaneously.
|
||||
|
||||
.. versionadded:: 0.22
|
||||
"""
|
||||
_validate_type(raw, BaseRaw, "raw")
|
||||
_validate_type(other, BaseRaw, "other")
|
||||
t_raw = np.array(t_raw, float)
|
||||
t_other = np.array(t_other, float)
|
||||
if t_raw.ndim != 1 or t_raw.shape != t_other.shape:
|
||||
raise ValueError(
|
||||
"t_raw and t_other must be 1D with the same shape, "
|
||||
f"got shapes {t_raw.shape} and {t_other.shape}"
|
||||
)
|
||||
if len(t_raw) < 20:
|
||||
warn("Fewer than 20 times passed, results may be unreliable")
|
||||
|
||||
# 1. Compute correction factors
|
||||
poly = Polynomial.fit(x=t_other, y=t_raw, deg=1)
|
||||
converted = poly.convert(domain=(-1, 1))
|
||||
[zero_ord, first_ord] = converted.coef
|
||||
logger.info(
|
||||
f"Zero order coefficient: {zero_ord} \nFirst order coefficient: {first_ord}"
|
||||
)
|
||||
r, p = pearsonr(t_other, t_raw)
|
||||
msg = f"Linear correlation computed as R={r:0.3f} and p={p:0.2e}"
|
||||
if p > 0.05 or r <= 0:
|
||||
raise ValueError(msg + ", cannot resample safely")
|
||||
if p > 1e-6:
|
||||
warn(msg + ", results may be unreliable")
|
||||
else:
|
||||
logger.info(msg)
|
||||
dr_ms_s = 1000 * abs(1 - first_ord)
|
||||
logger.info(
|
||||
f"Drift rate: {1000 * dr_ms_s:0.1f} μs/s "
|
||||
f"(total drift over {raw.times[-1]:0.1f} s recording: "
|
||||
f"{raw.times[-1] * dr_ms_s:0.1f} ms)"
|
||||
)
|
||||
|
||||
# 2. Crop start of recordings to match
|
||||
if zero_ord > 0: # need to crop start of raw to match other
|
||||
logger.info(f"Cropping {zero_ord:0.3f} s from the start of raw")
|
||||
raw.crop(zero_ord, None)
|
||||
t_raw -= zero_ord
|
||||
elif zero_ord < 0: # need to crop start of other to match raw
|
||||
t_crop = -zero_ord / first_ord
|
||||
logger.info(f"Cropping {t_crop:0.3f} s from the start of other")
|
||||
other.crop(t_crop, None)
|
||||
t_other -= t_crop
|
||||
|
||||
# 3. Resample data using the first-order term
|
||||
nan_ch_names = [
|
||||
ch for ch in other.info["ch_names"] if np.isnan(other.get_data(picks=ch)).any()
|
||||
]
|
||||
if len(nan_ch_names) > 0: # Issue warning if any channel in other has nan values
|
||||
warn(
|
||||
f"Channel(s) {', '.join(nan_ch_names)} in `other` contain NaN values. "
|
||||
"Resampling these channels will result in the whole channel being NaN. "
|
||||
"(If realigning eye-tracking data, consider using interpolate_blinks and "
|
||||
"passing interpolate_gaze=True)"
|
||||
)
|
||||
logger.info("Resampling other")
|
||||
sfreq_new = raw.info["sfreq"] * first_ord
|
||||
other.load_data().resample(sfreq_new)
|
||||
with other.info._unlock():
|
||||
other.info["sfreq"] = raw.info["sfreq"]
|
||||
|
||||
# 4. Realign the onsets and durations in other.annotations
|
||||
# Must happen before end cropping to avoid losing annotations
|
||||
logger.info("Correcting annotations in other")
|
||||
other.annotations.onset *= first_ord
|
||||
other.annotations.duration *= first_ord
|
||||
|
||||
# 5. Crop the end of one of the recordings if necessary
|
||||
delta = raw.times[-1] - other.times[-1]
|
||||
msg = f"Cropping {abs(delta):0.3f} s from the end of "
|
||||
if delta > 0:
|
||||
logger.info(msg + "raw")
|
||||
raw.crop(0, other.times[-1])
|
||||
elif delta < 0:
|
||||
logger.info(msg + "other")
|
||||
other.crop(0, raw.times[-1])
|
||||
605
mne/preprocessing/ssp.py
Normal file
605
mne/preprocessing/ssp.py
Normal file
@@ -0,0 +1,605 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import copy as cp
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.pick import pick_types
|
||||
from .._fiff.reference import make_eeg_average_ref_proj
|
||||
from ..epochs import Epochs
|
||||
from ..proj import compute_proj_epochs, compute_proj_evoked
|
||||
from ..utils import _validate_type, logger, verbose, warn
|
||||
from .ecg import find_ecg_events
|
||||
from .eog import find_eog_events
|
||||
|
||||
|
||||
def _safe_del_key(dict_, key):
|
||||
"""Aux function.
|
||||
|
||||
Use this function when preparing rejection parameters
|
||||
instead of directly deleting keys.
|
||||
"""
|
||||
if key in dict_:
|
||||
del dict_[key]
|
||||
|
||||
|
||||
def _compute_exg_proj(
|
||||
mode,
|
||||
raw,
|
||||
raw_event,
|
||||
tmin,
|
||||
tmax,
|
||||
n_grad,
|
||||
n_mag,
|
||||
n_eeg,
|
||||
l_freq,
|
||||
h_freq,
|
||||
average,
|
||||
filter_length,
|
||||
n_jobs,
|
||||
ch_name,
|
||||
reject,
|
||||
flat,
|
||||
bads,
|
||||
avg_ref,
|
||||
no_proj,
|
||||
event_id,
|
||||
exg_l_freq,
|
||||
exg_h_freq,
|
||||
tstart,
|
||||
qrs_threshold,
|
||||
filter_method,
|
||||
iir_params,
|
||||
return_drop_log,
|
||||
copy,
|
||||
meg,
|
||||
verbose,
|
||||
):
|
||||
"""Compute SSP/PCA projections for ECG or EOG artifacts."""
|
||||
raw = raw.copy() if copy else raw
|
||||
del copy
|
||||
raw.load_data() # we will filter it later
|
||||
|
||||
if no_proj:
|
||||
projs = []
|
||||
else:
|
||||
projs = cp.deepcopy(raw.info["projs"])
|
||||
logger.info(f"Including {len(projs)} SSP projectors from raw file")
|
||||
|
||||
if avg_ref:
|
||||
eeg_proj = make_eeg_average_ref_proj(raw.info)
|
||||
projs.append(eeg_proj)
|
||||
|
||||
if raw_event is None:
|
||||
raw_event = raw
|
||||
|
||||
assert mode in ("ECG", "EOG") # internal function
|
||||
logger.info(f"Running {mode} SSP computation")
|
||||
if mode == "ECG":
|
||||
events, _, _ = find_ecg_events(
|
||||
raw_event,
|
||||
ch_name=ch_name,
|
||||
event_id=event_id,
|
||||
l_freq=exg_l_freq,
|
||||
h_freq=exg_h_freq,
|
||||
tstart=tstart,
|
||||
qrs_threshold=qrs_threshold,
|
||||
filter_length=filter_length,
|
||||
)
|
||||
else: # mode == 'EOG':
|
||||
events = find_eog_events(
|
||||
raw_event,
|
||||
event_id=event_id,
|
||||
l_freq=exg_l_freq,
|
||||
h_freq=exg_h_freq,
|
||||
filter_length=filter_length,
|
||||
ch_name=ch_name,
|
||||
tstart=tstart,
|
||||
)
|
||||
|
||||
# Check to make sure we actually got at least one usable event
|
||||
if events.shape[0] < 1:
|
||||
warn(f"No {mode} events found")
|
||||
return ([], events) + (([],) if return_drop_log else ())
|
||||
|
||||
logger.info("Computing projector")
|
||||
my_info = cp.deepcopy(raw.info)
|
||||
my_info["bads"] += bads
|
||||
|
||||
# Handler rejection parameters
|
||||
_validate_type(reject, (None, dict), "reject")
|
||||
_validate_type(flat, (None, dict), "flat")
|
||||
if reject is not None: # make sure they didn't pass None
|
||||
reject = reject.copy() # must make a copy or we modify default!
|
||||
if (
|
||||
len(
|
||||
pick_types(
|
||||
my_info,
|
||||
meg="grad",
|
||||
eeg=False,
|
||||
eog=False,
|
||||
ref_meg=False,
|
||||
exclude="bads",
|
||||
)
|
||||
)
|
||||
== 0
|
||||
):
|
||||
_safe_del_key(reject, "grad")
|
||||
if (
|
||||
len(
|
||||
pick_types(
|
||||
my_info,
|
||||
meg="mag",
|
||||
eeg=False,
|
||||
eog=False,
|
||||
ref_meg=False,
|
||||
exclude="bads",
|
||||
)
|
||||
)
|
||||
== 0
|
||||
):
|
||||
_safe_del_key(reject, "mag")
|
||||
if (
|
||||
len(
|
||||
pick_types(
|
||||
my_info,
|
||||
meg=False,
|
||||
eeg=True,
|
||||
eog=False,
|
||||
ref_meg=False,
|
||||
exclude="bads",
|
||||
)
|
||||
)
|
||||
== 0
|
||||
):
|
||||
_safe_del_key(reject, "eeg")
|
||||
if (
|
||||
len(
|
||||
pick_types(
|
||||
my_info,
|
||||
meg=False,
|
||||
eeg=False,
|
||||
eog=True,
|
||||
ref_meg=False,
|
||||
exclude="bads",
|
||||
)
|
||||
)
|
||||
== 0
|
||||
):
|
||||
_safe_del_key(reject, "eog")
|
||||
if flat is not None: # make sure they didn't pass None
|
||||
flat = flat.copy()
|
||||
if (
|
||||
len(
|
||||
pick_types(
|
||||
my_info,
|
||||
meg="grad",
|
||||
eeg=False,
|
||||
eog=False,
|
||||
ref_meg=False,
|
||||
exclude="bads",
|
||||
)
|
||||
)
|
||||
== 0
|
||||
):
|
||||
_safe_del_key(flat, "grad")
|
||||
if (
|
||||
len(
|
||||
pick_types(
|
||||
my_info,
|
||||
meg="mag",
|
||||
eeg=False,
|
||||
eog=False,
|
||||
ref_meg=False,
|
||||
exclude="bads",
|
||||
)
|
||||
)
|
||||
== 0
|
||||
):
|
||||
_safe_del_key(flat, "mag")
|
||||
if (
|
||||
len(
|
||||
pick_types(
|
||||
my_info,
|
||||
meg=False,
|
||||
eeg=True,
|
||||
eog=False,
|
||||
ref_meg=False,
|
||||
exclude="bads",
|
||||
)
|
||||
)
|
||||
== 0
|
||||
):
|
||||
_safe_del_key(flat, "eeg")
|
||||
if (
|
||||
len(
|
||||
pick_types(
|
||||
my_info,
|
||||
meg=False,
|
||||
eeg=False,
|
||||
eog=True,
|
||||
ref_meg=False,
|
||||
exclude="bads",
|
||||
)
|
||||
)
|
||||
== 0
|
||||
):
|
||||
_safe_del_key(flat, "eog")
|
||||
|
||||
# exclude bad channels from projection
|
||||
# keep reference channels if compensation channels are present
|
||||
ref_meg = len(my_info["comps"]) > 0
|
||||
picks = pick_types(
|
||||
my_info, meg=True, eeg=True, eog=True, ecg=True, ref_meg=ref_meg, exclude="bads"
|
||||
)
|
||||
|
||||
raw.filter(
|
||||
l_freq,
|
||||
h_freq,
|
||||
picks=picks,
|
||||
filter_length=filter_length,
|
||||
n_jobs=n_jobs,
|
||||
method=filter_method,
|
||||
iir_params=iir_params,
|
||||
l_trans_bandwidth=0.5,
|
||||
h_trans_bandwidth=0.5,
|
||||
phase="zero-double",
|
||||
fir_design="firwin2",
|
||||
)
|
||||
|
||||
epochs = Epochs(
|
||||
raw,
|
||||
events,
|
||||
None,
|
||||
tmin,
|
||||
tmax,
|
||||
baseline=None,
|
||||
preload=True,
|
||||
picks=picks,
|
||||
reject=reject,
|
||||
flat=flat,
|
||||
proj=True,
|
||||
)
|
||||
|
||||
drop_log = epochs.drop_log
|
||||
if epochs.events.shape[0] < 1:
|
||||
warn("No good epochs found")
|
||||
return ([], events) + ((drop_log,) if return_drop_log else ())
|
||||
|
||||
if average:
|
||||
evoked = epochs.average()
|
||||
ev_projs = compute_proj_evoked(
|
||||
evoked, n_grad=n_grad, n_mag=n_mag, n_eeg=n_eeg, meg=meg
|
||||
)
|
||||
else:
|
||||
ev_projs = compute_proj_epochs(
|
||||
epochs, n_grad=n_grad, n_mag=n_mag, n_eeg=n_eeg, n_jobs=n_jobs, meg=meg
|
||||
)
|
||||
|
||||
for p in ev_projs:
|
||||
p["desc"] = mode + "-" + p["desc"]
|
||||
|
||||
projs.extend(ev_projs)
|
||||
logger.info("Done.")
|
||||
return (projs, events) + ((drop_log,) if return_drop_log else ())
|
||||
|
||||
|
||||
@verbose
|
||||
def compute_proj_ecg(
|
||||
raw,
|
||||
raw_event=None,
|
||||
tmin=-0.2,
|
||||
tmax=0.4,
|
||||
n_grad=2,
|
||||
n_mag=2,
|
||||
n_eeg=2,
|
||||
l_freq=1.0,
|
||||
h_freq=35.0,
|
||||
average=True,
|
||||
filter_length="10s",
|
||||
n_jobs=None,
|
||||
ch_name=None,
|
||||
reject=dict(grad=2000e-13, mag=3000e-15, eeg=50e-6, eog=250e-6), # noqa: B006
|
||||
flat=None,
|
||||
bads=(),
|
||||
avg_ref=False,
|
||||
no_proj=False,
|
||||
event_id=999,
|
||||
ecg_l_freq=5,
|
||||
ecg_h_freq=35,
|
||||
tstart=0.0,
|
||||
qrs_threshold="auto",
|
||||
filter_method="fir",
|
||||
iir_params=None,
|
||||
copy=True,
|
||||
return_drop_log=False,
|
||||
meg="separate",
|
||||
verbose=None,
|
||||
):
|
||||
"""Compute SSP (signal-space projection) vectors for ECG artifacts.
|
||||
|
||||
%(compute_proj_ecg)s
|
||||
|
||||
.. note:: Raw data will be loaded if it hasn't been preloaded already.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : mne.io.Raw
|
||||
Raw input file.
|
||||
raw_event : mne.io.Raw or None
|
||||
Raw file to use for event detection (if None, raw is used).
|
||||
tmin : float
|
||||
Time before event in seconds.
|
||||
tmax : float
|
||||
Time after event in seconds.
|
||||
n_grad : int
|
||||
Number of SSP vectors for gradiometers.
|
||||
n_mag : int
|
||||
Number of SSP vectors for magnetometers.
|
||||
n_eeg : int
|
||||
Number of SSP vectors for EEG.
|
||||
l_freq : float | None
|
||||
Filter low cut-off frequency for the data channels in Hz.
|
||||
h_freq : float | None
|
||||
Filter high cut-off frequency for the data channels in Hz.
|
||||
average : bool
|
||||
Compute SSP after averaging. Default is True.
|
||||
filter_length : str | int | None
|
||||
Number of taps to use for filtering.
|
||||
%(n_jobs)s
|
||||
ch_name : str | None
|
||||
Channel to use for ECG detection (Required if no ECG found).
|
||||
reject : dict | None
|
||||
Epoch rejection configuration (see Epochs).
|
||||
flat : dict | None
|
||||
Epoch flat configuration (see Epochs).
|
||||
bads : list
|
||||
List with (additional) bad channels.
|
||||
avg_ref : bool
|
||||
Add EEG average reference proj.
|
||||
no_proj : bool
|
||||
Exclude the SSP projectors currently in the fiff file.
|
||||
event_id : int
|
||||
ID to use for events.
|
||||
ecg_l_freq : float
|
||||
Low pass frequency applied to the ECG channel for event detection.
|
||||
ecg_h_freq : float
|
||||
High pass frequency applied to the ECG channel for event detection.
|
||||
tstart : float
|
||||
Start artifact detection after tstart seconds.
|
||||
qrs_threshold : float | str
|
||||
Between 0 and 1. qrs detection threshold. Can also be "auto" to
|
||||
automatically choose the threshold that generates a reasonable
|
||||
number of heartbeats (40-160 beats / min).
|
||||
filter_method : str
|
||||
Method for filtering ('iir' or 'fir').
|
||||
iir_params : dict | None
|
||||
Dictionary of parameters to use for IIR filtering.
|
||||
See mne.filter.construct_iir_filter for details. If iir_params
|
||||
is None and method="iir", 4th order Butterworth will be used.
|
||||
copy : bool
|
||||
If False, filtering raw data is done in place. Defaults to True.
|
||||
return_drop_log : bool
|
||||
If True, return the drop log.
|
||||
|
||||
.. versionadded:: 0.15
|
||||
meg : str
|
||||
Can be ``'separate'`` (default) or ``'combined'`` to compute projectors
|
||||
for magnetometers and gradiometers separately or jointly.
|
||||
If ``'combined'``, ``n_mag == n_grad`` is required and the number of
|
||||
projectors computed for MEG will be ``n_mag``.
|
||||
|
||||
.. versionadded:: 0.18
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
%(projs)s
|
||||
ecg_events : ndarray
|
||||
Detected ECG events.
|
||||
drop_log : list
|
||||
The drop log, if requested.
|
||||
|
||||
See Also
|
||||
--------
|
||||
find_ecg_events
|
||||
create_ecg_epochs
|
||||
|
||||
Notes
|
||||
-----
|
||||
Filtering is applied to the ECG channel while finding events using
|
||||
``ecg_l_freq`` and ``ecg_h_freq``, and then to the ``raw`` instance
|
||||
using ``l_freq`` and ``h_freq`` before creation of the epochs used to
|
||||
create the projectors.
|
||||
"""
|
||||
return _compute_exg_proj(
|
||||
"ECG",
|
||||
raw,
|
||||
raw_event,
|
||||
tmin,
|
||||
tmax,
|
||||
n_grad,
|
||||
n_mag,
|
||||
n_eeg,
|
||||
l_freq,
|
||||
h_freq,
|
||||
average,
|
||||
filter_length,
|
||||
n_jobs,
|
||||
ch_name,
|
||||
reject,
|
||||
flat,
|
||||
bads,
|
||||
avg_ref,
|
||||
no_proj,
|
||||
event_id,
|
||||
ecg_l_freq,
|
||||
ecg_h_freq,
|
||||
tstart,
|
||||
qrs_threshold,
|
||||
filter_method,
|
||||
iir_params,
|
||||
return_drop_log,
|
||||
copy,
|
||||
meg,
|
||||
verbose,
|
||||
)
|
||||
|
||||
|
||||
@verbose
|
||||
def compute_proj_eog(
|
||||
raw,
|
||||
raw_event=None,
|
||||
tmin=-0.2,
|
||||
tmax=0.2,
|
||||
n_grad=2,
|
||||
n_mag=2,
|
||||
n_eeg=2,
|
||||
l_freq=1.0,
|
||||
h_freq=35.0,
|
||||
average=True,
|
||||
filter_length="10s",
|
||||
n_jobs=None,
|
||||
reject=dict(grad=2000e-13, mag=3000e-15, eeg=500e-6, eog=np.inf), # noqa: B006
|
||||
flat=None,
|
||||
bads=(),
|
||||
avg_ref=False,
|
||||
no_proj=False,
|
||||
event_id=998,
|
||||
eog_l_freq=1,
|
||||
eog_h_freq=10,
|
||||
tstart=0.0,
|
||||
filter_method="fir",
|
||||
iir_params=None,
|
||||
ch_name=None,
|
||||
copy=True,
|
||||
return_drop_log=False,
|
||||
meg="separate",
|
||||
verbose=None,
|
||||
):
|
||||
"""Compute SSP (signal-space projection) vectors for EOG artifacts.
|
||||
|
||||
%(compute_proj_eog)s
|
||||
|
||||
.. note:: Raw data must be preloaded.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : mne.io.Raw
|
||||
Raw input file.
|
||||
raw_event : mne.io.Raw or None
|
||||
Raw file to use for event detection (if None, raw is used).
|
||||
tmin : float
|
||||
Time before event in seconds.
|
||||
tmax : float
|
||||
Time after event in seconds.
|
||||
n_grad : int
|
||||
Number of SSP vectors for gradiometers.
|
||||
n_mag : int
|
||||
Number of SSP vectors for magnetometers.
|
||||
n_eeg : int
|
||||
Number of SSP vectors for EEG.
|
||||
l_freq : float | None
|
||||
Filter low cut-off frequency for the data channels in Hz.
|
||||
h_freq : float | None
|
||||
Filter high cut-off frequency for the data channels in Hz.
|
||||
average : bool
|
||||
Compute SSP after averaging. Default is True.
|
||||
filter_length : str | int | None
|
||||
Number of taps to use for filtering.
|
||||
%(n_jobs)s
|
||||
reject : dict | None
|
||||
Epoch rejection configuration (see Epochs).
|
||||
flat : dict | None
|
||||
Epoch flat configuration (see Epochs).
|
||||
bads : list
|
||||
List with (additional) bad channels.
|
||||
avg_ref : bool
|
||||
Add EEG average reference proj.
|
||||
no_proj : bool
|
||||
Exclude the SSP projectors currently in the fiff file.
|
||||
event_id : int
|
||||
ID to use for events.
|
||||
eog_l_freq : float
|
||||
Low pass frequency applied to the E0G channel for event detection.
|
||||
eog_h_freq : float
|
||||
High pass frequency applied to the EOG channel for event detection.
|
||||
tstart : float
|
||||
Start artifact detection after tstart seconds.
|
||||
filter_method : str
|
||||
Method for filtering ('iir' or 'fir').
|
||||
iir_params : dict | None
|
||||
Dictionary of parameters to use for IIR filtering.
|
||||
See mne.filter.construct_iir_filter for details. If iir_params
|
||||
is None and method="iir", 4th order Butterworth will be used.
|
||||
ch_name : str | None
|
||||
If not None, specify EOG channel name.
|
||||
copy : bool
|
||||
If False, filtering raw data is done in place. Defaults to True.
|
||||
return_drop_log : bool
|
||||
If True, return the drop log.
|
||||
|
||||
.. versionadded:: 0.15
|
||||
meg : str
|
||||
Can be 'separate' (default) or 'combined' to compute projectors
|
||||
for magnetometers and gradiometers separately or jointly.
|
||||
If 'combined', ``n_mag == n_grad`` is required and the number of
|
||||
projectors computed for MEG will be ``n_mag``.
|
||||
|
||||
.. versionadded:: 0.18
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
%(projs)s
|
||||
eog_events: ndarray
|
||||
Detected EOG events.
|
||||
drop_log : list
|
||||
The drop log, if requested.
|
||||
|
||||
See Also
|
||||
--------
|
||||
find_eog_events
|
||||
create_eog_epochs
|
||||
|
||||
Notes
|
||||
-----
|
||||
Filtering is applied to the EOG channel while finding events using
|
||||
``eog_l_freq`` and ``eog_h_freq``, and then to the ``raw`` instance
|
||||
using ``l_freq`` and ``h_freq`` before creation of the epochs used to
|
||||
create the projectors.
|
||||
"""
|
||||
return _compute_exg_proj(
|
||||
"EOG",
|
||||
raw,
|
||||
raw_event,
|
||||
tmin,
|
||||
tmax,
|
||||
n_grad,
|
||||
n_mag,
|
||||
n_eeg,
|
||||
l_freq,
|
||||
h_freq,
|
||||
average,
|
||||
filter_length,
|
||||
n_jobs,
|
||||
ch_name,
|
||||
reject,
|
||||
flat,
|
||||
bads,
|
||||
avg_ref,
|
||||
no_proj,
|
||||
event_id,
|
||||
eog_l_freq,
|
||||
eog_h_freq,
|
||||
tstart,
|
||||
"auto",
|
||||
filter_method,
|
||||
iir_params,
|
||||
return_drop_log,
|
||||
copy,
|
||||
meg,
|
||||
verbose,
|
||||
)
|
||||
176
mne/preprocessing/stim.py
Normal file
176
mne/preprocessing/stim.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
from scipy.interpolate import interp1d
|
||||
from scipy.signal.windows import hann
|
||||
|
||||
from .._fiff.pick import _picks_to_idx
|
||||
from ..epochs import BaseEpochs
|
||||
from ..event import find_events
|
||||
from ..evoked import Evoked
|
||||
from ..io import BaseRaw
|
||||
from ..utils import _check_option, _check_preload, _validate_type, fill_doc
|
||||
|
||||
|
||||
def _get_window(start, end):
|
||||
"""Return window which has length as much as parameter start - end."""
|
||||
window = 1 - np.r_[hann(4)[:2], np.ones(np.abs(end - start) - 4), hann(4)[-2:]].T
|
||||
return window
|
||||
|
||||
|
||||
def _fix_artifact(
|
||||
data, window, picks, first_samp, last_samp, base_tmin, base_tmax, mode
|
||||
):
|
||||
"""Modify original data by using parameter data."""
|
||||
if mode == "linear":
|
||||
x = np.array([first_samp, last_samp])
|
||||
f = interp1d(x, data[:, (first_samp, last_samp)][picks])
|
||||
xnew = np.arange(first_samp, last_samp)
|
||||
interp_data = f(xnew)
|
||||
data[picks, first_samp:last_samp] = interp_data
|
||||
if mode == "window":
|
||||
data[picks, first_samp:last_samp] = (
|
||||
data[picks, first_samp:last_samp] * window[np.newaxis, :]
|
||||
)
|
||||
if mode == "constant":
|
||||
data[picks, first_samp:last_samp] = data[picks, base_tmin:base_tmax].mean(
|
||||
axis=1
|
||||
)[:, None]
|
||||
|
||||
|
||||
@fill_doc
|
||||
def fix_stim_artifact(
|
||||
inst,
|
||||
events=None,
|
||||
event_id=None,
|
||||
tmin=0.0,
|
||||
tmax=0.01,
|
||||
*,
|
||||
baseline=None,
|
||||
mode="linear",
|
||||
stim_channel=None,
|
||||
picks=None,
|
||||
):
|
||||
"""Eliminate stimulation's artifacts from instance.
|
||||
|
||||
.. note:: This function operates in-place, consider passing
|
||||
``inst.copy()`` if this is not desired.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst : instance of Raw or Epochs or Evoked
|
||||
The data.
|
||||
events : array, shape (n_events, 3)
|
||||
The list of events. Required only when inst is Raw.
|
||||
event_id : int
|
||||
The id of the events generating the stimulation artifacts.
|
||||
If None, read all events. Required only when inst is Raw.
|
||||
tmin : float
|
||||
Start time of the interpolation window in seconds.
|
||||
tmax : float
|
||||
End time of the interpolation window in seconds.
|
||||
baseline : None | tuple, shape (2,)
|
||||
The baseline to use when ``mode='constant'``, in which case it
|
||||
must be non-None.
|
||||
|
||||
.. versionadded:: 1.8
|
||||
mode : 'linear' | 'window' | 'constant'
|
||||
Way to fill the artifacted time interval.
|
||||
|
||||
``"linear"``
|
||||
Does linear interpolation.
|
||||
``"window"``
|
||||
Applies a ``(1 - hanning)`` window.
|
||||
``"constant"``
|
||||
Uses baseline average. baseline parameter must be provided.
|
||||
|
||||
.. versionchanged:: 1.8
|
||||
Added the ``"constant"`` mode.
|
||||
stim_channel : str | None
|
||||
Stim channel to use.
|
||||
%(picks_all_data)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
inst : instance of Raw or Evoked or Epochs
|
||||
Instance with modified data.
|
||||
"""
|
||||
_check_option("mode", mode, ["linear", "window", "constant"])
|
||||
s_start = int(np.ceil(inst.info["sfreq"] * tmin))
|
||||
s_end = int(np.ceil(inst.info["sfreq"] * tmax))
|
||||
if mode == "constant":
|
||||
_validate_type(
|
||||
baseline, (tuple, list), "baseline", extra="when mode='constant'"
|
||||
)
|
||||
_check_option("len(baseline)", len(baseline), [2])
|
||||
for bi, b in enumerate(baseline):
|
||||
_validate_type(
|
||||
b, "numeric", f"baseline[{bi}]", extra="when mode='constant'"
|
||||
)
|
||||
b_start = int(np.ceil(inst.info["sfreq"] * baseline[0]))
|
||||
b_end = int(np.ceil(inst.info["sfreq"] * baseline[1]))
|
||||
else:
|
||||
b_start = b_end = np.nan
|
||||
if (mode == "window") and (s_end - s_start) < 4:
|
||||
raise ValueError(
|
||||
'Time range is too short. Use a larger interval or set mode to "linear".'
|
||||
)
|
||||
window = None
|
||||
if mode == "window":
|
||||
window = _get_window(s_start, s_end)
|
||||
|
||||
picks = _picks_to_idx(inst.info, picks, "data", exclude=())
|
||||
|
||||
_check_preload(inst, "fix_stim_artifact")
|
||||
if isinstance(inst, BaseRaw):
|
||||
if events is None:
|
||||
events = find_events(inst, stim_channel=stim_channel)
|
||||
if len(events) == 0:
|
||||
raise ValueError("No events are found")
|
||||
if event_id is None:
|
||||
events_sel = np.arange(len(events))
|
||||
else:
|
||||
events_sel = events[:, 2] == event_id
|
||||
event_start = events[events_sel, 0]
|
||||
data = inst._data
|
||||
for event_idx in event_start:
|
||||
first_samp = int(event_idx) - inst.first_samp + s_start
|
||||
last_samp = int(event_idx) - inst.first_samp + s_end
|
||||
base_t1 = int(event_idx) - inst.first_samp + b_start
|
||||
base_t2 = int(event_idx) - inst.first_samp + b_end
|
||||
_fix_artifact(
|
||||
data, window, picks, first_samp, last_samp, base_t1, base_t2, mode
|
||||
)
|
||||
elif isinstance(inst, BaseEpochs):
|
||||
if inst.reject is not None:
|
||||
raise RuntimeError(
|
||||
"Reject is already applied. Use reject=None in the constructor."
|
||||
)
|
||||
e_start = int(np.ceil(inst.info["sfreq"] * inst.tmin))
|
||||
first_samp = s_start - e_start
|
||||
last_samp = s_end - e_start
|
||||
data = inst._data
|
||||
base_t1 = b_start - e_start
|
||||
base_t2 = b_end - e_start
|
||||
for epoch in data:
|
||||
_fix_artifact(
|
||||
epoch, window, picks, first_samp, last_samp, base_t1, base_t2, mode
|
||||
)
|
||||
|
||||
elif isinstance(inst, Evoked):
|
||||
first_samp = s_start - inst.first
|
||||
last_samp = s_end - inst.first
|
||||
data = inst.data
|
||||
base_t1 = b_start - inst.first
|
||||
base_t2 = b_end - inst.first
|
||||
|
||||
_fix_artifact(
|
||||
data, window, picks, first_samp, last_samp, base_t1, base_t2, mode
|
||||
)
|
||||
|
||||
else:
|
||||
raise TypeError(f"Not a Raw or Epochs or Evoked (got {type(inst)}).")
|
||||
|
||||
return inst
|
||||
682
mne/preprocessing/xdawn.py
Normal file
682
mne/preprocessing/xdawn.py
Normal file
@@ -0,0 +1,682 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
from scipy import linalg
|
||||
|
||||
from .._fiff.pick import _pick_data_channels, pick_info
|
||||
from ..cov import Covariance, _regularized_covariance
|
||||
from ..decoding import BaseEstimator, TransformerMixin
|
||||
from ..epochs import BaseEpochs
|
||||
from ..evoked import Evoked, EvokedArray
|
||||
from ..io import BaseRaw
|
||||
from ..utils import _check_option, logger, pinv
|
||||
|
||||
|
||||
def _construct_signal_from_epochs(epochs, events, sfreq, tmin):
|
||||
"""Reconstruct pseudo continuous signal from epochs."""
|
||||
n_epochs, n_channels, n_times = epochs.shape
|
||||
tmax = tmin + n_times / float(sfreq)
|
||||
start = np.min(events[:, 0]) + int(tmin * sfreq)
|
||||
stop = np.max(events[:, 0]) + int(tmax * sfreq) + 1
|
||||
|
||||
n_samples = stop - start
|
||||
n_epochs, n_channels, n_times = epochs.shape
|
||||
events_pos = events[:, 0] - events[0, 0]
|
||||
|
||||
raw = np.zeros((n_channels, n_samples))
|
||||
for idx in range(n_epochs):
|
||||
onset = events_pos[idx]
|
||||
offset = onset + n_times
|
||||
raw[:, onset:offset] = epochs[idx]
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def _least_square_evoked(epochs_data, events, tmin, sfreq):
|
||||
"""Least square estimation of evoked response from epochs data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
epochs_data : array, shape (n_channels, n_times)
|
||||
The epochs data to estimate evoked.
|
||||
events : array, shape (n_events, 3)
|
||||
The events typically returned by the read_events function.
|
||||
If some events don't match the events of interest as specified
|
||||
by event_id, they will be ignored.
|
||||
tmin : float
|
||||
Start time before event.
|
||||
sfreq : float
|
||||
Sampling frequency.
|
||||
|
||||
Returns
|
||||
-------
|
||||
evokeds : array, shape (n_class, n_components, n_times)
|
||||
An concatenated array of evoked data for each event type.
|
||||
toeplitz : array, shape (n_class * n_components, n_channels)
|
||||
An concatenated array of toeplitz matrix for each event type.
|
||||
"""
|
||||
n_epochs, n_channels, n_times = epochs_data.shape
|
||||
tmax = tmin + n_times / float(sfreq)
|
||||
|
||||
# Deal with shuffled epochs
|
||||
events = events.copy()
|
||||
events[:, 0] -= events[0, 0] + int(tmin * sfreq)
|
||||
|
||||
# Construct raw signal
|
||||
raw = _construct_signal_from_epochs(epochs_data, events, sfreq, tmin)
|
||||
|
||||
# Compute the independent evoked responses per condition, while correcting
|
||||
# for event overlaps.
|
||||
n_min, n_max = int(tmin * sfreq), int(tmax * sfreq)
|
||||
window = n_max - n_min
|
||||
n_samples = raw.shape[1]
|
||||
toeplitz = list()
|
||||
classes = np.unique(events[:, 2])
|
||||
for ii, this_class in enumerate(classes):
|
||||
# select events by type
|
||||
sel = events[:, 2] == this_class
|
||||
|
||||
# build toeplitz matrix
|
||||
trig = np.zeros((n_samples,))
|
||||
ix_trig = (events[sel, 0]) + n_min
|
||||
trig[ix_trig] = 1
|
||||
toeplitz.append(linalg.toeplitz(trig[0:window], trig))
|
||||
|
||||
# Concatenate toeplitz
|
||||
toeplitz = np.array(toeplitz)
|
||||
X = np.concatenate(toeplitz)
|
||||
|
||||
# least square estimation
|
||||
predictor = np.dot(pinv(np.dot(X, X.T)), X)
|
||||
evokeds = np.dot(predictor, raw.T)
|
||||
evokeds = np.transpose(np.vsplit(evokeds, len(classes)), (0, 2, 1))
|
||||
return evokeds, toeplitz
|
||||
|
||||
|
||||
def _fit_xdawn(
|
||||
epochs_data,
|
||||
y,
|
||||
n_components,
|
||||
reg=None,
|
||||
signal_cov=None,
|
||||
events=None,
|
||||
tmin=0.0,
|
||||
sfreq=1.0,
|
||||
method_params=None,
|
||||
info=None,
|
||||
):
|
||||
"""Fit filters and coefs using Xdawn Algorithm.
|
||||
|
||||
Xdawn is a spatial filtering method designed to improve the signal
|
||||
to signal + noise ratio (SSNR) of the event related responses. Xdawn was
|
||||
originally designed for P300 evoked potential by enhancing the target
|
||||
response with respect to the non-target response. This implementation is a
|
||||
generalization to any type of event related response.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
epochs_data : array, shape (n_epochs, n_channels, n_times)
|
||||
The epochs data.
|
||||
y : array, shape (n_epochs)
|
||||
The epochs class.
|
||||
n_components : int (default 2)
|
||||
The number of components to decompose the signals signals.
|
||||
reg : float | str | None (default None)
|
||||
If not None (same as ``'empirical'``, default), allow
|
||||
regularization for covariance estimation.
|
||||
If float, shrinkage is used (0 <= shrinkage <= 1).
|
||||
For str options, ``reg`` will be passed as ``method`` to
|
||||
:func:`mne.compute_covariance`.
|
||||
signal_cov : None | Covariance | array, shape (n_channels, n_channels)
|
||||
The signal covariance used for whitening of the data.
|
||||
if None, the covariance is estimated from the epochs signal.
|
||||
events : array, shape (n_epochs, 3)
|
||||
The epochs events, used to correct for epochs overlap.
|
||||
tmin : float
|
||||
Epochs starting time. Only used if events is passed to correct for
|
||||
epochs overlap.
|
||||
sfreq : float
|
||||
Sampling frequency. Only used if events is passed to correct for
|
||||
epochs overlap.
|
||||
|
||||
Returns
|
||||
-------
|
||||
filters : array, shape (n_channels, n_channels)
|
||||
The Xdawn components used to decompose the data for each event type.
|
||||
Each row corresponds to one component.
|
||||
patterns : array, shape (n_channels, n_channels)
|
||||
The Xdawn patterns used to restore the signals for each event type.
|
||||
evokeds : array, shape (n_class, n_components, n_times)
|
||||
The independent evoked responses per condition.
|
||||
"""
|
||||
if not isinstance(epochs_data, np.ndarray) or epochs_data.ndim != 3:
|
||||
raise ValueError("epochs_data must be 3D ndarray")
|
||||
|
||||
classes = np.unique(y)
|
||||
|
||||
# XXX Eventually this could be made to deal with rank deficiency properly
|
||||
# by exposing this "rank" parameter, but this will require refactoring
|
||||
# the linalg.eigh call to operate in the lower-dimension
|
||||
# subspace, then project back out.
|
||||
|
||||
# Retrieve or compute whitening covariance
|
||||
if signal_cov is None:
|
||||
signal_cov = _regularized_covariance(
|
||||
np.hstack(epochs_data), reg, method_params, info, rank="full"
|
||||
)
|
||||
elif isinstance(signal_cov, Covariance):
|
||||
signal_cov = signal_cov.data
|
||||
if not isinstance(signal_cov, np.ndarray) or (
|
||||
not np.array_equal(signal_cov.shape, np.tile(epochs_data.shape[1], 2))
|
||||
):
|
||||
raise ValueError(
|
||||
"signal_cov must be None, a covariance instance, "
|
||||
"or an array of shape (n_chans, n_chans)"
|
||||
)
|
||||
|
||||
# Get prototype events
|
||||
if events is not None:
|
||||
evokeds, toeplitzs = _least_square_evoked(epochs_data, events, tmin, sfreq)
|
||||
else:
|
||||
evokeds, toeplitzs = list(), list()
|
||||
for c in classes:
|
||||
# Prototyped response for each class
|
||||
evokeds.append(np.mean(epochs_data[y == c, :, :], axis=0))
|
||||
toeplitzs.append(1.0)
|
||||
|
||||
filters = list()
|
||||
patterns = list()
|
||||
for evo, toeplitz in zip(evokeds, toeplitzs):
|
||||
# Estimate covariance matrix of the prototype response
|
||||
evo = np.dot(evo, toeplitz)
|
||||
evo_cov = _regularized_covariance(evo, reg, method_params, info, rank="full")
|
||||
|
||||
# Fit spatial filters
|
||||
try:
|
||||
evals, evecs = linalg.eigh(evo_cov, signal_cov)
|
||||
except np.linalg.LinAlgError as exp:
|
||||
raise ValueError(
|
||||
"Could not compute eigenvalues, ensure "
|
||||
f"proper regularization ({exp})"
|
||||
)
|
||||
evecs = evecs[:, np.argsort(evals)[::-1]] # sort eigenvectors
|
||||
evecs /= np.apply_along_axis(np.linalg.norm, 0, evecs)
|
||||
_patterns = np.linalg.pinv(evecs.T)
|
||||
filters.append(evecs[:, :n_components].T)
|
||||
patterns.append(_patterns[:, :n_components].T)
|
||||
|
||||
filters = np.concatenate(filters, axis=0)
|
||||
patterns = np.concatenate(patterns, axis=0)
|
||||
evokeds = np.array(evokeds)
|
||||
return filters, patterns, evokeds
|
||||
|
||||
|
||||
class _XdawnTransformer(BaseEstimator, TransformerMixin):
|
||||
"""Implementation of the Xdawn Algorithm compatible with scikit-learn.
|
||||
|
||||
Xdawn is a spatial filtering method designed to improve the signal
|
||||
to signal + noise ratio (SSNR) of the event related responses. Xdawn was
|
||||
originally designed for P300 evoked potential by enhancing the target
|
||||
response with respect to the non-target response. This implementation is a
|
||||
generalization to any type of event related response.
|
||||
|
||||
.. note:: _XdawnTransformer does not correct for epochs overlap. To correct
|
||||
overlaps see ``Xdawn``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_components : int (default 2)
|
||||
The number of components to decompose the signals.
|
||||
reg : float | str | None (default None)
|
||||
If not None (same as ``'empirical'``, default), allow
|
||||
regularization for covariance estimation.
|
||||
If float, shrinkage is used (0 <= shrinkage <= 1).
|
||||
For str options, ``reg`` will be passed to ``method`` to
|
||||
:func:`mne.compute_covariance`.
|
||||
signal_cov : None | Covariance | array, shape (n_channels, n_channels)
|
||||
The signal covariance used for whitening of the data.
|
||||
if None, the covariance is estimated from the epochs signal.
|
||||
method_params : dict | None
|
||||
Parameters to pass to :func:`mne.compute_covariance`.
|
||||
|
||||
.. versionadded:: 0.16
|
||||
|
||||
Attributes
|
||||
----------
|
||||
classes_ : array, shape (n_classes)
|
||||
The event indices of the classes.
|
||||
filters_ : array, shape (n_channels, n_channels)
|
||||
The Xdawn components used to decompose the data for each event type.
|
||||
patterns_ : array, shape (n_channels, n_channels)
|
||||
The Xdawn patterns used to restore the signals for each event type.
|
||||
"""
|
||||
|
||||
def __init__(self, n_components=2, reg=None, signal_cov=None, method_params=None):
|
||||
"""Init."""
|
||||
self.n_components = n_components
|
||||
self.signal_cov = signal_cov
|
||||
self.reg = reg
|
||||
self.method_params = method_params
|
||||
|
||||
def fit(self, X, y=None):
|
||||
"""Fit Xdawn spatial filters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array, shape (n_epochs, n_channels, n_samples)
|
||||
The target data.
|
||||
y : array, shape (n_epochs,) | None
|
||||
The target labels. If None, Xdawn fit on the average evoked.
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : Xdawn instance
|
||||
The Xdawn instance.
|
||||
"""
|
||||
X, y = self._check_Xy(X, y)
|
||||
|
||||
# Main function
|
||||
self.classes_ = np.unique(y)
|
||||
self.filters_, self.patterns_, _ = _fit_xdawn(
|
||||
X,
|
||||
y,
|
||||
n_components=self.n_components,
|
||||
reg=self.reg,
|
||||
signal_cov=self.signal_cov,
|
||||
method_params=self.method_params,
|
||||
)
|
||||
return self
|
||||
|
||||
def transform(self, X):
|
||||
"""Transform data with spatial filters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array, shape (n_epochs, n_channels, n_samples)
|
||||
The target data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
X : array, shape (n_epochs, n_components * n_classes, n_samples)
|
||||
The transformed data.
|
||||
"""
|
||||
X, _ = self._check_Xy(X)
|
||||
|
||||
# Check size
|
||||
if self.filters_.shape[1] != X.shape[1]:
|
||||
raise ValueError(
|
||||
f"X must have {self.filters_.shape[1]} channels, got {X.shape[1]} "
|
||||
"instead."
|
||||
)
|
||||
|
||||
# Transform
|
||||
X = np.dot(self.filters_, X)
|
||||
X = X.transpose((1, 0, 2))
|
||||
return X
|
||||
|
||||
def inverse_transform(self, X):
|
||||
"""Remove selected components from the signal.
|
||||
|
||||
Given the unmixing matrix, transform data, zero out components,
|
||||
and inverse transform the data. This procedure will reconstruct
|
||||
the signals from which the dynamics described by the excluded
|
||||
components is subtracted.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array, shape (n_epochs, n_components * n_classes, n_times)
|
||||
The transformed data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
X : array, shape (n_epochs, n_channels * n_classes, n_times)
|
||||
The inverse transform data.
|
||||
"""
|
||||
# Check size
|
||||
X, _ = self._check_Xy(X)
|
||||
n_epochs, n_comp, n_times = X.shape
|
||||
if n_comp != (self.n_components * len(self.classes_)):
|
||||
raise ValueError(
|
||||
f"X must have {self.n_components * len(self.classes_)} components, "
|
||||
f"got {n_comp} instead."
|
||||
)
|
||||
|
||||
# Transform
|
||||
return np.dot(self.patterns_.T, X).transpose(1, 0, 2)
|
||||
|
||||
def _check_Xy(self, X, y=None):
|
||||
"""Check X and y types and dimensions."""
|
||||
# Check data
|
||||
if not isinstance(X, np.ndarray) or X.ndim != 3:
|
||||
raise ValueError(
|
||||
"X must be an array of shape (n_epochs, n_channels, n_samples)."
|
||||
)
|
||||
if y is None:
|
||||
y = np.ones(len(X))
|
||||
y = np.asarray(y)
|
||||
if len(X) != len(y):
|
||||
raise ValueError("X and y must have the same length")
|
||||
return X, y
|
||||
|
||||
|
||||
class Xdawn(_XdawnTransformer):
|
||||
"""Implementation of the Xdawn Algorithm.
|
||||
|
||||
Xdawn :footcite:`RivetEtAl2009,RivetEtAl2011` is a spatial
|
||||
filtering method designed to improve the signal to signal + noise
|
||||
ratio (SSNR) of the ERP responses. Xdawn was originally designed for
|
||||
P300 evoked potential by enhancing the target response with respect
|
||||
to the non-target response. This implementation is a generalization
|
||||
to any type of ERP.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_components : int, (default 2)
|
||||
The number of components to decompose the signals.
|
||||
signal_cov : None | Covariance | ndarray, shape (n_channels, n_channels)
|
||||
(default None). The signal covariance used for whitening of the data.
|
||||
if None, the covariance is estimated from the epochs signal.
|
||||
correct_overlap : 'auto' or bool (default 'auto')
|
||||
Compute the independent evoked responses per condition, while
|
||||
correcting for event overlaps if any. If 'auto', then
|
||||
overlapp_correction = True if the events do overlap.
|
||||
reg : float | str | None (default None)
|
||||
If not None (same as ``'empirical'``, default), allow
|
||||
regularization for covariance estimation.
|
||||
If float, shrinkage is used (0 <= shrinkage <= 1).
|
||||
For str options, ``reg`` will be passed as ``method`` to
|
||||
:func:`mne.compute_covariance`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
filters_ : dict of ndarray
|
||||
If fit, the Xdawn components used to decompose the data for each event
|
||||
type, else empty. For each event type, the filters are in the rows of
|
||||
the corresponding array.
|
||||
patterns_ : dict of ndarray
|
||||
If fit, the Xdawn patterns used to restore the signals for each event
|
||||
type, else empty.
|
||||
evokeds_ : dict of Evoked
|
||||
If fit, the evoked response for each event type.
|
||||
event_id_ : dict
|
||||
The event id.
|
||||
correct_overlap_ : bool
|
||||
Whether overlap correction was applied.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.decoding.CSP, mne.decoding.SPoC
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 0.10
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, n_components=2, signal_cov=None, correct_overlap="auto", reg=None
|
||||
):
|
||||
"""Init."""
|
||||
super().__init__(n_components=n_components, signal_cov=signal_cov, reg=reg)
|
||||
self.correct_overlap = _check_option(
|
||||
"correct_overlap", correct_overlap, ["auto", True, False]
|
||||
)
|
||||
|
||||
def fit(self, epochs, y=None):
|
||||
"""Fit Xdawn from epochs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
epochs : instance of Epochs
|
||||
An instance of Epoch on which Xdawn filters will be fitted.
|
||||
y : ndarray | None (default None)
|
||||
If None, used epochs.events[:, 2].
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : instance of Xdawn
|
||||
The Xdawn instance.
|
||||
"""
|
||||
# Check data
|
||||
if not isinstance(epochs, BaseEpochs):
|
||||
raise ValueError("epochs must be an Epochs object.")
|
||||
picks = _pick_data_channels(epochs.info)
|
||||
use_info = pick_info(epochs.info, picks)
|
||||
X = epochs.get_data(picks)
|
||||
y = epochs.events[:, 2] if y is None else y
|
||||
self.event_id_ = epochs.event_id
|
||||
|
||||
# Check that no baseline was applied with correct overlap
|
||||
correct_overlap = self.correct_overlap
|
||||
if correct_overlap == "auto":
|
||||
# Events are overlapped if the minimal inter-stimulus
|
||||
# interval is smaller than the time window.
|
||||
isi = np.diff(np.sort(epochs.events[:, 0]))
|
||||
window = int((epochs.tmax - epochs.tmin) * epochs.info["sfreq"])
|
||||
correct_overlap = isi.min() < window
|
||||
|
||||
if epochs.baseline and correct_overlap:
|
||||
raise ValueError("Cannot apply correct_overlap if epochs were baselined.")
|
||||
|
||||
events, tmin, sfreq = None, 0.0, 1.0
|
||||
if correct_overlap:
|
||||
events = epochs.events
|
||||
tmin = epochs.tmin
|
||||
sfreq = epochs.info["sfreq"]
|
||||
self.correct_overlap_ = correct_overlap
|
||||
|
||||
# Note: In this original version of Xdawn we compute and keep all
|
||||
# components. The selection comes at transform().
|
||||
n_components = X.shape[1]
|
||||
|
||||
# Main fitting function
|
||||
filters, patterns, evokeds = _fit_xdawn(
|
||||
X,
|
||||
y,
|
||||
n_components=n_components,
|
||||
reg=self.reg,
|
||||
signal_cov=self.signal_cov,
|
||||
events=events,
|
||||
tmin=tmin,
|
||||
sfreq=sfreq,
|
||||
method_params=self.method_params,
|
||||
info=use_info,
|
||||
)
|
||||
|
||||
# Re-order filters and patterns according to event_id
|
||||
filters = filters.reshape(-1, n_components, filters.shape[-1])
|
||||
patterns = patterns.reshape(-1, n_components, patterns.shape[-1])
|
||||
self.filters_, self.patterns_, self.evokeds_ = dict(), dict(), dict()
|
||||
idx = np.argsort([value for _, value in epochs.event_id.items()])
|
||||
for eid, this_filter, this_pattern, this_evo in zip(
|
||||
epochs.event_id, filters[idx], patterns[idx], evokeds[idx]
|
||||
):
|
||||
self.filters_[eid] = this_filter
|
||||
self.patterns_[eid] = this_pattern
|
||||
n_events = len(epochs[eid])
|
||||
evoked = EvokedArray(
|
||||
this_evo, use_info, tmin=epochs.tmin, comment=eid, nave=n_events
|
||||
)
|
||||
self.evokeds_[eid] = evoked
|
||||
return self
|
||||
|
||||
def transform(self, inst):
|
||||
"""Apply Xdawn dim reduction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst : Epochs | Evoked | ndarray, shape ([n_epochs, ]n_channels, n_times)
|
||||
Data on which Xdawn filters will be applied.
|
||||
|
||||
Returns
|
||||
-------
|
||||
X : ndarray, shape ([n_epochs, ]n_components * n_event_types, n_times)
|
||||
Spatially filtered signals.
|
||||
""" # noqa: E501
|
||||
if isinstance(inst, BaseEpochs):
|
||||
X = inst.get_data(copy=False)
|
||||
elif isinstance(inst, Evoked):
|
||||
X = inst.data
|
||||
elif isinstance(inst, np.ndarray):
|
||||
X = inst
|
||||
if X.ndim not in (2, 3):
|
||||
raise ValueError(f"X must be 2D or 3D, got {X.ndim}")
|
||||
else:
|
||||
raise ValueError("Data input must be of Epoch type or numpy array")
|
||||
|
||||
filters = [filt[: self.n_components] for filt in self.filters_.values()]
|
||||
filters = np.concatenate(filters, axis=0)
|
||||
X = np.dot(filters, X)
|
||||
if X.ndim == 3:
|
||||
X = X.transpose((1, 0, 2))
|
||||
return X
|
||||
|
||||
def apply(self, inst, event_id=None, include=None, exclude=None):
|
||||
"""Remove selected components from the signal.
|
||||
|
||||
Given the unmixing matrix, transform data,
|
||||
zero out components, and inverse transform the data.
|
||||
This procedure will reconstruct the signals from which
|
||||
the dynamics described by the excluded components is subtracted.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst : instance of Raw | Epochs | Evoked
|
||||
The data to be processed.
|
||||
event_id : dict | list of str | None (default None)
|
||||
The kind of event to apply. if None, a dict of inst will be return
|
||||
one for each type of event xdawn has been fitted.
|
||||
include : array_like of int | None (default None)
|
||||
The indices referring to columns in the ummixing matrix. The
|
||||
components to be kept. If None, the first n_components (as defined
|
||||
in the Xdawn constructor) will be kept.
|
||||
exclude : array_like of int | None (default None)
|
||||
The indices referring to columns in the ummixing matrix. The
|
||||
components to be zeroed out. If None, all the components except the
|
||||
first n_components will be exclude.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : dict
|
||||
A dict of instance (from the same type as inst input) for each
|
||||
event type in event_id.
|
||||
"""
|
||||
if event_id is None:
|
||||
event_id = self.event_id_
|
||||
|
||||
if not isinstance(inst, BaseRaw | BaseEpochs | Evoked):
|
||||
raise ValueError("Data input must be Raw, Epochs or Evoked type")
|
||||
picks = _pick_data_channels(inst.info)
|
||||
|
||||
# Define the components to keep
|
||||
default_exclude = list(range(self.n_components, len(inst.ch_names)))
|
||||
if exclude is None:
|
||||
exclude = default_exclude
|
||||
else:
|
||||
exclude = list(set(list(default_exclude) + list(exclude)))
|
||||
|
||||
if isinstance(inst, BaseRaw):
|
||||
out = self._apply_raw(
|
||||
raw=inst,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
event_id=event_id,
|
||||
picks=picks,
|
||||
)
|
||||
elif isinstance(inst, BaseEpochs):
|
||||
out = self._apply_epochs(
|
||||
epochs=inst,
|
||||
include=include,
|
||||
picks=picks,
|
||||
exclude=exclude,
|
||||
event_id=event_id,
|
||||
)
|
||||
elif isinstance(inst, Evoked):
|
||||
out = self._apply_evoked(
|
||||
evoked=inst,
|
||||
include=include,
|
||||
picks=picks,
|
||||
exclude=exclude,
|
||||
event_id=event_id,
|
||||
)
|
||||
return out
|
||||
|
||||
def _apply_raw(self, raw, include, exclude, event_id, picks):
|
||||
"""Aux method."""
|
||||
if not raw.preload:
|
||||
raise ValueError("Raw data must be preloaded to apply Xdawn")
|
||||
|
||||
raws = dict()
|
||||
for eid in event_id:
|
||||
data = raw[picks, :][0]
|
||||
|
||||
data = self._pick_sources(data, include, exclude, eid)
|
||||
|
||||
raw_r = raw.copy()
|
||||
|
||||
raw_r[picks, :] = data
|
||||
raws[eid] = raw_r
|
||||
return raws
|
||||
|
||||
def _apply_epochs(self, epochs, include, exclude, event_id, picks):
|
||||
"""Aux method."""
|
||||
if not epochs.preload:
|
||||
raise ValueError("Epochs must be preloaded to apply Xdawn")
|
||||
|
||||
# special case where epochs come picked but fit was 'unpicked'.
|
||||
epochs_dict = dict()
|
||||
data = np.hstack(epochs.get_data(picks))
|
||||
|
||||
for eid in event_id:
|
||||
data_r = self._pick_sources(data, include, exclude, eid)
|
||||
data_r = np.array(np.split(data_r, len(epochs.events), 1))
|
||||
epochs_r = epochs.copy().load_data()
|
||||
epochs_r._data[:, picks, :] = data_r
|
||||
epochs_dict[eid] = epochs_r
|
||||
|
||||
return epochs_dict
|
||||
|
||||
def _apply_evoked(self, evoked, include, exclude, event_id, picks):
|
||||
"""Aux method."""
|
||||
data = evoked.data[picks]
|
||||
evokeds = dict()
|
||||
|
||||
for eid in event_id:
|
||||
data_r = self._pick_sources(data, include, exclude, eid)
|
||||
evokeds[eid] = evoked.copy()
|
||||
|
||||
# restore evoked
|
||||
evokeds[eid].data[picks] = data_r
|
||||
|
||||
return evokeds
|
||||
|
||||
def _pick_sources(self, data, include, exclude, eid):
|
||||
"""Aux method."""
|
||||
logger.info("Transforming to Xdawn space")
|
||||
|
||||
# Apply unmixing
|
||||
sources = np.dot(self.filters_[eid], data)
|
||||
|
||||
if include not in (None, list()):
|
||||
mask = np.ones(len(sources), dtype=bool)
|
||||
mask[np.unique(include)] = False
|
||||
sources[mask] = 0.0
|
||||
logger.info(f"Zeroing out {int(mask.sum())} Xdawn components")
|
||||
elif exclude not in (None, list()):
|
||||
exclude_ = np.unique(exclude)
|
||||
sources[exclude_] = 0.0
|
||||
logger.info(f"Zeroing out {len(exclude_)} Xdawn components")
|
||||
logger.info("Inverse transforming to sensor space")
|
||||
data = np.dot(self.patterns_[eid].T, sources)
|
||||
|
||||
return data
|
||||
|
||||
def inverse_transform(self):
|
||||
"""Not implemented, see Xdawn.apply() instead."""
|
||||
# Exists because of _XdawnTransformer
|
||||
raise NotImplementedError("See Xdawn.apply()")
|
||||
Reference in New Issue
Block a user