# Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. import numpy as np from scipy.interpolate import interp1d from scipy.signal.windows import hann from .._fiff.pick import _picks_to_idx from ..epochs import BaseEpochs from ..event import find_events from ..evoked import Evoked from ..io import BaseRaw from ..utils import _check_option, _check_preload, _validate_type, fill_doc def _get_window(start, end): """Return window which has length as much as parameter start - end.""" window = 1 - np.r_[hann(4)[:2], np.ones(np.abs(end - start) - 4), hann(4)[-2:]].T return window def _fix_artifact( data, window, picks, first_samp, last_samp, base_tmin, base_tmax, mode ): """Modify original data by using parameter data.""" if mode == "linear": x = np.array([first_samp, last_samp]) f = interp1d(x, data[:, (first_samp, last_samp)][picks]) xnew = np.arange(first_samp, last_samp) interp_data = f(xnew) data[picks, first_samp:last_samp] = interp_data if mode == "window": data[picks, first_samp:last_samp] = ( data[picks, first_samp:last_samp] * window[np.newaxis, :] ) if mode == "constant": data[picks, first_samp:last_samp] = data[picks, base_tmin:base_tmax].mean( axis=1 )[:, None] @fill_doc def fix_stim_artifact( inst, events=None, event_id=None, tmin=0.0, tmax=0.01, *, baseline=None, mode="linear", stim_channel=None, picks=None, ): """Eliminate stimulation's artifacts from instance. .. note:: This function operates in-place, consider passing ``inst.copy()`` if this is not desired. Parameters ---------- inst : instance of Raw or Epochs or Evoked The data. events : array, shape (n_events, 3) The list of events. Required only when inst is Raw. event_id : int The id of the events generating the stimulation artifacts. If None, read all events. Required only when inst is Raw. tmin : float Start time of the interpolation window in seconds. tmax : float End time of the interpolation window in seconds. baseline : None | tuple, shape (2,) The baseline to use when ``mode='constant'``, in which case it must be non-None. .. versionadded:: 1.8 mode : 'linear' | 'window' | 'constant' Way to fill the artifacted time interval. ``"linear"`` Does linear interpolation. ``"window"`` Applies a ``(1 - hanning)`` window. ``"constant"`` Uses baseline average. baseline parameter must be provided. .. versionchanged:: 1.8 Added the ``"constant"`` mode. stim_channel : str | None Stim channel to use. %(picks_all_data)s Returns ------- inst : instance of Raw or Evoked or Epochs Instance with modified data. """ _check_option("mode", mode, ["linear", "window", "constant"]) s_start = int(np.ceil(inst.info["sfreq"] * tmin)) s_end = int(np.ceil(inst.info["sfreq"] * tmax)) if mode == "constant": _validate_type( baseline, (tuple, list), "baseline", extra="when mode='constant'" ) _check_option("len(baseline)", len(baseline), [2]) for bi, b in enumerate(baseline): _validate_type( b, "numeric", f"baseline[{bi}]", extra="when mode='constant'" ) b_start = int(np.ceil(inst.info["sfreq"] * baseline[0])) b_end = int(np.ceil(inst.info["sfreq"] * baseline[1])) else: b_start = b_end = np.nan if (mode == "window") and (s_end - s_start) < 4: raise ValueError( 'Time range is too short. Use a larger interval or set mode to "linear".' ) window = None if mode == "window": window = _get_window(s_start, s_end) picks = _picks_to_idx(inst.info, picks, "data", exclude=()) _check_preload(inst, "fix_stim_artifact") if isinstance(inst, BaseRaw): if events is None: events = find_events(inst, stim_channel=stim_channel) if len(events) == 0: raise ValueError("No events are found") if event_id is None: events_sel = np.arange(len(events)) else: events_sel = events[:, 2] == event_id event_start = events[events_sel, 0] data = inst._data for event_idx in event_start: first_samp = int(event_idx) - inst.first_samp + s_start last_samp = int(event_idx) - inst.first_samp + s_end base_t1 = int(event_idx) - inst.first_samp + b_start base_t2 = int(event_idx) - inst.first_samp + b_end _fix_artifact( data, window, picks, first_samp, last_samp, base_t1, base_t2, mode ) elif isinstance(inst, BaseEpochs): if inst.reject is not None: raise RuntimeError( "Reject is already applied. Use reject=None in the constructor." ) e_start = int(np.ceil(inst.info["sfreq"] * inst.tmin)) first_samp = s_start - e_start last_samp = s_end - e_start data = inst._data base_t1 = b_start - e_start base_t2 = b_end - e_start for epoch in data: _fix_artifact( epoch, window, picks, first_samp, last_samp, base_t1, base_t2, mode ) elif isinstance(inst, Evoked): first_samp = s_start - inst.first last_samp = s_end - inst.first data = inst.data base_t1 = b_start - inst.first base_t2 = b_end - inst.first _fix_artifact( data, window, picks, first_samp, last_samp, base_t1, base_t2, mode ) else: raise TypeError(f"Not a Raw or Epochs or Evoked (got {type(inst)}).") return inst