initial commit
This commit is contained in:
98
mne/preprocessing/_lof.py
Normal file
98
mne/preprocessing/_lof.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Bad channel detection using Local Outlier Factor (LOF)."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.pick import _picks_to_idx
|
||||
from ..io.base import BaseRaw
|
||||
from ..utils import _soft_import, _validate_type, logger, verbose
|
||||
|
||||
|
||||
@verbose
|
||||
def find_bad_channels_lof(
|
||||
raw,
|
||||
n_neighbors=20,
|
||||
*,
|
||||
picks=None,
|
||||
metric="euclidean",
|
||||
threshold=1.5,
|
||||
return_scores=False,
|
||||
verbose=None,
|
||||
):
|
||||
"""Find bad channels using Local Outlier Factor (LOF) algorithm.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
Raw data to process.
|
||||
n_neighbors : int
|
||||
Number of neighbors defining the local neighborhood (default is 20).
|
||||
Smaller values will lead to higher LOF scores.
|
||||
%(picks_good_data)s
|
||||
metric : str
|
||||
Metric to use for distance computation. Default is “euclidean”,
|
||||
see :func:`sklearn.metrics.pairwise.distance_metrics` for details.
|
||||
threshold : float
|
||||
Threshold to define outliers. Theoretical threshold ranges anywhere
|
||||
between 1.0 and any positive integer. Default: 1.5
|
||||
It is recommended to consider this as an hyperparameter to optimize.
|
||||
return_scores : bool
|
||||
If ``True``, return a dictionary with LOF scores for each
|
||||
evaluated channel. Default is ``False``.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
noisy_chs : list
|
||||
List of bad M/EEG channels that were automatically detected.
|
||||
scores : ndarray, shape (n_picks,)
|
||||
Only returned when ``return_scores`` is ``True``. It contains the
|
||||
LOF outlier score for each channel in ``picks``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
maxwell_filter
|
||||
annotate_amplitude
|
||||
|
||||
Notes
|
||||
-----
|
||||
See :footcite:`KumaravelEtAl2022` and :footcite:`BreunigEtAl2000` for background on
|
||||
choosing ``threshold``.
|
||||
|
||||
.. versionadded:: 1.7
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
""" # noqa: E501
|
||||
_soft_import("sklearn", "using LOF detection", strict=True)
|
||||
from sklearn.neighbors import LocalOutlierFactor
|
||||
|
||||
_validate_type(raw, BaseRaw, "raw")
|
||||
# Get the channel types
|
||||
channel_types = raw.get_channel_types()
|
||||
picks = _picks_to_idx(raw.info, picks=picks, none="data", exclude="bads")
|
||||
picked_ch_types = set(channel_types[p] for p in picks)
|
||||
|
||||
# Check if there are different channel types
|
||||
if len(picked_ch_types) != 1:
|
||||
raise ValueError(
|
||||
f"Need exactly one channel type in picks, got {sorted(picked_ch_types)}"
|
||||
)
|
||||
ch_names = [raw.ch_names[pick] for pick in picks]
|
||||
data = raw.get_data(picks=picks)
|
||||
clf = LocalOutlierFactor(n_neighbors=n_neighbors, metric=metric)
|
||||
clf.fit_predict(data)
|
||||
scores_lof = clf.negative_outlier_factor_
|
||||
bad_channel_indices = [
|
||||
i for i, v in enumerate(np.abs(scores_lof)) if v >= threshold
|
||||
]
|
||||
bads = [ch_names[idx] for idx in bad_channel_indices]
|
||||
logger.info(f"LOF: Detected bad channel(s): {bads}")
|
||||
if return_scores:
|
||||
return bads, scores_lof
|
||||
else:
|
||||
return bads
|
||||
Reference in New Issue
Block a user