initial commit
This commit is contained in:
231
mne/preprocessing/interpolate.py
Normal file
231
mne/preprocessing/interpolate.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""Tools for data interpolation."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from itertools import chain
|
||||
|
||||
import numpy as np
|
||||
from scipy.sparse.csgraph import connected_components
|
||||
|
||||
from .._fiff.meas_info import create_info
|
||||
from ..epochs import BaseEpochs, EpochsArray
|
||||
from ..evoked import Evoked, EvokedArray
|
||||
from ..io import BaseRaw, RawArray
|
||||
from ..transforms import _cart_to_sph, _sph_to_cart
|
||||
from ..utils import _ensure_int, _validate_type
|
||||
|
||||
|
||||
def equalize_bads(insts, interp_thresh=1.0, copy=True):
|
||||
"""Interpolate or mark bads consistently for a list of instances.
|
||||
|
||||
Once called on a list of instances, the instances can be concatenated
|
||||
as they will have the same list of bad channels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
insts : list
|
||||
The list of instances (Evoked, Epochs or Raw) to consider
|
||||
for interpolation. Each instance should have marked channels.
|
||||
interp_thresh : float
|
||||
A float between 0 and 1 (default) that specifies the fraction of time
|
||||
a channel should be good to be eventually interpolated for certain
|
||||
instances. For example if 0.5, a channel which is good at least half
|
||||
of the time will be interpolated in the instances where it is marked
|
||||
as bad. If 1 then channels will never be interpolated and if 0 all bad
|
||||
channels will be systematically interpolated.
|
||||
copy : bool
|
||||
If True then the returned instances will be copies.
|
||||
|
||||
Returns
|
||||
-------
|
||||
insts_bads : list
|
||||
The list of instances, with the same channel(s) marked as bad in all of
|
||||
them, possibly with some formerly bad channels interpolated.
|
||||
"""
|
||||
if not 0 <= interp_thresh <= 1:
|
||||
raise ValueError(f"interp_thresh must be between 0 and 1, got {interp_thresh}")
|
||||
|
||||
all_bads = list(set(chain.from_iterable([inst.info["bads"] for inst in insts])))
|
||||
if isinstance(insts[0], BaseEpochs):
|
||||
durations = [len(inst) * len(inst.times) for inst in insts]
|
||||
else:
|
||||
durations = [len(inst.times) for inst in insts]
|
||||
|
||||
good_times = []
|
||||
for ch_name in all_bads:
|
||||
good_times.append(
|
||||
sum(
|
||||
durations[k]
|
||||
for k, inst in enumerate(insts)
|
||||
if ch_name not in inst.info["bads"]
|
||||
)
|
||||
/ np.sum(durations)
|
||||
)
|
||||
|
||||
bads_keep = [ch for k, ch in enumerate(all_bads) if good_times[k] < interp_thresh]
|
||||
if copy:
|
||||
insts = [inst.copy() for inst in insts]
|
||||
|
||||
for inst in insts:
|
||||
if len(set(inst.info["bads"]) - set(bads_keep)):
|
||||
inst.interpolate_bads(exclude=bads_keep)
|
||||
inst.info["bads"] = bads_keep
|
||||
|
||||
return insts
|
||||
|
||||
|
||||
def interpolate_bridged_electrodes(inst, bridged_idx, bad_limit=4):
|
||||
"""Interpolate bridged electrode pairs.
|
||||
|
||||
Because bridged electrodes contain brain signal, it's just that the
|
||||
signal is spatially smeared between the two electrodes, we can
|
||||
make a virtual channel midway between the bridged pairs and use
|
||||
that to aid in interpolation rather than completely discarding the
|
||||
data from the two channels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst : instance of Epochs, Evoked, or Raw
|
||||
The data object with channels that are to be interpolated.
|
||||
bridged_idx : list of tuple
|
||||
The indices of channels marked as bridged with each bridged
|
||||
pair stored as a tuple.
|
||||
bad_limit : int
|
||||
The maximum number of electrodes that can be bridged together
|
||||
(included) and interpolated. Above this number, an error will be
|
||||
raised.
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
Returns
|
||||
-------
|
||||
inst : instance of Epochs, Evoked, or Raw
|
||||
The modified data object.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.preprocessing.compute_bridged_electrodes
|
||||
"""
|
||||
_validate_type(inst, (BaseRaw, BaseEpochs, Evoked))
|
||||
bad_limit = _ensure_int(bad_limit, "bad_limit")
|
||||
if bad_limit <= 0:
|
||||
raise ValueError(
|
||||
"Argument 'bad_limit' should be a strictly positive "
|
||||
f"integer. Provided {bad_limit} is invalid."
|
||||
)
|
||||
montage = inst.get_montage()
|
||||
if montage is None:
|
||||
raise RuntimeError("No channel positions found in ``inst``")
|
||||
pos = montage.get_positions()
|
||||
if pos["coord_frame"] != "head":
|
||||
raise RuntimeError(
|
||||
f"Montage channel positions must be in ``head`` got {pos['coord_frame']}"
|
||||
)
|
||||
# store bads orig to put back at the end
|
||||
bads_orig = inst.info["bads"]
|
||||
inst.info["bads"] = list()
|
||||
|
||||
# look for group of bad channels
|
||||
nodes = sorted(set(chain(*bridged_idx)))
|
||||
G_dense = np.zeros((len(nodes), len(nodes)))
|
||||
# fill the edges with a weight of 1
|
||||
for bridge in bridged_idx:
|
||||
idx0 = np.searchsorted(nodes, bridge[0])
|
||||
idx1 = np.searchsorted(nodes, bridge[1])
|
||||
G_dense[idx0, idx1] = 1
|
||||
G_dense[idx1, idx0] = 1
|
||||
# look for connected components
|
||||
_, labels = connected_components(G_dense, directed=False)
|
||||
groups_idx = [[nodes[j] for j in np.where(labels == k)[0]] for k in set(labels)]
|
||||
groups_names = [
|
||||
[inst.info.ch_names[k] for k in group_idx] for group_idx in groups_idx
|
||||
]
|
||||
|
||||
# warn for all bridged areas that include too many electrodes
|
||||
for group_names in groups_names:
|
||||
if len(group_names) > bad_limit:
|
||||
raise RuntimeError(
|
||||
f"The channels {', '.join(group_names)} are bridged together "
|
||||
"and form a large area of bridged electrodes. Interpolation "
|
||||
"might be inaccurate."
|
||||
)
|
||||
|
||||
# make virtual channels
|
||||
virtual_chs = dict()
|
||||
bads = set()
|
||||
for k, group_idx in enumerate(groups_idx):
|
||||
group_names = [inst.info.ch_names[k] for k in group_idx]
|
||||
bads = bads.union(group_names)
|
||||
# compute centroid position in spherical "head" coordinates
|
||||
pos_virtual = _find_centroid_sphere(pos["ch_pos"], group_names)
|
||||
# create the virtual channel info and set the position
|
||||
virtual_info = create_info([f"virtual {k + 1}"], inst.info["sfreq"], "eeg")
|
||||
virtual_info["chs"][0]["loc"][:3] = pos_virtual
|
||||
# create virtual channel
|
||||
data = inst.get_data(picks=group_names)
|
||||
if isinstance(inst, BaseRaw):
|
||||
data = np.average(data, axis=0).reshape(1, -1)
|
||||
virtual_ch = RawArray(data, virtual_info, first_samp=inst.first_samp)
|
||||
elif isinstance(inst, BaseEpochs):
|
||||
data = np.average(data, axis=1).reshape(len(data), 1, -1)
|
||||
virtual_ch = EpochsArray(data, virtual_info, tmin=inst.tmin)
|
||||
else: # evoked
|
||||
data = np.average(data, axis=0).reshape(1, -1)
|
||||
virtual_ch = EvokedArray(
|
||||
np.average(data, axis=0).reshape(1, -1),
|
||||
virtual_info,
|
||||
tmin=inst.tmin,
|
||||
nave=inst.nave,
|
||||
kind=inst.kind,
|
||||
)
|
||||
virtual_chs[f"virtual {k + 1}"] = virtual_ch
|
||||
|
||||
# add the virtual channels
|
||||
inst.add_channels(list(virtual_chs.values()), force_update_info=True)
|
||||
|
||||
# use the virtual channels to interpolate
|
||||
inst.info["bads"] = list(bads)
|
||||
inst.interpolate_bads()
|
||||
|
||||
# drop virtual channels
|
||||
inst.drop_channels(list(virtual_chs.keys()))
|
||||
|
||||
inst.info["bads"] = bads_orig
|
||||
return inst
|
||||
|
||||
|
||||
def _find_centroid_sphere(ch_pos, group_names):
|
||||
"""Compute the centroid position between N electrodes.
|
||||
|
||||
The centroid should be determined in spherical "head" coordinates which is
|
||||
more accurante than cutting through the scalp by averaging in cartesian
|
||||
coordinates.
|
||||
|
||||
A simple way is to average the location in cartesian coordinate, convert
|
||||
to spehrical coordinate and replace the radius with the average radius of
|
||||
the N points in spherical coordinates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ch_pos : OrderedDict
|
||||
The position of all channels in cartesian coordinates.
|
||||
group_names : list | tuple
|
||||
The name of the N electrodes used to determine the centroid.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pos_centroid : array of shape (3,)
|
||||
The position of the centroid in cartesian coordinates.
|
||||
"""
|
||||
cartesian_positions = np.array([ch_pos[ch_name] for ch_name in group_names])
|
||||
sphere_positions = _cart_to_sph(cartesian_positions)
|
||||
cartesian_pos_centroid = np.average(cartesian_positions, axis=0)
|
||||
sphere_pos_centroid = _cart_to_sph(cartesian_pos_centroid)
|
||||
# average the radius and overwrite it
|
||||
avg_radius = np.average(sphere_positions, axis=0)[0]
|
||||
sphere_pos_centroid[0, 0] = avg_radius
|
||||
# convert back to cartesian
|
||||
pos_centroid = _sph_to_cart(sphere_pos_centroid)[0, :]
|
||||
return pos_centroid
|
||||
Reference in New Issue
Block a user