initial commit

This commit is contained in:
2025-08-19 09:13:22 -07:00
parent 28464811d6
commit 0977a3e14d
820 changed files with 1003358 additions and 2 deletions

4282
mne/viz/_3d.py Normal file

File diff suppressed because it is too large Load Diff

184
mne/viz/_3d_overlay.py Normal file
View File

@@ -0,0 +1,184 @@
"""Classes to handle overlapping surfaces."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from collections import OrderedDict
import numpy as np
from ..utils import logger
class _Overlay:
def __init__(self, scalars, colormap, rng, opacity, name):
self._scalars = scalars
self._colormap = colormap
assert rng is not None
self._rng = rng
self._opacity = opacity
self._name = name
def to_colors(self):
from matplotlib.colors import Colormap, ListedColormap
from ._3d import _get_cmap
if isinstance(self._colormap, str):
cmap = _get_cmap(self._colormap)
elif isinstance(self._colormap, Colormap):
cmap = self._colormap
else:
cmap = ListedColormap(
self._colormap / 255.0, name=str(type(self._colormap))
)
logger.debug(
f"Color mapping {repr(self._name)} with {cmap.name} "
f"colormap and range {self._rng}"
)
rng = self._rng
assert rng is not None
scalars = self._norm(rng)
colors = cmap(scalars)
if self._opacity is not None:
colors[:, 3] *= self._opacity
return colors
def _norm(self, rng):
if rng[0] == rng[1]:
factor = 1 if rng[0] == 0 else 1e-6 * rng[0]
else:
factor = rng[1] - rng[0]
return (self._scalars - rng[0]) / factor
class _LayeredMesh:
def __init__(self, renderer, vertices, triangles, normals):
self._renderer = renderer
self._vertices = vertices
self._triangles = triangles
self._normals = normals
self._polydata = None
self._actor = None
self._is_mapped = False
self._current_colors = None
self._cached_colors = None
self._overlays = OrderedDict()
self._default_scalars = np.ones(vertices.shape)
self._default_scalars_name = "Data"
def map(self):
kwargs = {
"color": None,
"pickable": True,
"rgba": True,
}
mesh_data = self._renderer.mesh(
x=self._vertices[:, 0],
y=self._vertices[:, 1],
z=self._vertices[:, 2],
triangles=self._triangles,
normals=self._normals,
scalars=self._default_scalars,
**kwargs,
)
self._actor, self._polydata = mesh_data
self._is_mapped = True
def _compute_over(self, B, A):
assert A.ndim == B.ndim == 2
assert A.shape[1] == B.shape[1] == 4
A_w = A[:, 3:] # * 1
B_w = B[:, 3:] * (1 - A_w)
C = A.copy()
C[:, :3] *= A_w
C[:, :3] += B[:, :3] * B_w
C[:, 3:] += B_w
C[:, :3] /= C[:, 3:]
return np.clip(C, 0, 1, out=C)
def _compose_overlays(self):
B = cache = None
for overlay in self._overlays.values():
A = overlay.to_colors()
if B is None:
B = A
else:
cache = B
B = self._compute_over(cache, A)
return B, cache
def add_overlay(self, scalars, colormap, rng, opacity, name):
overlay = _Overlay(
scalars=scalars,
colormap=colormap,
rng=rng,
opacity=opacity,
name=name,
)
self._overlays[name] = overlay
colors = overlay.to_colors()
if self._current_colors is None:
self._current_colors = colors
else:
# save previous colors to cache
self._cached_colors = self._current_colors
self._current_colors = self._compute_over(self._cached_colors, colors)
# apply the texture
self._apply()
def remove_overlay(self, names):
to_update = False
if not isinstance(names, list):
names = [names]
for name in names:
if name in self._overlays:
del self._overlays[name]
to_update = True
if to_update:
self.update()
def _apply(self):
if self._current_colors is None or self._renderer is None:
return
self._polydata[self._default_scalars_name] = self._current_colors
def update(self, colors=None):
if colors is not None and self._cached_colors is not None:
self._current_colors = self._compute_over(self._cached_colors, colors)
else:
self._current_colors, self._cached_colors = self._compose_overlays()
self._apply()
def _clean(self):
mapper = self._actor.GetMapper()
mapper.SetLookupTable(None)
self._actor.SetMapper(None)
self._actor = None
self._polydata = None
self._renderer = None
def update_overlay(self, name, scalars=None, colormap=None, opacity=None, rng=None):
overlay = self._overlays.get(name, None)
if overlay is None:
return
if scalars is not None:
overlay._scalars = scalars
if colormap is not None:
overlay._colormap = colormap
if opacity is not None:
overlay._opacity = opacity
if rng is not None:
overlay._rng = rng
# partial update: use cache if possible
if name == list(self._overlays.keys())[-1]:
self.update(colors=overlay.to_colors())
else: # full update
self.update()

8
mne/viz/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
"""Visualization routines."""
import lazy_loader as lazy
(__getattr__, __dir__, __all__) = lazy.attach_stub(__name__, __file__)

177
mne/viz/__init__.pyi Normal file
View File

@@ -0,0 +1,177 @@
__all__ = [
"Brain",
"ClickableImage",
"EvokedField",
"Figure3D",
"_RAW_CLIP_DEF",
"_get_plot_ch_type",
"_get_presser",
"_plot_sources",
"_scraper",
"add_background_image",
"adjust_axes",
"backends",
"centers_to_edges",
"circular_layout",
"close_3d_figure",
"close_all_3d_figures",
"compare_fiff",
"concatenate_images",
"create_3d_figure",
"eyetracking",
"get_3d_backend",
"get_brain_class",
"get_browser_backend",
"iter_topography",
"link_brains",
"mne_analyze_colormap",
"plot_alignment",
"plot_arrowmap",
"plot_bem",
"plot_brain_colorbar",
"plot_bridged_electrodes",
"plot_ch_adjacency",
"plot_channel_labels_circle",
"plot_chpi_snr",
"plot_compare_evokeds",
"plot_cov",
"plot_csd",
"plot_dipole_amplitudes",
"plot_dipole_locations",
"plot_drop_log",
"plot_epochs",
"plot_epochs_image",
"plot_epochs_psd",
"plot_epochs_psd_topomap",
"plot_events",
"plot_evoked",
"plot_evoked_field",
"plot_evoked_image",
"plot_evoked_joint",
"plot_evoked_topo",
"plot_evoked_topomap",
"plot_evoked_white",
"plot_filter",
"plot_head_positions",
"plot_ica_components",
"plot_ica_overlay",
"plot_ica_properties",
"plot_ica_scores",
"plot_ica_sources",
"plot_ideal_filter",
"plot_layout",
"plot_montage",
"plot_projs_joint",
"plot_projs_topomap",
"plot_raw",
"plot_raw_psd",
"plot_raw_psd_topo",
"plot_regression_weights",
"plot_sensors",
"plot_snr_estimate",
"plot_source_estimates",
"plot_source_spectrogram",
"plot_sparse_source_estimates",
"plot_tfr_topomap",
"plot_topo_image_epochs",
"plot_topomap",
"plot_vector_source_estimates",
"plot_volume_source_estimates",
"set_3d_backend",
"set_3d_options",
"set_3d_title",
"set_3d_view",
"set_browser_backend",
"snapshot_brain_montage",
"ui_events",
"use_3d_backend",
"use_browser_backend",
]
from . import _scraper, backends, eyetracking, ui_events
from ._3d import (
link_brains,
plot_alignment,
plot_brain_colorbar,
plot_dipole_locations,
plot_evoked_field,
plot_head_positions,
plot_source_estimates,
plot_sparse_source_estimates,
plot_vector_source_estimates,
plot_volume_source_estimates,
set_3d_options,
snapshot_brain_montage,
)
from ._brain import Brain
from ._figure import get_browser_backend, set_browser_backend, use_browser_backend
from ._proj import plot_projs_joint
from .backends._abstract import Figure3D
from .backends.renderer import (
close_3d_figure,
close_all_3d_figures,
create_3d_figure,
get_3d_backend,
get_brain_class,
set_3d_backend,
set_3d_title,
set_3d_view,
use_3d_backend,
)
from .circle import circular_layout, plot_channel_labels_circle
from .epochs import plot_drop_log, plot_epochs, plot_epochs_image, plot_epochs_psd
from .evoked import (
plot_compare_evokeds,
plot_evoked,
plot_evoked_image,
plot_evoked_joint,
plot_evoked_topo,
plot_evoked_white,
plot_snr_estimate,
)
from .evoked_field import EvokedField
from .ica import (
_plot_sources,
plot_ica_overlay,
plot_ica_properties,
plot_ica_scores,
plot_ica_sources,
)
from .misc import (
_get_presser,
adjust_axes,
plot_bem,
plot_chpi_snr,
plot_cov,
plot_csd,
plot_dipole_amplitudes,
plot_events,
plot_filter,
plot_ideal_filter,
plot_source_spectrogram,
)
from .montage import plot_montage
from .raw import _RAW_CLIP_DEF, plot_raw, plot_raw_psd, plot_raw_psd_topo
from .topo import iter_topography, plot_topo_image_epochs
from .topomap import (
plot_arrowmap,
plot_bridged_electrodes,
plot_ch_adjacency,
plot_epochs_psd_topomap,
plot_evoked_topomap,
plot_ica_components,
plot_layout,
plot_projs_topomap,
plot_regression_weights,
plot_tfr_topomap,
plot_topomap,
)
from .utils import (
ClickableImage,
_get_plot_ch_type,
add_background_image,
centers_to_edges,
compare_fiff,
concatenate_images,
mne_analyze_colormap,
plot_sensors,
)

View File

@@ -0,0 +1,11 @@
"""Plot Cortex Surface."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from ._brain import Brain, _LayeredMesh
from ._scraper import _BrainScraper
from ._linkviewer import _LinkViewer
__all__ = ["Brain"]

4141
mne/viz/_brain/_brain.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,98 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from ...utils import warn
from ..ui_events import link
class _LinkViewer:
"""Class to link multiple Brain objects."""
def __init__(self, brains, time=True, camera=False, colorbar=True, picking=False):
self.brains = brains
self.leader = self.brains[0] # select a brain as leader
# check time infos
times = [brain._times for brain in brains]
if time and not all(np.allclose(x, times[0]) for x in times):
warn("stc.times do not match, not linking time")
time = False
if camera:
self.link_cameras()
events_to_link = []
if time:
events_to_link.append("time_change")
if colorbar:
events_to_link.append("colormap_range")
for brain in brains[1:]:
link(self.leader, brain, include_events=events_to_link)
if picking:
def _func_add(*args, **kwargs):
for brain in self.brains:
brain._add_vertex_glyph2(*args, **kwargs)
brain.plotter.update()
def _func_remove(*args, **kwargs):
for brain in self.brains:
brain._remove_vertex_glyph2(*args, **kwargs)
# save initial picked points
initial_points = dict()
for hemi in ("lh", "rh"):
initial_points[hemi] = set()
for brain in self.brains:
initial_points[hemi] |= set(brain.picked_points[hemi])
# link the viewers
for brain in self.brains:
brain.clear_glyphs()
brain._add_vertex_glyph2 = brain._add_vertex_glyph
brain._add_vertex_glyph = _func_add
brain._remove_vertex_glyph2 = brain._remove_vertex_glyph
brain._remove_vertex_glyph = _func_remove
# link the initial points
for hemi in initial_points.keys():
if hemi in brain._layered_meshes:
mesh = brain._layered_meshes[hemi]._polydata
for vertex_id in initial_points[hemi]:
self.leader._add_vertex_glyph(hemi, mesh, vertex_id)
def set_fmin(self, value):
self.leader.update_lut(fmin=value)
def set_fmid(self, value):
self.leader.update_lut(fmid=value)
def set_fmax(self, value):
self.leader.update_lut(fmax=value)
def set_time_point(self, value):
self.leader.set_time_point(value)
def set_playback_speed(self, value):
self.leader.set_playback_speed(value)
def toggle_playback(self):
self.leader.toggle_playback()
def link_cameras(self):
from ..backends._pyvista import _add_camera_callback
def _update_camera(vtk_picker, event):
for brain in self.brains:
brain.plotter.update()
camera = self.leader.plotter.camera
_add_camera_callback(camera, _update_camera)
for brain in self.brains:
for renderer in brain.plotter.renderers:
renderer.camera = camera

102
mne/viz/_brain/_scraper.py Normal file
View File

@@ -0,0 +1,102 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import os
import os.path as op
from ._brain import Brain
class _BrainScraper:
"""Scrape Brain objects."""
def __repr__(self):
return "<BrainScraper>"
def __call__(self, block, block_vars, gallery_conf):
rst = ""
for brain in list(block_vars["example_globals"].values()):
# Only need to process if it's a brain with a time_viewer
# with traces on and shown in the same window, otherwise
# PyVista and matplotlib scrapers can just do the work
if (not isinstance(brain, Brain)) or brain._closed:
continue
from matplotlib import animation
from matplotlib import pyplot as plt
from sphinx_gallery.scrapers import matplotlib_scraper
img = brain.screenshot(time_viewer=True)
dpi = 100.0
figsize = (img.shape[1] / dpi, img.shape[0] / dpi)
fig = plt.figure(figsize=figsize, dpi=dpi)
ax = plt.Axes(fig, [0, 0, 1, 1])
fig.add_axes(ax)
img = ax.imshow(img)
movie_key = "# brain.save_movie"
if movie_key in block[1]:
kwargs = dict()
# Parse our parameters
lines = block[1].splitlines()
for li, line in enumerate(block[1].splitlines()):
if line.startswith(movie_key):
line = line[len(movie_key) :].replace("..., ", "")
for ni in range(1, 5): # should be enough
if len(lines) > li + ni and lines[li + ni].startswith(
"# "
):
line = line + lines[li + ni][1:].strip()
else:
break
assert line.startswith("(") and line.endswith(")")
kwargs.update(eval(f"dict{line}")) # nosec B307
for key, default in [
("time_dilation", 4),
("framerate", 24),
("tmin", None),
("tmax", None),
("interpolation", None),
("time_viewer", False),
]:
if key not in kwargs:
kwargs[key] = default
kwargs.pop("filename", None) # always omit this one
if brain.time_viewer:
assert kwargs["time_viewer"], "Must use time_viewer=True"
frames = brain._make_movie_frames(callback=None, **kwargs)
# Turn them into an animation
def func(frame):
img.set_data(frame)
return [img]
anim = animation.FuncAnimation(
fig,
func=func,
frames=frames,
blit=True,
interval=1000.0 / kwargs["framerate"],
)
# Out to sphinx-gallery:
#
# 1. A static image but hide it (useful for carousel)
if animation.FFMpegWriter.isAvailable():
writer = "ffmpeg"
elif animation.ImageMagickWriter.isAvailable():
writer = "imagemagick"
else:
writer = None
static_fname = next(block_vars["image_path_iterator"])
static_fname = static_fname[:-4] + ".gif"
anim.save(static_fname, writer=writer, dpi=dpi)
rel_fname = op.relpath(static_fname, gallery_conf["src_dir"])
rel_fname = rel_fname.replace(os.sep, "/").lstrip("/")
rst += f"\n.. image:: /{rel_fname}\n :class: hidden\n"
# 2. An animation that will be embedded and visible
block_vars["example_globals"]["_brain_anim_"] = anim
brain.close()
rst += matplotlib_scraper(block, block_vars, gallery_conf)
return rst

185
mne/viz/_brain/colormap.py Normal file
View File

@@ -0,0 +1,185 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
def create_lut(cmap, n_colors=256, center=None):
"""Return a colormap suitable for setting as a LUT."""
from .._3d import _get_cmap
assert not (isinstance(cmap, str) and cmap == "auto")
cmap = _get_cmap(cmap)
lut = np.round(cmap(np.linspace(0, 1, n_colors)) * 255.0).astype(np.int64)
return lut
def scale_sequential_lut(lut_table, fmin, fmid, fmax):
"""Scale a sequential colormap."""
assert fmin <= fmid <= fmax # guaranteed by calculate_lut
lut_table_new = lut_table.copy()
n_colors = lut_table.shape[0]
n_colors2 = n_colors // 2
if fmax == fmin:
fmid_idx = 0
else:
fmid_idx = np.clip(
int(np.round(n_colors * ((fmid - fmin) / (fmax - fmin))) - 1),
0,
n_colors - 2,
)
n_left = fmid_idx + 1
n_right = n_colors - n_left
for i in range(4):
lut_table_new[: fmid_idx + 1, i] = np.interp(
np.linspace(0, n_colors2 - 1, n_left), np.arange(n_colors), lut_table[:, i]
)
lut_table_new[fmid_idx + 1 :, i] = np.interp(
np.linspace(n_colors - 1, n_colors2, n_right)[::-1],
np.arange(n_colors),
lut_table[:, i],
)
return lut_table_new
def get_fill_colors(cols, n_fill):
"""Get the fill colors for the middle of divergent colormaps."""
steps = np.linalg.norm(np.diff(cols[:, :3].astype(float), axis=0), axis=1)
ind = np.flatnonzero(steps[1:-1] > steps[[0, -1]].mean() * 3)
if ind.size > 0:
# choose the two colors between which there is the large step
ind = ind[0] + 1
fillcols = np.r_[
np.tile(cols[ind, :], (n_fill // 2, 1)),
np.tile(cols[ind + 1, :], (n_fill - n_fill // 2, 1)),
]
else:
# choose a color from the middle of the colormap
fillcols = np.tile(cols[int(cols.shape[0] / 2), :], (n_fill, 1))
return fillcols
def calculate_lut(lut_table, alpha, fmin, fmid, fmax, center=None, transparent=True):
"""Transparent color map calculation.
A colormap may be sequential or divergent. When the colormap is
divergent indicate this by providing a value for 'center'. The
meanings of fmin, fmid and fmax are different for sequential and
divergent colormaps. A sequential colormap is characterised by::
[fmin, fmid, fmax]
where fmin and fmax define the edges of the colormap and fmid
will be the value mapped to the center of the originally chosen colormap.
A divergent colormap is characterised by::
[center-fmax, center-fmid, center-fmin, center,
center+fmin, center+fmid, center+fmax]
i.e., values between center-fmin and center+fmin will not be shown
while center-fmid will map to the fmid of the first half of the
original colormap and center-fmid to the fmid of the second half.
Parameters
----------
lim_cmap : Colormap
Color map obtained from _process_mapdata.
alpha : float
Alpha value to apply globally to the overlay. Has no effect with mpl
backend.
fmin : float
Min value in colormap.
fmid : float
Intermediate value in colormap.
fmax : float
Max value in colormap.
center : float or None
If not None, center of a divergent colormap, changes the meaning of
fmin, fmax and fmid.
transparent : boolean
if True: use a linear transparency between fmin and fmid and make
values below fmin fully transparent (symmetrically for divergent
colormaps)
Returns
-------
cmap : matplotlib.ListedColormap
Color map with transparency channel.
"""
if not fmin <= fmid <= fmax:
raise ValueError(f"Must have fmin ({fmin}) <= fmid ({fmid}) <= fmax ({fmax})")
lut_table = create_lut(lut_table)
assert lut_table.dtype.kind == "i"
divergent = center is not None
n_colors = lut_table.shape[0]
# Add transparency if needed
n_colors2 = n_colors // 2
if transparent:
if divergent:
N4 = np.full(4, n_colors // 4)
N4[[0, 3, 1, 2][: np.mod(n_colors, 4)]] += 1
assert N4.sum() == n_colors
lut_table[:, -1] = np.round(
np.hstack(
[
np.full(N4[0], 255.0),
np.linspace(0, 255, N4[1])[::-1],
np.linspace(0, 255, N4[2]),
np.full(N4[3], 255.0),
]
)
)
else:
lut_table[:n_colors2, -1] = np.round(np.linspace(0, 255, n_colors2))
lut_table[n_colors2:, -1] = 255
alpha = float(alpha)
if alpha < 1.0:
lut_table[:, -1] = np.round(lut_table[:, -1] * alpha)
if divergent:
if np.isclose(fmax, fmin, rtol=1e-6, atol=0):
lut_table = np.r_[
lut_table[:1],
get_fill_colors(
lut_table[n_colors2 - 3 : n_colors2 + 3, :], n_colors - 2
),
lut_table[-1:],
]
else:
n_fill = int(round(fmin * n_colors2 / (fmax - fmin))) * 2
lut_table = np.r_[
scale_sequential_lut(
lut_table[:n_colors2, :],
center - fmax,
center - fmid,
center - fmin,
),
get_fill_colors(lut_table[n_colors2 - 3 : n_colors2 + 3, :], n_fill),
scale_sequential_lut(
lut_table[n_colors2:, :][::-1],
center - fmax,
center - fmid,
center - fmin,
)[::-1],
]
else:
lut_table = scale_sequential_lut(lut_table, fmin, fmid, fmax)
n_colors = lut_table.shape[0]
if n_colors != 256:
lut = np.zeros((256, 4))
x = np.linspace(1, n_colors, 256)
for chan in range(4):
lut[:, chan] = np.interp(x, np.arange(1, n_colors + 1), lut_table[:, chan])
lut_table = lut
lut_table = lut_table.astype(np.float64) / 255.0
return lut_table

177
mne/viz/_brain/surface.py Normal file
View File

@@ -0,0 +1,177 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from os import path as path
import numpy as np
from ...surface import _read_patch, complete_surface_info, read_curvature, read_surface
from ...utils import _check_fname, _check_option, _validate_type, get_subjects_dir
class _Surface:
"""Container for a brain surface.
It is used for storing vertices, faces and morphometric data
(curvature) of a hemisphere mesh.
Parameters
----------
subject : string
Name of subject
hemi : {'lh', 'rh'}
Which hemisphere to load
surf : string
Name of the surface to load (eg. inflated, orig ...).
subjects_dir : str | None
If not None, this directory will be used as the subjects directory
instead of the value set using the SUBJECTS_DIR environment variable.
offset : float | None
If 0.0, the surface will be offset such that the medial
wall is aligned with the origin. If None, no offset will
be applied. If != 0.0, an additional offset will be used.
units : str
Can be 'm' or 'mm' (default).
x_dir : ndarray | None
The x direction to use for offset alignment.
Attributes
----------
bin_curv : numpy.ndarray
Curvature values stored as non-negative integers.
coords : numpy.ndarray
nvtx x 3 array of vertex (x, y, z) coordinates.
curv : numpy.ndarray
Vector representation of surface morpometry (curvature) values as
loaded from a file.
grey_curv : numpy.ndarray
Normalized morphometry (curvature) data, used in order to get
a gray cortex.
faces : numpy.ndarray
nfaces x 3 array of defining mesh triangles.
hemi : {'lh', 'rh'}
Which hemisphere to load.
nn : numpy.ndarray
Vertex normals for a triangulated surface.
offset : float | None
If float, align inside edge of each hemisphere to center + offset.
If None, do not change coordinates (default).
subject : string
Name of subject.
surf : string
Name of the surface to load (eg. inflated, orig ...).
units : str
Can be 'm' or 'mm' (default).
"""
def __init__(
self,
subject,
hemi,
surf,
subjects_dir=None,
offset=None,
units="mm",
x_dir=None,
):
x_dir = np.array([1.0, 0, 0]) if x_dir is None else x_dir
assert isinstance(x_dir, np.ndarray)
assert np.isclose(np.linalg.norm(x_dir), 1.0, atol=1e-6)
assert hemi in ("lh", "rh")
_validate_type(offset, (None, "numeric"), "offset")
self.units = _check_option("units", units, ("mm", "m"))
self.subject = subject
self.hemi = hemi
self.surf = surf
self.offset = offset
self.bin_curv = None
self.coords = None
self.curv = None
self.faces = None
self.nn = None
self.labels = dict()
self.x_dir = x_dir
subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=True))
self.data_path = path.join(subjects_dir, subject)
if surf == "seghead":
raise ValueError(
"`surf` cannot be seghead, use "
"`mne.viz.Brain.add_head` to plot the seghead"
)
def load_geometry(self):
"""Load geometry of the surface.
Parameters
----------
None
Returns
-------
None
"""
if self.surf == "flat": # special case
fname = path.join(self.data_path, "surf", f"{self.hemi}.cortex.patch.flat")
_check_fname(
fname, overwrite="read", must_exist=True, name="flatmap surface file"
)
coords, faces, orig_faces = _read_patch(fname)
# rotate 90 degrees to get to a more standard orientation
# where X determines the distance between the hemis
coords = coords[:, [1, 0, 2]]
coords[:, 1] *= -1
else:
# allow ?h.pial.T1 if ?h.pial doesn't exist for instance
# end with '' for better file not found error
for img in ("", ".T1", ".T2", ""):
surf_fname = path.join(
self.data_path, "surf", f"{self.hemi}.{self.surf}{img}"
)
if path.isfile(surf_fname):
break
coords, faces = read_surface(surf_fname)
orig_faces = faces
if self.units == "m":
coords /= 1000.0
if self.offset is not None:
x_ = coords @ self.x_dir
if self.hemi == "lh":
coords -= (np.max(x_) + self.offset) * self.x_dir
else:
coords -= (np.min(x_) + self.offset) * self.x_dir
surf = dict(rr=coords, tris=faces)
complete_surface_info(surf, copy=False, verbose=False, do_neighbor_tri=False)
nn = surf["nn"]
self.coords = coords
self.faces = faces
self.orig_faces = orig_faces
self.nn = nn
def __len__(self):
"""Return number of vertices."""
return len(self.coords)
@property
def x(self):
return self.coords[:, 0]
@property
def y(self):
return self.coords[:, 1]
@property
def z(self):
return self.coords[:, 2]
def load_curvature(self):
"""Load in curvature values from the ?h.curv file."""
curv_path = path.join(self.data_path, "surf", f"{self.hemi}.curv")
if path.isfile(curv_path):
self.curv = read_curvature(curv_path, binary=False)
self.bin_curv = np.array(self.curv > 0, np.int64)
else:
self.curv = None
self.bin_curv = None

54
mne/viz/_brain/view.py Normal file
View File

@@ -0,0 +1,54 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
ORIGIN = "auto"
DIST = "auto"
_lh_views_dict = {
"lateral": dict(azimuth=180.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
"medial": dict(azimuth=0.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
"rostral": dict(azimuth=90.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
"caudal": dict(azimuth=270.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
"dorsal": dict(azimuth=180.0, elevation=0.0, focalpoint=ORIGIN, distance=DIST),
"ventral": dict(azimuth=180.0, elevation=180.0, focalpoint=ORIGIN, distance=DIST),
"frontal": dict(azimuth=120.0, elevation=80.0, focalpoint=ORIGIN, distance=DIST),
"parietal": dict(azimuth=-120.0, elevation=60.0, focalpoint=ORIGIN, distance=DIST),
"sagittal": dict(azimuth=180.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
"coronal": dict(azimuth=90.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
"axial": dict(
azimuth=180.0, elevation=0.0, focalpoint=ORIGIN, roll=0, distance=DIST
), # noqa: E501
}
_rh_views_dict = {
"lateral": dict(azimuth=180.0, elevation=-90.0, focalpoint=ORIGIN, distance=DIST),
"medial": dict(azimuth=0.0, elevation=-90.0, focalpoint=ORIGIN, distance=DIST),
"rostral": dict(azimuth=-90.0, elevation=-90.0, focalpoint=ORIGIN, distance=DIST),
"caudal": dict(azimuth=90.0, elevation=-90.0, focalpoint=ORIGIN, distance=DIST),
"dorsal": dict(azimuth=180.0, elevation=0.0, focalpoint=ORIGIN, distance=DIST),
"ventral": dict(azimuth=180.0, elevation=180.0, focalpoint=ORIGIN, distance=DIST),
"frontal": dict(azimuth=60.0, elevation=80.0, focalpoint=ORIGIN, distance=DIST),
"parietal": dict(azimuth=-60.0, elevation=60.0, focalpoint=ORIGIN, distance=DIST),
"sagittal": dict(azimuth=180.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
"coronal": dict(azimuth=90.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
"axial": dict(
azimuth=180.0, elevation=0.0, focalpoint=ORIGIN, roll=0, distance=DIST
),
}
# add short-size version entries into the dict
lh_views_dict = _lh_views_dict.copy()
for k, v in _lh_views_dict.items():
lh_views_dict[k[:3]] = v
lh_views_dict["flat"] = dict(
azimuth=0, elevation=0, focalpoint=ORIGIN, roll=0, distance=DIST
)
rh_views_dict = _rh_views_dict.copy()
for k, v in _rh_views_dict.items():
rh_views_dict[k[:3]] = v
rh_views_dict["flat"] = dict(
azimuth=0, elevation=0, focalpoint=ORIGIN, roll=0, distance=DIST
)
views_dicts = dict(
lh=lh_views_dict, vol=lh_views_dict, both=lh_views_dict, rh=rh_views_dict
)

214
mne/viz/_dipole.py Normal file
View File

@@ -0,0 +1,214 @@
"""Dipole viz specific functions."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import os.path as op
import numpy as np
from scipy.spatial import ConvexHull
from .._freesurfer import _estimate_talxfm_rigid, _get_head_surface
from ..surface import read_surface
from ..transforms import _get_trans, apply_trans, invert_transform
from ..utils import _check_option, _validate_type, get_subjects_dir
from .utils import _validate_if_list_of_axes, plt_show
def _check_concat_dipoles(dipole):
from ..dipole import Dipole, _concatenate_dipoles
if not isinstance(dipole, Dipole):
dipole = _concatenate_dipoles(dipole)
return dipole
def _plot_dipole_mri_outlines(
dipoles,
*,
subject,
trans,
ax,
subjects_dir,
color,
scale,
coord_frame,
show,
block,
head_source,
title,
surf,
width,
):
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection, PatchCollection
from matplotlib.patches import Circle
extra = 'when mode is "outlines"'
trans = _get_trans(trans, fro="head", to="mri")[0]
_check_option(
"coord_frame", coord_frame, ["head", "mri", "mri_rotated"], extra=extra
)
_validate_type(surf, (str, None), "surf")
_check_option("surf", surf, ("white", "pial", None))
if ax is None:
_, ax = plt.subplots(1, 3, figsize=(7, 2.5), squeeze=True, layout="constrained")
_validate_if_list_of_axes(ax, 3, name="ax")
dipoles = _check_concat_dipoles(dipoles)
color = "r" if color is None else color
scale = 0.03 if scale is None else scale
width = 0.015 if width is None else width
fig = ax[0].figure
surfs = dict()
hemis = ("lh", "rh")
if surf is not None:
for hemi in hemis:
surfs[hemi] = read_surface(
op.join(subjects_dir, subject, "surf", f"{hemi}.{surf}"),
return_dict=True,
)[2]
surfs[hemi]["rr"] /= 1000.0
subjects_dir = get_subjects_dir(subjects_dir)
if subjects_dir is not None:
subjects_dir = str(subjects_dir)
surfs["head"] = _get_head_surface(head_source, subject, subjects_dir)
del head_source
mri_trans = head_trans = np.eye(4)
if coord_frame in ("mri", "mri_rotated"):
head_trans = trans["trans"]
if coord_frame == "mri_rotated":
rot = _estimate_talxfm_rigid(subject, subjects_dir)
rot[:3, 3] = 0.0
head_trans = rot @ head_trans
mri_trans = rot @ mri_trans
else:
assert coord_frame == "head"
mri_trans = invert_transform(trans)["trans"]
for s in surfs.values():
s["rr"] = 1000 * apply_trans(mri_trans, s["rr"])
del mri_trans
levels = dict()
if surf is not None:
use_rr = np.concatenate([surfs[key]["rr"] for key in hemis])
else:
use_rr = surfs["head"]["rr"]
views = [("Axial", "XY"), ("Coronal", "XZ"), ("Sagittal", "YZ")]
# axial: 25% up the Z axis
axial = float(np.percentile(use_rr[:, 2], 20.0))
coronal = float(np.percentile(use_rr[:, 1], 55.0))
for key in hemis + ("head",):
levels[key] = dict(Axial=axial, Coronal=coronal)
if surf is not None:
levels["rh"]["Sagittal"] = float(np.percentile(surfs["rh"]["rr"][:, 0], 50))
levels["head"]["Sagittal"] = 0.0
for ax_, (name, coords) in zip(ax, views):
idx = list(map(dict(X=0, Y=1, Z=2).get, coords))
miss = np.setdiff1d(np.arange(3), idx)[0]
pos = 1000 * apply_trans(head_trans, dipoles.pos)
ori = 1000 * apply_trans(head_trans, dipoles.ori, move=False)
lims = dict()
for ii, char in enumerate(coords):
lim = surfs["head"]["rr"][:, idx[ii]]
lim = np.array([lim.min(), lim.max()])
lims[char] = lim
ax_.quiver(
pos[:, idx[0]],
pos[:, idx[1]],
scale * ori[:, idx[0]],
scale * ori[:, idx[1]],
color=color,
pivot="middle",
zorder=5,
scale_units="xy",
angles="xy",
scale=1.0,
width=width,
minshaft=0.5,
headwidth=2.5,
headlength=2.5,
headaxislength=2,
)
coll = PatchCollection(
[
Circle((x, y), radius=scale * 1000 * width * 6)
for x, y in zip(pos[:, idx[0]], pos[:, idx[1]])
],
linewidths=0.0,
facecolors=color,
zorder=6,
)
for key, surf in surfs.items():
try:
level = levels[key][name]
except KeyError:
continue
if key != "head":
rrs = surf["rr"][:, idx]
tris = ConvexHull(rrs).simplices
segments = LineCollection(
rrs[:, [0, 1]][tris],
linewidths=1,
linestyles="-",
colors="k",
zorder=3,
alpha=0.25,
)
ax_.add_collection(segments)
ax_.tricontour(
surf["rr"][:, idx[0]],
surf["rr"][:, idx[1]],
surf["tris"],
surf["rr"][:, miss],
levels=[level],
colors="k",
linewidths=1.0,
linestyles=["-"],
zorder=4,
alpha=0.5,
)
# TODO: this breaks the PatchCollection in MPL
# for coll in h.collections:
# coll.set_clip_on(False)
ax_.add_collection(coll)
ax_.set(
title=name,
xlim=lims[coords[0]],
ylim=lims[coords[1]],
xlabel=coords[0] + " (mm)",
ylabel=coords[1] + " (mm)",
)
for spine in ax_.spines.values():
spine.set_visible(False)
ax_.grid(True, ls=":", zorder=2)
ax_.set_aspect("equal")
if title is not None:
fig.suptitle(title)
plt_show(show, block=block)
return fig
def _plot_dipole_3d(dipoles, *, coord_frame, color, fig, trans, scale, mode):
from .backends.renderer import _get_renderer
_check_option("coord_frame", coord_frame, ("head", "mri"))
color = "r" if color is None else color
scale = 0.005 if scale is None else scale
renderer = _get_renderer(fig=fig, size=(600, 600))
pos = dipoles.pos
ori = dipoles.ori
if coord_frame != "head":
trans = _get_trans(trans, fro="head", to=coord_frame)[0]
pos = apply_trans(trans, pos)
ori = apply_trans(trans, ori)
renderer.sphere(center=pos, color=color, scale=scale)
if mode == "arrow":
x, y, z = pos.T
u, v, w = ori.T
renderer.quiver3d(x, y, z, u, v, w, scale=3 * scale, color=color, mode="arrow")
renderer.show()
fig = renderer.scene()
return fig

853
mne/viz/_figure.py Normal file
View File

@@ -0,0 +1,853 @@
"""Base classes and functions for 2D browser backends."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import importlib
from abc import ABC, abstractmethod
from collections import OrderedDict
from contextlib import contextmanager
from copy import deepcopy
from itertools import cycle
import numpy as np
from .._fiff.pick import _DATA_CH_TYPES_SPLIT
from ..defaults import _handle_default
from ..filter import _iir_filter, _overlap_add_filter
from ..fixes import _compare_version
from ..utils import (
_check_option,
_get_stim_channel,
_validate_type,
get_config,
logger,
set_config,
verbose,
)
from .backends._utils import VALID_BROWSE_BACKENDS
from .utils import _get_color_list, _setup_plot_projector, _show_browser
MNE_BROWSER_BACKEND = None
backend = None
class BrowserParams:
"""Container object for 2D browser parameters."""
def __init__(self, **kwargs):
# default key to close window
self.close_key = "escape"
vars(self).update(**kwargs)
class BrowserBase(ABC):
"""A base class containing for the 2D browser.
This class contains all backend-independent attributes and methods.
"""
def __init__(self, **kwargs):
from ..epochs import BaseEpochs
from ..io import BaseRaw
from ..preprocessing import ICA
self.backend_name = None
self._data = None
self._times = None
self.mne = BrowserParams(**kwargs)
inst = kwargs.get("inst", None)
ica = kwargs.get("ica", None)
# what kind of data are we dealing with?
if isinstance(ica, ICA):
self.mne.instance_type = "ica"
elif isinstance(inst, BaseRaw):
self.mne.instance_type = "raw"
elif isinstance(inst, BaseEpochs):
self.mne.instance_type = "epochs"
else:
raise TypeError(
f"Expected an instance of Raw, Epochs, or ICA, got {type(inst)}."
)
logger.debug(f"Opening {self.mne.instance_type} browser...")
self.mne.ica_type = None
if self.mne.instance_type == "ica":
if isinstance(self.mne.ica_inst, BaseRaw):
self.mne.ica_type = "raw"
elif isinstance(self.mne.ica_inst, BaseEpochs):
self.mne.ica_type = "epochs"
self.mne.is_epochs = "epochs" in (self.mne.instance_type, self.mne.ica_type)
# things that always start the same
self.mne.ch_start = 0
self.mne.projector = None
if hasattr(self.mne, "projs"):
self.mne.projs_active = np.array([p["active"] for p in self.mne.projs])
self.mne.whitened_ch_names = list()
if hasattr(self.mne, "noise_cov"):
self.mne.use_noise_cov = self.mne.noise_cov is not None
# allow up to 10000 zorder levels for annotations
self.mne.zorder = dict(
patch=0,
grid=1,
ann=2,
events=10003,
bads=10004,
data=10005,
mag=10006,
grad=10007,
scalebar=10008,
vline=10009,
)
# additional params for epochs (won't affect raw / ICA)
self.mne.epoch_traces = list()
self.mne.bad_epochs = list()
if inst is not None:
self.mne.sampling_period = np.diff(inst.times[:2])[0] / inst.info["sfreq"]
# annotations
self.mne.annotations = list()
self.mne.hscroll_annotations = list()
self.mne.annotation_segments = list()
self.mne.annotation_texts = list()
self.mne.new_annotation_labels = list()
self.mne.annotation_segment_colors = dict()
self.mne.annotation_hover_line = None
self.mne.draggable_annotations = False
# lines
self.mne.event_lines = list()
self.mne.event_texts = list()
self.mne.vline_visible = False
# decim
self.mne.decim_times = None
self.mne.decim_data = None
# scalings
if hasattr(self.mne, "butterfly"):
self.mne.scale_factor = 0.5 if self.mne.butterfly else 1.0
self.mne.scalebars = dict()
self.mne.scalebar_texts = dict()
# ancillary child figures
self.mne.child_figs = list()
self.mne.fig_help = None
self.mne.fig_proj = None
self.mne.fig_histogram = None
self.mne.fig_selection = None
self.mne.fig_annotation = None
# extra attributes for epochs
if self.mne.is_epochs:
# add epoch boundaries & center epoch numbers between boundaries
self.mne.midpoints = (
np.convolve(self.mne.boundary_times, np.ones(2), mode="valid") / 2
)
# initialize picks and projectors
self._update_picks()
if not self.mne.instance_type == "ica":
self._update_projector()
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# ANNOTATIONS
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
def _get_annotation_labels(self):
"""Get the unique labels in the raw object and added in the UI."""
return sorted(
set(self.mne.inst.annotations.description)
| set(self.mne.new_annotation_labels)
)
def _setup_annotation_colors(self):
"""Set up colors for annotations; init some annotation vars."""
segment_colors = getattr(self.mne, "annotation_segment_colors", dict())
labels = self._get_annotation_labels()
colors, red = _get_color_list(annotations=True)
color_cycle = cycle(colors)
for key, color in segment_colors.items():
if color != red and key in labels:
next(color_cycle)
for idx, key in enumerate(labels):
if key.lower().startswith("bad") or key.lower().startswith("edge"):
segment_colors[key] = red
elif key in segment_colors:
continue
else:
segment_colors[key] = next(color_cycle)
self.mne.annotation_segment_colors = segment_colors
# init a couple other annotation-related variables
self.mne.visible_annotations = {label: True for label in labels}
self.mne.show_hide_annotation_checkboxes = None
def _update_annotation_segments(self):
"""Update the array of annotation start/end times."""
from ..annotations import _sync_onset
self.mne.annotation_segments = np.array([])
if len(self.mne.inst.annotations):
annot_start = _sync_onset(self.mne.inst, self.mne.inst.annotations.onset)
durations = self.mne.inst.annotations.duration.copy()
durations[durations < 1 / self.mne.info["sfreq"]] = (
1 / self.mne.info["sfreq"]
)
annot_end = annot_start + durations
self.mne.annotation_segments = np.vstack((annot_start, annot_end)).T
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# PROJECTOR & BADS
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
def _update_projector(self):
"""Update the data after projectors (or bads) have changed."""
inds = np.where(self.mne.projs_on)[0] # doesn't include "active" projs
# copy projs from full list (self.mne.projs) to info object
with self.mne.info._unlock():
self.mne.info["projs"] = [deepcopy(self.mne.projs[ix]) for ix in inds]
# compute the projection operator
proj, wh_chs = _setup_plot_projector(
self.mne.info, self.mne.noise_cov, True, self.mne.use_noise_cov
)
self.mne.whitened_ch_names = list(wh_chs)
self.mne.projector = proj
def _toggle_bad_channel(self, idx):
"""Mark/unmark bad channels; `idx` is index of *visible* channels."""
pick = self.mne.picks[idx]
ch_name = self.mne.ch_names[pick]
# add/remove from bads list
bads = self.mne.info["bads"]
marked_bad = ch_name not in bads
if marked_bad:
bads.append(ch_name)
color = self.mne.ch_color_bad
else:
while ch_name in bads: # to make sure duplicates are removed
bads.remove(ch_name)
# Only mpl-backend has ch_colors
if hasattr(self.mne, "ch_colors"):
color = self.mne.ch_colors[idx]
else:
color = None
self.mne.info["bads"] = bads
self._update_projector()
return color, pick, marked_bad
def _toggle_single_channel_annotation(self, ch_pick, annot_idx):
current_ch_names = list(self.mne.inst.annotations.ch_names[annot_idx])
if ch_pick in current_ch_names:
current_ch_names.remove(ch_pick)
else:
current_ch_names.append(ch_pick)
self.mne.inst.annotations.ch_names[annot_idx] = tuple(current_ch_names)
def _toggle_bad_epoch(self, xtime):
epoch_num = self._get_epoch_num_from_time(xtime)
epoch_ix = self.mne.inst.selection.tolist().index(epoch_num)
if epoch_num in self.mne.bad_epochs:
self.mne.bad_epochs.remove(epoch_num)
color = "none"
else:
self.mne.bad_epochs.append(epoch_num)
self.mne.bad_epochs.sort()
color = self.mne.epoch_color_bad
return epoch_ix, color
def _toggle_whitening(self):
if self.mne.noise_cov is not None:
self.mne.use_noise_cov = not self.mne.use_noise_cov
self._update_projector()
self._update_yaxis_labels() # add/remove italics
self._redraw()
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# MANAGE TRACES
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
def _update_picks(self):
"""Compute which channel indices to show."""
if self.mne.butterfly and self.mne.ch_selections is not None:
selections_dict = self._make_butterfly_selections_dict()
self.mne.picks = np.concatenate(tuple(selections_dict.values()))
elif self.mne.butterfly:
self.mne.picks = self.mne.ch_order
else:
_slice = slice(self.mne.ch_start, self.mne.ch_start + self.mne.n_channels)
self.mne.picks = self.mne.ch_order[_slice]
self.mne.n_channels = len(self.mne.picks)
assert isinstance(self.mne.picks, np.ndarray)
assert self.mne.picks.dtype.kind == "i"
def _make_butterfly_selections_dict(self):
"""Make an altered copy of the selections dict for butterfly mode."""
selections_dict = deepcopy(self.mne.ch_selections)
# remove potential duplicates
for selection_group in ("Vertex", "Custom"):
selections_dict.pop(selection_group, None)
# if present, remove stim channel from non-misc selection groups
stim_ch = _get_stim_channel(None, self.mne.info, raise_error=False)
if len(stim_ch):
stim_pick = self.mne.ch_names.tolist().index(stim_ch[0])
for _sel, _picks in selections_dict.items():
if _sel != "Misc":
stim_mask = np.isin(_picks, [stim_pick], invert=True)
selections_dict[_sel] = np.array(_picks)[stim_mask]
return selections_dict
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# MANAGE DATA
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
def _get_start_stop(self):
# update time
start_sec = self.mne.t_start - self.mne.first_time
stop_sec = start_sec + self.mne.duration
if self.mne.is_epochs:
start, stop = np.round(
np.array([start_sec, stop_sec]) * self.mne.info["sfreq"]
).astype(int)
else:
start, stop = self.mne.inst.time_as_index((start_sec, stop_sec))
return start, stop
def _load_data(self, start=None, stop=None):
"""Retrieve the bit of data we need for plotting."""
if "raw" in (self.mne.instance_type, self.mne.ica_type):
# Add additional sample to cover the case sfreq!=1000
# when the shown time-range wouldn't correspond to duration anymore
if stop is None:
return self.mne.inst[:, start:]
else:
return self.mne.inst[:, start : stop + 2]
else:
# subtract one sample from tstart before searchsorted, to make sure
# we land on the left side of the boundary time (avoid precision
# errors)
ix_start = np.searchsorted(
self.mne.boundary_times, self.mne.t_start - self.mne.sampling_period
)
ix_stop = ix_start + self.mne.n_epochs
item = slice(ix_start, ix_stop)
data = np.concatenate(
self.mne.inst.get_data(item=item, copy=False), axis=-1
)
times = np.arange(start, stop) / self.mne.info["sfreq"]
return data, times
def _apply_filter(self, data, start, stop, picks):
"""Filter (with same defaults as raw.filter())."""
starts, stops = self.mne.filter_bounds
mask = (starts < stop) & (stops > start)
starts = np.maximum(starts[mask], start) - start
stops = np.minimum(stops[mask], stop) - start
for _start, _stop in zip(starts, stops):
_picks = np.where(np.isin(picks, self.mne.picks_data))[0]
if len(_picks) == 0:
break
this_data = data[_picks, _start:_stop]
if isinstance(self.mne.filter_coefs, np.ndarray): # FIR
this_data = _overlap_add_filter(
this_data, self.mne.filter_coefs, copy=False
)
else: # IIR
this_data = _iir_filter(
this_data, self.mne.filter_coefs, None, 1, False
)
data[_picks, _start:_stop] = this_data
def _process_data(self, data, start, stop, picks, thread=None):
"""Update self.mne.data after user interaction."""
# apply projectors
if self.mne.projector is not None:
# thread is the loading-thread only available in Qt-backend
if thread:
thread.processText.emit("Applying Projectors...")
data = self.mne.projector @ data
# get only the channels we're displaying
data = data[picks]
# remove DC
if self.mne.remove_dc:
if thread:
thread.processText.emit("Removing DC...")
data -= np.nanmean(data, axis=1, keepdims=True)
# apply filter
if self.mne.filter_coefs is not None:
if thread:
thread.processText.emit("Apply Filter...")
self._apply_filter(data, start, stop, picks)
# scale the data for display in a 1-vertical-axis-unit slot
if thread:
thread.processText.emit("Scale Data...")
this_names = self.mne.ch_names[picks]
this_types = self.mne.ch_types[picks]
stims = this_types == "stim"
white = np.logical_and(
np.isin(this_names, self.mne.whitened_ch_names),
np.isin(this_names, self.mne.info["bads"], invert=True),
)
norms = np.vectorize(self.mne.scalings.__getitem__)(this_types)
norms[stims] = data[stims].max(axis=-1)
norms[white] = self.mne.scalings["whitened"]
norms[norms == 0] = 1
data /= 2 * norms[:, np.newaxis]
return data
def _update_data(self):
start, stop = self._get_start_stop()
# get the data
data, times = self._load_data(start, stop)
# process the data
data = self._process_data(data, start, stop, self.mne.picks)
# set the data as attributes
self.mne.data = data
self.mne.times = times
def _get_epoch_num_from_time(self, time):
epoch_nums = self.mne.inst.selection
return epoch_nums[np.searchsorted(self.mne.boundary_times[1:], time)]
def _redraw(self, update_data=True, annotations=False):
"""Redraws backend if necessary."""
if update_data:
self._update_data()
self._draw_traces()
if annotations and not self.mne.is_epochs:
self._draw_annotations()
def _close(self, event):
"""Handle close events (via keypress or window [x])."""
from matplotlib.pyplot import close
logger.debug(f"Closing {self.mne.instance_type} browser...")
# write out bad epochs (after converting epoch numbers to indices)
if self.mne.instance_type == "epochs":
bad_ixs = np.isin(self.mne.inst.selection, self.mne.bad_epochs).nonzero()[0]
self.mne.inst.drop(bad_ixs)
logger.info(
"The following epochs were marked as bad "
"and are dropped:\n"
f"{self.mne.bad_epochs}"
)
# write bad channels back to instance (don't do this for proj;
# proj checkboxes are for viz only and shouldn't modify the instance)
if self.mne.instance_type in ("raw", "epochs"):
self.mne.inst.info["bads"] = self.mne.info["bads"]
logger.info(f"Channels marked as bad:\n{self.mne.info['bads'] or 'none'}")
# ICA excludes
elif self.mne.instance_type == "ica":
self.mne.ica.exclude = [
self.mne.ica._ica_names.index(ch) for ch in self.mne.info["bads"]
]
# write window size to config
str_size = ",".join([str(i) for i in self._get_size()])
set_config("MNE_BROWSE_RAW_SIZE", str_size, set_env=False)
# Clean up child figures (don't pop(), child figs remove themselves)
while len(self.mne.child_figs):
fig = self.mne.child_figs[-1]
close(fig)
self._close_event(fig)
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# CHILD FIGURES
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
@abstractmethod
def _new_child_figure(self, fig_name, **kwargs):
pass
def _create_ch_context_fig(self, idx):
"""Show context figure; idx is index of **visible** channels."""
inst = self.mne.instance_type
pick = self.mne.picks[idx]
if inst == "raw":
fig = self._create_ch_location_fig(pick)
elif inst == "ica":
fig = self._create_ica_properties_fig(pick)
else:
fig = self._create_epoch_image_fig(pick)
return fig
def _create_ch_location_fig(self, pick):
"""Show channel location figure."""
from .utils import _channel_type_prettyprint, plot_sensors
ch_name = self.mne.ch_names[pick]
ch_type = self.mne.ch_types[pick]
if ch_type not in _DATA_CH_TYPES_SPLIT:
return
# create figure and axes
title = f"Location of {ch_name}"
fig = self._new_child_figure(figsize=(4, 4), fig_name=None, window_title=title)
fig.suptitle(title)
ax = fig.add_subplot(111)
title = f"{ch_name} position ({_channel_type_prettyprint[ch_type]})"
_ = plot_sensors(
self.mne.info,
ch_type=ch_type,
axes=ax,
title=title,
kind="select",
show=False,
)
# highlight desired channel & disable interactivity
inds = np.isin(fig.lasso.ch_names, [ch_name])
fig.lasso.disconnect()
fig.lasso.alpha_other = 0.3
fig.lasso.linewidth_selected = 3
fig.lasso.style_sensors(inds)
return fig
def _create_ica_properties_fig(self, idx):
"""Show ICA properties for the selected component."""
from mne.viz.ica import (
_create_properties_layout,
_fast_plot_ica_properties,
_prepare_data_ica_properties,
)
ch_name = self.mne.ch_names[idx]
if ch_name not in self.mne.ica._ica_names: # for EOG chans: do nothing
return
pick = self.mne.ica._ica_names.index(ch_name)
title = f"{ch_name} properties"
fig = self._new_child_figure(figsize=(7, 6), fig_name=None, window_title=title)
fig.suptitle(title)
fig, axes = _create_properties_layout(fig=fig)
if not hasattr(self.mne, "data_ica_properties"):
# Precompute epoch sources only once
self.mne.data_ica_properties = _prepare_data_ica_properties(
self.mne.ica_inst, self.mne.ica
)
_fast_plot_ica_properties(
self.mne.ica,
self.mne.ica_inst,
picks=pick,
axes=axes,
psd_args=self.mne.psd_args,
precomputed_data=self.mne.data_ica_properties,
show=False,
)
return fig
def _create_epoch_image_fig(self, pick):
"""Show epochs image for the selected channel."""
from matplotlib.gridspec import GridSpec
from mne.viz import plot_epochs_image
ch_name = self.mne.ch_names[pick]
title = f"Epochs image ({ch_name})"
fig = self._new_child_figure(figsize=(6, 4), fig_name=None, window_title=title)
fig.suptitle = title
gs = GridSpec(nrows=3, ncols=10, figure=fig)
fig.add_subplot(gs[:2, :9])
fig.add_subplot(gs[2, :9])
fig.add_subplot(gs[:2, 9])
plot_epochs_image(self.mne.inst, picks=pick, fig=fig, show=False)
return fig
def _create_epoch_histogram(self):
"""Create peak-to-peak histogram of channel amplitudes."""
epochs = self.mne.inst
data = OrderedDict()
ptp = np.ptp(epochs.get_data(copy=False), axis=2)
for ch_type in ("eeg", "mag", "grad"):
if ch_type in epochs:
data[ch_type] = ptp.T[self.mne.ch_types == ch_type].ravel()
units = _handle_default("units")
titles = _handle_default("titles")
colors = _handle_default("color")
scalings = _handle_default("scalings")
title = "Histogram of peak-to-peak amplitudes"
figsize = (4, 1 + 1.5 * len(data))
fig = self._new_child_figure(
figsize=figsize, fig_name="fig_histogram", window_title=title
)
for ix, (_ch_type, _data) in enumerate(data.items()):
ax = fig.add_subplot(len(data), 1, ix + 1)
ax.set(title=titles[_ch_type], xlabel=units[_ch_type], ylabel="Count")
# set histogram bin range based on rejection thresholds
reject = None
_range = None
if epochs.reject is not None and _ch_type in epochs.reject:
reject = epochs.reject[_ch_type] * scalings[_ch_type]
_range = (0.0, reject * 1.1)
# plot it
ax.hist(
_data * scalings[_ch_type],
bins=100,
color=colors[_ch_type],
range=_range,
)
if reject is not None:
ax.plot((reject, reject), (0, ax.get_ylim()[1]), color="r")
# finalize
fig.suptitle(title, y=0.99)
self.mne.fig_histogram = fig
return fig
def _close_event(self, fig):
"""Look at _close_event in mne.fixes.py for why this exists."""
pass
def fake_keypress(self, key, fig=None): # noqa: D400
"""Pass a fake keypress to the figure.
Parameters
----------
key : str
The key to fake (e.g., ``'a'``).
fig : instance of Figure
The figure to pass the keypress to.
"""
return self._fake_keypress(key, fig=fig)
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# TEST METHODS
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
@abstractmethod
def _get_size(self):
pass
@abstractmethod
def _fake_keypress(self, key, fig):
pass
@abstractmethod
def _fake_click(self, point, fig, axis, xform, button, kind):
pass
@abstractmethod
def _click_ch_name(self, ch_index, button):
pass
@abstractmethod
def _resize_by_factor(self, factor):
pass
@abstractmethod
def _get_ticklabels(self, orientation):
pass
@abstractmethod
def _update_yaxis_labels(self):
pass
def _load_backend(backend_name):
global backend
if backend_name == "matplotlib":
backend = importlib.import_module(name="._mpl_figure", package="mne.viz")
else:
from mne_qt_browser import _pg_figure as backend
logger.info(f"Using {backend_name} as 2D backend.")
return backend
def _get_browser(show, block, **kwargs):
"""Instantiate a new MNE browse-style figure."""
from .utils import _get_figsize_from_config
figsize = kwargs.setdefault("figsize", _get_figsize_from_config())
if figsize is None or np.any(np.array(figsize) < 8):
kwargs["figsize"] = (8, 8)
kwargs["splash"] = kwargs.get("splash", True) and show
if kwargs.get("theme", None) is None:
kwargs["theme"] = get_config("MNE_BROWSER_THEME", "auto")
if kwargs.get("overview_mode", None) is None:
kwargs["overview_mode"] = get_config("MNE_BROWSER_OVERVIEW_MODE", "channels")
# Initialize browser backend
backend_name = get_browser_backend()
# Check mne-qt-browser compatibility
if backend_name == "qt":
import mne_qt_browser
from ..epochs import BaseEpochs
is_ica = kwargs.get("ica", False)
is_epochs = isinstance(kwargs.get("inst", False), BaseEpochs)
not_compat = _compare_version(mne_qt_browser.__version__, "<", "0.2.0")
inst_str = "ICA" if is_ica else "Epochs"
if not_compat and (is_ica or is_epochs):
logger.info(
f'You set the browser-backend to "qt" but your'
f" current version {mne_qt_browser.__version__}"
f" of mne-qt-browser is too low for {inst_str}."
f"Update with pip or conda."
f"Defaults to matplotlib."
)
with use_browser_backend("matplotlib"):
# Initialize Browser
fig = backend._init_browser(**kwargs)
_show_browser(show=show, block=block, fig=fig)
return fig
# Initialize Browser
fig = backend._init_browser(**kwargs)
_show_browser(show=show, block=block, fig=fig)
return fig
def _check_browser_backend_name(backend_name):
_validate_type(backend_name, str, "backend_name")
backend_name = backend_name.lower()
backend_name = "qt" if backend_name == "pyqtgraph" else backend_name
_check_option("backend_name", backend_name, VALID_BROWSE_BACKENDS)
return backend_name
@verbose
def set_browser_backend(backend_name, verbose=None):
"""Set the 2D browser backend for MNE.
The backend will be set as specified and operations will use
that backend.
Parameters
----------
backend_name : str
The 2D browser backend to select. See Notes for the capabilities
of each backend (``'qt'``, ``'matplotlib'``). The ``'qt'`` browser
requires `mne-qt-browser
<https://github.com/mne-tools/mne-qt-browser>`__.
%(verbose)s
Returns
-------
old_backend_name : str | None
The old backend that was in use.
Notes
-----
This table shows the capabilities of each backend ("" for full support,
and "-" for partial support):
.. table::
:widths: auto
+--------------------------------------+------------+----+
| **2D browser function:** | matplotlib | qt |
+======================================+============+====+
| :func:`plot_raw` | ✓ | ✓ |
+--------------------------------------+------------+----+
| :func:`plot_epochs` | ✓ | ✓ |
+--------------------------------------+------------+----+
| :func:`plot_ica_sources` | ✓ | ✓ |
+--------------------------------------+------------+----+
+--------------------------------------+------------+----+
| **Feature:** |
+--------------------------------------+------------+----+
| Show Events | ✓ | ✓ |
+--------------------------------------+------------+----+
| Add/Edit/Remove Annotations | ✓ | ✓ |
+--------------------------------------+------------+----+
| Toggle Projections | ✓ | ✓ |
+--------------------------------------+------------+----+
| Butterfly Mode | ✓ | ✓ |
+--------------------------------------+------------+----+
| Selection Mode | ✓ | ✓ |
+--------------------------------------+------------+----+
| Smooth Scrolling | | ✓ |
+--------------------------------------+------------+----+
| Overview-Bar (with Z-Score-Mode) | | ✓ |
+--------------------------------------+------------+----+
.. versionadded:: 0.24
"""
global MNE_BROWSER_BACKEND
old_backend_name = MNE_BROWSER_BACKEND
backend_name = _check_browser_backend_name(backend_name)
if MNE_BROWSER_BACKEND != backend_name:
_load_backend(backend_name)
MNE_BROWSER_BACKEND = backend_name
return old_backend_name
def _init_browser_backend():
global MNE_BROWSER_BACKEND
# check if MNE_BROWSER_BACKEND is not None and valid or get it from config
loaded_backend = MNE_BROWSER_BACKEND or get_config(
key="MNE_BROWSER_BACKEND", default=None
)
if loaded_backend is not None:
set_browser_backend(loaded_backend)
return MNE_BROWSER_BACKEND
else:
errors = dict()
# Try import of valid browser backends
for name in VALID_BROWSE_BACKENDS:
try:
_load_backend(name)
except ImportError as exc:
errors[name] = str(exc)
else:
MNE_BROWSER_BACKEND = name
break
else:
raise RuntimeError(
"Could not load any valid 2D backend:\n"
+ "\n".join(f"{key}: {val}" for key, val in errors.items())
)
return MNE_BROWSER_BACKEND
def get_browser_backend():
"""Return the 2D backend currently used.
Returns
-------
backend_used : str | None
The 2D browser backend currently in use. If no backend is found,
returns ``None``.
"""
try:
backend_name = _init_browser_backend()
except RuntimeError as exc:
backend_name = None
logger.info(str(exc))
return backend_name
@contextmanager
def use_browser_backend(backend_name):
"""Create a 2D browser visualization context using the designated backend.
See :func:`mne.viz.set_browser_backend` for more details on the available
2D browser backends and their capabilities.
Parameters
----------
backend_name : {'qt', 'matplotlib'}
The 2D browser backend to use in the context.
"""
old_backend = set_browser_backend(backend_name)
try:
yield backend
finally:
if old_backend is not None:
try:
set_browser_backend(old_backend)
except Exception:
pass

2530
mne/viz/_mpl_figure.py Normal file

File diff suppressed because it is too large Load Diff

252
mne/viz/_proj.py Normal file
View File

@@ -0,0 +1,252 @@
"""Functions for plotting projectors."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from copy import deepcopy
import numpy as np
from .._fiff.pick import _picks_to_idx
from ..defaults import DEFAULTS
from ..utils import _pl, _validate_type, verbose, warn
from .evoked import _plot_evoked
from .topomap import _plot_projs_topomap
from .utils import _check_type_projs, plt_show
@verbose
def plot_projs_joint(
projs, evoked, picks_trace=None, *, topomap_kwargs=None, show=True, verbose=None
):
"""Plot projectors and evoked jointly.
Parameters
----------
projs : list of Projection
The projectors to plot.
evoked : instance of Evoked
The data to plot. Typically this is the evoked instance created from
averaging the epochs used to create the projection.
%(picks_plot_projs_joint_trace)s
topomap_kwargs : dict | None
Keyword arguments to pass to :func:`mne.viz.plot_projs_topomap`.
%(show)s
%(verbose)s
Returns
-------
fig : instance of matplotlib Figure
The figure.
Notes
-----
This function creates a figure with three columns:
1. The left shows the evoked data traces before (black) and after (green)
projection.
2. The center shows the topomaps associated with each of the projectors.
3. The right again shows the data traces (black), but this time with:
1. The data projected onto each projector with a single normalization
factor (solid lines). This is useful for seeing the relative power
in each projection vector.
2. The data projected onto each projector with individual normalization
factors (dashed lines). This is useful for visualizing each time
course regardless of its power.
3. Additional data traces from ``picks_trace`` (solid yellow lines).
This is useful for visualizing the "ground truth" of the time
course, e.g. the measured EOG or ECG channel time courses.
.. versionadded:: 1.1
"""
import matplotlib.pyplot as plt
from ..evoked import Evoked
_validate_type(evoked, Evoked, "evoked")
_validate_type(topomap_kwargs, (None, dict), "topomap_kwargs")
projs = _check_type_projs(projs)
topomap_kwargs = dict() if topomap_kwargs is None else topomap_kwargs
if picks_trace is not None:
picks_trace = _picks_to_idx(evoked.info, picks_trace, allow_empty=False)
info = evoked.info
ch_types = evoked.get_channel_types(unique=True, only_data_chs=True)
proj_by_type = dict() # will be set up like an enumerate key->[pi, proj]
ch_names_by_type = dict()
used = np.zeros(len(projs), int)
for ch_type in ch_types:
these_picks = _picks_to_idx(info, ch_type, allow_empty=True)
these_chs = [evoked.ch_names[pick] for pick in these_picks]
ch_names_by_type[ch_type] = these_chs
for pi, proj in enumerate(projs):
if not set(these_chs).intersection(proj["data"]["col_names"]):
continue
if ch_type not in proj_by_type:
proj_by_type[ch_type] = list()
proj_by_type[ch_type].append([pi, deepcopy(proj)])
used[pi] += 1
missing = (~used.astype(bool)).sum()
if missing:
warn(
f"{missing} projector{_pl(missing)} had no channel names "
"present in epochs"
)
del projs
ch_types = list(proj_by_type) # reduce to number we actually need
# room for legend
max_proj_per_type = max(len(x) for x in proj_by_type.values())
cs_trace = 3
cs_topo = 2
n_col = max_proj_per_type * cs_topo + 2 * cs_trace
n_row = len(ch_types)
shape = (n_row, n_col)
fig = plt.figure(
figsize=(n_col * 1.1 + 0.5, n_row * 1.8 + 0.5), layout="constrained"
)
ri = 0
# pick some sufficiently distinct colors (6 per proj type, e.g., ECG,
# should be enough hopefully!)
# https://personal.sron.nl/~pault/data/colourschemes.pdf
# "Vibrant" color scheme
proj_colors = [
"#CC3311", # red
"#009988", # teal
"#0077BB", # blue
"#EE3377", # magenta
"#EE7733", # orange
"#33BBEE", # cyan
]
trace_color = "#CCBB44" # yellow
after_color, after_name = "#228833", "green"
type_titles = DEFAULTS["titles"]
last_ax = [None] * 2
first_ax = dict()
pe_kwargs = dict(show=False, draw=False)
for ch_type, these_projs in proj_by_type.items():
these_idxs, these_projs = zip(*these_projs)
ch_names = ch_names_by_type[ch_type]
idx = np.where(
[np.isin(ch_names, proj["data"]["col_names"]).all() for proj in these_projs]
)[0]
used[idx] += 1
count = len(these_projs)
for proj in these_projs:
sub_idx = [proj["data"]["col_names"].index(name) for name in ch_names]
proj["data"]["data"] = proj["data"]["data"][:, sub_idx]
proj["data"]["col_names"] = ch_names
ba_ax = plt.subplot2grid(shape, (ri, 0), colspan=cs_trace, fig=fig)
topo_axes = [
plt.subplot2grid(
shape, (ri, ci * cs_topo + cs_trace), colspan=cs_topo, fig=fig
)
for ci in range(count)
]
tr_ax = plt.subplot2grid(
shape, (ri, n_col - cs_trace), colspan=cs_trace, fig=fig
)
# topomaps
_plot_projs_topomap(these_projs, info=info, axes=topo_axes, **topomap_kwargs)
for idx, proj, ax_ in zip(these_idxs, these_projs, topo_axes):
ax_.set_title("") # could use proj['desc'] but it's long
ax_.set_xlabel(f"projs[{idx}]", fontsize="small")
unit = DEFAULTS["units"][ch_type]
# traces
this_evoked = evoked.copy().pick(ch_names)
p = np.concatenate([p["data"]["data"] for p in these_projs])
assert p.shape == (len(these_projs), len(this_evoked.data))
traces = np.dot(p, this_evoked.data)
traces *= np.sign(np.mean(np.dot(this_evoked.data, traces.T), 0))[:, np.newaxis]
if picks_trace is not None:
ch_traces = evoked.data[picks_trace]
ch_traces -= np.mean(ch_traces, axis=1, keepdims=True)
ch_traces /= np.abs(ch_traces).max()
_plot_evoked(
this_evoked, picks="all", axes=[tr_ax], **pe_kwargs, spatial_colors=False
)
for line in tr_ax.lines:
line.set(lw=0.5, zorder=3)
for t in list(tr_ax.texts):
t.remove()
scale = 0.8 * np.abs(tr_ax.get_ylim()).max()
hs, labels = list(), list()
traces /= np.abs(traces).max() # uniformly scaled
for ti, trace in enumerate(traces):
hs.append(
tr_ax.plot(
this_evoked.times,
trace * scale,
color=proj_colors[ti % len(proj_colors)],
zorder=5,
)[0]
)
labels.append(f"projs[{these_idxs[ti]}]")
traces /= np.abs(traces).max(1, keepdims=True) # independently
for ti, trace in enumerate(traces):
tr_ax.plot(
this_evoked.times,
trace * scale,
color=proj_colors[ti % len(proj_colors)],
zorder=3.5,
ls="--",
lw=1.0,
alpha=0.75,
)
if picks_trace is not None:
trace_ch = [evoked.ch_names[pick] for pick in picks_trace]
if len(picks_trace) == 1:
trace_ch = trace_ch[0]
hs.append(
tr_ax.plot(
this_evoked.times,
ch_traces.T * scale,
color=trace_color,
lw=3,
zorder=4,
alpha=0.75,
)[0]
)
labels.append(str(trace_ch))
tr_ax.set(title="", xlabel="", ylabel="")
# This will steal space from the subplots in a constrained layout
# https://matplotlib.org/3.5.0/tutorials/intermediate/constrainedlayout_guide.html#legends # noqa: E501
tr_ax.legend(
hs,
labels,
loc="center left",
borderaxespad=0.05,
bbox_to_anchor=[1.05, 0.5],
)
last_ax[1] = tr_ax
key = "Projected time course"
if key not in first_ax:
first_ax[key] = tr_ax
# Before and after traces
_plot_evoked(this_evoked, picks="all", axes=[ba_ax], **pe_kwargs)
for line in ba_ax.lines:
line.set(lw=0.5, zorder=3)
loff = len(ba_ax.lines)
this_proj_evoked = this_evoked.copy().add_proj(these_projs)
# with meg='combined' any existing mag projectors (those already part
# of evoked before we add_proj above) will have greatly
# reduced power, so we ignore the warning about this issue
this_proj_evoked.apply_proj(verbose="error")
_plot_evoked(this_proj_evoked, picks="all", axes=[ba_ax], **pe_kwargs)
for line in ba_ax.lines[loff:]:
line.set(lw=0.5, zorder=4, color=after_color)
for t in list(ba_ax.texts):
t.remove()
ba_ax.set(title="", xlabel="")
ba_ax.set(ylabel=f"{type_titles[ch_type]}\n{unit}")
last_ax[0] = ba_ax
key = f"Before (black) and after ({after_name})"
if key not in first_ax:
first_ax[key] = ba_ax
ri += 1
for ax in last_ax:
ax.set(xlabel="Time (s)")
for title, ax in first_ax.items():
ax.set_title(title, fontsize="medium")
plt_show(show)
return fig

81
mne/viz/_scraper.py Normal file
View File

@@ -0,0 +1,81 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from contextlib import contextmanager
from ..utils import _pl
from .backends._utils import _pixmap_to_ndarray
class _MNEQtBrowserScraper:
def __repr__(self):
return "<MNEQtBrowserScraper>"
def __call__(self, block, block_vars, gallery_conf):
import mne_qt_browser
from sphinx_gallery.scrapers import figure_rst
if gallery_conf["builder_name"] != "html":
return ""
img_fnames = list()
inst = None
n_plot = 0
for gui in list(mne_qt_browser._browser_instances):
try:
scraped = getattr(gui, "_scraped", False)
except Exception: # super __init__ not called, perhaps stale?
scraped = True
if scraped:
continue
gui._scraped = True # monkey-patch but it's easy enough
n_plot += 1
img_fnames.append(next(block_vars["image_path_iterator"]))
pixmap, inst = _mne_qt_browser_screenshot(gui, inst)
pixmap.save(img_fnames[-1])
# child figures
for fig in gui.mne.child_figs:
# For now we only support Selection
if not hasattr(fig, "channel_fig"):
continue
fig = fig.channel_fig
img_fnames.append(next(block_vars["image_path_iterator"]))
fig.savefig(img_fnames[-1])
gui.close()
del gui, pixmap
if not len(img_fnames):
return ""
for _ in range(2):
inst.processEvents()
return figure_rst(img_fnames, gallery_conf["src_dir"], f"Raw plot{_pl(n_plot)}")
@contextmanager
def _screenshot_mode(browser):
if need_zen := browser.mne.scrollbars_visible:
browser._toggle_zenmode()
try:
yield
finally:
if need_zen:
browser._toggle_zenmode()
def _mne_qt_browser_screenshot(browser, inst=None, return_type="pixmap"):
from mne_qt_browser._pg_figure import QApplication
if getattr(browser, "load_thread", None) is not None:
if browser.load_thread.isRunning():
browser.load_thread.wait(30000)
if inst is None:
inst = QApplication.instance()
# processEvents to make sure our progressBar is updated
with _screenshot_mode(browser):
for _ in range(2):
inst.processEvents()
pixmap = browser.grab()
assert return_type in ("pixmap", "ndarray")
if return_type == "ndarray":
return _pixmap_to_ndarray(pixmap)
else:
return pixmap, inst

View File

@@ -0,0 +1,7 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
"""Visualization backend."""
from . import renderer

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

1358
mne/viz/backends/_pyvista.py Normal file

File diff suppressed because it is too large Load Diff

1852
mne/viz/backends/_qt.py Normal file

File diff suppressed because it is too large Load Diff

421
mne/viz/backends/_utils.py Normal file
View File

@@ -0,0 +1,421 @@
#
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import collections.abc
import functools
import os
import platform
import signal
import sys
from colorsys import rgb_to_hls
from contextlib import contextmanager
from ctypes import c_char_p, c_void_p, cdll
from pathlib import Path
import numpy as np
from ...fixes import _compare_version
from ...utils import _check_qt_version, _validate_type, logger, warn
from ..utils import _get_cmap
VALID_BROWSE_BACKENDS = (
"qt",
"matplotlib",
)
VALID_3D_BACKENDS = (
"pyvistaqt", # default 3d backend
"notebook",
)
ALLOWED_QUIVER_MODES = ("2darrow", "arrow", "cone", "cylinder", "sphere", "oct")
_ICONS_PATH = Path(__file__).parents[2] / "icons"
def _get_colormap_from_array(
colormap=None, normalized_colormap=False, default_colormap="coolwarm"
):
from matplotlib.colors import ListedColormap
if colormap is None:
cmap = _get_cmap(default_colormap)
elif isinstance(colormap, str):
cmap = _get_cmap(colormap)
elif normalized_colormap:
cmap = ListedColormap(colormap)
else:
cmap = ListedColormap(np.array(colormap) / 255.0)
return cmap
def _check_color(color):
from matplotlib.colors import colorConverter
if isinstance(color, str):
color = colorConverter.to_rgb(color)
elif isinstance(color, collections.abc.Iterable):
np_color = np.array(color)
if np_color.size % 3 != 0 and np_color.size % 4 != 0:
raise ValueError("The expected valid format is RGB or RGBA.")
if np_color.dtype in (np.int64, np.int32):
if (np_color < 0).any() or (np_color > 255).any():
raise ValueError("Values out of range [0, 255].")
elif np_color.dtype == np.float64:
if (np_color < 0.0).any() or (np_color > 1.0).any():
raise ValueError("Values out of range [0.0, 1.0].")
else:
raise TypeError(
"Expected data type is `np.int64`, `np.int32`, or `np.float64` but "
f"{np_color.dtype} was given."
)
else:
raise TypeError(
f"Expected type is `str` or iterable but {type(color)} was given."
)
return color
def _alpha_blend_background(ctable, background_color):
alphas = ctable[:, -1][:, np.newaxis] / 255.0
use_table = ctable.copy()
use_table[:, -1] = 255.0
return (use_table * alphas) + background_color * (1 - alphas)
@functools.lru_cache(1)
def _qt_init_icons():
from qtpy.QtGui import QIcon
QIcon.setThemeSearchPaths([str(_ICONS_PATH)] + QIcon.themeSearchPaths())
QIcon.setFallbackThemeName("light")
return str(_ICONS_PATH)
@contextmanager
def _qt_disable_paint(widget):
paintEvent = widget.paintEvent
widget.paintEvent = lambda *args, **kwargs: None
try:
yield
finally:
widget.paintEvent = paintEvent
_QT_ICON_KEYS = dict(app=None)
def _init_mne_qtapp(enable_icon=True, pg_app=False, splash=False):
"""Get QApplication-instance for MNE-Python.
Parameter
---------
enable_icon: bool
If to set an MNE-icon for the app.
pg_app: bool
If to create the QApplication with pyqtgraph. For an until know
undiscovered reason the pyqtgraph-browser won't show without
mkQApp from pyqtgraph.
splash : bool | str
If not False, display a splash screen. If str, set the message
to the given string.
Returns
-------
app : ``qtpy.QtWidgets.QApplication``
Instance of QApplication.
splash : ``qtpy.QtWidgets.QSplashScreen``
Instance of QSplashScreen. Only returned if splash is True or a
string.
"""
from qtpy.QtCore import Qt
from qtpy.QtGui import QGuiApplication, QIcon, QPixmap
from qtpy.QtWidgets import QApplication, QSplashScreen
app_name = "MNE-Python"
organization_name = "MNE"
# Fix from cbrnr/mnelab for app name in menu bar
# This has to come *before* the creation of the QApplication to work.
# It also only affects the title bar, not the application dock.
# There seems to be no way to change the application dock from "python"
# at runtime.
if sys.platform.startswith("darwin"):
try:
# set bundle name on macOS (app name shown in the menu bar)
from Foundation import NSBundle
bundle = NSBundle.mainBundle()
info = bundle.localizedInfoDictionary() or bundle.infoDictionary()
if "CFBundleName" not in info:
info["CFBundleName"] = app_name
except ModuleNotFoundError:
pass
# First we need to check to make sure the display is valid, otherwise
# Qt might segfault on us
app = QApplication.instance()
if not (app or _display_is_valid()):
raise RuntimeError("Cannot connect to a valid display")
if pg_app:
from pyqtgraph import mkQApp
old_argv = sys.argv
try:
sys.argv = []
app = mkQApp(app_name)
finally:
sys.argv = old_argv
elif not app:
app = QApplication([app_name])
app.setApplicationName(app_name)
app.setOrganizationName(organization_name)
qt_version = _check_qt_version(check_usable_display=False)
# HiDPI is enabled by default in Qt6, requires to be explicitly set for Qt5
if _compare_version(qt_version, "<", "6.0"):
app.setAttribute(Qt.AA_UseHighDpiPixmaps)
if enable_icon or splash:
icons_path = _qt_init_icons()
if (
enable_icon
and app.windowIcon().cacheKey() != _QT_ICON_KEYS["app"]
and app.windowIcon().isNull() # don't overwrite existing icon (e.g. MNELAB)
):
# Set icon
kind = "bigsur_" if platform.mac_ver()[0] >= "10.16" else "default_"
icon = QIcon(f"{icons_path}/mne_{kind}icon.png")
app.setWindowIcon(icon)
_QT_ICON_KEYS["app"] = app.windowIcon().cacheKey()
out = app
if splash:
pixmap = QPixmap(f"{icons_path}/mne_splash.png")
pixmap.setDevicePixelRatio(QGuiApplication.primaryScreen().devicePixelRatio())
args = (pixmap,)
if _should_raise_window():
args += (Qt.WindowStaysOnTopHint,)
qsplash = QSplashScreen(*args)
qsplash.setAttribute(Qt.WA_ShowWithoutActivating, True)
if isinstance(splash, str):
alignment = int(Qt.AlignBottom | Qt.AlignHCenter)
qsplash.showMessage(splash, alignment=alignment, color=Qt.white)
qsplash.show()
app.processEvents()
out = (out, qsplash)
return out
def _display_is_valid():
# Adapted from matplotilb _c_internal_utils.py
if sys.platform != "linux":
return True
if os.getenv("DISPLAY"): # if it's not there, don't bother
libX11 = cdll.LoadLibrary("libX11.so.6")
libX11.XOpenDisplay.restype = c_void_p
libX11.XOpenDisplay.argtypes = [c_char_p]
display = libX11.XOpenDisplay(None)
if display is not None:
libX11.XCloseDisplay.argtypes = [c_void_p]
libX11.XCloseDisplay(display)
return True
# not found, try Wayland
if os.getenv("WAYLAND_DISPLAY"):
libwayland = cdll.LoadLibrary("libwayland-client.so.0")
if libwayland is not None:
if all(
hasattr(libwayland, f"wl_display_{kind}connect") for kind in ("", "dis")
):
libwayland.wl_display_connect.restype = c_void_p
libwayland.wl_display_connect.argtypes = [c_char_p]
display = libwayland.wl_display_connect(None)
if display:
libwayland.wl_display_disconnect.argtypes = [c_void_p]
libwayland.wl_display_disconnect(display)
return True
return False
# https://stackoverflow.com/questions/5160577/ctrl-c-doesnt-work-with-pyqt
def _qt_app_exec(app):
# adapted from matplotlib
old_signal = signal.getsignal(signal.SIGINT)
is_python_signal_handler = old_signal is not None
if is_python_signal_handler:
signal.signal(signal.SIGINT, signal.SIG_DFL)
try:
# Make IPython Console accessible again in Spyder
app.lastWindowClosed.connect(app.quit)
app.exec_()
finally:
# reset the SIGINT exception handler
if is_python_signal_handler:
signal.signal(signal.SIGINT, old_signal)
def _qt_detect_theme():
try:
import darkdetect
theme = darkdetect.theme().lower()
except ModuleNotFoundError:
logger.info(
'For automatic theme detection, "darkdetect" has to'
" be installed! You can install it with "
"`pip install darkdetect`"
)
theme = "light"
except Exception:
theme = "light"
return theme
def _qt_get_stylesheet(theme):
_validate_type(theme, ("path-like",), "theme")
theme = str(theme)
stylesheet = "" # no stylesheet
if theme in ("auto", "dark", "light"):
if theme == "auto":
return stylesheet
assert theme in ("dark", "light")
system_theme = _qt_detect_theme()
if theme == system_theme:
return stylesheet
_, api = _check_qt_version(return_api=True)
# On macOS or Qt 6, we shouldn't need to set anything when the requested
# theme matches that of the current OS state
try:
import qdarkstyle
except ModuleNotFoundError:
logger.info(
f'To use {theme} mode when in {system_theme} mode, "qdarkstyle" has'
"to be installed! You can install it with:\n"
"pip install qdarkstyle\n"
)
else:
if api in ("PySide6", "PyQt6") and _compare_version(
qdarkstyle.__version__, "<", "3.2.3"
):
warn(
f"Setting theme={repr(theme)} is not supported for {api} in "
f"qdarkstyle {qdarkstyle.__version__}, it will be ignored. "
"Consider upgrading qdarkstyle to >=3.2.3."
)
else:
stylesheet = qdarkstyle.load_stylesheet(
getattr(
getattr(qdarkstyle, theme).palette,
f"{theme.capitalize()}Palette",
)
)
return stylesheet
else:
try:
file = open(theme)
except OSError:
warn(
"Requested theme file not found, will use light instead: "
f"{repr(theme)}"
)
else:
with file as fid:
stylesheet = fid.read()
return stylesheet
def _should_raise_window():
from matplotlib import rcParams
return rcParams["figure.raise_window"]
def _qt_raise_window(widget):
# Set raise_window like matplotlib if possible
if _should_raise_window():
widget.activateWindow()
widget.raise_()
def _qt_is_dark(widget):
# Ideally this would use CIELab, but this should be good enough
win = widget.window()
bgcolor = win.palette().color(win.backgroundRole()).getRgbF()[:3]
return rgb_to_hls(*bgcolor)[1] < 0.5
def _pixmap_to_ndarray(pixmap):
from qtpy.QtGui import QImage
img = pixmap.toImage()
img = img.convertToFormat(QImage.Format.Format_RGBA8888)
ptr = img.bits()
count = img.height() * img.width() * 4
if hasattr(ptr, "setsize"): # PyQt
ptr.setsize(count)
data = np.frombuffer(ptr, dtype=np.uint8, count=count).copy()
data.shape = (img.height(), img.width(), 4)
return data / 255.0
def _notebook_vtk_works():
if sys.platform != "linux":
return True
# check if it's OSMesa -- if it is, continue
try:
from vtkmodules import vtkRenderingOpenGL2
vtkRenderingOpenGL2.vtkOSOpenGLRenderWindow
except Exception:
pass
else:
return True # has vtkOSOpenGLRenderWindow (OSMesa build)
# if it's not OSMesa, we need to check display validity
if _display_is_valid():
return True
return False
def _qt_safe_window(
*, splash="figure.splash", window="figure.plotter.app_window", always_close=True
):
def dec(meth, splash=splash, always_close=always_close):
@functools.wraps(meth)
def func(self, *args, **kwargs):
close_splash = always_close
error = False
try:
meth(self, *args, **kwargs)
except Exception:
close_splash = error = True
raise
finally:
for attr, do_close in ((splash, close_splash), (window, error)):
if attr is None or not do_close:
continue
parent = self
name = attr.split(".")[-1]
try:
for n in attr.split(".")[:-1]:
parent = getattr(parent, n)
if name:
widget = getattr(parent, name, False)
else: # empty string means "self"
widget = parent
if widget:
widget.close()
del widget
except Exception:
pass
finally:
try:
delattr(parent, name)
except Exception:
pass
return func
return dec

View File

@@ -0,0 +1,587 @@
"""Core visualization operations."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import importlib
import time
from contextlib import contextmanager
from functools import partial
import numpy as np
from ...utils import (
_auto_weakref,
_check_option,
_validate_type,
fill_doc,
get_config,
logger,
verbose,
)
from .._3d import _get_3d_option
from ..utils import safe_event
from ._utils import VALID_3D_BACKENDS
MNE_3D_BACKEND = None
MNE_3D_BACKEND_TESTING = False
_backend_name_map = dict(
pyvistaqt="._qt",
notebook="._notebook",
)
backend = None
def _reload_backend(backend_name):
global backend
backend = importlib.import_module(
name=_backend_name_map[backend_name], package="mne.viz.backends"
)
logger.info(f"Using {backend_name} 3d backend.")
def _get_backend():
_get_3d_backend()
return backend
def _get_renderer(*args, **kwargs):
_get_3d_backend()
return backend._Renderer(*args, **kwargs)
def _check_3d_backend_name(backend_name):
_validate_type(backend_name, str, "backend_name")
backend_name = "pyvistaqt" if backend_name == "pyvista" else backend_name
_check_option("backend_name", backend_name, VALID_3D_BACKENDS)
return backend_name
@verbose
def set_3d_backend(backend_name, verbose=None):
"""Set the 3D backend for MNE.
The backend will be set as specified and operations will use
that backend.
Parameters
----------
backend_name : str
The 3d backend to select. See Notes for the capabilities of each
backend (``'pyvistaqt'`` and ``'notebook'``).
.. versionchanged:: 0.24
The ``'pyvista'`` backend was renamed ``'pyvistaqt'``.
%(verbose)s
Returns
-------
old_backend_name : str | None
The old backend that was in use.
Notes
-----
To use PyVista, set ``backend_name`` to ``pyvistaqt`` but the value
``pyvista`` is still supported for backward compatibility.
This table shows the capabilities of each backend ("" for full support,
and "-" for partial support):
.. table::
:widths: auto
+--------------------------------------+-----------+----------+
| **3D function:** | pyvistaqt | notebook |
+======================================+===========+==========+
| :func:`plot_vector_source_estimates` | ✓ | ✓ |
+--------------------------------------+-----------+----------+
| :func:`plot_source_estimates` | ✓ | ✓ |
+--------------------------------------+-----------+----------+
| :func:`plot_alignment` | ✓ | ✓ |
+--------------------------------------+-----------+----------+
| :func:`plot_sparse_source_estimates` | ✓ | ✓ |
+--------------------------------------+-----------+----------+
| :func:`plot_evoked_field` | ✓ | ✓ |
+--------------------------------------+-----------+----------+
| :func:`snapshot_brain_montage` | ✓ | ✓ |
+--------------------------------------+-----------+----------+
| :func:`link_brains` | ✓ | |
+--------------------------------------+-----------+----------+
+--------------------------------------+-----------+----------+
| **Feature:** |
+--------------------------------------+-----------+----------+
| Large data | ✓ | ✓ |
+--------------------------------------+-----------+----------+
| Opacity/transparency | ✓ | ✓ |
+--------------------------------------+-----------+----------+
| Support geometric glyph | ✓ | ✓ |
+--------------------------------------+-----------+----------+
| Smooth shading | ✓ | ✓ |
+--------------------------------------+-----------+----------+
| Subplotting | ✓ | ✓ |
+--------------------------------------+-----------+----------+
| Inline plot in Jupyter Notebook | | ✓ |
+--------------------------------------+-----------+----------+
| Inline plot in JupyterLab | | ✓ |
+--------------------------------------+-----------+----------+
| Inline plot in Google Colab | | |
+--------------------------------------+-----------+----------+
| Toolbar | ✓ | ✓ |
+--------------------------------------+-----------+----------+
"""
global MNE_3D_BACKEND
old_backend_name = MNE_3D_BACKEND
backend_name = _check_3d_backend_name(backend_name)
if MNE_3D_BACKEND != backend_name:
_reload_backend(backend_name)
MNE_3D_BACKEND = backend_name
return old_backend_name
def get_3d_backend():
"""Return the 3D backend currently used.
Returns
-------
backend_used : str | None
The 3d backend currently in use. If no backend is found,
returns ``None``.
.. versionchanged:: 0.24
The ``'pyvista'`` backend has been renamed ``'pyvistaqt'``, so
``'pyvista'`` is no longer returned by this function.
"""
try:
backend = _get_3d_backend()
except RuntimeError as exc:
backend = None
logger.info(str(exc))
return backend
def _get_3d_backend():
"""Load and return the current 3d backend."""
global MNE_3D_BACKEND
if MNE_3D_BACKEND is None:
MNE_3D_BACKEND = get_config(key="MNE_3D_BACKEND", default=None)
if MNE_3D_BACKEND is None: # try them in order
errors = dict()
for name in VALID_3D_BACKENDS:
try:
_reload_backend(name)
except ImportError as exc:
errors[name] = str(exc)
else:
MNE_3D_BACKEND = name
break
else:
raise RuntimeError(
"Could not load any valid 3D backend\n"
+ "\n".join(f"{key}: {val}" for key, val in errors.items())
+ "\n".join(
(
"\n\n install pyvistaqt, using pip or conda:",
"'pip install pyvistaqt'",
"'conda install -c conda-forge pyvistaqt'",
"\n or install ipywidgets, "
+ "if using a notebook backend",
"'pip install ipywidgets'",
"'conda install -c conda-forge ipywidgets'",
)
)
)
else:
MNE_3D_BACKEND = _check_3d_backend_name(MNE_3D_BACKEND)
_reload_backend(MNE_3D_BACKEND)
MNE_3D_BACKEND = _check_3d_backend_name(MNE_3D_BACKEND)
return MNE_3D_BACKEND
@contextmanager
def use_3d_backend(backend_name):
"""Create a 3d visualization context using the designated backend.
See :func:`mne.viz.set_3d_backend` for more details on the available
3d backends and their capabilities.
Parameters
----------
backend_name : {'pyvistaqt', 'notebook'}
The 3d backend to use in the context.
"""
old_backend = set_3d_backend(backend_name)
try:
yield
finally:
if old_backend is not None:
try:
set_3d_backend(old_backend)
except Exception:
pass
@contextmanager
def _use_test_3d_backend(backend_name, interactive=False):
"""Create a testing viz context.
Parameters
----------
backend_name : str
The 3d backend to use in the context.
interactive : bool
If True, ensure interactive elements are accessible.
"""
with _actors_invisible():
with use_3d_backend(backend_name):
with backend._testing_context(interactive):
yield
@contextmanager
def _actors_invisible():
global MNE_3D_BACKEND_TESTING
orig_testing = MNE_3D_BACKEND_TESTING
MNE_3D_BACKEND_TESTING = True
try:
yield
finally:
MNE_3D_BACKEND_TESTING = orig_testing
@fill_doc
def set_3d_view(
figure,
azimuth=None,
elevation=None,
focalpoint=None,
distance=None,
roll=None,
):
"""Configure the view of the given scene.
Parameters
----------
figure : object
The scene which is modified.
%(azimuth)s
%(elevation)s
%(focalpoint)s
%(distance)s
%(roll)s
"""
backend._set_3d_view(
figure=figure,
azimuth=azimuth,
elevation=elevation,
focalpoint=focalpoint,
distance=distance,
roll=roll,
)
@fill_doc
def set_3d_title(figure, title, size=40, *, color="white", position="upper_left"):
"""Configure the title of the given scene.
Parameters
----------
figure : object
The scene which is modified.
title : str
The title of the scene.
size : int
The size of the title.
color : matplotlib color
The color of the title.
.. versionadded:: 1.9
position : str
The position to use, e.g., "upper_left". See
:meth:`pyvista.Plotter.add_text` for details.
.. versionadded:: 1.9
Returns
-------
text : object
The text object returned by the given backend.
.. versionadded:: 1.0
"""
return backend._set_3d_title(
figure=figure, title=title, size=size, color=color, position=position
)
def create_3d_figure(
size,
bgcolor=(0, 0, 0),
smooth_shading=None,
handle=None,
*,
scene=True,
show=False,
title="MNE 3D Figure",
):
"""Return an empty figure based on the current 3d backend.
.. warning:: Proceed with caution when the renderer object is
returned (with ``scene=False``) because the _Renderer
API is not necessarily stable enough for production,
it's still actively in development.
Parameters
----------
size : tuple
The dimensions of the 3d figure (width, height).
bgcolor : tuple
The color of the background.
smooth_shading : bool | None
Whether to enable smooth shading. If ``None``, uses the config value
``MNE_3D_OPTION_SMOOTH_SHADING``. Defaults to ``None``.
handle : int | None
The figure identifier.
scene : bool
If True (default), the returned object is the Figure3D. If False,
an advanced, undocumented Renderer object is returned (the API is not
stable or documented, so this is not recommended).
show : bool
If True, show the renderer immediately.
.. versionadded:: 1.0
title : str
The window title to use (if applicable).
.. versionadded:: 1.9
Returns
-------
figure : instance of Figure3D or ``Renderer``
The requested empty figure or renderer, depending on ``scene``.
"""
_validate_type(smooth_shading, (bool, None), "smooth_shading")
if smooth_shading is None:
smooth_shading = _get_3d_option("smooth_shading")
renderer = _get_renderer(
fig=handle,
size=size,
bgcolor=bgcolor,
smooth_shading=smooth_shading,
show=show,
name=title,
)
if scene:
return renderer.scene()
else:
return renderer
def close_3d_figure(figure):
"""Close the given scene.
Parameters
----------
figure : object
The scene which needs to be closed.
"""
backend._close_3d_figure(figure)
def close_all_3d_figures():
"""Close all the scenes of the current 3d backend."""
backend._close_all()
def get_brain_class():
"""Return the proper Brain class based on the current 3d backend.
Returns
-------
brain : object
The Brain class corresponding to the current 3d backend.
"""
from ...viz._brain import Brain
return Brain
class _TimeInteraction:
"""Mixin enabling time interaction controls."""
def _enable_time_interaction(
self,
fig,
current_time_func,
times,
init_playback_speed=0.01,
playback_speed_range=(0.01, 0.1),
):
from ..ui_events import (
PlaybackSpeed,
TimeChange,
publish,
subscribe,
)
self._fig = fig
self._current_time_func = current_time_func
self._times = times
self._init_time = current_time_func()
self._init_playback_speed = init_playback_speed
if not hasattr(self, "_dock"):
self._dock_initialize()
if not hasattr(self, "_tool_bar") or self._tool_bar is None:
self._tool_bar_initialize(name="Toolbar")
if not hasattr(self, "_widgets"):
self._widgets = dict()
# Dock widgets
@_auto_weakref
def publish_time_change(time_index):
publish(
fig,
TimeChange(time=np.interp(time_index, np.arange(len(times)), times)),
)
layout = self._dock_add_group_box("")
self._widgets["time_slider"] = self._dock_add_slider(
name="Time (s)",
value=np.interp(current_time_func(), times, np.arange(len(times))),
rng=[0, len(times) - 1],
double=True,
callback=publish_time_change,
compact=False,
layout=layout,
)
hlayout = self._dock_add_layout(vertical=False)
self._widgets["min_time"] = self._dock_add_label("-", layout=hlayout)
self._dock_add_stretch(hlayout)
self._widgets["current_time"] = self._dock_add_label(value="x", layout=hlayout)
self._dock_add_stretch(hlayout)
self._widgets["max_time"] = self._dock_add_label(value="+", layout=hlayout)
self._layout_add_widget(layout, hlayout)
self._widgets["min_time"].set_value(f"{times[0]: .3f}")
self._widgets["current_time"].set_value(f"{current_time_func(): .3f}")
self._widgets["max_time"].set_value(f"{times[-1]: .3f}")
@_auto_weakref
def publish_playback_speed(speed):
publish(fig, PlaybackSpeed(speed=speed))
self._widgets["playback_speed"] = self._dock_add_spin_box(
name="Speed",
value=init_playback_speed,
rng=playback_speed_range,
callback=publish_playback_speed,
layout=layout,
)
# Tool bar buttons
self._widgets["reset"] = self._tool_bar_add_button(
name="reset", desc="Reset", func=self._reset_time
)
self._widgets["play"] = self._tool_bar_add_play_button(
name="play",
desc="Play/Pause",
func=self._toggle_playback,
shortcut=" ",
)
# Configure playback
self._playback = False
self._playback_initialize(
func=self._play,
timeout=17,
value=np.interp(current_time_func(), times, np.arange(len(times))),
rng=[0, len(times) - 1],
time_widget=self._widgets["time_slider"],
play_widget=self._widgets["play"],
)
# Keyboard shortcuts
@_auto_weakref
def shift_time(direction):
amount = self._widgets["playback_speed"].get_value()
publish(
self._fig,
TimeChange(time=self._current_time_func() + direction * amount),
)
if self.plotter.iren is not None:
self.plotter.add_key_event("n", partial(shift_time, direction=1))
self.plotter.add_key_event("b", partial(shift_time, direction=-1))
# Subscribe to relevant UI events
subscribe(fig, "time_change", self._on_time_change)
subscribe(fig, "playback_speed", self._on_playback_speed)
def _on_time_change(self, event):
"""Respond to time_change UI event."""
from ..ui_events import disable_ui_events
new_time = np.clip(event.time, self._times[0], self._times[-1])
new_time_idx = np.interp(new_time, self._times, np.arange(len(self._times)))
with disable_ui_events(self._fig):
self._widgets["time_slider"].set_value(new_time_idx)
self._widgets["current_time"].set_value(f"{new_time:.3f}")
def _on_playback_speed(self, event):
"""Respond to playback_speed UI event."""
from ..ui_events import disable_ui_events
with disable_ui_events(self._fig):
self._widgets["playback_speed"].set_value(event.speed)
def _toggle_playback(self, value=None):
"""Toggle time playback."""
from ..ui_events import TimeChange, publish
if value is None:
self._playback = not self._playback
else:
self._playback = value
if self._playback:
self._tool_bar_update_button_icon(name="play", icon_name="pause")
if self._current_time_func() == self._times[-1]: # start over
publish(self._fig, TimeChange(time=self._times[0]))
self._last_tick = time.time()
else:
self._tool_bar_update_button_icon(name="play", icon_name="play")
def _reset_time(self):
"""Reset time and playback speed to initial values."""
from ..ui_events import PlaybackSpeed, TimeChange, publish
publish(self._fig, TimeChange(time=self._init_time))
publish(self._fig, PlaybackSpeed(speed=self._init_playback_speed))
@safe_event
def _play(self):
if self._playback:
try:
self._advance()
except Exception:
self._toggle_playback(value=False)
raise
def _advance(self):
from ..ui_events import TimeChange, publish
this_time = time.time()
delta = this_time - self._last_tick
self._last_tick = time.time()
time_shift = delta * self._widgets["playback_speed"].get_value()
new_time = min(self._current_time_func() + time_shift, self._times[-1])
publish(self._fig, TimeChange(time=new_time))
if new_time == self._times[-1]:
self._toggle_playback(value=False)

469
mne/viz/circle.py Normal file
View File

@@ -0,0 +1,469 @@
"""Functions to plot on circle as for connectivity."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from functools import partial
from itertools import cycle
import numpy as np
from ..utils import _validate_type
from .utils import _get_cmap, plt_show
def circular_layout(
node_names,
node_order,
start_pos=90,
start_between=True,
group_boundaries=None,
group_sep=10,
):
"""Create layout arranging nodes on a circle.
Parameters
----------
node_names : list of str
Node names.
node_order : list of str
List with node names defining the order in which the nodes are
arranged. Must have the elements as node_names but the order can be
different. The nodes are arranged clockwise starting at "start_pos"
degrees.
start_pos : float
Angle in degrees that defines where the first node is plotted.
start_between : bool
If True, the layout starts with the position between the nodes. This is
the same as adding "180. / len(node_names)" to start_pos.
group_boundaries : None | array-like
List of of boundaries between groups at which point a "group_sep" will
be inserted. E.g. "[0, len(node_names) / 2]" will create two groups.
group_sep : float
Group separation angle in degrees. See "group_boundaries".
Returns
-------
node_angles : array, shape=(n_node_names,)
Node angles in degrees.
"""
n_nodes = len(node_names)
if len(node_order) != n_nodes:
raise ValueError("node_order has to be the same length as node_names")
if group_boundaries is not None:
boundaries = np.array(group_boundaries, dtype=np.int64)
if np.any(boundaries >= n_nodes) or np.any(boundaries < 0):
raise ValueError('"group_boundaries" has to be between 0 and n_nodes - 1.')
if len(boundaries) > 1 and np.any(np.diff(boundaries) <= 0):
raise ValueError('"group_boundaries" must have non-decreasing values.')
n_group_sep = len(group_boundaries)
else:
n_group_sep = 0
boundaries = None
# convert it to a list with indices
node_order = [node_order.index(name) for name in node_names]
node_order = np.array(node_order)
if len(np.unique(node_order)) != n_nodes:
raise ValueError("node_order has repeated entries")
node_sep = (360.0 - n_group_sep * group_sep) / n_nodes
if start_between:
start_pos += node_sep / 2
if boundaries is not None and boundaries[0] == 0:
# special case when a group separator is at the start
start_pos += group_sep / 2
boundaries = boundaries[1:] if n_group_sep > 1 else None
node_angles = np.ones(n_nodes, dtype=np.float64) * node_sep
node_angles[0] = start_pos
if boundaries is not None:
node_angles[boundaries] += group_sep
node_angles = np.cumsum(node_angles)[node_order]
return node_angles
def _plot_connectivity_circle_onpick(
event,
fig=None,
ax=None,
indices=None,
n_nodes=0,
node_angles=None,
ylim=(9, 10),
):
"""Isolate connections around a single node when user left clicks a node.
On right click, resets all connections.
"""
if event.inaxes != ax:
return
if event.button == 1: # left click
# click must be near node radius
if not ylim[0] <= event.ydata <= ylim[1]:
return
# all angles in range [0, 2*pi]
node_angles = node_angles % (np.pi * 2)
node = np.argmin(np.abs(event.xdata - node_angles))
patches = event.inaxes.patches
for ii, (x, y) in enumerate(zip(indices[0], indices[1])):
patches[ii].set_visible(node in [x, y])
fig.canvas.draw()
elif event.button == 3: # right click
patches = event.inaxes.patches
for ii in range(np.size(indices, axis=1)):
patches[ii].set_visible(True)
fig.canvas.draw()
def _plot_connectivity_circle(
con,
node_names,
indices=None,
n_lines=None,
node_angles=None,
node_width=None,
node_height=None,
node_colors=None,
facecolor="black",
textcolor="white",
node_edgecolor="black",
linewidth=1.5,
colormap="hot",
vmin=None,
vmax=None,
colorbar=True,
title=None,
colorbar_size=None,
colorbar_pos=None,
fontsize_title=12,
fontsize_names=8,
fontsize_colorbar=8,
padding=6.0,
ax=None,
interactive=True,
node_linewidth=2.0,
show=True,
):
import matplotlib.patches as m_patches
import matplotlib.path as m_path
import matplotlib.pyplot as plt
from matplotlib.projections.polar import PolarAxes
_validate_type(ax, (None, PolarAxes))
n_nodes = len(node_names)
if node_angles is not None:
if len(node_angles) != n_nodes:
raise ValueError("node_angles has to be the same length as node_names")
# convert it to radians
node_angles = node_angles * np.pi / 180
else:
# uniform layout on unit circle
node_angles = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False)
if node_width is None:
# widths correspond to the minimum angle between two nodes
dist_mat = node_angles[None, :] - node_angles[:, None]
dist_mat[np.diag_indices(n_nodes)] = 1e9
node_width = np.min(np.abs(dist_mat))
else:
node_width = node_width * np.pi / 180
if node_height is None:
node_height = 1.0
if node_colors is not None:
if len(node_colors) < n_nodes:
node_colors = cycle(node_colors)
else:
# assign colors using colormap
try:
spectral = plt.cm.spectral
except AttributeError:
spectral = plt.cm.Spectral
node_colors = [spectral(i / float(n_nodes)) for i in range(n_nodes)]
# handle 1D and 2D connectivity information
if con.ndim == 1:
if indices is None:
raise ValueError("indices has to be provided if con.ndim == 1")
elif con.ndim == 2:
if con.shape[0] != n_nodes or con.shape[1] != n_nodes:
raise ValueError("con has to be 1D or a square matrix")
# we use the lower-triangular part
indices = np.tril_indices(n_nodes, -1)
con = con[indices]
else:
raise ValueError("con has to be 1D or a square matrix")
# get the colormap
colormap = _get_cmap(colormap)
# Use a polar axes
if ax is None:
fig = plt.figure(figsize=(8, 8), facecolor=facecolor, layout="constrained")
ax = fig.add_subplot(polar=True)
else:
fig = ax.figure
ax.set_facecolor(facecolor)
# No ticks, we'll put our own
ax.set_xticks([])
ax.set_yticks([])
# Set y axes limit, add additional space if requested
ax.set_ylim(0, 10 + padding)
# Remove the black axes border which may obscure the labels
ax.spines["polar"].set_visible(False)
# Draw lines between connected nodes, only draw the strongest connections
if n_lines is not None and len(con) > n_lines:
con_thresh = np.sort(np.abs(con).ravel())[-n_lines]
else:
con_thresh = 0.0
# get the connections which we are drawing and sort by connection strength
# this will allow us to draw the strongest connections first
con_abs = np.abs(con)
con_draw_idx = np.where(con_abs >= con_thresh)[0]
con = con[con_draw_idx]
con_abs = con_abs[con_draw_idx]
indices = [ind[con_draw_idx] for ind in indices]
# now sort them
sort_idx = np.argsort(con_abs)
del con_abs
con = con[sort_idx]
indices = [ind[sort_idx] for ind in indices]
# Get vmin vmax for color scaling
if vmin is None:
vmin = np.min(con[np.abs(con) >= con_thresh])
if vmax is None:
vmax = np.max(con)
vrange = vmax - vmin
# We want to add some "noise" to the start and end position of the
# edges: We modulate the noise with the number of connections of the
# node and the connection strength, such that the strongest connections
# are closer to the node center
nodes_n_con = np.zeros((n_nodes), dtype=np.int64)
for i, j in zip(indices[0], indices[1]):
nodes_n_con[i] += 1
nodes_n_con[j] += 1
# initialize random number generator so plot is reproducible
rng = np.random.mtrand.RandomState(0)
n_con = len(indices[0])
noise_max = 0.25 * node_width
start_noise = rng.uniform(-noise_max, noise_max, n_con)
end_noise = rng.uniform(-noise_max, noise_max, n_con)
nodes_n_con_seen = np.zeros_like(nodes_n_con)
for i, (start, end) in enumerate(zip(indices[0], indices[1])):
nodes_n_con_seen[start] += 1
nodes_n_con_seen[end] += 1
start_noise[i] *= (nodes_n_con[start] - nodes_n_con_seen[start]) / float(
nodes_n_con[start]
)
end_noise[i] *= (nodes_n_con[end] - nodes_n_con_seen[end]) / float(
nodes_n_con[end]
)
# scale connectivity for colormap (vmin<=>0, vmax<=>1)
con_val_scaled = (con - vmin) / vrange
# Finally, we draw the connections
for pos, (i, j) in enumerate(zip(indices[0], indices[1])):
# Start point
t0, r0 = node_angles[i], 10
# End point
t1, r1 = node_angles[j], 10
# Some noise in start and end point
t0 += start_noise[pos]
t1 += end_noise[pos]
verts = [(t0, r0), (t0, 5), (t1, 5), (t1, r1)]
codes = [
m_path.Path.MOVETO,
m_path.Path.CURVE4,
m_path.Path.CURVE4,
m_path.Path.LINETO,
]
path = m_path.Path(verts, codes)
color = colormap(con_val_scaled[pos])
# Actual line
patch = m_patches.PathPatch(
path, fill=False, edgecolor=color, linewidth=linewidth, alpha=1.0
)
ax.add_patch(patch)
# Draw ring with colored nodes
height = np.ones(n_nodes) * node_height
bars = ax.bar(
node_angles,
height,
width=node_width,
bottom=9,
edgecolor=node_edgecolor,
lw=node_linewidth,
facecolor=".9",
align="center",
)
for bar, color in zip(bars, node_colors):
bar.set_facecolor(color)
# Draw node labels
angles_deg = 180 * node_angles / np.pi
for name, angle_rad, angle_deg in zip(node_names, node_angles, angles_deg):
if angle_deg >= 270:
ha = "left"
else:
# Flip the label, so text is always upright
angle_deg += 180
ha = "right"
ax.text(
angle_rad,
9.4 + node_height,
name,
size=fontsize_names,
rotation=angle_deg,
rotation_mode="anchor",
horizontalalignment=ha,
verticalalignment="center",
color=textcolor,
)
if title is not None:
ax.set_title(title, color=textcolor, fontsize=fontsize_title)
if colorbar:
sm = plt.cm.ScalarMappable(cmap=colormap, norm=plt.Normalize(vmin, vmax))
sm.set_array(np.linspace(vmin, vmax))
colorbar_kwargs = dict()
if colorbar_size is not None:
colorbar_kwargs.update(shrink=colorbar_size)
if colorbar_pos is not None:
colorbar_kwargs.update(anchor=colorbar_pos)
cb = fig.colorbar(sm, ax=ax, **colorbar_kwargs)
cb_yticks = plt.getp(cb.ax.axes, "yticklabels")
cb.ax.tick_params(labelsize=fontsize_colorbar)
plt.setp(cb_yticks, color=textcolor)
# Add callback for interaction
if interactive:
callback = partial(
_plot_connectivity_circle_onpick,
fig=fig,
ax=ax,
indices=indices,
n_nodes=n_nodes,
node_angles=node_angles,
)
fig.canvas.mpl_connect("button_press_event", callback)
plt_show(show)
return fig, ax
def plot_channel_labels_circle(labels, colors=None, picks=None, **kwargs):
"""Plot labels for each channel in a circle plot.
.. note:: This primarily makes sense for sEEG channels where each
channel can be assigned an anatomical label as the electrode
passes through various brain areas.
Parameters
----------
labels : dict
Lists of labels (values) associated with each channel (keys).
colors : dict
The color (value) for each label (key).
picks : list | tuple
The channels to consider.
**kwargs : kwargs
Keyword arguments for
:func:`mne_connectivity.viz.plot_connectivity_circle`.
Returns
-------
fig : instance of matplotlib.figure.Figure
The figure handle.
axes : instance of matplotlib.projections.polar.PolarAxes
The subplot handle.
"""
from matplotlib.colors import LinearSegmentedColormap
_validate_type(labels, dict, "labels")
_validate_type(colors, (dict, None), "colors")
_validate_type(picks, (list, tuple, None), "picks")
if picks is not None:
labels = {k: v for k, v in labels.items() if k in picks}
ch_names = list(labels.keys())
all_labels = list(set([label for val in labels.values() for label in val]))
n_labels = len(all_labels)
if colors is not None:
for label in all_labels:
if label not in colors:
raise ValueError(f"No color provided for {label} in `colors`")
# update all_labels, there may be unconnected labels in colors
all_labels = list(colors.keys())
n_labels = len(all_labels)
# make colormap
label_colors = [colors[label] for label in all_labels]
node_colors = ["black"] * len(ch_names) + label_colors
label_cmap = LinearSegmentedColormap.from_list(
"label_cmap", label_colors, N=len(label_colors)
)
else:
node_colors = None
node_names = ch_names + all_labels
con = np.zeros((len(node_names), len(node_names))) * np.nan
for idx, ch_name in enumerate(ch_names):
for label in labels[ch_name]:
node_idx = node_names.index(label)
label_color = all_labels.index(label) / n_labels
con[idx, node_idx] = con[node_idx, idx] = label_color # symmetric
# plot
node_order = ch_names + all_labels[::-1]
node_angles = circular_layout(
node_names, node_order, start_pos=90, group_boundaries=[0, len(ch_names)]
)
# provide defaults but don't overwrite
if "node_angles" not in kwargs:
kwargs.update(node_angles=node_angles)
if "colorbar" not in kwargs:
kwargs.update(colorbar=False)
if "node_colors" not in kwargs:
kwargs.update(node_colors=node_colors)
if "vmin" not in kwargs:
kwargs.update(vmin=0)
if "vmax" not in kwargs:
kwargs.update(vmax=1)
if "colormap" not in kwargs:
kwargs.update(colormap=label_cmap)
return _plot_connectivity_circle(con, node_names, **kwargs)

47
mne/viz/conftest.py Normal file
View File

@@ -0,0 +1,47 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import os.path as op
import numpy as np
import pytest
from mne import Epochs, EvokedArray, create_info, events_from_annotations
from mne.channels import make_standard_montage
from mne.datasets.testing import _pytest_param, data_path
from mne.io import read_raw_nirx
from mne.preprocessing.nirs import beer_lambert_law, optical_density
fname_nirx = op.join(
data_path(download=False), "NIRx", "nirscout", "nirx_15_2_recording_w_overlap"
)
@pytest.fixture()
def fnirs_evoked():
"""Create an fnirs evoked structure."""
montage = make_standard_montage("biosemi16")
ch_names = montage.ch_names
ch_types = ["eeg"] * 16
info = create_info(ch_names=ch_names, sfreq=20, ch_types=ch_types)
evoked_data = np.random.randn(16, 30)
evoked = EvokedArray(evoked_data, info=info, tmin=-0.2, nave=4)
evoked.set_montage(montage)
evoked.set_channel_types(
{"Fp1": "hbo", "Fp2": "hbo", "F4": "hbo", "Fz": "hbo"}, verbose="error"
)
return evoked
@pytest.fixture(params=[_pytest_param()])
def fnirs_epochs():
"""Create an fnirs epoch structure."""
raw_intensity = read_raw_nirx(fname_nirx, preload=False)
raw_od = optical_density(raw_intensity)
raw_haemo = beer_lambert_law(raw_od, ppf=6.0)
evts, _ = events_from_annotations(raw_haemo, event_id={"1.0": 1})
evts_dct = {"A": 1}
tn, tx = -1, 2
epochs = Epochs(raw_haemo, evts, event_id=evts_dct, tmin=tn, tmax=tx)
return epochs

1160
mne/viz/epochs.py Normal file

File diff suppressed because it is too large Load Diff

3207
mne/viz/evoked.py Normal file

File diff suppressed because it is too large Load Diff

577
mne/viz/evoked_field.py Normal file
View File

@@ -0,0 +1,577 @@
"""Class to draw evoked MEG and EEG fieldlines, with a GUI to control the figure.
author: Marijn van Vliet <w.m.vanvliet@gmail.com>
"""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from functools import partial
import numpy as np
from scipy.interpolate import interp1d
from .._fiff.pick import pick_types
from ..defaults import DEFAULTS
from ..utils import (
_auto_weakref,
_check_option,
_ensure_int,
_to_rgb,
_validate_type,
fill_doc,
)
from ._3d_overlay import _LayeredMesh
from .ui_events import (
ColormapRange,
Contours,
TimeChange,
disable_ui_events,
publish,
subscribe,
)
from .utils import mne_analyze_colormap
@fill_doc
class EvokedField:
"""Plot MEG/EEG fields on head surface and helmet in 3D.
Parameters
----------
evoked : instance of mne.Evoked
The evoked object.
surf_maps : list
The surface mapping information obtained with make_field_map.
time : float | None
The time point at which the field map shall be displayed. If None,
the average peak latency (across sensor types) is used.
time_label : str | None
How to print info about the time instant visualized.
%(n_jobs)s
fig : instance of Figure3D | None
If None (default), a new figure will be created, otherwise it will
plot into the given figure.
.. versionadded:: 0.20
vmax : float | dict | None
Maximum intensity. Can be a dictionary with two entries ``"eeg"`` and ``"meg"``
to specify separate values for EEG and MEG fields respectively. Can be
``None`` to use the maximum value of the data.
.. versionadded:: 0.21
.. versionadded:: 1.4
``vmax`` can be a dictionary to specify separate values for EEG and
MEG fields.
n_contours : int
The number of contours.
.. versionadded:: 0.21
show_density : bool
Whether to draw the field density as an overlay on top of the helmet/head
surface. Defaults to ``True``.
alpha : float | dict | None
Opacity of the meshes (between 0 and 1). Can be a dictionary with two
entries ``"eeg"`` and ``"meg"`` to specify separate values for EEG and
MEG fields respectively. Can be ``None`` to use 1.0 when a single field
map is shown, or ``dict(eeg=1.0, meg=0.5)`` when both field maps are shown.
.. versionadded:: 1.4
%(interpolation_brain_time)s
.. versionadded:: 1.6
%(interaction_scene)s
Defaults to ``'terrain'``.
.. versionadded:: 1.1
time_viewer : bool | str
Display time viewer GUI. Can also be ``"auto"``, which will mean
``True`` if there is more than one time point and ``False`` otherwise.
.. versionadded:: 1.6
%(verbose)s
Notes
-----
The figure will publish and subscribe to the following UI events:
* :class:`~mne.viz.ui_events.TimeChange`
* :class:`~mne.viz.ui_events.Contours`, ``kind="field_strength_meg" | "field_strength_eeg"``
* :class:`~mne.viz.ui_events.ColormapRange`, ``kind="field_strength_meg" | "field_strength_eeg"``
""" # noqa
def __init__(
self,
evoked,
surf_maps,
*,
time=None,
time_label="t = %0.0f ms",
n_jobs=None,
fig=None,
vmax=None,
n_contours=21,
show_density=True,
alpha=None,
interpolation="nearest",
interaction="terrain",
time_viewer="auto",
verbose=None,
):
from .backends.renderer import _get_3d_backend, _get_renderer
# Setup figure parameters
self._evoked = evoked
if time is None:
types = [t for t in ["eeg", "grad", "mag"] if t in evoked]
time = np.mean([evoked.get_peak(ch_type=t)[1] for t in types])
self._current_time = time
if not evoked.times[0] <= time <= evoked.times[-1]:
raise ValueError(f"`time` ({time:0.3f}) must be inside `evoked.times`")
self._time_label = time_label
self._vmax = _validate_type(vmax, (None, "numeric", dict), "vmax")
self._n_contours = _ensure_int(n_contours, "n_contours")
self._time_interpolation = _check_option(
"interpolation",
interpolation,
("linear", "nearest", "zero", "slinear", "quadratic", "cubic"),
)
self._interaction = _check_option(
"interaction", interaction, ["trackball", "terrain"]
)
surf_map_kinds = [surf_map["kind"] for surf_map in surf_maps]
if vmax is None:
self._vmax = {kind: None for kind in surf_map_kinds}
elif isinstance(vmax, dict):
for kind in surf_map_kinds:
if kind not in vmax:
raise ValueError(
f'No entry for "{kind}" found in the vmax dictionary'
)
self._vmax = vmax
else: # float value
self._vmax = {kind: vmax for kind in surf_map_kinds}
if alpha is None:
self._alpha = {
surf_map["kind"]: val for surf_map, val in zip(surf_maps, [1.0, 0.5])
}
elif isinstance(alpha, dict):
for kind in surf_map_kinds:
if kind not in alpha:
raise ValueError(
f'No entry for "{kind}" found in the alpha dictionary'
)
self._alpha = alpha
else: # float value
self._alpha = {kind: alpha for kind in surf_map_kinds}
self._colors = [(0.6, 0.6, 0.6), (1.0, 1.0, 1.0)]
self._colormap = mne_analyze_colormap(format="vtk")
self._colormap_lines = np.concatenate(
[
np.tile([0.0, 0.0, 255.0, 255.0], (127, 1)),
np.tile([0.0, 0.0, 0.0, 255.0], (2, 1)),
np.tile([255.0, 0.0, 0.0, 255.0], (127, 1)),
]
)
self._show_density = show_density
from ._brain import Brain
if isinstance(fig, Brain):
self._renderer = fig._renderer
self._in_brain_figure = True
if _get_3d_backend() == "notebook":
raise NotImplementedError(
"Plotting on top of an existing Brain figure "
"is currently not supported inside a notebook."
)
else:
self._renderer = _get_renderer(
fig, bgcolor=(0.0, 0.0, 0.0), size=(600, 600)
)
self._in_brain_figure = False
self.plotter = self._renderer.plotter
self.interaction = interaction
# Prepare the surface maps
self._surf_maps = [
self._prepare_surf_map(surf_map, color, self._alpha[surf_map["kind"]])
for surf_map, color in zip(surf_maps, self._colors)
]
# Do we want the time viewer?
if time_viewer == "auto":
time_viewer = len(evoked.times) > 1
self.time_viewer = time_viewer
# Configure UI events
@_auto_weakref
def current_time_func():
return self._current_time
self._widgets = dict()
if self.time_viewer:
# Draw widgets only if not inside a figure that already has them.
if (
not hasattr(self._renderer, "_widgets")
or "time_slider" not in self._renderer._widgets
):
self._renderer._enable_time_interaction(
self,
current_time_func=current_time_func,
times=evoked.times,
)
if not self._in_brain_figure or "time_slider" not in fig.widgets:
# Draw the time label
self._time_label = time_label
if time_label is not None:
if "%" in time_label:
time_label = time_label % np.round(1e3 * time)
self._time_label_actor = self._renderer.text2d(
x_window=0.01, y_window=0.01, text=time_label
)
self._configure_dock()
subscribe(self, "time_change", self._on_time_change)
subscribe(self, "colormap_range", self._on_colormap_range)
subscribe(self, "contours", self._on_contours)
if not self._in_brain_figure:
self._renderer.set_interaction(interaction)
self._renderer.set_camera(azimuth=10, elevation=60, distance="auto")
self._renderer.show()
def _prepare_surf_map(self, surf_map, color, alpha):
"""Compute all the data required to render a fieldlines map."""
if surf_map["kind"] == "eeg":
pick = pick_types(self._evoked.info, meg=False, eeg=True)
else:
pick = pick_types(self._evoked.info, meg=True, eeg=False, ref_meg=False)
evoked_ch_names = set([self._evoked.ch_names[k] for k in pick])
map_ch_names = set(surf_map["ch_names"])
if evoked_ch_names != map_ch_names:
message = ["Channels in map and data do not match."]
diff = map_ch_names - evoked_ch_names
if len(diff):
message += [f"{list(diff)} not in data file. "]
diff = evoked_ch_names - map_ch_names
if len(diff):
message += [f"{list(diff)} not in map file."]
raise RuntimeError(" ".join(message))
data = surf_map["data"] @ self._evoked.data[pick]
data_interp = interp1d(
self._evoked.times,
data,
kind=self._time_interpolation,
assume_sorted=True,
)
current_data = data_interp(self._current_time)
# Make a solid surface
surf = surf_map["surf"]
if self._in_brain_figure:
surf["rr"] *= 1000
map_vmax = self._vmax.get(surf_map["kind"])
if map_vmax is None:
map_vmax = float(np.max(current_data))
mesh = _LayeredMesh(
renderer=self._renderer,
vertices=surf["rr"],
triangles=surf["tris"],
normals=surf["nn"],
)
mesh.map()
color = _to_rgb(color, alpha=True)
cmap = np.array([(0, 0, 0, 0), color])
ctable = np.round(cmap * 255).astype(np.uint8)
mesh.add_overlay(
scalars=np.ones(len(current_data)),
colormap=ctable,
rng=[0, 1],
opacity=alpha,
name="surf",
)
# Show the field density
if self._show_density:
mesh.add_overlay(
scalars=current_data,
colormap=self._colormap,
rng=[-map_vmax, map_vmax],
opacity=1.0,
name="field",
)
# And the field lines on top
if self._n_contours > 1:
contours = np.linspace(-map_vmax, map_vmax, self._n_contours)
contours_actor, _ = self._renderer.contour(
surface=surf,
scalars=current_data,
contours=contours,
vmin=-map_vmax,
vmax=map_vmax,
colormap=self._colormap_lines,
)
else:
contours = None # noqa
contours_actor = None
return dict(
pick=pick,
data=data,
data_interp=data_interp,
map_kind=surf_map["kind"],
mesh=mesh,
contours=contours,
contours_actor=contours_actor,
surf=surf,
map_vmax=map_vmax,
)
def _update(self):
"""Update the figure to reflect the current settings."""
for surf_map in self._surf_maps:
current_data = surf_map["data_interp"](self._current_time)
surf_map["mesh"].update_overlay(name="field", scalars=current_data)
if surf_map["contours"] is not None:
self._renderer.plotter.remove_actor(
surf_map["contours_actor"], render=False
)
if self._n_contours > 1:
surf_map["contours_actor"], _ = self._renderer.contour(
surface=surf_map["surf"],
scalars=current_data,
contours=surf_map["contours"],
vmin=-surf_map["map_vmax"],
vmax=surf_map["map_vmax"],
colormap=self._colormap_lines,
)
if self._time_label is not None:
if hasattr(self, "_time_label_actor"):
self._renderer.plotter.remove_actor(
self._time_label_actor, render=False
)
time_label = self._time_label
if "%" in self._time_label:
time_label = self._time_label % np.round(1e3 * self._current_time)
self._time_label_actor = self._renderer.text2d(
x_window=0.01, y_window=0.01, text=time_label
)
self._renderer.plotter.update()
def _configure_dock(self):
"""Configure the widgets shown in the dock on the left."""
r = self._renderer
if not hasattr(r, "_dock"):
r._dock_initialize()
# Fieldline configuration
layout = r._dock_add_group_box("Fieldlines")
if self._show_density:
r._dock_add_label(value="max value", align=True, layout=layout)
@_auto_weakref
def _callback(vmax, kind, scaling):
self.set_vmax(vmax / scaling, kind=kind)
for surf_map in self._surf_maps:
if surf_map["map_kind"] == "meg":
scaling = DEFAULTS["scalings"]["grad"]
else:
scaling = DEFAULTS["scalings"]["eeg"]
rng = [0, np.max(np.abs(surf_map["data"])) * scaling]
hlayout = r._dock_add_layout(vertical=False)
self._widgets[f"vmax_slider_{surf_map['map_kind']}"] = (
r._dock_add_slider(
name=surf_map["map_kind"].upper(),
value=surf_map["map_vmax"] * scaling,
rng=rng,
callback=partial(
_callback, kind=surf_map["map_kind"], scaling=scaling
),
double=True,
layout=hlayout,
)
)
self._widgets[f"vmax_spin_{surf_map['map_kind']}"] = (
r._dock_add_spin_box(
name="",
value=surf_map["map_vmax"] * scaling,
rng=rng,
callback=partial(
_callback, kind=surf_map["map_kind"], scaling=scaling
),
layout=hlayout,
)
)
r._layout_add_widget(layout, hlayout)
hlayout = r._dock_add_layout(vertical=False)
r._dock_add_label(
value="Rescale",
align=True,
layout=hlayout,
)
r._dock_add_button(
name="",
callback=self._rescale,
layout=hlayout,
style="toolbutton",
)
r._layout_add_widget(layout, hlayout)
self._widgets["contours"] = r._dock_add_spin_box(
name="Contour lines",
value=21,
rng=[0, 99],
step=1,
double=False,
callback=self.set_contours,
layout=layout,
)
r._dock_finalize()
def _on_time_change(self, event):
"""Respond to time_change UI event."""
new_time = np.clip(event.time, self._evoked.times[0], self._evoked.times[-1])
if new_time == self._current_time:
return
self._current_time = new_time
self._update()
def _on_colormap_range(self, event):
"""Response to the colormap_range UI event."""
if event.kind == "field_strength_meg":
kind = "meg"
elif event.kind == "field_strength_eeg":
kind = "eeg"
else:
return
for surf_map in self._surf_maps:
if surf_map["map_kind"] == kind:
break
else:
# No field map currently shown of the requested type.
return
vmin = event.fmin
vmax = event.fmax
surf_map["contours"] = np.linspace(vmin, vmax, self._n_contours)
if self._show_density:
surf_map["mesh"].update_overlay(name="field", rng=[vmin, vmax])
# Update the GUI widgets
if kind == "meg":
scaling = DEFAULTS["scalings"]["grad"]
else:
scaling = DEFAULTS["scalings"]["eeg"]
with disable_ui_events(self):
widget = self._widgets.get(f"vmax_slider_{kind}", None)
if widget is not None:
widget.set_value(vmax * scaling)
widget = self._widgets.get(f"vmax_spin_{kind}", None)
if widget is not None:
widget.set_value(vmax * scaling)
self._update()
def _on_contours(self, event):
"""Respond to the contours UI event."""
if event.kind == "field_strength_meg":
kind = "meg"
elif event.kind == "field_strength_eeg":
kind = "eeg"
else:
return
for surf_map in self._surf_maps:
if surf_map["map_kind"] == kind:
break
surf_map["contours"] = event.contours
self._n_contours = len(event.contours)
with disable_ui_events(self):
if "contours" in self._widgets:
self._widgets["contours"].set_value(len(event.contours))
self._update()
def set_time(self, time):
"""Set the time to display (in seconds).
Parameters
----------
time : float
The time to show, in seconds.
"""
if self._evoked.times[0] <= time <= self._evoked.times[-1]:
publish(self, TimeChange(time=time))
else:
raise ValueError(
f"Requested time ({time} s) is outside the range of "
f"available times ({self._evoked.times[0]}-{self._evoked.times[-1]} s)."
)
def set_contours(self, n_contours):
"""Adjust the number of contour lines to use when drawing the fieldlines.
Parameters
----------
n_contours : int
The number of contour lines to use.
"""
for surf_map in self._surf_maps:
publish(
self,
Contours(
kind=f"field_strength_{surf_map['map_kind']}",
contours=np.linspace(
-surf_map["map_vmax"], surf_map["map_vmax"], n_contours
).tolist(),
),
)
def set_vmax(self, vmax, kind="meg"):
"""Change the color range of the density maps.
Parameters
----------
vmax : float
The new maximum value of the color range.
kind : 'meg' | 'eeg'
Which field map to apply the new color range to.
"""
_check_option("type", kind, ["eeg", "meg"])
for surf_map in self._surf_maps:
if surf_map["map_kind"] == kind:
publish(
self,
ColormapRange(
kind=f"field_strength_{kind}",
fmin=-vmax,
fmax=vmax,
),
)
break
else:
raise ValueError(f"No {type.upper()} field map currently shown.")
def _rescale(self):
"""Rescale the fieldlines and density maps to the current time point."""
for surf_map in self._surf_maps:
current_data = surf_map["data_interp"](self._current_time)
vmax = float(np.max(current_data))
self.set_vmax(vmax, kind=surf_map["map_kind"])

View File

@@ -0,0 +1,7 @@
"""Eye-tracking visualization routines."""
#
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from .heatmap import plot_gaze

View File

@@ -0,0 +1,209 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from scipy.ndimage import gaussian_filter
from ..._fiff.constants import FIFF
from ...utils import _validate_type, fill_doc, logger
from ..utils import plt_show
@fill_doc
def plot_gaze(
epochs,
*,
calibration=None,
width=None,
height=None,
sigma=25,
cmap=None,
alpha=1.0,
vlim=(None, None),
axes=None,
show=True,
):
"""Plot a heatmap of eyetracking gaze data.
Parameters
----------
epochs : instance of Epochs
The :class:`~mne.Epochs` object containing eyegaze channels.
calibration : instance of Calibration | None
An instance of Calibration with information about the screen size, distance,
and resolution. If ``None``, you must provide a width and height.
width : int
The width dimension of the plot canvas, only valid if eyegaze data are in
pixels. For example, if the participant screen resolution was 1920x1080, then
the width should be 1920.
height : int
The height dimension of the plot canvas, only valid if eyegaze data are in
pixels. For example, if the participant screen resolution was 1920x1080, then
the height should be 1080.
sigma : float | None
The amount of Gaussian smoothing applied to the heatmap data (standard
deviation in pixels). If ``None``, no smoothing is applied. Default is 25.
%(cmap)s
alpha : float
The opacity of the heatmap (default is 1).
%(vlim_plot_topomap)s
%(axes_plot_topomap)s
%(show)s
Returns
-------
fig : instance of Figure
The resulting figure object for the heatmap plot.
Notes
-----
.. versionadded:: 1.6
"""
from mne import BaseEpochs
from mne._fiff.pick import _picks_to_idx
from ...preprocessing.eyetracking.utils import (
_check_calibration,
get_screen_visual_angle,
)
_validate_type(epochs, BaseEpochs, "epochs")
_validate_type(alpha, "numeric", "alpha")
_validate_type(sigma, ("numeric", None), "sigma")
# Get the gaze data
pos_picks = _picks_to_idx(epochs.info, "eyegaze")
gaze_data = epochs.get_data(picks=pos_picks)
gaze_ch_loc = np.array([epochs.info["chs"][idx]["loc"] for idx in pos_picks])
x_data = gaze_data[:, np.where(gaze_ch_loc[:, 4] == -1)[0], :]
y_data = gaze_data[:, np.where(gaze_ch_loc[:, 4] == 1)[0], :]
unit = epochs.info["chs"][pos_picks[0]]["unit"] # assumes all units are the same
if x_data.shape[1] > 1: # binocular recording. Average across eyes
logger.info("Detected binocular recording. Averaging positions across eyes.")
x_data = np.nanmean(x_data, axis=1) # shape (n_epochs, n_samples)
y_data = np.nanmean(y_data, axis=1)
canvas = np.vstack((x_data.flatten(), y_data.flatten())) # shape (2, n_samples)
# Check that we have the right inputs
if calibration is not None:
if width is not None or height is not None:
raise ValueError(
"If a calibration is provided, you cannot provide a width or height"
" to plot heatmaps. Please provide only the calibration object."
)
_check_calibration(calibration)
if unit == FIFF.FIFF_UNIT_PX:
width, height = calibration["screen_resolution"]
elif unit == FIFF.FIFF_UNIT_RAD:
width, height = calibration["screen_size"]
else:
raise ValueError(
f"Invalid unit type: {unit}. gaze data Must be pixels or radians."
)
else:
if width is None or height is None:
raise ValueError(
"If no calibration is provided, you must provide a width and height"
" to plot heatmaps."
)
# Create 2D histogram
# We need to set the histogram bins & bounds, and imshow extent, based on the units
if unit == FIFF.FIFF_UNIT_PX: # pixel on screen
_range = [[0, height], [0, width]]
bins_x, bins_y = width, height
extent = [0, width, height, 0]
elif unit == FIFF.FIFF_UNIT_RAD: # radians of visual angle
if not calibration:
raise ValueError(
"If gaze data are in Radians, you must provide a"
" calibration instance to plot heatmaps."
)
width, height = get_screen_visual_angle(calibration)
x_range = [-width / 2, width / 2]
y_range = [-height / 2, height / 2]
_range = [y_range, x_range]
extent = (x_range[0], x_range[1], y_range[0], y_range[1])
bins_x, bins_y = calibration["screen_resolution"]
hist, _, _ = np.histogram2d(
canvas[1, :],
canvas[0, :],
bins=(bins_y, bins_x),
range=_range,
)
# Convert density from samples to seconds
hist /= epochs.info["sfreq"]
# Smooth the heatmap
if sigma:
hist = gaussian_filter(hist, sigma=sigma)
return _plot_heatmap_array(
hist,
width=width,
height=height,
cmap=cmap,
alpha=alpha,
vmin=vlim[0],
vmax=vlim[1],
extent=extent,
axes=axes,
show=show,
)
def _plot_heatmap_array(
data,
width,
height,
*,
cmap=None,
alpha=None,
vmin=None,
vmax=None,
extent=None,
axes=None,
show=True,
):
"""Plot a heatmap of eyetracking gaze data from a numpy array."""
import matplotlib.pyplot as plt
# Prepare axes
if axes is not None:
from matplotlib.axes import Axes
_validate_type(axes, Axes, "axes")
ax = axes
fig = ax.get_figure()
else:
fig, ax = plt.subplots(constrained_layout=True)
ax.set_title("Gaze heatmap")
ax.set_xlabel("X position")
ax.set_ylabel("Y position")
# Prepare the heatmap
alphas = 1 if alpha is None else alpha
vmin = np.nanmin(data) if vmin is None else vmin
vmax = np.nanmax(data) if vmax is None else vmax
if extent is None:
extent = [0, width, height, 0]
# Plot heatmap
im = ax.imshow(
data,
aspect="equal",
cmap=cmap,
alpha=alphas,
extent=extent,
origin="upper",
vmin=vmin,
vmax=vmax,
)
# Prepare the colorbar
fig.colorbar(im, ax=ax, shrink=0.6, label="Dwell time (seconds)")
plt_show(show)
return fig

1460
mne/viz/ica.py Normal file

File diff suppressed because it is too large Load Diff

1645
mne/viz/misc.py Normal file

File diff suppressed because it is too large Load Diff

135
mne/viz/montage.py Normal file
View File

@@ -0,0 +1,135 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
"""Functions to plot EEG sensor montages or digitizer montages."""
from copy import deepcopy
import numpy as np
from scipy.spatial.distance import cdist
from .._fiff._digitization import _get_fid_coords
from .._fiff.meas_info import create_info
from ..utils import _check_option, _validate_type, logger, verbose
from .utils import plot_sensors
@verbose
def plot_montage(
montage,
*,
scale=None,
scale_factor=None,
show_names=True,
kind="topomap",
show=True,
sphere=None,
axes=None,
verbose=None,
):
"""Plot a montage.
Parameters
----------
montage : instance of DigMontage
The montage to visualize.
scale : float
Determines the scale of the channel points and labels; values < 1 will scale
down, whereas values > 1 will scale up. Default to None, which implies 1.
scale_factor : float
Determines the size of the points. Deprecated, use scale instead.
show_names : bool | list
Whether to display all channel names. If a list, only the channel
names in the list are shown. Defaults to True.
kind : str
Whether to plot the montage as '3d' or 'topomap' (default).
show : bool
Show figure if True.
%(sphere_topomap_auto)s
%(axes_montage)s
.. versionadded:: 1.4
%(verbose)s
Returns
-------
fig : instance of matplotlib.figure.Figure
The figure object.
"""
import matplotlib.pyplot as plt
from ..channels import DigMontage, make_dig_montage
if scale_factor is not None:
msg = "scale_factor has been deprecated and will be removed. Use scale instead."
if scale is not None:
raise ValueError(
" ".join(["scale and scale_factor cannot be used together.", msg])
)
logger.info(msg)
if scale is None:
scale = 1
_check_option("kind", kind, ["topomap", "3d"])
_validate_type(montage, DigMontage, item_name="montage")
ch_names = montage.ch_names
title = None
if len(ch_names) == 0:
raise RuntimeError("No valid channel positions found.")
pos = np.array(list(montage._get_ch_pos().values()))
dists = cdist(pos, pos)
# only consider upper triangular part by setting the rest to np.nan
dists[np.tril_indices(dists.shape[0])] = np.nan
dupes = np.argwhere(np.isclose(dists, 0))
if dupes.any():
montage = deepcopy(montage)
n_chans = pos.shape[0]
n_dupes = dupes.shape[0]
idx = np.setdiff1d(np.arange(len(pos)), dupes[:, 1]).tolist()
logger.info(f"{n_dupes} duplicate electrode labels found:")
logger.info(", ".join([ch_names[d[0]] + "/" + ch_names[d[1]] for d in dupes]))
logger.info(f"Plotting {n_chans - n_dupes} unique labels.")
ch_names = [ch_names[i] for i in idx]
ch_pos = dict(zip(ch_names, pos[idx, :]))
# XXX: this might cause trouble if montage was originally in head
fid, _ = _get_fid_coords(montage.dig)
montage = make_dig_montage(ch_pos=ch_pos, **fid)
info = create_info(ch_names, sfreq=256, ch_types="eeg")
info.set_montage(montage, on_missing="ignore")
fig = plot_sensors(
info,
kind=kind,
show_names=show_names,
show=show,
title=title,
sphere=sphere,
axes=axes,
)
if scale_factor is not None:
# scale points
collection = fig.axes[0].collections[0]
collection.set_sizes([scale_factor])
elif scale is not None:
# scale points
collection = fig.axes[0].collections[0]
collection.set_sizes([scale * 10])
# scale labels
labels = fig.findobj(match=plt.Text)
x_label, y_label = fig.axes[0].xaxis.label, fig.axes[0].yaxis.label
z_label = fig.axes[0].zaxis.label if kind == "3d" else None
tick_labels = fig.axes[0].get_xticklabels() + fig.axes[0].get_yticklabels()
if kind == "3d":
tick_labels += fig.axes[0].get_zticklabels()
for label in labels:
if label not in [x_label, y_label, z_label] + tick_labels:
label.set_fontsize(label.get_fontsize() * scale)
return fig

635
mne/viz/raw.py Normal file
View File

@@ -0,0 +1,635 @@
"""Functions to plot raw M/EEG data."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from collections import OrderedDict
import numpy as np
from .._fiff.pick import _picks_to_idx, pick_channels, pick_types
from ..defaults import _handle_default
from ..filter import create_filter
from ..utils import _check_option, _get_stim_channel, _validate_type, legacy, verbose
from ..utils.spectrum import _split_psd_kwargs
from .utils import (
_check_cov,
_compute_scalings,
_get_channel_plotting_order,
_handle_decim,
_handle_precompute,
_make_event_color_dict,
_shorten_path_from_middle,
)
_RAW_CLIP_DEF = 1.5
@verbose
def plot_raw(
raw,
events=None,
duration=10.0,
start=0.0,
n_channels=20,
bgcolor="w",
color=None,
bad_color="lightgray",
event_color="cyan",
scalings=None,
remove_dc=True,
order=None,
show_options=False,
title=None,
show=True,
block=False,
highpass=None,
lowpass=None,
filtorder=4,
clipping=_RAW_CLIP_DEF,
show_first_samp=False,
proj=True,
group_by="type",
butterfly=False,
decim="auto",
noise_cov=None,
event_id=None,
show_scrollbars=True,
show_scalebars=True,
time_format="float",
precompute=None,
use_opengl=None,
picks=None,
*,
theme=None,
overview_mode=None,
splash=True,
verbose=None,
):
"""Plot raw data.
Parameters
----------
raw : instance of Raw
The raw data to plot.
events : array | None
Events to show with vertical bars.
duration : float
Time window (s) to plot. The lesser of this value and the duration
of the raw file will be used.
start : float
Initial time to show (can be changed dynamically once plotted). If
show_first_samp is True, then it is taken relative to
``raw.first_samp``.
n_channels : int
Number of channels to plot at once. Defaults to 20. The lesser of
``n_channels`` and ``len(raw.ch_names)`` will be shown.
Has no effect if ``order`` is 'position', 'selection' or 'butterfly'.
bgcolor : color object
Color of the background.
color : dict | color object | None
Color for the data traces. If None, defaults to::
dict(mag='darkblue', grad='b', eeg='k', eog='k', ecg='m',
emg='k', ref_meg='steelblue', misc='k', stim='k',
resp='k', chpi='k')
bad_color : color object
Color to make bad channels.
%(event_color)s
Defaults to ``'cyan'``.
%(scalings)s
remove_dc : bool
If True remove DC component when plotting data.
order : array of int | None
Order in which to plot data. If the array is shorter than the number of
channels, only the given channels are plotted. If None (default), all
channels are plotted. If ``group_by`` is ``'position'`` or
``'selection'``, the ``order`` parameter is used only for selecting the
channels to be plotted.
show_options : bool
If True, a dialog for options related to projection is shown.
title : str | None
The title of the window. If None, and either the filename of the
raw object or '<unknown>' will be displayed as title.
show : bool
Show figure if True.
block : bool
Whether to halt program execution until the figure is closed.
Useful for setting bad channels on the fly by clicking on a line.
May not work on all systems / platforms.
(Only Qt) If you run from a script, this needs to
be ``True`` or a Qt-eventloop needs to be started somewhere
else in the script (e.g. if you want to implement the browser
inside another Qt-Application).
highpass : float | None
Highpass to apply when displaying data.
lowpass : float | None
Lowpass to apply when displaying data.
If highpass > lowpass, a bandstop rather than bandpass filter
will be applied.
filtorder : int
Filtering order. 0 will use FIR filtering with MNE defaults.
Other values will construct an IIR filter of the given order
and apply it with :func:`~scipy.signal.filtfilt` (making the effective
order twice ``filtorder``). Filtering may produce some edge artifacts
(at the left and right edges) of the signals during display.
.. versionchanged:: 0.18
Support for ``filtorder=0`` to use FIR filtering.
clipping : str | float | None
If None, channels are allowed to exceed their designated bounds in
the plot. If "clamp", then values are clamped to the appropriate
range for display, creating step-like artifacts. If "transparent",
then excessive values are not shown, creating gaps in the traces.
If float, clipping occurs for values beyond the ``clipping`` multiple
of their dedicated range, so ``clipping=1.`` is an alias for
``clipping='transparent'``.
.. versionchanged:: 0.21
Support for float, and default changed from None to 1.5.
show_first_samp : bool
If True, show time axis relative to the ``raw.first_samp``.
proj : bool
Whether to apply projectors prior to plotting (default is ``True``).
Individual projectors can be enabled/disabled interactively (see
Notes). This argument only affects the plot; use ``raw.apply_proj()``
to modify the data stored in the Raw object.
%(group_by_browse)s
butterfly : bool
Whether to start in butterfly mode. Defaults to False.
decim : int | 'auto'
Amount to decimate the data during display for speed purposes.
You should only decimate if the data are sufficiently low-passed,
otherwise aliasing can occur. The 'auto' mode (default) uses
the decimation that results in a sampling rate least three times
larger than ``min(info['lowpass'], lowpass)`` (e.g., a 40 Hz lowpass
will result in at least a 120 Hz displayed sample rate).
noise_cov : instance of Covariance | str | None
Noise covariance used to whiten the data while plotting.
Whitened data channels are scaled by ``scalings['whitened']``,
and their channel names are shown in italic.
Can be a string to load a covariance from disk.
See also :meth:`mne.Evoked.plot_white` for additional inspection
of noise covariance properties when whitening evoked data.
For data processed with SSS, the effective dependence between
magnetometers and gradiometers may introduce differences in scaling,
consider using :meth:`mne.Evoked.plot_white`.
.. versionadded:: 0.16.0
event_id : dict | None
Event IDs used to show at event markers (default None shows
the event numbers).
.. versionadded:: 0.16.0
%(show_scrollbars)s
%(show_scalebars)s
.. versionadded:: 0.20.0
%(time_format)s
%(precompute)s
%(use_opengl)s
%(picks_all)s
%(theme_pg)s
.. versionadded:: 1.0
%(overview_mode)s
.. versionadded:: 1.1
%(splash)s
.. versionadded:: 1.6
%(verbose)s
Returns
-------
%(browser)s
Notes
-----
The arrow keys (up/down/left/right) can typically be used to navigate
between channels and time ranges, but this depends on the backend
matplotlib is configured to use (e.g., mpl.use('TkAgg') should work). The
left/right arrows will scroll by 25%% of ``duration``, whereas
shift+left/shift+right will scroll by 100%% of ``duration``. The scaling
can be adjusted with - and + (or =) keys. The viewport dimensions can be
adjusted with page up/page down and home/end keys. Full screen mode can be
toggled with the F11 key, and scrollbars can be hidden/shown by pressing
'z'. Right-click a channel label to view its location. To mark or un-mark a
channel as bad, click on a channel label or a channel trace. The changes
will be reflected immediately in the raw object's ``raw.info['bads']``
entry.
If projectors are present, a button labelled "Prj" in the lower right
corner of the plot window opens a secondary control window, which allows
enabling/disabling specific projectors individually. This provides a means
of interactively observing how each projector would affect the raw data if
it were applied.
Annotation mode is toggled by pressing 'a', butterfly mode by pressing
'b', and whitening mode (when ``noise_cov is not None``) by pressing 'w'.
By default, the channel means are removed when ``remove_dc`` is set to
``True``. This flag can be toggled by pressing 'd'.
%(notes_2d_backend)s
"""
from ..annotations import _annotations_starts_stops
from ..io import BaseRaw
from ._figure import _get_browser
info = raw.info.copy()
sfreq = info["sfreq"]
projs = info["projs"]
# this will be an attr for which projectors are currently "on" in the plot
projs_on = np.full_like(projs, proj, dtype=bool)
# disable projs in info if user doesn't want to see them right away
if not proj:
with info._unlock():
info["projs"] = list()
# handle defaults / check arg validity
color = _handle_default("color", color)
scalings = _compute_scalings(scalings, raw, remove_dc=remove_dc, duration=duration)
if scalings["whitened"] == "auto":
scalings["whitened"] = 1.0
_validate_type(raw, BaseRaw, "raw", "Raw")
decim, picks_data = _handle_decim(info, decim, lowpass)
noise_cov = _check_cov(noise_cov, info)
units = _handle_default("units", None)
unit_scalings = _handle_default("scalings", None)
_check_option("group_by", group_by, ("selection", "position", "original", "type"))
# clipping
_validate_type(clipping, (None, "numeric", str), "clipping")
if isinstance(clipping, str):
_check_option(
"clipping", clipping, ("clamp", "transparent"), extra="when a string"
)
clipping = 1.0 if clipping == "transparent" else clipping
elif clipping is not None:
clipping = float(clipping)
# be forgiving if user asks for too much time
duration = min(raw.times[-1], float(duration))
# determine IIR filtering parameters
if highpass is not None and highpass <= 0:
raise ValueError(f"highpass must be > 0, got {highpass}")
if highpass is None and lowpass is None:
ba = filt_bounds = None
else:
filtorder = int(filtorder)
if filtorder == 0:
method = "fir"
iir_params = None
else:
method = "iir"
iir_params = dict(order=filtorder, output="sos", ftype="butter")
ba = create_filter(
np.zeros((1, int(round(duration * sfreq)))),
sfreq,
highpass,
lowpass,
method=method,
iir_params=iir_params,
)
filt_bounds = _annotations_starts_stops(
raw, ("edge", "bad_acq_skip"), invert=True
)
# compute event times in seconds
if events is not None:
event_times = (events[:, 0] - raw.first_samp).astype(float)
event_times /= sfreq
event_nums = events[:, 2]
else:
event_times = event_nums = None
# determine trace order
ch_names = np.array(raw.ch_names)
ch_types = np.array(raw.get_channel_types())
picks = _picks_to_idx(info, picks, none="all", exclude=())
order = _get_channel_plotting_order(order, ch_types, picks=picks)
n_channels = min(info["nchan"], n_channels, len(order))
# adjust order based on channel selection, if needed
selections = None
if group_by in ("selection", "position"):
selections = _setup_channel_selections(raw, group_by, order)
order = np.concatenate(list(selections.values()))
default_selection = list(selections)[0]
n_channels = len(selections[default_selection])
assert isinstance(order, np.ndarray)
assert order.dtype.kind == "i"
if order.size == 0:
raise RuntimeError("No channels found to plot")
# handle event colors
event_color_dict = _make_event_color_dict(event_color, events, event_id)
# handle first_samp
first_time = raw._first_time if show_first_samp else 0
start += first_time
event_id_rev = {v: k for k, v in (event_id or {}).items()}
# generate window title; allow instances without a filename (e.g., ICA)
if title is None:
title = "<unknown>"
fnames = list(tuple(raw.filenames)) # get a list of a copy of the filenames
if len(fnames):
title = fnames.pop(0)
extra = f" ... (+ {len(fnames)} more)" if len(fnames) else ""
title = f"{title}{extra}"
if len(title) > 60:
title = _shorten_path_from_middle(title)
elif not isinstance(title, str):
raise TypeError(f"title must be None or a string, got a {type(title)}")
# gather parameters and initialize figure
_validate_type(use_opengl, (bool, None), "use_opengl")
precompute = _handle_precompute(precompute)
params = dict(
inst=raw,
info=info,
# channels and channel order
ch_names=ch_names,
ch_types=ch_types,
ch_order=order,
picks=order[:n_channels],
n_channels=n_channels,
picks_data=picks_data,
group_by=group_by,
ch_selections=selections,
# time
t_start=start,
duration=duration,
n_times=raw.n_times,
first_time=first_time,
time_format=time_format,
decim=decim,
# events
event_color_dict=event_color_dict,
event_times=event_times,
event_nums=event_nums,
event_id_rev=event_id_rev,
# preprocessing
projs=projs,
projs_on=projs_on,
apply_proj=proj,
remove_dc=remove_dc,
filter_coefs=ba,
filter_bounds=filt_bounds,
noise_cov=noise_cov,
# scalings
scalings=scalings,
units=units,
unit_scalings=unit_scalings,
# colors
ch_color_bad=bad_color,
ch_color_dict=color,
# display
butterfly=butterfly,
clipping=clipping,
scrollbars_visible=show_scrollbars,
scalebars_visible=show_scalebars,
window_title=title,
bgcolor=bgcolor,
# Qt-specific
precompute=precompute,
use_opengl=use_opengl,
theme=theme,
overview_mode=overview_mode,
splash=splash,
)
fig = _get_browser(show=show, block=block, **params)
return fig
@legacy(alt="Raw.compute_psd().plot()")
@verbose
def plot_raw_psd(
raw,
fmin=0,
fmax=np.inf,
tmin=None,
tmax=None,
proj=False,
n_fft=None,
n_overlap=0,
reject_by_annotation=True,
picks=None,
ax=None,
color="black",
xscale="linear",
area_mode="std",
area_alpha=0.33,
dB=True,
estimate="power",
show=True,
n_jobs=None,
average=False,
line_alpha=None,
spatial_colors=True,
sphere=None,
window="hamming",
exclude="bads",
verbose=None,
):
"""%(plot_psd_doc)s.
Parameters
----------
raw : instance of Raw
The raw object.
%(fmin_fmax_psd)s
%(tmin_tmax_psd)s
%(proj_psd)s
n_fft : int | None
Number of points to use in Welch FFT calculations. Default is ``None``,
which uses the minimum of 2048 and the number of time points.
n_overlap : int
The number of points of overlap between blocks. The default value
is 0 (no overlap).
%(reject_by_annotation_psd)s
%(picks_good_data_noref)s
%(ax_plot_psd)s
%(color_plot_psd)s
%(xscale_plot_psd)s
%(area_mode_plot_psd)s
%(area_alpha_plot_psd)s
%(dB_plot_psd)s
%(estimate_plot_psd)s
%(show)s
%(n_jobs)s
%(average_plot_psd)s
%(line_alpha_plot_psd)s
%(spatial_colors_psd)s
%(sphere_topomap_auto)s
%(window_psd)s
.. versionadded:: 0.22.0
exclude : list of str | 'bads'
Channels names to exclude from being shown. If 'bads', the bad channels
are excluded. Pass an empty list to plot all channels (including
channels marked "bad", if any).
.. versionadded:: 0.24.0
%(verbose)s
Returns
-------
fig : instance of Figure
Figure with frequency spectra of the data channels.
Notes
-----
%(notes_plot_*_psd_func)s
"""
from ..time_frequency import Spectrum
init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot)
return raw.compute_psd(**init_kw).plot(**plot_kw)
@legacy(alt="Raw.compute_psd().plot_topo()")
@verbose
def plot_raw_psd_topo(
raw,
tmin=0.0,
tmax=None,
fmin=0.0,
fmax=100.0,
proj=False,
*,
n_fft=2048,
n_overlap=0,
dB=True,
layout=None,
color="w",
fig_facecolor="k",
axis_facecolor="k",
axes=None,
block=False,
show=True,
n_jobs=None,
verbose=None,
):
"""Plot power spectral density, separately for each channel.
Parameters
----------
raw : instance of io.Raw
The raw instance to use.
%(tmin_tmax_psd)s
%(fmin_fmax_psd_topo)s
%(proj_psd)s
n_fft : int
Number of points to use in Welch FFT calculations. Defaults to 2048.
n_overlap : int
The number of points of overlap between blocks. Defaults to 0
(no overlap).
%(dB_spectrum_plot_topo)s
layout : instance of Layout | None
Layout instance specifying sensor positions (does not need to be
specified for Neuromag data). If ``None`` (default), the layout is
inferred from the data.
color : str | tuple
A matplotlib-compatible color to use for the curves. Defaults to white.
fig_facecolor : str | tuple
A matplotlib-compatible color to use for the figure background.
Defaults to black.
axis_facecolor : str | tuple
A matplotlib-compatible color to use for the axis background.
Defaults to black.
%(axes_spectrum_plot_topo)s
block : bool
Whether to halt program execution until the figure is closed.
May not work on all systems / platforms. Defaults to False.
%(show)s
%(n_jobs)s
%(verbose)s
Returns
-------
fig : instance of matplotlib.figure.Figure
Figure distributing one image per channel across sensor topography.
"""
from ..time_frequency import Spectrum
init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot_topo)
return raw.compute_psd(**init_kw).plot_topo(**plot_kw)
def _setup_channel_selections(raw, kind, order):
"""Get dictionary of channel groupings."""
from ..channels import (
_EEG_SELECTIONS,
_SELECTIONS,
_divide_to_regions,
read_vectorview_selection,
)
_check_option("group_by", kind, ("position", "selection"))
if kind == "position":
selections_dict = _divide_to_regions(raw.info)
keys = _SELECTIONS[1:] # omit 'Vertex'
else: # kind == 'selection'
from ..channels.channels import _get_ch_info
(
has_vv_mag,
has_vv_grad,
*_,
has_neuromag_122_grad,
has_csd_coils,
) = _get_ch_info(raw.info)
if not (has_vv_grad or has_vv_mag or has_neuromag_122_grad):
raise ValueError(
"order='selection' only works for Neuromag "
"data. Use order='position' instead."
)
selections_dict = OrderedDict()
# get stim channel (if any)
stim_ch = _get_stim_channel(None, raw.info, raise_error=False)
stim_ch = stim_ch if len(stim_ch) else [""]
stim_ch = pick_channels(raw.ch_names, stim_ch, ordered=False)
# loop over regions
keys = np.concatenate([_SELECTIONS, _EEG_SELECTIONS])
for key in keys:
channels = read_vectorview_selection(key, info=raw.info)
picks = pick_channels(raw.ch_names, channels, ordered=False)
picks = np.intersect1d(picks, order)
if not len(picks):
continue # omit empty selections
selections_dict[key] = np.concatenate([picks, stim_ch])
# add misc channels
misc = pick_types(
raw.info,
meg=False,
eeg=False,
stim=True,
eog=True,
ecg=True,
emg=True,
ref_meg=False,
misc=True,
resp=True,
chpi=True,
exci=True,
ias=True,
syst=True,
seeg=False,
bio=True,
ecog=False,
fnirs=False,
dbs=False,
temperature=True,
gsr=True,
exclude=(),
)
if len(misc) and np.isin(misc, order).any():
selections_dict["Misc"] = misc
return selections_dict

1309
mne/viz/topo.py Normal file

File diff suppressed because it is too large Load Diff

4102
mne/viz/topomap.py Normal file

File diff suppressed because it is too large Load Diff

480
mne/viz/ui_events.py Normal file
View File

@@ -0,0 +1,480 @@
"""
Event API for inter-figure communication.
The event API allows figures to communicate with each other, such that a change
in one figure can trigger a change in another figure. For example, moving the
time cursor in one plot can update the current time in another plot. Another
scenario is two drawing routines drawing into the same window, using events to
stay in-sync.
Authors: Marijn van Vliet <w.m.vanvliet@gmail.com>
"""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from __future__ import annotations # only needed for Python ≤ 3.9
import contextlib
import re
import weakref
from dataclasses import dataclass
from matplotlib.colors import Colormap
from ..utils import _validate_type, fill_doc, logger, verbose, warn
# Global dict {fig: channel} containing all currently active event channels.
_event_channels = weakref.WeakKeyDictionary()
# The event channels of figures can be linked together. This dict keeps track
# of these links. Links are bi-directional, so if {fig1: fig2} exists, then so
# must {fig2: fig1}.
_event_channel_links = weakref.WeakKeyDictionary()
# Event channels that are temporarily disabled by the disable_ui_events context
# manager.
_disabled_event_channels = weakref.WeakSet()
# Regex pattern used when converting CamelCase to snake_case.
# Detects all capital letters that are not at the beginning of a word.
_camel_to_snake = re.compile(r"(?<!^)(?=[A-Z])")
# List of events
@fill_doc
class UIEvent:
"""Abstract base class for all events.
Attributes
----------
%(ui_event_name_source)s
"""
source = None
@property
def name(self):
"""The name of the event, which is the class name in snake case."""
return _camel_to_snake.sub("_", self.__class__.__name__).lower()
@fill_doc
class FigureClosing(UIEvent):
"""Indicates that the user has requested to close a figure.
Attributes
----------
%(ui_event_name_source)s
"""
pass
@dataclass
@fill_doc
class TimeChange(UIEvent):
"""Indicates that the user has selected a time.
Parameters
----------
time : float
The new time in seconds.
Attributes
----------
%(ui_event_name_source)s
time : float
The new time in seconds.
"""
time: float
@dataclass
@fill_doc
class PlaybackSpeed(UIEvent):
"""Indicates that the user has selected a different playback speed for videos.
Parameters
----------
speed : float
The new speed in seconds per frame.
Attributes
----------
%(ui_event_name_source)s
speed : float
The new speed in seconds per frame.
"""
speed: float
@dataclass
@fill_doc
class ColormapRange(UIEvent):
"""Indicates that the user has updated the bounds of the colormap.
Parameters
----------
kind : str
Kind of colormap being updated. The Notes section of the drawing
routine publishing this event should mention the possible kinds.
ch_type : str
Type of sensor the data originates from.
%(fmin_fmid_fmax)s
%(alpha)s
cmap : str
The colormap to use. Either string or matplotlib.colors.Colormap
instance.
Attributes
----------
kind : str
Kind of colormap being updated. The Notes section of the drawing
routine publishing this event should mention the possible kinds.
ch_type : str
Type of sensor the data originates from.
unit : str
The unit of the values.
%(ui_event_name_source)s
%(fmin_fmid_fmax)s
%(alpha)s
cmap : str
The colormap to use. Either string or matplotlib.colors.Colormap
instance.
"""
kind: str
ch_type: str | None = None
fmin: float | None = None
fmid: float | None = None
fmax: float | None = None
alpha: bool | None = None
cmap: Colormap | str | None = None
@dataclass
@fill_doc
class VertexSelect(UIEvent):
"""Indicates that the user has selected a vertex.
Parameters
----------
hemi : str
The hemisphere the vertex was selected on.
Can be ``"lh"``, ``"rh"``, or ``"vol"``.
vertex_id : int
The vertex number (in the high resolution mesh) that was selected.
Attributes
----------
%(ui_event_name_source)s
hemi : str
The hemisphere the vertex was selected on.
Can be ``"lh"``, ``"rh"``, or ``"vol"``.
vertex_id : int
The vertex number (in the high resolution mesh) that was selected.
"""
hemi: str
vertex_id: int
@dataclass
@fill_doc
class Contours(UIEvent):
"""Indicates that the user has changed the contour lines.
Parameters
----------
kind : str
The kind of contours lines being changed. The Notes section of the
drawing routine publishing this event should mention the possible
kinds.
contours : list of float
The new values at which contour lines need to be drawn.
Attributes
----------
%(ui_event_name_source)s
kind : str
The kind of contours lines being changed. The Notes section of the
drawing routine publishing this event should mention the possible
kinds.
contours : list of float
The new values at which contour lines need to be drawn.
"""
kind: str
contours: list[str]
def _get_event_channel(fig):
"""Get the event channel associated with a figure.
If the event channel doesn't exist yet, it gets created and added to the
global ``_event_channels`` dict.
Parameters
----------
fig : matplotlib.figure.Figure | Figure3D
The figure to get the event channel for.
Returns
-------
channel : dict[event -> list]
The event channel. An event channel is a list mapping string event
names to a list of callback representing all subscribers to the
channel.
"""
import matplotlib
from ._brain import Brain
from .evoked_field import EvokedField
# Create the event channel if it doesn't exist yet
if fig not in _event_channels:
# The channel itself is a dict mapping string event names to a list of
# subscribers. No subscribers yet for this new event channel.
_event_channels[fig] = dict()
weakfig = weakref.ref(fig)
# When the figure is closed, its associated event channel should be
# deleted. This is a good time to set this up.
def delete_event_channel(event=None, *, weakfig=weakfig):
"""Delete the event channel (callback function)."""
fig = weakfig()
if fig is None:
return
publish(fig, event=FigureClosing()) # Notify subscribers of imminent close
logger.debug(f"unlink(({fig})")
unlink(fig) # Remove channel from the _event_channel_links dict
if fig in _event_channels:
logger.debug(f" del _event_channels[{fig}]")
del _event_channels[fig]
if fig in _disabled_event_channels:
logger.debug(f" _disabled_event_channels.remove({fig})")
_disabled_event_channels.remove(fig)
# Hook up the above callback function to the close event of the figure
# window. How this is done exactly depends on the various figure types
# MNE-Python has.
_validate_type(fig, (matplotlib.figure.Figure, Brain, EvokedField), "fig")
if isinstance(fig, matplotlib.figure.Figure):
fig.canvas.mpl_connect("close_event", delete_event_channel)
else:
assert hasattr(fig, "_renderer") # figures like Brain, EvokedField, etc.
fig._renderer._window_close_connect(delete_event_channel, after=False)
# Now the event channel exists for sure.
return _event_channels[fig]
@verbose
def publish(fig, event, *, verbose=None):
"""Publish an event to all subscribers of the figure's channel.
The figure's event channel and all linked event channels are searched for
subscribers to the given event. Each subscriber had provided a callback
function when subscribing, so we call that.
Parameters
----------
fig : matplotlib.figure.Figure | Figure3D
The figure that publishes the event.
event : UIEvent
Event to publish.
%(verbose)s
"""
if fig in _disabled_event_channels:
return
# Compile a list of all event channels that the event should be published
# on.
channels = [_get_event_channel(fig)]
links = _event_channel_links.get(fig, None)
if links is not None:
for linked_fig, (include_events, exclude_events) in links.items():
if (include_events is None or event.name in include_events) and (
exclude_events is None or event.name not in exclude_events
):
channels.append(_get_event_channel(linked_fig))
# Publish the event by calling the registered callback functions.
event.source = fig
logger.debug(f"Publishing {event} on channel {fig}")
for channel in channels:
if event.name not in channel:
channel[event.name] = set()
for callback in channel[event.name]:
callback(event=event)
@verbose
def subscribe(fig, event_name, callback, *, verbose=None):
"""Subscribe to an event on a figure's event channel.
Parameters
----------
fig : matplotlib.figure.Figure | Figure3D
The figure of which event channel to subscribe.
event_name : str
The name of the event to listen for.
callback : callable
The function that should be called whenever the event is published.
%(verbose)s
"""
channel = _get_event_channel(fig)
logger.debug(f"Subscribing to channel {channel}")
if event_name not in channel:
channel[event_name] = set()
channel[event_name].add(callback)
@verbose
def unsubscribe(fig, event_names, callback=None, *, verbose=None):
"""Unsubscribe from an event on a figure's event channel.
Parameters
----------
fig : matplotlib.figure.Figure | Figure3D
The figure of which event channel to unsubscribe from.
event_names : str | list of str
Select which events to stop subscribing to. Can be a single string
event name, a list of event names or ``"all"`` which will unsubscribe
from all events.
callback : callable | None
The callback function that should be unsubscribed, leaving all other
callback functions that may be subscribed untouched. By default
(``None``) all callback functions are unsubscribed from the event.
%(verbose)s
"""
channel = _get_event_channel(fig)
# Determine which events to unsubscribe for.
if event_names == "all":
if callback is None:
event_names = list(channel.keys())
else:
event_names = list(k for k, v in channel.items() if callback in v)
elif isinstance(event_names, str):
event_names = [event_names]
for event_name in event_names:
if event_name not in channel:
warn(
f'Cannot unsubscribe from event "{event_name}" as we have never '
"subscribed to it."
)
continue
if callback is None:
del channel[event_name]
else:
# Unsubscribe specific callback function.
subscribers = channel[event_name]
if callback in subscribers:
subscribers.remove(callback)
else:
warn(
f'Cannot unsubscribe {callback} from event "{event_name}" '
"as it was never subscribed to it."
)
if len(subscribers) == 0:
del channel[event_name] # keep things tidy
@verbose
def link(*figs, include_events=None, exclude_events=None, verbose=None):
"""Link the event channels of two figures together.
When event channels are linked, any events that are published on one
channel are simultaneously published on the other channel. Links are
bi-directional.
Parameters
----------
*figs : tuple of matplotlib.figure.Figure | tuple of Figure3D
The figures whose event channel will be linked.
include_events : list of str | None
Select which events to publish across figures. By default (``None``),
both figures will receive all of each other's events. Passing a list of
event names will restrict the events being shared across the figures to
only the given ones.
exclude_events : list of str | None
Select which events not to publish across figures. By default (``None``),
no events are excluded.
%(verbose)s
"""
if include_events is not None:
include_events = set(include_events)
if exclude_events is not None:
exclude_events = set(exclude_events)
# Make sure the event channels of the figures are setup properly.
for fig in figs:
_get_event_channel(fig)
if fig not in _event_channel_links:
_event_channel_links[fig] = weakref.WeakKeyDictionary()
# Link the event channels
for fig1 in figs:
for fig2 in figs:
if fig1 is not fig2:
_event_channel_links[fig1][fig2] = (include_events, exclude_events)
@verbose
def unlink(fig, *, verbose=None):
"""Remove all links involving the event channel of the given figure.
Parameters
----------
fig : matplotlib.figure.Figure | Figure3D
The figure whose event channel should be unlinked from all other event
channels.
%(verbose)s
"""
linked_figs = _event_channel_links.get(fig)
if linked_figs is not None:
for linked_fig in linked_figs.keys():
del _event_channel_links[linked_fig][fig]
if len(_event_channel_links[linked_fig]) == 0:
del _event_channel_links[linked_fig]
if fig in _event_channel_links: # need to check again because of weak refs
del _event_channel_links[fig]
@contextlib.contextmanager
def disable_ui_events(fig):
"""Temporarily disable generation of UI events. Use as context manager.
Parameters
----------
fig : matplotlib.figure.Figure | Figure3D
The figure whose UI event generation should be temporarily disabled.
"""
_disabled_event_channels.add(fig)
try:
yield
finally:
_disabled_event_channels.remove(fig)
def _cleanup_agg():
"""Call close_event for Agg canvases to help our doc build."""
import matplotlib.backends.backend_agg
import matplotlib.figure
for key in list(_event_channels): # we might remove keys as we go
if isinstance(key, matplotlib.figure.Figure):
canvas = key.canvas
if isinstance(canvas, matplotlib.backends.backend_agg.FigureCanvasAgg):
for cb in key.canvas.callbacks.callbacks["close_event"].values():
cb = cb() # get the true ref
if cb is not None:
cb()

2814
mne/viz/utils.py Normal file

File diff suppressed because it is too large Load Diff