initial commit

This commit is contained in:
2025-08-19 09:13:22 -07:00
parent 28464811d6
commit 0977a3e14d
820 changed files with 1003358 additions and 2 deletions

View File

@@ -0,0 +1,9 @@
"""Preprocessing with artifact detection, SSP, and ICA."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import lazy_loader as lazy
(__getattr__, __dir__, __all__) = lazy.attach_stub(__name__, __file__)

View File

@@ -0,0 +1,91 @@
__all__ = [
"EOGRegression",
"ICA",
"Xdawn",
"annotate_amplitude",
"annotate_break",
"annotate_movement",
"annotate_muscle_zscore",
"annotate_nan",
"compute_average_dev_head_t",
"compute_bridged_electrodes",
"compute_current_source_density",
"compute_fine_calibration",
"compute_maxwell_basis",
"compute_proj_ecg",
"compute_proj_eog",
"compute_proj_hfc",
"corrmap",
"cortical_signal_suppression",
"create_ecg_epochs",
"create_eog_epochs",
"equalize_bads",
"eyetracking",
"find_bad_channels_lof",
"find_bad_channels_maxwell",
"find_ecg_events",
"find_eog_events",
"fix_stim_artifact",
"get_score_funcs",
"ica_find_ecg_events",
"ica_find_eog_events",
"ieeg",
"infomax",
"interpolate_bridged_electrodes",
"maxwell_filter",
"maxwell_filter_prepare_emptyroom",
"nirs",
"oversampled_temporal_projection",
"peak_finder",
"read_eog_regression",
"read_fine_calibration",
"read_ica",
"read_ica_eeglab",
"realign_raw",
"regress_artifact",
"write_fine_calibration",
]
from . import eyetracking, ieeg, nirs
from ._annotate_amplitude import annotate_amplitude
from ._annotate_nan import annotate_nan
from ._csd import compute_bridged_electrodes, compute_current_source_density
from ._css import cortical_signal_suppression
from ._fine_cal import (
compute_fine_calibration,
read_fine_calibration,
write_fine_calibration,
)
from ._lof import find_bad_channels_lof
from ._peak_finder import peak_finder
from ._regress import EOGRegression, read_eog_regression, regress_artifact
from .artifact_detection import (
annotate_break,
annotate_movement,
annotate_muscle_zscore,
compute_average_dev_head_t,
)
from .ecg import create_ecg_epochs, find_ecg_events
from .eog import create_eog_epochs, find_eog_events
from .hfc import compute_proj_hfc
from .ica import (
ICA,
corrmap,
get_score_funcs,
ica_find_ecg_events,
ica_find_eog_events,
read_ica,
read_ica_eeglab,
)
from .infomax_ import infomax
from .interpolate import equalize_bads, interpolate_bridged_electrodes
from .maxwell import (
compute_maxwell_basis,
find_bad_channels_maxwell,
maxwell_filter,
maxwell_filter_prepare_emptyroom,
)
from .otp import oversampled_temporal_projection
from .realign import realign_raw
from .ssp import compute_proj_ecg, compute_proj_eog
from .stim import fix_stim_artifact
from .xdawn import Xdawn

View File

@@ -0,0 +1,280 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from .._fiff.pick import _picks_by_type, _picks_to_idx
from ..annotations import (
Annotations,
_adjust_onset_meas_date,
_annotations_starts_stops,
)
from ..fixes import jit
from ..io import BaseRaw
from ..utils import _mask_to_onsets_offsets, _validate_type, logger, verbose
@verbose
def annotate_amplitude(
raw,
peak=None,
flat=None,
bad_percent=5,
min_duration=0.005,
picks=None,
*,
verbose=None,
):
"""Annotate raw data based on peak-to-peak amplitude.
Creates annotations ``BAD_peak`` or ``BAD_flat`` for spans of data where
consecutive samples exceed the threshold in ``peak`` or fall below the
threshold in ``flat`` for more than ``min_duration``.
Channels where more than ``bad_percent`` of the total recording length
should be annotated with either ``BAD_peak`` or ``BAD_flat`` are returned
in ``bads`` instead.
Note that the annotations and the bads are not automatically added to the
:class:`~mne.io.Raw` object; use :meth:`~mne.io.Raw.set_annotations` and
:class:`info['bads'] <mne.Info>` to do so.
Parameters
----------
raw : instance of Raw
The raw data.
peak : float | dict | None
Annotate segments based on **maximum** peak-to-peak signal amplitude
(PTP). Valid **keys** can be any channel type present in the object.
The **values** are floats that set the maximum acceptable PTP. If the
PTP is larger than this threshold, the segment will be annotated.
If float, the minimum acceptable PTP is applied to all channels.
flat : float | dict | None
Annotate segments based on **minimum** peak-to-peak signal amplitude
(PTP). Valid **keys** can be any channel type present in the object.
The **values** are floats that set the minimum acceptable PTP. If the
PTP is smaller than this threshold, the segment will be annotated.
If float, the minimum acceptable PTP is applied to all channels.
bad_percent : float
The percentage of the time a channel can be above or below thresholds.
Below this percentage, :class:`~mne.Annotations` are created.
Above this percentage, the channel involved is return in ``bads``. Note
the returned ``bads`` are not automatically added to
:class:`info['bads'] <mne.Info>`.
Defaults to ``5``, i.e. 5%%.
min_duration : float
The minimum duration (s) required by consecutives samples to be above
``peak`` or below ``flat`` thresholds to be considered.
to consider as above or below threshold.
For some systems, adjacent time samples with exactly the same value are
not totally uncommon. Defaults to ``0.005`` (5 ms).
%(picks_good_data)s
%(verbose)s
Returns
-------
annotations : instance of Annotations
The annotated bad segments.
bads : list
The channels detected as bad.
Notes
-----
This function does not use a window to detect small peak-to-peak or large
peak-to-peak amplitude changes as the ``reject`` and ``flat`` argument from
:class:`~mne.Epochs` does. Instead, it looks at the difference between
consecutive samples.
- When used to detect segments below ``flat``, at least ``min_duration``
seconds of consecutive samples must respect
``abs(a[i+1] - a[i]) ≤ flat``.
- When used to detect segments above ``peak``, at least ``min_duration``
seconds of consecutive samples must respect
``abs(a[i+1] - a[i]) ≥ peak``.
Thus, this function does not detect every temporal event with large
peak-to-peak amplitude, but only the ones where the peak-to-peak amplitude
is supra-threshold between consecutive samples. For instance, segments
experiencing a DC shift will not be picked up. Only the edges from the DC
shift will be annotated (and those only if the edge transitions are longer
than ``min_duration``).
This function may perform faster if data is loaded in memory, as it
loads data one channel type at a time (across all time points), which is
typically not an efficient way to read raw data from disk.
.. versionadded:: 1.0
"""
_validate_type(raw, BaseRaw, "raw")
picks_ = _picks_to_idx(raw.info, picks, "data_or_ica", exclude="bads")
peak = _check_ptp(peak, "peak", raw.info, picks_)
flat = _check_ptp(flat, "flat", raw.info, picks_)
if peak is None and flat is None:
raise ValueError(
"At least one of the arguments 'peak' or 'flat' must not be None."
)
bad_percent = _check_bad_percent(bad_percent)
min_duration = _check_min_duration(
min_duration, raw.times.size * 1 / raw.info["sfreq"]
)
min_duration_samples = int(np.round(min_duration * raw.info["sfreq"]))
bads = list()
# grouping picks by channel types to avoid operating on each channel
# individually
picks = {
ch_type: np.intersect1d(picks_of_type, picks_, assume_unique=True)
for ch_type, picks_of_type in _picks_by_type(raw.info, exclude="bads")
if np.intersect1d(picks_of_type, picks_, assume_unique=True).size != 0
}
del picks_ # reusing this variable name in for loop
# skip BAD_acq_skip sections
onsets, ends = _annotations_starts_stops(raw, "bad_acq_skip", invert=True)
index = np.concatenate(
[np.arange(raw.times.size)[onset:end] for onset, end in zip(onsets, ends)]
)
# size matching the diff a[i+1] - a[i]
any_flat = np.zeros(len(raw.times) - 1, bool)
any_peak = np.zeros(len(raw.times) - 1, bool)
# look for discrete difference above or below thresholds
logger.info("Finding segments below or above PTP threshold.")
for ch_type, picks_ in picks.items():
data = np.concatenate(
[raw[picks_, onset:end][0] for onset, end in zip(onsets, ends)], axis=1
)
diff = np.abs(np.diff(data, axis=1))
if flat is not None:
flat_ = diff <= flat[ch_type]
# reject too short segments
flat_ = _reject_short_segments(flat_, min_duration_samples)
# reject channels above maximum bad_percentage
flat_count = flat_.sum(axis=1)
flat_count[np.nonzero(flat_count)] += 1 # offset by 1 due to diff
flat_mean = flat_count / raw.times.size * 100
flat_ch_to_set_bad = picks_[np.where(flat_mean >= bad_percent)[0]]
bads.extend(flat_ch_to_set_bad)
# add onset/offset for annotations
flat_ch_to_annotate = np.where((0 < flat_mean) & (flat_mean < bad_percent))[
0
]
# convert from raw.times[onset:end] - 1 to raw.times[:] - 1
idx = index[np.where(flat_[flat_ch_to_annotate, :])[1]]
any_flat[idx] = True
if peak is not None:
peak_ = diff >= peak[ch_type]
# reject too short segments
peak_ = _reject_short_segments(peak_, min_duration_samples)
# reject channels above maximum bad_percentage
peak_count = peak_.sum(axis=1)
peak_count[np.nonzero(peak_count)] += 1 # offset by 1 due to diff
peak_mean = peak_count / raw.times.size * 100
peak_ch_to_set_bad = picks_[np.where(peak_mean >= bad_percent)[0]]
bads.extend(peak_ch_to_set_bad)
# add onset/offset for annotations
peak_ch_to_annotate = np.where((0 < peak_mean) & (peak_mean < bad_percent))[
0
]
# convert from raw.times[onset:end] - 1 to raw.times[:] - 1
idx = index[np.where(peak_[peak_ch_to_annotate, :])[1]]
any_peak[idx] = True
# annotation for flat
annotation_flat = _create_annotations(any_flat, "flat", raw)
# annotation for peak
annotation_peak = _create_annotations(any_peak, "peak", raw)
# group
annotations = annotation_flat + annotation_peak
# bads
bads = [raw.ch_names[bad] for bad in bads if bad not in raw.info["bads"]]
return annotations, bads
def _check_ptp(ptp, name, info, picks):
"""Check the PTP threhsold argument, and converts it to dict if needed."""
_validate_type(ptp, ("numeric", dict, None))
if ptp is not None and not isinstance(ptp, dict):
if ptp < 0:
raise ValueError(
f"Argument '{name}' should define a positive threshold. "
f"Provided: '{ptp}'."
)
ch_types = set(info.get_channel_types(picks))
ptp = {ch_type: ptp for ch_type in ch_types}
elif isinstance(ptp, dict):
for key, value in ptp.items():
if value < 0:
raise ValueError(
f"Argument '{name}' should define positive thresholds. "
f"Provided for channel type '{key}': '{value}'."
)
return ptp
def _check_bad_percent(bad_percent):
"""Check that bad_percent is a valid percentage and converts to float."""
_validate_type(bad_percent, "numeric", "bad_percent")
bad_percent = float(bad_percent)
if not 0 <= bad_percent <= 100:
raise ValueError(
"Argument 'bad_percent' should define a percentage between 0% "
f"and 100%. Provided: {bad_percent}%."
)
return bad_percent
def _check_min_duration(min_duration, raw_duration):
"""Check that min_duration is a valid duration and converts to float."""
_validate_type(min_duration, "numeric", "min_duration")
min_duration = float(min_duration)
if min_duration < 0:
raise ValueError(
"Argument 'min_duration' should define a positive duration in "
f"seconds. Provided: '{min_duration}' seconds."
)
if min_duration >= raw_duration:
raise ValueError(
"Argument 'min_duration' should define a positive duration in "
f"seconds shorter than the raw duration ({raw_duration} seconds). "
f"Provided: '{min_duration}' seconds."
)
return min_duration
def _reject_short_segments(arr, min_duration_samples):
"""Check if flat or peak segments are longer than the minimum duration."""
assert arr.dtype == np.dtype(bool) and arr.ndim == 2
for k, ch in enumerate(arr):
onsets, offsets = _mask_to_onsets_offsets(ch)
_mark_inner(arr[k], onsets, offsets, min_duration_samples)
return arr
@jit()
def _mark_inner(arr_k, onsets, offsets, min_duration_samples):
"""Inner loop of _reject_short_segments()."""
for start, stop in zip(onsets, offsets):
if stop - start < min_duration_samples:
arr_k[start:stop] = False
def _create_annotations(any_arr, kind, raw):
"""Create the peak of flat annotations from the any_arr."""
assert kind in ("peak", "flat")
starts, stops = _mask_to_onsets_offsets(any_arr)
starts, stops = np.array(starts), np.array(stops)
onsets = starts / raw.info["sfreq"]
durations = (stops - starts) / raw.info["sfreq"]
annot = Annotations(
onsets,
durations,
[f"BAD_{kind}"] * len(onsets),
orig_time=raw.info["meas_date"],
)
_adjust_onset_meas_date(annot, raw)
return annot

View File

@@ -0,0 +1,38 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from ..annotations import Annotations, _adjust_onset_meas_date
from ..utils import verbose
from .artifact_detection import _annotations_from_mask
@verbose
def annotate_nan(raw, *, verbose=None):
"""Detect segments with NaN and return a new Annotations instance.
Parameters
----------
raw : instance of Raw
Data to find segments with NaN values.
%(verbose)s
Returns
-------
annot : instance of Annotations
New channel-specific annotations for the data.
"""
data, times = raw.get_data(return_times=True)
onsets, durations, ch_names = list(), list(), list()
for row, ch_name in zip(data, raw.ch_names):
annot = _annotations_from_mask(times, np.isnan(row), "BAD_NAN")
onsets.extend(annot.onset)
durations.extend(annot.duration)
ch_names.extend([[ch_name]] * len(annot))
annot = Annotations(
onsets, durations, "BAD_NAN", ch_names=ch_names, orig_time=raw.info["meas_date"]
)
_adjust_onset_meas_date(annot, raw)
return annot

323
mne/preprocessing/_csd.py Normal file
View File

@@ -0,0 +1,323 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
# Copyright 2003-2010 Jürgen Kayser <rjk23@columbia.edu>
#
# The original CSD Toolbox can be found at
# http://psychophysiology.cpmc.columbia.edu/Software/CSDtoolbox/
#
# Relicensed under BSD-3-Clause and adapted with permission from authors of original GPL
# code.
import numpy as np
from scipy.optimize import minimize_scalar
from scipy.stats import gaussian_kde
from .._fiff.constants import FIFF
from .._fiff.pick import pick_types
from ..bem import fit_sphere_to_headshape
from ..channels.interpolation import _calc_g, _calc_h
from ..epochs import BaseEpochs, make_fixed_length_epochs
from ..evoked import Evoked
from ..io import BaseRaw
from ..utils import _check_preload, _ensure_int, _validate_type, logger, verbose
def _prepare_G(G, lambda2):
G.flat[:: len(G) + 1] += lambda2
# compute the CSD
Gi = np.linalg.inv(G)
TC = Gi.sum(0)
sgi = np.sum(TC) # compute sum total
return Gi, TC, sgi
def _compute_csd(G_precomputed, H, radius):
"""Compute the CSD."""
n_channels = H.shape[0]
data = np.eye(n_channels)
mu = data.mean(0)
Z = data - mu
Gi, TC, sgi = G_precomputed
Cp2 = np.dot(Gi, Z)
c02 = np.sum(Cp2, axis=0) / sgi
C2 = Cp2 - np.dot(TC[:, np.newaxis], c02[np.newaxis, :])
X = np.dot(C2.T, H).T / radius**2
return X
@verbose
def compute_current_source_density(
inst,
sphere="auto",
lambda2=1e-5,
stiffness=4,
n_legendre_terms=50,
copy=True,
*,
verbose=None,
):
"""Get the current source density (CSD) transformation.
Transformation based on spherical spline surface Laplacian
:footcite:`PerrinEtAl1987,PerrinEtAl1989,Cohen2014,KayserTenke2015`.
This function can be used to re-reference the signal using a Laplacian
(LAP) "reference-free" transformation.
Parameters
----------
inst : instance of Raw, Epochs or Evoked
The data to be transformed.
sphere : array-like, shape (4,) | str
The sphere, head-model of the form (x, y, z, r) where x, y, z
is the center of the sphere and r is the radius in meters.
Can also be "auto" to use a digitization-based fit.
lambda2 : float
Regularization parameter, produces smoothness. Defaults to 1e-5.
stiffness : float
Stiffness of the spline.
n_legendre_terms : int
Number of Legendre terms to evaluate.
copy : bool
Whether to overwrite instance data or create a copy.
%(verbose)s
Returns
-------
inst_csd : instance of Raw, Epochs or Evoked
The transformed data. Output type will match input type.
Notes
-----
.. versionadded:: 0.20
References
----------
.. footbibliography::
"""
_validate_type(inst, (BaseEpochs, BaseRaw, Evoked), "inst")
_check_preload(inst, "Computing CSD")
if inst.info["custom_ref_applied"] == FIFF.FIFFV_MNE_CUSTOM_REF_CSD:
raise ValueError("CSD already applied, should not be reapplied")
_validate_type(copy, (bool), "copy")
inst = inst.copy() if copy else inst
picks = pick_types(inst.info, meg=False, eeg=True, exclude=[])
if any([ch in np.array(inst.ch_names)[picks] for ch in inst.info["bads"]]):
raise ValueError(
"CSD cannot be computed with bad EEG channels. Either"
" drop (inst.drop_channels(inst.info['bads']) "
"or interpolate (`inst.interpolate_bads()`) "
"bad EEG channels."
)
if len(picks) == 0:
raise ValueError("No EEG channels found.")
_validate_type(lambda2, "numeric", "lambda2")
if not 0 <= lambda2 < 1:
raise ValueError(f"lambda2 must be between 0 and 1, got {lambda2}")
_validate_type(stiffness, "numeric", "stiffness")
if stiffness < 0:
raise ValueError(f"stiffness must be non-negative got {stiffness}")
n_legendre_terms = _ensure_int(n_legendre_terms, "n_legendre_terms")
if n_legendre_terms < 1:
raise ValueError(
f"n_legendre_terms must be greater than 0, got {n_legendre_terms}"
)
if isinstance(sphere, str) and sphere == "auto":
radius, origin_head, origin_device = fit_sphere_to_headshape(inst.info)
x, y, z = origin_head - origin_device
sphere = (x, y, z, radius)
try:
sphere = np.array(sphere, float)
x, y, z, radius = sphere
except Exception:
raise ValueError(
f'sphere must be "auto" or array-like with shape (4,), got {sphere}'
)
_validate_type(x, "numeric", "x")
_validate_type(y, "numeric", "y")
_validate_type(z, "numeric", "z")
_validate_type(radius, "numeric", "radius")
if radius <= 0:
raise ValueError("sphere radius must be greater than 0, got {radius}")
pos = np.array([inst.info["chs"][pick]["loc"][:3] for pick in picks])
if not np.isfinite(pos).all() or np.isclose(pos, 0.0).all(1).any():
raise ValueError("Zero or infinite position found in chs")
pos -= (x, y, z)
# Project onto a unit sphere to compute the cosine similarity:
pos /= np.linalg.norm(pos, axis=1, keepdims=True)
cos_dist = np.clip(np.dot(pos, pos.T), -1, 1)
# This is equivalent to doing one minus half the squared Euclidean:
# from scipy.spatial.distance import squareform, pdist
# cos_dist = 1 - squareform(pdist(pos, 'sqeuclidean')) / 2.
del pos
G = _calc_g(cos_dist, stiffness=stiffness, n_legendre_terms=n_legendre_terms)
H = _calc_h(cos_dist, stiffness=stiffness, n_legendre_terms=n_legendre_terms)
G_precomputed = _prepare_G(G, lambda2)
trans_csd = _compute_csd(G_precomputed=G_precomputed, H=H, radius=radius)
epochs = [inst._data] if not isinstance(inst, BaseEpochs) else inst._data
for epo in epochs:
epo[picks] = np.dot(trans_csd, epo[picks])
with inst.info._unlock():
inst.info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_CSD
for pick in picks:
inst.info["chs"][pick].update(
coil_type=FIFF.FIFFV_COIL_EEG_CSD, unit=FIFF.FIFF_UNIT_V_M2
)
# Remove rejection thresholds for EEG
if isinstance(inst, BaseEpochs):
if inst.reject and "eeg" in inst.reject:
del inst.reject["eeg"]
if inst.flat and "eeg" in inst.flat:
del inst.flat["eeg"]
return inst
@verbose
def compute_bridged_electrodes(
inst,
lm_cutoff=16,
epoch_threshold=0.5,
l_freq=0.5,
h_freq=30,
epoch_duration=2,
bw_method=None,
verbose=None,
):
r"""Compute bridged EEG electrodes using the intrinsic Hjorth algorithm.
First, an electrical distance matrix is computed by taking the pairwise
variance between electrodes. Local minimums in this matrix below
``lm_cutoff`` are indicative of bridging between a pair of electrodes.
Pairs of electrodes are marked as bridged as long as their electrical
distance is below ``lm_cutoff`` on more than the ``epoch_threshold``
proportion of epochs.
Based on :footcite:`TenkeKayser2001,GreischarEtAl2004,DelormeMakeig2004`
and the `EEGLAB implementation
<https://psychophysiology.cpmc.columbia.edu/>`__.
Parameters
----------
inst : instance of Raw, Epochs or Evoked
The data to compute electrode bridging on.
lm_cutoff : float
The distance in :math:`{\mu}V^2` cutoff below which to
search for a local minimum (lm) indicative of bridging.
EEGLAB defaults to 5 :math:`{\mu}V^2`. MNE defaults to
16 :math:`{\mu}V^2` to be conservative based on the distributions in
:footcite:t:`GreischarEtAl2004`.
epoch_threshold : float
The proportion of epochs with electrical distance less than
``lm_cutoff`` in order to consider the channel bridged.
The default is 0.5.
l_freq : float
The low cutoff frequency to use. Default is 0.5 Hz.
h_freq : float
The high cutoff frequency to use. Default is 30 Hz.
epoch_duration : float
The time in seconds to divide the raw into fixed-length epochs
to check for consistent bridging. Only used if ``inst`` is
:class:`mne.io.BaseRaw`. The default is 2 seconds.
bw_method : None
``bw_method`` to pass to :class:`scipy.stats.gaussian_kde`.
%(verbose)s
Returns
-------
bridged_idx : list of tuple
The indices of channels marked as bridged with each bridged
pair stored as a tuple.
ed_matrix : ndarray of float, shape (n_epochs, n_channels, n_channels)
The electrical distance matrix for each pair of EEG electrodes.
Notes
-----
.. versionadded:: 1.1
References
----------
.. footbibliography::
"""
_check_preload(inst, "Computing bridged electrodes")
inst = inst.copy() # don't modify original
picks = pick_types(inst.info, eeg=True)
if len(picks) == 0:
raise RuntimeError("No EEG channels found, cannot compute electrode bridging")
# first, filter
inst.filter(l_freq=l_freq, h_freq=h_freq, picks=picks, verbose=False)
if isinstance(inst, BaseRaw):
inst = make_fixed_length_epochs(
inst, duration=epoch_duration, preload=True, verbose=False
)
# standardize shape
data = inst.get_data(picks=picks)
if isinstance(inst, Evoked):
data = data[np.newaxis, ...] # expand evoked
# next, compute electrical distance matrix, upper triangular
n_epochs = data.shape[0]
ed_matrix = np.zeros((n_epochs, picks.size, picks.size)) * np.nan
for i in range(picks.size):
for j in range(i + 1, picks.size):
ed_matrix[:, i, j] = np.var(data[:, i] - data[:, j], axis=1)
# scale, fill in other half, diagonal
ed_matrix *= 1e12 # scale to muV**2
# initialize bridged indices
bridged_idx = list()
# if not enough values below local minimum cutoff, return no bridges
ed_flat = ed_matrix[~np.isnan(ed_matrix)]
if ed_flat[ed_flat < lm_cutoff].size / n_epochs < epoch_threshold:
return bridged_idx, ed_matrix
# kernel density estimation
kde = gaussian_kde(ed_flat[ed_flat < lm_cutoff], bw_method=bw_method)
with np.errstate(invalid="ignore"):
local_minimum = float(
minimize_scalar(
lambda x: kde(x) if x < lm_cutoff and x > 0 else np.inf
).x.item()
)
logger.info(f"Local minimum {local_minimum} found")
# find electrodes that are below the cutoff local minimum on
# `epochs_threshold` proportion of epochs
for i in range(picks.size):
for j in range(i + 1, picks.size):
bridged_count = np.sum(ed_matrix[:, i, j] < local_minimum)
if bridged_count / n_epochs > epoch_threshold:
logger.info(
"Bridge detected between "
f"{inst.ch_names[picks[i]]} and "
f"{inst.ch_names[picks[j]]}"
)
bridged_idx.append((picks[i], picks[j]))
return bridged_idx, ed_matrix

95
mne/preprocessing/_css.py Normal file
View File

@@ -0,0 +1,95 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from .._fiff.pick import _picks_to_idx
from ..evoked import Evoked
from ..utils import _ensure_int, _validate_type, verbose
def _temp_proj(ref_2, ref_1, raw_data, n_proj=6):
# Orthonormalize gradiometer and magnetometer data by a QR decomposition
ref_1_orth = np.linalg.qr(ref_1.T)[0]
ref_2_orth = np.linalg.qr(ref_2.T)[0]
# Calculate cross-correlation
cross_corr = np.dot(ref_1_orth.T, ref_2_orth)
# Channel weights for common temporal subspace by SVD of cross-correlation
ref_1_ch_weights, _, _ = np.linalg.svd(cross_corr)
# Get temporal signals from channel weights
proj_mat = ref_1_orth @ ref_1_ch_weights
# Project out common subspace
filtered_data = raw_data
proj_vec = proj_mat[:, :n_proj]
weights = filtered_data @ proj_vec
filtered_data -= weights @ proj_vec.T
@verbose
def cortical_signal_suppression(
evoked, picks=None, mag_picks=None, grad_picks=None, n_proj=6, *, verbose=None
):
"""Apply cortical signal suppression (CSS) to evoked data.
Parameters
----------
evoked : instance of Evoked
The evoked object to use for CSS. Must contain magnetometer,
gradiometer, and EEG channels.
%(picks_good_data)s
mag_picks : array-like of int
Array of the first set of channel indices that will be used to find
the common temporal subspace. If None (default), all magnetometers will
be used.
grad_picks : array-like of int
Array of the second set of channel indices that will be used to find
the common temporal subspace. If None (default), all gradiometers will
be used.
n_proj : int
The number of projection vectors.
%(verbose)s
Returns
-------
evoked_subcortical : instance of Evoked
The evoked object with contributions from the ``mag_picks`` and ``grad_picks``
channels removed from the ``picks`` channels.
Notes
-----
This method removes the common signal subspace between two sets of
channels (``mag_picks`` and ``grad_picks``) from a set of channels
(``picks``) via a temporal projection using ``n_proj`` number of
projection vectors. In the reference publication :footcite:`Samuelsson2019`,
the joint subspace between magnetometers and gradiometers is used to
suppress the cortical signal in the EEG data. In principle, other
combinations of sensor types (or channels) could be used to suppress
signals from other sources.
References
----------
.. footbibliography::
"""
_validate_type(evoked, Evoked, "evoked")
n_proj = _ensure_int(n_proj, "n_proj")
picks = _picks_to_idx(evoked.info, picks, none="data", exclude="bads")
mag_picks = _picks_to_idx(evoked.info, mag_picks, none="mag", exclude="bads")
grad_picks = _picks_to_idx(evoked.info, grad_picks, none="grad", exclude="bads")
evoked_subcortical = evoked.copy()
# Get data
all_data = evoked.data
mag_data = all_data[mag_picks]
grad_data = all_data[grad_picks]
# Process data with temporal projection algorithm
data = all_data[picks]
_temp_proj(mag_data, grad_data, data, n_proj=n_proj)
evoked_subcortical.data[picks, :] = data
return evoked_subcortical

View File

@@ -0,0 +1,597 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from collections import defaultdict
from functools import partial
import numpy as np
from scipy.optimize import minimize
from .._fiff.pick import pick_info, pick_types
from .._fiff.tag import _coil_trans_to_loc, _loc_to_coil_trans
from ..bem import _check_origin
from ..io import BaseRaw
from ..transforms import _find_vector_rotation
from ..utils import (
_check_fname,
_check_option,
_clean_names,
_ensure_int,
_pl,
_reg_pinv,
_validate_type,
check_fname,
logger,
verbose,
)
from .maxwell import (
_col_norm_pinv,
_get_grad_point_coilsets,
_prep_fine_cal,
_prep_mf_coils,
_read_cross_talk,
_trans_sss_basis,
)
@verbose
def compute_fine_calibration(
raw,
n_imbalance=3,
t_window=10.0,
ext_order=2,
origin=(0.0, 0.0, 0.0),
cross_talk=None,
calibration=None,
*,
angle_limit=5.0,
err_limit=5.0,
verbose=None,
):
"""Compute fine calibration from empty-room data.
Parameters
----------
raw : instance of Raw
The raw data to use. Should be from an empty-room recording,
and all channels should be good.
n_imbalance : int
Can be 1 or 3 (default), indicating the number of gradiometer
imbalance components. Only used if gradiometers are present.
t_window : float
Time window to use for surface normal rotation in seconds.
Default is 10.
%(ext_order_maxwell)s
Default is 2, which is lower than the default (3) for
:func:`mne.preprocessing.maxwell_filter` because it tends to yield
more stable parameter estimates.
%(origin_maxwell)s
%(cross_talk_maxwell)s
calibration : dict | None
Dictionary with existing calibration. If provided, the magnetometer
imbalances and adjusted normals will be used and only the gradiometer
imbalances will be estimated (see step 2 in Notes below).
angle_limit : float
The maximum permitted angle in degrees between the original and adjusted
magnetometer normals. If the angle is exceeded, the segment is treated as
an outlier and discarded.
.. versionadded:: 1.9
err_limit : float
The maximum error (in percent) for each channel in order for a segment to
be used.
.. versionadded:: 1.9
%(verbose)s
Returns
-------
calibration : dict
Fine calibration data.
count : int
The number of good segments used to compute the magnetometer
parameters.
See Also
--------
mne.preprocessing.maxwell_filter
Notes
-----
This algorithm proceeds in two steps, both optimizing the fit between the
data and a reconstruction of the data based only on an external multipole
expansion:
1. Estimate magnetometer normal directions and scale factors. All
coils (mag and matching grad) are rotated by the adjusted normal
direction.
2. Estimate gradiometer imbalance factors. These add point magnetometers
in just the gradiometer difference direction or in all three directions
(depending on ``n_imbalance``).
Magnetometer normal and coefficient estimation (1) is typically the most
time consuming step. Gradiometer imbalance parameters (2) can be
iteratively reestimated (for example, first using ``n_imbalance=1`` then
subsequently ``n_imbalance=3``) by passing the previous ``calibration``
output to the ``calibration`` input in the second call.
MaxFilter processes at most 120 seconds of data, so consider cropping
your raw instance prior to processing. It also checks to make sure that
there were some minimal usable ``count`` number of segments (default 5)
that were included in the estimate.
.. versionadded:: 0.21
"""
n_imbalance = _ensure_int(n_imbalance, "n_imbalance")
_check_option("n_imbalance", n_imbalance, (1, 3))
_validate_type(raw, BaseRaw, "raw")
ext_order = _ensure_int(ext_order, "ext_order")
origin = _check_origin(origin, raw.info, "meg", disp=True)
_check_option("raw.info['bads']", raw.info["bads"], ([],))
_validate_type(err_limit, "numeric", "err_limit")
_validate_type(angle_limit, "numeric", "angle_limit")
for key, val in dict(err_limit=err_limit, angle_limit=angle_limit).items():
if val < 0:
raise ValueError(f"{key} must be greater than or equal to 0, got {val}")
# Fine cal should not include ref channels
picks = pick_types(raw.info, meg=True, ref_meg=False)
if raw.info["dev_head_t"] is not None:
raise ValueError(
'info["dev_head_t"] is not None, suggesting that the '
"data are not from an empty-room recording"
)
info = pick_info(raw.info, picks) # make a copy and pick MEG channels
mag_picks = pick_types(info, meg="mag", exclude=())
grad_picks = pick_types(info, meg="grad", exclude=())
# Get cross-talk
ctc, _ = _read_cross_talk(cross_talk, info["ch_names"])
# Check fine cal
_validate_type(calibration, (dict, None), "calibration")
#
# 1. Rotate surface normals using magnetometer information (if present)
#
cals = np.ones(len(info["ch_names"]))
time_idxs = raw.time_as_index(np.arange(0.0, raw.times[-1], t_window))
if len(time_idxs) <= 1:
time_idxs = np.array([0, len(raw.times)], int)
else:
time_idxs[-1] = len(raw.times)
count = 0
locs = np.array([ch["loc"] for ch in info["chs"]])
zs = locs[mag_picks, -3:].copy()
if calibration is not None:
_, calibration, _ = _prep_fine_cal(info, calibration, ignore_ref=True)
for pi, pick in enumerate(mag_picks):
idx = calibration["ch_names"].index(info["ch_names"][pick])
cals[pick] = calibration["imb_cals"][idx].item()
zs[pi] = calibration["locs"][idx][-3:]
elif len(mag_picks) > 0:
cal_list = list()
z_list = list()
logger.info(
f"Adjusting normals for {len(mag_picks)} magnetometers "
f"(averaging over {len(time_idxs) - 1} time intervals)"
)
for start, stop in zip(time_idxs[:-1], time_idxs[1:]):
logger.info(
f" Processing interval {start / info['sfreq']:0.3f} - "
f"{stop / info['sfreq']:0.3f} s"
)
data = raw[picks, start:stop][0]
if ctc is not None:
data = ctc.dot(data)
z, cal, good = _adjust_mag_normals(
info,
data,
origin,
ext_order,
angle_limit=angle_limit,
err_limit=err_limit,
)
if good:
z_list.append(z)
cal_list.append(cal)
count = len(cal_list)
if count == 0:
raise RuntimeError("No usable segments found")
cals[:] = np.mean(cal_list, axis=0)
zs[:] = np.mean(z_list, axis=0)
if len(mag_picks) > 0:
for ii, new_z in enumerate(zs):
z_loc = locs[mag_picks[ii]]
# Find sensors with same NZ and R0 (should be three for VV)
idxs = _matched_loc_idx(z_loc, locs)
# Rotate the direction vectors to the plane defined by new normal
_rotate_locs(locs, idxs, new_z)
for ci, loc in enumerate(locs):
info["chs"][ci]["loc"][:] = loc
del calibration, zs
#
# 2. Estimate imbalance parameters (always done)
#
if len(grad_picks) > 0:
extra = "X direction" if n_imbalance == 1 else ("XYZ directions")
logger.info(f"Computing imbalance for {len(grad_picks)} gradimeters ({extra})")
imb_list = list()
for start, stop in zip(time_idxs[:-1], time_idxs[1:]):
logger.info(
f" Processing interval {start / info['sfreq']:0.3f} - "
f"{stop / info['sfreq']:0.3f} s"
)
data = raw[picks, start:stop][0]
if ctc is not None:
data = ctc.dot(data)
out = _estimate_imbalance(info, data, cals, n_imbalance, origin, ext_order)
imb_list.append(out)
imb = np.mean(imb_list, axis=0)
else:
imb = np.zeros((len(info["ch_names"]), n_imbalance))
#
# Put in output structure
#
assert len(np.intersect1d(mag_picks, grad_picks)) == 0
imb_cals = [
cals[ii : ii + 1] if ii in mag_picks else imb[ii]
for ii in range(len(info["ch_names"]))
]
ch_names = _clean_names(info["ch_names"], remove_whitespace=True)
calibration = dict(ch_names=ch_names, locs=locs, imb_cals=imb_cals)
return calibration, count
def _matched_loc_idx(mag_loc, all_loc):
return np.where(
[
np.allclose(mag_loc[-3:], loc[-3:]) and np.allclose(mag_loc[:3], loc[:3])
for loc in all_loc
]
)[0]
def _rotate_locs(locs, idxs, new_z):
new_z = new_z / np.linalg.norm(new_z)
old_z = locs[idxs[0]][-3:]
old_z = old_z / np.linalg.norm(old_z)
rot = _find_vector_rotation(old_z, new_z)
for ci in idxs:
this_trans = _loc_to_coil_trans(locs[ci])
this_trans[:3, :3] = np.dot(rot, this_trans[:3, :3])
locs[ci][:] = _coil_trans_to_loc(this_trans)
np.testing.assert_allclose(locs[ci][-3:], new_z, atol=1e-4)
def _vector_angle(x, y):
"""Get the angle between two vectors in degrees."""
return np.abs(
np.arccos(
np.clip(
(x * y).sum(axis=-1)
/ (np.linalg.norm(x, axis=-1) * np.linalg.norm(y, axis=-1)),
-1,
1.0,
)
)
)
def _adjust_mag_normals(info, data, origin, ext_order, *, angle_limit, err_limit):
"""Adjust coil normals using magnetometers and empty-room data."""
# in principle we could allow using just mag or mag+grad, but MF uses
# just mag so let's follow suit
mag_scale = 100.0
picks_use = pick_types(info, meg="mag", exclude="bads")
picks_meg = pick_types(info, meg=True, exclude=())
picks_mag_orig = pick_types(info, meg="mag", exclude="bads")
info = pick_info(info, picks_use) # copy
data = data[picks_use]
cals = np.ones((len(data), 1))
angles = np.zeros(len(cals))
picks_mag = pick_types(info, meg="mag")
data[picks_mag] *= mag_scale
# Transform variables so we're only dealing with good mags
exp = dict(int_order=0, ext_order=ext_order, origin=origin)
all_coils = _prep_mf_coils(info, ignore_ref=True)
S_tot = _trans_sss_basis(exp, all_coils, coil_scale=mag_scale)
first_err = _data_err(data, S_tot, cals)
count = 0
# two passes: first do the worst, then do all in order
zs = np.array([ch["loc"][-3:] for ch in info["chs"]])
zs /= np.linalg.norm(zs, axis=-1, keepdims=True)
orig_zs = zs.copy()
match_idx = dict()
locs = np.array([ch["loc"] for ch in info["chs"]])
for pick in picks_mag:
match_idx[pick] = _matched_loc_idx(locs[pick], locs)
counts = defaultdict(lambda: 0)
for ki, kind in enumerate(("worst first", "in order")):
logger.info(f" Magnetometer normal adjustment ({kind}) ...")
S_tot = _trans_sss_basis(exp, all_coils, coil_scale=mag_scale)
for pick in picks_mag:
err = _data_err(data, S_tot, cals, axis=1)
# First pass: do worst; second pass: do all in order (up to 3x/sen)
if ki == 0:
order = list(np.argsort(err[picks_mag]))
cal_idx = 0
while len(order) > 0:
cal_idx = picks_mag[order.pop(-1)]
if counts[cal_idx] < 3:
break
if err[cal_idx] < 2.5:
break # move on to second loop
else:
cal_idx = pick
counts[cal_idx] += 1
assert cal_idx in picks_mag
count += 1
old_z = zs[cal_idx].copy()
objective = partial(
_cal_sss_target,
old_z=old_z,
all_coils=all_coils,
cal_idx=cal_idx,
data=data,
cals=cals,
match_idx=match_idx,
S_tot=S_tot,
origin=origin,
ext_order=ext_order,
)
# Figure out the additive term for z-component
zs[cal_idx] = minimize(
objective,
old_z,
bounds=[(-2, 2)] * 3,
# BFGS is the default for minimize but COBYLA converges faster
method="COBYLA",
# Start with a small relative step because nominal geometry information
# should be fairly accurate to begin with
options=dict(rhobeg=1e-1),
).x
# Do in-place adjustment to all_coils
cals[cal_idx] = 1.0 / np.linalg.norm(zs[cal_idx])
zs[cal_idx] *= cals[cal_idx]
for idx in match_idx[cal_idx]:
_rotate_coil(zs[cal_idx], old_z, all_coils, idx, inplace=True)
# Recalculate S_tot, taking into account rotations
S_tot = _trans_sss_basis(exp, all_coils)
# Reprt results
old_err = err[cal_idx]
new_err = _data_err(data, S_tot, cals, idx=cal_idx)
angles[cal_idx] = np.abs(
np.rad2deg(_vector_angle(zs[cal_idx], orig_zs[cal_idx]))
)
ch_name = info["ch_names"][cal_idx]
logger.debug(
f" Optimization step {count:3d} "
f"{ch_name} ({counts[cal_idx]}) "
f"res {old_err:5.2f}{new_err:5.2f}% "
f"×{cals[cal_idx, 0]:0.3f} {angles[cal_idx]:0.2f}°"
)
last_err = _data_err(data, S_tot, cals)
# Chunk is usable if all angles and errors are both small
reason = list()
max_angle = np.max(angles)
if max_angle >= angle_limit:
reason.append(f"max angle {max_angle:0.2f} >= {angle_limit:0.1f}°")
each_err = _data_err(data, S_tot, cals, axis=-1)[picks_mag]
n_bad = (each_err > err_limit).sum()
if n_bad:
reason.append(
f"{n_bad} residual{_pl(n_bad)} > {err_limit:0.1f}% "
f"(max: {each_err.max():0.2f}%)"
)
reason = ", ".join(reason)
if reason:
reason = f" ({reason})"
good = not bool(reason)
assert np.allclose(np.linalg.norm(zs, axis=1), 1.0)
logger.info(f" Fit mismatch {first_err:0.2f}{last_err:0.2f}%")
logger.info(f' Data segment {"" if good else "un"}usable{reason}')
# Reformat zs and cals to be the n_mags (including bads)
assert zs.shape == (len(data), 3)
assert cals.shape == (len(data), 1)
imb_cals = np.ones(len(picks_meg))
imb_cals[picks_mag_orig] = cals[:, 0]
return zs, imb_cals, good
def _data_err(data, S_tot, cals, idx=None, axis=None):
if idx is None:
idx = slice(None)
S_tot = S_tot / cals
data_model = np.dot(np.dot(S_tot[idx], _col_norm_pinv(S_tot.copy())[0]), data)
err = 100 * (
np.linalg.norm(data_model - data[idx], axis=axis)
/ np.linalg.norm(data[idx], axis=axis)
)
return err
def _rotate_coil(new_z, old_z, all_coils, idx, inplace=False):
"""Adjust coils."""
# Turn NX and NY to the plane determined by NZ
old_z = old_z / np.linalg.norm(old_z)
new_z = new_z / np.linalg.norm(new_z)
rot = _find_vector_rotation(old_z, new_z) # additional coil rotation
this_sl = all_coils[5][idx]
this_rmag = np.dot(rot, all_coils[0][this_sl].T).T
this_cosmag = np.dot(rot, all_coils[1][this_sl].T).T
if inplace:
all_coils[0][this_sl] = this_rmag
all_coils[1][this_sl] = this_cosmag
subset = (
this_rmag,
this_cosmag,
np.zeros(this_rmag.shape[0], int),
1,
all_coils[4][[idx]],
{0: this_sl},
)
return subset
def _cal_sss_target(
new_z, old_z, all_coils, cal_idx, data, cals, S_tot, origin, ext_order, match_idx
):
"""Evaluate objective function for SSS-based magnetometer calibration."""
cals[cal_idx] = 1.0 / np.linalg.norm(new_z)
exp = dict(int_order=0, ext_order=ext_order, origin=origin)
S_tot = S_tot.copy()
# Rotate necessary coils properly and adjust correct element in c
for idx in match_idx[cal_idx]:
this_coil = _rotate_coil(new_z, old_z, all_coils, idx)
# Replace correct row of S_tot with new value
S_tot[idx] = _trans_sss_basis(exp, this_coil)
# Get the GOF
return _data_err(data, S_tot, cals, idx=cal_idx)
def _estimate_imbalance(info, data, cals, n_imbalance, origin, ext_order):
"""Estimate gradiometer imbalance parameters."""
mag_scale = 100.0
n_iterations = 3
mag_picks = pick_types(info, meg="mag", exclude=())
grad_picks = pick_types(info, meg="grad", exclude=())
data = data.copy()
data[mag_picks, :] *= mag_scale
del mag_picks
grad_imb = np.zeros((len(grad_picks), n_imbalance))
exp = dict(origin=origin, int_order=0, ext_order=ext_order)
all_coils = _prep_mf_coils(info, ignore_ref=True)
grad_point_coils = _get_grad_point_coilsets(info, n_imbalance, ignore_ref=True)
S_orig = _trans_sss_basis(exp, all_coils, coil_scale=mag_scale)
S_orig /= cals[:, np.newaxis]
# Compute point gradiometers for each grad channel
this_cs = np.array([mag_scale], float)
S_pt = np.array(
[_trans_sss_basis(exp, coils, None, this_cs) for coils in grad_point_coils]
)
for k in range(n_iterations):
S_tot = S_orig.copy()
# In theory we could zero out the homogeneous components with:
# S_tot[grad_picks, :3] = 0
# But in practice it doesn't seem to matter
S_recon = S_tot[grad_picks]
# Add influence of point magnetometers
S_tot[grad_picks, :] += np.einsum("ij,ijk->jk", grad_imb.T, S_pt)
# Compute multipolar moments
mm = np.dot(_col_norm_pinv(S_tot.copy())[0], data)
# Use good channels to recalculate
prev_imb = grad_imb.copy()
data_recon = np.dot(S_recon, mm)
assert S_pt.shape == (n_imbalance, len(grad_picks), S_tot.shape[1])
khi_pts = (S_pt @ mm).transpose(1, 2, 0)
assert khi_pts.shape == (len(grad_picks), data.shape[1], n_imbalance)
residual = data[grad_picks] - data_recon
assert residual.shape == (len(grad_picks), data.shape[1])
d = (residual[:, np.newaxis, :] @ khi_pts)[:, 0]
assert d.shape == (len(grad_picks), n_imbalance)
dinv, _, _ = _reg_pinv(khi_pts.swapaxes(-1, -2) @ khi_pts, rcond=1e-6)
assert dinv.shape == (len(grad_picks), n_imbalance, n_imbalance)
grad_imb[:] = (d[:, np.newaxis] @ dinv)[:, 0]
# This code is equivalent but hits a np.linalg.pinv bug on old NumPy:
# grad_imb[:] = np.sum( # dot product across the time dim
# np.linalg.pinv(khi_pts) * residual[:, np.newaxis], axis=-1)
deltas = np.linalg.norm(grad_imb - prev_imb) / max(
np.linalg.norm(grad_imb), np.linalg.norm(prev_imb)
)
logger.debug(
f" Iteration {k + 1}/{n_iterations}: "
f"max ∆ = {100 * deltas.max():7.3f}%"
)
imb = np.zeros((len(data), n_imbalance))
imb[grad_picks] = grad_imb
return imb
def read_fine_calibration(fname):
"""Read fine calibration information from a ``.dat`` file.
The fine calibration typically includes improved sensor locations,
calibration coefficients, and gradiometer imbalance information.
Parameters
----------
fname : path-like
The filename.
Returns
-------
calibration : dict
Fine calibration information. Key-value pairs are:
- ``ch_names``
List of str of the channel names.
- ``locs``
Coil location and orientation parameters.
- ``imb_cals``
For magnetometers, the calibration coefficients.
For gradiometers, one or three imbalance parameters.
"""
# Read new sensor locations
fname = _check_fname(fname, overwrite="read", must_exist=True)
check_fname(fname, "cal", (".dat",))
ch_names, locs, imb_cals = list(), list(), list()
with open(fname) as fid:
for line in fid:
if line[0] in "#\n":
continue
vals = line.strip().split()
if len(vals) not in [14, 16]:
raise RuntimeError(
"Error parsing fine calibration file, "
"should have 14 or 16 entries per line "
f"but found {len(vals)} on line:\n{line}"
)
# `vals` contains channel number
ch_name = vals[0]
if len(ch_name) in (3, 4): # heuristic for Neuromag fix
try:
ch_name = int(ch_name)
except ValueError: # something other than e.g. 113 or 2642
pass
else:
ch_name = f"MEG{int(ch_name):04}"
# (x, y, z), x-norm 3-vec, y-norm 3-vec, z-norm 3-vec
# and 1 or 3 imbalance terms
ch_names.append(ch_name)
locs.append(np.array(vals[1:13], float))
imb_cals.append(np.array(vals[13:], float))
locs = np.array(locs)
return dict(ch_names=ch_names, locs=locs, imb_cals=imb_cals)
def write_fine_calibration(fname, calibration):
"""Write fine calibration information to a ``.dat`` file.
Parameters
----------
fname : path-like
The filename to write out.
calibration : dict
Fine calibration information.
"""
fname = _check_fname(fname, overwrite=True)
check_fname(fname, "cal", (".dat",))
keys = ("ch_names", "locs", "imb_cals")
with open(fname, "wb") as cal_file:
for ch_name, loc, imb_cal in zip(*(calibration[key] for key in keys)):
cal_line = np.concatenate([loc, imb_cal]).round(6)
cal_line = " ".join(f"{c:0.6f}" for c in cal_line)
cal_file.write(f"{ch_name} {cal_line}\n".encode("ASCII"))

98
mne/preprocessing/_lof.py Normal file
View File

@@ -0,0 +1,98 @@
"""Bad channel detection using Local Outlier Factor (LOF)."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from .._fiff.pick import _picks_to_idx
from ..io.base import BaseRaw
from ..utils import _soft_import, _validate_type, logger, verbose
@verbose
def find_bad_channels_lof(
raw,
n_neighbors=20,
*,
picks=None,
metric="euclidean",
threshold=1.5,
return_scores=False,
verbose=None,
):
"""Find bad channels using Local Outlier Factor (LOF) algorithm.
Parameters
----------
raw : instance of Raw
Raw data to process.
n_neighbors : int
Number of neighbors defining the local neighborhood (default is 20).
Smaller values will lead to higher LOF scores.
%(picks_good_data)s
metric : str
Metric to use for distance computation. Default is “euclidean”,
see :func:`sklearn.metrics.pairwise.distance_metrics` for details.
threshold : float
Threshold to define outliers. Theoretical threshold ranges anywhere
between 1.0 and any positive integer. Default: 1.5
It is recommended to consider this as an hyperparameter to optimize.
return_scores : bool
If ``True``, return a dictionary with LOF scores for each
evaluated channel. Default is ``False``.
%(verbose)s
Returns
-------
noisy_chs : list
List of bad M/EEG channels that were automatically detected.
scores : ndarray, shape (n_picks,)
Only returned when ``return_scores`` is ``True``. It contains the
LOF outlier score for each channel in ``picks``.
See Also
--------
maxwell_filter
annotate_amplitude
Notes
-----
See :footcite:`KumaravelEtAl2022` and :footcite:`BreunigEtAl2000` for background on
choosing ``threshold``.
.. versionadded:: 1.7
References
----------
.. footbibliography::
""" # noqa: E501
_soft_import("sklearn", "using LOF detection", strict=True)
from sklearn.neighbors import LocalOutlierFactor
_validate_type(raw, BaseRaw, "raw")
# Get the channel types
channel_types = raw.get_channel_types()
picks = _picks_to_idx(raw.info, picks=picks, none="data", exclude="bads")
picked_ch_types = set(channel_types[p] for p in picks)
# Check if there are different channel types
if len(picked_ch_types) != 1:
raise ValueError(
f"Need exactly one channel type in picks, got {sorted(picked_ch_types)}"
)
ch_names = [raw.ch_names[pick] for pick in picks]
data = raw.get_data(picks=picks)
clf = LocalOutlierFactor(n_neighbors=n_neighbors, metric=metric)
clf.fit_predict(data)
scores_lof = clf.negative_outlier_factor_
bad_channel_indices = [
i for i, v in enumerate(np.abs(scores_lof)) if v >= threshold
]
bads = [ch_names[idx] for idx in bad_channel_indices]
logger.info(f"LOF: Detected bad channel(s): {bads}")
if return_scores:
return bads, scores_lof
else:
return bads

View File

@@ -0,0 +1,184 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from ..utils import _pl, logger, verbose
@verbose
def peak_finder(x0, thresh=None, extrema=1, verbose=None):
"""Noise-tolerant fast peak-finding algorithm.
Parameters
----------
x0 : 1d array
A real vector from the maxima will be found (required).
thresh : float | None
The amount above surrounding data for a peak to be
identified. Larger values mean the algorithm is more selective in
finding peaks. If ``None``, use the default of
``(max(x0) - min(x0)) / 4``.
extrema : {-1, 1}
1 if maxima are desired, -1 if minima are desired
(default = maxima, 1).
%(verbose)s
Returns
-------
peak_loc : array
The indices of the identified peaks in x0.
peak_mag : array
The magnitude of the identified peaks.
Notes
-----
If repeated values are found the first is identified as the peak.
Conversion from initial Matlab code from:
Nathanael C. Yoder (ncyoder@purdue.edu)
Examples
--------
>>> import numpy as np
>>> from mne.preprocessing import peak_finder
>>> t = np.arange(0, 3, 0.01)
>>> x = np.sin(np.pi*t) - np.sin(0.5*np.pi*t)
>>> peak_locs, peak_mags = peak_finder(x) # doctest: +SKIP
>>> peak_locs # doctest: +SKIP
array([36, 260]) # doctest: +SKIP
>>> peak_mags # doctest: +SKIP
array([0.36900026, 1.76007351]) # doctest: +SKIP
"""
x0 = np.asanyarray(x0)
s = x0.size
if x0.ndim >= 2 or s == 0:
raise ValueError("The input data must be a non empty 1D vector")
if thresh is None:
thresh = (np.max(x0) - np.min(x0)) / 4
logger.debug(f"Peak finder automatic threshold: {thresh:0.2g}")
assert extrema in [-1, 1]
if extrema == -1:
x0 = extrema * x0 # Make it so we are finding maxima regardless
dx0 = np.diff(x0) # Find derivative
# This is so we find the first of repeated values
dx0[dx0 == 0] = -np.finfo(float).eps
# Find where the derivative changes sign
ind = np.where(dx0[:-1:] * dx0[1::] < 0)[0] + 1
# Include endpoints in potential peaks and valleys
x = np.concatenate((x0[:1], x0[ind], x0[-1:]))
ind = np.concatenate(([0], ind, [s - 1]))
del x0
# x only has the peaks, valleys, and endpoints
length = x.size
min_mag = np.min(x)
if length > 2: # Function with peaks and valleys
# Set initial parameters for loop
temp_mag = min_mag
found_peak = False
left_min = min_mag
# Deal with first point a little differently since tacked it on
# Calculate the sign of the derivative since we took the first point
# on it does not necessarily alternate like the rest.
signDx = np.sign(np.diff(x[:3]))
if signDx[0] <= 0: # The first point is larger or equal to the second
ii = -1
if signDx[0] == signDx[1]: # Want alternating signs
x = np.concatenate((x[:1], x[2:]))
ind = np.concatenate((ind[:1], ind[2:]))
length -= 1
else: # First point is smaller than the second
ii = 0
if signDx[0] == signDx[1]: # Want alternating signs
x = x[1:]
ind = ind[1:]
length -= 1
# Preallocate max number of maxima
maxPeaks = int(np.ceil(length / 2.0))
peak_loc = np.zeros(maxPeaks, dtype=np.int64)
peak_mag = np.zeros(maxPeaks)
c_ind = 0
# Loop through extrema which should be peaks and then valleys
while ii < (length - 1):
ii += 1 # This is a peak
# Reset peak finding if we had a peak and the next peak is bigger
# than the last or the left min was small enough to reset.
if found_peak and (
(x[ii] > peak_mag[-1]) or (left_min < peak_mag[-1] - thresh)
):
temp_mag = min_mag
found_peak = False
# Make sure we don't iterate past the length of our vector
if ii == length - 1:
break # We assign the last point differently out of the loop
# Found new peak that was lager than temp mag and threshold larger
# than the minimum to its left.
if (x[ii] > temp_mag) and (x[ii] > left_min + thresh):
temp_loc = ii
temp_mag = x[ii]
ii += 1 # Move onto the valley
# Come down at least thresh from peak
if not found_peak and (temp_mag > (thresh + x[ii])):
found_peak = True # We have found a peak
left_min = x[ii]
peak_loc[c_ind] = temp_loc # Add peak to index
peak_mag[c_ind] = temp_mag
c_ind += 1
elif x[ii] < left_min: # New left minima
left_min = x[ii]
# Check end point
if (x[-1] > temp_mag) and (x[-1] > (left_min + thresh)):
peak_loc[c_ind] = length - 1
peak_mag[c_ind] = x[-1]
c_ind += 1
elif not found_peak and temp_mag > min_mag:
# Check if we still need to add the last point
peak_loc[c_ind] = temp_loc
peak_mag[c_ind] = temp_mag
c_ind += 1
# Create output
peak_inds = ind[peak_loc[:c_ind]]
peak_mags = peak_mag[:c_ind]
else: # This is a monotone function where an endpoint is the only peak
x_ind = np.argmax(x)
peak_mags = x[x_ind]
if peak_mags > (min_mag + thresh):
peak_inds = ind[x_ind]
else:
peak_mags = []
peak_inds = []
# Change sign of data if was finding minima
if extrema < 0:
peak_mags *= -1.0
# ensure output type array
if not isinstance(peak_inds, np.ndarray):
peak_inds = np.atleast_1d(peak_inds).astype("int64")
if not isinstance(peak_mags, np.ndarray):
peak_mags = np.atleast_1d(peak_mags).astype("float64")
# Plot if no output desired
if len(peak_inds) == 0:
logger.info("No significant peaks found")
else:
logger.info(f"Found {len(peak_inds)} significant peak{_pl(peak_inds)}")
return peak_inds, peak_mags

View File

@@ -0,0 +1,390 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from .._fiff.pick import _picks_to_idx, pick_info
from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT
from ..epochs import BaseEpochs
from ..evoked import Evoked
from ..io import BaseRaw
from ..minimum_norm.inverse import _needs_eeg_average_ref_proj
from ..utils import (
_check_fname,
_check_option,
_check_preload,
_import_h5io_funcs,
_validate_type,
copy_function_doc_to_method_doc,
fill_doc,
verbose,
)
from ..viz import plot_regression_weights
@verbose
def regress_artifact(
inst,
picks=None,
*,
exclude="bads",
picks_artifact="eog",
betas=None,
proj=True,
copy=True,
verbose=None,
):
"""Remove artifacts using regression based on reference channels.
Parameters
----------
inst : instance of Epochs | Raw
The instance to process.
%(picks_good_data)s
exclude : list | 'bads'
List of channels to exclude from the regression, only used when picking
based on types (e.g., exclude="bads" when picks="meg").
Specify ``'bads'`` (the default) to exclude all channels marked as bad.
.. versionadded:: 1.2
picks_artifact : array-like | str
Channel picks to use as predictor/explanatory variables capturing
the artifact of interest (default is "eog").
betas : ndarray, shape (n_picks, n_picks_ref) | None
The regression coefficients to use. If None (default), they will be
estimated from the data.
proj : bool
Whether to automatically apply SSP projection vectors before performing
the regression. Default is ``True``.
copy : bool
If True (default), copy the instance before modifying it.
%(verbose)s
Returns
-------
inst : instance of Epochs | Raw
The processed data.
betas : ndarray, shape (n_picks, n_picks_ref)
The betas used during regression.
Notes
-----
To implement the method outlined in :footcite:`GrattonEtAl1983`,
remove the evoked response from epochs before estimating the
regression coefficients, then apply those regression coefficients to the
original data in two calls like (here for a single-condition ``epochs``
only):
>>> epochs_no_ave = epochs.copy().subtract_evoked() # doctest:+SKIP
>>> _, betas = mne.preprocessing.regress(epochs_no_ave) # doctest:+SKIP
>>> epochs_clean, _ = mne.preprocessing.regress(epochs, betas=betas) # doctest:+SKIP
References
----------
.. footbibliography::
""" # noqa: E501
if betas is None:
model = EOGRegression(
picks=picks, exclude=exclude, picks_artifact=picks_artifact, proj=proj
)
model.fit(inst)
else:
# Create an EOGRegression object and load the given betas into it.
picks = _picks_to_idx(inst.info, picks, exclude=exclude, none="data")
picks_artifact = _picks_to_idx(inst.info, picks_artifact)
want_betas_shape = (len(picks), len(picks_artifact))
_check_option("betas.shape", betas.shape, (want_betas_shape,))
model = EOGRegression(picks, picks_artifact, proj=proj)
model.info_ = inst.info.copy()
model.coef_ = betas
return model.apply(inst, copy=copy), model.coef_
@fill_doc
class EOGRegression:
"""Remove EOG artifact signals from other channels by regression.
Employs linear regression to remove signals captured by some channels,
typically EOG, as described in :footcite:`GrattonEtAl1983`. You can also
choose to fit the regression coefficients on evoked blink/saccade data and
then apply them to continuous data, as described in
:footcite:`CroftBarry2000`.
Parameters
----------
%(picks_good_data)s
exclude : list | 'bads'
List of channels to exclude from the regression, only used when picking
based on types (e.g., exclude="bads" when picks="meg").
Specify ``'bads'`` (the default) to exclude all channels marked as bad.
picks_artifact : array-like | str
Channel picks to use as predictor/explanatory variables capturing
the artifact of interest (default is "eog").
proj : bool
Whether to automatically apply SSP projection vectors before fitting
and applying the regression. Default is ``True``.
Attributes
----------
coef_ : ndarray, shape (n, n)
The regression coefficients. Only available after fitting.
info_ : Info
Channel information corresponding to the regression weights.
Only available after fitting.
picks : array-like | str
Channels to perform the regression on.
exclude : list | 'bads'
Channels to exclude from the regression.
picks_artifact : array-like | str
The channels designated as containing the artifacts of interest.
proj : bool
Whether projections will be applied before performing the regression.
Notes
-----
.. versionadded:: 1.2
References
----------
.. footbibliography::
"""
def __init__(self, picks=None, exclude="bads", picks_artifact="eog", proj=True):
self.picks = picks
self.exclude = exclude
self.picks_artifact = picks_artifact
self.proj = proj
def fit(self, inst):
"""Fit EOG regression coefficients.
Parameters
----------
inst : Raw | Epochs | Evoked
The data on which the EOG regression weights should be fitted.
Returns
-------
self : EOGRegression
The fitted ``EOGRegression`` object. The regression coefficients
are available as the ``.coef_`` and ``.intercept_`` attributes.
Notes
-----
If your data contains EEG channels, make sure to apply the desired
reference (see :func:`mne.set_eeg_reference`) before performing EOG
regression.
"""
picks, picks_artifact = self._check_inst(inst)
# Calculate regression coefficients. Add a row of ones to also fit the
# intercept.
_check_preload(inst, "artifact regression")
artifact_data = inst._data[..., picks_artifact, :]
ref_data = artifact_data - np.mean(artifact_data, axis=-1, keepdims=True)
if ref_data.ndim == 3:
ref_data = ref_data.transpose(1, 0, 2)
ref_data = ref_data.reshape(len(picks_artifact), -1)
cov_ref = ref_data @ ref_data.T
# Process each channel separately to reduce memory load
coef = np.zeros((len(picks), len(picks_artifact)))
for pi, pick in enumerate(picks):
this_data = inst._data[..., pick, :] # view
# Subtract mean over time from every trial/channel
cov_data = this_data - np.mean(this_data, -1, keepdims=True)
cov_data = cov_data.reshape(1, -1)
# Perform the linear regression
coef[pi] = np.linalg.solve(cov_ref, ref_data @ cov_data.T).T[0]
# Store relevant parameters in the object.
self.coef_ = coef
self.info_ = inst.info.copy()
return self
@fill_doc
def apply(self, inst, copy=True):
"""Apply the regression coefficients to data.
Parameters
----------
inst : Raw | Epochs | Evoked
The data on which to apply the regression.
%(copy_df)s
Returns
-------
inst : Raw | Epochs | Evoked
A version of the data with the artifact channels regressed out.
Notes
-----
Only works after ``.fit()`` has been used.
References
----------
.. footbibliography::
"""
if copy:
inst = inst.copy()
picks, picks_artifact = self._check_inst(inst)
# Check that the channels are compatible with the regression weights.
ref_picks = _picks_to_idx(
self.info_, self.picks, none="data", exclude=self.exclude
)
ref_picks_artifact = _picks_to_idx(self.info_, self.picks_artifact)
if any(
inst.ch_names[ch1] != self.info_["chs"][ch2]["ch_name"]
for ch1, ch2 in zip(picks, ref_picks)
):
raise ValueError(
"Selected data channels are not compatible with "
"the regression weights. Make sure that all data "
"channels are present and in the correct order."
)
if any(
inst.ch_names[ch1] != self.info_["chs"][ch2]["ch_name"]
for ch1, ch2 in zip(picks_artifact, ref_picks_artifact)
):
raise ValueError(
"Selected artifact channels are not compatible "
"with the regression weights. Make sure that all "
"artifact channels are present and in the "
"correct order."
)
_check_preload(inst, "artifact regression")
artifact_data = inst._data[..., picks_artifact, :]
ref_data = artifact_data - np.mean(artifact_data, -1, keepdims=True)
for pi, pick in enumerate(picks):
this_data = inst._data[..., pick, :] # view
this_data -= (self.coef_[pi] @ ref_data).reshape(this_data.shape)
return inst
@copy_function_doc_to_method_doc(plot_regression_weights)
def plot(
self,
ch_type=None,
sensors=True,
show_names=False,
mask=None,
mask_params=None,
contours=6,
outlines="head",
sphere=None,
image_interp=_INTERPOLATION_DEFAULT,
extrapolate=_EXTRAPOLATE_DEFAULT,
border=_BORDER_DEFAULT,
res=64,
size=1,
cmap=None,
vlim=(None, None),
cnorm=None,
axes=None,
colorbar=True,
cbar_fmt="%1.1e",
title=None,
show=True,
):
return plot_regression_weights(
self,
ch_type=ch_type,
sensors=sensors,
show_names=show_names,
mask=mask,
mask_params=mask_params,
contours=contours,
outlines=outlines,
sphere=sphere,
image_interp=image_interp,
extrapolate=extrapolate,
border=border,
res=res,
size=size,
cmap=cmap,
vlim=vlim,
cnorm=cnorm,
axes=axes,
colorbar=colorbar,
cbar_fmt=cbar_fmt,
title=title,
show=show,
)
def _check_inst(self, inst):
"""Perform some sanity checks on the input."""
_validate_type(
inst, (BaseRaw, BaseEpochs, Evoked), "inst", "Raw, Epochs, Evoked"
)
picks = _picks_to_idx(inst.info, self.picks, none="data", exclude=self.exclude)
picks_artifact = _picks_to_idx(inst.info, self.picks_artifact)
all_picks = np.unique(np.concatenate([picks, picks_artifact]))
use_info = pick_info(inst.info, all_picks)
del all_picks
if _needs_eeg_average_ref_proj(use_info):
raise RuntimeError(
"No average reference for the EEG channels has been "
"set. Use inst.set_eeg_reference(projection=True) to do so."
)
if self.proj and not inst.proj:
inst.apply_proj()
if not inst.proj and len(use_info.get("projs", [])) > 0:
raise RuntimeError(
"Projections need to be applied before "
"regression can be performed. Use the "
".apply_proj() method to do so."
)
return picks, picks_artifact
def __repr__(self):
"""Produce a string representation of this object."""
s = "<EOGRegression | "
if hasattr(self, "coef_"):
n_art = self.coef_.shape[1]
plural = "s" if n_art > 1 else ""
s += f"fitted to {n_art} artifact channel{plural}>"
else:
s += "not fitted>"
return s
@fill_doc
def save(self, fname, overwrite=False):
"""Save the regression model to an HDF5 file.
Parameters
----------
fname : path-like
The file to write the regression weights to. Should end in ``.h5``.
%(overwrite)s
"""
_, write_hdf5 = _import_h5io_funcs()
_validate_type(fname, "path-like", "fname")
fname = _check_fname(fname, overwrite=overwrite, name="fname")
write_hdf5(fname, self.__dict__, overwrite=overwrite)
def read_eog_regression(fname):
"""Read an EOG regression model from an HDF5 file.
Parameters
----------
fname : path-like
The file to read the regression model from. Should end in ``.h5``.
Returns
-------
model : EOGRegression
The regression model read from the file.
Notes
-----
.. versionadded:: 1.2
"""
read_hdf5, _ = _import_h5io_funcs()
_validate_type(fname, "path-like", "fname")
fname = _check_fname(fname, overwrite="read", must_exist=True, name="fname")
model = EOGRegression()
model.__dict__.update(read_hdf5(fname))
return model

View File

@@ -0,0 +1,655 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from scipy.ndimage import distance_transform_edt, label
from scipy.signal import find_peaks
from scipy.stats import zscore
from ..annotations import (
Annotations,
_adjust_onset_meas_date,
_annotations_starts_stops,
annotations_from_events,
)
from ..filter import filter_data
from ..io.base import BaseRaw
from ..transforms import (
Transform,
_angle_between_quats,
_average_quats,
_quat_to_affine,
apply_trans,
quat_to_rot,
)
from ..utils import (
_check_option,
_mask_to_onsets_offsets,
_pl,
_validate_type,
logger,
verbose,
warn,
)
@verbose
def annotate_muscle_zscore(
raw,
threshold=4,
ch_type=None,
min_length_good=0.1,
filter_freq=(110, 140),
n_jobs=None,
verbose=None,
):
"""Create annotations for segments that likely contain muscle artifacts.
Detects data segments containing activity in the frequency range given by
``filter_freq`` whose envelope magnitude exceeds the specified z-score
threshold, when summed across channels and divided by ``sqrt(n_channels)``.
False-positive transient peaks are prevented by low-pass filtering the
resulting z-score time series at 4 Hz. Only operates on a single channel
type, if ``ch_type`` is ``None`` it will select the first type in the list
``mag``, ``grad``, ``eeg``.
See :footcite:`Muthukumaraswamy2013` for background on choosing
``filter_freq`` and ``threshold``.
Parameters
----------
raw : instance of Raw
Data to estimate segments with muscle artifacts.
threshold : float
The threshold in z-scores for marking segments as containing muscle
activity artifacts.
ch_type : 'mag' | 'grad' | 'eeg' | None
The type of sensors to use. If ``None`` it will take the first type in
``mag``, ``grad``, ``eeg``.
min_length_good : float | None
The shortest allowed duration of "good data" (in seconds) between
adjacent annotations; shorter segments will be incorporated into the
surrounding annotations.``None`` is equivalent to ``0``.
Default is ``0.1``.
filter_freq : array-like, shape (2,)
The lower and upper frequencies of the band-pass filter.
Default is ``(110, 140)``.
%(n_jobs)s
%(verbose)s
Returns
-------
annot : mne.Annotations
Periods with muscle artifacts annotated as BAD_muscle.
scores_muscle : array
Z-score values averaged across channels for each sample.
References
----------
.. footbibliography::
"""
raw_copy = raw.copy()
if ch_type is None:
raw_ch_type = raw_copy.get_channel_types()
if "mag" in raw_ch_type:
ch_type = "mag"
elif "grad" in raw_ch_type:
ch_type = "grad"
elif "eeg" in raw_ch_type:
ch_type = "eeg"
else:
raise ValueError(
"No M/EEG channel types found, please specify a 'ch_type' or provide "
"M/EEG sensor data."
)
logger.info("Using %s sensors for muscle artifact detection", ch_type)
else:
_check_option("ch_type", ch_type, ["mag", "grad", "eeg"])
raw_copy.pick(ch_type)
raw_copy.filter(
filter_freq[0],
filter_freq[1],
fir_design="firwin",
pad="reflect_limited",
n_jobs=n_jobs,
)
raw_copy.apply_hilbert(envelope=True, n_jobs=n_jobs)
data = raw_copy.get_data(reject_by_annotation="NaN")
nan_mask = ~np.isnan(data[0])
sfreq = raw_copy.info["sfreq"]
art_scores = zscore(data[:, nan_mask], axis=1)
art_scores = art_scores.sum(axis=0) / np.sqrt(art_scores.shape[0])
art_scores = filter_data(art_scores, sfreq, None, 4)
scores_muscle = np.zeros(data.shape[1])
scores_muscle[nan_mask] = art_scores
art_mask = scores_muscle > threshold
# return muscle scores with NaNs
scores_muscle[~nan_mask] = np.nan
# remove artifact free periods shorter than min_length_good
min_length_good = 0 if min_length_good is None else min_length_good
min_samps = min_length_good * sfreq
comps, num_comps = label(art_mask == 0)
for com in range(1, num_comps + 1):
l_idx = np.nonzero(comps == com)[0]
if len(l_idx) < min_samps:
art_mask[l_idx] = True
annot = _annotations_from_mask(
raw_copy.times, art_mask, "BAD_muscle", orig_time=raw.info["meas_date"]
)
_adjust_onset_meas_date(annot, raw)
return annot, scores_muscle
def annotate_movement(
raw,
pos,
rotation_velocity_limit=None,
translation_velocity_limit=None,
mean_distance_limit=None,
use_dev_head_trans="average",
):
"""Detect segments with movement.
Detects segments periods further from rotation_velocity_limit,
translation_velocity_limit and mean_distance_limit. It returns an
annotation with the bad segments.
Parameters
----------
raw : instance of Raw
Data to compute head position.
pos : array, shape (N, 10)
The position and quaternion parameters from cHPI fitting. Obtained
with `mne.chpi` functions.
rotation_velocity_limit : float
Head rotation velocity limit in degrees per second.
translation_velocity_limit : float
Head translation velocity limit in meters per second.
mean_distance_limit : float
Head position limit from mean recording in meters.
use_dev_head_trans : 'average' (default) | 'info'
Identify the device to head transform used to define the
fixed HPI locations for computing moving distances.
If ``average`` the average device to head transform is
computed using ``compute_average_dev_head_t``.
If ``info``, ``raw.info['dev_head_t']`` is used.
Returns
-------
annot : mne.Annotations
Periods with head motion.
hpi_disp : array
Head position over time with respect to the mean head pos.
See Also
--------
compute_average_dev_head_t
"""
sfreq = raw.info["sfreq"]
hp_ts = pos[:, 0].copy() - raw.first_time
dt = np.diff(hp_ts)
hp_ts = np.concatenate([hp_ts, [hp_ts[-1] + 1.0 / sfreq]])
orig_time = raw.info["meas_date"]
annot = Annotations([], [], [], orig_time=orig_time)
# Annotate based on rotational velocity
t_tot = raw.times[-1]
if rotation_velocity_limit is not None:
assert rotation_velocity_limit > 0
# Rotational velocity (radians / s)
r = _angle_between_quats(pos[:-1, 1:4], pos[1:, 1:4])
r /= dt
bad_mask = r >= np.deg2rad(rotation_velocity_limit)
onsets, offsets = _mask_to_onsets_offsets(bad_mask)
onsets, offsets = hp_ts[onsets], hp_ts[offsets]
bad_pct = 100 * (offsets - onsets).sum() / t_tot
logger.info(
"Omitting %5.1f%% (%3d segments): " "ω >= %5.1f°/s (max: %0.1f°/s)",
bad_pct,
len(onsets),
rotation_velocity_limit,
np.rad2deg(r.max()),
)
annot += _annotations_from_mask(
hp_ts, bad_mask, "BAD_mov_rotat_vel", orig_time=orig_time
)
# Annotate based on translational velocity limit
if translation_velocity_limit is not None:
assert translation_velocity_limit > 0
v = np.linalg.norm(np.diff(pos[:, 4:7], axis=0), axis=-1)
v /= dt
bad_mask = v >= translation_velocity_limit
onsets, offsets = _mask_to_onsets_offsets(bad_mask)
onsets, offsets = hp_ts[onsets], hp_ts[offsets]
bad_pct = 100 * (offsets - onsets).sum() / t_tot
logger.info(
"Omitting %5.1f%% (%3d segments): " "v >= %5.4fm/s (max: %5.4fm/s)",
bad_pct,
len(onsets),
translation_velocity_limit,
v.max(),
)
annot += _annotations_from_mask(
hp_ts, bad_mask, "BAD_mov_trans_vel", orig_time=orig_time
)
# Annotate based on displacement from mean head position
disp = []
if mean_distance_limit is not None:
assert mean_distance_limit > 0
# compute dev to head transform for fixed points
use_dev_head_trans = use_dev_head_trans.lower()
if use_dev_head_trans not in ["average", "info"]:
raise ValueError(
"use_dev_head_trans must be either"
f" 'average' or 'info': got '{use_dev_head_trans}'"
)
if use_dev_head_trans == "average":
fixed_dev_head_t = compute_average_dev_head_t(raw, pos)
elif use_dev_head_trans == "info":
fixed_dev_head_t = raw.info["dev_head_t"]
# Get static head pos from file, used to convert quat to cartesian
chpi_pos = sorted(
[d for d in raw.info["hpi_results"][-1]["dig_points"]],
key=lambda x: x["ident"],
)
chpi_pos = np.array([d["r"] for d in chpi_pos])
# Get head pos changes during recording
chpi_pos_mov = np.array(
[apply_trans(_quat_to_affine(quat), chpi_pos) for quat in pos[:, 1:7]]
)
# get fixed position
chpi_pos_fix = apply_trans(fixed_dev_head_t, chpi_pos)
# get movement displacement from mean pos
hpi_disp = chpi_pos_mov - np.tile(chpi_pos_fix, (pos.shape[0], 1, 1))
# get positions above threshold distance
disp = np.sqrt((hpi_disp**2).sum(axis=2))
bad_mask = np.any(disp > mean_distance_limit, axis=1)
onsets, offsets = _mask_to_onsets_offsets(bad_mask)
onsets, offsets = hp_ts[onsets], hp_ts[offsets]
bad_pct = 100 * (offsets - onsets).sum() / t_tot
logger.info(
"Omitting %5.1f%% (%3d segments): " "disp >= %5.4fm (max: %5.4fm)",
bad_pct,
len(onsets),
mean_distance_limit,
disp.max(),
)
annot += _annotations_from_mask(
hp_ts, bad_mask, "BAD_mov_dist", orig_time=orig_time
)
_adjust_onset_meas_date(annot, raw)
return annot, disp
@verbose
def compute_average_dev_head_t(raw, pos, *, verbose=None):
"""Get new device to head transform based on good segments.
Segments starting with "BAD" annotations are not included for calculating
the mean head position.
Parameters
----------
raw : instance of Raw | list of Raw
Data to compute head position. Can be a list containing multiple raw
instances.
pos : array, shape (N, 10) | list of ndarray
The position and quaternion parameters from cHPI fitting. Can be
a list containing multiple position arrays, one per raw instance passed.
%(verbose)s
Returns
-------
dev_head_t : instance of Transform
New ``dev_head_t`` transformation using the averaged good head positions.
Notes
-----
.. versionchanged:: 1.7
Support for multiple raw instances and position arrays was added.
"""
# Get weighted head pos trans and rot
if not isinstance(raw, list | tuple):
raw = [raw]
if not isinstance(pos, list | tuple):
pos = [pos]
if len(pos) != len(raw):
raise ValueError(
f"Number of head positions ({len(pos)}) must match the number of raw "
f"instances ({len(raw)})"
)
hp = list()
dt = list()
for ri, (r, p) in enumerate(zip(raw, pos)):
_validate_type(r, BaseRaw, f"raw[{ri}]")
_validate_type(p, np.ndarray, f"pos[{ri}]")
hp_, dt_ = _raw_hp_weights(r, p)
hp.append(hp_)
dt.append(dt_)
hp = np.concatenate(hp, axis=0)
dt = np.concatenate(dt, axis=0)
dt /= dt.sum()
best_q = _average_quats(hp[:, 1:4], weights=dt)
trans = np.eye(4)
trans[:3, :3] = quat_to_rot(best_q)
trans[:3, 3] = dt @ hp[:, 4:7]
dist = np.linalg.norm(trans[:3, 3])
if dist > 1: # less than 1 meter is sane
warn(f"Implausible head position detected: {dist} meters from device origin")
dev_head_t = Transform("meg", "head", trans)
return dev_head_t
def _raw_hp_weights(raw, pos):
sfreq = raw.info["sfreq"]
seg_good = np.ones(len(raw.times))
hp = pos.copy()
hp_ts = hp[:, 0] - raw._first_time
# Check rounding issues at 0 time
if hp_ts[0] < 0:
hp_ts[0] = 0
assert hp_ts[1] > 1.0 / sfreq
# Mask out segments if beyond scan time
mask = hp_ts <= raw.times[-1]
if not mask.all():
logger.info(
" Removing %d samples > raw.times[-1] (%s)",
np.sum(~mask),
raw.times[-1],
)
hp = hp[mask]
del mask, hp_ts
# Get time indices
ts = np.concatenate((hp[:, 0], [(raw.last_samp + 1) / sfreq]))
assert (np.diff(ts) > 0).all()
ts -= raw.first_samp / sfreq
idx = raw.time_as_index(ts, use_rounding=True)
del ts
if idx[0] == -1: # annoying rounding errors
idx[0] = 0
assert idx[1] > 0
assert (idx >= 0).all()
assert idx[-1] == len(seg_good)
assert (np.diff(idx) > 0).all()
# Mark times bad that are bad according to annotations
onsets, ends = _annotations_starts_stops(raw, "bad")
for onset, end in zip(onsets, ends):
seg_good[onset:end] = 0
dt = np.diff(np.cumsum(np.concatenate([[0], seg_good]))[idx])
assert (dt >= 0).all()
dt = dt / sfreq
del seg_good, idx
return hp, dt
def _annotations_from_mask(times, mask, annot_name, orig_time=None):
"""Construct annotations from boolean mask of the data."""
mask_tf = distance_transform_edt(mask)
# Overcome the shortcoming of find_peaks
# in finding a marginal peak, by
# inserting 0s at the front and the
# rear, then subtracting in index
ins_mask_tf = np.concatenate((np.zeros(1), mask_tf, np.zeros(1)))
left_midpt_index = find_peaks(ins_mask_tf)[0] - 1
right_midpt_index = (
np.flip(len(ins_mask_tf) - 1 - find_peaks(ins_mask_tf[::-1])[0]) - 1
)
onsets_index = left_midpt_index - mask_tf[left_midpt_index].astype(int) + 1
ends_index = right_midpt_index + mask_tf[right_midpt_index].astype(int)
# Ensure onsets_index >= 0,
# otherwise the duration starts from the beginning
onsets_index[onsets_index < 0] = 0
# Ensure ends_index < len(times),
# otherwise the duration is to the end of times
if len(times) == len(mask):
ends_index[ends_index >= len(times)] = len(times) - 1
# To be consistent with the original code,
# possibly a bug in tests code
else:
ends_index[ends_index >= len(mask)] = len(mask)
onsets = times[onsets_index]
ends = times[ends_index]
durations = ends - onsets
desc = [annot_name] * len(durations)
return Annotations(onsets, durations, desc, orig_time=orig_time)
@verbose
def annotate_break(
raw,
events=None,
min_break_duration=15.0,
t_start_after_previous=5.0,
t_stop_before_next=5.0,
ignore=("bad", "edge"),
*,
verbose=None,
):
"""Create `~mne.Annotations` for breaks in an ongoing recording.
This function first searches for segments in the data that are not
annotated or do not contain any events and are at least
``min_break_duration`` seconds long, and then proceeds to creating
annotations for those break periods.
Parameters
----------
raw : instance of Raw
The continuous data to analyze.
events : None | array, shape (n_events, 3)
If ``None`` (default), operate based solely on the annotations present
in ``raw``. If an events array, ignore any annotations in the raw data,
and operate based on these events only.
min_break_duration : float
The minimum time span in seconds between the offset of one and the
onset of the subsequent annotation (if ``events`` is ``None``) or
between two consecutive events (if ``events`` is an array) to consider
this period a "break". Defaults to 15 seconds.
.. note:: This value defines the minimum duration of a break period in
the data, **not** the minimum duration of the generated
annotations! See also ``t_start_after_previous`` and
``t_stop_before_next`` for details.
t_start_after_previous, t_stop_before_next : float
Specifies how far the to-be-created "break" annotation extends towards
the two annotations or events spanning the break. This can be used to
ensure e.g. that the break annotation doesn't start and end immediately
with a stimulation event. If, for example, your data contains a break
of 30 seconds between two stimuli, and ``t_start_after_previous`` is
set to ``5`` and ``t_stop_before_next`` is set to ``3``, the break
annotation will start 5 seconds after the first stimulus, and end 3
seconds before the second stimulus, yielding an annotated break of
``30 - 5 - 3 = 22`` seconds. Both default to 5 seconds.
.. note:: The beginning and the end of the recording will be annotated
as breaks, too, if the period from recording start until the
first annotation or event (or from last annotation or event
until recording end) is at least ``min_break_duration``
seconds long.
ignore : iterable of str
Annotation descriptions starting with these strings will be ignored by
the break-finding algorithm. The string comparison is case-insensitive,
i.e., ``('bad',)`` and ``('BAD',)`` are equivalent. By default, all
annotation descriptions starting with "bad" and annotations
indicating "edges" (produced by data concatenation) will be
ignored. Pass an empty list or tuple to take all existing annotations
into account. If ``events`` is passed, this parameter has no effect.
%(verbose)s
Returns
-------
break_annotations : instance of Annotations
The break annotations, each with the description ``'BAD_break'``. If
no breaks could be found given the provided function parameters, an
empty `~mne.Annotations` object will be returned.
Notes
-----
.. versionadded:: 0.24
"""
_validate_type(item=raw, item_name="raw", types=BaseRaw, type_name="Raw")
_validate_type(item=events, item_name="events", types=(None, np.ndarray))
if min_break_duration - t_start_after_previous - t_stop_before_next <= 0:
annot_dur = min_break_duration - t_start_after_previous - t_stop_before_next
raise ValueError(
f"The result of "
f"min_break_duration - t_start_after_previous - "
f"t_stop_before_next must be greater than 0, but it is: "
f"{annot_dur}"
)
if events is not None and events.size == 0:
raise ValueError("The events array must not be empty.")
if events is not None or not ignore:
ignore = tuple()
else:
ignore = tuple(ignore)
for item in ignore:
_validate_type(item=item, types="str", item_name='All elements of "ignore"')
if events is None:
annotations = raw.annotations.copy()
if ignore:
logger.info(
f"Ignoring annotations with descriptions starting "
f'with: {", ".join(ignore)}'
)
else:
annotations = annotations_from_events(
events=events, sfreq=raw.info["sfreq"], orig_time=raw.info["meas_date"]
)
if not annotations:
raise ValueError("Could not find (or generate) any annotations in your data.")
# Only keep annotations of interest and extract annotated time periods
# Ignore case
ignore = tuple(i.lower() for i in ignore)
keep_mask = [True] * len(annotations)
for idx, description in enumerate(annotations.description):
description = description.lower()
if any(description.startswith(i) for i in ignore):
keep_mask[idx] = False
annotated_intervals = [
[onset, onset + duration]
for onset, duration in zip(
annotations.onset[keep_mask], annotations.duration[keep_mask]
)
]
# Merge overlapping annotation intervals
# Pre-load `merged_intervals` with the first interval to simplify
# processing
merged_intervals = [annotated_intervals[0]]
for interval in annotated_intervals:
merged_interval_stop = merged_intervals[-1][1]
interval_start, interval_stop = interval
if interval_stop < merged_interval_stop:
# Current interval ends sooner than the merged one; skip it
continue
elif (
interval_start <= merged_interval_stop
and interval_stop >= merged_interval_stop
):
# Expand duration of the merged interval
merged_intervals[-1][1] = interval_stop
else:
# No overlap between the current interval and the existing merged
# time period; proceed to the next interval
merged_intervals.append(interval)
merged_intervals = np.array(merged_intervals)
merged_intervals -= raw.first_time # work in zero-based time
# Now extract the actual break periods
break_onsets = []
break_durations = []
# Handle the time period up until the first annotation
if 0 < merged_intervals[0][0] and merged_intervals[0][0] >= min_break_duration:
onset = 0 # don't add t_start_after_previous here
offset = merged_intervals[0][0] - t_stop_before_next
duration = offset - onset
break_onsets.append(onset)
break_durations.append(duration)
# Handle the time period between first and last annotation
for idx, _ in enumerate(merged_intervals[1:, :], start=1):
this_start = merged_intervals[idx, 0]
previous_stop = merged_intervals[idx - 1, 1]
if this_start - previous_stop < min_break_duration:
continue
onset = previous_stop + t_start_after_previous
offset = this_start - t_stop_before_next
duration = offset - onset
break_onsets.append(onset)
break_durations.append(duration)
# Handle the time period after the last annotation
if (
raw.times[-1] > merged_intervals[-1][1]
and raw.times[-1] - merged_intervals[-1][1] >= min_break_duration
):
onset = merged_intervals[-1][1] + t_start_after_previous
offset = raw.times[-1] # don't subtract t_stop_before_next here
duration = offset - onset
break_onsets.append(onset)
break_durations.append(duration)
# Finally, create the break annotations
break_annotations = Annotations(
onset=break_onsets,
duration=break_durations,
description=["BAD_break"],
orig_time=raw.info["meas_date"],
)
# Log some info
n_breaks = len(break_annotations)
break_times = [
f"{o:.1f} {o + d:.1f} s [{d:.1f} s]"
for o, d in zip(break_annotations.onset, break_annotations.duration)
]
break_times = "\n ".join(break_times)
total_break_dur = sum(break_annotations.duration)
fraction_breaks = total_break_dur / raw.times[-1]
logger.info(
f"\nDetected {n_breaks} break period{_pl(n_breaks)} of >= "
f"{min_break_duration} s duration:\n {break_times}\n"
f"In total, {round(100 * fraction_breaks, 1):.1f}% of the "
f"data ({round(total_break_dur, 1):.1f} s) have been marked "
f"as a break.\n"
)
_adjust_onset_meas_date(break_annotations, raw)
return break_annotations

50
mne/preprocessing/bads.py Normal file
View File

@@ -0,0 +1,50 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from scipy.stats import zscore
def _find_outliers(X, threshold=3.0, max_iter=2, tail=0):
"""Find outliers based on iterated Z-scoring.
This procedure compares the absolute z-score against the threshold.
After excluding local outliers, the comparison is repeated until no
local outlier is present any more.
Parameters
----------
X : np.ndarray of float, shape (n_elemenets,)
The scores for which to find outliers.
threshold : float
The value above which a feature is classified as outlier.
max_iter : int
The maximum number of iterations.
tail : {0, 1, -1}
Whether to search for outliers on both extremes of the z-scores (0),
or on just the positive (1) or negative (-1) side.
Returns
-------
bad_idx : np.ndarray of int, shape (n_features)
The outlier indices.
"""
my_mask = np.zeros(len(X), dtype=bool)
for _ in range(max_iter):
X = np.ma.masked_array(X, my_mask)
if tail == 0:
this_z = np.abs(zscore(X))
elif tail == 1:
this_z = zscore(X)
elif tail == -1:
this_z = -zscore(X)
else:
raise ValueError(f"Tail parameter {tail} not recognised.")
local_bad = this_z > threshold
my_mask = np.max([my_mask, local_bad], 0)
if not np.any(local_bad):
break
bad_idx = np.where(my_mask)[0]
return bad_idx

167
mne/preprocessing/ctps_.py Normal file
View File

@@ -0,0 +1,167 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import math
import numpy as np
from scipy.signal import hilbert
from scipy.special import logsumexp
def _compute_normalized_phase(data):
"""Compute normalized phase angles.
Parameters
----------
data : ndarray, shape (n_epochs, n_sources, n_times)
The data to compute the phase angles for.
Returns
-------
phase_angles : ndarray, shape (n_epochs, n_sources, n_times)
The normalized phase angles.
"""
return (np.angle(hilbert(data)) + np.pi) / (2 * np.pi)
def ctps(data, is_raw=True):
"""Compute cross-trial-phase-statistics [1].
Note. It is assumed that the sources are already
appropriately filtered
Parameters
----------
data: ndarray, shape (n_epochs, n_channels, n_times)
Any kind of data of dimensions trials, traces, features.
is_raw : bool
If True it is assumed that data haven't been transformed to Hilbert
space and phase angles haven't been normalized. Defaults to True.
Returns
-------
ks_dynamics : ndarray, shape (n_sources, n_times)
The kuiper statistics.
pk_dynamics : ndarray, shape (n_sources, n_times)
The normalized kuiper index for ICA sources and
time slices.
phase_angles : ndarray, shape (n_epochs, n_sources, n_times) | None
The phase values for epochs, sources and time slices. If ``is_raw``
is False, None is returned.
References
----------
[1] Dammers, J., Schiek, M., Boers, F., Silex, C., Zvyagintsev,
M., Pietrzyk, U., Mathiak, K., 2008. Integration of amplitude
and phase statistics for complete artifact removal in independent
components of neuromagnetic recordings. Biomedical
Engineering, IEEE Transactions on 55 (10), 2353-2362.
"""
if not data.ndim == 3:
raise ValueError(f"Data must have 3 dimensions, not {data.ndim}.")
if is_raw:
phase_angles = _compute_normalized_phase(data)
else:
phase_angles = data # phase angles can be computed externally
# initialize array for results
ks_dynamics = np.zeros_like(phase_angles[0])
pk_dynamics = np.zeros_like(phase_angles[0])
# calculate Kuiper's statistic for each source
for ii, source in enumerate(np.transpose(phase_angles, [1, 0, 2])):
ks, pk = kuiper(source)
pk_dynamics[ii, :] = pk
ks_dynamics[ii, :] = ks
return ks_dynamics, pk_dynamics, phase_angles if is_raw else None
def kuiper(data, dtype=np.float64): # noqa: D401
"""Kuiper's test of uniform distribution.
Parameters
----------
data : ndarray, shape (n_sources,) | (n_sources, n_times)
Empirical distribution.
dtype : str | obj
The data type to be used.
Returns
-------
ks : ndarray
Kuiper's statistic.
pk : ndarray
Normalized probability of Kuiper's statistic [0, 1].
"""
# if data not numpy array, implicitly convert and make to use copied data
# ! sort data array along first axis !
data = np.sort(data, axis=0).astype(dtype)
shape = data.shape
n_dim = len(shape)
n_trials = shape[0]
# create uniform cdf
j1 = (np.arange(n_trials, dtype=dtype) + 1.0) / float(n_trials)
j2 = np.arange(n_trials, dtype=dtype) / float(n_trials)
if n_dim > 1: # single phase vector (n_trials)
j1 = j1[:, np.newaxis]
j2 = j2[:, np.newaxis]
d1 = (j1 - data).max(axis=0)
d2 = (data - j2).max(axis=0)
n_eff = n_trials
d = d1 + d2 # Kuiper's statistic [n_time_slices]
return d, _prob_kuiper(d, n_eff, dtype=dtype)
def _prob_kuiper(d, n_eff, dtype="f8"):
"""Test for statistical significance against uniform distribution.
Parameters
----------
d : float
The kuiper distance value.
n_eff : int
The effective number of elements.
dtype : str | obj
The data type to be used. Defaults to double precision floats.
Returns
-------
pk_norm : float
The normalized Kuiper value such that 0 < ``pk_norm`` < 1.
References
----------
[1] Stephens MA 1970. Journal of the Royal Statistical Society, ser. B,
vol 32, pp 115-122.
[2] Kuiper NH 1962. Proceedings of the Koninklijke Nederlands Akademie
van Wetenschappen, ser Vol 63 pp 38-47
"""
n_time_slices = np.size(d) # single value or vector
n_points = 100
en = math.sqrt(n_eff)
k_lambda = (en + 0.155 + 0.24 / en) * d # see [1]
l2 = k_lambda**2.0
j2 = (np.arange(n_points) + 1) ** 2
j2 = j2.repeat(n_time_slices).reshape(n_points, n_time_slices)
fact = 4.0 * j2 * l2 - 1.0
# compute normalized pK value in range [0,1]
a = -2.0 * j2 * l2
b = 2.0 * fact
pk_norm = -logsumexp(a, b=b, axis=0) / (2.0 * n_eff)
# check for no difference to uniform cdf
pk_norm = np.where(k_lambda < 0.4, 0.0, pk_norm)
# check for round off errors
pk_norm = np.where(pk_norm > 1.0, 1.0, pk_norm)
return pk_norm

539
mne/preprocessing/ecg.py Normal file
View File

@@ -0,0 +1,539 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from .._fiff.meas_info import create_info
from .._fiff.pick import _picks_to_idx, pick_channels, pick_types
from ..annotations import _annotations_starts_stops
from ..epochs import BaseEpochs, Epochs
from ..evoked import Evoked
from ..filter import filter_data
from ..io import BaseRaw, RawArray
from ..utils import int_like, logger, sum_squared, verbose, warn
@verbose
def qrs_detector(
sfreq,
ecg,
thresh_value=0.6,
levels=2.5,
n_thresh=3,
l_freq=5,
h_freq=35,
tstart=0,
filter_length="10s",
verbose=None,
):
"""Detect QRS component in ECG channels.
QRS is the main wave on the heart beat.
Parameters
----------
sfreq : float
Sampling rate
ecg : array
ECG signal
thresh_value : float | str
qrs detection threshold. Can also be "auto" for automatic
selection of threshold.
levels : float
number of std from mean to include for detection
n_thresh : int
max number of crossings
l_freq : float
Low pass frequency
h_freq : float
High pass frequency
%(tstart_ecg)s
%(filter_length_ecg)s
%(verbose)s
Returns
-------
events : array
Indices of ECG peaks.
"""
win_size = int(round((60.0 * sfreq) / 120.0))
filtecg = filter_data(
ecg,
sfreq,
l_freq,
h_freq,
None,
filter_length,
0.5,
0.5,
phase="zero-double",
fir_window="hann",
fir_design="firwin2",
)
ecg_abs = np.abs(filtecg)
init = int(sfreq)
n_samples_start = int(sfreq * tstart)
ecg_abs = ecg_abs[n_samples_start:]
n_points = len(ecg_abs)
maxpt = np.empty(3)
maxpt[0] = np.max(ecg_abs[:init])
maxpt[1] = np.max(ecg_abs[init : init * 2])
maxpt[2] = np.max(ecg_abs[init * 2 : init * 3])
init_max = np.mean(maxpt)
if thresh_value == "auto":
thresh_runs = np.arange(0.3, 1.1, 0.05)
elif isinstance(thresh_value, str):
raise ValueError('threshold value must be "auto" or a float')
else:
thresh_runs = [thresh_value]
# Try a few thresholds (or just one)
clean_events = list()
for thresh_value in thresh_runs:
thresh1 = init_max * thresh_value
numcross = list()
time = list()
rms = list()
ii = 0
while ii < (n_points - win_size):
window = ecg_abs[ii : ii + win_size]
if window[0] > thresh1:
max_time = np.argmax(window)
time.append(ii + max_time)
nx = np.sum(
np.diff(((window > thresh1).astype(np.int64) == 1).astype(int))
)
numcross.append(nx)
rms.append(np.sqrt(sum_squared(window) / window.size))
ii += win_size
else:
ii += 1
if len(rms) == 0:
rms.append(0.0)
time.append(0.0)
time = np.array(time)
rms_mean = np.mean(rms)
rms_std = np.std(rms)
rms_thresh = rms_mean + (rms_std * levels)
b = np.where(rms < rms_thresh)[0]
a = np.array(numcross)[b]
ce = time[b[a < n_thresh]]
ce += n_samples_start
if ce.size > 0: # We actually found an event
clean_events.append(ce)
if clean_events:
# pick the best threshold; first get effective heart rates
rates = np.array(
[60.0 * len(cev) / (len(ecg) / float(sfreq)) for cev in clean_events]
)
# now find heart rates that seem reasonable (infant through adult
# athlete)
idx = np.where(np.logical_and(rates <= 160.0, rates >= 40.0))[0]
if idx.size > 0:
ideal_rate = np.median(rates[idx]) # get close to the median
else:
ideal_rate = 80.0 # get close to a reasonable default
idx = np.argmin(np.abs(rates - ideal_rate))
clean_events = clean_events[idx]
else:
clean_events = np.array([])
return clean_events
@verbose
def find_ecg_events(
raw,
event_id=999,
ch_name=None,
tstart=0.0,
l_freq=5,
h_freq=35,
qrs_threshold="auto",
filter_length="10s",
return_ecg=False,
reject_by_annotation=True,
verbose=None,
):
"""Find ECG events by localizing the R wave peaks.
Parameters
----------
raw : instance of Raw
The raw data.
%(event_id_ecg)s
%(ch_name_ecg)s
%(tstart_ecg)s
%(l_freq_ecg_filter)s
qrs_threshold : float | str
Between 0 and 1. qrs detection threshold. Can also be "auto" to
automatically choose the threshold that generates a reasonable
number of heartbeats (40-160 beats / min).
%(filter_length_ecg)s
return_ecg : bool
Return the ECG data. This is especially useful if no ECG channel
is present in the input data, so one will be synthesized (only works if MEG
channels are present in the data). Defaults to ``False``.
%(reject_by_annotation_all)s
.. versionadded:: 0.18
%(verbose)s
Returns
-------
ecg_events : array
The events corresponding to the peaks of the R waves.
ch_ecg : int | None
Index of channel used.
average_pulse : float
The estimated average pulse. If no ECG events could be found, this will
be zero.
ecg : array | None
The ECG data of the synthesized ECG channel, if any. This will only
be returned if ``return_ecg=True`` was passed.
See Also
--------
create_ecg_epochs
compute_proj_ecg
"""
skip_by_annotation = ("edge", "bad") if reject_by_annotation else ()
del reject_by_annotation
idx_ecg = _get_ecg_channel_index(ch_name, raw)
if idx_ecg is not None:
logger.info(f"Using channel {raw.ch_names[idx_ecg]} to identify heart beats.")
ecg = raw.get_data(picks=idx_ecg)
else:
ecg, _ = _make_ecg(raw, start=None, stop=None)
assert ecg.ndim == 2 and ecg.shape[0] == 1
ecg = ecg[0]
# Deal with filtering the same way we do in raw, i.e. filter each good
# segment
onsets, ends = _annotations_starts_stops(
raw, skip_by_annotation, "reject_by_annotation", invert=True
)
ecgs = list()
max_idx = (ends - onsets).argmax()
for si, (start, stop) in enumerate(zip(onsets, ends)):
# Only output filter params once (for info level), and only warn
# once about the length criterion (longest segment is too short)
use_verbose = verbose if si == max_idx else "error"
ecgs.append(
filter_data(
ecg[start:stop],
raw.info["sfreq"],
l_freq,
h_freq,
[0],
filter_length,
0.5,
0.5,
1,
"fir",
None,
copy=False,
phase="zero-double",
fir_window="hann",
fir_design="firwin2",
verbose=use_verbose,
)
)
ecg = np.concatenate(ecgs)
# detecting QRS and generating events. Since not user-controlled, don't
# output filter params here (hardcode verbose=False)
ecg_events = qrs_detector(
raw.info["sfreq"],
ecg,
tstart=tstart,
thresh_value=qrs_threshold,
l_freq=None,
h_freq=None,
verbose=False,
)
# map ECG events back to original times
remap = np.empty(len(ecg), int)
offset = 0
for start, stop in zip(onsets, ends):
this_len = stop - start
assert this_len >= 0
remap[offset : offset + this_len] = np.arange(start, stop)
offset += this_len
assert offset == len(ecg)
if ecg_events.size > 0:
ecg_events = remap[ecg_events]
else:
ecg_events = np.array([])
n_events = len(ecg_events)
duration_sec = len(ecg) / raw.info["sfreq"] - tstart
duration_min = duration_sec / 60.0
average_pulse = n_events / duration_min
logger.info(
f"Number of ECG events detected : {n_events} "
f"(average pulse {average_pulse} / min.)"
)
ecg_events = np.array(
[
ecg_events + raw.first_samp,
np.zeros(n_events, int),
event_id * np.ones(n_events, int),
]
).T
out = (ecg_events, idx_ecg, average_pulse)
ecg = ecg[np.newaxis] # backward compat output 2D
if return_ecg:
out += (ecg,)
return out
def _get_ecg_channel_index(ch_name, inst):
"""Get ECG channel index, if no channel found returns None."""
if ch_name is None:
ecg_idx = pick_types(
inst.info,
meg=False,
eeg=False,
stim=False,
eog=False,
ecg=True,
emg=False,
ref_meg=False,
exclude="bads",
)
else:
if ch_name not in inst.ch_names:
raise ValueError(f"{ch_name} not in channel list ({inst.ch_names})")
ecg_idx = pick_channels(inst.ch_names, include=[ch_name])
if len(ecg_idx) == 0:
return None
if len(ecg_idx) > 1:
warn(
f"More than one ECG channel found. Using only {inst.ch_names[ecg_idx[0]]}."
)
return ecg_idx[0]
@verbose
def create_ecg_epochs(
raw,
ch_name=None,
event_id=999,
picks=None,
tmin=-0.5,
tmax=0.5,
l_freq=8,
h_freq=16,
reject=None,
flat=None,
baseline=None,
preload=True,
keep_ecg=False,
reject_by_annotation=True,
decim=1,
verbose=None,
):
"""Conveniently generate epochs around ECG artifact events.
%(create_ecg_epochs)s
.. note:: Filtering is only applied to the ECG channel while finding
events. The resulting ``ecg_epochs`` will have no filtering
applied (i.e., have the same filter properties as the input
``raw`` instance).
Parameters
----------
raw : instance of Raw
The raw data.
%(ch_name_ecg)s
%(event_id_ecg)s
%(picks_all)s
tmin : float
Start time before event.
tmax : float
End time after event.
%(l_freq_ecg_filter)s
%(reject_epochs)s
%(flat)s
%(baseline_epochs)s
preload : bool
Preload epochs or not (default True). Must be True if
keep_ecg is True.
keep_ecg : bool
When ECG is synthetically created (after picking), should it be added
to the epochs? Must be False when synthetic channel is not used.
Defaults to False.
%(reject_by_annotation_epochs)s
.. versionadded:: 0.14.0
%(decim)s
.. versionadded:: 0.21.0
%(verbose)s
Returns
-------
ecg_epochs : instance of Epochs
Data epoched around ECG R wave peaks.
See Also
--------
find_ecg_events
compute_proj_ecg
Notes
-----
If you already have a list of R-peak times, or want to compute R-peaks
outside MNE-Python using a different algorithm, the recommended approach is
to call the :class:`~mne.Epochs` constructor directly, with your R-peaks
formatted as an :term:`events` array (here we also demonstrate the relevant
default values)::
mne.Epochs(raw, r_peak_events_array, tmin=-0.5, tmax=0.5,
baseline=None, preload=True, proj=False) # doctest: +SKIP
"""
has_ecg = "ecg" in raw or ch_name is not None
if keep_ecg and (has_ecg or not preload):
raise ValueError(
"keep_ecg can be True only if the ECG channel is "
"created synthetically and preload=True."
)
events, _, _, ecg = find_ecg_events(
raw,
ch_name=ch_name,
event_id=event_id,
l_freq=l_freq,
h_freq=h_freq,
return_ecg=True,
reject_by_annotation=reject_by_annotation,
)
picks = _picks_to_idx(raw.info, picks, "all", exclude=())
# create epochs around ECG events and baseline (important)
ecg_epochs = Epochs(
raw,
events=events,
event_id=event_id,
tmin=tmin,
tmax=tmax,
proj=False,
flat=flat,
picks=picks,
reject=reject,
baseline=baseline,
reject_by_annotation=reject_by_annotation,
preload=preload,
decim=decim,
)
if keep_ecg:
# We know we have created a synthetic channel and epochs are preloaded
ecg_raw = RawArray(
ecg,
create_info(
ch_names=["ECG-SYN"], sfreq=raw.info["sfreq"], ch_types=["ecg"]
),
first_samp=raw.first_samp,
)
with ecg_raw.info._unlock():
ignore = ["ch_names", "chs", "nchan", "bads"]
for k, v in raw.info.items():
if k not in ignore:
ecg_raw.info[k] = v
syn_epochs = Epochs(
ecg_raw,
events=ecg_epochs.events,
event_id=event_id,
tmin=tmin,
tmax=tmax,
proj=False,
picks=[0],
baseline=baseline,
decim=decim,
preload=True,
)
ecg_epochs = ecg_epochs.add_channels([syn_epochs])
return ecg_epochs
@verbose
def _make_ecg(inst, start, stop, reject_by_annotation=False, verbose=None):
"""Create ECG signal from cross channel average."""
if not any(c in inst for c in ["mag", "grad"]):
raise ValueError(
"Generating an artificial ECG channel can only be done for MEG data."
)
for ch in ["mag", "grad"]:
if ch in inst:
break
logger.info(
"Reconstructing ECG signal from {}".format(
{"mag": "Magnetometers", "grad": "Gradiometers"}[ch]
)
)
picks = pick_types(inst.info, meg=ch, eeg=False, ref_meg=False)
# Handle start/stop
msg = (
"integer arguments for the start and stop parameters are "
"not supported for Epochs and Evoked objects. Please "
"consider using float arguments specifying start and stop "
"time in seconds."
)
begin_param_name = "tmin"
if isinstance(start, int_like):
if isinstance(inst, BaseRaw):
# Raw has start param, can just use int
begin_param_name = "start"
else:
raise ValueError(msg)
end_param_name = "tmax"
if isinstance(start, int_like):
if isinstance(inst, BaseRaw):
# Raw has stop param, can just use int
end_param_name = "stop"
else:
raise ValueError(msg)
kwargs = {begin_param_name: start, end_param_name: stop}
if isinstance(inst, BaseRaw):
reject_by_annotation = "omit" if reject_by_annotation else None
ecg, times = inst.get_data(
picks,
return_times=True,
**kwargs,
reject_by_annotation=reject_by_annotation,
)
elif isinstance(inst, BaseEpochs):
ecg = np.hstack(inst.copy().get_data(picks, **kwargs))
times = inst.times
elif isinstance(inst, Evoked):
ecg = inst.get_data(picks, **kwargs)
times = inst.times
return ecg.mean(0, keepdims=True), times

342
mne/preprocessing/eog.py Normal file
View File

@@ -0,0 +1,342 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from .._fiff.pick import pick_channels, pick_types
from ..epochs import Epochs
from ..filter import filter_data
from ..utils import _pl, _validate_type, logger, verbose
from ._peak_finder import peak_finder
@verbose
def find_eog_events(
raw,
event_id=998,
l_freq=1,
h_freq=10,
filter_length="10s",
ch_name=None,
tstart=0,
reject_by_annotation=False,
thresh=None,
verbose=None,
):
"""Locate EOG artifacts.
.. note:: To control true-positive and true-negative detection rates, you
may adjust the ``thresh`` parameter.
Parameters
----------
raw : instance of Raw
The raw data.
event_id : int
The index to assign to found events.
l_freq : float
Low cut-off frequency to apply to the EOG channel in Hz.
h_freq : float
High cut-off frequency to apply to the EOG channel in Hz.
filter_length : str | int | None
Number of taps to use for filtering.
%(ch_name_eog)s
tstart : float
Start detection after tstart seconds.
reject_by_annotation : bool
Whether to omit data that is annotated as bad.
thresh : float | None
Threshold to trigger the detection of an EOG event. This controls the
thresholding of the underlying peak-finding algorithm. Larger values
mean that fewer peaks (i.e., fewer EOG events) will be detected.
If ``None``, use the default of ``(max(eog) - min(eog)) / 4``,
with ``eog`` being the filtered EOG signal.
%(verbose)s
Returns
-------
eog_events : array
Events.
See Also
--------
create_eog_epochs
compute_proj_eog
"""
# Getting EOG Channel
eog_inds = _get_eog_channel_index(ch_name, raw)
eog_names = np.array(raw.ch_names)[eog_inds] # for logging
logger.info(f"EOG channel index for this subject is: {eog_inds}")
# Reject bad segments.
reject_by_annotation = "omit" if reject_by_annotation else None
eog, times = raw.get_data(
picks=eog_inds, reject_by_annotation=reject_by_annotation, return_times=True
)
times = times * raw.info["sfreq"] + raw.first_samp
eog_events = _find_eog_events(
eog,
ch_names=eog_names,
event_id=event_id,
l_freq=l_freq,
h_freq=h_freq,
sampling_rate=raw.info["sfreq"],
first_samp=raw.first_samp,
filter_length=filter_length,
tstart=tstart,
thresh=thresh,
verbose=verbose,
)
# Map times to corresponding samples.
eog_events[:, 0] = np.round(times[eog_events[:, 0] - raw.first_samp]).astype(int)
return eog_events
@verbose
def _find_eog_events(
eog,
*,
ch_names,
event_id,
l_freq,
h_freq,
sampling_rate,
first_samp,
filter_length="10s",
tstart=0.0,
thresh=None,
verbose=None,
):
"""Find EOG events."""
logger.info(
"Filtering the data to remove DC offset to help "
"distinguish blinks from saccades"
)
# filtering to remove dc offset so that we know which is blink and saccades
# hardcode verbose=False to suppress filter param messages (since this
# filter is not under user control)
fmax = np.minimum(45, sampling_rate / 2.0 - 0.75) # protect Nyquist
filteog = np.array(
[
filter_data(
x,
sampling_rate,
2,
fmax,
None,
filter_length,
0.5,
0.5,
phase="zero-double",
fir_window="hann",
fir_design="firwin2",
verbose=False,
)
for x in eog
]
)
temp = np.sqrt(np.sum(filteog**2, axis=1))
indexmax = np.argmax(temp)
if ch_names is not None: # it can be None if called from ica_find_eog_events
logger.info(f"Selecting channel {ch_names[indexmax]} for blink detection")
# easier to detect peaks with filtering.
filteog = filter_data(
eog[indexmax],
sampling_rate,
l_freq,
h_freq,
None,
filter_length,
0.5,
0.5,
phase="zero-double",
fir_window="hann",
fir_design="firwin2",
)
# detecting eog blinks and generating event file
logger.info("Now detecting blinks and generating corresponding events")
temp = filteog - np.mean(filteog)
n_samples_start = int(sampling_rate * tstart)
if np.abs(np.max(temp)) > np.abs(np.min(temp)):
eog_events, _ = peak_finder(filteog[n_samples_start:], thresh, extrema=1)
else:
eog_events, _ = peak_finder(filteog[n_samples_start:], thresh, extrema=-1)
eog_events += n_samples_start
n_events = len(eog_events)
logger.info(f"Number of EOG events detected: {n_events}")
eog_events = np.array(
[
eog_events + first_samp,
np.zeros(n_events, int),
event_id * np.ones(n_events, int),
]
).T
return eog_events
def _get_eog_channel_index(ch_name, inst):
"""Get EOG channel indices."""
_validate_type(ch_name, types=(None, str, list), item_name="ch_name")
if ch_name is None:
eog_inds = pick_types(
inst.info,
meg=False,
eeg=False,
stim=False,
eog=True,
ecg=False,
emg=False,
ref_meg=False,
exclude="bads",
)
if eog_inds.size == 0:
raise RuntimeError("No EOG channel(s) found")
ch_names = [inst.ch_names[i] for i in eog_inds]
elif isinstance(ch_name, str):
ch_names = [ch_name]
else: # it's a list
ch_names = ch_name.copy()
# ensure the specified channels are present in the data
if ch_name is not None:
not_found = [ch_name for ch_name in ch_names if ch_name not in inst.ch_names]
if not_found:
raise ValueError(
f"The specified EOG channel{_pl(not_found)} "
f'cannot be found: {", ".join(not_found)}'
)
eog_inds = pick_channels(inst.ch_names, include=ch_names)
logger.info(f'Using EOG channel{_pl(ch_names)}: {", ".join(ch_names)}')
return eog_inds
@verbose
def create_eog_epochs(
raw,
ch_name=None,
event_id=998,
picks=None,
tmin=-0.5,
tmax=0.5,
l_freq=1,
h_freq=10,
reject=None,
flat=None,
baseline=None,
preload=True,
reject_by_annotation=True,
thresh=None,
decim=1,
verbose=None,
):
"""Conveniently generate epochs around EOG artifact events.
%(create_eog_epochs)s
Parameters
----------
raw : instance of Raw
The raw data.
%(ch_name_eog)s
event_id : int
The index to assign to found events.
%(picks_all)s
tmin : float
Start time before event.
tmax : float
End time after event.
l_freq : float
Low pass frequency to apply to the EOG channel while finding events.
h_freq : float
High pass frequency to apply to the EOG channel while finding events.
reject : dict | None
Rejection parameters based on peak-to-peak amplitude.
Valid keys are 'grad' | 'mag' | 'eeg' | 'eog' | 'ecg'.
If reject is None then no rejection is done. Example::
reject = dict(grad=4000e-13, # T / m (gradiometers)
mag=4e-12, # T (magnetometers)
eeg=40e-6, # V (EEG channels)
eog=250e-6 # V (EOG channels)
)
flat : dict | None
Rejection parameters based on flatness of signal.
Valid keys are 'grad' | 'mag' | 'eeg' | 'eog' | 'ecg', and values
are floats that set the minimum acceptable peak-to-peak amplitude.
If flat is None then no rejection is done.
baseline : tuple or list of length 2, or None
The time interval to apply rescaling / baseline correction.
If None do not apply it. If baseline is (a, b)
the interval is between "a (s)" and "b (s)".
If a is None the beginning of the data is used
and if b is None then b is set to the end of the interval.
If baseline is equal to (None, None) all the time
interval is used. If None, no correction is applied.
preload : bool
Preload epochs or not.
%(reject_by_annotation_epochs)s
.. versionadded:: 0.14.0
thresh : float
Threshold to trigger EOG event.
%(decim)s
.. versionadded:: 0.21.0
%(verbose)s
Returns
-------
eog_epochs : instance of Epochs
Data epoched around EOG events.
See Also
--------
find_eog_events
compute_proj_eog
Notes
-----
Filtering is only applied to the EOG channel while finding events.
The resulting ``eog_epochs`` will have no filtering applied (i.e., have
the same filter properties as the input ``raw`` instance).
"""
events = find_eog_events(
raw,
ch_name=ch_name,
event_id=event_id,
l_freq=l_freq,
h_freq=h_freq,
reject_by_annotation=reject_by_annotation,
thresh=thresh,
)
# create epochs around EOG events
eog_epochs = Epochs(
raw,
events=events,
event_id=event_id,
tmin=tmin,
tmax=tmax,
proj=False,
reject=reject,
flat=flat,
picks=picks,
baseline=baseline,
preload=preload,
reject_by_annotation=reject_by_annotation,
decim=decim,
)
return eog_epochs

View File

@@ -0,0 +1,10 @@
"""Eye tracking specific preprocessing functions."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from .eyetracking import set_channel_types_eyetrack, convert_units
from .calibration import Calibration, read_eyelink_calibration
from ._pupillometry import interpolate_blinks
from .utils import get_screen_visual_angle

View File

@@ -0,0 +1,121 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from ..._fiff.constants import FIFF
from ...annotations import _annotations_starts_stops
from ...io import BaseRaw
from ...utils import _check_preload, _validate_type, logger, warn
def interpolate_blinks(raw, buffer=0.05, match="BAD_blink", interpolate_gaze=False):
"""Interpolate eyetracking signals during blinks.
This function uses the timing of blink annotations to estimate missing
data. Missing values are then interpolated linearly. Operates in place.
Parameters
----------
raw : instance of Raw
The raw data with at least one ``'pupil'`` or ``'eyegaze'`` channel.
buffer : float | array-like of float, shape ``(2,))``
The time in seconds before and after a blink to consider invalid and
include in the segment to be interpolated over. Default is ``0.05`` seconds
(50 ms). If array-like, the first element is the time before the blink and the
second element is the time after the blink to consider invalid, for example,
``(0.025, .1)``.
match : str | list of str
The description of annotations to interpolate over. If a list, the data within
all annotations that match any of the strings in the list will be interpolated
over. If a ``match`` starts with ``'BAD_'``, that part will be removed from the
annotation description after interpolation. Defaults to ``'BAD_blink'``.
interpolate_gaze : bool
If False, only apply interpolation to ``'pupil channels'``. If True, interpolate
over ``'eyegaze'`` channels as well. Defaults to False, because eye position can
change in unpredictable ways during blinks.
Returns
-------
self : instance of Raw
Returns the modified instance.
Notes
-----
.. versionadded:: 1.5
"""
_check_preload(raw, "interpolate_blinks")
_validate_type(raw, BaseRaw, "raw")
_validate_type(buffer, (float, tuple, list, np.ndarray), "buffer")
_validate_type(match, (str, tuple, list, np.ndarray), "match")
# determine the buffer around blinks to include in the interpolation
buffer = np.array(buffer, dtype=float)
if buffer.size == 1:
buffer = np.array([buffer, buffer])
if isinstance(match, str):
match = [match]
# get the blink annotations
blink_annots = [annot for annot in raw.annotations if annot["description"] in match]
if not blink_annots:
warn(f"No annotations matching {match} found. Aborting.")
return raw
_interpolate_blinks(raw, buffer, blink_annots, interpolate_gaze=interpolate_gaze)
# remove bad from the annotation description
for desc in match:
if desc.startswith("BAD_"):
logger.info(f"Removing 'BAD_' from {desc}.")
raw.annotations.rename({desc: desc.replace("BAD_", "")})
return raw
def _interpolate_blinks(raw, buffer, blink_annots, interpolate_gaze):
"""Interpolate eyetracking signals during blinks in-place."""
logger.info("Interpolating missing data during blinks...")
pre_buffer, post_buffer = buffer
# iterate over each eyetrack channel and interpolate the blinks
interpolated_chs = []
for ci, ch_info in enumerate(raw.info["chs"]):
if interpolate_gaze: # interpolate over all eyetrack channels
if ch_info["kind"] != FIFF.FIFFV_EYETRACK_CH:
continue
else: # interpolate over pupil channels only
if ch_info["coil_type"] != FIFF.FIFFV_COIL_EYETRACK_PUPIL:
continue
# Create an empty boolean mask
mask = np.zeros_like(raw.times, dtype=bool)
starts, ends = _annotations_starts_stops(raw, "BAD_blink")
starts = np.divide(starts, raw.info["sfreq"])
ends = np.divide(ends, raw.info["sfreq"])
for annot, start, end in zip(blink_annots, starts, ends):
if "ch_names" not in annot or not annot["ch_names"]:
msg = f"Blink annotation missing values for 'ch_names' key: {annot}"
raise ValueError(msg)
start -= pre_buffer
end += post_buffer
if ch_info["ch_name"] not in annot["ch_names"]:
continue # skip if the channel is not in the blink annotation
# Update the mask for times within the current blink period
mask |= (raw.times >= start) & (raw.times <= end)
blink_indices = np.where(mask)[0]
non_blink_indices = np.where(~mask)[0]
# Linear interpolation
interpolated_samples = np.interp(
raw.times[blink_indices],
raw.times[non_blink_indices],
raw._data[ci, non_blink_indices],
)
# Replace the samples at the blink_indices with the interpolated values
raw._data[ci, blink_indices] = interpolated_samples
interpolated_chs.append(ch_info["ch_name"])
if interpolated_chs:
logger.info(
f"Interpolated {len(interpolated_chs)} channels: {interpolated_chs}"
)
else:
warn("No channels were interpolated.")

View File

@@ -0,0 +1,222 @@
"""Eyetracking Calibration(s) class constructor."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from copy import deepcopy
import numpy as np
from ...io.eyelink._utils import _parse_calibration
from ...utils import _check_fname, _validate_type, fill_doc, logger
from ...viz.utils import plt_show
@fill_doc
class Calibration(dict):
"""Eye-tracking calibration info.
This data structure behaves like a dictionary. It contains information regarding a
calibration that was conducted during an eye-tracking recording.
.. note::
When possible, a Calibration instance should be created with a helper function,
such as :func:`~mne.preprocessing.eyetracking.read_eyelink_calibration`.
Parameters
----------
onset : float
The onset of the calibration in seconds. If the calibration was
performed before the recording started, the the onset can be
negative.
model : str
A string, which is the model of the eye-tracking calibration that was applied.
For example ``'H3'`` for a horizontal only 3-point calibration, or ``'HV3'``
for a horizontal and vertical 3-point calibration.
eye : str
The eye that was calibrated. For example, ``'left'``, or ``'right'``.
avg_error : float
The average error in degrees between the calibration positions and the
actual gaze position.
max_error : float
The maximum error in degrees that occurred between the calibration
positions and the actual gaze position.
positions : array-like of float, shape ``(n_calibration_points, 2)``
The x and y coordinates of the calibration points.
offsets : array-like of float, shape ``(n_calibration_points,)``
The error in degrees between the calibration position and the actual
gaze position for each calibration point.
gaze : array-like of float, shape ``(n_calibration_points, 2)``
The x and y coordinates of the actual gaze position for each calibration point.
screen_size : array-like of shape ``(2,)``
The width and height (in meters) of the screen that the eyetracking
data was collected with. For example ``(.531, .298)`` for a monitor with
a display area of 531 x 298 mm.
screen_distance : float
The distance (in meters) from the participant's eyes to the screen.
screen_resolution : array-like of shape ``(2,)``
The resolution (in pixels) of the screen that the eyetracking data
was collected with. For example, ``(1920, 1080)`` for a 1920x1080
resolution display.
"""
def __init__(
self,
*,
onset,
model,
eye,
avg_error,
max_error,
positions,
offsets,
gaze,
screen_size=None,
screen_distance=None,
screen_resolution=None,
):
super().__init__(
onset=onset,
model=model,
eye=eye,
avg_error=avg_error,
max_error=max_error,
screen_size=screen_size,
screen_distance=screen_distance,
screen_resolution=screen_resolution,
positions=positions,
offsets=offsets,
gaze=gaze,
)
def __repr__(self):
"""Return a summary of the Calibration object."""
return (
f"Calibration |\n"
f" onset: {self['onset']} seconds\n"
f" model: {self['model']}\n"
f" eye: {self['eye']}\n"
f" average error: {self['avg_error']} degrees\n"
f" max error: {self['max_error']} degrees\n"
f" screen size: {self['screen_size']} meters\n"
f" screen distance: {self['screen_distance']} meters\n"
f" screen resolution: {self['screen_resolution']} pixels\n"
)
def copy(self):
"""Copy the instance.
Returns
-------
cal : instance of Calibration
The copied Calibration.
"""
return deepcopy(self)
def plot(self, show_offsets=True, axes=None, show=True):
"""Visualize calibration.
Parameters
----------
show_offsets : bool
Whether to display the offset (in visual degrees) of each calibration
point or not. Defaults to ``True``.
axes : instance of matplotlib.axes.Axes | None
Axes to draw the calibration positions to. If ``None`` (default), a new axes
will be created.
show : bool
Whether to show the figure or not. Defaults to ``True``.
Returns
-------
fig : instance of matplotlib.figure.Figure
The resulting figure object for the calibration plot.
"""
import matplotlib.pyplot as plt
msg = "positions and gaze keys must both be 2D numpy arrays."
assert isinstance(self["positions"], np.ndarray), msg
assert isinstance(self["gaze"], np.ndarray), msg
if axes is not None:
from matplotlib.axes import Axes
_validate_type(axes, Axes, "axes")
ax = axes
fig = ax.get_figure()
else: # create new figure and axes
fig, ax = plt.subplots(layout="constrained")
px, py = self["positions"].T
gaze_x, gaze_y = self["gaze"].T
ax.set_title(f"Calibration ({self['eye']} eye)")
ax.set_xlabel("x (pixels)")
ax.set_ylabel("y (pixels)")
# Display avg_error and max_error in the top left corner
text = (
f"avg_error: {self['avg_error']} deg.\nmax_error: {self['max_error']} deg."
)
ax.text(
0,
1.01,
text,
transform=ax.transAxes,
verticalalignment="baseline",
fontsize=8,
)
# Invert y-axis because the origin is in the top left corner
ax.invert_yaxis()
ax.scatter(px, py, color="gray")
ax.scatter(gaze_x, gaze_y, color="red", alpha=0.5)
if show_offsets:
for i in range(len(px)):
x_offset = 0.01 * gaze_x[i] # 1% to the right of the gazepoint
text = ax.text(
x=gaze_x[i] + x_offset,
y=gaze_y[i],
s=self["offsets"][i],
fontsize=8,
ha="left",
va="center",
)
plt_show(show)
return fig
@fill_doc
def read_eyelink_calibration(
fname, screen_size=None, screen_distance=None, screen_resolution=None
):
"""Return info on calibrations collected in an eyelink file.
Parameters
----------
fname : path-like
Path to the eyelink file (.asc).
screen_size : array-like of shape ``(2,)``
The width and height (in meters) of the screen that the eyetracking
data was collected with. For example ``(.531, .298)`` for a monitor with
a display area of 531 x 298 mm. Defaults to ``None``.
screen_distance : float
The distance (in meters) from the participant's eyes to the screen.
Defaults to ``None``.
screen_resolution : array-like of shape ``(2,)``
The resolution (in pixels) of the screen that the eyetracking data
was collected with. For example, ``(1920, 1080)`` for a 1920x1080
resolution display. Defaults to ``None``.
Returns
-------
calibrations : list
A list of :class:`~mne.preprocessing.eyetracking.Calibration` instances, one for
each eye of every calibration that was performed during the recording session.
"""
fname = _check_fname(fname, overwrite="read", must_exist=True, name="fname")
logger.info(f"Reading calibration data from {fname}")
lines = fname.read_text(encoding="ASCII").splitlines()
return _parse_calibration(lines, screen_size, screen_distance, screen_resolution)

View File

@@ -0,0 +1,327 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from ..._fiff.constants import FIFF
from ...epochs import BaseEpochs
from ...evoked import Evoked
from ...io import BaseRaw
from ...utils import _check_option, _validate_type, logger, warn
from .calibration import Calibration
from .utils import _check_calibration
# specific function to set eyetrack channels
def set_channel_types_eyetrack(inst, mapping):
"""Define sensor type for eyetrack channels.
This function can set all eye tracking specific information:
channel type, unit, eye (and x/y component; only for gaze channels)
Supported channel types:
``'eyegaze'`` and ``'pupil'``
Supported units:
``'au'``, ``'px'``, ``'deg'``, ``'rad'`` (for eyegaze)
``'au'``, ``'mm'``, ``'m'`` (for pupil)
Parameters
----------
inst : instance of Raw, Epochs, or Evoked
The data instance.
mapping : dict
A dictionary mapping a channel to a list/tuple including
channel type, unit, eye, [and x/y component] (all as str), e.g.,
``{'l_x': ('eyegaze', 'deg', 'left', 'x')}`` or
``{'r_pupil': ('pupil', 'au', 'right')}``.
Returns
-------
inst : instance of Raw | Epochs | Evoked
The instance, modified in place.
Notes
-----
``inst.set_channel_types()`` to ``'eyegaze'`` or ``'pupil'``
works as well, but cannot correctly set unit, eye and x/y component.
Data will be stored in SI units:
if your data comes in ``deg`` (visual angle) it will be converted to
``rad``, if it is in ``mm`` it will be converted to ``m``.
"""
ch_names = inst.info["ch_names"]
# allowed
valid_types = ["eyegaze", "pupil"] # ch_type
valid_units = {
"px": ["px", "pixel"],
"rad": ["rad", "radian", "radians"],
"deg": ["deg", "degree", "degrees"],
"m": ["m", "meter", "meters"],
"mm": ["mm", "millimeter", "millimeters"],
"au": [None, "none", "au", "arbitrary"],
}
valid_units["all"] = [item for sublist in valid_units.values() for item in sublist]
valid_eye = {"l": ["left", "l"], "r": ["right", "r"]}
valid_eye["all"] = [item for sublist in valid_eye.values() for item in sublist]
valid_xy = {"x": ["x", "h", "horizontal"], "y": ["y", "v", "vertical"]}
valid_xy["all"] = [item for sublist in valid_xy.values() for item in sublist]
# loop over channels
for ch_name, ch_desc in mapping.items():
if ch_name not in ch_names:
raise ValueError(f"This channel name ({ch_name}) doesn't exist in info.")
c_ind = ch_names.index(ch_name)
# set ch_type and unit
ch_type = ch_desc[0].lower()
if ch_type not in valid_types:
raise ValueError(
f"ch_type must be one of {valid_types}. Got '{ch_type}' instead."
)
if ch_type == "eyegaze":
coil_type = FIFF.FIFFV_COIL_EYETRACK_POS
elif ch_type == "pupil":
coil_type = FIFF.FIFFV_COIL_EYETRACK_PUPIL
inst.info["chs"][c_ind]["coil_type"] = coil_type
inst.info["chs"][c_ind]["kind"] = FIFF.FIFFV_EYETRACK_CH
ch_unit = None if (ch_desc[1] is None) else ch_desc[1].lower()
if ch_unit not in valid_units["all"]:
raise ValueError(
"unit must be one of {}. Got '{}' instead.".format(
valid_units["all"], ch_unit
)
)
if ch_unit in valid_units["px"]:
unit_new = FIFF.FIFF_UNIT_PX
elif ch_unit in valid_units["rad"]:
unit_new = FIFF.FIFF_UNIT_RAD
elif ch_unit in valid_units["deg"]: # convert deg to rad (SI)
inst = inst.apply_function(_convert_deg_to_rad, picks=ch_name)
unit_new = FIFF.FIFF_UNIT_RAD
elif ch_unit in valid_units["m"]:
unit_new = FIFF.FIFF_UNIT_M
elif ch_unit in valid_units["mm"]: # convert mm to m (SI)
inst = inst.apply_function(_convert_mm_to_m, picks=ch_name)
unit_new = FIFF.FIFF_UNIT_M
elif ch_unit in valid_units["au"]:
unit_new = FIFF.FIFF_UNIT_NONE
inst.info["chs"][c_ind]["unit"] = unit_new
# set eye (and x/y-component)
loc = np.array(
[
np.nan,
np.nan,
np.nan,
np.nan,
np.nan,
np.nan,
np.nan,
np.nan,
np.nan,
np.nan,
np.nan,
np.nan,
]
)
ch_eye = ch_desc[2].lower()
if ch_eye not in valid_eye["all"]:
raise ValueError(
"eye must be one of {}. Got '{}' instead.".format(
valid_eye["all"], ch_eye
)
)
if ch_eye in valid_eye["l"]:
loc[3] = -1
elif ch_eye in valid_eye["r"]:
loc[3] = 1
if ch_type == "eyegaze":
ch_xy = ch_desc[3].lower()
if ch_xy not in valid_xy["all"]:
raise ValueError(
"x/y must be one of {}. Got '{}' instead.".format(
valid_xy["all"], ch_xy
)
)
if ch_xy in valid_xy["x"]:
loc[4] = -1
elif ch_xy in valid_xy["y"]:
loc[4] = 1
inst.info["chs"][c_ind]["loc"] = loc
return inst
def _convert_mm_to_m(array):
return array * 0.001
def _convert_deg_to_rad(array):
return array * np.pi / 180.0
def convert_units(inst, calibration, to="radians"):
"""Convert Eyegaze data from pixels to radians of visual angle or vice versa.
.. warning::
Currently, depending on the units (pixels or radians), eyegaze channels may not
be reported correctly in visualization functions like :meth:`mne.io.Raw.plot`.
They will be shown correctly in :func:`mne.viz.eyetracking.plot_gaze`.
See :gh:`11879` for more information.
.. Important::
There are important considerations to keep in mind when using this function,
see the Notes section below.
Parameters
----------
inst : instance of Raw, Epochs, or Evoked
The Raw, Epochs, or Evoked instance with eyegaze channels.
calibration : Calibration
Instance of Calibration, containing information about the screen size
(in meters), viewing distance (in meters), and the screen resolution
(in pixels).
to : str
Must be either ``"radians"`` or ``"pixels"``, indicating the desired unit.
Returns
-------
inst : instance of Raw | Epochs | Evoked
The Raw, Epochs, or Evoked instance, modified in place.
Notes
-----
There are at least two important considerations to keep in mind when using this
function:
1. Converting between on-screen pixels and visual angle is not a linear
transformation. If the visual angle subtends less than approximately ``.44``
radians (``25`` degrees), the conversion could be considered to be approximately
linear. However, as the visual angle increases, the conversion becomes
increasingly non-linear. This may lead to unexpected results after converting
between pixels and visual angle.
* This function assumes that the head is fixed in place and aligned with the center
of the screen, such that gaze to the center of the screen results in a visual
angle of ``0`` radians.
.. versionadded:: 1.7
"""
_validate_type(inst, (BaseRaw, BaseEpochs, Evoked), "inst")
_validate_type(calibration, Calibration, "calibration")
_check_option("to", to, ("radians", "pixels"))
_check_calibration(calibration)
# get screen parameters
screen_size = calibration["screen_size"]
screen_resolution = calibration["screen_resolution"]
dist = calibration["screen_distance"]
# loop through channels and convert units
converted_chs = []
for ch_dict in inst.info["chs"]:
if ch_dict["coil_type"] != FIFF.FIFFV_COIL_EYETRACK_POS:
continue
unit = ch_dict["unit"]
name = ch_dict["ch_name"]
if ch_dict["loc"][4] == -1: # x-coordinate
size = screen_size[0]
res = screen_resolution[0]
elif ch_dict["loc"][4] == 1: # y-coordinate
size = screen_size[1]
res = screen_resolution[1]
else:
raise ValueError(
f"loc array not set properly for channel '{name}'. Index 4 should"
f" be -1 or 1, but got {ch_dict['loc'][4]}"
)
# check unit, convert, and set new unit
if to == "radians":
if unit != FIFF.FIFF_UNIT_PX:
raise ValueError(
f"Data must be in pixels in order to convert to radians."
f" Got {unit} for {name}"
)
inst.apply_function(_pix_to_rad, picks=name, size=size, res=res, dist=dist)
ch_dict["unit"] = FIFF.FIFF_UNIT_RAD
elif to == "pixels":
if unit != FIFF.FIFF_UNIT_RAD:
raise ValueError(
f"Data must be in radians in order to convert to pixels."
f" Got {unit} for {name}"
)
inst.apply_function(_rad_to_pix, picks=name, size=size, res=res, dist=dist)
ch_dict["unit"] = FIFF.FIFF_UNIT_PX
converted_chs.append(name)
if converted_chs:
logger.info(f"Converted {converted_chs} to {to}.")
if to == "radians":
# check if any values are greaater than .44 radians
# (25 degrees) and warn user
data = inst.get_data(picks=converted_chs)
if np.any(np.abs(data) > 0.52):
warn(
"Some visual angle values subtend greater than .52 radians "
"(30 degrees), meaning that the conversion between pixels "
"and visual angle may be very non-linear. Take caution when "
"interpreting these values. Max visual angle value in data:"
f" {np.nanmax(data):0.2f} radians.",
UserWarning,
)
else:
warn("Could not find any eyegaze channels. Doing nothing.", UserWarning)
return inst
def _pix_to_rad(data, size, res, dist):
"""Convert pixel coordinates to radians of visual angle.
Parameters
----------
data : array-like, shape (n_samples,)
A vector of pixel coordinates.
size : float
The width or height of the screen, in meters.
res : int
The screen resolution in pixels, along the x or y axis.
dist : float
The viewing distance from the screen, in meters.
Returns
-------
rad : ndarray, shape (n_samples)
the data in radians.
"""
# Center the data so that 0 radians will be the center of the screen
data -= res / 2
# How many meters is the pixel width or height
px_size = size / res
# Convert to radians
return np.arctan((data * px_size) / dist)
def _rad_to_pix(data, size, res, dist):
"""Convert radians of visual angle to pixel coordinates.
See the parameters section of _pix_to_rad for more information.
Returns
-------
pix : ndarray, shape (n_samples)
the data in pixels.
"""
# How many meters is the pixel width or height
px_size = size / res
# 1. calculate length of opposite side of triangle (in meters)
# 2. convert meters to pixel coordinates
# 3. add half of screen resolution to uncenter the pixel data (0,0 is top left)
return np.tan(data) * dist / px_size + res / 2

View File

@@ -0,0 +1,45 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from ...utils import _validate_type
from .calibration import Calibration
def _check_calibration(
calibration, want_keys=("screen_size", "screen_resolution", "screen_distance")
):
missing_keys = []
for key in want_keys:
if calibration.get(key, None) is None:
missing_keys.append(key)
if missing_keys:
raise KeyError(
"Calibration object must have the following keys with valid values:"
f" {', '.join(missing_keys)}"
)
else:
return True
def get_screen_visual_angle(calibration):
"""Calculate the radians of visual angle that the participant screen subtends.
Parameters
----------
calibration : Calibration
An instance of Calibration. Must have valid values for ``"screen_size"`` and
``"screen_distance"`` keys.
Returns
-------
visual angle in radians : ndarray, shape (2,)
The visual angle of the monitor width and height, respectively.
"""
_validate_type(calibration, Calibration, "calibration")
_check_calibration(calibration, want_keys=("screen_size", "screen_distance"))
size = np.array(calibration["screen_size"])
return 2 * np.arctan(size / (2 * calibration["screen_distance"]))

106
mne/preprocessing/hfc.py Normal file
View File

@@ -0,0 +1,106 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from .._fiff.pick import _picks_to_idx, pick_info
from .._fiff.proj import Projection
from ..utils import verbose
from .maxwell import _prep_mf_coils, _sss_basis
@verbose
def compute_proj_hfc(
info, order=1, picks="meg", exclude="bads", *, accuracy="accurate", verbose=None
):
"""Generate projectors to perform homogeneous/harmonic correction to data.
Remove environmental fields from magnetometer data by assuming it is
explained as a homogeneous :footcite:`TierneyEtAl2021` or harmonic field
:footcite:`TierneyEtAl2022`. Useful for arrays of OPMs.
Parameters
----------
%(info)s
order : int
The order of the spherical harmonic basis set to use. Set to 1 to use
only the homogeneous field component (default), 2 to add gradients, 3
to add quadrature terms, etc.
picks : str | array_like | slice | None
Channels to include. Default of ``'meg'`` (same as None) will select
all non-reference MEG channels. Use ``('meg', 'ref_meg')`` to include
reference sensors as well.
exclude : list | 'bads'
List of channels to exclude from HFC, only used when picking
based on types (e.g., exclude="bads" when picks="meg").
Specify ``'bads'`` (the default) to exclude all channels marked as bad.
accuracy : str
Can be ``"point"``, ``"normal"`` or ``"accurate"`` (default), defines
which level of coil definition accuracy is used to generate model.
%(verbose)s
Returns
-------
%(projs)s
See Also
--------
mne.io.Raw.add_proj
mne.io.Raw.apply_proj
Notes
-----
To apply the projectors to a dataset, use
``inst.add_proj(projs).apply_proj()``.
.. versionadded:: 1.4
References
----------
.. footbibliography::
"""
picks = _picks_to_idx(info, picks, none="meg", exclude=exclude, with_ref_meg=False)
info = pick_info(info, picks)
del picks
exp = dict(origin=(0.0, 0.0, 0.0), int_order=0, ext_order=order)
coils = _prep_mf_coils(info, ignore_ref=False, accuracy=accuracy)
n_chs = len(coils[5])
if n_chs != info["nchan"]:
raise ValueError(
f'Only {n_chs}/{info["nchan"]} picks could be interpreted '
"as MEG channels."
)
S = _sss_basis(exp, coils)
del coils
bad_chans = [
info["ch_names"][pick] for pick in np.where((~np.isfinite(S)).any(axis=1))[0]
]
if bad_chans:
raise ValueError(
"The following channel(s) generate non-finite projectors:\n"
f" {bad_chans}\nPlease exclude from picks!"
)
S /= np.linalg.norm(S, axis=0)
labels = _label_basis(order)
assert len(labels) == S.shape[1]
projs = []
for label, vec in zip(labels, S.T):
proj_data = dict(
col_names=info["ch_names"],
row_names=None,
data=vec[np.newaxis, :],
ncol=info["nchan"],
nrow=1,
)
projs.append(Projection(active=False, data=proj_data, desc=label))
return projs
def _label_basis(order):
"""Give basis vectors names for Projection() class."""
return [
f"HFC: l={L} m={m}"
for L in np.arange(1, order + 1)
for m in np.arange(-1 * L, L + 1)
]

3552
mne/preprocessing/ica.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,8 @@
"""Intracranial EEG specific preprocessing functions."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from ._projection import project_sensors_onto_brain
from ._volume import make_montage_volume, warp_montage

View File

@@ -0,0 +1,209 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from itertools import combinations
import numpy as np
from scipy.spatial.distance import pdist, squareform
from ..._fiff.pick import _picks_to_idx
from ...channels import make_dig_montage
from ...surface import (
_compute_nearest,
_read_mri_surface,
_read_patch,
fast_cross_3d,
read_surface,
)
from ...transforms import _cart_to_sph, _ensure_trans, apply_trans, invert_transform
from ...utils import _ensure_int, _validate_type, get_subjects_dir, verbose
@verbose
def project_sensors_onto_brain(
info,
trans,
subject,
subjects_dir=None,
picks=None,
n_neighbors=10,
copy=True,
verbose=None,
):
"""Project sensors onto the brain surface.
Parameters
----------
%(info_not_none)s
%(trans_not_none)s
%(subject)s
%(subjects_dir)s
%(picks_base)s only ``ecog`` channels.
n_neighbors : int
The number of neighbors to use to compute the normal vectors
for the projection. Must be 2 or greater. More neighbors makes
a normal vector with greater averaging which preserves the grid
structure. Fewer neighbors has less averaging which better
preserves contours in the grid.
copy : bool
If ``True``, return a new instance of ``info``, if ``False``
``info`` is modified in place.
%(verbose)s
Returns
-------
%(info_not_none)s
Notes
-----
This is useful in ECoG analysis for compensating for "brain shift"
or shrinking of the brain away from the skull due to changes
in pressure during the craniotomy.
To use the brain surface, a BEM model must be created e.g. using
:ref:`mne watershed_bem` using the T1 or :ref:`mne flash_bem`
using a FLASH scan.
"""
n_neighbors = _ensure_int(n_neighbors, "n_neighbors")
_validate_type(copy, bool, "copy")
if copy:
info = info.copy()
if n_neighbors < 2:
raise ValueError(f"n_neighbors must be 2 or greater, got {n_neighbors}")
subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
try:
surf = _read_mri_surface(subjects_dir / subject / "bem" / "brain.surf")
except FileNotFoundError as err:
raise RuntimeError(
f"{err}\n\nThe brain surface requires generating "
"a BEM using `mne flash_bem` (if you have "
"the FLASH scan) or `mne watershed_bem` (to "
"use the T1)"
) from None
# get channel locations
picks_idx = _picks_to_idx(info, "ecog" if picks is None else picks)
locs = np.array([info["chs"][idx]["loc"][:3] for idx in picks_idx])
trans = _ensure_trans(trans, "head", "mri")
locs = apply_trans(trans, locs)
# compute distances for nearest neighbors
dists = squareform(pdist(locs))
# find angles for brain surface and points
angles = _cart_to_sph(locs)
surf_angles = _cart_to_sph(surf["rr"])
# initialize projected locs
proj_locs = np.zeros(locs.shape) * np.nan
for i, loc in enumerate(locs):
neighbor_pts = locs[np.argsort(dists[i])[: n_neighbors + 1]]
pt1, pt2, pt3 = map(np.array, zip(*combinations(neighbor_pts, 3)))
normals = fast_cross_3d(pt1 - pt2, pt1 - pt3)
normals[normals @ loc < 0] *= -1
normal = np.mean(normals, axis=0)
normal /= np.linalg.norm(normal)
# find the correct orientation brain surface point nearest the line
# https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html
use_rr = surf["rr"][
abs(surf_angles[:, 1:] - angles[i, 1:]).sum(axis=1) < np.pi / 4
]
surf_dists = np.linalg.norm(
fast_cross_3d(use_rr - loc, use_rr - loc + normal), axis=1
)
proj_locs[i] = use_rr[np.argmin(surf_dists)]
# back to the "head" coordinate frame for storing in ``raw``
proj_locs = apply_trans(invert_transform(trans), proj_locs)
montage = info.get_montage()
montage_kwargs = (
montage.get_positions() if montage else dict(ch_pos=dict(), coord_frame="head")
)
for idx, loc in zip(picks_idx, proj_locs):
# surface RAS-> head and mm->m
montage_kwargs["ch_pos"][info.ch_names[idx]] = loc
info.set_montage(make_dig_montage(**montage_kwargs))
return info
@verbose
def _project_sensors_onto_inflated(
info,
trans,
subject,
subjects_dir=None,
picks=None,
max_dist=0.004,
flat=False,
verbose=None,
):
"""Project sensors onto the brain surface.
Parameters
----------
%(info_not_none)s
%(trans_not_none)s
%(subject)s
%(subjects_dir)s
%(picks_base)s only ``seeg`` channels.
%(max_dist_ieeg)s
flat : bool
Whether to project the sensors onto the flat map of the
inflated brain instead of the normal inflated brain.
%(verbose)s
Returns
-------
%(info_not_none)s
Notes
-----
This is useful in sEEG analysis for visualization
"""
subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
surf_data = dict(lh=dict(), rh=dict())
x_dir = np.array([1.0, 0.0, 0.0])
surfs = ("pial", "inflated")
if flat:
surfs += ("cortex.patch.flat",)
for hemi in ("lh", "rh"):
for surf in surfs:
for img in ("", ".T1", ".T2", ""):
surf_fname = subjects_dir / subject / "surf" / f"{hemi}.{surf}"
if surf_fname.is_file():
break
if surf.split(".")[-1] == "flat":
surf = "flat"
coords, faces, orig_faces = _read_patch(surf_fname)
# rotate 90 degrees to get to a more standard orientation
# where X determines the distance between the hemis
coords = coords[:, [1, 0, 2]]
coords[:, 1] *= -1
else:
coords, faces = read_surface(surf_fname)
if surf in ("inflated", "flat"):
x_ = coords @ x_dir
coords -= np.max(x_) * x_dir if hemi == "lh" else np.min(x_) * x_dir
surf_data[hemi][surf] = (coords / 1000, faces) # mm -> m
# get channel locations
picks_idx = _picks_to_idx(info, "seeg" if picks is None else picks)
locs = np.array([info["chs"][idx]["loc"][:3] for idx in picks_idx])
trans = _ensure_trans(trans, "head", "mri")
locs = apply_trans(trans, locs)
# initialize projected locs
proj_locs = np.zeros(locs.shape) * np.nan
surf = "flat" if flat else "inflated"
for hemi in ("lh", "rh"):
hemi_picks = np.where(locs[:, 0] <= 0 if hemi == "lh" else locs[:, 0] > 0)[0]
# compute distances to pial vertices
nearest, dists = _compute_nearest(
surf_data[hemi]["pial"][0], locs[hemi_picks], return_dists=True
)
mask = dists / 1000 < max_dist
proj_locs[hemi_picks[mask]] = surf_data[hemi][surf][0][nearest[mask]]
# back to the "head" coordinate frame for storing in ``raw``
proj_locs = apply_trans(invert_transform(trans), proj_locs)
montage = info.get_montage()
montage_kwargs = (
montage.get_positions() if montage else dict(ch_pos=dict(), coord_frame="head")
)
for idx, loc in zip(picks_idx, proj_locs):
montage_kwargs["ch_pos"][info.ch_names[idx]] = loc
info.set_montage(make_dig_montage(**montage_kwargs))
return info

View File

@@ -0,0 +1,242 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from ...channels import DigMontage, make_dig_montage
from ...surface import _voxel_neighbors
from ...transforms import Transform, _frame_to_str, apply_trans
from ...utils import _check_option, _pl, _require_version, _validate_type, verbose, warn
@verbose
def warp_montage(montage, moving, static, reg_affine, sdr_morph, verbose=None):
"""Warp a montage to a template with image volumes using SDR.
.. note:: This is likely only applicable for channels inside the brain
(intracranial electrodes).
Parameters
----------
montage : instance of mne.channels.DigMontage
The montage object containing the channels.
%(moving)s
%(static)s
%(reg_affine)s
%(sdr_morph)s
%(verbose)s
Returns
-------
montage_warped : mne.channels.DigMontage
The modified montage object containing the channels.
"""
_require_version("nibabel", "warp montage", "2.1.0")
_require_version("dipy", "warping points using SDR", "1.6.0")
from dipy.align.imwarp import DiffeomorphicMap
from nibabel import MGHImage
from nibabel.spatialimages import SpatialImage
_validate_type(moving, SpatialImage, "moving")
_validate_type(static, SpatialImage, "static")
_validate_type(reg_affine, np.ndarray, "reg_affine")
_check_option("reg_affine.shape", reg_affine.shape, ((4, 4),))
_validate_type(sdr_morph, (DiffeomorphicMap, None), "sdr_morph")
_validate_type(montage, DigMontage, "montage")
moving_mgh = MGHImage(np.array(moving.dataobj).astype(np.float32), moving.affine)
static_mgh = MGHImage(np.array(static.dataobj).astype(np.float32), static.affine)
del moving, static
# get montage channel coordinates
ch_dict = montage.get_positions()
if ch_dict["coord_frame"] != "mri":
bad_coord_frames = np.unique([d["coord_frame"] for d in montage.dig])
bad_coord_frames = ", ".join(
[
_frame_to_str[cf] if cf in _frame_to_str else str(cf)
for cf in bad_coord_frames
]
)
raise RuntimeError(
f'Coordinate frame not supported, expected "mri", got {bad_coord_frames}'
)
ch_names = list(ch_dict["ch_pos"].keys())
ch_coords = np.array([ch_dict["ch_pos"][name] for name in ch_names])
ch_coords = apply_trans( # convert to moving voxel space
np.linalg.inv(moving_mgh.header.get_vox2ras_tkr()), ch_coords * 1000
)
# next, to moving scanner RAS
ch_coords = apply_trans(moving_mgh.header.get_vox2ras(), ch_coords)
# now, apply reg_affine
ch_coords = apply_trans(
Transform( # to static ras
fro="ras", to="ras", trans=np.linalg.inv(reg_affine)
),
ch_coords,
)
# now, apply SDR morph
if sdr_morph is not None:
ch_coords = sdr_morph.transform_points(
ch_coords,
coord2world=sdr_morph.domain_grid2world,
world2coord=sdr_morph.domain_world2grid,
)
# back to voxels but now for the static image
ch_coords = apply_trans(np.linalg.inv(static_mgh.header.get_vox2ras()), ch_coords)
# finally, back to surface RAS
ch_coords = apply_trans(static_mgh.header.get_vox2ras_tkr(), ch_coords) / 1000
# make warped montage
montage_warped = make_dig_montage(dict(zip(ch_names, ch_coords)), coord_frame="mri")
return montage_warped
def _warn_missing_chs(info, dig_image, after_warp=False):
"""Warn that channels are missing."""
# ensure that each electrode contact was marked in at least one voxel
missing = set(np.arange(1, len(info.ch_names) + 1)).difference(
set(np.unique(np.array(dig_image.dataobj)))
)
missing_ch = [info.ch_names[idx - 1] for idx in missing]
if missing_ch:
warn(
f"Channel{_pl(missing_ch)} "
f'{", ".join(repr(ch) for ch in missing_ch)} not assigned '
"voxels " + (f" after applying {after_warp}" if after_warp else "")
)
@verbose
def make_montage_volume(
montage,
base_image,
thresh=0.5,
max_peak_dist=1,
voxels_max=100,
use_min=False,
verbose=None,
):
"""Make a volume from intracranial electrode contact locations.
Find areas of the input volume with intensity greater than
a threshold surrounding local extrema near the channel location.
Monotonicity from the peak is enforced to prevent channels
bleeding into each other.
Parameters
----------
montage : instance of mne.channels.DigMontage
The montage object containing the channels.
base_image : path-like | nibabel.spatialimages.SpatialImage
Path to a volumetric scan (e.g. CT) of the subject. Can be in any
format readable by nibabel. Can also be a nibabel image object.
Local extrema (max or min) should be nearby montage channel locations.
thresh : float
The threshold relative to the peak to determine the size
of the sensors on the volume.
max_peak_dist : int
The number of voxels away from the channel location to
look in the ``image``. This will depend on the accuracy of
the channel locations, the default (one voxel in all directions)
will work only with localizations that are that accurate.
voxels_max : int
The maximum number of voxels for each channel.
use_min : bool
Whether to hypointensities in the volume as channel locations.
Default False uses hyperintensities.
%(verbose)s
Returns
-------
elec_image : nibabel.spatialimages.SpatialImage
An image in Freesurfer surface RAS space with voxel values
corresponding to the index of the channel. The background
is 0s and this index starts at 1.
"""
_require_version("nibabel", "montage volume", "2.1.0")
import nibabel as nib
_validate_type(montage, DigMontage, "montage")
_validate_type(base_image, nib.spatialimages.SpatialImage, "base_image")
_validate_type(thresh, float, "thresh")
if thresh < 0 or thresh >= 1:
raise ValueError(f"`thresh` must be between 0 and 1, got {thresh}")
_validate_type(max_peak_dist, int, "max_peak_dist")
_validate_type(voxels_max, int, "voxels_max")
_validate_type(use_min, bool, "use_min")
# load image and make sure it's in surface RAS
if not isinstance(base_image, nib.spatialimages.SpatialImage):
base_image = nib.load(base_image)
base_image_mgh = nib.MGHImage(
np.array(base_image.dataobj).astype(np.float32), base_image.affine
)
del base_image
# get montage channel coordinates
ch_dict = montage.get_positions()
if ch_dict["coord_frame"] != "mri":
bad_coord_frames = np.unique([d["coord_frame"] for d in montage.dig])
bad_coord_frames = ", ".join(
[
_frame_to_str[cf] if cf in _frame_to_str else str(cf)
for cf in bad_coord_frames
]
)
raise RuntimeError(
f'Coordinate frame not supported, expected "mri", got {bad_coord_frames}'
)
ch_names = list(ch_dict["ch_pos"].keys())
ch_coords = np.array([ch_dict["ch_pos"][name] for name in ch_names])
# convert to voxel space
ch_coords = apply_trans(
np.linalg.inv(base_image_mgh.header.get_vox2ras_tkr()), ch_coords * 1000
)
# take channel coordinates and use the image to transform them
# into a volume where all the voxels over a threshold nearby
# are labeled with an index
image_data = np.array(base_image_mgh.dataobj)
if use_min:
image_data *= -1
elec_image = np.zeros(base_image_mgh.shape, dtype=int)
for i, ch_coord in enumerate(ch_coords):
if np.isnan(ch_coord).any():
continue
# this looks up to a voxel away, it may be marked imperfectly
volume = _voxel_neighbors(
ch_coord,
image_data,
thresh=thresh,
max_peak_dist=max_peak_dist,
voxels_max=voxels_max,
)
for voxel in volume:
if elec_image[voxel] != 0:
# some voxels ambiguous because the contacts are bridged on
# the image so assign the voxel to the nearest contact location
dist_old = np.sqrt(
(ch_coords[elec_image[voxel] - 1] - voxel) ** 2
).sum()
dist_new = np.sqrt((ch_coord - voxel) ** 2).sum()
if dist_new < dist_old:
elec_image[voxel] = i + 1
else:
elec_image[voxel] = i + 1
# assemble the volume
elec_image = nib.spatialimages.SpatialImage(elec_image, base_image_mgh.affine)
_warn_missing_chs(montage, elec_image, after_warp=False)
return elec_image

View File

@@ -0,0 +1,336 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import math
import numpy as np
from scipy.special import expit
from scipy.stats import kurtosis
from ..utils import check_random_state, logger, random_permutation, verbose
@verbose
def infomax(
data,
weights=None,
l_rate=None,
block=None,
w_change=1e-12,
anneal_deg=60.0,
anneal_step=0.9,
extended=True,
n_subgauss=1,
kurt_size=6000,
ext_blocks=1,
max_iter=200,
random_state=None,
blowup=1e4,
blowup_fac=0.5,
n_small_angle=20,
use_bias=True,
verbose=None,
return_n_iter=False,
):
"""Run (extended) Infomax ICA decomposition on raw data.
Parameters
----------
data : np.ndarray, shape (n_samples, n_features)
The whitened data to unmix.
weights : np.ndarray, shape (n_features, n_features)
The initialized unmixing matrix.
Defaults to None, which means the identity matrix is used.
l_rate : float
This quantity indicates the relative size of the change in weights.
Defaults to ``0.01 / log(n_features ** 2)``.
.. note:: Smaller learning rates will slow down the ICA procedure.
block : int
The block size of randomly chosen data segments.
Defaults to floor(sqrt(n_times / 3.)).
w_change : float
The change at which to stop iteration. Defaults to 1e-12.
anneal_deg : float
The angle (in degrees) at which the learning rate will be reduced.
Defaults to 60.0.
anneal_step : float
The factor by which the learning rate will be reduced once
``anneal_deg`` is exceeded: ``l_rate *= anneal_step.``
Defaults to 0.9.
extended : bool
Whether to use the extended Infomax algorithm or not.
Defaults to True.
n_subgauss : int
The number of subgaussian components. Only considered for extended
Infomax. Defaults to 1.
kurt_size : int
The window size for kurtosis estimation. Only considered for extended
Infomax. Defaults to 6000.
ext_blocks : int
Only considered for extended Infomax. If positive, denotes the number
of blocks after which to recompute the kurtosis, which is used to
estimate the signs of the sources. In this case, the number of
sub-gaussian sources is automatically determined.
If negative, the number of sub-gaussian sources to be used is fixed
and equal to n_subgauss. In this case, the kurtosis is not estimated.
Defaults to 1.
max_iter : int
The maximum number of iterations. Defaults to 200.
%(random_state)s
blowup : float
The maximum difference allowed between two successive estimations of
the unmixing matrix. Defaults to 10000.
blowup_fac : float
The factor by which the learning rate will be reduced if the difference
between two successive estimations of the unmixing matrix exceededs
``blowup``: ``l_rate *= blowup_fac``. Defaults to 0.5.
n_small_angle : int | None
The maximum number of allowed steps in which the angle between two
successive estimations of the unmixing matrix is less than
``anneal_deg``. If None, this parameter is not taken into account to
stop the iterations. Defaults to 20.
use_bias : bool
This quantity indicates if the bias should be computed.
Defaults to True.
%(verbose)s
return_n_iter : bool
Whether to return the number of iterations performed. Defaults to
False.
Returns
-------
unmixing_matrix : np.ndarray, shape (n_features, n_features)
The linear unmixing operator.
n_iter : int
The number of iterations. Only returned if ``return_max_iter=True``.
References
----------
.. [1] A. J. Bell, T. J. Sejnowski. An information-maximization approach to
blind separation and blind deconvolution. Neural Computation, 7(6),
1129-1159, 1995.
.. [2] T. W. Lee, M. Girolami, T. J. Sejnowski. Independent component
analysis using an extended infomax algorithm for mixed subgaussian
and supergaussian sources. Neural Computation, 11(2), 417-441, 1999.
"""
rng = check_random_state(random_state)
# define some default parameters
max_weight = 1e8
restart_fac = 0.9
min_l_rate = 1e-10
degconst = 180.0 / np.pi
# for extended Infomax
extmomentum = 0.5
signsbias = 0.02
signcount_threshold = 25
signcount_step = 2
# check data shape
n_samples, n_features = data.shape
n_features_square = n_features**2
# check input parameters
# heuristic default - may need adjustment for large or tiny data sets
if l_rate is None:
l_rate = 0.01 / math.log(n_features**2.0)
if block is None:
block = int(math.floor(math.sqrt(n_samples / 3.0)))
logger.info(f"Computing{' Extended ' if extended else ' '}Infomax ICA")
# collect parameters
nblock = n_samples // block
lastt = (nblock - 1) * block + 1
# initialize training
if weights is None:
weights = np.identity(n_features, dtype=np.float64)
else:
weights = weights.T
BI = block * np.identity(n_features, dtype=np.float64)
bias = np.zeros((n_features, 1), dtype=np.float64)
onesrow = np.ones((1, block), dtype=np.float64)
startweights = weights.copy()
oldweights = startweights.copy()
step = 0
count_small_angle = 0
wts_blowup = False
blockno = 0
signcount = 0
initial_ext_blocks = ext_blocks # save the initial value in case of reset
# for extended Infomax
if extended:
signs = np.ones(n_features)
for k in range(n_subgauss):
signs[k] = -1
kurt_size = min(kurt_size, n_samples)
old_kurt = np.zeros(n_features, dtype=np.float64)
oldsigns = np.zeros(n_features)
# trainings loop
olddelta, oldchange = 1.0, 0.0
while step < max_iter:
# shuffle data at each step
permute = random_permutation(n_samples, rng)
# ICA training block
# loop across block samples
for t in range(0, lastt, block):
u = np.dot(data[permute[t : t + block], :], weights)
u += np.dot(bias, onesrow).T
if extended:
# extended ICA update
y = np.tanh(u)
weights += l_rate * np.dot(
weights, BI - signs[None, :] * np.dot(u.T, y) - np.dot(u.T, u)
)
if use_bias:
bias += l_rate * np.reshape(
np.sum(y, axis=0, dtype=np.float64) * -2.0, (n_features, 1)
)
else:
# logistic ICA weights update
y = expit(u)
weights += l_rate * np.dot(weights, BI + np.dot(u.T, (1.0 - 2.0 * y)))
if use_bias:
bias += l_rate * np.reshape(
np.sum((1.0 - 2.0 * y), axis=0, dtype=np.float64),
(n_features, 1),
)
# check change limit
max_weight_val = np.max(np.abs(weights))
if max_weight_val > max_weight:
wts_blowup = True
blockno += 1
if wts_blowup:
break
# ICA kurtosis estimation
if extended:
if ext_blocks > 0 and blockno % ext_blocks == 0:
if kurt_size < n_samples:
rp = np.floor(rng.uniform(0, 1, kurt_size) * (n_samples - 1))
tpartact = np.dot(data[rp.astype(int), :], weights).T
else:
tpartact = np.dot(data, weights).T
# estimate kurtosis
kurt = kurtosis(tpartact, axis=1, fisher=True)
if extmomentum != 0:
kurt = extmomentum * old_kurt + (1.0 - extmomentum) * kurt
old_kurt = kurt
# estimate weighted signs
signs = np.sign(kurt + signsbias)
ndiff = (signs - oldsigns != 0).sum()
if ndiff == 0:
signcount += 1
else:
signcount = 0
oldsigns = signs
if signcount >= signcount_threshold:
ext_blocks = np.fix(ext_blocks * signcount_step)
signcount = 0
# here we continue after the for loop over the ICA training blocks
# if weights in bounds:
if not wts_blowup:
oldwtchange = weights - oldweights
step += 1
angledelta = 0.0
delta = oldwtchange.reshape(1, n_features_square)
change = np.sum(delta * delta, dtype=np.float64)
if step > 2:
angledelta = math.acos(
np.sum(delta * olddelta) / math.sqrt(change * oldchange)
)
angledelta *= degconst
if verbose:
logger.info(
"step %d - lrate %5f, wchange %8.8f, angledelta %4.1f deg",
step,
l_rate,
change,
angledelta,
)
# anneal learning rate
oldweights = weights.copy()
if angledelta > anneal_deg:
l_rate *= anneal_step # anneal learning rate
# accumulate angledelta until anneal_deg reaches l_rate
olddelta = delta
oldchange = change
count_small_angle = 0 # reset count when angledelta is large
else:
if step == 1: # on first step only
olddelta = delta # initialize
oldchange = change
if n_small_angle is not None:
count_small_angle += 1
if count_small_angle > n_small_angle:
max_iter = step
# apply stopping rule
if step > 2 and change < w_change:
step = max_iter
elif change > blowup:
l_rate *= blowup_fac
# restart if weights blow up (for lowering l_rate)
else:
step = 0 # start again
wts_blowup = 0 # re-initialize variables
blockno = 1
l_rate *= restart_fac # with lower learning rate
weights = startweights.copy()
oldweights = startweights.copy()
olddelta = np.zeros((1, n_features_square), dtype=np.float64)
bias = np.zeros((n_features, 1), dtype=np.float64)
ext_blocks = initial_ext_blocks
# for extended Infomax
if extended:
signs = np.ones(n_features)
for k in range(n_subgauss):
signs[k] = -1
oldsigns = np.zeros(n_features)
if l_rate > min_l_rate:
if verbose:
logger.info(
f"... lowering learning rate to {l_rate:g}"
"\n... re-starting..."
)
else:
raise ValueError(
"Error in Infomax ICA: unmixing_matrix matrix"
"might not be invertible!"
)
# prepare return values
if return_n_iter:
return weights.T, step
else:
return weights.T

View File

@@ -0,0 +1,231 @@
"""Tools for data interpolation."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from itertools import chain
import numpy as np
from scipy.sparse.csgraph import connected_components
from .._fiff.meas_info import create_info
from ..epochs import BaseEpochs, EpochsArray
from ..evoked import Evoked, EvokedArray
from ..io import BaseRaw, RawArray
from ..transforms import _cart_to_sph, _sph_to_cart
from ..utils import _ensure_int, _validate_type
def equalize_bads(insts, interp_thresh=1.0, copy=True):
"""Interpolate or mark bads consistently for a list of instances.
Once called on a list of instances, the instances can be concatenated
as they will have the same list of bad channels.
Parameters
----------
insts : list
The list of instances (Evoked, Epochs or Raw) to consider
for interpolation. Each instance should have marked channels.
interp_thresh : float
A float between 0 and 1 (default) that specifies the fraction of time
a channel should be good to be eventually interpolated for certain
instances. For example if 0.5, a channel which is good at least half
of the time will be interpolated in the instances where it is marked
as bad. If 1 then channels will never be interpolated and if 0 all bad
channels will be systematically interpolated.
copy : bool
If True then the returned instances will be copies.
Returns
-------
insts_bads : list
The list of instances, with the same channel(s) marked as bad in all of
them, possibly with some formerly bad channels interpolated.
"""
if not 0 <= interp_thresh <= 1:
raise ValueError(f"interp_thresh must be between 0 and 1, got {interp_thresh}")
all_bads = list(set(chain.from_iterable([inst.info["bads"] for inst in insts])))
if isinstance(insts[0], BaseEpochs):
durations = [len(inst) * len(inst.times) for inst in insts]
else:
durations = [len(inst.times) for inst in insts]
good_times = []
for ch_name in all_bads:
good_times.append(
sum(
durations[k]
for k, inst in enumerate(insts)
if ch_name not in inst.info["bads"]
)
/ np.sum(durations)
)
bads_keep = [ch for k, ch in enumerate(all_bads) if good_times[k] < interp_thresh]
if copy:
insts = [inst.copy() for inst in insts]
for inst in insts:
if len(set(inst.info["bads"]) - set(bads_keep)):
inst.interpolate_bads(exclude=bads_keep)
inst.info["bads"] = bads_keep
return insts
def interpolate_bridged_electrodes(inst, bridged_idx, bad_limit=4):
"""Interpolate bridged electrode pairs.
Because bridged electrodes contain brain signal, it's just that the
signal is spatially smeared between the two electrodes, we can
make a virtual channel midway between the bridged pairs and use
that to aid in interpolation rather than completely discarding the
data from the two channels.
Parameters
----------
inst : instance of Epochs, Evoked, or Raw
The data object with channels that are to be interpolated.
bridged_idx : list of tuple
The indices of channels marked as bridged with each bridged
pair stored as a tuple.
bad_limit : int
The maximum number of electrodes that can be bridged together
(included) and interpolated. Above this number, an error will be
raised.
.. versionadded:: 1.2
Returns
-------
inst : instance of Epochs, Evoked, or Raw
The modified data object.
See Also
--------
mne.preprocessing.compute_bridged_electrodes
"""
_validate_type(inst, (BaseRaw, BaseEpochs, Evoked))
bad_limit = _ensure_int(bad_limit, "bad_limit")
if bad_limit <= 0:
raise ValueError(
"Argument 'bad_limit' should be a strictly positive "
f"integer. Provided {bad_limit} is invalid."
)
montage = inst.get_montage()
if montage is None:
raise RuntimeError("No channel positions found in ``inst``")
pos = montage.get_positions()
if pos["coord_frame"] != "head":
raise RuntimeError(
f"Montage channel positions must be in ``head`` got {pos['coord_frame']}"
)
# store bads orig to put back at the end
bads_orig = inst.info["bads"]
inst.info["bads"] = list()
# look for group of bad channels
nodes = sorted(set(chain(*bridged_idx)))
G_dense = np.zeros((len(nodes), len(nodes)))
# fill the edges with a weight of 1
for bridge in bridged_idx:
idx0 = np.searchsorted(nodes, bridge[0])
idx1 = np.searchsorted(nodes, bridge[1])
G_dense[idx0, idx1] = 1
G_dense[idx1, idx0] = 1
# look for connected components
_, labels = connected_components(G_dense, directed=False)
groups_idx = [[nodes[j] for j in np.where(labels == k)[0]] for k in set(labels)]
groups_names = [
[inst.info.ch_names[k] for k in group_idx] for group_idx in groups_idx
]
# warn for all bridged areas that include too many electrodes
for group_names in groups_names:
if len(group_names) > bad_limit:
raise RuntimeError(
f"The channels {', '.join(group_names)} are bridged together "
"and form a large area of bridged electrodes. Interpolation "
"might be inaccurate."
)
# make virtual channels
virtual_chs = dict()
bads = set()
for k, group_idx in enumerate(groups_idx):
group_names = [inst.info.ch_names[k] for k in group_idx]
bads = bads.union(group_names)
# compute centroid position in spherical "head" coordinates
pos_virtual = _find_centroid_sphere(pos["ch_pos"], group_names)
# create the virtual channel info and set the position
virtual_info = create_info([f"virtual {k + 1}"], inst.info["sfreq"], "eeg")
virtual_info["chs"][0]["loc"][:3] = pos_virtual
# create virtual channel
data = inst.get_data(picks=group_names)
if isinstance(inst, BaseRaw):
data = np.average(data, axis=0).reshape(1, -1)
virtual_ch = RawArray(data, virtual_info, first_samp=inst.first_samp)
elif isinstance(inst, BaseEpochs):
data = np.average(data, axis=1).reshape(len(data), 1, -1)
virtual_ch = EpochsArray(data, virtual_info, tmin=inst.tmin)
else: # evoked
data = np.average(data, axis=0).reshape(1, -1)
virtual_ch = EvokedArray(
np.average(data, axis=0).reshape(1, -1),
virtual_info,
tmin=inst.tmin,
nave=inst.nave,
kind=inst.kind,
)
virtual_chs[f"virtual {k + 1}"] = virtual_ch
# add the virtual channels
inst.add_channels(list(virtual_chs.values()), force_update_info=True)
# use the virtual channels to interpolate
inst.info["bads"] = list(bads)
inst.interpolate_bads()
# drop virtual channels
inst.drop_channels(list(virtual_chs.keys()))
inst.info["bads"] = bads_orig
return inst
def _find_centroid_sphere(ch_pos, group_names):
"""Compute the centroid position between N electrodes.
The centroid should be determined in spherical "head" coordinates which is
more accurante than cutting through the scalp by averaging in cartesian
coordinates.
A simple way is to average the location in cartesian coordinate, convert
to spehrical coordinate and replace the radius with the average radius of
the N points in spherical coordinates.
Parameters
----------
ch_pos : OrderedDict
The position of all channels in cartesian coordinates.
group_names : list | tuple
The name of the N electrodes used to determine the centroid.
Returns
-------
pos_centroid : array of shape (3,)
The position of the centroid in cartesian coordinates.
"""
cartesian_positions = np.array([ch_pos[ch_name] for ch_name in group_names])
sphere_positions = _cart_to_sph(cartesian_positions)
cartesian_pos_centroid = np.average(cartesian_positions, axis=0)
sphere_pos_centroid = _cart_to_sph(cartesian_pos_centroid)
# average the radius and overwrite it
avg_radius = np.average(sphere_positions, axis=0)[0]
sphere_pos_centroid[0, 0] = avg_radius
# convert back to cartesian
pos_centroid = _sph_to_cart(sphere_pos_centroid)[0, :]
return pos_centroid

2893
mne/preprocessing/maxwell.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,22 @@
"""NIRS specific preprocessing functions."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from .nirs import (
short_channels,
source_detector_distances,
_check_channels_ordered,
_channel_frequencies,
_fnirs_spread_bads,
_channel_chromophore,
_validate_nirs_info,
_fnirs_optode_names,
_optode_position,
_reorder_nirx,
)
from ._optical_density import optical_density
from ._beer_lambert_law import beer_lambert_law
from ._scalp_coupling_index import scalp_coupling_index
from ._tddr import temporal_derivative_distribution_repair, tddr

View File

@@ -0,0 +1,115 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import os.path as op
import numpy as np
from scipy.interpolate import interp1d
from scipy.io import loadmat
from ..._fiff.constants import FIFF
from ...io import BaseRaw
from ...utils import _validate_type, pinv, warn
from ..nirs import _validate_nirs_info, source_detector_distances
def beer_lambert_law(raw, ppf=6.0):
r"""Convert NIRS optical density data to haemoglobin concentration.
Parameters
----------
raw : instance of Raw
The optical density data.
ppf : tuple | float
The partial pathlength factors for each wavelength.
.. versionchanged:: 1.7
Support for different factors for the two wavelengths.
Returns
-------
raw : instance of Raw
The modified raw instance.
"""
raw = raw.copy().load_data()
_validate_type(raw, BaseRaw, "raw")
_validate_type(ppf, ("numeric", "array-like"), "ppf")
ppf = np.array(ppf, float)
if ppf.ndim == 0: # upcast single float to shape (2,)
ppf = np.array([ppf, ppf])
if ppf.shape != (2,):
raise ValueError(
f"ppf must be float or array-like of shape (2,), got shape {ppf.shape}"
)
ppf = ppf[:, np.newaxis] # shape (2, 1)
picks = _validate_nirs_info(raw.info, fnirs="od", which="Beer-lambert")
# This is the one place we *really* need the actual/accurate frequencies
freqs = np.array([raw.info["chs"][pick]["loc"][9] for pick in picks], float)
abs_coef = _load_absorption(freqs)
distances = source_detector_distances(raw.info, picks="all")
bad = ~np.isfinite(distances[picks])
bad |= distances[picks] <= 0
if bad.any():
warn(
"Source-detector distances are zero on NaN, some resulting "
"concentrations will be zero. Consider setting a montage "
"with raw.set_montage."
)
distances[picks[bad]] = 0.0
if (distances[picks] > 0.1).any():
warn(
"Source-detector distances are greater than 10 cm. "
"Large distances will result in invalid data, and are "
"likely due to optode locations being stored in a "
" unit other than meters."
)
rename = dict()
for ii, jj in zip(picks[::2], picks[1::2]):
EL = abs_coef * distances[ii] * ppf
iEL = pinv(EL)
raw._data[[ii, jj]] = iEL @ raw._data[[ii, jj]] * 1e-3
# Update channel information
coil_dict = dict(hbo=FIFF.FIFFV_COIL_FNIRS_HBO, hbr=FIFF.FIFFV_COIL_FNIRS_HBR)
for ki, kind in zip((ii, jj), ("hbo", "hbr")):
ch = raw.info["chs"][ki]
ch.update(coil_type=coil_dict[kind], unit=FIFF.FIFF_UNIT_MOL)
new_name = f'{ch["ch_name"].split(" ")[0]} {kind}'
rename[ch["ch_name"]] = new_name
raw.rename_channels(rename)
# Validate the format of data after transformation is valid
_validate_nirs_info(raw.info, fnirs="hb")
return raw
def _load_absorption(freqs):
"""Load molar extinction coefficients."""
# Data from https://omlc.org/spectra/hemoglobin/summary.html
# The text was copied to a text file. The text before and
# after the table was deleted. The the following was run in
# matlab
# extinct_coef=importdata('extinction_coef.txt')
# save('extinction_coef.mat', 'extinct_coef')
#
# Returns data as [[HbO2(freq1), Hb(freq1)],
# [HbO2(freq2), Hb(freq2)]]
extinction_fname = op.join(
op.dirname(__file__), "..", "..", "data", "extinction_coef.mat"
)
a = loadmat(extinction_fname)["extinct_coef"]
interp_hbo = interp1d(a[:, 0], a[:, 1], kind="linear")
interp_hb = interp1d(a[:, 0], a[:, 2], kind="linear")
ext_coef = np.array(
[
[interp_hbo(freqs[0]), interp_hb(freqs[0])],
[interp_hbo(freqs[1]), interp_hb(freqs[1])],
]
)
abs_coef = ext_coef * 0.2303
return abs_coef

View File

@@ -0,0 +1,53 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from ..._fiff.constants import FIFF
from ...io import BaseRaw
from ...utils import _validate_type, verbose, warn
from ..nirs import _validate_nirs_info
@verbose
def optical_density(raw, *, verbose=None):
r"""Convert NIRS raw data to optical density.
Parameters
----------
raw : instance of Raw
The raw data.
%(verbose)s
Returns
-------
raw : instance of Raw
The modified raw instance.
"""
raw = raw.copy().load_data()
_validate_type(raw, BaseRaw, "raw")
picks = _validate_nirs_info(raw.info, fnirs="cw_amplitude")
# The devices measure light intensity. Negative light intensities should
# not occur. If they do it is likely due to hardware or movement issues.
# Set all negative values to abs(x), this also has the benefit of ensuring
# that the means are all greater than zero for the division below.
if np.any(raw._data[picks] <= 0):
warn("Negative intensities encountered. Setting to abs(x)")
min_ = np.inf
for pi in picks:
np.abs(raw._data[pi], out=raw._data[pi])
min_ = min(min_, raw._data[pi].min() or min_)
# avoid == 0
for pi in picks:
np.maximum(raw._data[pi], min_, out=raw._data[pi])
for pi in picks:
data_mean = np.mean(raw._data[pi])
raw._data[pi] /= data_mean
np.log(raw._data[pi], out=raw._data[pi])
raw._data[pi] *= -1
raw.info["chs"][pi]["coil_type"] = FIFF.FIFFV_COIL_FNIRS_OD
return raw

View File

@@ -0,0 +1,69 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from ...io import BaseRaw
from ...utils import _validate_type, verbose
from ..nirs import _validate_nirs_info
@verbose
def scalp_coupling_index(
raw,
l_freq=0.7,
h_freq=1.5,
l_trans_bandwidth=0.3,
h_trans_bandwidth=0.3,
verbose=False,
):
r"""Calculate scalp coupling index.
This function calculates the scalp coupling index
:footcite:`pollonini2014auditory`. This is a measure of the quality of the
connection between the optode and the scalp.
Parameters
----------
raw : instance of Raw
The raw data.
%(l_freq)s
%(h_freq)s
%(l_trans_bandwidth)s
%(h_trans_bandwidth)s
%(verbose)s
Returns
-------
sci : array of float
Array containing scalp coupling index for each channel.
References
----------
.. footbibliography::
"""
_validate_type(raw, BaseRaw, "raw")
picks = _validate_nirs_info(raw.info, fnirs="od", which="Scalp coupling index")
raw = raw.copy().pick(picks).load_data()
zero_mask = np.std(raw._data, axis=-1) == 0
filtered_data = raw.filter(
l_freq,
h_freq,
l_trans_bandwidth=l_trans_bandwidth,
h_trans_bandwidth=h_trans_bandwidth,
verbose=verbose,
).get_data()
sci = np.zeros(picks.shape)
for ii in range(0, len(picks), 2):
with np.errstate(invalid="ignore"):
c = np.corrcoef(filtered_data[ii], filtered_data[ii + 1])[0][1]
if not np.isfinite(c): # someone had std=0
c = 0
sci[ii] = c
sci[ii + 1] = c
sci[zero_mask] = 0
sci = sci[np.argsort(picks)] # restore original order
return sci

View File

@@ -0,0 +1,155 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from scipy.signal import butter, filtfilt
from ...io import BaseRaw
from ...utils import _validate_type, verbose
from ..nirs import _validate_nirs_info
@verbose
def temporal_derivative_distribution_repair(raw, *, verbose=None):
"""Apply temporal derivative distribution repair to data.
Applies temporal derivative distribution repair (TDDR) to data
:footcite:`FishburnEtAl2019`. This approach removes baseline shift
and spike artifacts without the need for any user-supplied parameters.
Parameters
----------
raw : instance of Raw
The raw data.
%(verbose)s
Returns
-------
raw : instance of Raw
Data with TDDR applied.
Notes
-----
TDDR was initially designed to be used on optical density fNIRS data but
has been enabled to be applied on hemoglobin concentration fNIRS data as
well in MNE. We recommend applying the algorithm to optical density fNIRS
data as intended by the original author wherever possible.
There is a shorter alias ``mne.preprocessing.nirs.tddr`` that can be used
instead of this function (e.g. if line length is an issue).
References
----------
.. footbibliography::
"""
raw = raw.copy().load_data()
_validate_type(raw, BaseRaw, "raw")
picks = _validate_nirs_info(raw.info)
if not len(picks):
raise RuntimeError("TDDR should be run on optical density or hemoglobin data.")
for pick in picks:
raw._data[pick] = _TDDR(raw._data[pick], raw.info["sfreq"])
return raw
# provide a short alias
tddr = temporal_derivative_distribution_repair
# Taken from https://github.com/frankfishburn/TDDR/ (MIT license).
# With permission https://github.com/frankfishburn/TDDR/issues/1.
# The only modification is the name, scipy signal import and flake fixes.
def _TDDR(signal, sample_rate):
# This function is the reference implementation for the TDDR algorithm for
# motion correction of fNIRS data, as described in:
#
# Fishburn F.A., Ludlum R.S., Vaidya C.J., & Medvedev A.V. (2019).
# Temporal Derivative Distribution Repair (TDDR): A motion correction
# method for fNIRS. NeuroImage, 184, 171-179.
# https://doi.org/10.1016/j.neuroimage.2018.09.025
#
# Usage:
# signals_corrected = TDDR( signals , sample_rate );
#
# Inputs:
# signals: A [sample x channel] matrix of uncorrected optical density or
# hemoglobin data
# sample_rate: A scalar reflecting the rate of acquisition in Hz
#
# Outputs:
# signals_corrected: A [sample x channel] matrix of corrected optical
# density data
signal = np.array(signal)
if len(signal.shape) != 1:
for ch in range(signal.shape[1]):
signal[:, ch] = _TDDR(signal[:, ch], sample_rate)
return signal
# Preprocess: Separate high and low frequencies
filter_cutoff = 0.5
filter_order = 3
Fc = filter_cutoff * 2 / sample_rate
signal_mean = np.mean(signal)
signal -= signal_mean
if Fc < 1:
fb, fa = butter(filter_order, Fc)
signal_low = filtfilt(fb, fa, signal, padlen=0)
else:
signal_low = signal
signal_high = signal - signal_low
# Initialize
tune = 4.685
D = np.sqrt(np.finfo(signal.dtype).eps)
mu = np.inf
# Step 1. Compute temporal derivative of the signal
deriv = np.diff(signal_low)
# Step 2. Initialize observation weights
w = np.ones(deriv.shape)
# Step 3. Iterative estimation of robust weights
for _ in range(50):
mu0 = mu
# Step 3a. Estimate weighted mean
mu = np.sum(w * deriv) / np.sum(w)
# Step 3b. Calculate absolute residuals of estimate
dev = np.abs(deriv - mu)
# Step 3c. Robust estimate of standard deviation of the residuals
sigma = 1.4826 * np.median(dev)
# Step 3d. Scale deviations by standard deviation and tuning parameter
if sigma == 0:
break
r = dev / (sigma * tune)
# Step 3e. Calculate new weights according to Tukey's biweight function
w = ((1 - r**2) * (r < 1)) ** 2
# Step 3f. Terminate if new estimate is within
# machine-precision of old estimate
if abs(mu - mu0) < D * max(abs(mu), abs(mu0)):
break
# Step 4. Apply robust weights to centered derivative
new_deriv = w * (deriv - mu)
# Step 5. Integrate corrected derivative
signal_low_corrected = np.cumsum(np.insert(new_deriv, 0, 0.0))
# Postprocess: Center the corrected signal
signal_low_corrected = signal_low_corrected - np.mean(signal_low_corrected)
# Postprocess: Merge back with uncorrected high frequency component
signal_corrected = signal_low_corrected + signal_high + signal_mean
return signal_corrected

View File

@@ -0,0 +1,336 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import re
import numpy as np
from ..._fiff.pick import _picks_to_idx, pick_types
from ...utils import _check_option, _validate_type, fill_doc
# Standardized fNIRS channel name regexs
_S_D_F_RE = re.compile(r"S(\d+)_D(\d+) (\d+\.?\d*)")
_S_D_H_RE = re.compile(r"S(\d+)_D(\d+) (\w+)")
@fill_doc
def source_detector_distances(info, picks=None):
r"""Determine the distance between NIRS source and detectors.
Parameters
----------
%(info_not_none)s
%(picks_all_data)s
Returns
-------
dists : array of float
Array containing distances in meters.
Of shape equal to number of channels, or shape of picks if supplied.
"""
return np.array(
[
np.linalg.norm(
np.diff(info["chs"][pick]["loc"][3:9].reshape(2, 3), axis=0)[0]
)
for pick in _picks_to_idx(info, picks, exclude=[])
],
float,
)
@fill_doc
def short_channels(info, threshold=0.01):
r"""Determine which NIRS channels are short.
Channels with a source to detector distance of less than
``threshold`` are reported as short. The default threshold is 0.01 m.
Parameters
----------
%(info_not_none)s
threshold : float
The threshold distance for what is considered short in meters.
Returns
-------
short : array of bool
Array indicating which channels are short.
Of shape equal to number of channels.
"""
return source_detector_distances(info) < threshold
def _channel_frequencies(info):
"""Return the light frequency for each channel."""
# Only valid for fNIRS data before conversion to haemoglobin
picks = _picks_to_idx(
info, ["fnirs_cw_amplitude", "fnirs_od"], exclude=[], allow_empty=True
)
freqs = list()
for pick in picks:
freqs.append(round(float(_S_D_F_RE.match(info["ch_names"][pick]).groups()[2])))
return np.array(freqs, int)
def _channel_chromophore(info):
"""Return the chromophore of each channel."""
# Only valid for fNIRS data after conversion to haemoglobin
picks = _picks_to_idx(info, ["hbo", "hbr"], exclude=[], allow_empty=True)
chroma = []
for ii in picks:
chroma.append(info["ch_names"][ii].split(" ")[1])
return chroma
def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=True):
"""Check channels follow expected fNIRS format.
If the channels are correctly ordered then an array of valid picks
will be returned.
If throw_errors is True then any errors in fNIRS formatting will be
thrown to inform the user. If throw_errors is False then an empty array
will be returned if the channels are not sufficiently formatted.
"""
# Every second channel should be same SD pair
# and have the specified light frequencies.
# All wavelength based fNIRS data.
picks_wave = _picks_to_idx(
info, ["fnirs_cw_amplitude", "fnirs_od"], exclude=[], allow_empty=True
)
# All chromophore fNIRS data
picks_chroma = _picks_to_idx(info, ["hbo", "hbr"], exclude=[], allow_empty=True)
if (len(picks_wave) > 0) & (len(picks_chroma) > 0):
picks = _throw_or_return_empty(
"MNE does not support a combination of amplitude, optical "
"density, and haemoglobin data in the same raw structure.",
throw_errors,
)
# All continuous wave fNIRS data
if len(picks_wave):
error_word = "frequencies"
use_RE = _S_D_F_RE
picks = picks_wave
else:
error_word = "chromophore"
use_RE = _S_D_H_RE
picks = picks_chroma
pair_vals = np.array(pair_vals)
if pair_vals.shape != (2,):
raise ValueError(
f"Exactly two {error_word} must exist in info, got {list(pair_vals)}"
)
# In principle we do not need to require that these be sorted --
# all we need to do is change our sorted() below to make use of a
# pair_vals.index(...) in a sort key -- but in practice we always want
# (hbo, hbr) or (lower_freq, upper_freq) pairings, both of which will
# work with a naive string sort, so let's just enforce sorted-ness here
is_str = pair_vals.dtype.kind == "U"
pair_vals = list(pair_vals)
if is_str:
if pair_vals != ["hbo", "hbr"]:
raise ValueError(
f'The {error_word} in info must be ["hbo", "hbr"], but got '
f"{pair_vals} instead"
)
elif not np.array_equal(np.unique(pair_vals), pair_vals):
raise ValueError(
f"The {error_word} in info must be unique and sorted, but got "
f"got {pair_vals} instead"
)
if len(picks) % 2 != 0:
picks = _throw_or_return_empty(
"NIRS channels not ordered correctly. An even number of NIRS "
f"channels is required. {len(info.ch_names)} channels were"
f"provided",
throw_errors,
)
# Ensure wavelength info exists for waveform data
all_freqs = [info["chs"][ii]["loc"][9] for ii in picks_wave]
if np.any(np.isnan(all_freqs)):
picks = _throw_or_return_empty(
f"NIRS channels is missing wavelength information in the "
f'info["chs"] structure. The encoded wavelengths are {all_freqs}.',
throw_errors,
)
# Validate the channel naming scheme
for pick in picks:
ch_name_info = use_RE.match(info["chs"][pick]["ch_name"])
if not bool(ch_name_info):
picks = _throw_or_return_empty(
"NIRS channels have specified naming conventions. "
"The provided channel name can not be parsed: "
f"{repr(info.ch_names[pick])}",
throw_errors,
)
break
value = ch_name_info.groups()[2]
if len(picks_wave):
value = value
else: # picks_chroma
if value not in ["hbo", "hbr"]:
picks = _throw_or_return_empty(
"NIRS channels have specified naming conventions."
"Chromophore data must be labeled either hbo or hbr. "
f"The failing channel is {info['chs'][pick]['ch_name']}",
throw_errors,
)
break
# Reorder to be paired (naive sort okay here given validation above)
picks = picks[np.argsort([info["ch_names"][pick] for pick in picks])]
# Validate our paired ordering
for ii, jj in zip(picks[::2], picks[1::2]):
ch1_name = info["chs"][ii]["ch_name"]
ch2_name = info["chs"][jj]["ch_name"]
ch1_re = use_RE.match(ch1_name)
ch2_re = use_RE.match(ch2_name)
ch1_S, ch1_D, ch1_value = ch1_re.groups()[:3]
ch2_S, ch2_D, ch2_value = ch2_re.groups()[:3]
if len(picks_wave):
ch1_value, ch2_value = float(ch1_value), float(ch2_value)
if (
(ch1_S != ch2_S)
or (ch1_D != ch2_D)
or (ch1_value != pair_vals[0])
or (ch2_value != pair_vals[1])
):
picks = _throw_or_return_empty(
"NIRS channels not ordered correctly. Channels must be "
"ordered as source detector pairs with alternating"
f" {error_word} {pair_vals[0]} & {pair_vals[1]}, but got "
f"S{ch1_S}_D{ch1_D} pair "
f"{repr(ch1_name)} and {repr(ch2_name)}",
throw_errors,
)
break
if check_bads:
for ii, jj in zip(picks[::2], picks[1::2]):
want = [info.ch_names[ii], info.ch_names[jj]]
got = list(set(info["bads"]).intersection(want))
if len(got) == 1:
raise RuntimeError(
f"NIRS bad labelling is not consistent, found {got} but "
f"needed {want}"
)
return picks
def _throw_or_return_empty(msg, throw_errors):
if throw_errors:
raise ValueError(msg)
else:
return []
def _validate_nirs_info(
info,
*,
throw_errors=True,
fnirs=None,
which=None,
check_bads=True,
allow_empty=True,
):
"""Apply all checks to fNIRS info. Works on all continuous wave types."""
_validate_type(fnirs, (None, str), "fnirs")
kinds = dict(
od="optical density",
cw_amplitude="continuous wave",
hb="chromophore",
)
_check_option("fnirs", fnirs, (None,) + tuple(kinds))
if fnirs is not None:
kind = kinds[fnirs]
fnirs = ["hbo", "hbr"] if fnirs == "hb" else f"fnirs_{fnirs}"
if not len(pick_types(info, fnirs=fnirs)):
raise RuntimeError(
f"{which} must operate on {kind} data, but none was found."
)
freqs = np.unique(_channel_frequencies(info))
if freqs.size > 0:
pair_vals = freqs
else:
pair_vals = np.unique(_channel_chromophore(info))
out = _check_channels_ordered(
info, pair_vals, throw_errors=throw_errors, check_bads=check_bads
)
return out
def _fnirs_spread_bads(info):
"""Spread bad labeling across fnirs channels."""
# For an optode pair if any component (light frequency or chroma) is marked
# as bad, then they all should be. This function will find any pairs marked
# as bad and spread the bad marking to all components of the optode pair.
picks = _validate_nirs_info(info, check_bads=False)
new_bads = set(info["bads"])
for ii, jj in zip(picks[::2], picks[1::2]):
ch1_name, ch2_name = info.ch_names[ii], info.ch_names[jj]
if ch1_name in new_bads:
new_bads.add(ch2_name)
elif ch2_name in new_bads:
new_bads.add(ch1_name)
info["bads"] = sorted(new_bads)
return info
def _fnirs_optode_names(info):
"""Return list of unique optode names."""
picks_wave = _picks_to_idx(
info, ["fnirs_cw_amplitude", "fnirs_od"], exclude=[], allow_empty=True
)
picks_chroma = _picks_to_idx(info, ["hbo", "hbr"], exclude=[], allow_empty=True)
if len(picks_wave) > 0:
regex = _S_D_F_RE
elif len(picks_chroma) > 0:
regex = _S_D_H_RE
else:
return [], []
sources = np.unique([int(regex.match(ch).groups()[0]) for ch in info.ch_names])
detectors = np.unique([int(regex.match(ch).groups()[1]) for ch in info.ch_names])
src_names = [f"S{s}" for s in sources]
det_names = [f"D{d}" for d in detectors]
return src_names, det_names
def _optode_position(info, optode):
"""Find the position of an optode."""
idx = [optode in a for a in info.ch_names].index(True)
if "S" in optode:
loc_idx = range(3, 6)
elif "D" in optode:
loc_idx = range(6, 9)
return info["chs"][idx]["loc"][loc_idx]
def _reorder_nirx(raw):
# Maybe someday we should make this public like
# mne.preprocessing.nirs.reorder_standard(raw, order='nirx')
info = raw.info
picks = pick_types(info, fnirs=True, exclude=[])
prefixes = [info["ch_names"][pick].split()[0] for pick in picks]
nirs_names = [info["ch_names"][pick] for pick in picks]
nirs_sorted = sorted(
nirs_names,
key=lambda name: (prefixes.index(name.split()[0]), name.split(maxsplit=1)[1]),
)
raw.reorder_channels(nirs_sorted)

140
mne/preprocessing/otp.py Normal file
View File

@@ -0,0 +1,140 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from functools import partial
import numpy as np
from .._fiff.pick import _picks_to_idx
from .._ola import _COLA, _Storer
from ..surface import _normalize_vectors
from ..utils import logger, verbose
def _svd_cov(cov, data):
"""Use a covariance matrix to compute the SVD faster."""
# This makes use of mathematical equivalences between PCA and SVD
# on zero-mean data
s, u = np.linalg.eigh(cov)
norm = np.ones((s.size,))
mask = s > np.finfo(float).eps * s[-1] # largest is last
s = np.sqrt(s, out=s)
norm[mask] = 1.0 / s[mask]
u *= norm
v = np.dot(u.T[mask], data)
return u, s, v
@verbose
def oversampled_temporal_projection(raw, duration=10.0, picks=None, verbose=None):
"""Denoise MEG channels using leave-one-out temporal projection.
Parameters
----------
raw : instance of Raw
Raw data to denoise.
duration : float | str
The window duration (in seconds; default 10.) to use. Can also
be "min" to use as short a window as possible.
%(picks_all_data)s
%(verbose)s
Returns
-------
raw_clean : instance of Raw
The cleaned data.
Notes
-----
This algorithm is computationally expensive, and can be several times
slower than realtime for conventional M/EEG datasets. It uses a
leave-one-out procedure with parallel temporal projection to remove
individual sensor noise under the assumption that sampled fields
(e.g., MEG and EEG) are oversampled by the sensor array
:footcite:`LarsonTaulu2018`.
OTP can improve sensor noise levels (especially under visual
inspection) and repair some bad channels. This noise reduction is known
to interact with :func:`tSSS <mne.preprocessing.maxwell_filter>` such
that increasing the ``st_correlation`` value will likely be necessary.
Channels marked as bad will not be used to reconstruct good channels,
but good channels will be used to process the bad channels. Depending
on the type of noise present in the bad channels, this might make
them usable again.
Use of this algorithm is covered by a provisional patent.
.. versionadded:: 0.16
References
----------
.. footbibliography::
"""
logger.info("Processing MEG data using oversampled temporal projection")
picks = _picks_to_idx(raw.info, picks, exclude=())
picks_good, picks_bad = list(), list() # these are indices into picks
for ii, pi in enumerate(picks):
if raw.ch_names[pi] in raw.info["bads"]:
picks_bad.append(ii)
else:
picks_good.append(ii)
picks_good = np.array(picks_good, int)
picks_bad = np.array(picks_bad, int)
n_samples = int(round(float(duration) * raw.info["sfreq"]))
if n_samples < len(picks_good) - 1:
raise ValueError(
f"duration ({n_samples / raw.info['sfreq']}) yielded {n_samples} samples, "
f"which is fewer than the number of channels -1 ({len(picks_good) - 1})"
)
n_overlap = n_samples // 2
raw_otp = raw.copy().load_data(verbose=False)
otp = _COLA(
partial(_otp, picks_good=picks_good, picks_bad=picks_bad),
_Storer(raw_otp._data, picks=picks),
len(raw.times),
n_samples,
n_overlap,
raw.info["sfreq"],
)
read_lims = list(range(0, len(raw.times), n_samples)) + [len(raw.times)]
for start, stop in zip(read_lims[:-1], read_lims[1:]):
logger.info(
f" Denoising {raw.times[[start, stop - 1]][0]: 8.2f} "
f"{raw.times[[start, stop - 1]][1]: 8.2f} s"
)
otp.feed(raw[picks, start:stop][0])
return raw_otp
def _otp(data, picks_good, picks_bad):
"""Perform OTP on one segment of data."""
if not np.isfinite(data).all():
raise RuntimeError("non-finite data (inf or nan) found in raw instance")
# demean our data
data_means = np.mean(data, axis=-1, keepdims=True)
data -= data_means
# make a copy
data_good = data[picks_good]
# scale the copy that will be used to form the temporal basis vectors
# so that _orth_svdvals thresholding should work properly with
# different channel types (e.g., M-EEG)
norms = _normalize_vectors(data_good)
cov = np.dot(data_good, data_good.T)
if len(picks_bad) > 0:
full_basis = _svd_cov(cov, data_good)[2]
for mi, pick in enumerate(picks_good):
# operate on original data
idx = list(range(mi)) + list(range(mi + 1, len(data_good)))
# Equivalent: svd(data[idx], full_matrices=False)[2]
t_basis = _svd_cov(cov[np.ix_(idx, idx)], data_good[idx])[2]
x = np.dot(np.dot(data_good[mi], t_basis.T), t_basis)
x *= norms[mi]
x += data_means[pick]
data[pick] = x
for pick in picks_bad:
data[pick] = np.dot(np.dot(data[pick], full_basis.T), full_basis)
data[pick] += data_means[pick]
return [data]

View File

@@ -0,0 +1,132 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from numpy.polynomial.polynomial import Polynomial
from scipy.stats import pearsonr
from ..io import BaseRaw
from ..utils import _validate_type, logger, verbose, warn
@verbose
def realign_raw(raw, other, t_raw, t_other, *, verbose=None):
"""Realign two simultaneous recordings.
Due to clock drift, recordings at a given same sample rate made by two
separate devices simultaneously can become out of sync over time. This
function uses event times captured by both acquisition devices to resample
``other`` to match ``raw``.
Parameters
----------
raw : instance of Raw
The first raw instance.
other : instance of Raw
The second raw instance. It will be resampled to match ``raw``.
t_raw : array-like, shape (n_events,)
The times of shared events in ``raw`` relative to ``raw.times[0]`` (0).
Typically these could be events on some TTL channel such as::
find_events(raw)[:, 0] / raw.info["sfreq"] - raw.first_time
t_other : array-like, shape (n_events,)
The times of shared events in ``other`` relative to ``other.times[0]``.
%(verbose)s
Notes
-----
This function operates inplace. It will:
1. Estimate the zero-order (start offset) and first-order (clock drift)
correction.
2. Crop the start of ``raw`` or ``other``, depending on which started
recording first.
3. Resample ``other`` to match ``raw`` based on the clock drift.
4. Realign the onsets and durations in ``other.annotations``.
5. Crop the end of ``raw`` or ``other``, depending on which stopped
recording first (and the clock drift rate).
This function is primarily designed to work on recordings made at the same
sample rate, but it can also operate on recordings made at different
sample rates to resample and deal with clock drift simultaneously.
.. versionadded:: 0.22
"""
_validate_type(raw, BaseRaw, "raw")
_validate_type(other, BaseRaw, "other")
t_raw = np.array(t_raw, float)
t_other = np.array(t_other, float)
if t_raw.ndim != 1 or t_raw.shape != t_other.shape:
raise ValueError(
"t_raw and t_other must be 1D with the same shape, "
f"got shapes {t_raw.shape} and {t_other.shape}"
)
if len(t_raw) < 20:
warn("Fewer than 20 times passed, results may be unreliable")
# 1. Compute correction factors
poly = Polynomial.fit(x=t_other, y=t_raw, deg=1)
converted = poly.convert(domain=(-1, 1))
[zero_ord, first_ord] = converted.coef
logger.info(
f"Zero order coefficient: {zero_ord} \nFirst order coefficient: {first_ord}"
)
r, p = pearsonr(t_other, t_raw)
msg = f"Linear correlation computed as R={r:0.3f} and p={p:0.2e}"
if p > 0.05 or r <= 0:
raise ValueError(msg + ", cannot resample safely")
if p > 1e-6:
warn(msg + ", results may be unreliable")
else:
logger.info(msg)
dr_ms_s = 1000 * abs(1 - first_ord)
logger.info(
f"Drift rate: {1000 * dr_ms_s:0.1f} μs/s "
f"(total drift over {raw.times[-1]:0.1f} s recording: "
f"{raw.times[-1] * dr_ms_s:0.1f} ms)"
)
# 2. Crop start of recordings to match
if zero_ord > 0: # need to crop start of raw to match other
logger.info(f"Cropping {zero_ord:0.3f} s from the start of raw")
raw.crop(zero_ord, None)
t_raw -= zero_ord
elif zero_ord < 0: # need to crop start of other to match raw
t_crop = -zero_ord / first_ord
logger.info(f"Cropping {t_crop:0.3f} s from the start of other")
other.crop(t_crop, None)
t_other -= t_crop
# 3. Resample data using the first-order term
nan_ch_names = [
ch for ch in other.info["ch_names"] if np.isnan(other.get_data(picks=ch)).any()
]
if len(nan_ch_names) > 0: # Issue warning if any channel in other has nan values
warn(
f"Channel(s) {', '.join(nan_ch_names)} in `other` contain NaN values. "
"Resampling these channels will result in the whole channel being NaN. "
"(If realigning eye-tracking data, consider using interpolate_blinks and "
"passing interpolate_gaze=True)"
)
logger.info("Resampling other")
sfreq_new = raw.info["sfreq"] * first_ord
other.load_data().resample(sfreq_new)
with other.info._unlock():
other.info["sfreq"] = raw.info["sfreq"]
# 4. Realign the onsets and durations in other.annotations
# Must happen before end cropping to avoid losing annotations
logger.info("Correcting annotations in other")
other.annotations.onset *= first_ord
other.annotations.duration *= first_ord
# 5. Crop the end of one of the recordings if necessary
delta = raw.times[-1] - other.times[-1]
msg = f"Cropping {abs(delta):0.3f} s from the end of "
if delta > 0:
logger.info(msg + "raw")
raw.crop(0, other.times[-1])
elif delta < 0:
logger.info(msg + "other")
other.crop(0, raw.times[-1])

605
mne/preprocessing/ssp.py Normal file
View File

@@ -0,0 +1,605 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import copy as cp
import numpy as np
from .._fiff.pick import pick_types
from .._fiff.reference import make_eeg_average_ref_proj
from ..epochs import Epochs
from ..proj import compute_proj_epochs, compute_proj_evoked
from ..utils import _validate_type, logger, verbose, warn
from .ecg import find_ecg_events
from .eog import find_eog_events
def _safe_del_key(dict_, key):
"""Aux function.
Use this function when preparing rejection parameters
instead of directly deleting keys.
"""
if key in dict_:
del dict_[key]
def _compute_exg_proj(
mode,
raw,
raw_event,
tmin,
tmax,
n_grad,
n_mag,
n_eeg,
l_freq,
h_freq,
average,
filter_length,
n_jobs,
ch_name,
reject,
flat,
bads,
avg_ref,
no_proj,
event_id,
exg_l_freq,
exg_h_freq,
tstart,
qrs_threshold,
filter_method,
iir_params,
return_drop_log,
copy,
meg,
verbose,
):
"""Compute SSP/PCA projections for ECG or EOG artifacts."""
raw = raw.copy() if copy else raw
del copy
raw.load_data() # we will filter it later
if no_proj:
projs = []
else:
projs = cp.deepcopy(raw.info["projs"])
logger.info(f"Including {len(projs)} SSP projectors from raw file")
if avg_ref:
eeg_proj = make_eeg_average_ref_proj(raw.info)
projs.append(eeg_proj)
if raw_event is None:
raw_event = raw
assert mode in ("ECG", "EOG") # internal function
logger.info(f"Running {mode} SSP computation")
if mode == "ECG":
events, _, _ = find_ecg_events(
raw_event,
ch_name=ch_name,
event_id=event_id,
l_freq=exg_l_freq,
h_freq=exg_h_freq,
tstart=tstart,
qrs_threshold=qrs_threshold,
filter_length=filter_length,
)
else: # mode == 'EOG':
events = find_eog_events(
raw_event,
event_id=event_id,
l_freq=exg_l_freq,
h_freq=exg_h_freq,
filter_length=filter_length,
ch_name=ch_name,
tstart=tstart,
)
# Check to make sure we actually got at least one usable event
if events.shape[0] < 1:
warn(f"No {mode} events found")
return ([], events) + (([],) if return_drop_log else ())
logger.info("Computing projector")
my_info = cp.deepcopy(raw.info)
my_info["bads"] += bads
# Handler rejection parameters
_validate_type(reject, (None, dict), "reject")
_validate_type(flat, (None, dict), "flat")
if reject is not None: # make sure they didn't pass None
reject = reject.copy() # must make a copy or we modify default!
if (
len(
pick_types(
my_info,
meg="grad",
eeg=False,
eog=False,
ref_meg=False,
exclude="bads",
)
)
== 0
):
_safe_del_key(reject, "grad")
if (
len(
pick_types(
my_info,
meg="mag",
eeg=False,
eog=False,
ref_meg=False,
exclude="bads",
)
)
== 0
):
_safe_del_key(reject, "mag")
if (
len(
pick_types(
my_info,
meg=False,
eeg=True,
eog=False,
ref_meg=False,
exclude="bads",
)
)
== 0
):
_safe_del_key(reject, "eeg")
if (
len(
pick_types(
my_info,
meg=False,
eeg=False,
eog=True,
ref_meg=False,
exclude="bads",
)
)
== 0
):
_safe_del_key(reject, "eog")
if flat is not None: # make sure they didn't pass None
flat = flat.copy()
if (
len(
pick_types(
my_info,
meg="grad",
eeg=False,
eog=False,
ref_meg=False,
exclude="bads",
)
)
== 0
):
_safe_del_key(flat, "grad")
if (
len(
pick_types(
my_info,
meg="mag",
eeg=False,
eog=False,
ref_meg=False,
exclude="bads",
)
)
== 0
):
_safe_del_key(flat, "mag")
if (
len(
pick_types(
my_info,
meg=False,
eeg=True,
eog=False,
ref_meg=False,
exclude="bads",
)
)
== 0
):
_safe_del_key(flat, "eeg")
if (
len(
pick_types(
my_info,
meg=False,
eeg=False,
eog=True,
ref_meg=False,
exclude="bads",
)
)
== 0
):
_safe_del_key(flat, "eog")
# exclude bad channels from projection
# keep reference channels if compensation channels are present
ref_meg = len(my_info["comps"]) > 0
picks = pick_types(
my_info, meg=True, eeg=True, eog=True, ecg=True, ref_meg=ref_meg, exclude="bads"
)
raw.filter(
l_freq,
h_freq,
picks=picks,
filter_length=filter_length,
n_jobs=n_jobs,
method=filter_method,
iir_params=iir_params,
l_trans_bandwidth=0.5,
h_trans_bandwidth=0.5,
phase="zero-double",
fir_design="firwin2",
)
epochs = Epochs(
raw,
events,
None,
tmin,
tmax,
baseline=None,
preload=True,
picks=picks,
reject=reject,
flat=flat,
proj=True,
)
drop_log = epochs.drop_log
if epochs.events.shape[0] < 1:
warn("No good epochs found")
return ([], events) + ((drop_log,) if return_drop_log else ())
if average:
evoked = epochs.average()
ev_projs = compute_proj_evoked(
evoked, n_grad=n_grad, n_mag=n_mag, n_eeg=n_eeg, meg=meg
)
else:
ev_projs = compute_proj_epochs(
epochs, n_grad=n_grad, n_mag=n_mag, n_eeg=n_eeg, n_jobs=n_jobs, meg=meg
)
for p in ev_projs:
p["desc"] = mode + "-" + p["desc"]
projs.extend(ev_projs)
logger.info("Done.")
return (projs, events) + ((drop_log,) if return_drop_log else ())
@verbose
def compute_proj_ecg(
raw,
raw_event=None,
tmin=-0.2,
tmax=0.4,
n_grad=2,
n_mag=2,
n_eeg=2,
l_freq=1.0,
h_freq=35.0,
average=True,
filter_length="10s",
n_jobs=None,
ch_name=None,
reject=dict(grad=2000e-13, mag=3000e-15, eeg=50e-6, eog=250e-6), # noqa: B006
flat=None,
bads=(),
avg_ref=False,
no_proj=False,
event_id=999,
ecg_l_freq=5,
ecg_h_freq=35,
tstart=0.0,
qrs_threshold="auto",
filter_method="fir",
iir_params=None,
copy=True,
return_drop_log=False,
meg="separate",
verbose=None,
):
"""Compute SSP (signal-space projection) vectors for ECG artifacts.
%(compute_proj_ecg)s
.. note:: Raw data will be loaded if it hasn't been preloaded already.
Parameters
----------
raw : mne.io.Raw
Raw input file.
raw_event : mne.io.Raw or None
Raw file to use for event detection (if None, raw is used).
tmin : float
Time before event in seconds.
tmax : float
Time after event in seconds.
n_grad : int
Number of SSP vectors for gradiometers.
n_mag : int
Number of SSP vectors for magnetometers.
n_eeg : int
Number of SSP vectors for EEG.
l_freq : float | None
Filter low cut-off frequency for the data channels in Hz.
h_freq : float | None
Filter high cut-off frequency for the data channels in Hz.
average : bool
Compute SSP after averaging. Default is True.
filter_length : str | int | None
Number of taps to use for filtering.
%(n_jobs)s
ch_name : str | None
Channel to use for ECG detection (Required if no ECG found).
reject : dict | None
Epoch rejection configuration (see Epochs).
flat : dict | None
Epoch flat configuration (see Epochs).
bads : list
List with (additional) bad channels.
avg_ref : bool
Add EEG average reference proj.
no_proj : bool
Exclude the SSP projectors currently in the fiff file.
event_id : int
ID to use for events.
ecg_l_freq : float
Low pass frequency applied to the ECG channel for event detection.
ecg_h_freq : float
High pass frequency applied to the ECG channel for event detection.
tstart : float
Start artifact detection after tstart seconds.
qrs_threshold : float | str
Between 0 and 1. qrs detection threshold. Can also be "auto" to
automatically choose the threshold that generates a reasonable
number of heartbeats (40-160 beats / min).
filter_method : str
Method for filtering ('iir' or 'fir').
iir_params : dict | None
Dictionary of parameters to use for IIR filtering.
See mne.filter.construct_iir_filter for details. If iir_params
is None and method="iir", 4th order Butterworth will be used.
copy : bool
If False, filtering raw data is done in place. Defaults to True.
return_drop_log : bool
If True, return the drop log.
.. versionadded:: 0.15
meg : str
Can be ``'separate'`` (default) or ``'combined'`` to compute projectors
for magnetometers and gradiometers separately or jointly.
If ``'combined'``, ``n_mag == n_grad`` is required and the number of
projectors computed for MEG will be ``n_mag``.
.. versionadded:: 0.18
%(verbose)s
Returns
-------
%(projs)s
ecg_events : ndarray
Detected ECG events.
drop_log : list
The drop log, if requested.
See Also
--------
find_ecg_events
create_ecg_epochs
Notes
-----
Filtering is applied to the ECG channel while finding events using
``ecg_l_freq`` and ``ecg_h_freq``, and then to the ``raw`` instance
using ``l_freq`` and ``h_freq`` before creation of the epochs used to
create the projectors.
"""
return _compute_exg_proj(
"ECG",
raw,
raw_event,
tmin,
tmax,
n_grad,
n_mag,
n_eeg,
l_freq,
h_freq,
average,
filter_length,
n_jobs,
ch_name,
reject,
flat,
bads,
avg_ref,
no_proj,
event_id,
ecg_l_freq,
ecg_h_freq,
tstart,
qrs_threshold,
filter_method,
iir_params,
return_drop_log,
copy,
meg,
verbose,
)
@verbose
def compute_proj_eog(
raw,
raw_event=None,
tmin=-0.2,
tmax=0.2,
n_grad=2,
n_mag=2,
n_eeg=2,
l_freq=1.0,
h_freq=35.0,
average=True,
filter_length="10s",
n_jobs=None,
reject=dict(grad=2000e-13, mag=3000e-15, eeg=500e-6, eog=np.inf), # noqa: B006
flat=None,
bads=(),
avg_ref=False,
no_proj=False,
event_id=998,
eog_l_freq=1,
eog_h_freq=10,
tstart=0.0,
filter_method="fir",
iir_params=None,
ch_name=None,
copy=True,
return_drop_log=False,
meg="separate",
verbose=None,
):
"""Compute SSP (signal-space projection) vectors for EOG artifacts.
%(compute_proj_eog)s
.. note:: Raw data must be preloaded.
Parameters
----------
raw : mne.io.Raw
Raw input file.
raw_event : mne.io.Raw or None
Raw file to use for event detection (if None, raw is used).
tmin : float
Time before event in seconds.
tmax : float
Time after event in seconds.
n_grad : int
Number of SSP vectors for gradiometers.
n_mag : int
Number of SSP vectors for magnetometers.
n_eeg : int
Number of SSP vectors for EEG.
l_freq : float | None
Filter low cut-off frequency for the data channels in Hz.
h_freq : float | None
Filter high cut-off frequency for the data channels in Hz.
average : bool
Compute SSP after averaging. Default is True.
filter_length : str | int | None
Number of taps to use for filtering.
%(n_jobs)s
reject : dict | None
Epoch rejection configuration (see Epochs).
flat : dict | None
Epoch flat configuration (see Epochs).
bads : list
List with (additional) bad channels.
avg_ref : bool
Add EEG average reference proj.
no_proj : bool
Exclude the SSP projectors currently in the fiff file.
event_id : int
ID to use for events.
eog_l_freq : float
Low pass frequency applied to the E0G channel for event detection.
eog_h_freq : float
High pass frequency applied to the EOG channel for event detection.
tstart : float
Start artifact detection after tstart seconds.
filter_method : str
Method for filtering ('iir' or 'fir').
iir_params : dict | None
Dictionary of parameters to use for IIR filtering.
See mne.filter.construct_iir_filter for details. If iir_params
is None and method="iir", 4th order Butterworth will be used.
ch_name : str | None
If not None, specify EOG channel name.
copy : bool
If False, filtering raw data is done in place. Defaults to True.
return_drop_log : bool
If True, return the drop log.
.. versionadded:: 0.15
meg : str
Can be 'separate' (default) or 'combined' to compute projectors
for magnetometers and gradiometers separately or jointly.
If 'combined', ``n_mag == n_grad`` is required and the number of
projectors computed for MEG will be ``n_mag``.
.. versionadded:: 0.18
%(verbose)s
Returns
-------
%(projs)s
eog_events: ndarray
Detected EOG events.
drop_log : list
The drop log, if requested.
See Also
--------
find_eog_events
create_eog_epochs
Notes
-----
Filtering is applied to the EOG channel while finding events using
``eog_l_freq`` and ``eog_h_freq``, and then to the ``raw`` instance
using ``l_freq`` and ``h_freq`` before creation of the epochs used to
create the projectors.
"""
return _compute_exg_proj(
"EOG",
raw,
raw_event,
tmin,
tmax,
n_grad,
n_mag,
n_eeg,
l_freq,
h_freq,
average,
filter_length,
n_jobs,
ch_name,
reject,
flat,
bads,
avg_ref,
no_proj,
event_id,
eog_l_freq,
eog_h_freq,
tstart,
"auto",
filter_method,
iir_params,
return_drop_log,
copy,
meg,
verbose,
)

176
mne/preprocessing/stim.py Normal file
View File

@@ -0,0 +1,176 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from scipy.interpolate import interp1d
from scipy.signal.windows import hann
from .._fiff.pick import _picks_to_idx
from ..epochs import BaseEpochs
from ..event import find_events
from ..evoked import Evoked
from ..io import BaseRaw
from ..utils import _check_option, _check_preload, _validate_type, fill_doc
def _get_window(start, end):
"""Return window which has length as much as parameter start - end."""
window = 1 - np.r_[hann(4)[:2], np.ones(np.abs(end - start) - 4), hann(4)[-2:]].T
return window
def _fix_artifact(
data, window, picks, first_samp, last_samp, base_tmin, base_tmax, mode
):
"""Modify original data by using parameter data."""
if mode == "linear":
x = np.array([first_samp, last_samp])
f = interp1d(x, data[:, (first_samp, last_samp)][picks])
xnew = np.arange(first_samp, last_samp)
interp_data = f(xnew)
data[picks, first_samp:last_samp] = interp_data
if mode == "window":
data[picks, first_samp:last_samp] = (
data[picks, first_samp:last_samp] * window[np.newaxis, :]
)
if mode == "constant":
data[picks, first_samp:last_samp] = data[picks, base_tmin:base_tmax].mean(
axis=1
)[:, None]
@fill_doc
def fix_stim_artifact(
inst,
events=None,
event_id=None,
tmin=0.0,
tmax=0.01,
*,
baseline=None,
mode="linear",
stim_channel=None,
picks=None,
):
"""Eliminate stimulation's artifacts from instance.
.. note:: This function operates in-place, consider passing
``inst.copy()`` if this is not desired.
Parameters
----------
inst : instance of Raw or Epochs or Evoked
The data.
events : array, shape (n_events, 3)
The list of events. Required only when inst is Raw.
event_id : int
The id of the events generating the stimulation artifacts.
If None, read all events. Required only when inst is Raw.
tmin : float
Start time of the interpolation window in seconds.
tmax : float
End time of the interpolation window in seconds.
baseline : None | tuple, shape (2,)
The baseline to use when ``mode='constant'``, in which case it
must be non-None.
.. versionadded:: 1.8
mode : 'linear' | 'window' | 'constant'
Way to fill the artifacted time interval.
``"linear"``
Does linear interpolation.
``"window"``
Applies a ``(1 - hanning)`` window.
``"constant"``
Uses baseline average. baseline parameter must be provided.
.. versionchanged:: 1.8
Added the ``"constant"`` mode.
stim_channel : str | None
Stim channel to use.
%(picks_all_data)s
Returns
-------
inst : instance of Raw or Evoked or Epochs
Instance with modified data.
"""
_check_option("mode", mode, ["linear", "window", "constant"])
s_start = int(np.ceil(inst.info["sfreq"] * tmin))
s_end = int(np.ceil(inst.info["sfreq"] * tmax))
if mode == "constant":
_validate_type(
baseline, (tuple, list), "baseline", extra="when mode='constant'"
)
_check_option("len(baseline)", len(baseline), [2])
for bi, b in enumerate(baseline):
_validate_type(
b, "numeric", f"baseline[{bi}]", extra="when mode='constant'"
)
b_start = int(np.ceil(inst.info["sfreq"] * baseline[0]))
b_end = int(np.ceil(inst.info["sfreq"] * baseline[1]))
else:
b_start = b_end = np.nan
if (mode == "window") and (s_end - s_start) < 4:
raise ValueError(
'Time range is too short. Use a larger interval or set mode to "linear".'
)
window = None
if mode == "window":
window = _get_window(s_start, s_end)
picks = _picks_to_idx(inst.info, picks, "data", exclude=())
_check_preload(inst, "fix_stim_artifact")
if isinstance(inst, BaseRaw):
if events is None:
events = find_events(inst, stim_channel=stim_channel)
if len(events) == 0:
raise ValueError("No events are found")
if event_id is None:
events_sel = np.arange(len(events))
else:
events_sel = events[:, 2] == event_id
event_start = events[events_sel, 0]
data = inst._data
for event_idx in event_start:
first_samp = int(event_idx) - inst.first_samp + s_start
last_samp = int(event_idx) - inst.first_samp + s_end
base_t1 = int(event_idx) - inst.first_samp + b_start
base_t2 = int(event_idx) - inst.first_samp + b_end
_fix_artifact(
data, window, picks, first_samp, last_samp, base_t1, base_t2, mode
)
elif isinstance(inst, BaseEpochs):
if inst.reject is not None:
raise RuntimeError(
"Reject is already applied. Use reject=None in the constructor."
)
e_start = int(np.ceil(inst.info["sfreq"] * inst.tmin))
first_samp = s_start - e_start
last_samp = s_end - e_start
data = inst._data
base_t1 = b_start - e_start
base_t2 = b_end - e_start
for epoch in data:
_fix_artifact(
epoch, window, picks, first_samp, last_samp, base_t1, base_t2, mode
)
elif isinstance(inst, Evoked):
first_samp = s_start - inst.first
last_samp = s_end - inst.first
data = inst.data
base_t1 = b_start - inst.first
base_t2 = b_end - inst.first
_fix_artifact(
data, window, picks, first_samp, last_samp, base_t1, base_t2, mode
)
else:
raise TypeError(f"Not a Raw or Epochs or Evoked (got {type(inst)}).")
return inst

682
mne/preprocessing/xdawn.py Normal file
View File

@@ -0,0 +1,682 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from scipy import linalg
from .._fiff.pick import _pick_data_channels, pick_info
from ..cov import Covariance, _regularized_covariance
from ..decoding import BaseEstimator, TransformerMixin
from ..epochs import BaseEpochs
from ..evoked import Evoked, EvokedArray
from ..io import BaseRaw
from ..utils import _check_option, logger, pinv
def _construct_signal_from_epochs(epochs, events, sfreq, tmin):
"""Reconstruct pseudo continuous signal from epochs."""
n_epochs, n_channels, n_times = epochs.shape
tmax = tmin + n_times / float(sfreq)
start = np.min(events[:, 0]) + int(tmin * sfreq)
stop = np.max(events[:, 0]) + int(tmax * sfreq) + 1
n_samples = stop - start
n_epochs, n_channels, n_times = epochs.shape
events_pos = events[:, 0] - events[0, 0]
raw = np.zeros((n_channels, n_samples))
for idx in range(n_epochs):
onset = events_pos[idx]
offset = onset + n_times
raw[:, onset:offset] = epochs[idx]
return raw
def _least_square_evoked(epochs_data, events, tmin, sfreq):
"""Least square estimation of evoked response from epochs data.
Parameters
----------
epochs_data : array, shape (n_channels, n_times)
The epochs data to estimate evoked.
events : array, shape (n_events, 3)
The events typically returned by the read_events function.
If some events don't match the events of interest as specified
by event_id, they will be ignored.
tmin : float
Start time before event.
sfreq : float
Sampling frequency.
Returns
-------
evokeds : array, shape (n_class, n_components, n_times)
An concatenated array of evoked data for each event type.
toeplitz : array, shape (n_class * n_components, n_channels)
An concatenated array of toeplitz matrix for each event type.
"""
n_epochs, n_channels, n_times = epochs_data.shape
tmax = tmin + n_times / float(sfreq)
# Deal with shuffled epochs
events = events.copy()
events[:, 0] -= events[0, 0] + int(tmin * sfreq)
# Construct raw signal
raw = _construct_signal_from_epochs(epochs_data, events, sfreq, tmin)
# Compute the independent evoked responses per condition, while correcting
# for event overlaps.
n_min, n_max = int(tmin * sfreq), int(tmax * sfreq)
window = n_max - n_min
n_samples = raw.shape[1]
toeplitz = list()
classes = np.unique(events[:, 2])
for ii, this_class in enumerate(classes):
# select events by type
sel = events[:, 2] == this_class
# build toeplitz matrix
trig = np.zeros((n_samples,))
ix_trig = (events[sel, 0]) + n_min
trig[ix_trig] = 1
toeplitz.append(linalg.toeplitz(trig[0:window], trig))
# Concatenate toeplitz
toeplitz = np.array(toeplitz)
X = np.concatenate(toeplitz)
# least square estimation
predictor = np.dot(pinv(np.dot(X, X.T)), X)
evokeds = np.dot(predictor, raw.T)
evokeds = np.transpose(np.vsplit(evokeds, len(classes)), (0, 2, 1))
return evokeds, toeplitz
def _fit_xdawn(
epochs_data,
y,
n_components,
reg=None,
signal_cov=None,
events=None,
tmin=0.0,
sfreq=1.0,
method_params=None,
info=None,
):
"""Fit filters and coefs using Xdawn Algorithm.
Xdawn is a spatial filtering method designed to improve the signal
to signal + noise ratio (SSNR) of the event related responses. Xdawn was
originally designed for P300 evoked potential by enhancing the target
response with respect to the non-target response. This implementation is a
generalization to any type of event related response.
Parameters
----------
epochs_data : array, shape (n_epochs, n_channels, n_times)
The epochs data.
y : array, shape (n_epochs)
The epochs class.
n_components : int (default 2)
The number of components to decompose the signals signals.
reg : float | str | None (default None)
If not None (same as ``'empirical'``, default), allow
regularization for covariance estimation.
If float, shrinkage is used (0 <= shrinkage <= 1).
For str options, ``reg`` will be passed as ``method`` to
:func:`mne.compute_covariance`.
signal_cov : None | Covariance | array, shape (n_channels, n_channels)
The signal covariance used for whitening of the data.
if None, the covariance is estimated from the epochs signal.
events : array, shape (n_epochs, 3)
The epochs events, used to correct for epochs overlap.
tmin : float
Epochs starting time. Only used if events is passed to correct for
epochs overlap.
sfreq : float
Sampling frequency. Only used if events is passed to correct for
epochs overlap.
Returns
-------
filters : array, shape (n_channels, n_channels)
The Xdawn components used to decompose the data for each event type.
Each row corresponds to one component.
patterns : array, shape (n_channels, n_channels)
The Xdawn patterns used to restore the signals for each event type.
evokeds : array, shape (n_class, n_components, n_times)
The independent evoked responses per condition.
"""
if not isinstance(epochs_data, np.ndarray) or epochs_data.ndim != 3:
raise ValueError("epochs_data must be 3D ndarray")
classes = np.unique(y)
# XXX Eventually this could be made to deal with rank deficiency properly
# by exposing this "rank" parameter, but this will require refactoring
# the linalg.eigh call to operate in the lower-dimension
# subspace, then project back out.
# Retrieve or compute whitening covariance
if signal_cov is None:
signal_cov = _regularized_covariance(
np.hstack(epochs_data), reg, method_params, info, rank="full"
)
elif isinstance(signal_cov, Covariance):
signal_cov = signal_cov.data
if not isinstance(signal_cov, np.ndarray) or (
not np.array_equal(signal_cov.shape, np.tile(epochs_data.shape[1], 2))
):
raise ValueError(
"signal_cov must be None, a covariance instance, "
"or an array of shape (n_chans, n_chans)"
)
# Get prototype events
if events is not None:
evokeds, toeplitzs = _least_square_evoked(epochs_data, events, tmin, sfreq)
else:
evokeds, toeplitzs = list(), list()
for c in classes:
# Prototyped response for each class
evokeds.append(np.mean(epochs_data[y == c, :, :], axis=0))
toeplitzs.append(1.0)
filters = list()
patterns = list()
for evo, toeplitz in zip(evokeds, toeplitzs):
# Estimate covariance matrix of the prototype response
evo = np.dot(evo, toeplitz)
evo_cov = _regularized_covariance(evo, reg, method_params, info, rank="full")
# Fit spatial filters
try:
evals, evecs = linalg.eigh(evo_cov, signal_cov)
except np.linalg.LinAlgError as exp:
raise ValueError(
"Could not compute eigenvalues, ensure "
f"proper regularization ({exp})"
)
evecs = evecs[:, np.argsort(evals)[::-1]] # sort eigenvectors
evecs /= np.apply_along_axis(np.linalg.norm, 0, evecs)
_patterns = np.linalg.pinv(evecs.T)
filters.append(evecs[:, :n_components].T)
patterns.append(_patterns[:, :n_components].T)
filters = np.concatenate(filters, axis=0)
patterns = np.concatenate(patterns, axis=0)
evokeds = np.array(evokeds)
return filters, patterns, evokeds
class _XdawnTransformer(BaseEstimator, TransformerMixin):
"""Implementation of the Xdawn Algorithm compatible with scikit-learn.
Xdawn is a spatial filtering method designed to improve the signal
to signal + noise ratio (SSNR) of the event related responses. Xdawn was
originally designed for P300 evoked potential by enhancing the target
response with respect to the non-target response. This implementation is a
generalization to any type of event related response.
.. note:: _XdawnTransformer does not correct for epochs overlap. To correct
overlaps see ``Xdawn``.
Parameters
----------
n_components : int (default 2)
The number of components to decompose the signals.
reg : float | str | None (default None)
If not None (same as ``'empirical'``, default), allow
regularization for covariance estimation.
If float, shrinkage is used (0 <= shrinkage <= 1).
For str options, ``reg`` will be passed to ``method`` to
:func:`mne.compute_covariance`.
signal_cov : None | Covariance | array, shape (n_channels, n_channels)
The signal covariance used for whitening of the data.
if None, the covariance is estimated from the epochs signal.
method_params : dict | None
Parameters to pass to :func:`mne.compute_covariance`.
.. versionadded:: 0.16
Attributes
----------
classes_ : array, shape (n_classes)
The event indices of the classes.
filters_ : array, shape (n_channels, n_channels)
The Xdawn components used to decompose the data for each event type.
patterns_ : array, shape (n_channels, n_channels)
The Xdawn patterns used to restore the signals for each event type.
"""
def __init__(self, n_components=2, reg=None, signal_cov=None, method_params=None):
"""Init."""
self.n_components = n_components
self.signal_cov = signal_cov
self.reg = reg
self.method_params = method_params
def fit(self, X, y=None):
"""Fit Xdawn spatial filters.
Parameters
----------
X : array, shape (n_epochs, n_channels, n_samples)
The target data.
y : array, shape (n_epochs,) | None
The target labels. If None, Xdawn fit on the average evoked.
Returns
-------
self : Xdawn instance
The Xdawn instance.
"""
X, y = self._check_Xy(X, y)
# Main function
self.classes_ = np.unique(y)
self.filters_, self.patterns_, _ = _fit_xdawn(
X,
y,
n_components=self.n_components,
reg=self.reg,
signal_cov=self.signal_cov,
method_params=self.method_params,
)
return self
def transform(self, X):
"""Transform data with spatial filters.
Parameters
----------
X : array, shape (n_epochs, n_channels, n_samples)
The target data.
Returns
-------
X : array, shape (n_epochs, n_components * n_classes, n_samples)
The transformed data.
"""
X, _ = self._check_Xy(X)
# Check size
if self.filters_.shape[1] != X.shape[1]:
raise ValueError(
f"X must have {self.filters_.shape[1]} channels, got {X.shape[1]} "
"instead."
)
# Transform
X = np.dot(self.filters_, X)
X = X.transpose((1, 0, 2))
return X
def inverse_transform(self, X):
"""Remove selected components from the signal.
Given the unmixing matrix, transform data, zero out components,
and inverse transform the data. This procedure will reconstruct
the signals from which the dynamics described by the excluded
components is subtracted.
Parameters
----------
X : array, shape (n_epochs, n_components * n_classes, n_times)
The transformed data.
Returns
-------
X : array, shape (n_epochs, n_channels * n_classes, n_times)
The inverse transform data.
"""
# Check size
X, _ = self._check_Xy(X)
n_epochs, n_comp, n_times = X.shape
if n_comp != (self.n_components * len(self.classes_)):
raise ValueError(
f"X must have {self.n_components * len(self.classes_)} components, "
f"got {n_comp} instead."
)
# Transform
return np.dot(self.patterns_.T, X).transpose(1, 0, 2)
def _check_Xy(self, X, y=None):
"""Check X and y types and dimensions."""
# Check data
if not isinstance(X, np.ndarray) or X.ndim != 3:
raise ValueError(
"X must be an array of shape (n_epochs, n_channels, n_samples)."
)
if y is None:
y = np.ones(len(X))
y = np.asarray(y)
if len(X) != len(y):
raise ValueError("X and y must have the same length")
return X, y
class Xdawn(_XdawnTransformer):
"""Implementation of the Xdawn Algorithm.
Xdawn :footcite:`RivetEtAl2009,RivetEtAl2011` is a spatial
filtering method designed to improve the signal to signal + noise
ratio (SSNR) of the ERP responses. Xdawn was originally designed for
P300 evoked potential by enhancing the target response with respect
to the non-target response. This implementation is a generalization
to any type of ERP.
Parameters
----------
n_components : int, (default 2)
The number of components to decompose the signals.
signal_cov : None | Covariance | ndarray, shape (n_channels, n_channels)
(default None). The signal covariance used for whitening of the data.
if None, the covariance is estimated from the epochs signal.
correct_overlap : 'auto' or bool (default 'auto')
Compute the independent evoked responses per condition, while
correcting for event overlaps if any. If 'auto', then
overlapp_correction = True if the events do overlap.
reg : float | str | None (default None)
If not None (same as ``'empirical'``, default), allow
regularization for covariance estimation.
If float, shrinkage is used (0 <= shrinkage <= 1).
For str options, ``reg`` will be passed as ``method`` to
:func:`mne.compute_covariance`.
Attributes
----------
filters_ : dict of ndarray
If fit, the Xdawn components used to decompose the data for each event
type, else empty. For each event type, the filters are in the rows of
the corresponding array.
patterns_ : dict of ndarray
If fit, the Xdawn patterns used to restore the signals for each event
type, else empty.
evokeds_ : dict of Evoked
If fit, the evoked response for each event type.
event_id_ : dict
The event id.
correct_overlap_ : bool
Whether overlap correction was applied.
See Also
--------
mne.decoding.CSP, mne.decoding.SPoC
Notes
-----
.. versionadded:: 0.10
References
----------
.. footbibliography::
"""
def __init__(
self, n_components=2, signal_cov=None, correct_overlap="auto", reg=None
):
"""Init."""
super().__init__(n_components=n_components, signal_cov=signal_cov, reg=reg)
self.correct_overlap = _check_option(
"correct_overlap", correct_overlap, ["auto", True, False]
)
def fit(self, epochs, y=None):
"""Fit Xdawn from epochs.
Parameters
----------
epochs : instance of Epochs
An instance of Epoch on which Xdawn filters will be fitted.
y : ndarray | None (default None)
If None, used epochs.events[:, 2].
Returns
-------
self : instance of Xdawn
The Xdawn instance.
"""
# Check data
if not isinstance(epochs, BaseEpochs):
raise ValueError("epochs must be an Epochs object.")
picks = _pick_data_channels(epochs.info)
use_info = pick_info(epochs.info, picks)
X = epochs.get_data(picks)
y = epochs.events[:, 2] if y is None else y
self.event_id_ = epochs.event_id
# Check that no baseline was applied with correct overlap
correct_overlap = self.correct_overlap
if correct_overlap == "auto":
# Events are overlapped if the minimal inter-stimulus
# interval is smaller than the time window.
isi = np.diff(np.sort(epochs.events[:, 0]))
window = int((epochs.tmax - epochs.tmin) * epochs.info["sfreq"])
correct_overlap = isi.min() < window
if epochs.baseline and correct_overlap:
raise ValueError("Cannot apply correct_overlap if epochs were baselined.")
events, tmin, sfreq = None, 0.0, 1.0
if correct_overlap:
events = epochs.events
tmin = epochs.tmin
sfreq = epochs.info["sfreq"]
self.correct_overlap_ = correct_overlap
# Note: In this original version of Xdawn we compute and keep all
# components. The selection comes at transform().
n_components = X.shape[1]
# Main fitting function
filters, patterns, evokeds = _fit_xdawn(
X,
y,
n_components=n_components,
reg=self.reg,
signal_cov=self.signal_cov,
events=events,
tmin=tmin,
sfreq=sfreq,
method_params=self.method_params,
info=use_info,
)
# Re-order filters and patterns according to event_id
filters = filters.reshape(-1, n_components, filters.shape[-1])
patterns = patterns.reshape(-1, n_components, patterns.shape[-1])
self.filters_, self.patterns_, self.evokeds_ = dict(), dict(), dict()
idx = np.argsort([value for _, value in epochs.event_id.items()])
for eid, this_filter, this_pattern, this_evo in zip(
epochs.event_id, filters[idx], patterns[idx], evokeds[idx]
):
self.filters_[eid] = this_filter
self.patterns_[eid] = this_pattern
n_events = len(epochs[eid])
evoked = EvokedArray(
this_evo, use_info, tmin=epochs.tmin, comment=eid, nave=n_events
)
self.evokeds_[eid] = evoked
return self
def transform(self, inst):
"""Apply Xdawn dim reduction.
Parameters
----------
inst : Epochs | Evoked | ndarray, shape ([n_epochs, ]n_channels, n_times)
Data on which Xdawn filters will be applied.
Returns
-------
X : ndarray, shape ([n_epochs, ]n_components * n_event_types, n_times)
Spatially filtered signals.
""" # noqa: E501
if isinstance(inst, BaseEpochs):
X = inst.get_data(copy=False)
elif isinstance(inst, Evoked):
X = inst.data
elif isinstance(inst, np.ndarray):
X = inst
if X.ndim not in (2, 3):
raise ValueError(f"X must be 2D or 3D, got {X.ndim}")
else:
raise ValueError("Data input must be of Epoch type or numpy array")
filters = [filt[: self.n_components] for filt in self.filters_.values()]
filters = np.concatenate(filters, axis=0)
X = np.dot(filters, X)
if X.ndim == 3:
X = X.transpose((1, 0, 2))
return X
def apply(self, inst, event_id=None, include=None, exclude=None):
"""Remove selected components from the signal.
Given the unmixing matrix, transform data,
zero out components, and inverse transform the data.
This procedure will reconstruct the signals from which
the dynamics described by the excluded components is subtracted.
Parameters
----------
inst : instance of Raw | Epochs | Evoked
The data to be processed.
event_id : dict | list of str | None (default None)
The kind of event to apply. if None, a dict of inst will be return
one for each type of event xdawn has been fitted.
include : array_like of int | None (default None)
The indices referring to columns in the ummixing matrix. The
components to be kept. If None, the first n_components (as defined
in the Xdawn constructor) will be kept.
exclude : array_like of int | None (default None)
The indices referring to columns in the ummixing matrix. The
components to be zeroed out. If None, all the components except the
first n_components will be exclude.
Returns
-------
out : dict
A dict of instance (from the same type as inst input) for each
event type in event_id.
"""
if event_id is None:
event_id = self.event_id_
if not isinstance(inst, BaseRaw | BaseEpochs | Evoked):
raise ValueError("Data input must be Raw, Epochs or Evoked type")
picks = _pick_data_channels(inst.info)
# Define the components to keep
default_exclude = list(range(self.n_components, len(inst.ch_names)))
if exclude is None:
exclude = default_exclude
else:
exclude = list(set(list(default_exclude) + list(exclude)))
if isinstance(inst, BaseRaw):
out = self._apply_raw(
raw=inst,
include=include,
exclude=exclude,
event_id=event_id,
picks=picks,
)
elif isinstance(inst, BaseEpochs):
out = self._apply_epochs(
epochs=inst,
include=include,
picks=picks,
exclude=exclude,
event_id=event_id,
)
elif isinstance(inst, Evoked):
out = self._apply_evoked(
evoked=inst,
include=include,
picks=picks,
exclude=exclude,
event_id=event_id,
)
return out
def _apply_raw(self, raw, include, exclude, event_id, picks):
"""Aux method."""
if not raw.preload:
raise ValueError("Raw data must be preloaded to apply Xdawn")
raws = dict()
for eid in event_id:
data = raw[picks, :][0]
data = self._pick_sources(data, include, exclude, eid)
raw_r = raw.copy()
raw_r[picks, :] = data
raws[eid] = raw_r
return raws
def _apply_epochs(self, epochs, include, exclude, event_id, picks):
"""Aux method."""
if not epochs.preload:
raise ValueError("Epochs must be preloaded to apply Xdawn")
# special case where epochs come picked but fit was 'unpicked'.
epochs_dict = dict()
data = np.hstack(epochs.get_data(picks))
for eid in event_id:
data_r = self._pick_sources(data, include, exclude, eid)
data_r = np.array(np.split(data_r, len(epochs.events), 1))
epochs_r = epochs.copy().load_data()
epochs_r._data[:, picks, :] = data_r
epochs_dict[eid] = epochs_r
return epochs_dict
def _apply_evoked(self, evoked, include, exclude, event_id, picks):
"""Aux method."""
data = evoked.data[picks]
evokeds = dict()
for eid in event_id:
data_r = self._pick_sources(data, include, exclude, eid)
evokeds[eid] = evoked.copy()
# restore evoked
evokeds[eid].data[picks] = data_r
return evokeds
def _pick_sources(self, data, include, exclude, eid):
"""Aux method."""
logger.info("Transforming to Xdawn space")
# Apply unmixing
sources = np.dot(self.filters_[eid], data)
if include not in (None, list()):
mask = np.ones(len(sources), dtype=bool)
mask[np.unique(include)] = False
sources[mask] = 0.0
logger.info(f"Zeroing out {int(mask.sum())} Xdawn components")
elif exclude not in (None, list()):
exclude_ = np.unique(exclude)
sources[exclude_] = 0.0
logger.info(f"Zeroing out {len(exclude_)} Xdawn components")
logger.info("Inverse transforming to sensor space")
data = np.dot(self.patterns_[eid].T, sources)
return data
def inverse_transform(self):
"""Not implemented, see Xdawn.apply() instead."""
# Exists because of _XdawnTransformer
raise NotImplementedError("See Xdawn.apply()")