initial commit
This commit is contained in:
7
mne/utils/__init__.py
Normal file
7
mne/utils/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import lazy_loader as lazy
|
||||
|
||||
(__getattr__, __dir__, __all__) = lazy.attach_stub(__name__, __file__)
|
||||
384
mne/utils/__init__.pyi
Normal file
384
mne/utils/__init__.pyi
Normal file
@@ -0,0 +1,384 @@
|
||||
__all__ = [
|
||||
"ArgvSetter",
|
||||
"Bunch",
|
||||
"BunchConst",
|
||||
"BunchConstNamed",
|
||||
"ClosingStringIO",
|
||||
"ExtendedTimeMixin",
|
||||
"GetEpochsMixin",
|
||||
"ProgressBar",
|
||||
"SizeMixin",
|
||||
"TimeMixin",
|
||||
"_DefaultEventParser",
|
||||
"_PCA",
|
||||
"_ReuseCycle",
|
||||
"_TempDir",
|
||||
"_apply_scaling_array",
|
||||
"_apply_scaling_cov",
|
||||
"_arange_div",
|
||||
"_array_equal_nan",
|
||||
"_array_repr",
|
||||
"_assert_no_instances",
|
||||
"_auto_weakref",
|
||||
"_build_data_frame",
|
||||
"_check_all_same_channel_names",
|
||||
"_check_ch_locs",
|
||||
"_check_channels_spatial_filter",
|
||||
"_check_combine",
|
||||
"_check_compensation_grade",
|
||||
"_check_decim",
|
||||
"_check_depth",
|
||||
"_check_dict_keys",
|
||||
"_check_dt",
|
||||
"_check_edfio_installed",
|
||||
"_check_eeglabio_installed",
|
||||
"_check_event_id",
|
||||
"_check_fname",
|
||||
"_check_freesurfer_home",
|
||||
"_check_head_radius",
|
||||
"_check_if_nan",
|
||||
"_check_info_inv",
|
||||
"_check_integer_or_list",
|
||||
"_check_method_kwargs",
|
||||
"_check_on_missing",
|
||||
"_check_one_ch_type",
|
||||
"_check_option",
|
||||
"_check_pandas_index_arguments",
|
||||
"_check_pandas_installed",
|
||||
"_check_preload",
|
||||
"_check_pybv_installed",
|
||||
"_check_pymatreader_installed",
|
||||
"_check_qt_version",
|
||||
"_check_range",
|
||||
"_check_rank",
|
||||
"_check_sphere",
|
||||
"_check_src_normal",
|
||||
"_check_stc_units",
|
||||
"_check_subject",
|
||||
"_check_time_format",
|
||||
"_clean_names",
|
||||
"_click_ch_name",
|
||||
"_compute_row_norms",
|
||||
"_convert_times",
|
||||
"_custom_lru_cache",
|
||||
"_doc_special_members",
|
||||
"_date_to_julian",
|
||||
"_dt_to_stamp",
|
||||
"_empty_hash",
|
||||
"_ensure_events",
|
||||
"_ensure_int",
|
||||
"_explain_exception",
|
||||
"_file_like",
|
||||
"_freq_mask",
|
||||
"_gen_events",
|
||||
"_get_argvalues",
|
||||
"_get_blas_funcs",
|
||||
"_get_call_line",
|
||||
"_get_extra_data_path",
|
||||
"_get_inst_data",
|
||||
"_get_numpy_libs",
|
||||
"_get_root_dir",
|
||||
"_get_stim_channel",
|
||||
"_hashable_ndarray",
|
||||
"_import_h5io_funcs",
|
||||
"_import_h5py",
|
||||
"_import_nibabel",
|
||||
"_import_pymatreader_funcs",
|
||||
"_is_numeric",
|
||||
"_julian_to_date",
|
||||
"_mask_to_onsets_offsets",
|
||||
"_on_missing",
|
||||
"_parse_verbose",
|
||||
"_path_like",
|
||||
"_pl",
|
||||
"_prepare_read_metadata",
|
||||
"_prepare_write_metadata",
|
||||
"_raw_annot",
|
||||
"_record_warnings",
|
||||
"_reg_pinv",
|
||||
"_reject_data_segments",
|
||||
"_repeated_svd",
|
||||
"_replace_md5",
|
||||
"_require_version",
|
||||
"_resource_path",
|
||||
"_safe_input",
|
||||
"_scale_dataframe_data",
|
||||
"_scaled_array",
|
||||
"_set_pandas_dtype",
|
||||
"_soft_import",
|
||||
"_stamp_to_dt",
|
||||
"_suggest",
|
||||
"_svd_lwork",
|
||||
"_sym_mat_pow",
|
||||
"_time_mask",
|
||||
"_to_rgb",
|
||||
"_undo_scaling_array",
|
||||
"_undo_scaling_cov",
|
||||
"_url_to_local_path",
|
||||
"_validate_type",
|
||||
"_verbose_safe_false",
|
||||
"array_split_idx",
|
||||
"assert_and_remove_boundary_annot",
|
||||
"assert_dig_allclose",
|
||||
"assert_meg_snr",
|
||||
"assert_object_equal",
|
||||
"assert_snr",
|
||||
"assert_stcs_equal",
|
||||
"buggy_mkl_svd",
|
||||
"catch_logging",
|
||||
"check_fname",
|
||||
"check_random_state",
|
||||
"check_version",
|
||||
"compute_corr",
|
||||
"copy_doc",
|
||||
"copy_function_doc_to_method_doc",
|
||||
"create_slices",
|
||||
"deprecated",
|
||||
"deprecated_alias",
|
||||
"eigh",
|
||||
"fill_doc",
|
||||
"filter_out_warnings",
|
||||
"get_config",
|
||||
"get_config_path",
|
||||
"get_subjects_dir",
|
||||
"grand_average",
|
||||
"has_freesurfer",
|
||||
"has_mne_c",
|
||||
"hashfunc",
|
||||
"int_like",
|
||||
"legacy",
|
||||
"linkcode_resolve",
|
||||
"logger",
|
||||
"object_diff",
|
||||
"object_hash",
|
||||
"object_size",
|
||||
"open_docs",
|
||||
"path_like",
|
||||
"pformat",
|
||||
"pinv",
|
||||
"pinvh",
|
||||
"random_permutation",
|
||||
"repr_html",
|
||||
"requires_freesurfer",
|
||||
"requires_good_network",
|
||||
"requires_mne",
|
||||
"requires_mne_mark",
|
||||
"requires_openmeeg_mark",
|
||||
"run_command_if_main",
|
||||
"run_subprocess",
|
||||
"running_subprocess",
|
||||
"set_cache_dir",
|
||||
"set_config",
|
||||
"set_log_file",
|
||||
"set_log_level",
|
||||
"set_memmap_min_size",
|
||||
"sizeof_fmt",
|
||||
"split_list",
|
||||
"sqrtm_sym",
|
||||
"sum_squared",
|
||||
"sys_info",
|
||||
"use_log_level",
|
||||
"verbose",
|
||||
"warn",
|
||||
"wrapped_stdout",
|
||||
]
|
||||
from ._bunch import Bunch, BunchConst, BunchConstNamed
|
||||
from ._logging import (
|
||||
ClosingStringIO,
|
||||
_get_call_line,
|
||||
_parse_verbose,
|
||||
_record_warnings,
|
||||
_verbose_safe_false,
|
||||
catch_logging,
|
||||
filter_out_warnings,
|
||||
logger,
|
||||
set_log_file,
|
||||
set_log_level,
|
||||
use_log_level,
|
||||
verbose,
|
||||
warn,
|
||||
wrapped_stdout,
|
||||
)
|
||||
from ._testing import (
|
||||
ArgvSetter,
|
||||
_click_ch_name,
|
||||
_raw_annot,
|
||||
_TempDir,
|
||||
assert_and_remove_boundary_annot,
|
||||
assert_dig_allclose,
|
||||
assert_meg_snr,
|
||||
assert_object_equal,
|
||||
assert_snr,
|
||||
assert_stcs_equal,
|
||||
buggy_mkl_svd,
|
||||
has_freesurfer,
|
||||
has_mne_c,
|
||||
requires_freesurfer,
|
||||
requires_good_network,
|
||||
requires_mne,
|
||||
requires_mne_mark,
|
||||
requires_openmeeg_mark,
|
||||
run_command_if_main,
|
||||
)
|
||||
from .check import (
|
||||
_check_all_same_channel_names,
|
||||
_check_ch_locs,
|
||||
_check_channels_spatial_filter,
|
||||
_check_combine,
|
||||
_check_compensation_grade,
|
||||
_check_depth,
|
||||
_check_dict_keys,
|
||||
_check_edfio_installed,
|
||||
_check_eeglabio_installed,
|
||||
_check_event_id,
|
||||
_check_fname,
|
||||
_check_freesurfer_home,
|
||||
_check_head_radius,
|
||||
_check_if_nan,
|
||||
_check_info_inv,
|
||||
_check_integer_or_list,
|
||||
_check_method_kwargs,
|
||||
_check_on_missing,
|
||||
_check_one_ch_type,
|
||||
_check_option,
|
||||
_check_pandas_index_arguments,
|
||||
_check_pandas_installed,
|
||||
_check_preload,
|
||||
_check_pybv_installed,
|
||||
_check_pymatreader_installed,
|
||||
_check_qt_version,
|
||||
_check_range,
|
||||
_check_rank,
|
||||
_check_sphere,
|
||||
_check_src_normal,
|
||||
_check_stc_units,
|
||||
_check_subject,
|
||||
_check_time_format,
|
||||
_ensure_events,
|
||||
_ensure_int,
|
||||
_import_h5io_funcs,
|
||||
_import_h5py,
|
||||
_import_nibabel,
|
||||
_import_pymatreader_funcs,
|
||||
_is_numeric,
|
||||
_on_missing,
|
||||
_path_like,
|
||||
_require_version,
|
||||
_safe_input,
|
||||
_soft_import,
|
||||
_suggest,
|
||||
_to_rgb,
|
||||
_validate_type,
|
||||
check_fname,
|
||||
check_random_state,
|
||||
check_version,
|
||||
int_like,
|
||||
path_like,
|
||||
)
|
||||
from .config import (
|
||||
_get_extra_data_path,
|
||||
_get_numpy_libs,
|
||||
_get_root_dir,
|
||||
_get_stim_channel,
|
||||
get_config,
|
||||
get_config_path,
|
||||
get_subjects_dir,
|
||||
set_cache_dir,
|
||||
set_config,
|
||||
set_memmap_min_size,
|
||||
sys_info,
|
||||
)
|
||||
from .dataframe import (
|
||||
_build_data_frame,
|
||||
_convert_times,
|
||||
_scale_dataframe_data,
|
||||
_set_pandas_dtype,
|
||||
)
|
||||
from .docs import (
|
||||
_doc_special_members,
|
||||
copy_doc,
|
||||
copy_function_doc_to_method_doc,
|
||||
deprecated,
|
||||
deprecated_alias,
|
||||
fill_doc,
|
||||
legacy,
|
||||
linkcode_resolve,
|
||||
open_docs,
|
||||
)
|
||||
from .fetching import _url_to_local_path
|
||||
from .linalg import (
|
||||
_get_blas_funcs,
|
||||
_repeated_svd,
|
||||
_svd_lwork,
|
||||
_sym_mat_pow,
|
||||
eigh,
|
||||
pinv,
|
||||
pinvh,
|
||||
sqrtm_sym,
|
||||
)
|
||||
from .misc import (
|
||||
_assert_no_instances,
|
||||
_auto_weakref,
|
||||
_clean_names,
|
||||
_DefaultEventParser,
|
||||
_empty_hash,
|
||||
_explain_exception,
|
||||
_file_like,
|
||||
_get_argvalues,
|
||||
_pl,
|
||||
_resource_path,
|
||||
pformat,
|
||||
repr_html,
|
||||
run_subprocess,
|
||||
running_subprocess,
|
||||
sizeof_fmt,
|
||||
)
|
||||
from .mixin import (
|
||||
ExtendedTimeMixin,
|
||||
GetEpochsMixin,
|
||||
SizeMixin,
|
||||
TimeMixin,
|
||||
_check_decim,
|
||||
_prepare_read_metadata,
|
||||
_prepare_write_metadata,
|
||||
)
|
||||
from .numerics import (
|
||||
_PCA,
|
||||
_apply_scaling_array,
|
||||
_apply_scaling_cov,
|
||||
_arange_div,
|
||||
_array_equal_nan,
|
||||
_array_repr,
|
||||
_check_dt,
|
||||
_compute_row_norms,
|
||||
_custom_lru_cache,
|
||||
_date_to_julian,
|
||||
_dt_to_stamp,
|
||||
_freq_mask,
|
||||
_gen_events,
|
||||
_get_inst_data,
|
||||
_hashable_ndarray,
|
||||
_julian_to_date,
|
||||
_mask_to_onsets_offsets,
|
||||
_reg_pinv,
|
||||
_reject_data_segments,
|
||||
_replace_md5,
|
||||
_ReuseCycle,
|
||||
_scaled_array,
|
||||
_stamp_to_dt,
|
||||
_time_mask,
|
||||
_undo_scaling_array,
|
||||
_undo_scaling_cov,
|
||||
array_split_idx,
|
||||
compute_corr,
|
||||
create_slices,
|
||||
grand_average,
|
||||
hashfunc,
|
||||
object_diff,
|
||||
object_hash,
|
||||
object_size,
|
||||
random_permutation,
|
||||
split_list,
|
||||
sum_squared,
|
||||
)
|
||||
from .progressbar import ProgressBar
|
||||
104
mne/utils/_bunch.py
Normal file
104
mne/utils/_bunch.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Bunch-related classes."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
###############################################################################
|
||||
# Create a Bunch class that acts like a struct (mybunch.key = val)
|
||||
|
||||
|
||||
class Bunch(dict):
|
||||
"""Dictionary-like object that exposes its keys as attributes."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
dict.__init__(self, kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
###############################################################################
|
||||
# A protected version that prevents overwriting
|
||||
|
||||
|
||||
class BunchConst(Bunch):
|
||||
"""Class to prevent us from re-defining constants (DRY)."""
|
||||
|
||||
def __setitem__(self, key, val): # noqa: D105
|
||||
if key != "__dict__" and key in self:
|
||||
raise AttributeError(f"Attribute {repr(key)} already set")
|
||||
super().__setitem__(key, val)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# A version that tweaks the __repr__ of its values based on keys
|
||||
|
||||
|
||||
class BunchConstNamed(BunchConst):
|
||||
"""Class to provide nice __repr__ for our integer constants.
|
||||
|
||||
Only supports string keys and int or float values.
|
||||
"""
|
||||
|
||||
def __setattr__(self, attr, val): # noqa: D105
|
||||
assert isinstance(attr, str)
|
||||
if isinstance(val, int):
|
||||
val = NamedInt(attr, val)
|
||||
elif isinstance(val, float):
|
||||
val = NamedFloat(attr, val)
|
||||
else:
|
||||
assert isinstance(val, BunchConstNamed), type(val)
|
||||
super().__setattr__(attr, val)
|
||||
|
||||
|
||||
class _Named:
|
||||
"""Provide shared methods for giving named-representation subclasses."""
|
||||
|
||||
def __new__(cls, name, val): # noqa: D102,D105
|
||||
out = _named_subclass(cls).__new__(cls, val)
|
||||
out._name = name
|
||||
return out
|
||||
|
||||
def __str__(self): # noqa: D105
|
||||
return f"{self.__class__.mro()[-2](self)} ({self._name})"
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
# see https://stackoverflow.com/a/15774013/2175965
|
||||
def __copy__(self): # noqa: D105
|
||||
cls = self.__class__
|
||||
result = cls.__new__(cls)
|
||||
result.__dict__.update(self.__dict__)
|
||||
return result
|
||||
|
||||
def __deepcopy__(self, memo): # noqa: D105
|
||||
cls = self.__class__
|
||||
result = cls.__new__(cls, self._name, self)
|
||||
memo[id(self)] = result
|
||||
for k, v in self.__dict__.items():
|
||||
setattr(result, k, deepcopy(v, memo))
|
||||
return result
|
||||
|
||||
def __getnewargs__(self): # noqa: D105
|
||||
return self._name, _named_subclass(self)(self)
|
||||
|
||||
|
||||
def _named_subclass(klass):
|
||||
if not isinstance(klass, type):
|
||||
klass = klass.__class__
|
||||
subklass = klass.mro()[-2]
|
||||
assert subklass in (int, float)
|
||||
return subklass
|
||||
|
||||
|
||||
class NamedInt(_Named, int):
|
||||
"""Int with a name in __repr__."""
|
||||
|
||||
pass # noqa
|
||||
|
||||
|
||||
class NamedFloat(_Named, float):
|
||||
"""Float with a name in __repr__."""
|
||||
|
||||
pass # noqa
|
||||
527
mne/utils/_logging.py
Normal file
527
mne/utils/_logging.py
Normal file
@@ -0,0 +1,527 @@
|
||||
"""Some utility functions."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import contextlib
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import os.path as op
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from io import StringIO
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from decorator import FunctionMaker
|
||||
|
||||
from .docs import fill_doc
|
||||
|
||||
logger = logging.getLogger("mne") # one selection here used across mne-python
|
||||
logger.propagate = False # don't propagate (in case of multiple imports)
|
||||
|
||||
|
||||
# class to provide frame information (should be low overhead, just on logger
|
||||
# calls)
|
||||
|
||||
|
||||
class _FrameFilter(logging.Filter):
|
||||
def __init__(self):
|
||||
self.add_frames = 0
|
||||
|
||||
def filter(self, record):
|
||||
record.frame_info = "Unknown"
|
||||
if self.add_frames:
|
||||
# 5 is the offset necessary to get out of here and the logging
|
||||
# module, reversal is to put the oldest at the top
|
||||
frame_info = _frame_info(5 + self.add_frames)[5:][::-1]
|
||||
if len(frame_info):
|
||||
frame_info[-1] = (frame_info[-1] + " :").ljust(30)
|
||||
if len(frame_info) > 1:
|
||||
frame_info[0] = "┌" + frame_info[0]
|
||||
frame_info[-1] = "└" + frame_info[-1]
|
||||
for ii, info in enumerate(frame_info[1:-1], 1):
|
||||
frame_info[ii] = "├" + info
|
||||
record.frame_info = "\n".join(frame_info)
|
||||
return True
|
||||
|
||||
|
||||
_filter = _FrameFilter()
|
||||
logger.addFilter(_filter)
|
||||
|
||||
|
||||
# Provide help for static type checkers:
|
||||
# https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators
|
||||
_FuncT = TypeVar("_FuncT", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def verbose(function: _FuncT) -> _FuncT:
|
||||
"""Verbose decorator to allow functions to override log-level.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
function : callable
|
||||
Function to be decorated by setting the verbosity level.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dec : callable
|
||||
The decorated function.
|
||||
|
||||
See Also
|
||||
--------
|
||||
set_log_level
|
||||
set_config
|
||||
|
||||
Notes
|
||||
-----
|
||||
This decorator is used to set the verbose level during a function or method
|
||||
call, such as :func:`mne.compute_covariance`. The `verbose` keyword
|
||||
argument can be 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL', True (an
|
||||
alias for 'INFO'), or False (an alias for 'WARNING'). To set the global
|
||||
verbosity level for all functions, use :func:`mne.set_log_level`.
|
||||
|
||||
This function also serves as a docstring filler.
|
||||
|
||||
Examples
|
||||
--------
|
||||
You can use the ``verbose`` argument to set the verbose level on the fly::
|
||||
|
||||
>>> import mne
|
||||
>>> cov = mne.compute_raw_covariance(raw, verbose='WARNING') # doctest: +SKIP
|
||||
>>> cov = mne.compute_raw_covariance(raw, verbose='INFO') # doctest: +SKIP
|
||||
Using up to 49 segments
|
||||
Number of samples used : 5880
|
||||
[done]
|
||||
""" # noqa: E501
|
||||
# See https://decorator.readthedocs.io/en/latest/tests.documentation.html
|
||||
# #dealing-with-third-party-decorators
|
||||
try:
|
||||
fill_doc(function)
|
||||
except TypeError: # nothing to add
|
||||
pass
|
||||
|
||||
# Anything using verbose should have `verbose=None` in the signature.
|
||||
# This code path will raise an error if this is not the case.
|
||||
body = """\
|
||||
def %(name)s(%(signature)s):\n
|
||||
try:
|
||||
do_level_change = verbose is not None
|
||||
except (NameError, UnboundLocalError):
|
||||
raise RuntimeError('Function/method %%s does not accept verbose '
|
||||
'parameter' %% (_function_,)) from None
|
||||
if do_level_change:
|
||||
with _use_log_level_(verbose):
|
||||
return _function_(%(shortsignature)s)
|
||||
else:
|
||||
return _function_(%(shortsignature)s)"""
|
||||
evaldict = dict(_use_log_level_=use_log_level, _function_=function)
|
||||
fm = FunctionMaker(function)
|
||||
attrs = dict(
|
||||
__wrapped__=function,
|
||||
__qualname__=function.__qualname__,
|
||||
__globals__=function.__globals__,
|
||||
)
|
||||
return fm.make(body, evaldict, addsource=True, **attrs)
|
||||
|
||||
|
||||
@fill_doc
|
||||
class use_log_level:
|
||||
"""Context manager for logging level.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(verbose)s
|
||||
%(add_frames)s
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.verbose
|
||||
|
||||
Notes
|
||||
-----
|
||||
See the :ref:`logging documentation <tut-logging>` for details.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from mne import use_log_level
|
||||
>>> from mne.utils import logger
|
||||
>>> with use_log_level(False):
|
||||
... # Most MNE logger messages are "info" level, False makes them not
|
||||
... # print:
|
||||
... logger.info('This message will not be printed')
|
||||
>>> with use_log_level(True):
|
||||
... # Using verbose=True in functions, methods, or this context manager
|
||||
... # will ensure they are printed
|
||||
... logger.info('This message will be printed!')
|
||||
This message will be printed!
|
||||
"""
|
||||
|
||||
def __init__(self, verbose=None, *, add_frames=None):
|
||||
self._level = verbose
|
||||
self._add_frames = add_frames
|
||||
self._old_frames = _filter.add_frames
|
||||
|
||||
def __enter__(self): # noqa: D105
|
||||
self._old_level = set_log_level(
|
||||
self._level, return_old_level=True, add_frames=self._add_frames
|
||||
)
|
||||
|
||||
def __exit__(self, *args): # noqa: D105
|
||||
add_frames = self._old_frames if self._add_frames is not None else None
|
||||
set_log_level(self._old_level, add_frames=add_frames)
|
||||
|
||||
|
||||
_LOGGING_TYPES = dict(
|
||||
DEBUG=logging.DEBUG,
|
||||
INFO=logging.INFO,
|
||||
WARNING=logging.WARNING,
|
||||
ERROR=logging.ERROR,
|
||||
CRITICAL=logging.CRITICAL,
|
||||
)
|
||||
|
||||
|
||||
@fill_doc
|
||||
def set_log_level(verbose=None, return_old_level=False, add_frames=None):
|
||||
"""Set the logging level.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
verbose : bool, str, int, or None
|
||||
The verbosity of messages to print. If a str, it can be either DEBUG,
|
||||
INFO, WARNING, ERROR, or CRITICAL. Note that these are for
|
||||
convenience and are equivalent to passing in logging.DEBUG, etc.
|
||||
For bool, True is the same as 'INFO', False is the same as 'WARNING'.
|
||||
If None, the environment variable MNE_LOGGING_LEVEL is read, and if
|
||||
it doesn't exist, defaults to INFO.
|
||||
return_old_level : bool
|
||||
If True, return the old verbosity level.
|
||||
%(add_frames)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
old_level : int
|
||||
The old level. Only returned if ``return_old_level`` is True.
|
||||
"""
|
||||
old_verbose = logger.level
|
||||
verbose = _parse_verbose(verbose)
|
||||
|
||||
if verbose != old_verbose:
|
||||
logger.setLevel(verbose)
|
||||
if add_frames is not None:
|
||||
_filter.add_frames = int(add_frames)
|
||||
fmt = "%(frame_info)s " if add_frames else ""
|
||||
fmt += "%(message)s"
|
||||
fmt = logging.Formatter(fmt)
|
||||
for handler in logger.handlers:
|
||||
handler.setFormatter(fmt)
|
||||
return old_verbose if return_old_level else None
|
||||
|
||||
|
||||
def _parse_verbose(verbose):
|
||||
from .check import _check_option, _validate_type
|
||||
from .config import get_config
|
||||
|
||||
_validate_type(verbose, (bool, str, int, None), "verbose")
|
||||
if verbose is None:
|
||||
verbose = get_config("MNE_LOGGING_LEVEL", "INFO")
|
||||
elif isinstance(verbose, bool):
|
||||
if verbose is True:
|
||||
verbose = "INFO"
|
||||
else:
|
||||
verbose = "WARNING"
|
||||
if isinstance(verbose, str):
|
||||
verbose = verbose.upper()
|
||||
_check_option("verbose", verbose, _LOGGING_TYPES, "(when a string)")
|
||||
verbose = _LOGGING_TYPES[verbose]
|
||||
|
||||
return verbose
|
||||
|
||||
|
||||
def set_log_file(fname=None, output_format="%(message)s", overwrite=None):
|
||||
"""Set the log to print to a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : path-like | None
|
||||
Filename of the log to print to. If None, stdout is used.
|
||||
To suppress log outputs, use set_log_level('WARNING').
|
||||
output_format : str
|
||||
Format of the output messages. See the following for examples:
|
||||
|
||||
https://docs.python.org/dev/howto/logging.html
|
||||
|
||||
e.g., "%(asctime)s - %(levelname)s - %(message)s".
|
||||
overwrite : bool | None
|
||||
Overwrite the log file (if it exists). Otherwise, statements
|
||||
will be appended to the log (default). None is the same as False,
|
||||
but additionally raises a warning to notify the user that log
|
||||
entries will be appended.
|
||||
"""
|
||||
_remove_close_handlers(logger)
|
||||
if fname is not None:
|
||||
if op.isfile(fname) and overwrite is None:
|
||||
# Don't use warn() here because we just want to
|
||||
# emit a warnings.warn here (not logger.warn)
|
||||
warnings.warn(
|
||||
"Log entries will be appended to the file. Use "
|
||||
"overwrite=False to avoid this message in the "
|
||||
"future.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
overwrite = False
|
||||
mode = "w" if overwrite else "a"
|
||||
lh = logging.FileHandler(fname, mode=mode)
|
||||
else:
|
||||
"""we should just be able to do:
|
||||
lh = logging.StreamHandler(sys.stdout)
|
||||
but because doctests uses some magic on stdout, we have to do this:
|
||||
"""
|
||||
lh = logging.StreamHandler(WrapStdOut())
|
||||
|
||||
lh.setFormatter(logging.Formatter(output_format))
|
||||
# actually add the stream handler
|
||||
logger.addHandler(lh)
|
||||
|
||||
|
||||
def _remove_close_handlers(logger):
|
||||
for h in list(logger.handlers):
|
||||
# only remove our handlers (get along nicely with nose)
|
||||
if isinstance(h, logging.FileHandler | logging.StreamHandler):
|
||||
if isinstance(h, logging.FileHandler):
|
||||
h.close()
|
||||
logger.removeHandler(h)
|
||||
|
||||
|
||||
class ClosingStringIO(StringIO):
|
||||
"""StringIO that closes after getvalue()."""
|
||||
|
||||
def getvalue(self, close=True):
|
||||
"""Get the value."""
|
||||
out = super().getvalue()
|
||||
if close:
|
||||
self.close()
|
||||
return out
|
||||
|
||||
|
||||
class catch_logging:
|
||||
"""Store logging.
|
||||
|
||||
This will remove all other logging handlers, and return the handler to
|
||||
stdout when complete.
|
||||
"""
|
||||
|
||||
def __init__(self, verbose=None):
|
||||
self.verbose = verbose
|
||||
|
||||
def __enter__(self): # noqa: D105
|
||||
if self.verbose is not None:
|
||||
self._ctx = use_log_level(self.verbose)
|
||||
else:
|
||||
self._ctx = contextlib.nullcontext()
|
||||
self._data = ClosingStringIO()
|
||||
self._lh = logging.StreamHandler(self._data)
|
||||
self._lh.setFormatter(logging.Formatter("%(message)s"))
|
||||
self._lh._mne_file_like = True # monkey patch for warn() use
|
||||
_remove_close_handlers(logger)
|
||||
logger.addHandler(self._lh)
|
||||
self._ctx.__enter__()
|
||||
return self._data
|
||||
|
||||
def __exit__(self, *args): # noqa: D105
|
||||
self._ctx.__exit__(*args)
|
||||
logger.removeHandler(self._lh)
|
||||
set_log_file(None)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _record_warnings():
|
||||
# this is a helper that mostly acts like pytest.warns(None) did before
|
||||
# pytest 7
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
yield w
|
||||
|
||||
|
||||
class WrapStdOut:
|
||||
"""Dynamically wrap to sys.stdout.
|
||||
|
||||
This makes packages that monkey-patch sys.stdout (e.g.doctest,
|
||||
sphinx-gallery) work properly.
|
||||
"""
|
||||
|
||||
def __getattr__(self, name): # noqa: D105
|
||||
# Even more ridiculous than this class, this must be sys.stdout (not
|
||||
# just stdout) in order for this to work (tested on OSX and Linux)
|
||||
if hasattr(sys.stdout, name):
|
||||
return getattr(sys.stdout, name)
|
||||
else:
|
||||
raise AttributeError(f"'file' object has not attribute '{name}'")
|
||||
|
||||
|
||||
_verbose_dec_re = re.compile("^<decorator-gen-[0-9]+>$")
|
||||
|
||||
|
||||
def warn(message, category=RuntimeWarning, module="mne", ignore_namespaces=("mne",)):
|
||||
"""Emit a warning with trace outside the mne namespace.
|
||||
|
||||
This function takes arguments like warnings.warn, and sends messages
|
||||
using both ``warnings.warn`` and ``logger.warn``. Warnings can be
|
||||
generated deep within nested function calls. In order to provide a
|
||||
more helpful warning, this function traverses the stack until it
|
||||
reaches a frame outside the ``mne`` namespace that caused the error.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
message : str
|
||||
Warning message.
|
||||
category : instance of Warning
|
||||
The warning class. Defaults to ``RuntimeWarning``.
|
||||
module : str
|
||||
The name of the module emitting the warning.
|
||||
ignore_namespaces : list of str
|
||||
Namespaces to ignore when traversing the stack.
|
||||
|
||||
.. versionadded:: 0.24
|
||||
"""
|
||||
root_dirs = [importlib.import_module(ns) for ns in ignore_namespaces]
|
||||
root_dirs = [op.dirname(ns.__file__) for ns in root_dirs]
|
||||
frame = None
|
||||
if logger.level <= logging.WARNING:
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
fname = frame.f_code.co_filename
|
||||
lineno = frame.f_lineno
|
||||
# in verbose dec
|
||||
if not _verbose_dec_re.search(fname):
|
||||
# treat tests as scripts
|
||||
# and don't capture unittest/case.py (assert_raises)
|
||||
if (
|
||||
not (
|
||||
any(fname.startswith(rd) for rd in root_dirs)
|
||||
or ("unittest" in fname and "case" in fname)
|
||||
)
|
||||
or op.basename(op.dirname(fname)) == "tests"
|
||||
):
|
||||
break
|
||||
frame = frame.f_back
|
||||
del frame
|
||||
# We need to use this instead of warn(message, category, stacklevel)
|
||||
# because we move out of the MNE stack, so warnings won't properly
|
||||
# recognize the module name (and our warnings.simplefilter will fail)
|
||||
warnings.warn_explicit(
|
||||
message,
|
||||
category,
|
||||
fname,
|
||||
lineno,
|
||||
module,
|
||||
globals().get("__warningregistry__", {}),
|
||||
)
|
||||
# To avoid a duplicate warning print, we only emit the logger.warning if
|
||||
# one of the handlers is a FileHandler. See gh-5592
|
||||
# But it's also nice to be able to do:
|
||||
# with mne.utils.use_log_level('warning', add_frames=3):
|
||||
# so also check our add_frames attribute.
|
||||
if (
|
||||
any(
|
||||
isinstance(h, logging.FileHandler) or getattr(h, "_mne_file_like", False)
|
||||
for h in logger.handlers
|
||||
)
|
||||
or _filter.add_frames
|
||||
):
|
||||
logger.warning(message)
|
||||
|
||||
|
||||
def _get_call_line():
|
||||
"""Get the call line from within a function."""
|
||||
frame = inspect.currentframe().f_back.f_back
|
||||
if _verbose_dec_re.search(frame.f_code.co_filename):
|
||||
frame = frame.f_back
|
||||
context = inspect.getframeinfo(frame).code_context
|
||||
context = "unknown" if context is None else context[0].strip()
|
||||
return context
|
||||
|
||||
|
||||
def filter_out_warnings(warn_record, category=None, match=None):
|
||||
r"""Remove particular records from ``warn_record``.
|
||||
|
||||
This helper takes a list of :class:`warnings.WarningMessage` objects,
|
||||
and remove those matching category and/or text.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
category: WarningMessage type | None
|
||||
class of the message to filter out
|
||||
|
||||
match : str | None
|
||||
text or regex that matches the error message to filter out
|
||||
"""
|
||||
regexp = re.compile(".*" if match is None else match)
|
||||
is_category = [
|
||||
w.category == category if category is not None else True
|
||||
for w in warn_record._list
|
||||
]
|
||||
is_match = [regexp.match(w.message.args[0]) is not None for w in warn_record._list]
|
||||
ind = [ind for ind, (c, m) in enumerate(zip(is_category, is_match)) if c and m]
|
||||
|
||||
for i in reversed(ind):
|
||||
warn_record._list.pop(i)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def wrapped_stdout(indent="", cull_newlines=False):
|
||||
"""Wrap stdout writes to logger.info, with an optional indent prefix.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
indent : str
|
||||
The indentation to add.
|
||||
cull_newlines : bool
|
||||
If True, cull any new/blank lines at the end.
|
||||
"""
|
||||
orig_stdout = sys.stdout
|
||||
my_out = ClosingStringIO()
|
||||
sys.stdout = my_out
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.stdout = orig_stdout
|
||||
pending_newlines = 0
|
||||
for line in my_out.getvalue().split("\n"):
|
||||
if not line.strip() and cull_newlines:
|
||||
pending_newlines += 1
|
||||
continue
|
||||
for _ in range(pending_newlines):
|
||||
logger.info("\n")
|
||||
logger.info(indent + line)
|
||||
|
||||
|
||||
def _frame_info(n):
|
||||
frame = inspect.currentframe()
|
||||
try:
|
||||
frame = frame.f_back
|
||||
infos = list()
|
||||
for _ in range(n):
|
||||
try:
|
||||
name = frame.f_globals["__name__"]
|
||||
except KeyError: # in our verbose dec
|
||||
pass
|
||||
else:
|
||||
infos.append(f'{name.lstrip("mne.")}:{frame.f_lineno}')
|
||||
frame = frame.f_back
|
||||
if frame is None:
|
||||
break
|
||||
return infos
|
||||
except Exception:
|
||||
return ["unknown"]
|
||||
finally:
|
||||
del frame
|
||||
|
||||
|
||||
def _verbose_safe_false(*, level="warning"):
|
||||
lev = _LOGGING_TYPES[level.upper()]
|
||||
return lev if logger.level <= lev else None
|
||||
359
mne/utils/_testing.py
Normal file
359
mne/utils/_testing.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""Testing functions."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
from functools import wraps
|
||||
from shutil import rmtree
|
||||
from unittest import SkipTest
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_array_equal
|
||||
from scipy import linalg
|
||||
|
||||
from ._logging import ClosingStringIO, warn
|
||||
from .check import check_version
|
||||
from .misc import run_subprocess
|
||||
from .numerics import object_diff
|
||||
|
||||
|
||||
def _explain_exception(start=-1, stop=None, prefix="> "):
|
||||
"""Explain an exception."""
|
||||
# start=-1 means "only the most recent caller"
|
||||
etype, value, tb = sys.exc_info()
|
||||
string = traceback.format_list(traceback.extract_tb(tb)[start:stop])
|
||||
string = "".join(string).split("\n") + traceback.format_exception_only(etype, value)
|
||||
string = ":\n" + prefix + ("\n" + prefix).join(string)
|
||||
return string
|
||||
|
||||
|
||||
class _TempDir(str):
|
||||
"""Create and auto-destroy temp dir.
|
||||
|
||||
This is designed to be used with testing modules. Instances should be
|
||||
defined inside test functions. Instances defined at module level can not
|
||||
guarantee proper destruction of the temporary directory.
|
||||
|
||||
When used at module level, the current use of the __del__() method for
|
||||
cleanup can fail because the rmtree function may be cleaned up before this
|
||||
object (an alternative could be using the atexit module instead).
|
||||
"""
|
||||
|
||||
def __new__(self): # noqa: D105
|
||||
new = str.__new__(self, tempfile.mkdtemp(prefix="tmp_mne_tempdir_"))
|
||||
return new
|
||||
|
||||
def __init__(self):
|
||||
self._path = self.__str__()
|
||||
|
||||
def __del__(self): # noqa: D105
|
||||
rmtree(self._path, ignore_errors=True)
|
||||
|
||||
|
||||
def requires_mne(func):
|
||||
"""Decorate a function as requiring MNE."""
|
||||
return requires_mne_mark()(func)
|
||||
|
||||
|
||||
def requires_mne_mark():
|
||||
"""Mark pytest tests that require MNE-C."""
|
||||
import pytest
|
||||
|
||||
return pytest.mark.skipif(not has_mne_c(), reason="Requires MNE-C")
|
||||
|
||||
|
||||
def requires_openmeeg_mark():
|
||||
"""Mark pytest tests that require OpenMEEG."""
|
||||
import pytest
|
||||
|
||||
return pytest.mark.skipif(
|
||||
not check_version("openmeeg", "2.5.6"), reason="Requires OpenMEEG >= 2.5.6"
|
||||
)
|
||||
|
||||
|
||||
def requires_freesurfer(arg):
|
||||
"""Require Freesurfer."""
|
||||
import pytest
|
||||
|
||||
reason = "Requires Freesurfer"
|
||||
if isinstance(arg, str):
|
||||
# Calling as @requires_freesurfer('progname'): return decorator
|
||||
# after checking for progname existence
|
||||
reason += f" command: {arg}"
|
||||
try:
|
||||
run_subprocess([arg, "--version"])
|
||||
except Exception:
|
||||
skip = True
|
||||
else:
|
||||
skip = False
|
||||
return pytest.mark.skipif(skip, reason=reason)
|
||||
else:
|
||||
# Calling directly as @requires_freesurfer: return decorated function
|
||||
# and just check env var existence
|
||||
return pytest.mark.skipif(not has_freesurfer(), reason="Requires Freesurfer")(
|
||||
arg
|
||||
)
|
||||
|
||||
|
||||
def requires_good_network(func):
|
||||
import pytest
|
||||
|
||||
return pytest.mark.skipif(
|
||||
int(os.environ.get("MNE_SKIP_NETWORK_TESTS", 0)),
|
||||
reason="MNE_SKIP_NETWORK_TESTS is set",
|
||||
)(func)
|
||||
|
||||
|
||||
def run_command_if_main():
|
||||
"""Run a given command if it's __main__."""
|
||||
local_vars = inspect.currentframe().f_back.f_locals
|
||||
if local_vars.get("__name__", "") == "__main__":
|
||||
local_vars["run"]()
|
||||
|
||||
|
||||
class ArgvSetter:
|
||||
"""Temporarily set sys.argv."""
|
||||
|
||||
def __init__(self, args=(), disable_stdout=True, disable_stderr=True):
|
||||
self.argv = list(("python",) + args)
|
||||
self.stdout = ClosingStringIO() if disable_stdout else sys.stdout
|
||||
self.stderr = ClosingStringIO() if disable_stderr else sys.stderr
|
||||
|
||||
def __enter__(self): # noqa: D105
|
||||
self.orig_argv = sys.argv
|
||||
sys.argv = self.argv
|
||||
self.orig_stdout = sys.stdout
|
||||
sys.stdout = self.stdout
|
||||
self.orig_stderr = sys.stderr
|
||||
sys.stderr = self.stderr
|
||||
return self
|
||||
|
||||
def __exit__(self, *args): # noqa: D105
|
||||
sys.argv = self.orig_argv
|
||||
sys.stdout = self.orig_stdout
|
||||
sys.stderr = self.orig_stderr
|
||||
|
||||
|
||||
def has_mne_c():
|
||||
"""Check for MNE-C."""
|
||||
return "MNE_ROOT" in os.environ
|
||||
|
||||
|
||||
def has_freesurfer():
|
||||
"""Check for Freesurfer."""
|
||||
return "FREESURFER_HOME" in os.environ
|
||||
|
||||
|
||||
def buggy_mkl_svd(function):
|
||||
"""Decorate tests that make calls to SVD and intermittently fail."""
|
||||
|
||||
@wraps(function)
|
||||
def dec(*args, **kwargs):
|
||||
try:
|
||||
return function(*args, **kwargs)
|
||||
except np.linalg.LinAlgError as exp:
|
||||
if "SVD did not converge" in str(exp):
|
||||
msg = "Intel MKL SVD convergence error detected, skipping test"
|
||||
warn(msg)
|
||||
raise SkipTest(msg)
|
||||
raise
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
def assert_and_remove_boundary_annot(annotations, n=1):
|
||||
"""Assert that there are boundary annotations and remove them."""
|
||||
from ..io import BaseRaw
|
||||
|
||||
if isinstance(annotations, BaseRaw): # allow either input
|
||||
annotations = annotations.annotations
|
||||
for key in ("EDGE", "BAD"):
|
||||
idx = np.where(annotations.description == f"{key} boundary")[0]
|
||||
assert len(idx) == n
|
||||
annotations.delete(idx)
|
||||
|
||||
|
||||
def assert_object_equal(a, b, *, err_msg="Object mismatch"):
|
||||
"""Assert two objects are equal."""
|
||||
d = object_diff(a, b)
|
||||
assert d == "", f"{err_msg}\n{d}"
|
||||
|
||||
|
||||
def _raw_annot(meas_date, orig_time):
|
||||
from .._fiff.meas_info import create_info
|
||||
from ..annotations import Annotations, _handle_meas_date
|
||||
from ..io import RawArray
|
||||
|
||||
info = create_info(ch_names=10, sfreq=10.0)
|
||||
raw = RawArray(data=np.empty((10, 10)), info=info, first_samp=10)
|
||||
if meas_date is not None:
|
||||
meas_date = _handle_meas_date(meas_date)
|
||||
with raw.info._unlock(check_after=True):
|
||||
raw.info["meas_date"] = meas_date
|
||||
annot = Annotations([0.5], [0.2], ["dummy"], orig_time)
|
||||
raw.set_annotations(annotations=annot)
|
||||
return raw
|
||||
|
||||
|
||||
def _get_data(x, ch_idx):
|
||||
"""Get the (n_ch, n_times) data array."""
|
||||
from ..evoked import Evoked
|
||||
from ..io import BaseRaw
|
||||
|
||||
if isinstance(x, BaseRaw):
|
||||
return x[ch_idx][0]
|
||||
elif isinstance(x, Evoked):
|
||||
return x.data[ch_idx]
|
||||
|
||||
|
||||
def _check_snr(actual, desired, picks, min_tol, med_tol, msg, kind="MEG"):
|
||||
"""Check the SNR of a set of channels."""
|
||||
actual_data = _get_data(actual, picks)
|
||||
desired_data = _get_data(desired, picks)
|
||||
bench_rms = np.sqrt(np.mean(desired_data * desired_data, axis=1))
|
||||
error = actual_data - desired_data
|
||||
error_rms = np.sqrt(np.mean(error * error, axis=1))
|
||||
np.clip(error_rms, 1e-60, np.inf, out=error_rms) # avoid division by zero
|
||||
snrs = bench_rms / error_rms
|
||||
# min tol
|
||||
snr = snrs.min()
|
||||
bad_count = (snrs < min_tol).sum()
|
||||
msg = f" ({msg})" if msg != "" else msg
|
||||
assert bad_count == 0, (
|
||||
f"SNR (worst {snr:0.2f}) < {min_tol:0.2f} "
|
||||
f"for {bad_count}/{len(picks)} channels{msg}"
|
||||
)
|
||||
# median tol
|
||||
snr = np.median(snrs)
|
||||
assert snr >= med_tol, f"{kind} SNR median {snr:0.2f} < {med_tol:0.2f}{msg}"
|
||||
|
||||
|
||||
def assert_meg_snr(
|
||||
actual, desired, min_tol, med_tol=500.0, chpi_med_tol=500.0, msg=None
|
||||
):
|
||||
"""Assert channel SNR of a certain level.
|
||||
|
||||
Mostly useful for operations like Maxwell filtering that modify
|
||||
MEG channels while leaving EEG and others intact.
|
||||
"""
|
||||
from .._fiff.pick import pick_types
|
||||
|
||||
picks = pick_types(desired.info, meg=True, exclude=[])
|
||||
picks_desired = pick_types(desired.info, meg=True, exclude=[])
|
||||
assert_array_equal(picks, picks_desired, err_msg="MEG pick mismatch")
|
||||
chpis = pick_types(actual.info, meg=False, chpi=True, exclude=[])
|
||||
chpis_desired = pick_types(desired.info, meg=False, chpi=True, exclude=[])
|
||||
if chpi_med_tol is not None:
|
||||
assert_array_equal(chpis, chpis_desired, err_msg="cHPI pick mismatch")
|
||||
others = np.setdiff1d(
|
||||
np.arange(len(actual.ch_names)), np.concatenate([picks, chpis])
|
||||
)
|
||||
others_desired = np.setdiff1d(
|
||||
np.arange(len(desired.ch_names)), np.concatenate([picks_desired, chpis_desired])
|
||||
)
|
||||
assert_array_equal(others, others_desired, err_msg="Other pick mismatch")
|
||||
if len(others) > 0: # if non-MEG channels present
|
||||
assert_allclose(
|
||||
_get_data(actual, others),
|
||||
_get_data(desired, others),
|
||||
atol=1e-11,
|
||||
rtol=1e-5,
|
||||
err_msg="non-MEG channel mismatch",
|
||||
)
|
||||
_check_snr(actual, desired, picks, min_tol, med_tol, msg, kind="MEG")
|
||||
if chpi_med_tol is not None and len(chpis) > 0:
|
||||
_check_snr(actual, desired, chpis, 0.0, chpi_med_tol, msg, kind="cHPI")
|
||||
|
||||
|
||||
def assert_snr(actual, desired, tol):
|
||||
"""Assert actual and desired arrays are within some SNR tolerance."""
|
||||
with np.errstate(divide="ignore"): # allow infinite
|
||||
snr = linalg.norm(desired, ord="fro") / linalg.norm(desired - actual, ord="fro")
|
||||
assert snr >= tol, f"{snr} < {tol}"
|
||||
|
||||
|
||||
def assert_stcs_equal(stc1, stc2):
|
||||
"""Check that two STC are equal."""
|
||||
assert_allclose(stc1.times, stc2.times)
|
||||
assert_allclose(stc1.data, stc2.data)
|
||||
assert_array_equal(stc1.vertices[0], stc2.vertices[0])
|
||||
assert_array_equal(stc1.vertices[1], stc2.vertices[1])
|
||||
assert_allclose(stc1.tmin, stc2.tmin)
|
||||
assert_allclose(stc1.tstep, stc2.tstep)
|
||||
|
||||
|
||||
def _dig_sort_key(dig):
|
||||
"""Sort dig keys."""
|
||||
return (dig["kind"], dig["ident"])
|
||||
|
||||
|
||||
def assert_dig_allclose(info_py, info_bin, limit=None):
|
||||
"""Assert dig allclose."""
|
||||
from .._fiff.constants import FIFF
|
||||
from .._fiff.meas_info import Info
|
||||
from ..bem import fit_sphere_to_headshape
|
||||
from ..channels.montage import DigMontage
|
||||
|
||||
# test dig positions
|
||||
dig_py, dig_bin = info_py, info_bin
|
||||
if isinstance(dig_py, Info):
|
||||
assert isinstance(dig_bin, Info)
|
||||
dig_py, dig_bin = dig_py["dig"], dig_bin["dig"]
|
||||
else:
|
||||
assert isinstance(dig_bin, DigMontage)
|
||||
assert isinstance(dig_py, DigMontage)
|
||||
dig_py, dig_bin = dig_py.dig, dig_bin.dig
|
||||
info_py = info_bin = None
|
||||
assert isinstance(dig_py, list)
|
||||
assert isinstance(dig_bin, list)
|
||||
dig_py = sorted(dig_py, key=_dig_sort_key)
|
||||
dig_bin = sorted(dig_bin, key=_dig_sort_key)
|
||||
assert len(dig_py) == len(dig_bin)
|
||||
for ii, (d_py, d_bin) in enumerate(zip(dig_py[:limit], dig_bin[:limit])):
|
||||
for key in ("ident", "kind", "coord_frame"):
|
||||
assert d_py[key] == d_bin[key], key
|
||||
assert_allclose(
|
||||
d_py["r"],
|
||||
d_bin["r"],
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
err_msg=f"Failure on {ii}:\n{d_py['r']}\n{d_bin['r']}",
|
||||
)
|
||||
if any(d["kind"] == FIFF.FIFFV_POINT_EXTRA for d in dig_py) and info_py is not None:
|
||||
r_bin, o_head_bin, o_dev_bin = fit_sphere_to_headshape(
|
||||
info_bin, units="m", verbose="error"
|
||||
)
|
||||
r_py, o_head_py, o_dev_py = fit_sphere_to_headshape(
|
||||
info_py, units="m", verbose="error"
|
||||
)
|
||||
assert_allclose(r_py, r_bin, atol=1e-6)
|
||||
assert_allclose(o_dev_py, o_dev_bin, rtol=1e-5, atol=1e-6)
|
||||
assert_allclose(o_head_py, o_head_bin, rtol=1e-5, atol=1e-6)
|
||||
|
||||
|
||||
def _click_ch_name(fig, ch_index=0, button=1):
|
||||
"""Click on a channel name in a raw/epochs/ICA browse-style plot."""
|
||||
from ..viz.utils import _fake_click
|
||||
|
||||
fig.canvas.draw()
|
||||
text = fig.mne.ax_main.get_yticklabels()[ch_index]
|
||||
bbox = text.get_window_extent()
|
||||
x = bbox.intervalx.mean()
|
||||
y = bbox.intervaly.mean()
|
||||
_fake_click(fig, fig.mne.ax_main, (x, y), xform="pix", button=button)
|
||||
|
||||
|
||||
def _get_suptitle(fig):
|
||||
"""Get fig suptitle (shim for matplotlib < 3.8.0)."""
|
||||
# TODO: obsolete when minimum MPL version is 3.8
|
||||
if check_version("matplotlib", "3.8"):
|
||||
return fig.get_suptitle()
|
||||
else:
|
||||
# unreliable hack; should work in most tests as we rarely use `sup_{x,y}label`
|
||||
return fig.texts[0].get_text()
|
||||
14
mne/utils/_typing.py
Normal file
14
mne/utils/_typing.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Shared objects used for type annotations."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self
|
||||
else:
|
||||
from typing import TypeVar
|
||||
|
||||
Self = TypeVar("Self")
|
||||
1292
mne/utils/check.py
Normal file
1292
mne/utils/check.py
Normal file
File diff suppressed because it is too large
Load Diff
917
mne/utils/config.py
Normal file
917
mne/utils/config.py
Normal file
@@ -0,0 +1,917 @@
|
||||
"""The config functions."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import atexit
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import os.path as op
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from functools import lru_cache, partial
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from urllib.error import URLError
|
||||
from urllib.request import urlopen
|
||||
|
||||
from packaging.version import parse
|
||||
|
||||
from ._logging import logger, warn
|
||||
from .check import _check_fname, _check_option, _check_qt_version, _validate_type
|
||||
from .docs import fill_doc
|
||||
from .misc import _pl
|
||||
|
||||
_temp_home_dir = None
|
||||
|
||||
|
||||
class UnknownPlatformError(Exception):
|
||||
"""Exception raised for unknown platforms."""
|
||||
|
||||
|
||||
def set_cache_dir(cache_dir):
|
||||
"""Set the directory to be used for temporary file storage.
|
||||
|
||||
This directory is used by joblib to store memmapped arrays,
|
||||
which reduces memory requirements and speeds up parallel
|
||||
computation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cache_dir : str or None
|
||||
Directory to use for temporary file storage. None disables
|
||||
temporary file storage.
|
||||
"""
|
||||
if cache_dir is not None and not op.exists(cache_dir):
|
||||
raise OSError(f"Directory {cache_dir} does not exist")
|
||||
|
||||
set_config("MNE_CACHE_DIR", cache_dir, set_env=False)
|
||||
|
||||
|
||||
def set_memmap_min_size(memmap_min_size):
|
||||
"""Set the minimum size for memmaping of arrays for parallel processing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
memmap_min_size : str or None
|
||||
Threshold on the minimum size of arrays that triggers automated memory
|
||||
mapping for parallel processing, e.g., '1M' for 1 megabyte.
|
||||
Use None to disable memmaping of large arrays.
|
||||
"""
|
||||
_validate_type(memmap_min_size, (str, None), "memmap_min_size")
|
||||
if memmap_min_size is not None:
|
||||
if memmap_min_size[-1] not in ["K", "M", "G"]:
|
||||
raise ValueError(
|
||||
"The size has to be given in kilo-, mega-, or "
|
||||
f"gigabytes, e.g., 100K, 500M, 1G, got {repr(memmap_min_size)}"
|
||||
)
|
||||
|
||||
set_config("MNE_MEMMAP_MIN_SIZE", memmap_min_size, set_env=False)
|
||||
|
||||
|
||||
# List the known configuration values
|
||||
_known_config_types = {
|
||||
"MNE_3D_OPTION_ANTIALIAS": (
|
||||
"bool, whether to use full-screen antialiasing in 3D plots"
|
||||
),
|
||||
"MNE_3D_OPTION_DEPTH_PEELING": "bool, whether to use depth peeling in 3D plots",
|
||||
"MNE_3D_OPTION_MULTI_SAMPLES": (
|
||||
"int, number of samples to use for full-screen antialiasing"
|
||||
),
|
||||
"MNE_3D_OPTION_SMOOTH_SHADING": ("bool, whether to use smooth shading in 3D plots"),
|
||||
"MNE_3D_OPTION_THEME": ("str, the color theme (light or dark) to use for 3D plots"),
|
||||
"MNE_BROWSE_RAW_SIZE": (
|
||||
"tuple, width and height of the raw browser window (in inches)"
|
||||
),
|
||||
"MNE_BROWSER_BACKEND": (
|
||||
"str, the backend to use for the MNE Browse Raw window (qt or matplotlib)"
|
||||
),
|
||||
"MNE_BROWSER_OVERVIEW_MODE": (
|
||||
"str, the overview mode to use in the MNE Browse Raw window )"
|
||||
"(see mne.viz.plot_raw for valid options)"
|
||||
),
|
||||
"MNE_BROWSER_PRECOMPUTE": (
|
||||
"bool, whether to precompute raw data in the MNE Browse Raw window"
|
||||
),
|
||||
"MNE_BROWSER_THEME": "str, the color theme (light or dark) to use for the browser",
|
||||
"MNE_BROWSER_USE_OPENGL": (
|
||||
"bool, whether to use OpenGL for rendering in the MNE Browse Raw window"
|
||||
),
|
||||
"MNE_CACHE_DIR": "str, path to the cache directory for parallel execution",
|
||||
"MNE_COREG_ADVANCED_RENDERING": (
|
||||
"bool, whether to use advanced OpenGL rendering in mne coreg"
|
||||
),
|
||||
"MNE_COREG_COPY_ANNOT": (
|
||||
"bool, whether to copy the annotation files during warping"
|
||||
),
|
||||
"MNE_COREG_FULLSCREEN": "bool, whether to use full-screen mode in mne coreg",
|
||||
"MNE_COREG_GUESS_MRI_SUBJECT": (
|
||||
"bool, whether to guess the MRI subject in mne coreg"
|
||||
),
|
||||
"MNE_COREG_HEAD_HIGH_RES": (
|
||||
"bool, whether to use high-res head surface in mne coreg"
|
||||
),
|
||||
"MNE_COREG_HEAD_OPACITY": ("bool, the head surface opacity to use in mne coreg"),
|
||||
"MNE_COREG_HEAD_INSIDE": (
|
||||
"bool, whether to add an opaque inner scalp head surface to help "
|
||||
"occlude points behind the head in mne coreg"
|
||||
),
|
||||
"MNE_COREG_INTERACTION": (
|
||||
"str, interaction style in mne coreg (trackball or terrain)"
|
||||
),
|
||||
"MNE_COREG_MARK_INSIDE": (
|
||||
"bool, whether to mark points inside the head surface in mne coreg"
|
||||
),
|
||||
"MNE_COREG_PREPARE_BEM": (
|
||||
"bool, whether to prepare the BEM solution after warping in mne coreg"
|
||||
),
|
||||
"MNE_COREG_ORIENT_TO_SURFACE": (
|
||||
"bool, whether to orient the digitization markers to the head surface "
|
||||
"in mne coreg"
|
||||
),
|
||||
"MNE_COREG_SCALE_LABELS": (
|
||||
"bool, whether to scale the MRI labels during warping in mne coreg"
|
||||
),
|
||||
"MNE_COREG_SCALE_BY_DISTANCE": (
|
||||
"bool, whether to scale the digitization markers by their distance from "
|
||||
"the scalp in mne coreg"
|
||||
),
|
||||
"MNE_COREG_SCENE_SCALE": (
|
||||
"float, the scale factor of the 3D scene in mne coreg (default 0.16)"
|
||||
),
|
||||
"MNE_COREG_WINDOW_HEIGHT": "int, window height for mne coreg",
|
||||
"MNE_COREG_WINDOW_WIDTH": "int, window width for mne coreg",
|
||||
"MNE_COREG_SUBJECTS_DIR": "str, path to the subjects directory for mne coreg",
|
||||
"MNE_CUDA_DEVICE": "int, CUDA device to use for GPU processing",
|
||||
"MNE_DATA": "str, default data directory",
|
||||
"MNE_DATASETS_BRAINSTORM_PATH": "str, path for brainstorm data",
|
||||
"MNE_DATASETS_EEGBCI_PATH": "str, path for EEGBCI data",
|
||||
"MNE_DATASETS_EPILEPSY_ECOG_PATH": "str, path for epilepsy_ecog data",
|
||||
"MNE_DATASETS_HF_SEF_PATH": "str, path for HF_SEF data",
|
||||
"MNE_DATASETS_MEGSIM_PATH": "str, path for MEGSIM data",
|
||||
"MNE_DATASETS_MISC_PATH": "str, path for misc data",
|
||||
"MNE_DATASETS_MTRF_PATH": "str, path for MTRF data",
|
||||
"MNE_DATASETS_SAMPLE_PATH": "str, path for sample data",
|
||||
"MNE_DATASETS_SOMATO_PATH": "str, path for somato data",
|
||||
"MNE_DATASETS_MULTIMODAL_PATH": "str, path for multimodal data",
|
||||
"MNE_DATASETS_FNIRS_MOTOR_PATH": "str, path for fnirs_motor data",
|
||||
"MNE_DATASETS_OPM_PATH": "str, path for OPM data",
|
||||
"MNE_DATASETS_SPM_FACE_DATASETS_TESTS": "str, path for spm_face data",
|
||||
"MNE_DATASETS_SPM_FACE_PATH": "str, path for spm_face data",
|
||||
"MNE_DATASETS_TESTING_PATH": "str, path for testing data",
|
||||
"MNE_DATASETS_VISUAL_92_CATEGORIES_PATH": "str, path for visual_92_categories data",
|
||||
"MNE_DATASETS_KILOWORD_PATH": "str, path for kiloword data",
|
||||
"MNE_DATASETS_FIELDTRIP_CMC_PATH": "str, path for fieldtrip_cmc data",
|
||||
"MNE_DATASETS_PHANTOM_KIT_PATH": "str, path for phantom_kit data",
|
||||
"MNE_DATASETS_PHANTOM_4DBTI_PATH": "str, path for phantom_4dbti data",
|
||||
"MNE_DATASETS_PHANTOM_KERNEL_PATH": "str, path for phantom_kernel data",
|
||||
"MNE_DATASETS_LIMO_PATH": "str, path for limo data",
|
||||
"MNE_DATASETS_REFMEG_NOISE_PATH": "str, path for refmeg_noise data",
|
||||
"MNE_DATASETS_SSVEP_PATH": "str, path for ssvep data",
|
||||
"MNE_DATASETS_ERP_CORE_PATH": "str, path for erp_core data",
|
||||
"MNE_FORCE_SERIAL": "bool, force serial rather than parallel execution",
|
||||
"MNE_LOGGING_LEVEL": (
|
||||
"str or int, controls the level of verbosity of any function "
|
||||
"decorated with @verbose. See "
|
||||
"https://mne.tools/stable/auto_tutorials/intro/50_configure_mne.html#logging"
|
||||
),
|
||||
"MNE_MEMMAP_MIN_SIZE": (
|
||||
"str, threshold on the minimum size of arrays passed to the workers that "
|
||||
"triggers automated memory mapping, e.g., 1M or 0.5G"
|
||||
),
|
||||
"MNE_REPR_HTML": (
|
||||
"bool, represent some of our objects with rich HTML in a notebook "
|
||||
"environment"
|
||||
),
|
||||
"MNE_SKIP_NETWORK_TESTS": (
|
||||
"bool, used in a test decorator (@requires_good_network) to skip "
|
||||
"tests that include large downloads"
|
||||
),
|
||||
"MNE_SKIP_TESTING_DATASET_TESTS": (
|
||||
"bool, used in test decorators (@requires_spm_data, "
|
||||
"@requires_bstraw_data) to skip tests that require specific datasets"
|
||||
),
|
||||
"MNE_STIM_CHANNEL": "string, the default channel name for mne.find_events",
|
||||
"MNE_TQDM": (
|
||||
'str, either "tqdm", "tqdm.auto", or "off". Controls presence/absence '
|
||||
"of progress bars"
|
||||
),
|
||||
"MNE_USE_CUDA": "bool, use GPU for filtering/resampling",
|
||||
"MNE_USE_NUMBA": (
|
||||
"bool, use Numba just-in-time compiler for some of our intensive "
|
||||
"computations"
|
||||
),
|
||||
"SUBJECTS_DIR": "path-like, directory of freesurfer MRI files for each subject",
|
||||
}
|
||||
|
||||
# These allow for partial matches, e.g. 'MNE_STIM_CHANNEL_1' is okay key
|
||||
_known_config_wildcards = (
|
||||
"MNE_STIM_CHANNEL", # can have multiple stim channels
|
||||
"MNE_DATASETS_FNIRS", # mne-nirs
|
||||
"MNE_NIRS", # mne-nirs
|
||||
"MNE_KIT2FIFF", # mne-kit-gui
|
||||
"MNE_ICALABEL", # mne-icalabel
|
||||
"MNE_LSL", # mne-lsl
|
||||
)
|
||||
|
||||
|
||||
def _load_config(config_path, raise_error=False):
|
||||
"""Safely load a config file."""
|
||||
with open(config_path) as fid:
|
||||
try:
|
||||
config = json.load(fid)
|
||||
except ValueError:
|
||||
# No JSON object could be decoded --> corrupt file?
|
||||
msg = (
|
||||
f"The MNE-Python config file ({config_path}) is not a valid JSON "
|
||||
"file and might be corrupted"
|
||||
)
|
||||
if raise_error:
|
||||
raise RuntimeError(msg)
|
||||
warn(msg)
|
||||
config = dict()
|
||||
return config
|
||||
|
||||
|
||||
def get_config_path(home_dir=None):
|
||||
r"""Get path to standard mne-python config file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
home_dir : str | None
|
||||
The folder that contains the .mne config folder.
|
||||
If None, it is found automatically.
|
||||
|
||||
Returns
|
||||
-------
|
||||
config_path : str
|
||||
The path to the mne-python configuration file. On windows, this
|
||||
will be '%USERPROFILE%\.mne\mne-python.json'. On every other
|
||||
system, this will be ~/.mne/mne-python.json.
|
||||
"""
|
||||
val = op.join(_get_extra_data_path(home_dir=home_dir), "mne-python.json")
|
||||
return val
|
||||
|
||||
|
||||
def get_config(key=None, default=None, raise_error=False, home_dir=None, use_env=True):
|
||||
"""Read MNE-Python preferences from environment or config file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : None | str
|
||||
The preference key to look for. The os environment is searched first,
|
||||
then the mne-python config file is parsed.
|
||||
If None, all the config parameters present in environment variables or
|
||||
the path are returned. If key is an empty string, a list of all valid
|
||||
keys (but not values) is returned.
|
||||
default : str | None
|
||||
Value to return if the key is not found.
|
||||
raise_error : bool
|
||||
If True, raise an error if the key is not found (instead of returning
|
||||
default).
|
||||
home_dir : str | None
|
||||
The folder that contains the .mne config folder.
|
||||
If None, it is found automatically.
|
||||
use_env : bool
|
||||
If True, consider env vars, if available.
|
||||
If False, only use MNE-Python configuration file values.
|
||||
|
||||
.. versionadded:: 0.18
|
||||
|
||||
Returns
|
||||
-------
|
||||
value : dict | str | None
|
||||
The preference key value.
|
||||
|
||||
See Also
|
||||
--------
|
||||
set_config
|
||||
"""
|
||||
_validate_type(key, (str, type(None)), "key", "string or None")
|
||||
|
||||
if key == "":
|
||||
# These are str->str (immutable) so we should just copy the dict
|
||||
# itself, no need for deepcopy
|
||||
return _known_config_types.copy()
|
||||
|
||||
# first, check to see if key is in env
|
||||
if use_env and key is not None and key in os.environ:
|
||||
return os.environ[key]
|
||||
|
||||
# second, look for it in mne-python config file
|
||||
config_path = get_config_path(home_dir=home_dir)
|
||||
if not op.isfile(config_path):
|
||||
config = {}
|
||||
else:
|
||||
config = _load_config(config_path)
|
||||
|
||||
if key is None:
|
||||
# update config with environment variables
|
||||
if use_env:
|
||||
env_keys = set(config).union(_known_config_types).intersection(os.environ)
|
||||
config.update({key: os.environ[key] for key in env_keys})
|
||||
return config
|
||||
elif raise_error is True and key not in config:
|
||||
loc_env = "the environment or in the " if use_env else ""
|
||||
meth_env = (
|
||||
(f'either os.environ["{key}"] = VALUE for a temporary solution, or ')
|
||||
if use_env
|
||||
else ""
|
||||
)
|
||||
extra_env = (
|
||||
" You can also set the environment variable before running python."
|
||||
if use_env
|
||||
else ""
|
||||
)
|
||||
meth_file = (
|
||||
f'mne.utils.set_config("{key}", VALUE, set_env=True) for a permanent one'
|
||||
)
|
||||
raise KeyError(
|
||||
f'Key "{key}" not found in {loc_env}'
|
||||
f"the mne-python config file ({config_path}). "
|
||||
f"Try {meth_env}{meth_file}.{extra_env}"
|
||||
)
|
||||
else:
|
||||
return config.get(key, default)
|
||||
|
||||
|
||||
def set_config(key, value, home_dir=None, set_env=True):
|
||||
"""Set a MNE-Python preference key in the config file and environment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The preference key to set.
|
||||
value : str | None
|
||||
The value to assign to the preference key. If None, the key is
|
||||
deleted.
|
||||
home_dir : str | None
|
||||
The folder that contains the .mne config folder.
|
||||
If None, it is found automatically.
|
||||
set_env : bool
|
||||
If True (default), update :data:`os.environ` in addition to
|
||||
updating the MNE-Python config file.
|
||||
|
||||
See Also
|
||||
--------
|
||||
get_config
|
||||
"""
|
||||
_validate_type(key, "str", "key")
|
||||
# While JSON allow non-string types, we allow users to override config
|
||||
# settings using env, which are strings, so we enforce that here
|
||||
_validate_type(value, (str, "path-like", type(None)), "value")
|
||||
if value is not None:
|
||||
value = str(value)
|
||||
|
||||
if key not in _known_config_types and not any(
|
||||
key.startswith(k) for k in _known_config_wildcards
|
||||
):
|
||||
warn(f'Setting non-standard config type: "{key}"')
|
||||
|
||||
# Read all previous values
|
||||
config_path = get_config_path(home_dir=home_dir)
|
||||
if op.isfile(config_path):
|
||||
config = _load_config(config_path, raise_error=True)
|
||||
else:
|
||||
config = dict()
|
||||
logger.info(
|
||||
f"Attempting to create new mne-python configuration file:\n{config_path}"
|
||||
)
|
||||
if value is None:
|
||||
config.pop(key, None)
|
||||
if set_env and key in os.environ:
|
||||
del os.environ[key]
|
||||
else:
|
||||
config[key] = value
|
||||
if set_env:
|
||||
os.environ[key] = value
|
||||
if key == "MNE_BROWSER_BACKEND":
|
||||
from ..viz._figure import set_browser_backend
|
||||
|
||||
set_browser_backend(value)
|
||||
|
||||
# Write all values. This may fail if the default directory is not
|
||||
# writeable.
|
||||
directory = op.dirname(config_path)
|
||||
if not op.isdir(directory):
|
||||
os.mkdir(directory)
|
||||
with open(config_path, "w") as fid:
|
||||
json.dump(config, fid, sort_keys=True, indent=0)
|
||||
|
||||
|
||||
def _get_extra_data_path(home_dir=None):
|
||||
"""Get path to extra data (config, tables, etc.)."""
|
||||
global _temp_home_dir
|
||||
if home_dir is None:
|
||||
home_dir = os.environ.get("_MNE_FAKE_HOME_DIR")
|
||||
if home_dir is None:
|
||||
# this has been checked on OSX64, Linux64, and Win32
|
||||
if "nt" == os.name.lower():
|
||||
APPDATA_DIR = os.getenv("APPDATA")
|
||||
USERPROFILE_DIR = os.getenv("USERPROFILE")
|
||||
if APPDATA_DIR is not None and op.isdir(
|
||||
op.join(APPDATA_DIR, ".mne")
|
||||
): # backward-compat
|
||||
home_dir = APPDATA_DIR
|
||||
elif USERPROFILE_DIR is not None:
|
||||
home_dir = USERPROFILE_DIR
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
"The USERPROFILE environment variable is not set, cannot "
|
||||
"determine the location of the MNE-Python configuration "
|
||||
"folder"
|
||||
)
|
||||
del APPDATA_DIR, USERPROFILE_DIR
|
||||
else:
|
||||
# This is a more robust way of getting the user's home folder on
|
||||
# Linux platforms (not sure about OSX, Unix or BSD) than checking
|
||||
# the HOME environment variable. If the user is running some sort
|
||||
# of script that isn't launched via the command line (e.g. a script
|
||||
# launched via Upstart) then the HOME environment variable will
|
||||
# not be set.
|
||||
if os.getenv("MNE_DONTWRITE_HOME", "") == "true":
|
||||
if _temp_home_dir is None:
|
||||
_temp_home_dir = tempfile.mkdtemp()
|
||||
atexit.register(
|
||||
partial(shutil.rmtree, _temp_home_dir, ignore_errors=True)
|
||||
)
|
||||
home_dir = _temp_home_dir
|
||||
else:
|
||||
home_dir = os.path.expanduser("~")
|
||||
|
||||
if home_dir is None:
|
||||
raise ValueError(
|
||||
"mne-python config file path could "
|
||||
"not be determined, please report this "
|
||||
"error to mne-python developers"
|
||||
)
|
||||
|
||||
return op.join(home_dir, ".mne")
|
||||
|
||||
|
||||
def get_subjects_dir(subjects_dir=None, raise_error=False):
|
||||
"""Safely use subjects_dir input to return SUBJECTS_DIR.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
subjects_dir : path-like | None
|
||||
If a value is provided, return subjects_dir. Otherwise, look for
|
||||
SUBJECTS_DIR config and return the result.
|
||||
raise_error : bool
|
||||
If True, raise a KeyError if no value for SUBJECTS_DIR can be found
|
||||
(instead of returning None).
|
||||
|
||||
Returns
|
||||
-------
|
||||
value : Path | None
|
||||
The SUBJECTS_DIR value.
|
||||
"""
|
||||
from_config = False
|
||||
if subjects_dir is None:
|
||||
subjects_dir = get_config("SUBJECTS_DIR", raise_error=raise_error)
|
||||
from_config = True
|
||||
if subjects_dir is not None:
|
||||
subjects_dir = Path(subjects_dir)
|
||||
if subjects_dir is not None:
|
||||
# Emit a nice error or warning if their config is bad
|
||||
try:
|
||||
subjects_dir = _check_fname(
|
||||
fname=subjects_dir,
|
||||
overwrite="read",
|
||||
must_exist=True,
|
||||
need_dir=True,
|
||||
name="subjects_dir",
|
||||
)
|
||||
except FileNotFoundError:
|
||||
if from_config:
|
||||
msg = (
|
||||
"SUBJECTS_DIR in your MNE-Python configuration or environment "
|
||||
"does not exist, consider using mne.set_config to fix it: "
|
||||
f"{subjects_dir}"
|
||||
)
|
||||
if raise_error:
|
||||
raise FileNotFoundError(msg) from None
|
||||
else:
|
||||
warn(msg)
|
||||
elif raise_error:
|
||||
raise
|
||||
|
||||
return subjects_dir
|
||||
|
||||
|
||||
@fill_doc
|
||||
def _get_stim_channel(stim_channel, info, raise_error=True):
|
||||
"""Determine the appropriate stim_channel.
|
||||
|
||||
First, 'MNE_STIM_CHANNEL', 'MNE_STIM_CHANNEL_1', 'MNE_STIM_CHANNEL_2', etc.
|
||||
are read. If these are not found, it will fall back to 'STI 014' if
|
||||
present, then fall back to the first channel of type 'stim', if present.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stim_channel : str | list of str | None
|
||||
The stim channel selected by the user.
|
||||
%(info_not_none)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
stim_channel : list of str
|
||||
The name of the stim channel(s) to use
|
||||
"""
|
||||
from .._fiff.pick import pick_types
|
||||
|
||||
if stim_channel is not None:
|
||||
if not isinstance(stim_channel, list):
|
||||
_validate_type(stim_channel, "str", "Stim channel")
|
||||
stim_channel = [stim_channel]
|
||||
for channel in stim_channel:
|
||||
_validate_type(channel, "str", "Each provided stim channel")
|
||||
return stim_channel
|
||||
|
||||
stim_channel = list()
|
||||
ch_count = 0
|
||||
ch = get_config("MNE_STIM_CHANNEL")
|
||||
while ch is not None and ch in info["ch_names"]:
|
||||
stim_channel.append(ch)
|
||||
ch_count += 1
|
||||
ch = get_config(f"MNE_STIM_CHANNEL_{ch_count}")
|
||||
if ch_count > 0:
|
||||
return stim_channel
|
||||
|
||||
if "STI101" in info["ch_names"]: # combination channel for newer systems
|
||||
return ["STI101"]
|
||||
if "STI 014" in info["ch_names"]: # for older systems
|
||||
return ["STI 014"]
|
||||
|
||||
stim_channel = pick_types(info, meg=False, ref_meg=False, stim=True)
|
||||
if len(stim_channel) == 0 and raise_error:
|
||||
raise ValueError(
|
||||
"No stim channels found. Consider specifying them "
|
||||
"manually using the 'stim_channel' parameter."
|
||||
)
|
||||
stim_channel = [info["ch_names"][ch_] for ch_ in stim_channel]
|
||||
return stim_channel
|
||||
|
||||
|
||||
def _get_root_dir():
|
||||
"""Get as close to the repo root as possible."""
|
||||
root_dir = Path(__file__).parents[1]
|
||||
up_dir = root_dir.parent
|
||||
if (up_dir / "setup.py").is_file() and all(
|
||||
(up_dir / x).is_dir() for x in ("mne", "examples", "doc")
|
||||
):
|
||||
root_dir = up_dir
|
||||
return root_dir
|
||||
|
||||
|
||||
def _get_numpy_libs():
|
||||
bad_lib = "unknown linalg bindings"
|
||||
try:
|
||||
from threadpoolctl import threadpool_info
|
||||
except Exception as exc:
|
||||
return bad_lib + f" (threadpoolctl module not found: {exc})"
|
||||
pools = threadpool_info()
|
||||
rename = dict(
|
||||
openblas="OpenBLAS",
|
||||
mkl="MKL",
|
||||
)
|
||||
for pool in pools:
|
||||
if pool["internal_api"] in ("openblas", "mkl"):
|
||||
return (
|
||||
f'{rename[pool["internal_api"]]} '
|
||||
f'{pool["version"]} with '
|
||||
f'{pool["num_threads"]} thread{_pl(pool["num_threads"])}'
|
||||
)
|
||||
return bad_lib
|
||||
|
||||
|
||||
_gpu_cmd = """\
|
||||
from pyvista import GPUInfo; \
|
||||
gi = GPUInfo(); \
|
||||
print(gi.version); \
|
||||
print(gi.renderer)"""
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_gpu_info():
|
||||
# Once https://github.com/pyvista/pyvista/pull/2250 is merged and PyVista
|
||||
# does a release, we can triage based on version > 0.33.2
|
||||
proc = subprocess.run(
|
||||
[sys.executable, "-c", _gpu_cmd], check=False, capture_output=True
|
||||
)
|
||||
out = proc.stdout.decode().strip().replace("\r", "").split("\n")
|
||||
if proc.returncode or len(out) != 2:
|
||||
return None, None
|
||||
return out
|
||||
|
||||
|
||||
def _get_total_memory():
|
||||
"""Return the total memory of the system in bytes."""
|
||||
if platform.system() == "Windows":
|
||||
o = subprocess.check_output(
|
||||
[
|
||||
"powershell.exe",
|
||||
"(Get-CimInstance Win32_ComputerSystem).TotalPhysicalMemory",
|
||||
]
|
||||
).decode()
|
||||
total_memory = int(o)
|
||||
elif platform.system() == "Linux":
|
||||
o = subprocess.check_output(["free", "-b"]).decode()
|
||||
total_memory = int(o.splitlines()[1].split()[1])
|
||||
elif platform.system() == "Darwin":
|
||||
o = subprocess.check_output(["sysctl", "hw.memsize"]).decode()
|
||||
total_memory = int(o.split(":")[1].strip())
|
||||
else:
|
||||
raise UnknownPlatformError("Could not determine total memory")
|
||||
|
||||
return total_memory
|
||||
|
||||
|
||||
def _get_cpu_brand():
|
||||
"""Return the CPU brand string."""
|
||||
if platform.system() == "Windows":
|
||||
o = subprocess.check_output(
|
||||
["powershell.exe", "(Get-CimInstance Win32_Processor).Name"]
|
||||
).decode()
|
||||
cpu_brand = o.strip().splitlines()[-1]
|
||||
elif platform.system() == "Linux":
|
||||
o = subprocess.check_output(["grep", "model name", "/proc/cpuinfo"]).decode()
|
||||
cpu_brand = o.splitlines()[0].split(": ")[1]
|
||||
elif platform.system() == "Darwin":
|
||||
o = subprocess.check_output(["sysctl", "machdep.cpu"]).decode()
|
||||
cpu_brand = o.split("brand_string: ")[1].strip()
|
||||
else:
|
||||
cpu_brand = "?"
|
||||
|
||||
return cpu_brand
|
||||
|
||||
|
||||
def sys_info(
|
||||
fid=None,
|
||||
show_paths=False,
|
||||
*,
|
||||
dependencies="user",
|
||||
unicode="auto",
|
||||
check_version=True,
|
||||
):
|
||||
"""Print system information.
|
||||
|
||||
This function prints system information useful when triaging bugs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fid : file-like | None
|
||||
The file to write to. Will be passed to :func:`print()`. Can be None to
|
||||
use :data:`sys.stdout`.
|
||||
show_paths : bool
|
||||
If True, print paths for each module.
|
||||
dependencies : 'user' | 'developer'
|
||||
Show dependencies relevant for users (default) or for developers
|
||||
(i.e., output includes additional dependencies).
|
||||
unicode : bool | "auto"
|
||||
Include Unicode symbols in output. If "auto", corresponds to True on Linux and
|
||||
macOS, and False on Windows.
|
||||
|
||||
.. versionadded:: 0.24
|
||||
check_version : bool | float
|
||||
If True (default), attempt to check that the version of MNE-Python is up to date
|
||||
with the latest release on GitHub. Can be a float to give a different timeout
|
||||
(in sec) from the default (2 sec).
|
||||
|
||||
.. versionadded:: 1.6
|
||||
"""
|
||||
_validate_type(dependencies, str)
|
||||
_check_option("dependencies", dependencies, ("user", "developer"))
|
||||
_validate_type(check_version, (bool, "numeric"), "check_version")
|
||||
_validate_type(unicode, (bool, str), "unicode")
|
||||
_check_option("unicode", unicode, ("auto", True, False))
|
||||
if unicode == "auto":
|
||||
if platform.system() in ("Darwin", "Linux"):
|
||||
unicode = True
|
||||
else: # Windows
|
||||
unicode = False
|
||||
ljust = 24 if dependencies == "developer" else 21
|
||||
platform_str = platform.platform()
|
||||
|
||||
out = partial(print, end="", file=fid)
|
||||
out("Platform".ljust(ljust) + platform_str + "\n")
|
||||
out("Python".ljust(ljust) + str(sys.version).replace("\n", " ") + "\n")
|
||||
out("Executable".ljust(ljust) + sys.executable + "\n")
|
||||
try:
|
||||
cpu_brand = _get_cpu_brand()
|
||||
except Exception:
|
||||
cpu_brand = "?"
|
||||
out("CPU".ljust(ljust) + f"{cpu_brand} ")
|
||||
out(f"({multiprocessing.cpu_count()} cores)\n")
|
||||
out("Memory".ljust(ljust))
|
||||
try:
|
||||
total_memory = _get_total_memory()
|
||||
except UnknownPlatformError:
|
||||
total_memory = "?"
|
||||
else:
|
||||
total_memory = f"{total_memory / 1024**3:.1f}" # convert to GiB
|
||||
out(f"{total_memory} GiB\n")
|
||||
out("\n")
|
||||
ljust -= 3 # account for +/- symbols
|
||||
libs = _get_numpy_libs()
|
||||
unavailable = []
|
||||
use_mod_names = (
|
||||
"# Core",
|
||||
"mne",
|
||||
"numpy",
|
||||
"scipy",
|
||||
"matplotlib",
|
||||
"",
|
||||
"# Numerical (optional)",
|
||||
"sklearn",
|
||||
"numba",
|
||||
"nibabel",
|
||||
"nilearn",
|
||||
"dipy",
|
||||
"openmeeg",
|
||||
"cupy",
|
||||
"pandas",
|
||||
"h5io",
|
||||
"h5py",
|
||||
"",
|
||||
"# Visualization (optional)",
|
||||
"pyvista",
|
||||
"pyvistaqt",
|
||||
"vtk",
|
||||
"qtpy",
|
||||
"ipympl",
|
||||
"pyqtgraph",
|
||||
"mne-qt-browser",
|
||||
"ipywidgets",
|
||||
# "trame", # no version, see https://github.com/Kitware/trame/issues/183
|
||||
"trame_client",
|
||||
"trame_server",
|
||||
"trame_vtk",
|
||||
"trame_vuetify",
|
||||
"",
|
||||
"# Ecosystem (optional)",
|
||||
"mne-bids",
|
||||
"mne-nirs",
|
||||
"mne-features",
|
||||
"mne-connectivity",
|
||||
"mne-icalabel",
|
||||
"mne-bids-pipeline",
|
||||
"neo",
|
||||
"eeglabio",
|
||||
"edfio",
|
||||
"mffpy",
|
||||
"pybv",
|
||||
"",
|
||||
)
|
||||
if dependencies == "developer":
|
||||
use_mod_names += (
|
||||
"# Testing",
|
||||
"pytest",
|
||||
"statsmodels",
|
||||
"numpydoc",
|
||||
"flake8",
|
||||
"jupyter_client",
|
||||
"nbclient",
|
||||
"nbformat",
|
||||
"pydocstyle",
|
||||
"nitime",
|
||||
"imageio",
|
||||
"imageio-ffmpeg",
|
||||
"snirf",
|
||||
"",
|
||||
"# Documentation",
|
||||
"sphinx",
|
||||
"sphinx-gallery",
|
||||
"pydata-sphinx-theme",
|
||||
"",
|
||||
"# Infrastructure",
|
||||
"decorator",
|
||||
"jinja2",
|
||||
# "lazy-loader",
|
||||
"packaging",
|
||||
"pooch",
|
||||
"tqdm",
|
||||
"",
|
||||
)
|
||||
try:
|
||||
unicode = unicode and (sys.stdout.encoding.lower().startswith("utf"))
|
||||
except Exception: # in case someone overrides sys.stdout in an unsafe way
|
||||
unicode = False
|
||||
mne_version_good = True
|
||||
for mi, mod_name in enumerate(use_mod_names):
|
||||
# upcoming break
|
||||
if mod_name == "": # break
|
||||
if unavailable:
|
||||
out("└☐ " if unicode else " - ")
|
||||
out("unavailable".ljust(ljust))
|
||||
out(f"{', '.join(unavailable)}\n")
|
||||
unavailable = []
|
||||
if mi != len(use_mod_names) - 1:
|
||||
out("\n")
|
||||
continue
|
||||
elif mod_name.startswith("# "): # header
|
||||
mod_name = mod_name.replace("# ", "")
|
||||
out(f"{mod_name}\n")
|
||||
continue
|
||||
pre = "├"
|
||||
last = use_mod_names[mi + 1] == "" and not unavailable
|
||||
if last:
|
||||
pre = "└"
|
||||
try:
|
||||
mod = import_module(mod_name.replace("-", "_"))
|
||||
except Exception:
|
||||
unavailable.append(mod_name)
|
||||
else:
|
||||
mark = "☑" if unicode else "+"
|
||||
mne_extra = ""
|
||||
if mod_name == "mne" and check_version:
|
||||
timeout = 2.0 if check_version is True else float(check_version)
|
||||
mne_version_good, mne_extra = _check_mne_version(timeout)
|
||||
if mne_version_good is None:
|
||||
mne_version_good = True
|
||||
elif not mne_version_good:
|
||||
mark = "☒" if unicode else "X"
|
||||
out(f"{pre}{mark} " if unicode else f" {mark} ")
|
||||
out(f"{mod_name}".ljust(ljust))
|
||||
if mod_name == "vtk":
|
||||
vtk_version = mod.vtkVersion()
|
||||
# 9.0 dev has VersionFull but 9.0 doesn't
|
||||
for attr in ("GetVTKVersionFull", "GetVTKVersion"):
|
||||
if hasattr(vtk_version, attr):
|
||||
version = getattr(vtk_version, attr)()
|
||||
if version != "":
|
||||
out(version)
|
||||
break
|
||||
else:
|
||||
out("unknown")
|
||||
else:
|
||||
out(mod.__version__.lstrip("v"))
|
||||
if mod_name == "numpy":
|
||||
out(f" ({libs})")
|
||||
elif mod_name == "qtpy":
|
||||
version, api = _check_qt_version(return_api=True)
|
||||
out(f" ({api}={version})")
|
||||
elif mod_name == "matplotlib":
|
||||
out(f" (backend={mod.get_backend()})")
|
||||
elif mod_name == "pyvista":
|
||||
version, renderer = _get_gpu_info()
|
||||
if version is None:
|
||||
out(" (OpenGL unavailable)")
|
||||
else:
|
||||
out(f" (OpenGL {version} via {renderer})")
|
||||
elif mod_name == "mne":
|
||||
out(f" ({mne_extra})")
|
||||
# Now comes stuff after the version
|
||||
if show_paths:
|
||||
if last:
|
||||
pre = " "
|
||||
elif unicode:
|
||||
pre = "│ "
|
||||
else:
|
||||
pre = " | "
|
||||
out(f'\n{pre}{" " * ljust}{op.dirname(mod.__file__)}')
|
||||
out("\n")
|
||||
|
||||
if not mne_version_good:
|
||||
out(
|
||||
"\nTo update to the latest supported release version to get bugfixes and "
|
||||
"improvements, visit "
|
||||
"https://mne.tools/stable/install/updating.html\n"
|
||||
)
|
||||
|
||||
|
||||
def _get_latest_version(timeout):
|
||||
# Bandit complains about urlopen, but we know the URL here
|
||||
url = "https://api.github.com/repos/mne-tools/mne-python/releases/latest"
|
||||
try:
|
||||
with urlopen(url, timeout=timeout) as f: # nosec
|
||||
response = json.load(f)
|
||||
except (URLError, TimeoutError) as err:
|
||||
# Triage error type
|
||||
if "SSL" in str(err):
|
||||
return "SSL error"
|
||||
elif "timed out" in str(err):
|
||||
return f"timeout after {timeout} sec"
|
||||
else:
|
||||
return f"unknown error: {err}"
|
||||
else:
|
||||
return response["tag_name"].lstrip("v") or "version unknown"
|
||||
|
||||
|
||||
def _check_mne_version(timeout):
|
||||
rel_ver = _get_latest_version(timeout)
|
||||
if not rel_ver[0].isnumeric():
|
||||
return None, (f"unable to check for latest version on GitHub, {rel_ver}")
|
||||
rel_ver = parse(rel_ver)
|
||||
this_ver = parse(import_module("mne").__version__)
|
||||
if this_ver > rel_ver:
|
||||
return True, f"devel, latest release is {rel_ver}"
|
||||
if this_ver == rel_ver:
|
||||
return True, "latest release"
|
||||
else:
|
||||
return False, f"outdated, release {rel_ver} is available!"
|
||||
131
mne/utils/dataframe.py
Normal file
131
mne/utils/dataframe.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""inst.to_data_frame() helper functions."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from inspect import signature
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..defaults import _handle_default
|
||||
from ._logging import logger, verbose
|
||||
from .check import check_version
|
||||
|
||||
|
||||
@verbose
|
||||
def _set_pandas_dtype(df, columns, dtype, verbose=None):
|
||||
"""Try to set the right columns to dtype."""
|
||||
for column in columns:
|
||||
df[column] = df[column].astype(dtype)
|
||||
logger.info(f'Converting "{column}" to "{dtype}"...')
|
||||
|
||||
|
||||
def _scale_dataframe_data(inst, data, picks, scalings):
|
||||
ch_types = inst.get_channel_types()
|
||||
ch_types_used = list()
|
||||
scalings = _handle_default("scalings", scalings)
|
||||
for tt in scalings.keys():
|
||||
if tt in ch_types:
|
||||
ch_types_used.append(tt)
|
||||
for tt in ch_types_used:
|
||||
scaling = scalings[tt]
|
||||
idx = [ii for ii in range(len(picks)) if ch_types[ii] == tt]
|
||||
if len(idx):
|
||||
data[:, idx] *= scaling
|
||||
return data
|
||||
|
||||
|
||||
def _convert_times(times, time_format, meas_date=None, first_time=0):
|
||||
"""Convert vector of time in seconds to ms, datetime, or timedelta."""
|
||||
# private function; pandas already checked in calling function
|
||||
from pandas import to_timedelta
|
||||
|
||||
if time_format == "ms":
|
||||
times = np.round(times * 1e3).astype(np.int64)
|
||||
elif time_format == "timedelta":
|
||||
times = to_timedelta(times, unit="s")
|
||||
elif time_format == "datetime":
|
||||
times = to_timedelta(times + first_time, unit="s") + meas_date
|
||||
return times
|
||||
|
||||
|
||||
def _inplace(df, method, **kwargs):
|
||||
# Handle transition: inplace=True (pandas <1.5) → copy=False (>=1.5)
|
||||
# and 3.0 warning:
|
||||
# E DeprecationWarning: The copy keyword is deprecated and will be removed in a
|
||||
# future version. Copy-on-Write is active in pandas since 3.0 which utilizes a
|
||||
# lazy copy mechanism that defers copies until necessary. Use .copy() to make
|
||||
# an eager copy if necessary.
|
||||
_meth = getattr(df, method) # used for set_index() and rename()
|
||||
|
||||
if check_version("pandas", "3.0"):
|
||||
return _meth(**kwargs)
|
||||
elif "copy" in signature(_meth).parameters:
|
||||
return _meth(**kwargs, copy=False)
|
||||
else:
|
||||
_meth(**kwargs, inplace=True)
|
||||
return df
|
||||
|
||||
|
||||
@verbose
|
||||
def _build_data_frame(
|
||||
inst,
|
||||
data,
|
||||
picks,
|
||||
long_format,
|
||||
mindex,
|
||||
index,
|
||||
default_index,
|
||||
col_names=None,
|
||||
col_kind="channel",
|
||||
verbose=None,
|
||||
):
|
||||
"""Build DataFrame from MNE-object-derived data array."""
|
||||
# private function; pandas already checked in calling function
|
||||
from pandas import DataFrame
|
||||
|
||||
from ..source_estimate import _BaseSourceEstimate
|
||||
|
||||
# build DataFrame
|
||||
if col_names is None:
|
||||
col_names = [inst.ch_names[p] for p in picks]
|
||||
df = DataFrame(data, columns=col_names)
|
||||
for i, (k, v) in enumerate(mindex):
|
||||
df.insert(i, k, v)
|
||||
# build Index
|
||||
if long_format:
|
||||
df = _inplace(df, "set_index", keys=default_index)
|
||||
df.columns.name = col_kind
|
||||
elif index is not None:
|
||||
df = _inplace(df, "set_index", keys=index)
|
||||
if set(index) == set(default_index):
|
||||
df.columns.name = col_kind
|
||||
# long format
|
||||
if long_format:
|
||||
df = df.stack().reset_index()
|
||||
df = _inplace(df, "rename", columns={0: "value"})
|
||||
# add column for channel types (as appropriate)
|
||||
ch_map = (
|
||||
None
|
||||
if isinstance(inst, _BaseSourceEstimate)
|
||||
else dict(
|
||||
zip(
|
||||
np.array(inst.ch_names)[picks],
|
||||
np.array(inst.get_channel_types())[picks],
|
||||
)
|
||||
)
|
||||
)
|
||||
if ch_map is not None:
|
||||
col_index = len(df.columns) - 1
|
||||
ch_type = df["channel"].map(ch_map)
|
||||
df.insert(col_index, "ch_type", ch_type)
|
||||
# restore index
|
||||
if index is not None:
|
||||
df = _inplace(df, "set_index", keys=index)
|
||||
# convert channel/vertex/ch_type columns to factors
|
||||
to_factor = [
|
||||
c for c in df.columns.tolist() if c not in ("freq", "time", "value")
|
||||
]
|
||||
_set_pandas_dtype(df, to_factor, "category")
|
||||
return df
|
||||
5609
mne/utils/docs.py
Normal file
5609
mne/utils/docs.py
Normal file
File diff suppressed because it is too large
Load Diff
19
mne/utils/fetching.py
Normal file
19
mne/utils/fetching.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""File downloading functions."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def _url_to_local_path(url, path):
|
||||
"""Mirror a url path in a local destination (keeping folder structure)."""
|
||||
from urllib import parse, request
|
||||
|
||||
destination = parse.urlparse(url).path
|
||||
# First char should be '/', and it needs to be discarded
|
||||
if len(destination) < 2 or destination[0] != "/":
|
||||
raise ValueError("Invalid URL")
|
||||
destination = os.path.join(path, request.url2pathname(destination)[1:])
|
||||
return destination
|
||||
243
mne/utils/linalg.py
Normal file
243
mne/utils/linalg.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""Utility functions to speed up linear algebraic operations.
|
||||
|
||||
In general, things like np.dot and linalg.svd should be used directly
|
||||
because they are smart about checking for bad values. However, in cases where
|
||||
things are done repeatedly (e.g., thousands of times on tiny matrices), the
|
||||
overhead can become problematic from a performance standpoint. Examples:
|
||||
|
||||
- Optimization routines:
|
||||
- Dipole fitting
|
||||
- Sparse solving
|
||||
- cHPI fitting
|
||||
- Inverse computation
|
||||
- Beamformers (LCMV/DICS)
|
||||
- eLORETA minimum norm
|
||||
|
||||
Significant performance gains can be achieved by ensuring that inputs
|
||||
are Fortran contiguous because that's what LAPACK requires. Without this,
|
||||
inputs will be memcopied.
|
||||
"""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
from scipy import linalg
|
||||
from scipy._lib._util import _asarray_validated
|
||||
|
||||
from ..fixes import _safe_svd
|
||||
|
||||
# For efficiency, names should be str or tuple of str, dtype a builtin
|
||||
# NumPy dtype
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _get_blas_funcs(dtype, names):
|
||||
return linalg.get_blas_funcs(names, (np.empty(0, dtype),))
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _get_lapack_funcs(dtype, names):
|
||||
assert dtype in (np.float64, np.complex128)
|
||||
x = np.empty(0, dtype)
|
||||
return linalg.get_lapack_funcs(names, (x,))
|
||||
|
||||
|
||||
###############################################################################
|
||||
# linalg.svd and linalg.pinv2
|
||||
|
||||
|
||||
def _svd_lwork(shape, dtype=np.float64):
|
||||
"""Set up SVD calculations on identical-shape float64/complex128 arrays."""
|
||||
try:
|
||||
ds = linalg._decomp_svd
|
||||
except AttributeError: # < 1.8.0
|
||||
ds = linalg.decomp_svd
|
||||
gesdd_lwork, gesvd_lwork = _get_lapack_funcs(dtype, ("gesdd_lwork", "gesvd_lwork"))
|
||||
sdd_lwork = ds._compute_lwork(
|
||||
gesdd_lwork, *shape, compute_uv=True, full_matrices=False
|
||||
)
|
||||
svd_lwork = ds._compute_lwork(
|
||||
gesvd_lwork, *shape, compute_uv=True, full_matrices=False
|
||||
)
|
||||
return sdd_lwork, svd_lwork
|
||||
|
||||
|
||||
def _repeated_svd(x, lwork, overwrite_a=False):
|
||||
"""Mimic scipy.linalg.svd, avoid lwork and get_lapack_funcs overhead."""
|
||||
gesdd, gesvd = _get_lapack_funcs(x.dtype, ("gesdd", "gesvd"))
|
||||
# this has to use overwrite_a=False in case we need to fall back to gesvd
|
||||
u, s, v, info = gesdd(
|
||||
x, compute_uv=True, lwork=lwork[0], full_matrices=False, overwrite_a=False
|
||||
)
|
||||
if info > 0:
|
||||
# Fall back to slower gesvd, sometimes gesdd fails
|
||||
u, s, v, info = gesvd(
|
||||
x,
|
||||
compute_uv=True,
|
||||
lwork=lwork[1],
|
||||
full_matrices=False,
|
||||
overwrite_a=overwrite_a,
|
||||
)
|
||||
if info > 0:
|
||||
raise np.linalg.LinAlgError("SVD did not converge")
|
||||
if info < 0:
|
||||
raise ValueError(f"illegal value in {-info}-th argument of internal gesdd")
|
||||
return u, s, v
|
||||
|
||||
|
||||
###############################################################################
|
||||
# linalg.eigh
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _get_evd(dtype):
|
||||
x = np.empty(0, dtype)
|
||||
if dtype == np.float64:
|
||||
driver = "syevd"
|
||||
else:
|
||||
assert dtype == np.complex128
|
||||
driver = "heevd"
|
||||
(evr,) = linalg.get_lapack_funcs((driver,), (x,))
|
||||
return evr, driver
|
||||
|
||||
|
||||
def eigh(a, overwrite_a=False, check_finite=True):
|
||||
"""Efficient wrapper for eigh.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : ndarray, shape (n_components, n_components)
|
||||
The symmetric array operate on.
|
||||
overwrite_a : bool
|
||||
If True, the contents of a can be overwritten for efficiency.
|
||||
check_finite : bool
|
||||
If True, check that all elements are finite.
|
||||
|
||||
Returns
|
||||
-------
|
||||
w : ndarray, shape (n_components,)
|
||||
The N eigenvalues, in ascending order, each repeated according to
|
||||
its multiplicity.
|
||||
v : ndarray, shape (n_components, n_components)
|
||||
The normalized eigenvector corresponding to the eigenvalue ``w[i]``
|
||||
is the column ``v[:, i]``.
|
||||
"""
|
||||
# We use SYEVD, see https://github.com/scipy/scipy/issues/9212
|
||||
if check_finite:
|
||||
a = _asarray_validated(a, check_finite=check_finite)
|
||||
evd, driver = _get_evd(a.dtype)
|
||||
w, v, info = evd(a, lower=1, overwrite_a=overwrite_a)
|
||||
if info == 0:
|
||||
return w, v
|
||||
if info < 0:
|
||||
raise ValueError(f"illegal value in argument {-info} of internal {driver}")
|
||||
else:
|
||||
raise linalg.LinAlgError(
|
||||
"internal fortran routine failed to converge: "
|
||||
f"{info} off-diagonal elements of an "
|
||||
"intermediate tridiagonal form did not converge"
|
||||
" to zero."
|
||||
)
|
||||
|
||||
|
||||
def sqrtm_sym(A, rcond=1e-7, inv=False):
|
||||
"""Compute the sqrt of a positive, semi-definite matrix (or its inverse).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A : ndarray, shape (..., n, n)
|
||||
The array to take the square root of.
|
||||
rcond : float
|
||||
The relative condition number used during reconstruction.
|
||||
inv : bool
|
||||
If True, compute the inverse of the square root rather than the
|
||||
square root itself.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A_sqrt : ndarray, shape (..., n, n)
|
||||
The (possibly inverted) square root of A.
|
||||
s : ndarray, shape (..., n)
|
||||
The original square root singular values (not inverted).
|
||||
"""
|
||||
# Same as linalg.sqrtm(C) but faster, also yields the eigenvalues
|
||||
return _sym_mat_pow(A, -0.5 if inv else 0.5, rcond, return_s=True)
|
||||
|
||||
|
||||
def _sym_mat_pow(A, power, rcond=1e-7, reduce_rank=False, return_s=False):
|
||||
"""Exponentiate Hermitian matrices with optional rank reduction."""
|
||||
assert power in (-1, 0.5, -0.5) # only used internally
|
||||
s, u = np.linalg.eigh(A) # eigenvalues in ascending order
|
||||
# Is it positive semi-defidite? If so, keep real
|
||||
limit = s[..., -1:] * rcond
|
||||
if not (s >= -limit).all(): # allow some tiny small negative ones
|
||||
raise ValueError("Matrix is not positive semi-definite")
|
||||
s[s <= limit] = np.inf if power < 0 else 0
|
||||
if reduce_rank:
|
||||
# These are ordered smallest to largest, so we set the first one
|
||||
# to inf -- then the 1. / s below will turn this to zero, as needed.
|
||||
s[..., 0] = np.inf
|
||||
if power in (-0.5, 0.5):
|
||||
np.sqrt(s, out=s)
|
||||
use_s = 1.0 / s if power < 0 else s
|
||||
out = np.matmul(u * use_s[..., np.newaxis, :], u.swapaxes(-2, -1).conj())
|
||||
if return_s:
|
||||
out = (out, s)
|
||||
return out
|
||||
|
||||
|
||||
# SciPy deprecation of pinv + pinvh rcond (never worked properly anyway)
|
||||
def pinvh(a, rtol=None):
|
||||
"""Compute a pseudo-inverse of a Hermitian matrix.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : ndarray, shape (n, n)
|
||||
The Hermitian array to invert.
|
||||
rtol : float | None
|
||||
The relative tolerance.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a_pinv : ndarray, shape (n, n)
|
||||
The pseudo-inverse of a.
|
||||
"""
|
||||
s, u = np.linalg.eigh(a)
|
||||
del a
|
||||
if rtol is None:
|
||||
rtol = s.size * np.finfo(s.dtype).eps
|
||||
maxS = np.max(np.abs(s))
|
||||
above_cutoff = abs(s) > maxS * rtol
|
||||
psigma_diag = 1.0 / s[above_cutoff]
|
||||
u = u[:, above_cutoff]
|
||||
return (u * psigma_diag) @ u.conj().T
|
||||
|
||||
|
||||
def pinv(a, rtol=None):
|
||||
"""Compute a pseudo-inverse of a matrix.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : ndarray, shape (n, m)
|
||||
The array to invert.
|
||||
rtol : float | None
|
||||
The relative tolerance.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a_pinv : ndarray, shape (m, n)
|
||||
The pseudo-inverse of a.
|
||||
"""
|
||||
u, s, vh = _safe_svd(a, full_matrices=False)
|
||||
del a
|
||||
maxS = np.max(s)
|
||||
if rtol is None:
|
||||
rtol = max(vh.shape + u.shape) * np.finfo(u.dtype).eps
|
||||
rank = np.sum(s > maxS * rtol)
|
||||
u = u[:, :rank]
|
||||
u /= s[:rank]
|
||||
return (u @ vh[:rank]).conj().T
|
||||
509
mne/utils/misc.py
Normal file
509
mne/utils/misc.py
Normal file
@@ -0,0 +1,509 @@
|
||||
"""Some miscellaneous utility functions."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import fnmatch
|
||||
import gc
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import traceback
|
||||
import weakref
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from importlib.resources import files
|
||||
from math import log
|
||||
from queue import Empty, Queue
|
||||
from string import Formatter
|
||||
from textwrap import dedent
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
from decorator import FunctionMaker
|
||||
|
||||
from ._logging import logger, verbose, warn
|
||||
from .check import _check_option, _validate_type
|
||||
|
||||
|
||||
def _identity_function(x):
|
||||
return x
|
||||
|
||||
|
||||
# TODO: no longer needed when py3.9 is minimum supported version
|
||||
def _empty_hash(kind="md5"):
|
||||
func = getattr(hashlib, kind)
|
||||
if "usedforsecurity" in inspect.signature(func).parameters:
|
||||
return func(usedforsecurity=False)
|
||||
else:
|
||||
return func()
|
||||
|
||||
|
||||
def _pl(x, non_pl="", pl="s"):
|
||||
"""Determine if plural should be used."""
|
||||
len_x = x if isinstance(x, int | np.generic) else len(x)
|
||||
return non_pl if len_x == 1 else pl
|
||||
|
||||
|
||||
def _explain_exception(start=-1, stop=None, prefix="> "):
|
||||
"""Explain an exception."""
|
||||
# start=-1 means "only the most recent caller"
|
||||
etype, value, tb = sys.exc_info()
|
||||
string = traceback.format_list(traceback.extract_tb(tb)[start:stop])
|
||||
string = "".join(string).split("\n") + traceback.format_exception_only(etype, value)
|
||||
string = ":\n" + prefix + ("\n" + prefix).join(string)
|
||||
return string
|
||||
|
||||
|
||||
def _sort_keys(x):
|
||||
"""Sort and return keys of dict."""
|
||||
keys = list(x.keys()) # note: not thread-safe
|
||||
idx = np.argsort([str(k) for k in keys])
|
||||
keys = [keys[ii] for ii in idx]
|
||||
return keys
|
||||
|
||||
|
||||
class _DefaultEventParser:
|
||||
"""Parse none standard events."""
|
||||
|
||||
def __init__(self):
|
||||
self.event_ids = dict()
|
||||
|
||||
def __call__(self, description, offset=1):
|
||||
if description not in self.event_ids:
|
||||
self.event_ids[description] = offset + len(self.event_ids)
|
||||
|
||||
return self.event_ids[description]
|
||||
|
||||
|
||||
class _FormatDict(dict):
|
||||
"""Help pformat() work properly."""
|
||||
|
||||
def __missing__(self, key):
|
||||
return "{" + key + "}"
|
||||
|
||||
|
||||
def pformat(temp, **fmt):
|
||||
"""Format a template string partially.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> pformat("{a}_{b}", a='x')
|
||||
'x_{b}'
|
||||
"""
|
||||
formatter = Formatter()
|
||||
mapping = _FormatDict(fmt)
|
||||
return formatter.vformat(temp, (), mapping)
|
||||
|
||||
|
||||
def _enqueue_output(out, queue):
|
||||
for line in iter(out.readline, b""):
|
||||
queue.put(line)
|
||||
|
||||
|
||||
@verbose
|
||||
def run_subprocess(command, return_code=False, verbose=None, *args, **kwargs):
|
||||
"""Run command using subprocess.Popen.
|
||||
|
||||
Run command and wait for command to complete. If the return code was zero
|
||||
then return, otherwise raise CalledProcessError.
|
||||
By default, this will also add stdout= and stderr=subproces.PIPE
|
||||
to the call to Popen to suppress printing to the terminal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
command : list of str | str
|
||||
Command to run as subprocess (see subprocess.Popen documentation).
|
||||
return_code : bool
|
||||
If True, return the return code instead of raising an error if it's
|
||||
non-zero.
|
||||
|
||||
.. versionadded:: 0.20
|
||||
%(verbose)s
|
||||
*args, **kwargs : arguments
|
||||
Additional arguments to pass to subprocess.Popen.
|
||||
|
||||
Returns
|
||||
-------
|
||||
stdout : str
|
||||
Stdout returned by the process.
|
||||
stderr : str
|
||||
Stderr returned by the process.
|
||||
code : int
|
||||
The return code, only returned if ``return_code == True``.
|
||||
"""
|
||||
all_out = ""
|
||||
all_err = ""
|
||||
# non-blocking adapted from https://stackoverflow.com/questions/375427/non-blocking-read-on-a-subprocess-pipe-in-python#4896288 # noqa: E501
|
||||
out_q = Queue()
|
||||
err_q = Queue()
|
||||
control_stdout = "stdout" not in kwargs
|
||||
control_stderr = "stderr" not in kwargs
|
||||
with running_subprocess(command, *args, **kwargs) as p:
|
||||
if control_stdout:
|
||||
out_t = Thread(target=_enqueue_output, args=(p.stdout, out_q))
|
||||
out_t.daemon = True
|
||||
out_t.start()
|
||||
if control_stderr:
|
||||
err_t = Thread(target=_enqueue_output, args=(p.stderr, err_q))
|
||||
err_t.daemon = True
|
||||
err_t.start()
|
||||
while True:
|
||||
do_break = p.poll() is not None
|
||||
# read all current lines without blocking
|
||||
while True: # process stdout
|
||||
try:
|
||||
out = out_q.get(timeout=0.01)
|
||||
except Empty:
|
||||
break
|
||||
else:
|
||||
out = out.decode("utf-8")
|
||||
log_out = out.removesuffix("\n")
|
||||
logger.info(log_out)
|
||||
all_out += out
|
||||
|
||||
while True: # process stderr
|
||||
try:
|
||||
err = err_q.get(timeout=0.01)
|
||||
except Empty:
|
||||
break
|
||||
else:
|
||||
err = err.decode("utf-8")
|
||||
err_out = err.removesuffix("\n")
|
||||
|
||||
# Leave this as logger.warning rather than warn(...) to
|
||||
# mirror the logger.info above for stdout. This function
|
||||
# is basically just a version of subprocess.call, and
|
||||
# shouldn't emit Python warnings due to stderr outputs
|
||||
# (the calling function can check for stderr output and
|
||||
# emit a warning if it wants).
|
||||
logger.warning(err_out)
|
||||
all_err += err
|
||||
|
||||
if do_break:
|
||||
break
|
||||
output = (all_out, all_err)
|
||||
|
||||
if return_code:
|
||||
output = output + (p.returncode,)
|
||||
elif p.returncode:
|
||||
stdout = all_out if control_stdout else None
|
||||
stderr = all_err if control_stderr else None
|
||||
raise subprocess.CalledProcessError(
|
||||
p.returncode, command, output=stdout, stderr=stderr
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@contextmanager
|
||||
def running_subprocess(command, after="wait", verbose=None, *args, **kwargs):
|
||||
"""Context manager to do something with a command running via Popen.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
command : list of str | str
|
||||
Command to run as subprocess (see :class:`python:subprocess.Popen`).
|
||||
after : str
|
||||
Can be:
|
||||
|
||||
- "wait" to use :meth:`~python:subprocess.Popen.wait`
|
||||
- "communicate" to use :meth:`~python.subprocess.Popen.communicate`
|
||||
- "terminate" to use :meth:`~python:subprocess.Popen.terminate`
|
||||
- "kill" to use :meth:`~python:subprocess.Popen.kill`
|
||||
|
||||
%(verbose)s
|
||||
*args, **kwargs : arguments
|
||||
Additional arguments to pass to subprocess.Popen.
|
||||
|
||||
Returns
|
||||
-------
|
||||
p : instance of Popen
|
||||
The process.
|
||||
"""
|
||||
_validate_type(after, str, "after")
|
||||
_check_option("after", after, ["wait", "terminate", "kill", "communicate"])
|
||||
contexts = list()
|
||||
for stdxxx in ("stderr", "stdout"):
|
||||
if stdxxx not in kwargs:
|
||||
kwargs[stdxxx] = subprocess.PIPE
|
||||
contexts.append(stdxxx)
|
||||
|
||||
# Check the PATH environment variable. If run_subprocess() is to be called
|
||||
# frequently this should be refactored so as to only check the path once.
|
||||
env = kwargs.get("env", os.environ)
|
||||
if any(p.startswith("~") for p in env["PATH"].split(os.pathsep)):
|
||||
warn(
|
||||
"Your PATH environment variable contains at least one path "
|
||||
'starting with a tilde ("~") character. Such paths are not '
|
||||
"interpreted correctly from within Python. It is recommended "
|
||||
'that you use "$HOME" instead of "~".'
|
||||
)
|
||||
if isinstance(command, str):
|
||||
command_str = command
|
||||
else:
|
||||
command = [str(s) for s in command]
|
||||
command_str = " ".join(s for s in command)
|
||||
logger.info(f"Running subprocess: {command_str}")
|
||||
try:
|
||||
p = subprocess.Popen(command, *args, **kwargs)
|
||||
except Exception:
|
||||
if isinstance(command, str):
|
||||
command_name = command.split()[0]
|
||||
else:
|
||||
command_name = command[0]
|
||||
logger.error(f"Command not found: {command_name}")
|
||||
raise
|
||||
try:
|
||||
with ExitStack() as stack:
|
||||
for context in contexts:
|
||||
stack.enter_context(getattr(p, context))
|
||||
yield p
|
||||
finally:
|
||||
getattr(p, after)()
|
||||
p.wait()
|
||||
|
||||
|
||||
def _clean_names(names, remove_whitespace=False, before_dash=True):
|
||||
"""Remove white-space on topo matching.
|
||||
|
||||
This function handles different naming conventions for old VS new VectorView systems
|
||||
(`remove_whitespace`) and removes system specific parts in CTF channel names
|
||||
(`before_dash`).
|
||||
|
||||
Usage
|
||||
-----
|
||||
# for new VectorView (only inside layout)
|
||||
ch_names = _clean_names(epochs.ch_names, remove_whitespace=True)
|
||||
|
||||
# for CTF
|
||||
ch_names = _clean_names(epochs.ch_names, before_dash=True)
|
||||
"""
|
||||
cleaned = []
|
||||
for name in names:
|
||||
if " " in name and remove_whitespace:
|
||||
name = name.replace(" ", "")
|
||||
if "-" in name and before_dash:
|
||||
name = name.split("-")[0]
|
||||
if name.endswith("_v"):
|
||||
name = name[:-2]
|
||||
cleaned.append(name)
|
||||
if len(set(cleaned)) != len(names):
|
||||
# this was probably not a VectorView or CTF dataset, and we now broke the
|
||||
# dataset by creating duplicates, so let's use the original channel names.
|
||||
return names
|
||||
return cleaned
|
||||
|
||||
|
||||
def _get_argvalues():
|
||||
"""Return all arguments (except self) and values of read_raw_xxx."""
|
||||
# call stack
|
||||
# read_raw_xxx -> <decorator-gen-000> -> BaseRaw.__init__ -> _get_argvalues
|
||||
|
||||
# This is equivalent to `frame = inspect.stack(0)[4][0]` but faster
|
||||
frame = inspect.currentframe()
|
||||
try:
|
||||
for _ in range(3):
|
||||
frame = frame.f_back
|
||||
fname = frame.f_code.co_filename
|
||||
if not fnmatch.fnmatch(fname, "*/mne/io/*"):
|
||||
return None
|
||||
args, _, _, values = inspect.getargvalues(frame)
|
||||
finally:
|
||||
del frame
|
||||
params = dict()
|
||||
for arg in args:
|
||||
params[arg] = values[arg]
|
||||
params.pop("self", None)
|
||||
return params
|
||||
|
||||
|
||||
def sizeof_fmt(num):
|
||||
"""Turn number of bytes into human-readable str.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num : int
|
||||
The number of bytes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
size : str
|
||||
The size in human-readable format.
|
||||
"""
|
||||
units = ["bytes", "KiB", "MiB", "GiB", "TiB", "PiB"]
|
||||
decimals = [0, 0, 1, 2, 2, 2]
|
||||
if num > 1:
|
||||
exponent = min(int(log(num, 1024)), len(units) - 1)
|
||||
quotient = float(num) / 1024**exponent
|
||||
unit = units[exponent]
|
||||
num_decimals = decimals[exponent]
|
||||
format_string = f"{{0:.{num_decimals}f}} {{1}}"
|
||||
return format_string.format(quotient, unit)
|
||||
if num == 0:
|
||||
return "0 bytes"
|
||||
if num == 1:
|
||||
return "1 byte"
|
||||
|
||||
|
||||
def _file_like(obj):
|
||||
# An alternative would be::
|
||||
#
|
||||
# isinstance(obj, (TextIOBase, BufferedIOBase, RawIOBase, IOBase))
|
||||
#
|
||||
# but this might be more robust to file-like objects not properly
|
||||
# inheriting from these classes:
|
||||
return all(callable(getattr(obj, name, None)) for name in ("read", "seek"))
|
||||
|
||||
|
||||
def _fullname(obj):
|
||||
klass = obj.__class__
|
||||
module = klass.__module__
|
||||
if module == "builtins":
|
||||
return klass.__qualname__
|
||||
return module + "." + klass.__qualname__
|
||||
|
||||
|
||||
def _assert_no_instances(cls, when=""):
|
||||
__tracebackhide__ = True
|
||||
n = 0
|
||||
ref = list()
|
||||
gc.collect()
|
||||
objs = gc.get_objects()
|
||||
for obj in objs:
|
||||
try:
|
||||
check = isinstance(obj, cls)
|
||||
except Exception: # such as a weakref
|
||||
check = False
|
||||
if check:
|
||||
if cls.__name__ == "Brain":
|
||||
ref.append(f'Brain._cleaned = {getattr(obj, "_cleaned", None)}')
|
||||
rr = gc.get_referrers(obj)
|
||||
count = 0
|
||||
for r in rr:
|
||||
if (
|
||||
r is not objs
|
||||
and r is not globals()
|
||||
and r is not locals()
|
||||
and not inspect.isframe(r)
|
||||
):
|
||||
if isinstance(r, list | dict | tuple):
|
||||
rep = f"len={len(r)}"
|
||||
r_ = gc.get_referrers(r)
|
||||
types = (_fullname(x) for x in r_)
|
||||
types = "/".join(sorted(set(x for x in types if x is not None)))
|
||||
rep += f", {len(r_)} referrers: {types}"
|
||||
del r_
|
||||
else:
|
||||
rep = repr(r)[:100].replace("\n", " ")
|
||||
# If it's a __closure__, get more information
|
||||
if rep.startswith("<cell at "):
|
||||
try:
|
||||
rep += f" ({repr(r.cell_contents)[:100]})"
|
||||
except Exception:
|
||||
pass
|
||||
name = _fullname(r)
|
||||
ref.append(f"{name}: {rep}")
|
||||
count += 1
|
||||
del r
|
||||
del rr
|
||||
n += count > 0
|
||||
del obj
|
||||
del objs
|
||||
gc.collect()
|
||||
assert n == 0, f"\n{n} {cls.__name__} @ {when}:\n" + "\n".join(ref)
|
||||
|
||||
|
||||
def _resource_path(submodule, filename):
|
||||
"""Return a full system path to a package resource (AKA a file).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
submodule : str
|
||||
An import-style module or submodule name
|
||||
(e.g., "mne.datasets.testing").
|
||||
filename : str
|
||||
The file whose full path you want.
|
||||
|
||||
Returns
|
||||
-------
|
||||
path : str
|
||||
The full system path to the requested file.
|
||||
"""
|
||||
return files(submodule).joinpath(filename)
|
||||
|
||||
|
||||
def repr_html(f):
|
||||
"""Decorate _repr_html_ methods.
|
||||
|
||||
If a _repr_html_ method is decorated with this decorator, the repr in a
|
||||
notebook will show HTML or plain text depending on the config value
|
||||
MNE_REPR_HTML (by default "true", which will render HTML).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
f : function
|
||||
The function to decorate.
|
||||
|
||||
Returns
|
||||
-------
|
||||
wrapper : function
|
||||
The decorated function.
|
||||
"""
|
||||
from ..utils import get_config
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if get_config("MNE_REPR_HTML", "true").lower() == "false":
|
||||
import html
|
||||
|
||||
r = "<pre>" + html.escape(repr(args[0])) + "</pre>"
|
||||
return r.replace("\n", "<br/>")
|
||||
else:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _auto_weakref(function):
|
||||
"""Create weakrefs to self (or other free vars in __closure__) then evaluate.
|
||||
|
||||
When a nested function is defined within an instance method, and the function makes
|
||||
use of ``self``, it creates a reference cycle that the Python garbage collector is
|
||||
not smart enough to resolve, so the parent object is never GC'd. (The reference to
|
||||
``self`` becomes part of the ``__closure__`` of the nested function).
|
||||
|
||||
This decorator allows the nested function to access ``self`` without increasing the
|
||||
reference counter on ``self``, which will prevent the memory leak. If the referent
|
||||
is not found (usually because already GC'd) it will short-circuit the decorated
|
||||
function and return ``None``.
|
||||
"""
|
||||
names = function.__code__.co_freevars
|
||||
assert len(names) == len(function.__closure__)
|
||||
__weakref_values__ = dict()
|
||||
evaldict = dict(__weakref_values__=__weakref_values__)
|
||||
for name, value in zip(names, function.__closure__):
|
||||
__weakref_values__[name] = weakref.ref(value.cell_contents)
|
||||
body = dedent(inspect.getsource(function))
|
||||
body = body.splitlines()
|
||||
for li, line in enumerate(body):
|
||||
if line.startswith(" "):
|
||||
body = body[li:]
|
||||
break
|
||||
old_body = "\n".join(body)
|
||||
body = """\
|
||||
def %(name)s(%(signature)s):
|
||||
"""
|
||||
for name in names:
|
||||
body += f"""
|
||||
{name} = __weakref_values__[{repr(name)}]()
|
||||
if {name} is None:
|
||||
return
|
||||
"""
|
||||
body = body + old_body
|
||||
fm = FunctionMaker(function)
|
||||
fun = fm.make(body, evaldict, addsource=True)
|
||||
fun.__globals__.update(function.__globals__)
|
||||
assert fun.__closure__ is None, fun.__closure__
|
||||
return fun
|
||||
781
mne/utils/mixin.py
Normal file
781
mne/utils/mixin.py
Normal file
@@ -0,0 +1,781 @@
|
||||
"""Some utility functions."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._logging import verbose, warn
|
||||
from ._typing import Self
|
||||
from .check import _check_pandas_installed, _check_preload, _validate_type
|
||||
from .numerics import _time_mask, object_hash, object_size
|
||||
|
||||
logger = logging.getLogger("mne") # one selection here used across mne-python
|
||||
logger.propagate = False # don't propagate (in case of multiple imports)
|
||||
|
||||
|
||||
class SizeMixin:
|
||||
"""Estimate MNE object sizes."""
|
||||
|
||||
def __eq__(self, other):
|
||||
"""Compare self to other.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
other : object
|
||||
The object to compare to.
|
||||
|
||||
Returns
|
||||
-------
|
||||
eq : bool
|
||||
True if the two objects are equal.
|
||||
"""
|
||||
return isinstance(other, type(self)) and hash(self) == hash(other)
|
||||
|
||||
@property
|
||||
def _size(self):
|
||||
"""Estimate the object size."""
|
||||
try:
|
||||
size = object_size(self.info)
|
||||
except Exception:
|
||||
warn("Could not get size for self.info")
|
||||
return -1
|
||||
if hasattr(self, "data"):
|
||||
size += object_size(self.data)
|
||||
elif hasattr(self, "_data"):
|
||||
size += object_size(self._data)
|
||||
return size
|
||||
|
||||
def __hash__(self):
|
||||
"""Hash the object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
hash : int
|
||||
The hash
|
||||
"""
|
||||
from ..epochs import BaseEpochs
|
||||
from ..evoked import Evoked
|
||||
from ..io import BaseRaw
|
||||
|
||||
if isinstance(self, Evoked):
|
||||
return object_hash(dict(info=self.info, data=self.data))
|
||||
elif isinstance(self, BaseEpochs | BaseRaw):
|
||||
_check_preload(self, "Hashing ")
|
||||
return object_hash(dict(info=self.info, data=self._data))
|
||||
else:
|
||||
raise RuntimeError(f"Hashing unknown object type: {type(self)}")
|
||||
|
||||
|
||||
class GetEpochsMixin:
|
||||
"""Class to add epoch selection and metadata to certain classes."""
|
||||
|
||||
def __getitem__(
|
||||
self: Self,
|
||||
item,
|
||||
) -> Self:
|
||||
"""Return an Epochs object with a copied subset of epochs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
item : int | slice | array-like | str
|
||||
See Notes for use cases.
|
||||
|
||||
Returns
|
||||
-------
|
||||
epochs : instance of Epochs
|
||||
The subset of epochs.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Epochs can be accessed as ``epochs[...]`` in several ways:
|
||||
|
||||
1. **Integer or slice:** ``epochs[idx]`` will return an `~mne.Epochs`
|
||||
object with a subset of epochs chosen by index (supports single
|
||||
index and Python-style slicing).
|
||||
|
||||
2. **String:** ``epochs['name']`` will return an `~mne.Epochs` object
|
||||
comprising only the epochs labeled ``'name'`` (i.e., epochs created
|
||||
around events with the label ``'name'``).
|
||||
|
||||
If there are no epochs labeled ``'name'`` but there are epochs
|
||||
labeled with /-separated tags (e.g. ``'name/left'``,
|
||||
``'name/right'``), then ``epochs['name']`` will select the epochs
|
||||
with labels that contain that tag (e.g., ``epochs['left']`` selects
|
||||
epochs labeled ``'audio/left'`` and ``'visual/left'``, but not
|
||||
``'audio_left'``).
|
||||
|
||||
If multiple tags are provided *as a single string* (e.g.,
|
||||
``epochs['name_1/name_2']``), this selects epochs containing *all*
|
||||
provided tags. For example, ``epochs['audio/left']`` selects
|
||||
``'audio/left'`` and ``'audio/quiet/left'``, but not
|
||||
``'audio/right'``. Note that tag-based selection is insensitive to
|
||||
order: tags like ``'audio/left'`` and ``'left/audio'`` will be
|
||||
treated the same way when selecting via tag.
|
||||
|
||||
3. **List of strings:** ``epochs[['name_1', 'name_2', ... ]]`` will
|
||||
return an `~mne.Epochs` object comprising epochs that match *any* of
|
||||
the provided names (i.e., the list of names is treated as an
|
||||
inclusive-or condition). If *none* of the provided names match any
|
||||
epoch labels, a ``KeyError`` will be raised.
|
||||
|
||||
If epoch labels are /-separated tags, then providing multiple tags
|
||||
*as separate list entries* will likewise act as an inclusive-or
|
||||
filter. For example, ``epochs[['audio', 'left']]`` would select
|
||||
``'audio/left'``, ``'audio/right'``, and ``'visual/left'``, but not
|
||||
``'visual/right'``.
|
||||
|
||||
4. **Pandas query:** ``epochs['pandas query']`` will return an
|
||||
`~mne.Epochs` object with a subset of epochs (and matching
|
||||
metadata) selected by the query called with
|
||||
``self.metadata.eval``, e.g.::
|
||||
|
||||
epochs["col_a > 2 and col_b == 'foo'"]
|
||||
|
||||
would return all epochs whose associated ``col_a`` metadata was
|
||||
greater than two, and whose ``col_b`` metadata was the string 'foo'.
|
||||
Query-based indexing only works if Pandas is installed and
|
||||
``self.metadata`` is a :class:`pandas.DataFrame`.
|
||||
|
||||
.. versionadded:: 0.16
|
||||
"""
|
||||
return self._getitem(item)
|
||||
|
||||
def _item_to_select(self, item):
|
||||
if isinstance(item, str):
|
||||
item = [item]
|
||||
|
||||
# Convert string to indices
|
||||
if (
|
||||
isinstance(item, list | tuple)
|
||||
and len(item) > 0
|
||||
and isinstance(item[0], str)
|
||||
):
|
||||
select = self._keys_to_idx(item)
|
||||
elif isinstance(item, slice):
|
||||
select = item
|
||||
else:
|
||||
select = np.atleast_1d(item)
|
||||
if len(select) == 0:
|
||||
select = np.array([], int)
|
||||
return select
|
||||
|
||||
def _getitem(
|
||||
self,
|
||||
item,
|
||||
reason="IGNORED",
|
||||
copy=True,
|
||||
drop_event_id=True,
|
||||
select_data=True,
|
||||
return_indices=False,
|
||||
):
|
||||
"""
|
||||
Select epochs from current object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
item: slice, array-like, str, or list
|
||||
see `__getitem__` for details.
|
||||
reason: str, list/tuple of str
|
||||
entry in `drop_log` for unselected epochs
|
||||
copy: bool
|
||||
return a copy of the current object
|
||||
drop_event_id: bool
|
||||
remove non-existing event-ids after selection
|
||||
select_data: bool
|
||||
apply selection to data
|
||||
(use `select_data=False` if subclasses do not have a
|
||||
valid `_data` field, or data has already been subselected)
|
||||
return_indices: bool
|
||||
return the indices of selected epochs from the original object
|
||||
in addition to the new `Epochs` objects
|
||||
|
||||
Returns
|
||||
-------
|
||||
`Epochs` or tuple(Epochs, np.ndarray) if `return_indices` is True
|
||||
subset of epochs (and optionally array with kept epoch indices)
|
||||
"""
|
||||
inst = self.copy() if copy else self
|
||||
if self._data is not None:
|
||||
np.copyto(inst._data, self._data, casting="no")
|
||||
del self
|
||||
|
||||
select = inst._item_to_select(item)
|
||||
has_selection = hasattr(inst, "selection")
|
||||
if has_selection:
|
||||
key_selection = inst.selection[select]
|
||||
drop_log = list(inst.drop_log)
|
||||
if reason is not None:
|
||||
_validate_type(reason, (list, tuple, str), "reason")
|
||||
if isinstance(reason, list | tuple):
|
||||
for r in reason:
|
||||
_validate_type(r, str, r)
|
||||
if isinstance(reason, str):
|
||||
reason = (reason,)
|
||||
reason = tuple(reason)
|
||||
for idx in np.setdiff1d(inst.selection, key_selection):
|
||||
drop_log[idx] = reason
|
||||
inst.drop_log = tuple(drop_log)
|
||||
inst.selection = key_selection
|
||||
del drop_log
|
||||
|
||||
inst.events = np.atleast_2d(inst.events[select])
|
||||
if inst.metadata is not None:
|
||||
pd = _check_pandas_installed(strict=False)
|
||||
if pd:
|
||||
metadata = inst.metadata.iloc[select]
|
||||
if has_selection:
|
||||
metadata.index = inst.selection
|
||||
else:
|
||||
metadata = np.array(inst.metadata, "object")[select].tolist()
|
||||
|
||||
# will reset the index for us
|
||||
GetEpochsMixin.metadata.fset(inst, metadata, verbose=False)
|
||||
if inst.preload and select_data:
|
||||
# ensure that each Epochs instance owns its own data so we can
|
||||
# resize later if necessary
|
||||
inst._data = np.require(inst._data[select], requirements=["O"])
|
||||
if drop_event_id:
|
||||
# update event id to reflect new content of inst
|
||||
inst.event_id = {
|
||||
k: v for k, v in inst.event_id.items() if v in inst.events[:, 2]
|
||||
}
|
||||
|
||||
if return_indices:
|
||||
return inst, select
|
||||
else:
|
||||
return inst
|
||||
|
||||
def _keys_to_idx(self, keys):
|
||||
"""Find entries in event dict."""
|
||||
from ..event import match_event_names # avoid circular import
|
||||
|
||||
keys = keys if isinstance(keys, list | tuple) else [keys]
|
||||
try:
|
||||
# Assume it's a condition name
|
||||
return np.where(
|
||||
np.any(
|
||||
np.array(
|
||||
[
|
||||
self.events[:, 2] == self.event_id[k]
|
||||
for k in match_event_names(self.event_id, keys)
|
||||
]
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
)[0]
|
||||
except KeyError as err:
|
||||
# Could we in principle use metadata with these Epochs and keys?
|
||||
if len(keys) != 1 or self.metadata is None:
|
||||
# If not, raise original error
|
||||
raise
|
||||
msg = str(err.args[0]) # message for KeyError
|
||||
pd = _check_pandas_installed(strict=False)
|
||||
# See if the query can be done
|
||||
if pd:
|
||||
md = self.metadata if hasattr(self, "_metadata") else None
|
||||
self._check_metadata(metadata=md)
|
||||
try:
|
||||
# Try metadata
|
||||
vals = (
|
||||
self.metadata.reset_index()
|
||||
.query(keys[0], engine="python")
|
||||
.index.values
|
||||
)
|
||||
except Exception as exp:
|
||||
msg += (
|
||||
" The epochs.metadata Pandas query did not "
|
||||
f"yield any results: {exp.args[0]}"
|
||||
)
|
||||
else:
|
||||
return vals
|
||||
else:
|
||||
# If not, warn this might be a problem
|
||||
msg += (
|
||||
" The epochs.metadata Pandas query could not "
|
||||
"be performed, consider installing Pandas."
|
||||
)
|
||||
raise KeyError(msg)
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of epochs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
n_epochs : int
|
||||
The number of remaining epochs.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function only works if bad epochs have been dropped.
|
||||
|
||||
Examples
|
||||
--------
|
||||
This can be used as::
|
||||
|
||||
>>> epochs.drop_bad() # doctest: +SKIP
|
||||
>>> len(epochs) # doctest: +SKIP
|
||||
43
|
||||
>>> len(epochs.events) # doctest: +SKIP
|
||||
43
|
||||
"""
|
||||
from ..epochs import BaseEpochs
|
||||
|
||||
if isinstance(self, BaseEpochs) and not self._bad_dropped:
|
||||
raise RuntimeError(
|
||||
"Since bad epochs have not been dropped, the "
|
||||
"length of the Epochs is not known. Load the "
|
||||
"Epochs with preload=True, or call "
|
||||
"Epochs.drop_bad(). To find the number "
|
||||
"of events in the Epochs, use "
|
||||
"len(Epochs.events)."
|
||||
)
|
||||
return len(self.events)
|
||||
|
||||
def __iter__(self):
|
||||
"""Facilitate iteration over epochs.
|
||||
|
||||
This method resets the object iteration state to the first epoch.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This enables the use of this Python pattern::
|
||||
|
||||
>>> for epoch in epochs: # doctest: +SKIP
|
||||
>>> print(epoch) # doctest: +SKIP
|
||||
|
||||
Where ``epoch`` is given by successive outputs of
|
||||
:meth:`mne.Epochs.next`.
|
||||
"""
|
||||
self._current = 0
|
||||
self._current_detrend_picks = self._detrend_picks
|
||||
return self
|
||||
|
||||
def __next__(self, return_event_id=False):
|
||||
"""Iterate over epoch data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
return_event_id : bool
|
||||
If True, return both the epoch data and an event_id.
|
||||
|
||||
Returns
|
||||
-------
|
||||
epoch : array of shape (n_channels, n_times)
|
||||
The epoch data.
|
||||
event_id : int
|
||||
The event id. Only returned if ``return_event_id`` is ``True``.
|
||||
"""
|
||||
if not hasattr(self, "_current_detrend_picks"):
|
||||
self.__iter__() # ensure we're ready to iterate
|
||||
if self.preload:
|
||||
if self._current >= len(self._data):
|
||||
self._stop_iter()
|
||||
epoch = self._data[self._current]
|
||||
self._current += 1
|
||||
else:
|
||||
is_good = False
|
||||
while not is_good:
|
||||
if self._current >= len(self.events):
|
||||
self._stop_iter()
|
||||
epoch_noproj = self._get_epoch_from_raw(self._current)
|
||||
epoch_noproj = self._detrend_offset_decim(
|
||||
epoch_noproj, self._current_detrend_picks
|
||||
)
|
||||
epoch = self._project_epoch(epoch_noproj)
|
||||
self._current += 1
|
||||
is_good, _ = self._is_good_epoch(epoch)
|
||||
# If delayed-ssp mode, pass 'virgin' data after rejection decision.
|
||||
if self._do_delayed_proj:
|
||||
epoch = epoch_noproj
|
||||
|
||||
if not return_event_id:
|
||||
return epoch
|
||||
else:
|
||||
return epoch, self.events[self._current - 1][-1]
|
||||
|
||||
def _stop_iter(self):
|
||||
del self._current
|
||||
del self._current_detrend_picks
|
||||
raise StopIteration # signal the end
|
||||
|
||||
next = __next__ # originally for Python2, now b/c public
|
||||
|
||||
def _check_metadata(self, metadata=None, reset_index=False):
|
||||
"""Check metadata consistency."""
|
||||
# reset_index=False will not copy!
|
||||
if metadata is None:
|
||||
return
|
||||
else:
|
||||
pd = _check_pandas_installed(strict=False)
|
||||
if pd:
|
||||
_validate_type(metadata, types=pd.DataFrame, item_name="metadata")
|
||||
if len(metadata) != len(self.events):
|
||||
raise ValueError(
|
||||
"metadata must have the same number of "
|
||||
f"rows ({len(metadata)}) as events ({len(self.events)})"
|
||||
)
|
||||
if reset_index:
|
||||
if hasattr(self, "selection"):
|
||||
# makes a copy
|
||||
metadata = metadata.reset_index(drop=True)
|
||||
metadata.index = self.selection
|
||||
else:
|
||||
metadata = deepcopy(metadata)
|
||||
else:
|
||||
_validate_type(metadata, types=list, item_name="metadata")
|
||||
if reset_index:
|
||||
metadata = deepcopy(metadata)
|
||||
return metadata
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
"""Get the metadata."""
|
||||
return self._metadata
|
||||
|
||||
@metadata.setter
|
||||
@verbose
|
||||
def metadata(self, metadata, verbose=None):
|
||||
metadata = self._check_metadata(metadata, reset_index=True)
|
||||
if metadata is not None:
|
||||
if _check_pandas_installed(strict=False):
|
||||
n_col = metadata.shape[1]
|
||||
else:
|
||||
n_col = len(metadata[0])
|
||||
n_col = f" with {n_col} columns"
|
||||
else:
|
||||
n_col = ""
|
||||
if hasattr(self, "_metadata") and self._metadata is not None:
|
||||
action = "Removing" if metadata is None else "Replacing"
|
||||
action += " existing"
|
||||
else:
|
||||
action = "Not setting" if metadata is None else "Adding"
|
||||
logger.info(f"{action} metadata{n_col}")
|
||||
self._metadata = metadata
|
||||
|
||||
|
||||
def _check_decim(info, decim, offset, check_filter=True):
|
||||
"""Check decimation parameters."""
|
||||
if decim < 1 or decim != int(decim):
|
||||
raise ValueError("decim must be an integer > 0")
|
||||
decim = int(decim)
|
||||
new_sfreq = info["sfreq"] / float(decim)
|
||||
offset = int(offset)
|
||||
if not 0 <= offset < decim:
|
||||
raise ValueError(
|
||||
f"decim must be at least 0 and less than {decim}, got {offset}"
|
||||
)
|
||||
if check_filter:
|
||||
lowpass = info["lowpass"]
|
||||
if decim > 1 and lowpass is None:
|
||||
warn(
|
||||
"The measurement information indicates data is not low-pass "
|
||||
f"filtered. The decim={decim} parameter will result in a "
|
||||
f"sampling frequency of {new_sfreq} Hz, which can cause "
|
||||
"aliasing artifacts."
|
||||
)
|
||||
elif decim > 1 and new_sfreq < 3 * lowpass:
|
||||
warn(
|
||||
"The measurement information indicates a low-pass frequency "
|
||||
f"of {lowpass} Hz. The decim={decim} parameter will result "
|
||||
f"in a sampling frequency of {new_sfreq} Hz, which can "
|
||||
"cause aliasing artifacts."
|
||||
) # > 50% nyquist lim
|
||||
return decim, offset, new_sfreq
|
||||
|
||||
|
||||
class TimeMixin:
|
||||
"""Class for time operations on any MNE object that has a time axis."""
|
||||
|
||||
def time_as_index(self, times, use_rounding=False):
|
||||
"""Convert time to indices.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
times : list-like | float | int
|
||||
List of numbers or a number representing points in time.
|
||||
use_rounding : bool
|
||||
If True, use rounding (instead of truncation) when converting
|
||||
times to indices. This can help avoid non-unique indices.
|
||||
|
||||
Returns
|
||||
-------
|
||||
index : ndarray
|
||||
Indices corresponding to the times supplied.
|
||||
"""
|
||||
from ..source_estimate import _BaseSourceEstimate
|
||||
|
||||
if isinstance(self, _BaseSourceEstimate):
|
||||
sfreq = 1.0 / self.tstep
|
||||
else:
|
||||
sfreq = self.info["sfreq"]
|
||||
index = (np.atleast_1d(times) - self.times[0]) * sfreq
|
||||
if use_rounding:
|
||||
index = np.round(index)
|
||||
return index.astype(int)
|
||||
|
||||
def _handle_tmin_tmax(self, tmin, tmax):
|
||||
"""Convert seconds to index into data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tmin : int | float | None
|
||||
Start time of data to get in seconds.
|
||||
tmax : int | float | None
|
||||
End time of data to get in seconds.
|
||||
|
||||
Returns
|
||||
-------
|
||||
start : int
|
||||
Integer index into data corresponding to tmin.
|
||||
stop : int
|
||||
Integer index into data corresponding to tmax.
|
||||
|
||||
"""
|
||||
_validate_type(
|
||||
tmin,
|
||||
types=("numeric", None),
|
||||
item_name="tmin",
|
||||
type_name="int, float, None",
|
||||
)
|
||||
_validate_type(
|
||||
tmax,
|
||||
types=("numeric", None),
|
||||
item_name="tmax",
|
||||
type_name="int, float, None",
|
||||
)
|
||||
|
||||
# handle tmin/tmax as start and stop indices into data array
|
||||
n_times = self.times.size
|
||||
start = 0 if tmin is None else self.time_as_index(tmin)[0]
|
||||
stop = n_times if tmax is None else self.time_as_index(tmax)[0]
|
||||
|
||||
# truncate start/stop to the open interval [0, n_times]
|
||||
start = min(max(0, start), n_times)
|
||||
stop = min(max(0, stop), n_times)
|
||||
|
||||
return start, stop
|
||||
|
||||
@property
|
||||
def times(self):
|
||||
"""Time vector in seconds."""
|
||||
return self._times_readonly
|
||||
|
||||
def _set_times(self, times):
|
||||
"""Set self._times_readonly (and make it read only)."""
|
||||
# naming used to indicate that it shouldn't be
|
||||
# changed directly, but rather via this method
|
||||
self._times_readonly = times.copy()
|
||||
self._times_readonly.flags["WRITEABLE"] = False
|
||||
|
||||
|
||||
class ExtendedTimeMixin(TimeMixin):
|
||||
"""Class for time operations on epochs/evoked-like MNE objects."""
|
||||
|
||||
@property
|
||||
def tmin(self):
|
||||
"""First time point."""
|
||||
return self.times[0]
|
||||
|
||||
@property
|
||||
def tmax(self):
|
||||
"""Last time point."""
|
||||
return self.times[-1]
|
||||
|
||||
@verbose
|
||||
def crop(self, tmin=None, tmax=None, include_tmax=True, verbose=None):
|
||||
"""Crop data to a given time interval.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tmin : float | None
|
||||
Start time of selection in seconds.
|
||||
tmax : float | None
|
||||
End time of selection in seconds.
|
||||
%(include_tmax)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
inst : instance of Raw, Epochs, Evoked, AverageTFR, or SourceEstimate
|
||||
The cropped time-series object, modified in-place.
|
||||
|
||||
Notes
|
||||
-----
|
||||
%(notes_tmax_included_by_default)s
|
||||
"""
|
||||
t_vars = dict(tmin=tmin, tmax=tmax)
|
||||
for name, t_var in t_vars.items():
|
||||
_validate_type(
|
||||
t_var,
|
||||
types=("numeric", None),
|
||||
item_name=name,
|
||||
)
|
||||
|
||||
if tmin is None:
|
||||
tmin = self.tmin
|
||||
elif tmin < self.tmin:
|
||||
warn(
|
||||
f"tmin is not in time interval. tmin is set to "
|
||||
f"{type(self)}.tmin ({self.tmin:g} s)"
|
||||
)
|
||||
tmin = self.tmin
|
||||
|
||||
if tmax is None:
|
||||
tmax = self.tmax
|
||||
elif tmax > self.tmax:
|
||||
warn(
|
||||
f"tmax is not in time interval. tmax is set to "
|
||||
f"{type(self)}.tmax ({self.tmax:g} s)"
|
||||
)
|
||||
tmax = self.tmax
|
||||
include_tmax = True
|
||||
|
||||
mask = _time_mask(
|
||||
self.times, tmin, tmax, sfreq=self.info["sfreq"], include_tmax=include_tmax
|
||||
)
|
||||
self._set_times(self.times[mask])
|
||||
self._raw_times = self._raw_times[mask]
|
||||
self._update_first_last()
|
||||
self._data = self._data[..., mask]
|
||||
|
||||
return self
|
||||
|
||||
@verbose
|
||||
def decimate(self, decim, offset=0, *, verbose=None):
|
||||
"""Decimate the time-series data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(decim)s
|
||||
%(offset_decim)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
inst : MNE-object
|
||||
The decimated object.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.Epochs.resample
|
||||
mne.io.Raw.resample
|
||||
|
||||
Notes
|
||||
-----
|
||||
%(decim_notes)s
|
||||
|
||||
If ``decim`` is 1, this method does not copy the underlying data.
|
||||
|
||||
.. versionadded:: 0.10.0
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
# if epochs have frequencies, they are not in time (EpochsTFR)
|
||||
# and so do not need to be checked whether they have been
|
||||
# appropriately filtered to avoid aliasing
|
||||
from ..epochs import BaseEpochs
|
||||
from ..evoked import Evoked
|
||||
from ..time_frequency import BaseTFR
|
||||
|
||||
# This should be the list of classes that inherit
|
||||
_validate_type(self, (BaseEpochs, Evoked, BaseTFR), "inst")
|
||||
decim, offset, new_sfreq = _check_decim(
|
||||
self.info, decim, offset, check_filter=not hasattr(self, "freqs")
|
||||
)
|
||||
start_idx = int(round(-self._raw_times[0] * (self.info["sfreq"] * self._decim)))
|
||||
self._decim *= decim
|
||||
i_start = start_idx % self._decim + offset
|
||||
decim_slice = slice(i_start, None, self._decim)
|
||||
with self.info._unlock():
|
||||
self.info["sfreq"] = new_sfreq
|
||||
|
||||
if self.preload:
|
||||
if decim != 1:
|
||||
self._data = self._data[..., decim_slice].copy()
|
||||
self._raw_times = self._raw_times[decim_slice].copy()
|
||||
else:
|
||||
self._data = np.ascontiguousarray(self._data)
|
||||
self._decim_slice = slice(None)
|
||||
self._decim = 1
|
||||
else:
|
||||
self._decim_slice = decim_slice
|
||||
self._set_times(self._raw_times[self._decim_slice])
|
||||
self._update_first_last()
|
||||
return self
|
||||
|
||||
def shift_time(self, tshift, relative=True):
|
||||
"""Shift time scale in epoched or evoked data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tshift : float
|
||||
The (absolute or relative) time shift in seconds. If ``relative``
|
||||
is True, positive tshift increases the time value associated with
|
||||
each sample, while negative tshift decreases it.
|
||||
relative : bool
|
||||
If True, increase or decrease time values by ``tshift`` seconds.
|
||||
Otherwise, shift the time values such that the time of the first
|
||||
sample equals ``tshift``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
epochs : MNE-object
|
||||
The modified instance.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This method allows you to shift the *time* values associated with each
|
||||
data sample by an arbitrary amount. It does *not* resample the signal
|
||||
or change the *data* values in any way.
|
||||
"""
|
||||
_check_preload(self, "shift_time")
|
||||
start = tshift + (self.times[0] if relative else 0.0)
|
||||
new_times = start + np.arange(len(self.times)) / self.info["sfreq"]
|
||||
self._set_times(new_times)
|
||||
self._update_first_last()
|
||||
return self
|
||||
|
||||
def _update_first_last(self):
|
||||
"""Update self.first and self.last (sample indices)."""
|
||||
from ..dipole import DipoleFixed
|
||||
from ..evoked import Evoked
|
||||
|
||||
if isinstance(self, Evoked | DipoleFixed):
|
||||
self.first = int(round(self.times[0] * self.info["sfreq"]))
|
||||
self.last = len(self.times) + self.first - 1
|
||||
|
||||
|
||||
def _prepare_write_metadata(metadata):
|
||||
"""Convert metadata to JSON for saving."""
|
||||
if metadata is not None:
|
||||
if not isinstance(metadata, list):
|
||||
metadata = metadata.reset_index().to_json(orient="records")
|
||||
else: # Pandas DataFrame
|
||||
metadata = json.dumps(metadata)
|
||||
assert isinstance(metadata, str)
|
||||
return metadata
|
||||
|
||||
|
||||
def _prepare_read_metadata(metadata):
|
||||
"""Convert saved metadata back from JSON."""
|
||||
if metadata is not None:
|
||||
pd = _check_pandas_installed(strict=False)
|
||||
# use json.loads because this preserves ordering
|
||||
# (which is necessary for round-trip equivalence)
|
||||
metadata = json.loads(metadata, object_pairs_hook=OrderedDict)
|
||||
assert isinstance(metadata, list)
|
||||
if pd:
|
||||
metadata = pd.DataFrame.from_records(metadata)
|
||||
if "index" in metadata.columns:
|
||||
metadata.set_index("index", inplace=True)
|
||||
assert isinstance(metadata, pd.DataFrame)
|
||||
return metadata
|
||||
1119
mne/utils/numerics.py
Normal file
1119
mne/utils/numerics.py
Normal file
File diff suppressed because it is too large
Load Diff
213
mne/utils/progressbar.py
Normal file
213
mne/utils/progressbar.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Some utility functions."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import os.path as op
|
||||
import tempfile
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._logging import logger
|
||||
from .check import _check_option
|
||||
from .config import get_config
|
||||
|
||||
|
||||
class ProgressBar:
|
||||
"""Generate a command-line progressbar.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
iterable : iterable | int | None
|
||||
The iterable to use. Can also be an int for backward compatibility
|
||||
(acts like ``max_value``).
|
||||
initial_value : int
|
||||
Initial value of process, useful when resuming process from a specific
|
||||
value, defaults to 0.
|
||||
mesg : str
|
||||
Message to include at end of progress bar.
|
||||
max_total_width : int | str
|
||||
Maximum total message width. Can use "auto" (default) to try to set
|
||||
a sane value based on the current terminal width.
|
||||
max_value : int | None
|
||||
The max value. If None, the length of ``iterable`` will be used.
|
||||
which_tqdm : str | None
|
||||
Which tqdm module to use. Can be "tqdm", "tqdm.notebook", or "off".
|
||||
Defaults to ``None``, which uses the value of the MNE_TQDM environment
|
||||
variable, or ``"tqdm.auto"`` if that is not set.
|
||||
**kwargs : dict
|
||||
Additional keyword arguments for tqdm.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
iterable=None,
|
||||
initial_value=0,
|
||||
mesg=None,
|
||||
max_total_width="auto",
|
||||
max_value=None,
|
||||
*,
|
||||
which_tqdm=None,
|
||||
**kwargs,
|
||||
):
|
||||
# The following mimics this, but with configurable module to use
|
||||
# from ..externals.tqdm import auto
|
||||
import tqdm
|
||||
|
||||
if which_tqdm is None:
|
||||
which_tqdm = get_config("MNE_TQDM", "tqdm.auto")
|
||||
_check_option(
|
||||
"MNE_TQDM", which_tqdm[:5], ("tqdm", "tqdm.", "off"), extra="beginning"
|
||||
)
|
||||
logger.debug(f"Using ProgressBar with {which_tqdm}")
|
||||
if which_tqdm not in ("tqdm", "off"):
|
||||
try:
|
||||
__import__(which_tqdm)
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f"Unknown tqdm backend {repr(which_tqdm)}, got: {exc}"
|
||||
) from None
|
||||
tqdm = getattr(tqdm, which_tqdm.split(".", 1)[1])
|
||||
tqdm = tqdm.tqdm
|
||||
defaults = dict(
|
||||
leave=True,
|
||||
mininterval=0.016,
|
||||
miniters=1,
|
||||
smoothing=0.05,
|
||||
bar_format="{percentage:3.0f}%|{bar}| {desc} : {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt:>11}{postfix}]", # noqa: E501
|
||||
)
|
||||
for key, val in defaults.items():
|
||||
if key not in kwargs:
|
||||
kwargs.update({key: val})
|
||||
if isinstance(iterable, Iterable):
|
||||
self.iterable = iterable
|
||||
if max_value is None:
|
||||
self.max_value = len(iterable)
|
||||
else:
|
||||
self.max_value = max_value
|
||||
else: # ignore max_value then
|
||||
self.max_value = int(iterable)
|
||||
self.iterable = None
|
||||
if max_total_width == "auto":
|
||||
max_total_width = None # tqdm's auto
|
||||
with tempfile.NamedTemporaryFile("wb", prefix="tmp_mne_prog") as tf:
|
||||
self._mmap_fname = tf.name
|
||||
del tf # should remove the file
|
||||
self._mmap = None
|
||||
disable = logger.level > logging.INFO or which_tqdm == "off"
|
||||
self._tqdm = tqdm(
|
||||
iterable=self.iterable,
|
||||
desc=mesg,
|
||||
total=self.max_value,
|
||||
initial=initial_value,
|
||||
ncols=max_total_width,
|
||||
disable=disable,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def update(self, cur_value):
|
||||
"""Update progressbar with current value of process.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cur_value : number
|
||||
Current value of process. Should be <= max_value (but this is not
|
||||
enforced). The percent of the progressbar will be computed as
|
||||
``(cur_value / max_value) * 100``.
|
||||
"""
|
||||
self.update_with_increment_value(cur_value - self._tqdm.n)
|
||||
|
||||
def update_with_increment_value(self, increment_value):
|
||||
"""Update progressbar with an increment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
increment_value : int
|
||||
Value of the increment of process. The percent of the progressbar
|
||||
will be computed as
|
||||
``(self.cur_value + increment_value / max_value) * 100``.
|
||||
"""
|
||||
try:
|
||||
self._tqdm.update(increment_value)
|
||||
except TypeError: # can happen during GC on Windows
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate to auto-increment the pbar with 1."""
|
||||
yield from self._tqdm
|
||||
|
||||
def subset(self, idx):
|
||||
"""Make a joblib-friendly index subset updater.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idx : ndarray
|
||||
List of indices for this subset.
|
||||
|
||||
Returns
|
||||
-------
|
||||
updater : instance of PBSubsetUpdater
|
||||
Class with a ``.update(ii)`` method.
|
||||
"""
|
||||
return _PBSubsetUpdater(self, idx)
|
||||
|
||||
def __enter__(self): # noqa: D105
|
||||
# This should only be used with pb.subset and parallelization
|
||||
if op.isfile(self._mmap_fname):
|
||||
os.remove(self._mmap_fname)
|
||||
# prevent corner cases where self.max_value == 0
|
||||
self._mmap = np.memmap(
|
||||
self._mmap_fname, bool, "w+", shape=max(self.max_value, 1)
|
||||
)
|
||||
self.update(0) # must be zero as we just created the memmap
|
||||
|
||||
# We need to control how the pickled bars exit: remove print statements
|
||||
self._thread = _UpdateThread(self)
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, type_, value, traceback): # noqa: D105
|
||||
# Restore exit behavior for our one from the main thread
|
||||
self.update(self._mmap.sum())
|
||||
self._tqdm.close()
|
||||
self._thread._mne_run = False
|
||||
self._thread.join()
|
||||
self._mmap = None
|
||||
if op.isfile(self._mmap_fname):
|
||||
try:
|
||||
os.remove(self._mmap_fname)
|
||||
# happens on Windows sometimes
|
||||
except PermissionError: # pragma: no cover
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure output completes."""
|
||||
if getattr(self, "_tqdm", None) is not None:
|
||||
self._tqdm.close()
|
||||
|
||||
|
||||
class _UpdateThread(Thread):
|
||||
def __init__(self, pb):
|
||||
super().__init__(daemon=True)
|
||||
self._mne_run = True
|
||||
self._mne_pb = pb
|
||||
|
||||
def run(self):
|
||||
while self._mne_run:
|
||||
self._mne_pb.update(self._mne_pb._mmap.sum())
|
||||
time.sleep(1.0 / 30.0) # 30 Hz refresh is plenty
|
||||
|
||||
|
||||
class _PBSubsetUpdater:
|
||||
def __init__(self, pb, idx):
|
||||
self.mmap = pb._mmap
|
||||
self.idx = idx
|
||||
|
||||
def update(self, ii):
|
||||
self.mmap[self.idx[ii - 1]] = True
|
||||
104
mne/utils/spectrum.py
Normal file
104
mne/utils/spectrum.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Utility functions for spectral and spectrotemporal analysis."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from inspect import currentframe, getargvalues, signature
|
||||
|
||||
from ..utils import warn
|
||||
|
||||
|
||||
def _get_instance_type_string(inst):
|
||||
"""Get string representation of the originating instance type."""
|
||||
from numpy import ndarray
|
||||
|
||||
from ..epochs import BaseEpochs
|
||||
from ..evoked import Evoked, EvokedArray
|
||||
from ..io import BaseRaw
|
||||
|
||||
parent_classes = inst._inst_type.__bases__
|
||||
if BaseRaw in parent_classes:
|
||||
inst_type_str = "Raw"
|
||||
elif BaseEpochs in parent_classes:
|
||||
inst_type_str = "Epochs"
|
||||
elif inst._inst_type in (Evoked, EvokedArray):
|
||||
inst_type_str = "Evoked"
|
||||
elif inst._inst_type == ndarray:
|
||||
inst_type_str = "Array"
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unknown instance type {inst._inst_type} in {type(inst).__name__}"
|
||||
)
|
||||
return inst_type_str
|
||||
|
||||
|
||||
def _pop_with_fallback(mapping, key, fallback_fun):
|
||||
"""Pop from a dict and fallback to a function parameter's default value."""
|
||||
fallback = signature(fallback_fun).parameters[key].default
|
||||
return mapping.pop(key, fallback)
|
||||
|
||||
|
||||
def _update_old_psd_kwargs(kwargs):
|
||||
"""Modify passed-in kwargs to match new API.
|
||||
|
||||
NOTE: using plot_raw_psd as fallback (even for epochs) is fine because
|
||||
their kwargs are the same (and will stay the same: both are @legacy funcs).
|
||||
"""
|
||||
from ..viz import plot_raw_psd as fallback_fun
|
||||
|
||||
may_change = ("axes", "alpha", "ci_alpha", "amplitude", "ci")
|
||||
for kwarg in may_change:
|
||||
if kwarg in kwargs:
|
||||
warn(
|
||||
"The legacy plot_psd() method got an unexpected keyword argument "
|
||||
f"'{kwarg}', which is a parameter of Spectrum.plot(). Try rewriting as "
|
||||
f"object.compute_psd(...).plot(..., {kwarg}=<whatever>)."
|
||||
)
|
||||
kwargs.setdefault("axes", _pop_with_fallback(kwargs, "ax", fallback_fun))
|
||||
kwargs.setdefault("alpha", _pop_with_fallback(kwargs, "line_alpha", fallback_fun))
|
||||
kwargs.setdefault(
|
||||
"ci_alpha", _pop_with_fallback(kwargs, "area_alpha", fallback_fun)
|
||||
)
|
||||
est = _pop_with_fallback(kwargs, "estimate", fallback_fun)
|
||||
kwargs.setdefault("amplitude", est == "amplitude")
|
||||
area_mode = _pop_with_fallback(kwargs, "area_mode", fallback_fun)
|
||||
kwargs.setdefault("ci", "sd" if area_mode == "std" else area_mode)
|
||||
|
||||
|
||||
def _split_psd_kwargs(*, plot_fun=None, kwargs=None):
|
||||
from ..io import BaseRaw
|
||||
from ..time_frequency import Spectrum
|
||||
|
||||
# if no kwargs supplied, get them from calling func
|
||||
if kwargs is None:
|
||||
frame = currentframe().f_back
|
||||
arginfo = getargvalues(frame)
|
||||
kwargs = {k: v for k, v in arginfo.locals.items() if k in arginfo.args}
|
||||
if arginfo.keywords is not None: # add in **method_kw
|
||||
kwargs.update(arginfo.locals[arginfo.keywords])
|
||||
|
||||
# for compatibility with `plot_raw_psd`, `plot_epochs_psd` and
|
||||
# `plot_epochs_psd_topomap` functions (not just the instance methods/mixin)
|
||||
if "raw" in kwargs:
|
||||
kwargs["self"] = kwargs.pop("raw")
|
||||
elif "epochs" in kwargs:
|
||||
kwargs["self"] = kwargs.pop("epochs")
|
||||
|
||||
# `reject_by_annotation` not needed for Epochs or Evoked
|
||||
if not isinstance(kwargs.pop("self", None), BaseRaw):
|
||||
kwargs.pop("reject_by_annotation", None)
|
||||
|
||||
# handle API changes from .plot_psd(...) to .compute_psd(...).plot(...)
|
||||
if plot_fun is Spectrum.plot:
|
||||
_update_old_psd_kwargs(kwargs)
|
||||
|
||||
# split off the plotting kwargs
|
||||
plot_kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k in signature(plot_fun).parameters and k != "picks"
|
||||
}
|
||||
for k in plot_kwargs:
|
||||
del kwargs[k]
|
||||
return kwargs, plot_kwargs
|
||||
Reference in New Issue
Block a user