# Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. """IIR and FIR filtering and resampling functions.""" from collections import Counter from copy import deepcopy from functools import partial from math import gcd import numpy as np from scipy import fft, signal from scipy.stats import f as fstat from ._fiff.pick import _picks_to_idx from ._ola import _COLA from .cuda import ( _fft_multiply_repeated, _fft_resample, _setup_cuda_fft_multiply_repeated, _setup_cuda_fft_resample, _smart_pad, ) from .fixes import minimum_phase from .parallel import parallel_func from .utils import ( _check_option, _check_preload, _ensure_int, _pl, _validate_type, logger, sum_squared, verbose, warn, ) # These values from Ifeachor and Jervis. _length_factors = dict(hann=3.1, hamming=3.3, blackman=5.0) def next_fast_len(target): """Find the next fast size of input data to `fft`, for zero-padding, etc. SciPy's FFTPACK has efficient functions for radix {2, 3, 4, 5}, so this returns the next composite of the prime factors 2, 3, and 5 which is greater than or equal to `target`. (These are also known as 5-smooth numbers, regular numbers, or Hamming numbers.) Parameters ---------- target : int Length to start searching from. Must be a positive integer. Returns ------- out : int The first 5-smooth number greater than or equal to `target`. Notes ----- Copied from SciPy with minor modifications. """ from bisect import bisect_left hams = ( 8, 9, 10, 12, 15, 16, 18, 20, 24, 25, 27, 30, 32, 36, 40, 45, 48, 50, 54, 60, 64, 72, 75, 80, 81, 90, 96, 100, 108, 120, 125, 128, 135, 144, 150, 160, 162, 180, 192, 200, 216, 225, 240, 243, 250, 256, 270, 288, 300, 320, 324, 360, 375, 384, 400, 405, 432, 450, 480, 486, 500, 512, 540, 576, 600, 625, 640, 648, 675, 720, 729, 750, 768, 800, 810, 864, 900, 960, 972, 1000, 1024, 1080, 1125, 1152, 1200, 1215, 1250, 1280, 1296, 1350, 1440, 1458, 1500, 1536, 1600, 1620, 1728, 1800, 1875, 1920, 1944, 2000, 2025, 2048, 2160, 2187, 2250, 2304, 2400, 2430, 2500, 2560, 2592, 2700, 2880, 2916, 3000, 3072, 3125, 3200, 3240, 3375, 3456, 3600, 3645, 3750, 3840, 3888, 4000, 4050, 4096, 4320, 4374, 4500, 4608, 4800, 4860, 5000, 5120, 5184, 5400, 5625, 5760, 5832, 6000, 6075, 6144, 6250, 6400, 6480, 6561, 6750, 6912, 7200, 7290, 7500, 7680, 7776, 8000, 8100, 8192, 8640, 8748, 9000, 9216, 9375, 9600, 9720, 10000, ) if target <= 6: return target # Quickly check if it's already a power of 2 if not (target & (target - 1)): return target # Get result quickly for small sizes, since FFT itself is similarly fast. if target <= hams[-1]: return hams[bisect_left(hams, target)] match = float("inf") # Anything found will be smaller p5 = 1 while p5 < target: p35 = p5 while p35 < target: # Ceiling integer division, avoiding conversion to float # (quotient = ceil(target / p35)) quotient = -(-target // p35) p2 = 2 ** int(quotient - 1).bit_length() N = p2 * p35 if N == target: return N elif N < match: match = N p35 *= 3 if p35 == target: return p35 if p35 < match: match = p35 p5 *= 5 if p5 == target: return p5 if p5 < match: match = p5 return match def _overlap_add_filter( x, h, n_fft=None, phase="zero", picks=None, n_jobs=None, copy=True, pad="reflect_limited", ): """Filter the signal x using h with overlap-add FFTs.""" # set up array for filtering, reshape to 2D, operate on last axis x, orig_shape, picks = _prep_for_filtering(x, copy, picks) # Extend the signal by mirroring the edges to reduce transient filter # response _check_zero_phase_length(len(h), phase) if len(h) == 1: return x * h**2 if phase == "zero-double" else x * h n_edge = max(min(len(h), x.shape[1]) - 1, 0) logger.debug(f"Smart-padding with: {n_edge} samples on each edge") n_x = x.shape[1] + 2 * n_edge if phase == "zero-double": h = np.convolve(h, h[::-1]) # Determine FFT length to use min_fft = 2 * len(h) - 1 if n_fft is None: max_fft = n_x if max_fft >= min_fft: # cost function based on number of multiplications N = 2 ** np.arange( np.ceil(np.log2(min_fft)), np.ceil(np.log2(max_fft)) + 1, dtype=int ) cost = ( np.ceil(n_x / (N - len(h) + 1).astype(np.float64)) * N * (np.log2(N) + 1) ) # add a heuristic term to prevent too-long FFT's which are slow # (not predicted by mult. cost alone, 4e-5 exp. determined) cost += 4e-5 * N * n_x n_fft = N[np.argmin(cost)] else: # Use only a single block n_fft = next_fast_len(min_fft) logger.debug(f"FFT block length: {n_fft}") if n_fft < min_fft: raise ValueError( f"n_fft is too short, has to be at least 2 * len(h) - 1 ({min_fft}), got " f"{n_fft}" ) # Figure out if we should use CUDA n_jobs, cuda_dict = _setup_cuda_fft_multiply_repeated(n_jobs, h, n_fft) # Process each row separately picks = _picks_to_idx(len(x), picks) parallel, p_fun, _ = parallel_func(_1d_overlap_filter, n_jobs) if n_jobs == 1: for p in picks: x[p] = _1d_overlap_filter( x[p], len(h), n_edge, phase, cuda_dict, pad, n_fft ) else: data_new = parallel( p_fun(x[p], len(h), n_edge, phase, cuda_dict, pad, n_fft) for p in picks ) for pp, p in enumerate(picks): x[p] = data_new[pp] x.shape = orig_shape return x def _1d_overlap_filter(x, n_h, n_edge, phase, cuda_dict, pad, n_fft): """Do one-dimensional overlap-add FFT FIR filtering.""" # pad to reduce ringing x_ext = _smart_pad(x, (n_edge, n_edge), pad) n_x = len(x_ext) x_filtered = np.zeros_like(x_ext) n_seg = n_fft - n_h + 1 n_segments = int(np.ceil(n_x / float(n_seg))) shift = ((n_h - 1) // 2 if phase.startswith("zero") else 0) + n_edge # Now the actual filtering step is identical for zero-phase (filtfilt-like) # or single-pass for seg_idx in range(n_segments): start = seg_idx * n_seg stop = (seg_idx + 1) * n_seg seg = x_ext[start:stop] seg = np.concatenate([seg, np.zeros(n_fft - len(seg))]) prod = _fft_multiply_repeated(seg, cuda_dict) start_filt = max(0, start - shift) stop_filt = min(start - shift + n_fft, n_x) start_prod = max(0, shift - start) stop_prod = start_prod + stop_filt - start_filt x_filtered[start_filt:stop_filt] += prod[start_prod:stop_prod] # Remove mirrored edges that we added and cast (n_edge can be zero) x_filtered = x_filtered[: n_x - 2 * n_edge].astype(x.dtype) return x_filtered def _filter_attenuation(h, freq, gain): """Compute minimum attenuation at stop frequency.""" _, filt_resp = signal.freqz(h.ravel(), worN=np.pi * freq) filt_resp = np.abs(filt_resp) # use amplitude response filt_resp[np.where(gain == 1)] = 0 idx = np.argmax(filt_resp) att_db = -20 * np.log10(np.maximum(filt_resp[idx], 1e-20)) att_freq = freq[idx] return att_db, att_freq def _prep_for_filtering(x, copy, picks=None): """Set up array as 2D for filtering ease.""" x = _check_filterable(x) if copy is True: x = x.copy() orig_shape = x.shape x = np.atleast_2d(x) picks = _picks_to_idx(x.shape[-2], picks) x.shape = (np.prod(x.shape[:-1]), x.shape[-1]) if len(orig_shape) == 3: n_epochs, n_channels, n_times = orig_shape offset = np.repeat(np.arange(0, n_channels * n_epochs, n_channels), len(picks)) picks = np.tile(picks, n_epochs) + offset elif len(orig_shape) > 3: raise ValueError( "picks argument is not supported for data with more" " than three dimensions" ) assert all(0 <= pick < x.shape[0] for pick in picks) # guaranteed by above return x, orig_shape, picks def _firwin_design(N, freq, gain, window, sfreq): """Construct a FIR filter using firwin.""" assert freq[0] == 0 assert len(freq) > 1 assert len(freq) == len(gain) assert N % 2 == 1 h = np.zeros(N) prev_freq = freq[-1] prev_gain = gain[-1] if gain[-1] == 1: h[N // 2] = 1 # start with "all up" assert prev_gain in (0, 1) for this_freq, this_gain in zip(freq[::-1][1:], gain[::-1][1:]): assert this_gain in (0, 1) if this_gain != prev_gain: # Get the correct N to satisfy the requested transition bandwidth transition = (prev_freq - this_freq) / 2.0 this_N = int(round(_length_factors[window] / transition)) this_N += 1 - this_N % 2 # make it odd if this_N > N: raise ValueError( f"The requested filter length {N} is too short for the requested " f"{transition * sfreq / 2.0:0.2f} Hz transition band, which " f"requires {this_N} samples" ) # Construct a lowpass this_h = signal.firwin( this_N, (prev_freq + this_freq) / 2.0, window=window, pass_zero=True, fs=freq[-1] * 2, ) assert this_h.shape == (this_N,) offset = (N - this_N) // 2 if this_gain == 0: h[offset : N - offset] -= this_h else: h[offset : N - offset] += this_h prev_gain = this_gain prev_freq = this_freq return h def _construct_fir_filter( sfreq, freq, gain, filter_length, phase, fir_window, fir_design ): """Filter signal using gain control points in the frequency domain. The filter impulse response is constructed from a Hann window (window used in "firwin2" function) to avoid ripples in the frequency response (windowing is a smoothing in frequency domain). If x is multi-dimensional, this operates along the last dimension. """ assert freq[0] == 0 if fir_design == "firwin2": fir_design = signal.firwin2 else: assert fir_design == "firwin" fir_design = partial(_firwin_design, sfreq=sfreq) # issue a warning if attenuation is less than this min_att_db = 12 if phase == "minimum-half" else 20 # normalize frequencies freq = np.array(freq) / (sfreq / 2.0) if freq[0] != 0 or freq[-1] != 1: raise ValueError( f"freq must start at 0 and end an Nyquist ({sfreq / 2.0}), got {freq}" ) gain = np.array(gain) # Use overlap-add filter with a fixed length N = _check_zero_phase_length(filter_length, phase, gain[-1]) # construct symmetric (linear phase) filter if phase == "minimum-half": h = fir_design(N * 2 - 1, freq, gain, window=fir_window) h = minimum_phase(h) else: h = fir_design(N, freq, gain, window=fir_window) if phase == "minimum": h = minimum_phase(h, half=False) assert h.size == N att_db, att_freq = _filter_attenuation(h, freq, gain) if phase == "zero-double": att_db += 6 if att_db < min_att_db: att_freq *= sfreq / 2.0 warn( f"Attenuation at stop frequency {att_freq:0.2f} Hz is only {att_db:0.2f} " "dB. Increase filter_length for higher attenuation." ) return h def _check_zero_phase_length(N, phase, gain_nyq=0): N = int(N) if N % 2 == 0: if phase == "zero": raise RuntimeError(f'filter_length must be odd if phase="zero", got {N}') elif phase == "zero-double" and gain_nyq == 1: N += 1 return N def _check_coefficients(system): """Check for filter stability.""" if isinstance(system, tuple): z, p, k = signal.tf2zpk(*system) else: # sos z, p, k = signal.sos2zpk(system) if np.any(np.abs(p) > 1.0): raise RuntimeError( "Filter poles outside unit circle, filter will be " "unstable. Consider using different filter " "coefficients." ) def _iir_filter(x, iir_params, picks, n_jobs, copy, phase="zero"): """Call filtfilt or lfilter.""" # set up array for filtering, reshape to 2D, operate on last axis x, orig_shape, picks = _prep_for_filtering(x, copy, picks) if phase in ("zero", "zero-double"): padlen = min(iir_params["padlen"], x.shape[-1] - 1) if "sos" in iir_params: fun = partial( signal.sosfiltfilt, sos=iir_params["sos"], padlen=padlen, axis=-1 ) _check_coefficients(iir_params["sos"]) else: fun = partial( signal.filtfilt, b=iir_params["b"], a=iir_params["a"], padlen=padlen, axis=-1, ) _check_coefficients((iir_params["b"], iir_params["a"])) else: if "sos" in iir_params: fun = partial(signal.sosfilt, sos=iir_params["sos"], axis=-1) _check_coefficients(iir_params["sos"]) else: fun = partial(signal.lfilter, b=iir_params["b"], a=iir_params["a"], axis=-1) _check_coefficients((iir_params["b"], iir_params["a"])) parallel, p_fun, n_jobs = parallel_func(fun, n_jobs) if n_jobs == 1: for p in picks: x[p] = fun(x=x[p]) else: data_new = parallel(p_fun(x=x[p]) for p in picks) for pp, p in enumerate(picks): x[p] = data_new[pp] x.shape = orig_shape return x def estimate_ringing_samples(system, max_try=100000): """Estimate filter ringing. Parameters ---------- system : tuple | ndarray A tuple of (b, a) or ndarray of second-order sections coefficients. max_try : int Approximate maximum number of samples to try. This will be changed to a multiple of 1000. Returns ------- n : int The approximate ringing. """ if isinstance(system, tuple): # TF kind = "ba" b, a = system zi = [0.0] * (len(a) - 1) else: kind = "sos" sos = system zi = [[0.0] * 2] * len(sos) n_per_chunk = 1000 n_chunks_max = int(np.ceil(max_try / float(n_per_chunk))) x = np.zeros(n_per_chunk) x[0] = 1 last_good = n_per_chunk thresh_val = 0 for ii in range(n_chunks_max): if kind == "ba": h, zi = signal.lfilter(b, a, x, zi=zi) else: h, zi = signal.sosfilt(sos, x, zi=zi) x[0] = 0 # for subsequent iterations we want zero input h = np.abs(h) thresh_val = max(0.001 * np.max(h), thresh_val) idx = np.where(np.abs(h) > thresh_val)[0] if len(idx) > 0: last_good = idx[-1] else: # this iteration had no sufficiently lange values idx = (ii - 1) * n_per_chunk + last_good break else: warn("Could not properly estimate ringing for the filter") idx = n_per_chunk * n_chunks_max return idx _ftype_dict = { "butter": "Butterworth", "cheby1": "Chebyshev I", "cheby2": "Chebyshev II", "ellip": "Cauer/elliptic", "bessel": "Bessel/Thomson", } @verbose def construct_iir_filter( iir_params, f_pass=None, f_stop=None, sfreq=None, btype=None, return_copy=True, *, phase="zero", verbose=None, ): """Use IIR parameters to get filtering coefficients. This function works like a wrapper for iirdesign and iirfilter in scipy.signal to make filter coefficients for IIR filtering. It also estimates the number of padding samples based on the filter ringing. It creates a new iir_params dict (or updates the one passed to the function) with the filter coefficients ('b' and 'a') and an estimate of the padding necessary ('padlen') so IIR filtering can be performed. Parameters ---------- iir_params : dict Dictionary of parameters to use for IIR filtering. * If ``iir_params['sos']`` exists, it will be used as second-order sections to perform IIR filtering. .. versionadded:: 0.13 * Otherwise, if ``iir_params['b']`` and ``iir_params['a']`` exist, these will be used as coefficients to perform IIR filtering. * Otherwise, if ``iir_params['order']`` and ``iir_params['ftype']`` exist, these will be used with `scipy.signal.iirfilter` to make a filter. You should also supply ``iir_params['rs']`` and ``iir_params['rp']`` if using elliptic or Chebychev filters. * Otherwise, if ``iir_params['gpass']`` and ``iir_params['gstop']`` exist, these will be used with `scipy.signal.iirdesign` to design a filter. * ``iir_params['padlen']`` defines the number of samples to pad (and an estimate will be calculated if it is not given). See Notes for more details. * ``iir_params['output']`` defines the system output kind when designing filters, either "sos" or "ba". For 0.13 the default is 'ba' but will change to 'sos' in 0.14. f_pass : float or list of float Frequency for the pass-band. Low-pass and high-pass filters should be a float, band-pass should be a 2-element list of float. f_stop : float or list of float Stop-band frequency (same size as f_pass). Not used if 'order' is specified in iir_params. sfreq : float | None The sample rate. btype : str Type of filter. Should be 'lowpass', 'highpass', or 'bandpass' (or analogous string representations known to :func:`scipy.signal.iirfilter`). return_copy : bool If False, the 'sos', 'b', 'a', and 'padlen' entries in ``iir_params`` will be set inplace (if they weren't already). Otherwise, a new ``iir_params`` instance will be created and returned with these entries. phase : str Phase of the filter. ``phase='zero'`` (default) or equivalently ``'zero-double'`` constructs and applies IIR filter twice, once forward, and once backward (making it non-causal) using :func:`~scipy.signal.filtfilt`; ``phase='forward'`` will apply the filter once in the forward (causal) direction using :func:`~scipy.signal.lfilter`. .. versionadded:: 0.13 %(verbose)s Returns ------- iir_params : dict Updated iir_params dict, with the entries (set only if they didn't exist before) for 'sos' (or 'b', 'a'), and 'padlen' for IIR filtering. See Also -------- mne.filter.filter_data mne.io.Raw.filter Notes ----- This function triages calls to :func:`scipy.signal.iirfilter` and :func:`scipy.signal.iirdesign` based on the input arguments (see linked functions for more details). .. versionchanged:: 0.14 Second-order sections are used in filter design by default (replacing ``output='ba'`` by ``output='sos'``) to help ensure filter stability and reduce numerical error. Examples -------- iir_params can have several forms. Consider constructing a low-pass filter at 40 Hz with 1000 Hz sampling rate. In the most basic (2-parameter) form of iir_params, the order of the filter 'N' and the type of filtering 'ftype' are specified. To get coefficients for a 4th-order Butterworth filter, this would be: >>> iir_params = dict(order=4, ftype='butter', output='sos') # doctest:+SKIP >>> iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low', return_copy=False) # doctest:+SKIP >>> print((2 * len(iir_params['sos']), iir_params['padlen'])) # doctest:+SKIP (4, 82) Filters can also be constructed using filter design methods. To get a 40 Hz Chebyshev type 1 lowpass with specific gain characteristics in the pass and stop bands (assuming the desired stop band is at 45 Hz), this would be a filter with much longer ringing: >>> iir_params = dict(ftype='cheby1', gpass=3, gstop=20, output='sos') # doctest:+SKIP >>> iir_params = construct_iir_filter(iir_params, 40, 50, 1000, 'low') # doctest:+SKIP >>> print((2 * len(iir_params['sos']), iir_params['padlen'])) # doctest:+SKIP (6, 439) Padding and/or filter coefficients can also be manually specified. For a 10-sample moving window with no padding during filtering, for example, one can just do: >>> iir_params = dict(b=np.ones((10)), a=[1, 0], padlen=0) # doctest:+SKIP >>> iir_params = construct_iir_filter(iir_params, return_copy=False) # doctest:+SKIP >>> print((iir_params['b'], iir_params['a'], iir_params['padlen'])) # doctest:+SKIP (array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), [1, 0], 0) For more information, see the tutorials :ref:`disc-filtering` and :ref:`tut-filter-resample`. """ # noqa: E501 known_filters = ( "bessel", "butter", "butterworth", "cauer", "cheby1", "cheby2", "chebyshev1", "chebyshev2", "chebyshevi", "chebyshevii", "ellip", "elliptic", ) if not isinstance(iir_params, dict): raise TypeError(f"iir_params must be a dict, got {type(iir_params)}") # if the filter has been designed, we're good to go Wp = None if "sos" in iir_params: system = iir_params["sos"] output = "sos" elif "a" in iir_params and "b" in iir_params: system = (iir_params["b"], iir_params["a"]) output = "ba" else: output = iir_params.get("output", "sos") _check_option("output", output, ("ba", "sos")) # ensure we have a valid ftype if "ftype" not in iir_params: raise RuntimeError( "ftype must be an entry in iir_params if 'b' and 'a' are not specified." ) ftype = iir_params["ftype"] if ftype not in known_filters: raise RuntimeError( "ftype must be in filter_dict from scipy.signal (e.g., butter, cheby1, " f"etc.) not {ftype}" ) # use order-based design f_pass = np.atleast_1d(f_pass) if f_pass.ndim > 1: raise ValueError(f"frequencies must be 1D, got {f_pass.ndim}D") edge_freqs = ", ".join(f"{f:0.2f}" for f in f_pass) Wp = f_pass / (float(sfreq) / 2) # IT will de designed ftype_nice = _ftype_dict.get(ftype, ftype) _validate_type(phase, str, "phase") _check_option("phase", phase, ("zero", "zero-double", "forward")) if phase in ("zero-double", "zero"): ptype = "zero-phase (two-pass forward and reverse) non-causal" else: ptype = "non-linear phase (one-pass forward) causal" logger.info("") logger.info("IIR filter parameters") logger.info("---------------------") logger.info(f"{ftype_nice} {btype} {ptype} filter:") # SciPy designs forward for -3dB, so forward-backward is -6dB if "order" in iir_params: singleton = btype in ("low", "lowpass", "high", "highpass") use_Wp = Wp.item() if singleton else Wp kwargs = dict( N=iir_params["order"], Wn=use_Wp, btype=btype, ftype=ftype, output=output, ) for key in ("rp", "rs"): if key in iir_params: kwargs[key] = iir_params[key] system = signal.iirfilter(**kwargs) if phase in ("zero", "zero-double"): ptype, pmul = "(effective, after forward-backward)", 2 else: ptype, pmul = "(forward)", 1 logger.info( "- Filter order %d %s", pmul * iir_params["order"] * len(Wp), ptype ) else: # use gpass / gstop design Ws = np.asanyarray(f_stop) / (float(sfreq) / 2) if "gpass" not in iir_params or "gstop" not in iir_params: raise ValueError( "iir_params must have at least 'gstop' and 'gpass' (or N) entries." ) system = signal.iirdesign( Wp, Ws, iir_params["gpass"], iir_params["gstop"], ftype=ftype, output=output, ) if system is None: raise RuntimeError("coefficients could not be created from iir_params") # do some sanity checks _check_coefficients(system) # get the gains at the cutoff frequencies if Wp is not None: if output == "sos": cutoffs = signal.sosfreqz(system, worN=Wp * np.pi)[1] else: cutoffs = signal.freqz(system[0], system[1], worN=Wp * np.pi)[1] cutoffs = 20 * np.log10(np.abs(cutoffs)) # 2 * 20 here because we do forward-backward filtering if phase in ("zero", "zero-double"): cutoffs *= 2 cutoffs = ", ".join([f"{c:0.2f}" for c in cutoffs]) logger.info(f"- Cutoff{_pl(f_pass)} at {edge_freqs} Hz: {cutoffs} dB") # now deal with padding if "padlen" not in iir_params: padlen = estimate_ringing_samples(system) else: padlen = iir_params["padlen"] if return_copy: iir_params = deepcopy(iir_params) iir_params.update(dict(padlen=padlen)) if output == "sos": iir_params.update(sos=system) else: iir_params.update(b=system[0], a=system[1]) logger.info("") return iir_params def _check_method(method, iir_params, extra_types=()): """Parse method arguments.""" allowed_types = ["iir", "fir", "fft"] + list(extra_types) _validate_type(method, "str", "method") _check_option("method", method, allowed_types) if method == "fft": method = "fir" # use the better name if method == "iir": if iir_params is None: iir_params = dict() if len(iir_params) == 0 or (len(iir_params) == 1 and "output" in iir_params): iir_params = dict( order=4, ftype="butter", output=iir_params.get("output", "sos") ) elif iir_params is not None: raise ValueError('iir_params must be None if method != "iir"') return iir_params, method @verbose def filter_data( data, sfreq, l_freq, h_freq, picks=None, filter_length="auto", l_trans_bandwidth="auto", h_trans_bandwidth="auto", n_jobs=None, method="fir", iir_params=None, copy=True, phase="zero", fir_window="hamming", fir_design="firwin", pad="reflect_limited", *, verbose=None, ): """Filter a subset of channels. Parameters ---------- data : ndarray, shape (..., n_times) The data to filter. sfreq : float The sample frequency in Hz. %(l_freq)s %(h_freq)s %(picks_nostr)s Currently this is only supported for 2D (n_channels, n_times) and 3D (n_epochs, n_channels, n_times) arrays. %(filter_length)s %(l_trans_bandwidth)s %(h_trans_bandwidth)s %(n_jobs_fir)s %(method_fir)s %(iir_params)s copy : bool If True, a copy of x, filtered, is returned. Otherwise, it operates on x in place. %(phase)s %(fir_window)s %(fir_design)s %(pad_fir)s The default is ``'reflect_limited'``. .. versionadded:: 0.15 %(verbose)s Returns ------- data : ndarray, shape (..., n_times) The filtered data. See Also -------- construct_iir_filter create_filter mne.io.Raw.filter notch_filter resample Notes ----- Applies a zero-phase low-pass, high-pass, band-pass, or band-stop filter to the channels selected by ``picks``. ``l_freq`` and ``h_freq`` are the frequencies below which and above which, respectively, to filter out of the data. Thus the uses are: * ``l_freq < h_freq``: band-pass filter * ``l_freq > h_freq``: band-stop filter * ``l_freq is not None and h_freq is None``: high-pass filter * ``l_freq is None and h_freq is not None``: low-pass filter .. note:: If n_jobs > 1, more memory is required as ``len(picks) * n_times`` additional time points need to be temporarily stored in memory. For more information, see the tutorials :ref:`disc-filtering` and :ref:`tut-filter-resample` and :func:`mne.filter.create_filter`. """ data = _check_filterable(data) iir_params, method = _check_method(method, iir_params) filt = create_filter( data, sfreq, l_freq, h_freq, filter_length, l_trans_bandwidth, h_trans_bandwidth, method, iir_params, phase, fir_window, fir_design, ) if method in ("fir", "fft"): data = _overlap_add_filter(data, filt, None, phase, picks, n_jobs, copy, pad) else: data = _iir_filter(data, filt, picks, n_jobs, copy, phase) return data @verbose def create_filter( data, sfreq, l_freq, h_freq, filter_length="auto", l_trans_bandwidth="auto", h_trans_bandwidth="auto", method="fir", iir_params=None, phase="zero", fir_window="hamming", fir_design="firwin", verbose=None, ): r"""Create a FIR or IIR filter. ``l_freq`` and ``h_freq`` are the frequencies below which and above which, respectively, to filter out of the data. Thus the uses are: * ``l_freq < h_freq``: band-pass filter * ``l_freq > h_freq``: band-stop filter * ``l_freq is not None and h_freq is None``: high-pass filter * ``l_freq is None and h_freq is not None``: low-pass filter Parameters ---------- data : ndarray, shape (..., n_times) | None The data that will be filtered. This is used for sanity checking only. If None, no sanity checking related to the length of the signal relative to the filter order will be performed. sfreq : float The sample frequency in Hz. %(l_freq)s %(h_freq)s %(filter_length)s %(l_trans_bandwidth)s %(h_trans_bandwidth)s %(method_fir)s %(iir_params)s %(phase)s %(fir_window)s %(fir_design)s %(verbose)s Returns ------- filt : array or dict Will be an array of FIR coefficients for method='fir', and dict with IIR parameters for method='iir'. See Also -------- filter_data Notes ----- .. note:: For FIR filters, the *cutoff frequency*, i.e. the -6 dB point, is in the middle of the transition band (when using phase='zero' and fir_design='firwin'). For IIR filters, the cutoff frequency is given by ``l_freq`` or ``h_freq`` directly, and ``l_trans_bandwidth`` and ``h_trans_bandwidth`` are ignored. **Band-pass filter** The frequency response is (approximately) given by:: 1-| ---------- | /| | \ |H| | / | | \ | / | | \ | / | | \ 0-|---------- | | -------------- | | | | | | 0 Fs1 Fp1 Fp2 Fs2 Nyq Where: * Fs1 = Fp1 - l_trans_bandwidth in Hz * Fs2 = Fp2 + h_trans_bandwidth in Hz **Band-stop filter** The frequency response is (approximately) given by:: 1-|--------- ---------- | \ / |H| | \ / | \ / | \ / 0-| ----------- | | | | | | 0 Fp1 Fs1 Fs2 Fp2 Nyq Where ``Fs1 = Fp1 + l_trans_bandwidth`` and ``Fs2 = Fp2 - h_trans_bandwidth``. Multiple stop bands can be specified using arrays. **Low-pass filter** The frequency response is (approximately) given by:: 1-|------------------------ | \ |H| | \ | \ | \ 0-| ---------------- | | | | 0 Fp Fstop Nyq Where ``Fstop = Fp + trans_bandwidth``. **High-pass filter** The frequency response is (approximately) given by:: 1-| ----------------------- | / |H| | / | / | / 0-|--------- | | | | 0 Fstop Fp Nyq Where ``Fstop = Fp - trans_bandwidth``. .. versionadded:: 0.14 """ sfreq = float(sfreq) if sfreq < 0: raise ValueError("sfreq must be positive") # If no data specified, sanity checking will be skipped if data is None: logger.info( "No data specified. Sanity checks related to the length of the signal " "relative to the filter order will be skipped." ) if h_freq is not None: h_freq = np.array(h_freq, float).ravel() if (h_freq > (sfreq / 2.0)).any(): raise ValueError( f"h_freq ({h_freq}) must be less than the Nyquist frequency " f"{sfreq / 2.0}" ) if l_freq is not None: l_freq = np.array(l_freq, float).ravel() if (l_freq == 0).all(): l_freq = None iir_params, method = _check_method(method, iir_params) if l_freq is None and h_freq is None: ( data, sfreq, _, _, _, _, filter_length, phase, fir_window, fir_design, ) = _triage_filter_params( data, sfreq, None, None, None, None, filter_length, method, phase, fir_window, fir_design, ) if method == "iir": out = dict() if iir_params is None else deepcopy(iir_params) out.update(b=np.array([1.0]), a=np.array([1.0])) else: freq = [0, sfreq / 2.0] gain = [1.0, 1.0] if l_freq is None and h_freq is not None: h_freq = h_freq.item() logger.info(f"Setting up low-pass filter at {h_freq:0.2g} Hz") ( data, sfreq, _, f_p, _, f_s, filter_length, phase, fir_window, fir_design, ) = _triage_filter_params( data, sfreq, None, h_freq, None, h_trans_bandwidth, filter_length, method, phase, fir_window, fir_design, ) if method == "iir": out = construct_iir_filter( iir_params, f_p, f_s, sfreq, "lowpass", phase=phase ) else: # 'fir' freq = [0, f_p, f_s] gain = [1, 1, 0] if f_s != sfreq / 2.0: freq += [sfreq / 2.0] gain += [0] elif l_freq is not None and h_freq is None: l_freq = l_freq.item() logger.info(f"Setting up high-pass filter at {l_freq:0.2g} Hz") ( data, sfreq, pass_, _, stop, _, filter_length, phase, fir_window, fir_design, ) = _triage_filter_params( data, sfreq, l_freq, None, l_trans_bandwidth, None, filter_length, method, phase, fir_window, fir_design, ) if method == "iir": out = construct_iir_filter( iir_params, pass_, stop, sfreq, "highpass", phase=phase ) else: # 'fir' freq = [stop, pass_, sfreq / 2.0] gain = [0, 1, 1] if stop != 0: freq = [0] + freq gain = [0] + gain elif l_freq is not None and h_freq is not None: if (l_freq < h_freq).any(): l_freq, h_freq = l_freq.item(), h_freq.item() logger.info( f"Setting up band-pass filter from {l_freq:0.2g} - {h_freq:0.2g} Hz" ) ( data, sfreq, f_p1, f_p2, f_s1, f_s2, filter_length, phase, fir_window, fir_design, ) = _triage_filter_params( data, sfreq, l_freq, h_freq, l_trans_bandwidth, h_trans_bandwidth, filter_length, method, phase, fir_window, fir_design, ) if method == "iir": out = construct_iir_filter( iir_params, [f_p1, f_p2], [f_s1, f_s2], sfreq, "bandpass", phase=phase, ) else: # 'fir' freq = [f_s1, f_p1, f_p2, f_s2] gain = [0, 1, 1, 0] if f_s2 != sfreq / 2.0: freq += [sfreq / 2.0] gain += [0] if f_s1 != 0: freq = [0] + freq gain = [0] + gain else: # This could possibly be removed after 0.14 release, but might # as well leave it in to sanity check notch_filter if len(l_freq) != len(h_freq): raise ValueError("l_freq and h_freq must be the same length") msg = "Setting up band-stop filter" if len(l_freq) == 1: l_freq, h_freq = l_freq.item(), h_freq.item() msg += f" from {h_freq:0.2g} - {l_freq:0.2g} Hz" logger.info(msg) # Note: order of outputs is intentionally switched here! ( data, sfreq, f_s1, f_s2, f_p1, f_p2, filter_length, phase, fir_window, fir_design, ) = _triage_filter_params( data, sfreq, h_freq, l_freq, h_trans_bandwidth, l_trans_bandwidth, filter_length, method, phase, fir_window, fir_design, bands="arr", reverse=True, ) if method == "iir": if len(f_p1) != 1: raise ValueError( "Multiple stop-bands can only be used with method='fir' " "and method='spectrum_fit'" ) out = construct_iir_filter( iir_params, [f_p1[0], f_p2[0]], [f_s1[0], f_s2[0]], sfreq, "bandstop", phase=phase, ) else: # 'fir' freq = np.r_[f_p1, f_s1, f_s2, f_p2] gain = np.r_[ np.ones_like(f_p1), np.zeros_like(f_s1), np.zeros_like(f_s2), np.ones_like(f_p2), ] order = np.argsort(freq) freq = freq[order] gain = gain[order] if freq[0] != 0: freq = np.r_[[0.0], freq] gain = np.r_[[1.0], gain] if freq[-1] != sfreq / 2.0: freq = np.r_[freq, [sfreq / 2.0]] gain = np.r_[gain, [1.0]] if np.any(np.abs(np.diff(gain, 2)) > 1): raise ValueError("Stop bands are not sufficiently separated.") if method == "fir": out = _construct_fir_filter( sfreq, freq, gain, filter_length, phase, fir_window, fir_design ) return out @verbose def notch_filter( x, Fs, freqs, filter_length="auto", notch_widths=None, trans_bandwidth=1, method="fir", iir_params=None, mt_bandwidth=None, p_value=0.05, picks=None, n_jobs=None, copy=True, phase="zero", fir_window="hamming", fir_design="firwin", pad="reflect_limited", *, verbose=None, ): r"""Notch filter for the signal x. Applies a zero-phase notch filter to the signal x, operating on the last dimension. Parameters ---------- x : array Signal to filter. Fs : float Sampling rate in Hz. freqs : float | array of float | None Frequencies to notch filter in Hz, e.g. np.arange(60, 241, 60). Multiple stop-bands can only be used with method='fir' and method='spectrum_fit'. None can only be used with the mode 'spectrum_fit', where an F test is used to find sinusoidal components. %(filter_length_notch)s notch_widths : float | array of float | None Width of the stop band (centred at each freq in freqs) in Hz. If None, freqs / 200 is used. trans_bandwidth : float Width of the transition band in Hz. Only used for ``method='fir'`` and ``method='iir'``. %(method_fir)s 'spectrum_fit' will use multi-taper estimation of sinusoidal components. If freqs=None and method='spectrum_fit', significant sinusoidal components are detected using an F test, and noted by logging. %(iir_params)s mt_bandwidth : float | None The bandwidth of the multitaper windowing function in Hz. Only used in 'spectrum_fit' mode. p_value : float P-value to use in F-test thresholding to determine significant sinusoidal components to remove when method='spectrum_fit' and freqs=None. Note that this will be Bonferroni corrected for the number of frequencies, so large p-values may be justified. %(picks_nostr)s Only supported for 2D (n_channels, n_times) and 3D (n_epochs, n_channels, n_times) data. %(n_jobs_fir)s copy : bool If True, a copy of x, filtered, is returned. Otherwise, it operates on x in place. %(phase)s %(fir_window)s %(fir_design)s %(pad_fir)s The default is ``'reflect_limited'``. %(verbose)s Returns ------- xf : array The x array filtered. See Also -------- filter_data resample Notes ----- The frequency response is (approximately) given by:: 1-|---------- ----------- | \ / |H| | \ / | \ / | \ / 0-| - | | | | | 0 Fp1 freq Fp2 Nyq For each freq in freqs, where ``Fp1 = freq - trans_bandwidth / 2`` and ``Fs2 = freq + trans_bandwidth / 2``. References ---------- Multi-taper removal is inspired by code from the Chronux toolbox, see www.chronux.org and the book "Observed Brain Dynamics" by Partha Mitra & Hemant Bokil, Oxford University Press, New York, 2008. Please cite this in publications if method 'spectrum_fit' is used. """ x = _check_filterable(x, "notch filtered", "notch_filter") iir_params, method = _check_method(method, iir_params, ["spectrum_fit"]) if freqs is not None: freqs = np.atleast_1d(freqs) elif method != "spectrum_fit": raise ValueError("freqs=None can only be used with method spectrum_fit") # Only have to deal with notch_widths for non-autodetect if freqs is not None: if notch_widths is None: notch_widths = freqs / 200.0 elif np.any(notch_widths < 0): raise ValueError("notch_widths must be >= 0") else: notch_widths = np.atleast_1d(notch_widths) if len(notch_widths) == 1: notch_widths = notch_widths[0] * np.ones_like(freqs) elif len(notch_widths) != len(freqs): raise ValueError( "notch_widths must be None, scalar, or the same length as freqs" ) if method in ("fir", "iir"): # Speed this up by computing the fourier coefficients once tb_2 = trans_bandwidth / 2.0 lows = [freq - nw / 2.0 - tb_2 for freq, nw in zip(freqs, notch_widths)] highs = [freq + nw / 2.0 + tb_2 for freq, nw in zip(freqs, notch_widths)] xf = filter_data( x, Fs, highs, lows, picks, filter_length, tb_2, tb_2, n_jobs, method, iir_params, copy, phase, fir_window, fir_design, pad=pad, ) elif method == "spectrum_fit": xf = _mt_spectrum_proc( x, Fs, freqs, notch_widths, mt_bandwidth, p_value, picks, n_jobs, copy, filter_length, ) return xf def _get_window_thresh(n_times, sfreq, mt_bandwidth, p_value): from .time_frequency.multitaper import _compute_mt_params # figure out what tapers to use window_fun, _, _ = _compute_mt_params( n_times, sfreq, mt_bandwidth, False, False, verbose=False ) # F-stat of 1-p point threshold = fstat.ppf(1 - p_value / n_times, 2, 2 * len(window_fun) - 2) return window_fun, threshold def _mt_spectrum_proc( x, sfreq, line_freqs, notch_widths, mt_bandwidth, p_value, picks, n_jobs, copy, filter_length, ): """Call _mt_spectrum_remove.""" # set up array for filtering, reshape to 2D, operate on last axis x, orig_shape, picks = _prep_for_filtering(x, copy, picks) if isinstance(filter_length, str) and filter_length == "auto": filter_length = "10s" if filter_length is None: filter_length = x.shape[-1] filter_length = min(_to_samples(filter_length, sfreq, "", ""), x.shape[-1]) get_wt = partial( _get_window_thresh, sfreq=sfreq, mt_bandwidth=mt_bandwidth, p_value=p_value ) window_fun, threshold = get_wt(filter_length) parallel, p_fun, n_jobs = parallel_func(_mt_spectrum_remove_win, n_jobs) if n_jobs == 1: freq_list = list() for ii, x_ in enumerate(x): if ii in picks: x[ii], f = _mt_spectrum_remove_win( x_, sfreq, line_freqs, notch_widths, window_fun, threshold, get_wt ) freq_list.append(f) else: data_new = parallel( p_fun(x_, sfreq, line_freqs, notch_widths, window_fun, threshold, get_wt) for xi, x_ in enumerate(x) if xi in picks ) freq_list = [d[1] for d in data_new] data_new = np.array([d[0] for d in data_new]) x[picks, :] = data_new # report found frequencies, but do some sanitizing first by binning into # 1 Hz bins counts = Counter( sum((np.unique(np.round(ff)).tolist() for f in freq_list for ff in f), list()) ) kind = "Detected" if line_freqs is None else "Removed" found_freqs = ( "\n".join( f" {freq:6.2f} : {counts[freq]:4d} window{_pl(counts[freq])}" for freq in sorted(counts) ) or " None" ) logger.info(f"{kind} notch frequencies (Hz):\n{found_freqs}") x.shape = orig_shape return x def _mt_spectrum_remove_win( x, sfreq, line_freqs, notch_widths, window_fun, threshold, get_thresh ): n_times = x.shape[-1] n_samples = window_fun.shape[1] n_overlap = (n_samples + 1) // 2 x_out = np.zeros_like(x) rm_freqs = list() idx = [0] # Define how to process a chunk of data def process(x_): out = _mt_spectrum_remove( x_, sfreq, line_freqs, notch_widths, window_fun, threshold, get_thresh ) rm_freqs.append(out[1]) return (out[0],) # must return a tuple # Define how to store a chunk of fully processed data (it's trivial) def store(x_): stop = idx[0] + x_.shape[-1] x_out[..., idx[0] : stop] += x_ idx[0] = stop _COLA(process, store, n_times, n_samples, n_overlap, sfreq, verbose=False).feed(x) assert idx[0] == n_times return x_out, rm_freqs def _mt_spectrum_remove( x, sfreq, line_freqs, notch_widths, window_fun, threshold, get_thresh ): """Use MT-spectrum to remove line frequencies. Based on Chronux. If line_freqs is specified, all freqs within notch_width of each line_freq is set to zero. """ from .time_frequency.multitaper import _mt_spectra assert x.ndim == 1 if x.shape[-1] != window_fun.shape[-1]: window_fun, threshold = get_thresh(x.shape[-1]) # drop the even tapers n_tapers = len(window_fun) tapers_odd = np.arange(0, n_tapers, 2) tapers_even = np.arange(1, n_tapers, 2) tapers_use = window_fun[tapers_odd] # sum tapers for (used) odd prolates across time (n_tapers, 1) H0 = np.sum(tapers_use, axis=1) # sum of squares across tapers (1, ) H0_sq = sum_squared(H0) # make "time" vector rads = 2 * np.pi * (np.arange(x.size) / float(sfreq)) # compute mt_spectrum (returning n_ch, n_tapers, n_freq) x_p, freqs = _mt_spectra(x[np.newaxis, :], window_fun, sfreq) # sum of the product of x_p and H0 across tapers (1, n_freqs) x_p_H0 = np.sum(x_p[:, tapers_odd, :] * H0[np.newaxis, :, np.newaxis], axis=1) # resulting calculated amplitudes for all freqs A = x_p_H0 / H0_sq if line_freqs is None: # figure out which freqs to remove using F stat # estimated coefficient x_hat = A * H0[:, np.newaxis] # numerator for F-statistic num = (n_tapers - 1) * (A * A.conj()).real * H0_sq # denominator for F-statistic den = np.sum(np.abs(x_p[:, tapers_odd, :] - x_hat) ** 2, 1) + np.sum( np.abs(x_p[:, tapers_even, :]) ** 2, 1 ) den[den == 0] = np.inf f_stat = num / den # find frequencies to remove indices = np.where(f_stat > threshold)[1] rm_freqs = freqs[indices] else: # specify frequencies indices_1 = np.unique([np.argmin(np.abs(freqs - lf)) for lf in line_freqs]) indices_2 = [ np.logical_and(freqs > lf - nw / 2.0, freqs < lf + nw / 2.0) for lf, nw in zip(line_freqs, notch_widths) ] indices_2 = np.where(np.any(np.array(indices_2), axis=0))[0] indices = np.unique(np.r_[indices_1, indices_2]) rm_freqs = freqs[indices] fits = list() for ind in indices: c = 2 * A[0, ind] fit = np.abs(c) * np.cos(freqs[ind] * rads + np.angle(c)) fits.append(fit) if len(fits) == 0: datafit = 0.0 else: # fitted sinusoids are summed, and subtracted from data datafit = np.sum(fits, axis=0) return x - datafit, rm_freqs def _check_filterable(x, kind="filtered", alternative="filter"): # Let's be fairly strict about this -- users can easily coerce to ndarray # at their end, and we already should do it internally any time we are # using these low-level functions. At the same time, let's # help people who might accidentally use low-level functions that they # shouldn't use by pushing them in the right direction from .epochs import BaseEpochs from .evoked import Evoked from .io import BaseRaw if isinstance(x, BaseRaw | BaseEpochs | Evoked): try: name = x.__class__.__name__ except Exception: pass else: raise TypeError( "This low-level function only operates on np.ndarray instances. To get " f"a {kind} {name} instance, use a method like `inst_new = inst.copy()." f"{alternative}(...)` instead." ) _validate_type(x, (np.ndarray, list, tuple), f"Data to be {kind}") x = np.asanyarray(x) if x.dtype != np.float64: raise ValueError(f"Data to be {kind} must be real floating, got {x.dtype}") return x def _resamp_ratio_len(up, down, n): ratio = float(up) / down return ratio, max(int(round(ratio * n)), 1) @verbose def resample( x, up=1.0, down=1.0, *, axis=-1, window="auto", n_jobs=None, pad="auto", npad=100, method="fft", verbose=None, ): """Resample an array. Operates along the last dimension of the array. Parameters ---------- x : ndarray Signal to resample. up : float Factor to upsample by. down : float Factor to downsample by. axis : int Axis along which to resample (default is the last axis). %(window_resample)s %(n_jobs_cuda)s ``n_jobs='cuda'`` is only supported when ``method="fft"``. %(pad_resample_auto)s .. versionadded:: 0.15 %(npad_resample)s %(method_resample)s .. versionadded:: 1.7 %(verbose)s Returns ------- y : array The x array resampled. Notes ----- When using ``method="fft"`` (default), this uses (hopefully) intelligent edge padding and frequency-domain windowing improve :func:`scipy.signal.resample`'s resampling method, which we have adapted for our use here. Choices of npad and window have important consequences, and the default choices should work well for most natural signals. """ _validate_type(method, str, "method") _validate_type(pad, str, "pad") _check_option("method", method, ("fft", "polyphase")) # make sure our arithmetic will work x = _check_filterable(x, "resampled", "resample") ratio, final_len = _resamp_ratio_len(up, down, x.shape[axis]) del up, down if axis < 0: axis = x.ndim + axis if x.shape[axis] == 0: warn(f"x has zero length along axis={axis}, returning a copy of x") return x.copy() # prep for resampling along the last axis (swap axis with last then reshape) out_shape = list(x.shape) out_shape.pop(axis) out_shape.append(final_len) x = np.atleast_2d(x.swapaxes(axis, -1).reshape((-1, x.shape[axis]))) # do the resampling using FFT or polyphase methods kwargs = dict(pad=pad, window=window, n_jobs=n_jobs) if method == "fft": y = _resample_fft(x, npad=npad, ratio=ratio, final_len=final_len, **kwargs) else: up, down, kwargs["window"] = _prep_polyphase( ratio, x.shape[-1], final_len, window ) half_len = len(window) // 2 logger.info( f"Polyphase resampling neighborhood: ±{half_len} " f"input sample{_pl(half_len)}" ) y = _resample_polyphase(x, up=up, down=down, **kwargs) assert y.shape[-1] == final_len # restore dimensions (reshape then swap axis with last) y = y.reshape(out_shape).swapaxes(axis, -1) return y def _prep_polyphase(ratio, x_len, final_len, window): if isinstance(window, str) and window == "auto": window = ("kaiser", 5.0) # SciPy default up = final_len down = x_len g_ = gcd(up, down) up = up // g_ down = down // g_ # Figure out our signal neighborhood and design window (adapted from SciPy) if not isinstance(window, list | np.ndarray): # Design a linear-phase low-pass FIR filter max_rate = max(up, down) f_c = 1.0 / max_rate # cutoff of FIR filter (rel. to Nyquist) half_len = 10 * max_rate # reasonable cutoff for sinc-like function window = signal.firwin(2 * half_len + 1, f_c, window=window) return up, down, window def _resample_polyphase(x, *, up, down, pad, window, n_jobs): if pad == "auto": pad = "reflect" kwargs = dict(padtype=pad, window=window, up=up, down=down) _validate_type( n_jobs, (None, "int-like"), "n_jobs", extra="when method='polyphase'" ) parallel, p_fun, n_jobs = parallel_func(signal.resample_poly, n_jobs) if n_jobs == 1: y = signal.resample_poly(x, axis=-1, **kwargs) else: y = np.array(parallel(p_fun(x_, **kwargs) for x_ in x)) return y def _resample_fft(x_flat, *, ratio, final_len, pad, window, npad, n_jobs): x_len = x_flat.shape[-1] pad = "reflect_limited" if pad == "auto" else pad if (isinstance(window, str) and window == "auto") or window is None: window = "boxcar" if isinstance(npad, str): _check_option("npad", npad, ("auto",), extra="when a string") # Figure out reasonable pad that gets us to a power of 2 min_add = min(x_len // 8, 100) * 2 npad = 2 ** int(np.ceil(np.log2(x_len + min_add))) - x_len npad, extra = divmod(npad, 2) npads = np.array([npad, npad + extra], int) else: npad = _ensure_int(npad, "npad", extra="or 'auto'") npads = np.array([npad, npad], int) del npad # prep for resampling now orig_len = x_len + npads.sum() # length after padding new_len = max(int(round(ratio * orig_len)), 1) # length after resampling to_removes = [int(round(ratio * npads[0]))] to_removes.append(new_len - final_len - to_removes[0]) to_removes = np.array(to_removes) # This should hold: # assert np.abs(to_removes[1] - to_removes[0]) <= int(np.ceil(ratio)) # figure out windowing function if callable(window): W = window(fft.fftfreq(orig_len)) elif isinstance(window, np.ndarray) and window.shape == (orig_len,): W = window else: W = fft.ifftshift(signal.get_window(window, orig_len)) W *= float(new_len) / float(orig_len) # figure out if we should use CUDA n_jobs, cuda_dict = _setup_cuda_fft_resample(n_jobs, W, new_len) # do the resampling using an adaptation of scipy's FFT-based resample() # use of the 'flat' window is recommended for minimal ringing parallel, p_fun, n_jobs = parallel_func(_fft_resample, n_jobs) if n_jobs == 1: y = np.zeros((len(x_flat), new_len - to_removes.sum()), dtype=x_flat.dtype) for xi, x_ in enumerate(x_flat): y[xi] = _fft_resample(x_, new_len, npads, to_removes, cuda_dict, pad) else: y = parallel( p_fun(x_, new_len, npads, to_removes, cuda_dict, pad) for x_ in x_flat ) y = np.array(y) return y def _resample_stim_channels(stim_data, up, down): """Resample stim channels, carefully. Parameters ---------- stim_data : array, shape (n_samples,) or (n_stim_channels, n_samples) Stim channels to resample. up : float Factor to upsample by. down : float Factor to downsample by. Returns ------- stim_resampled : array, shape (n_stim_channels, n_samples_resampled) The resampled stim channels. Note ---- The approach taken here is equivalent to the approach in the C-code. See the decimate_stimch function in MNE/mne_browse_raw/save.c """ stim_data = np.atleast_2d(stim_data) n_stim_channels, n_samples = stim_data.shape ratio = float(up) / down resampled_n_samples = int(round(n_samples * ratio)) stim_resampled = np.zeros((n_stim_channels, resampled_n_samples)) # Figure out which points in old data to subsample protect against # out-of-bounds, which can happen (having one sample more than # expected) due to padding sample_picks = np.minimum( (np.arange(resampled_n_samples) / ratio).astype(int), n_samples - 1 ) # Create windows starting from sample_picks[i], ending at sample_picks[i+1] windows = zip(sample_picks, np.r_[sample_picks[1:], n_samples]) # Use the first non-zero value in each window for window_i, window in enumerate(windows): for stim_num, stim in enumerate(stim_data): nonzero = stim[window[0] : window[1]].nonzero()[0] if len(nonzero) > 0: val = stim[window[0] + nonzero[0]] else: val = stim[window[0]] stim_resampled[stim_num, window_i] = val return stim_resampled def detrend(x, order=1, axis=-1): """Detrend the array x. Parameters ---------- x : n-d array Signal to detrend. order : int Fit order. Currently must be '0' or '1'. axis : int Axis of the array to operate on. Returns ------- y : array The x array detrended. Examples -------- As in :func:`scipy.signal.detrend`:: >>> randgen = np.random.RandomState(9) >>> npoints = int(1e3) >>> noise = randgen.randn(npoints) >>> x = 3 + 2*np.linspace(0, 1, npoints) + noise >>> bool((detrend(x) - noise).max() < 0.01) True """ if axis > len(x.shape): raise ValueError(f"x does not have {axis} axes") if order == 0: fit = "constant" elif order == 1: fit = "linear" else: raise ValueError("order must be 0 or 1") y = signal.detrend(x, axis=axis, type=fit) return y # Taken from Ifeachor and Jervis p. 356. # Note that here the passband ripple and stopband attenuation are # rendundant. The scalar passband ripple δp is expressed in dB as # 20 * log10(1+δp), but the scalar stopband ripple δs is expressed in dB as # -20 * log10(δs). So if we know that our stopband attenuation is 53 dB # (Hamming) then δs = 10 ** (53 / -20.), which means that the passband # deviation should be 20 * np.log10(1 + 10 ** (53 / -20.)) == 0.0194. _fir_window_dict = { "hann": dict(name="Hann", ripple=0.0546, attenuation=44), "hamming": dict(name="Hamming", ripple=0.0194, attenuation=53), "blackman": dict(name="Blackman", ripple=0.0017, attenuation=74), } _known_fir_windows = tuple(sorted(_fir_window_dict.keys())) _known_phases_fir = ("linear", "zero", "zero-double", "minimum", "minimum-half") _known_phases_iir = ("zero", "zero-double", "forward") _known_fir_designs = ("firwin", "firwin2") _fir_design_dict = { "firwin": "Windowed time-domain", "firwin2": "Windowed frequency-domain", } def _to_samples(filter_length, sfreq, phase, fir_design): _validate_type(filter_length, (str, "int-like"), "filter_length") if isinstance(filter_length, str): filter_length = filter_length.lower() err_msg = ( "filter_length, if a string, must be a " 'human-readable time, e.g. "10s", or "auto", not ' f'"{filter_length}"' ) if filter_length.lower().endswith("ms"): mult_fact = 1e-3 filter_length = filter_length[:-2] elif filter_length[-1].lower() == "s": mult_fact = 1 filter_length = filter_length[:-1] else: raise ValueError(err_msg) # now get the number try: filter_length = float(filter_length) except ValueError: raise ValueError(err_msg) filter_length = max(int(np.ceil(filter_length * mult_fact * sfreq)), 1) if fir_design == "firwin": filter_length += (filter_length - 1) % 2 filter_length = _ensure_int(filter_length, "filter_length") return filter_length def _triage_filter_params( x, sfreq, l_freq, h_freq, l_trans_bandwidth, h_trans_bandwidth, filter_length, method, phase, fir_window, fir_design, bands="scalar", reverse=False, ): """Validate and automate filter parameter selection.""" _validate_type(phase, "str", "phase") if method == "fir": _check_option("phase", phase, _known_phases_fir, extra="when FIR filtering") else: _check_option("phase", phase, _known_phases_iir, extra="when IIR filtering") _validate_type(fir_window, "str", "fir_window") _check_option("fir_window", fir_window, _known_fir_windows) _validate_type(fir_design, "str", "fir_design") _check_option("fir_design", fir_design, _known_fir_designs) # Helpers for reporting report_phase = "non-linear phase" if phase == "minimum" else "zero-phase" causality = "causal" if phase == "minimum" else "non-causal" if phase == "zero-double": report_pass = "two-pass forward and reverse" else: report_pass = "one-pass" if l_freq is not None: if h_freq is not None: kind = "bandstop" if reverse else "bandpass" else: kind = "highpass" assert not reverse elif h_freq is not None: kind = "lowpass" assert not reverse else: kind = "allpass" def float_array(c): return np.array(c, float).ravel() if bands == "arr": cast = float_array else: cast = float sfreq = float(sfreq) if l_freq is not None: l_freq = cast(l_freq) if np.any(l_freq <= 0): raise ValueError(f"highpass frequency {l_freq} must be greater than zero") if h_freq is not None: h_freq = cast(h_freq) if np.any(h_freq >= sfreq / 2.0): raise ValueError( f"lowpass frequency {h_freq} must be less than Nyquist ({sfreq / 2.0})" ) dB_cutoff = False # meaning, don't try to compute or report if bands == "scalar" or (len(h_freq) == 1 and len(l_freq) == 1): if phase == "zero": dB_cutoff = "-6 dB" elif phase == "zero-double": dB_cutoff = "-12 dB" # we go to the next power of two when in FIR and zero-double mode if method == "iir": # Ignore these parameters, effectively l_stop, h_stop = l_freq, h_freq else: # method == 'fir' l_stop = h_stop = None logger.info("") logger.info("FIR filter parameters") logger.info("---------------------") logger.info( f"Designing a {report_pass}, {report_phase}, {causality} {kind} filter:" ) logger.info(f"- {_fir_design_dict[fir_design]} design ({fir_design}) method") this_dict = _fir_window_dict[fir_window] if fir_design == "firwin": logger.info( "- {name:s} window with {ripple:0.4f} passband ripple " "and {attenuation:d} dB stopband attenuation".format(**this_dict) ) else: logger.info("- {name:s} window".format(**this_dict)) if l_freq is not None: # high-pass component if isinstance(l_trans_bandwidth, str): if l_trans_bandwidth != "auto": raise ValueError( 'l_trans_bandwidth must be "auto" if string, got "' f'{l_trans_bandwidth}"' ) l_trans_bandwidth = np.minimum(np.maximum(0.25 * l_freq, 2.0), l_freq) l_trans_rep = np.array(l_trans_bandwidth, float) if l_trans_rep.size == 1: l_trans_rep = f"{l_trans_rep.item():0.2f}" with np.printoptions(precision=2, floatmode="fixed"): msg = f"- Lower transition bandwidth: {l_trans_rep} Hz" if dB_cutoff: l_freq_rep = np.array(l_freq, float) if l_freq_rep.size == 1: l_freq_rep = f"{l_freq_rep.item():0.2f}" cutoff_rep = np.array(l_freq - l_trans_bandwidth / 2.0, float) if cutoff_rep.size == 1: cutoff_rep = f"{cutoff_rep.item():0.2f}" # Could be an array logger.info(f"- Lower passband edge: {l_freq_rep}") msg += f" ({dB_cutoff} cutoff frequency: {cutoff_rep} Hz)" logger.info(msg) l_trans_bandwidth = cast(l_trans_bandwidth) if np.any(l_trans_bandwidth <= 0): raise ValueError( f"l_trans_bandwidth must be positive, got {l_trans_bandwidth}" ) l_stop = l_freq - l_trans_bandwidth if reverse: # band-stop style l_stop += l_trans_bandwidth l_freq += l_trans_bandwidth if np.any(l_stop < 0): raise ValueError( "Filter specification invalid: Lower stop frequency negative (" f"{l_stop:0.2f} Hz). Increase pass frequency or reduce the " "transition bandwidth (l_trans_bandwidth)" ) if h_freq is not None: # low-pass component if isinstance(h_trans_bandwidth, str): if h_trans_bandwidth != "auto": raise ValueError( 'h_trans_bandwidth must be "auto" if ' f'string, got "{h_trans_bandwidth}"' ) h_trans_bandwidth = np.minimum( np.maximum(0.25 * h_freq, 2.0), sfreq / 2.0 - h_freq ) h_trans_rep = np.array(h_trans_bandwidth, float) if h_trans_rep.size == 1: h_trans_rep = f"{h_trans_rep.item():0.2f}" with np.printoptions(precision=2, floatmode="fixed"): msg = f"- Upper transition bandwidth: {h_trans_rep} Hz" if dB_cutoff: h_freq_rep = np.array(h_freq, float) if h_freq_rep.size == 1: h_freq_rep = f"{h_freq_rep.item():0.2f}" cutoff_rep = np.array(h_freq + h_trans_bandwidth / 2.0, float) if cutoff_rep.size == 1: cutoff_rep = f"{cutoff_rep.item():0.2f}" logger.info(f"- Upper passband edge: {h_freq_rep} Hz") msg += f" ({dB_cutoff} cutoff frequency: {cutoff_rep} Hz)" logger.info(msg) h_trans_bandwidth = cast(h_trans_bandwidth) if np.any(h_trans_bandwidth <= 0): raise ValueError( f"h_trans_bandwidth must be positive, got {h_trans_bandwidth}" ) h_stop = h_freq + h_trans_bandwidth if reverse: # band-stop style h_stop -= h_trans_bandwidth h_freq -= h_trans_bandwidth if np.any(h_stop > sfreq / 2): raise ValueError( f"Effective band-stop frequency ({h_stop}) is too high (maximum " f"based on Nyquist is {sfreq / 2.0})" ) if isinstance(filter_length, str) and filter_length.lower() == "auto": filter_length = filter_length.lower() h_check = l_check = np.inf if h_freq is not None: h_check = min(np.atleast_1d(h_trans_bandwidth)) if l_freq is not None: l_check = min(np.atleast_1d(l_trans_bandwidth)) mult_fact = 2.0 if fir_design == "firwin2" else 1.0 filter_length = f"{_length_factors[fir_window] * mult_fact / float(min(h_check, l_check))}s" # noqa: E501 next_pow_2 = False # disable old behavior else: next_pow_2 = isinstance(filter_length, str) and phase == "zero-double" filter_length = _to_samples(filter_length, sfreq, phase, fir_design) # use correct type of filter (must be odd length for firwin and for # zero phase) if fir_design == "firwin" or phase == "zero": filter_length += (filter_length - 1) % 2 logger.info( f"- Filter length: {filter_length} samples ({filter_length / sfreq:0.3f} s)" ) logger.info("") if filter_length <= 0: raise ValueError(f"filter_length must be positive, got {filter_length}") if next_pow_2: filter_length = 2 ** int(np.ceil(np.log2(filter_length))) if fir_design == "firwin": filter_length += (filter_length - 1) % 2 # If we have data supplied, do a sanity check if x is not None: x = _check_filterable(x) len_x = x.shape[-1] if method != "fir": filter_length = len_x if filter_length > len_x and not (l_freq is None and h_freq is None): warn( f"filter_length ({filter_length}) is longer than the signal ({len_x}), " "distortion is likely. Reduce filter length or filter a longer signal." ) logger.debug(f"Using filter length: {filter_length}") return ( x, sfreq, l_freq, h_freq, l_stop, h_stop, filter_length, phase, fir_window, fir_design, ) def _check_resamp_noop(sfreq, o_sfreq, rtol=1e-6): if np.isclose(sfreq, o_sfreq, atol=0, rtol=rtol): logger.info( f"Sampling frequency of the instance is already {sfreq}, returning " "unmodified." ) return True return False class FilterMixin: """Object for Epoch/Evoked filtering.""" @verbose def savgol_filter(self, h_freq, verbose=None): """Filter the data using Savitzky-Golay polynomial method. Parameters ---------- h_freq : float Approximate high cut-off frequency in Hz. Note that this is not an exact cutoff, since Savitzky-Golay filtering :footcite:`SavitzkyGolay1964` is done using polynomial fits instead of FIR/IIR filtering. This parameter is thus used to determine the length of the window over which a 5th-order polynomial smoothing is used. %(verbose)s Returns ------- inst : instance of Epochs, Evoked or SourceEstimate The object with the filtering applied. See Also -------- mne.io.Raw.filter Notes ----- For Savitzky-Golay low-pass approximation, see: https://gist.github.com/larsoner/bbac101d50176611136b When working on SourceEstimates the sample rate of the original data is inferred from tstep. .. versionadded:: 0.9.0 References ---------- .. footbibliography:: Examples -------- >>> import mne >>> from os import path as op >>> evoked_fname = op.join(mne.datasets.sample.data_path(), 'MEG', 'sample', 'sample_audvis-ave.fif') # doctest:+SKIP >>> evoked = mne.read_evokeds(evoked_fname, baseline=(None, 0))[0] # doctest:+SKIP >>> evoked.savgol_filter(10.) # low-pass at around 10 Hz # doctest:+SKIP >>> evoked.plot() # doctest:+SKIP """ # noqa: E501 from .source_estimate import _BaseSourceEstimate _check_preload(self, "inst.savgol_filter") if not isinstance(self, _BaseSourceEstimate): s_freq = self.info["sfreq"] else: s_freq = 1 / self.tstep h_freq = float(h_freq) if h_freq >= s_freq / 2.0: raise ValueError("h_freq must be less than half the sample rate") # savitzky-golay filtering window_length = (int(np.round(s_freq / h_freq)) // 2) * 2 + 1 logger.info("Using savgol length %d", window_length) self._data[:] = signal.savgol_filter( self._data, axis=-1, polyorder=5, window_length=window_length ) return self @verbose def filter( self, l_freq, h_freq, picks=None, filter_length="auto", l_trans_bandwidth="auto", h_trans_bandwidth="auto", n_jobs=None, method="fir", iir_params=None, phase="zero", fir_window="hamming", fir_design="firwin", skip_by_annotation=("edge", "bad_acq_skip"), pad="edge", *, verbose=None, ): """Filter a subset of channels/vertices. Parameters ---------- %(l_freq)s %(h_freq)s %(picks_all_data)s %(filter_length)s %(l_trans_bandwidth)s %(h_trans_bandwidth)s %(n_jobs_fir)s %(method_fir)s %(iir_params)s %(phase)s %(fir_window)s %(fir_design)s %(skip_by_annotation)s .. versionadded:: 0.16. %(pad_fir)s %(verbose)s Returns ------- inst : instance of Epochs, Evoked, SourceEstimate, or Raw The filtered data. See Also -------- mne.filter.create_filter mne.Evoked.savgol_filter mne.io.Raw.notch_filter mne.io.Raw.resample mne.filter.create_filter mne.filter.filter_data mne.filter.construct_iir_filter Notes ----- Applies a zero-phase low-pass, high-pass, band-pass, or band-stop filter to the channels selected by ``picks``. The data are modified inplace. The object has to have the data loaded e.g. with ``preload=True`` or ``self.load_data()``. ``l_freq`` and ``h_freq`` are the frequencies below which and above which, respectively, to filter out of the data. Thus the uses are: * ``l_freq < h_freq``: band-pass filter * ``l_freq > h_freq``: band-stop filter * ``l_freq is not None and h_freq is None``: high-pass filter * ``l_freq is None and h_freq is not None``: low-pass filter ``self.info['lowpass']`` and ``self.info['highpass']`` are only updated with picks=None. .. note:: If n_jobs > 1, more memory is required as ``len(picks) * n_times`` additional time points need to be temporarily stored in memory. When working on SourceEstimates the sample rate of the original data is inferred from tstep. For more information, see the tutorials :ref:`disc-filtering` and :ref:`tut-filter-resample` and :func:`mne.filter.create_filter`. .. versionadded:: 0.15 """ from .annotations import _annotations_starts_stops from .io import BaseRaw from .source_estimate import _BaseSourceEstimate _check_preload(self, "inst.filter") if not isinstance(self, _BaseSourceEstimate): update_info, picks = _filt_check_picks(self.info, picks, l_freq, h_freq) s_freq = self.info["sfreq"] else: s_freq = 1.0 / self.tstep if pad is None and method != "iir": pad = "edge" if isinstance(self, BaseRaw): # Deal with annotations onsets, ends = _annotations_starts_stops( self, skip_by_annotation, invert=True ) logger.info( "Filtering raw data in %d contiguous segment%s", len(onsets), _pl(onsets), ) else: onsets, ends = np.array([0]), np.array([self._data.shape[1]]) max_idx = (ends - onsets).argmax() for si, (start, stop) in enumerate(zip(onsets, ends)): # Only output filter params once (for info level), and only warn # once about the length criterion (longest segment is too short) use_verbose = verbose if si == max_idx else "error" filter_data( self._data[:, start:stop], s_freq, l_freq, h_freq, picks, filter_length, l_trans_bandwidth, h_trans_bandwidth, n_jobs, method, iir_params, copy=False, phase=phase, fir_window=fir_window, fir_design=fir_design, pad=pad, verbose=use_verbose, ) # update info if filter is applied to all data channels/vertices, # and it's not a band-stop filter if not isinstance(self, _BaseSourceEstimate): _filt_update_info(self.info, update_info, l_freq, h_freq) return self @verbose def resample( self, sfreq, *, npad="auto", window="auto", n_jobs=None, pad="edge", method="fft", verbose=None, ): """Resample data. If appropriate, an anti-aliasing filter is applied before resampling. See :ref:`resampling-and-decimating` for more information. .. note:: Data must be loaded. Parameters ---------- sfreq : float New sample rate to use. %(npad)s %(window_resample)s %(n_jobs_cuda)s %(pad_resample)s .. versionadded:: 0.15 %(method_resample)s .. versionadded:: 1.7 %(verbose)s Returns ------- inst : instance of Epochs or Evoked The resampled object. See Also -------- mne.io.Raw.resample Notes ----- For some data, it may be more accurate to use npad=0 to reduce artifacts. This is dataset dependent -- check your data! """ from .epochs import BaseEpochs from .evoked import Evoked # Should be guaranteed by our inheritance, and the fact that # mne.io.BaseRaw and _BaseSourceEstimate overrides this method assert isinstance(self, BaseEpochs | Evoked) sfreq = float(sfreq) o_sfreq = self.info["sfreq"] if _check_resamp_noop(sfreq, o_sfreq): return self _check_preload(self, "inst.resample") self._data = resample( self._data, sfreq, o_sfreq, npad=npad, window=window, n_jobs=n_jobs, pad=pad, method=method, ) lowpass = self.info.get("lowpass") lowpass = np.inf if lowpass is None else lowpass with self.info._unlock(): self.info["lowpass"] = min(lowpass, sfreq / 2.0) self.info["sfreq"] = float(sfreq) new_times = ( np.arange(self._data.shape[-1], dtype=np.float64) / sfreq + self.times[0] ) # adjust indirectly affected variables self._set_times(new_times) self._raw_times = self.times self._update_first_last() return self @verbose def apply_hilbert( self, picks=None, envelope=False, n_jobs=None, n_fft="auto", *, verbose=None ): """Compute analytic signal or envelope for a subset of channels/vertices. Parameters ---------- %(picks_all_data_noref)s envelope : bool Compute the envelope signal of each channel/vertex. Default False. See Notes. %(n_jobs)s n_fft : int | None | str Points to use in the FFT for Hilbert transformation. The signal will be padded with zeros before computing Hilbert, then cut back to original length. If None, n == self.n_times. If 'auto', the next highest fast FFT length will be use. %(verbose)s Returns ------- self : instance of Raw, Epochs, Evoked or SourceEstimate The raw object with transformed data. Notes ----- **Parameters** If ``envelope=False``, the analytic signal for the channels/vertices defined in ``picks`` is computed and the data of the Raw object is converted to a complex representation (the analytic signal is complex valued). If ``envelope=True``, the absolute value of the analytic signal for the channels/vertices defined in ``picks`` is computed, resulting in the envelope signal. .. warning: Do not use ``envelope=True`` if you intend to compute an inverse solution from the raw data. If you want to compute the envelope in source space, use ``envelope=False`` and compute the envelope after the inverse solution has been obtained. If envelope=False, more memory is required since the original raw data as well as the analytic signal have temporarily to be stored in memory. If n_jobs > 1, more memory is required as ``len(picks) * n_times`` additional time points need to be temporarily stored in memory. Also note that the ``n_fft`` parameter will allow you to pad the signal with zeros before performing the Hilbert transform. This padding is cut off, but it may result in a slightly different result (particularly around the edges). Use at your own risk. **Analytic signal** The analytic signal "x_a(t)" of "x(t)" is:: x_a = F^{-1}(F(x) 2U) = x + i y where "F" is the Fourier transform, "U" the unit step function, and "y" the Hilbert transform of "x". One usage of the analytic signal is the computation of the envelope signal, which is given by "e(t) = abs(x_a(t))". Due to the linearity of Hilbert transform and the MNE inverse solution, the enevlope in source space can be obtained by computing the analytic signal in sensor space, applying the MNE inverse, and computing the envelope in source space. """ from .source_estimate import _BaseSourceEstimate if not isinstance(self, _BaseSourceEstimate): use_info = self.info else: use_info = len(self._data) _check_preload(self, "inst.apply_hilbert") picks = _picks_to_idx(use_info, picks, exclude=(), with_ref_meg=False) if n_fft is None: n_fft = len(self.times) elif isinstance(n_fft, str): if n_fft != "auto": raise ValueError( f"n_fft must be an integer, string, or None, got {type(n_fft)}" ) n_fft = next_fast_len(len(self.times)) n_fft = int(n_fft) if n_fft < len(self.times): raise ValueError( f"n_fft ({n_fft}) must be at least the number of time points (" f"{len(self.times)})" ) dtype = None if envelope else np.complex128 args, kwargs = (), dict(n_fft=n_fft, envelope=envelope) data_in = self._data if dtype is not None and dtype != self._data.dtype: self._data = self._data.astype(dtype) parallel, p_fun, n_jobs = parallel_func(_check_fun, n_jobs) if n_jobs == 1: # modify data inplace to save memory for idx in picks: self._data[..., idx, :] = _check_fun( _my_hilbert, data_in[..., idx, :], *args, **kwargs ) else: # use parallel function data_picks_new = parallel( p_fun(_my_hilbert, data_in[..., p, :], *args, **kwargs) for p in picks ) for pp, p in enumerate(picks): self._data[..., p, :] = data_picks_new[pp] return self def _check_fun(fun, d, *args, **kwargs): """Check shapes.""" want_shape = d.shape d = fun(d, *args, **kwargs) if not isinstance(d, np.ndarray): raise TypeError("Return value must be an ndarray") if d.shape != want_shape: raise ValueError(f"Return data must have shape {want_shape} not {d.shape}") return d def _my_hilbert(x, n_fft=None, envelope=False): """Compute Hilbert transform of signals w/ zero padding. Parameters ---------- x : array, shape (n_times) The signal to convert n_fft : int Size of the FFT to perform, must be at least ``len(x)``. The signal will be cut back to original length. envelope : bool Whether to compute amplitude of the hilbert transform in order to return the signal envelope. Returns ------- out : array, shape (n_times) The hilbert transform of the signal, or the envelope. """ n_x = x.shape[-1] out = signal.hilbert(x, N=n_fft, axis=-1)[..., :n_x] if envelope: out = np.abs(out) return out @verbose def design_mne_c_filter( sfreq, l_freq=None, h_freq=40.0, l_trans_bandwidth=None, h_trans_bandwidth=5.0, verbose=None, ): """Create a FIR filter like that used by MNE-C. Parameters ---------- sfreq : float The sample frequency. l_freq : float | None The low filter frequency in Hz, default None. Can be None to avoid high-passing. h_freq : float The high filter frequency in Hz, default 40. Can be None to avoid low-passing. l_trans_bandwidth : float | None Low transition bandwidthin Hz. Can be None (default) to use 3 samples. h_trans_bandwidth : float High transition bandwidth in Hz. %(verbose)s Returns ------- h : ndarray, shape (8193,) The linear-phase (symmetric) FIR filter coefficients. Notes ----- This function is provided mostly for reference purposes. MNE-C uses a frequency-domain filter design technique by creating a linear-phase filter of length 8193. In the frequency domain, the 4197 frequencies are directly constructed, with zeroes in the stop-band and ones in the passband, with squared cosine ramps in between. """ n_freqs = (4096 + 2 * 2048) // 2 + 1 freq_resp = np.ones(n_freqs) l_freq = 0 if l_freq is None else float(l_freq) if l_trans_bandwidth is None: l_width = 3 else: l_width = (int(((n_freqs - 1) * l_trans_bandwidth) / (0.5 * sfreq)) + 1) // 2 l_start = int(((n_freqs - 1) * l_freq) / (0.5 * sfreq)) h_freq = sfreq / 2.0 if h_freq is None else float(h_freq) h_width = (int(((n_freqs - 1) * h_trans_bandwidth) / (0.5 * sfreq)) + 1) // 2 h_start = int(((n_freqs - 1) * h_freq) / (0.5 * sfreq)) logger.info( "filter : %7.3f ... %6.1f Hz bins : %d ... %d of %d " "hpw : %d lpw : %d", l_freq, h_freq, l_start, h_start, n_freqs, l_width, h_width, ) if l_freq > 0: start = l_start - l_width + 1 stop = start + 2 * l_width - 1 if start < 0 or stop >= n_freqs: raise RuntimeError("l_freq too low or l_trans_bandwidth too large") freq_resp[:start] = 0.0 k = np.arange(-l_width + 1, l_width) / float(l_width) + 3.0 freq_resp[start:stop] = np.cos(np.pi / 4.0 * k) ** 2 if h_freq < sfreq / 2.0: start = h_start - h_width + 1 stop = start + 2 * h_width - 1 if start < 0 or stop >= n_freqs: raise RuntimeError("h_freq too high or h_trans_bandwidth too large") k = np.arange(-h_width + 1, h_width) / float(h_width) + 1.0 freq_resp[start:stop] *= np.cos(np.pi / 4.0 * k) ** 2 freq_resp[stop:] = 0.0 # Get the time-domain version of this signal h = fft.irfft(freq_resp, n=2 * len(freq_resp) - 1) h = np.roll(h, n_freqs - 1) # center the impulse like a linear-phase filt return h def _filt_check_picks(info, picks, h_freq, l_freq): update_info = False # This will pick *all* data channels picks = _picks_to_idx(info, picks, "data_or_ica", exclude=()) if h_freq is not None or l_freq is not None: data_picks = _picks_to_idx( info, None, "data_or_ica", exclude=(), allow_empty=True ) if len(data_picks) == 0: logger.info( "No data channels found. The highpass and " "lowpass values in the measurement info will not " "be updated." ) elif np.isin(data_picks, picks).all(): update_info = True else: logger.info( "Filtering a subset of channels. The highpass and " "lowpass values in the measurement info will not " "be updated." ) return update_info, picks def _filt_update_info(info, update_info, l_freq, h_freq): if update_info: if ( h_freq is not None and (l_freq is None or l_freq < h_freq) and (info["lowpass"] is None or h_freq < info["lowpass"]) ): with info._unlock(): info["lowpass"] = float(h_freq) if ( l_freq is not None and (h_freq is None or l_freq < h_freq) and (info["highpass"] is None or l_freq > info["highpass"]) ): with info._unlock(): info["highpass"] = float(l_freq)