initial commit

This commit is contained in:
2025-08-19 09:13:22 -07:00
parent 28464811d6
commit 0977a3e14d
820 changed files with 1003358 additions and 2 deletions

View File

@@ -0,0 +1,8 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
"""Time frequency analysis tools."""
import lazy_loader as lazy
(__getattr__, __dir__, __all__) = lazy.attach_stub(__name__, __file__)

View File

@@ -0,0 +1,81 @@
__all__ = [
"AverageTFR",
"AverageTFRArray",
"BaseTFR",
"CrossSpectralDensity",
"EpochsSpectrum",
"EpochsSpectrumArray",
"EpochsTFR",
"EpochsTFRArray",
"RawTFR",
"RawTFRArray",
"Spectrum",
"SpectrumArray",
"csd_array_fourier",
"csd_array_morlet",
"csd_array_multitaper",
"csd_fourier",
"csd_morlet",
"csd_multitaper",
"csd_tfr",
"dpss_windows",
"fit_iir_model_raw",
"fwhm",
"istft",
"morlet",
"pick_channels_csd",
"psd_array_multitaper",
"psd_array_welch",
"read_csd",
"read_spectrum",
"read_tfrs",
"stft",
"stftfreq",
"tfr_array_morlet",
"tfr_array_multitaper",
"tfr_array_stockwell",
"tfr_morlet",
"tfr_multitaper",
"tfr_stockwell",
"write_tfrs",
]
from ._stft import istft, stft, stftfreq
from ._stockwell import tfr_array_stockwell, tfr_stockwell
from .ar import fit_iir_model_raw
from .csd import (
CrossSpectralDensity,
csd_array_fourier,
csd_array_morlet,
csd_array_multitaper,
csd_fourier,
csd_morlet,
csd_multitaper,
csd_tfr,
pick_channels_csd,
read_csd,
)
from .multitaper import dpss_windows, psd_array_multitaper, tfr_array_multitaper
from .psd import psd_array_welch
from .spectrum import (
EpochsSpectrum,
EpochsSpectrumArray,
Spectrum,
SpectrumArray,
read_spectrum,
)
from .tfr import (
AverageTFR,
AverageTFRArray,
BaseTFR,
EpochsTFR,
EpochsTFRArray,
RawTFR,
RawTFRArray,
fwhm,
morlet,
read_tfrs,
tfr_array_morlet,
tfr_morlet,
tfr_multitaper,
write_tfrs,
)

261
mne/time_frequency/_stft.py Normal file
View File

@@ -0,0 +1,261 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from math import ceil
import numpy as np
from scipy.fft import irfft, rfft, rfftfreq
from ..utils import logger, verbose
@verbose
def stft(x, wsize, tstep=None, verbose=None):
"""STFT Short-Term Fourier Transform using a sine window.
The transformation is designed to be a tight frame that can be
perfectly inverted. It only returns the positive frequencies.
Parameters
----------
x : array, shape (n_signals, n_times)
Containing multi-channels signal.
wsize : int
Length of the STFT window in samples (must be a multiple of 4).
tstep : int
Step between successive windows in samples (must be a multiple of 2,
a divider of wsize and smaller than wsize/2) (default: wsize/2).
%(verbose)s
Returns
-------
X : array, shape (n_signals, wsize // 2 + 1, n_step)
STFT coefficients for positive frequencies with
``n_step = ceil(T / tstep)``.
See Also
--------
istft
stftfreq
"""
if not np.isrealobj(x):
raise ValueError("x is not a real valued array")
if x.ndim == 1:
x = x[None, :]
n_signals, T = x.shape
wsize = int(wsize)
# Errors and warnings
if wsize % 4:
raise ValueError("The window length must be a multiple of 4.")
if tstep is None:
tstep = wsize / 2
tstep = int(tstep)
if (wsize % tstep) or (tstep % 2):
raise ValueError(
"The step size must be a multiple of 2 and a "
"divider of the window length."
)
if tstep > wsize / 2:
raise ValueError("The step size must be smaller than half the window length.")
n_step = int(ceil(T / float(tstep)))
n_freq = wsize // 2 + 1
logger.info(f"Number of frequencies: {n_freq}")
logger.info(f"Number of time steps: {n_step}")
X = np.zeros((n_signals, n_freq, n_step), dtype=np.complex128)
if n_signals == 0:
return X
# Defining sine window
win = np.sin(np.arange(0.5, wsize + 0.5) / wsize * np.pi)
win2 = win**2
swin = np.zeros((n_step - 1) * tstep + wsize)
for t in range(n_step):
swin[t * tstep : t * tstep + wsize] += win2
swin = np.sqrt(wsize * swin)
# Zero-padding and Pre-processing for edges
xp = np.zeros((n_signals, wsize + (n_step - 1) * tstep), dtype=x.dtype)
xp[:, (wsize - tstep) // 2 : (wsize - tstep) // 2 + T] = x
x = xp
for t in range(n_step):
# Framing
wwin = win / swin[t * tstep : t * tstep + wsize]
frame = x[:, t * tstep : t * tstep + wsize] * wwin[None, :]
# FFT
X[:, :, t] = rfft(frame)
return X
def istft(X, tstep=None, Tx=None):
"""ISTFT Inverse Short-Term Fourier Transform using a sine window.
Parameters
----------
X : array, shape (..., wsize / 2 + 1, n_step)
The STFT coefficients for positive frequencies.
tstep : int
Step between successive windows in samples (must be a multiple of 2,
a divider of wsize and smaller than wsize/2) (default: wsize/2).
Tx : int
Length of returned signal. If None Tx = n_step * tstep.
Returns
-------
x : array, shape (Tx,)
Array containing the inverse STFT signal.
See Also
--------
stft
"""
# Errors and warnings
X = np.asarray(X)
if X.ndim < 2:
raise ValueError(f"X must have ndim >= 2, got {X.ndim}")
n_win, n_step = X.shape[-2:]
signal_shape = X.shape[:-2]
if n_win % 2 == 0:
raise ValueError("The number of rows of the STFT matrix must be odd.")
wsize = 2 * (n_win - 1)
if tstep is None:
tstep = wsize / 2
if wsize % tstep:
raise ValueError(
"The step size must be a divider of two times the "
"number of rows of the STFT matrix minus two."
)
if wsize % 2:
raise ValueError("The step size must be a multiple of 2.")
if tstep > wsize / 2:
raise ValueError(
"The step size must be smaller than the number of "
"rows of the STFT matrix minus one."
)
if Tx is None:
Tx = n_step * tstep
T = n_step * tstep
x = np.zeros(signal_shape + (T + wsize - tstep,), dtype=np.float64)
if np.prod(signal_shape) == 0:
return x[..., :Tx]
# Defining sine window
win = np.sin(np.arange(0.5, wsize + 0.5) / wsize * np.pi)
# win = win / norm(win);
# Pre-processing for edges
swin = np.zeros(T + wsize - tstep, dtype=np.float64)
for t in range(n_step):
swin[t * tstep : t * tstep + wsize] += win**2
swin = np.sqrt(swin / wsize)
for t in range(n_step):
# IFFT
frame = irfft(X[..., t], wsize)
# Overlap-add
frame *= win / swin[t * tstep : t * tstep + wsize]
x[..., t * tstep : t * tstep + wsize] += frame
# Truncation
x = x[..., (wsize - tstep) // 2 : (wsize - tstep) // 2 + T + 1]
x = x[..., :Tx].copy()
return x
def stftfreq(wsize, sfreq=None): # noqa: D401
"""Compute frequencies of stft transformation.
Parameters
----------
wsize : int
Size of stft window.
sfreq : float
Sampling frequency. If None the frequencies are given between 0 and pi
otherwise it's given in Hz.
Returns
-------
freqs : array
The positive frequencies returned by stft.
See Also
--------
stft
istft
"""
freqs = rfftfreq(wsize)
if sfreq is not None:
freqs *= float(sfreq)
return freqs
def stft_norm2(X):
"""Compute L2 norm of STFT transform.
It takes into account that stft only return positive frequencies.
As we use tight frame this quantity is conserved by the stft.
Parameters
----------
X : 3D complex array
The STFT transforms
Returns
-------
norms2 : array
The squared L2 norm of every row of X.
"""
X2 = (X * X.conj()).real
# compute all L2 coefs and remove first and last frequency once.
norms2 = (
2.0 * X2.sum(axis=2).sum(axis=1)
- np.sum(X2[:, 0, :], axis=1)
- np.sum(X2[:, -1, :], axis=1)
)
return norms2
def stft_norm1(X):
"""Compute L1 norm of STFT transform.
It takes into account that stft only return positive frequencies.
Parameters
----------
X : 3D complex array
The STFT transforms
Returns
-------
norms : array
The L1 norm of every row of X.
"""
X_abs = np.abs(X)
# compute all L1 coefs and remove first and last frequency once.
norms = (
2.0 * X_abs.sum(axis=(1, 2))
- np.sum(X_abs[:, 0, :], axis=1)
- np.sum(X_abs[:, -1, :], axis=1)
)
return norms

View File

@@ -0,0 +1,322 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from copy import deepcopy
import numpy as np
from scipy.fft import fft, fftfreq, ifft
from .._fiff.pick import _pick_data_channels, pick_info
from ..parallel import parallel_func
from ..utils import _validate_type, legacy, logger, verbose
from .tfr import AverageTFRArray, _ensure_slice, _get_data
def _check_input_st(x_in, n_fft):
"""Aux function."""
# flatten to 2 D and memorize original shape
n_times = x_in.shape[-1]
def _is_power_of_two(n):
return not (n > 0 and (n & (n - 1)))
if n_fft is None or (not _is_power_of_two(n_fft) and n_times > n_fft):
# Compute next power of 2
n_fft = 2 ** int(np.ceil(np.log2(n_times)))
elif n_fft < n_times:
raise ValueError(
f"n_fft cannot be smaller than signal size. Got {n_fft} < {n_times}."
)
if n_times < n_fft:
logger.info(
f'The input signal is shorter ({x_in.shape[-1]}) than "n_fft" ({n_fft}). '
"Applying zero padding."
)
zero_pad = n_fft - n_times
pad_array = np.zeros(x_in.shape[:-1] + (zero_pad,), x_in.dtype)
x_in = np.concatenate((x_in, pad_array), axis=-1)
else:
zero_pad = 0
return x_in, n_fft, zero_pad
def _precompute_st_windows(n_samp, start_f, stop_f, sfreq, width):
"""Precompute stockwell Gaussian windows (in the freq domain)."""
tw = fftfreq(n_samp, 1.0 / sfreq) / n_samp
tw = np.r_[tw[:1], tw[1:][::-1]]
k = width # 1 for classical stowckwell transform
f_range = np.arange(start_f, stop_f, 1)
windows = np.empty((len(f_range), len(tw)), dtype=np.complex128)
for i_f, f in enumerate(f_range):
if f == 0.0:
window = np.ones(len(tw))
else:
window = (f / (np.sqrt(2.0 * np.pi) * k)) * np.exp(
-0.5 * (1.0 / k**2.0) * (f**2.0) * tw**2.0
)
window /= window.sum() # normalisation
windows[i_f] = fft(window)
return windows
def _st(x, start_f, windows):
"""Compute ST based on Ali Moukadem MATLAB code (used in tests)."""
from scipy.fft import fft, ifft
n_samp = x.shape[-1]
ST = np.empty(x.shape[:-1] + (len(windows), n_samp), dtype=np.complex128)
# do the work
Fx = fft(x)
XF = np.concatenate([Fx, Fx], axis=-1)
for i_f, window in enumerate(windows):
f = start_f + i_f
ST[..., i_f, :] = ifft(XF[..., f : f + n_samp] * window)
return ST
def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W):
"""Aux function."""
decim = _ensure_slice(decim)
n_samp = x.shape[-1]
decim_indices = decim.indices(n_samp - zero_pad)
n_out = len(range(*decim_indices))
psd = np.empty((len(W), n_out))
itc = np.empty_like(psd) if compute_itc else None
X = fft(x)
XX = np.concatenate([X, X], axis=-1)
for i_f, window in enumerate(W):
f = start_f + i_f
ST = ifft(XX[:, f : f + n_samp] * window)
TFR = ST[:, slice(*decim_indices)]
TFR_abs = np.abs(TFR)
TFR_abs[TFR_abs == 0] = 1.0
if compute_itc:
TFR /= TFR_abs
itc[i_f] = np.abs(np.mean(TFR, axis=0))
TFR_abs *= TFR_abs
psd[i_f] = np.mean(TFR_abs, axis=0)
return psd, itc
def _compute_freqs_st(fmin, fmax, n_fft, sfreq):
from scipy.fft import fftfreq
freqs = fftfreq(n_fft, 1.0 / sfreq)
if fmin is None:
fmin = freqs[freqs > 0][0]
if fmax is None:
fmax = freqs.max()
start_f = np.abs(freqs - fmin).argmin()
stop_f = np.abs(freqs - fmax).argmin()
freqs = freqs[start_f:stop_f]
return start_f, stop_f, freqs
@verbose
def tfr_array_stockwell(
data,
sfreq,
fmin=None,
fmax=None,
n_fft=None,
width=1.0,
decim=1,
return_itc=False,
n_jobs=None,
*,
verbose=None,
):
"""Compute power and intertrial coherence using Stockwell (S) transform.
Same computation as `~mne.time_frequency.tfr_stockwell`, but operates on
:class:`NumPy arrays <numpy.ndarray>` instead of `~mne.Epochs` objects.
See :footcite:`Stockwell2007,MoukademEtAl2014,WheatEtAl2010,JonesEtAl2006`
for more information.
Parameters
----------
data : ndarray, shape (n_epochs, n_channels, n_times)
The signal to transform.
sfreq : float
The sampling frequency.
fmin : None, float
The minimum frequency to include. If None defaults to the minimum fft
frequency greater than zero.
fmax : None, float
The maximum frequency to include. If None defaults to the maximum fft.
n_fft : int | None
The length of the windows used for FFT. If None, it defaults to the
next power of 2 larger than the signal length.
width : float
The width of the Gaussian window. If < 1, increased temporal
resolution, if > 1, increased frequency resolution. Defaults to 1.
(classical S-Transform).
%(decim_tfr)s
return_itc : bool
Return intertrial coherence (ITC) as well as averaged power.
%(n_jobs)s
%(verbose)s
Returns
-------
st_power : ndarray
The multitaper power of the Stockwell transformed data.
The last two dimensions are frequency and time.
itc : ndarray
The intertrial coherence. Only returned if return_itc is True.
freqs : ndarray
The frequencies.
See Also
--------
mne.time_frequency.tfr_stockwell
mne.time_frequency.tfr_multitaper
mne.time_frequency.tfr_array_multitaper
mne.time_frequency.tfr_morlet
mne.time_frequency.tfr_array_morlet
References
----------
.. footbibliography::
"""
_validate_type(data, np.ndarray, "data")
if data.ndim != 3:
raise ValueError(
"data must be 3D with shape (n_epochs, n_channels, n_times), "
f"got {data.shape}"
)
decim = _ensure_slice(decim)
_, n_channels, n_out = data[..., decim].shape
data, n_fft_, zero_pad = _check_input_st(data, n_fft)
start_f, stop_f, freqs = _compute_freqs_st(fmin, fmax, n_fft_, sfreq)
W = _precompute_st_windows(data.shape[-1], start_f, stop_f, sfreq, width)
n_freq = stop_f - start_f
psd = np.empty((n_channels, n_freq, n_out))
itc = np.empty((n_channels, n_freq, n_out)) if return_itc else None
parallel, my_st, n_jobs = parallel_func(_st_power_itc, n_jobs, verbose=verbose)
tfrs = parallel(
my_st(data[:, c, :], start_f, return_itc, zero_pad, decim, W)
for c in range(n_channels)
)
for c, (this_psd, this_itc) in enumerate(iter(tfrs)):
psd[c] = this_psd
if this_itc is not None:
itc[c] = this_itc
return psd, itc, freqs
@legacy(alt='.compute_tfr(method="stockwell", freqs="auto")')
@verbose
def tfr_stockwell(
inst,
fmin=None,
fmax=None,
n_fft=None,
width=1.0,
decim=1,
return_itc=False,
n_jobs=None,
verbose=None,
):
"""Compute Time-Frequency Representation (TFR) using Stockwell Transform.
Same computation as `~mne.time_frequency.tfr_array_stockwell`, but operates
on `~mne.Epochs` objects instead of :class:`NumPy arrays <numpy.ndarray>`.
See :footcite:`Stockwell2007,MoukademEtAl2014,WheatEtAl2010,JonesEtAl2006`
for more information.
Parameters
----------
inst : Epochs | Evoked
The epochs or evoked object.
fmin : None, float
The minimum frequency to include. If None defaults to the minimum fft
frequency greater than zero.
fmax : None, float
The maximum frequency to include. If None defaults to the maximum fft.
n_fft : int | None
The length of the windows used for FFT. If None, it defaults to the
next power of 2 larger than the signal length.
width : float
The width of the Gaussian window. If < 1, increased temporal
resolution, if > 1, increased frequency resolution. Defaults to 1.
(classical S-Transform).
decim : int
The decimation factor on the time axis. To reduce memory usage.
return_itc : bool
Return intertrial coherence (ITC) as well as averaged power.
n_jobs : int
The number of jobs to run in parallel (over channels).
%(verbose)s
Returns
-------
power : AverageTFR
The averaged power.
itc : AverageTFR
The intertrial coherence. Only returned if return_itc is True.
See Also
--------
mne.time_frequency.tfr_array_stockwell
mne.time_frequency.tfr_multitaper
mne.time_frequency.tfr_array_multitaper
mne.time_frequency.tfr_morlet
mne.time_frequency.tfr_array_morlet
Notes
-----
.. versionadded:: 0.9.0
References
----------
.. footbibliography::
"""
# verbose dec is used b/c subfunctions are verbose
data = _get_data(inst, return_itc)
picks = _pick_data_channels(inst.info)
info = pick_info(inst.info, picks)
data = data[:, picks, :]
decim = _ensure_slice(decim)
power, itc, freqs = tfr_array_stockwell(
data,
sfreq=info["sfreq"],
fmin=fmin,
fmax=fmax,
n_fft=n_fft,
width=width,
decim=decim,
return_itc=return_itc,
n_jobs=n_jobs,
)
times = inst.times[decim].copy()
nave = len(data)
out = AverageTFRArray(
info=info,
data=power,
times=times,
freqs=freqs,
nave=nave,
method="stockwell-power",
)
if return_itc:
out = (
out,
AverageTFRArray(
info=deepcopy(info),
data=itc,
times=times.copy(),
freqs=freqs.copy(),
nave=nave,
method="stockwell-itc",
),
)
return out

76
mne/time_frequency/ar.py Normal file
View File

@@ -0,0 +1,76 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from scipy import linalg
from .._fiff.pick import _picks_by_type, _picks_to_idx, pick_info
from ..defaults import _handle_default
from ..utils import _apply_scaling_array, verbose
def _yule_walker(X, order=1):
"""Compute Yule-Walker (adapted from statsmodels).
Operates in-place.
"""
assert X.ndim == 2
denom = X.shape[-1] - np.arange(order + 1)
r = np.zeros(order + 1, np.float64)
for di, d in enumerate(X):
d -= d.mean()
r[0] += np.dot(d, d)
for k in range(1, order + 1):
r[k] += np.dot(d[0:-k], d[k:])
r /= denom * len(X)
rho = linalg.solve(linalg.toeplitz(r[:-1]), r[1:])
sigmasq = r[0] - (r[1:] * rho).sum()
return rho, np.sqrt(sigmasq)
@verbose
def fit_iir_model_raw(raw, order=2, picks=None, tmin=None, tmax=None, verbose=None):
r"""Fit an AR model to raw data and creates the corresponding IIR filter.
The computed filter is fitted to data from all of the picked channels,
with frequency response given by the standard IIR formula:
.. math::
H(e^{jw}) = \frac{1}{a[0] + a[1]e^{-jw} + ... + a[n]e^{-jnw}}
Parameters
----------
raw : Raw object
An instance of Raw.
order : int
Order of the FIR filter.
%(picks_good_data)s
tmin : float
The beginning of time interval in seconds.
tmax : float
The end of time interval in seconds.
%(verbose)s
Returns
-------
b : ndarray
Numerator filter coefficients.
a : ndarray
Denominator filter coefficients.
"""
start, stop = None, None
if tmin is not None:
start = raw.time_as_index(tmin)[0]
if tmax is not None:
stop = raw.time_as_index(tmax)[0] + 1
picks = _picks_to_idx(raw.info, picks)
data = raw[picks, start:stop][0]
# rescale data to similar levels
picks_list = _picks_by_type(pick_info(raw.info, picks))
scalings = _handle_default("scalings_cov_rank", None)
_apply_scaling_array(data, picks_list=picks_list, scalings=scalings)
# do the fitting
coeffs, _ = _yule_walker(data, order=order)
return np.array([1.0]), np.concatenate(([1.0], -coeffs))

1607
mne/time_frequency/csd.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,555 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
# Parts of this code were copied from NiTime http://nipy.sourceforge.net/nitime
import numpy as np
from scipy.fft import rfft, rfftfreq
from scipy.integrate import trapezoid
from scipy.signal import get_window
from scipy.signal.windows import dpss as sp_dpss
from ..parallel import parallel_func
from ..utils import _check_option, logger, verbose, warn
def dpss_windows(N, half_nbw, Kmax, *, sym=True, norm=None, low_bias=True):
"""Compute Discrete Prolate Spheroidal Sequences.
Will give of orders [0,Kmax-1] for a given frequency-spacing multiple
NW and sequence length N.
.. note:: Copied from NiTime.
Parameters
----------
N : int
Sequence length.
half_nbw : float
Standardized half bandwidth corresponding to 2 * half_bw = BW*f0
= BW*N/dt but with dt taken as 1.
Kmax : int
Number of DPSS windows to return is Kmax (orders 0 through Kmax-1).
sym : bool
Whether to generate a symmetric window (``True``, for filter design) or
a periodic window (``False``, for spectral analysis). Default is
``True``.
.. versionadded:: 1.3
norm : 2 | ``'approximate'`` | ``'subsample'`` | None
Window normalization method. If ``'approximate'`` or ``'subsample'``,
windows are normalized by the maximum, and a correction scale-factor
for even-length windows is applied either using
``N**2/(N**2+half_nbw)`` ("approximate") or a FFT-based subsample shift
("subsample"). ``2`` uses the L2 norm. ``None`` (the default) uses
``"approximate"`` when ``Kmax=None`` and ``2`` otherwise.
.. versionadded:: 1.3
low_bias : bool
Keep only tapers with eigenvalues > 0.9.
Returns
-------
v, e : tuple,
The v array contains DPSS windows shaped (Kmax, N).
e are the eigenvalues.
Notes
-----
Tridiagonal form of DPSS calculation from :footcite:`Slepian1978`.
References
----------
.. footbibliography::
"""
dpss, eigvals = sp_dpss(N, half_nbw, Kmax, sym=sym, norm=norm, return_ratios=True)
if low_bias:
idx = eigvals > 0.9
if not idx.any():
warn("Could not properly use low_bias, keeping lowest-bias taper")
idx = [np.argmax(eigvals)]
dpss, eigvals = dpss[idx], eigvals[idx]
assert len(dpss) > 0 # should never happen
assert dpss.shape[1] == N # old nitime bug
return dpss, eigvals
def _psd_from_mt_adaptive(x_mt, eigvals, freq_mask, max_iter=250, return_weights=False):
r"""Use iterative procedure to compute the PSD from tapered spectra.
.. note:: Modified from NiTime.
Parameters
----------
x_mt : array, shape=(n_signals, n_tapers, n_freqs)
The DFTs of the tapered sequences (only positive frequencies)
eigvals : array, length n_tapers
The eigenvalues of the DPSS tapers
freq_mask : array
Frequency indices to keep
max_iter : int
Maximum number of iterations for weight computation.
return_weights : bool
Also return the weights
Returns
-------
psd : array, shape=(n_signals, np.sum(freq_mask))
The computed PSDs
weights : array shape=(n_signals, n_tapers, np.sum(freq_mask))
The weights used to combine the tapered spectra
Notes
-----
The weights to use for making the multitaper estimate, such that
:math:`S_{mt} = \sum_{k} |w_k|^2S_k^{mt} / \sum_{k} |w_k|^2`
"""
n_signals, n_tapers, n_freqs = x_mt.shape
if len(eigvals) != n_tapers:
raise ValueError("Need one eigenvalue for each taper")
if n_tapers < 3:
raise ValueError("Not enough tapers to compute adaptive weights.")
rt_eig = np.sqrt(eigvals)
# estimate the variance from an estimate with fixed weights
psd_est = _psd_from_mt(x_mt, rt_eig[np.newaxis, :, np.newaxis])
x_var = trapezoid(psd_est, dx=np.pi / n_freqs) / (2 * np.pi)
del psd_est
# allocate space for output
psd = np.empty((n_signals, np.sum(freq_mask)))
# only keep the frequencies of interest
x_mt = x_mt[:, :, freq_mask]
if return_weights:
weights = np.empty((n_signals, n_tapers, psd.shape[1]))
for i, (xk, var) in enumerate(zip(x_mt, x_var)):
# combine the SDFs in the traditional way in order to estimate
# the variance of the timeseries
# The process is to iteratively switch solving for the following
# two expressions:
# (1) Adaptive Multitaper SDF:
# S^{mt}(f) = [ sum |d_k(f)|^2 S_k(f) ]/ sum |d_k(f)|^2
#
# (2) Weights
# d_k(f) = [sqrt(lam_k) S^{mt}(f)] / [lam_k S^{mt}(f) + E{B_k(f)}]
#
# Where lam_k are the eigenvalues corresponding to the DPSS tapers,
# and the expected value of the broadband bias function
# E{B_k(f)} is replaced by its full-band integration
# (1/2pi) int_{-pi}^{pi} E{B_k(f)} = sig^2(1-lam_k)
# start with an estimate from incomplete data--the first 2 tapers
psd_iter = _psd_from_mt(xk[:2, :], rt_eig[:2, np.newaxis])
err = np.zeros_like(xk)
for n in range(max_iter):
d_k = psd_iter / (
eigvals[:, np.newaxis] * psd_iter + (1 - eigvals[:, np.newaxis]) * var
)
d_k *= rt_eig[:, np.newaxis]
# Test for convergence -- this is overly conservative, since
# iteration only stops when all frequencies have converged.
# A better approach is to iterate separately for each freq, but
# that is a nonvectorized algorithm.
# Take the RMS difference in weights from the previous iterate
# across frequencies. If the maximum RMS error across freqs is
# less than 1e-10, then we're converged
err -= d_k
if np.max(np.mean(err**2, axis=0)) < 1e-10:
break
# update the iterative estimate with this d_k
psd_iter = _psd_from_mt(xk, d_k)
err = d_k
if n == max_iter - 1:
warn("Iterative multi-taper PSD computation did not converge.")
psd[i, :] = psd_iter
if return_weights:
weights[i, :, :] = d_k
if return_weights:
return psd, weights
else:
return psd
def _psd_from_mt(x_mt, weights):
"""Compute PSD from tapered spectra.
Parameters
----------
x_mt : array, shape=(..., n_tapers, n_freqs)
Tapered spectra
weights : array, shape=(n_tapers,)
Weights used to combine the tapered spectra
Returns
-------
psd : array, shape=(..., n_freqs)
The computed PSD
"""
psd = weights * x_mt
psd *= psd.conj()
psd = psd.real.sum(axis=-2)
psd *= 2 / (weights * weights.conj()).real.sum(axis=-2)
return psd
def _csd_from_mt(x_mt, y_mt, weights_x, weights_y):
"""Compute CSD from tapered spectra.
Parameters
----------
x_mt : array, shape=(..., n_tapers, n_freqs)
Tapered spectra for x
y_mt : array, shape=(..., n_tapers, n_freqs)
Tapered spectra for y
weights_x : array, shape=(n_tapers,)
Weights used to combine the tapered spectra of x_mt
weights_y : array, shape=(n_tapers,)
Weights used to combine the tapered spectra of y_mt
Returns
-------
csd: array
The computed CSD
"""
csd = np.sum(weights_x * x_mt * (weights_y * y_mt).conj(), axis=-2)
denom = np.sqrt((weights_x * weights_x.conj()).real.sum(axis=-2)) * np.sqrt(
(weights_y * weights_y.conj()).real.sum(axis=-2)
)
csd *= 2 / denom
return csd
def _mt_spectra(x, dpss, sfreq, n_fft=None, remove_dc=True):
"""Compute tapered spectra.
Parameters
----------
x : array, shape=(..., n_times)
Input signal
dpss : array, shape=(n_tapers, n_times)
The tapers
sfreq : float
The sampling frequency
n_fft : int | None
Length of the FFT. If None, the number of samples in the input signal
will be used.
%(remove_dc)s
Returns
-------
x_mt : array, shape=(..., n_tapers, n_freqs)
The tapered spectra
freqs : array, shape=(n_freqs,)
The frequency points in Hz of the spectra
"""
if n_fft is None:
n_fft = x.shape[-1]
# remove mean (do not use in-place subtraction as it may modify input x)
if remove_dc:
x = x - np.mean(x, axis=-1, keepdims=True)
# only keep positive frequencies
freqs = rfftfreq(n_fft, 1.0 / sfreq)
# The following is equivalent to this, but uses less memory:
# x_mt = fftpack.fft(x[:, np.newaxis, :] * dpss, n=n_fft)
n_tapers = dpss.shape[0] if dpss.ndim > 1 else 1
x_mt = np.zeros(x.shape[:-1] + (n_tapers, len(freqs)), dtype=np.complex128)
for idx, sig in enumerate(x):
x_mt[idx] = rfft(sig[..., np.newaxis, :] * dpss, n=n_fft)
# Adjust DC and maybe Nyquist, depending on one-sided transform
x_mt[..., 0] /= np.sqrt(2.0)
if n_fft % 2 == 0:
x_mt[..., -1] /= np.sqrt(2.0)
return x_mt, freqs
@verbose
def _compute_mt_params(n_times, sfreq, bandwidth, low_bias, adaptive, verbose=None):
"""Triage windowing and multitaper parameters."""
# Compute standardized half-bandwidth
if isinstance(bandwidth, str):
logger.info(f' Using standard spectrum estimation with "{bandwidth}" window')
window_fun = get_window(bandwidth, n_times)[np.newaxis]
return window_fun, np.ones(1), False
if bandwidth is not None:
half_nbw = float(bandwidth) * n_times / (2.0 * sfreq)
else:
half_nbw = 4.0
if half_nbw < 0.5:
raise ValueError(
f"bandwidth value {bandwidth} yields a normalized half-bandwidth of "
f"{half_nbw} < 0.5, use a value of at least {sfreq / n_times}"
)
# Compute DPSS windows
n_tapers_max = int(2 * half_nbw)
window_fun, eigvals = dpss_windows(
n_times, half_nbw, n_tapers_max, sym=False, low_bias=low_bias
)
logger.info(
f" Using multitaper spectrum estimation with {len(eigvals)} DPSS windows"
)
if adaptive and len(eigvals) < 3:
warn(
"Not adaptively combining the spectral estimators due to a "
f"low number of tapers ({len(eigvals)} < 3)."
)
adaptive = False
return window_fun, eigvals, adaptive
@verbose
def psd_array_multitaper(
x,
sfreq,
fmin=0.0,
fmax=np.inf,
bandwidth=None,
adaptive=False,
low_bias=True,
normalization="length",
remove_dc=True,
output="power",
n_jobs=None,
*,
max_iter=150,
verbose=None,
):
r"""Compute power spectral density (PSD) using a multi-taper method.
The power spectral density is computed with DPSS
tapers :footcite:p:`Slepian1978`.
Parameters
----------
x : array, shape=(..., n_times)
The data to compute PSD from.
sfreq : float
The sampling frequency.
%(fmin_fmax_psd)s
bandwidth : float
Frequency bandwidth of the multi-taper window function in Hz. For a
given frequency, frequencies at ``± bandwidth / 2`` are smoothed
together. The default value is a bandwidth of
``8 * (sfreq / n_times)``.
adaptive : bool
Use adaptive weights to combine the tapered spectra into PSD
(slow, use n_jobs >> 1 to speed up computation).
low_bias : bool
Only use tapers with more than 90%% spectral concentration within
bandwidth.
%(normalization)s
%(remove_dc)s
output : str
The format of the returned ``psds`` array, ``'complex'`` or
``'power'``:
* ``'power'`` : the power spectral density is returned.
* ``'complex'`` : the complex fourier coefficients are returned per
taper.
%(n_jobs)s
%(max_iter_multitaper)s
%(verbose)s
Returns
-------
psds : ndarray, shape (..., n_freqs) or (..., n_tapers, n_freqs)
The power spectral densities. All dimensions up to the last (or the
last two if ``output='complex'``) will be the same as input.
freqs : array
The frequency points in Hz of the PSD.
weights : ndarray
The weights used for averaging across tapers. Only returned if
``output='complex'``.
See Also
--------
csd_multitaper
mne.io.Raw.compute_psd
mne.Epochs.compute_psd
mne.Evoked.compute_psd
Notes
-----
.. versionadded:: 0.14.0
References
----------
.. footbibliography::
"""
_check_option("normalization", normalization, ["length", "full"])
# Reshape data so its 2-D for parallelization
ndim_in = x.ndim
x = np.atleast_2d(x)
n_times = x.shape[-1]
dshape = x.shape[:-1]
x = x.reshape(-1, n_times)
dpss, eigvals, adaptive = _compute_mt_params(
n_times, sfreq, bandwidth, low_bias, adaptive
)
n_tapers = len(dpss)
weights = np.sqrt(eigvals)[np.newaxis, :, np.newaxis]
# decide which frequencies to keep
freqs = rfftfreq(n_times, 1.0 / sfreq)
freq_mask = (freqs >= fmin) & (freqs <= fmax)
freqs = freqs[freq_mask]
n_freqs = len(freqs)
if output == "complex":
psd = np.zeros((x.shape[0], n_tapers, n_freqs), dtype="complex")
else:
psd = np.zeros((x.shape[0], n_freqs))
# Let's go in up to 50 MB chunks of signals to save memory
n_chunk = max(50000000 // (len(freq_mask) * len(eigvals) * 16), 1)
offsets = np.concatenate((np.arange(0, x.shape[0], n_chunk), [x.shape[0]]))
for start, stop in zip(offsets[:-1], offsets[1:]):
x_mt = _mt_spectra(x[start:stop], dpss, sfreq, remove_dc=remove_dc)[0]
if output == "power":
if not adaptive:
psd[start:stop] = _psd_from_mt(x_mt[:, :, freq_mask], weights)
else:
parallel, my_psd_from_mt_adaptive, n_jobs = parallel_func(
_psd_from_mt_adaptive, n_jobs
)
n_splits = min(stop - start, n_jobs)
out = parallel(
my_psd_from_mt_adaptive(x, eigvals, freq_mask, max_iter)
for x in np.array_split(x_mt, n_splits)
)
psd[start:stop] = np.concatenate(out)
else:
psd[start:stop] = x_mt[:, :, freq_mask]
if normalization == "full":
psd /= sfreq
# Combining/reshaping to original data shape
last_dims = (n_freqs,) if output == "power" else (n_tapers, n_freqs)
psd.shape = dshape + last_dims
if ndim_in == 1:
psd = psd[0]
if output == "complex":
return psd, freqs, weights
else:
return psd, freqs
@verbose
def tfr_array_multitaper(
data,
sfreq,
freqs,
n_cycles=7.0,
zero_mean=True,
time_bandwidth=4.0,
use_fft=True,
decim=1,
output="complex",
n_jobs=None,
*,
verbose=None,
):
"""Compute Time-Frequency Representation (TFR) using DPSS tapers.
Same computation as `~mne.time_frequency.tfr_multitaper`, but operates on
:class:`NumPy arrays <numpy.ndarray>` instead of `~mne.Epochs` or
`~mne.Evoked` objects.
Parameters
----------
data : array of shape (n_epochs, n_channels, n_times)
The epochs.
sfreq : float
Sampling frequency of the data in Hz.
%(freqs_tfr_array)s
%(n_cycles_tfr)s
zero_mean : bool
If True, make sure the wavelets have a mean of zero. Defaults to True.
%(time_bandwidth_tfr)s
use_fft : bool
Use the FFT for convolutions or not. Defaults to True.
%(decim_tfr)s
output : str, default 'complex'
* ``'complex'`` : single trial per taper complex values.
* ``'power'`` : single trial power.
* ``'phase'`` : single trial per taper phase.
* ``'avg_power'`` : average of single trial power.
* ``'itc'`` : inter-trial coherence.
* ``'avg_power_itc'`` : average of single trial power and inter-trial
coherence across trials.
%(n_jobs)s
The parallelization is implemented across channels.
%(verbose)s
Returns
-------
out : array
Time frequency transform of ``data``.
- if ``output in ('complex',' 'phase')``, array of shape
``(n_epochs, n_chans, n_tapers, n_freqs, n_times)``
- if ``output`` is ``'power'``, array of shape ``(n_epochs, n_chans,
n_freqs, n_times)``
- else, array of shape ``(n_chans, n_freqs, n_times)``
If ``output`` is ``'avg_power_itc'``, the real values in ``out``
contain the average power and the imaginary values contain the
inter-trial coherence: :math:`out = power_{avg} + i * ITC`.
See Also
--------
mne.time_frequency.tfr_multitaper
mne.time_frequency.tfr_morlet
mne.time_frequency.tfr_array_morlet
mne.time_frequency.tfr_stockwell
mne.time_frequency.tfr_array_stockwell
Notes
-----
%(temporal_window_tfr_intro)s
%(temporal_window_tfr_multitaper_notes)s
%(time_bandwidth_tfr_notes)s
.. versionadded:: 0.14.0
"""
from .tfr import _compute_tfr
return _compute_tfr(
data,
freqs,
sfreq=sfreq,
method="multitaper",
n_cycles=n_cycles,
zero_mean=zero_mean,
time_bandwidth=time_bandwidth,
use_fft=use_fft,
decim=decim,
output=output,
n_jobs=n_jobs,
verbose=verbose,
)

272
mne/time_frequency/psd.py Normal file
View File

@@ -0,0 +1,272 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import warnings
from functools import partial
import numpy as np
from scipy.signal import spectrogram
from ..parallel import parallel_func
from ..utils import _check_option, _ensure_int, logger, verbose
from ..utils.numerics import _mask_to_onsets_offsets
# adapted from SciPy
# https://github.com/scipy/scipy/blob/f71e7fad717801c4476312fe1e23f2dfbb4c9d7f/scipy/signal/_spectral_py.py#L2019 # noqa: E501
def _median_biases(n):
# Compute the biases for 0 to max(n, 1) terms included in a median calc
biases = np.ones(n + 1)
# The original SciPy code is:
#
# def _median_bias(n):
# ii_2 = 2 * np.arange(1., (n - 1) // 2 + 1)
# return 1 + np.sum(1. / (ii_2 + 1) - 1. / ii_2)
#
# This is a sum over (n-1)//2 terms.
# The ii_2 terms here for different n are:
#
# n=0: [] # 0 terms
# n=1: [] # 0 terms
# n=2: [] # 0 terms
# n=3: [2] # 1 term
# n=4: [2] # 1 term
# n=5: [2, 4] # 2 terms
# n=6: [2, 4] # 2 terms
# ...
#
# We can get the terms for 0 through n using a cumulative summation and
# indexing:
if n >= 3:
ii_2 = 2 * np.arange(1, (n - 1) // 2 + 1)
sums = 1 + np.cumsum(1.0 / (ii_2 + 1) - 1.0 / ii_2)
idx = np.arange(2, n) // 2 - 1
biases[3:] = sums[idx]
return biases
def _decomp_aggregate_mask(epoch, func, average, freq_sl):
_, _, spect = func(epoch)
spect = spect[..., freq_sl, :]
# Do the averaging here (per epoch) to save memory
if average == "mean":
spect = np.nanmean(spect, axis=-1)
elif average == "median":
biases = _median_biases(spect.shape[-1])
idx = (~np.isnan(spect)).sum(-1)
spect = np.nanmedian(spect, axis=-1) / biases[idx]
return spect
def _spect_func(epoch, func, freq_sl, average, *, output="power"):
"""Aux function."""
# Decide if we should split this to save memory or not, since doing
# multiple calls will incur some performance overhead. Eventually we might
# want to write (really, go back to) our own spectrogram implementation
# that, if possible, averages after each transform, but this will incur
# a lot of overhead because of the many Python calls required.
kwargs = dict(func=func, average=average, freq_sl=freq_sl)
if epoch.nbytes > 10e6:
spect = np.apply_along_axis(_decomp_aggregate_mask, -1, epoch, **kwargs)
else:
spect = _decomp_aggregate_mask(epoch, **kwargs)
return spect
def _check_nfft(n, n_fft, n_per_seg, n_overlap):
"""Ensure n_fft, n_per_seg and n_overlap make sense."""
if n_per_seg is None and n_fft > n:
raise ValueError(
"If n_per_seg is None n_fft is not allowed to be > "
"n_times. If you want zero-padding, you have to set "
f"n_per_seg to relevant length. Got n_fft of {n_fft} while"
f" signal length is {n}."
)
n_per_seg = n_fft if n_per_seg is None or n_per_seg > n_fft else n_per_seg
n_per_seg = n if n_per_seg > n else n_per_seg
if n_overlap >= n_per_seg:
raise ValueError(
"n_overlap cannot be greater than n_per_seg (or n_fft). Got n_overlap "
f"of {n_overlap} while n_per_seg is {n_per_seg}."
)
return n_fft, n_per_seg, n_overlap
@verbose
def psd_array_welch(
x,
sfreq,
fmin=0,
fmax=np.inf,
n_fft=256,
n_overlap=0,
n_per_seg=None,
n_jobs=None,
average="mean",
window="hamming",
remove_dc=True,
*,
output="power",
verbose=None,
):
"""Compute power spectral density (PSD) using Welch's method.
Welch's method is described in :footcite:t:`Welch1967`.
Parameters
----------
x : array, shape=(..., n_times)
The data to compute PSD from.
sfreq : float
The sampling frequency.
fmin : float
The lower frequency of interest.
fmax : float
The upper frequency of interest.
n_fft : int
The length of FFT used, must be ``>= n_per_seg`` (default: 256).
The segments will be zero-padded if ``n_fft > n_per_seg``.
n_overlap : int
The number of points of overlap between segments. Will be adjusted
to be <= n_per_seg. The default value is 0.
n_per_seg : int | None
Length of each Welch segment (windowed with a Hamming window). Defaults
to None, which sets n_per_seg equal to n_fft.
%(n_jobs)s
%(average_psd)s
.. versionadded:: 0.19.0
%(window_psd)s
.. versionadded:: 0.22.0
%(remove_dc)s
output : str
The format of the returned ``psds`` array, ``'complex'`` or
``'power'``:
* ``'power'`` : the power spectral density is returned.
* ``'complex'`` : the complex fourier coefficients are returned per
window.
.. versionadded:: 1.4.0
%(verbose)s
Returns
-------
psds : ndarray, shape (..., n_freqs) or (..., n_freqs, n_segments)
The power spectral densities. If ``average='mean`` or
``average='median'``, the returned array will have the same shape
as the input data plus an additional frequency dimension.
If ``average=None``, the returned array will have the same shape as
the input data plus two additional dimensions corresponding to
frequencies and the unaggregated segments, respectively.
freqs : ndarray, shape (n_freqs,)
The frequencies.
Notes
-----
.. versionadded:: 0.14.0
References
----------
.. footbibliography::
"""
_check_option("average", average, (None, False, "mean", "median"))
_check_option("output", output, ("power", "complex"))
detrend = "constant" if remove_dc else False
mode = "complex" if output == "complex" else "psd"
n_fft = _ensure_int(n_fft, "n_fft")
n_overlap = _ensure_int(n_overlap, "n_overlap")
if n_per_seg is not None:
n_per_seg = _ensure_int(n_per_seg, "n_per_seg")
if average is False:
average = None
dshape = x.shape[:-1]
n_times = x.shape[-1]
x = x.reshape(-1, n_times)
# Prep the PSD
n_fft, n_per_seg, n_overlap = _check_nfft(n_times, n_fft, n_per_seg, n_overlap)
win_size = n_fft / float(sfreq)
logger.info(f"Effective window size : {win_size:0.3f} (s)")
freqs = np.arange(n_fft // 2 + 1, dtype=float) * (sfreq / n_fft)
freq_mask = (freqs >= fmin) & (freqs <= fmax)
if not freq_mask.any():
raise ValueError(f"No frequencies found between fmin={fmin} and fmax={fmax}")
freq_sl = slice(*(np.where(freq_mask)[0][[0, -1]] + [0, 1]))
del freq_mask
freqs = freqs[freq_sl]
# Parallelize across first N-1 dimensions
logger.debug(
f"Spectogram using {n_fft}-point FFT on {n_per_seg} samples with "
f"{n_overlap} overlap and {window} window"
)
parallel, my_spect_func, n_jobs = parallel_func(_spect_func, n_jobs=n_jobs)
_func = partial(
spectrogram,
detrend=detrend,
noverlap=n_overlap,
nperseg=n_per_seg,
nfft=n_fft,
fs=sfreq,
window=window,
mode=mode,
)
if np.any(np.isnan(x)):
good_mask = ~np.isnan(x)
# NaNs originate from annot, so must match for all channels. Note that we CANNOT
# use np.testing.assert_allclose() here; it is strict about shapes/broadcasting
assert np.allclose(good_mask, good_mask[[0]], equal_nan=True)
t_onsets, t_offsets = _mask_to_onsets_offsets(good_mask[0])
x_splits = [x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets)]
# weights reflect the number of samples used from each span. For spans longer
# than `n_per_seg`, trailing samples may be discarded. For spans shorter than
# `n_per_seg`, the wrapped function (`scipy.signal.spectrogram`) automatically
# reduces `n_per_seg` to match the span length (with a warning).
step = n_per_seg - n_overlap
span_lengths = [span.shape[-1] for span in x_splits]
weights = [
w if w < n_per_seg else w - ((w - n_overlap) % step) for w in span_lengths
]
agg_func = partial(np.average, weights=weights)
if n_jobs > 1:
logger.info(
f"Data split into {len(x_splits)} (probably unequal) chunks due to "
'"bad_*" annotations. Parallelization may be sub-optimal.'
)
if (np.array(span_lengths) < n_per_seg).any():
logger.info(
"At least one good data span is shorter than n_per_seg, and will be "
"analyzed with a shorter window than the rest of the file."
)
def func(*args, **kwargs):
# swallow SciPy warnings caused by short good data spans
with warnings.catch_warnings():
warnings.filterwarnings(
action="ignore",
module="scipy",
category=UserWarning,
message=r"nperseg = \d+ is greater than input length",
)
return _func(*args, **kwargs)
else:
x_splits = [arr for arr in np.array_split(x, n_jobs) if arr.size != 0]
agg_func = np.concatenate
func = _func
f_spect = parallel(
my_spect_func(d, func=func, freq_sl=freq_sl, average=average, output=output)
for d in x_splits
)
psds = agg_func(f_spect, axis=0)
shape = dshape + (len(freqs),)
if average is None:
shape = shape + (-1,)
psds.shape = shape
return psds, freqs

File diff suppressed because it is too large Load Diff

4216
mne/time_frequency/tfr.py Normal file

File diff suppressed because it is too large Load Diff