initial commit
This commit is contained in:
322
mne/time_frequency/_stockwell.py
Normal file
322
mne/time_frequency/_stockwell.py
Normal file
@@ -0,0 +1,322 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from scipy.fft import fft, fftfreq, ifft
|
||||
|
||||
from .._fiff.pick import _pick_data_channels, pick_info
|
||||
from ..parallel import parallel_func
|
||||
from ..utils import _validate_type, legacy, logger, verbose
|
||||
from .tfr import AverageTFRArray, _ensure_slice, _get_data
|
||||
|
||||
|
||||
def _check_input_st(x_in, n_fft):
|
||||
"""Aux function."""
|
||||
# flatten to 2 D and memorize original shape
|
||||
n_times = x_in.shape[-1]
|
||||
|
||||
def _is_power_of_two(n):
|
||||
return not (n > 0 and (n & (n - 1)))
|
||||
|
||||
if n_fft is None or (not _is_power_of_two(n_fft) and n_times > n_fft):
|
||||
# Compute next power of 2
|
||||
n_fft = 2 ** int(np.ceil(np.log2(n_times)))
|
||||
elif n_fft < n_times:
|
||||
raise ValueError(
|
||||
f"n_fft cannot be smaller than signal size. Got {n_fft} < {n_times}."
|
||||
)
|
||||
if n_times < n_fft:
|
||||
logger.info(
|
||||
f'The input signal is shorter ({x_in.shape[-1]}) than "n_fft" ({n_fft}). '
|
||||
"Applying zero padding."
|
||||
)
|
||||
zero_pad = n_fft - n_times
|
||||
pad_array = np.zeros(x_in.shape[:-1] + (zero_pad,), x_in.dtype)
|
||||
x_in = np.concatenate((x_in, pad_array), axis=-1)
|
||||
else:
|
||||
zero_pad = 0
|
||||
return x_in, n_fft, zero_pad
|
||||
|
||||
|
||||
def _precompute_st_windows(n_samp, start_f, stop_f, sfreq, width):
|
||||
"""Precompute stockwell Gaussian windows (in the freq domain)."""
|
||||
tw = fftfreq(n_samp, 1.0 / sfreq) / n_samp
|
||||
tw = np.r_[tw[:1], tw[1:][::-1]]
|
||||
|
||||
k = width # 1 for classical stowckwell transform
|
||||
f_range = np.arange(start_f, stop_f, 1)
|
||||
windows = np.empty((len(f_range), len(tw)), dtype=np.complex128)
|
||||
for i_f, f in enumerate(f_range):
|
||||
if f == 0.0:
|
||||
window = np.ones(len(tw))
|
||||
else:
|
||||
window = (f / (np.sqrt(2.0 * np.pi) * k)) * np.exp(
|
||||
-0.5 * (1.0 / k**2.0) * (f**2.0) * tw**2.0
|
||||
)
|
||||
window /= window.sum() # normalisation
|
||||
windows[i_f] = fft(window)
|
||||
return windows
|
||||
|
||||
|
||||
def _st(x, start_f, windows):
|
||||
"""Compute ST based on Ali Moukadem MATLAB code (used in tests)."""
|
||||
from scipy.fft import fft, ifft
|
||||
|
||||
n_samp = x.shape[-1]
|
||||
ST = np.empty(x.shape[:-1] + (len(windows), n_samp), dtype=np.complex128)
|
||||
# do the work
|
||||
Fx = fft(x)
|
||||
XF = np.concatenate([Fx, Fx], axis=-1)
|
||||
for i_f, window in enumerate(windows):
|
||||
f = start_f + i_f
|
||||
ST[..., i_f, :] = ifft(XF[..., f : f + n_samp] * window)
|
||||
return ST
|
||||
|
||||
|
||||
def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W):
|
||||
"""Aux function."""
|
||||
decim = _ensure_slice(decim)
|
||||
n_samp = x.shape[-1]
|
||||
decim_indices = decim.indices(n_samp - zero_pad)
|
||||
n_out = len(range(*decim_indices))
|
||||
psd = np.empty((len(W), n_out))
|
||||
itc = np.empty_like(psd) if compute_itc else None
|
||||
X = fft(x)
|
||||
XX = np.concatenate([X, X], axis=-1)
|
||||
for i_f, window in enumerate(W):
|
||||
f = start_f + i_f
|
||||
ST = ifft(XX[:, f : f + n_samp] * window)
|
||||
TFR = ST[:, slice(*decim_indices)]
|
||||
TFR_abs = np.abs(TFR)
|
||||
TFR_abs[TFR_abs == 0] = 1.0
|
||||
if compute_itc:
|
||||
TFR /= TFR_abs
|
||||
itc[i_f] = np.abs(np.mean(TFR, axis=0))
|
||||
TFR_abs *= TFR_abs
|
||||
psd[i_f] = np.mean(TFR_abs, axis=0)
|
||||
return psd, itc
|
||||
|
||||
|
||||
def _compute_freqs_st(fmin, fmax, n_fft, sfreq):
|
||||
from scipy.fft import fftfreq
|
||||
|
||||
freqs = fftfreq(n_fft, 1.0 / sfreq)
|
||||
if fmin is None:
|
||||
fmin = freqs[freqs > 0][0]
|
||||
if fmax is None:
|
||||
fmax = freqs.max()
|
||||
|
||||
start_f = np.abs(freqs - fmin).argmin()
|
||||
stop_f = np.abs(freqs - fmax).argmin()
|
||||
freqs = freqs[start_f:stop_f]
|
||||
return start_f, stop_f, freqs
|
||||
|
||||
|
||||
@verbose
|
||||
def tfr_array_stockwell(
|
||||
data,
|
||||
sfreq,
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
n_fft=None,
|
||||
width=1.0,
|
||||
decim=1,
|
||||
return_itc=False,
|
||||
n_jobs=None,
|
||||
*,
|
||||
verbose=None,
|
||||
):
|
||||
"""Compute power and intertrial coherence using Stockwell (S) transform.
|
||||
|
||||
Same computation as `~mne.time_frequency.tfr_stockwell`, but operates on
|
||||
:class:`NumPy arrays <numpy.ndarray>` instead of `~mne.Epochs` objects.
|
||||
|
||||
See :footcite:`Stockwell2007,MoukademEtAl2014,WheatEtAl2010,JonesEtAl2006`
|
||||
for more information.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : ndarray, shape (n_epochs, n_channels, n_times)
|
||||
The signal to transform.
|
||||
sfreq : float
|
||||
The sampling frequency.
|
||||
fmin : None, float
|
||||
The minimum frequency to include. If None defaults to the minimum fft
|
||||
frequency greater than zero.
|
||||
fmax : None, float
|
||||
The maximum frequency to include. If None defaults to the maximum fft.
|
||||
n_fft : int | None
|
||||
The length of the windows used for FFT. If None, it defaults to the
|
||||
next power of 2 larger than the signal length.
|
||||
width : float
|
||||
The width of the Gaussian window. If < 1, increased temporal
|
||||
resolution, if > 1, increased frequency resolution. Defaults to 1.
|
||||
(classical S-Transform).
|
||||
%(decim_tfr)s
|
||||
return_itc : bool
|
||||
Return intertrial coherence (ITC) as well as averaged power.
|
||||
%(n_jobs)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
st_power : ndarray
|
||||
The multitaper power of the Stockwell transformed data.
|
||||
The last two dimensions are frequency and time.
|
||||
itc : ndarray
|
||||
The intertrial coherence. Only returned if return_itc is True.
|
||||
freqs : ndarray
|
||||
The frequencies.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.time_frequency.tfr_stockwell
|
||||
mne.time_frequency.tfr_multitaper
|
||||
mne.time_frequency.tfr_array_multitaper
|
||||
mne.time_frequency.tfr_morlet
|
||||
mne.time_frequency.tfr_array_morlet
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
_validate_type(data, np.ndarray, "data")
|
||||
if data.ndim != 3:
|
||||
raise ValueError(
|
||||
"data must be 3D with shape (n_epochs, n_channels, n_times), "
|
||||
f"got {data.shape}"
|
||||
)
|
||||
decim = _ensure_slice(decim)
|
||||
_, n_channels, n_out = data[..., decim].shape
|
||||
data, n_fft_, zero_pad = _check_input_st(data, n_fft)
|
||||
start_f, stop_f, freqs = _compute_freqs_st(fmin, fmax, n_fft_, sfreq)
|
||||
|
||||
W = _precompute_st_windows(data.shape[-1], start_f, stop_f, sfreq, width)
|
||||
n_freq = stop_f - start_f
|
||||
psd = np.empty((n_channels, n_freq, n_out))
|
||||
itc = np.empty((n_channels, n_freq, n_out)) if return_itc else None
|
||||
|
||||
parallel, my_st, n_jobs = parallel_func(_st_power_itc, n_jobs, verbose=verbose)
|
||||
tfrs = parallel(
|
||||
my_st(data[:, c, :], start_f, return_itc, zero_pad, decim, W)
|
||||
for c in range(n_channels)
|
||||
)
|
||||
for c, (this_psd, this_itc) in enumerate(iter(tfrs)):
|
||||
psd[c] = this_psd
|
||||
if this_itc is not None:
|
||||
itc[c] = this_itc
|
||||
|
||||
return psd, itc, freqs
|
||||
|
||||
|
||||
@legacy(alt='.compute_tfr(method="stockwell", freqs="auto")')
|
||||
@verbose
|
||||
def tfr_stockwell(
|
||||
inst,
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
n_fft=None,
|
||||
width=1.0,
|
||||
decim=1,
|
||||
return_itc=False,
|
||||
n_jobs=None,
|
||||
verbose=None,
|
||||
):
|
||||
"""Compute Time-Frequency Representation (TFR) using Stockwell Transform.
|
||||
|
||||
Same computation as `~mne.time_frequency.tfr_array_stockwell`, but operates
|
||||
on `~mne.Epochs` objects instead of :class:`NumPy arrays <numpy.ndarray>`.
|
||||
|
||||
See :footcite:`Stockwell2007,MoukademEtAl2014,WheatEtAl2010,JonesEtAl2006`
|
||||
for more information.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst : Epochs | Evoked
|
||||
The epochs or evoked object.
|
||||
fmin : None, float
|
||||
The minimum frequency to include. If None defaults to the minimum fft
|
||||
frequency greater than zero.
|
||||
fmax : None, float
|
||||
The maximum frequency to include. If None defaults to the maximum fft.
|
||||
n_fft : int | None
|
||||
The length of the windows used for FFT. If None, it defaults to the
|
||||
next power of 2 larger than the signal length.
|
||||
width : float
|
||||
The width of the Gaussian window. If < 1, increased temporal
|
||||
resolution, if > 1, increased frequency resolution. Defaults to 1.
|
||||
(classical S-Transform).
|
||||
decim : int
|
||||
The decimation factor on the time axis. To reduce memory usage.
|
||||
return_itc : bool
|
||||
Return intertrial coherence (ITC) as well as averaged power.
|
||||
n_jobs : int
|
||||
The number of jobs to run in parallel (over channels).
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
power : AverageTFR
|
||||
The averaged power.
|
||||
itc : AverageTFR
|
||||
The intertrial coherence. Only returned if return_itc is True.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.time_frequency.tfr_array_stockwell
|
||||
mne.time_frequency.tfr_multitaper
|
||||
mne.time_frequency.tfr_array_multitaper
|
||||
mne.time_frequency.tfr_morlet
|
||||
mne.time_frequency.tfr_array_morlet
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 0.9.0
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
# verbose dec is used b/c subfunctions are verbose
|
||||
data = _get_data(inst, return_itc)
|
||||
picks = _pick_data_channels(inst.info)
|
||||
info = pick_info(inst.info, picks)
|
||||
data = data[:, picks, :]
|
||||
decim = _ensure_slice(decim)
|
||||
power, itc, freqs = tfr_array_stockwell(
|
||||
data,
|
||||
sfreq=info["sfreq"],
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
n_fft=n_fft,
|
||||
width=width,
|
||||
decim=decim,
|
||||
return_itc=return_itc,
|
||||
n_jobs=n_jobs,
|
||||
)
|
||||
times = inst.times[decim].copy()
|
||||
nave = len(data)
|
||||
out = AverageTFRArray(
|
||||
info=info,
|
||||
data=power,
|
||||
times=times,
|
||||
freqs=freqs,
|
||||
nave=nave,
|
||||
method="stockwell-power",
|
||||
)
|
||||
if return_itc:
|
||||
out = (
|
||||
out,
|
||||
AverageTFRArray(
|
||||
info=deepcopy(info),
|
||||
data=itc,
|
||||
times=times.copy(),
|
||||
freqs=freqs.copy(),
|
||||
nave=nave,
|
||||
method="stockwell-itc",
|
||||
),
|
||||
)
|
||||
return out
|
||||
Reference in New Issue
Block a user