initial commit

This commit is contained in:
2025-08-19 09:13:22 -07:00
parent 28464811d6
commit 0977a3e14d
820 changed files with 1003358 additions and 2 deletions

7
mne/utils/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import lazy_loader as lazy
(__getattr__, __dir__, __all__) = lazy.attach_stub(__name__, __file__)

384
mne/utils/__init__.pyi Normal file
View File

@@ -0,0 +1,384 @@
__all__ = [
"ArgvSetter",
"Bunch",
"BunchConst",
"BunchConstNamed",
"ClosingStringIO",
"ExtendedTimeMixin",
"GetEpochsMixin",
"ProgressBar",
"SizeMixin",
"TimeMixin",
"_DefaultEventParser",
"_PCA",
"_ReuseCycle",
"_TempDir",
"_apply_scaling_array",
"_apply_scaling_cov",
"_arange_div",
"_array_equal_nan",
"_array_repr",
"_assert_no_instances",
"_auto_weakref",
"_build_data_frame",
"_check_all_same_channel_names",
"_check_ch_locs",
"_check_channels_spatial_filter",
"_check_combine",
"_check_compensation_grade",
"_check_decim",
"_check_depth",
"_check_dict_keys",
"_check_dt",
"_check_edfio_installed",
"_check_eeglabio_installed",
"_check_event_id",
"_check_fname",
"_check_freesurfer_home",
"_check_head_radius",
"_check_if_nan",
"_check_info_inv",
"_check_integer_or_list",
"_check_method_kwargs",
"_check_on_missing",
"_check_one_ch_type",
"_check_option",
"_check_pandas_index_arguments",
"_check_pandas_installed",
"_check_preload",
"_check_pybv_installed",
"_check_pymatreader_installed",
"_check_qt_version",
"_check_range",
"_check_rank",
"_check_sphere",
"_check_src_normal",
"_check_stc_units",
"_check_subject",
"_check_time_format",
"_clean_names",
"_click_ch_name",
"_compute_row_norms",
"_convert_times",
"_custom_lru_cache",
"_doc_special_members",
"_date_to_julian",
"_dt_to_stamp",
"_empty_hash",
"_ensure_events",
"_ensure_int",
"_explain_exception",
"_file_like",
"_freq_mask",
"_gen_events",
"_get_argvalues",
"_get_blas_funcs",
"_get_call_line",
"_get_extra_data_path",
"_get_inst_data",
"_get_numpy_libs",
"_get_root_dir",
"_get_stim_channel",
"_hashable_ndarray",
"_import_h5io_funcs",
"_import_h5py",
"_import_nibabel",
"_import_pymatreader_funcs",
"_is_numeric",
"_julian_to_date",
"_mask_to_onsets_offsets",
"_on_missing",
"_parse_verbose",
"_path_like",
"_pl",
"_prepare_read_metadata",
"_prepare_write_metadata",
"_raw_annot",
"_record_warnings",
"_reg_pinv",
"_reject_data_segments",
"_repeated_svd",
"_replace_md5",
"_require_version",
"_resource_path",
"_safe_input",
"_scale_dataframe_data",
"_scaled_array",
"_set_pandas_dtype",
"_soft_import",
"_stamp_to_dt",
"_suggest",
"_svd_lwork",
"_sym_mat_pow",
"_time_mask",
"_to_rgb",
"_undo_scaling_array",
"_undo_scaling_cov",
"_url_to_local_path",
"_validate_type",
"_verbose_safe_false",
"array_split_idx",
"assert_and_remove_boundary_annot",
"assert_dig_allclose",
"assert_meg_snr",
"assert_object_equal",
"assert_snr",
"assert_stcs_equal",
"buggy_mkl_svd",
"catch_logging",
"check_fname",
"check_random_state",
"check_version",
"compute_corr",
"copy_doc",
"copy_function_doc_to_method_doc",
"create_slices",
"deprecated",
"deprecated_alias",
"eigh",
"fill_doc",
"filter_out_warnings",
"get_config",
"get_config_path",
"get_subjects_dir",
"grand_average",
"has_freesurfer",
"has_mne_c",
"hashfunc",
"int_like",
"legacy",
"linkcode_resolve",
"logger",
"object_diff",
"object_hash",
"object_size",
"open_docs",
"path_like",
"pformat",
"pinv",
"pinvh",
"random_permutation",
"repr_html",
"requires_freesurfer",
"requires_good_network",
"requires_mne",
"requires_mne_mark",
"requires_openmeeg_mark",
"run_command_if_main",
"run_subprocess",
"running_subprocess",
"set_cache_dir",
"set_config",
"set_log_file",
"set_log_level",
"set_memmap_min_size",
"sizeof_fmt",
"split_list",
"sqrtm_sym",
"sum_squared",
"sys_info",
"use_log_level",
"verbose",
"warn",
"wrapped_stdout",
]
from ._bunch import Bunch, BunchConst, BunchConstNamed
from ._logging import (
ClosingStringIO,
_get_call_line,
_parse_verbose,
_record_warnings,
_verbose_safe_false,
catch_logging,
filter_out_warnings,
logger,
set_log_file,
set_log_level,
use_log_level,
verbose,
warn,
wrapped_stdout,
)
from ._testing import (
ArgvSetter,
_click_ch_name,
_raw_annot,
_TempDir,
assert_and_remove_boundary_annot,
assert_dig_allclose,
assert_meg_snr,
assert_object_equal,
assert_snr,
assert_stcs_equal,
buggy_mkl_svd,
has_freesurfer,
has_mne_c,
requires_freesurfer,
requires_good_network,
requires_mne,
requires_mne_mark,
requires_openmeeg_mark,
run_command_if_main,
)
from .check import (
_check_all_same_channel_names,
_check_ch_locs,
_check_channels_spatial_filter,
_check_combine,
_check_compensation_grade,
_check_depth,
_check_dict_keys,
_check_edfio_installed,
_check_eeglabio_installed,
_check_event_id,
_check_fname,
_check_freesurfer_home,
_check_head_radius,
_check_if_nan,
_check_info_inv,
_check_integer_or_list,
_check_method_kwargs,
_check_on_missing,
_check_one_ch_type,
_check_option,
_check_pandas_index_arguments,
_check_pandas_installed,
_check_preload,
_check_pybv_installed,
_check_pymatreader_installed,
_check_qt_version,
_check_range,
_check_rank,
_check_sphere,
_check_src_normal,
_check_stc_units,
_check_subject,
_check_time_format,
_ensure_events,
_ensure_int,
_import_h5io_funcs,
_import_h5py,
_import_nibabel,
_import_pymatreader_funcs,
_is_numeric,
_on_missing,
_path_like,
_require_version,
_safe_input,
_soft_import,
_suggest,
_to_rgb,
_validate_type,
check_fname,
check_random_state,
check_version,
int_like,
path_like,
)
from .config import (
_get_extra_data_path,
_get_numpy_libs,
_get_root_dir,
_get_stim_channel,
get_config,
get_config_path,
get_subjects_dir,
set_cache_dir,
set_config,
set_memmap_min_size,
sys_info,
)
from .dataframe import (
_build_data_frame,
_convert_times,
_scale_dataframe_data,
_set_pandas_dtype,
)
from .docs import (
_doc_special_members,
copy_doc,
copy_function_doc_to_method_doc,
deprecated,
deprecated_alias,
fill_doc,
legacy,
linkcode_resolve,
open_docs,
)
from .fetching import _url_to_local_path
from .linalg import (
_get_blas_funcs,
_repeated_svd,
_svd_lwork,
_sym_mat_pow,
eigh,
pinv,
pinvh,
sqrtm_sym,
)
from .misc import (
_assert_no_instances,
_auto_weakref,
_clean_names,
_DefaultEventParser,
_empty_hash,
_explain_exception,
_file_like,
_get_argvalues,
_pl,
_resource_path,
pformat,
repr_html,
run_subprocess,
running_subprocess,
sizeof_fmt,
)
from .mixin import (
ExtendedTimeMixin,
GetEpochsMixin,
SizeMixin,
TimeMixin,
_check_decim,
_prepare_read_metadata,
_prepare_write_metadata,
)
from .numerics import (
_PCA,
_apply_scaling_array,
_apply_scaling_cov,
_arange_div,
_array_equal_nan,
_array_repr,
_check_dt,
_compute_row_norms,
_custom_lru_cache,
_date_to_julian,
_dt_to_stamp,
_freq_mask,
_gen_events,
_get_inst_data,
_hashable_ndarray,
_julian_to_date,
_mask_to_onsets_offsets,
_reg_pinv,
_reject_data_segments,
_replace_md5,
_ReuseCycle,
_scaled_array,
_stamp_to_dt,
_time_mask,
_undo_scaling_array,
_undo_scaling_cov,
array_split_idx,
compute_corr,
create_slices,
grand_average,
hashfunc,
object_diff,
object_hash,
object_size,
random_permutation,
split_list,
sum_squared,
)
from .progressbar import ProgressBar

104
mne/utils/_bunch.py Normal file
View File

@@ -0,0 +1,104 @@
"""Bunch-related classes."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from copy import deepcopy
###############################################################################
# Create a Bunch class that acts like a struct (mybunch.key = val)
class Bunch(dict):
"""Dictionary-like object that exposes its keys as attributes."""
def __init__(self, **kwargs):
dict.__init__(self, kwargs)
self.__dict__ = self
###############################################################################
# A protected version that prevents overwriting
class BunchConst(Bunch):
"""Class to prevent us from re-defining constants (DRY)."""
def __setitem__(self, key, val): # noqa: D105
if key != "__dict__" and key in self:
raise AttributeError(f"Attribute {repr(key)} already set")
super().__setitem__(key, val)
###############################################################################
# A version that tweaks the __repr__ of its values based on keys
class BunchConstNamed(BunchConst):
"""Class to provide nice __repr__ for our integer constants.
Only supports string keys and int or float values.
"""
def __setattr__(self, attr, val): # noqa: D105
assert isinstance(attr, str)
if isinstance(val, int):
val = NamedInt(attr, val)
elif isinstance(val, float):
val = NamedFloat(attr, val)
else:
assert isinstance(val, BunchConstNamed), type(val)
super().__setattr__(attr, val)
class _Named:
"""Provide shared methods for giving named-representation subclasses."""
def __new__(cls, name, val): # noqa: D102,D105
out = _named_subclass(cls).__new__(cls, val)
out._name = name
return out
def __str__(self): # noqa: D105
return f"{self.__class__.mro()[-2](self)} ({self._name})"
__repr__ = __str__
# see https://stackoverflow.com/a/15774013/2175965
def __copy__(self): # noqa: D105
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result
def __deepcopy__(self, memo): # noqa: D105
cls = self.__class__
result = cls.__new__(cls, self._name, self)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, deepcopy(v, memo))
return result
def __getnewargs__(self): # noqa: D105
return self._name, _named_subclass(self)(self)
def _named_subclass(klass):
if not isinstance(klass, type):
klass = klass.__class__
subklass = klass.mro()[-2]
assert subklass in (int, float)
return subklass
class NamedInt(_Named, int):
"""Int with a name in __repr__."""
pass # noqa
class NamedFloat(_Named, float):
"""Float with a name in __repr__."""
pass # noqa

527
mne/utils/_logging.py Normal file
View File

@@ -0,0 +1,527 @@
"""Some utility functions."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import contextlib
import importlib
import inspect
import logging
import os.path as op
import re
import sys
import warnings
from collections.abc import Callable
from io import StringIO
from typing import Any, TypeVar
from decorator import FunctionMaker
from .docs import fill_doc
logger = logging.getLogger("mne") # one selection here used across mne-python
logger.propagate = False # don't propagate (in case of multiple imports)
# class to provide frame information (should be low overhead, just on logger
# calls)
class _FrameFilter(logging.Filter):
def __init__(self):
self.add_frames = 0
def filter(self, record):
record.frame_info = "Unknown"
if self.add_frames:
# 5 is the offset necessary to get out of here and the logging
# module, reversal is to put the oldest at the top
frame_info = _frame_info(5 + self.add_frames)[5:][::-1]
if len(frame_info):
frame_info[-1] = (frame_info[-1] + " :").ljust(30)
if len(frame_info) > 1:
frame_info[0] = "" + frame_info[0]
frame_info[-1] = "" + frame_info[-1]
for ii, info in enumerate(frame_info[1:-1], 1):
frame_info[ii] = "" + info
record.frame_info = "\n".join(frame_info)
return True
_filter = _FrameFilter()
logger.addFilter(_filter)
# Provide help for static type checkers:
# https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators
_FuncT = TypeVar("_FuncT", bound=Callable[..., Any])
def verbose(function: _FuncT) -> _FuncT:
"""Verbose decorator to allow functions to override log-level.
Parameters
----------
function : callable
Function to be decorated by setting the verbosity level.
Returns
-------
dec : callable
The decorated function.
See Also
--------
set_log_level
set_config
Notes
-----
This decorator is used to set the verbose level during a function or method
call, such as :func:`mne.compute_covariance`. The `verbose` keyword
argument can be 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL', True (an
alias for 'INFO'), or False (an alias for 'WARNING'). To set the global
verbosity level for all functions, use :func:`mne.set_log_level`.
This function also serves as a docstring filler.
Examples
--------
You can use the ``verbose`` argument to set the verbose level on the fly::
>>> import mne
>>> cov = mne.compute_raw_covariance(raw, verbose='WARNING') # doctest: +SKIP
>>> cov = mne.compute_raw_covariance(raw, verbose='INFO') # doctest: +SKIP
Using up to 49 segments
Number of samples used : 5880
[done]
""" # noqa: E501
# See https://decorator.readthedocs.io/en/latest/tests.documentation.html
# #dealing-with-third-party-decorators
try:
fill_doc(function)
except TypeError: # nothing to add
pass
# Anything using verbose should have `verbose=None` in the signature.
# This code path will raise an error if this is not the case.
body = """\
def %(name)s(%(signature)s):\n
try:
do_level_change = verbose is not None
except (NameError, UnboundLocalError):
raise RuntimeError('Function/method %%s does not accept verbose '
'parameter' %% (_function_,)) from None
if do_level_change:
with _use_log_level_(verbose):
return _function_(%(shortsignature)s)
else:
return _function_(%(shortsignature)s)"""
evaldict = dict(_use_log_level_=use_log_level, _function_=function)
fm = FunctionMaker(function)
attrs = dict(
__wrapped__=function,
__qualname__=function.__qualname__,
__globals__=function.__globals__,
)
return fm.make(body, evaldict, addsource=True, **attrs)
@fill_doc
class use_log_level:
"""Context manager for logging level.
Parameters
----------
%(verbose)s
%(add_frames)s
See Also
--------
mne.verbose
Notes
-----
See the :ref:`logging documentation <tut-logging>` for details.
Examples
--------
>>> from mne import use_log_level
>>> from mne.utils import logger
>>> with use_log_level(False):
... # Most MNE logger messages are "info" level, False makes them not
... # print:
... logger.info('This message will not be printed')
>>> with use_log_level(True):
... # Using verbose=True in functions, methods, or this context manager
... # will ensure they are printed
... logger.info('This message will be printed!')
This message will be printed!
"""
def __init__(self, verbose=None, *, add_frames=None):
self._level = verbose
self._add_frames = add_frames
self._old_frames = _filter.add_frames
def __enter__(self): # noqa: D105
self._old_level = set_log_level(
self._level, return_old_level=True, add_frames=self._add_frames
)
def __exit__(self, *args): # noqa: D105
add_frames = self._old_frames if self._add_frames is not None else None
set_log_level(self._old_level, add_frames=add_frames)
_LOGGING_TYPES = dict(
DEBUG=logging.DEBUG,
INFO=logging.INFO,
WARNING=logging.WARNING,
ERROR=logging.ERROR,
CRITICAL=logging.CRITICAL,
)
@fill_doc
def set_log_level(verbose=None, return_old_level=False, add_frames=None):
"""Set the logging level.
Parameters
----------
verbose : bool, str, int, or None
The verbosity of messages to print. If a str, it can be either DEBUG,
INFO, WARNING, ERROR, or CRITICAL. Note that these are for
convenience and are equivalent to passing in logging.DEBUG, etc.
For bool, True is the same as 'INFO', False is the same as 'WARNING'.
If None, the environment variable MNE_LOGGING_LEVEL is read, and if
it doesn't exist, defaults to INFO.
return_old_level : bool
If True, return the old verbosity level.
%(add_frames)s
Returns
-------
old_level : int
The old level. Only returned if ``return_old_level`` is True.
"""
old_verbose = logger.level
verbose = _parse_verbose(verbose)
if verbose != old_verbose:
logger.setLevel(verbose)
if add_frames is not None:
_filter.add_frames = int(add_frames)
fmt = "%(frame_info)s " if add_frames else ""
fmt += "%(message)s"
fmt = logging.Formatter(fmt)
for handler in logger.handlers:
handler.setFormatter(fmt)
return old_verbose if return_old_level else None
def _parse_verbose(verbose):
from .check import _check_option, _validate_type
from .config import get_config
_validate_type(verbose, (bool, str, int, None), "verbose")
if verbose is None:
verbose = get_config("MNE_LOGGING_LEVEL", "INFO")
elif isinstance(verbose, bool):
if verbose is True:
verbose = "INFO"
else:
verbose = "WARNING"
if isinstance(verbose, str):
verbose = verbose.upper()
_check_option("verbose", verbose, _LOGGING_TYPES, "(when a string)")
verbose = _LOGGING_TYPES[verbose]
return verbose
def set_log_file(fname=None, output_format="%(message)s", overwrite=None):
"""Set the log to print to a file.
Parameters
----------
fname : path-like | None
Filename of the log to print to. If None, stdout is used.
To suppress log outputs, use set_log_level('WARNING').
output_format : str
Format of the output messages. See the following for examples:
https://docs.python.org/dev/howto/logging.html
e.g., "%(asctime)s - %(levelname)s - %(message)s".
overwrite : bool | None
Overwrite the log file (if it exists). Otherwise, statements
will be appended to the log (default). None is the same as False,
but additionally raises a warning to notify the user that log
entries will be appended.
"""
_remove_close_handlers(logger)
if fname is not None:
if op.isfile(fname) and overwrite is None:
# Don't use warn() here because we just want to
# emit a warnings.warn here (not logger.warn)
warnings.warn(
"Log entries will be appended to the file. Use "
"overwrite=False to avoid this message in the "
"future.",
RuntimeWarning,
stacklevel=2,
)
overwrite = False
mode = "w" if overwrite else "a"
lh = logging.FileHandler(fname, mode=mode)
else:
"""we should just be able to do:
lh = logging.StreamHandler(sys.stdout)
but because doctests uses some magic on stdout, we have to do this:
"""
lh = logging.StreamHandler(WrapStdOut())
lh.setFormatter(logging.Formatter(output_format))
# actually add the stream handler
logger.addHandler(lh)
def _remove_close_handlers(logger):
for h in list(logger.handlers):
# only remove our handlers (get along nicely with nose)
if isinstance(h, logging.FileHandler | logging.StreamHandler):
if isinstance(h, logging.FileHandler):
h.close()
logger.removeHandler(h)
class ClosingStringIO(StringIO):
"""StringIO that closes after getvalue()."""
def getvalue(self, close=True):
"""Get the value."""
out = super().getvalue()
if close:
self.close()
return out
class catch_logging:
"""Store logging.
This will remove all other logging handlers, and return the handler to
stdout when complete.
"""
def __init__(self, verbose=None):
self.verbose = verbose
def __enter__(self): # noqa: D105
if self.verbose is not None:
self._ctx = use_log_level(self.verbose)
else:
self._ctx = contextlib.nullcontext()
self._data = ClosingStringIO()
self._lh = logging.StreamHandler(self._data)
self._lh.setFormatter(logging.Formatter("%(message)s"))
self._lh._mne_file_like = True # monkey patch for warn() use
_remove_close_handlers(logger)
logger.addHandler(self._lh)
self._ctx.__enter__()
return self._data
def __exit__(self, *args): # noqa: D105
self._ctx.__exit__(*args)
logger.removeHandler(self._lh)
set_log_file(None)
@contextlib.contextmanager
def _record_warnings():
# this is a helper that mostly acts like pytest.warns(None) did before
# pytest 7
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
yield w
class WrapStdOut:
"""Dynamically wrap to sys.stdout.
This makes packages that monkey-patch sys.stdout (e.g.doctest,
sphinx-gallery) work properly.
"""
def __getattr__(self, name): # noqa: D105
# Even more ridiculous than this class, this must be sys.stdout (not
# just stdout) in order for this to work (tested on OSX and Linux)
if hasattr(sys.stdout, name):
return getattr(sys.stdout, name)
else:
raise AttributeError(f"'file' object has not attribute '{name}'")
_verbose_dec_re = re.compile("^<decorator-gen-[0-9]+>$")
def warn(message, category=RuntimeWarning, module="mne", ignore_namespaces=("mne",)):
"""Emit a warning with trace outside the mne namespace.
This function takes arguments like warnings.warn, and sends messages
using both ``warnings.warn`` and ``logger.warn``. Warnings can be
generated deep within nested function calls. In order to provide a
more helpful warning, this function traverses the stack until it
reaches a frame outside the ``mne`` namespace that caused the error.
Parameters
----------
message : str
Warning message.
category : instance of Warning
The warning class. Defaults to ``RuntimeWarning``.
module : str
The name of the module emitting the warning.
ignore_namespaces : list of str
Namespaces to ignore when traversing the stack.
.. versionadded:: 0.24
"""
root_dirs = [importlib.import_module(ns) for ns in ignore_namespaces]
root_dirs = [op.dirname(ns.__file__) for ns in root_dirs]
frame = None
if logger.level <= logging.WARNING:
frame = inspect.currentframe()
while frame:
fname = frame.f_code.co_filename
lineno = frame.f_lineno
# in verbose dec
if not _verbose_dec_re.search(fname):
# treat tests as scripts
# and don't capture unittest/case.py (assert_raises)
if (
not (
any(fname.startswith(rd) for rd in root_dirs)
or ("unittest" in fname and "case" in fname)
)
or op.basename(op.dirname(fname)) == "tests"
):
break
frame = frame.f_back
del frame
# We need to use this instead of warn(message, category, stacklevel)
# because we move out of the MNE stack, so warnings won't properly
# recognize the module name (and our warnings.simplefilter will fail)
warnings.warn_explicit(
message,
category,
fname,
lineno,
module,
globals().get("__warningregistry__", {}),
)
# To avoid a duplicate warning print, we only emit the logger.warning if
# one of the handlers is a FileHandler. See gh-5592
# But it's also nice to be able to do:
# with mne.utils.use_log_level('warning', add_frames=3):
# so also check our add_frames attribute.
if (
any(
isinstance(h, logging.FileHandler) or getattr(h, "_mne_file_like", False)
for h in logger.handlers
)
or _filter.add_frames
):
logger.warning(message)
def _get_call_line():
"""Get the call line from within a function."""
frame = inspect.currentframe().f_back.f_back
if _verbose_dec_re.search(frame.f_code.co_filename):
frame = frame.f_back
context = inspect.getframeinfo(frame).code_context
context = "unknown" if context is None else context[0].strip()
return context
def filter_out_warnings(warn_record, category=None, match=None):
r"""Remove particular records from ``warn_record``.
This helper takes a list of :class:`warnings.WarningMessage` objects,
and remove those matching category and/or text.
Parameters
----------
category: WarningMessage type | None
class of the message to filter out
match : str | None
text or regex that matches the error message to filter out
"""
regexp = re.compile(".*" if match is None else match)
is_category = [
w.category == category if category is not None else True
for w in warn_record._list
]
is_match = [regexp.match(w.message.args[0]) is not None for w in warn_record._list]
ind = [ind for ind, (c, m) in enumerate(zip(is_category, is_match)) if c and m]
for i in reversed(ind):
warn_record._list.pop(i)
@contextlib.contextmanager
def wrapped_stdout(indent="", cull_newlines=False):
"""Wrap stdout writes to logger.info, with an optional indent prefix.
Parameters
----------
indent : str
The indentation to add.
cull_newlines : bool
If True, cull any new/blank lines at the end.
"""
orig_stdout = sys.stdout
my_out = ClosingStringIO()
sys.stdout = my_out
try:
yield
finally:
sys.stdout = orig_stdout
pending_newlines = 0
for line in my_out.getvalue().split("\n"):
if not line.strip() and cull_newlines:
pending_newlines += 1
continue
for _ in range(pending_newlines):
logger.info("\n")
logger.info(indent + line)
def _frame_info(n):
frame = inspect.currentframe()
try:
frame = frame.f_back
infos = list()
for _ in range(n):
try:
name = frame.f_globals["__name__"]
except KeyError: # in our verbose dec
pass
else:
infos.append(f'{name.lstrip("mne.")}:{frame.f_lineno}')
frame = frame.f_back
if frame is None:
break
return infos
except Exception:
return ["unknown"]
finally:
del frame
def _verbose_safe_false(*, level="warning"):
lev = _LOGGING_TYPES[level.upper()]
return lev if logger.level <= lev else None

359
mne/utils/_testing.py Normal file
View File

@@ -0,0 +1,359 @@
"""Testing functions."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import inspect
import os
import sys
import tempfile
import traceback
from functools import wraps
from shutil import rmtree
from unittest import SkipTest
import numpy as np
from numpy.testing import assert_allclose, assert_array_equal
from scipy import linalg
from ._logging import ClosingStringIO, warn
from .check import check_version
from .misc import run_subprocess
from .numerics import object_diff
def _explain_exception(start=-1, stop=None, prefix="> "):
"""Explain an exception."""
# start=-1 means "only the most recent caller"
etype, value, tb = sys.exc_info()
string = traceback.format_list(traceback.extract_tb(tb)[start:stop])
string = "".join(string).split("\n") + traceback.format_exception_only(etype, value)
string = ":\n" + prefix + ("\n" + prefix).join(string)
return string
class _TempDir(str):
"""Create and auto-destroy temp dir.
This is designed to be used with testing modules. Instances should be
defined inside test functions. Instances defined at module level can not
guarantee proper destruction of the temporary directory.
When used at module level, the current use of the __del__() method for
cleanup can fail because the rmtree function may be cleaned up before this
object (an alternative could be using the atexit module instead).
"""
def __new__(self): # noqa: D105
new = str.__new__(self, tempfile.mkdtemp(prefix="tmp_mne_tempdir_"))
return new
def __init__(self):
self._path = self.__str__()
def __del__(self): # noqa: D105
rmtree(self._path, ignore_errors=True)
def requires_mne(func):
"""Decorate a function as requiring MNE."""
return requires_mne_mark()(func)
def requires_mne_mark():
"""Mark pytest tests that require MNE-C."""
import pytest
return pytest.mark.skipif(not has_mne_c(), reason="Requires MNE-C")
def requires_openmeeg_mark():
"""Mark pytest tests that require OpenMEEG."""
import pytest
return pytest.mark.skipif(
not check_version("openmeeg", "2.5.6"), reason="Requires OpenMEEG >= 2.5.6"
)
def requires_freesurfer(arg):
"""Require Freesurfer."""
import pytest
reason = "Requires Freesurfer"
if isinstance(arg, str):
# Calling as @requires_freesurfer('progname'): return decorator
# after checking for progname existence
reason += f" command: {arg}"
try:
run_subprocess([arg, "--version"])
except Exception:
skip = True
else:
skip = False
return pytest.mark.skipif(skip, reason=reason)
else:
# Calling directly as @requires_freesurfer: return decorated function
# and just check env var existence
return pytest.mark.skipif(not has_freesurfer(), reason="Requires Freesurfer")(
arg
)
def requires_good_network(func):
import pytest
return pytest.mark.skipif(
int(os.environ.get("MNE_SKIP_NETWORK_TESTS", 0)),
reason="MNE_SKIP_NETWORK_TESTS is set",
)(func)
def run_command_if_main():
"""Run a given command if it's __main__."""
local_vars = inspect.currentframe().f_back.f_locals
if local_vars.get("__name__", "") == "__main__":
local_vars["run"]()
class ArgvSetter:
"""Temporarily set sys.argv."""
def __init__(self, args=(), disable_stdout=True, disable_stderr=True):
self.argv = list(("python",) + args)
self.stdout = ClosingStringIO() if disable_stdout else sys.stdout
self.stderr = ClosingStringIO() if disable_stderr else sys.stderr
def __enter__(self): # noqa: D105
self.orig_argv = sys.argv
sys.argv = self.argv
self.orig_stdout = sys.stdout
sys.stdout = self.stdout
self.orig_stderr = sys.stderr
sys.stderr = self.stderr
return self
def __exit__(self, *args): # noqa: D105
sys.argv = self.orig_argv
sys.stdout = self.orig_stdout
sys.stderr = self.orig_stderr
def has_mne_c():
"""Check for MNE-C."""
return "MNE_ROOT" in os.environ
def has_freesurfer():
"""Check for Freesurfer."""
return "FREESURFER_HOME" in os.environ
def buggy_mkl_svd(function):
"""Decorate tests that make calls to SVD and intermittently fail."""
@wraps(function)
def dec(*args, **kwargs):
try:
return function(*args, **kwargs)
except np.linalg.LinAlgError as exp:
if "SVD did not converge" in str(exp):
msg = "Intel MKL SVD convergence error detected, skipping test"
warn(msg)
raise SkipTest(msg)
raise
return dec
def assert_and_remove_boundary_annot(annotations, n=1):
"""Assert that there are boundary annotations and remove them."""
from ..io import BaseRaw
if isinstance(annotations, BaseRaw): # allow either input
annotations = annotations.annotations
for key in ("EDGE", "BAD"):
idx = np.where(annotations.description == f"{key} boundary")[0]
assert len(idx) == n
annotations.delete(idx)
def assert_object_equal(a, b, *, err_msg="Object mismatch"):
"""Assert two objects are equal."""
d = object_diff(a, b)
assert d == "", f"{err_msg}\n{d}"
def _raw_annot(meas_date, orig_time):
from .._fiff.meas_info import create_info
from ..annotations import Annotations, _handle_meas_date
from ..io import RawArray
info = create_info(ch_names=10, sfreq=10.0)
raw = RawArray(data=np.empty((10, 10)), info=info, first_samp=10)
if meas_date is not None:
meas_date = _handle_meas_date(meas_date)
with raw.info._unlock(check_after=True):
raw.info["meas_date"] = meas_date
annot = Annotations([0.5], [0.2], ["dummy"], orig_time)
raw.set_annotations(annotations=annot)
return raw
def _get_data(x, ch_idx):
"""Get the (n_ch, n_times) data array."""
from ..evoked import Evoked
from ..io import BaseRaw
if isinstance(x, BaseRaw):
return x[ch_idx][0]
elif isinstance(x, Evoked):
return x.data[ch_idx]
def _check_snr(actual, desired, picks, min_tol, med_tol, msg, kind="MEG"):
"""Check the SNR of a set of channels."""
actual_data = _get_data(actual, picks)
desired_data = _get_data(desired, picks)
bench_rms = np.sqrt(np.mean(desired_data * desired_data, axis=1))
error = actual_data - desired_data
error_rms = np.sqrt(np.mean(error * error, axis=1))
np.clip(error_rms, 1e-60, np.inf, out=error_rms) # avoid division by zero
snrs = bench_rms / error_rms
# min tol
snr = snrs.min()
bad_count = (snrs < min_tol).sum()
msg = f" ({msg})" if msg != "" else msg
assert bad_count == 0, (
f"SNR (worst {snr:0.2f}) < {min_tol:0.2f} "
f"for {bad_count}/{len(picks)} channels{msg}"
)
# median tol
snr = np.median(snrs)
assert snr >= med_tol, f"{kind} SNR median {snr:0.2f} < {med_tol:0.2f}{msg}"
def assert_meg_snr(
actual, desired, min_tol, med_tol=500.0, chpi_med_tol=500.0, msg=None
):
"""Assert channel SNR of a certain level.
Mostly useful for operations like Maxwell filtering that modify
MEG channels while leaving EEG and others intact.
"""
from .._fiff.pick import pick_types
picks = pick_types(desired.info, meg=True, exclude=[])
picks_desired = pick_types(desired.info, meg=True, exclude=[])
assert_array_equal(picks, picks_desired, err_msg="MEG pick mismatch")
chpis = pick_types(actual.info, meg=False, chpi=True, exclude=[])
chpis_desired = pick_types(desired.info, meg=False, chpi=True, exclude=[])
if chpi_med_tol is not None:
assert_array_equal(chpis, chpis_desired, err_msg="cHPI pick mismatch")
others = np.setdiff1d(
np.arange(len(actual.ch_names)), np.concatenate([picks, chpis])
)
others_desired = np.setdiff1d(
np.arange(len(desired.ch_names)), np.concatenate([picks_desired, chpis_desired])
)
assert_array_equal(others, others_desired, err_msg="Other pick mismatch")
if len(others) > 0: # if non-MEG channels present
assert_allclose(
_get_data(actual, others),
_get_data(desired, others),
atol=1e-11,
rtol=1e-5,
err_msg="non-MEG channel mismatch",
)
_check_snr(actual, desired, picks, min_tol, med_tol, msg, kind="MEG")
if chpi_med_tol is not None and len(chpis) > 0:
_check_snr(actual, desired, chpis, 0.0, chpi_med_tol, msg, kind="cHPI")
def assert_snr(actual, desired, tol):
"""Assert actual and desired arrays are within some SNR tolerance."""
with np.errstate(divide="ignore"): # allow infinite
snr = linalg.norm(desired, ord="fro") / linalg.norm(desired - actual, ord="fro")
assert snr >= tol, f"{snr} < {tol}"
def assert_stcs_equal(stc1, stc2):
"""Check that two STC are equal."""
assert_allclose(stc1.times, stc2.times)
assert_allclose(stc1.data, stc2.data)
assert_array_equal(stc1.vertices[0], stc2.vertices[0])
assert_array_equal(stc1.vertices[1], stc2.vertices[1])
assert_allclose(stc1.tmin, stc2.tmin)
assert_allclose(stc1.tstep, stc2.tstep)
def _dig_sort_key(dig):
"""Sort dig keys."""
return (dig["kind"], dig["ident"])
def assert_dig_allclose(info_py, info_bin, limit=None):
"""Assert dig allclose."""
from .._fiff.constants import FIFF
from .._fiff.meas_info import Info
from ..bem import fit_sphere_to_headshape
from ..channels.montage import DigMontage
# test dig positions
dig_py, dig_bin = info_py, info_bin
if isinstance(dig_py, Info):
assert isinstance(dig_bin, Info)
dig_py, dig_bin = dig_py["dig"], dig_bin["dig"]
else:
assert isinstance(dig_bin, DigMontage)
assert isinstance(dig_py, DigMontage)
dig_py, dig_bin = dig_py.dig, dig_bin.dig
info_py = info_bin = None
assert isinstance(dig_py, list)
assert isinstance(dig_bin, list)
dig_py = sorted(dig_py, key=_dig_sort_key)
dig_bin = sorted(dig_bin, key=_dig_sort_key)
assert len(dig_py) == len(dig_bin)
for ii, (d_py, d_bin) in enumerate(zip(dig_py[:limit], dig_bin[:limit])):
for key in ("ident", "kind", "coord_frame"):
assert d_py[key] == d_bin[key], key
assert_allclose(
d_py["r"],
d_bin["r"],
rtol=1e-5,
atol=1e-5,
err_msg=f"Failure on {ii}:\n{d_py['r']}\n{d_bin['r']}",
)
if any(d["kind"] == FIFF.FIFFV_POINT_EXTRA for d in dig_py) and info_py is not None:
r_bin, o_head_bin, o_dev_bin = fit_sphere_to_headshape(
info_bin, units="m", verbose="error"
)
r_py, o_head_py, o_dev_py = fit_sphere_to_headshape(
info_py, units="m", verbose="error"
)
assert_allclose(r_py, r_bin, atol=1e-6)
assert_allclose(o_dev_py, o_dev_bin, rtol=1e-5, atol=1e-6)
assert_allclose(o_head_py, o_head_bin, rtol=1e-5, atol=1e-6)
def _click_ch_name(fig, ch_index=0, button=1):
"""Click on a channel name in a raw/epochs/ICA browse-style plot."""
from ..viz.utils import _fake_click
fig.canvas.draw()
text = fig.mne.ax_main.get_yticklabels()[ch_index]
bbox = text.get_window_extent()
x = bbox.intervalx.mean()
y = bbox.intervaly.mean()
_fake_click(fig, fig.mne.ax_main, (x, y), xform="pix", button=button)
def _get_suptitle(fig):
"""Get fig suptitle (shim for matplotlib < 3.8.0)."""
# TODO: obsolete when minimum MPL version is 3.8
if check_version("matplotlib", "3.8"):
return fig.get_suptitle()
else:
# unreliable hack; should work in most tests as we rarely use `sup_{x,y}label`
return fig.texts[0].get_text()

14
mne/utils/_typing.py Normal file
View File

@@ -0,0 +1,14 @@
"""Shared objects used for type annotations."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import sys
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing import TypeVar
Self = TypeVar("Self")

1292
mne/utils/check.py Normal file

File diff suppressed because it is too large Load Diff

917
mne/utils/config.py Normal file
View File

@@ -0,0 +1,917 @@
"""The config functions."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import atexit
import json
import multiprocessing
import os
import os.path as op
import platform
import shutil
import subprocess
import sys
import tempfile
from functools import lru_cache, partial
from importlib import import_module
from pathlib import Path
from urllib.error import URLError
from urllib.request import urlopen
from packaging.version import parse
from ._logging import logger, warn
from .check import _check_fname, _check_option, _check_qt_version, _validate_type
from .docs import fill_doc
from .misc import _pl
_temp_home_dir = None
class UnknownPlatformError(Exception):
"""Exception raised for unknown platforms."""
def set_cache_dir(cache_dir):
"""Set the directory to be used for temporary file storage.
This directory is used by joblib to store memmapped arrays,
which reduces memory requirements and speeds up parallel
computation.
Parameters
----------
cache_dir : str or None
Directory to use for temporary file storage. None disables
temporary file storage.
"""
if cache_dir is not None and not op.exists(cache_dir):
raise OSError(f"Directory {cache_dir} does not exist")
set_config("MNE_CACHE_DIR", cache_dir, set_env=False)
def set_memmap_min_size(memmap_min_size):
"""Set the minimum size for memmaping of arrays for parallel processing.
Parameters
----------
memmap_min_size : str or None
Threshold on the minimum size of arrays that triggers automated memory
mapping for parallel processing, e.g., '1M' for 1 megabyte.
Use None to disable memmaping of large arrays.
"""
_validate_type(memmap_min_size, (str, None), "memmap_min_size")
if memmap_min_size is not None:
if memmap_min_size[-1] not in ["K", "M", "G"]:
raise ValueError(
"The size has to be given in kilo-, mega-, or "
f"gigabytes, e.g., 100K, 500M, 1G, got {repr(memmap_min_size)}"
)
set_config("MNE_MEMMAP_MIN_SIZE", memmap_min_size, set_env=False)
# List the known configuration values
_known_config_types = {
"MNE_3D_OPTION_ANTIALIAS": (
"bool, whether to use full-screen antialiasing in 3D plots"
),
"MNE_3D_OPTION_DEPTH_PEELING": "bool, whether to use depth peeling in 3D plots",
"MNE_3D_OPTION_MULTI_SAMPLES": (
"int, number of samples to use for full-screen antialiasing"
),
"MNE_3D_OPTION_SMOOTH_SHADING": ("bool, whether to use smooth shading in 3D plots"),
"MNE_3D_OPTION_THEME": ("str, the color theme (light or dark) to use for 3D plots"),
"MNE_BROWSE_RAW_SIZE": (
"tuple, width and height of the raw browser window (in inches)"
),
"MNE_BROWSER_BACKEND": (
"str, the backend to use for the MNE Browse Raw window (qt or matplotlib)"
),
"MNE_BROWSER_OVERVIEW_MODE": (
"str, the overview mode to use in the MNE Browse Raw window )"
"(see mne.viz.plot_raw for valid options)"
),
"MNE_BROWSER_PRECOMPUTE": (
"bool, whether to precompute raw data in the MNE Browse Raw window"
),
"MNE_BROWSER_THEME": "str, the color theme (light or dark) to use for the browser",
"MNE_BROWSER_USE_OPENGL": (
"bool, whether to use OpenGL for rendering in the MNE Browse Raw window"
),
"MNE_CACHE_DIR": "str, path to the cache directory for parallel execution",
"MNE_COREG_ADVANCED_RENDERING": (
"bool, whether to use advanced OpenGL rendering in mne coreg"
),
"MNE_COREG_COPY_ANNOT": (
"bool, whether to copy the annotation files during warping"
),
"MNE_COREG_FULLSCREEN": "bool, whether to use full-screen mode in mne coreg",
"MNE_COREG_GUESS_MRI_SUBJECT": (
"bool, whether to guess the MRI subject in mne coreg"
),
"MNE_COREG_HEAD_HIGH_RES": (
"bool, whether to use high-res head surface in mne coreg"
),
"MNE_COREG_HEAD_OPACITY": ("bool, the head surface opacity to use in mne coreg"),
"MNE_COREG_HEAD_INSIDE": (
"bool, whether to add an opaque inner scalp head surface to help "
"occlude points behind the head in mne coreg"
),
"MNE_COREG_INTERACTION": (
"str, interaction style in mne coreg (trackball or terrain)"
),
"MNE_COREG_MARK_INSIDE": (
"bool, whether to mark points inside the head surface in mne coreg"
),
"MNE_COREG_PREPARE_BEM": (
"bool, whether to prepare the BEM solution after warping in mne coreg"
),
"MNE_COREG_ORIENT_TO_SURFACE": (
"bool, whether to orient the digitization markers to the head surface "
"in mne coreg"
),
"MNE_COREG_SCALE_LABELS": (
"bool, whether to scale the MRI labels during warping in mne coreg"
),
"MNE_COREG_SCALE_BY_DISTANCE": (
"bool, whether to scale the digitization markers by their distance from "
"the scalp in mne coreg"
),
"MNE_COREG_SCENE_SCALE": (
"float, the scale factor of the 3D scene in mne coreg (default 0.16)"
),
"MNE_COREG_WINDOW_HEIGHT": "int, window height for mne coreg",
"MNE_COREG_WINDOW_WIDTH": "int, window width for mne coreg",
"MNE_COREG_SUBJECTS_DIR": "str, path to the subjects directory for mne coreg",
"MNE_CUDA_DEVICE": "int, CUDA device to use for GPU processing",
"MNE_DATA": "str, default data directory",
"MNE_DATASETS_BRAINSTORM_PATH": "str, path for brainstorm data",
"MNE_DATASETS_EEGBCI_PATH": "str, path for EEGBCI data",
"MNE_DATASETS_EPILEPSY_ECOG_PATH": "str, path for epilepsy_ecog data",
"MNE_DATASETS_HF_SEF_PATH": "str, path for HF_SEF data",
"MNE_DATASETS_MEGSIM_PATH": "str, path for MEGSIM data",
"MNE_DATASETS_MISC_PATH": "str, path for misc data",
"MNE_DATASETS_MTRF_PATH": "str, path for MTRF data",
"MNE_DATASETS_SAMPLE_PATH": "str, path for sample data",
"MNE_DATASETS_SOMATO_PATH": "str, path for somato data",
"MNE_DATASETS_MULTIMODAL_PATH": "str, path for multimodal data",
"MNE_DATASETS_FNIRS_MOTOR_PATH": "str, path for fnirs_motor data",
"MNE_DATASETS_OPM_PATH": "str, path for OPM data",
"MNE_DATASETS_SPM_FACE_DATASETS_TESTS": "str, path for spm_face data",
"MNE_DATASETS_SPM_FACE_PATH": "str, path for spm_face data",
"MNE_DATASETS_TESTING_PATH": "str, path for testing data",
"MNE_DATASETS_VISUAL_92_CATEGORIES_PATH": "str, path for visual_92_categories data",
"MNE_DATASETS_KILOWORD_PATH": "str, path for kiloword data",
"MNE_DATASETS_FIELDTRIP_CMC_PATH": "str, path for fieldtrip_cmc data",
"MNE_DATASETS_PHANTOM_KIT_PATH": "str, path for phantom_kit data",
"MNE_DATASETS_PHANTOM_4DBTI_PATH": "str, path for phantom_4dbti data",
"MNE_DATASETS_PHANTOM_KERNEL_PATH": "str, path for phantom_kernel data",
"MNE_DATASETS_LIMO_PATH": "str, path for limo data",
"MNE_DATASETS_REFMEG_NOISE_PATH": "str, path for refmeg_noise data",
"MNE_DATASETS_SSVEP_PATH": "str, path for ssvep data",
"MNE_DATASETS_ERP_CORE_PATH": "str, path for erp_core data",
"MNE_FORCE_SERIAL": "bool, force serial rather than parallel execution",
"MNE_LOGGING_LEVEL": (
"str or int, controls the level of verbosity of any function "
"decorated with @verbose. See "
"https://mne.tools/stable/auto_tutorials/intro/50_configure_mne.html#logging"
),
"MNE_MEMMAP_MIN_SIZE": (
"str, threshold on the minimum size of arrays passed to the workers that "
"triggers automated memory mapping, e.g., 1M or 0.5G"
),
"MNE_REPR_HTML": (
"bool, represent some of our objects with rich HTML in a notebook "
"environment"
),
"MNE_SKIP_NETWORK_TESTS": (
"bool, used in a test decorator (@requires_good_network) to skip "
"tests that include large downloads"
),
"MNE_SKIP_TESTING_DATASET_TESTS": (
"bool, used in test decorators (@requires_spm_data, "
"@requires_bstraw_data) to skip tests that require specific datasets"
),
"MNE_STIM_CHANNEL": "string, the default channel name for mne.find_events",
"MNE_TQDM": (
'str, either "tqdm", "tqdm.auto", or "off". Controls presence/absence '
"of progress bars"
),
"MNE_USE_CUDA": "bool, use GPU for filtering/resampling",
"MNE_USE_NUMBA": (
"bool, use Numba just-in-time compiler for some of our intensive "
"computations"
),
"SUBJECTS_DIR": "path-like, directory of freesurfer MRI files for each subject",
}
# These allow for partial matches, e.g. 'MNE_STIM_CHANNEL_1' is okay key
_known_config_wildcards = (
"MNE_STIM_CHANNEL", # can have multiple stim channels
"MNE_DATASETS_FNIRS", # mne-nirs
"MNE_NIRS", # mne-nirs
"MNE_KIT2FIFF", # mne-kit-gui
"MNE_ICALABEL", # mne-icalabel
"MNE_LSL", # mne-lsl
)
def _load_config(config_path, raise_error=False):
"""Safely load a config file."""
with open(config_path) as fid:
try:
config = json.load(fid)
except ValueError:
# No JSON object could be decoded --> corrupt file?
msg = (
f"The MNE-Python config file ({config_path}) is not a valid JSON "
"file and might be corrupted"
)
if raise_error:
raise RuntimeError(msg)
warn(msg)
config = dict()
return config
def get_config_path(home_dir=None):
r"""Get path to standard mne-python config file.
Parameters
----------
home_dir : str | None
The folder that contains the .mne config folder.
If None, it is found automatically.
Returns
-------
config_path : str
The path to the mne-python configuration file. On windows, this
will be '%USERPROFILE%\.mne\mne-python.json'. On every other
system, this will be ~/.mne/mne-python.json.
"""
val = op.join(_get_extra_data_path(home_dir=home_dir), "mne-python.json")
return val
def get_config(key=None, default=None, raise_error=False, home_dir=None, use_env=True):
"""Read MNE-Python preferences from environment or config file.
Parameters
----------
key : None | str
The preference key to look for. The os environment is searched first,
then the mne-python config file is parsed.
If None, all the config parameters present in environment variables or
the path are returned. If key is an empty string, a list of all valid
keys (but not values) is returned.
default : str | None
Value to return if the key is not found.
raise_error : bool
If True, raise an error if the key is not found (instead of returning
default).
home_dir : str | None
The folder that contains the .mne config folder.
If None, it is found automatically.
use_env : bool
If True, consider env vars, if available.
If False, only use MNE-Python configuration file values.
.. versionadded:: 0.18
Returns
-------
value : dict | str | None
The preference key value.
See Also
--------
set_config
"""
_validate_type(key, (str, type(None)), "key", "string or None")
if key == "":
# These are str->str (immutable) so we should just copy the dict
# itself, no need for deepcopy
return _known_config_types.copy()
# first, check to see if key is in env
if use_env and key is not None and key in os.environ:
return os.environ[key]
# second, look for it in mne-python config file
config_path = get_config_path(home_dir=home_dir)
if not op.isfile(config_path):
config = {}
else:
config = _load_config(config_path)
if key is None:
# update config with environment variables
if use_env:
env_keys = set(config).union(_known_config_types).intersection(os.environ)
config.update({key: os.environ[key] for key in env_keys})
return config
elif raise_error is True and key not in config:
loc_env = "the environment or in the " if use_env else ""
meth_env = (
(f'either os.environ["{key}"] = VALUE for a temporary solution, or ')
if use_env
else ""
)
extra_env = (
" You can also set the environment variable before running python."
if use_env
else ""
)
meth_file = (
f'mne.utils.set_config("{key}", VALUE, set_env=True) for a permanent one'
)
raise KeyError(
f'Key "{key}" not found in {loc_env}'
f"the mne-python config file ({config_path}). "
f"Try {meth_env}{meth_file}.{extra_env}"
)
else:
return config.get(key, default)
def set_config(key, value, home_dir=None, set_env=True):
"""Set a MNE-Python preference key in the config file and environment.
Parameters
----------
key : str
The preference key to set.
value : str | None
The value to assign to the preference key. If None, the key is
deleted.
home_dir : str | None
The folder that contains the .mne config folder.
If None, it is found automatically.
set_env : bool
If True (default), update :data:`os.environ` in addition to
updating the MNE-Python config file.
See Also
--------
get_config
"""
_validate_type(key, "str", "key")
# While JSON allow non-string types, we allow users to override config
# settings using env, which are strings, so we enforce that here
_validate_type(value, (str, "path-like", type(None)), "value")
if value is not None:
value = str(value)
if key not in _known_config_types and not any(
key.startswith(k) for k in _known_config_wildcards
):
warn(f'Setting non-standard config type: "{key}"')
# Read all previous values
config_path = get_config_path(home_dir=home_dir)
if op.isfile(config_path):
config = _load_config(config_path, raise_error=True)
else:
config = dict()
logger.info(
f"Attempting to create new mne-python configuration file:\n{config_path}"
)
if value is None:
config.pop(key, None)
if set_env and key in os.environ:
del os.environ[key]
else:
config[key] = value
if set_env:
os.environ[key] = value
if key == "MNE_BROWSER_BACKEND":
from ..viz._figure import set_browser_backend
set_browser_backend(value)
# Write all values. This may fail if the default directory is not
# writeable.
directory = op.dirname(config_path)
if not op.isdir(directory):
os.mkdir(directory)
with open(config_path, "w") as fid:
json.dump(config, fid, sort_keys=True, indent=0)
def _get_extra_data_path(home_dir=None):
"""Get path to extra data (config, tables, etc.)."""
global _temp_home_dir
if home_dir is None:
home_dir = os.environ.get("_MNE_FAKE_HOME_DIR")
if home_dir is None:
# this has been checked on OSX64, Linux64, and Win32
if "nt" == os.name.lower():
APPDATA_DIR = os.getenv("APPDATA")
USERPROFILE_DIR = os.getenv("USERPROFILE")
if APPDATA_DIR is not None and op.isdir(
op.join(APPDATA_DIR, ".mne")
): # backward-compat
home_dir = APPDATA_DIR
elif USERPROFILE_DIR is not None:
home_dir = USERPROFILE_DIR
else:
raise FileNotFoundError(
"The USERPROFILE environment variable is not set, cannot "
"determine the location of the MNE-Python configuration "
"folder"
)
del APPDATA_DIR, USERPROFILE_DIR
else:
# This is a more robust way of getting the user's home folder on
# Linux platforms (not sure about OSX, Unix or BSD) than checking
# the HOME environment variable. If the user is running some sort
# of script that isn't launched via the command line (e.g. a script
# launched via Upstart) then the HOME environment variable will
# not be set.
if os.getenv("MNE_DONTWRITE_HOME", "") == "true":
if _temp_home_dir is None:
_temp_home_dir = tempfile.mkdtemp()
atexit.register(
partial(shutil.rmtree, _temp_home_dir, ignore_errors=True)
)
home_dir = _temp_home_dir
else:
home_dir = os.path.expanduser("~")
if home_dir is None:
raise ValueError(
"mne-python config file path could "
"not be determined, please report this "
"error to mne-python developers"
)
return op.join(home_dir, ".mne")
def get_subjects_dir(subjects_dir=None, raise_error=False):
"""Safely use subjects_dir input to return SUBJECTS_DIR.
Parameters
----------
subjects_dir : path-like | None
If a value is provided, return subjects_dir. Otherwise, look for
SUBJECTS_DIR config and return the result.
raise_error : bool
If True, raise a KeyError if no value for SUBJECTS_DIR can be found
(instead of returning None).
Returns
-------
value : Path | None
The SUBJECTS_DIR value.
"""
from_config = False
if subjects_dir is None:
subjects_dir = get_config("SUBJECTS_DIR", raise_error=raise_error)
from_config = True
if subjects_dir is not None:
subjects_dir = Path(subjects_dir)
if subjects_dir is not None:
# Emit a nice error or warning if their config is bad
try:
subjects_dir = _check_fname(
fname=subjects_dir,
overwrite="read",
must_exist=True,
need_dir=True,
name="subjects_dir",
)
except FileNotFoundError:
if from_config:
msg = (
"SUBJECTS_DIR in your MNE-Python configuration or environment "
"does not exist, consider using mne.set_config to fix it: "
f"{subjects_dir}"
)
if raise_error:
raise FileNotFoundError(msg) from None
else:
warn(msg)
elif raise_error:
raise
return subjects_dir
@fill_doc
def _get_stim_channel(stim_channel, info, raise_error=True):
"""Determine the appropriate stim_channel.
First, 'MNE_STIM_CHANNEL', 'MNE_STIM_CHANNEL_1', 'MNE_STIM_CHANNEL_2', etc.
are read. If these are not found, it will fall back to 'STI 014' if
present, then fall back to the first channel of type 'stim', if present.
Parameters
----------
stim_channel : str | list of str | None
The stim channel selected by the user.
%(info_not_none)s
Returns
-------
stim_channel : list of str
The name of the stim channel(s) to use
"""
from .._fiff.pick import pick_types
if stim_channel is not None:
if not isinstance(stim_channel, list):
_validate_type(stim_channel, "str", "Stim channel")
stim_channel = [stim_channel]
for channel in stim_channel:
_validate_type(channel, "str", "Each provided stim channel")
return stim_channel
stim_channel = list()
ch_count = 0
ch = get_config("MNE_STIM_CHANNEL")
while ch is not None and ch in info["ch_names"]:
stim_channel.append(ch)
ch_count += 1
ch = get_config(f"MNE_STIM_CHANNEL_{ch_count}")
if ch_count > 0:
return stim_channel
if "STI101" in info["ch_names"]: # combination channel for newer systems
return ["STI101"]
if "STI 014" in info["ch_names"]: # for older systems
return ["STI 014"]
stim_channel = pick_types(info, meg=False, ref_meg=False, stim=True)
if len(stim_channel) == 0 and raise_error:
raise ValueError(
"No stim channels found. Consider specifying them "
"manually using the 'stim_channel' parameter."
)
stim_channel = [info["ch_names"][ch_] for ch_ in stim_channel]
return stim_channel
def _get_root_dir():
"""Get as close to the repo root as possible."""
root_dir = Path(__file__).parents[1]
up_dir = root_dir.parent
if (up_dir / "setup.py").is_file() and all(
(up_dir / x).is_dir() for x in ("mne", "examples", "doc")
):
root_dir = up_dir
return root_dir
def _get_numpy_libs():
bad_lib = "unknown linalg bindings"
try:
from threadpoolctl import threadpool_info
except Exception as exc:
return bad_lib + f" (threadpoolctl module not found: {exc})"
pools = threadpool_info()
rename = dict(
openblas="OpenBLAS",
mkl="MKL",
)
for pool in pools:
if pool["internal_api"] in ("openblas", "mkl"):
return (
f'{rename[pool["internal_api"]]} '
f'{pool["version"]} with '
f'{pool["num_threads"]} thread{_pl(pool["num_threads"])}'
)
return bad_lib
_gpu_cmd = """\
from pyvista import GPUInfo; \
gi = GPUInfo(); \
print(gi.version); \
print(gi.renderer)"""
@lru_cache(maxsize=1)
def _get_gpu_info():
# Once https://github.com/pyvista/pyvista/pull/2250 is merged and PyVista
# does a release, we can triage based on version > 0.33.2
proc = subprocess.run(
[sys.executable, "-c", _gpu_cmd], check=False, capture_output=True
)
out = proc.stdout.decode().strip().replace("\r", "").split("\n")
if proc.returncode or len(out) != 2:
return None, None
return out
def _get_total_memory():
"""Return the total memory of the system in bytes."""
if platform.system() == "Windows":
o = subprocess.check_output(
[
"powershell.exe",
"(Get-CimInstance Win32_ComputerSystem).TotalPhysicalMemory",
]
).decode()
total_memory = int(o)
elif platform.system() == "Linux":
o = subprocess.check_output(["free", "-b"]).decode()
total_memory = int(o.splitlines()[1].split()[1])
elif platform.system() == "Darwin":
o = subprocess.check_output(["sysctl", "hw.memsize"]).decode()
total_memory = int(o.split(":")[1].strip())
else:
raise UnknownPlatformError("Could not determine total memory")
return total_memory
def _get_cpu_brand():
"""Return the CPU brand string."""
if platform.system() == "Windows":
o = subprocess.check_output(
["powershell.exe", "(Get-CimInstance Win32_Processor).Name"]
).decode()
cpu_brand = o.strip().splitlines()[-1]
elif platform.system() == "Linux":
o = subprocess.check_output(["grep", "model name", "/proc/cpuinfo"]).decode()
cpu_brand = o.splitlines()[0].split(": ")[1]
elif platform.system() == "Darwin":
o = subprocess.check_output(["sysctl", "machdep.cpu"]).decode()
cpu_brand = o.split("brand_string: ")[1].strip()
else:
cpu_brand = "?"
return cpu_brand
def sys_info(
fid=None,
show_paths=False,
*,
dependencies="user",
unicode="auto",
check_version=True,
):
"""Print system information.
This function prints system information useful when triaging bugs.
Parameters
----------
fid : file-like | None
The file to write to. Will be passed to :func:`print()`. Can be None to
use :data:`sys.stdout`.
show_paths : bool
If True, print paths for each module.
dependencies : 'user' | 'developer'
Show dependencies relevant for users (default) or for developers
(i.e., output includes additional dependencies).
unicode : bool | "auto"
Include Unicode symbols in output. If "auto", corresponds to True on Linux and
macOS, and False on Windows.
.. versionadded:: 0.24
check_version : bool | float
If True (default), attempt to check that the version of MNE-Python is up to date
with the latest release on GitHub. Can be a float to give a different timeout
(in sec) from the default (2 sec).
.. versionadded:: 1.6
"""
_validate_type(dependencies, str)
_check_option("dependencies", dependencies, ("user", "developer"))
_validate_type(check_version, (bool, "numeric"), "check_version")
_validate_type(unicode, (bool, str), "unicode")
_check_option("unicode", unicode, ("auto", True, False))
if unicode == "auto":
if platform.system() in ("Darwin", "Linux"):
unicode = True
else: # Windows
unicode = False
ljust = 24 if dependencies == "developer" else 21
platform_str = platform.platform()
out = partial(print, end="", file=fid)
out("Platform".ljust(ljust) + platform_str + "\n")
out("Python".ljust(ljust) + str(sys.version).replace("\n", " ") + "\n")
out("Executable".ljust(ljust) + sys.executable + "\n")
try:
cpu_brand = _get_cpu_brand()
except Exception:
cpu_brand = "?"
out("CPU".ljust(ljust) + f"{cpu_brand} ")
out(f"({multiprocessing.cpu_count()} cores)\n")
out("Memory".ljust(ljust))
try:
total_memory = _get_total_memory()
except UnknownPlatformError:
total_memory = "?"
else:
total_memory = f"{total_memory / 1024**3:.1f}" # convert to GiB
out(f"{total_memory} GiB\n")
out("\n")
ljust -= 3 # account for +/- symbols
libs = _get_numpy_libs()
unavailable = []
use_mod_names = (
"# Core",
"mne",
"numpy",
"scipy",
"matplotlib",
"",
"# Numerical (optional)",
"sklearn",
"numba",
"nibabel",
"nilearn",
"dipy",
"openmeeg",
"cupy",
"pandas",
"h5io",
"h5py",
"",
"# Visualization (optional)",
"pyvista",
"pyvistaqt",
"vtk",
"qtpy",
"ipympl",
"pyqtgraph",
"mne-qt-browser",
"ipywidgets",
# "trame", # no version, see https://github.com/Kitware/trame/issues/183
"trame_client",
"trame_server",
"trame_vtk",
"trame_vuetify",
"",
"# Ecosystem (optional)",
"mne-bids",
"mne-nirs",
"mne-features",
"mne-connectivity",
"mne-icalabel",
"mne-bids-pipeline",
"neo",
"eeglabio",
"edfio",
"mffpy",
"pybv",
"",
)
if dependencies == "developer":
use_mod_names += (
"# Testing",
"pytest",
"statsmodels",
"numpydoc",
"flake8",
"jupyter_client",
"nbclient",
"nbformat",
"pydocstyle",
"nitime",
"imageio",
"imageio-ffmpeg",
"snirf",
"",
"# Documentation",
"sphinx",
"sphinx-gallery",
"pydata-sphinx-theme",
"",
"# Infrastructure",
"decorator",
"jinja2",
# "lazy-loader",
"packaging",
"pooch",
"tqdm",
"",
)
try:
unicode = unicode and (sys.stdout.encoding.lower().startswith("utf"))
except Exception: # in case someone overrides sys.stdout in an unsafe way
unicode = False
mne_version_good = True
for mi, mod_name in enumerate(use_mod_names):
# upcoming break
if mod_name == "": # break
if unavailable:
out("└☐ " if unicode else " - ")
out("unavailable".ljust(ljust))
out(f"{', '.join(unavailable)}\n")
unavailable = []
if mi != len(use_mod_names) - 1:
out("\n")
continue
elif mod_name.startswith("# "): # header
mod_name = mod_name.replace("# ", "")
out(f"{mod_name}\n")
continue
pre = ""
last = use_mod_names[mi + 1] == "" and not unavailable
if last:
pre = ""
try:
mod = import_module(mod_name.replace("-", "_"))
except Exception:
unavailable.append(mod_name)
else:
mark = "" if unicode else "+"
mne_extra = ""
if mod_name == "mne" and check_version:
timeout = 2.0 if check_version is True else float(check_version)
mne_version_good, mne_extra = _check_mne_version(timeout)
if mne_version_good is None:
mne_version_good = True
elif not mne_version_good:
mark = "" if unicode else "X"
out(f"{pre}{mark} " if unicode else f" {mark} ")
out(f"{mod_name}".ljust(ljust))
if mod_name == "vtk":
vtk_version = mod.vtkVersion()
# 9.0 dev has VersionFull but 9.0 doesn't
for attr in ("GetVTKVersionFull", "GetVTKVersion"):
if hasattr(vtk_version, attr):
version = getattr(vtk_version, attr)()
if version != "":
out(version)
break
else:
out("unknown")
else:
out(mod.__version__.lstrip("v"))
if mod_name == "numpy":
out(f" ({libs})")
elif mod_name == "qtpy":
version, api = _check_qt_version(return_api=True)
out(f" ({api}={version})")
elif mod_name == "matplotlib":
out(f" (backend={mod.get_backend()})")
elif mod_name == "pyvista":
version, renderer = _get_gpu_info()
if version is None:
out(" (OpenGL unavailable)")
else:
out(f" (OpenGL {version} via {renderer})")
elif mod_name == "mne":
out(f" ({mne_extra})")
# Now comes stuff after the version
if show_paths:
if last:
pre = " "
elif unicode:
pre = ""
else:
pre = " | "
out(f'\n{pre}{" " * ljust}{op.dirname(mod.__file__)}')
out("\n")
if not mne_version_good:
out(
"\nTo update to the latest supported release version to get bugfixes and "
"improvements, visit "
"https://mne.tools/stable/install/updating.html\n"
)
def _get_latest_version(timeout):
# Bandit complains about urlopen, but we know the URL here
url = "https://api.github.com/repos/mne-tools/mne-python/releases/latest"
try:
with urlopen(url, timeout=timeout) as f: # nosec
response = json.load(f)
except (URLError, TimeoutError) as err:
# Triage error type
if "SSL" in str(err):
return "SSL error"
elif "timed out" in str(err):
return f"timeout after {timeout} sec"
else:
return f"unknown error: {err}"
else:
return response["tag_name"].lstrip("v") or "version unknown"
def _check_mne_version(timeout):
rel_ver = _get_latest_version(timeout)
if not rel_ver[0].isnumeric():
return None, (f"unable to check for latest version on GitHub, {rel_ver}")
rel_ver = parse(rel_ver)
this_ver = parse(import_module("mne").__version__)
if this_ver > rel_ver:
return True, f"devel, latest release is {rel_ver}"
if this_ver == rel_ver:
return True, "latest release"
else:
return False, f"outdated, release {rel_ver} is available!"

131
mne/utils/dataframe.py Normal file
View File

@@ -0,0 +1,131 @@
"""inst.to_data_frame() helper functions."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from inspect import signature
import numpy as np
from ..defaults import _handle_default
from ._logging import logger, verbose
from .check import check_version
@verbose
def _set_pandas_dtype(df, columns, dtype, verbose=None):
"""Try to set the right columns to dtype."""
for column in columns:
df[column] = df[column].astype(dtype)
logger.info(f'Converting "{column}" to "{dtype}"...')
def _scale_dataframe_data(inst, data, picks, scalings):
ch_types = inst.get_channel_types()
ch_types_used = list()
scalings = _handle_default("scalings", scalings)
for tt in scalings.keys():
if tt in ch_types:
ch_types_used.append(tt)
for tt in ch_types_used:
scaling = scalings[tt]
idx = [ii for ii in range(len(picks)) if ch_types[ii] == tt]
if len(idx):
data[:, idx] *= scaling
return data
def _convert_times(times, time_format, meas_date=None, first_time=0):
"""Convert vector of time in seconds to ms, datetime, or timedelta."""
# private function; pandas already checked in calling function
from pandas import to_timedelta
if time_format == "ms":
times = np.round(times * 1e3).astype(np.int64)
elif time_format == "timedelta":
times = to_timedelta(times, unit="s")
elif time_format == "datetime":
times = to_timedelta(times + first_time, unit="s") + meas_date
return times
def _inplace(df, method, **kwargs):
# Handle transition: inplace=True (pandas <1.5) → copy=False (>=1.5)
# and 3.0 warning:
# E DeprecationWarning: The copy keyword is deprecated and will be removed in a
# future version. Copy-on-Write is active in pandas since 3.0 which utilizes a
# lazy copy mechanism that defers copies until necessary. Use .copy() to make
# an eager copy if necessary.
_meth = getattr(df, method) # used for set_index() and rename()
if check_version("pandas", "3.0"):
return _meth(**kwargs)
elif "copy" in signature(_meth).parameters:
return _meth(**kwargs, copy=False)
else:
_meth(**kwargs, inplace=True)
return df
@verbose
def _build_data_frame(
inst,
data,
picks,
long_format,
mindex,
index,
default_index,
col_names=None,
col_kind="channel",
verbose=None,
):
"""Build DataFrame from MNE-object-derived data array."""
# private function; pandas already checked in calling function
from pandas import DataFrame
from ..source_estimate import _BaseSourceEstimate
# build DataFrame
if col_names is None:
col_names = [inst.ch_names[p] for p in picks]
df = DataFrame(data, columns=col_names)
for i, (k, v) in enumerate(mindex):
df.insert(i, k, v)
# build Index
if long_format:
df = _inplace(df, "set_index", keys=default_index)
df.columns.name = col_kind
elif index is not None:
df = _inplace(df, "set_index", keys=index)
if set(index) == set(default_index):
df.columns.name = col_kind
# long format
if long_format:
df = df.stack().reset_index()
df = _inplace(df, "rename", columns={0: "value"})
# add column for channel types (as appropriate)
ch_map = (
None
if isinstance(inst, _BaseSourceEstimate)
else dict(
zip(
np.array(inst.ch_names)[picks],
np.array(inst.get_channel_types())[picks],
)
)
)
if ch_map is not None:
col_index = len(df.columns) - 1
ch_type = df["channel"].map(ch_map)
df.insert(col_index, "ch_type", ch_type)
# restore index
if index is not None:
df = _inplace(df, "set_index", keys=index)
# convert channel/vertex/ch_type columns to factors
to_factor = [
c for c in df.columns.tolist() if c not in ("freq", "time", "value")
]
_set_pandas_dtype(df, to_factor, "category")
return df

5609
mne/utils/docs.py Normal file

File diff suppressed because it is too large Load Diff

19
mne/utils/fetching.py Normal file
View File

@@ -0,0 +1,19 @@
"""File downloading functions."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import os
def _url_to_local_path(url, path):
"""Mirror a url path in a local destination (keeping folder structure)."""
from urllib import parse, request
destination = parse.urlparse(url).path
# First char should be '/', and it needs to be discarded
if len(destination) < 2 or destination[0] != "/":
raise ValueError("Invalid URL")
destination = os.path.join(path, request.url2pathname(destination)[1:])
return destination

243
mne/utils/linalg.py Normal file
View File

@@ -0,0 +1,243 @@
"""Utility functions to speed up linear algebraic operations.
In general, things like np.dot and linalg.svd should be used directly
because they are smart about checking for bad values. However, in cases where
things are done repeatedly (e.g., thousands of times on tiny matrices), the
overhead can become problematic from a performance standpoint. Examples:
- Optimization routines:
- Dipole fitting
- Sparse solving
- cHPI fitting
- Inverse computation
- Beamformers (LCMV/DICS)
- eLORETA minimum norm
Significant performance gains can be achieved by ensuring that inputs
are Fortran contiguous because that's what LAPACK requires. Without this,
inputs will be memcopied.
"""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import functools
import numpy as np
from scipy import linalg
from scipy._lib._util import _asarray_validated
from ..fixes import _safe_svd
# For efficiency, names should be str or tuple of str, dtype a builtin
# NumPy dtype
@functools.lru_cache(None)
def _get_blas_funcs(dtype, names):
return linalg.get_blas_funcs(names, (np.empty(0, dtype),))
@functools.lru_cache(None)
def _get_lapack_funcs(dtype, names):
assert dtype in (np.float64, np.complex128)
x = np.empty(0, dtype)
return linalg.get_lapack_funcs(names, (x,))
###############################################################################
# linalg.svd and linalg.pinv2
def _svd_lwork(shape, dtype=np.float64):
"""Set up SVD calculations on identical-shape float64/complex128 arrays."""
try:
ds = linalg._decomp_svd
except AttributeError: # < 1.8.0
ds = linalg.decomp_svd
gesdd_lwork, gesvd_lwork = _get_lapack_funcs(dtype, ("gesdd_lwork", "gesvd_lwork"))
sdd_lwork = ds._compute_lwork(
gesdd_lwork, *shape, compute_uv=True, full_matrices=False
)
svd_lwork = ds._compute_lwork(
gesvd_lwork, *shape, compute_uv=True, full_matrices=False
)
return sdd_lwork, svd_lwork
def _repeated_svd(x, lwork, overwrite_a=False):
"""Mimic scipy.linalg.svd, avoid lwork and get_lapack_funcs overhead."""
gesdd, gesvd = _get_lapack_funcs(x.dtype, ("gesdd", "gesvd"))
# this has to use overwrite_a=False in case we need to fall back to gesvd
u, s, v, info = gesdd(
x, compute_uv=True, lwork=lwork[0], full_matrices=False, overwrite_a=False
)
if info > 0:
# Fall back to slower gesvd, sometimes gesdd fails
u, s, v, info = gesvd(
x,
compute_uv=True,
lwork=lwork[1],
full_matrices=False,
overwrite_a=overwrite_a,
)
if info > 0:
raise np.linalg.LinAlgError("SVD did not converge")
if info < 0:
raise ValueError(f"illegal value in {-info}-th argument of internal gesdd")
return u, s, v
###############################################################################
# linalg.eigh
@functools.lru_cache(None)
def _get_evd(dtype):
x = np.empty(0, dtype)
if dtype == np.float64:
driver = "syevd"
else:
assert dtype == np.complex128
driver = "heevd"
(evr,) = linalg.get_lapack_funcs((driver,), (x,))
return evr, driver
def eigh(a, overwrite_a=False, check_finite=True):
"""Efficient wrapper for eigh.
Parameters
----------
a : ndarray, shape (n_components, n_components)
The symmetric array operate on.
overwrite_a : bool
If True, the contents of a can be overwritten for efficiency.
check_finite : bool
If True, check that all elements are finite.
Returns
-------
w : ndarray, shape (n_components,)
The N eigenvalues, in ascending order, each repeated according to
its multiplicity.
v : ndarray, shape (n_components, n_components)
The normalized eigenvector corresponding to the eigenvalue ``w[i]``
is the column ``v[:, i]``.
"""
# We use SYEVD, see https://github.com/scipy/scipy/issues/9212
if check_finite:
a = _asarray_validated(a, check_finite=check_finite)
evd, driver = _get_evd(a.dtype)
w, v, info = evd(a, lower=1, overwrite_a=overwrite_a)
if info == 0:
return w, v
if info < 0:
raise ValueError(f"illegal value in argument {-info} of internal {driver}")
else:
raise linalg.LinAlgError(
"internal fortran routine failed to converge: "
f"{info} off-diagonal elements of an "
"intermediate tridiagonal form did not converge"
" to zero."
)
def sqrtm_sym(A, rcond=1e-7, inv=False):
"""Compute the sqrt of a positive, semi-definite matrix (or its inverse).
Parameters
----------
A : ndarray, shape (..., n, n)
The array to take the square root of.
rcond : float
The relative condition number used during reconstruction.
inv : bool
If True, compute the inverse of the square root rather than the
square root itself.
Returns
-------
A_sqrt : ndarray, shape (..., n, n)
The (possibly inverted) square root of A.
s : ndarray, shape (..., n)
The original square root singular values (not inverted).
"""
# Same as linalg.sqrtm(C) but faster, also yields the eigenvalues
return _sym_mat_pow(A, -0.5 if inv else 0.5, rcond, return_s=True)
def _sym_mat_pow(A, power, rcond=1e-7, reduce_rank=False, return_s=False):
"""Exponentiate Hermitian matrices with optional rank reduction."""
assert power in (-1, 0.5, -0.5) # only used internally
s, u = np.linalg.eigh(A) # eigenvalues in ascending order
# Is it positive semi-defidite? If so, keep real
limit = s[..., -1:] * rcond
if not (s >= -limit).all(): # allow some tiny small negative ones
raise ValueError("Matrix is not positive semi-definite")
s[s <= limit] = np.inf if power < 0 else 0
if reduce_rank:
# These are ordered smallest to largest, so we set the first one
# to inf -- then the 1. / s below will turn this to zero, as needed.
s[..., 0] = np.inf
if power in (-0.5, 0.5):
np.sqrt(s, out=s)
use_s = 1.0 / s if power < 0 else s
out = np.matmul(u * use_s[..., np.newaxis, :], u.swapaxes(-2, -1).conj())
if return_s:
out = (out, s)
return out
# SciPy deprecation of pinv + pinvh rcond (never worked properly anyway)
def pinvh(a, rtol=None):
"""Compute a pseudo-inverse of a Hermitian matrix.
Parameters
----------
a : ndarray, shape (n, n)
The Hermitian array to invert.
rtol : float | None
The relative tolerance.
Returns
-------
a_pinv : ndarray, shape (n, n)
The pseudo-inverse of a.
"""
s, u = np.linalg.eigh(a)
del a
if rtol is None:
rtol = s.size * np.finfo(s.dtype).eps
maxS = np.max(np.abs(s))
above_cutoff = abs(s) > maxS * rtol
psigma_diag = 1.0 / s[above_cutoff]
u = u[:, above_cutoff]
return (u * psigma_diag) @ u.conj().T
def pinv(a, rtol=None):
"""Compute a pseudo-inverse of a matrix.
Parameters
----------
a : ndarray, shape (n, m)
The array to invert.
rtol : float | None
The relative tolerance.
Returns
-------
a_pinv : ndarray, shape (m, n)
The pseudo-inverse of a.
"""
u, s, vh = _safe_svd(a, full_matrices=False)
del a
maxS = np.max(s)
if rtol is None:
rtol = max(vh.shape + u.shape) * np.finfo(u.dtype).eps
rank = np.sum(s > maxS * rtol)
u = u[:, :rank]
u /= s[:rank]
return (u @ vh[:rank]).conj().T

509
mne/utils/misc.py Normal file
View File

@@ -0,0 +1,509 @@
"""Some miscellaneous utility functions."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import fnmatch
import gc
import hashlib
import inspect
import os
import subprocess
import sys
import traceback
import weakref
from contextlib import ExitStack, contextmanager
from importlib.resources import files
from math import log
from queue import Empty, Queue
from string import Formatter
from textwrap import dedent
from threading import Thread
import numpy as np
from decorator import FunctionMaker
from ._logging import logger, verbose, warn
from .check import _check_option, _validate_type
def _identity_function(x):
return x
# TODO: no longer needed when py3.9 is minimum supported version
def _empty_hash(kind="md5"):
func = getattr(hashlib, kind)
if "usedforsecurity" in inspect.signature(func).parameters:
return func(usedforsecurity=False)
else:
return func()
def _pl(x, non_pl="", pl="s"):
"""Determine if plural should be used."""
len_x = x if isinstance(x, int | np.generic) else len(x)
return non_pl if len_x == 1 else pl
def _explain_exception(start=-1, stop=None, prefix="> "):
"""Explain an exception."""
# start=-1 means "only the most recent caller"
etype, value, tb = sys.exc_info()
string = traceback.format_list(traceback.extract_tb(tb)[start:stop])
string = "".join(string).split("\n") + traceback.format_exception_only(etype, value)
string = ":\n" + prefix + ("\n" + prefix).join(string)
return string
def _sort_keys(x):
"""Sort and return keys of dict."""
keys = list(x.keys()) # note: not thread-safe
idx = np.argsort([str(k) for k in keys])
keys = [keys[ii] for ii in idx]
return keys
class _DefaultEventParser:
"""Parse none standard events."""
def __init__(self):
self.event_ids = dict()
def __call__(self, description, offset=1):
if description not in self.event_ids:
self.event_ids[description] = offset + len(self.event_ids)
return self.event_ids[description]
class _FormatDict(dict):
"""Help pformat() work properly."""
def __missing__(self, key):
return "{" + key + "}"
def pformat(temp, **fmt):
"""Format a template string partially.
Examples
--------
>>> pformat("{a}_{b}", a='x')
'x_{b}'
"""
formatter = Formatter()
mapping = _FormatDict(fmt)
return formatter.vformat(temp, (), mapping)
def _enqueue_output(out, queue):
for line in iter(out.readline, b""):
queue.put(line)
@verbose
def run_subprocess(command, return_code=False, verbose=None, *args, **kwargs):
"""Run command using subprocess.Popen.
Run command and wait for command to complete. If the return code was zero
then return, otherwise raise CalledProcessError.
By default, this will also add stdout= and stderr=subproces.PIPE
to the call to Popen to suppress printing to the terminal.
Parameters
----------
command : list of str | str
Command to run as subprocess (see subprocess.Popen documentation).
return_code : bool
If True, return the return code instead of raising an error if it's
non-zero.
.. versionadded:: 0.20
%(verbose)s
*args, **kwargs : arguments
Additional arguments to pass to subprocess.Popen.
Returns
-------
stdout : str
Stdout returned by the process.
stderr : str
Stderr returned by the process.
code : int
The return code, only returned if ``return_code == True``.
"""
all_out = ""
all_err = ""
# non-blocking adapted from https://stackoverflow.com/questions/375427/non-blocking-read-on-a-subprocess-pipe-in-python#4896288 # noqa: E501
out_q = Queue()
err_q = Queue()
control_stdout = "stdout" not in kwargs
control_stderr = "stderr" not in kwargs
with running_subprocess(command, *args, **kwargs) as p:
if control_stdout:
out_t = Thread(target=_enqueue_output, args=(p.stdout, out_q))
out_t.daemon = True
out_t.start()
if control_stderr:
err_t = Thread(target=_enqueue_output, args=(p.stderr, err_q))
err_t.daemon = True
err_t.start()
while True:
do_break = p.poll() is not None
# read all current lines without blocking
while True: # process stdout
try:
out = out_q.get(timeout=0.01)
except Empty:
break
else:
out = out.decode("utf-8")
log_out = out.removesuffix("\n")
logger.info(log_out)
all_out += out
while True: # process stderr
try:
err = err_q.get(timeout=0.01)
except Empty:
break
else:
err = err.decode("utf-8")
err_out = err.removesuffix("\n")
# Leave this as logger.warning rather than warn(...) to
# mirror the logger.info above for stdout. This function
# is basically just a version of subprocess.call, and
# shouldn't emit Python warnings due to stderr outputs
# (the calling function can check for stderr output and
# emit a warning if it wants).
logger.warning(err_out)
all_err += err
if do_break:
break
output = (all_out, all_err)
if return_code:
output = output + (p.returncode,)
elif p.returncode:
stdout = all_out if control_stdout else None
stderr = all_err if control_stderr else None
raise subprocess.CalledProcessError(
p.returncode, command, output=stdout, stderr=stderr
)
return output
@contextmanager
def running_subprocess(command, after="wait", verbose=None, *args, **kwargs):
"""Context manager to do something with a command running via Popen.
Parameters
----------
command : list of str | str
Command to run as subprocess (see :class:`python:subprocess.Popen`).
after : str
Can be:
- "wait" to use :meth:`~python:subprocess.Popen.wait`
- "communicate" to use :meth:`~python.subprocess.Popen.communicate`
- "terminate" to use :meth:`~python:subprocess.Popen.terminate`
- "kill" to use :meth:`~python:subprocess.Popen.kill`
%(verbose)s
*args, **kwargs : arguments
Additional arguments to pass to subprocess.Popen.
Returns
-------
p : instance of Popen
The process.
"""
_validate_type(after, str, "after")
_check_option("after", after, ["wait", "terminate", "kill", "communicate"])
contexts = list()
for stdxxx in ("stderr", "stdout"):
if stdxxx not in kwargs:
kwargs[stdxxx] = subprocess.PIPE
contexts.append(stdxxx)
# Check the PATH environment variable. If run_subprocess() is to be called
# frequently this should be refactored so as to only check the path once.
env = kwargs.get("env", os.environ)
if any(p.startswith("~") for p in env["PATH"].split(os.pathsep)):
warn(
"Your PATH environment variable contains at least one path "
'starting with a tilde ("~") character. Such paths are not '
"interpreted correctly from within Python. It is recommended "
'that you use "$HOME" instead of "~".'
)
if isinstance(command, str):
command_str = command
else:
command = [str(s) for s in command]
command_str = " ".join(s for s in command)
logger.info(f"Running subprocess: {command_str}")
try:
p = subprocess.Popen(command, *args, **kwargs)
except Exception:
if isinstance(command, str):
command_name = command.split()[0]
else:
command_name = command[0]
logger.error(f"Command not found: {command_name}")
raise
try:
with ExitStack() as stack:
for context in contexts:
stack.enter_context(getattr(p, context))
yield p
finally:
getattr(p, after)()
p.wait()
def _clean_names(names, remove_whitespace=False, before_dash=True):
"""Remove white-space on topo matching.
This function handles different naming conventions for old VS new VectorView systems
(`remove_whitespace`) and removes system specific parts in CTF channel names
(`before_dash`).
Usage
-----
# for new VectorView (only inside layout)
ch_names = _clean_names(epochs.ch_names, remove_whitespace=True)
# for CTF
ch_names = _clean_names(epochs.ch_names, before_dash=True)
"""
cleaned = []
for name in names:
if " " in name and remove_whitespace:
name = name.replace(" ", "")
if "-" in name and before_dash:
name = name.split("-")[0]
if name.endswith("_v"):
name = name[:-2]
cleaned.append(name)
if len(set(cleaned)) != len(names):
# this was probably not a VectorView or CTF dataset, and we now broke the
# dataset by creating duplicates, so let's use the original channel names.
return names
return cleaned
def _get_argvalues():
"""Return all arguments (except self) and values of read_raw_xxx."""
# call stack
# read_raw_xxx -> <decorator-gen-000> -> BaseRaw.__init__ -> _get_argvalues
# This is equivalent to `frame = inspect.stack(0)[4][0]` but faster
frame = inspect.currentframe()
try:
for _ in range(3):
frame = frame.f_back
fname = frame.f_code.co_filename
if not fnmatch.fnmatch(fname, "*/mne/io/*"):
return None
args, _, _, values = inspect.getargvalues(frame)
finally:
del frame
params = dict()
for arg in args:
params[arg] = values[arg]
params.pop("self", None)
return params
def sizeof_fmt(num):
"""Turn number of bytes into human-readable str.
Parameters
----------
num : int
The number of bytes.
Returns
-------
size : str
The size in human-readable format.
"""
units = ["bytes", "KiB", "MiB", "GiB", "TiB", "PiB"]
decimals = [0, 0, 1, 2, 2, 2]
if num > 1:
exponent = min(int(log(num, 1024)), len(units) - 1)
quotient = float(num) / 1024**exponent
unit = units[exponent]
num_decimals = decimals[exponent]
format_string = f"{{0:.{num_decimals}f}} {{1}}"
return format_string.format(quotient, unit)
if num == 0:
return "0 bytes"
if num == 1:
return "1 byte"
def _file_like(obj):
# An alternative would be::
#
# isinstance(obj, (TextIOBase, BufferedIOBase, RawIOBase, IOBase))
#
# but this might be more robust to file-like objects not properly
# inheriting from these classes:
return all(callable(getattr(obj, name, None)) for name in ("read", "seek"))
def _fullname(obj):
klass = obj.__class__
module = klass.__module__
if module == "builtins":
return klass.__qualname__
return module + "." + klass.__qualname__
def _assert_no_instances(cls, when=""):
__tracebackhide__ = True
n = 0
ref = list()
gc.collect()
objs = gc.get_objects()
for obj in objs:
try:
check = isinstance(obj, cls)
except Exception: # such as a weakref
check = False
if check:
if cls.__name__ == "Brain":
ref.append(f'Brain._cleaned = {getattr(obj, "_cleaned", None)}')
rr = gc.get_referrers(obj)
count = 0
for r in rr:
if (
r is not objs
and r is not globals()
and r is not locals()
and not inspect.isframe(r)
):
if isinstance(r, list | dict | tuple):
rep = f"len={len(r)}"
r_ = gc.get_referrers(r)
types = (_fullname(x) for x in r_)
types = "/".join(sorted(set(x for x in types if x is not None)))
rep += f", {len(r_)} referrers: {types}"
del r_
else:
rep = repr(r)[:100].replace("\n", " ")
# If it's a __closure__, get more information
if rep.startswith("<cell at "):
try:
rep += f" ({repr(r.cell_contents)[:100]})"
except Exception:
pass
name = _fullname(r)
ref.append(f"{name}: {rep}")
count += 1
del r
del rr
n += count > 0
del obj
del objs
gc.collect()
assert n == 0, f"\n{n} {cls.__name__} @ {when}:\n" + "\n".join(ref)
def _resource_path(submodule, filename):
"""Return a full system path to a package resource (AKA a file).
Parameters
----------
submodule : str
An import-style module or submodule name
(e.g., "mne.datasets.testing").
filename : str
The file whose full path you want.
Returns
-------
path : str
The full system path to the requested file.
"""
return files(submodule).joinpath(filename)
def repr_html(f):
"""Decorate _repr_html_ methods.
If a _repr_html_ method is decorated with this decorator, the repr in a
notebook will show HTML or plain text depending on the config value
MNE_REPR_HTML (by default "true", which will render HTML).
Parameters
----------
f : function
The function to decorate.
Returns
-------
wrapper : function
The decorated function.
"""
from ..utils import get_config
def wrapper(*args, **kwargs):
if get_config("MNE_REPR_HTML", "true").lower() == "false":
import html
r = "<pre>" + html.escape(repr(args[0])) + "</pre>"
return r.replace("\n", "<br/>")
else:
return f(*args, **kwargs)
return wrapper
def _auto_weakref(function):
"""Create weakrefs to self (or other free vars in __closure__) then evaluate.
When a nested function is defined within an instance method, and the function makes
use of ``self``, it creates a reference cycle that the Python garbage collector is
not smart enough to resolve, so the parent object is never GC'd. (The reference to
``self`` becomes part of the ``__closure__`` of the nested function).
This decorator allows the nested function to access ``self`` without increasing the
reference counter on ``self``, which will prevent the memory leak. If the referent
is not found (usually because already GC'd) it will short-circuit the decorated
function and return ``None``.
"""
names = function.__code__.co_freevars
assert len(names) == len(function.__closure__)
__weakref_values__ = dict()
evaldict = dict(__weakref_values__=__weakref_values__)
for name, value in zip(names, function.__closure__):
__weakref_values__[name] = weakref.ref(value.cell_contents)
body = dedent(inspect.getsource(function))
body = body.splitlines()
for li, line in enumerate(body):
if line.startswith(" "):
body = body[li:]
break
old_body = "\n".join(body)
body = """\
def %(name)s(%(signature)s):
"""
for name in names:
body += f"""
{name} = __weakref_values__[{repr(name)}]()
if {name} is None:
return
"""
body = body + old_body
fm = FunctionMaker(function)
fun = fm.make(body, evaldict, addsource=True)
fun.__globals__.update(function.__globals__)
assert fun.__closure__ is None, fun.__closure__
return fun

781
mne/utils/mixin.py Normal file
View File

@@ -0,0 +1,781 @@
"""Some utility functions."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import json
import logging
from collections import OrderedDict
from copy import deepcopy
import numpy as np
from ._logging import verbose, warn
from ._typing import Self
from .check import _check_pandas_installed, _check_preload, _validate_type
from .numerics import _time_mask, object_hash, object_size
logger = logging.getLogger("mne") # one selection here used across mne-python
logger.propagate = False # don't propagate (in case of multiple imports)
class SizeMixin:
"""Estimate MNE object sizes."""
def __eq__(self, other):
"""Compare self to other.
Parameters
----------
other : object
The object to compare to.
Returns
-------
eq : bool
True if the two objects are equal.
"""
return isinstance(other, type(self)) and hash(self) == hash(other)
@property
def _size(self):
"""Estimate the object size."""
try:
size = object_size(self.info)
except Exception:
warn("Could not get size for self.info")
return -1
if hasattr(self, "data"):
size += object_size(self.data)
elif hasattr(self, "_data"):
size += object_size(self._data)
return size
def __hash__(self):
"""Hash the object.
Returns
-------
hash : int
The hash
"""
from ..epochs import BaseEpochs
from ..evoked import Evoked
from ..io import BaseRaw
if isinstance(self, Evoked):
return object_hash(dict(info=self.info, data=self.data))
elif isinstance(self, BaseEpochs | BaseRaw):
_check_preload(self, "Hashing ")
return object_hash(dict(info=self.info, data=self._data))
else:
raise RuntimeError(f"Hashing unknown object type: {type(self)}")
class GetEpochsMixin:
"""Class to add epoch selection and metadata to certain classes."""
def __getitem__(
self: Self,
item,
) -> Self:
"""Return an Epochs object with a copied subset of epochs.
Parameters
----------
item : int | slice | array-like | str
See Notes for use cases.
Returns
-------
epochs : instance of Epochs
The subset of epochs.
Notes
-----
Epochs can be accessed as ``epochs[...]`` in several ways:
1. **Integer or slice:** ``epochs[idx]`` will return an `~mne.Epochs`
object with a subset of epochs chosen by index (supports single
index and Python-style slicing).
2. **String:** ``epochs['name']`` will return an `~mne.Epochs` object
comprising only the epochs labeled ``'name'`` (i.e., epochs created
around events with the label ``'name'``).
If there are no epochs labeled ``'name'`` but there are epochs
labeled with /-separated tags (e.g. ``'name/left'``,
``'name/right'``), then ``epochs['name']`` will select the epochs
with labels that contain that tag (e.g., ``epochs['left']`` selects
epochs labeled ``'audio/left'`` and ``'visual/left'``, but not
``'audio_left'``).
If multiple tags are provided *as a single string* (e.g.,
``epochs['name_1/name_2']``), this selects epochs containing *all*
provided tags. For example, ``epochs['audio/left']`` selects
``'audio/left'`` and ``'audio/quiet/left'``, but not
``'audio/right'``. Note that tag-based selection is insensitive to
order: tags like ``'audio/left'`` and ``'left/audio'`` will be
treated the same way when selecting via tag.
3. **List of strings:** ``epochs[['name_1', 'name_2', ... ]]`` will
return an `~mne.Epochs` object comprising epochs that match *any* of
the provided names (i.e., the list of names is treated as an
inclusive-or condition). If *none* of the provided names match any
epoch labels, a ``KeyError`` will be raised.
If epoch labels are /-separated tags, then providing multiple tags
*as separate list entries* will likewise act as an inclusive-or
filter. For example, ``epochs[['audio', 'left']]`` would select
``'audio/left'``, ``'audio/right'``, and ``'visual/left'``, but not
``'visual/right'``.
4. **Pandas query:** ``epochs['pandas query']`` will return an
`~mne.Epochs` object with a subset of epochs (and matching
metadata) selected by the query called with
``self.metadata.eval``, e.g.::
epochs["col_a > 2 and col_b == 'foo'"]
would return all epochs whose associated ``col_a`` metadata was
greater than two, and whose ``col_b`` metadata was the string 'foo'.
Query-based indexing only works if Pandas is installed and
``self.metadata`` is a :class:`pandas.DataFrame`.
.. versionadded:: 0.16
"""
return self._getitem(item)
def _item_to_select(self, item):
if isinstance(item, str):
item = [item]
# Convert string to indices
if (
isinstance(item, list | tuple)
and len(item) > 0
and isinstance(item[0], str)
):
select = self._keys_to_idx(item)
elif isinstance(item, slice):
select = item
else:
select = np.atleast_1d(item)
if len(select) == 0:
select = np.array([], int)
return select
def _getitem(
self,
item,
reason="IGNORED",
copy=True,
drop_event_id=True,
select_data=True,
return_indices=False,
):
"""
Select epochs from current object.
Parameters
----------
item: slice, array-like, str, or list
see `__getitem__` for details.
reason: str, list/tuple of str
entry in `drop_log` for unselected epochs
copy: bool
return a copy of the current object
drop_event_id: bool
remove non-existing event-ids after selection
select_data: bool
apply selection to data
(use `select_data=False` if subclasses do not have a
valid `_data` field, or data has already been subselected)
return_indices: bool
return the indices of selected epochs from the original object
in addition to the new `Epochs` objects
Returns
-------
`Epochs` or tuple(Epochs, np.ndarray) if `return_indices` is True
subset of epochs (and optionally array with kept epoch indices)
"""
inst = self.copy() if copy else self
if self._data is not None:
np.copyto(inst._data, self._data, casting="no")
del self
select = inst._item_to_select(item)
has_selection = hasattr(inst, "selection")
if has_selection:
key_selection = inst.selection[select]
drop_log = list(inst.drop_log)
if reason is not None:
_validate_type(reason, (list, tuple, str), "reason")
if isinstance(reason, list | tuple):
for r in reason:
_validate_type(r, str, r)
if isinstance(reason, str):
reason = (reason,)
reason = tuple(reason)
for idx in np.setdiff1d(inst.selection, key_selection):
drop_log[idx] = reason
inst.drop_log = tuple(drop_log)
inst.selection = key_selection
del drop_log
inst.events = np.atleast_2d(inst.events[select])
if inst.metadata is not None:
pd = _check_pandas_installed(strict=False)
if pd:
metadata = inst.metadata.iloc[select]
if has_selection:
metadata.index = inst.selection
else:
metadata = np.array(inst.metadata, "object")[select].tolist()
# will reset the index for us
GetEpochsMixin.metadata.fset(inst, metadata, verbose=False)
if inst.preload and select_data:
# ensure that each Epochs instance owns its own data so we can
# resize later if necessary
inst._data = np.require(inst._data[select], requirements=["O"])
if drop_event_id:
# update event id to reflect new content of inst
inst.event_id = {
k: v for k, v in inst.event_id.items() if v in inst.events[:, 2]
}
if return_indices:
return inst, select
else:
return inst
def _keys_to_idx(self, keys):
"""Find entries in event dict."""
from ..event import match_event_names # avoid circular import
keys = keys if isinstance(keys, list | tuple) else [keys]
try:
# Assume it's a condition name
return np.where(
np.any(
np.array(
[
self.events[:, 2] == self.event_id[k]
for k in match_event_names(self.event_id, keys)
]
),
axis=0,
)
)[0]
except KeyError as err:
# Could we in principle use metadata with these Epochs and keys?
if len(keys) != 1 or self.metadata is None:
# If not, raise original error
raise
msg = str(err.args[0]) # message for KeyError
pd = _check_pandas_installed(strict=False)
# See if the query can be done
if pd:
md = self.metadata if hasattr(self, "_metadata") else None
self._check_metadata(metadata=md)
try:
# Try metadata
vals = (
self.metadata.reset_index()
.query(keys[0], engine="python")
.index.values
)
except Exception as exp:
msg += (
" The epochs.metadata Pandas query did not "
f"yield any results: {exp.args[0]}"
)
else:
return vals
else:
# If not, warn this might be a problem
msg += (
" The epochs.metadata Pandas query could not "
"be performed, consider installing Pandas."
)
raise KeyError(msg)
def __len__(self):
"""Return the number of epochs.
Returns
-------
n_epochs : int
The number of remaining epochs.
Notes
-----
This function only works if bad epochs have been dropped.
Examples
--------
This can be used as::
>>> epochs.drop_bad() # doctest: +SKIP
>>> len(epochs) # doctest: +SKIP
43
>>> len(epochs.events) # doctest: +SKIP
43
"""
from ..epochs import BaseEpochs
if isinstance(self, BaseEpochs) and not self._bad_dropped:
raise RuntimeError(
"Since bad epochs have not been dropped, the "
"length of the Epochs is not known. Load the "
"Epochs with preload=True, or call "
"Epochs.drop_bad(). To find the number "
"of events in the Epochs, use "
"len(Epochs.events)."
)
return len(self.events)
def __iter__(self):
"""Facilitate iteration over epochs.
This method resets the object iteration state to the first epoch.
Notes
-----
This enables the use of this Python pattern::
>>> for epoch in epochs: # doctest: +SKIP
>>> print(epoch) # doctest: +SKIP
Where ``epoch`` is given by successive outputs of
:meth:`mne.Epochs.next`.
"""
self._current = 0
self._current_detrend_picks = self._detrend_picks
return self
def __next__(self, return_event_id=False):
"""Iterate over epoch data.
Parameters
----------
return_event_id : bool
If True, return both the epoch data and an event_id.
Returns
-------
epoch : array of shape (n_channels, n_times)
The epoch data.
event_id : int
The event id. Only returned if ``return_event_id`` is ``True``.
"""
if not hasattr(self, "_current_detrend_picks"):
self.__iter__() # ensure we're ready to iterate
if self.preload:
if self._current >= len(self._data):
self._stop_iter()
epoch = self._data[self._current]
self._current += 1
else:
is_good = False
while not is_good:
if self._current >= len(self.events):
self._stop_iter()
epoch_noproj = self._get_epoch_from_raw(self._current)
epoch_noproj = self._detrend_offset_decim(
epoch_noproj, self._current_detrend_picks
)
epoch = self._project_epoch(epoch_noproj)
self._current += 1
is_good, _ = self._is_good_epoch(epoch)
# If delayed-ssp mode, pass 'virgin' data after rejection decision.
if self._do_delayed_proj:
epoch = epoch_noproj
if not return_event_id:
return epoch
else:
return epoch, self.events[self._current - 1][-1]
def _stop_iter(self):
del self._current
del self._current_detrend_picks
raise StopIteration # signal the end
next = __next__ # originally for Python2, now b/c public
def _check_metadata(self, metadata=None, reset_index=False):
"""Check metadata consistency."""
# reset_index=False will not copy!
if metadata is None:
return
else:
pd = _check_pandas_installed(strict=False)
if pd:
_validate_type(metadata, types=pd.DataFrame, item_name="metadata")
if len(metadata) != len(self.events):
raise ValueError(
"metadata must have the same number of "
f"rows ({len(metadata)}) as events ({len(self.events)})"
)
if reset_index:
if hasattr(self, "selection"):
# makes a copy
metadata = metadata.reset_index(drop=True)
metadata.index = self.selection
else:
metadata = deepcopy(metadata)
else:
_validate_type(metadata, types=list, item_name="metadata")
if reset_index:
metadata = deepcopy(metadata)
return metadata
@property
def metadata(self):
"""Get the metadata."""
return self._metadata
@metadata.setter
@verbose
def metadata(self, metadata, verbose=None):
metadata = self._check_metadata(metadata, reset_index=True)
if metadata is not None:
if _check_pandas_installed(strict=False):
n_col = metadata.shape[1]
else:
n_col = len(metadata[0])
n_col = f" with {n_col} columns"
else:
n_col = ""
if hasattr(self, "_metadata") and self._metadata is not None:
action = "Removing" if metadata is None else "Replacing"
action += " existing"
else:
action = "Not setting" if metadata is None else "Adding"
logger.info(f"{action} metadata{n_col}")
self._metadata = metadata
def _check_decim(info, decim, offset, check_filter=True):
"""Check decimation parameters."""
if decim < 1 or decim != int(decim):
raise ValueError("decim must be an integer > 0")
decim = int(decim)
new_sfreq = info["sfreq"] / float(decim)
offset = int(offset)
if not 0 <= offset < decim:
raise ValueError(
f"decim must be at least 0 and less than {decim}, got {offset}"
)
if check_filter:
lowpass = info["lowpass"]
if decim > 1 and lowpass is None:
warn(
"The measurement information indicates data is not low-pass "
f"filtered. The decim={decim} parameter will result in a "
f"sampling frequency of {new_sfreq} Hz, which can cause "
"aliasing artifacts."
)
elif decim > 1 and new_sfreq < 3 * lowpass:
warn(
"The measurement information indicates a low-pass frequency "
f"of {lowpass} Hz. The decim={decim} parameter will result "
f"in a sampling frequency of {new_sfreq} Hz, which can "
"cause aliasing artifacts."
) # > 50% nyquist lim
return decim, offset, new_sfreq
class TimeMixin:
"""Class for time operations on any MNE object that has a time axis."""
def time_as_index(self, times, use_rounding=False):
"""Convert time to indices.
Parameters
----------
times : list-like | float | int
List of numbers or a number representing points in time.
use_rounding : bool
If True, use rounding (instead of truncation) when converting
times to indices. This can help avoid non-unique indices.
Returns
-------
index : ndarray
Indices corresponding to the times supplied.
"""
from ..source_estimate import _BaseSourceEstimate
if isinstance(self, _BaseSourceEstimate):
sfreq = 1.0 / self.tstep
else:
sfreq = self.info["sfreq"]
index = (np.atleast_1d(times) - self.times[0]) * sfreq
if use_rounding:
index = np.round(index)
return index.astype(int)
def _handle_tmin_tmax(self, tmin, tmax):
"""Convert seconds to index into data.
Parameters
----------
tmin : int | float | None
Start time of data to get in seconds.
tmax : int | float | None
End time of data to get in seconds.
Returns
-------
start : int
Integer index into data corresponding to tmin.
stop : int
Integer index into data corresponding to tmax.
"""
_validate_type(
tmin,
types=("numeric", None),
item_name="tmin",
type_name="int, float, None",
)
_validate_type(
tmax,
types=("numeric", None),
item_name="tmax",
type_name="int, float, None",
)
# handle tmin/tmax as start and stop indices into data array
n_times = self.times.size
start = 0 if tmin is None else self.time_as_index(tmin)[0]
stop = n_times if tmax is None else self.time_as_index(tmax)[0]
# truncate start/stop to the open interval [0, n_times]
start = min(max(0, start), n_times)
stop = min(max(0, stop), n_times)
return start, stop
@property
def times(self):
"""Time vector in seconds."""
return self._times_readonly
def _set_times(self, times):
"""Set self._times_readonly (and make it read only)."""
# naming used to indicate that it shouldn't be
# changed directly, but rather via this method
self._times_readonly = times.copy()
self._times_readonly.flags["WRITEABLE"] = False
class ExtendedTimeMixin(TimeMixin):
"""Class for time operations on epochs/evoked-like MNE objects."""
@property
def tmin(self):
"""First time point."""
return self.times[0]
@property
def tmax(self):
"""Last time point."""
return self.times[-1]
@verbose
def crop(self, tmin=None, tmax=None, include_tmax=True, verbose=None):
"""Crop data to a given time interval.
Parameters
----------
tmin : float | None
Start time of selection in seconds.
tmax : float | None
End time of selection in seconds.
%(include_tmax)s
%(verbose)s
Returns
-------
inst : instance of Raw, Epochs, Evoked, AverageTFR, or SourceEstimate
The cropped time-series object, modified in-place.
Notes
-----
%(notes_tmax_included_by_default)s
"""
t_vars = dict(tmin=tmin, tmax=tmax)
for name, t_var in t_vars.items():
_validate_type(
t_var,
types=("numeric", None),
item_name=name,
)
if tmin is None:
tmin = self.tmin
elif tmin < self.tmin:
warn(
f"tmin is not in time interval. tmin is set to "
f"{type(self)}.tmin ({self.tmin:g} s)"
)
tmin = self.tmin
if tmax is None:
tmax = self.tmax
elif tmax > self.tmax:
warn(
f"tmax is not in time interval. tmax is set to "
f"{type(self)}.tmax ({self.tmax:g} s)"
)
tmax = self.tmax
include_tmax = True
mask = _time_mask(
self.times, tmin, tmax, sfreq=self.info["sfreq"], include_tmax=include_tmax
)
self._set_times(self.times[mask])
self._raw_times = self._raw_times[mask]
self._update_first_last()
self._data = self._data[..., mask]
return self
@verbose
def decimate(self, decim, offset=0, *, verbose=None):
"""Decimate the time-series data.
Parameters
----------
%(decim)s
%(offset_decim)s
%(verbose)s
Returns
-------
inst : MNE-object
The decimated object.
See Also
--------
mne.Epochs.resample
mne.io.Raw.resample
Notes
-----
%(decim_notes)s
If ``decim`` is 1, this method does not copy the underlying data.
.. versionadded:: 0.10.0
References
----------
.. footbibliography::
"""
# if epochs have frequencies, they are not in time (EpochsTFR)
# and so do not need to be checked whether they have been
# appropriately filtered to avoid aliasing
from ..epochs import BaseEpochs
from ..evoked import Evoked
from ..time_frequency import BaseTFR
# This should be the list of classes that inherit
_validate_type(self, (BaseEpochs, Evoked, BaseTFR), "inst")
decim, offset, new_sfreq = _check_decim(
self.info, decim, offset, check_filter=not hasattr(self, "freqs")
)
start_idx = int(round(-self._raw_times[0] * (self.info["sfreq"] * self._decim)))
self._decim *= decim
i_start = start_idx % self._decim + offset
decim_slice = slice(i_start, None, self._decim)
with self.info._unlock():
self.info["sfreq"] = new_sfreq
if self.preload:
if decim != 1:
self._data = self._data[..., decim_slice].copy()
self._raw_times = self._raw_times[decim_slice].copy()
else:
self._data = np.ascontiguousarray(self._data)
self._decim_slice = slice(None)
self._decim = 1
else:
self._decim_slice = decim_slice
self._set_times(self._raw_times[self._decim_slice])
self._update_first_last()
return self
def shift_time(self, tshift, relative=True):
"""Shift time scale in epoched or evoked data.
Parameters
----------
tshift : float
The (absolute or relative) time shift in seconds. If ``relative``
is True, positive tshift increases the time value associated with
each sample, while negative tshift decreases it.
relative : bool
If True, increase or decrease time values by ``tshift`` seconds.
Otherwise, shift the time values such that the time of the first
sample equals ``tshift``.
Returns
-------
epochs : MNE-object
The modified instance.
Notes
-----
This method allows you to shift the *time* values associated with each
data sample by an arbitrary amount. It does *not* resample the signal
or change the *data* values in any way.
"""
_check_preload(self, "shift_time")
start = tshift + (self.times[0] if relative else 0.0)
new_times = start + np.arange(len(self.times)) / self.info["sfreq"]
self._set_times(new_times)
self._update_first_last()
return self
def _update_first_last(self):
"""Update self.first and self.last (sample indices)."""
from ..dipole import DipoleFixed
from ..evoked import Evoked
if isinstance(self, Evoked | DipoleFixed):
self.first = int(round(self.times[0] * self.info["sfreq"]))
self.last = len(self.times) + self.first - 1
def _prepare_write_metadata(metadata):
"""Convert metadata to JSON for saving."""
if metadata is not None:
if not isinstance(metadata, list):
metadata = metadata.reset_index().to_json(orient="records")
else: # Pandas DataFrame
metadata = json.dumps(metadata)
assert isinstance(metadata, str)
return metadata
def _prepare_read_metadata(metadata):
"""Convert saved metadata back from JSON."""
if metadata is not None:
pd = _check_pandas_installed(strict=False)
# use json.loads because this preserves ordering
# (which is necessary for round-trip equivalence)
metadata = json.loads(metadata, object_pairs_hook=OrderedDict)
assert isinstance(metadata, list)
if pd:
metadata = pd.DataFrame.from_records(metadata)
if "index" in metadata.columns:
metadata.set_index("index", inplace=True)
assert isinstance(metadata, pd.DataFrame)
return metadata

1119
mne/utils/numerics.py Normal file

File diff suppressed because it is too large Load Diff

213
mne/utils/progressbar.py Normal file
View File

@@ -0,0 +1,213 @@
"""Some utility functions."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import logging
import os
import os.path as op
import tempfile
import time
from collections.abc import Iterable
from threading import Thread
import numpy as np
from ._logging import logger
from .check import _check_option
from .config import get_config
class ProgressBar:
"""Generate a command-line progressbar.
Parameters
----------
iterable : iterable | int | None
The iterable to use. Can also be an int for backward compatibility
(acts like ``max_value``).
initial_value : int
Initial value of process, useful when resuming process from a specific
value, defaults to 0.
mesg : str
Message to include at end of progress bar.
max_total_width : int | str
Maximum total message width. Can use "auto" (default) to try to set
a sane value based on the current terminal width.
max_value : int | None
The max value. If None, the length of ``iterable`` will be used.
which_tqdm : str | None
Which tqdm module to use. Can be "tqdm", "tqdm.notebook", or "off".
Defaults to ``None``, which uses the value of the MNE_TQDM environment
variable, or ``"tqdm.auto"`` if that is not set.
**kwargs : dict
Additional keyword arguments for tqdm.
"""
def __init__(
self,
iterable=None,
initial_value=0,
mesg=None,
max_total_width="auto",
max_value=None,
*,
which_tqdm=None,
**kwargs,
):
# The following mimics this, but with configurable module to use
# from ..externals.tqdm import auto
import tqdm
if which_tqdm is None:
which_tqdm = get_config("MNE_TQDM", "tqdm.auto")
_check_option(
"MNE_TQDM", which_tqdm[:5], ("tqdm", "tqdm.", "off"), extra="beginning"
)
logger.debug(f"Using ProgressBar with {which_tqdm}")
if which_tqdm not in ("tqdm", "off"):
try:
__import__(which_tqdm)
except Exception as exc:
raise ValueError(
f"Unknown tqdm backend {repr(which_tqdm)}, got: {exc}"
) from None
tqdm = getattr(tqdm, which_tqdm.split(".", 1)[1])
tqdm = tqdm.tqdm
defaults = dict(
leave=True,
mininterval=0.016,
miniters=1,
smoothing=0.05,
bar_format="{percentage:3.0f}%|{bar}| {desc} : {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt:>11}{postfix}]", # noqa: E501
)
for key, val in defaults.items():
if key not in kwargs:
kwargs.update({key: val})
if isinstance(iterable, Iterable):
self.iterable = iterable
if max_value is None:
self.max_value = len(iterable)
else:
self.max_value = max_value
else: # ignore max_value then
self.max_value = int(iterable)
self.iterable = None
if max_total_width == "auto":
max_total_width = None # tqdm's auto
with tempfile.NamedTemporaryFile("wb", prefix="tmp_mne_prog") as tf:
self._mmap_fname = tf.name
del tf # should remove the file
self._mmap = None
disable = logger.level > logging.INFO or which_tqdm == "off"
self._tqdm = tqdm(
iterable=self.iterable,
desc=mesg,
total=self.max_value,
initial=initial_value,
ncols=max_total_width,
disable=disable,
**kwargs,
)
def update(self, cur_value):
"""Update progressbar with current value of process.
Parameters
----------
cur_value : number
Current value of process. Should be <= max_value (but this is not
enforced). The percent of the progressbar will be computed as
``(cur_value / max_value) * 100``.
"""
self.update_with_increment_value(cur_value - self._tqdm.n)
def update_with_increment_value(self, increment_value):
"""Update progressbar with an increment.
Parameters
----------
increment_value : int
Value of the increment of process. The percent of the progressbar
will be computed as
``(self.cur_value + increment_value / max_value) * 100``.
"""
try:
self._tqdm.update(increment_value)
except TypeError: # can happen during GC on Windows
pass
def __iter__(self):
"""Iterate to auto-increment the pbar with 1."""
yield from self._tqdm
def subset(self, idx):
"""Make a joblib-friendly index subset updater.
Parameters
----------
idx : ndarray
List of indices for this subset.
Returns
-------
updater : instance of PBSubsetUpdater
Class with a ``.update(ii)`` method.
"""
return _PBSubsetUpdater(self, idx)
def __enter__(self): # noqa: D105
# This should only be used with pb.subset and parallelization
if op.isfile(self._mmap_fname):
os.remove(self._mmap_fname)
# prevent corner cases where self.max_value == 0
self._mmap = np.memmap(
self._mmap_fname, bool, "w+", shape=max(self.max_value, 1)
)
self.update(0) # must be zero as we just created the memmap
# We need to control how the pickled bars exit: remove print statements
self._thread = _UpdateThread(self)
self._thread.start()
return self
def __exit__(self, type_, value, traceback): # noqa: D105
# Restore exit behavior for our one from the main thread
self.update(self._mmap.sum())
self._tqdm.close()
self._thread._mne_run = False
self._thread.join()
self._mmap = None
if op.isfile(self._mmap_fname):
try:
os.remove(self._mmap_fname)
# happens on Windows sometimes
except PermissionError: # pragma: no cover
pass
def __del__(self):
"""Ensure output completes."""
if getattr(self, "_tqdm", None) is not None:
self._tqdm.close()
class _UpdateThread(Thread):
def __init__(self, pb):
super().__init__(daemon=True)
self._mne_run = True
self._mne_pb = pb
def run(self):
while self._mne_run:
self._mne_pb.update(self._mne_pb._mmap.sum())
time.sleep(1.0 / 30.0) # 30 Hz refresh is plenty
class _PBSubsetUpdater:
def __init__(self, pb, idx):
self.mmap = pb._mmap
self.idx = idx
def update(self, ii):
self.mmap[self.idx[ii - 1]] = True

104
mne/utils/spectrum.py Normal file
View File

@@ -0,0 +1,104 @@
"""Utility functions for spectral and spectrotemporal analysis."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from inspect import currentframe, getargvalues, signature
from ..utils import warn
def _get_instance_type_string(inst):
"""Get string representation of the originating instance type."""
from numpy import ndarray
from ..epochs import BaseEpochs
from ..evoked import Evoked, EvokedArray
from ..io import BaseRaw
parent_classes = inst._inst_type.__bases__
if BaseRaw in parent_classes:
inst_type_str = "Raw"
elif BaseEpochs in parent_classes:
inst_type_str = "Epochs"
elif inst._inst_type in (Evoked, EvokedArray):
inst_type_str = "Evoked"
elif inst._inst_type == ndarray:
inst_type_str = "Array"
else:
raise RuntimeError(
f"Unknown instance type {inst._inst_type} in {type(inst).__name__}"
)
return inst_type_str
def _pop_with_fallback(mapping, key, fallback_fun):
"""Pop from a dict and fallback to a function parameter's default value."""
fallback = signature(fallback_fun).parameters[key].default
return mapping.pop(key, fallback)
def _update_old_psd_kwargs(kwargs):
"""Modify passed-in kwargs to match new API.
NOTE: using plot_raw_psd as fallback (even for epochs) is fine because
their kwargs are the same (and will stay the same: both are @legacy funcs).
"""
from ..viz import plot_raw_psd as fallback_fun
may_change = ("axes", "alpha", "ci_alpha", "amplitude", "ci")
for kwarg in may_change:
if kwarg in kwargs:
warn(
"The legacy plot_psd() method got an unexpected keyword argument "
f"'{kwarg}', which is a parameter of Spectrum.plot(). Try rewriting as "
f"object.compute_psd(...).plot(..., {kwarg}=<whatever>)."
)
kwargs.setdefault("axes", _pop_with_fallback(kwargs, "ax", fallback_fun))
kwargs.setdefault("alpha", _pop_with_fallback(kwargs, "line_alpha", fallback_fun))
kwargs.setdefault(
"ci_alpha", _pop_with_fallback(kwargs, "area_alpha", fallback_fun)
)
est = _pop_with_fallback(kwargs, "estimate", fallback_fun)
kwargs.setdefault("amplitude", est == "amplitude")
area_mode = _pop_with_fallback(kwargs, "area_mode", fallback_fun)
kwargs.setdefault("ci", "sd" if area_mode == "std" else area_mode)
def _split_psd_kwargs(*, plot_fun=None, kwargs=None):
from ..io import BaseRaw
from ..time_frequency import Spectrum
# if no kwargs supplied, get them from calling func
if kwargs is None:
frame = currentframe().f_back
arginfo = getargvalues(frame)
kwargs = {k: v for k, v in arginfo.locals.items() if k in arginfo.args}
if arginfo.keywords is not None: # add in **method_kw
kwargs.update(arginfo.locals[arginfo.keywords])
# for compatibility with `plot_raw_psd`, `plot_epochs_psd` and
# `plot_epochs_psd_topomap` functions (not just the instance methods/mixin)
if "raw" in kwargs:
kwargs["self"] = kwargs.pop("raw")
elif "epochs" in kwargs:
kwargs["self"] = kwargs.pop("epochs")
# `reject_by_annotation` not needed for Epochs or Evoked
if not isinstance(kwargs.pop("self", None), BaseRaw):
kwargs.pop("reject_by_annotation", None)
# handle API changes from .plot_psd(...) to .compute_psd(...).plot(...)
if plot_fun is Spectrum.plot:
_update_old_psd_kwargs(kwargs)
# split off the plotting kwargs
plot_kwargs = {
k: v
for k, v in kwargs.items()
if k in signature(plot_fun).parameters and k != "picks"
}
for k in plot_kwargs:
del kwargs[k]
return kwargs, plot_kwargs