initial commit
This commit is contained in:
213
mne/utils/progressbar.py
Normal file
213
mne/utils/progressbar.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Some utility functions."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import os.path as op
|
||||
import tempfile
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._logging import logger
|
||||
from .check import _check_option
|
||||
from .config import get_config
|
||||
|
||||
|
||||
class ProgressBar:
|
||||
"""Generate a command-line progressbar.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
iterable : iterable | int | None
|
||||
The iterable to use. Can also be an int for backward compatibility
|
||||
(acts like ``max_value``).
|
||||
initial_value : int
|
||||
Initial value of process, useful when resuming process from a specific
|
||||
value, defaults to 0.
|
||||
mesg : str
|
||||
Message to include at end of progress bar.
|
||||
max_total_width : int | str
|
||||
Maximum total message width. Can use "auto" (default) to try to set
|
||||
a sane value based on the current terminal width.
|
||||
max_value : int | None
|
||||
The max value. If None, the length of ``iterable`` will be used.
|
||||
which_tqdm : str | None
|
||||
Which tqdm module to use. Can be "tqdm", "tqdm.notebook", or "off".
|
||||
Defaults to ``None``, which uses the value of the MNE_TQDM environment
|
||||
variable, or ``"tqdm.auto"`` if that is not set.
|
||||
**kwargs : dict
|
||||
Additional keyword arguments for tqdm.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
iterable=None,
|
||||
initial_value=0,
|
||||
mesg=None,
|
||||
max_total_width="auto",
|
||||
max_value=None,
|
||||
*,
|
||||
which_tqdm=None,
|
||||
**kwargs,
|
||||
):
|
||||
# The following mimics this, but with configurable module to use
|
||||
# from ..externals.tqdm import auto
|
||||
import tqdm
|
||||
|
||||
if which_tqdm is None:
|
||||
which_tqdm = get_config("MNE_TQDM", "tqdm.auto")
|
||||
_check_option(
|
||||
"MNE_TQDM", which_tqdm[:5], ("tqdm", "tqdm.", "off"), extra="beginning"
|
||||
)
|
||||
logger.debug(f"Using ProgressBar with {which_tqdm}")
|
||||
if which_tqdm not in ("tqdm", "off"):
|
||||
try:
|
||||
__import__(which_tqdm)
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f"Unknown tqdm backend {repr(which_tqdm)}, got: {exc}"
|
||||
) from None
|
||||
tqdm = getattr(tqdm, which_tqdm.split(".", 1)[1])
|
||||
tqdm = tqdm.tqdm
|
||||
defaults = dict(
|
||||
leave=True,
|
||||
mininterval=0.016,
|
||||
miniters=1,
|
||||
smoothing=0.05,
|
||||
bar_format="{percentage:3.0f}%|{bar}| {desc} : {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt:>11}{postfix}]", # noqa: E501
|
||||
)
|
||||
for key, val in defaults.items():
|
||||
if key not in kwargs:
|
||||
kwargs.update({key: val})
|
||||
if isinstance(iterable, Iterable):
|
||||
self.iterable = iterable
|
||||
if max_value is None:
|
||||
self.max_value = len(iterable)
|
||||
else:
|
||||
self.max_value = max_value
|
||||
else: # ignore max_value then
|
||||
self.max_value = int(iterable)
|
||||
self.iterable = None
|
||||
if max_total_width == "auto":
|
||||
max_total_width = None # tqdm's auto
|
||||
with tempfile.NamedTemporaryFile("wb", prefix="tmp_mne_prog") as tf:
|
||||
self._mmap_fname = tf.name
|
||||
del tf # should remove the file
|
||||
self._mmap = None
|
||||
disable = logger.level > logging.INFO or which_tqdm == "off"
|
||||
self._tqdm = tqdm(
|
||||
iterable=self.iterable,
|
||||
desc=mesg,
|
||||
total=self.max_value,
|
||||
initial=initial_value,
|
||||
ncols=max_total_width,
|
||||
disable=disable,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def update(self, cur_value):
|
||||
"""Update progressbar with current value of process.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cur_value : number
|
||||
Current value of process. Should be <= max_value (but this is not
|
||||
enforced). The percent of the progressbar will be computed as
|
||||
``(cur_value / max_value) * 100``.
|
||||
"""
|
||||
self.update_with_increment_value(cur_value - self._tqdm.n)
|
||||
|
||||
def update_with_increment_value(self, increment_value):
|
||||
"""Update progressbar with an increment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
increment_value : int
|
||||
Value of the increment of process. The percent of the progressbar
|
||||
will be computed as
|
||||
``(self.cur_value + increment_value / max_value) * 100``.
|
||||
"""
|
||||
try:
|
||||
self._tqdm.update(increment_value)
|
||||
except TypeError: # can happen during GC on Windows
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate to auto-increment the pbar with 1."""
|
||||
yield from self._tqdm
|
||||
|
||||
def subset(self, idx):
|
||||
"""Make a joblib-friendly index subset updater.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idx : ndarray
|
||||
List of indices for this subset.
|
||||
|
||||
Returns
|
||||
-------
|
||||
updater : instance of PBSubsetUpdater
|
||||
Class with a ``.update(ii)`` method.
|
||||
"""
|
||||
return _PBSubsetUpdater(self, idx)
|
||||
|
||||
def __enter__(self): # noqa: D105
|
||||
# This should only be used with pb.subset and parallelization
|
||||
if op.isfile(self._mmap_fname):
|
||||
os.remove(self._mmap_fname)
|
||||
# prevent corner cases where self.max_value == 0
|
||||
self._mmap = np.memmap(
|
||||
self._mmap_fname, bool, "w+", shape=max(self.max_value, 1)
|
||||
)
|
||||
self.update(0) # must be zero as we just created the memmap
|
||||
|
||||
# We need to control how the pickled bars exit: remove print statements
|
||||
self._thread = _UpdateThread(self)
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, type_, value, traceback): # noqa: D105
|
||||
# Restore exit behavior for our one from the main thread
|
||||
self.update(self._mmap.sum())
|
||||
self._tqdm.close()
|
||||
self._thread._mne_run = False
|
||||
self._thread.join()
|
||||
self._mmap = None
|
||||
if op.isfile(self._mmap_fname):
|
||||
try:
|
||||
os.remove(self._mmap_fname)
|
||||
# happens on Windows sometimes
|
||||
except PermissionError: # pragma: no cover
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure output completes."""
|
||||
if getattr(self, "_tqdm", None) is not None:
|
||||
self._tqdm.close()
|
||||
|
||||
|
||||
class _UpdateThread(Thread):
|
||||
def __init__(self, pb):
|
||||
super().__init__(daemon=True)
|
||||
self._mne_run = True
|
||||
self._mne_pb = pb
|
||||
|
||||
def run(self):
|
||||
while self._mne_run:
|
||||
self._mne_pb.update(self._mne_pb._mmap.sum())
|
||||
time.sleep(1.0 / 30.0) # 30 Hz refresh is plenty
|
||||
|
||||
|
||||
class _PBSubsetUpdater:
|
||||
def __init__(self, pb, idx):
|
||||
self.mmap = pb._mmap
|
||||
self.idx = idx
|
||||
|
||||
def update(self, ii):
|
||||
self.mmap[self.idx[ii - 1]] = True
|
||||
Reference in New Issue
Block a user