initial commit
This commit is contained in:
8
mne/simulation/__init__.py
Normal file
8
mne/simulation/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
"""Data simulation code."""
|
||||
import lazy_loader as lazy
|
||||
|
||||
(__getattr__, __dir__, __all__) = lazy.attach_stub(__name__, __file__)
|
||||
22
mne/simulation/__init__.pyi
Normal file
22
mne/simulation/__init__.pyi
Normal file
@@ -0,0 +1,22 @@
|
||||
__all__ = [
|
||||
"SourceSimulator",
|
||||
"add_chpi",
|
||||
"add_ecg",
|
||||
"add_eog",
|
||||
"add_noise",
|
||||
"metrics",
|
||||
"select_source_in_label",
|
||||
"simulate_evoked",
|
||||
"simulate_raw",
|
||||
"simulate_sparse_stc",
|
||||
"simulate_stc",
|
||||
]
|
||||
from . import metrics
|
||||
from .evoked import add_noise, simulate_evoked
|
||||
from .raw import add_chpi, add_ecg, add_eog, simulate_raw
|
||||
from .source import (
|
||||
SourceSimulator,
|
||||
select_source_in_label,
|
||||
simulate_sparse_stc,
|
||||
simulate_stc,
|
||||
)
|
||||
13
mne/simulation/_metrics.py
Normal file
13
mne/simulation/_metrics.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _check_stc(stc1, stc2):
|
||||
"""Check that stcs are compatible."""
|
||||
if stc1.data.shape != stc2.data.shape:
|
||||
raise ValueError("Data in stcs must have the same size")
|
||||
if np.all(stc1.times != stc2.times):
|
||||
raise ValueError("Times of two stcs must match.")
|
||||
186
mne/simulation/evoked.py
Normal file
186
mne/simulation/evoked.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from scipy.signal import lfilter
|
||||
|
||||
from .._fiff.pick import pick_info
|
||||
from ..cov import Covariance, compute_whitener
|
||||
from ..epochs import BaseEpochs
|
||||
from ..evoked import Evoked
|
||||
from ..forward import apply_forward
|
||||
from ..io import BaseRaw
|
||||
from ..utils import _check_preload, _validate_type, check_random_state, logger, verbose
|
||||
|
||||
|
||||
@verbose
|
||||
def simulate_evoked(
|
||||
fwd,
|
||||
stc,
|
||||
info,
|
||||
cov=None,
|
||||
nave=30,
|
||||
iir_filter=None,
|
||||
random_state=None,
|
||||
use_cps=True,
|
||||
verbose=None,
|
||||
):
|
||||
"""Generate noisy evoked data.
|
||||
|
||||
.. note:: No projections from ``info`` will be present in the
|
||||
output ``evoked``. You can use e.g.
|
||||
:func:`evoked.add_proj <mne.Evoked.add_proj>` or
|
||||
:func:`evoked.set_eeg_reference <mne.Evoked.set_eeg_reference>`
|
||||
to add them afterward as necessary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fwd : instance of Forward
|
||||
A forward solution.
|
||||
stc : SourceEstimate object
|
||||
The source time courses.
|
||||
%(info_not_none)s Used to generate the evoked.
|
||||
cov : Covariance object | None
|
||||
The noise covariance. If None, no noise is added.
|
||||
nave : int
|
||||
Number of averaged epochs (defaults to 30).
|
||||
|
||||
.. versionadded:: 0.15.0
|
||||
iir_filter : None | array
|
||||
IIR filter coefficients (denominator) e.g. [1, -1, 0.2].
|
||||
%(random_state)s
|
||||
%(use_cps)s
|
||||
|
||||
.. versionadded:: 0.15
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
evoked : Evoked object
|
||||
The simulated evoked data.
|
||||
|
||||
See Also
|
||||
--------
|
||||
simulate_raw
|
||||
simulate_stc
|
||||
simulate_sparse_stc
|
||||
|
||||
Notes
|
||||
-----
|
||||
To make the equivalence between snr and nave, when the snr is given
|
||||
instead of nave::
|
||||
|
||||
nave = (1 / 10 ** ((actual_snr - snr)) / 20) ** 2
|
||||
|
||||
where actual_snr is the snr to the generated noise before scaling.
|
||||
|
||||
.. versionadded:: 0.10.0
|
||||
"""
|
||||
evoked = apply_forward(fwd, stc, info, use_cps=use_cps)
|
||||
if cov is None:
|
||||
return evoked
|
||||
|
||||
if nave < np.inf:
|
||||
noise = _simulate_noise_evoked(evoked, cov, iir_filter, random_state)
|
||||
evoked.data += noise.data / math.sqrt(nave)
|
||||
evoked.nave = np.int64(nave)
|
||||
if cov.get("projs", None):
|
||||
evoked.add_proj(cov["projs"]).apply_proj()
|
||||
return evoked
|
||||
|
||||
|
||||
def _simulate_noise_evoked(evoked, cov, iir_filter, random_state):
|
||||
noise = evoked.copy()
|
||||
noise.data[:] = 0
|
||||
return _add_noise(noise, cov, iir_filter, random_state, allow_subselection=False)
|
||||
|
||||
|
||||
@verbose
|
||||
def add_noise(inst, cov, iir_filter=None, random_state=None, verbose=None):
|
||||
"""Create noise as a multivariate Gaussian.
|
||||
|
||||
The spatial covariance of the noise is given from the cov matrix.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst : instance of Evoked, Epochs, or Raw
|
||||
Instance to which to add noise.
|
||||
cov : instance of Covariance
|
||||
The noise covariance.
|
||||
iir_filter : None | array-like
|
||||
IIR filter coefficients (denominator).
|
||||
%(random_state)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
inst : instance of Evoked, Epochs, or Raw
|
||||
The instance, modified to have additional noise.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Only channels in both ``inst.info['ch_names']`` and
|
||||
``cov['names']`` will have noise added to them.
|
||||
|
||||
This function operates inplace on ``inst``.
|
||||
|
||||
.. versionadded:: 0.18.0
|
||||
"""
|
||||
# We always allow subselection here
|
||||
return _add_noise(inst, cov, iir_filter, random_state)
|
||||
|
||||
|
||||
def _add_noise(inst, cov, iir_filter, random_state, allow_subselection=True):
|
||||
"""Add noise, possibly with channel subselection."""
|
||||
_validate_type(cov, Covariance, "cov")
|
||||
_validate_type(
|
||||
inst, (BaseRaw, BaseEpochs, Evoked), "inst", "Raw, Epochs, or Evoked"
|
||||
)
|
||||
_check_preload(inst, "Adding noise")
|
||||
data = inst._data
|
||||
assert data.ndim in (2, 3)
|
||||
if data.ndim == 2:
|
||||
data = data[np.newaxis]
|
||||
# Subselect if necessary
|
||||
info = inst.info
|
||||
info._check_consistency()
|
||||
picks = gen_picks = slice(None)
|
||||
if allow_subselection:
|
||||
use_chs = list(set(info["ch_names"]) & set(cov["names"]))
|
||||
picks = np.where(np.isin(info["ch_names"], use_chs))[0]
|
||||
logger.info(
|
||||
"Adding noise to %d/%d channels (%d channels in cov)",
|
||||
len(picks),
|
||||
len(info["chs"]),
|
||||
len(cov["names"]),
|
||||
)
|
||||
info = pick_info(inst.info, picks)
|
||||
info._check_consistency()
|
||||
|
||||
gen_picks = np.arange(info["nchan"])
|
||||
for epoch in data:
|
||||
epoch[picks] += _generate_noise(
|
||||
info, cov, iir_filter, random_state, epoch.shape[1], picks=gen_picks
|
||||
)[0]
|
||||
return inst
|
||||
|
||||
|
||||
def _generate_noise(
|
||||
info, cov, iir_filter, random_state, n_samples, zi=None, picks=None
|
||||
):
|
||||
"""Create spatially colored and temporally IIR-filtered noise."""
|
||||
rng = check_random_state(random_state)
|
||||
_, _, colorer = compute_whitener(
|
||||
cov, info, pca=True, return_colorer=True, picks=picks, verbose=False
|
||||
)
|
||||
noise = np.dot(colorer, rng.standard_normal((colorer.shape[1], n_samples)))
|
||||
if iir_filter is not None:
|
||||
if zi is None:
|
||||
zi = np.zeros((len(colorer), len(iir_filter) - 1))
|
||||
noise, zf = lfilter([1], iir_filter, noise, axis=-1, zi=zi)
|
||||
else:
|
||||
zf = None
|
||||
return noise, zf
|
||||
20
mne/simulation/metrics/__init__.py
Normal file
20
mne/simulation/metrics/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
"""Metrics module for compute stc-based metrics."""
|
||||
|
||||
from .metrics import (
|
||||
cosine_score,
|
||||
region_localization_error,
|
||||
precision_score,
|
||||
recall_score,
|
||||
f1_score,
|
||||
roc_auc_score,
|
||||
peak_position_error,
|
||||
source_estimate_quantification,
|
||||
spatial_deviation_error,
|
||||
_thresholding,
|
||||
_check_threshold,
|
||||
_uniform_stc,
|
||||
)
|
||||
572
mne/simulation/metrics/metrics.py
Normal file
572
mne/simulation/metrics/metrics.py
Normal file
@@ -0,0 +1,572 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from ...utils import _check_option, _validate_type, fill_doc
|
||||
|
||||
|
||||
def _check_stc(stc1, stc2):
|
||||
"""Check that stcs are compatible."""
|
||||
if stc1.data.shape != stc2.data.shape:
|
||||
raise ValueError("Data in stcs must have the same size")
|
||||
if np.all(stc1.times != stc2.times):
|
||||
raise ValueError("Times of two stcs must match.")
|
||||
|
||||
|
||||
def source_estimate_quantification(stc1, stc2, metric="rms"):
|
||||
"""Calculate STC similarities across all sources and times.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stc1 : SourceEstimate
|
||||
First source estimate for comparison.
|
||||
stc2 : SourceEstimate
|
||||
Second source estimate for comparison.
|
||||
metric : str
|
||||
Metric to calculate, ``'rms'`` or ``'cosine'``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
score : float | array
|
||||
Calculated metric.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Metric calculation has multiple options:
|
||||
|
||||
* rms: Root mean square of difference between stc data matrices.
|
||||
* cosine: Normalized correlation of all elements in stc data matrices.
|
||||
|
||||
.. versionadded:: 0.10.0
|
||||
"""
|
||||
_check_option("metric", metric, ["rms", "cosine"])
|
||||
|
||||
# This is checking that the data are having the same size meaning
|
||||
# no comparison between distributed and sparse can be done so far.
|
||||
_check_stc(stc1, stc2)
|
||||
data1, data2 = stc1.data, stc2.data
|
||||
|
||||
# Calculate root mean square difference between two matrices
|
||||
if metric == "rms":
|
||||
score = np.sqrt(np.mean((data1 - data2) ** 2))
|
||||
# Calculate correlation coefficient between matrix elements
|
||||
elif metric == "cosine":
|
||||
score = 1.0 - _cosine(data1, data2)
|
||||
return score
|
||||
|
||||
|
||||
def _uniform_stc(stc1, stc2):
|
||||
"""Uniform vertices of two stcs.
|
||||
|
||||
This function returns the stcs with the same vertices by
|
||||
inserting zeros in data for missing vertices.
|
||||
"""
|
||||
if len(stc1.vertices) != len(stc2.vertices):
|
||||
raise ValueError(
|
||||
"Data in stcs must have the same number of vertices "
|
||||
f"components. Got {len(stc1.vertices)} != {len(stc2.vertices)}."
|
||||
)
|
||||
idx_start1 = 0
|
||||
idx_start2 = 0
|
||||
stc1 = stc1.copy()
|
||||
stc2 = stc2.copy()
|
||||
all_data1 = []
|
||||
all_data2 = []
|
||||
for i, (vert1, vert2) in enumerate(zip(stc1.vertices, stc2.vertices)):
|
||||
vert = np.union1d(vert1, vert2)
|
||||
data1 = np.zeros([len(vert), stc1.data.shape[1]])
|
||||
data2 = np.zeros([len(vert), stc2.data.shape[1]])
|
||||
data1[np.searchsorted(vert, vert1)] = stc1.data[
|
||||
idx_start1 : idx_start1 + len(vert1)
|
||||
]
|
||||
data2[np.searchsorted(vert, vert2)] = stc2.data[
|
||||
idx_start2 : idx_start2 + len(vert2)
|
||||
]
|
||||
idx_start1 += len(vert1)
|
||||
idx_start2 += len(vert2)
|
||||
stc1.vertices[i] = vert
|
||||
stc2.vertices[i] = vert
|
||||
all_data1.append(data1)
|
||||
all_data2.append(data2)
|
||||
|
||||
stc1._data = np.concatenate(all_data1, axis=0)
|
||||
stc2._data = np.concatenate(all_data2, axis=0)
|
||||
return stc1, stc2
|
||||
|
||||
|
||||
def _apply(func, stc_true, stc_est, per_sample):
|
||||
"""Apply metric to stcs.
|
||||
|
||||
Applies a metric to each pair of columns of stc_true and stc_est
|
||||
if per_sample is True. Otherwise it applies it to stc_true and stc_est
|
||||
directly.
|
||||
"""
|
||||
if per_sample:
|
||||
metric = np.empty(stc_true.data.shape[1]) # one value per time point
|
||||
for i in range(stc_true.data.shape[1]):
|
||||
metric[i] = func(stc_true.data[:, i : i + 1], stc_est.data[:, i : i + 1])
|
||||
else:
|
||||
metric = func(stc_true.data, stc_est.data)
|
||||
return metric
|
||||
|
||||
|
||||
def _thresholding(stc_true, stc_est, threshold):
|
||||
relative = isinstance(threshold, str)
|
||||
threshold = _check_threshold(threshold)
|
||||
if relative:
|
||||
if stc_true is not None:
|
||||
stc_true._data[
|
||||
np.abs(stc_true._data) <= threshold * np.max(np.abs(stc_true._data))
|
||||
] = 0.0
|
||||
stc_est._data[
|
||||
np.abs(stc_est._data) <= threshold * np.max(np.abs(stc_est._data))
|
||||
] = 0.0
|
||||
else:
|
||||
if stc_true is not None:
|
||||
stc_true._data[np.abs(stc_true._data) <= threshold] = 0.0
|
||||
stc_est._data[np.abs(stc_est._data) <= threshold] = 0.0
|
||||
return stc_true, stc_est
|
||||
|
||||
|
||||
def _cosine(x, y):
|
||||
p = x.ravel()
|
||||
q = y.ravel()
|
||||
p_norm = np.linalg.norm(p)
|
||||
q_norm = np.linalg.norm(q)
|
||||
if p_norm * q_norm:
|
||||
return (p.T @ q) / (p_norm * q_norm)
|
||||
elif p_norm == q_norm:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
@fill_doc
|
||||
def cosine_score(stc_true, stc_est, per_sample=True):
|
||||
"""Compute cosine similarity between 2 source estimates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(stc_true_metric)s
|
||||
%(stc_est_metric)s
|
||||
%(per_sample_metric)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
%(stc_metric)s
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 1.2
|
||||
"""
|
||||
stc_true, stc_est = _uniform_stc(stc_true, stc_est)
|
||||
metric = _apply(_cosine, stc_true, stc_est, per_sample=per_sample)
|
||||
return metric
|
||||
|
||||
|
||||
def _check_threshold(threshold):
|
||||
"""Accept a float or a string that ends with %."""
|
||||
_validate_type(threshold, ("numeric", str), "threshold")
|
||||
if isinstance(threshold, str):
|
||||
if not threshold.endswith("%"):
|
||||
raise ValueError(
|
||||
f'Threshold if a string must end with "%". Got {threshold}.'
|
||||
)
|
||||
threshold = float(threshold[:-1]) / 100.0
|
||||
threshold = float(threshold)
|
||||
if not 0 <= threshold <= 1:
|
||||
raise ValueError(
|
||||
"Threshold proportion must be between 0 and 1 (inclusive), but "
|
||||
f"got {threshold}"
|
||||
)
|
||||
return threshold
|
||||
|
||||
|
||||
def _abs_col_sum(x):
|
||||
return np.abs(x).sum(axis=1)
|
||||
|
||||
|
||||
def _dle(p, q, src, stc):
|
||||
"""Aux function to compute dipole localization error."""
|
||||
p = _abs_col_sum(p)
|
||||
q = _abs_col_sum(q)
|
||||
idx1 = np.nonzero(p)[0]
|
||||
idx2 = np.nonzero(q)[0]
|
||||
points = []
|
||||
for i in range(len(src)):
|
||||
points.append(src[i]["rr"][stc.vertices[i]])
|
||||
points = np.concatenate(points, axis=0)
|
||||
if len(idx1) and len(idx2):
|
||||
D = cdist(points[idx1], points[idx2])
|
||||
D_min_1 = np.min(D, axis=0)
|
||||
D_min_2 = np.min(D, axis=1)
|
||||
return (np.mean(D_min_1) + np.mean(D_min_2)) / 2.0
|
||||
else:
|
||||
return np.inf
|
||||
|
||||
|
||||
@fill_doc
|
||||
def region_localization_error(stc_true, stc_est, src, threshold="90%", per_sample=True):
|
||||
r"""Compute region localization error (RLE) between 2 source estimates.
|
||||
|
||||
.. math::
|
||||
|
||||
RLE = \frac{1}{2Q}\sum_{k \in I} \min_{l \in \hat{I}}{||r_k - r_l||} + \frac{1}{2\hat{Q}}\sum_{l \in \hat{I}} \min_{k \in I}{||r_k - r_l||}
|
||||
|
||||
where :math:`I` and :math:`\hat{I}` denote respectively the original and
|
||||
estimated indexes of active sources, :math:`Q` and :math:`\hat{Q}` are
|
||||
the numbers of original and estimated active sources.
|
||||
:math:`r_k` denotes the position of the k-th source dipole in space
|
||||
and :math:`||\cdot||` is an Euclidean norm in :math:`\mathbb{R}^3`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(stc_true_metric)s
|
||||
%(stc_est_metric)s
|
||||
src : instance of SourceSpaces
|
||||
The source space on which the source estimates are defined.
|
||||
threshold : float | str
|
||||
The threshold to apply to source estimates before computing
|
||||
the dipole localization error. If a string the threshold is
|
||||
a percentage and it should end with the percent character.
|
||||
%(per_sample_metric)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
%(stc_metric)s
|
||||
|
||||
Notes
|
||||
-----
|
||||
Papers :footcite:`MaksymenkoEtAl2017` and :footcite:`BeckerEtAl2017`
|
||||
use term Dipole Localization Error (DLE) for the same formula. Paper
|
||||
:footcite:`YaoEtAl2005` uses term Error Distance (ED) for the same formula.
|
||||
To unify the terminology and to avoid confusion with other cases
|
||||
of using term DLE but for different metric :footcite:`MolinsEtAl2008`, we
|
||||
use term Region Localization Error (RLE).
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
""" # noqa: E501
|
||||
stc_true, stc_est = _uniform_stc(stc_true, stc_est)
|
||||
stc_true, stc_est = _thresholding(stc_true, stc_est, threshold)
|
||||
func = partial(_dle, src=src, stc=stc_true)
|
||||
metric = _apply(func, stc_true, stc_est, per_sample=per_sample)
|
||||
return metric
|
||||
|
||||
|
||||
def _roc_auc_score(p, q):
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
return roc_auc_score(np.abs(p) > 0, np.abs(q))
|
||||
|
||||
|
||||
@fill_doc
|
||||
def roc_auc_score(stc_true, stc_est, per_sample=True):
|
||||
"""Compute ROC AUC between 2 source estimates.
|
||||
|
||||
ROC stands for receiver operating curve and AUC is Area under the curve.
|
||||
When computing this metric the stc_true must be thresholded
|
||||
as any non-zero value will be considered as a positive.
|
||||
|
||||
The ROC-AUC metric is computed between amplitudes of the source
|
||||
estimates, i.e. after taking the absolute values.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(stc_true_metric)s
|
||||
%(stc_est_metric)s
|
||||
%(per_sample_metric)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
%(stc_metric)s
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 1.2
|
||||
"""
|
||||
stc_true, stc_est = _uniform_stc(stc_true, stc_est)
|
||||
metric = _apply(_roc_auc_score, stc_true, stc_est, per_sample=per_sample)
|
||||
return metric
|
||||
|
||||
|
||||
def _f1_score(p, q):
|
||||
from sklearn.metrics import f1_score
|
||||
|
||||
return f1_score(_abs_col_sum(p) > 0, _abs_col_sum(q) > 0)
|
||||
|
||||
|
||||
@fill_doc
|
||||
def f1_score(stc_true, stc_est, threshold="90%", per_sample=True):
|
||||
"""Compute the F1 score, also known as balanced F-score or F-measure.
|
||||
|
||||
The F1 score can be interpreted as a weighted average of the precision
|
||||
and recall, where an F1 score reaches its best value at 1 and worst score
|
||||
at 0. The relative contribution of precision and recall to the F1
|
||||
score are equal.
|
||||
The formula for the F1 score is::
|
||||
|
||||
F1 = 2 * (precision * recall) / (precision + recall)
|
||||
|
||||
Threshold is used first for data binarization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(stc_true_metric)s
|
||||
%(stc_est_metric)s
|
||||
threshold : float | str
|
||||
The threshold to apply to source estimates before computing
|
||||
the f1 score. If a string the threshold is
|
||||
a percentage and it should end with the percent character.
|
||||
%(per_sample_metric)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
%(stc_metric)s
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 1.2
|
||||
"""
|
||||
stc_true, stc_est = _uniform_stc(stc_true, stc_est)
|
||||
stc_true, stc_est = _thresholding(stc_true, stc_est, threshold)
|
||||
metric = _apply(_f1_score, stc_true, stc_est, per_sample=per_sample)
|
||||
return metric
|
||||
|
||||
|
||||
def _precision_score(p, q):
|
||||
from sklearn.metrics import precision_score
|
||||
|
||||
return precision_score(_abs_col_sum(p) > 0, _abs_col_sum(q) > 0)
|
||||
|
||||
|
||||
@fill_doc
|
||||
def precision_score(stc_true, stc_est, threshold="90%", per_sample=True):
|
||||
"""Compute the precision.
|
||||
|
||||
The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
|
||||
true positives and ``fp`` the number of false positives. The precision is
|
||||
intuitively the ability of the classifier not to label as positive a sample
|
||||
that is negative.
|
||||
|
||||
The best value is 1 and the worst value is 0.
|
||||
|
||||
Threshold is used first for data binarization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(stc_true_metric)s
|
||||
%(stc_est_metric)s
|
||||
threshold : float | str
|
||||
The threshold to apply to source estimates before computing
|
||||
the precision. If a string the threshold is
|
||||
a percentage and it should end with the percent character.
|
||||
%(per_sample_metric)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
%(stc_metric)s
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 1.2
|
||||
"""
|
||||
stc_true, stc_est = _uniform_stc(stc_true, stc_est)
|
||||
stc_true, stc_est = _thresholding(stc_true, stc_est, threshold)
|
||||
metric = _apply(_precision_score, stc_true, stc_est, per_sample=per_sample)
|
||||
return metric
|
||||
|
||||
|
||||
def _recall_score(p, q):
|
||||
from sklearn.metrics import recall_score
|
||||
|
||||
return recall_score(_abs_col_sum(p) > 0, _abs_col_sum(q) > 0)
|
||||
|
||||
|
||||
@fill_doc
|
||||
def recall_score(stc_true, stc_est, threshold="90%", per_sample=True):
|
||||
"""Compute the recall.
|
||||
|
||||
The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of
|
||||
true positives and ``fn`` the number of false negatives. The recall is
|
||||
intuitively the ability of the classifier to find all the positive samples.
|
||||
|
||||
The best value is 1 and the worst value is 0.
|
||||
|
||||
Threshold is used first for data binarization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(stc_true_metric)s
|
||||
%(stc_est_metric)s
|
||||
threshold : float | str
|
||||
The threshold to apply to source estimates before computing
|
||||
the recall. If a string the threshold is
|
||||
a percentage and it should end with the percent character.
|
||||
%(per_sample_metric)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
%(stc_metric)s
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 1.2
|
||||
"""
|
||||
stc_true, stc_est = _uniform_stc(stc_true, stc_est)
|
||||
stc_true, stc_est = _thresholding(stc_true, stc_est, threshold)
|
||||
metric = _apply(_recall_score, stc_true, stc_est, per_sample=per_sample)
|
||||
return metric
|
||||
|
||||
|
||||
def _prepare_ppe_sd(stc_true, stc_est, src, threshold="50%"):
|
||||
stc_true = stc_true.copy()
|
||||
stc_est = stc_est.copy()
|
||||
n_dipoles = 0
|
||||
for i, v in enumerate(stc_true.vertices):
|
||||
if len(v):
|
||||
n_dipoles += len(v)
|
||||
r_true = src[i]["rr"][v]
|
||||
if n_dipoles != 1:
|
||||
raise ValueError(f"True source must contain only one dipole, got {n_dipoles}.")
|
||||
|
||||
_, stc_est = _thresholding(None, stc_est, threshold)
|
||||
|
||||
r_est = np.empty([0, 3])
|
||||
for i, v in enumerate(stc_est.vertices):
|
||||
if len(v):
|
||||
r_est = np.vstack([r_est, src[i]["rr"][v]])
|
||||
return stc_est, r_true, r_est
|
||||
|
||||
|
||||
def _peak_position_error(p, q, r_est, r_true):
|
||||
q = _abs_col_sum(q)
|
||||
if np.sum(q):
|
||||
q /= np.sum(q)
|
||||
r_est_mean = np.dot(q, r_est)
|
||||
return np.linalg.norm(r_est_mean - r_true)
|
||||
else:
|
||||
return np.inf
|
||||
|
||||
|
||||
@fill_doc
|
||||
def peak_position_error(stc_true, stc_est, src, threshold="50%", per_sample=True):
|
||||
r"""Compute the peak position error.
|
||||
|
||||
The peak position error measures the distance between the center-of-mass
|
||||
of the estimated and the true source.
|
||||
|
||||
.. math::
|
||||
|
||||
PPE = \| \dfrac{\sum_i|s_i|r_{i}}{\sum_i|s_i|}
|
||||
- r_{true}\|,
|
||||
|
||||
where :math:`r_{true}` is a true dipole position,
|
||||
:math:`r_i` and :math:`|s_i|` denote respectively the position
|
||||
and amplitude of i-th dipole in source estimate.
|
||||
|
||||
Threshold is used on estimated source for focusing the metric to strong
|
||||
amplitudes and omitting the low-amplitude values.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(stc_true_metric)s
|
||||
%(stc_est_metric)s
|
||||
src : instance of SourceSpaces
|
||||
The source space on which the source estimates are defined.
|
||||
threshold : float | str
|
||||
The threshold to apply to source estimates before computing
|
||||
the recall. If a string the threshold is
|
||||
a percentage and it should end with the percent character.
|
||||
%(per_sample_metric)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
%(stc_metric)s
|
||||
|
||||
Notes
|
||||
-----
|
||||
These metrics are documented in :footcite:`StenroosHauk2013` and
|
||||
:footcite:`LinEtAl2006a`.
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
stc_est, r_true, r_est = _prepare_ppe_sd(stc_true, stc_est, src, threshold)
|
||||
func = partial(_peak_position_error, r_est=r_est, r_true=r_true)
|
||||
metric = _apply(func, stc_true, stc_est, per_sample=per_sample)
|
||||
return metric
|
||||
|
||||
|
||||
def _spatial_deviation(p, q, r_est, r_true):
|
||||
q = _abs_col_sum(q)
|
||||
if np.sum(q):
|
||||
q /= np.sum(q)
|
||||
r_true_tile = np.tile(r_true, (r_est.shape[0], 1))
|
||||
r_diff = r_est - r_true_tile
|
||||
r_diff_norm = np.sum(r_diff**2, axis=1)
|
||||
return np.sqrt(np.dot(q, r_diff_norm))
|
||||
else:
|
||||
return np.inf
|
||||
|
||||
|
||||
@fill_doc
|
||||
def spatial_deviation_error(stc_true, stc_est, src, threshold="50%", per_sample=True):
|
||||
r"""Compute the spatial deviation.
|
||||
|
||||
The spatial deviation characterizes the spread of the estimate source
|
||||
around the true source.
|
||||
|
||||
.. math::
|
||||
|
||||
SD = \dfrac{\sum_i|s_i|\|r_{i} - r_{true}\|^2}{\sum_i|s_i|}.
|
||||
|
||||
where :math:`r_{true}` is a true dipole position,
|
||||
:math:`r_i` and :math:`|s_i|` denote respectively the position
|
||||
and amplitude of i-th dipole in source estimate.
|
||||
|
||||
Threshold is used on estimated source for focusing the metric to strong
|
||||
amplitudes and omitting the low-amplitude values.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(stc_true_metric)s
|
||||
%(stc_est_metric)s
|
||||
src : instance of SourceSpaces
|
||||
The source space on which the source estimates are defined.
|
||||
threshold : float | str
|
||||
The threshold to apply to source estimates before computing
|
||||
the recall. If a string the threshold is
|
||||
a percentage and it should end with the percent character.
|
||||
%(per_sample_metric)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
%(stc_metric)s
|
||||
|
||||
Notes
|
||||
-----
|
||||
These metrics are documented in :footcite:`StenroosHauk2013` and
|
||||
:footcite:`LinEtAl2006a`.
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
stc_est, r_true, r_est = _prepare_ppe_sd(stc_true, stc_est, src, threshold)
|
||||
func = partial(_spatial_deviation, r_est=r_est, r_true=r_true)
|
||||
metric = _apply(func, stc_true, stc_est, per_sample=per_sample)
|
||||
return metric
|
||||
875
mne/simulation/raw.py
Normal file
875
mne/simulation/raw.py
Normal file
@@ -0,0 +1,875 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.constants import FIFF
|
||||
from .._fiff.meas_info import Info
|
||||
from .._fiff.pick import pick_channels, pick_channels_forward, pick_info, pick_types
|
||||
from .._ola import _Interp2
|
||||
from ..bem import fit_sphere_to_headshape, make_sphere_model, read_bem_solution
|
||||
from ..chpi import (
|
||||
_get_hpi_initial_fit,
|
||||
get_chpi_info,
|
||||
head_pos_to_trans_rot_t,
|
||||
read_head_pos,
|
||||
)
|
||||
from ..cov import Covariance, make_ad_hoc_cov, read_cov
|
||||
from ..event import _get_stim_channel
|
||||
from ..forward import (
|
||||
_compute_forwards,
|
||||
_magnetic_dipole_field_vec,
|
||||
_merge_fwds,
|
||||
_prep_meg_channels,
|
||||
_prepare_for_forward,
|
||||
_stc_src_sel,
|
||||
_to_forward_dict,
|
||||
_transform_orig_meg_coils,
|
||||
convert_forward_solution,
|
||||
restrict_forward_to_stc,
|
||||
)
|
||||
from ..io import BaseRaw, RawArray
|
||||
from ..source_estimate import _BaseSourceEstimate
|
||||
from ..source_space._source_space import (
|
||||
_ensure_src,
|
||||
_set_source_space_vertices,
|
||||
setup_volume_source_space,
|
||||
)
|
||||
from ..surface import _CheckInside
|
||||
from ..transforms import _get_trans, transform_surface_to
|
||||
from ..utils import (
|
||||
_check_preload,
|
||||
_pl,
|
||||
_validate_type,
|
||||
check_random_state,
|
||||
logger,
|
||||
verbose,
|
||||
)
|
||||
from .source import SourceSimulator
|
||||
|
||||
|
||||
def _check_cov(info, cov):
|
||||
"""Check that the user provided a valid covariance matrix for the noise."""
|
||||
_validate_type(cov, (Covariance, None, dict, str, "path-like"), "cov")
|
||||
if isinstance(cov, Covariance) or cov is None:
|
||||
pass
|
||||
elif isinstance(cov, dict):
|
||||
cov = make_ad_hoc_cov(info, cov, verbose=False)
|
||||
else:
|
||||
if cov == "simple":
|
||||
cov = make_ad_hoc_cov(info, None, verbose=False)
|
||||
else:
|
||||
cov = read_cov(cov, verbose=False)
|
||||
return cov
|
||||
|
||||
|
||||
def _check_stc_iterable(stc, info):
|
||||
# 1. Check that our STC is iterable (or convert it to one using cycle)
|
||||
# 2. Do first iter so we can get the vertex subselection
|
||||
# 3. Get the list of verts, which must stay the same across iterations
|
||||
if isinstance(stc, _BaseSourceEstimate):
|
||||
stc = [stc]
|
||||
_validate_type(stc, Iterable, "SourceEstimate, tuple, or iterable")
|
||||
stc_enum = enumerate(stc)
|
||||
del stc
|
||||
|
||||
try:
|
||||
stc_counted = next(stc_enum)
|
||||
except StopIteration:
|
||||
raise RuntimeError("Iterable did not provide stc[0]")
|
||||
_, _, verts = _stc_data_event(stc_counted, 1, info["sfreq"])
|
||||
return stc_enum, stc_counted, verts
|
||||
|
||||
|
||||
def _log_ch(start, info, ch):
|
||||
"""Log channel information."""
|
||||
if ch is not None:
|
||||
extra, just, ch = " stored on channel:", 50, info["ch_names"][ch]
|
||||
else:
|
||||
extra, just, ch = " not stored", 0, ""
|
||||
logger.info((start + extra).ljust(just) + ch)
|
||||
|
||||
|
||||
def _check_head_pos(head_pos, info, first_samp, times=None):
|
||||
if head_pos is None: # use pos from info['dev_head_t']
|
||||
head_pos = dict()
|
||||
if isinstance(head_pos, str | Path | os.PathLike):
|
||||
head_pos = read_head_pos(head_pos)
|
||||
if isinstance(head_pos, np.ndarray): # can be head_pos quats
|
||||
head_pos = head_pos_to_trans_rot_t(head_pos)
|
||||
if isinstance(head_pos, tuple): # can be quats converted to trans, rot, t
|
||||
transs, rots, ts = head_pos
|
||||
first_time = first_samp / info["sfreq"]
|
||||
ts = ts - first_time # MF files need reref
|
||||
dev_head_ts = [
|
||||
np.r_[np.c_[r, t[:, np.newaxis]], [[0, 0, 0, 1]]]
|
||||
for r, t in zip(rots, transs)
|
||||
]
|
||||
del transs, rots
|
||||
elif isinstance(head_pos, dict):
|
||||
ts = np.array(list(head_pos.keys()), float)
|
||||
ts.sort()
|
||||
dev_head_ts = [head_pos[float(tt)] for tt in ts]
|
||||
else:
|
||||
raise TypeError(f"unknown head_pos type {type(head_pos)}")
|
||||
bad = ts < 0
|
||||
if bad.any():
|
||||
raise RuntimeError(
|
||||
f"All position times must be >= 0, found {bad.sum()}/{len(bad)}< 0"
|
||||
)
|
||||
if times is not None:
|
||||
bad = ts > times[-1]
|
||||
if bad.any():
|
||||
raise RuntimeError(
|
||||
f"All position times must be <= t_end ({times[-1]:0.1f} "
|
||||
f"s), found {bad.sum()}/{len(bad)} bad values (is this a split "
|
||||
"file?)"
|
||||
)
|
||||
# If it starts close to zero, make it zero (else unique(offset) fails)
|
||||
if len(ts) > 0 and ts[0] < (0.5 / info["sfreq"]):
|
||||
ts[0] = 0.0
|
||||
# If it doesn't start at zero, insert one at t=0
|
||||
elif len(ts) == 0 or ts[0] > 0:
|
||||
ts = np.r_[[0.0], ts]
|
||||
dev_head_ts.insert(0, info["dev_head_t"]["trans"])
|
||||
dev_head_ts = [
|
||||
{"trans": d, "to": info["dev_head_t"]["to"], "from": info["dev_head_t"]["from"]}
|
||||
for d in dev_head_ts
|
||||
]
|
||||
offsets = np.round(ts * info["sfreq"]).astype(int)
|
||||
assert np.array_equal(offsets, np.unique(offsets))
|
||||
assert len(offsets) == len(dev_head_ts)
|
||||
offsets = list(offsets)
|
||||
return dev_head_ts, offsets
|
||||
|
||||
|
||||
@verbose
|
||||
def simulate_raw(
|
||||
info,
|
||||
stc=None,
|
||||
trans=None,
|
||||
src=None,
|
||||
bem=None,
|
||||
head_pos=None,
|
||||
mindist=1.0,
|
||||
interp="cos2",
|
||||
n_jobs=None,
|
||||
use_cps=True,
|
||||
forward=None,
|
||||
first_samp=0,
|
||||
max_iter=10000,
|
||||
verbose=None,
|
||||
):
|
||||
"""Simulate raw data.
|
||||
|
||||
Head movements can optionally be simulated using the ``head_pos``
|
||||
parameter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(info_not_none)s Used for simulation.
|
||||
|
||||
.. versionchanged:: 0.18
|
||||
Support for :class:`mne.Info`.
|
||||
stc : iterable | SourceEstimate | SourceSimulator
|
||||
The source estimates to use to simulate data. Each must have the same
|
||||
sample rate as the raw data, and the vertices of all stcs in the
|
||||
iterable must match. Each entry in the iterable can also be a tuple of
|
||||
``(SourceEstimate, ndarray)`` to allow specifying the stim channel
|
||||
(e.g., STI001) data accompany the source estimate.
|
||||
See Notes for details.
|
||||
|
||||
.. versionchanged:: 0.18
|
||||
Support for tuple, iterable of tuple or `~mne.SourceEstimate`,
|
||||
or `~mne.simulation.SourceSimulator`.
|
||||
trans : dict | str | None
|
||||
Either a transformation filename (usually made using mne_analyze)
|
||||
or an info dict (usually opened using read_trans()).
|
||||
If string, an ending of ``.fif`` or ``.fif.gz`` will be assumed to
|
||||
be in FIF format, any other ending will be assumed to be a text
|
||||
file with a 4x4 transformation matrix (like the ``--trans`` MNE-C
|
||||
option). If trans is None, an identity transform will be used.
|
||||
src : path-like | instance of SourceSpaces | None
|
||||
Source space corresponding to the stc. If string, should be a source
|
||||
space filename. Can also be an instance of loaded or generated
|
||||
SourceSpaces. Can be None if ``forward`` is provided.
|
||||
bem : path-like | dict | None
|
||||
BEM solution corresponding to the stc. If string, should be a BEM
|
||||
solution filename (e.g., "sample-5120-5120-5120-bem-sol.fif").
|
||||
Can be None if ``forward`` is provided.
|
||||
%(head_pos)s
|
||||
See for example :footcite:`LarsonTaulu2017`.
|
||||
mindist : float
|
||||
Minimum distance between sources and the inner skull boundary
|
||||
to use during forward calculation.
|
||||
%(interp)s
|
||||
%(n_jobs)s
|
||||
%(use_cps)s
|
||||
forward : instance of Forward | None
|
||||
The forward operator to use. If None (default) it will be computed
|
||||
using ``bem``, ``trans``, and ``src``. If not None,
|
||||
``bem``, ``trans``, and ``src`` are ignored.
|
||||
|
||||
.. versionadded:: 0.17
|
||||
first_samp : int
|
||||
The first_samp property in the output Raw instance.
|
||||
|
||||
.. versionadded:: 0.18
|
||||
max_iter : int
|
||||
The maximum number of STC iterations to allow.
|
||||
This is a sanity parameter to prevent accidental blowups.
|
||||
|
||||
.. versionadded:: 0.18
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
raw : instance of Raw
|
||||
The simulated raw file.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.chpi.read_head_pos
|
||||
add_chpi
|
||||
add_noise
|
||||
add_ecg
|
||||
add_eog
|
||||
simulate_evoked
|
||||
simulate_stc
|
||||
simulate_sparse_stc
|
||||
|
||||
Notes
|
||||
-----
|
||||
**Stim channel encoding**
|
||||
|
||||
By default, the stimulus channel will have the head position number
|
||||
(starting at 1) stored in the trigger channel (if available) at the
|
||||
t=0 point in each repetition of the ``stc``. If ``stc`` is a tuple of
|
||||
``(SourceEstimate, ndarray)`` the array values will be placed in the
|
||||
stim channel aligned with the :class:`mne.SourceEstimate`.
|
||||
|
||||
**Data simulation**
|
||||
|
||||
In the most advanced case where ``stc`` is an iterable of tuples the output
|
||||
will be concatenated in time as:
|
||||
|
||||
.. table:: Data alignment and stim channel encoding
|
||||
|
||||
+---------+--------------------------+--------------------------+---------+
|
||||
| Channel | Data |
|
||||
+=========+==========================+==========================+=========+
|
||||
| M/EEG | ``fwd @ stc[0][0].data`` | ``fwd @ stc[1][0].data`` | ``...`` |
|
||||
+---------+--------------------------+--------------------------+---------+
|
||||
| STIM | ``stc[0][1]`` | ``stc[1][1]`` | ``...`` |
|
||||
+---------+--------------------------+--------------------------+---------+
|
||||
| | *time →* |
|
||||
+---------+--------------------------+--------------------------+---------+
|
||||
|
||||
.. versionadded:: 0.10.0
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
""" # noqa: E501
|
||||
_validate_type(info, Info, "info")
|
||||
|
||||
if len(pick_types(info, meg=False, stim=True)) == 0:
|
||||
event_ch = None
|
||||
else:
|
||||
event_ch = pick_channels(info["ch_names"], _get_stim_channel(None, info))[0]
|
||||
|
||||
if forward is not None:
|
||||
if any(x is not None for x in (trans, src, bem, head_pos)):
|
||||
raise ValueError(
|
||||
"If forward is not None then trans, src, bem, "
|
||||
"and head_pos must all be None"
|
||||
)
|
||||
if not np.allclose(
|
||||
forward["info"]["dev_head_t"]["trans"],
|
||||
info["dev_head_t"]["trans"],
|
||||
atol=1e-6,
|
||||
):
|
||||
raise ValueError(
|
||||
"The forward meg<->head transform "
|
||||
'forward["info"]["dev_head_t"] does not match '
|
||||
'the one in raw.info["dev_head_t"]'
|
||||
)
|
||||
src = forward["src"]
|
||||
|
||||
dev_head_ts, offsets = _check_head_pos(head_pos, info, first_samp, None)
|
||||
|
||||
src = _ensure_src(src, verbose=False)
|
||||
if isinstance(bem, str):
|
||||
bem = read_bem_solution(bem, verbose=False)
|
||||
|
||||
# Extract necessary info
|
||||
meeg_picks = pick_types(info, meg=True, eeg=True, exclude=[])
|
||||
logger.info(
|
||||
f"Setting up raw simulation: {len(dev_head_ts)} "
|
||||
f'position{_pl(dev_head_ts)}, "{interp}" interpolation'
|
||||
)
|
||||
|
||||
if isinstance(stc, SourceSimulator) and stc.first_samp != first_samp:
|
||||
logger.info("SourceSimulator first_samp does not match argument.")
|
||||
|
||||
stc_enum, stc_counted, verts = _check_stc_iterable(stc, info)
|
||||
if forward is not None:
|
||||
forward = restrict_forward_to_stc(forward, verts)
|
||||
src = forward["src"]
|
||||
else:
|
||||
_stc_src_sel(src, verts, on_missing="warn", extra="")
|
||||
src = _set_source_space_vertices(src.copy(), verts)
|
||||
|
||||
# array used to store result
|
||||
raw_datas = list()
|
||||
_log_ch("Event information", info, event_ch)
|
||||
# don't process these any more if no MEG present
|
||||
n = 1
|
||||
get_fwd = _SimForwards(
|
||||
dev_head_ts,
|
||||
offsets,
|
||||
info,
|
||||
trans,
|
||||
src,
|
||||
bem,
|
||||
mindist,
|
||||
n_jobs,
|
||||
meeg_picks,
|
||||
forward,
|
||||
use_cps,
|
||||
)
|
||||
interper = _Interp2(offsets, get_fwd, interp)
|
||||
|
||||
this_start = 0
|
||||
for n in range(max_iter):
|
||||
if isinstance(stc_counted[1], list | tuple):
|
||||
this_n = stc_counted[1][0].data.shape[1]
|
||||
else:
|
||||
this_n = stc_counted[1].data.shape[1]
|
||||
this_stop = this_start + this_n
|
||||
logger.info(
|
||||
f" Interval {this_start / info['sfreq']:0.3f}–"
|
||||
f"{this_stop / info['sfreq']:0.3f} s"
|
||||
)
|
||||
n_doing = this_stop - this_start
|
||||
assert n_doing > 0
|
||||
this_data = np.zeros((len(info["ch_names"]), n_doing))
|
||||
raw_datas.append(this_data)
|
||||
# Stim channel
|
||||
fwd, fi = interper.feed(this_stop - this_start)
|
||||
fi = fi[0]
|
||||
stc_data, stim_data, _ = _stc_data_event(
|
||||
stc_counted, fi, info["sfreq"], get_fwd.src, None if n == 0 else verts
|
||||
)
|
||||
if event_ch is not None:
|
||||
this_data[event_ch, :] = stim_data[:n_doing]
|
||||
this_data[meeg_picks] = np.einsum("svt,vt->st", fwd, stc_data)
|
||||
try:
|
||||
stc_counted = next(stc_enum)
|
||||
except StopIteration:
|
||||
logger.info(f" {n + 1} STC iteration{_pl(n + 1)} provided")
|
||||
break
|
||||
del fwd
|
||||
else:
|
||||
raise RuntimeError(f"Maximum number of STC iterations ({n}) exceeded")
|
||||
raw_data = np.concatenate(raw_datas, axis=-1)
|
||||
raw = RawArray(raw_data, info, first_samp=first_samp, verbose=False)
|
||||
raw.set_annotations(raw.annotations)
|
||||
logger.info("[done]")
|
||||
return raw
|
||||
|
||||
|
||||
@verbose
|
||||
def add_eog(
|
||||
raw, head_pos=None, interp="cos2", n_jobs=None, random_state=None, verbose=None
|
||||
):
|
||||
"""Add blink noise to raw data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The raw instance to modify.
|
||||
%(head_pos)s
|
||||
%(interp)s
|
||||
%(n_jobs)s
|
||||
%(random_state)s
|
||||
The random generator state used for blink, ECG, and sensor noise
|
||||
randomization.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
raw : instance of Raw
|
||||
The instance, modified in place.
|
||||
|
||||
See Also
|
||||
--------
|
||||
add_chpi
|
||||
add_ecg
|
||||
add_noise
|
||||
simulate_raw
|
||||
|
||||
Notes
|
||||
-----
|
||||
The blink artifacts are generated by:
|
||||
|
||||
1. Random activation times are drawn from an inhomogeneous poisson
|
||||
process whose blink rate oscillates between 4.5 blinks/minute
|
||||
and 17 blinks/minute based on the low (reading) and high (resting)
|
||||
blink rates from :footcite:`BentivoglioEtAl1997`.
|
||||
2. The activation kernel is a 250 ms Hanning window.
|
||||
3. Two activated dipoles are located in the z=0 plane (in head
|
||||
coordinates) at ±30 degrees away from the y axis (nasion).
|
||||
4. Activations affect MEG and EEG channels.
|
||||
|
||||
The scale-factor of the activation function was chosen based on
|
||||
visual inspection to yield amplitudes generally consistent with those
|
||||
seen in experimental data. Noisy versions of the activation will be
|
||||
stored in the first EOG channel in the raw instance, if it exists.
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
return _add_exg(raw, "blink", head_pos, interp, n_jobs, random_state)
|
||||
|
||||
|
||||
@verbose
|
||||
def add_ecg(
|
||||
raw, head_pos=None, interp="cos2", n_jobs=None, random_state=None, verbose=None
|
||||
):
|
||||
"""Add ECG noise to raw data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The raw instance to modify.
|
||||
%(head_pos)s
|
||||
%(interp)s
|
||||
%(n_jobs)s
|
||||
%(random_state)s
|
||||
The random generator state used for blink, ECG, and sensor noise
|
||||
randomization.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
raw : instance of Raw
|
||||
The instance, modified in place.
|
||||
|
||||
See Also
|
||||
--------
|
||||
add_chpi
|
||||
add_eog
|
||||
add_noise
|
||||
simulate_raw
|
||||
|
||||
Notes
|
||||
-----
|
||||
The ECG artifacts are generated by:
|
||||
|
||||
1. Random inter-beat intervals are drawn from a uniform distribution
|
||||
of times corresponding to 40 and 80 beats per minute.
|
||||
2. The activation function is the sum of three Hanning windows with
|
||||
varying durations and scales to make a more complex waveform.
|
||||
3. The activated dipole is located one (estimated) head radius to
|
||||
the left (-x) of head center and three head radii below (+z)
|
||||
head center; this dipole is oriented in the +x direction.
|
||||
4. Activations only affect MEG channels.
|
||||
|
||||
The scale-factor of the activation function was chosen based on
|
||||
visual inspection to yield amplitudes generally consistent with those
|
||||
seen in experimental data. Noisy versions of the activation will be
|
||||
stored in the first EOG channel in the raw instance, if it exists.
|
||||
|
||||
.. versionadded:: 0.18
|
||||
"""
|
||||
return _add_exg(raw, "ecg", head_pos, interp, n_jobs, random_state)
|
||||
|
||||
|
||||
def _add_exg(raw, kind, head_pos, interp, n_jobs, random_state):
|
||||
assert isinstance(kind, str) and kind in ("ecg", "blink")
|
||||
_validate_type(raw, BaseRaw, "raw")
|
||||
_check_preload(raw, f"Adding {kind} noise ")
|
||||
rng = check_random_state(random_state)
|
||||
info, times, first_samp = raw.info, raw.times, raw.first_samp
|
||||
data = raw._data
|
||||
meg_picks = pick_types(info, meg=True, eeg=False, exclude=())
|
||||
meeg_picks = pick_types(info, meg=True, eeg=True, exclude=())
|
||||
R, r0 = fit_sphere_to_headshape(info, units="m", verbose=False)[:2]
|
||||
bem = make_sphere_model(
|
||||
r0,
|
||||
head_radius=R,
|
||||
relative_radii=(0.97, 0.98, 0.99, 1.0),
|
||||
sigmas=(0.33, 1.0, 0.004, 0.33),
|
||||
verbose=False,
|
||||
)
|
||||
trans = None
|
||||
dev_head_ts, offsets = _check_head_pos(head_pos, info, first_samp, times)
|
||||
if kind == "blink":
|
||||
# place dipoles at 45 degree angles in z=0 plane
|
||||
exg_rr = np.array(
|
||||
[
|
||||
[np.cos(np.pi / 3.0), np.sin(np.pi / 3.0), 0.0],
|
||||
[-np.cos(np.pi / 3.0), np.sin(np.pi / 3), 0.0],
|
||||
]
|
||||
)
|
||||
exg_rr /= np.sqrt(np.sum(exg_rr * exg_rr, axis=1, keepdims=True))
|
||||
exg_rr *= 0.96 * R
|
||||
exg_rr += r0
|
||||
# oriented upward
|
||||
nn = np.array([[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]])
|
||||
# Blink times drawn from an inhomogeneous poisson process
|
||||
# by 1) creating the rate and 2) pulling random numbers
|
||||
blink_rate = (1 + np.cos(2 * np.pi * 1.0 / 60.0 * times)) / 2.0
|
||||
blink_rate *= 12.5 / 60.0
|
||||
blink_rate += 4.5 / 60.0
|
||||
blink_data = rng.uniform(size=len(times)) < blink_rate / info["sfreq"]
|
||||
blink_data = blink_data * (rng.uniform(size=len(times)) + 0.5) # amps
|
||||
# Activation kernel is a simple hanning window
|
||||
blink_kernel = np.hanning(int(0.25 * info["sfreq"]))
|
||||
exg_data = np.convolve(blink_data, blink_kernel, "same")[np.newaxis, :] * 1e-7
|
||||
# Add rescaled noisy data to EOG ch
|
||||
ch = pick_types(info, meg=False, eeg=False, eog=True)
|
||||
picks = meeg_picks
|
||||
del blink_kernel, blink_rate, blink_data
|
||||
else:
|
||||
if len(meg_picks) == 0:
|
||||
raise RuntimeError(
|
||||
"Can only add ECG artifacts if MEG data channels are present"
|
||||
)
|
||||
exg_rr = np.array([[-R, 0, -3 * R]])
|
||||
max_beats = int(np.ceil(times[-1] * 80.0 / 60.0))
|
||||
# activation times with intervals drawn from a uniform distribution
|
||||
# based on activation rates between 40 and 80 beats per minute
|
||||
cardiac_idx = np.cumsum(
|
||||
rng.uniform(60.0 / 80.0, 60.0 / 40.0, max_beats) * info["sfreq"]
|
||||
).astype(int)
|
||||
cardiac_idx = cardiac_idx[cardiac_idx < len(times)]
|
||||
cardiac_data = np.zeros(len(times))
|
||||
cardiac_data[cardiac_idx] = 1
|
||||
# kernel is the sum of three hanning windows
|
||||
cardiac_kernel = np.concatenate(
|
||||
[
|
||||
2 * np.hanning(int(0.04 * info["sfreq"])),
|
||||
-0.3 * np.hanning(int(0.05 * info["sfreq"])),
|
||||
0.2 * np.hanning(int(0.26 * info["sfreq"])),
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
exg_data = (
|
||||
np.convolve(cardiac_data, cardiac_kernel, "same")[np.newaxis, :] * 15e-8
|
||||
)
|
||||
# Add rescaled noisy data to ECG ch
|
||||
ch = pick_types(info, meg=False, eeg=False, ecg=True)
|
||||
picks = meg_picks
|
||||
del cardiac_data, cardiac_kernel, max_beats, cardiac_idx
|
||||
nn = np.zeros_like(exg_rr)
|
||||
nn[:, 0] = 1 # arbitrarily rightward
|
||||
del meg_picks, meeg_picks
|
||||
noise = rng.standard_normal(exg_data.shape[1]) * 5e-6
|
||||
if len(ch) >= 1:
|
||||
ch = ch[-1]
|
||||
data[ch, :] = exg_data * 1e3 + noise
|
||||
else:
|
||||
ch = None
|
||||
src = setup_volume_source_space(pos=dict(rr=exg_rr, nn=nn), sphere_units="mm")
|
||||
_log_ch(f"{kind} simulated and trace", info, ch)
|
||||
del ch, nn, noise
|
||||
|
||||
used = np.zeros(len(raw.times), bool)
|
||||
get_fwd = _SimForwards(
|
||||
dev_head_ts, offsets, info, trans, src, bem, 0.005, n_jobs, picks
|
||||
)
|
||||
interper = _Interp2(offsets, get_fwd, interp)
|
||||
proc_lims = np.concatenate([np.arange(0, len(used), 10000), [len(used)]])
|
||||
for start, stop in zip(proc_lims[:-1], proc_lims[1:]):
|
||||
fwd, _ = interper.feed(stop - start)
|
||||
data[picks, start:stop] += np.einsum("svt,vt->st", fwd, exg_data[:, start:stop])
|
||||
assert not used[start:stop].any()
|
||||
used[start:stop] = True
|
||||
assert used.all()
|
||||
|
||||
|
||||
@verbose
|
||||
def add_chpi(raw, head_pos=None, interp="cos2", n_jobs=None, verbose=None):
|
||||
"""Add cHPI activations to raw data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The raw instance to be modified.
|
||||
%(head_pos)s
|
||||
%(interp)s
|
||||
%(n_jobs)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
raw : instance of Raw
|
||||
The instance, modified in place.
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 0.18
|
||||
"""
|
||||
_validate_type(raw, BaseRaw, "raw")
|
||||
_check_preload(raw, "Adding cHPI signals ")
|
||||
info, first_samp, times = raw.info, raw.first_samp, raw.times
|
||||
meg_picks = pick_types(info, meg=True, eeg=False, exclude=[]) # for CHPI
|
||||
if len(meg_picks) == 0:
|
||||
raise RuntimeError("Cannot add cHPI if no MEG picks are present")
|
||||
dev_head_ts, offsets = _check_head_pos(head_pos, info, first_samp, times)
|
||||
hpi_freqs, hpi_pick, hpi_ons = get_chpi_info(info, on_missing="raise")
|
||||
hpi_rrs = _get_hpi_initial_fit(info, verbose="error")
|
||||
hpi_nns = hpi_rrs / np.sqrt(np.sum(hpi_rrs * hpi_rrs, axis=1))[:, np.newaxis]
|
||||
# turn on cHPI in file
|
||||
data = raw._data
|
||||
data[hpi_pick, :] = hpi_ons.sum()
|
||||
_log_ch("cHPI status bits enabled and", info, hpi_pick)
|
||||
sinusoids = 70e-9 * np.sin(
|
||||
2 * np.pi * hpi_freqs[:, np.newaxis] * (np.arange(len(times)) / info["sfreq"])
|
||||
)
|
||||
info = pick_info(info, meg_picks)
|
||||
with info._unlock():
|
||||
info.update(projs=[], bads=[]) # Ensure no 'projs' or 'bads'
|
||||
megcoils = _prep_meg_channels(info, ignore_ref=True)["defs"]
|
||||
used = np.zeros(len(raw.times), bool)
|
||||
dev_head_ts.append(dev_head_ts[-1]) # ZOH after time ends
|
||||
get_fwd = _HPIForwards(offsets, dev_head_ts, megcoils, hpi_rrs, hpi_nns)
|
||||
interper = _Interp2(offsets, get_fwd, interp)
|
||||
lims = np.concatenate([offsets, [len(raw.times)]])
|
||||
for start, stop in zip(lims[:-1], lims[1:]):
|
||||
(fwd,) = interper.feed(stop - start)
|
||||
data[meg_picks, start:stop] += np.einsum(
|
||||
"svt,vt->st", fwd, sinusoids[:, start:stop]
|
||||
)
|
||||
assert not used[start:stop].any()
|
||||
used[start:stop] = True
|
||||
assert used.all()
|
||||
return raw
|
||||
|
||||
|
||||
class _HPIForwards:
|
||||
def __init__(self, offsets, dev_head_ts, megcoils, hpi_rrs, hpi_nns):
|
||||
self.offsets = offsets
|
||||
self.dev_head_ts = dev_head_ts
|
||||
self.hpi_rrs = hpi_rrs
|
||||
self.hpi_nns = hpi_nns
|
||||
self.megcoils = megcoils
|
||||
self.idx = 0
|
||||
|
||||
def __call__(self, offset):
|
||||
assert offset == self.offsets[self.idx]
|
||||
_transform_orig_meg_coils(self.megcoils, self.dev_head_ts[self.idx])
|
||||
fwd = _magnetic_dipole_field_vec(self.hpi_rrs, self.megcoils).T
|
||||
# align cHPI magnetic dipoles in approx. radial direction
|
||||
fwd = np.array(
|
||||
[
|
||||
np.dot(fwd[:, 3 * ii : 3 * (ii + 1)], self.hpi_nns[ii])
|
||||
for ii in range(len(self.hpi_rrs))
|
||||
]
|
||||
).T
|
||||
self.idx += 1
|
||||
return (fwd,)
|
||||
|
||||
|
||||
def _stc_data_event(stc_counted, head_idx, sfreq, src=None, verts=None):
|
||||
stc_idx, stc = stc_counted
|
||||
if isinstance(stc, list | tuple):
|
||||
if len(stc) != 2:
|
||||
raise ValueError(f"stc, if tuple, must be length 2, got {len(stc)}")
|
||||
stc, stim_data = stc
|
||||
else:
|
||||
stim_data = None
|
||||
_validate_type(
|
||||
stc,
|
||||
_BaseSourceEstimate,
|
||||
"stc",
|
||||
"SourceEstimate or tuple with first entry SourceEstimate",
|
||||
)
|
||||
# Convert event data
|
||||
if stim_data is None:
|
||||
stim_data = np.zeros(len(stc.times), int)
|
||||
stim_data[np.argmin(np.abs(stc.times))] = head_idx
|
||||
del head_idx
|
||||
_validate_type(stim_data, np.ndarray, "stim_data")
|
||||
if stim_data.dtype.kind != "i":
|
||||
raise ValueError(
|
||||
"stim_data in a stc tuple must be an integer ndarray,"
|
||||
f" got dtype {stim_data.dtype}"
|
||||
)
|
||||
if stim_data.shape != (len(stc.times),):
|
||||
raise ValueError(
|
||||
f"event data had shape {stim_data.shape} but needed to "
|
||||
f"be ({len(stc.times)},) tomatch stc"
|
||||
)
|
||||
# Validate STC
|
||||
if not np.allclose(sfreq, 1.0 / stc.tstep):
|
||||
raise ValueError(
|
||||
f"stc and info must have same sample rate, "
|
||||
f"got {1.0 / stc.tstep} and {sfreq}"
|
||||
)
|
||||
if len(stc.times) <= 2: # to ensure event encoding works
|
||||
raise ValueError(
|
||||
f"stc must have at least three time points, got {len(stc.times)}"
|
||||
)
|
||||
verts_ = stc.vertices
|
||||
if verts is None:
|
||||
assert stc_idx == 0
|
||||
else:
|
||||
if len(verts) != len(verts_) or not all(
|
||||
np.array_equal(a, b) for a, b in zip(verts, verts_)
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Vertex mismatch for stc[{stc_idx}], all stc.vertices must match"
|
||||
)
|
||||
stc_data = stc.data
|
||||
if src is None:
|
||||
assert stc_idx == 0
|
||||
else:
|
||||
# on_missing depends on whether or not this is the first iteration
|
||||
on_missing = "warn" if verts is None else "ignore"
|
||||
_, stc_sel, _ = _stc_src_sel(src, stc, on_missing=on_missing)
|
||||
stc_data = stc_data[stc_sel]
|
||||
return stc_data, stim_data, verts_
|
||||
|
||||
|
||||
class _SimForwards:
|
||||
def __init__(
|
||||
self,
|
||||
dev_head_ts,
|
||||
offsets,
|
||||
info,
|
||||
trans,
|
||||
src,
|
||||
bem,
|
||||
mindist,
|
||||
n_jobs,
|
||||
meeg_picks,
|
||||
forward=None,
|
||||
use_cps=True,
|
||||
):
|
||||
self.idx = 0
|
||||
self.offsets = offsets
|
||||
self.use_cps = use_cps
|
||||
self.iter = iter(
|
||||
_iter_forward_solutions(
|
||||
info, trans, src, bem, dev_head_ts, mindist, n_jobs, forward, meeg_picks
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, offset):
|
||||
assert self.offsets[self.idx] == offset
|
||||
self.idx += 1
|
||||
fwd = next(self.iter)
|
||||
self.src = fwd["src"]
|
||||
# XXX eventually we could speed this up by allowing the forward
|
||||
# solution code to only compute the normal direction
|
||||
convert_forward_solution(
|
||||
fwd,
|
||||
surf_ori=True,
|
||||
force_fixed=True,
|
||||
use_cps=self.use_cps,
|
||||
copy=False,
|
||||
verbose=False,
|
||||
)
|
||||
return fwd["sol"]["data"], np.array(self.idx, float)
|
||||
|
||||
|
||||
def _iter_forward_solutions(
|
||||
info, trans, src, bem, dev_head_ts, mindist, n_jobs, forward, picks
|
||||
):
|
||||
"""Calculate a forward solution for a subject."""
|
||||
logger.info("Setting up forward solutions")
|
||||
info = pick_info(info, picks)
|
||||
with info._unlock():
|
||||
info.update(projs=[], bads=[]) # Ensure no 'projs' or 'bads'
|
||||
mri_head_t, trans = _get_trans(trans)
|
||||
sensors, rr, info, update_kwargs, bem = _prepare_for_forward(
|
||||
src, mri_head_t, info, bem, mindist, n_jobs, allow_bem_none=True, verbose=False
|
||||
)
|
||||
del (src, mindist)
|
||||
|
||||
eegnames = sensors.get("eeg", dict()).get("ch_names", [])
|
||||
if not len(eegnames):
|
||||
eegfwd = None
|
||||
elif forward is not None:
|
||||
eegfwd = pick_channels_forward(forward, eegnames, verbose=False)
|
||||
else:
|
||||
eegels = sensors.get("eeg", dict()).get("defs", [])
|
||||
this_sensors = dict(eeg=dict(ch_names=eegnames, defs=eegels))
|
||||
eegfwd = _compute_forwards(
|
||||
rr, bem=bem, sensors=this_sensors, n_jobs=n_jobs, verbose=False
|
||||
)["eeg"]
|
||||
eegfwd = _to_forward_dict(eegfwd, eegnames)
|
||||
del eegels
|
||||
del eegnames
|
||||
|
||||
# short circuit here if there are no MEG channels (don't need to iterate)
|
||||
if "meg" not in sensors:
|
||||
eegfwd.update(**update_kwargs)
|
||||
for _ in dev_head_ts:
|
||||
yield eegfwd
|
||||
yield eegfwd
|
||||
return
|
||||
|
||||
coord_frame = FIFF.FIFFV_COORD_HEAD
|
||||
if bem is not None and not bem["is_sphere"]:
|
||||
idx = np.where(
|
||||
np.array([s["id"] for s in bem["surfs"]]) == FIFF.FIFFV_BEM_SURF_ID_BRAIN
|
||||
)[0]
|
||||
assert len(idx) == 1
|
||||
# make a copy so it isn't mangled in use
|
||||
bem_surf = transform_surface_to(
|
||||
bem["surfs"][idx[0]], coord_frame, mri_head_t, copy=True
|
||||
)
|
||||
megcoils = sensors["meg"]["defs"]
|
||||
if "eeg" in sensors:
|
||||
del sensors["eeg"]
|
||||
megnames = sensors["meg"]["ch_names"]
|
||||
fwds = dict()
|
||||
if eegfwd is not None:
|
||||
fwds["eeg"] = eegfwd
|
||||
del eegfwd
|
||||
for ti, dev_head_t in enumerate(dev_head_ts):
|
||||
# Could be *slightly* more efficient not to do this N times,
|
||||
# but the cost here is tiny compared to actual fwd calculation
|
||||
logger.info(f"Computing gain matrix for transform #{ti + 1}/{len(dev_head_ts)}")
|
||||
_transform_orig_meg_coils(megcoils, dev_head_t)
|
||||
|
||||
# Make sure our sensors are all outside our BEM
|
||||
coil_rr = np.array([coil["r0"] for coil in megcoils])
|
||||
|
||||
# Compute forward
|
||||
if forward is None:
|
||||
if not bem["is_sphere"]:
|
||||
outside = ~_CheckInside(bem_surf)(coil_rr, n_jobs, verbose=False)
|
||||
elif bem.radius is not None:
|
||||
d = coil_rr - bem["r0"]
|
||||
outside = np.sqrt(np.sum(d * d, axis=1)) > bem.radius
|
||||
else: # only r0 provided
|
||||
outside = np.ones(len(coil_rr), bool)
|
||||
if not outside.all():
|
||||
raise RuntimeError(
|
||||
f"{np.sum(~outside)} MEG sensors collided with inner skull "
|
||||
f"surface for transform {ti}"
|
||||
)
|
||||
megfwd = _compute_forwards(
|
||||
rr, sensors=sensors, bem=bem, n_jobs=n_jobs, verbose=False
|
||||
)["meg"]
|
||||
megfwd = _to_forward_dict(megfwd, megnames)
|
||||
else:
|
||||
megfwd = pick_channels_forward(forward, megnames, verbose=False)
|
||||
fwds["meg"] = megfwd
|
||||
fwd = _merge_fwds(fwds, verbose=False)
|
||||
fwd.update(**update_kwargs)
|
||||
|
||||
yield fwd
|
||||
# need an extra one to fill last buffer
|
||||
yield fwd
|
||||
589
mne/simulation/source.py
Normal file
589
mne/simulation/source.py
Normal file
@@ -0,0 +1,589 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..fixes import rng_uniform
|
||||
from ..label import Label
|
||||
from ..source_estimate import SourceEstimate, VolSourceEstimate
|
||||
from ..source_space._source_space import _ensure_src
|
||||
from ..surface import _compute_nearest
|
||||
from ..utils import (
|
||||
_check_option,
|
||||
_ensure_events,
|
||||
_ensure_int,
|
||||
_validate_type,
|
||||
check_random_state,
|
||||
fill_doc,
|
||||
warn,
|
||||
)
|
||||
|
||||
|
||||
@fill_doc
|
||||
def select_source_in_label(
|
||||
src,
|
||||
label,
|
||||
random_state=None,
|
||||
location="random",
|
||||
subject=None,
|
||||
subjects_dir=None,
|
||||
surf="sphere",
|
||||
):
|
||||
"""Select source positions using a label.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
src : list of dict
|
||||
The source space.
|
||||
label : Label
|
||||
The label.
|
||||
%(random_state)s
|
||||
location : str
|
||||
The label location to choose. Can be 'random' (default) or 'center'
|
||||
to use :func:`mne.Label.center_of_mass` (restricting to vertices
|
||||
both in the label and in the source space). Note that for 'center'
|
||||
mode the label values are used as weights.
|
||||
|
||||
.. versionadded:: 0.13
|
||||
subject : str | None
|
||||
The subject the label is defined for.
|
||||
Only used with ``location='center'``.
|
||||
|
||||
.. versionadded:: 0.13
|
||||
%(subjects_dir)s
|
||||
|
||||
.. versionadded:: 0.13
|
||||
surf : str
|
||||
The surface to use for Euclidean distance center of mass
|
||||
finding. The default here is "sphere", which finds the center
|
||||
of mass on the spherical surface to help avoid potential issues
|
||||
with cortical folding.
|
||||
|
||||
.. versionadded:: 0.13
|
||||
|
||||
Returns
|
||||
-------
|
||||
lh_vertno : list
|
||||
Selected source coefficients on the left hemisphere.
|
||||
rh_vertno : list
|
||||
Selected source coefficients on the right hemisphere.
|
||||
"""
|
||||
lh_vertno = list()
|
||||
rh_vertno = list()
|
||||
_check_option("location", location, ["random", "center"])
|
||||
|
||||
rng = check_random_state(random_state)
|
||||
if label.hemi == "lh":
|
||||
vertno = lh_vertno
|
||||
hemi_idx = 0
|
||||
else:
|
||||
vertno = rh_vertno
|
||||
hemi_idx = 1
|
||||
src_sel = np.intersect1d(src[hemi_idx]["vertno"], label.vertices)
|
||||
if location == "random":
|
||||
idx = src_sel[rng_uniform(rng)(0, len(src_sel), 1)[0]]
|
||||
else: # 'center'
|
||||
idx = label.center_of_mass(
|
||||
subject, restrict_vertices=src_sel, subjects_dir=subjects_dir, surf=surf
|
||||
)
|
||||
vertno.append(idx)
|
||||
return lh_vertno, rh_vertno
|
||||
|
||||
|
||||
@fill_doc
|
||||
def simulate_sparse_stc(
|
||||
src,
|
||||
n_dipoles,
|
||||
times,
|
||||
data_fun=lambda t: 1e-7 * np.sin(20 * np.pi * t),
|
||||
labels=None,
|
||||
random_state=None,
|
||||
location="random",
|
||||
subject=None,
|
||||
subjects_dir=None,
|
||||
surf="sphere",
|
||||
):
|
||||
"""Generate sparse (n_dipoles) sources time courses from data_fun.
|
||||
|
||||
This function randomly selects ``n_dipoles`` vertices in the whole
|
||||
cortex or one single vertex (randomly in or in the center of) each
|
||||
label if ``labels is not None``. It uses ``data_fun`` to generate
|
||||
waveforms for each vertex.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
src : instance of SourceSpaces
|
||||
The source space.
|
||||
n_dipoles : int
|
||||
Number of dipoles to simulate.
|
||||
times : array
|
||||
Time array.
|
||||
data_fun : callable
|
||||
Function to generate the waveforms. The default is a 100 nAm, 10 Hz
|
||||
sinusoid as ``1e-7 * np.sin(20 * pi * t)``. The function should take
|
||||
as input the array of time samples in seconds and return an array of
|
||||
the same length containing the time courses.
|
||||
labels : None | list of Label
|
||||
The labels. The default is None, otherwise its size must be n_dipoles.
|
||||
%(random_state)s
|
||||
location : str
|
||||
The label location to choose. Can be ``'random'`` (default) or
|
||||
``'center'`` to use :func:`mne.Label.center_of_mass`. Note that for
|
||||
``'center'`` mode the label values are used as weights.
|
||||
|
||||
.. versionadded:: 0.13
|
||||
subject : str | None
|
||||
The subject the label is defined for.
|
||||
Only used with ``location='center'``.
|
||||
|
||||
.. versionadded:: 0.13
|
||||
%(subjects_dir)s
|
||||
|
||||
.. versionadded:: 0.13
|
||||
surf : str
|
||||
The surface to use for Euclidean distance center of mass
|
||||
finding. The default here is "sphere", which finds the center
|
||||
of mass on the spherical surface to help avoid potential issues
|
||||
with cortical folding.
|
||||
|
||||
.. versionadded:: 0.13
|
||||
|
||||
Returns
|
||||
-------
|
||||
stc : SourceEstimate
|
||||
The generated source time courses.
|
||||
|
||||
See Also
|
||||
--------
|
||||
simulate_raw
|
||||
simulate_evoked
|
||||
simulate_stc
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 0.10.0
|
||||
"""
|
||||
rng = check_random_state(random_state)
|
||||
src = _ensure_src(src, verbose=False)
|
||||
subject_src = src._subject
|
||||
if subject is None:
|
||||
subject = subject_src
|
||||
elif subject_src is not None and subject != subject_src:
|
||||
raise ValueError(
|
||||
f"subject argument ({subject}) did not match the source "
|
||||
f"space subject_his_id ({subject_src})"
|
||||
)
|
||||
data = np.zeros((n_dipoles, len(times)))
|
||||
for i_dip in range(n_dipoles):
|
||||
data[i_dip, :] = data_fun(times)
|
||||
|
||||
if labels is None:
|
||||
# can be vol or surface source space
|
||||
offsets = np.linspace(0, n_dipoles, len(src) + 1).astype(int)
|
||||
n_dipoles_ss = np.diff(offsets)
|
||||
# don't use .choice b/c not on old numpy
|
||||
vs = [
|
||||
s["vertno"][np.sort(rng.permutation(np.arange(s["nuse"]))[:n])]
|
||||
for n, s in zip(n_dipoles_ss, src)
|
||||
]
|
||||
datas = data
|
||||
elif n_dipoles > len(labels):
|
||||
raise ValueError(
|
||||
f"Number of labels ({len(labels)}) smaller than n_dipoles ({n_dipoles:d}) "
|
||||
"is not allowed."
|
||||
)
|
||||
else:
|
||||
if n_dipoles != len(labels):
|
||||
warn(
|
||||
"The number of labels is different from the number of "
|
||||
f"dipoles. {min(n_dipoles, len(labels))} dipole(s) will be generated."
|
||||
)
|
||||
labels = labels[:n_dipoles] if n_dipoles < len(labels) else labels
|
||||
|
||||
vertno = [[], []]
|
||||
lh_data = [np.empty((0, data.shape[1]))]
|
||||
rh_data = [np.empty((0, data.shape[1]))]
|
||||
for i, label in enumerate(labels):
|
||||
lh_vertno, rh_vertno = select_source_in_label(
|
||||
src, label, rng, location, subject, subjects_dir, surf
|
||||
)
|
||||
vertno[0] += lh_vertno
|
||||
vertno[1] += rh_vertno
|
||||
if len(lh_vertno) != 0:
|
||||
lh_data.append(data[i][np.newaxis])
|
||||
elif len(rh_vertno) != 0:
|
||||
rh_data.append(data[i][np.newaxis])
|
||||
else:
|
||||
raise ValueError("No vertno found.")
|
||||
vs = [np.array(v) for v in vertno]
|
||||
datas = [np.concatenate(d) for d in [lh_data, rh_data]]
|
||||
# need to sort each hemi by vertex number
|
||||
for ii in range(2):
|
||||
order = np.argsort(vs[ii])
|
||||
vs[ii] = vs[ii][order]
|
||||
if len(order) > 0: # fix for old numpy
|
||||
datas[ii] = datas[ii][order]
|
||||
datas = np.concatenate(datas)
|
||||
|
||||
tmin, tstep = times[0], np.diff(times[:2])[0]
|
||||
assert datas.shape == data.shape
|
||||
cls = SourceEstimate if len(vs) == 2 else VolSourceEstimate
|
||||
stc = cls(datas, vertices=vs, tmin=tmin, tstep=tstep, subject=subject)
|
||||
return stc
|
||||
|
||||
|
||||
def simulate_stc(
|
||||
src, labels, stc_data, tmin, tstep, value_fun=None, allow_overlap=False
|
||||
):
|
||||
"""Simulate sources time courses from waveforms and labels.
|
||||
|
||||
This function generates a source estimate with extended sources by
|
||||
filling the labels with the waveforms given in stc_data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
src : instance of SourceSpaces
|
||||
The source space.
|
||||
labels : list of Label
|
||||
The labels.
|
||||
stc_data : array, shape (n_labels, n_times)
|
||||
The waveforms.
|
||||
tmin : float
|
||||
The beginning of the timeseries.
|
||||
tstep : float
|
||||
The time step (1 / sampling frequency).
|
||||
value_fun : callable | None
|
||||
Function to apply to the label values to obtain the waveform
|
||||
scaling for each vertex in the label. If None (default), uniform
|
||||
scaling is used.
|
||||
allow_overlap : bool
|
||||
Allow overlapping labels or not. Default value is False.
|
||||
|
||||
.. versionadded:: 0.18
|
||||
|
||||
Returns
|
||||
-------
|
||||
stc : SourceEstimate
|
||||
The generated source time courses.
|
||||
|
||||
See Also
|
||||
--------
|
||||
simulate_raw
|
||||
simulate_evoked
|
||||
simulate_sparse_stc
|
||||
"""
|
||||
if len(labels) != len(stc_data):
|
||||
raise ValueError("labels and stc_data must have the same length")
|
||||
|
||||
vertno = [[], []]
|
||||
stc_data_extended = [[], []]
|
||||
hemi_to_ind = {"lh": 0, "rh": 1}
|
||||
for i, label in enumerate(labels):
|
||||
hemi_ind = hemi_to_ind[label.hemi]
|
||||
src_sel = np.intersect1d(src[hemi_ind]["vertno"], label.vertices)
|
||||
if len(src_sel) == 0:
|
||||
idx = src[hemi_ind]["inuse"].astype("bool")
|
||||
xhs = src[hemi_ind]["rr"][idx]
|
||||
rr = src[hemi_ind]["rr"][label.vertices]
|
||||
closest_src = _compute_nearest(xhs, rr)
|
||||
src_sel = src[hemi_ind]["vertno"][np.unique(closest_src)]
|
||||
|
||||
if value_fun is not None:
|
||||
idx_sel = np.searchsorted(label.vertices, src_sel)
|
||||
values_sel = np.array([value_fun(v) for v in label.values[idx_sel]])
|
||||
|
||||
data = np.outer(values_sel, stc_data[i])
|
||||
else:
|
||||
data = np.tile(stc_data[i], (len(src_sel), 1))
|
||||
# If overlaps are allowed, deal with them
|
||||
if allow_overlap:
|
||||
# Search for duplicate vertex indices
|
||||
# in the existing vertex matrix vertex.
|
||||
duplicates = []
|
||||
for src_ind, vertex_ind in enumerate(src_sel):
|
||||
ind = np.where(vertex_ind == vertno[hemi_ind])[0]
|
||||
if len(ind) > 0:
|
||||
assert len(ind) == 1
|
||||
# Add the new data to the existing one
|
||||
stc_data_extended[hemi_ind][ind[0]] += data[src_ind]
|
||||
duplicates.append(src_ind)
|
||||
# Remove the duplicates from both data and selected vertices
|
||||
data = np.delete(data, duplicates, axis=0)
|
||||
src_sel = list(np.delete(np.array(src_sel), duplicates))
|
||||
# Extend the existing list instead of appending it so that we can
|
||||
# index its elements
|
||||
vertno[hemi_ind].extend(src_sel)
|
||||
stc_data_extended[hemi_ind].extend(np.atleast_2d(data))
|
||||
|
||||
vertno = [np.array(v) for v in vertno]
|
||||
if not allow_overlap:
|
||||
for v, hemi in zip(vertno, ("left", "right")):
|
||||
d = len(v) - len(np.unique(v))
|
||||
if d > 0:
|
||||
raise RuntimeError(
|
||||
f"Labels had {d} overlaps in the {hemi} "
|
||||
"hemisphere, they must be non-overlapping"
|
||||
)
|
||||
# the data is in the order left, right
|
||||
data = list()
|
||||
for i in range(2):
|
||||
if len(stc_data_extended[i]) != 0:
|
||||
stc_data_extended[i] = np.vstack(stc_data_extended[i])
|
||||
# Order the indices of each hemisphere
|
||||
idx = np.argsort(vertno[i])
|
||||
data.append(stc_data_extended[i][idx])
|
||||
vertno[i] = vertno[i][idx]
|
||||
|
||||
stc = SourceEstimate(
|
||||
np.concatenate(data),
|
||||
vertices=vertno,
|
||||
tmin=tmin,
|
||||
tstep=tstep,
|
||||
subject=src._subject,
|
||||
)
|
||||
return stc
|
||||
|
||||
|
||||
class SourceSimulator:
|
||||
"""Class to generate simulated Source Estimates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
src : instance of SourceSpaces
|
||||
Source space.
|
||||
tstep : float
|
||||
Time step between successive samples in data. Default is 0.001 s.
|
||||
duration : float | None
|
||||
Time interval during which the simulation takes place in seconds.
|
||||
If None, it is computed using existing events and waveform lengths.
|
||||
first_samp : int
|
||||
First sample from which the simulation takes place, as an integer.
|
||||
Comparable to the :term:`first_samp` property of `~mne.io.Raw` objects.
|
||||
Default is 0.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
duration : float
|
||||
The duration of the simulation in seconds.
|
||||
n_times : int
|
||||
The number of time samples of the simulation.
|
||||
"""
|
||||
|
||||
def __init__(self, src, tstep=1e-3, duration=None, first_samp=0):
|
||||
if duration is not None and duration < tstep:
|
||||
raise ValueError("duration must be None or >= tstep.")
|
||||
self.first_samp = _ensure_int(first_samp, "first_samp")
|
||||
self._src = src
|
||||
self._tstep = tstep
|
||||
self._labels = []
|
||||
self._waveforms = []
|
||||
self._events = np.empty((0, 3), dtype=int)
|
||||
self._duration = duration # if not None, sets # samples
|
||||
self._last_samples = []
|
||||
self._chk_duration = 1000
|
||||
|
||||
@property
|
||||
def duration(self):
|
||||
"""Duration of the simulation in same units as tstep."""
|
||||
if self._duration is not None:
|
||||
return self._duration
|
||||
return self.n_times * self._tstep
|
||||
|
||||
@property
|
||||
def n_times(self):
|
||||
"""Number of time samples in the simulation."""
|
||||
if self._duration is not None:
|
||||
return int(self._duration / self._tstep)
|
||||
ls = self.first_samp
|
||||
if len(self._last_samples) > 0:
|
||||
ls = np.max(self._last_samples)
|
||||
return ls - self.first_samp + 1 # >= 1
|
||||
|
||||
@property
|
||||
def last_samp(self):
|
||||
return self.first_samp + self.n_times - 1
|
||||
|
||||
def add_data(self, label, waveform, events):
|
||||
"""Add data to the simulation.
|
||||
|
||||
Data should be added in the form of a triplet of
|
||||
Label (Where) - Waveform(s) (What) - Event(s) (When)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
label : instance of Label
|
||||
The label (as created for example by mne.read_label). If the label
|
||||
does not match any sources in the SourceEstimate, a ValueError is
|
||||
raised.
|
||||
waveform : array, shape (n_times,) or (n_events, n_times) | list
|
||||
The waveform(s) describing the activity on the label vertices.
|
||||
If list, it must have the same length as events.
|
||||
events : array of int, shape (n_events, 3)
|
||||
Events associated to the waveform(s) to specify when the activity
|
||||
should occur.
|
||||
"""
|
||||
_validate_type(label, Label, "label")
|
||||
|
||||
# If it is not a list then make it one
|
||||
if not isinstance(waveform, list) and np.ndim(waveform) == 2:
|
||||
waveform = list(waveform)
|
||||
if not isinstance(waveform, list) and np.ndim(waveform) == 1:
|
||||
waveform = [waveform]
|
||||
if len(waveform) == 1:
|
||||
waveform = waveform * len(events)
|
||||
# The length is either equal to the length of events, or 1
|
||||
if len(waveform) != len(events):
|
||||
raise ValueError(
|
||||
"Number of waveforms and events should match or "
|
||||
f"there should be a single waveform ({len(waveform)} != {len(events)})."
|
||||
)
|
||||
events = _ensure_events(events).astype(np.int64)
|
||||
# Update the last sample possible based on events + waveforms
|
||||
self._labels.extend([label] * len(events))
|
||||
self._waveforms.extend(waveform)
|
||||
self._events = np.concatenate([self._events, events])
|
||||
assert self._events.dtype == np.int64
|
||||
# First sample per waveform is the first column of events
|
||||
# Last is computed below
|
||||
self._last_samples = np.array(
|
||||
[self._events[i, 0] + len(w) - 1 for i, w in enumerate(self._waveforms)]
|
||||
)
|
||||
|
||||
def get_stim_channel(self, start_sample=0, stop_sample=None):
|
||||
"""Get the stim channel from the provided data.
|
||||
|
||||
Returns the stim channel data according to the simulation parameters
|
||||
which should be added through the add_data method. If both start_sample
|
||||
and stop_sample are not specified, the entire duration is used.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_sample : int
|
||||
First sample in chunk. Default is the value of the ``first_samp``
|
||||
attribute.
|
||||
stop_sample : int | None
|
||||
The final sample of the returned stc. If None, then all samples
|
||||
from start_sample onward are returned.
|
||||
|
||||
Returns
|
||||
-------
|
||||
stim_data : ndarray of int, shape (n_samples,)
|
||||
The stimulation channel data.
|
||||
"""
|
||||
if start_sample is None:
|
||||
start_sample = self.first_samp
|
||||
if stop_sample is None:
|
||||
stop_sample = start_sample + self.n_times - 1
|
||||
elif stop_sample < start_sample:
|
||||
raise ValueError("Argument start_sample must be >= stop_sample.")
|
||||
n_samples = stop_sample - start_sample + 1
|
||||
|
||||
# Initialize the stim data array
|
||||
stim_data = np.zeros(n_samples, dtype=np.int64)
|
||||
|
||||
# Select only events in the time chunk
|
||||
stim_ind = np.where(
|
||||
np.logical_and(
|
||||
self._events[:, 0] >= start_sample, self._events[:, 0] < stop_sample
|
||||
)
|
||||
)[0]
|
||||
|
||||
if len(stim_ind) > 0:
|
||||
relative_ind = self._events[stim_ind, 0] - start_sample
|
||||
stim_data[relative_ind] = self._events[stim_ind, 2]
|
||||
|
||||
return stim_data
|
||||
|
||||
def get_stc(self, start_sample=None, stop_sample=None):
|
||||
"""Simulate a SourceEstimate from the provided data.
|
||||
|
||||
Returns a SourceEstimate object constructed according to the simulation
|
||||
parameters which should be added through function add_data. If both
|
||||
start_sample and stop_sample are not specified, the entire duration is
|
||||
used.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_sample : int | None
|
||||
First sample in chunk. If ``None`` the value of the ``first_samp``
|
||||
attribute is used. Defaults to ``None``.
|
||||
stop_sample : int | None
|
||||
The final sample of the returned STC. If ``None``, then all samples
|
||||
past ``start_sample`` are returned.
|
||||
|
||||
Returns
|
||||
-------
|
||||
stc : SourceEstimate object
|
||||
The generated source time courses.
|
||||
"""
|
||||
if len(self._labels) == 0:
|
||||
raise ValueError(
|
||||
"No simulation parameters were found. Please use "
|
||||
"function add_data to add simulation parameters."
|
||||
)
|
||||
if start_sample is None:
|
||||
start_sample = self.first_samp
|
||||
if stop_sample is None:
|
||||
stop_sample = start_sample + self.n_times - 1
|
||||
elif stop_sample < start_sample:
|
||||
raise ValueError("start_sample must be >= stop_sample.")
|
||||
n_samples = stop_sample - start_sample + 1
|
||||
|
||||
# Initialize the stc_data array to span all possible samples
|
||||
stc_data = np.zeros((len(self._labels), n_samples))
|
||||
|
||||
# Select only the events that fall within the span
|
||||
ind = np.where(
|
||||
np.logical_and(
|
||||
self._last_samples >= start_sample, self._events[:, 0] <= stop_sample
|
||||
)
|
||||
)[0]
|
||||
|
||||
# Loop only over the items that are in the time span
|
||||
subset_waveforms = [self._waveforms[i] for i in ind]
|
||||
for i, (waveform, event) in enumerate(zip(subset_waveforms, self._events[ind])):
|
||||
# We retrieve the first and last sample of each waveform
|
||||
# According to the corresponding event
|
||||
wf_start = event[0]
|
||||
wf_stop = self._last_samples[ind[i]]
|
||||
|
||||
# Recover the indices of the event that should be in the chunk
|
||||
waveform_ind = np.isin(
|
||||
np.arange(wf_start, wf_stop + 1),
|
||||
np.arange(start_sample, stop_sample + 1),
|
||||
)
|
||||
|
||||
# Recover the indices that correspond to the overlap
|
||||
stc_ind = np.isin(
|
||||
np.arange(start_sample, stop_sample + 1),
|
||||
np.arange(wf_start, wf_stop + 1),
|
||||
)
|
||||
|
||||
# add the resulting waveform chunk to the corresponding label
|
||||
stc_data[ind[i]][stc_ind] += waveform[waveform_ind]
|
||||
|
||||
start_sample -= self.first_samp # STC sample ref is 0
|
||||
stc = simulate_stc(
|
||||
self._src,
|
||||
self._labels,
|
||||
stc_data,
|
||||
start_sample * self._tstep,
|
||||
self._tstep,
|
||||
allow_overlap=True,
|
||||
)
|
||||
|
||||
return stc
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over 1 second STCs."""
|
||||
# Arbitrary chunk size, can be modified later to something else.
|
||||
# Loop over chunks of 1 second - or, maximum sample size.
|
||||
# Can be modified to a different value.
|
||||
last_sample = self.last_samp
|
||||
for start_sample in range(self.first_samp, last_sample + 1, self._chk_duration):
|
||||
stop_sample = min(start_sample + self._chk_duration - 1, last_sample)
|
||||
yield (
|
||||
self.get_stc(start_sample, stop_sample),
|
||||
self.get_stim_channel(start_sample, stop_sample),
|
||||
)
|
||||
Reference in New Issue
Block a user