initial commit
This commit is contained in:
8
mne/time_frequency/__init__.py
Normal file
8
mne/time_frequency/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
"""Time frequency analysis tools."""
|
||||
import lazy_loader as lazy
|
||||
|
||||
(__getattr__, __dir__, __all__) = lazy.attach_stub(__name__, __file__)
|
||||
81
mne/time_frequency/__init__.pyi
Normal file
81
mne/time_frequency/__init__.pyi
Normal file
@@ -0,0 +1,81 @@
|
||||
__all__ = [
|
||||
"AverageTFR",
|
||||
"AverageTFRArray",
|
||||
"BaseTFR",
|
||||
"CrossSpectralDensity",
|
||||
"EpochsSpectrum",
|
||||
"EpochsSpectrumArray",
|
||||
"EpochsTFR",
|
||||
"EpochsTFRArray",
|
||||
"RawTFR",
|
||||
"RawTFRArray",
|
||||
"Spectrum",
|
||||
"SpectrumArray",
|
||||
"csd_array_fourier",
|
||||
"csd_array_morlet",
|
||||
"csd_array_multitaper",
|
||||
"csd_fourier",
|
||||
"csd_morlet",
|
||||
"csd_multitaper",
|
||||
"csd_tfr",
|
||||
"dpss_windows",
|
||||
"fit_iir_model_raw",
|
||||
"fwhm",
|
||||
"istft",
|
||||
"morlet",
|
||||
"pick_channels_csd",
|
||||
"psd_array_multitaper",
|
||||
"psd_array_welch",
|
||||
"read_csd",
|
||||
"read_spectrum",
|
||||
"read_tfrs",
|
||||
"stft",
|
||||
"stftfreq",
|
||||
"tfr_array_morlet",
|
||||
"tfr_array_multitaper",
|
||||
"tfr_array_stockwell",
|
||||
"tfr_morlet",
|
||||
"tfr_multitaper",
|
||||
"tfr_stockwell",
|
||||
"write_tfrs",
|
||||
]
|
||||
from ._stft import istft, stft, stftfreq
|
||||
from ._stockwell import tfr_array_stockwell, tfr_stockwell
|
||||
from .ar import fit_iir_model_raw
|
||||
from .csd import (
|
||||
CrossSpectralDensity,
|
||||
csd_array_fourier,
|
||||
csd_array_morlet,
|
||||
csd_array_multitaper,
|
||||
csd_fourier,
|
||||
csd_morlet,
|
||||
csd_multitaper,
|
||||
csd_tfr,
|
||||
pick_channels_csd,
|
||||
read_csd,
|
||||
)
|
||||
from .multitaper import dpss_windows, psd_array_multitaper, tfr_array_multitaper
|
||||
from .psd import psd_array_welch
|
||||
from .spectrum import (
|
||||
EpochsSpectrum,
|
||||
EpochsSpectrumArray,
|
||||
Spectrum,
|
||||
SpectrumArray,
|
||||
read_spectrum,
|
||||
)
|
||||
from .tfr import (
|
||||
AverageTFR,
|
||||
AverageTFRArray,
|
||||
BaseTFR,
|
||||
EpochsTFR,
|
||||
EpochsTFRArray,
|
||||
RawTFR,
|
||||
RawTFRArray,
|
||||
fwhm,
|
||||
morlet,
|
||||
read_tfrs,
|
||||
tfr_array_morlet,
|
||||
tfr_morlet,
|
||||
tfr_multitaper,
|
||||
write_tfrs,
|
||||
)
|
||||
261
mne/time_frequency/_stft.py
Normal file
261
mne/time_frequency/_stft.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from math import ceil
|
||||
|
||||
import numpy as np
|
||||
from scipy.fft import irfft, rfft, rfftfreq
|
||||
|
||||
from ..utils import logger, verbose
|
||||
|
||||
|
||||
@verbose
|
||||
def stft(x, wsize, tstep=None, verbose=None):
|
||||
"""STFT Short-Term Fourier Transform using a sine window.
|
||||
|
||||
The transformation is designed to be a tight frame that can be
|
||||
perfectly inverted. It only returns the positive frequencies.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : array, shape (n_signals, n_times)
|
||||
Containing multi-channels signal.
|
||||
wsize : int
|
||||
Length of the STFT window in samples (must be a multiple of 4).
|
||||
tstep : int
|
||||
Step between successive windows in samples (must be a multiple of 2,
|
||||
a divider of wsize and smaller than wsize/2) (default: wsize/2).
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
X : array, shape (n_signals, wsize // 2 + 1, n_step)
|
||||
STFT coefficients for positive frequencies with
|
||||
``n_step = ceil(T / tstep)``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
istft
|
||||
stftfreq
|
||||
"""
|
||||
if not np.isrealobj(x):
|
||||
raise ValueError("x is not a real valued array")
|
||||
|
||||
if x.ndim == 1:
|
||||
x = x[None, :]
|
||||
|
||||
n_signals, T = x.shape
|
||||
wsize = int(wsize)
|
||||
|
||||
# Errors and warnings
|
||||
if wsize % 4:
|
||||
raise ValueError("The window length must be a multiple of 4.")
|
||||
|
||||
if tstep is None:
|
||||
tstep = wsize / 2
|
||||
|
||||
tstep = int(tstep)
|
||||
|
||||
if (wsize % tstep) or (tstep % 2):
|
||||
raise ValueError(
|
||||
"The step size must be a multiple of 2 and a "
|
||||
"divider of the window length."
|
||||
)
|
||||
|
||||
if tstep > wsize / 2:
|
||||
raise ValueError("The step size must be smaller than half the window length.")
|
||||
|
||||
n_step = int(ceil(T / float(tstep)))
|
||||
n_freq = wsize // 2 + 1
|
||||
logger.info(f"Number of frequencies: {n_freq}")
|
||||
logger.info(f"Number of time steps: {n_step}")
|
||||
|
||||
X = np.zeros((n_signals, n_freq, n_step), dtype=np.complex128)
|
||||
|
||||
if n_signals == 0:
|
||||
return X
|
||||
|
||||
# Defining sine window
|
||||
win = np.sin(np.arange(0.5, wsize + 0.5) / wsize * np.pi)
|
||||
win2 = win**2
|
||||
|
||||
swin = np.zeros((n_step - 1) * tstep + wsize)
|
||||
for t in range(n_step):
|
||||
swin[t * tstep : t * tstep + wsize] += win2
|
||||
swin = np.sqrt(wsize * swin)
|
||||
|
||||
# Zero-padding and Pre-processing for edges
|
||||
xp = np.zeros((n_signals, wsize + (n_step - 1) * tstep), dtype=x.dtype)
|
||||
xp[:, (wsize - tstep) // 2 : (wsize - tstep) // 2 + T] = x
|
||||
x = xp
|
||||
|
||||
for t in range(n_step):
|
||||
# Framing
|
||||
wwin = win / swin[t * tstep : t * tstep + wsize]
|
||||
frame = x[:, t * tstep : t * tstep + wsize] * wwin[None, :]
|
||||
# FFT
|
||||
X[:, :, t] = rfft(frame)
|
||||
|
||||
return X
|
||||
|
||||
|
||||
def istft(X, tstep=None, Tx=None):
|
||||
"""ISTFT Inverse Short-Term Fourier Transform using a sine window.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array, shape (..., wsize / 2 + 1, n_step)
|
||||
The STFT coefficients for positive frequencies.
|
||||
tstep : int
|
||||
Step between successive windows in samples (must be a multiple of 2,
|
||||
a divider of wsize and smaller than wsize/2) (default: wsize/2).
|
||||
Tx : int
|
||||
Length of returned signal. If None Tx = n_step * tstep.
|
||||
|
||||
Returns
|
||||
-------
|
||||
x : array, shape (Tx,)
|
||||
Array containing the inverse STFT signal.
|
||||
|
||||
See Also
|
||||
--------
|
||||
stft
|
||||
"""
|
||||
# Errors and warnings
|
||||
X = np.asarray(X)
|
||||
if X.ndim < 2:
|
||||
raise ValueError(f"X must have ndim >= 2, got {X.ndim}")
|
||||
n_win, n_step = X.shape[-2:]
|
||||
signal_shape = X.shape[:-2]
|
||||
if n_win % 2 == 0:
|
||||
raise ValueError("The number of rows of the STFT matrix must be odd.")
|
||||
|
||||
wsize = 2 * (n_win - 1)
|
||||
if tstep is None:
|
||||
tstep = wsize / 2
|
||||
|
||||
if wsize % tstep:
|
||||
raise ValueError(
|
||||
"The step size must be a divider of two times the "
|
||||
"number of rows of the STFT matrix minus two."
|
||||
)
|
||||
|
||||
if wsize % 2:
|
||||
raise ValueError("The step size must be a multiple of 2.")
|
||||
|
||||
if tstep > wsize / 2:
|
||||
raise ValueError(
|
||||
"The step size must be smaller than the number of "
|
||||
"rows of the STFT matrix minus one."
|
||||
)
|
||||
|
||||
if Tx is None:
|
||||
Tx = n_step * tstep
|
||||
|
||||
T = n_step * tstep
|
||||
|
||||
x = np.zeros(signal_shape + (T + wsize - tstep,), dtype=np.float64)
|
||||
|
||||
if np.prod(signal_shape) == 0:
|
||||
return x[..., :Tx]
|
||||
|
||||
# Defining sine window
|
||||
win = np.sin(np.arange(0.5, wsize + 0.5) / wsize * np.pi)
|
||||
# win = win / norm(win);
|
||||
|
||||
# Pre-processing for edges
|
||||
swin = np.zeros(T + wsize - tstep, dtype=np.float64)
|
||||
for t in range(n_step):
|
||||
swin[t * tstep : t * tstep + wsize] += win**2
|
||||
swin = np.sqrt(swin / wsize)
|
||||
|
||||
for t in range(n_step):
|
||||
# IFFT
|
||||
frame = irfft(X[..., t], wsize)
|
||||
# Overlap-add
|
||||
frame *= win / swin[t * tstep : t * tstep + wsize]
|
||||
x[..., t * tstep : t * tstep + wsize] += frame
|
||||
|
||||
# Truncation
|
||||
x = x[..., (wsize - tstep) // 2 : (wsize - tstep) // 2 + T + 1]
|
||||
x = x[..., :Tx].copy()
|
||||
return x
|
||||
|
||||
|
||||
def stftfreq(wsize, sfreq=None): # noqa: D401
|
||||
"""Compute frequencies of stft transformation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wsize : int
|
||||
Size of stft window.
|
||||
sfreq : float
|
||||
Sampling frequency. If None the frequencies are given between 0 and pi
|
||||
otherwise it's given in Hz.
|
||||
|
||||
Returns
|
||||
-------
|
||||
freqs : array
|
||||
The positive frequencies returned by stft.
|
||||
|
||||
See Also
|
||||
--------
|
||||
stft
|
||||
istft
|
||||
"""
|
||||
freqs = rfftfreq(wsize)
|
||||
if sfreq is not None:
|
||||
freqs *= float(sfreq)
|
||||
return freqs
|
||||
|
||||
|
||||
def stft_norm2(X):
|
||||
"""Compute L2 norm of STFT transform.
|
||||
|
||||
It takes into account that stft only return positive frequencies.
|
||||
As we use tight frame this quantity is conserved by the stft.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : 3D complex array
|
||||
The STFT transforms
|
||||
|
||||
Returns
|
||||
-------
|
||||
norms2 : array
|
||||
The squared L2 norm of every row of X.
|
||||
"""
|
||||
X2 = (X * X.conj()).real
|
||||
# compute all L2 coefs and remove first and last frequency once.
|
||||
norms2 = (
|
||||
2.0 * X2.sum(axis=2).sum(axis=1)
|
||||
- np.sum(X2[:, 0, :], axis=1)
|
||||
- np.sum(X2[:, -1, :], axis=1)
|
||||
)
|
||||
return norms2
|
||||
|
||||
|
||||
def stft_norm1(X):
|
||||
"""Compute L1 norm of STFT transform.
|
||||
|
||||
It takes into account that stft only return positive frequencies.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : 3D complex array
|
||||
The STFT transforms
|
||||
|
||||
Returns
|
||||
-------
|
||||
norms : array
|
||||
The L1 norm of every row of X.
|
||||
"""
|
||||
X_abs = np.abs(X)
|
||||
# compute all L1 coefs and remove first and last frequency once.
|
||||
norms = (
|
||||
2.0 * X_abs.sum(axis=(1, 2))
|
||||
- np.sum(X_abs[:, 0, :], axis=1)
|
||||
- np.sum(X_abs[:, -1, :], axis=1)
|
||||
)
|
||||
return norms
|
||||
322
mne/time_frequency/_stockwell.py
Normal file
322
mne/time_frequency/_stockwell.py
Normal file
@@ -0,0 +1,322 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from scipy.fft import fft, fftfreq, ifft
|
||||
|
||||
from .._fiff.pick import _pick_data_channels, pick_info
|
||||
from ..parallel import parallel_func
|
||||
from ..utils import _validate_type, legacy, logger, verbose
|
||||
from .tfr import AverageTFRArray, _ensure_slice, _get_data
|
||||
|
||||
|
||||
def _check_input_st(x_in, n_fft):
|
||||
"""Aux function."""
|
||||
# flatten to 2 D and memorize original shape
|
||||
n_times = x_in.shape[-1]
|
||||
|
||||
def _is_power_of_two(n):
|
||||
return not (n > 0 and (n & (n - 1)))
|
||||
|
||||
if n_fft is None or (not _is_power_of_two(n_fft) and n_times > n_fft):
|
||||
# Compute next power of 2
|
||||
n_fft = 2 ** int(np.ceil(np.log2(n_times)))
|
||||
elif n_fft < n_times:
|
||||
raise ValueError(
|
||||
f"n_fft cannot be smaller than signal size. Got {n_fft} < {n_times}."
|
||||
)
|
||||
if n_times < n_fft:
|
||||
logger.info(
|
||||
f'The input signal is shorter ({x_in.shape[-1]}) than "n_fft" ({n_fft}). '
|
||||
"Applying zero padding."
|
||||
)
|
||||
zero_pad = n_fft - n_times
|
||||
pad_array = np.zeros(x_in.shape[:-1] + (zero_pad,), x_in.dtype)
|
||||
x_in = np.concatenate((x_in, pad_array), axis=-1)
|
||||
else:
|
||||
zero_pad = 0
|
||||
return x_in, n_fft, zero_pad
|
||||
|
||||
|
||||
def _precompute_st_windows(n_samp, start_f, stop_f, sfreq, width):
|
||||
"""Precompute stockwell Gaussian windows (in the freq domain)."""
|
||||
tw = fftfreq(n_samp, 1.0 / sfreq) / n_samp
|
||||
tw = np.r_[tw[:1], tw[1:][::-1]]
|
||||
|
||||
k = width # 1 for classical stowckwell transform
|
||||
f_range = np.arange(start_f, stop_f, 1)
|
||||
windows = np.empty((len(f_range), len(tw)), dtype=np.complex128)
|
||||
for i_f, f in enumerate(f_range):
|
||||
if f == 0.0:
|
||||
window = np.ones(len(tw))
|
||||
else:
|
||||
window = (f / (np.sqrt(2.0 * np.pi) * k)) * np.exp(
|
||||
-0.5 * (1.0 / k**2.0) * (f**2.0) * tw**2.0
|
||||
)
|
||||
window /= window.sum() # normalisation
|
||||
windows[i_f] = fft(window)
|
||||
return windows
|
||||
|
||||
|
||||
def _st(x, start_f, windows):
|
||||
"""Compute ST based on Ali Moukadem MATLAB code (used in tests)."""
|
||||
from scipy.fft import fft, ifft
|
||||
|
||||
n_samp = x.shape[-1]
|
||||
ST = np.empty(x.shape[:-1] + (len(windows), n_samp), dtype=np.complex128)
|
||||
# do the work
|
||||
Fx = fft(x)
|
||||
XF = np.concatenate([Fx, Fx], axis=-1)
|
||||
for i_f, window in enumerate(windows):
|
||||
f = start_f + i_f
|
||||
ST[..., i_f, :] = ifft(XF[..., f : f + n_samp] * window)
|
||||
return ST
|
||||
|
||||
|
||||
def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W):
|
||||
"""Aux function."""
|
||||
decim = _ensure_slice(decim)
|
||||
n_samp = x.shape[-1]
|
||||
decim_indices = decim.indices(n_samp - zero_pad)
|
||||
n_out = len(range(*decim_indices))
|
||||
psd = np.empty((len(W), n_out))
|
||||
itc = np.empty_like(psd) if compute_itc else None
|
||||
X = fft(x)
|
||||
XX = np.concatenate([X, X], axis=-1)
|
||||
for i_f, window in enumerate(W):
|
||||
f = start_f + i_f
|
||||
ST = ifft(XX[:, f : f + n_samp] * window)
|
||||
TFR = ST[:, slice(*decim_indices)]
|
||||
TFR_abs = np.abs(TFR)
|
||||
TFR_abs[TFR_abs == 0] = 1.0
|
||||
if compute_itc:
|
||||
TFR /= TFR_abs
|
||||
itc[i_f] = np.abs(np.mean(TFR, axis=0))
|
||||
TFR_abs *= TFR_abs
|
||||
psd[i_f] = np.mean(TFR_abs, axis=0)
|
||||
return psd, itc
|
||||
|
||||
|
||||
def _compute_freqs_st(fmin, fmax, n_fft, sfreq):
|
||||
from scipy.fft import fftfreq
|
||||
|
||||
freqs = fftfreq(n_fft, 1.0 / sfreq)
|
||||
if fmin is None:
|
||||
fmin = freqs[freqs > 0][0]
|
||||
if fmax is None:
|
||||
fmax = freqs.max()
|
||||
|
||||
start_f = np.abs(freqs - fmin).argmin()
|
||||
stop_f = np.abs(freqs - fmax).argmin()
|
||||
freqs = freqs[start_f:stop_f]
|
||||
return start_f, stop_f, freqs
|
||||
|
||||
|
||||
@verbose
|
||||
def tfr_array_stockwell(
|
||||
data,
|
||||
sfreq,
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
n_fft=None,
|
||||
width=1.0,
|
||||
decim=1,
|
||||
return_itc=False,
|
||||
n_jobs=None,
|
||||
*,
|
||||
verbose=None,
|
||||
):
|
||||
"""Compute power and intertrial coherence using Stockwell (S) transform.
|
||||
|
||||
Same computation as `~mne.time_frequency.tfr_stockwell`, but operates on
|
||||
:class:`NumPy arrays <numpy.ndarray>` instead of `~mne.Epochs` objects.
|
||||
|
||||
See :footcite:`Stockwell2007,MoukademEtAl2014,WheatEtAl2010,JonesEtAl2006`
|
||||
for more information.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : ndarray, shape (n_epochs, n_channels, n_times)
|
||||
The signal to transform.
|
||||
sfreq : float
|
||||
The sampling frequency.
|
||||
fmin : None, float
|
||||
The minimum frequency to include. If None defaults to the minimum fft
|
||||
frequency greater than zero.
|
||||
fmax : None, float
|
||||
The maximum frequency to include. If None defaults to the maximum fft.
|
||||
n_fft : int | None
|
||||
The length of the windows used for FFT. If None, it defaults to the
|
||||
next power of 2 larger than the signal length.
|
||||
width : float
|
||||
The width of the Gaussian window. If < 1, increased temporal
|
||||
resolution, if > 1, increased frequency resolution. Defaults to 1.
|
||||
(classical S-Transform).
|
||||
%(decim_tfr)s
|
||||
return_itc : bool
|
||||
Return intertrial coherence (ITC) as well as averaged power.
|
||||
%(n_jobs)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
st_power : ndarray
|
||||
The multitaper power of the Stockwell transformed data.
|
||||
The last two dimensions are frequency and time.
|
||||
itc : ndarray
|
||||
The intertrial coherence. Only returned if return_itc is True.
|
||||
freqs : ndarray
|
||||
The frequencies.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.time_frequency.tfr_stockwell
|
||||
mne.time_frequency.tfr_multitaper
|
||||
mne.time_frequency.tfr_array_multitaper
|
||||
mne.time_frequency.tfr_morlet
|
||||
mne.time_frequency.tfr_array_morlet
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
_validate_type(data, np.ndarray, "data")
|
||||
if data.ndim != 3:
|
||||
raise ValueError(
|
||||
"data must be 3D with shape (n_epochs, n_channels, n_times), "
|
||||
f"got {data.shape}"
|
||||
)
|
||||
decim = _ensure_slice(decim)
|
||||
_, n_channels, n_out = data[..., decim].shape
|
||||
data, n_fft_, zero_pad = _check_input_st(data, n_fft)
|
||||
start_f, stop_f, freqs = _compute_freqs_st(fmin, fmax, n_fft_, sfreq)
|
||||
|
||||
W = _precompute_st_windows(data.shape[-1], start_f, stop_f, sfreq, width)
|
||||
n_freq = stop_f - start_f
|
||||
psd = np.empty((n_channels, n_freq, n_out))
|
||||
itc = np.empty((n_channels, n_freq, n_out)) if return_itc else None
|
||||
|
||||
parallel, my_st, n_jobs = parallel_func(_st_power_itc, n_jobs, verbose=verbose)
|
||||
tfrs = parallel(
|
||||
my_st(data[:, c, :], start_f, return_itc, zero_pad, decim, W)
|
||||
for c in range(n_channels)
|
||||
)
|
||||
for c, (this_psd, this_itc) in enumerate(iter(tfrs)):
|
||||
psd[c] = this_psd
|
||||
if this_itc is not None:
|
||||
itc[c] = this_itc
|
||||
|
||||
return psd, itc, freqs
|
||||
|
||||
|
||||
@legacy(alt='.compute_tfr(method="stockwell", freqs="auto")')
|
||||
@verbose
|
||||
def tfr_stockwell(
|
||||
inst,
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
n_fft=None,
|
||||
width=1.0,
|
||||
decim=1,
|
||||
return_itc=False,
|
||||
n_jobs=None,
|
||||
verbose=None,
|
||||
):
|
||||
"""Compute Time-Frequency Representation (TFR) using Stockwell Transform.
|
||||
|
||||
Same computation as `~mne.time_frequency.tfr_array_stockwell`, but operates
|
||||
on `~mne.Epochs` objects instead of :class:`NumPy arrays <numpy.ndarray>`.
|
||||
|
||||
See :footcite:`Stockwell2007,MoukademEtAl2014,WheatEtAl2010,JonesEtAl2006`
|
||||
for more information.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst : Epochs | Evoked
|
||||
The epochs or evoked object.
|
||||
fmin : None, float
|
||||
The minimum frequency to include. If None defaults to the minimum fft
|
||||
frequency greater than zero.
|
||||
fmax : None, float
|
||||
The maximum frequency to include. If None defaults to the maximum fft.
|
||||
n_fft : int | None
|
||||
The length of the windows used for FFT. If None, it defaults to the
|
||||
next power of 2 larger than the signal length.
|
||||
width : float
|
||||
The width of the Gaussian window. If < 1, increased temporal
|
||||
resolution, if > 1, increased frequency resolution. Defaults to 1.
|
||||
(classical S-Transform).
|
||||
decim : int
|
||||
The decimation factor on the time axis. To reduce memory usage.
|
||||
return_itc : bool
|
||||
Return intertrial coherence (ITC) as well as averaged power.
|
||||
n_jobs : int
|
||||
The number of jobs to run in parallel (over channels).
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
power : AverageTFR
|
||||
The averaged power.
|
||||
itc : AverageTFR
|
||||
The intertrial coherence. Only returned if return_itc is True.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.time_frequency.tfr_array_stockwell
|
||||
mne.time_frequency.tfr_multitaper
|
||||
mne.time_frequency.tfr_array_multitaper
|
||||
mne.time_frequency.tfr_morlet
|
||||
mne.time_frequency.tfr_array_morlet
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 0.9.0
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
# verbose dec is used b/c subfunctions are verbose
|
||||
data = _get_data(inst, return_itc)
|
||||
picks = _pick_data_channels(inst.info)
|
||||
info = pick_info(inst.info, picks)
|
||||
data = data[:, picks, :]
|
||||
decim = _ensure_slice(decim)
|
||||
power, itc, freqs = tfr_array_stockwell(
|
||||
data,
|
||||
sfreq=info["sfreq"],
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
n_fft=n_fft,
|
||||
width=width,
|
||||
decim=decim,
|
||||
return_itc=return_itc,
|
||||
n_jobs=n_jobs,
|
||||
)
|
||||
times = inst.times[decim].copy()
|
||||
nave = len(data)
|
||||
out = AverageTFRArray(
|
||||
info=info,
|
||||
data=power,
|
||||
times=times,
|
||||
freqs=freqs,
|
||||
nave=nave,
|
||||
method="stockwell-power",
|
||||
)
|
||||
if return_itc:
|
||||
out = (
|
||||
out,
|
||||
AverageTFRArray(
|
||||
info=deepcopy(info),
|
||||
data=itc,
|
||||
times=times.copy(),
|
||||
freqs=freqs.copy(),
|
||||
nave=nave,
|
||||
method="stockwell-itc",
|
||||
),
|
||||
)
|
||||
return out
|
||||
76
mne/time_frequency/ar.py
Normal file
76
mne/time_frequency/ar.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
from scipy import linalg
|
||||
|
||||
from .._fiff.pick import _picks_by_type, _picks_to_idx, pick_info
|
||||
from ..defaults import _handle_default
|
||||
from ..utils import _apply_scaling_array, verbose
|
||||
|
||||
|
||||
def _yule_walker(X, order=1):
|
||||
"""Compute Yule-Walker (adapted from statsmodels).
|
||||
|
||||
Operates in-place.
|
||||
"""
|
||||
assert X.ndim == 2
|
||||
denom = X.shape[-1] - np.arange(order + 1)
|
||||
r = np.zeros(order + 1, np.float64)
|
||||
for di, d in enumerate(X):
|
||||
d -= d.mean()
|
||||
r[0] += np.dot(d, d)
|
||||
for k in range(1, order + 1):
|
||||
r[k] += np.dot(d[0:-k], d[k:])
|
||||
r /= denom * len(X)
|
||||
rho = linalg.solve(linalg.toeplitz(r[:-1]), r[1:])
|
||||
sigmasq = r[0] - (r[1:] * rho).sum()
|
||||
return rho, np.sqrt(sigmasq)
|
||||
|
||||
|
||||
@verbose
|
||||
def fit_iir_model_raw(raw, order=2, picks=None, tmin=None, tmax=None, verbose=None):
|
||||
r"""Fit an AR model to raw data and creates the corresponding IIR filter.
|
||||
|
||||
The computed filter is fitted to data from all of the picked channels,
|
||||
with frequency response given by the standard IIR formula:
|
||||
|
||||
.. math::
|
||||
|
||||
H(e^{jw}) = \frac{1}{a[0] + a[1]e^{-jw} + ... + a[n]e^{-jnw}}
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : Raw object
|
||||
An instance of Raw.
|
||||
order : int
|
||||
Order of the FIR filter.
|
||||
%(picks_good_data)s
|
||||
tmin : float
|
||||
The beginning of time interval in seconds.
|
||||
tmax : float
|
||||
The end of time interval in seconds.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
b : ndarray
|
||||
Numerator filter coefficients.
|
||||
a : ndarray
|
||||
Denominator filter coefficients.
|
||||
"""
|
||||
start, stop = None, None
|
||||
if tmin is not None:
|
||||
start = raw.time_as_index(tmin)[0]
|
||||
if tmax is not None:
|
||||
stop = raw.time_as_index(tmax)[0] + 1
|
||||
picks = _picks_to_idx(raw.info, picks)
|
||||
data = raw[picks, start:stop][0]
|
||||
# rescale data to similar levels
|
||||
picks_list = _picks_by_type(pick_info(raw.info, picks))
|
||||
scalings = _handle_default("scalings_cov_rank", None)
|
||||
_apply_scaling_array(data, picks_list=picks_list, scalings=scalings)
|
||||
# do the fitting
|
||||
coeffs, _ = _yule_walker(data, order=order)
|
||||
return np.array([1.0]), np.concatenate(([1.0], -coeffs))
|
||||
1607
mne/time_frequency/csd.py
Normal file
1607
mne/time_frequency/csd.py
Normal file
File diff suppressed because it is too large
Load Diff
555
mne/time_frequency/multitaper.py
Normal file
555
mne/time_frequency/multitaper.py
Normal file
@@ -0,0 +1,555 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
# Parts of this code were copied from NiTime http://nipy.sourceforge.net/nitime
|
||||
|
||||
import numpy as np
|
||||
from scipy.fft import rfft, rfftfreq
|
||||
from scipy.integrate import trapezoid
|
||||
from scipy.signal import get_window
|
||||
from scipy.signal.windows import dpss as sp_dpss
|
||||
|
||||
from ..parallel import parallel_func
|
||||
from ..utils import _check_option, logger, verbose, warn
|
||||
|
||||
|
||||
def dpss_windows(N, half_nbw, Kmax, *, sym=True, norm=None, low_bias=True):
|
||||
"""Compute Discrete Prolate Spheroidal Sequences.
|
||||
|
||||
Will give of orders [0,Kmax-1] for a given frequency-spacing multiple
|
||||
NW and sequence length N.
|
||||
|
||||
.. note:: Copied from NiTime.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
N : int
|
||||
Sequence length.
|
||||
half_nbw : float
|
||||
Standardized half bandwidth corresponding to 2 * half_bw = BW*f0
|
||||
= BW*N/dt but with dt taken as 1.
|
||||
Kmax : int
|
||||
Number of DPSS windows to return is Kmax (orders 0 through Kmax-1).
|
||||
sym : bool
|
||||
Whether to generate a symmetric window (``True``, for filter design) or
|
||||
a periodic window (``False``, for spectral analysis). Default is
|
||||
``True``.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
norm : 2 | ``'approximate'`` | ``'subsample'`` | None
|
||||
Window normalization method. If ``'approximate'`` or ``'subsample'``,
|
||||
windows are normalized by the maximum, and a correction scale-factor
|
||||
for even-length windows is applied either using
|
||||
``N**2/(N**2+half_nbw)`` ("approximate") or a FFT-based subsample shift
|
||||
("subsample"). ``2`` uses the L2 norm. ``None`` (the default) uses
|
||||
``"approximate"`` when ``Kmax=None`` and ``2`` otherwise.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
low_bias : bool
|
||||
Keep only tapers with eigenvalues > 0.9.
|
||||
|
||||
Returns
|
||||
-------
|
||||
v, e : tuple,
|
||||
The v array contains DPSS windows shaped (Kmax, N).
|
||||
e are the eigenvalues.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Tridiagonal form of DPSS calculation from :footcite:`Slepian1978`.
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
dpss, eigvals = sp_dpss(N, half_nbw, Kmax, sym=sym, norm=norm, return_ratios=True)
|
||||
if low_bias:
|
||||
idx = eigvals > 0.9
|
||||
if not idx.any():
|
||||
warn("Could not properly use low_bias, keeping lowest-bias taper")
|
||||
idx = [np.argmax(eigvals)]
|
||||
dpss, eigvals = dpss[idx], eigvals[idx]
|
||||
assert len(dpss) > 0 # should never happen
|
||||
assert dpss.shape[1] == N # old nitime bug
|
||||
return dpss, eigvals
|
||||
|
||||
|
||||
def _psd_from_mt_adaptive(x_mt, eigvals, freq_mask, max_iter=250, return_weights=False):
|
||||
r"""Use iterative procedure to compute the PSD from tapered spectra.
|
||||
|
||||
.. note:: Modified from NiTime.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x_mt : array, shape=(n_signals, n_tapers, n_freqs)
|
||||
The DFTs of the tapered sequences (only positive frequencies)
|
||||
eigvals : array, length n_tapers
|
||||
The eigenvalues of the DPSS tapers
|
||||
freq_mask : array
|
||||
Frequency indices to keep
|
||||
max_iter : int
|
||||
Maximum number of iterations for weight computation.
|
||||
return_weights : bool
|
||||
Also return the weights
|
||||
|
||||
Returns
|
||||
-------
|
||||
psd : array, shape=(n_signals, np.sum(freq_mask))
|
||||
The computed PSDs
|
||||
weights : array shape=(n_signals, n_tapers, np.sum(freq_mask))
|
||||
The weights used to combine the tapered spectra
|
||||
|
||||
Notes
|
||||
-----
|
||||
The weights to use for making the multitaper estimate, such that
|
||||
:math:`S_{mt} = \sum_{k} |w_k|^2S_k^{mt} / \sum_{k} |w_k|^2`
|
||||
"""
|
||||
n_signals, n_tapers, n_freqs = x_mt.shape
|
||||
|
||||
if len(eigvals) != n_tapers:
|
||||
raise ValueError("Need one eigenvalue for each taper")
|
||||
|
||||
if n_tapers < 3:
|
||||
raise ValueError("Not enough tapers to compute adaptive weights.")
|
||||
|
||||
rt_eig = np.sqrt(eigvals)
|
||||
|
||||
# estimate the variance from an estimate with fixed weights
|
||||
psd_est = _psd_from_mt(x_mt, rt_eig[np.newaxis, :, np.newaxis])
|
||||
x_var = trapezoid(psd_est, dx=np.pi / n_freqs) / (2 * np.pi)
|
||||
del psd_est
|
||||
|
||||
# allocate space for output
|
||||
psd = np.empty((n_signals, np.sum(freq_mask)))
|
||||
|
||||
# only keep the frequencies of interest
|
||||
x_mt = x_mt[:, :, freq_mask]
|
||||
|
||||
if return_weights:
|
||||
weights = np.empty((n_signals, n_tapers, psd.shape[1]))
|
||||
|
||||
for i, (xk, var) in enumerate(zip(x_mt, x_var)):
|
||||
# combine the SDFs in the traditional way in order to estimate
|
||||
# the variance of the timeseries
|
||||
|
||||
# The process is to iteratively switch solving for the following
|
||||
# two expressions:
|
||||
# (1) Adaptive Multitaper SDF:
|
||||
# S^{mt}(f) = [ sum |d_k(f)|^2 S_k(f) ]/ sum |d_k(f)|^2
|
||||
#
|
||||
# (2) Weights
|
||||
# d_k(f) = [sqrt(lam_k) S^{mt}(f)] / [lam_k S^{mt}(f) + E{B_k(f)}]
|
||||
#
|
||||
# Where lam_k are the eigenvalues corresponding to the DPSS tapers,
|
||||
# and the expected value of the broadband bias function
|
||||
# E{B_k(f)} is replaced by its full-band integration
|
||||
# (1/2pi) int_{-pi}^{pi} E{B_k(f)} = sig^2(1-lam_k)
|
||||
|
||||
# start with an estimate from incomplete data--the first 2 tapers
|
||||
psd_iter = _psd_from_mt(xk[:2, :], rt_eig[:2, np.newaxis])
|
||||
|
||||
err = np.zeros_like(xk)
|
||||
for n in range(max_iter):
|
||||
d_k = psd_iter / (
|
||||
eigvals[:, np.newaxis] * psd_iter + (1 - eigvals[:, np.newaxis]) * var
|
||||
)
|
||||
d_k *= rt_eig[:, np.newaxis]
|
||||
# Test for convergence -- this is overly conservative, since
|
||||
# iteration only stops when all frequencies have converged.
|
||||
# A better approach is to iterate separately for each freq, but
|
||||
# that is a nonvectorized algorithm.
|
||||
# Take the RMS difference in weights from the previous iterate
|
||||
# across frequencies. If the maximum RMS error across freqs is
|
||||
# less than 1e-10, then we're converged
|
||||
err -= d_k
|
||||
if np.max(np.mean(err**2, axis=0)) < 1e-10:
|
||||
break
|
||||
|
||||
# update the iterative estimate with this d_k
|
||||
psd_iter = _psd_from_mt(xk, d_k)
|
||||
err = d_k
|
||||
|
||||
if n == max_iter - 1:
|
||||
warn("Iterative multi-taper PSD computation did not converge.")
|
||||
|
||||
psd[i, :] = psd_iter
|
||||
|
||||
if return_weights:
|
||||
weights[i, :, :] = d_k
|
||||
|
||||
if return_weights:
|
||||
return psd, weights
|
||||
else:
|
||||
return psd
|
||||
|
||||
|
||||
def _psd_from_mt(x_mt, weights):
|
||||
"""Compute PSD from tapered spectra.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x_mt : array, shape=(..., n_tapers, n_freqs)
|
||||
Tapered spectra
|
||||
weights : array, shape=(n_tapers,)
|
||||
Weights used to combine the tapered spectra
|
||||
|
||||
Returns
|
||||
-------
|
||||
psd : array, shape=(..., n_freqs)
|
||||
The computed PSD
|
||||
"""
|
||||
psd = weights * x_mt
|
||||
psd *= psd.conj()
|
||||
psd = psd.real.sum(axis=-2)
|
||||
psd *= 2 / (weights * weights.conj()).real.sum(axis=-2)
|
||||
return psd
|
||||
|
||||
|
||||
def _csd_from_mt(x_mt, y_mt, weights_x, weights_y):
|
||||
"""Compute CSD from tapered spectra.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x_mt : array, shape=(..., n_tapers, n_freqs)
|
||||
Tapered spectra for x
|
||||
y_mt : array, shape=(..., n_tapers, n_freqs)
|
||||
Tapered spectra for y
|
||||
weights_x : array, shape=(n_tapers,)
|
||||
Weights used to combine the tapered spectra of x_mt
|
||||
weights_y : array, shape=(n_tapers,)
|
||||
Weights used to combine the tapered spectra of y_mt
|
||||
|
||||
Returns
|
||||
-------
|
||||
csd: array
|
||||
The computed CSD
|
||||
"""
|
||||
csd = np.sum(weights_x * x_mt * (weights_y * y_mt).conj(), axis=-2)
|
||||
denom = np.sqrt((weights_x * weights_x.conj()).real.sum(axis=-2)) * np.sqrt(
|
||||
(weights_y * weights_y.conj()).real.sum(axis=-2)
|
||||
)
|
||||
csd *= 2 / denom
|
||||
return csd
|
||||
|
||||
|
||||
def _mt_spectra(x, dpss, sfreq, n_fft=None, remove_dc=True):
|
||||
"""Compute tapered spectra.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : array, shape=(..., n_times)
|
||||
Input signal
|
||||
dpss : array, shape=(n_tapers, n_times)
|
||||
The tapers
|
||||
sfreq : float
|
||||
The sampling frequency
|
||||
n_fft : int | None
|
||||
Length of the FFT. If None, the number of samples in the input signal
|
||||
will be used.
|
||||
%(remove_dc)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
x_mt : array, shape=(..., n_tapers, n_freqs)
|
||||
The tapered spectra
|
||||
freqs : array, shape=(n_freqs,)
|
||||
The frequency points in Hz of the spectra
|
||||
"""
|
||||
if n_fft is None:
|
||||
n_fft = x.shape[-1]
|
||||
|
||||
# remove mean (do not use in-place subtraction as it may modify input x)
|
||||
if remove_dc:
|
||||
x = x - np.mean(x, axis=-1, keepdims=True)
|
||||
|
||||
# only keep positive frequencies
|
||||
freqs = rfftfreq(n_fft, 1.0 / sfreq)
|
||||
|
||||
# The following is equivalent to this, but uses less memory:
|
||||
# x_mt = fftpack.fft(x[:, np.newaxis, :] * dpss, n=n_fft)
|
||||
n_tapers = dpss.shape[0] if dpss.ndim > 1 else 1
|
||||
x_mt = np.zeros(x.shape[:-1] + (n_tapers, len(freqs)), dtype=np.complex128)
|
||||
for idx, sig in enumerate(x):
|
||||
x_mt[idx] = rfft(sig[..., np.newaxis, :] * dpss, n=n_fft)
|
||||
# Adjust DC and maybe Nyquist, depending on one-sided transform
|
||||
x_mt[..., 0] /= np.sqrt(2.0)
|
||||
if n_fft % 2 == 0:
|
||||
x_mt[..., -1] /= np.sqrt(2.0)
|
||||
return x_mt, freqs
|
||||
|
||||
|
||||
@verbose
|
||||
def _compute_mt_params(n_times, sfreq, bandwidth, low_bias, adaptive, verbose=None):
|
||||
"""Triage windowing and multitaper parameters."""
|
||||
# Compute standardized half-bandwidth
|
||||
if isinstance(bandwidth, str):
|
||||
logger.info(f' Using standard spectrum estimation with "{bandwidth}" window')
|
||||
window_fun = get_window(bandwidth, n_times)[np.newaxis]
|
||||
return window_fun, np.ones(1), False
|
||||
|
||||
if bandwidth is not None:
|
||||
half_nbw = float(bandwidth) * n_times / (2.0 * sfreq)
|
||||
else:
|
||||
half_nbw = 4.0
|
||||
if half_nbw < 0.5:
|
||||
raise ValueError(
|
||||
f"bandwidth value {bandwidth} yields a normalized half-bandwidth of "
|
||||
f"{half_nbw} < 0.5, use a value of at least {sfreq / n_times}"
|
||||
)
|
||||
|
||||
# Compute DPSS windows
|
||||
n_tapers_max = int(2 * half_nbw)
|
||||
window_fun, eigvals = dpss_windows(
|
||||
n_times, half_nbw, n_tapers_max, sym=False, low_bias=low_bias
|
||||
)
|
||||
logger.info(
|
||||
f" Using multitaper spectrum estimation with {len(eigvals)} DPSS windows"
|
||||
)
|
||||
|
||||
if adaptive and len(eigvals) < 3:
|
||||
warn(
|
||||
"Not adaptively combining the spectral estimators due to a "
|
||||
f"low number of tapers ({len(eigvals)} < 3)."
|
||||
)
|
||||
adaptive = False
|
||||
|
||||
return window_fun, eigvals, adaptive
|
||||
|
||||
|
||||
@verbose
|
||||
def psd_array_multitaper(
|
||||
x,
|
||||
sfreq,
|
||||
fmin=0.0,
|
||||
fmax=np.inf,
|
||||
bandwidth=None,
|
||||
adaptive=False,
|
||||
low_bias=True,
|
||||
normalization="length",
|
||||
remove_dc=True,
|
||||
output="power",
|
||||
n_jobs=None,
|
||||
*,
|
||||
max_iter=150,
|
||||
verbose=None,
|
||||
):
|
||||
r"""Compute power spectral density (PSD) using a multi-taper method.
|
||||
|
||||
The power spectral density is computed with DPSS
|
||||
tapers :footcite:p:`Slepian1978`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : array, shape=(..., n_times)
|
||||
The data to compute PSD from.
|
||||
sfreq : float
|
||||
The sampling frequency.
|
||||
%(fmin_fmax_psd)s
|
||||
bandwidth : float
|
||||
Frequency bandwidth of the multi-taper window function in Hz. For a
|
||||
given frequency, frequencies at ``± bandwidth / 2`` are smoothed
|
||||
together. The default value is a bandwidth of
|
||||
``8 * (sfreq / n_times)``.
|
||||
adaptive : bool
|
||||
Use adaptive weights to combine the tapered spectra into PSD
|
||||
(slow, use n_jobs >> 1 to speed up computation).
|
||||
low_bias : bool
|
||||
Only use tapers with more than 90%% spectral concentration within
|
||||
bandwidth.
|
||||
%(normalization)s
|
||||
%(remove_dc)s
|
||||
output : str
|
||||
The format of the returned ``psds`` array, ``'complex'`` or
|
||||
``'power'``:
|
||||
|
||||
* ``'power'`` : the power spectral density is returned.
|
||||
* ``'complex'`` : the complex fourier coefficients are returned per
|
||||
taper.
|
||||
%(n_jobs)s
|
||||
%(max_iter_multitaper)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
psds : ndarray, shape (..., n_freqs) or (..., n_tapers, n_freqs)
|
||||
The power spectral densities. All dimensions up to the last (or the
|
||||
last two if ``output='complex'``) will be the same as input.
|
||||
freqs : array
|
||||
The frequency points in Hz of the PSD.
|
||||
weights : ndarray
|
||||
The weights used for averaging across tapers. Only returned if
|
||||
``output='complex'``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
csd_multitaper
|
||||
mne.io.Raw.compute_psd
|
||||
mne.Epochs.compute_psd
|
||||
mne.Evoked.compute_psd
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 0.14.0
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
_check_option("normalization", normalization, ["length", "full"])
|
||||
|
||||
# Reshape data so its 2-D for parallelization
|
||||
ndim_in = x.ndim
|
||||
x = np.atleast_2d(x)
|
||||
n_times = x.shape[-1]
|
||||
dshape = x.shape[:-1]
|
||||
x = x.reshape(-1, n_times)
|
||||
|
||||
dpss, eigvals, adaptive = _compute_mt_params(
|
||||
n_times, sfreq, bandwidth, low_bias, adaptive
|
||||
)
|
||||
n_tapers = len(dpss)
|
||||
weights = np.sqrt(eigvals)[np.newaxis, :, np.newaxis]
|
||||
|
||||
# decide which frequencies to keep
|
||||
freqs = rfftfreq(n_times, 1.0 / sfreq)
|
||||
freq_mask = (freqs >= fmin) & (freqs <= fmax)
|
||||
freqs = freqs[freq_mask]
|
||||
n_freqs = len(freqs)
|
||||
|
||||
if output == "complex":
|
||||
psd = np.zeros((x.shape[0], n_tapers, n_freqs), dtype="complex")
|
||||
else:
|
||||
psd = np.zeros((x.shape[0], n_freqs))
|
||||
|
||||
# Let's go in up to 50 MB chunks of signals to save memory
|
||||
n_chunk = max(50000000 // (len(freq_mask) * len(eigvals) * 16), 1)
|
||||
offsets = np.concatenate((np.arange(0, x.shape[0], n_chunk), [x.shape[0]]))
|
||||
for start, stop in zip(offsets[:-1], offsets[1:]):
|
||||
x_mt = _mt_spectra(x[start:stop], dpss, sfreq, remove_dc=remove_dc)[0]
|
||||
if output == "power":
|
||||
if not adaptive:
|
||||
psd[start:stop] = _psd_from_mt(x_mt[:, :, freq_mask], weights)
|
||||
else:
|
||||
parallel, my_psd_from_mt_adaptive, n_jobs = parallel_func(
|
||||
_psd_from_mt_adaptive, n_jobs
|
||||
)
|
||||
n_splits = min(stop - start, n_jobs)
|
||||
out = parallel(
|
||||
my_psd_from_mt_adaptive(x, eigvals, freq_mask, max_iter)
|
||||
for x in np.array_split(x_mt, n_splits)
|
||||
)
|
||||
psd[start:stop] = np.concatenate(out)
|
||||
else:
|
||||
psd[start:stop] = x_mt[:, :, freq_mask]
|
||||
|
||||
if normalization == "full":
|
||||
psd /= sfreq
|
||||
|
||||
# Combining/reshaping to original data shape
|
||||
last_dims = (n_freqs,) if output == "power" else (n_tapers, n_freqs)
|
||||
psd.shape = dshape + last_dims
|
||||
if ndim_in == 1:
|
||||
psd = psd[0]
|
||||
|
||||
if output == "complex":
|
||||
return psd, freqs, weights
|
||||
else:
|
||||
return psd, freqs
|
||||
|
||||
|
||||
@verbose
|
||||
def tfr_array_multitaper(
|
||||
data,
|
||||
sfreq,
|
||||
freqs,
|
||||
n_cycles=7.0,
|
||||
zero_mean=True,
|
||||
time_bandwidth=4.0,
|
||||
use_fft=True,
|
||||
decim=1,
|
||||
output="complex",
|
||||
n_jobs=None,
|
||||
*,
|
||||
verbose=None,
|
||||
):
|
||||
"""Compute Time-Frequency Representation (TFR) using DPSS tapers.
|
||||
|
||||
Same computation as `~mne.time_frequency.tfr_multitaper`, but operates on
|
||||
:class:`NumPy arrays <numpy.ndarray>` instead of `~mne.Epochs` or
|
||||
`~mne.Evoked` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array of shape (n_epochs, n_channels, n_times)
|
||||
The epochs.
|
||||
sfreq : float
|
||||
Sampling frequency of the data in Hz.
|
||||
%(freqs_tfr_array)s
|
||||
%(n_cycles_tfr)s
|
||||
zero_mean : bool
|
||||
If True, make sure the wavelets have a mean of zero. Defaults to True.
|
||||
%(time_bandwidth_tfr)s
|
||||
use_fft : bool
|
||||
Use the FFT for convolutions or not. Defaults to True.
|
||||
%(decim_tfr)s
|
||||
output : str, default 'complex'
|
||||
|
||||
* ``'complex'`` : single trial per taper complex values.
|
||||
* ``'power'`` : single trial power.
|
||||
* ``'phase'`` : single trial per taper phase.
|
||||
* ``'avg_power'`` : average of single trial power.
|
||||
* ``'itc'`` : inter-trial coherence.
|
||||
* ``'avg_power_itc'`` : average of single trial power and inter-trial
|
||||
coherence across trials.
|
||||
%(n_jobs)s
|
||||
The parallelization is implemented across channels.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : array
|
||||
Time frequency transform of ``data``.
|
||||
|
||||
- if ``output in ('complex',' 'phase')``, array of shape
|
||||
``(n_epochs, n_chans, n_tapers, n_freqs, n_times)``
|
||||
- if ``output`` is ``'power'``, array of shape ``(n_epochs, n_chans,
|
||||
n_freqs, n_times)``
|
||||
- else, array of shape ``(n_chans, n_freqs, n_times)``
|
||||
|
||||
If ``output`` is ``'avg_power_itc'``, the real values in ``out``
|
||||
contain the average power and the imaginary values contain the
|
||||
inter-trial coherence: :math:`out = power_{avg} + i * ITC`.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.time_frequency.tfr_multitaper
|
||||
mne.time_frequency.tfr_morlet
|
||||
mne.time_frequency.tfr_array_morlet
|
||||
mne.time_frequency.tfr_stockwell
|
||||
mne.time_frequency.tfr_array_stockwell
|
||||
|
||||
Notes
|
||||
-----
|
||||
%(temporal_window_tfr_intro)s
|
||||
%(temporal_window_tfr_multitaper_notes)s
|
||||
%(time_bandwidth_tfr_notes)s
|
||||
|
||||
.. versionadded:: 0.14.0
|
||||
"""
|
||||
from .tfr import _compute_tfr
|
||||
|
||||
return _compute_tfr(
|
||||
data,
|
||||
freqs,
|
||||
sfreq=sfreq,
|
||||
method="multitaper",
|
||||
n_cycles=n_cycles,
|
||||
zero_mean=zero_mean,
|
||||
time_bandwidth=time_bandwidth,
|
||||
use_fft=use_fft,
|
||||
decim=decim,
|
||||
output=output,
|
||||
n_jobs=n_jobs,
|
||||
verbose=verbose,
|
||||
)
|
||||
272
mne/time_frequency/psd.py
Normal file
272
mne/time_frequency/psd.py
Normal file
@@ -0,0 +1,272 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from scipy.signal import spectrogram
|
||||
|
||||
from ..parallel import parallel_func
|
||||
from ..utils import _check_option, _ensure_int, logger, verbose
|
||||
from ..utils.numerics import _mask_to_onsets_offsets
|
||||
|
||||
|
||||
# adapted from SciPy
|
||||
# https://github.com/scipy/scipy/blob/f71e7fad717801c4476312fe1e23f2dfbb4c9d7f/scipy/signal/_spectral_py.py#L2019 # noqa: E501
|
||||
def _median_biases(n):
|
||||
# Compute the biases for 0 to max(n, 1) terms included in a median calc
|
||||
biases = np.ones(n + 1)
|
||||
# The original SciPy code is:
|
||||
#
|
||||
# def _median_bias(n):
|
||||
# ii_2 = 2 * np.arange(1., (n - 1) // 2 + 1)
|
||||
# return 1 + np.sum(1. / (ii_2 + 1) - 1. / ii_2)
|
||||
#
|
||||
# This is a sum over (n-1)//2 terms.
|
||||
# The ii_2 terms here for different n are:
|
||||
#
|
||||
# n=0: [] # 0 terms
|
||||
# n=1: [] # 0 terms
|
||||
# n=2: [] # 0 terms
|
||||
# n=3: [2] # 1 term
|
||||
# n=4: [2] # 1 term
|
||||
# n=5: [2, 4] # 2 terms
|
||||
# n=6: [2, 4] # 2 terms
|
||||
# ...
|
||||
#
|
||||
# We can get the terms for 0 through n using a cumulative summation and
|
||||
# indexing:
|
||||
if n >= 3:
|
||||
ii_2 = 2 * np.arange(1, (n - 1) // 2 + 1)
|
||||
sums = 1 + np.cumsum(1.0 / (ii_2 + 1) - 1.0 / ii_2)
|
||||
idx = np.arange(2, n) // 2 - 1
|
||||
biases[3:] = sums[idx]
|
||||
return biases
|
||||
|
||||
|
||||
def _decomp_aggregate_mask(epoch, func, average, freq_sl):
|
||||
_, _, spect = func(epoch)
|
||||
spect = spect[..., freq_sl, :]
|
||||
# Do the averaging here (per epoch) to save memory
|
||||
if average == "mean":
|
||||
spect = np.nanmean(spect, axis=-1)
|
||||
elif average == "median":
|
||||
biases = _median_biases(spect.shape[-1])
|
||||
idx = (~np.isnan(spect)).sum(-1)
|
||||
spect = np.nanmedian(spect, axis=-1) / biases[idx]
|
||||
return spect
|
||||
|
||||
|
||||
def _spect_func(epoch, func, freq_sl, average, *, output="power"):
|
||||
"""Aux function."""
|
||||
# Decide if we should split this to save memory or not, since doing
|
||||
# multiple calls will incur some performance overhead. Eventually we might
|
||||
# want to write (really, go back to) our own spectrogram implementation
|
||||
# that, if possible, averages after each transform, but this will incur
|
||||
# a lot of overhead because of the many Python calls required.
|
||||
kwargs = dict(func=func, average=average, freq_sl=freq_sl)
|
||||
if epoch.nbytes > 10e6:
|
||||
spect = np.apply_along_axis(_decomp_aggregate_mask, -1, epoch, **kwargs)
|
||||
else:
|
||||
spect = _decomp_aggregate_mask(epoch, **kwargs)
|
||||
return spect
|
||||
|
||||
|
||||
def _check_nfft(n, n_fft, n_per_seg, n_overlap):
|
||||
"""Ensure n_fft, n_per_seg and n_overlap make sense."""
|
||||
if n_per_seg is None and n_fft > n:
|
||||
raise ValueError(
|
||||
"If n_per_seg is None n_fft is not allowed to be > "
|
||||
"n_times. If you want zero-padding, you have to set "
|
||||
f"n_per_seg to relevant length. Got n_fft of {n_fft} while"
|
||||
f" signal length is {n}."
|
||||
)
|
||||
n_per_seg = n_fft if n_per_seg is None or n_per_seg > n_fft else n_per_seg
|
||||
n_per_seg = n if n_per_seg > n else n_per_seg
|
||||
if n_overlap >= n_per_seg:
|
||||
raise ValueError(
|
||||
"n_overlap cannot be greater than n_per_seg (or n_fft). Got n_overlap "
|
||||
f"of {n_overlap} while n_per_seg is {n_per_seg}."
|
||||
)
|
||||
return n_fft, n_per_seg, n_overlap
|
||||
|
||||
|
||||
@verbose
|
||||
def psd_array_welch(
|
||||
x,
|
||||
sfreq,
|
||||
fmin=0,
|
||||
fmax=np.inf,
|
||||
n_fft=256,
|
||||
n_overlap=0,
|
||||
n_per_seg=None,
|
||||
n_jobs=None,
|
||||
average="mean",
|
||||
window="hamming",
|
||||
remove_dc=True,
|
||||
*,
|
||||
output="power",
|
||||
verbose=None,
|
||||
):
|
||||
"""Compute power spectral density (PSD) using Welch's method.
|
||||
|
||||
Welch's method is described in :footcite:t:`Welch1967`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : array, shape=(..., n_times)
|
||||
The data to compute PSD from.
|
||||
sfreq : float
|
||||
The sampling frequency.
|
||||
fmin : float
|
||||
The lower frequency of interest.
|
||||
fmax : float
|
||||
The upper frequency of interest.
|
||||
n_fft : int
|
||||
The length of FFT used, must be ``>= n_per_seg`` (default: 256).
|
||||
The segments will be zero-padded if ``n_fft > n_per_seg``.
|
||||
n_overlap : int
|
||||
The number of points of overlap between segments. Will be adjusted
|
||||
to be <= n_per_seg. The default value is 0.
|
||||
n_per_seg : int | None
|
||||
Length of each Welch segment (windowed with a Hamming window). Defaults
|
||||
to None, which sets n_per_seg equal to n_fft.
|
||||
%(n_jobs)s
|
||||
%(average_psd)s
|
||||
|
||||
.. versionadded:: 0.19.0
|
||||
%(window_psd)s
|
||||
|
||||
.. versionadded:: 0.22.0
|
||||
%(remove_dc)s
|
||||
|
||||
output : str
|
||||
The format of the returned ``psds`` array, ``'complex'`` or
|
||||
``'power'``:
|
||||
|
||||
* ``'power'`` : the power spectral density is returned.
|
||||
* ``'complex'`` : the complex fourier coefficients are returned per
|
||||
window.
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
psds : ndarray, shape (..., n_freqs) or (..., n_freqs, n_segments)
|
||||
The power spectral densities. If ``average='mean`` or
|
||||
``average='median'``, the returned array will have the same shape
|
||||
as the input data plus an additional frequency dimension.
|
||||
If ``average=None``, the returned array will have the same shape as
|
||||
the input data plus two additional dimensions corresponding to
|
||||
frequencies and the unaggregated segments, respectively.
|
||||
freqs : ndarray, shape (n_freqs,)
|
||||
The frequencies.
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 0.14.0
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
_check_option("average", average, (None, False, "mean", "median"))
|
||||
_check_option("output", output, ("power", "complex"))
|
||||
detrend = "constant" if remove_dc else False
|
||||
mode = "complex" if output == "complex" else "psd"
|
||||
n_fft = _ensure_int(n_fft, "n_fft")
|
||||
n_overlap = _ensure_int(n_overlap, "n_overlap")
|
||||
if n_per_seg is not None:
|
||||
n_per_seg = _ensure_int(n_per_seg, "n_per_seg")
|
||||
if average is False:
|
||||
average = None
|
||||
|
||||
dshape = x.shape[:-1]
|
||||
n_times = x.shape[-1]
|
||||
x = x.reshape(-1, n_times)
|
||||
|
||||
# Prep the PSD
|
||||
n_fft, n_per_seg, n_overlap = _check_nfft(n_times, n_fft, n_per_seg, n_overlap)
|
||||
win_size = n_fft / float(sfreq)
|
||||
logger.info(f"Effective window size : {win_size:0.3f} (s)")
|
||||
freqs = np.arange(n_fft // 2 + 1, dtype=float) * (sfreq / n_fft)
|
||||
freq_mask = (freqs >= fmin) & (freqs <= fmax)
|
||||
if not freq_mask.any():
|
||||
raise ValueError(f"No frequencies found between fmin={fmin} and fmax={fmax}")
|
||||
freq_sl = slice(*(np.where(freq_mask)[0][[0, -1]] + [0, 1]))
|
||||
del freq_mask
|
||||
freqs = freqs[freq_sl]
|
||||
|
||||
# Parallelize across first N-1 dimensions
|
||||
logger.debug(
|
||||
f"Spectogram using {n_fft}-point FFT on {n_per_seg} samples with "
|
||||
f"{n_overlap} overlap and {window} window"
|
||||
)
|
||||
|
||||
parallel, my_spect_func, n_jobs = parallel_func(_spect_func, n_jobs=n_jobs)
|
||||
_func = partial(
|
||||
spectrogram,
|
||||
detrend=detrend,
|
||||
noverlap=n_overlap,
|
||||
nperseg=n_per_seg,
|
||||
nfft=n_fft,
|
||||
fs=sfreq,
|
||||
window=window,
|
||||
mode=mode,
|
||||
)
|
||||
if np.any(np.isnan(x)):
|
||||
good_mask = ~np.isnan(x)
|
||||
# NaNs originate from annot, so must match for all channels. Note that we CANNOT
|
||||
# use np.testing.assert_allclose() here; it is strict about shapes/broadcasting
|
||||
assert np.allclose(good_mask, good_mask[[0]], equal_nan=True)
|
||||
t_onsets, t_offsets = _mask_to_onsets_offsets(good_mask[0])
|
||||
x_splits = [x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets)]
|
||||
# weights reflect the number of samples used from each span. For spans longer
|
||||
# than `n_per_seg`, trailing samples may be discarded. For spans shorter than
|
||||
# `n_per_seg`, the wrapped function (`scipy.signal.spectrogram`) automatically
|
||||
# reduces `n_per_seg` to match the span length (with a warning).
|
||||
step = n_per_seg - n_overlap
|
||||
span_lengths = [span.shape[-1] for span in x_splits]
|
||||
weights = [
|
||||
w if w < n_per_seg else w - ((w - n_overlap) % step) for w in span_lengths
|
||||
]
|
||||
agg_func = partial(np.average, weights=weights)
|
||||
if n_jobs > 1:
|
||||
logger.info(
|
||||
f"Data split into {len(x_splits)} (probably unequal) chunks due to "
|
||||
'"bad_*" annotations. Parallelization may be sub-optimal.'
|
||||
)
|
||||
if (np.array(span_lengths) < n_per_seg).any():
|
||||
logger.info(
|
||||
"At least one good data span is shorter than n_per_seg, and will be "
|
||||
"analyzed with a shorter window than the rest of the file."
|
||||
)
|
||||
|
||||
def func(*args, **kwargs):
|
||||
# swallow SciPy warnings caused by short good data spans
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
action="ignore",
|
||||
module="scipy",
|
||||
category=UserWarning,
|
||||
message=r"nperseg = \d+ is greater than input length",
|
||||
)
|
||||
return _func(*args, **kwargs)
|
||||
|
||||
else:
|
||||
x_splits = [arr for arr in np.array_split(x, n_jobs) if arr.size != 0]
|
||||
agg_func = np.concatenate
|
||||
func = _func
|
||||
f_spect = parallel(
|
||||
my_spect_func(d, func=func, freq_sl=freq_sl, average=average, output=output)
|
||||
for d in x_splits
|
||||
)
|
||||
psds = agg_func(f_spect, axis=0)
|
||||
shape = dshape + (len(freqs),)
|
||||
if average is None:
|
||||
shape = shape + (-1,)
|
||||
psds.shape = shape
|
||||
return psds, freqs
|
||||
1716
mne/time_frequency/spectrum.py
Normal file
1716
mne/time_frequency/spectrum.py
Normal file
File diff suppressed because it is too large
Load Diff
4216
mne/time_frequency/tfr.py
Normal file
4216
mne/time_frequency/tfr.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user