initial commit
This commit is contained in:
168
mne/decoding/time_frequency.py
Normal file
168
mne/decoding/time_frequency.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
from sklearn.base import BaseEstimator, TransformerMixin
|
||||
|
||||
from ..time_frequency.tfr import _compute_tfr
|
||||
from ..utils import _check_option, fill_doc, verbose
|
||||
|
||||
|
||||
@fill_doc
|
||||
class TimeFrequency(TransformerMixin, BaseEstimator):
|
||||
"""Time frequency transformer.
|
||||
|
||||
Time-frequency transform of times series along the last axis.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
freqs : array-like of float, shape (n_freqs,)
|
||||
The frequencies.
|
||||
sfreq : float | int, default 1.0
|
||||
Sampling frequency of the data.
|
||||
method : 'multitaper' | 'morlet', default 'morlet'
|
||||
The time-frequency method. 'morlet' convolves a Morlet wavelet.
|
||||
'multitaper' uses Morlet wavelets windowed with multiple DPSS
|
||||
multitapers.
|
||||
n_cycles : float | array of float, default 7.0
|
||||
Number of cycles in the Morlet wavelet. Fixed number
|
||||
or one per frequency.
|
||||
time_bandwidth : float, default None
|
||||
If None and method=multitaper, will be set to 4.0 (3 tapers).
|
||||
Time x (Full) Bandwidth product. Only applies if
|
||||
method == 'multitaper'. The number of good tapers (low-bias) is
|
||||
chosen automatically based on this to equal floor(time_bandwidth - 1).
|
||||
use_fft : bool, default True
|
||||
Use the FFT for convolutions or not.
|
||||
decim : int | slice, default 1
|
||||
To reduce memory usage, decimation factor after time-frequency
|
||||
decomposition.
|
||||
If `int`, returns tfr[..., ::decim].
|
||||
If `slice`, returns tfr[..., decim].
|
||||
|
||||
.. note:: Decimation may create aliasing artifacts, yet decimation
|
||||
is done after the convolutions.
|
||||
|
||||
output : str, default 'complex'
|
||||
* 'complex' : single trial complex.
|
||||
* 'power' : single trial power.
|
||||
* 'phase' : single trial phase.
|
||||
%(n_jobs)s
|
||||
The number of epochs to process at the same time. The parallelization
|
||||
is implemented across channels.
|
||||
%(verbose)s
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.time_frequency.tfr_morlet
|
||||
mne.time_frequency.tfr_multitaper
|
||||
"""
|
||||
|
||||
@verbose
|
||||
def __init__(
|
||||
self,
|
||||
freqs,
|
||||
sfreq=1.0,
|
||||
method="morlet",
|
||||
n_cycles=7.0,
|
||||
time_bandwidth=None,
|
||||
use_fft=True,
|
||||
decim=1,
|
||||
output="complex",
|
||||
n_jobs=1,
|
||||
verbose=None,
|
||||
):
|
||||
"""Init TimeFrequency transformer."""
|
||||
# Check non-average output
|
||||
output = _check_option("output", output, ["complex", "power", "phase"])
|
||||
|
||||
self.freqs = freqs
|
||||
self.sfreq = sfreq
|
||||
self.method = method
|
||||
self.n_cycles = n_cycles
|
||||
self.time_bandwidth = time_bandwidth
|
||||
self.use_fft = use_fft
|
||||
self.decim = decim
|
||||
# Check that output is not an average metric (e.g. ITC)
|
||||
self.output = output
|
||||
self.n_jobs = n_jobs
|
||||
self.verbose = verbose
|
||||
|
||||
def fit_transform(self, X, y=None):
|
||||
"""Time-frequency transform of times series along the last axis.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array, shape (n_samples, n_channels, n_times)
|
||||
The training data samples. The channel dimension can be zero- or
|
||||
1-dimensional.
|
||||
y : None
|
||||
For scikit-learn compatibility purposes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Xt : array, shape (n_samples, n_channels, n_freqs, n_times)
|
||||
The time-frequency transform of the data, where n_channels can be
|
||||
zero- or 1-dimensional.
|
||||
"""
|
||||
return self.fit(X, y).transform(X)
|
||||
|
||||
def fit(self, X, y=None): # noqa: D401
|
||||
"""Do nothing (for scikit-learn compatibility purposes).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array, shape (n_samples, n_channels, n_times)
|
||||
The training data.
|
||||
y : array | None
|
||||
The target values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : object
|
||||
Return self.
|
||||
"""
|
||||
return self
|
||||
|
||||
def transform(self, X):
|
||||
"""Time-frequency transform of times series along the last axis.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array, shape (n_samples, n_channels, n_times)
|
||||
The training data samples. The channel dimension can be zero- or
|
||||
1-dimensional.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Xt : array, shape (n_samples, n_channels, n_freqs, n_times)
|
||||
The time-frequency transform of the data, where n_channels can be
|
||||
zero- or 1-dimensional.
|
||||
"""
|
||||
# Ensure 3-dimensional X
|
||||
shape = X.shape[1:-1]
|
||||
if not shape:
|
||||
X = X[:, np.newaxis, :]
|
||||
|
||||
# Compute time-frequency
|
||||
Xt = _compute_tfr(
|
||||
X,
|
||||
freqs=self.freqs,
|
||||
sfreq=self.sfreq,
|
||||
method=self.method,
|
||||
n_cycles=self.n_cycles,
|
||||
zero_mean=True,
|
||||
time_bandwidth=self.time_bandwidth,
|
||||
use_fft=self.use_fft,
|
||||
decim=self.decim,
|
||||
output=self.output,
|
||||
n_jobs=self.n_jobs,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
# Back to original shape
|
||||
if not shape:
|
||||
Xt = Xt[:, 0, :]
|
||||
|
||||
return Xt
|
||||
Reference in New Issue
Block a user