initial commit
This commit is contained in:
4282
mne/viz/_3d.py
Normal file
4282
mne/viz/_3d.py
Normal file
File diff suppressed because it is too large
Load Diff
184
mne/viz/_3d_overlay.py
Normal file
184
mne/viz/_3d_overlay.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Classes to handle overlapping surfaces."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils import logger
|
||||
|
||||
|
||||
class _Overlay:
|
||||
def __init__(self, scalars, colormap, rng, opacity, name):
|
||||
self._scalars = scalars
|
||||
self._colormap = colormap
|
||||
assert rng is not None
|
||||
self._rng = rng
|
||||
self._opacity = opacity
|
||||
self._name = name
|
||||
|
||||
def to_colors(self):
|
||||
from matplotlib.colors import Colormap, ListedColormap
|
||||
|
||||
from ._3d import _get_cmap
|
||||
|
||||
if isinstance(self._colormap, str):
|
||||
cmap = _get_cmap(self._colormap)
|
||||
elif isinstance(self._colormap, Colormap):
|
||||
cmap = self._colormap
|
||||
else:
|
||||
cmap = ListedColormap(
|
||||
self._colormap / 255.0, name=str(type(self._colormap))
|
||||
)
|
||||
logger.debug(
|
||||
f"Color mapping {repr(self._name)} with {cmap.name} "
|
||||
f"colormap and range {self._rng}"
|
||||
)
|
||||
|
||||
rng = self._rng
|
||||
assert rng is not None
|
||||
scalars = self._norm(rng)
|
||||
|
||||
colors = cmap(scalars)
|
||||
if self._opacity is not None:
|
||||
colors[:, 3] *= self._opacity
|
||||
return colors
|
||||
|
||||
def _norm(self, rng):
|
||||
if rng[0] == rng[1]:
|
||||
factor = 1 if rng[0] == 0 else 1e-6 * rng[0]
|
||||
else:
|
||||
factor = rng[1] - rng[0]
|
||||
return (self._scalars - rng[0]) / factor
|
||||
|
||||
|
||||
class _LayeredMesh:
|
||||
def __init__(self, renderer, vertices, triangles, normals):
|
||||
self._renderer = renderer
|
||||
self._vertices = vertices
|
||||
self._triangles = triangles
|
||||
self._normals = normals
|
||||
|
||||
self._polydata = None
|
||||
self._actor = None
|
||||
self._is_mapped = False
|
||||
|
||||
self._current_colors = None
|
||||
self._cached_colors = None
|
||||
self._overlays = OrderedDict()
|
||||
|
||||
self._default_scalars = np.ones(vertices.shape)
|
||||
self._default_scalars_name = "Data"
|
||||
|
||||
def map(self):
|
||||
kwargs = {
|
||||
"color": None,
|
||||
"pickable": True,
|
||||
"rgba": True,
|
||||
}
|
||||
mesh_data = self._renderer.mesh(
|
||||
x=self._vertices[:, 0],
|
||||
y=self._vertices[:, 1],
|
||||
z=self._vertices[:, 2],
|
||||
triangles=self._triangles,
|
||||
normals=self._normals,
|
||||
scalars=self._default_scalars,
|
||||
**kwargs,
|
||||
)
|
||||
self._actor, self._polydata = mesh_data
|
||||
self._is_mapped = True
|
||||
|
||||
def _compute_over(self, B, A):
|
||||
assert A.ndim == B.ndim == 2
|
||||
assert A.shape[1] == B.shape[1] == 4
|
||||
A_w = A[:, 3:] # * 1
|
||||
B_w = B[:, 3:] * (1 - A_w)
|
||||
C = A.copy()
|
||||
C[:, :3] *= A_w
|
||||
C[:, :3] += B[:, :3] * B_w
|
||||
C[:, 3:] += B_w
|
||||
C[:, :3] /= C[:, 3:]
|
||||
return np.clip(C, 0, 1, out=C)
|
||||
|
||||
def _compose_overlays(self):
|
||||
B = cache = None
|
||||
for overlay in self._overlays.values():
|
||||
A = overlay.to_colors()
|
||||
if B is None:
|
||||
B = A
|
||||
else:
|
||||
cache = B
|
||||
B = self._compute_over(cache, A)
|
||||
return B, cache
|
||||
|
||||
def add_overlay(self, scalars, colormap, rng, opacity, name):
|
||||
overlay = _Overlay(
|
||||
scalars=scalars,
|
||||
colormap=colormap,
|
||||
rng=rng,
|
||||
opacity=opacity,
|
||||
name=name,
|
||||
)
|
||||
self._overlays[name] = overlay
|
||||
colors = overlay.to_colors()
|
||||
if self._current_colors is None:
|
||||
self._current_colors = colors
|
||||
else:
|
||||
# save previous colors to cache
|
||||
self._cached_colors = self._current_colors
|
||||
self._current_colors = self._compute_over(self._cached_colors, colors)
|
||||
|
||||
# apply the texture
|
||||
self._apply()
|
||||
|
||||
def remove_overlay(self, names):
|
||||
to_update = False
|
||||
if not isinstance(names, list):
|
||||
names = [names]
|
||||
for name in names:
|
||||
if name in self._overlays:
|
||||
del self._overlays[name]
|
||||
to_update = True
|
||||
if to_update:
|
||||
self.update()
|
||||
|
||||
def _apply(self):
|
||||
if self._current_colors is None or self._renderer is None:
|
||||
return
|
||||
self._polydata[self._default_scalars_name] = self._current_colors
|
||||
|
||||
def update(self, colors=None):
|
||||
if colors is not None and self._cached_colors is not None:
|
||||
self._current_colors = self._compute_over(self._cached_colors, colors)
|
||||
else:
|
||||
self._current_colors, self._cached_colors = self._compose_overlays()
|
||||
self._apply()
|
||||
|
||||
def _clean(self):
|
||||
mapper = self._actor.GetMapper()
|
||||
mapper.SetLookupTable(None)
|
||||
self._actor.SetMapper(None)
|
||||
self._actor = None
|
||||
self._polydata = None
|
||||
self._renderer = None
|
||||
|
||||
def update_overlay(self, name, scalars=None, colormap=None, opacity=None, rng=None):
|
||||
overlay = self._overlays.get(name, None)
|
||||
if overlay is None:
|
||||
return
|
||||
if scalars is not None:
|
||||
overlay._scalars = scalars
|
||||
if colormap is not None:
|
||||
overlay._colormap = colormap
|
||||
if opacity is not None:
|
||||
overlay._opacity = opacity
|
||||
if rng is not None:
|
||||
overlay._rng = rng
|
||||
# partial update: use cache if possible
|
||||
if name == list(self._overlays.keys())[-1]:
|
||||
self.update(colors=overlay.to_colors())
|
||||
else: # full update
|
||||
self.update()
|
||||
8
mne/viz/__init__.py
Normal file
8
mne/viz/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
"""Visualization routines."""
|
||||
import lazy_loader as lazy
|
||||
|
||||
(__getattr__, __dir__, __all__) = lazy.attach_stub(__name__, __file__)
|
||||
177
mne/viz/__init__.pyi
Normal file
177
mne/viz/__init__.pyi
Normal file
@@ -0,0 +1,177 @@
|
||||
__all__ = [
|
||||
"Brain",
|
||||
"ClickableImage",
|
||||
"EvokedField",
|
||||
"Figure3D",
|
||||
"_RAW_CLIP_DEF",
|
||||
"_get_plot_ch_type",
|
||||
"_get_presser",
|
||||
"_plot_sources",
|
||||
"_scraper",
|
||||
"add_background_image",
|
||||
"adjust_axes",
|
||||
"backends",
|
||||
"centers_to_edges",
|
||||
"circular_layout",
|
||||
"close_3d_figure",
|
||||
"close_all_3d_figures",
|
||||
"compare_fiff",
|
||||
"concatenate_images",
|
||||
"create_3d_figure",
|
||||
"eyetracking",
|
||||
"get_3d_backend",
|
||||
"get_brain_class",
|
||||
"get_browser_backend",
|
||||
"iter_topography",
|
||||
"link_brains",
|
||||
"mne_analyze_colormap",
|
||||
"plot_alignment",
|
||||
"plot_arrowmap",
|
||||
"plot_bem",
|
||||
"plot_brain_colorbar",
|
||||
"plot_bridged_electrodes",
|
||||
"plot_ch_adjacency",
|
||||
"plot_channel_labels_circle",
|
||||
"plot_chpi_snr",
|
||||
"plot_compare_evokeds",
|
||||
"plot_cov",
|
||||
"plot_csd",
|
||||
"plot_dipole_amplitudes",
|
||||
"plot_dipole_locations",
|
||||
"plot_drop_log",
|
||||
"plot_epochs",
|
||||
"plot_epochs_image",
|
||||
"plot_epochs_psd",
|
||||
"plot_epochs_psd_topomap",
|
||||
"plot_events",
|
||||
"plot_evoked",
|
||||
"plot_evoked_field",
|
||||
"plot_evoked_image",
|
||||
"plot_evoked_joint",
|
||||
"plot_evoked_topo",
|
||||
"plot_evoked_topomap",
|
||||
"plot_evoked_white",
|
||||
"plot_filter",
|
||||
"plot_head_positions",
|
||||
"plot_ica_components",
|
||||
"plot_ica_overlay",
|
||||
"plot_ica_properties",
|
||||
"plot_ica_scores",
|
||||
"plot_ica_sources",
|
||||
"plot_ideal_filter",
|
||||
"plot_layout",
|
||||
"plot_montage",
|
||||
"plot_projs_joint",
|
||||
"plot_projs_topomap",
|
||||
"plot_raw",
|
||||
"plot_raw_psd",
|
||||
"plot_raw_psd_topo",
|
||||
"plot_regression_weights",
|
||||
"plot_sensors",
|
||||
"plot_snr_estimate",
|
||||
"plot_source_estimates",
|
||||
"plot_source_spectrogram",
|
||||
"plot_sparse_source_estimates",
|
||||
"plot_tfr_topomap",
|
||||
"plot_topo_image_epochs",
|
||||
"plot_topomap",
|
||||
"plot_vector_source_estimates",
|
||||
"plot_volume_source_estimates",
|
||||
"set_3d_backend",
|
||||
"set_3d_options",
|
||||
"set_3d_title",
|
||||
"set_3d_view",
|
||||
"set_browser_backend",
|
||||
"snapshot_brain_montage",
|
||||
"ui_events",
|
||||
"use_3d_backend",
|
||||
"use_browser_backend",
|
||||
]
|
||||
from . import _scraper, backends, eyetracking, ui_events
|
||||
from ._3d import (
|
||||
link_brains,
|
||||
plot_alignment,
|
||||
plot_brain_colorbar,
|
||||
plot_dipole_locations,
|
||||
plot_evoked_field,
|
||||
plot_head_positions,
|
||||
plot_source_estimates,
|
||||
plot_sparse_source_estimates,
|
||||
plot_vector_source_estimates,
|
||||
plot_volume_source_estimates,
|
||||
set_3d_options,
|
||||
snapshot_brain_montage,
|
||||
)
|
||||
from ._brain import Brain
|
||||
from ._figure import get_browser_backend, set_browser_backend, use_browser_backend
|
||||
from ._proj import plot_projs_joint
|
||||
from .backends._abstract import Figure3D
|
||||
from .backends.renderer import (
|
||||
close_3d_figure,
|
||||
close_all_3d_figures,
|
||||
create_3d_figure,
|
||||
get_3d_backend,
|
||||
get_brain_class,
|
||||
set_3d_backend,
|
||||
set_3d_title,
|
||||
set_3d_view,
|
||||
use_3d_backend,
|
||||
)
|
||||
from .circle import circular_layout, plot_channel_labels_circle
|
||||
from .epochs import plot_drop_log, plot_epochs, plot_epochs_image, plot_epochs_psd
|
||||
from .evoked import (
|
||||
plot_compare_evokeds,
|
||||
plot_evoked,
|
||||
plot_evoked_image,
|
||||
plot_evoked_joint,
|
||||
plot_evoked_topo,
|
||||
plot_evoked_white,
|
||||
plot_snr_estimate,
|
||||
)
|
||||
from .evoked_field import EvokedField
|
||||
from .ica import (
|
||||
_plot_sources,
|
||||
plot_ica_overlay,
|
||||
plot_ica_properties,
|
||||
plot_ica_scores,
|
||||
plot_ica_sources,
|
||||
)
|
||||
from .misc import (
|
||||
_get_presser,
|
||||
adjust_axes,
|
||||
plot_bem,
|
||||
plot_chpi_snr,
|
||||
plot_cov,
|
||||
plot_csd,
|
||||
plot_dipole_amplitudes,
|
||||
plot_events,
|
||||
plot_filter,
|
||||
plot_ideal_filter,
|
||||
plot_source_spectrogram,
|
||||
)
|
||||
from .montage import plot_montage
|
||||
from .raw import _RAW_CLIP_DEF, plot_raw, plot_raw_psd, plot_raw_psd_topo
|
||||
from .topo import iter_topography, plot_topo_image_epochs
|
||||
from .topomap import (
|
||||
plot_arrowmap,
|
||||
plot_bridged_electrodes,
|
||||
plot_ch_adjacency,
|
||||
plot_epochs_psd_topomap,
|
||||
plot_evoked_topomap,
|
||||
plot_ica_components,
|
||||
plot_layout,
|
||||
plot_projs_topomap,
|
||||
plot_regression_weights,
|
||||
plot_tfr_topomap,
|
||||
plot_topomap,
|
||||
)
|
||||
from .utils import (
|
||||
ClickableImage,
|
||||
_get_plot_ch_type,
|
||||
add_background_image,
|
||||
centers_to_edges,
|
||||
compare_fiff,
|
||||
concatenate_images,
|
||||
mne_analyze_colormap,
|
||||
plot_sensors,
|
||||
)
|
||||
11
mne/viz/_brain/__init__.py
Normal file
11
mne/viz/_brain/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Plot Cortex Surface."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from ._brain import Brain, _LayeredMesh
|
||||
from ._scraper import _BrainScraper
|
||||
from ._linkviewer import _LinkViewer
|
||||
|
||||
__all__ = ["Brain"]
|
||||
4141
mne/viz/_brain/_brain.py
Normal file
4141
mne/viz/_brain/_brain.py
Normal file
File diff suppressed because it is too large
Load Diff
98
mne/viz/_brain/_linkviewer.py
Normal file
98
mne/viz/_brain/_linkviewer.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...utils import warn
|
||||
from ..ui_events import link
|
||||
|
||||
|
||||
class _LinkViewer:
|
||||
"""Class to link multiple Brain objects."""
|
||||
|
||||
def __init__(self, brains, time=True, camera=False, colorbar=True, picking=False):
|
||||
self.brains = brains
|
||||
self.leader = self.brains[0] # select a brain as leader
|
||||
|
||||
# check time infos
|
||||
times = [brain._times for brain in brains]
|
||||
if time and not all(np.allclose(x, times[0]) for x in times):
|
||||
warn("stc.times do not match, not linking time")
|
||||
time = False
|
||||
|
||||
if camera:
|
||||
self.link_cameras()
|
||||
|
||||
events_to_link = []
|
||||
if time:
|
||||
events_to_link.append("time_change")
|
||||
if colorbar:
|
||||
events_to_link.append("colormap_range")
|
||||
|
||||
for brain in brains[1:]:
|
||||
link(self.leader, brain, include_events=events_to_link)
|
||||
|
||||
if picking:
|
||||
|
||||
def _func_add(*args, **kwargs):
|
||||
for brain in self.brains:
|
||||
brain._add_vertex_glyph2(*args, **kwargs)
|
||||
brain.plotter.update()
|
||||
|
||||
def _func_remove(*args, **kwargs):
|
||||
for brain in self.brains:
|
||||
brain._remove_vertex_glyph2(*args, **kwargs)
|
||||
|
||||
# save initial picked points
|
||||
initial_points = dict()
|
||||
for hemi in ("lh", "rh"):
|
||||
initial_points[hemi] = set()
|
||||
for brain in self.brains:
|
||||
initial_points[hemi] |= set(brain.picked_points[hemi])
|
||||
|
||||
# link the viewers
|
||||
for brain in self.brains:
|
||||
brain.clear_glyphs()
|
||||
brain._add_vertex_glyph2 = brain._add_vertex_glyph
|
||||
brain._add_vertex_glyph = _func_add
|
||||
brain._remove_vertex_glyph2 = brain._remove_vertex_glyph
|
||||
brain._remove_vertex_glyph = _func_remove
|
||||
|
||||
# link the initial points
|
||||
for hemi in initial_points.keys():
|
||||
if hemi in brain._layered_meshes:
|
||||
mesh = brain._layered_meshes[hemi]._polydata
|
||||
for vertex_id in initial_points[hemi]:
|
||||
self.leader._add_vertex_glyph(hemi, mesh, vertex_id)
|
||||
|
||||
def set_fmin(self, value):
|
||||
self.leader.update_lut(fmin=value)
|
||||
|
||||
def set_fmid(self, value):
|
||||
self.leader.update_lut(fmid=value)
|
||||
|
||||
def set_fmax(self, value):
|
||||
self.leader.update_lut(fmax=value)
|
||||
|
||||
def set_time_point(self, value):
|
||||
self.leader.set_time_point(value)
|
||||
|
||||
def set_playback_speed(self, value):
|
||||
self.leader.set_playback_speed(value)
|
||||
|
||||
def toggle_playback(self):
|
||||
self.leader.toggle_playback()
|
||||
|
||||
def link_cameras(self):
|
||||
from ..backends._pyvista import _add_camera_callback
|
||||
|
||||
def _update_camera(vtk_picker, event):
|
||||
for brain in self.brains:
|
||||
brain.plotter.update()
|
||||
|
||||
camera = self.leader.plotter.camera
|
||||
_add_camera_callback(camera, _update_camera)
|
||||
for brain in self.brains:
|
||||
for renderer in brain.plotter.renderers:
|
||||
renderer.camera = camera
|
||||
102
mne/viz/_brain/_scraper.py
Normal file
102
mne/viz/_brain/_scraper.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import os
|
||||
import os.path as op
|
||||
|
||||
from ._brain import Brain
|
||||
|
||||
|
||||
class _BrainScraper:
|
||||
"""Scrape Brain objects."""
|
||||
|
||||
def __repr__(self):
|
||||
return "<BrainScraper>"
|
||||
|
||||
def __call__(self, block, block_vars, gallery_conf):
|
||||
rst = ""
|
||||
for brain in list(block_vars["example_globals"].values()):
|
||||
# Only need to process if it's a brain with a time_viewer
|
||||
# with traces on and shown in the same window, otherwise
|
||||
# PyVista and matplotlib scrapers can just do the work
|
||||
if (not isinstance(brain, Brain)) or brain._closed:
|
||||
continue
|
||||
from matplotlib import animation
|
||||
from matplotlib import pyplot as plt
|
||||
from sphinx_gallery.scrapers import matplotlib_scraper
|
||||
|
||||
img = brain.screenshot(time_viewer=True)
|
||||
dpi = 100.0
|
||||
figsize = (img.shape[1] / dpi, img.shape[0] / dpi)
|
||||
fig = plt.figure(figsize=figsize, dpi=dpi)
|
||||
ax = plt.Axes(fig, [0, 0, 1, 1])
|
||||
fig.add_axes(ax)
|
||||
img = ax.imshow(img)
|
||||
movie_key = "# brain.save_movie"
|
||||
if movie_key in block[1]:
|
||||
kwargs = dict()
|
||||
# Parse our parameters
|
||||
lines = block[1].splitlines()
|
||||
for li, line in enumerate(block[1].splitlines()):
|
||||
if line.startswith(movie_key):
|
||||
line = line[len(movie_key) :].replace("..., ", "")
|
||||
for ni in range(1, 5): # should be enough
|
||||
if len(lines) > li + ni and lines[li + ni].startswith(
|
||||
"# "
|
||||
):
|
||||
line = line + lines[li + ni][1:].strip()
|
||||
else:
|
||||
break
|
||||
assert line.startswith("(") and line.endswith(")")
|
||||
kwargs.update(eval(f"dict{line}")) # nosec B307
|
||||
for key, default in [
|
||||
("time_dilation", 4),
|
||||
("framerate", 24),
|
||||
("tmin", None),
|
||||
("tmax", None),
|
||||
("interpolation", None),
|
||||
("time_viewer", False),
|
||||
]:
|
||||
if key not in kwargs:
|
||||
kwargs[key] = default
|
||||
kwargs.pop("filename", None) # always omit this one
|
||||
if brain.time_viewer:
|
||||
assert kwargs["time_viewer"], "Must use time_viewer=True"
|
||||
frames = brain._make_movie_frames(callback=None, **kwargs)
|
||||
|
||||
# Turn them into an animation
|
||||
def func(frame):
|
||||
img.set_data(frame)
|
||||
return [img]
|
||||
|
||||
anim = animation.FuncAnimation(
|
||||
fig,
|
||||
func=func,
|
||||
frames=frames,
|
||||
blit=True,
|
||||
interval=1000.0 / kwargs["framerate"],
|
||||
)
|
||||
|
||||
# Out to sphinx-gallery:
|
||||
#
|
||||
# 1. A static image but hide it (useful for carousel)
|
||||
if animation.FFMpegWriter.isAvailable():
|
||||
writer = "ffmpeg"
|
||||
elif animation.ImageMagickWriter.isAvailable():
|
||||
writer = "imagemagick"
|
||||
else:
|
||||
writer = None
|
||||
static_fname = next(block_vars["image_path_iterator"])
|
||||
static_fname = static_fname[:-4] + ".gif"
|
||||
anim.save(static_fname, writer=writer, dpi=dpi)
|
||||
rel_fname = op.relpath(static_fname, gallery_conf["src_dir"])
|
||||
rel_fname = rel_fname.replace(os.sep, "/").lstrip("/")
|
||||
rst += f"\n.. image:: /{rel_fname}\n :class: hidden\n"
|
||||
|
||||
# 2. An animation that will be embedded and visible
|
||||
block_vars["example_globals"]["_brain_anim_"] = anim
|
||||
|
||||
brain.close()
|
||||
rst += matplotlib_scraper(block, block_vars, gallery_conf)
|
||||
return rst
|
||||
185
mne/viz/_brain/colormap.py
Normal file
185
mne/viz/_brain/colormap.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_lut(cmap, n_colors=256, center=None):
|
||||
"""Return a colormap suitable for setting as a LUT."""
|
||||
from .._3d import _get_cmap
|
||||
|
||||
assert not (isinstance(cmap, str) and cmap == "auto")
|
||||
cmap = _get_cmap(cmap)
|
||||
lut = np.round(cmap(np.linspace(0, 1, n_colors)) * 255.0).astype(np.int64)
|
||||
return lut
|
||||
|
||||
|
||||
def scale_sequential_lut(lut_table, fmin, fmid, fmax):
|
||||
"""Scale a sequential colormap."""
|
||||
assert fmin <= fmid <= fmax # guaranteed by calculate_lut
|
||||
lut_table_new = lut_table.copy()
|
||||
n_colors = lut_table.shape[0]
|
||||
n_colors2 = n_colors // 2
|
||||
|
||||
if fmax == fmin:
|
||||
fmid_idx = 0
|
||||
else:
|
||||
fmid_idx = np.clip(
|
||||
int(np.round(n_colors * ((fmid - fmin) / (fmax - fmin))) - 1),
|
||||
0,
|
||||
n_colors - 2,
|
||||
)
|
||||
|
||||
n_left = fmid_idx + 1
|
||||
n_right = n_colors - n_left
|
||||
for i in range(4):
|
||||
lut_table_new[: fmid_idx + 1, i] = np.interp(
|
||||
np.linspace(0, n_colors2 - 1, n_left), np.arange(n_colors), lut_table[:, i]
|
||||
)
|
||||
lut_table_new[fmid_idx + 1 :, i] = np.interp(
|
||||
np.linspace(n_colors - 1, n_colors2, n_right)[::-1],
|
||||
np.arange(n_colors),
|
||||
lut_table[:, i],
|
||||
)
|
||||
|
||||
return lut_table_new
|
||||
|
||||
|
||||
def get_fill_colors(cols, n_fill):
|
||||
"""Get the fill colors for the middle of divergent colormaps."""
|
||||
steps = np.linalg.norm(np.diff(cols[:, :3].astype(float), axis=0), axis=1)
|
||||
|
||||
ind = np.flatnonzero(steps[1:-1] > steps[[0, -1]].mean() * 3)
|
||||
if ind.size > 0:
|
||||
# choose the two colors between which there is the large step
|
||||
ind = ind[0] + 1
|
||||
fillcols = np.r_[
|
||||
np.tile(cols[ind, :], (n_fill // 2, 1)),
|
||||
np.tile(cols[ind + 1, :], (n_fill - n_fill // 2, 1)),
|
||||
]
|
||||
else:
|
||||
# choose a color from the middle of the colormap
|
||||
fillcols = np.tile(cols[int(cols.shape[0] / 2), :], (n_fill, 1))
|
||||
|
||||
return fillcols
|
||||
|
||||
|
||||
def calculate_lut(lut_table, alpha, fmin, fmid, fmax, center=None, transparent=True):
|
||||
"""Transparent color map calculation.
|
||||
|
||||
A colormap may be sequential or divergent. When the colormap is
|
||||
divergent indicate this by providing a value for 'center'. The
|
||||
meanings of fmin, fmid and fmax are different for sequential and
|
||||
divergent colormaps. A sequential colormap is characterised by::
|
||||
|
||||
[fmin, fmid, fmax]
|
||||
|
||||
where fmin and fmax define the edges of the colormap and fmid
|
||||
will be the value mapped to the center of the originally chosen colormap.
|
||||
A divergent colormap is characterised by::
|
||||
|
||||
[center-fmax, center-fmid, center-fmin, center,
|
||||
center+fmin, center+fmid, center+fmax]
|
||||
|
||||
i.e., values between center-fmin and center+fmin will not be shown
|
||||
while center-fmid will map to the fmid of the first half of the
|
||||
original colormap and center-fmid to the fmid of the second half.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lim_cmap : Colormap
|
||||
Color map obtained from _process_mapdata.
|
||||
alpha : float
|
||||
Alpha value to apply globally to the overlay. Has no effect with mpl
|
||||
backend.
|
||||
fmin : float
|
||||
Min value in colormap.
|
||||
fmid : float
|
||||
Intermediate value in colormap.
|
||||
fmax : float
|
||||
Max value in colormap.
|
||||
center : float or None
|
||||
If not None, center of a divergent colormap, changes the meaning of
|
||||
fmin, fmax and fmid.
|
||||
transparent : boolean
|
||||
if True: use a linear transparency between fmin and fmid and make
|
||||
values below fmin fully transparent (symmetrically for divergent
|
||||
colormaps)
|
||||
|
||||
Returns
|
||||
-------
|
||||
cmap : matplotlib.ListedColormap
|
||||
Color map with transparency channel.
|
||||
"""
|
||||
if not fmin <= fmid <= fmax:
|
||||
raise ValueError(f"Must have fmin ({fmin}) <= fmid ({fmid}) <= fmax ({fmax})")
|
||||
lut_table = create_lut(lut_table)
|
||||
assert lut_table.dtype.kind == "i"
|
||||
divergent = center is not None
|
||||
n_colors = lut_table.shape[0]
|
||||
|
||||
# Add transparency if needed
|
||||
n_colors2 = n_colors // 2
|
||||
if transparent:
|
||||
if divergent:
|
||||
N4 = np.full(4, n_colors // 4)
|
||||
N4[[0, 3, 1, 2][: np.mod(n_colors, 4)]] += 1
|
||||
assert N4.sum() == n_colors
|
||||
lut_table[:, -1] = np.round(
|
||||
np.hstack(
|
||||
[
|
||||
np.full(N4[0], 255.0),
|
||||
np.linspace(0, 255, N4[1])[::-1],
|
||||
np.linspace(0, 255, N4[2]),
|
||||
np.full(N4[3], 255.0),
|
||||
]
|
||||
)
|
||||
)
|
||||
else:
|
||||
lut_table[:n_colors2, -1] = np.round(np.linspace(0, 255, n_colors2))
|
||||
lut_table[n_colors2:, -1] = 255
|
||||
|
||||
alpha = float(alpha)
|
||||
if alpha < 1.0:
|
||||
lut_table[:, -1] = np.round(lut_table[:, -1] * alpha)
|
||||
|
||||
if divergent:
|
||||
if np.isclose(fmax, fmin, rtol=1e-6, atol=0):
|
||||
lut_table = np.r_[
|
||||
lut_table[:1],
|
||||
get_fill_colors(
|
||||
lut_table[n_colors2 - 3 : n_colors2 + 3, :], n_colors - 2
|
||||
),
|
||||
lut_table[-1:],
|
||||
]
|
||||
else:
|
||||
n_fill = int(round(fmin * n_colors2 / (fmax - fmin))) * 2
|
||||
lut_table = np.r_[
|
||||
scale_sequential_lut(
|
||||
lut_table[:n_colors2, :],
|
||||
center - fmax,
|
||||
center - fmid,
|
||||
center - fmin,
|
||||
),
|
||||
get_fill_colors(lut_table[n_colors2 - 3 : n_colors2 + 3, :], n_fill),
|
||||
scale_sequential_lut(
|
||||
lut_table[n_colors2:, :][::-1],
|
||||
center - fmax,
|
||||
center - fmid,
|
||||
center - fmin,
|
||||
)[::-1],
|
||||
]
|
||||
else:
|
||||
lut_table = scale_sequential_lut(lut_table, fmin, fmid, fmax)
|
||||
|
||||
n_colors = lut_table.shape[0]
|
||||
if n_colors != 256:
|
||||
lut = np.zeros((256, 4))
|
||||
x = np.linspace(1, n_colors, 256)
|
||||
for chan in range(4):
|
||||
lut[:, chan] = np.interp(x, np.arange(1, n_colors + 1), lut_table[:, chan])
|
||||
lut_table = lut
|
||||
|
||||
lut_table = lut_table.astype(np.float64) / 255.0
|
||||
return lut_table
|
||||
177
mne/viz/_brain/surface.py
Normal file
177
mne/viz/_brain/surface.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from os import path as path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...surface import _read_patch, complete_surface_info, read_curvature, read_surface
|
||||
from ...utils import _check_fname, _check_option, _validate_type, get_subjects_dir
|
||||
|
||||
|
||||
class _Surface:
|
||||
"""Container for a brain surface.
|
||||
|
||||
It is used for storing vertices, faces and morphometric data
|
||||
(curvature) of a hemisphere mesh.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
subject : string
|
||||
Name of subject
|
||||
hemi : {'lh', 'rh'}
|
||||
Which hemisphere to load
|
||||
surf : string
|
||||
Name of the surface to load (eg. inflated, orig ...).
|
||||
subjects_dir : str | None
|
||||
If not None, this directory will be used as the subjects directory
|
||||
instead of the value set using the SUBJECTS_DIR environment variable.
|
||||
offset : float | None
|
||||
If 0.0, the surface will be offset such that the medial
|
||||
wall is aligned with the origin. If None, no offset will
|
||||
be applied. If != 0.0, an additional offset will be used.
|
||||
units : str
|
||||
Can be 'm' or 'mm' (default).
|
||||
x_dir : ndarray | None
|
||||
The x direction to use for offset alignment.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
bin_curv : numpy.ndarray
|
||||
Curvature values stored as non-negative integers.
|
||||
coords : numpy.ndarray
|
||||
nvtx x 3 array of vertex (x, y, z) coordinates.
|
||||
curv : numpy.ndarray
|
||||
Vector representation of surface morpometry (curvature) values as
|
||||
loaded from a file.
|
||||
grey_curv : numpy.ndarray
|
||||
Normalized morphometry (curvature) data, used in order to get
|
||||
a gray cortex.
|
||||
faces : numpy.ndarray
|
||||
nfaces x 3 array of defining mesh triangles.
|
||||
hemi : {'lh', 'rh'}
|
||||
Which hemisphere to load.
|
||||
nn : numpy.ndarray
|
||||
Vertex normals for a triangulated surface.
|
||||
offset : float | None
|
||||
If float, align inside edge of each hemisphere to center + offset.
|
||||
If None, do not change coordinates (default).
|
||||
subject : string
|
||||
Name of subject.
|
||||
surf : string
|
||||
Name of the surface to load (eg. inflated, orig ...).
|
||||
units : str
|
||||
Can be 'm' or 'mm' (default).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
subject,
|
||||
hemi,
|
||||
surf,
|
||||
subjects_dir=None,
|
||||
offset=None,
|
||||
units="mm",
|
||||
x_dir=None,
|
||||
):
|
||||
x_dir = np.array([1.0, 0, 0]) if x_dir is None else x_dir
|
||||
assert isinstance(x_dir, np.ndarray)
|
||||
assert np.isclose(np.linalg.norm(x_dir), 1.0, atol=1e-6)
|
||||
assert hemi in ("lh", "rh")
|
||||
_validate_type(offset, (None, "numeric"), "offset")
|
||||
|
||||
self.units = _check_option("units", units, ("mm", "m"))
|
||||
self.subject = subject
|
||||
self.hemi = hemi
|
||||
self.surf = surf
|
||||
self.offset = offset
|
||||
self.bin_curv = None
|
||||
self.coords = None
|
||||
self.curv = None
|
||||
self.faces = None
|
||||
self.nn = None
|
||||
self.labels = dict()
|
||||
self.x_dir = x_dir
|
||||
|
||||
subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=True))
|
||||
self.data_path = path.join(subjects_dir, subject)
|
||||
if surf == "seghead":
|
||||
raise ValueError(
|
||||
"`surf` cannot be seghead, use "
|
||||
"`mne.viz.Brain.add_head` to plot the seghead"
|
||||
)
|
||||
|
||||
def load_geometry(self):
|
||||
"""Load geometry of the surface.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
if self.surf == "flat": # special case
|
||||
fname = path.join(self.data_path, "surf", f"{self.hemi}.cortex.patch.flat")
|
||||
_check_fname(
|
||||
fname, overwrite="read", must_exist=True, name="flatmap surface file"
|
||||
)
|
||||
coords, faces, orig_faces = _read_patch(fname)
|
||||
# rotate 90 degrees to get to a more standard orientation
|
||||
# where X determines the distance between the hemis
|
||||
coords = coords[:, [1, 0, 2]]
|
||||
coords[:, 1] *= -1
|
||||
else:
|
||||
# allow ?h.pial.T1 if ?h.pial doesn't exist for instance
|
||||
# end with '' for better file not found error
|
||||
for img in ("", ".T1", ".T2", ""):
|
||||
surf_fname = path.join(
|
||||
self.data_path, "surf", f"{self.hemi}.{self.surf}{img}"
|
||||
)
|
||||
if path.isfile(surf_fname):
|
||||
break
|
||||
coords, faces = read_surface(surf_fname)
|
||||
orig_faces = faces
|
||||
if self.units == "m":
|
||||
coords /= 1000.0
|
||||
if self.offset is not None:
|
||||
x_ = coords @ self.x_dir
|
||||
if self.hemi == "lh":
|
||||
coords -= (np.max(x_) + self.offset) * self.x_dir
|
||||
else:
|
||||
coords -= (np.min(x_) + self.offset) * self.x_dir
|
||||
surf = dict(rr=coords, tris=faces)
|
||||
complete_surface_info(surf, copy=False, verbose=False, do_neighbor_tri=False)
|
||||
nn = surf["nn"]
|
||||
self.coords = coords
|
||||
self.faces = faces
|
||||
self.orig_faces = orig_faces
|
||||
self.nn = nn
|
||||
|
||||
def __len__(self):
|
||||
"""Return number of vertices."""
|
||||
return len(self.coords)
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
return self.coords[:, 0]
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
return self.coords[:, 1]
|
||||
|
||||
@property
|
||||
def z(self):
|
||||
return self.coords[:, 2]
|
||||
|
||||
def load_curvature(self):
|
||||
"""Load in curvature values from the ?h.curv file."""
|
||||
curv_path = path.join(self.data_path, "surf", f"{self.hemi}.curv")
|
||||
if path.isfile(curv_path):
|
||||
self.curv = read_curvature(curv_path, binary=False)
|
||||
self.bin_curv = np.array(self.curv > 0, np.int64)
|
||||
else:
|
||||
self.curv = None
|
||||
self.bin_curv = None
|
||||
54
mne/viz/_brain/view.py
Normal file
54
mne/viz/_brain/view.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
ORIGIN = "auto"
|
||||
DIST = "auto"
|
||||
|
||||
_lh_views_dict = {
|
||||
"lateral": dict(azimuth=180.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"medial": dict(azimuth=0.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"rostral": dict(azimuth=90.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"caudal": dict(azimuth=270.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"dorsal": dict(azimuth=180.0, elevation=0.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"ventral": dict(azimuth=180.0, elevation=180.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"frontal": dict(azimuth=120.0, elevation=80.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"parietal": dict(azimuth=-120.0, elevation=60.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"sagittal": dict(azimuth=180.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"coronal": dict(azimuth=90.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"axial": dict(
|
||||
azimuth=180.0, elevation=0.0, focalpoint=ORIGIN, roll=0, distance=DIST
|
||||
), # noqa: E501
|
||||
}
|
||||
_rh_views_dict = {
|
||||
"lateral": dict(azimuth=180.0, elevation=-90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"medial": dict(azimuth=0.0, elevation=-90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"rostral": dict(azimuth=-90.0, elevation=-90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"caudal": dict(azimuth=90.0, elevation=-90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"dorsal": dict(azimuth=180.0, elevation=0.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"ventral": dict(azimuth=180.0, elevation=180.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"frontal": dict(azimuth=60.0, elevation=80.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"parietal": dict(azimuth=-60.0, elevation=60.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"sagittal": dict(azimuth=180.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"coronal": dict(azimuth=90.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"axial": dict(
|
||||
azimuth=180.0, elevation=0.0, focalpoint=ORIGIN, roll=0, distance=DIST
|
||||
),
|
||||
}
|
||||
# add short-size version entries into the dict
|
||||
lh_views_dict = _lh_views_dict.copy()
|
||||
for k, v in _lh_views_dict.items():
|
||||
lh_views_dict[k[:3]] = v
|
||||
lh_views_dict["flat"] = dict(
|
||||
azimuth=0, elevation=0, focalpoint=ORIGIN, roll=0, distance=DIST
|
||||
)
|
||||
|
||||
rh_views_dict = _rh_views_dict.copy()
|
||||
for k, v in _rh_views_dict.items():
|
||||
rh_views_dict[k[:3]] = v
|
||||
rh_views_dict["flat"] = dict(
|
||||
azimuth=0, elevation=0, focalpoint=ORIGIN, roll=0, distance=DIST
|
||||
)
|
||||
views_dicts = dict(
|
||||
lh=lh_views_dict, vol=lh_views_dict, both=lh_views_dict, rh=rh_views_dict
|
||||
)
|
||||
214
mne/viz/_dipole.py
Normal file
214
mne/viz/_dipole.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Dipole viz specific functions."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import os.path as op
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial import ConvexHull
|
||||
|
||||
from .._freesurfer import _estimate_talxfm_rigid, _get_head_surface
|
||||
from ..surface import read_surface
|
||||
from ..transforms import _get_trans, apply_trans, invert_transform
|
||||
from ..utils import _check_option, _validate_type, get_subjects_dir
|
||||
from .utils import _validate_if_list_of_axes, plt_show
|
||||
|
||||
|
||||
def _check_concat_dipoles(dipole):
|
||||
from ..dipole import Dipole, _concatenate_dipoles
|
||||
|
||||
if not isinstance(dipole, Dipole):
|
||||
dipole = _concatenate_dipoles(dipole)
|
||||
return dipole
|
||||
|
||||
|
||||
def _plot_dipole_mri_outlines(
|
||||
dipoles,
|
||||
*,
|
||||
subject,
|
||||
trans,
|
||||
ax,
|
||||
subjects_dir,
|
||||
color,
|
||||
scale,
|
||||
coord_frame,
|
||||
show,
|
||||
block,
|
||||
head_source,
|
||||
title,
|
||||
surf,
|
||||
width,
|
||||
):
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.collections import LineCollection, PatchCollection
|
||||
from matplotlib.patches import Circle
|
||||
|
||||
extra = 'when mode is "outlines"'
|
||||
trans = _get_trans(trans, fro="head", to="mri")[0]
|
||||
_check_option(
|
||||
"coord_frame", coord_frame, ["head", "mri", "mri_rotated"], extra=extra
|
||||
)
|
||||
_validate_type(surf, (str, None), "surf")
|
||||
_check_option("surf", surf, ("white", "pial", None))
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(1, 3, figsize=(7, 2.5), squeeze=True, layout="constrained")
|
||||
_validate_if_list_of_axes(ax, 3, name="ax")
|
||||
dipoles = _check_concat_dipoles(dipoles)
|
||||
color = "r" if color is None else color
|
||||
scale = 0.03 if scale is None else scale
|
||||
width = 0.015 if width is None else width
|
||||
fig = ax[0].figure
|
||||
surfs = dict()
|
||||
hemis = ("lh", "rh")
|
||||
if surf is not None:
|
||||
for hemi in hemis:
|
||||
surfs[hemi] = read_surface(
|
||||
op.join(subjects_dir, subject, "surf", f"{hemi}.{surf}"),
|
||||
return_dict=True,
|
||||
)[2]
|
||||
surfs[hemi]["rr"] /= 1000.0
|
||||
subjects_dir = get_subjects_dir(subjects_dir)
|
||||
if subjects_dir is not None:
|
||||
subjects_dir = str(subjects_dir)
|
||||
surfs["head"] = _get_head_surface(head_source, subject, subjects_dir)
|
||||
del head_source
|
||||
mri_trans = head_trans = np.eye(4)
|
||||
if coord_frame in ("mri", "mri_rotated"):
|
||||
head_trans = trans["trans"]
|
||||
if coord_frame == "mri_rotated":
|
||||
rot = _estimate_talxfm_rigid(subject, subjects_dir)
|
||||
rot[:3, 3] = 0.0
|
||||
head_trans = rot @ head_trans
|
||||
mri_trans = rot @ mri_trans
|
||||
else:
|
||||
assert coord_frame == "head"
|
||||
mri_trans = invert_transform(trans)["trans"]
|
||||
for s in surfs.values():
|
||||
s["rr"] = 1000 * apply_trans(mri_trans, s["rr"])
|
||||
del mri_trans
|
||||
levels = dict()
|
||||
if surf is not None:
|
||||
use_rr = np.concatenate([surfs[key]["rr"] for key in hemis])
|
||||
else:
|
||||
use_rr = surfs["head"]["rr"]
|
||||
views = [("Axial", "XY"), ("Coronal", "XZ"), ("Sagittal", "YZ")]
|
||||
# axial: 25% up the Z axis
|
||||
axial = float(np.percentile(use_rr[:, 2], 20.0))
|
||||
coronal = float(np.percentile(use_rr[:, 1], 55.0))
|
||||
for key in hemis + ("head",):
|
||||
levels[key] = dict(Axial=axial, Coronal=coronal)
|
||||
if surf is not None:
|
||||
levels["rh"]["Sagittal"] = float(np.percentile(surfs["rh"]["rr"][:, 0], 50))
|
||||
levels["head"]["Sagittal"] = 0.0
|
||||
for ax_, (name, coords) in zip(ax, views):
|
||||
idx = list(map(dict(X=0, Y=1, Z=2).get, coords))
|
||||
miss = np.setdiff1d(np.arange(3), idx)[0]
|
||||
pos = 1000 * apply_trans(head_trans, dipoles.pos)
|
||||
ori = 1000 * apply_trans(head_trans, dipoles.ori, move=False)
|
||||
lims = dict()
|
||||
for ii, char in enumerate(coords):
|
||||
lim = surfs["head"]["rr"][:, idx[ii]]
|
||||
lim = np.array([lim.min(), lim.max()])
|
||||
lims[char] = lim
|
||||
ax_.quiver(
|
||||
pos[:, idx[0]],
|
||||
pos[:, idx[1]],
|
||||
scale * ori[:, idx[0]],
|
||||
scale * ori[:, idx[1]],
|
||||
color=color,
|
||||
pivot="middle",
|
||||
zorder=5,
|
||||
scale_units="xy",
|
||||
angles="xy",
|
||||
scale=1.0,
|
||||
width=width,
|
||||
minshaft=0.5,
|
||||
headwidth=2.5,
|
||||
headlength=2.5,
|
||||
headaxislength=2,
|
||||
)
|
||||
coll = PatchCollection(
|
||||
[
|
||||
Circle((x, y), radius=scale * 1000 * width * 6)
|
||||
for x, y in zip(pos[:, idx[0]], pos[:, idx[1]])
|
||||
],
|
||||
linewidths=0.0,
|
||||
facecolors=color,
|
||||
zorder=6,
|
||||
)
|
||||
for key, surf in surfs.items():
|
||||
try:
|
||||
level = levels[key][name]
|
||||
except KeyError:
|
||||
continue
|
||||
if key != "head":
|
||||
rrs = surf["rr"][:, idx]
|
||||
tris = ConvexHull(rrs).simplices
|
||||
segments = LineCollection(
|
||||
rrs[:, [0, 1]][tris],
|
||||
linewidths=1,
|
||||
linestyles="-",
|
||||
colors="k",
|
||||
zorder=3,
|
||||
alpha=0.25,
|
||||
)
|
||||
ax_.add_collection(segments)
|
||||
ax_.tricontour(
|
||||
surf["rr"][:, idx[0]],
|
||||
surf["rr"][:, idx[1]],
|
||||
surf["tris"],
|
||||
surf["rr"][:, miss],
|
||||
levels=[level],
|
||||
colors="k",
|
||||
linewidths=1.0,
|
||||
linestyles=["-"],
|
||||
zorder=4,
|
||||
alpha=0.5,
|
||||
)
|
||||
# TODO: this breaks the PatchCollection in MPL
|
||||
# for coll in h.collections:
|
||||
# coll.set_clip_on(False)
|
||||
ax_.add_collection(coll)
|
||||
ax_.set(
|
||||
title=name,
|
||||
xlim=lims[coords[0]],
|
||||
ylim=lims[coords[1]],
|
||||
xlabel=coords[0] + " (mm)",
|
||||
ylabel=coords[1] + " (mm)",
|
||||
)
|
||||
for spine in ax_.spines.values():
|
||||
spine.set_visible(False)
|
||||
ax_.grid(True, ls=":", zorder=2)
|
||||
ax_.set_aspect("equal")
|
||||
|
||||
if title is not None:
|
||||
fig.suptitle(title)
|
||||
plt_show(show, block=block)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def _plot_dipole_3d(dipoles, *, coord_frame, color, fig, trans, scale, mode):
|
||||
from .backends.renderer import _get_renderer
|
||||
|
||||
_check_option("coord_frame", coord_frame, ("head", "mri"))
|
||||
color = "r" if color is None else color
|
||||
scale = 0.005 if scale is None else scale
|
||||
renderer = _get_renderer(fig=fig, size=(600, 600))
|
||||
pos = dipoles.pos
|
||||
ori = dipoles.ori
|
||||
if coord_frame != "head":
|
||||
trans = _get_trans(trans, fro="head", to=coord_frame)[0]
|
||||
pos = apply_trans(trans, pos)
|
||||
ori = apply_trans(trans, ori)
|
||||
|
||||
renderer.sphere(center=pos, color=color, scale=scale)
|
||||
if mode == "arrow":
|
||||
x, y, z = pos.T
|
||||
u, v, w = ori.T
|
||||
renderer.quiver3d(x, y, z, u, v, w, scale=3 * scale, color=color, mode="arrow")
|
||||
renderer.show()
|
||||
fig = renderer.scene()
|
||||
return fig
|
||||
853
mne/viz/_figure.py
Normal file
853
mne/viz/_figure.py
Normal file
@@ -0,0 +1,853 @@
|
||||
"""Base classes and functions for 2D browser backends."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from itertools import cycle
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.pick import _DATA_CH_TYPES_SPLIT
|
||||
from ..defaults import _handle_default
|
||||
from ..filter import _iir_filter, _overlap_add_filter
|
||||
from ..fixes import _compare_version
|
||||
from ..utils import (
|
||||
_check_option,
|
||||
_get_stim_channel,
|
||||
_validate_type,
|
||||
get_config,
|
||||
logger,
|
||||
set_config,
|
||||
verbose,
|
||||
)
|
||||
from .backends._utils import VALID_BROWSE_BACKENDS
|
||||
from .utils import _get_color_list, _setup_plot_projector, _show_browser
|
||||
|
||||
MNE_BROWSER_BACKEND = None
|
||||
backend = None
|
||||
|
||||
|
||||
class BrowserParams:
|
||||
"""Container object for 2D browser parameters."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# default key to close window
|
||||
self.close_key = "escape"
|
||||
vars(self).update(**kwargs)
|
||||
|
||||
|
||||
class BrowserBase(ABC):
|
||||
"""A base class containing for the 2D browser.
|
||||
|
||||
This class contains all backend-independent attributes and methods.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
from ..epochs import BaseEpochs
|
||||
from ..io import BaseRaw
|
||||
from ..preprocessing import ICA
|
||||
|
||||
self.backend_name = None
|
||||
|
||||
self._data = None
|
||||
self._times = None
|
||||
|
||||
self.mne = BrowserParams(**kwargs)
|
||||
|
||||
inst = kwargs.get("inst", None)
|
||||
ica = kwargs.get("ica", None)
|
||||
|
||||
# what kind of data are we dealing with?
|
||||
if isinstance(ica, ICA):
|
||||
self.mne.instance_type = "ica"
|
||||
elif isinstance(inst, BaseRaw):
|
||||
self.mne.instance_type = "raw"
|
||||
elif isinstance(inst, BaseEpochs):
|
||||
self.mne.instance_type = "epochs"
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected an instance of Raw, Epochs, or ICA, got {type(inst)}."
|
||||
)
|
||||
|
||||
logger.debug(f"Opening {self.mne.instance_type} browser...")
|
||||
|
||||
self.mne.ica_type = None
|
||||
if self.mne.instance_type == "ica":
|
||||
if isinstance(self.mne.ica_inst, BaseRaw):
|
||||
self.mne.ica_type = "raw"
|
||||
elif isinstance(self.mne.ica_inst, BaseEpochs):
|
||||
self.mne.ica_type = "epochs"
|
||||
self.mne.is_epochs = "epochs" in (self.mne.instance_type, self.mne.ica_type)
|
||||
|
||||
# things that always start the same
|
||||
self.mne.ch_start = 0
|
||||
self.mne.projector = None
|
||||
if hasattr(self.mne, "projs"):
|
||||
self.mne.projs_active = np.array([p["active"] for p in self.mne.projs])
|
||||
self.mne.whitened_ch_names = list()
|
||||
if hasattr(self.mne, "noise_cov"):
|
||||
self.mne.use_noise_cov = self.mne.noise_cov is not None
|
||||
# allow up to 10000 zorder levels for annotations
|
||||
self.mne.zorder = dict(
|
||||
patch=0,
|
||||
grid=1,
|
||||
ann=2,
|
||||
events=10003,
|
||||
bads=10004,
|
||||
data=10005,
|
||||
mag=10006,
|
||||
grad=10007,
|
||||
scalebar=10008,
|
||||
vline=10009,
|
||||
)
|
||||
# additional params for epochs (won't affect raw / ICA)
|
||||
self.mne.epoch_traces = list()
|
||||
self.mne.bad_epochs = list()
|
||||
if inst is not None:
|
||||
self.mne.sampling_period = np.diff(inst.times[:2])[0] / inst.info["sfreq"]
|
||||
# annotations
|
||||
self.mne.annotations = list()
|
||||
self.mne.hscroll_annotations = list()
|
||||
self.mne.annotation_segments = list()
|
||||
self.mne.annotation_texts = list()
|
||||
self.mne.new_annotation_labels = list()
|
||||
self.mne.annotation_segment_colors = dict()
|
||||
self.mne.annotation_hover_line = None
|
||||
self.mne.draggable_annotations = False
|
||||
# lines
|
||||
self.mne.event_lines = list()
|
||||
self.mne.event_texts = list()
|
||||
self.mne.vline_visible = False
|
||||
# decim
|
||||
self.mne.decim_times = None
|
||||
self.mne.decim_data = None
|
||||
# scalings
|
||||
if hasattr(self.mne, "butterfly"):
|
||||
self.mne.scale_factor = 0.5 if self.mne.butterfly else 1.0
|
||||
self.mne.scalebars = dict()
|
||||
self.mne.scalebar_texts = dict()
|
||||
# ancillary child figures
|
||||
self.mne.child_figs = list()
|
||||
self.mne.fig_help = None
|
||||
self.mne.fig_proj = None
|
||||
self.mne.fig_histogram = None
|
||||
self.mne.fig_selection = None
|
||||
self.mne.fig_annotation = None
|
||||
# extra attributes for epochs
|
||||
if self.mne.is_epochs:
|
||||
# add epoch boundaries & center epoch numbers between boundaries
|
||||
self.mne.midpoints = (
|
||||
np.convolve(self.mne.boundary_times, np.ones(2), mode="valid") / 2
|
||||
)
|
||||
|
||||
# initialize picks and projectors
|
||||
self._update_picks()
|
||||
if not self.mne.instance_type == "ica":
|
||||
self._update_projector()
|
||||
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
# ANNOTATIONS
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
|
||||
def _get_annotation_labels(self):
|
||||
"""Get the unique labels in the raw object and added in the UI."""
|
||||
return sorted(
|
||||
set(self.mne.inst.annotations.description)
|
||||
| set(self.mne.new_annotation_labels)
|
||||
)
|
||||
|
||||
def _setup_annotation_colors(self):
|
||||
"""Set up colors for annotations; init some annotation vars."""
|
||||
segment_colors = getattr(self.mne, "annotation_segment_colors", dict())
|
||||
labels = self._get_annotation_labels()
|
||||
colors, red = _get_color_list(annotations=True)
|
||||
color_cycle = cycle(colors)
|
||||
for key, color in segment_colors.items():
|
||||
if color != red and key in labels:
|
||||
next(color_cycle)
|
||||
for idx, key in enumerate(labels):
|
||||
if key.lower().startswith("bad") or key.lower().startswith("edge"):
|
||||
segment_colors[key] = red
|
||||
elif key in segment_colors:
|
||||
continue
|
||||
else:
|
||||
segment_colors[key] = next(color_cycle)
|
||||
self.mne.annotation_segment_colors = segment_colors
|
||||
# init a couple other annotation-related variables
|
||||
self.mne.visible_annotations = {label: True for label in labels}
|
||||
self.mne.show_hide_annotation_checkboxes = None
|
||||
|
||||
def _update_annotation_segments(self):
|
||||
"""Update the array of annotation start/end times."""
|
||||
from ..annotations import _sync_onset
|
||||
|
||||
self.mne.annotation_segments = np.array([])
|
||||
if len(self.mne.inst.annotations):
|
||||
annot_start = _sync_onset(self.mne.inst, self.mne.inst.annotations.onset)
|
||||
durations = self.mne.inst.annotations.duration.copy()
|
||||
durations[durations < 1 / self.mne.info["sfreq"]] = (
|
||||
1 / self.mne.info["sfreq"]
|
||||
)
|
||||
annot_end = annot_start + durations
|
||||
self.mne.annotation_segments = np.vstack((annot_start, annot_end)).T
|
||||
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
# PROJECTOR & BADS
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
|
||||
def _update_projector(self):
|
||||
"""Update the data after projectors (or bads) have changed."""
|
||||
inds = np.where(self.mne.projs_on)[0] # doesn't include "active" projs
|
||||
# copy projs from full list (self.mne.projs) to info object
|
||||
with self.mne.info._unlock():
|
||||
self.mne.info["projs"] = [deepcopy(self.mne.projs[ix]) for ix in inds]
|
||||
# compute the projection operator
|
||||
proj, wh_chs = _setup_plot_projector(
|
||||
self.mne.info, self.mne.noise_cov, True, self.mne.use_noise_cov
|
||||
)
|
||||
self.mne.whitened_ch_names = list(wh_chs)
|
||||
self.mne.projector = proj
|
||||
|
||||
def _toggle_bad_channel(self, idx):
|
||||
"""Mark/unmark bad channels; `idx` is index of *visible* channels."""
|
||||
pick = self.mne.picks[idx]
|
||||
ch_name = self.mne.ch_names[pick]
|
||||
# add/remove from bads list
|
||||
bads = self.mne.info["bads"]
|
||||
marked_bad = ch_name not in bads
|
||||
if marked_bad:
|
||||
bads.append(ch_name)
|
||||
color = self.mne.ch_color_bad
|
||||
else:
|
||||
while ch_name in bads: # to make sure duplicates are removed
|
||||
bads.remove(ch_name)
|
||||
# Only mpl-backend has ch_colors
|
||||
if hasattr(self.mne, "ch_colors"):
|
||||
color = self.mne.ch_colors[idx]
|
||||
else:
|
||||
color = None
|
||||
self.mne.info["bads"] = bads
|
||||
|
||||
self._update_projector()
|
||||
|
||||
return color, pick, marked_bad
|
||||
|
||||
def _toggle_single_channel_annotation(self, ch_pick, annot_idx):
|
||||
current_ch_names = list(self.mne.inst.annotations.ch_names[annot_idx])
|
||||
if ch_pick in current_ch_names:
|
||||
current_ch_names.remove(ch_pick)
|
||||
else:
|
||||
current_ch_names.append(ch_pick)
|
||||
self.mne.inst.annotations.ch_names[annot_idx] = tuple(current_ch_names)
|
||||
|
||||
def _toggle_bad_epoch(self, xtime):
|
||||
epoch_num = self._get_epoch_num_from_time(xtime)
|
||||
epoch_ix = self.mne.inst.selection.tolist().index(epoch_num)
|
||||
if epoch_num in self.mne.bad_epochs:
|
||||
self.mne.bad_epochs.remove(epoch_num)
|
||||
color = "none"
|
||||
else:
|
||||
self.mne.bad_epochs.append(epoch_num)
|
||||
self.mne.bad_epochs.sort()
|
||||
color = self.mne.epoch_color_bad
|
||||
|
||||
return epoch_ix, color
|
||||
|
||||
def _toggle_whitening(self):
|
||||
if self.mne.noise_cov is not None:
|
||||
self.mne.use_noise_cov = not self.mne.use_noise_cov
|
||||
self._update_projector()
|
||||
self._update_yaxis_labels() # add/remove italics
|
||||
self._redraw()
|
||||
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
# MANAGE TRACES
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
|
||||
def _update_picks(self):
|
||||
"""Compute which channel indices to show."""
|
||||
if self.mne.butterfly and self.mne.ch_selections is not None:
|
||||
selections_dict = self._make_butterfly_selections_dict()
|
||||
self.mne.picks = np.concatenate(tuple(selections_dict.values()))
|
||||
elif self.mne.butterfly:
|
||||
self.mne.picks = self.mne.ch_order
|
||||
else:
|
||||
_slice = slice(self.mne.ch_start, self.mne.ch_start + self.mne.n_channels)
|
||||
self.mne.picks = self.mne.ch_order[_slice]
|
||||
self.mne.n_channels = len(self.mne.picks)
|
||||
assert isinstance(self.mne.picks, np.ndarray)
|
||||
assert self.mne.picks.dtype.kind == "i"
|
||||
|
||||
def _make_butterfly_selections_dict(self):
|
||||
"""Make an altered copy of the selections dict for butterfly mode."""
|
||||
selections_dict = deepcopy(self.mne.ch_selections)
|
||||
# remove potential duplicates
|
||||
for selection_group in ("Vertex", "Custom"):
|
||||
selections_dict.pop(selection_group, None)
|
||||
# if present, remove stim channel from non-misc selection groups
|
||||
stim_ch = _get_stim_channel(None, self.mne.info, raise_error=False)
|
||||
if len(stim_ch):
|
||||
stim_pick = self.mne.ch_names.tolist().index(stim_ch[0])
|
||||
for _sel, _picks in selections_dict.items():
|
||||
if _sel != "Misc":
|
||||
stim_mask = np.isin(_picks, [stim_pick], invert=True)
|
||||
selections_dict[_sel] = np.array(_picks)[stim_mask]
|
||||
return selections_dict
|
||||
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
# MANAGE DATA
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
|
||||
def _get_start_stop(self):
|
||||
# update time
|
||||
start_sec = self.mne.t_start - self.mne.first_time
|
||||
stop_sec = start_sec + self.mne.duration
|
||||
if self.mne.is_epochs:
|
||||
start, stop = np.round(
|
||||
np.array([start_sec, stop_sec]) * self.mne.info["sfreq"]
|
||||
).astype(int)
|
||||
else:
|
||||
start, stop = self.mne.inst.time_as_index((start_sec, stop_sec))
|
||||
|
||||
return start, stop
|
||||
|
||||
def _load_data(self, start=None, stop=None):
|
||||
"""Retrieve the bit of data we need for plotting."""
|
||||
if "raw" in (self.mne.instance_type, self.mne.ica_type):
|
||||
# Add additional sample to cover the case sfreq!=1000
|
||||
# when the shown time-range wouldn't correspond to duration anymore
|
||||
if stop is None:
|
||||
return self.mne.inst[:, start:]
|
||||
else:
|
||||
return self.mne.inst[:, start : stop + 2]
|
||||
else:
|
||||
# subtract one sample from tstart before searchsorted, to make sure
|
||||
# we land on the left side of the boundary time (avoid precision
|
||||
# errors)
|
||||
ix_start = np.searchsorted(
|
||||
self.mne.boundary_times, self.mne.t_start - self.mne.sampling_period
|
||||
)
|
||||
ix_stop = ix_start + self.mne.n_epochs
|
||||
item = slice(ix_start, ix_stop)
|
||||
data = np.concatenate(
|
||||
self.mne.inst.get_data(item=item, copy=False), axis=-1
|
||||
)
|
||||
times = np.arange(start, stop) / self.mne.info["sfreq"]
|
||||
return data, times
|
||||
|
||||
def _apply_filter(self, data, start, stop, picks):
|
||||
"""Filter (with same defaults as raw.filter())."""
|
||||
starts, stops = self.mne.filter_bounds
|
||||
mask = (starts < stop) & (stops > start)
|
||||
starts = np.maximum(starts[mask], start) - start
|
||||
stops = np.minimum(stops[mask], stop) - start
|
||||
for _start, _stop in zip(starts, stops):
|
||||
_picks = np.where(np.isin(picks, self.mne.picks_data))[0]
|
||||
if len(_picks) == 0:
|
||||
break
|
||||
this_data = data[_picks, _start:_stop]
|
||||
if isinstance(self.mne.filter_coefs, np.ndarray): # FIR
|
||||
this_data = _overlap_add_filter(
|
||||
this_data, self.mne.filter_coefs, copy=False
|
||||
)
|
||||
else: # IIR
|
||||
this_data = _iir_filter(
|
||||
this_data, self.mne.filter_coefs, None, 1, False
|
||||
)
|
||||
data[_picks, _start:_stop] = this_data
|
||||
|
||||
def _process_data(self, data, start, stop, picks, thread=None):
|
||||
"""Update self.mne.data after user interaction."""
|
||||
# apply projectors
|
||||
if self.mne.projector is not None:
|
||||
# thread is the loading-thread only available in Qt-backend
|
||||
if thread:
|
||||
thread.processText.emit("Applying Projectors...")
|
||||
data = self.mne.projector @ data
|
||||
# get only the channels we're displaying
|
||||
data = data[picks]
|
||||
# remove DC
|
||||
if self.mne.remove_dc:
|
||||
if thread:
|
||||
thread.processText.emit("Removing DC...")
|
||||
data -= np.nanmean(data, axis=1, keepdims=True)
|
||||
# apply filter
|
||||
if self.mne.filter_coefs is not None:
|
||||
if thread:
|
||||
thread.processText.emit("Apply Filter...")
|
||||
self._apply_filter(data, start, stop, picks)
|
||||
# scale the data for display in a 1-vertical-axis-unit slot
|
||||
if thread:
|
||||
thread.processText.emit("Scale Data...")
|
||||
this_names = self.mne.ch_names[picks]
|
||||
this_types = self.mne.ch_types[picks]
|
||||
stims = this_types == "stim"
|
||||
white = np.logical_and(
|
||||
np.isin(this_names, self.mne.whitened_ch_names),
|
||||
np.isin(this_names, self.mne.info["bads"], invert=True),
|
||||
)
|
||||
norms = np.vectorize(self.mne.scalings.__getitem__)(this_types)
|
||||
norms[stims] = data[stims].max(axis=-1)
|
||||
norms[white] = self.mne.scalings["whitened"]
|
||||
norms[norms == 0] = 1
|
||||
data /= 2 * norms[:, np.newaxis]
|
||||
|
||||
return data
|
||||
|
||||
def _update_data(self):
|
||||
start, stop = self._get_start_stop()
|
||||
# get the data
|
||||
data, times = self._load_data(start, stop)
|
||||
# process the data
|
||||
data = self._process_data(data, start, stop, self.mne.picks)
|
||||
# set the data as attributes
|
||||
self.mne.data = data
|
||||
self.mne.times = times
|
||||
|
||||
def _get_epoch_num_from_time(self, time):
|
||||
epoch_nums = self.mne.inst.selection
|
||||
return epoch_nums[np.searchsorted(self.mne.boundary_times[1:], time)]
|
||||
|
||||
def _redraw(self, update_data=True, annotations=False):
|
||||
"""Redraws backend if necessary."""
|
||||
if update_data:
|
||||
self._update_data()
|
||||
|
||||
self._draw_traces()
|
||||
|
||||
if annotations and not self.mne.is_epochs:
|
||||
self._draw_annotations()
|
||||
|
||||
def _close(self, event):
|
||||
"""Handle close events (via keypress or window [x])."""
|
||||
from matplotlib.pyplot import close
|
||||
|
||||
logger.debug(f"Closing {self.mne.instance_type} browser...")
|
||||
# write out bad epochs (after converting epoch numbers to indices)
|
||||
if self.mne.instance_type == "epochs":
|
||||
bad_ixs = np.isin(self.mne.inst.selection, self.mne.bad_epochs).nonzero()[0]
|
||||
self.mne.inst.drop(bad_ixs)
|
||||
logger.info(
|
||||
"The following epochs were marked as bad "
|
||||
"and are dropped:\n"
|
||||
f"{self.mne.bad_epochs}"
|
||||
)
|
||||
# write bad channels back to instance (don't do this for proj;
|
||||
# proj checkboxes are for viz only and shouldn't modify the instance)
|
||||
if self.mne.instance_type in ("raw", "epochs"):
|
||||
self.mne.inst.info["bads"] = self.mne.info["bads"]
|
||||
logger.info(f"Channels marked as bad:\n{self.mne.info['bads'] or 'none'}")
|
||||
# ICA excludes
|
||||
elif self.mne.instance_type == "ica":
|
||||
self.mne.ica.exclude = [
|
||||
self.mne.ica._ica_names.index(ch) for ch in self.mne.info["bads"]
|
||||
]
|
||||
# write window size to config
|
||||
str_size = ",".join([str(i) for i in self._get_size()])
|
||||
set_config("MNE_BROWSE_RAW_SIZE", str_size, set_env=False)
|
||||
# Clean up child figures (don't pop(), child figs remove themselves)
|
||||
while len(self.mne.child_figs):
|
||||
fig = self.mne.child_figs[-1]
|
||||
close(fig)
|
||||
self._close_event(fig)
|
||||
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
# CHILD FIGURES
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
@abstractmethod
|
||||
def _new_child_figure(self, fig_name, **kwargs):
|
||||
pass
|
||||
|
||||
def _create_ch_context_fig(self, idx):
|
||||
"""Show context figure; idx is index of **visible** channels."""
|
||||
inst = self.mne.instance_type
|
||||
pick = self.mne.picks[idx]
|
||||
if inst == "raw":
|
||||
fig = self._create_ch_location_fig(pick)
|
||||
elif inst == "ica":
|
||||
fig = self._create_ica_properties_fig(pick)
|
||||
else:
|
||||
fig = self._create_epoch_image_fig(pick)
|
||||
|
||||
return fig
|
||||
|
||||
def _create_ch_location_fig(self, pick):
|
||||
"""Show channel location figure."""
|
||||
from .utils import _channel_type_prettyprint, plot_sensors
|
||||
|
||||
ch_name = self.mne.ch_names[pick]
|
||||
ch_type = self.mne.ch_types[pick]
|
||||
if ch_type not in _DATA_CH_TYPES_SPLIT:
|
||||
return
|
||||
# create figure and axes
|
||||
title = f"Location of {ch_name}"
|
||||
fig = self._new_child_figure(figsize=(4, 4), fig_name=None, window_title=title)
|
||||
fig.suptitle(title)
|
||||
ax = fig.add_subplot(111)
|
||||
title = f"{ch_name} position ({_channel_type_prettyprint[ch_type]})"
|
||||
_ = plot_sensors(
|
||||
self.mne.info,
|
||||
ch_type=ch_type,
|
||||
axes=ax,
|
||||
title=title,
|
||||
kind="select",
|
||||
show=False,
|
||||
)
|
||||
# highlight desired channel & disable interactivity
|
||||
inds = np.isin(fig.lasso.ch_names, [ch_name])
|
||||
fig.lasso.disconnect()
|
||||
fig.lasso.alpha_other = 0.3
|
||||
fig.lasso.linewidth_selected = 3
|
||||
fig.lasso.style_sensors(inds)
|
||||
|
||||
return fig
|
||||
|
||||
def _create_ica_properties_fig(self, idx):
|
||||
"""Show ICA properties for the selected component."""
|
||||
from mne.viz.ica import (
|
||||
_create_properties_layout,
|
||||
_fast_plot_ica_properties,
|
||||
_prepare_data_ica_properties,
|
||||
)
|
||||
|
||||
ch_name = self.mne.ch_names[idx]
|
||||
if ch_name not in self.mne.ica._ica_names: # for EOG chans: do nothing
|
||||
return
|
||||
pick = self.mne.ica._ica_names.index(ch_name)
|
||||
title = f"{ch_name} properties"
|
||||
fig = self._new_child_figure(figsize=(7, 6), fig_name=None, window_title=title)
|
||||
fig.suptitle(title)
|
||||
fig, axes = _create_properties_layout(fig=fig)
|
||||
if not hasattr(self.mne, "data_ica_properties"):
|
||||
# Precompute epoch sources only once
|
||||
self.mne.data_ica_properties = _prepare_data_ica_properties(
|
||||
self.mne.ica_inst, self.mne.ica
|
||||
)
|
||||
_fast_plot_ica_properties(
|
||||
self.mne.ica,
|
||||
self.mne.ica_inst,
|
||||
picks=pick,
|
||||
axes=axes,
|
||||
psd_args=self.mne.psd_args,
|
||||
precomputed_data=self.mne.data_ica_properties,
|
||||
show=False,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def _create_epoch_image_fig(self, pick):
|
||||
"""Show epochs image for the selected channel."""
|
||||
from matplotlib.gridspec import GridSpec
|
||||
|
||||
from mne.viz import plot_epochs_image
|
||||
|
||||
ch_name = self.mne.ch_names[pick]
|
||||
title = f"Epochs image ({ch_name})"
|
||||
fig = self._new_child_figure(figsize=(6, 4), fig_name=None, window_title=title)
|
||||
fig.suptitle = title
|
||||
gs = GridSpec(nrows=3, ncols=10, figure=fig)
|
||||
fig.add_subplot(gs[:2, :9])
|
||||
fig.add_subplot(gs[2, :9])
|
||||
fig.add_subplot(gs[:2, 9])
|
||||
plot_epochs_image(self.mne.inst, picks=pick, fig=fig, show=False)
|
||||
|
||||
return fig
|
||||
|
||||
def _create_epoch_histogram(self):
|
||||
"""Create peak-to-peak histogram of channel amplitudes."""
|
||||
epochs = self.mne.inst
|
||||
data = OrderedDict()
|
||||
ptp = np.ptp(epochs.get_data(copy=False), axis=2)
|
||||
for ch_type in ("eeg", "mag", "grad"):
|
||||
if ch_type in epochs:
|
||||
data[ch_type] = ptp.T[self.mne.ch_types == ch_type].ravel()
|
||||
units = _handle_default("units")
|
||||
titles = _handle_default("titles")
|
||||
colors = _handle_default("color")
|
||||
scalings = _handle_default("scalings")
|
||||
title = "Histogram of peak-to-peak amplitudes"
|
||||
figsize = (4, 1 + 1.5 * len(data))
|
||||
fig = self._new_child_figure(
|
||||
figsize=figsize, fig_name="fig_histogram", window_title=title
|
||||
)
|
||||
for ix, (_ch_type, _data) in enumerate(data.items()):
|
||||
ax = fig.add_subplot(len(data), 1, ix + 1)
|
||||
ax.set(title=titles[_ch_type], xlabel=units[_ch_type], ylabel="Count")
|
||||
# set histogram bin range based on rejection thresholds
|
||||
reject = None
|
||||
_range = None
|
||||
if epochs.reject is not None and _ch_type in epochs.reject:
|
||||
reject = epochs.reject[_ch_type] * scalings[_ch_type]
|
||||
_range = (0.0, reject * 1.1)
|
||||
# plot it
|
||||
ax.hist(
|
||||
_data * scalings[_ch_type],
|
||||
bins=100,
|
||||
color=colors[_ch_type],
|
||||
range=_range,
|
||||
)
|
||||
if reject is not None:
|
||||
ax.plot((reject, reject), (0, ax.get_ylim()[1]), color="r")
|
||||
# finalize
|
||||
fig.suptitle(title, y=0.99)
|
||||
self.mne.fig_histogram = fig
|
||||
|
||||
return fig
|
||||
|
||||
def _close_event(self, fig):
|
||||
"""Look at _close_event in mne.fixes.py for why this exists."""
|
||||
pass
|
||||
|
||||
def fake_keypress(self, key, fig=None): # noqa: D400
|
||||
"""Pass a fake keypress to the figure.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The key to fake (e.g., ``'a'``).
|
||||
fig : instance of Figure
|
||||
The figure to pass the keypress to.
|
||||
"""
|
||||
return self._fake_keypress(key, fig=fig)
|
||||
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
# TEST METHODS
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
|
||||
@abstractmethod
|
||||
def _get_size(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _fake_keypress(self, key, fig):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _fake_click(self, point, fig, axis, xform, button, kind):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _click_ch_name(self, ch_index, button):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _resize_by_factor(self, factor):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_ticklabels(self, orientation):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _update_yaxis_labels(self):
|
||||
pass
|
||||
|
||||
|
||||
def _load_backend(backend_name):
|
||||
global backend
|
||||
if backend_name == "matplotlib":
|
||||
backend = importlib.import_module(name="._mpl_figure", package="mne.viz")
|
||||
else:
|
||||
from mne_qt_browser import _pg_figure as backend
|
||||
|
||||
logger.info(f"Using {backend_name} as 2D backend.")
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
def _get_browser(show, block, **kwargs):
|
||||
"""Instantiate a new MNE browse-style figure."""
|
||||
from .utils import _get_figsize_from_config
|
||||
|
||||
figsize = kwargs.setdefault("figsize", _get_figsize_from_config())
|
||||
if figsize is None or np.any(np.array(figsize) < 8):
|
||||
kwargs["figsize"] = (8, 8)
|
||||
kwargs["splash"] = kwargs.get("splash", True) and show
|
||||
if kwargs.get("theme", None) is None:
|
||||
kwargs["theme"] = get_config("MNE_BROWSER_THEME", "auto")
|
||||
if kwargs.get("overview_mode", None) is None:
|
||||
kwargs["overview_mode"] = get_config("MNE_BROWSER_OVERVIEW_MODE", "channels")
|
||||
|
||||
# Initialize browser backend
|
||||
backend_name = get_browser_backend()
|
||||
# Check mne-qt-browser compatibility
|
||||
if backend_name == "qt":
|
||||
import mne_qt_browser
|
||||
|
||||
from ..epochs import BaseEpochs
|
||||
|
||||
is_ica = kwargs.get("ica", False)
|
||||
is_epochs = isinstance(kwargs.get("inst", False), BaseEpochs)
|
||||
not_compat = _compare_version(mne_qt_browser.__version__, "<", "0.2.0")
|
||||
inst_str = "ICA" if is_ica else "Epochs"
|
||||
if not_compat and (is_ica or is_epochs):
|
||||
logger.info(
|
||||
f'You set the browser-backend to "qt" but your'
|
||||
f" current version {mne_qt_browser.__version__}"
|
||||
f" of mne-qt-browser is too low for {inst_str}."
|
||||
f"Update with pip or conda."
|
||||
f"Defaults to matplotlib."
|
||||
)
|
||||
with use_browser_backend("matplotlib"):
|
||||
# Initialize Browser
|
||||
fig = backend._init_browser(**kwargs)
|
||||
_show_browser(show=show, block=block, fig=fig)
|
||||
return fig
|
||||
|
||||
# Initialize Browser
|
||||
fig = backend._init_browser(**kwargs)
|
||||
_show_browser(show=show, block=block, fig=fig)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def _check_browser_backend_name(backend_name):
|
||||
_validate_type(backend_name, str, "backend_name")
|
||||
backend_name = backend_name.lower()
|
||||
backend_name = "qt" if backend_name == "pyqtgraph" else backend_name
|
||||
_check_option("backend_name", backend_name, VALID_BROWSE_BACKENDS)
|
||||
return backend_name
|
||||
|
||||
|
||||
@verbose
|
||||
def set_browser_backend(backend_name, verbose=None):
|
||||
"""Set the 2D browser backend for MNE.
|
||||
|
||||
The backend will be set as specified and operations will use
|
||||
that backend.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend_name : str
|
||||
The 2D browser backend to select. See Notes for the capabilities
|
||||
of each backend (``'qt'``, ``'matplotlib'``). The ``'qt'`` browser
|
||||
requires `mne-qt-browser
|
||||
<https://github.com/mne-tools/mne-qt-browser>`__.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
old_backend_name : str | None
|
||||
The old backend that was in use.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This table shows the capabilities of each backend ("✓" for full support,
|
||||
and "-" for partial support):
|
||||
|
||||
.. table::
|
||||
:widths: auto
|
||||
|
||||
+--------------------------------------+------------+----+
|
||||
| **2D browser function:** | matplotlib | qt |
|
||||
+======================================+============+====+
|
||||
| :func:`plot_raw` | ✓ | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
| :func:`plot_epochs` | ✓ | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
| :func:`plot_ica_sources` | ✓ | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
+--------------------------------------+------------+----+
|
||||
| **Feature:** |
|
||||
+--------------------------------------+------------+----+
|
||||
| Show Events | ✓ | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
| Add/Edit/Remove Annotations | ✓ | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
| Toggle Projections | ✓ | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
| Butterfly Mode | ✓ | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
| Selection Mode | ✓ | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
| Smooth Scrolling | | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
| Overview-Bar (with Z-Score-Mode) | | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
|
||||
.. versionadded:: 0.24
|
||||
"""
|
||||
global MNE_BROWSER_BACKEND
|
||||
old_backend_name = MNE_BROWSER_BACKEND
|
||||
backend_name = _check_browser_backend_name(backend_name)
|
||||
if MNE_BROWSER_BACKEND != backend_name:
|
||||
_load_backend(backend_name)
|
||||
MNE_BROWSER_BACKEND = backend_name
|
||||
|
||||
return old_backend_name
|
||||
|
||||
|
||||
def _init_browser_backend():
|
||||
global MNE_BROWSER_BACKEND
|
||||
|
||||
# check if MNE_BROWSER_BACKEND is not None and valid or get it from config
|
||||
loaded_backend = MNE_BROWSER_BACKEND or get_config(
|
||||
key="MNE_BROWSER_BACKEND", default=None
|
||||
)
|
||||
if loaded_backend is not None:
|
||||
set_browser_backend(loaded_backend)
|
||||
return MNE_BROWSER_BACKEND
|
||||
else:
|
||||
errors = dict()
|
||||
# Try import of valid browser backends
|
||||
for name in VALID_BROWSE_BACKENDS:
|
||||
try:
|
||||
_load_backend(name)
|
||||
except ImportError as exc:
|
||||
errors[name] = str(exc)
|
||||
else:
|
||||
MNE_BROWSER_BACKEND = name
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Could not load any valid 2D backend:\n"
|
||||
+ "\n".join(f"{key}: {val}" for key, val in errors.items())
|
||||
)
|
||||
|
||||
return MNE_BROWSER_BACKEND
|
||||
|
||||
|
||||
def get_browser_backend():
|
||||
"""Return the 2D backend currently used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
backend_used : str | None
|
||||
The 2D browser backend currently in use. If no backend is found,
|
||||
returns ``None``.
|
||||
"""
|
||||
try:
|
||||
backend_name = _init_browser_backend()
|
||||
except RuntimeError as exc:
|
||||
backend_name = None
|
||||
logger.info(str(exc))
|
||||
return backend_name
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_browser_backend(backend_name):
|
||||
"""Create a 2D browser visualization context using the designated backend.
|
||||
|
||||
See :func:`mne.viz.set_browser_backend` for more details on the available
|
||||
2D browser backends and their capabilities.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend_name : {'qt', 'matplotlib'}
|
||||
The 2D browser backend to use in the context.
|
||||
"""
|
||||
old_backend = set_browser_backend(backend_name)
|
||||
try:
|
||||
yield backend
|
||||
finally:
|
||||
if old_backend is not None:
|
||||
try:
|
||||
set_browser_backend(old_backend)
|
||||
except Exception:
|
||||
pass
|
||||
2530
mne/viz/_mpl_figure.py
Normal file
2530
mne/viz/_mpl_figure.py
Normal file
File diff suppressed because it is too large
Load Diff
252
mne/viz/_proj.py
Normal file
252
mne/viz/_proj.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""Functions for plotting projectors."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.pick import _picks_to_idx
|
||||
from ..defaults import DEFAULTS
|
||||
from ..utils import _pl, _validate_type, verbose, warn
|
||||
from .evoked import _plot_evoked
|
||||
from .topomap import _plot_projs_topomap
|
||||
from .utils import _check_type_projs, plt_show
|
||||
|
||||
|
||||
@verbose
|
||||
def plot_projs_joint(
|
||||
projs, evoked, picks_trace=None, *, topomap_kwargs=None, show=True, verbose=None
|
||||
):
|
||||
"""Plot projectors and evoked jointly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
projs : list of Projection
|
||||
The projectors to plot.
|
||||
evoked : instance of Evoked
|
||||
The data to plot. Typically this is the evoked instance created from
|
||||
averaging the epochs used to create the projection.
|
||||
%(picks_plot_projs_joint_trace)s
|
||||
topomap_kwargs : dict | None
|
||||
Keyword arguments to pass to :func:`mne.viz.plot_projs_topomap`.
|
||||
%(show)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
fig : instance of matplotlib Figure
|
||||
The figure.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function creates a figure with three columns:
|
||||
|
||||
1. The left shows the evoked data traces before (black) and after (green)
|
||||
projection.
|
||||
2. The center shows the topomaps associated with each of the projectors.
|
||||
3. The right again shows the data traces (black), but this time with:
|
||||
|
||||
1. The data projected onto each projector with a single normalization
|
||||
factor (solid lines). This is useful for seeing the relative power
|
||||
in each projection vector.
|
||||
2. The data projected onto each projector with individual normalization
|
||||
factors (dashed lines). This is useful for visualizing each time
|
||||
course regardless of its power.
|
||||
3. Additional data traces from ``picks_trace`` (solid yellow lines).
|
||||
This is useful for visualizing the "ground truth" of the time
|
||||
course, e.g. the measured EOG or ECG channel time courses.
|
||||
|
||||
.. versionadded:: 1.1
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ..evoked import Evoked
|
||||
|
||||
_validate_type(evoked, Evoked, "evoked")
|
||||
_validate_type(topomap_kwargs, (None, dict), "topomap_kwargs")
|
||||
projs = _check_type_projs(projs)
|
||||
topomap_kwargs = dict() if topomap_kwargs is None else topomap_kwargs
|
||||
if picks_trace is not None:
|
||||
picks_trace = _picks_to_idx(evoked.info, picks_trace, allow_empty=False)
|
||||
info = evoked.info
|
||||
ch_types = evoked.get_channel_types(unique=True, only_data_chs=True)
|
||||
proj_by_type = dict() # will be set up like an enumerate key->[pi, proj]
|
||||
ch_names_by_type = dict()
|
||||
used = np.zeros(len(projs), int)
|
||||
for ch_type in ch_types:
|
||||
these_picks = _picks_to_idx(info, ch_type, allow_empty=True)
|
||||
these_chs = [evoked.ch_names[pick] for pick in these_picks]
|
||||
ch_names_by_type[ch_type] = these_chs
|
||||
for pi, proj in enumerate(projs):
|
||||
if not set(these_chs).intersection(proj["data"]["col_names"]):
|
||||
continue
|
||||
if ch_type not in proj_by_type:
|
||||
proj_by_type[ch_type] = list()
|
||||
proj_by_type[ch_type].append([pi, deepcopy(proj)])
|
||||
used[pi] += 1
|
||||
missing = (~used.astype(bool)).sum()
|
||||
if missing:
|
||||
warn(
|
||||
f"{missing} projector{_pl(missing)} had no channel names "
|
||||
"present in epochs"
|
||||
)
|
||||
del projs
|
||||
ch_types = list(proj_by_type) # reduce to number we actually need
|
||||
# room for legend
|
||||
max_proj_per_type = max(len(x) for x in proj_by_type.values())
|
||||
cs_trace = 3
|
||||
cs_topo = 2
|
||||
n_col = max_proj_per_type * cs_topo + 2 * cs_trace
|
||||
n_row = len(ch_types)
|
||||
shape = (n_row, n_col)
|
||||
fig = plt.figure(
|
||||
figsize=(n_col * 1.1 + 0.5, n_row * 1.8 + 0.5), layout="constrained"
|
||||
)
|
||||
ri = 0
|
||||
# pick some sufficiently distinct colors (6 per proj type, e.g., ECG,
|
||||
# should be enough hopefully!)
|
||||
# https://personal.sron.nl/~pault/data/colourschemes.pdf
|
||||
# "Vibrant" color scheme
|
||||
proj_colors = [
|
||||
"#CC3311", # red
|
||||
"#009988", # teal
|
||||
"#0077BB", # blue
|
||||
"#EE3377", # magenta
|
||||
"#EE7733", # orange
|
||||
"#33BBEE", # cyan
|
||||
]
|
||||
trace_color = "#CCBB44" # yellow
|
||||
after_color, after_name = "#228833", "green"
|
||||
type_titles = DEFAULTS["titles"]
|
||||
last_ax = [None] * 2
|
||||
first_ax = dict()
|
||||
pe_kwargs = dict(show=False, draw=False)
|
||||
for ch_type, these_projs in proj_by_type.items():
|
||||
these_idxs, these_projs = zip(*these_projs)
|
||||
ch_names = ch_names_by_type[ch_type]
|
||||
idx = np.where(
|
||||
[np.isin(ch_names, proj["data"]["col_names"]).all() for proj in these_projs]
|
||||
)[0]
|
||||
used[idx] += 1
|
||||
count = len(these_projs)
|
||||
for proj in these_projs:
|
||||
sub_idx = [proj["data"]["col_names"].index(name) for name in ch_names]
|
||||
proj["data"]["data"] = proj["data"]["data"][:, sub_idx]
|
||||
proj["data"]["col_names"] = ch_names
|
||||
ba_ax = plt.subplot2grid(shape, (ri, 0), colspan=cs_trace, fig=fig)
|
||||
topo_axes = [
|
||||
plt.subplot2grid(
|
||||
shape, (ri, ci * cs_topo + cs_trace), colspan=cs_topo, fig=fig
|
||||
)
|
||||
for ci in range(count)
|
||||
]
|
||||
tr_ax = plt.subplot2grid(
|
||||
shape, (ri, n_col - cs_trace), colspan=cs_trace, fig=fig
|
||||
)
|
||||
# topomaps
|
||||
_plot_projs_topomap(these_projs, info=info, axes=topo_axes, **topomap_kwargs)
|
||||
for idx, proj, ax_ in zip(these_idxs, these_projs, topo_axes):
|
||||
ax_.set_title("") # could use proj['desc'] but it's long
|
||||
ax_.set_xlabel(f"projs[{idx}]", fontsize="small")
|
||||
unit = DEFAULTS["units"][ch_type]
|
||||
# traces
|
||||
this_evoked = evoked.copy().pick(ch_names)
|
||||
p = np.concatenate([p["data"]["data"] for p in these_projs])
|
||||
assert p.shape == (len(these_projs), len(this_evoked.data))
|
||||
traces = np.dot(p, this_evoked.data)
|
||||
traces *= np.sign(np.mean(np.dot(this_evoked.data, traces.T), 0))[:, np.newaxis]
|
||||
if picks_trace is not None:
|
||||
ch_traces = evoked.data[picks_trace]
|
||||
ch_traces -= np.mean(ch_traces, axis=1, keepdims=True)
|
||||
ch_traces /= np.abs(ch_traces).max()
|
||||
_plot_evoked(
|
||||
this_evoked, picks="all", axes=[tr_ax], **pe_kwargs, spatial_colors=False
|
||||
)
|
||||
for line in tr_ax.lines:
|
||||
line.set(lw=0.5, zorder=3)
|
||||
for t in list(tr_ax.texts):
|
||||
t.remove()
|
||||
scale = 0.8 * np.abs(tr_ax.get_ylim()).max()
|
||||
hs, labels = list(), list()
|
||||
traces /= np.abs(traces).max() # uniformly scaled
|
||||
for ti, trace in enumerate(traces):
|
||||
hs.append(
|
||||
tr_ax.plot(
|
||||
this_evoked.times,
|
||||
trace * scale,
|
||||
color=proj_colors[ti % len(proj_colors)],
|
||||
zorder=5,
|
||||
)[0]
|
||||
)
|
||||
labels.append(f"projs[{these_idxs[ti]}]")
|
||||
traces /= np.abs(traces).max(1, keepdims=True) # independently
|
||||
for ti, trace in enumerate(traces):
|
||||
tr_ax.plot(
|
||||
this_evoked.times,
|
||||
trace * scale,
|
||||
color=proj_colors[ti % len(proj_colors)],
|
||||
zorder=3.5,
|
||||
ls="--",
|
||||
lw=1.0,
|
||||
alpha=0.75,
|
||||
)
|
||||
if picks_trace is not None:
|
||||
trace_ch = [evoked.ch_names[pick] for pick in picks_trace]
|
||||
if len(picks_trace) == 1:
|
||||
trace_ch = trace_ch[0]
|
||||
hs.append(
|
||||
tr_ax.plot(
|
||||
this_evoked.times,
|
||||
ch_traces.T * scale,
|
||||
color=trace_color,
|
||||
lw=3,
|
||||
zorder=4,
|
||||
alpha=0.75,
|
||||
)[0]
|
||||
)
|
||||
labels.append(str(trace_ch))
|
||||
tr_ax.set(title="", xlabel="", ylabel="")
|
||||
# This will steal space from the subplots in a constrained layout
|
||||
# https://matplotlib.org/3.5.0/tutorials/intermediate/constrainedlayout_guide.html#legends # noqa: E501
|
||||
tr_ax.legend(
|
||||
hs,
|
||||
labels,
|
||||
loc="center left",
|
||||
borderaxespad=0.05,
|
||||
bbox_to_anchor=[1.05, 0.5],
|
||||
)
|
||||
last_ax[1] = tr_ax
|
||||
key = "Projected time course"
|
||||
if key not in first_ax:
|
||||
first_ax[key] = tr_ax
|
||||
# Before and after traces
|
||||
_plot_evoked(this_evoked, picks="all", axes=[ba_ax], **pe_kwargs)
|
||||
for line in ba_ax.lines:
|
||||
line.set(lw=0.5, zorder=3)
|
||||
loff = len(ba_ax.lines)
|
||||
this_proj_evoked = this_evoked.copy().add_proj(these_projs)
|
||||
# with meg='combined' any existing mag projectors (those already part
|
||||
# of evoked before we add_proj above) will have greatly
|
||||
# reduced power, so we ignore the warning about this issue
|
||||
this_proj_evoked.apply_proj(verbose="error")
|
||||
_plot_evoked(this_proj_evoked, picks="all", axes=[ba_ax], **pe_kwargs)
|
||||
for line in ba_ax.lines[loff:]:
|
||||
line.set(lw=0.5, zorder=4, color=after_color)
|
||||
for t in list(ba_ax.texts):
|
||||
t.remove()
|
||||
ba_ax.set(title="", xlabel="")
|
||||
ba_ax.set(ylabel=f"{type_titles[ch_type]}\n{unit}")
|
||||
last_ax[0] = ba_ax
|
||||
key = f"Before (black) and after ({after_name})"
|
||||
if key not in first_ax:
|
||||
first_ax[key] = ba_ax
|
||||
ri += 1
|
||||
for ax in last_ax:
|
||||
ax.set(xlabel="Time (s)")
|
||||
for title, ax in first_ax.items():
|
||||
ax.set_title(title, fontsize="medium")
|
||||
plt_show(show)
|
||||
return fig
|
||||
81
mne/viz/_scraper.py
Normal file
81
mne/viz/_scraper.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ..utils import _pl
|
||||
from .backends._utils import _pixmap_to_ndarray
|
||||
|
||||
|
||||
class _MNEQtBrowserScraper:
|
||||
def __repr__(self):
|
||||
return "<MNEQtBrowserScraper>"
|
||||
|
||||
def __call__(self, block, block_vars, gallery_conf):
|
||||
import mne_qt_browser
|
||||
from sphinx_gallery.scrapers import figure_rst
|
||||
|
||||
if gallery_conf["builder_name"] != "html":
|
||||
return ""
|
||||
img_fnames = list()
|
||||
inst = None
|
||||
n_plot = 0
|
||||
for gui in list(mne_qt_browser._browser_instances):
|
||||
try:
|
||||
scraped = getattr(gui, "_scraped", False)
|
||||
except Exception: # super __init__ not called, perhaps stale?
|
||||
scraped = True
|
||||
if scraped:
|
||||
continue
|
||||
gui._scraped = True # monkey-patch but it's easy enough
|
||||
n_plot += 1
|
||||
img_fnames.append(next(block_vars["image_path_iterator"]))
|
||||
pixmap, inst = _mne_qt_browser_screenshot(gui, inst)
|
||||
pixmap.save(img_fnames[-1])
|
||||
# child figures
|
||||
for fig in gui.mne.child_figs:
|
||||
# For now we only support Selection
|
||||
if not hasattr(fig, "channel_fig"):
|
||||
continue
|
||||
fig = fig.channel_fig
|
||||
img_fnames.append(next(block_vars["image_path_iterator"]))
|
||||
fig.savefig(img_fnames[-1])
|
||||
gui.close()
|
||||
del gui, pixmap
|
||||
if not len(img_fnames):
|
||||
return ""
|
||||
for _ in range(2):
|
||||
inst.processEvents()
|
||||
return figure_rst(img_fnames, gallery_conf["src_dir"], f"Raw plot{_pl(n_plot)}")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _screenshot_mode(browser):
|
||||
if need_zen := browser.mne.scrollbars_visible:
|
||||
browser._toggle_zenmode()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if need_zen:
|
||||
browser._toggle_zenmode()
|
||||
|
||||
|
||||
def _mne_qt_browser_screenshot(browser, inst=None, return_type="pixmap"):
|
||||
from mne_qt_browser._pg_figure import QApplication
|
||||
|
||||
if getattr(browser, "load_thread", None) is not None:
|
||||
if browser.load_thread.isRunning():
|
||||
browser.load_thread.wait(30000)
|
||||
if inst is None:
|
||||
inst = QApplication.instance()
|
||||
# processEvents to make sure our progressBar is updated
|
||||
with _screenshot_mode(browser):
|
||||
for _ in range(2):
|
||||
inst.processEvents()
|
||||
pixmap = browser.grab()
|
||||
assert return_type in ("pixmap", "ndarray")
|
||||
if return_type == "ndarray":
|
||||
return _pixmap_to_ndarray(pixmap)
|
||||
else:
|
||||
return pixmap, inst
|
||||
7
mne/viz/backends/__init__.py
Normal file
7
mne/viz/backends/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
"""Visualization backend."""
|
||||
|
||||
from . import renderer
|
||||
1500
mne/viz/backends/_abstract.py
Normal file
1500
mne/viz/backends/_abstract.py
Normal file
File diff suppressed because it is too large
Load Diff
1619
mne/viz/backends/_notebook.py
Normal file
1619
mne/viz/backends/_notebook.py
Normal file
File diff suppressed because it is too large
Load Diff
1358
mne/viz/backends/_pyvista.py
Normal file
1358
mne/viz/backends/_pyvista.py
Normal file
File diff suppressed because it is too large
Load Diff
1852
mne/viz/backends/_qt.py
Normal file
1852
mne/viz/backends/_qt.py
Normal file
File diff suppressed because it is too large
Load Diff
421
mne/viz/backends/_utils.py
Normal file
421
mne/viz/backends/_utils.py
Normal file
@@ -0,0 +1,421 @@
|
||||
#
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import collections.abc
|
||||
import functools
|
||||
import os
|
||||
import platform
|
||||
import signal
|
||||
import sys
|
||||
from colorsys import rgb_to_hls
|
||||
from contextlib import contextmanager
|
||||
from ctypes import c_char_p, c_void_p, cdll
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...fixes import _compare_version
|
||||
from ...utils import _check_qt_version, _validate_type, logger, warn
|
||||
from ..utils import _get_cmap
|
||||
|
||||
VALID_BROWSE_BACKENDS = (
|
||||
"qt",
|
||||
"matplotlib",
|
||||
)
|
||||
|
||||
VALID_3D_BACKENDS = (
|
||||
"pyvistaqt", # default 3d backend
|
||||
"notebook",
|
||||
)
|
||||
ALLOWED_QUIVER_MODES = ("2darrow", "arrow", "cone", "cylinder", "sphere", "oct")
|
||||
_ICONS_PATH = Path(__file__).parents[2] / "icons"
|
||||
|
||||
|
||||
def _get_colormap_from_array(
|
||||
colormap=None, normalized_colormap=False, default_colormap="coolwarm"
|
||||
):
|
||||
from matplotlib.colors import ListedColormap
|
||||
|
||||
if colormap is None:
|
||||
cmap = _get_cmap(default_colormap)
|
||||
elif isinstance(colormap, str):
|
||||
cmap = _get_cmap(colormap)
|
||||
elif normalized_colormap:
|
||||
cmap = ListedColormap(colormap)
|
||||
else:
|
||||
cmap = ListedColormap(np.array(colormap) / 255.0)
|
||||
return cmap
|
||||
|
||||
|
||||
def _check_color(color):
|
||||
from matplotlib.colors import colorConverter
|
||||
|
||||
if isinstance(color, str):
|
||||
color = colorConverter.to_rgb(color)
|
||||
elif isinstance(color, collections.abc.Iterable):
|
||||
np_color = np.array(color)
|
||||
if np_color.size % 3 != 0 and np_color.size % 4 != 0:
|
||||
raise ValueError("The expected valid format is RGB or RGBA.")
|
||||
if np_color.dtype in (np.int64, np.int32):
|
||||
if (np_color < 0).any() or (np_color > 255).any():
|
||||
raise ValueError("Values out of range [0, 255].")
|
||||
elif np_color.dtype == np.float64:
|
||||
if (np_color < 0.0).any() or (np_color > 1.0).any():
|
||||
raise ValueError("Values out of range [0.0, 1.0].")
|
||||
else:
|
||||
raise TypeError(
|
||||
"Expected data type is `np.int64`, `np.int32`, or `np.float64` but "
|
||||
f"{np_color.dtype} was given."
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected type is `str` or iterable but {type(color)} was given."
|
||||
)
|
||||
return color
|
||||
|
||||
|
||||
def _alpha_blend_background(ctable, background_color):
|
||||
alphas = ctable[:, -1][:, np.newaxis] / 255.0
|
||||
use_table = ctable.copy()
|
||||
use_table[:, -1] = 255.0
|
||||
return (use_table * alphas) + background_color * (1 - alphas)
|
||||
|
||||
|
||||
@functools.lru_cache(1)
|
||||
def _qt_init_icons():
|
||||
from qtpy.QtGui import QIcon
|
||||
|
||||
QIcon.setThemeSearchPaths([str(_ICONS_PATH)] + QIcon.themeSearchPaths())
|
||||
QIcon.setFallbackThemeName("light")
|
||||
return str(_ICONS_PATH)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _qt_disable_paint(widget):
|
||||
paintEvent = widget.paintEvent
|
||||
widget.paintEvent = lambda *args, **kwargs: None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
widget.paintEvent = paintEvent
|
||||
|
||||
|
||||
_QT_ICON_KEYS = dict(app=None)
|
||||
|
||||
|
||||
def _init_mne_qtapp(enable_icon=True, pg_app=False, splash=False):
|
||||
"""Get QApplication-instance for MNE-Python.
|
||||
|
||||
Parameter
|
||||
---------
|
||||
enable_icon: bool
|
||||
If to set an MNE-icon for the app.
|
||||
pg_app: bool
|
||||
If to create the QApplication with pyqtgraph. For an until know
|
||||
undiscovered reason the pyqtgraph-browser won't show without
|
||||
mkQApp from pyqtgraph.
|
||||
splash : bool | str
|
||||
If not False, display a splash screen. If str, set the message
|
||||
to the given string.
|
||||
|
||||
Returns
|
||||
-------
|
||||
app : ``qtpy.QtWidgets.QApplication``
|
||||
Instance of QApplication.
|
||||
splash : ``qtpy.QtWidgets.QSplashScreen``
|
||||
Instance of QSplashScreen. Only returned if splash is True or a
|
||||
string.
|
||||
"""
|
||||
from qtpy.QtCore import Qt
|
||||
from qtpy.QtGui import QGuiApplication, QIcon, QPixmap
|
||||
from qtpy.QtWidgets import QApplication, QSplashScreen
|
||||
|
||||
app_name = "MNE-Python"
|
||||
organization_name = "MNE"
|
||||
|
||||
# Fix from cbrnr/mnelab for app name in menu bar
|
||||
# This has to come *before* the creation of the QApplication to work.
|
||||
# It also only affects the title bar, not the application dock.
|
||||
# There seems to be no way to change the application dock from "python"
|
||||
# at runtime.
|
||||
if sys.platform.startswith("darwin"):
|
||||
try:
|
||||
# set bundle name on macOS (app name shown in the menu bar)
|
||||
from Foundation import NSBundle
|
||||
|
||||
bundle = NSBundle.mainBundle()
|
||||
info = bundle.localizedInfoDictionary() or bundle.infoDictionary()
|
||||
if "CFBundleName" not in info:
|
||||
info["CFBundleName"] = app_name
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
# First we need to check to make sure the display is valid, otherwise
|
||||
# Qt might segfault on us
|
||||
app = QApplication.instance()
|
||||
if not (app or _display_is_valid()):
|
||||
raise RuntimeError("Cannot connect to a valid display")
|
||||
|
||||
if pg_app:
|
||||
from pyqtgraph import mkQApp
|
||||
|
||||
old_argv = sys.argv
|
||||
try:
|
||||
sys.argv = []
|
||||
app = mkQApp(app_name)
|
||||
finally:
|
||||
sys.argv = old_argv
|
||||
elif not app:
|
||||
app = QApplication([app_name])
|
||||
app.setApplicationName(app_name)
|
||||
app.setOrganizationName(organization_name)
|
||||
qt_version = _check_qt_version(check_usable_display=False)
|
||||
# HiDPI is enabled by default in Qt6, requires to be explicitly set for Qt5
|
||||
if _compare_version(qt_version, "<", "6.0"):
|
||||
app.setAttribute(Qt.AA_UseHighDpiPixmaps)
|
||||
|
||||
if enable_icon or splash:
|
||||
icons_path = _qt_init_icons()
|
||||
|
||||
if (
|
||||
enable_icon
|
||||
and app.windowIcon().cacheKey() != _QT_ICON_KEYS["app"]
|
||||
and app.windowIcon().isNull() # don't overwrite existing icon (e.g. MNELAB)
|
||||
):
|
||||
# Set icon
|
||||
kind = "bigsur_" if platform.mac_ver()[0] >= "10.16" else "default_"
|
||||
icon = QIcon(f"{icons_path}/mne_{kind}icon.png")
|
||||
app.setWindowIcon(icon)
|
||||
_QT_ICON_KEYS["app"] = app.windowIcon().cacheKey()
|
||||
|
||||
out = app
|
||||
if splash:
|
||||
pixmap = QPixmap(f"{icons_path}/mne_splash.png")
|
||||
pixmap.setDevicePixelRatio(QGuiApplication.primaryScreen().devicePixelRatio())
|
||||
args = (pixmap,)
|
||||
if _should_raise_window():
|
||||
args += (Qt.WindowStaysOnTopHint,)
|
||||
qsplash = QSplashScreen(*args)
|
||||
qsplash.setAttribute(Qt.WA_ShowWithoutActivating, True)
|
||||
if isinstance(splash, str):
|
||||
alignment = int(Qt.AlignBottom | Qt.AlignHCenter)
|
||||
qsplash.showMessage(splash, alignment=alignment, color=Qt.white)
|
||||
qsplash.show()
|
||||
app.processEvents()
|
||||
out = (out, qsplash)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _display_is_valid():
|
||||
# Adapted from matplotilb _c_internal_utils.py
|
||||
if sys.platform != "linux":
|
||||
return True
|
||||
if os.getenv("DISPLAY"): # if it's not there, don't bother
|
||||
libX11 = cdll.LoadLibrary("libX11.so.6")
|
||||
libX11.XOpenDisplay.restype = c_void_p
|
||||
libX11.XOpenDisplay.argtypes = [c_char_p]
|
||||
display = libX11.XOpenDisplay(None)
|
||||
if display is not None:
|
||||
libX11.XCloseDisplay.argtypes = [c_void_p]
|
||||
libX11.XCloseDisplay(display)
|
||||
return True
|
||||
# not found, try Wayland
|
||||
if os.getenv("WAYLAND_DISPLAY"):
|
||||
libwayland = cdll.LoadLibrary("libwayland-client.so.0")
|
||||
if libwayland is not None:
|
||||
if all(
|
||||
hasattr(libwayland, f"wl_display_{kind}connect") for kind in ("", "dis")
|
||||
):
|
||||
libwayland.wl_display_connect.restype = c_void_p
|
||||
libwayland.wl_display_connect.argtypes = [c_char_p]
|
||||
display = libwayland.wl_display_connect(None)
|
||||
if display:
|
||||
libwayland.wl_display_disconnect.argtypes = [c_void_p]
|
||||
libwayland.wl_display_disconnect(display)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# https://stackoverflow.com/questions/5160577/ctrl-c-doesnt-work-with-pyqt
|
||||
def _qt_app_exec(app):
|
||||
# adapted from matplotlib
|
||||
old_signal = signal.getsignal(signal.SIGINT)
|
||||
is_python_signal_handler = old_signal is not None
|
||||
if is_python_signal_handler:
|
||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||
try:
|
||||
# Make IPython Console accessible again in Spyder
|
||||
app.lastWindowClosed.connect(app.quit)
|
||||
app.exec_()
|
||||
finally:
|
||||
# reset the SIGINT exception handler
|
||||
if is_python_signal_handler:
|
||||
signal.signal(signal.SIGINT, old_signal)
|
||||
|
||||
|
||||
def _qt_detect_theme():
|
||||
try:
|
||||
import darkdetect
|
||||
|
||||
theme = darkdetect.theme().lower()
|
||||
except ModuleNotFoundError:
|
||||
logger.info(
|
||||
'For automatic theme detection, "darkdetect" has to'
|
||||
" be installed! You can install it with "
|
||||
"`pip install darkdetect`"
|
||||
)
|
||||
theme = "light"
|
||||
except Exception:
|
||||
theme = "light"
|
||||
return theme
|
||||
|
||||
|
||||
def _qt_get_stylesheet(theme):
|
||||
_validate_type(theme, ("path-like",), "theme")
|
||||
theme = str(theme)
|
||||
stylesheet = "" # no stylesheet
|
||||
if theme in ("auto", "dark", "light"):
|
||||
if theme == "auto":
|
||||
return stylesheet
|
||||
assert theme in ("dark", "light")
|
||||
system_theme = _qt_detect_theme()
|
||||
if theme == system_theme:
|
||||
return stylesheet
|
||||
_, api = _check_qt_version(return_api=True)
|
||||
# On macOS or Qt 6, we shouldn't need to set anything when the requested
|
||||
# theme matches that of the current OS state
|
||||
try:
|
||||
import qdarkstyle
|
||||
except ModuleNotFoundError:
|
||||
logger.info(
|
||||
f'To use {theme} mode when in {system_theme} mode, "qdarkstyle" has'
|
||||
"to be installed! You can install it with:\n"
|
||||
"pip install qdarkstyle\n"
|
||||
)
|
||||
else:
|
||||
if api in ("PySide6", "PyQt6") and _compare_version(
|
||||
qdarkstyle.__version__, "<", "3.2.3"
|
||||
):
|
||||
warn(
|
||||
f"Setting theme={repr(theme)} is not supported for {api} in "
|
||||
f"qdarkstyle {qdarkstyle.__version__}, it will be ignored. "
|
||||
"Consider upgrading qdarkstyle to >=3.2.3."
|
||||
)
|
||||
else:
|
||||
stylesheet = qdarkstyle.load_stylesheet(
|
||||
getattr(
|
||||
getattr(qdarkstyle, theme).palette,
|
||||
f"{theme.capitalize()}Palette",
|
||||
)
|
||||
)
|
||||
return stylesheet
|
||||
else:
|
||||
try:
|
||||
file = open(theme)
|
||||
except OSError:
|
||||
warn(
|
||||
"Requested theme file not found, will use light instead: "
|
||||
f"{repr(theme)}"
|
||||
)
|
||||
else:
|
||||
with file as fid:
|
||||
stylesheet = fid.read()
|
||||
return stylesheet
|
||||
|
||||
|
||||
def _should_raise_window():
|
||||
from matplotlib import rcParams
|
||||
|
||||
return rcParams["figure.raise_window"]
|
||||
|
||||
|
||||
def _qt_raise_window(widget):
|
||||
# Set raise_window like matplotlib if possible
|
||||
if _should_raise_window():
|
||||
widget.activateWindow()
|
||||
widget.raise_()
|
||||
|
||||
|
||||
def _qt_is_dark(widget):
|
||||
# Ideally this would use CIELab, but this should be good enough
|
||||
win = widget.window()
|
||||
bgcolor = win.palette().color(win.backgroundRole()).getRgbF()[:3]
|
||||
return rgb_to_hls(*bgcolor)[1] < 0.5
|
||||
|
||||
|
||||
def _pixmap_to_ndarray(pixmap):
|
||||
from qtpy.QtGui import QImage
|
||||
|
||||
img = pixmap.toImage()
|
||||
img = img.convertToFormat(QImage.Format.Format_RGBA8888)
|
||||
ptr = img.bits()
|
||||
count = img.height() * img.width() * 4
|
||||
if hasattr(ptr, "setsize"): # PyQt
|
||||
ptr.setsize(count)
|
||||
data = np.frombuffer(ptr, dtype=np.uint8, count=count).copy()
|
||||
data.shape = (img.height(), img.width(), 4)
|
||||
return data / 255.0
|
||||
|
||||
|
||||
def _notebook_vtk_works():
|
||||
if sys.platform != "linux":
|
||||
return True
|
||||
# check if it's OSMesa -- if it is, continue
|
||||
try:
|
||||
from vtkmodules import vtkRenderingOpenGL2
|
||||
|
||||
vtkRenderingOpenGL2.vtkOSOpenGLRenderWindow
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
return True # has vtkOSOpenGLRenderWindow (OSMesa build)
|
||||
|
||||
# if it's not OSMesa, we need to check display validity
|
||||
if _display_is_valid():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _qt_safe_window(
|
||||
*, splash="figure.splash", window="figure.plotter.app_window", always_close=True
|
||||
):
|
||||
def dec(meth, splash=splash, always_close=always_close):
|
||||
@functools.wraps(meth)
|
||||
def func(self, *args, **kwargs):
|
||||
close_splash = always_close
|
||||
error = False
|
||||
try:
|
||||
meth(self, *args, **kwargs)
|
||||
except Exception:
|
||||
close_splash = error = True
|
||||
raise
|
||||
finally:
|
||||
for attr, do_close in ((splash, close_splash), (window, error)):
|
||||
if attr is None or not do_close:
|
||||
continue
|
||||
parent = self
|
||||
name = attr.split(".")[-1]
|
||||
try:
|
||||
for n in attr.split(".")[:-1]:
|
||||
parent = getattr(parent, n)
|
||||
if name:
|
||||
widget = getattr(parent, name, False)
|
||||
else: # empty string means "self"
|
||||
widget = parent
|
||||
if widget:
|
||||
widget.close()
|
||||
del widget
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
delattr(parent, name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return func
|
||||
|
||||
return dec
|
||||
587
mne/viz/backends/renderer.py
Normal file
587
mne/viz/backends/renderer.py
Normal file
@@ -0,0 +1,587 @@
|
||||
"""Core visualization operations."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import importlib
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...utils import (
|
||||
_auto_weakref,
|
||||
_check_option,
|
||||
_validate_type,
|
||||
fill_doc,
|
||||
get_config,
|
||||
logger,
|
||||
verbose,
|
||||
)
|
||||
from .._3d import _get_3d_option
|
||||
from ..utils import safe_event
|
||||
from ._utils import VALID_3D_BACKENDS
|
||||
|
||||
MNE_3D_BACKEND = None
|
||||
MNE_3D_BACKEND_TESTING = False
|
||||
|
||||
|
||||
_backend_name_map = dict(
|
||||
pyvistaqt="._qt",
|
||||
notebook="._notebook",
|
||||
)
|
||||
backend = None
|
||||
|
||||
|
||||
def _reload_backend(backend_name):
|
||||
global backend
|
||||
backend = importlib.import_module(
|
||||
name=_backend_name_map[backend_name], package="mne.viz.backends"
|
||||
)
|
||||
logger.info(f"Using {backend_name} 3d backend.")
|
||||
|
||||
|
||||
def _get_backend():
|
||||
_get_3d_backend()
|
||||
return backend
|
||||
|
||||
|
||||
def _get_renderer(*args, **kwargs):
|
||||
_get_3d_backend()
|
||||
return backend._Renderer(*args, **kwargs)
|
||||
|
||||
|
||||
def _check_3d_backend_name(backend_name):
|
||||
_validate_type(backend_name, str, "backend_name")
|
||||
backend_name = "pyvistaqt" if backend_name == "pyvista" else backend_name
|
||||
_check_option("backend_name", backend_name, VALID_3D_BACKENDS)
|
||||
return backend_name
|
||||
|
||||
|
||||
@verbose
|
||||
def set_3d_backend(backend_name, verbose=None):
|
||||
"""Set the 3D backend for MNE.
|
||||
|
||||
The backend will be set as specified and operations will use
|
||||
that backend.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend_name : str
|
||||
The 3d backend to select. See Notes for the capabilities of each
|
||||
backend (``'pyvistaqt'`` and ``'notebook'``).
|
||||
|
||||
.. versionchanged:: 0.24
|
||||
The ``'pyvista'`` backend was renamed ``'pyvistaqt'``.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
old_backend_name : str | None
|
||||
The old backend that was in use.
|
||||
|
||||
Notes
|
||||
-----
|
||||
To use PyVista, set ``backend_name`` to ``pyvistaqt`` but the value
|
||||
``pyvista`` is still supported for backward compatibility.
|
||||
|
||||
This table shows the capabilities of each backend ("✓" for full support,
|
||||
and "-" for partial support):
|
||||
|
||||
.. table::
|
||||
:widths: auto
|
||||
|
||||
+--------------------------------------+-----------+----------+
|
||||
| **3D function:** | pyvistaqt | notebook |
|
||||
+======================================+===========+==========+
|
||||
| :func:`plot_vector_source_estimates` | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| :func:`plot_source_estimates` | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| :func:`plot_alignment` | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| :func:`plot_sparse_source_estimates` | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| :func:`plot_evoked_field` | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| :func:`snapshot_brain_montage` | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| :func:`link_brains` | ✓ | |
|
||||
+--------------------------------------+-----------+----------+
|
||||
+--------------------------------------+-----------+----------+
|
||||
| **Feature:** |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Large data | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Opacity/transparency | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Support geometric glyph | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Smooth shading | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Subplotting | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Inline plot in Jupyter Notebook | | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Inline plot in JupyterLab | | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Inline plot in Google Colab | | |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Toolbar | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
"""
|
||||
global MNE_3D_BACKEND
|
||||
old_backend_name = MNE_3D_BACKEND
|
||||
backend_name = _check_3d_backend_name(backend_name)
|
||||
if MNE_3D_BACKEND != backend_name:
|
||||
_reload_backend(backend_name)
|
||||
MNE_3D_BACKEND = backend_name
|
||||
return old_backend_name
|
||||
|
||||
|
||||
def get_3d_backend():
|
||||
"""Return the 3D backend currently used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
backend_used : str | None
|
||||
The 3d backend currently in use. If no backend is found,
|
||||
returns ``None``.
|
||||
|
||||
.. versionchanged:: 0.24
|
||||
The ``'pyvista'`` backend has been renamed ``'pyvistaqt'``, so
|
||||
``'pyvista'`` is no longer returned by this function.
|
||||
"""
|
||||
try:
|
||||
backend = _get_3d_backend()
|
||||
except RuntimeError as exc:
|
||||
backend = None
|
||||
logger.info(str(exc))
|
||||
return backend
|
||||
|
||||
|
||||
def _get_3d_backend():
|
||||
"""Load and return the current 3d backend."""
|
||||
global MNE_3D_BACKEND
|
||||
if MNE_3D_BACKEND is None:
|
||||
MNE_3D_BACKEND = get_config(key="MNE_3D_BACKEND", default=None)
|
||||
if MNE_3D_BACKEND is None: # try them in order
|
||||
errors = dict()
|
||||
for name in VALID_3D_BACKENDS:
|
||||
try:
|
||||
_reload_backend(name)
|
||||
except ImportError as exc:
|
||||
errors[name] = str(exc)
|
||||
else:
|
||||
MNE_3D_BACKEND = name
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Could not load any valid 3D backend\n"
|
||||
+ "\n".join(f"{key}: {val}" for key, val in errors.items())
|
||||
+ "\n".join(
|
||||
(
|
||||
"\n\n install pyvistaqt, using pip or conda:",
|
||||
"'pip install pyvistaqt'",
|
||||
"'conda install -c conda-forge pyvistaqt'",
|
||||
"\n or install ipywidgets, "
|
||||
+ "if using a notebook backend",
|
||||
"'pip install ipywidgets'",
|
||||
"'conda install -c conda-forge ipywidgets'",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
MNE_3D_BACKEND = _check_3d_backend_name(MNE_3D_BACKEND)
|
||||
_reload_backend(MNE_3D_BACKEND)
|
||||
MNE_3D_BACKEND = _check_3d_backend_name(MNE_3D_BACKEND)
|
||||
return MNE_3D_BACKEND
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_3d_backend(backend_name):
|
||||
"""Create a 3d visualization context using the designated backend.
|
||||
|
||||
See :func:`mne.viz.set_3d_backend` for more details on the available
|
||||
3d backends and their capabilities.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend_name : {'pyvistaqt', 'notebook'}
|
||||
The 3d backend to use in the context.
|
||||
"""
|
||||
old_backend = set_3d_backend(backend_name)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if old_backend is not None:
|
||||
try:
|
||||
set_3d_backend(old_backend)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _use_test_3d_backend(backend_name, interactive=False):
|
||||
"""Create a testing viz context.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend_name : str
|
||||
The 3d backend to use in the context.
|
||||
interactive : bool
|
||||
If True, ensure interactive elements are accessible.
|
||||
"""
|
||||
with _actors_invisible():
|
||||
with use_3d_backend(backend_name):
|
||||
with backend._testing_context(interactive):
|
||||
yield
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _actors_invisible():
|
||||
global MNE_3D_BACKEND_TESTING
|
||||
orig_testing = MNE_3D_BACKEND_TESTING
|
||||
MNE_3D_BACKEND_TESTING = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
MNE_3D_BACKEND_TESTING = orig_testing
|
||||
|
||||
|
||||
@fill_doc
|
||||
def set_3d_view(
|
||||
figure,
|
||||
azimuth=None,
|
||||
elevation=None,
|
||||
focalpoint=None,
|
||||
distance=None,
|
||||
roll=None,
|
||||
):
|
||||
"""Configure the view of the given scene.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
figure : object
|
||||
The scene which is modified.
|
||||
%(azimuth)s
|
||||
%(elevation)s
|
||||
%(focalpoint)s
|
||||
%(distance)s
|
||||
%(roll)s
|
||||
"""
|
||||
backend._set_3d_view(
|
||||
figure=figure,
|
||||
azimuth=azimuth,
|
||||
elevation=elevation,
|
||||
focalpoint=focalpoint,
|
||||
distance=distance,
|
||||
roll=roll,
|
||||
)
|
||||
|
||||
|
||||
@fill_doc
|
||||
def set_3d_title(figure, title, size=40, *, color="white", position="upper_left"):
|
||||
"""Configure the title of the given scene.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
figure : object
|
||||
The scene which is modified.
|
||||
title : str
|
||||
The title of the scene.
|
||||
size : int
|
||||
The size of the title.
|
||||
color : matplotlib color
|
||||
The color of the title.
|
||||
|
||||
.. versionadded:: 1.9
|
||||
position : str
|
||||
The position to use, e.g., "upper_left". See
|
||||
:meth:`pyvista.Plotter.add_text` for details.
|
||||
|
||||
.. versionadded:: 1.9
|
||||
|
||||
Returns
|
||||
-------
|
||||
text : object
|
||||
The text object returned by the given backend.
|
||||
|
||||
.. versionadded:: 1.0
|
||||
"""
|
||||
return backend._set_3d_title(
|
||||
figure=figure, title=title, size=size, color=color, position=position
|
||||
)
|
||||
|
||||
|
||||
def create_3d_figure(
|
||||
size,
|
||||
bgcolor=(0, 0, 0),
|
||||
smooth_shading=None,
|
||||
handle=None,
|
||||
*,
|
||||
scene=True,
|
||||
show=False,
|
||||
title="MNE 3D Figure",
|
||||
):
|
||||
"""Return an empty figure based on the current 3d backend.
|
||||
|
||||
.. warning:: Proceed with caution when the renderer object is
|
||||
returned (with ``scene=False``) because the _Renderer
|
||||
API is not necessarily stable enough for production,
|
||||
it's still actively in development.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
size : tuple
|
||||
The dimensions of the 3d figure (width, height).
|
||||
bgcolor : tuple
|
||||
The color of the background.
|
||||
smooth_shading : bool | None
|
||||
Whether to enable smooth shading. If ``None``, uses the config value
|
||||
``MNE_3D_OPTION_SMOOTH_SHADING``. Defaults to ``None``.
|
||||
handle : int | None
|
||||
The figure identifier.
|
||||
scene : bool
|
||||
If True (default), the returned object is the Figure3D. If False,
|
||||
an advanced, undocumented Renderer object is returned (the API is not
|
||||
stable or documented, so this is not recommended).
|
||||
show : bool
|
||||
If True, show the renderer immediately.
|
||||
|
||||
.. versionadded:: 1.0
|
||||
title : str
|
||||
The window title to use (if applicable).
|
||||
|
||||
.. versionadded:: 1.9
|
||||
|
||||
Returns
|
||||
-------
|
||||
figure : instance of Figure3D or ``Renderer``
|
||||
The requested empty figure or renderer, depending on ``scene``.
|
||||
"""
|
||||
_validate_type(smooth_shading, (bool, None), "smooth_shading")
|
||||
if smooth_shading is None:
|
||||
smooth_shading = _get_3d_option("smooth_shading")
|
||||
renderer = _get_renderer(
|
||||
fig=handle,
|
||||
size=size,
|
||||
bgcolor=bgcolor,
|
||||
smooth_shading=smooth_shading,
|
||||
show=show,
|
||||
name=title,
|
||||
)
|
||||
if scene:
|
||||
return renderer.scene()
|
||||
else:
|
||||
return renderer
|
||||
|
||||
|
||||
def close_3d_figure(figure):
|
||||
"""Close the given scene.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
figure : object
|
||||
The scene which needs to be closed.
|
||||
"""
|
||||
backend._close_3d_figure(figure)
|
||||
|
||||
|
||||
def close_all_3d_figures():
|
||||
"""Close all the scenes of the current 3d backend."""
|
||||
backend._close_all()
|
||||
|
||||
|
||||
def get_brain_class():
|
||||
"""Return the proper Brain class based on the current 3d backend.
|
||||
|
||||
Returns
|
||||
-------
|
||||
brain : object
|
||||
The Brain class corresponding to the current 3d backend.
|
||||
"""
|
||||
from ...viz._brain import Brain
|
||||
|
||||
return Brain
|
||||
|
||||
|
||||
class _TimeInteraction:
|
||||
"""Mixin enabling time interaction controls."""
|
||||
|
||||
def _enable_time_interaction(
|
||||
self,
|
||||
fig,
|
||||
current_time_func,
|
||||
times,
|
||||
init_playback_speed=0.01,
|
||||
playback_speed_range=(0.01, 0.1),
|
||||
):
|
||||
from ..ui_events import (
|
||||
PlaybackSpeed,
|
||||
TimeChange,
|
||||
publish,
|
||||
subscribe,
|
||||
)
|
||||
|
||||
self._fig = fig
|
||||
self._current_time_func = current_time_func
|
||||
self._times = times
|
||||
self._init_time = current_time_func()
|
||||
self._init_playback_speed = init_playback_speed
|
||||
|
||||
if not hasattr(self, "_dock"):
|
||||
self._dock_initialize()
|
||||
|
||||
if not hasattr(self, "_tool_bar") or self._tool_bar is None:
|
||||
self._tool_bar_initialize(name="Toolbar")
|
||||
|
||||
if not hasattr(self, "_widgets"):
|
||||
self._widgets = dict()
|
||||
|
||||
# Dock widgets
|
||||
@_auto_weakref
|
||||
def publish_time_change(time_index):
|
||||
publish(
|
||||
fig,
|
||||
TimeChange(time=np.interp(time_index, np.arange(len(times)), times)),
|
||||
)
|
||||
|
||||
layout = self._dock_add_group_box("")
|
||||
self._widgets["time_slider"] = self._dock_add_slider(
|
||||
name="Time (s)",
|
||||
value=np.interp(current_time_func(), times, np.arange(len(times))),
|
||||
rng=[0, len(times) - 1],
|
||||
double=True,
|
||||
callback=publish_time_change,
|
||||
compact=False,
|
||||
layout=layout,
|
||||
)
|
||||
hlayout = self._dock_add_layout(vertical=False)
|
||||
self._widgets["min_time"] = self._dock_add_label("-", layout=hlayout)
|
||||
self._dock_add_stretch(hlayout)
|
||||
self._widgets["current_time"] = self._dock_add_label(value="x", layout=hlayout)
|
||||
self._dock_add_stretch(hlayout)
|
||||
self._widgets["max_time"] = self._dock_add_label(value="+", layout=hlayout)
|
||||
self._layout_add_widget(layout, hlayout)
|
||||
|
||||
self._widgets["min_time"].set_value(f"{times[0]: .3f}")
|
||||
self._widgets["current_time"].set_value(f"{current_time_func(): .3f}")
|
||||
self._widgets["max_time"].set_value(f"{times[-1]: .3f}")
|
||||
|
||||
@_auto_weakref
|
||||
def publish_playback_speed(speed):
|
||||
publish(fig, PlaybackSpeed(speed=speed))
|
||||
|
||||
self._widgets["playback_speed"] = self._dock_add_spin_box(
|
||||
name="Speed",
|
||||
value=init_playback_speed,
|
||||
rng=playback_speed_range,
|
||||
callback=publish_playback_speed,
|
||||
layout=layout,
|
||||
)
|
||||
|
||||
# Tool bar buttons
|
||||
self._widgets["reset"] = self._tool_bar_add_button(
|
||||
name="reset", desc="Reset", func=self._reset_time
|
||||
)
|
||||
self._widgets["play"] = self._tool_bar_add_play_button(
|
||||
name="play",
|
||||
desc="Play/Pause",
|
||||
func=self._toggle_playback,
|
||||
shortcut=" ",
|
||||
)
|
||||
|
||||
# Configure playback
|
||||
self._playback = False
|
||||
self._playback_initialize(
|
||||
func=self._play,
|
||||
timeout=17,
|
||||
value=np.interp(current_time_func(), times, np.arange(len(times))),
|
||||
rng=[0, len(times) - 1],
|
||||
time_widget=self._widgets["time_slider"],
|
||||
play_widget=self._widgets["play"],
|
||||
)
|
||||
|
||||
# Keyboard shortcuts
|
||||
@_auto_weakref
|
||||
def shift_time(direction):
|
||||
amount = self._widgets["playback_speed"].get_value()
|
||||
publish(
|
||||
self._fig,
|
||||
TimeChange(time=self._current_time_func() + direction * amount),
|
||||
)
|
||||
|
||||
if self.plotter.iren is not None:
|
||||
self.plotter.add_key_event("n", partial(shift_time, direction=1))
|
||||
self.plotter.add_key_event("b", partial(shift_time, direction=-1))
|
||||
|
||||
# Subscribe to relevant UI events
|
||||
subscribe(fig, "time_change", self._on_time_change)
|
||||
subscribe(fig, "playback_speed", self._on_playback_speed)
|
||||
|
||||
def _on_time_change(self, event):
|
||||
"""Respond to time_change UI event."""
|
||||
from ..ui_events import disable_ui_events
|
||||
|
||||
new_time = np.clip(event.time, self._times[0], self._times[-1])
|
||||
new_time_idx = np.interp(new_time, self._times, np.arange(len(self._times)))
|
||||
|
||||
with disable_ui_events(self._fig):
|
||||
self._widgets["time_slider"].set_value(new_time_idx)
|
||||
self._widgets["current_time"].set_value(f"{new_time:.3f}")
|
||||
|
||||
def _on_playback_speed(self, event):
|
||||
"""Respond to playback_speed UI event."""
|
||||
from ..ui_events import disable_ui_events
|
||||
|
||||
with disable_ui_events(self._fig):
|
||||
self._widgets["playback_speed"].set_value(event.speed)
|
||||
|
||||
def _toggle_playback(self, value=None):
|
||||
"""Toggle time playback."""
|
||||
from ..ui_events import TimeChange, publish
|
||||
|
||||
if value is None:
|
||||
self._playback = not self._playback
|
||||
else:
|
||||
self._playback = value
|
||||
|
||||
if self._playback:
|
||||
self._tool_bar_update_button_icon(name="play", icon_name="pause")
|
||||
if self._current_time_func() == self._times[-1]: # start over
|
||||
publish(self._fig, TimeChange(time=self._times[0]))
|
||||
self._last_tick = time.time()
|
||||
else:
|
||||
self._tool_bar_update_button_icon(name="play", icon_name="play")
|
||||
|
||||
def _reset_time(self):
|
||||
"""Reset time and playback speed to initial values."""
|
||||
from ..ui_events import PlaybackSpeed, TimeChange, publish
|
||||
|
||||
publish(self._fig, TimeChange(time=self._init_time))
|
||||
publish(self._fig, PlaybackSpeed(speed=self._init_playback_speed))
|
||||
|
||||
@safe_event
|
||||
def _play(self):
|
||||
if self._playback:
|
||||
try:
|
||||
self._advance()
|
||||
except Exception:
|
||||
self._toggle_playback(value=False)
|
||||
raise
|
||||
|
||||
def _advance(self):
|
||||
from ..ui_events import TimeChange, publish
|
||||
|
||||
this_time = time.time()
|
||||
delta = this_time - self._last_tick
|
||||
self._last_tick = time.time()
|
||||
time_shift = delta * self._widgets["playback_speed"].get_value()
|
||||
new_time = min(self._current_time_func() + time_shift, self._times[-1])
|
||||
publish(self._fig, TimeChange(time=new_time))
|
||||
if new_time == self._times[-1]:
|
||||
self._toggle_playback(value=False)
|
||||
469
mne/viz/circle.py
Normal file
469
mne/viz/circle.py
Normal file
@@ -0,0 +1,469 @@
|
||||
"""Functions to plot on circle as for connectivity."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from functools import partial
|
||||
from itertools import cycle
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils import _validate_type
|
||||
from .utils import _get_cmap, plt_show
|
||||
|
||||
|
||||
def circular_layout(
|
||||
node_names,
|
||||
node_order,
|
||||
start_pos=90,
|
||||
start_between=True,
|
||||
group_boundaries=None,
|
||||
group_sep=10,
|
||||
):
|
||||
"""Create layout arranging nodes on a circle.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
node_names : list of str
|
||||
Node names.
|
||||
node_order : list of str
|
||||
List with node names defining the order in which the nodes are
|
||||
arranged. Must have the elements as node_names but the order can be
|
||||
different. The nodes are arranged clockwise starting at "start_pos"
|
||||
degrees.
|
||||
start_pos : float
|
||||
Angle in degrees that defines where the first node is plotted.
|
||||
start_between : bool
|
||||
If True, the layout starts with the position between the nodes. This is
|
||||
the same as adding "180. / len(node_names)" to start_pos.
|
||||
group_boundaries : None | array-like
|
||||
List of of boundaries between groups at which point a "group_sep" will
|
||||
be inserted. E.g. "[0, len(node_names) / 2]" will create two groups.
|
||||
group_sep : float
|
||||
Group separation angle in degrees. See "group_boundaries".
|
||||
|
||||
Returns
|
||||
-------
|
||||
node_angles : array, shape=(n_node_names,)
|
||||
Node angles in degrees.
|
||||
"""
|
||||
n_nodes = len(node_names)
|
||||
|
||||
if len(node_order) != n_nodes:
|
||||
raise ValueError("node_order has to be the same length as node_names")
|
||||
|
||||
if group_boundaries is not None:
|
||||
boundaries = np.array(group_boundaries, dtype=np.int64)
|
||||
if np.any(boundaries >= n_nodes) or np.any(boundaries < 0):
|
||||
raise ValueError('"group_boundaries" has to be between 0 and n_nodes - 1.')
|
||||
if len(boundaries) > 1 and np.any(np.diff(boundaries) <= 0):
|
||||
raise ValueError('"group_boundaries" must have non-decreasing values.')
|
||||
n_group_sep = len(group_boundaries)
|
||||
else:
|
||||
n_group_sep = 0
|
||||
boundaries = None
|
||||
|
||||
# convert it to a list with indices
|
||||
node_order = [node_order.index(name) for name in node_names]
|
||||
node_order = np.array(node_order)
|
||||
if len(np.unique(node_order)) != n_nodes:
|
||||
raise ValueError("node_order has repeated entries")
|
||||
|
||||
node_sep = (360.0 - n_group_sep * group_sep) / n_nodes
|
||||
|
||||
if start_between:
|
||||
start_pos += node_sep / 2
|
||||
|
||||
if boundaries is not None and boundaries[0] == 0:
|
||||
# special case when a group separator is at the start
|
||||
start_pos += group_sep / 2
|
||||
boundaries = boundaries[1:] if n_group_sep > 1 else None
|
||||
|
||||
node_angles = np.ones(n_nodes, dtype=np.float64) * node_sep
|
||||
node_angles[0] = start_pos
|
||||
if boundaries is not None:
|
||||
node_angles[boundaries] += group_sep
|
||||
|
||||
node_angles = np.cumsum(node_angles)[node_order]
|
||||
|
||||
return node_angles
|
||||
|
||||
|
||||
def _plot_connectivity_circle_onpick(
|
||||
event,
|
||||
fig=None,
|
||||
ax=None,
|
||||
indices=None,
|
||||
n_nodes=0,
|
||||
node_angles=None,
|
||||
ylim=(9, 10),
|
||||
):
|
||||
"""Isolate connections around a single node when user left clicks a node.
|
||||
|
||||
On right click, resets all connections.
|
||||
"""
|
||||
if event.inaxes != ax:
|
||||
return
|
||||
|
||||
if event.button == 1: # left click
|
||||
# click must be near node radius
|
||||
if not ylim[0] <= event.ydata <= ylim[1]:
|
||||
return
|
||||
|
||||
# all angles in range [0, 2*pi]
|
||||
node_angles = node_angles % (np.pi * 2)
|
||||
node = np.argmin(np.abs(event.xdata - node_angles))
|
||||
|
||||
patches = event.inaxes.patches
|
||||
for ii, (x, y) in enumerate(zip(indices[0], indices[1])):
|
||||
patches[ii].set_visible(node in [x, y])
|
||||
fig.canvas.draw()
|
||||
elif event.button == 3: # right click
|
||||
patches = event.inaxes.patches
|
||||
for ii in range(np.size(indices, axis=1)):
|
||||
patches[ii].set_visible(True)
|
||||
fig.canvas.draw()
|
||||
|
||||
|
||||
def _plot_connectivity_circle(
|
||||
con,
|
||||
node_names,
|
||||
indices=None,
|
||||
n_lines=None,
|
||||
node_angles=None,
|
||||
node_width=None,
|
||||
node_height=None,
|
||||
node_colors=None,
|
||||
facecolor="black",
|
||||
textcolor="white",
|
||||
node_edgecolor="black",
|
||||
linewidth=1.5,
|
||||
colormap="hot",
|
||||
vmin=None,
|
||||
vmax=None,
|
||||
colorbar=True,
|
||||
title=None,
|
||||
colorbar_size=None,
|
||||
colorbar_pos=None,
|
||||
fontsize_title=12,
|
||||
fontsize_names=8,
|
||||
fontsize_colorbar=8,
|
||||
padding=6.0,
|
||||
ax=None,
|
||||
interactive=True,
|
||||
node_linewidth=2.0,
|
||||
show=True,
|
||||
):
|
||||
import matplotlib.patches as m_patches
|
||||
import matplotlib.path as m_path
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.projections.polar import PolarAxes
|
||||
|
||||
_validate_type(ax, (None, PolarAxes))
|
||||
|
||||
n_nodes = len(node_names)
|
||||
|
||||
if node_angles is not None:
|
||||
if len(node_angles) != n_nodes:
|
||||
raise ValueError("node_angles has to be the same length as node_names")
|
||||
# convert it to radians
|
||||
node_angles = node_angles * np.pi / 180
|
||||
else:
|
||||
# uniform layout on unit circle
|
||||
node_angles = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False)
|
||||
|
||||
if node_width is None:
|
||||
# widths correspond to the minimum angle between two nodes
|
||||
dist_mat = node_angles[None, :] - node_angles[:, None]
|
||||
dist_mat[np.diag_indices(n_nodes)] = 1e9
|
||||
node_width = np.min(np.abs(dist_mat))
|
||||
else:
|
||||
node_width = node_width * np.pi / 180
|
||||
|
||||
if node_height is None:
|
||||
node_height = 1.0
|
||||
|
||||
if node_colors is not None:
|
||||
if len(node_colors) < n_nodes:
|
||||
node_colors = cycle(node_colors)
|
||||
else:
|
||||
# assign colors using colormap
|
||||
try:
|
||||
spectral = plt.cm.spectral
|
||||
except AttributeError:
|
||||
spectral = plt.cm.Spectral
|
||||
node_colors = [spectral(i / float(n_nodes)) for i in range(n_nodes)]
|
||||
|
||||
# handle 1D and 2D connectivity information
|
||||
if con.ndim == 1:
|
||||
if indices is None:
|
||||
raise ValueError("indices has to be provided if con.ndim == 1")
|
||||
elif con.ndim == 2:
|
||||
if con.shape[0] != n_nodes or con.shape[1] != n_nodes:
|
||||
raise ValueError("con has to be 1D or a square matrix")
|
||||
# we use the lower-triangular part
|
||||
indices = np.tril_indices(n_nodes, -1)
|
||||
con = con[indices]
|
||||
else:
|
||||
raise ValueError("con has to be 1D or a square matrix")
|
||||
|
||||
# get the colormap
|
||||
colormap = _get_cmap(colormap)
|
||||
|
||||
# Use a polar axes
|
||||
if ax is None:
|
||||
fig = plt.figure(figsize=(8, 8), facecolor=facecolor, layout="constrained")
|
||||
ax = fig.add_subplot(polar=True)
|
||||
else:
|
||||
fig = ax.figure
|
||||
ax.set_facecolor(facecolor)
|
||||
|
||||
# No ticks, we'll put our own
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
# Set y axes limit, add additional space if requested
|
||||
ax.set_ylim(0, 10 + padding)
|
||||
|
||||
# Remove the black axes border which may obscure the labels
|
||||
ax.spines["polar"].set_visible(False)
|
||||
|
||||
# Draw lines between connected nodes, only draw the strongest connections
|
||||
if n_lines is not None and len(con) > n_lines:
|
||||
con_thresh = np.sort(np.abs(con).ravel())[-n_lines]
|
||||
else:
|
||||
con_thresh = 0.0
|
||||
|
||||
# get the connections which we are drawing and sort by connection strength
|
||||
# this will allow us to draw the strongest connections first
|
||||
con_abs = np.abs(con)
|
||||
con_draw_idx = np.where(con_abs >= con_thresh)[0]
|
||||
|
||||
con = con[con_draw_idx]
|
||||
con_abs = con_abs[con_draw_idx]
|
||||
indices = [ind[con_draw_idx] for ind in indices]
|
||||
|
||||
# now sort them
|
||||
sort_idx = np.argsort(con_abs)
|
||||
del con_abs
|
||||
con = con[sort_idx]
|
||||
indices = [ind[sort_idx] for ind in indices]
|
||||
|
||||
# Get vmin vmax for color scaling
|
||||
if vmin is None:
|
||||
vmin = np.min(con[np.abs(con) >= con_thresh])
|
||||
if vmax is None:
|
||||
vmax = np.max(con)
|
||||
vrange = vmax - vmin
|
||||
|
||||
# We want to add some "noise" to the start and end position of the
|
||||
# edges: We modulate the noise with the number of connections of the
|
||||
# node and the connection strength, such that the strongest connections
|
||||
# are closer to the node center
|
||||
nodes_n_con = np.zeros((n_nodes), dtype=np.int64)
|
||||
for i, j in zip(indices[0], indices[1]):
|
||||
nodes_n_con[i] += 1
|
||||
nodes_n_con[j] += 1
|
||||
|
||||
# initialize random number generator so plot is reproducible
|
||||
rng = np.random.mtrand.RandomState(0)
|
||||
|
||||
n_con = len(indices[0])
|
||||
noise_max = 0.25 * node_width
|
||||
start_noise = rng.uniform(-noise_max, noise_max, n_con)
|
||||
end_noise = rng.uniform(-noise_max, noise_max, n_con)
|
||||
|
||||
nodes_n_con_seen = np.zeros_like(nodes_n_con)
|
||||
for i, (start, end) in enumerate(zip(indices[0], indices[1])):
|
||||
nodes_n_con_seen[start] += 1
|
||||
nodes_n_con_seen[end] += 1
|
||||
|
||||
start_noise[i] *= (nodes_n_con[start] - nodes_n_con_seen[start]) / float(
|
||||
nodes_n_con[start]
|
||||
)
|
||||
end_noise[i] *= (nodes_n_con[end] - nodes_n_con_seen[end]) / float(
|
||||
nodes_n_con[end]
|
||||
)
|
||||
|
||||
# scale connectivity for colormap (vmin<=>0, vmax<=>1)
|
||||
con_val_scaled = (con - vmin) / vrange
|
||||
|
||||
# Finally, we draw the connections
|
||||
for pos, (i, j) in enumerate(zip(indices[0], indices[1])):
|
||||
# Start point
|
||||
t0, r0 = node_angles[i], 10
|
||||
|
||||
# End point
|
||||
t1, r1 = node_angles[j], 10
|
||||
|
||||
# Some noise in start and end point
|
||||
t0 += start_noise[pos]
|
||||
t1 += end_noise[pos]
|
||||
|
||||
verts = [(t0, r0), (t0, 5), (t1, 5), (t1, r1)]
|
||||
codes = [
|
||||
m_path.Path.MOVETO,
|
||||
m_path.Path.CURVE4,
|
||||
m_path.Path.CURVE4,
|
||||
m_path.Path.LINETO,
|
||||
]
|
||||
path = m_path.Path(verts, codes)
|
||||
|
||||
color = colormap(con_val_scaled[pos])
|
||||
|
||||
# Actual line
|
||||
patch = m_patches.PathPatch(
|
||||
path, fill=False, edgecolor=color, linewidth=linewidth, alpha=1.0
|
||||
)
|
||||
ax.add_patch(patch)
|
||||
|
||||
# Draw ring with colored nodes
|
||||
height = np.ones(n_nodes) * node_height
|
||||
bars = ax.bar(
|
||||
node_angles,
|
||||
height,
|
||||
width=node_width,
|
||||
bottom=9,
|
||||
edgecolor=node_edgecolor,
|
||||
lw=node_linewidth,
|
||||
facecolor=".9",
|
||||
align="center",
|
||||
)
|
||||
|
||||
for bar, color in zip(bars, node_colors):
|
||||
bar.set_facecolor(color)
|
||||
|
||||
# Draw node labels
|
||||
angles_deg = 180 * node_angles / np.pi
|
||||
for name, angle_rad, angle_deg in zip(node_names, node_angles, angles_deg):
|
||||
if angle_deg >= 270:
|
||||
ha = "left"
|
||||
else:
|
||||
# Flip the label, so text is always upright
|
||||
angle_deg += 180
|
||||
ha = "right"
|
||||
|
||||
ax.text(
|
||||
angle_rad,
|
||||
9.4 + node_height,
|
||||
name,
|
||||
size=fontsize_names,
|
||||
rotation=angle_deg,
|
||||
rotation_mode="anchor",
|
||||
horizontalalignment=ha,
|
||||
verticalalignment="center",
|
||||
color=textcolor,
|
||||
)
|
||||
|
||||
if title is not None:
|
||||
ax.set_title(title, color=textcolor, fontsize=fontsize_title)
|
||||
|
||||
if colorbar:
|
||||
sm = plt.cm.ScalarMappable(cmap=colormap, norm=plt.Normalize(vmin, vmax))
|
||||
sm.set_array(np.linspace(vmin, vmax))
|
||||
colorbar_kwargs = dict()
|
||||
if colorbar_size is not None:
|
||||
colorbar_kwargs.update(shrink=colorbar_size)
|
||||
if colorbar_pos is not None:
|
||||
colorbar_kwargs.update(anchor=colorbar_pos)
|
||||
cb = fig.colorbar(sm, ax=ax, **colorbar_kwargs)
|
||||
cb_yticks = plt.getp(cb.ax.axes, "yticklabels")
|
||||
cb.ax.tick_params(labelsize=fontsize_colorbar)
|
||||
plt.setp(cb_yticks, color=textcolor)
|
||||
|
||||
# Add callback for interaction
|
||||
if interactive:
|
||||
callback = partial(
|
||||
_plot_connectivity_circle_onpick,
|
||||
fig=fig,
|
||||
ax=ax,
|
||||
indices=indices,
|
||||
n_nodes=n_nodes,
|
||||
node_angles=node_angles,
|
||||
)
|
||||
|
||||
fig.canvas.mpl_connect("button_press_event", callback)
|
||||
|
||||
plt_show(show)
|
||||
return fig, ax
|
||||
|
||||
|
||||
def plot_channel_labels_circle(labels, colors=None, picks=None, **kwargs):
|
||||
"""Plot labels for each channel in a circle plot.
|
||||
|
||||
.. note:: This primarily makes sense for sEEG channels where each
|
||||
channel can be assigned an anatomical label as the electrode
|
||||
passes through various brain areas.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
labels : dict
|
||||
Lists of labels (values) associated with each channel (keys).
|
||||
colors : dict
|
||||
The color (value) for each label (key).
|
||||
picks : list | tuple
|
||||
The channels to consider.
|
||||
**kwargs : kwargs
|
||||
Keyword arguments for
|
||||
:func:`mne_connectivity.viz.plot_connectivity_circle`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
fig : instance of matplotlib.figure.Figure
|
||||
The figure handle.
|
||||
axes : instance of matplotlib.projections.polar.PolarAxes
|
||||
The subplot handle.
|
||||
"""
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
|
||||
_validate_type(labels, dict, "labels")
|
||||
_validate_type(colors, (dict, None), "colors")
|
||||
_validate_type(picks, (list, tuple, None), "picks")
|
||||
if picks is not None:
|
||||
labels = {k: v for k, v in labels.items() if k in picks}
|
||||
ch_names = list(labels.keys())
|
||||
all_labels = list(set([label for val in labels.values() for label in val]))
|
||||
n_labels = len(all_labels)
|
||||
if colors is not None:
|
||||
for label in all_labels:
|
||||
if label not in colors:
|
||||
raise ValueError(f"No color provided for {label} in `colors`")
|
||||
# update all_labels, there may be unconnected labels in colors
|
||||
all_labels = list(colors.keys())
|
||||
n_labels = len(all_labels)
|
||||
# make colormap
|
||||
label_colors = [colors[label] for label in all_labels]
|
||||
node_colors = ["black"] * len(ch_names) + label_colors
|
||||
label_cmap = LinearSegmentedColormap.from_list(
|
||||
"label_cmap", label_colors, N=len(label_colors)
|
||||
)
|
||||
else:
|
||||
node_colors = None
|
||||
|
||||
node_names = ch_names + all_labels
|
||||
con = np.zeros((len(node_names), len(node_names))) * np.nan
|
||||
for idx, ch_name in enumerate(ch_names):
|
||||
for label in labels[ch_name]:
|
||||
node_idx = node_names.index(label)
|
||||
label_color = all_labels.index(label) / n_labels
|
||||
con[idx, node_idx] = con[node_idx, idx] = label_color # symmetric
|
||||
# plot
|
||||
node_order = ch_names + all_labels[::-1]
|
||||
node_angles = circular_layout(
|
||||
node_names, node_order, start_pos=90, group_boundaries=[0, len(ch_names)]
|
||||
)
|
||||
# provide defaults but don't overwrite
|
||||
if "node_angles" not in kwargs:
|
||||
kwargs.update(node_angles=node_angles)
|
||||
if "colorbar" not in kwargs:
|
||||
kwargs.update(colorbar=False)
|
||||
if "node_colors" not in kwargs:
|
||||
kwargs.update(node_colors=node_colors)
|
||||
if "vmin" not in kwargs:
|
||||
kwargs.update(vmin=0)
|
||||
if "vmax" not in kwargs:
|
||||
kwargs.update(vmax=1)
|
||||
if "colormap" not in kwargs:
|
||||
kwargs.update(colormap=label_cmap)
|
||||
return _plot_connectivity_circle(con, node_names, **kwargs)
|
||||
47
mne/viz/conftest.py
Normal file
47
mne/viz/conftest.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import os.path as op
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mne import Epochs, EvokedArray, create_info, events_from_annotations
|
||||
from mne.channels import make_standard_montage
|
||||
from mne.datasets.testing import _pytest_param, data_path
|
||||
from mne.io import read_raw_nirx
|
||||
from mne.preprocessing.nirs import beer_lambert_law, optical_density
|
||||
|
||||
fname_nirx = op.join(
|
||||
data_path(download=False), "NIRx", "nirscout", "nirx_15_2_recording_w_overlap"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fnirs_evoked():
|
||||
"""Create an fnirs evoked structure."""
|
||||
montage = make_standard_montage("biosemi16")
|
||||
ch_names = montage.ch_names
|
||||
ch_types = ["eeg"] * 16
|
||||
info = create_info(ch_names=ch_names, sfreq=20, ch_types=ch_types)
|
||||
evoked_data = np.random.randn(16, 30)
|
||||
evoked = EvokedArray(evoked_data, info=info, tmin=-0.2, nave=4)
|
||||
evoked.set_montage(montage)
|
||||
evoked.set_channel_types(
|
||||
{"Fp1": "hbo", "Fp2": "hbo", "F4": "hbo", "Fz": "hbo"}, verbose="error"
|
||||
)
|
||||
return evoked
|
||||
|
||||
|
||||
@pytest.fixture(params=[_pytest_param()])
|
||||
def fnirs_epochs():
|
||||
"""Create an fnirs epoch structure."""
|
||||
raw_intensity = read_raw_nirx(fname_nirx, preload=False)
|
||||
raw_od = optical_density(raw_intensity)
|
||||
raw_haemo = beer_lambert_law(raw_od, ppf=6.0)
|
||||
evts, _ = events_from_annotations(raw_haemo, event_id={"1.0": 1})
|
||||
evts_dct = {"A": 1}
|
||||
tn, tx = -1, 2
|
||||
epochs = Epochs(raw_haemo, evts, event_id=evts_dct, tmin=tn, tmax=tx)
|
||||
return epochs
|
||||
1160
mne/viz/epochs.py
Normal file
1160
mne/viz/epochs.py
Normal file
File diff suppressed because it is too large
Load Diff
3207
mne/viz/evoked.py
Normal file
3207
mne/viz/evoked.py
Normal file
File diff suppressed because it is too large
Load Diff
577
mne/viz/evoked_field.py
Normal file
577
mne/viz/evoked_field.py
Normal file
@@ -0,0 +1,577 @@
|
||||
"""Class to draw evoked MEG and EEG fieldlines, with a GUI to control the figure.
|
||||
|
||||
author: Marijn van Vliet <w.m.vanvliet@gmail.com>
|
||||
"""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from scipy.interpolate import interp1d
|
||||
|
||||
from .._fiff.pick import pick_types
|
||||
from ..defaults import DEFAULTS
|
||||
from ..utils import (
|
||||
_auto_weakref,
|
||||
_check_option,
|
||||
_ensure_int,
|
||||
_to_rgb,
|
||||
_validate_type,
|
||||
fill_doc,
|
||||
)
|
||||
from ._3d_overlay import _LayeredMesh
|
||||
from .ui_events import (
|
||||
ColormapRange,
|
||||
Contours,
|
||||
TimeChange,
|
||||
disable_ui_events,
|
||||
publish,
|
||||
subscribe,
|
||||
)
|
||||
from .utils import mne_analyze_colormap
|
||||
|
||||
|
||||
@fill_doc
|
||||
class EvokedField:
|
||||
"""Plot MEG/EEG fields on head surface and helmet in 3D.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
evoked : instance of mne.Evoked
|
||||
The evoked object.
|
||||
surf_maps : list
|
||||
The surface mapping information obtained with make_field_map.
|
||||
time : float | None
|
||||
The time point at which the field map shall be displayed. If None,
|
||||
the average peak latency (across sensor types) is used.
|
||||
time_label : str | None
|
||||
How to print info about the time instant visualized.
|
||||
%(n_jobs)s
|
||||
fig : instance of Figure3D | None
|
||||
If None (default), a new figure will be created, otherwise it will
|
||||
plot into the given figure.
|
||||
|
||||
.. versionadded:: 0.20
|
||||
vmax : float | dict | None
|
||||
Maximum intensity. Can be a dictionary with two entries ``"eeg"`` and ``"meg"``
|
||||
to specify separate values for EEG and MEG fields respectively. Can be
|
||||
``None`` to use the maximum value of the data.
|
||||
|
||||
.. versionadded:: 0.21
|
||||
.. versionadded:: 1.4
|
||||
``vmax`` can be a dictionary to specify separate values for EEG and
|
||||
MEG fields.
|
||||
n_contours : int
|
||||
The number of contours.
|
||||
|
||||
.. versionadded:: 0.21
|
||||
show_density : bool
|
||||
Whether to draw the field density as an overlay on top of the helmet/head
|
||||
surface. Defaults to ``True``.
|
||||
alpha : float | dict | None
|
||||
Opacity of the meshes (between 0 and 1). Can be a dictionary with two
|
||||
entries ``"eeg"`` and ``"meg"`` to specify separate values for EEG and
|
||||
MEG fields respectively. Can be ``None`` to use 1.0 when a single field
|
||||
map is shown, or ``dict(eeg=1.0, meg=0.5)`` when both field maps are shown.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
%(interpolation_brain_time)s
|
||||
|
||||
.. versionadded:: 1.6
|
||||
%(interaction_scene)s
|
||||
Defaults to ``'terrain'``.
|
||||
|
||||
.. versionadded:: 1.1
|
||||
time_viewer : bool | str
|
||||
Display time viewer GUI. Can also be ``"auto"``, which will mean
|
||||
``True`` if there is more than one time point and ``False`` otherwise.
|
||||
|
||||
.. versionadded:: 1.6
|
||||
%(verbose)s
|
||||
|
||||
Notes
|
||||
-----
|
||||
The figure will publish and subscribe to the following UI events:
|
||||
|
||||
* :class:`~mne.viz.ui_events.TimeChange`
|
||||
* :class:`~mne.viz.ui_events.Contours`, ``kind="field_strength_meg" | "field_strength_eeg"``
|
||||
* :class:`~mne.viz.ui_events.ColormapRange`, ``kind="field_strength_meg" | "field_strength_eeg"``
|
||||
""" # noqa
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
evoked,
|
||||
surf_maps,
|
||||
*,
|
||||
time=None,
|
||||
time_label="t = %0.0f ms",
|
||||
n_jobs=None,
|
||||
fig=None,
|
||||
vmax=None,
|
||||
n_contours=21,
|
||||
show_density=True,
|
||||
alpha=None,
|
||||
interpolation="nearest",
|
||||
interaction="terrain",
|
||||
time_viewer="auto",
|
||||
verbose=None,
|
||||
):
|
||||
from .backends.renderer import _get_3d_backend, _get_renderer
|
||||
|
||||
# Setup figure parameters
|
||||
self._evoked = evoked
|
||||
if time is None:
|
||||
types = [t for t in ["eeg", "grad", "mag"] if t in evoked]
|
||||
time = np.mean([evoked.get_peak(ch_type=t)[1] for t in types])
|
||||
self._current_time = time
|
||||
if not evoked.times[0] <= time <= evoked.times[-1]:
|
||||
raise ValueError(f"`time` ({time:0.3f}) must be inside `evoked.times`")
|
||||
self._time_label = time_label
|
||||
|
||||
self._vmax = _validate_type(vmax, (None, "numeric", dict), "vmax")
|
||||
self._n_contours = _ensure_int(n_contours, "n_contours")
|
||||
self._time_interpolation = _check_option(
|
||||
"interpolation",
|
||||
interpolation,
|
||||
("linear", "nearest", "zero", "slinear", "quadratic", "cubic"),
|
||||
)
|
||||
self._interaction = _check_option(
|
||||
"interaction", interaction, ["trackball", "terrain"]
|
||||
)
|
||||
|
||||
surf_map_kinds = [surf_map["kind"] for surf_map in surf_maps]
|
||||
if vmax is None:
|
||||
self._vmax = {kind: None for kind in surf_map_kinds}
|
||||
elif isinstance(vmax, dict):
|
||||
for kind in surf_map_kinds:
|
||||
if kind not in vmax:
|
||||
raise ValueError(
|
||||
f'No entry for "{kind}" found in the vmax dictionary'
|
||||
)
|
||||
self._vmax = vmax
|
||||
else: # float value
|
||||
self._vmax = {kind: vmax for kind in surf_map_kinds}
|
||||
|
||||
if alpha is None:
|
||||
self._alpha = {
|
||||
surf_map["kind"]: val for surf_map, val in zip(surf_maps, [1.0, 0.5])
|
||||
}
|
||||
elif isinstance(alpha, dict):
|
||||
for kind in surf_map_kinds:
|
||||
if kind not in alpha:
|
||||
raise ValueError(
|
||||
f'No entry for "{kind}" found in the alpha dictionary'
|
||||
)
|
||||
self._alpha = alpha
|
||||
else: # float value
|
||||
self._alpha = {kind: alpha for kind in surf_map_kinds}
|
||||
|
||||
self._colors = [(0.6, 0.6, 0.6), (1.0, 1.0, 1.0)]
|
||||
self._colormap = mne_analyze_colormap(format="vtk")
|
||||
self._colormap_lines = np.concatenate(
|
||||
[
|
||||
np.tile([0.0, 0.0, 255.0, 255.0], (127, 1)),
|
||||
np.tile([0.0, 0.0, 0.0, 255.0], (2, 1)),
|
||||
np.tile([255.0, 0.0, 0.0, 255.0], (127, 1)),
|
||||
]
|
||||
)
|
||||
self._show_density = show_density
|
||||
|
||||
from ._brain import Brain
|
||||
|
||||
if isinstance(fig, Brain):
|
||||
self._renderer = fig._renderer
|
||||
self._in_brain_figure = True
|
||||
if _get_3d_backend() == "notebook":
|
||||
raise NotImplementedError(
|
||||
"Plotting on top of an existing Brain figure "
|
||||
"is currently not supported inside a notebook."
|
||||
)
|
||||
else:
|
||||
self._renderer = _get_renderer(
|
||||
fig, bgcolor=(0.0, 0.0, 0.0), size=(600, 600)
|
||||
)
|
||||
self._in_brain_figure = False
|
||||
|
||||
self.plotter = self._renderer.plotter
|
||||
self.interaction = interaction
|
||||
|
||||
# Prepare the surface maps
|
||||
self._surf_maps = [
|
||||
self._prepare_surf_map(surf_map, color, self._alpha[surf_map["kind"]])
|
||||
for surf_map, color in zip(surf_maps, self._colors)
|
||||
]
|
||||
|
||||
# Do we want the time viewer?
|
||||
if time_viewer == "auto":
|
||||
time_viewer = len(evoked.times) > 1
|
||||
self.time_viewer = time_viewer
|
||||
|
||||
# Configure UI events
|
||||
@_auto_weakref
|
||||
def current_time_func():
|
||||
return self._current_time
|
||||
|
||||
self._widgets = dict()
|
||||
if self.time_viewer:
|
||||
# Draw widgets only if not inside a figure that already has them.
|
||||
if (
|
||||
not hasattr(self._renderer, "_widgets")
|
||||
or "time_slider" not in self._renderer._widgets
|
||||
):
|
||||
self._renderer._enable_time_interaction(
|
||||
self,
|
||||
current_time_func=current_time_func,
|
||||
times=evoked.times,
|
||||
)
|
||||
if not self._in_brain_figure or "time_slider" not in fig.widgets:
|
||||
# Draw the time label
|
||||
self._time_label = time_label
|
||||
if time_label is not None:
|
||||
if "%" in time_label:
|
||||
time_label = time_label % np.round(1e3 * time)
|
||||
self._time_label_actor = self._renderer.text2d(
|
||||
x_window=0.01, y_window=0.01, text=time_label
|
||||
)
|
||||
self._configure_dock()
|
||||
|
||||
subscribe(self, "time_change", self._on_time_change)
|
||||
subscribe(self, "colormap_range", self._on_colormap_range)
|
||||
subscribe(self, "contours", self._on_contours)
|
||||
|
||||
if not self._in_brain_figure:
|
||||
self._renderer.set_interaction(interaction)
|
||||
self._renderer.set_camera(azimuth=10, elevation=60, distance="auto")
|
||||
self._renderer.show()
|
||||
|
||||
def _prepare_surf_map(self, surf_map, color, alpha):
|
||||
"""Compute all the data required to render a fieldlines map."""
|
||||
if surf_map["kind"] == "eeg":
|
||||
pick = pick_types(self._evoked.info, meg=False, eeg=True)
|
||||
else:
|
||||
pick = pick_types(self._evoked.info, meg=True, eeg=False, ref_meg=False)
|
||||
|
||||
evoked_ch_names = set([self._evoked.ch_names[k] for k in pick])
|
||||
map_ch_names = set(surf_map["ch_names"])
|
||||
if evoked_ch_names != map_ch_names:
|
||||
message = ["Channels in map and data do not match."]
|
||||
diff = map_ch_names - evoked_ch_names
|
||||
if len(diff):
|
||||
message += [f"{list(diff)} not in data file. "]
|
||||
diff = evoked_ch_names - map_ch_names
|
||||
if len(diff):
|
||||
message += [f"{list(diff)} not in map file."]
|
||||
raise RuntimeError(" ".join(message))
|
||||
|
||||
data = surf_map["data"] @ self._evoked.data[pick]
|
||||
data_interp = interp1d(
|
||||
self._evoked.times,
|
||||
data,
|
||||
kind=self._time_interpolation,
|
||||
assume_sorted=True,
|
||||
)
|
||||
current_data = data_interp(self._current_time)
|
||||
|
||||
# Make a solid surface
|
||||
surf = surf_map["surf"]
|
||||
if self._in_brain_figure:
|
||||
surf["rr"] *= 1000
|
||||
map_vmax = self._vmax.get(surf_map["kind"])
|
||||
if map_vmax is None:
|
||||
map_vmax = float(np.max(current_data))
|
||||
mesh = _LayeredMesh(
|
||||
renderer=self._renderer,
|
||||
vertices=surf["rr"],
|
||||
triangles=surf["tris"],
|
||||
normals=surf["nn"],
|
||||
)
|
||||
mesh.map()
|
||||
color = _to_rgb(color, alpha=True)
|
||||
cmap = np.array([(0, 0, 0, 0), color])
|
||||
ctable = np.round(cmap * 255).astype(np.uint8)
|
||||
mesh.add_overlay(
|
||||
scalars=np.ones(len(current_data)),
|
||||
colormap=ctable,
|
||||
rng=[0, 1],
|
||||
opacity=alpha,
|
||||
name="surf",
|
||||
)
|
||||
|
||||
# Show the field density
|
||||
if self._show_density:
|
||||
mesh.add_overlay(
|
||||
scalars=current_data,
|
||||
colormap=self._colormap,
|
||||
rng=[-map_vmax, map_vmax],
|
||||
opacity=1.0,
|
||||
name="field",
|
||||
)
|
||||
|
||||
# And the field lines on top
|
||||
if self._n_contours > 1:
|
||||
contours = np.linspace(-map_vmax, map_vmax, self._n_contours)
|
||||
contours_actor, _ = self._renderer.contour(
|
||||
surface=surf,
|
||||
scalars=current_data,
|
||||
contours=contours,
|
||||
vmin=-map_vmax,
|
||||
vmax=map_vmax,
|
||||
colormap=self._colormap_lines,
|
||||
)
|
||||
else:
|
||||
contours = None # noqa
|
||||
contours_actor = None
|
||||
|
||||
return dict(
|
||||
pick=pick,
|
||||
data=data,
|
||||
data_interp=data_interp,
|
||||
map_kind=surf_map["kind"],
|
||||
mesh=mesh,
|
||||
contours=contours,
|
||||
contours_actor=contours_actor,
|
||||
surf=surf,
|
||||
map_vmax=map_vmax,
|
||||
)
|
||||
|
||||
def _update(self):
|
||||
"""Update the figure to reflect the current settings."""
|
||||
for surf_map in self._surf_maps:
|
||||
current_data = surf_map["data_interp"](self._current_time)
|
||||
surf_map["mesh"].update_overlay(name="field", scalars=current_data)
|
||||
|
||||
if surf_map["contours"] is not None:
|
||||
self._renderer.plotter.remove_actor(
|
||||
surf_map["contours_actor"], render=False
|
||||
)
|
||||
if self._n_contours > 1:
|
||||
surf_map["contours_actor"], _ = self._renderer.contour(
|
||||
surface=surf_map["surf"],
|
||||
scalars=current_data,
|
||||
contours=surf_map["contours"],
|
||||
vmin=-surf_map["map_vmax"],
|
||||
vmax=surf_map["map_vmax"],
|
||||
colormap=self._colormap_lines,
|
||||
)
|
||||
if self._time_label is not None:
|
||||
if hasattr(self, "_time_label_actor"):
|
||||
self._renderer.plotter.remove_actor(
|
||||
self._time_label_actor, render=False
|
||||
)
|
||||
time_label = self._time_label
|
||||
if "%" in self._time_label:
|
||||
time_label = self._time_label % np.round(1e3 * self._current_time)
|
||||
self._time_label_actor = self._renderer.text2d(
|
||||
x_window=0.01, y_window=0.01, text=time_label
|
||||
)
|
||||
|
||||
self._renderer.plotter.update()
|
||||
|
||||
def _configure_dock(self):
|
||||
"""Configure the widgets shown in the dock on the left."""
|
||||
r = self._renderer
|
||||
|
||||
if not hasattr(r, "_dock"):
|
||||
r._dock_initialize()
|
||||
|
||||
# Fieldline configuration
|
||||
layout = r._dock_add_group_box("Fieldlines")
|
||||
|
||||
if self._show_density:
|
||||
r._dock_add_label(value="max value", align=True, layout=layout)
|
||||
|
||||
@_auto_weakref
|
||||
def _callback(vmax, kind, scaling):
|
||||
self.set_vmax(vmax / scaling, kind=kind)
|
||||
|
||||
for surf_map in self._surf_maps:
|
||||
if surf_map["map_kind"] == "meg":
|
||||
scaling = DEFAULTS["scalings"]["grad"]
|
||||
else:
|
||||
scaling = DEFAULTS["scalings"]["eeg"]
|
||||
rng = [0, np.max(np.abs(surf_map["data"])) * scaling]
|
||||
hlayout = r._dock_add_layout(vertical=False)
|
||||
|
||||
self._widgets[f"vmax_slider_{surf_map['map_kind']}"] = (
|
||||
r._dock_add_slider(
|
||||
name=surf_map["map_kind"].upper(),
|
||||
value=surf_map["map_vmax"] * scaling,
|
||||
rng=rng,
|
||||
callback=partial(
|
||||
_callback, kind=surf_map["map_kind"], scaling=scaling
|
||||
),
|
||||
double=True,
|
||||
layout=hlayout,
|
||||
)
|
||||
)
|
||||
self._widgets[f"vmax_spin_{surf_map['map_kind']}"] = (
|
||||
r._dock_add_spin_box(
|
||||
name="",
|
||||
value=surf_map["map_vmax"] * scaling,
|
||||
rng=rng,
|
||||
callback=partial(
|
||||
_callback, kind=surf_map["map_kind"], scaling=scaling
|
||||
),
|
||||
layout=hlayout,
|
||||
)
|
||||
)
|
||||
r._layout_add_widget(layout, hlayout)
|
||||
|
||||
hlayout = r._dock_add_layout(vertical=False)
|
||||
r._dock_add_label(
|
||||
value="Rescale",
|
||||
align=True,
|
||||
layout=hlayout,
|
||||
)
|
||||
r._dock_add_button(
|
||||
name="↺",
|
||||
callback=self._rescale,
|
||||
layout=hlayout,
|
||||
style="toolbutton",
|
||||
)
|
||||
r._layout_add_widget(layout, hlayout)
|
||||
|
||||
self._widgets["contours"] = r._dock_add_spin_box(
|
||||
name="Contour lines",
|
||||
value=21,
|
||||
rng=[0, 99],
|
||||
step=1,
|
||||
double=False,
|
||||
callback=self.set_contours,
|
||||
layout=layout,
|
||||
)
|
||||
r._dock_finalize()
|
||||
|
||||
def _on_time_change(self, event):
|
||||
"""Respond to time_change UI event."""
|
||||
new_time = np.clip(event.time, self._evoked.times[0], self._evoked.times[-1])
|
||||
if new_time == self._current_time:
|
||||
return
|
||||
self._current_time = new_time
|
||||
self._update()
|
||||
|
||||
def _on_colormap_range(self, event):
|
||||
"""Response to the colormap_range UI event."""
|
||||
if event.kind == "field_strength_meg":
|
||||
kind = "meg"
|
||||
elif event.kind == "field_strength_eeg":
|
||||
kind = "eeg"
|
||||
else:
|
||||
return
|
||||
|
||||
for surf_map in self._surf_maps:
|
||||
if surf_map["map_kind"] == kind:
|
||||
break
|
||||
else:
|
||||
# No field map currently shown of the requested type.
|
||||
return
|
||||
|
||||
vmin = event.fmin
|
||||
vmax = event.fmax
|
||||
surf_map["contours"] = np.linspace(vmin, vmax, self._n_contours)
|
||||
|
||||
if self._show_density:
|
||||
surf_map["mesh"].update_overlay(name="field", rng=[vmin, vmax])
|
||||
# Update the GUI widgets
|
||||
if kind == "meg":
|
||||
scaling = DEFAULTS["scalings"]["grad"]
|
||||
else:
|
||||
scaling = DEFAULTS["scalings"]["eeg"]
|
||||
with disable_ui_events(self):
|
||||
widget = self._widgets.get(f"vmax_slider_{kind}", None)
|
||||
if widget is not None:
|
||||
widget.set_value(vmax * scaling)
|
||||
widget = self._widgets.get(f"vmax_spin_{kind}", None)
|
||||
if widget is not None:
|
||||
widget.set_value(vmax * scaling)
|
||||
|
||||
self._update()
|
||||
|
||||
def _on_contours(self, event):
|
||||
"""Respond to the contours UI event."""
|
||||
if event.kind == "field_strength_meg":
|
||||
kind = "meg"
|
||||
elif event.kind == "field_strength_eeg":
|
||||
kind = "eeg"
|
||||
else:
|
||||
return
|
||||
|
||||
for surf_map in self._surf_maps:
|
||||
if surf_map["map_kind"] == kind:
|
||||
break
|
||||
surf_map["contours"] = event.contours
|
||||
self._n_contours = len(event.contours)
|
||||
with disable_ui_events(self):
|
||||
if "contours" in self._widgets:
|
||||
self._widgets["contours"].set_value(len(event.contours))
|
||||
self._update()
|
||||
|
||||
def set_time(self, time):
|
||||
"""Set the time to display (in seconds).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
time : float
|
||||
The time to show, in seconds.
|
||||
"""
|
||||
if self._evoked.times[0] <= time <= self._evoked.times[-1]:
|
||||
publish(self, TimeChange(time=time))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Requested time ({time} s) is outside the range of "
|
||||
f"available times ({self._evoked.times[0]}-{self._evoked.times[-1]} s)."
|
||||
)
|
||||
|
||||
def set_contours(self, n_contours):
|
||||
"""Adjust the number of contour lines to use when drawing the fieldlines.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_contours : int
|
||||
The number of contour lines to use.
|
||||
"""
|
||||
for surf_map in self._surf_maps:
|
||||
publish(
|
||||
self,
|
||||
Contours(
|
||||
kind=f"field_strength_{surf_map['map_kind']}",
|
||||
contours=np.linspace(
|
||||
-surf_map["map_vmax"], surf_map["map_vmax"], n_contours
|
||||
).tolist(),
|
||||
),
|
||||
)
|
||||
|
||||
def set_vmax(self, vmax, kind="meg"):
|
||||
"""Change the color range of the density maps.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vmax : float
|
||||
The new maximum value of the color range.
|
||||
kind : 'meg' | 'eeg'
|
||||
Which field map to apply the new color range to.
|
||||
"""
|
||||
_check_option("type", kind, ["eeg", "meg"])
|
||||
for surf_map in self._surf_maps:
|
||||
if surf_map["map_kind"] == kind:
|
||||
publish(
|
||||
self,
|
||||
ColormapRange(
|
||||
kind=f"field_strength_{kind}",
|
||||
fmin=-vmax,
|
||||
fmax=vmax,
|
||||
),
|
||||
)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"No {type.upper()} field map currently shown.")
|
||||
|
||||
def _rescale(self):
|
||||
"""Rescale the fieldlines and density maps to the current time point."""
|
||||
for surf_map in self._surf_maps:
|
||||
current_data = surf_map["data_interp"](self._current_time)
|
||||
vmax = float(np.max(current_data))
|
||||
self.set_vmax(vmax, kind=surf_map["map_kind"])
|
||||
7
mne/viz/eyetracking/__init__.py
Normal file
7
mne/viz/eyetracking/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Eye-tracking visualization routines."""
|
||||
#
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from .heatmap import plot_gaze
|
||||
209
mne/viz/eyetracking/heatmap.py
Normal file
209
mne/viz/eyetracking/heatmap.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
from scipy.ndimage import gaussian_filter
|
||||
|
||||
from ..._fiff.constants import FIFF
|
||||
from ...utils import _validate_type, fill_doc, logger
|
||||
from ..utils import plt_show
|
||||
|
||||
|
||||
@fill_doc
|
||||
def plot_gaze(
|
||||
epochs,
|
||||
*,
|
||||
calibration=None,
|
||||
width=None,
|
||||
height=None,
|
||||
sigma=25,
|
||||
cmap=None,
|
||||
alpha=1.0,
|
||||
vlim=(None, None),
|
||||
axes=None,
|
||||
show=True,
|
||||
):
|
||||
"""Plot a heatmap of eyetracking gaze data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
epochs : instance of Epochs
|
||||
The :class:`~mne.Epochs` object containing eyegaze channels.
|
||||
calibration : instance of Calibration | None
|
||||
An instance of Calibration with information about the screen size, distance,
|
||||
and resolution. If ``None``, you must provide a width and height.
|
||||
width : int
|
||||
The width dimension of the plot canvas, only valid if eyegaze data are in
|
||||
pixels. For example, if the participant screen resolution was 1920x1080, then
|
||||
the width should be 1920.
|
||||
height : int
|
||||
The height dimension of the plot canvas, only valid if eyegaze data are in
|
||||
pixels. For example, if the participant screen resolution was 1920x1080, then
|
||||
the height should be 1080.
|
||||
sigma : float | None
|
||||
The amount of Gaussian smoothing applied to the heatmap data (standard
|
||||
deviation in pixels). If ``None``, no smoothing is applied. Default is 25.
|
||||
%(cmap)s
|
||||
alpha : float
|
||||
The opacity of the heatmap (default is 1).
|
||||
%(vlim_plot_topomap)s
|
||||
%(axes_plot_topomap)s
|
||||
%(show)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
fig : instance of Figure
|
||||
The resulting figure object for the heatmap plot.
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 1.6
|
||||
"""
|
||||
from mne import BaseEpochs
|
||||
from mne._fiff.pick import _picks_to_idx
|
||||
|
||||
from ...preprocessing.eyetracking.utils import (
|
||||
_check_calibration,
|
||||
get_screen_visual_angle,
|
||||
)
|
||||
|
||||
_validate_type(epochs, BaseEpochs, "epochs")
|
||||
_validate_type(alpha, "numeric", "alpha")
|
||||
_validate_type(sigma, ("numeric", None), "sigma")
|
||||
|
||||
# Get the gaze data
|
||||
pos_picks = _picks_to_idx(epochs.info, "eyegaze")
|
||||
gaze_data = epochs.get_data(picks=pos_picks)
|
||||
gaze_ch_loc = np.array([epochs.info["chs"][idx]["loc"] for idx in pos_picks])
|
||||
x_data = gaze_data[:, np.where(gaze_ch_loc[:, 4] == -1)[0], :]
|
||||
y_data = gaze_data[:, np.where(gaze_ch_loc[:, 4] == 1)[0], :]
|
||||
unit = epochs.info["chs"][pos_picks[0]]["unit"] # assumes all units are the same
|
||||
|
||||
if x_data.shape[1] > 1: # binocular recording. Average across eyes
|
||||
logger.info("Detected binocular recording. Averaging positions across eyes.")
|
||||
x_data = np.nanmean(x_data, axis=1) # shape (n_epochs, n_samples)
|
||||
y_data = np.nanmean(y_data, axis=1)
|
||||
canvas = np.vstack((x_data.flatten(), y_data.flatten())) # shape (2, n_samples)
|
||||
|
||||
# Check that we have the right inputs
|
||||
if calibration is not None:
|
||||
if width is not None or height is not None:
|
||||
raise ValueError(
|
||||
"If a calibration is provided, you cannot provide a width or height"
|
||||
" to plot heatmaps. Please provide only the calibration object."
|
||||
)
|
||||
_check_calibration(calibration)
|
||||
if unit == FIFF.FIFF_UNIT_PX:
|
||||
width, height = calibration["screen_resolution"]
|
||||
elif unit == FIFF.FIFF_UNIT_RAD:
|
||||
width, height = calibration["screen_size"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid unit type: {unit}. gaze data Must be pixels or radians."
|
||||
)
|
||||
else:
|
||||
if width is None or height is None:
|
||||
raise ValueError(
|
||||
"If no calibration is provided, you must provide a width and height"
|
||||
" to plot heatmaps."
|
||||
)
|
||||
|
||||
# Create 2D histogram
|
||||
# We need to set the histogram bins & bounds, and imshow extent, based on the units
|
||||
if unit == FIFF.FIFF_UNIT_PX: # pixel on screen
|
||||
_range = [[0, height], [0, width]]
|
||||
bins_x, bins_y = width, height
|
||||
extent = [0, width, height, 0]
|
||||
elif unit == FIFF.FIFF_UNIT_RAD: # radians of visual angle
|
||||
if not calibration:
|
||||
raise ValueError(
|
||||
"If gaze data are in Radians, you must provide a"
|
||||
" calibration instance to plot heatmaps."
|
||||
)
|
||||
width, height = get_screen_visual_angle(calibration)
|
||||
x_range = [-width / 2, width / 2]
|
||||
y_range = [-height / 2, height / 2]
|
||||
_range = [y_range, x_range]
|
||||
extent = (x_range[0], x_range[1], y_range[0], y_range[1])
|
||||
bins_x, bins_y = calibration["screen_resolution"]
|
||||
|
||||
hist, _, _ = np.histogram2d(
|
||||
canvas[1, :],
|
||||
canvas[0, :],
|
||||
bins=(bins_y, bins_x),
|
||||
range=_range,
|
||||
)
|
||||
# Convert density from samples to seconds
|
||||
hist /= epochs.info["sfreq"]
|
||||
# Smooth the heatmap
|
||||
if sigma:
|
||||
hist = gaussian_filter(hist, sigma=sigma)
|
||||
|
||||
return _plot_heatmap_array(
|
||||
hist,
|
||||
width=width,
|
||||
height=height,
|
||||
cmap=cmap,
|
||||
alpha=alpha,
|
||||
vmin=vlim[0],
|
||||
vmax=vlim[1],
|
||||
extent=extent,
|
||||
axes=axes,
|
||||
show=show,
|
||||
)
|
||||
|
||||
|
||||
def _plot_heatmap_array(
|
||||
data,
|
||||
width,
|
||||
height,
|
||||
*,
|
||||
cmap=None,
|
||||
alpha=None,
|
||||
vmin=None,
|
||||
vmax=None,
|
||||
extent=None,
|
||||
axes=None,
|
||||
show=True,
|
||||
):
|
||||
"""Plot a heatmap of eyetracking gaze data from a numpy array."""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Prepare axes
|
||||
if axes is not None:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
_validate_type(axes, Axes, "axes")
|
||||
ax = axes
|
||||
fig = ax.get_figure()
|
||||
else:
|
||||
fig, ax = plt.subplots(constrained_layout=True)
|
||||
|
||||
ax.set_title("Gaze heatmap")
|
||||
ax.set_xlabel("X position")
|
||||
ax.set_ylabel("Y position")
|
||||
|
||||
# Prepare the heatmap
|
||||
alphas = 1 if alpha is None else alpha
|
||||
vmin = np.nanmin(data) if vmin is None else vmin
|
||||
vmax = np.nanmax(data) if vmax is None else vmax
|
||||
if extent is None:
|
||||
extent = [0, width, height, 0]
|
||||
|
||||
# Plot heatmap
|
||||
im = ax.imshow(
|
||||
data,
|
||||
aspect="equal",
|
||||
cmap=cmap,
|
||||
alpha=alphas,
|
||||
extent=extent,
|
||||
origin="upper",
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
)
|
||||
|
||||
# Prepare the colorbar
|
||||
fig.colorbar(im, ax=ax, shrink=0.6, label="Dwell time (seconds)")
|
||||
plt_show(show)
|
||||
return fig
|
||||
1460
mne/viz/ica.py
Normal file
1460
mne/viz/ica.py
Normal file
File diff suppressed because it is too large
Load Diff
1645
mne/viz/misc.py
Normal file
1645
mne/viz/misc.py
Normal file
File diff suppressed because it is too large
Load Diff
135
mne/viz/montage.py
Normal file
135
mne/viz/montage.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
"""Functions to plot EEG sensor montages or digitizer montages."""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from .._fiff._digitization import _get_fid_coords
|
||||
from .._fiff.meas_info import create_info
|
||||
from ..utils import _check_option, _validate_type, logger, verbose
|
||||
from .utils import plot_sensors
|
||||
|
||||
|
||||
@verbose
|
||||
def plot_montage(
|
||||
montage,
|
||||
*,
|
||||
scale=None,
|
||||
scale_factor=None,
|
||||
show_names=True,
|
||||
kind="topomap",
|
||||
show=True,
|
||||
sphere=None,
|
||||
axes=None,
|
||||
verbose=None,
|
||||
):
|
||||
"""Plot a montage.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
montage : instance of DigMontage
|
||||
The montage to visualize.
|
||||
scale : float
|
||||
Determines the scale of the channel points and labels; values < 1 will scale
|
||||
down, whereas values > 1 will scale up. Default to None, which implies 1.
|
||||
scale_factor : float
|
||||
Determines the size of the points. Deprecated, use scale instead.
|
||||
show_names : bool | list
|
||||
Whether to display all channel names. If a list, only the channel
|
||||
names in the list are shown. Defaults to True.
|
||||
kind : str
|
||||
Whether to plot the montage as '3d' or 'topomap' (default).
|
||||
show : bool
|
||||
Show figure if True.
|
||||
%(sphere_topomap_auto)s
|
||||
%(axes_montage)s
|
||||
|
||||
.. versionadded:: 1.4
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
fig : instance of matplotlib.figure.Figure
|
||||
The figure object.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ..channels import DigMontage, make_dig_montage
|
||||
|
||||
if scale_factor is not None:
|
||||
msg = "scale_factor has been deprecated and will be removed. Use scale instead."
|
||||
if scale is not None:
|
||||
raise ValueError(
|
||||
" ".join(["scale and scale_factor cannot be used together.", msg])
|
||||
)
|
||||
logger.info(msg)
|
||||
if scale is None:
|
||||
scale = 1
|
||||
|
||||
_check_option("kind", kind, ["topomap", "3d"])
|
||||
_validate_type(montage, DigMontage, item_name="montage")
|
||||
ch_names = montage.ch_names
|
||||
title = None
|
||||
|
||||
if len(ch_names) == 0:
|
||||
raise RuntimeError("No valid channel positions found.")
|
||||
|
||||
pos = np.array(list(montage._get_ch_pos().values()))
|
||||
|
||||
dists = cdist(pos, pos)
|
||||
|
||||
# only consider upper triangular part by setting the rest to np.nan
|
||||
dists[np.tril_indices(dists.shape[0])] = np.nan
|
||||
dupes = np.argwhere(np.isclose(dists, 0))
|
||||
if dupes.any():
|
||||
montage = deepcopy(montage)
|
||||
n_chans = pos.shape[0]
|
||||
n_dupes = dupes.shape[0]
|
||||
idx = np.setdiff1d(np.arange(len(pos)), dupes[:, 1]).tolist()
|
||||
logger.info(f"{n_dupes} duplicate electrode labels found:")
|
||||
logger.info(", ".join([ch_names[d[0]] + "/" + ch_names[d[1]] for d in dupes]))
|
||||
logger.info(f"Plotting {n_chans - n_dupes} unique labels.")
|
||||
ch_names = [ch_names[i] for i in idx]
|
||||
ch_pos = dict(zip(ch_names, pos[idx, :]))
|
||||
# XXX: this might cause trouble if montage was originally in head
|
||||
fid, _ = _get_fid_coords(montage.dig)
|
||||
montage = make_dig_montage(ch_pos=ch_pos, **fid)
|
||||
|
||||
info = create_info(ch_names, sfreq=256, ch_types="eeg")
|
||||
info.set_montage(montage, on_missing="ignore")
|
||||
fig = plot_sensors(
|
||||
info,
|
||||
kind=kind,
|
||||
show_names=show_names,
|
||||
show=show,
|
||||
title=title,
|
||||
sphere=sphere,
|
||||
axes=axes,
|
||||
)
|
||||
|
||||
if scale_factor is not None:
|
||||
# scale points
|
||||
collection = fig.axes[0].collections[0]
|
||||
collection.set_sizes([scale_factor])
|
||||
elif scale is not None:
|
||||
# scale points
|
||||
collection = fig.axes[0].collections[0]
|
||||
collection.set_sizes([scale * 10])
|
||||
|
||||
# scale labels
|
||||
labels = fig.findobj(match=plt.Text)
|
||||
x_label, y_label = fig.axes[0].xaxis.label, fig.axes[0].yaxis.label
|
||||
z_label = fig.axes[0].zaxis.label if kind == "3d" else None
|
||||
tick_labels = fig.axes[0].get_xticklabels() + fig.axes[0].get_yticklabels()
|
||||
if kind == "3d":
|
||||
tick_labels += fig.axes[0].get_zticklabels()
|
||||
for label in labels:
|
||||
if label not in [x_label, y_label, z_label] + tick_labels:
|
||||
label.set_fontsize(label.get_fontsize() * scale)
|
||||
|
||||
return fig
|
||||
635
mne/viz/raw.py
Normal file
635
mne/viz/raw.py
Normal file
@@ -0,0 +1,635 @@
|
||||
"""Functions to plot raw M/EEG data."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.pick import _picks_to_idx, pick_channels, pick_types
|
||||
from ..defaults import _handle_default
|
||||
from ..filter import create_filter
|
||||
from ..utils import _check_option, _get_stim_channel, _validate_type, legacy, verbose
|
||||
from ..utils.spectrum import _split_psd_kwargs
|
||||
from .utils import (
|
||||
_check_cov,
|
||||
_compute_scalings,
|
||||
_get_channel_plotting_order,
|
||||
_handle_decim,
|
||||
_handle_precompute,
|
||||
_make_event_color_dict,
|
||||
_shorten_path_from_middle,
|
||||
)
|
||||
|
||||
_RAW_CLIP_DEF = 1.5
|
||||
|
||||
|
||||
@verbose
|
||||
def plot_raw(
|
||||
raw,
|
||||
events=None,
|
||||
duration=10.0,
|
||||
start=0.0,
|
||||
n_channels=20,
|
||||
bgcolor="w",
|
||||
color=None,
|
||||
bad_color="lightgray",
|
||||
event_color="cyan",
|
||||
scalings=None,
|
||||
remove_dc=True,
|
||||
order=None,
|
||||
show_options=False,
|
||||
title=None,
|
||||
show=True,
|
||||
block=False,
|
||||
highpass=None,
|
||||
lowpass=None,
|
||||
filtorder=4,
|
||||
clipping=_RAW_CLIP_DEF,
|
||||
show_first_samp=False,
|
||||
proj=True,
|
||||
group_by="type",
|
||||
butterfly=False,
|
||||
decim="auto",
|
||||
noise_cov=None,
|
||||
event_id=None,
|
||||
show_scrollbars=True,
|
||||
show_scalebars=True,
|
||||
time_format="float",
|
||||
precompute=None,
|
||||
use_opengl=None,
|
||||
picks=None,
|
||||
*,
|
||||
theme=None,
|
||||
overview_mode=None,
|
||||
splash=True,
|
||||
verbose=None,
|
||||
):
|
||||
"""Plot raw data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The raw data to plot.
|
||||
events : array | None
|
||||
Events to show with vertical bars.
|
||||
duration : float
|
||||
Time window (s) to plot. The lesser of this value and the duration
|
||||
of the raw file will be used.
|
||||
start : float
|
||||
Initial time to show (can be changed dynamically once plotted). If
|
||||
show_first_samp is True, then it is taken relative to
|
||||
``raw.first_samp``.
|
||||
n_channels : int
|
||||
Number of channels to plot at once. Defaults to 20. The lesser of
|
||||
``n_channels`` and ``len(raw.ch_names)`` will be shown.
|
||||
Has no effect if ``order`` is 'position', 'selection' or 'butterfly'.
|
||||
bgcolor : color object
|
||||
Color of the background.
|
||||
color : dict | color object | None
|
||||
Color for the data traces. If None, defaults to::
|
||||
|
||||
dict(mag='darkblue', grad='b', eeg='k', eog='k', ecg='m',
|
||||
emg='k', ref_meg='steelblue', misc='k', stim='k',
|
||||
resp='k', chpi='k')
|
||||
|
||||
bad_color : color object
|
||||
Color to make bad channels.
|
||||
%(event_color)s
|
||||
Defaults to ``'cyan'``.
|
||||
%(scalings)s
|
||||
remove_dc : bool
|
||||
If True remove DC component when plotting data.
|
||||
order : array of int | None
|
||||
Order in which to plot data. If the array is shorter than the number of
|
||||
channels, only the given channels are plotted. If None (default), all
|
||||
channels are plotted. If ``group_by`` is ``'position'`` or
|
||||
``'selection'``, the ``order`` parameter is used only for selecting the
|
||||
channels to be plotted.
|
||||
show_options : bool
|
||||
If True, a dialog for options related to projection is shown.
|
||||
title : str | None
|
||||
The title of the window. If None, and either the filename of the
|
||||
raw object or '<unknown>' will be displayed as title.
|
||||
show : bool
|
||||
Show figure if True.
|
||||
block : bool
|
||||
Whether to halt program execution until the figure is closed.
|
||||
Useful for setting bad channels on the fly by clicking on a line.
|
||||
May not work on all systems / platforms.
|
||||
(Only Qt) If you run from a script, this needs to
|
||||
be ``True`` or a Qt-eventloop needs to be started somewhere
|
||||
else in the script (e.g. if you want to implement the browser
|
||||
inside another Qt-Application).
|
||||
highpass : float | None
|
||||
Highpass to apply when displaying data.
|
||||
lowpass : float | None
|
||||
Lowpass to apply when displaying data.
|
||||
If highpass > lowpass, a bandstop rather than bandpass filter
|
||||
will be applied.
|
||||
filtorder : int
|
||||
Filtering order. 0 will use FIR filtering with MNE defaults.
|
||||
Other values will construct an IIR filter of the given order
|
||||
and apply it with :func:`~scipy.signal.filtfilt` (making the effective
|
||||
order twice ``filtorder``). Filtering may produce some edge artifacts
|
||||
(at the left and right edges) of the signals during display.
|
||||
|
||||
.. versionchanged:: 0.18
|
||||
Support for ``filtorder=0`` to use FIR filtering.
|
||||
clipping : str | float | None
|
||||
If None, channels are allowed to exceed their designated bounds in
|
||||
the plot. If "clamp", then values are clamped to the appropriate
|
||||
range for display, creating step-like artifacts. If "transparent",
|
||||
then excessive values are not shown, creating gaps in the traces.
|
||||
If float, clipping occurs for values beyond the ``clipping`` multiple
|
||||
of their dedicated range, so ``clipping=1.`` is an alias for
|
||||
``clipping='transparent'``.
|
||||
|
||||
.. versionchanged:: 0.21
|
||||
Support for float, and default changed from None to 1.5.
|
||||
show_first_samp : bool
|
||||
If True, show time axis relative to the ``raw.first_samp``.
|
||||
proj : bool
|
||||
Whether to apply projectors prior to plotting (default is ``True``).
|
||||
Individual projectors can be enabled/disabled interactively (see
|
||||
Notes). This argument only affects the plot; use ``raw.apply_proj()``
|
||||
to modify the data stored in the Raw object.
|
||||
%(group_by_browse)s
|
||||
butterfly : bool
|
||||
Whether to start in butterfly mode. Defaults to False.
|
||||
decim : int | 'auto'
|
||||
Amount to decimate the data during display for speed purposes.
|
||||
You should only decimate if the data are sufficiently low-passed,
|
||||
otherwise aliasing can occur. The 'auto' mode (default) uses
|
||||
the decimation that results in a sampling rate least three times
|
||||
larger than ``min(info['lowpass'], lowpass)`` (e.g., a 40 Hz lowpass
|
||||
will result in at least a 120 Hz displayed sample rate).
|
||||
noise_cov : instance of Covariance | str | None
|
||||
Noise covariance used to whiten the data while plotting.
|
||||
Whitened data channels are scaled by ``scalings['whitened']``,
|
||||
and their channel names are shown in italic.
|
||||
Can be a string to load a covariance from disk.
|
||||
See also :meth:`mne.Evoked.plot_white` for additional inspection
|
||||
of noise covariance properties when whitening evoked data.
|
||||
For data processed with SSS, the effective dependence between
|
||||
magnetometers and gradiometers may introduce differences in scaling,
|
||||
consider using :meth:`mne.Evoked.plot_white`.
|
||||
|
||||
.. versionadded:: 0.16.0
|
||||
event_id : dict | None
|
||||
Event IDs used to show at event markers (default None shows
|
||||
the event numbers).
|
||||
|
||||
.. versionadded:: 0.16.0
|
||||
%(show_scrollbars)s
|
||||
%(show_scalebars)s
|
||||
|
||||
.. versionadded:: 0.20.0
|
||||
%(time_format)s
|
||||
%(precompute)s
|
||||
%(use_opengl)s
|
||||
%(picks_all)s
|
||||
%(theme_pg)s
|
||||
|
||||
.. versionadded:: 1.0
|
||||
%(overview_mode)s
|
||||
|
||||
.. versionadded:: 1.1
|
||||
%(splash)s
|
||||
|
||||
.. versionadded:: 1.6
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
%(browser)s
|
||||
|
||||
Notes
|
||||
-----
|
||||
The arrow keys (up/down/left/right) can typically be used to navigate
|
||||
between channels and time ranges, but this depends on the backend
|
||||
matplotlib is configured to use (e.g., mpl.use('TkAgg') should work). The
|
||||
left/right arrows will scroll by 25%% of ``duration``, whereas
|
||||
shift+left/shift+right will scroll by 100%% of ``duration``. The scaling
|
||||
can be adjusted with - and + (or =) keys. The viewport dimensions can be
|
||||
adjusted with page up/page down and home/end keys. Full screen mode can be
|
||||
toggled with the F11 key, and scrollbars can be hidden/shown by pressing
|
||||
'z'. Right-click a channel label to view its location. To mark or un-mark a
|
||||
channel as bad, click on a channel label or a channel trace. The changes
|
||||
will be reflected immediately in the raw object's ``raw.info['bads']``
|
||||
entry.
|
||||
|
||||
If projectors are present, a button labelled "Prj" in the lower right
|
||||
corner of the plot window opens a secondary control window, which allows
|
||||
enabling/disabling specific projectors individually. This provides a means
|
||||
of interactively observing how each projector would affect the raw data if
|
||||
it were applied.
|
||||
|
||||
Annotation mode is toggled by pressing 'a', butterfly mode by pressing
|
||||
'b', and whitening mode (when ``noise_cov is not None``) by pressing 'w'.
|
||||
By default, the channel means are removed when ``remove_dc`` is set to
|
||||
``True``. This flag can be toggled by pressing 'd'.
|
||||
|
||||
%(notes_2d_backend)s
|
||||
"""
|
||||
from ..annotations import _annotations_starts_stops
|
||||
from ..io import BaseRaw
|
||||
from ._figure import _get_browser
|
||||
|
||||
info = raw.info.copy()
|
||||
sfreq = info["sfreq"]
|
||||
projs = info["projs"]
|
||||
# this will be an attr for which projectors are currently "on" in the plot
|
||||
projs_on = np.full_like(projs, proj, dtype=bool)
|
||||
# disable projs in info if user doesn't want to see them right away
|
||||
if not proj:
|
||||
with info._unlock():
|
||||
info["projs"] = list()
|
||||
|
||||
# handle defaults / check arg validity
|
||||
color = _handle_default("color", color)
|
||||
scalings = _compute_scalings(scalings, raw, remove_dc=remove_dc, duration=duration)
|
||||
if scalings["whitened"] == "auto":
|
||||
scalings["whitened"] = 1.0
|
||||
_validate_type(raw, BaseRaw, "raw", "Raw")
|
||||
decim, picks_data = _handle_decim(info, decim, lowpass)
|
||||
noise_cov = _check_cov(noise_cov, info)
|
||||
units = _handle_default("units", None)
|
||||
unit_scalings = _handle_default("scalings", None)
|
||||
_check_option("group_by", group_by, ("selection", "position", "original", "type"))
|
||||
|
||||
# clipping
|
||||
_validate_type(clipping, (None, "numeric", str), "clipping")
|
||||
if isinstance(clipping, str):
|
||||
_check_option(
|
||||
"clipping", clipping, ("clamp", "transparent"), extra="when a string"
|
||||
)
|
||||
clipping = 1.0 if clipping == "transparent" else clipping
|
||||
elif clipping is not None:
|
||||
clipping = float(clipping)
|
||||
|
||||
# be forgiving if user asks for too much time
|
||||
duration = min(raw.times[-1], float(duration))
|
||||
|
||||
# determine IIR filtering parameters
|
||||
if highpass is not None and highpass <= 0:
|
||||
raise ValueError(f"highpass must be > 0, got {highpass}")
|
||||
if highpass is None and lowpass is None:
|
||||
ba = filt_bounds = None
|
||||
else:
|
||||
filtorder = int(filtorder)
|
||||
if filtorder == 0:
|
||||
method = "fir"
|
||||
iir_params = None
|
||||
else:
|
||||
method = "iir"
|
||||
iir_params = dict(order=filtorder, output="sos", ftype="butter")
|
||||
ba = create_filter(
|
||||
np.zeros((1, int(round(duration * sfreq)))),
|
||||
sfreq,
|
||||
highpass,
|
||||
lowpass,
|
||||
method=method,
|
||||
iir_params=iir_params,
|
||||
)
|
||||
filt_bounds = _annotations_starts_stops(
|
||||
raw, ("edge", "bad_acq_skip"), invert=True
|
||||
)
|
||||
|
||||
# compute event times in seconds
|
||||
if events is not None:
|
||||
event_times = (events[:, 0] - raw.first_samp).astype(float)
|
||||
event_times /= sfreq
|
||||
event_nums = events[:, 2]
|
||||
else:
|
||||
event_times = event_nums = None
|
||||
|
||||
# determine trace order
|
||||
ch_names = np.array(raw.ch_names)
|
||||
ch_types = np.array(raw.get_channel_types())
|
||||
|
||||
picks = _picks_to_idx(info, picks, none="all", exclude=())
|
||||
order = _get_channel_plotting_order(order, ch_types, picks=picks)
|
||||
n_channels = min(info["nchan"], n_channels, len(order))
|
||||
# adjust order based on channel selection, if needed
|
||||
selections = None
|
||||
if group_by in ("selection", "position"):
|
||||
selections = _setup_channel_selections(raw, group_by, order)
|
||||
order = np.concatenate(list(selections.values()))
|
||||
default_selection = list(selections)[0]
|
||||
n_channels = len(selections[default_selection])
|
||||
assert isinstance(order, np.ndarray)
|
||||
assert order.dtype.kind == "i"
|
||||
if order.size == 0:
|
||||
raise RuntimeError("No channels found to plot")
|
||||
|
||||
# handle event colors
|
||||
event_color_dict = _make_event_color_dict(event_color, events, event_id)
|
||||
|
||||
# handle first_samp
|
||||
first_time = raw._first_time if show_first_samp else 0
|
||||
start += first_time
|
||||
event_id_rev = {v: k for k, v in (event_id or {}).items()}
|
||||
|
||||
# generate window title; allow instances without a filename (e.g., ICA)
|
||||
if title is None:
|
||||
title = "<unknown>"
|
||||
fnames = list(tuple(raw.filenames)) # get a list of a copy of the filenames
|
||||
if len(fnames):
|
||||
title = fnames.pop(0)
|
||||
extra = f" ... (+ {len(fnames)} more)" if len(fnames) else ""
|
||||
title = f"{title}{extra}"
|
||||
if len(title) > 60:
|
||||
title = _shorten_path_from_middle(title)
|
||||
elif not isinstance(title, str):
|
||||
raise TypeError(f"title must be None or a string, got a {type(title)}")
|
||||
|
||||
# gather parameters and initialize figure
|
||||
_validate_type(use_opengl, (bool, None), "use_opengl")
|
||||
precompute = _handle_precompute(precompute)
|
||||
params = dict(
|
||||
inst=raw,
|
||||
info=info,
|
||||
# channels and channel order
|
||||
ch_names=ch_names,
|
||||
ch_types=ch_types,
|
||||
ch_order=order,
|
||||
picks=order[:n_channels],
|
||||
n_channels=n_channels,
|
||||
picks_data=picks_data,
|
||||
group_by=group_by,
|
||||
ch_selections=selections,
|
||||
# time
|
||||
t_start=start,
|
||||
duration=duration,
|
||||
n_times=raw.n_times,
|
||||
first_time=first_time,
|
||||
time_format=time_format,
|
||||
decim=decim,
|
||||
# events
|
||||
event_color_dict=event_color_dict,
|
||||
event_times=event_times,
|
||||
event_nums=event_nums,
|
||||
event_id_rev=event_id_rev,
|
||||
# preprocessing
|
||||
projs=projs,
|
||||
projs_on=projs_on,
|
||||
apply_proj=proj,
|
||||
remove_dc=remove_dc,
|
||||
filter_coefs=ba,
|
||||
filter_bounds=filt_bounds,
|
||||
noise_cov=noise_cov,
|
||||
# scalings
|
||||
scalings=scalings,
|
||||
units=units,
|
||||
unit_scalings=unit_scalings,
|
||||
# colors
|
||||
ch_color_bad=bad_color,
|
||||
ch_color_dict=color,
|
||||
# display
|
||||
butterfly=butterfly,
|
||||
clipping=clipping,
|
||||
scrollbars_visible=show_scrollbars,
|
||||
scalebars_visible=show_scalebars,
|
||||
window_title=title,
|
||||
bgcolor=bgcolor,
|
||||
# Qt-specific
|
||||
precompute=precompute,
|
||||
use_opengl=use_opengl,
|
||||
theme=theme,
|
||||
overview_mode=overview_mode,
|
||||
splash=splash,
|
||||
)
|
||||
|
||||
fig = _get_browser(show=show, block=block, **params)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
@legacy(alt="Raw.compute_psd().plot()")
|
||||
@verbose
|
||||
def plot_raw_psd(
|
||||
raw,
|
||||
fmin=0,
|
||||
fmax=np.inf,
|
||||
tmin=None,
|
||||
tmax=None,
|
||||
proj=False,
|
||||
n_fft=None,
|
||||
n_overlap=0,
|
||||
reject_by_annotation=True,
|
||||
picks=None,
|
||||
ax=None,
|
||||
color="black",
|
||||
xscale="linear",
|
||||
area_mode="std",
|
||||
area_alpha=0.33,
|
||||
dB=True,
|
||||
estimate="power",
|
||||
show=True,
|
||||
n_jobs=None,
|
||||
average=False,
|
||||
line_alpha=None,
|
||||
spatial_colors=True,
|
||||
sphere=None,
|
||||
window="hamming",
|
||||
exclude="bads",
|
||||
verbose=None,
|
||||
):
|
||||
"""%(plot_psd_doc)s.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The raw object.
|
||||
%(fmin_fmax_psd)s
|
||||
%(tmin_tmax_psd)s
|
||||
%(proj_psd)s
|
||||
n_fft : int | None
|
||||
Number of points to use in Welch FFT calculations. Default is ``None``,
|
||||
which uses the minimum of 2048 and the number of time points.
|
||||
n_overlap : int
|
||||
The number of points of overlap between blocks. The default value
|
||||
is 0 (no overlap).
|
||||
%(reject_by_annotation_psd)s
|
||||
%(picks_good_data_noref)s
|
||||
%(ax_plot_psd)s
|
||||
%(color_plot_psd)s
|
||||
%(xscale_plot_psd)s
|
||||
%(area_mode_plot_psd)s
|
||||
%(area_alpha_plot_psd)s
|
||||
%(dB_plot_psd)s
|
||||
%(estimate_plot_psd)s
|
||||
%(show)s
|
||||
%(n_jobs)s
|
||||
%(average_plot_psd)s
|
||||
%(line_alpha_plot_psd)s
|
||||
%(spatial_colors_psd)s
|
||||
%(sphere_topomap_auto)s
|
||||
%(window_psd)s
|
||||
|
||||
.. versionadded:: 0.22.0
|
||||
exclude : list of str | 'bads'
|
||||
Channels names to exclude from being shown. If 'bads', the bad channels
|
||||
are excluded. Pass an empty list to plot all channels (including
|
||||
channels marked "bad", if any).
|
||||
|
||||
.. versionadded:: 0.24.0
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
fig : instance of Figure
|
||||
Figure with frequency spectra of the data channels.
|
||||
|
||||
Notes
|
||||
-----
|
||||
%(notes_plot_*_psd_func)s
|
||||
"""
|
||||
from ..time_frequency import Spectrum
|
||||
|
||||
init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot)
|
||||
return raw.compute_psd(**init_kw).plot(**plot_kw)
|
||||
|
||||
|
||||
@legacy(alt="Raw.compute_psd().plot_topo()")
|
||||
@verbose
|
||||
def plot_raw_psd_topo(
|
||||
raw,
|
||||
tmin=0.0,
|
||||
tmax=None,
|
||||
fmin=0.0,
|
||||
fmax=100.0,
|
||||
proj=False,
|
||||
*,
|
||||
n_fft=2048,
|
||||
n_overlap=0,
|
||||
dB=True,
|
||||
layout=None,
|
||||
color="w",
|
||||
fig_facecolor="k",
|
||||
axis_facecolor="k",
|
||||
axes=None,
|
||||
block=False,
|
||||
show=True,
|
||||
n_jobs=None,
|
||||
verbose=None,
|
||||
):
|
||||
"""Plot power spectral density, separately for each channel.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of io.Raw
|
||||
The raw instance to use.
|
||||
%(tmin_tmax_psd)s
|
||||
%(fmin_fmax_psd_topo)s
|
||||
%(proj_psd)s
|
||||
n_fft : int
|
||||
Number of points to use in Welch FFT calculations. Defaults to 2048.
|
||||
n_overlap : int
|
||||
The number of points of overlap between blocks. Defaults to 0
|
||||
(no overlap).
|
||||
%(dB_spectrum_plot_topo)s
|
||||
layout : instance of Layout | None
|
||||
Layout instance specifying sensor positions (does not need to be
|
||||
specified for Neuromag data). If ``None`` (default), the layout is
|
||||
inferred from the data.
|
||||
color : str | tuple
|
||||
A matplotlib-compatible color to use for the curves. Defaults to white.
|
||||
fig_facecolor : str | tuple
|
||||
A matplotlib-compatible color to use for the figure background.
|
||||
Defaults to black.
|
||||
axis_facecolor : str | tuple
|
||||
A matplotlib-compatible color to use for the axis background.
|
||||
Defaults to black.
|
||||
%(axes_spectrum_plot_topo)s
|
||||
block : bool
|
||||
Whether to halt program execution until the figure is closed.
|
||||
May not work on all systems / platforms. Defaults to False.
|
||||
%(show)s
|
||||
%(n_jobs)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
fig : instance of matplotlib.figure.Figure
|
||||
Figure distributing one image per channel across sensor topography.
|
||||
"""
|
||||
from ..time_frequency import Spectrum
|
||||
|
||||
init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot_topo)
|
||||
return raw.compute_psd(**init_kw).plot_topo(**plot_kw)
|
||||
|
||||
|
||||
def _setup_channel_selections(raw, kind, order):
|
||||
"""Get dictionary of channel groupings."""
|
||||
from ..channels import (
|
||||
_EEG_SELECTIONS,
|
||||
_SELECTIONS,
|
||||
_divide_to_regions,
|
||||
read_vectorview_selection,
|
||||
)
|
||||
|
||||
_check_option("group_by", kind, ("position", "selection"))
|
||||
if kind == "position":
|
||||
selections_dict = _divide_to_regions(raw.info)
|
||||
keys = _SELECTIONS[1:] # omit 'Vertex'
|
||||
else: # kind == 'selection'
|
||||
from ..channels.channels import _get_ch_info
|
||||
|
||||
(
|
||||
has_vv_mag,
|
||||
has_vv_grad,
|
||||
*_,
|
||||
has_neuromag_122_grad,
|
||||
has_csd_coils,
|
||||
) = _get_ch_info(raw.info)
|
||||
if not (has_vv_grad or has_vv_mag or has_neuromag_122_grad):
|
||||
raise ValueError(
|
||||
"order='selection' only works for Neuromag "
|
||||
"data. Use order='position' instead."
|
||||
)
|
||||
selections_dict = OrderedDict()
|
||||
# get stim channel (if any)
|
||||
stim_ch = _get_stim_channel(None, raw.info, raise_error=False)
|
||||
stim_ch = stim_ch if len(stim_ch) else [""]
|
||||
stim_ch = pick_channels(raw.ch_names, stim_ch, ordered=False)
|
||||
# loop over regions
|
||||
keys = np.concatenate([_SELECTIONS, _EEG_SELECTIONS])
|
||||
for key in keys:
|
||||
channels = read_vectorview_selection(key, info=raw.info)
|
||||
picks = pick_channels(raw.ch_names, channels, ordered=False)
|
||||
picks = np.intersect1d(picks, order)
|
||||
if not len(picks):
|
||||
continue # omit empty selections
|
||||
selections_dict[key] = np.concatenate([picks, stim_ch])
|
||||
# add misc channels
|
||||
misc = pick_types(
|
||||
raw.info,
|
||||
meg=False,
|
||||
eeg=False,
|
||||
stim=True,
|
||||
eog=True,
|
||||
ecg=True,
|
||||
emg=True,
|
||||
ref_meg=False,
|
||||
misc=True,
|
||||
resp=True,
|
||||
chpi=True,
|
||||
exci=True,
|
||||
ias=True,
|
||||
syst=True,
|
||||
seeg=False,
|
||||
bio=True,
|
||||
ecog=False,
|
||||
fnirs=False,
|
||||
dbs=False,
|
||||
temperature=True,
|
||||
gsr=True,
|
||||
exclude=(),
|
||||
)
|
||||
if len(misc) and np.isin(misc, order).any():
|
||||
selections_dict["Misc"] = misc
|
||||
return selections_dict
|
||||
1309
mne/viz/topo.py
Normal file
1309
mne/viz/topo.py
Normal file
File diff suppressed because it is too large
Load Diff
4102
mne/viz/topomap.py
Normal file
4102
mne/viz/topomap.py
Normal file
File diff suppressed because it is too large
Load Diff
480
mne/viz/ui_events.py
Normal file
480
mne/viz/ui_events.py
Normal file
@@ -0,0 +1,480 @@
|
||||
"""
|
||||
Event API for inter-figure communication.
|
||||
|
||||
The event API allows figures to communicate with each other, such that a change
|
||||
in one figure can trigger a change in another figure. For example, moving the
|
||||
time cursor in one plot can update the current time in another plot. Another
|
||||
scenario is two drawing routines drawing into the same window, using events to
|
||||
stay in-sync.
|
||||
|
||||
Authors: Marijn van Vliet <w.m.vanvliet@gmail.com>
|
||||
"""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from __future__ import annotations # only needed for Python ≤ 3.9
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
|
||||
from matplotlib.colors import Colormap
|
||||
|
||||
from ..utils import _validate_type, fill_doc, logger, verbose, warn
|
||||
|
||||
# Global dict {fig: channel} containing all currently active event channels.
|
||||
_event_channels = weakref.WeakKeyDictionary()
|
||||
|
||||
# The event channels of figures can be linked together. This dict keeps track
|
||||
# of these links. Links are bi-directional, so if {fig1: fig2} exists, then so
|
||||
# must {fig2: fig1}.
|
||||
_event_channel_links = weakref.WeakKeyDictionary()
|
||||
|
||||
# Event channels that are temporarily disabled by the disable_ui_events context
|
||||
# manager.
|
||||
_disabled_event_channels = weakref.WeakSet()
|
||||
|
||||
# Regex pattern used when converting CamelCase to snake_case.
|
||||
# Detects all capital letters that are not at the beginning of a word.
|
||||
_camel_to_snake = re.compile(r"(?<!^)(?=[A-Z])")
|
||||
|
||||
|
||||
# List of events
|
||||
@fill_doc
|
||||
class UIEvent:
|
||||
"""Abstract base class for all events.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
%(ui_event_name_source)s
|
||||
"""
|
||||
|
||||
source = None
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""The name of the event, which is the class name in snake case."""
|
||||
return _camel_to_snake.sub("_", self.__class__.__name__).lower()
|
||||
|
||||
|
||||
@fill_doc
|
||||
class FigureClosing(UIEvent):
|
||||
"""Indicates that the user has requested to close a figure.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
%(ui_event_name_source)s
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
@fill_doc
|
||||
class TimeChange(UIEvent):
|
||||
"""Indicates that the user has selected a time.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
time : float
|
||||
The new time in seconds.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
%(ui_event_name_source)s
|
||||
time : float
|
||||
The new time in seconds.
|
||||
"""
|
||||
|
||||
time: float
|
||||
|
||||
|
||||
@dataclass
|
||||
@fill_doc
|
||||
class PlaybackSpeed(UIEvent):
|
||||
"""Indicates that the user has selected a different playback speed for videos.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
speed : float
|
||||
The new speed in seconds per frame.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
%(ui_event_name_source)s
|
||||
speed : float
|
||||
The new speed in seconds per frame.
|
||||
"""
|
||||
|
||||
speed: float
|
||||
|
||||
|
||||
@dataclass
|
||||
@fill_doc
|
||||
class ColormapRange(UIEvent):
|
||||
"""Indicates that the user has updated the bounds of the colormap.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kind : str
|
||||
Kind of colormap being updated. The Notes section of the drawing
|
||||
routine publishing this event should mention the possible kinds.
|
||||
ch_type : str
|
||||
Type of sensor the data originates from.
|
||||
%(fmin_fmid_fmax)s
|
||||
%(alpha)s
|
||||
cmap : str
|
||||
The colormap to use. Either string or matplotlib.colors.Colormap
|
||||
instance.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
kind : str
|
||||
Kind of colormap being updated. The Notes section of the drawing
|
||||
routine publishing this event should mention the possible kinds.
|
||||
ch_type : str
|
||||
Type of sensor the data originates from.
|
||||
unit : str
|
||||
The unit of the values.
|
||||
%(ui_event_name_source)s
|
||||
%(fmin_fmid_fmax)s
|
||||
%(alpha)s
|
||||
cmap : str
|
||||
The colormap to use. Either string or matplotlib.colors.Colormap
|
||||
instance.
|
||||
"""
|
||||
|
||||
kind: str
|
||||
ch_type: str | None = None
|
||||
fmin: float | None = None
|
||||
fmid: float | None = None
|
||||
fmax: float | None = None
|
||||
alpha: bool | None = None
|
||||
cmap: Colormap | str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@fill_doc
|
||||
class VertexSelect(UIEvent):
|
||||
"""Indicates that the user has selected a vertex.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
hemi : str
|
||||
The hemisphere the vertex was selected on.
|
||||
Can be ``"lh"``, ``"rh"``, or ``"vol"``.
|
||||
vertex_id : int
|
||||
The vertex number (in the high resolution mesh) that was selected.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
%(ui_event_name_source)s
|
||||
hemi : str
|
||||
The hemisphere the vertex was selected on.
|
||||
Can be ``"lh"``, ``"rh"``, or ``"vol"``.
|
||||
vertex_id : int
|
||||
The vertex number (in the high resolution mesh) that was selected.
|
||||
"""
|
||||
|
||||
hemi: str
|
||||
vertex_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
@fill_doc
|
||||
class Contours(UIEvent):
|
||||
"""Indicates that the user has changed the contour lines.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kind : str
|
||||
The kind of contours lines being changed. The Notes section of the
|
||||
drawing routine publishing this event should mention the possible
|
||||
kinds.
|
||||
contours : list of float
|
||||
The new values at which contour lines need to be drawn.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
%(ui_event_name_source)s
|
||||
kind : str
|
||||
The kind of contours lines being changed. The Notes section of the
|
||||
drawing routine publishing this event should mention the possible
|
||||
kinds.
|
||||
contours : list of float
|
||||
The new values at which contour lines need to be drawn.
|
||||
"""
|
||||
|
||||
kind: str
|
||||
contours: list[str]
|
||||
|
||||
|
||||
def _get_event_channel(fig):
|
||||
"""Get the event channel associated with a figure.
|
||||
|
||||
If the event channel doesn't exist yet, it gets created and added to the
|
||||
global ``_event_channels`` dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fig : matplotlib.figure.Figure | Figure3D
|
||||
The figure to get the event channel for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
channel : dict[event -> list]
|
||||
The event channel. An event channel is a list mapping string event
|
||||
names to a list of callback representing all subscribers to the
|
||||
channel.
|
||||
"""
|
||||
import matplotlib
|
||||
|
||||
from ._brain import Brain
|
||||
from .evoked_field import EvokedField
|
||||
|
||||
# Create the event channel if it doesn't exist yet
|
||||
if fig not in _event_channels:
|
||||
# The channel itself is a dict mapping string event names to a list of
|
||||
# subscribers. No subscribers yet for this new event channel.
|
||||
_event_channels[fig] = dict()
|
||||
|
||||
weakfig = weakref.ref(fig)
|
||||
|
||||
# When the figure is closed, its associated event channel should be
|
||||
# deleted. This is a good time to set this up.
|
||||
def delete_event_channel(event=None, *, weakfig=weakfig):
|
||||
"""Delete the event channel (callback function)."""
|
||||
fig = weakfig()
|
||||
if fig is None:
|
||||
return
|
||||
publish(fig, event=FigureClosing()) # Notify subscribers of imminent close
|
||||
logger.debug(f"unlink(({fig})")
|
||||
unlink(fig) # Remove channel from the _event_channel_links dict
|
||||
if fig in _event_channels:
|
||||
logger.debug(f" del _event_channels[{fig}]")
|
||||
del _event_channels[fig]
|
||||
if fig in _disabled_event_channels:
|
||||
logger.debug(f" _disabled_event_channels.remove({fig})")
|
||||
_disabled_event_channels.remove(fig)
|
||||
|
||||
# Hook up the above callback function to the close event of the figure
|
||||
# window. How this is done exactly depends on the various figure types
|
||||
# MNE-Python has.
|
||||
_validate_type(fig, (matplotlib.figure.Figure, Brain, EvokedField), "fig")
|
||||
if isinstance(fig, matplotlib.figure.Figure):
|
||||
fig.canvas.mpl_connect("close_event", delete_event_channel)
|
||||
else:
|
||||
assert hasattr(fig, "_renderer") # figures like Brain, EvokedField, etc.
|
||||
fig._renderer._window_close_connect(delete_event_channel, after=False)
|
||||
|
||||
# Now the event channel exists for sure.
|
||||
return _event_channels[fig]
|
||||
|
||||
|
||||
@verbose
|
||||
def publish(fig, event, *, verbose=None):
|
||||
"""Publish an event to all subscribers of the figure's channel.
|
||||
|
||||
The figure's event channel and all linked event channels are searched for
|
||||
subscribers to the given event. Each subscriber had provided a callback
|
||||
function when subscribing, so we call that.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fig : matplotlib.figure.Figure | Figure3D
|
||||
The figure that publishes the event.
|
||||
event : UIEvent
|
||||
Event to publish.
|
||||
%(verbose)s
|
||||
"""
|
||||
if fig in _disabled_event_channels:
|
||||
return
|
||||
|
||||
# Compile a list of all event channels that the event should be published
|
||||
# on.
|
||||
channels = [_get_event_channel(fig)]
|
||||
links = _event_channel_links.get(fig, None)
|
||||
if links is not None:
|
||||
for linked_fig, (include_events, exclude_events) in links.items():
|
||||
if (include_events is None or event.name in include_events) and (
|
||||
exclude_events is None or event.name not in exclude_events
|
||||
):
|
||||
channels.append(_get_event_channel(linked_fig))
|
||||
|
||||
# Publish the event by calling the registered callback functions.
|
||||
event.source = fig
|
||||
logger.debug(f"Publishing {event} on channel {fig}")
|
||||
for channel in channels:
|
||||
if event.name not in channel:
|
||||
channel[event.name] = set()
|
||||
for callback in channel[event.name]:
|
||||
callback(event=event)
|
||||
|
||||
|
||||
@verbose
|
||||
def subscribe(fig, event_name, callback, *, verbose=None):
|
||||
"""Subscribe to an event on a figure's event channel.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fig : matplotlib.figure.Figure | Figure3D
|
||||
The figure of which event channel to subscribe.
|
||||
event_name : str
|
||||
The name of the event to listen for.
|
||||
callback : callable
|
||||
The function that should be called whenever the event is published.
|
||||
%(verbose)s
|
||||
"""
|
||||
channel = _get_event_channel(fig)
|
||||
logger.debug(f"Subscribing to channel {channel}")
|
||||
if event_name not in channel:
|
||||
channel[event_name] = set()
|
||||
channel[event_name].add(callback)
|
||||
|
||||
|
||||
@verbose
|
||||
def unsubscribe(fig, event_names, callback=None, *, verbose=None):
|
||||
"""Unsubscribe from an event on a figure's event channel.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fig : matplotlib.figure.Figure | Figure3D
|
||||
The figure of which event channel to unsubscribe from.
|
||||
event_names : str | list of str
|
||||
Select which events to stop subscribing to. Can be a single string
|
||||
event name, a list of event names or ``"all"`` which will unsubscribe
|
||||
from all events.
|
||||
callback : callable | None
|
||||
The callback function that should be unsubscribed, leaving all other
|
||||
callback functions that may be subscribed untouched. By default
|
||||
(``None``) all callback functions are unsubscribed from the event.
|
||||
%(verbose)s
|
||||
"""
|
||||
channel = _get_event_channel(fig)
|
||||
|
||||
# Determine which events to unsubscribe for.
|
||||
if event_names == "all":
|
||||
if callback is None:
|
||||
event_names = list(channel.keys())
|
||||
else:
|
||||
event_names = list(k for k, v in channel.items() if callback in v)
|
||||
elif isinstance(event_names, str):
|
||||
event_names = [event_names]
|
||||
|
||||
for event_name in event_names:
|
||||
if event_name not in channel:
|
||||
warn(
|
||||
f'Cannot unsubscribe from event "{event_name}" as we have never '
|
||||
"subscribed to it."
|
||||
)
|
||||
continue
|
||||
|
||||
if callback is None:
|
||||
del channel[event_name]
|
||||
else:
|
||||
# Unsubscribe specific callback function.
|
||||
subscribers = channel[event_name]
|
||||
if callback in subscribers:
|
||||
subscribers.remove(callback)
|
||||
else:
|
||||
warn(
|
||||
f'Cannot unsubscribe {callback} from event "{event_name}" '
|
||||
"as it was never subscribed to it."
|
||||
)
|
||||
if len(subscribers) == 0:
|
||||
del channel[event_name] # keep things tidy
|
||||
|
||||
|
||||
@verbose
|
||||
def link(*figs, include_events=None, exclude_events=None, verbose=None):
|
||||
"""Link the event channels of two figures together.
|
||||
|
||||
When event channels are linked, any events that are published on one
|
||||
channel are simultaneously published on the other channel. Links are
|
||||
bi-directional.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
*figs : tuple of matplotlib.figure.Figure | tuple of Figure3D
|
||||
The figures whose event channel will be linked.
|
||||
include_events : list of str | None
|
||||
Select which events to publish across figures. By default (``None``),
|
||||
both figures will receive all of each other's events. Passing a list of
|
||||
event names will restrict the events being shared across the figures to
|
||||
only the given ones.
|
||||
exclude_events : list of str | None
|
||||
Select which events not to publish across figures. By default (``None``),
|
||||
no events are excluded.
|
||||
%(verbose)s
|
||||
"""
|
||||
if include_events is not None:
|
||||
include_events = set(include_events)
|
||||
if exclude_events is not None:
|
||||
exclude_events = set(exclude_events)
|
||||
|
||||
# Make sure the event channels of the figures are setup properly.
|
||||
for fig in figs:
|
||||
_get_event_channel(fig)
|
||||
if fig not in _event_channel_links:
|
||||
_event_channel_links[fig] = weakref.WeakKeyDictionary()
|
||||
|
||||
# Link the event channels
|
||||
for fig1 in figs:
|
||||
for fig2 in figs:
|
||||
if fig1 is not fig2:
|
||||
_event_channel_links[fig1][fig2] = (include_events, exclude_events)
|
||||
|
||||
|
||||
@verbose
|
||||
def unlink(fig, *, verbose=None):
|
||||
"""Remove all links involving the event channel of the given figure.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fig : matplotlib.figure.Figure | Figure3D
|
||||
The figure whose event channel should be unlinked from all other event
|
||||
channels.
|
||||
%(verbose)s
|
||||
"""
|
||||
linked_figs = _event_channel_links.get(fig)
|
||||
if linked_figs is not None:
|
||||
for linked_fig in linked_figs.keys():
|
||||
del _event_channel_links[linked_fig][fig]
|
||||
if len(_event_channel_links[linked_fig]) == 0:
|
||||
del _event_channel_links[linked_fig]
|
||||
if fig in _event_channel_links: # need to check again because of weak refs
|
||||
del _event_channel_links[fig]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def disable_ui_events(fig):
|
||||
"""Temporarily disable generation of UI events. Use as context manager.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fig : matplotlib.figure.Figure | Figure3D
|
||||
The figure whose UI event generation should be temporarily disabled.
|
||||
"""
|
||||
_disabled_event_channels.add(fig)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_disabled_event_channels.remove(fig)
|
||||
|
||||
|
||||
def _cleanup_agg():
|
||||
"""Call close_event for Agg canvases to help our doc build."""
|
||||
import matplotlib.backends.backend_agg
|
||||
import matplotlib.figure
|
||||
|
||||
for key in list(_event_channels): # we might remove keys as we go
|
||||
if isinstance(key, matplotlib.figure.Figure):
|
||||
canvas = key.canvas
|
||||
if isinstance(canvas, matplotlib.backends.backend_agg.FigureCanvasAgg):
|
||||
for cb in key.canvas.callbacks.callbacks["close_event"].values():
|
||||
cb = cb() # get the true ref
|
||||
if cb is not None:
|
||||
cb()
|
||||
2814
mne/viz/utils.py
Normal file
2814
mne/viz/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user