initial commit
This commit is contained in:
417
mne/channels/interpolation.py
Normal file
417
mne/channels/interpolation.py
Normal file
@@ -0,0 +1,417 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
from numpy.polynomial.legendre import legval
|
||||
from scipy.interpolate import RectBivariateSpline
|
||||
from scipy.linalg import pinv
|
||||
from scipy.spatial.distance import pdist, squareform
|
||||
|
||||
from .._fiff.meas_info import _simplify_info
|
||||
from .._fiff.pick import pick_channels, pick_info, pick_types
|
||||
from ..surface import _normalize_vectors
|
||||
from ..utils import _validate_type, logger, verbose, warn
|
||||
|
||||
|
||||
def _calc_h(cosang, stiffness=4, n_legendre_terms=50):
|
||||
"""Calculate spherical spline h function between points on a sphere.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cosang : array-like | float
|
||||
cosine of angles between pairs of points on a spherical surface. This
|
||||
is equivalent to the dot product of unit vectors.
|
||||
stiffness : float
|
||||
stiffnes of the spline. Also referred to as ``m``.
|
||||
n_legendre_terms : int
|
||||
number of Legendre terms to evaluate.
|
||||
"""
|
||||
factors = [
|
||||
(2 * n + 1) / (n ** (stiffness - 1) * (n + 1) ** (stiffness - 1) * 4 * np.pi)
|
||||
for n in range(1, n_legendre_terms + 1)
|
||||
]
|
||||
return legval(cosang, [0] + factors)
|
||||
|
||||
|
||||
def _calc_g(cosang, stiffness=4, n_legendre_terms=50):
|
||||
"""Calculate spherical spline g function between points on a sphere.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cosang : array-like of float, shape(n_channels, n_channels)
|
||||
cosine of angles between pairs of points on a spherical surface. This
|
||||
is equivalent to the dot product of unit vectors.
|
||||
stiffness : float
|
||||
stiffness of the spline.
|
||||
n_legendre_terms : int
|
||||
number of Legendre terms to evaluate.
|
||||
|
||||
Returns
|
||||
-------
|
||||
G : np.ndrarray of float, shape(n_channels, n_channels)
|
||||
The G matrix.
|
||||
"""
|
||||
factors = [
|
||||
(2 * n + 1) / (n**stiffness * (n + 1) ** stiffness * 4 * np.pi)
|
||||
for n in range(1, n_legendre_terms + 1)
|
||||
]
|
||||
return legval(cosang, [0] + factors)
|
||||
|
||||
|
||||
def _make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
|
||||
"""Compute interpolation matrix based on spherical splines.
|
||||
|
||||
Implementation based on [1]
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos_from : np.ndarray of float, shape(n_good_sensors, 3)
|
||||
The positions to interpolate from.
|
||||
pos_to : np.ndarray of float, shape(n_bad_sensors, 3)
|
||||
The positions to interpolate.
|
||||
alpha : float
|
||||
Regularization parameter. Defaults to 1e-5.
|
||||
|
||||
Returns
|
||||
-------
|
||||
interpolation : np.ndarray of float, shape(len(pos_from), len(pos_to))
|
||||
The interpolation matrix that maps good signals to the location
|
||||
of bad signals.
|
||||
|
||||
References
|
||||
----------
|
||||
[1] Perrin, F., Pernier, J., Bertrand, O. and Echallier, JF. (1989).
|
||||
Spherical splines for scalp potential and current density mapping.
|
||||
Electroencephalography Clinical Neurophysiology, Feb; 72(2):184-7.
|
||||
"""
|
||||
pos_from = pos_from.copy()
|
||||
pos_to = pos_to.copy()
|
||||
n_from = pos_from.shape[0]
|
||||
n_to = pos_to.shape[0]
|
||||
|
||||
# normalize sensor positions to sphere
|
||||
_normalize_vectors(pos_from)
|
||||
_normalize_vectors(pos_to)
|
||||
|
||||
# cosine angles between source positions
|
||||
cosang_from = pos_from.dot(pos_from.T)
|
||||
cosang_to_from = pos_to.dot(pos_from.T)
|
||||
G_from = _calc_g(cosang_from)
|
||||
G_to_from = _calc_g(cosang_to_from)
|
||||
assert G_from.shape == (n_from, n_from)
|
||||
assert G_to_from.shape == (n_to, n_from)
|
||||
|
||||
if alpha is not None:
|
||||
G_from.flat[:: len(G_from) + 1] += alpha
|
||||
|
||||
C = np.vstack(
|
||||
[
|
||||
np.hstack([G_from, np.ones((n_from, 1))]),
|
||||
np.hstack([np.ones((1, n_from)), [[0]]]),
|
||||
]
|
||||
)
|
||||
C_inv = pinv(C)
|
||||
|
||||
interpolation = np.hstack([G_to_from, np.ones((n_to, 1))]) @ C_inv[:, :-1]
|
||||
assert interpolation.shape == (n_to, n_from)
|
||||
return interpolation
|
||||
|
||||
|
||||
def _do_interp_dots(inst, interpolation, goods_idx, bads_idx):
|
||||
"""Dot product of channel mapping matrix to channel data."""
|
||||
from ..epochs import BaseEpochs
|
||||
from ..evoked import Evoked
|
||||
from ..io import BaseRaw
|
||||
|
||||
_validate_type(inst, (BaseRaw, BaseEpochs, Evoked), "inst")
|
||||
inst._data[..., bads_idx, :] = np.matmul(
|
||||
interpolation, inst._data[..., goods_idx, :]
|
||||
)
|
||||
|
||||
|
||||
@verbose
|
||||
def _interpolate_bads_eeg(inst, origin, exclude=None, ecog=False, verbose=None):
|
||||
if exclude is None:
|
||||
exclude = list()
|
||||
bads_idx = np.zeros(len(inst.ch_names), dtype=bool)
|
||||
goods_idx = np.zeros(len(inst.ch_names), dtype=bool)
|
||||
|
||||
picks = pick_types(inst.info, meg=False, eeg=not ecog, ecog=ecog, exclude=exclude)
|
||||
inst.info._check_consistency()
|
||||
bads_idx[picks] = [inst.ch_names[ch] in inst.info["bads"] for ch in picks]
|
||||
|
||||
if len(picks) == 0 or bads_idx.sum() == 0:
|
||||
return
|
||||
|
||||
goods_idx[picks] = True
|
||||
goods_idx[bads_idx] = False
|
||||
|
||||
pos = inst._get_channel_positions(picks)
|
||||
|
||||
# Make sure only EEG are used
|
||||
bads_idx_pos = bads_idx[picks]
|
||||
goods_idx_pos = goods_idx[picks]
|
||||
|
||||
# test spherical fit
|
||||
distance = np.linalg.norm(pos - origin, axis=-1)
|
||||
distance = np.mean(distance / np.mean(distance))
|
||||
if np.abs(1.0 - distance) > 0.1:
|
||||
warn(
|
||||
"Your spherical fit is poor, interpolation results are "
|
||||
"likely to be inaccurate."
|
||||
)
|
||||
|
||||
pos_good = pos[goods_idx_pos] - origin
|
||||
pos_bad = pos[bads_idx_pos] - origin
|
||||
logger.info(f"Computing interpolation matrix from {len(pos_good)} sensor positions")
|
||||
interpolation = _make_interpolation_matrix(pos_good, pos_bad)
|
||||
|
||||
logger.info(f"Interpolating {len(pos_bad)} sensors")
|
||||
_do_interp_dots(inst, interpolation, goods_idx, bads_idx)
|
||||
|
||||
|
||||
@verbose
|
||||
def _interpolate_bads_ecog(inst, origin, exclude=None, verbose=None):
|
||||
_interpolate_bads_eeg(inst, origin, exclude=exclude, ecog=True, verbose=verbose)
|
||||
|
||||
|
||||
def _interpolate_bads_meg(
|
||||
inst, mode="accurate", origin=(0.0, 0.0, 0.04), verbose=None, ref_meg=False
|
||||
):
|
||||
return _interpolate_bads_meeg(
|
||||
inst, mode, origin, ref_meg=ref_meg, eeg=False, verbose=verbose
|
||||
)
|
||||
|
||||
|
||||
@verbose
|
||||
def _interpolate_bads_nan(
|
||||
inst,
|
||||
ch_type,
|
||||
ref_meg=False,
|
||||
exclude=(),
|
||||
*,
|
||||
verbose=None,
|
||||
):
|
||||
info = _simplify_info(inst.info)
|
||||
picks_type = pick_types(info, ref_meg=ref_meg, exclude=exclude, **{ch_type: True})
|
||||
use_ch_names = [inst.info["ch_names"][p] for p in picks_type]
|
||||
bads_type = [ch for ch in inst.info["bads"] if ch in use_ch_names]
|
||||
if len(bads_type) == 0 or len(picks_type) == 0:
|
||||
return
|
||||
# select the bad channels to be interpolated
|
||||
picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[])
|
||||
inst._data[..., picks_bad, :] = np.nan
|
||||
|
||||
|
||||
@verbose
|
||||
def _interpolate_bads_meeg(
|
||||
inst,
|
||||
mode="accurate",
|
||||
origin=(0.0, 0.0, 0.04),
|
||||
meg=True,
|
||||
eeg=True,
|
||||
ref_meg=False,
|
||||
exclude=(),
|
||||
*,
|
||||
method=None,
|
||||
verbose=None,
|
||||
):
|
||||
from ..forward import _map_meg_or_eeg_channels
|
||||
|
||||
if method is None:
|
||||
method = {"meg": "MNE", "eeg": "MNE"}
|
||||
bools = dict(meg=meg, eeg=eeg)
|
||||
info = _simplify_info(inst.info)
|
||||
for ch_type, do in bools.items():
|
||||
if not do:
|
||||
continue
|
||||
kw = dict(meg=False, eeg=False)
|
||||
kw[ch_type] = True
|
||||
picks_type = pick_types(info, ref_meg=ref_meg, exclude=exclude, **kw)
|
||||
picks_good = pick_types(info, ref_meg=ref_meg, exclude="bads", **kw)
|
||||
use_ch_names = [inst.info["ch_names"][p] for p in picks_type]
|
||||
bads_type = [ch for ch in inst.info["bads"] if ch in use_ch_names]
|
||||
if len(bads_type) == 0 or len(picks_type) == 0:
|
||||
continue
|
||||
# select the bad channels to be interpolated
|
||||
picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[])
|
||||
|
||||
# do MNE based interpolation
|
||||
if ch_type == "eeg":
|
||||
picks_to = picks_type
|
||||
bad_sel = np.isin(picks_type, picks_bad)
|
||||
else:
|
||||
picks_to = picks_bad
|
||||
bad_sel = slice(None)
|
||||
info_from = pick_info(inst.info, picks_good)
|
||||
info_to = pick_info(inst.info, picks_to)
|
||||
mapping = _map_meg_or_eeg_channels(info_from, info_to, mode=mode, origin=origin)
|
||||
mapping = mapping[bad_sel]
|
||||
_do_interp_dots(inst, mapping, picks_good, picks_bad)
|
||||
|
||||
|
||||
@verbose
|
||||
def _interpolate_bads_nirs(inst, exclude=(), verbose=None):
|
||||
from mne.preprocessing.nirs import _validate_nirs_info
|
||||
|
||||
if len(pick_types(inst.info, fnirs=True, exclude=())) == 0:
|
||||
return
|
||||
|
||||
# Returns pick of all nirs and ensures channels are correctly ordered
|
||||
picks_nirs = _validate_nirs_info(inst.info)
|
||||
nirs_ch_names = [inst.info["ch_names"][p] for p in picks_nirs]
|
||||
nirs_ch_names = [ch for ch in nirs_ch_names if ch not in exclude]
|
||||
bads_nirs = [ch for ch in inst.info["bads"] if ch in nirs_ch_names]
|
||||
if len(bads_nirs) == 0:
|
||||
return
|
||||
picks_bad = pick_channels(inst.info["ch_names"], bads_nirs, exclude=[])
|
||||
bads_mask = [p in picks_bad for p in picks_nirs]
|
||||
|
||||
chs = [inst.info["chs"][i] for i in picks_nirs]
|
||||
locs3d = np.array([ch["loc"][:3] for ch in chs])
|
||||
|
||||
dist = pdist(locs3d)
|
||||
dist = squareform(dist)
|
||||
|
||||
for bad in picks_bad:
|
||||
dists_to_bad = dist[bad]
|
||||
# Ignore distances to self
|
||||
dists_to_bad[dists_to_bad == 0] = np.inf
|
||||
# Ignore distances to other bad channels
|
||||
dists_to_bad[bads_mask] = np.inf
|
||||
# Find closest remaining channels for same frequency
|
||||
closest_idx = np.argmin(dists_to_bad) + (bad % 2)
|
||||
inst._data[bad] = inst._data[closest_idx]
|
||||
|
||||
# TODO: this seems like a bug because it does not respect reset_bads
|
||||
inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude]
|
||||
|
||||
return inst
|
||||
|
||||
|
||||
def _find_seeg_electrode_shaft(pos, tol_shaft=0.002, tol_spacing=1):
|
||||
# 1) find nearest neighbor to define the electrode shaft line
|
||||
# 2) find all contacts on the same line
|
||||
# 3) remove contacts with large distances
|
||||
|
||||
dist = squareform(pdist(pos))
|
||||
np.fill_diagonal(dist, np.inf)
|
||||
|
||||
shafts = list()
|
||||
shaft_ts = list()
|
||||
for i, n1 in enumerate(pos):
|
||||
if any([i in shaft for shaft in shafts]):
|
||||
continue
|
||||
n2 = pos[np.argmin(dist[i])] # 1
|
||||
# https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html
|
||||
shaft_dists = np.linalg.norm(
|
||||
np.cross((pos - n1), (pos - n2)), axis=1
|
||||
) / np.linalg.norm(n2 - n1)
|
||||
shaft = np.where(shaft_dists < tol_shaft)[0] # 2
|
||||
shaft_prev = None
|
||||
for _ in range(10): # avoid potential cycles
|
||||
if np.array_equal(shaft, shaft_prev):
|
||||
break
|
||||
shaft_prev = shaft
|
||||
# compute median shaft line
|
||||
v = np.median(
|
||||
[
|
||||
pos[i] - pos[j]
|
||||
for idx, i in enumerate(shaft)
|
||||
for j in shaft[idx + 1 :]
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
c = np.median(pos[shaft], axis=0)
|
||||
# recompute distances
|
||||
shaft_dists = np.linalg.norm(
|
||||
np.cross((pos - c), (pos - c + v)), axis=1
|
||||
) / np.linalg.norm(v)
|
||||
shaft = np.where(shaft_dists < tol_shaft)[0]
|
||||
ts = np.array([np.dot(c - n0, v) / np.linalg.norm(v) ** 2 for n0 in pos[shaft]])
|
||||
shaft_order = np.argsort(ts)
|
||||
shaft = shaft[shaft_order]
|
||||
ts = ts[shaft_order]
|
||||
|
||||
# only include the largest group with spacing with the error tolerance
|
||||
# avoid interpolating across spans between contacts
|
||||
t_diffs = np.diff(ts)
|
||||
t_diff_med = np.median(t_diffs)
|
||||
spacing_errors = (t_diffs - t_diff_med) / t_diff_med
|
||||
groups = list()
|
||||
group = [shaft[0]]
|
||||
for j in range(len(shaft) - 1):
|
||||
if spacing_errors[j] > tol_spacing:
|
||||
groups.append(group)
|
||||
group = [shaft[j + 1]]
|
||||
else:
|
||||
group.append(shaft[j + 1])
|
||||
groups.append(group)
|
||||
group = [group for group in groups if i in group][0]
|
||||
ts = ts[np.isin(shaft, group)]
|
||||
shaft = np.array(group, dtype=int)
|
||||
|
||||
shafts.append(shaft)
|
||||
shaft_ts.append(ts)
|
||||
return shafts, shaft_ts
|
||||
|
||||
|
||||
@verbose
|
||||
def _interpolate_bads_seeg(
|
||||
inst, exclude=None, tol_shaft=0.002, tol_spacing=1, verbose=None
|
||||
):
|
||||
if exclude is None:
|
||||
exclude = list()
|
||||
picks = pick_types(inst.info, meg=False, seeg=True, exclude=exclude)
|
||||
inst.info._check_consistency()
|
||||
bads_idx = np.isin(np.array(inst.ch_names)[picks], inst.info["bads"])
|
||||
|
||||
if len(picks) == 0 or bads_idx.sum() == 0:
|
||||
return
|
||||
|
||||
pos = inst._get_channel_positions(picks)
|
||||
|
||||
# Make sure only sEEG are used
|
||||
bads_idx_pos = bads_idx[picks]
|
||||
|
||||
shafts, shaft_ts = _find_seeg_electrode_shaft(
|
||||
pos, tol_shaft=tol_shaft, tol_spacing=tol_spacing
|
||||
)
|
||||
|
||||
# interpolate the bad contacts
|
||||
picks_bad = list(np.where(bads_idx_pos)[0])
|
||||
for shaft, ts in zip(shafts, shaft_ts):
|
||||
bads_shaft = np.array([idx for idx in picks_bad if idx in shaft])
|
||||
if bads_shaft.size == 0:
|
||||
continue
|
||||
goods_shaft = shaft[np.isin(shaft, bads_shaft, invert=True)]
|
||||
if goods_shaft.size < 4: # cubic spline requires 3 channels
|
||||
msg = "No shaft" if shaft.size < 4 else "Not enough good channels"
|
||||
no_shaft_chs = " and ".join(np.array(inst.ch_names)[bads_shaft])
|
||||
raise RuntimeError(
|
||||
f"{msg} found in a line with {no_shaft_chs} "
|
||||
"at least 3 good channels on the same line "
|
||||
f"are required for interpolation, {goods_shaft.size} found. "
|
||||
f"Dropping {no_shaft_chs} is recommended."
|
||||
)
|
||||
logger.debug(
|
||||
f"Interpolating {np.array(inst.ch_names)[bads_shaft]} using "
|
||||
f"data from {np.array(inst.ch_names)[goods_shaft]}"
|
||||
)
|
||||
bads_shaft_idx = np.where(np.isin(shaft, bads_shaft))[0]
|
||||
goods_shaft_idx = np.where(~np.isin(shaft, bads_shaft))[0]
|
||||
|
||||
z = inst._data[..., goods_shaft, :]
|
||||
is_epochs = z.ndim == 3
|
||||
if is_epochs:
|
||||
z = z.swapaxes(0, 1)
|
||||
z = z.reshape(z.shape[0], -1)
|
||||
y = np.arange(z.shape[-1])
|
||||
out = RectBivariateSpline(x=ts[goods_shaft_idx], y=y, z=z)(
|
||||
x=ts[bads_shaft_idx], y=y
|
||||
)
|
||||
if is_epochs:
|
||||
out = out.reshape(bads_shaft.size, inst._data.shape[0], -1)
|
||||
out = out.swapaxes(0, 1)
|
||||
inst._data[..., bads_shaft, :] = out
|
||||
Reference in New Issue
Block a user