initial commit
This commit is contained in:
+4282
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,184 @@
|
||||
"""Classes to handle overlapping surfaces."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils import logger
|
||||
|
||||
|
||||
class _Overlay:
|
||||
def __init__(self, scalars, colormap, rng, opacity, name):
|
||||
self._scalars = scalars
|
||||
self._colormap = colormap
|
||||
assert rng is not None
|
||||
self._rng = rng
|
||||
self._opacity = opacity
|
||||
self._name = name
|
||||
|
||||
def to_colors(self):
|
||||
from matplotlib.colors import Colormap, ListedColormap
|
||||
|
||||
from ._3d import _get_cmap
|
||||
|
||||
if isinstance(self._colormap, str):
|
||||
cmap = _get_cmap(self._colormap)
|
||||
elif isinstance(self._colormap, Colormap):
|
||||
cmap = self._colormap
|
||||
else:
|
||||
cmap = ListedColormap(
|
||||
self._colormap / 255.0, name=str(type(self._colormap))
|
||||
)
|
||||
logger.debug(
|
||||
f"Color mapping {repr(self._name)} with {cmap.name} "
|
||||
f"colormap and range {self._rng}"
|
||||
)
|
||||
|
||||
rng = self._rng
|
||||
assert rng is not None
|
||||
scalars = self._norm(rng)
|
||||
|
||||
colors = cmap(scalars)
|
||||
if self._opacity is not None:
|
||||
colors[:, 3] *= self._opacity
|
||||
return colors
|
||||
|
||||
def _norm(self, rng):
|
||||
if rng[0] == rng[1]:
|
||||
factor = 1 if rng[0] == 0 else 1e-6 * rng[0]
|
||||
else:
|
||||
factor = rng[1] - rng[0]
|
||||
return (self._scalars - rng[0]) / factor
|
||||
|
||||
|
||||
class _LayeredMesh:
|
||||
def __init__(self, renderer, vertices, triangles, normals):
|
||||
self._renderer = renderer
|
||||
self._vertices = vertices
|
||||
self._triangles = triangles
|
||||
self._normals = normals
|
||||
|
||||
self._polydata = None
|
||||
self._actor = None
|
||||
self._is_mapped = False
|
||||
|
||||
self._current_colors = None
|
||||
self._cached_colors = None
|
||||
self._overlays = OrderedDict()
|
||||
|
||||
self._default_scalars = np.ones(vertices.shape)
|
||||
self._default_scalars_name = "Data"
|
||||
|
||||
def map(self):
|
||||
kwargs = {
|
||||
"color": None,
|
||||
"pickable": True,
|
||||
"rgba": True,
|
||||
}
|
||||
mesh_data = self._renderer.mesh(
|
||||
x=self._vertices[:, 0],
|
||||
y=self._vertices[:, 1],
|
||||
z=self._vertices[:, 2],
|
||||
triangles=self._triangles,
|
||||
normals=self._normals,
|
||||
scalars=self._default_scalars,
|
||||
**kwargs,
|
||||
)
|
||||
self._actor, self._polydata = mesh_data
|
||||
self._is_mapped = True
|
||||
|
||||
def _compute_over(self, B, A):
|
||||
assert A.ndim == B.ndim == 2
|
||||
assert A.shape[1] == B.shape[1] == 4
|
||||
A_w = A[:, 3:] # * 1
|
||||
B_w = B[:, 3:] * (1 - A_w)
|
||||
C = A.copy()
|
||||
C[:, :3] *= A_w
|
||||
C[:, :3] += B[:, :3] * B_w
|
||||
C[:, 3:] += B_w
|
||||
C[:, :3] /= C[:, 3:]
|
||||
return np.clip(C, 0, 1, out=C)
|
||||
|
||||
def _compose_overlays(self):
|
||||
B = cache = None
|
||||
for overlay in self._overlays.values():
|
||||
A = overlay.to_colors()
|
||||
if B is None:
|
||||
B = A
|
||||
else:
|
||||
cache = B
|
||||
B = self._compute_over(cache, A)
|
||||
return B, cache
|
||||
|
||||
def add_overlay(self, scalars, colormap, rng, opacity, name):
|
||||
overlay = _Overlay(
|
||||
scalars=scalars,
|
||||
colormap=colormap,
|
||||
rng=rng,
|
||||
opacity=opacity,
|
||||
name=name,
|
||||
)
|
||||
self._overlays[name] = overlay
|
||||
colors = overlay.to_colors()
|
||||
if self._current_colors is None:
|
||||
self._current_colors = colors
|
||||
else:
|
||||
# save previous colors to cache
|
||||
self._cached_colors = self._current_colors
|
||||
self._current_colors = self._compute_over(self._cached_colors, colors)
|
||||
|
||||
# apply the texture
|
||||
self._apply()
|
||||
|
||||
def remove_overlay(self, names):
|
||||
to_update = False
|
||||
if not isinstance(names, list):
|
||||
names = [names]
|
||||
for name in names:
|
||||
if name in self._overlays:
|
||||
del self._overlays[name]
|
||||
to_update = True
|
||||
if to_update:
|
||||
self.update()
|
||||
|
||||
def _apply(self):
|
||||
if self._current_colors is None or self._renderer is None:
|
||||
return
|
||||
self._polydata[self._default_scalars_name] = self._current_colors
|
||||
|
||||
def update(self, colors=None):
|
||||
if colors is not None and self._cached_colors is not None:
|
||||
self._current_colors = self._compute_over(self._cached_colors, colors)
|
||||
else:
|
||||
self._current_colors, self._cached_colors = self._compose_overlays()
|
||||
self._apply()
|
||||
|
||||
def _clean(self):
|
||||
mapper = self._actor.GetMapper()
|
||||
mapper.SetLookupTable(None)
|
||||
self._actor.SetMapper(None)
|
||||
self._actor = None
|
||||
self._polydata = None
|
||||
self._renderer = None
|
||||
|
||||
def update_overlay(self, name, scalars=None, colormap=None, opacity=None, rng=None):
|
||||
overlay = self._overlays.get(name, None)
|
||||
if overlay is None:
|
||||
return
|
||||
if scalars is not None:
|
||||
overlay._scalars = scalars
|
||||
if colormap is not None:
|
||||
overlay._colormap = colormap
|
||||
if opacity is not None:
|
||||
overlay._opacity = opacity
|
||||
if rng is not None:
|
||||
overlay._rng = rng
|
||||
# partial update: use cache if possible
|
||||
if name == list(self._overlays.keys())[-1]:
|
||||
self.update(colors=overlay.to_colors())
|
||||
else: # full update
|
||||
self.update()
|
||||
@@ -0,0 +1,8 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
"""Visualization routines."""
|
||||
import lazy_loader as lazy
|
||||
|
||||
(__getattr__, __dir__, __all__) = lazy.attach_stub(__name__, __file__)
|
||||
@@ -0,0 +1,177 @@
|
||||
__all__ = [
|
||||
"Brain",
|
||||
"ClickableImage",
|
||||
"EvokedField",
|
||||
"Figure3D",
|
||||
"_RAW_CLIP_DEF",
|
||||
"_get_plot_ch_type",
|
||||
"_get_presser",
|
||||
"_plot_sources",
|
||||
"_scraper",
|
||||
"add_background_image",
|
||||
"adjust_axes",
|
||||
"backends",
|
||||
"centers_to_edges",
|
||||
"circular_layout",
|
||||
"close_3d_figure",
|
||||
"close_all_3d_figures",
|
||||
"compare_fiff",
|
||||
"concatenate_images",
|
||||
"create_3d_figure",
|
||||
"eyetracking",
|
||||
"get_3d_backend",
|
||||
"get_brain_class",
|
||||
"get_browser_backend",
|
||||
"iter_topography",
|
||||
"link_brains",
|
||||
"mne_analyze_colormap",
|
||||
"plot_alignment",
|
||||
"plot_arrowmap",
|
||||
"plot_bem",
|
||||
"plot_brain_colorbar",
|
||||
"plot_bridged_electrodes",
|
||||
"plot_ch_adjacency",
|
||||
"plot_channel_labels_circle",
|
||||
"plot_chpi_snr",
|
||||
"plot_compare_evokeds",
|
||||
"plot_cov",
|
||||
"plot_csd",
|
||||
"plot_dipole_amplitudes",
|
||||
"plot_dipole_locations",
|
||||
"plot_drop_log",
|
||||
"plot_epochs",
|
||||
"plot_epochs_image",
|
||||
"plot_epochs_psd",
|
||||
"plot_epochs_psd_topomap",
|
||||
"plot_events",
|
||||
"plot_evoked",
|
||||
"plot_evoked_field",
|
||||
"plot_evoked_image",
|
||||
"plot_evoked_joint",
|
||||
"plot_evoked_topo",
|
||||
"plot_evoked_topomap",
|
||||
"plot_evoked_white",
|
||||
"plot_filter",
|
||||
"plot_head_positions",
|
||||
"plot_ica_components",
|
||||
"plot_ica_overlay",
|
||||
"plot_ica_properties",
|
||||
"plot_ica_scores",
|
||||
"plot_ica_sources",
|
||||
"plot_ideal_filter",
|
||||
"plot_layout",
|
||||
"plot_montage",
|
||||
"plot_projs_joint",
|
||||
"plot_projs_topomap",
|
||||
"plot_raw",
|
||||
"plot_raw_psd",
|
||||
"plot_raw_psd_topo",
|
||||
"plot_regression_weights",
|
||||
"plot_sensors",
|
||||
"plot_snr_estimate",
|
||||
"plot_source_estimates",
|
||||
"plot_source_spectrogram",
|
||||
"plot_sparse_source_estimates",
|
||||
"plot_tfr_topomap",
|
||||
"plot_topo_image_epochs",
|
||||
"plot_topomap",
|
||||
"plot_vector_source_estimates",
|
||||
"plot_volume_source_estimates",
|
||||
"set_3d_backend",
|
||||
"set_3d_options",
|
||||
"set_3d_title",
|
||||
"set_3d_view",
|
||||
"set_browser_backend",
|
||||
"snapshot_brain_montage",
|
||||
"ui_events",
|
||||
"use_3d_backend",
|
||||
"use_browser_backend",
|
||||
]
|
||||
from . import _scraper, backends, eyetracking, ui_events
|
||||
from ._3d import (
|
||||
link_brains,
|
||||
plot_alignment,
|
||||
plot_brain_colorbar,
|
||||
plot_dipole_locations,
|
||||
plot_evoked_field,
|
||||
plot_head_positions,
|
||||
plot_source_estimates,
|
||||
plot_sparse_source_estimates,
|
||||
plot_vector_source_estimates,
|
||||
plot_volume_source_estimates,
|
||||
set_3d_options,
|
||||
snapshot_brain_montage,
|
||||
)
|
||||
from ._brain import Brain
|
||||
from ._figure import get_browser_backend, set_browser_backend, use_browser_backend
|
||||
from ._proj import plot_projs_joint
|
||||
from .backends._abstract import Figure3D
|
||||
from .backends.renderer import (
|
||||
close_3d_figure,
|
||||
close_all_3d_figures,
|
||||
create_3d_figure,
|
||||
get_3d_backend,
|
||||
get_brain_class,
|
||||
set_3d_backend,
|
||||
set_3d_title,
|
||||
set_3d_view,
|
||||
use_3d_backend,
|
||||
)
|
||||
from .circle import circular_layout, plot_channel_labels_circle
|
||||
from .epochs import plot_drop_log, plot_epochs, plot_epochs_image, plot_epochs_psd
|
||||
from .evoked import (
|
||||
plot_compare_evokeds,
|
||||
plot_evoked,
|
||||
plot_evoked_image,
|
||||
plot_evoked_joint,
|
||||
plot_evoked_topo,
|
||||
plot_evoked_white,
|
||||
plot_snr_estimate,
|
||||
)
|
||||
from .evoked_field import EvokedField
|
||||
from .ica import (
|
||||
_plot_sources,
|
||||
plot_ica_overlay,
|
||||
plot_ica_properties,
|
||||
plot_ica_scores,
|
||||
plot_ica_sources,
|
||||
)
|
||||
from .misc import (
|
||||
_get_presser,
|
||||
adjust_axes,
|
||||
plot_bem,
|
||||
plot_chpi_snr,
|
||||
plot_cov,
|
||||
plot_csd,
|
||||
plot_dipole_amplitudes,
|
||||
plot_events,
|
||||
plot_filter,
|
||||
plot_ideal_filter,
|
||||
plot_source_spectrogram,
|
||||
)
|
||||
from .montage import plot_montage
|
||||
from .raw import _RAW_CLIP_DEF, plot_raw, plot_raw_psd, plot_raw_psd_topo
|
||||
from .topo import iter_topography, plot_topo_image_epochs
|
||||
from .topomap import (
|
||||
plot_arrowmap,
|
||||
plot_bridged_electrodes,
|
||||
plot_ch_adjacency,
|
||||
plot_epochs_psd_topomap,
|
||||
plot_evoked_topomap,
|
||||
plot_ica_components,
|
||||
plot_layout,
|
||||
plot_projs_topomap,
|
||||
plot_regression_weights,
|
||||
plot_tfr_topomap,
|
||||
plot_topomap,
|
||||
)
|
||||
from .utils import (
|
||||
ClickableImage,
|
||||
_get_plot_ch_type,
|
||||
add_background_image,
|
||||
centers_to_edges,
|
||||
compare_fiff,
|
||||
concatenate_images,
|
||||
mne_analyze_colormap,
|
||||
plot_sensors,
|
||||
)
|
||||
@@ -0,0 +1,11 @@
|
||||
"""Plot Cortex Surface."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from ._brain import Brain, _LayeredMesh
|
||||
from ._scraper import _BrainScraper
|
||||
from ._linkviewer import _LinkViewer
|
||||
|
||||
__all__ = ["Brain"]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,98 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...utils import warn
|
||||
from ..ui_events import link
|
||||
|
||||
|
||||
class _LinkViewer:
|
||||
"""Class to link multiple Brain objects."""
|
||||
|
||||
def __init__(self, brains, time=True, camera=False, colorbar=True, picking=False):
|
||||
self.brains = brains
|
||||
self.leader = self.brains[0] # select a brain as leader
|
||||
|
||||
# check time infos
|
||||
times = [brain._times for brain in brains]
|
||||
if time and not all(np.allclose(x, times[0]) for x in times):
|
||||
warn("stc.times do not match, not linking time")
|
||||
time = False
|
||||
|
||||
if camera:
|
||||
self.link_cameras()
|
||||
|
||||
events_to_link = []
|
||||
if time:
|
||||
events_to_link.append("time_change")
|
||||
if colorbar:
|
||||
events_to_link.append("colormap_range")
|
||||
|
||||
for brain in brains[1:]:
|
||||
link(self.leader, brain, include_events=events_to_link)
|
||||
|
||||
if picking:
|
||||
|
||||
def _func_add(*args, **kwargs):
|
||||
for brain in self.brains:
|
||||
brain._add_vertex_glyph2(*args, **kwargs)
|
||||
brain.plotter.update()
|
||||
|
||||
def _func_remove(*args, **kwargs):
|
||||
for brain in self.brains:
|
||||
brain._remove_vertex_glyph2(*args, **kwargs)
|
||||
|
||||
# save initial picked points
|
||||
initial_points = dict()
|
||||
for hemi in ("lh", "rh"):
|
||||
initial_points[hemi] = set()
|
||||
for brain in self.brains:
|
||||
initial_points[hemi] |= set(brain.picked_points[hemi])
|
||||
|
||||
# link the viewers
|
||||
for brain in self.brains:
|
||||
brain.clear_glyphs()
|
||||
brain._add_vertex_glyph2 = brain._add_vertex_glyph
|
||||
brain._add_vertex_glyph = _func_add
|
||||
brain._remove_vertex_glyph2 = brain._remove_vertex_glyph
|
||||
brain._remove_vertex_glyph = _func_remove
|
||||
|
||||
# link the initial points
|
||||
for hemi in initial_points.keys():
|
||||
if hemi in brain._layered_meshes:
|
||||
mesh = brain._layered_meshes[hemi]._polydata
|
||||
for vertex_id in initial_points[hemi]:
|
||||
self.leader._add_vertex_glyph(hemi, mesh, vertex_id)
|
||||
|
||||
def set_fmin(self, value):
|
||||
self.leader.update_lut(fmin=value)
|
||||
|
||||
def set_fmid(self, value):
|
||||
self.leader.update_lut(fmid=value)
|
||||
|
||||
def set_fmax(self, value):
|
||||
self.leader.update_lut(fmax=value)
|
||||
|
||||
def set_time_point(self, value):
|
||||
self.leader.set_time_point(value)
|
||||
|
||||
def set_playback_speed(self, value):
|
||||
self.leader.set_playback_speed(value)
|
||||
|
||||
def toggle_playback(self):
|
||||
self.leader.toggle_playback()
|
||||
|
||||
def link_cameras(self):
|
||||
from ..backends._pyvista import _add_camera_callback
|
||||
|
||||
def _update_camera(vtk_picker, event):
|
||||
for brain in self.brains:
|
||||
brain.plotter.update()
|
||||
|
||||
camera = self.leader.plotter.camera
|
||||
_add_camera_callback(camera, _update_camera)
|
||||
for brain in self.brains:
|
||||
for renderer in brain.plotter.renderers:
|
||||
renderer.camera = camera
|
||||
@@ -0,0 +1,102 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import os
|
||||
import os.path as op
|
||||
|
||||
from ._brain import Brain
|
||||
|
||||
|
||||
class _BrainScraper:
|
||||
"""Scrape Brain objects."""
|
||||
|
||||
def __repr__(self):
|
||||
return "<BrainScraper>"
|
||||
|
||||
def __call__(self, block, block_vars, gallery_conf):
|
||||
rst = ""
|
||||
for brain in list(block_vars["example_globals"].values()):
|
||||
# Only need to process if it's a brain with a time_viewer
|
||||
# with traces on and shown in the same window, otherwise
|
||||
# PyVista and matplotlib scrapers can just do the work
|
||||
if (not isinstance(brain, Brain)) or brain._closed:
|
||||
continue
|
||||
from matplotlib import animation
|
||||
from matplotlib import pyplot as plt
|
||||
from sphinx_gallery.scrapers import matplotlib_scraper
|
||||
|
||||
img = brain.screenshot(time_viewer=True)
|
||||
dpi = 100.0
|
||||
figsize = (img.shape[1] / dpi, img.shape[0] / dpi)
|
||||
fig = plt.figure(figsize=figsize, dpi=dpi)
|
||||
ax = plt.Axes(fig, [0, 0, 1, 1])
|
||||
fig.add_axes(ax)
|
||||
img = ax.imshow(img)
|
||||
movie_key = "# brain.save_movie"
|
||||
if movie_key in block[1]:
|
||||
kwargs = dict()
|
||||
# Parse our parameters
|
||||
lines = block[1].splitlines()
|
||||
for li, line in enumerate(block[1].splitlines()):
|
||||
if line.startswith(movie_key):
|
||||
line = line[len(movie_key) :].replace("..., ", "")
|
||||
for ni in range(1, 5): # should be enough
|
||||
if len(lines) > li + ni and lines[li + ni].startswith(
|
||||
"# "
|
||||
):
|
||||
line = line + lines[li + ni][1:].strip()
|
||||
else:
|
||||
break
|
||||
assert line.startswith("(") and line.endswith(")")
|
||||
kwargs.update(eval(f"dict{line}")) # nosec B307
|
||||
for key, default in [
|
||||
("time_dilation", 4),
|
||||
("framerate", 24),
|
||||
("tmin", None),
|
||||
("tmax", None),
|
||||
("interpolation", None),
|
||||
("time_viewer", False),
|
||||
]:
|
||||
if key not in kwargs:
|
||||
kwargs[key] = default
|
||||
kwargs.pop("filename", None) # always omit this one
|
||||
if brain.time_viewer:
|
||||
assert kwargs["time_viewer"], "Must use time_viewer=True"
|
||||
frames = brain._make_movie_frames(callback=None, **kwargs)
|
||||
|
||||
# Turn them into an animation
|
||||
def func(frame):
|
||||
img.set_data(frame)
|
||||
return [img]
|
||||
|
||||
anim = animation.FuncAnimation(
|
||||
fig,
|
||||
func=func,
|
||||
frames=frames,
|
||||
blit=True,
|
||||
interval=1000.0 / kwargs["framerate"],
|
||||
)
|
||||
|
||||
# Out to sphinx-gallery:
|
||||
#
|
||||
# 1. A static image but hide it (useful for carousel)
|
||||
if animation.FFMpegWriter.isAvailable():
|
||||
writer = "ffmpeg"
|
||||
elif animation.ImageMagickWriter.isAvailable():
|
||||
writer = "imagemagick"
|
||||
else:
|
||||
writer = None
|
||||
static_fname = next(block_vars["image_path_iterator"])
|
||||
static_fname = static_fname[:-4] + ".gif"
|
||||
anim.save(static_fname, writer=writer, dpi=dpi)
|
||||
rel_fname = op.relpath(static_fname, gallery_conf["src_dir"])
|
||||
rel_fname = rel_fname.replace(os.sep, "/").lstrip("/")
|
||||
rst += f"\n.. image:: /{rel_fname}\n :class: hidden\n"
|
||||
|
||||
# 2. An animation that will be embedded and visible
|
||||
block_vars["example_globals"]["_brain_anim_"] = anim
|
||||
|
||||
brain.close()
|
||||
rst += matplotlib_scraper(block, block_vars, gallery_conf)
|
||||
return rst
|
||||
@@ -0,0 +1,185 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_lut(cmap, n_colors=256, center=None):
|
||||
"""Return a colormap suitable for setting as a LUT."""
|
||||
from .._3d import _get_cmap
|
||||
|
||||
assert not (isinstance(cmap, str) and cmap == "auto")
|
||||
cmap = _get_cmap(cmap)
|
||||
lut = np.round(cmap(np.linspace(0, 1, n_colors)) * 255.0).astype(np.int64)
|
||||
return lut
|
||||
|
||||
|
||||
def scale_sequential_lut(lut_table, fmin, fmid, fmax):
|
||||
"""Scale a sequential colormap."""
|
||||
assert fmin <= fmid <= fmax # guaranteed by calculate_lut
|
||||
lut_table_new = lut_table.copy()
|
||||
n_colors = lut_table.shape[0]
|
||||
n_colors2 = n_colors // 2
|
||||
|
||||
if fmax == fmin:
|
||||
fmid_idx = 0
|
||||
else:
|
||||
fmid_idx = np.clip(
|
||||
int(np.round(n_colors * ((fmid - fmin) / (fmax - fmin))) - 1),
|
||||
0,
|
||||
n_colors - 2,
|
||||
)
|
||||
|
||||
n_left = fmid_idx + 1
|
||||
n_right = n_colors - n_left
|
||||
for i in range(4):
|
||||
lut_table_new[: fmid_idx + 1, i] = np.interp(
|
||||
np.linspace(0, n_colors2 - 1, n_left), np.arange(n_colors), lut_table[:, i]
|
||||
)
|
||||
lut_table_new[fmid_idx + 1 :, i] = np.interp(
|
||||
np.linspace(n_colors - 1, n_colors2, n_right)[::-1],
|
||||
np.arange(n_colors),
|
||||
lut_table[:, i],
|
||||
)
|
||||
|
||||
return lut_table_new
|
||||
|
||||
|
||||
def get_fill_colors(cols, n_fill):
|
||||
"""Get the fill colors for the middle of divergent colormaps."""
|
||||
steps = np.linalg.norm(np.diff(cols[:, :3].astype(float), axis=0), axis=1)
|
||||
|
||||
ind = np.flatnonzero(steps[1:-1] > steps[[0, -1]].mean() * 3)
|
||||
if ind.size > 0:
|
||||
# choose the two colors between which there is the large step
|
||||
ind = ind[0] + 1
|
||||
fillcols = np.r_[
|
||||
np.tile(cols[ind, :], (n_fill // 2, 1)),
|
||||
np.tile(cols[ind + 1, :], (n_fill - n_fill // 2, 1)),
|
||||
]
|
||||
else:
|
||||
# choose a color from the middle of the colormap
|
||||
fillcols = np.tile(cols[int(cols.shape[0] / 2), :], (n_fill, 1))
|
||||
|
||||
return fillcols
|
||||
|
||||
|
||||
def calculate_lut(lut_table, alpha, fmin, fmid, fmax, center=None, transparent=True):
|
||||
"""Transparent color map calculation.
|
||||
|
||||
A colormap may be sequential or divergent. When the colormap is
|
||||
divergent indicate this by providing a value for 'center'. The
|
||||
meanings of fmin, fmid and fmax are different for sequential and
|
||||
divergent colormaps. A sequential colormap is characterised by::
|
||||
|
||||
[fmin, fmid, fmax]
|
||||
|
||||
where fmin and fmax define the edges of the colormap and fmid
|
||||
will be the value mapped to the center of the originally chosen colormap.
|
||||
A divergent colormap is characterised by::
|
||||
|
||||
[center-fmax, center-fmid, center-fmin, center,
|
||||
center+fmin, center+fmid, center+fmax]
|
||||
|
||||
i.e., values between center-fmin and center+fmin will not be shown
|
||||
while center-fmid will map to the fmid of the first half of the
|
||||
original colormap and center-fmid to the fmid of the second half.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lim_cmap : Colormap
|
||||
Color map obtained from _process_mapdata.
|
||||
alpha : float
|
||||
Alpha value to apply globally to the overlay. Has no effect with mpl
|
||||
backend.
|
||||
fmin : float
|
||||
Min value in colormap.
|
||||
fmid : float
|
||||
Intermediate value in colormap.
|
||||
fmax : float
|
||||
Max value in colormap.
|
||||
center : float or None
|
||||
If not None, center of a divergent colormap, changes the meaning of
|
||||
fmin, fmax and fmid.
|
||||
transparent : boolean
|
||||
if True: use a linear transparency between fmin and fmid and make
|
||||
values below fmin fully transparent (symmetrically for divergent
|
||||
colormaps)
|
||||
|
||||
Returns
|
||||
-------
|
||||
cmap : matplotlib.ListedColormap
|
||||
Color map with transparency channel.
|
||||
"""
|
||||
if not fmin <= fmid <= fmax:
|
||||
raise ValueError(f"Must have fmin ({fmin}) <= fmid ({fmid}) <= fmax ({fmax})")
|
||||
lut_table = create_lut(lut_table)
|
||||
assert lut_table.dtype.kind == "i"
|
||||
divergent = center is not None
|
||||
n_colors = lut_table.shape[0]
|
||||
|
||||
# Add transparency if needed
|
||||
n_colors2 = n_colors // 2
|
||||
if transparent:
|
||||
if divergent:
|
||||
N4 = np.full(4, n_colors // 4)
|
||||
N4[[0, 3, 1, 2][: np.mod(n_colors, 4)]] += 1
|
||||
assert N4.sum() == n_colors
|
||||
lut_table[:, -1] = np.round(
|
||||
np.hstack(
|
||||
[
|
||||
np.full(N4[0], 255.0),
|
||||
np.linspace(0, 255, N4[1])[::-1],
|
||||
np.linspace(0, 255, N4[2]),
|
||||
np.full(N4[3], 255.0),
|
||||
]
|
||||
)
|
||||
)
|
||||
else:
|
||||
lut_table[:n_colors2, -1] = np.round(np.linspace(0, 255, n_colors2))
|
||||
lut_table[n_colors2:, -1] = 255
|
||||
|
||||
alpha = float(alpha)
|
||||
if alpha < 1.0:
|
||||
lut_table[:, -1] = np.round(lut_table[:, -1] * alpha)
|
||||
|
||||
if divergent:
|
||||
if np.isclose(fmax, fmin, rtol=1e-6, atol=0):
|
||||
lut_table = np.r_[
|
||||
lut_table[:1],
|
||||
get_fill_colors(
|
||||
lut_table[n_colors2 - 3 : n_colors2 + 3, :], n_colors - 2
|
||||
),
|
||||
lut_table[-1:],
|
||||
]
|
||||
else:
|
||||
n_fill = int(round(fmin * n_colors2 / (fmax - fmin))) * 2
|
||||
lut_table = np.r_[
|
||||
scale_sequential_lut(
|
||||
lut_table[:n_colors2, :],
|
||||
center - fmax,
|
||||
center - fmid,
|
||||
center - fmin,
|
||||
),
|
||||
get_fill_colors(lut_table[n_colors2 - 3 : n_colors2 + 3, :], n_fill),
|
||||
scale_sequential_lut(
|
||||
lut_table[n_colors2:, :][::-1],
|
||||
center - fmax,
|
||||
center - fmid,
|
||||
center - fmin,
|
||||
)[::-1],
|
||||
]
|
||||
else:
|
||||
lut_table = scale_sequential_lut(lut_table, fmin, fmid, fmax)
|
||||
|
||||
n_colors = lut_table.shape[0]
|
||||
if n_colors != 256:
|
||||
lut = np.zeros((256, 4))
|
||||
x = np.linspace(1, n_colors, 256)
|
||||
for chan in range(4):
|
||||
lut[:, chan] = np.interp(x, np.arange(1, n_colors + 1), lut_table[:, chan])
|
||||
lut_table = lut
|
||||
|
||||
lut_table = lut_table.astype(np.float64) / 255.0
|
||||
return lut_table
|
||||
@@ -0,0 +1,177 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from os import path as path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...surface import _read_patch, complete_surface_info, read_curvature, read_surface
|
||||
from ...utils import _check_fname, _check_option, _validate_type, get_subjects_dir
|
||||
|
||||
|
||||
class _Surface:
|
||||
"""Container for a brain surface.
|
||||
|
||||
It is used for storing vertices, faces and morphometric data
|
||||
(curvature) of a hemisphere mesh.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
subject : string
|
||||
Name of subject
|
||||
hemi : {'lh', 'rh'}
|
||||
Which hemisphere to load
|
||||
surf : string
|
||||
Name of the surface to load (eg. inflated, orig ...).
|
||||
subjects_dir : str | None
|
||||
If not None, this directory will be used as the subjects directory
|
||||
instead of the value set using the SUBJECTS_DIR environment variable.
|
||||
offset : float | None
|
||||
If 0.0, the surface will be offset such that the medial
|
||||
wall is aligned with the origin. If None, no offset will
|
||||
be applied. If != 0.0, an additional offset will be used.
|
||||
units : str
|
||||
Can be 'm' or 'mm' (default).
|
||||
x_dir : ndarray | None
|
||||
The x direction to use for offset alignment.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
bin_curv : numpy.ndarray
|
||||
Curvature values stored as non-negative integers.
|
||||
coords : numpy.ndarray
|
||||
nvtx x 3 array of vertex (x, y, z) coordinates.
|
||||
curv : numpy.ndarray
|
||||
Vector representation of surface morpometry (curvature) values as
|
||||
loaded from a file.
|
||||
grey_curv : numpy.ndarray
|
||||
Normalized morphometry (curvature) data, used in order to get
|
||||
a gray cortex.
|
||||
faces : numpy.ndarray
|
||||
nfaces x 3 array of defining mesh triangles.
|
||||
hemi : {'lh', 'rh'}
|
||||
Which hemisphere to load.
|
||||
nn : numpy.ndarray
|
||||
Vertex normals for a triangulated surface.
|
||||
offset : float | None
|
||||
If float, align inside edge of each hemisphere to center + offset.
|
||||
If None, do not change coordinates (default).
|
||||
subject : string
|
||||
Name of subject.
|
||||
surf : string
|
||||
Name of the surface to load (eg. inflated, orig ...).
|
||||
units : str
|
||||
Can be 'm' or 'mm' (default).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
subject,
|
||||
hemi,
|
||||
surf,
|
||||
subjects_dir=None,
|
||||
offset=None,
|
||||
units="mm",
|
||||
x_dir=None,
|
||||
):
|
||||
x_dir = np.array([1.0, 0, 0]) if x_dir is None else x_dir
|
||||
assert isinstance(x_dir, np.ndarray)
|
||||
assert np.isclose(np.linalg.norm(x_dir), 1.0, atol=1e-6)
|
||||
assert hemi in ("lh", "rh")
|
||||
_validate_type(offset, (None, "numeric"), "offset")
|
||||
|
||||
self.units = _check_option("units", units, ("mm", "m"))
|
||||
self.subject = subject
|
||||
self.hemi = hemi
|
||||
self.surf = surf
|
||||
self.offset = offset
|
||||
self.bin_curv = None
|
||||
self.coords = None
|
||||
self.curv = None
|
||||
self.faces = None
|
||||
self.nn = None
|
||||
self.labels = dict()
|
||||
self.x_dir = x_dir
|
||||
|
||||
subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=True))
|
||||
self.data_path = path.join(subjects_dir, subject)
|
||||
if surf == "seghead":
|
||||
raise ValueError(
|
||||
"`surf` cannot be seghead, use "
|
||||
"`mne.viz.Brain.add_head` to plot the seghead"
|
||||
)
|
||||
|
||||
def load_geometry(self):
|
||||
"""Load geometry of the surface.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
if self.surf == "flat": # special case
|
||||
fname = path.join(self.data_path, "surf", f"{self.hemi}.cortex.patch.flat")
|
||||
_check_fname(
|
||||
fname, overwrite="read", must_exist=True, name="flatmap surface file"
|
||||
)
|
||||
coords, faces, orig_faces = _read_patch(fname)
|
||||
# rotate 90 degrees to get to a more standard orientation
|
||||
# where X determines the distance between the hemis
|
||||
coords = coords[:, [1, 0, 2]]
|
||||
coords[:, 1] *= -1
|
||||
else:
|
||||
# allow ?h.pial.T1 if ?h.pial doesn't exist for instance
|
||||
# end with '' for better file not found error
|
||||
for img in ("", ".T1", ".T2", ""):
|
||||
surf_fname = path.join(
|
||||
self.data_path, "surf", f"{self.hemi}.{self.surf}{img}"
|
||||
)
|
||||
if path.isfile(surf_fname):
|
||||
break
|
||||
coords, faces = read_surface(surf_fname)
|
||||
orig_faces = faces
|
||||
if self.units == "m":
|
||||
coords /= 1000.0
|
||||
if self.offset is not None:
|
||||
x_ = coords @ self.x_dir
|
||||
if self.hemi == "lh":
|
||||
coords -= (np.max(x_) + self.offset) * self.x_dir
|
||||
else:
|
||||
coords -= (np.min(x_) + self.offset) * self.x_dir
|
||||
surf = dict(rr=coords, tris=faces)
|
||||
complete_surface_info(surf, copy=False, verbose=False, do_neighbor_tri=False)
|
||||
nn = surf["nn"]
|
||||
self.coords = coords
|
||||
self.faces = faces
|
||||
self.orig_faces = orig_faces
|
||||
self.nn = nn
|
||||
|
||||
def __len__(self):
|
||||
"""Return number of vertices."""
|
||||
return len(self.coords)
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
return self.coords[:, 0]
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
return self.coords[:, 1]
|
||||
|
||||
@property
|
||||
def z(self):
|
||||
return self.coords[:, 2]
|
||||
|
||||
def load_curvature(self):
|
||||
"""Load in curvature values from the ?h.curv file."""
|
||||
curv_path = path.join(self.data_path, "surf", f"{self.hemi}.curv")
|
||||
if path.isfile(curv_path):
|
||||
self.curv = read_curvature(curv_path, binary=False)
|
||||
self.bin_curv = np.array(self.curv > 0, np.int64)
|
||||
else:
|
||||
self.curv = None
|
||||
self.bin_curv = None
|
||||
@@ -0,0 +1,54 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
ORIGIN = "auto"
|
||||
DIST = "auto"
|
||||
|
||||
_lh_views_dict = {
|
||||
"lateral": dict(azimuth=180.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"medial": dict(azimuth=0.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"rostral": dict(azimuth=90.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"caudal": dict(azimuth=270.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"dorsal": dict(azimuth=180.0, elevation=0.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"ventral": dict(azimuth=180.0, elevation=180.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"frontal": dict(azimuth=120.0, elevation=80.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"parietal": dict(azimuth=-120.0, elevation=60.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"sagittal": dict(azimuth=180.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"coronal": dict(azimuth=90.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"axial": dict(
|
||||
azimuth=180.0, elevation=0.0, focalpoint=ORIGIN, roll=0, distance=DIST
|
||||
), # noqa: E501
|
||||
}
|
||||
_rh_views_dict = {
|
||||
"lateral": dict(azimuth=180.0, elevation=-90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"medial": dict(azimuth=0.0, elevation=-90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"rostral": dict(azimuth=-90.0, elevation=-90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"caudal": dict(azimuth=90.0, elevation=-90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"dorsal": dict(azimuth=180.0, elevation=0.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"ventral": dict(azimuth=180.0, elevation=180.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"frontal": dict(azimuth=60.0, elevation=80.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"parietal": dict(azimuth=-60.0, elevation=60.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"sagittal": dict(azimuth=180.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"coronal": dict(azimuth=90.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
|
||||
"axial": dict(
|
||||
azimuth=180.0, elevation=0.0, focalpoint=ORIGIN, roll=0, distance=DIST
|
||||
),
|
||||
}
|
||||
# add short-size version entries into the dict
|
||||
lh_views_dict = _lh_views_dict.copy()
|
||||
for k, v in _lh_views_dict.items():
|
||||
lh_views_dict[k[:3]] = v
|
||||
lh_views_dict["flat"] = dict(
|
||||
azimuth=0, elevation=0, focalpoint=ORIGIN, roll=0, distance=DIST
|
||||
)
|
||||
|
||||
rh_views_dict = _rh_views_dict.copy()
|
||||
for k, v in _rh_views_dict.items():
|
||||
rh_views_dict[k[:3]] = v
|
||||
rh_views_dict["flat"] = dict(
|
||||
azimuth=0, elevation=0, focalpoint=ORIGIN, roll=0, distance=DIST
|
||||
)
|
||||
views_dicts = dict(
|
||||
lh=lh_views_dict, vol=lh_views_dict, both=lh_views_dict, rh=rh_views_dict
|
||||
)
|
||||
@@ -0,0 +1,214 @@
|
||||
"""Dipole viz specific functions."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import os.path as op
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial import ConvexHull
|
||||
|
||||
from .._freesurfer import _estimate_talxfm_rigid, _get_head_surface
|
||||
from ..surface import read_surface
|
||||
from ..transforms import _get_trans, apply_trans, invert_transform
|
||||
from ..utils import _check_option, _validate_type, get_subjects_dir
|
||||
from .utils import _validate_if_list_of_axes, plt_show
|
||||
|
||||
|
||||
def _check_concat_dipoles(dipole):
|
||||
from ..dipole import Dipole, _concatenate_dipoles
|
||||
|
||||
if not isinstance(dipole, Dipole):
|
||||
dipole = _concatenate_dipoles(dipole)
|
||||
return dipole
|
||||
|
||||
|
||||
def _plot_dipole_mri_outlines(
|
||||
dipoles,
|
||||
*,
|
||||
subject,
|
||||
trans,
|
||||
ax,
|
||||
subjects_dir,
|
||||
color,
|
||||
scale,
|
||||
coord_frame,
|
||||
show,
|
||||
block,
|
||||
head_source,
|
||||
title,
|
||||
surf,
|
||||
width,
|
||||
):
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.collections import LineCollection, PatchCollection
|
||||
from matplotlib.patches import Circle
|
||||
|
||||
extra = 'when mode is "outlines"'
|
||||
trans = _get_trans(trans, fro="head", to="mri")[0]
|
||||
_check_option(
|
||||
"coord_frame", coord_frame, ["head", "mri", "mri_rotated"], extra=extra
|
||||
)
|
||||
_validate_type(surf, (str, None), "surf")
|
||||
_check_option("surf", surf, ("white", "pial", None))
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(1, 3, figsize=(7, 2.5), squeeze=True, layout="constrained")
|
||||
_validate_if_list_of_axes(ax, 3, name="ax")
|
||||
dipoles = _check_concat_dipoles(dipoles)
|
||||
color = "r" if color is None else color
|
||||
scale = 0.03 if scale is None else scale
|
||||
width = 0.015 if width is None else width
|
||||
fig = ax[0].figure
|
||||
surfs = dict()
|
||||
hemis = ("lh", "rh")
|
||||
if surf is not None:
|
||||
for hemi in hemis:
|
||||
surfs[hemi] = read_surface(
|
||||
op.join(subjects_dir, subject, "surf", f"{hemi}.{surf}"),
|
||||
return_dict=True,
|
||||
)[2]
|
||||
surfs[hemi]["rr"] /= 1000.0
|
||||
subjects_dir = get_subjects_dir(subjects_dir)
|
||||
if subjects_dir is not None:
|
||||
subjects_dir = str(subjects_dir)
|
||||
surfs["head"] = _get_head_surface(head_source, subject, subjects_dir)
|
||||
del head_source
|
||||
mri_trans = head_trans = np.eye(4)
|
||||
if coord_frame in ("mri", "mri_rotated"):
|
||||
head_trans = trans["trans"]
|
||||
if coord_frame == "mri_rotated":
|
||||
rot = _estimate_talxfm_rigid(subject, subjects_dir)
|
||||
rot[:3, 3] = 0.0
|
||||
head_trans = rot @ head_trans
|
||||
mri_trans = rot @ mri_trans
|
||||
else:
|
||||
assert coord_frame == "head"
|
||||
mri_trans = invert_transform(trans)["trans"]
|
||||
for s in surfs.values():
|
||||
s["rr"] = 1000 * apply_trans(mri_trans, s["rr"])
|
||||
del mri_trans
|
||||
levels = dict()
|
||||
if surf is not None:
|
||||
use_rr = np.concatenate([surfs[key]["rr"] for key in hemis])
|
||||
else:
|
||||
use_rr = surfs["head"]["rr"]
|
||||
views = [("Axial", "XY"), ("Coronal", "XZ"), ("Sagittal", "YZ")]
|
||||
# axial: 25% up the Z axis
|
||||
axial = float(np.percentile(use_rr[:, 2], 20.0))
|
||||
coronal = float(np.percentile(use_rr[:, 1], 55.0))
|
||||
for key in hemis + ("head",):
|
||||
levels[key] = dict(Axial=axial, Coronal=coronal)
|
||||
if surf is not None:
|
||||
levels["rh"]["Sagittal"] = float(np.percentile(surfs["rh"]["rr"][:, 0], 50))
|
||||
levels["head"]["Sagittal"] = 0.0
|
||||
for ax_, (name, coords) in zip(ax, views):
|
||||
idx = list(map(dict(X=0, Y=1, Z=2).get, coords))
|
||||
miss = np.setdiff1d(np.arange(3), idx)[0]
|
||||
pos = 1000 * apply_trans(head_trans, dipoles.pos)
|
||||
ori = 1000 * apply_trans(head_trans, dipoles.ori, move=False)
|
||||
lims = dict()
|
||||
for ii, char in enumerate(coords):
|
||||
lim = surfs["head"]["rr"][:, idx[ii]]
|
||||
lim = np.array([lim.min(), lim.max()])
|
||||
lims[char] = lim
|
||||
ax_.quiver(
|
||||
pos[:, idx[0]],
|
||||
pos[:, idx[1]],
|
||||
scale * ori[:, idx[0]],
|
||||
scale * ori[:, idx[1]],
|
||||
color=color,
|
||||
pivot="middle",
|
||||
zorder=5,
|
||||
scale_units="xy",
|
||||
angles="xy",
|
||||
scale=1.0,
|
||||
width=width,
|
||||
minshaft=0.5,
|
||||
headwidth=2.5,
|
||||
headlength=2.5,
|
||||
headaxislength=2,
|
||||
)
|
||||
coll = PatchCollection(
|
||||
[
|
||||
Circle((x, y), radius=scale * 1000 * width * 6)
|
||||
for x, y in zip(pos[:, idx[0]], pos[:, idx[1]])
|
||||
],
|
||||
linewidths=0.0,
|
||||
facecolors=color,
|
||||
zorder=6,
|
||||
)
|
||||
for key, surf in surfs.items():
|
||||
try:
|
||||
level = levels[key][name]
|
||||
except KeyError:
|
||||
continue
|
||||
if key != "head":
|
||||
rrs = surf["rr"][:, idx]
|
||||
tris = ConvexHull(rrs).simplices
|
||||
segments = LineCollection(
|
||||
rrs[:, [0, 1]][tris],
|
||||
linewidths=1,
|
||||
linestyles="-",
|
||||
colors="k",
|
||||
zorder=3,
|
||||
alpha=0.25,
|
||||
)
|
||||
ax_.add_collection(segments)
|
||||
ax_.tricontour(
|
||||
surf["rr"][:, idx[0]],
|
||||
surf["rr"][:, idx[1]],
|
||||
surf["tris"],
|
||||
surf["rr"][:, miss],
|
||||
levels=[level],
|
||||
colors="k",
|
||||
linewidths=1.0,
|
||||
linestyles=["-"],
|
||||
zorder=4,
|
||||
alpha=0.5,
|
||||
)
|
||||
# TODO: this breaks the PatchCollection in MPL
|
||||
# for coll in h.collections:
|
||||
# coll.set_clip_on(False)
|
||||
ax_.add_collection(coll)
|
||||
ax_.set(
|
||||
title=name,
|
||||
xlim=lims[coords[0]],
|
||||
ylim=lims[coords[1]],
|
||||
xlabel=coords[0] + " (mm)",
|
||||
ylabel=coords[1] + " (mm)",
|
||||
)
|
||||
for spine in ax_.spines.values():
|
||||
spine.set_visible(False)
|
||||
ax_.grid(True, ls=":", zorder=2)
|
||||
ax_.set_aspect("equal")
|
||||
|
||||
if title is not None:
|
||||
fig.suptitle(title)
|
||||
plt_show(show, block=block)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def _plot_dipole_3d(dipoles, *, coord_frame, color, fig, trans, scale, mode):
|
||||
from .backends.renderer import _get_renderer
|
||||
|
||||
_check_option("coord_frame", coord_frame, ("head", "mri"))
|
||||
color = "r" if color is None else color
|
||||
scale = 0.005 if scale is None else scale
|
||||
renderer = _get_renderer(fig=fig, size=(600, 600))
|
||||
pos = dipoles.pos
|
||||
ori = dipoles.ori
|
||||
if coord_frame != "head":
|
||||
trans = _get_trans(trans, fro="head", to=coord_frame)[0]
|
||||
pos = apply_trans(trans, pos)
|
||||
ori = apply_trans(trans, ori)
|
||||
|
||||
renderer.sphere(center=pos, color=color, scale=scale)
|
||||
if mode == "arrow":
|
||||
x, y, z = pos.T
|
||||
u, v, w = ori.T
|
||||
renderer.quiver3d(x, y, z, u, v, w, scale=3 * scale, color=color, mode="arrow")
|
||||
renderer.show()
|
||||
fig = renderer.scene()
|
||||
return fig
|
||||
@@ -0,0 +1,853 @@
|
||||
"""Base classes and functions for 2D browser backends."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from itertools import cycle
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.pick import _DATA_CH_TYPES_SPLIT
|
||||
from ..defaults import _handle_default
|
||||
from ..filter import _iir_filter, _overlap_add_filter
|
||||
from ..fixes import _compare_version
|
||||
from ..utils import (
|
||||
_check_option,
|
||||
_get_stim_channel,
|
||||
_validate_type,
|
||||
get_config,
|
||||
logger,
|
||||
set_config,
|
||||
verbose,
|
||||
)
|
||||
from .backends._utils import VALID_BROWSE_BACKENDS
|
||||
from .utils import _get_color_list, _setup_plot_projector, _show_browser
|
||||
|
||||
MNE_BROWSER_BACKEND = None
|
||||
backend = None
|
||||
|
||||
|
||||
class BrowserParams:
|
||||
"""Container object for 2D browser parameters."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# default key to close window
|
||||
self.close_key = "escape"
|
||||
vars(self).update(**kwargs)
|
||||
|
||||
|
||||
class BrowserBase(ABC):
|
||||
"""A base class containing for the 2D browser.
|
||||
|
||||
This class contains all backend-independent attributes and methods.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
from ..epochs import BaseEpochs
|
||||
from ..io import BaseRaw
|
||||
from ..preprocessing import ICA
|
||||
|
||||
self.backend_name = None
|
||||
|
||||
self._data = None
|
||||
self._times = None
|
||||
|
||||
self.mne = BrowserParams(**kwargs)
|
||||
|
||||
inst = kwargs.get("inst", None)
|
||||
ica = kwargs.get("ica", None)
|
||||
|
||||
# what kind of data are we dealing with?
|
||||
if isinstance(ica, ICA):
|
||||
self.mne.instance_type = "ica"
|
||||
elif isinstance(inst, BaseRaw):
|
||||
self.mne.instance_type = "raw"
|
||||
elif isinstance(inst, BaseEpochs):
|
||||
self.mne.instance_type = "epochs"
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected an instance of Raw, Epochs, or ICA, got {type(inst)}."
|
||||
)
|
||||
|
||||
logger.debug(f"Opening {self.mne.instance_type} browser...")
|
||||
|
||||
self.mne.ica_type = None
|
||||
if self.mne.instance_type == "ica":
|
||||
if isinstance(self.mne.ica_inst, BaseRaw):
|
||||
self.mne.ica_type = "raw"
|
||||
elif isinstance(self.mne.ica_inst, BaseEpochs):
|
||||
self.mne.ica_type = "epochs"
|
||||
self.mne.is_epochs = "epochs" in (self.mne.instance_type, self.mne.ica_type)
|
||||
|
||||
# things that always start the same
|
||||
self.mne.ch_start = 0
|
||||
self.mne.projector = None
|
||||
if hasattr(self.mne, "projs"):
|
||||
self.mne.projs_active = np.array([p["active"] for p in self.mne.projs])
|
||||
self.mne.whitened_ch_names = list()
|
||||
if hasattr(self.mne, "noise_cov"):
|
||||
self.mne.use_noise_cov = self.mne.noise_cov is not None
|
||||
# allow up to 10000 zorder levels for annotations
|
||||
self.mne.zorder = dict(
|
||||
patch=0,
|
||||
grid=1,
|
||||
ann=2,
|
||||
events=10003,
|
||||
bads=10004,
|
||||
data=10005,
|
||||
mag=10006,
|
||||
grad=10007,
|
||||
scalebar=10008,
|
||||
vline=10009,
|
||||
)
|
||||
# additional params for epochs (won't affect raw / ICA)
|
||||
self.mne.epoch_traces = list()
|
||||
self.mne.bad_epochs = list()
|
||||
if inst is not None:
|
||||
self.mne.sampling_period = np.diff(inst.times[:2])[0] / inst.info["sfreq"]
|
||||
# annotations
|
||||
self.mne.annotations = list()
|
||||
self.mne.hscroll_annotations = list()
|
||||
self.mne.annotation_segments = list()
|
||||
self.mne.annotation_texts = list()
|
||||
self.mne.new_annotation_labels = list()
|
||||
self.mne.annotation_segment_colors = dict()
|
||||
self.mne.annotation_hover_line = None
|
||||
self.mne.draggable_annotations = False
|
||||
# lines
|
||||
self.mne.event_lines = list()
|
||||
self.mne.event_texts = list()
|
||||
self.mne.vline_visible = False
|
||||
# decim
|
||||
self.mne.decim_times = None
|
||||
self.mne.decim_data = None
|
||||
# scalings
|
||||
if hasattr(self.mne, "butterfly"):
|
||||
self.mne.scale_factor = 0.5 if self.mne.butterfly else 1.0
|
||||
self.mne.scalebars = dict()
|
||||
self.mne.scalebar_texts = dict()
|
||||
# ancillary child figures
|
||||
self.mne.child_figs = list()
|
||||
self.mne.fig_help = None
|
||||
self.mne.fig_proj = None
|
||||
self.mne.fig_histogram = None
|
||||
self.mne.fig_selection = None
|
||||
self.mne.fig_annotation = None
|
||||
# extra attributes for epochs
|
||||
if self.mne.is_epochs:
|
||||
# add epoch boundaries & center epoch numbers between boundaries
|
||||
self.mne.midpoints = (
|
||||
np.convolve(self.mne.boundary_times, np.ones(2), mode="valid") / 2
|
||||
)
|
||||
|
||||
# initialize picks and projectors
|
||||
self._update_picks()
|
||||
if not self.mne.instance_type == "ica":
|
||||
self._update_projector()
|
||||
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
# ANNOTATIONS
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
|
||||
def _get_annotation_labels(self):
|
||||
"""Get the unique labels in the raw object and added in the UI."""
|
||||
return sorted(
|
||||
set(self.mne.inst.annotations.description)
|
||||
| set(self.mne.new_annotation_labels)
|
||||
)
|
||||
|
||||
def _setup_annotation_colors(self):
|
||||
"""Set up colors for annotations; init some annotation vars."""
|
||||
segment_colors = getattr(self.mne, "annotation_segment_colors", dict())
|
||||
labels = self._get_annotation_labels()
|
||||
colors, red = _get_color_list(annotations=True)
|
||||
color_cycle = cycle(colors)
|
||||
for key, color in segment_colors.items():
|
||||
if color != red and key in labels:
|
||||
next(color_cycle)
|
||||
for idx, key in enumerate(labels):
|
||||
if key.lower().startswith("bad") or key.lower().startswith("edge"):
|
||||
segment_colors[key] = red
|
||||
elif key in segment_colors:
|
||||
continue
|
||||
else:
|
||||
segment_colors[key] = next(color_cycle)
|
||||
self.mne.annotation_segment_colors = segment_colors
|
||||
# init a couple other annotation-related variables
|
||||
self.mne.visible_annotations = {label: True for label in labels}
|
||||
self.mne.show_hide_annotation_checkboxes = None
|
||||
|
||||
def _update_annotation_segments(self):
|
||||
"""Update the array of annotation start/end times."""
|
||||
from ..annotations import _sync_onset
|
||||
|
||||
self.mne.annotation_segments = np.array([])
|
||||
if len(self.mne.inst.annotations):
|
||||
annot_start = _sync_onset(self.mne.inst, self.mne.inst.annotations.onset)
|
||||
durations = self.mne.inst.annotations.duration.copy()
|
||||
durations[durations < 1 / self.mne.info["sfreq"]] = (
|
||||
1 / self.mne.info["sfreq"]
|
||||
)
|
||||
annot_end = annot_start + durations
|
||||
self.mne.annotation_segments = np.vstack((annot_start, annot_end)).T
|
||||
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
# PROJECTOR & BADS
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
|
||||
def _update_projector(self):
|
||||
"""Update the data after projectors (or bads) have changed."""
|
||||
inds = np.where(self.mne.projs_on)[0] # doesn't include "active" projs
|
||||
# copy projs from full list (self.mne.projs) to info object
|
||||
with self.mne.info._unlock():
|
||||
self.mne.info["projs"] = [deepcopy(self.mne.projs[ix]) for ix in inds]
|
||||
# compute the projection operator
|
||||
proj, wh_chs = _setup_plot_projector(
|
||||
self.mne.info, self.mne.noise_cov, True, self.mne.use_noise_cov
|
||||
)
|
||||
self.mne.whitened_ch_names = list(wh_chs)
|
||||
self.mne.projector = proj
|
||||
|
||||
def _toggle_bad_channel(self, idx):
|
||||
"""Mark/unmark bad channels; `idx` is index of *visible* channels."""
|
||||
pick = self.mne.picks[idx]
|
||||
ch_name = self.mne.ch_names[pick]
|
||||
# add/remove from bads list
|
||||
bads = self.mne.info["bads"]
|
||||
marked_bad = ch_name not in bads
|
||||
if marked_bad:
|
||||
bads.append(ch_name)
|
||||
color = self.mne.ch_color_bad
|
||||
else:
|
||||
while ch_name in bads: # to make sure duplicates are removed
|
||||
bads.remove(ch_name)
|
||||
# Only mpl-backend has ch_colors
|
||||
if hasattr(self.mne, "ch_colors"):
|
||||
color = self.mne.ch_colors[idx]
|
||||
else:
|
||||
color = None
|
||||
self.mne.info["bads"] = bads
|
||||
|
||||
self._update_projector()
|
||||
|
||||
return color, pick, marked_bad
|
||||
|
||||
def _toggle_single_channel_annotation(self, ch_pick, annot_idx):
|
||||
current_ch_names = list(self.mne.inst.annotations.ch_names[annot_idx])
|
||||
if ch_pick in current_ch_names:
|
||||
current_ch_names.remove(ch_pick)
|
||||
else:
|
||||
current_ch_names.append(ch_pick)
|
||||
self.mne.inst.annotations.ch_names[annot_idx] = tuple(current_ch_names)
|
||||
|
||||
def _toggle_bad_epoch(self, xtime):
|
||||
epoch_num = self._get_epoch_num_from_time(xtime)
|
||||
epoch_ix = self.mne.inst.selection.tolist().index(epoch_num)
|
||||
if epoch_num in self.mne.bad_epochs:
|
||||
self.mne.bad_epochs.remove(epoch_num)
|
||||
color = "none"
|
||||
else:
|
||||
self.mne.bad_epochs.append(epoch_num)
|
||||
self.mne.bad_epochs.sort()
|
||||
color = self.mne.epoch_color_bad
|
||||
|
||||
return epoch_ix, color
|
||||
|
||||
def _toggle_whitening(self):
|
||||
if self.mne.noise_cov is not None:
|
||||
self.mne.use_noise_cov = not self.mne.use_noise_cov
|
||||
self._update_projector()
|
||||
self._update_yaxis_labels() # add/remove italics
|
||||
self._redraw()
|
||||
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
# MANAGE TRACES
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
|
||||
def _update_picks(self):
|
||||
"""Compute which channel indices to show."""
|
||||
if self.mne.butterfly and self.mne.ch_selections is not None:
|
||||
selections_dict = self._make_butterfly_selections_dict()
|
||||
self.mne.picks = np.concatenate(tuple(selections_dict.values()))
|
||||
elif self.mne.butterfly:
|
||||
self.mne.picks = self.mne.ch_order
|
||||
else:
|
||||
_slice = slice(self.mne.ch_start, self.mne.ch_start + self.mne.n_channels)
|
||||
self.mne.picks = self.mne.ch_order[_slice]
|
||||
self.mne.n_channels = len(self.mne.picks)
|
||||
assert isinstance(self.mne.picks, np.ndarray)
|
||||
assert self.mne.picks.dtype.kind == "i"
|
||||
|
||||
def _make_butterfly_selections_dict(self):
|
||||
"""Make an altered copy of the selections dict for butterfly mode."""
|
||||
selections_dict = deepcopy(self.mne.ch_selections)
|
||||
# remove potential duplicates
|
||||
for selection_group in ("Vertex", "Custom"):
|
||||
selections_dict.pop(selection_group, None)
|
||||
# if present, remove stim channel from non-misc selection groups
|
||||
stim_ch = _get_stim_channel(None, self.mne.info, raise_error=False)
|
||||
if len(stim_ch):
|
||||
stim_pick = self.mne.ch_names.tolist().index(stim_ch[0])
|
||||
for _sel, _picks in selections_dict.items():
|
||||
if _sel != "Misc":
|
||||
stim_mask = np.isin(_picks, [stim_pick], invert=True)
|
||||
selections_dict[_sel] = np.array(_picks)[stim_mask]
|
||||
return selections_dict
|
||||
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
# MANAGE DATA
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
|
||||
def _get_start_stop(self):
|
||||
# update time
|
||||
start_sec = self.mne.t_start - self.mne.first_time
|
||||
stop_sec = start_sec + self.mne.duration
|
||||
if self.mne.is_epochs:
|
||||
start, stop = np.round(
|
||||
np.array([start_sec, stop_sec]) * self.mne.info["sfreq"]
|
||||
).astype(int)
|
||||
else:
|
||||
start, stop = self.mne.inst.time_as_index((start_sec, stop_sec))
|
||||
|
||||
return start, stop
|
||||
|
||||
def _load_data(self, start=None, stop=None):
|
||||
"""Retrieve the bit of data we need for plotting."""
|
||||
if "raw" in (self.mne.instance_type, self.mne.ica_type):
|
||||
# Add additional sample to cover the case sfreq!=1000
|
||||
# when the shown time-range wouldn't correspond to duration anymore
|
||||
if stop is None:
|
||||
return self.mne.inst[:, start:]
|
||||
else:
|
||||
return self.mne.inst[:, start : stop + 2]
|
||||
else:
|
||||
# subtract one sample from tstart before searchsorted, to make sure
|
||||
# we land on the left side of the boundary time (avoid precision
|
||||
# errors)
|
||||
ix_start = np.searchsorted(
|
||||
self.mne.boundary_times, self.mne.t_start - self.mne.sampling_period
|
||||
)
|
||||
ix_stop = ix_start + self.mne.n_epochs
|
||||
item = slice(ix_start, ix_stop)
|
||||
data = np.concatenate(
|
||||
self.mne.inst.get_data(item=item, copy=False), axis=-1
|
||||
)
|
||||
times = np.arange(start, stop) / self.mne.info["sfreq"]
|
||||
return data, times
|
||||
|
||||
def _apply_filter(self, data, start, stop, picks):
|
||||
"""Filter (with same defaults as raw.filter())."""
|
||||
starts, stops = self.mne.filter_bounds
|
||||
mask = (starts < stop) & (stops > start)
|
||||
starts = np.maximum(starts[mask], start) - start
|
||||
stops = np.minimum(stops[mask], stop) - start
|
||||
for _start, _stop in zip(starts, stops):
|
||||
_picks = np.where(np.isin(picks, self.mne.picks_data))[0]
|
||||
if len(_picks) == 0:
|
||||
break
|
||||
this_data = data[_picks, _start:_stop]
|
||||
if isinstance(self.mne.filter_coefs, np.ndarray): # FIR
|
||||
this_data = _overlap_add_filter(
|
||||
this_data, self.mne.filter_coefs, copy=False
|
||||
)
|
||||
else: # IIR
|
||||
this_data = _iir_filter(
|
||||
this_data, self.mne.filter_coefs, None, 1, False
|
||||
)
|
||||
data[_picks, _start:_stop] = this_data
|
||||
|
||||
def _process_data(self, data, start, stop, picks, thread=None):
|
||||
"""Update self.mne.data after user interaction."""
|
||||
# apply projectors
|
||||
if self.mne.projector is not None:
|
||||
# thread is the loading-thread only available in Qt-backend
|
||||
if thread:
|
||||
thread.processText.emit("Applying Projectors...")
|
||||
data = self.mne.projector @ data
|
||||
# get only the channels we're displaying
|
||||
data = data[picks]
|
||||
# remove DC
|
||||
if self.mne.remove_dc:
|
||||
if thread:
|
||||
thread.processText.emit("Removing DC...")
|
||||
data -= np.nanmean(data, axis=1, keepdims=True)
|
||||
# apply filter
|
||||
if self.mne.filter_coefs is not None:
|
||||
if thread:
|
||||
thread.processText.emit("Apply Filter...")
|
||||
self._apply_filter(data, start, stop, picks)
|
||||
# scale the data for display in a 1-vertical-axis-unit slot
|
||||
if thread:
|
||||
thread.processText.emit("Scale Data...")
|
||||
this_names = self.mne.ch_names[picks]
|
||||
this_types = self.mne.ch_types[picks]
|
||||
stims = this_types == "stim"
|
||||
white = np.logical_and(
|
||||
np.isin(this_names, self.mne.whitened_ch_names),
|
||||
np.isin(this_names, self.mne.info["bads"], invert=True),
|
||||
)
|
||||
norms = np.vectorize(self.mne.scalings.__getitem__)(this_types)
|
||||
norms[stims] = data[stims].max(axis=-1)
|
||||
norms[white] = self.mne.scalings["whitened"]
|
||||
norms[norms == 0] = 1
|
||||
data /= 2 * norms[:, np.newaxis]
|
||||
|
||||
return data
|
||||
|
||||
def _update_data(self):
|
||||
start, stop = self._get_start_stop()
|
||||
# get the data
|
||||
data, times = self._load_data(start, stop)
|
||||
# process the data
|
||||
data = self._process_data(data, start, stop, self.mne.picks)
|
||||
# set the data as attributes
|
||||
self.mne.data = data
|
||||
self.mne.times = times
|
||||
|
||||
def _get_epoch_num_from_time(self, time):
|
||||
epoch_nums = self.mne.inst.selection
|
||||
return epoch_nums[np.searchsorted(self.mne.boundary_times[1:], time)]
|
||||
|
||||
def _redraw(self, update_data=True, annotations=False):
|
||||
"""Redraws backend if necessary."""
|
||||
if update_data:
|
||||
self._update_data()
|
||||
|
||||
self._draw_traces()
|
||||
|
||||
if annotations and not self.mne.is_epochs:
|
||||
self._draw_annotations()
|
||||
|
||||
def _close(self, event):
|
||||
"""Handle close events (via keypress or window [x])."""
|
||||
from matplotlib.pyplot import close
|
||||
|
||||
logger.debug(f"Closing {self.mne.instance_type} browser...")
|
||||
# write out bad epochs (after converting epoch numbers to indices)
|
||||
if self.mne.instance_type == "epochs":
|
||||
bad_ixs = np.isin(self.mne.inst.selection, self.mne.bad_epochs).nonzero()[0]
|
||||
self.mne.inst.drop(bad_ixs)
|
||||
logger.info(
|
||||
"The following epochs were marked as bad "
|
||||
"and are dropped:\n"
|
||||
f"{self.mne.bad_epochs}"
|
||||
)
|
||||
# write bad channels back to instance (don't do this for proj;
|
||||
# proj checkboxes are for viz only and shouldn't modify the instance)
|
||||
if self.mne.instance_type in ("raw", "epochs"):
|
||||
self.mne.inst.info["bads"] = self.mne.info["bads"]
|
||||
logger.info(f"Channels marked as bad:\n{self.mne.info['bads'] or 'none'}")
|
||||
# ICA excludes
|
||||
elif self.mne.instance_type == "ica":
|
||||
self.mne.ica.exclude = [
|
||||
self.mne.ica._ica_names.index(ch) for ch in self.mne.info["bads"]
|
||||
]
|
||||
# write window size to config
|
||||
str_size = ",".join([str(i) for i in self._get_size()])
|
||||
set_config("MNE_BROWSE_RAW_SIZE", str_size, set_env=False)
|
||||
# Clean up child figures (don't pop(), child figs remove themselves)
|
||||
while len(self.mne.child_figs):
|
||||
fig = self.mne.child_figs[-1]
|
||||
close(fig)
|
||||
self._close_event(fig)
|
||||
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
# CHILD FIGURES
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
@abstractmethod
|
||||
def _new_child_figure(self, fig_name, **kwargs):
|
||||
pass
|
||||
|
||||
def _create_ch_context_fig(self, idx):
|
||||
"""Show context figure; idx is index of **visible** channels."""
|
||||
inst = self.mne.instance_type
|
||||
pick = self.mne.picks[idx]
|
||||
if inst == "raw":
|
||||
fig = self._create_ch_location_fig(pick)
|
||||
elif inst == "ica":
|
||||
fig = self._create_ica_properties_fig(pick)
|
||||
else:
|
||||
fig = self._create_epoch_image_fig(pick)
|
||||
|
||||
return fig
|
||||
|
||||
def _create_ch_location_fig(self, pick):
|
||||
"""Show channel location figure."""
|
||||
from .utils import _channel_type_prettyprint, plot_sensors
|
||||
|
||||
ch_name = self.mne.ch_names[pick]
|
||||
ch_type = self.mne.ch_types[pick]
|
||||
if ch_type not in _DATA_CH_TYPES_SPLIT:
|
||||
return
|
||||
# create figure and axes
|
||||
title = f"Location of {ch_name}"
|
||||
fig = self._new_child_figure(figsize=(4, 4), fig_name=None, window_title=title)
|
||||
fig.suptitle(title)
|
||||
ax = fig.add_subplot(111)
|
||||
title = f"{ch_name} position ({_channel_type_prettyprint[ch_type]})"
|
||||
_ = plot_sensors(
|
||||
self.mne.info,
|
||||
ch_type=ch_type,
|
||||
axes=ax,
|
||||
title=title,
|
||||
kind="select",
|
||||
show=False,
|
||||
)
|
||||
# highlight desired channel & disable interactivity
|
||||
inds = np.isin(fig.lasso.ch_names, [ch_name])
|
||||
fig.lasso.disconnect()
|
||||
fig.lasso.alpha_other = 0.3
|
||||
fig.lasso.linewidth_selected = 3
|
||||
fig.lasso.style_sensors(inds)
|
||||
|
||||
return fig
|
||||
|
||||
def _create_ica_properties_fig(self, idx):
|
||||
"""Show ICA properties for the selected component."""
|
||||
from mne.viz.ica import (
|
||||
_create_properties_layout,
|
||||
_fast_plot_ica_properties,
|
||||
_prepare_data_ica_properties,
|
||||
)
|
||||
|
||||
ch_name = self.mne.ch_names[idx]
|
||||
if ch_name not in self.mne.ica._ica_names: # for EOG chans: do nothing
|
||||
return
|
||||
pick = self.mne.ica._ica_names.index(ch_name)
|
||||
title = f"{ch_name} properties"
|
||||
fig = self._new_child_figure(figsize=(7, 6), fig_name=None, window_title=title)
|
||||
fig.suptitle(title)
|
||||
fig, axes = _create_properties_layout(fig=fig)
|
||||
if not hasattr(self.mne, "data_ica_properties"):
|
||||
# Precompute epoch sources only once
|
||||
self.mne.data_ica_properties = _prepare_data_ica_properties(
|
||||
self.mne.ica_inst, self.mne.ica
|
||||
)
|
||||
_fast_plot_ica_properties(
|
||||
self.mne.ica,
|
||||
self.mne.ica_inst,
|
||||
picks=pick,
|
||||
axes=axes,
|
||||
psd_args=self.mne.psd_args,
|
||||
precomputed_data=self.mne.data_ica_properties,
|
||||
show=False,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def _create_epoch_image_fig(self, pick):
|
||||
"""Show epochs image for the selected channel."""
|
||||
from matplotlib.gridspec import GridSpec
|
||||
|
||||
from mne.viz import plot_epochs_image
|
||||
|
||||
ch_name = self.mne.ch_names[pick]
|
||||
title = f"Epochs image ({ch_name})"
|
||||
fig = self._new_child_figure(figsize=(6, 4), fig_name=None, window_title=title)
|
||||
fig.suptitle = title
|
||||
gs = GridSpec(nrows=3, ncols=10, figure=fig)
|
||||
fig.add_subplot(gs[:2, :9])
|
||||
fig.add_subplot(gs[2, :9])
|
||||
fig.add_subplot(gs[:2, 9])
|
||||
plot_epochs_image(self.mne.inst, picks=pick, fig=fig, show=False)
|
||||
|
||||
return fig
|
||||
|
||||
def _create_epoch_histogram(self):
|
||||
"""Create peak-to-peak histogram of channel amplitudes."""
|
||||
epochs = self.mne.inst
|
||||
data = OrderedDict()
|
||||
ptp = np.ptp(epochs.get_data(copy=False), axis=2)
|
||||
for ch_type in ("eeg", "mag", "grad"):
|
||||
if ch_type in epochs:
|
||||
data[ch_type] = ptp.T[self.mne.ch_types == ch_type].ravel()
|
||||
units = _handle_default("units")
|
||||
titles = _handle_default("titles")
|
||||
colors = _handle_default("color")
|
||||
scalings = _handle_default("scalings")
|
||||
title = "Histogram of peak-to-peak amplitudes"
|
||||
figsize = (4, 1 + 1.5 * len(data))
|
||||
fig = self._new_child_figure(
|
||||
figsize=figsize, fig_name="fig_histogram", window_title=title
|
||||
)
|
||||
for ix, (_ch_type, _data) in enumerate(data.items()):
|
||||
ax = fig.add_subplot(len(data), 1, ix + 1)
|
||||
ax.set(title=titles[_ch_type], xlabel=units[_ch_type], ylabel="Count")
|
||||
# set histogram bin range based on rejection thresholds
|
||||
reject = None
|
||||
_range = None
|
||||
if epochs.reject is not None and _ch_type in epochs.reject:
|
||||
reject = epochs.reject[_ch_type] * scalings[_ch_type]
|
||||
_range = (0.0, reject * 1.1)
|
||||
# plot it
|
||||
ax.hist(
|
||||
_data * scalings[_ch_type],
|
||||
bins=100,
|
||||
color=colors[_ch_type],
|
||||
range=_range,
|
||||
)
|
||||
if reject is not None:
|
||||
ax.plot((reject, reject), (0, ax.get_ylim()[1]), color="r")
|
||||
# finalize
|
||||
fig.suptitle(title, y=0.99)
|
||||
self.mne.fig_histogram = fig
|
||||
|
||||
return fig
|
||||
|
||||
def _close_event(self, fig):
|
||||
"""Look at _close_event in mne.fixes.py for why this exists."""
|
||||
pass
|
||||
|
||||
def fake_keypress(self, key, fig=None): # noqa: D400
|
||||
"""Pass a fake keypress to the figure.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The key to fake (e.g., ``'a'``).
|
||||
fig : instance of Figure
|
||||
The figure to pass the keypress to.
|
||||
"""
|
||||
return self._fake_keypress(key, fig=fig)
|
||||
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
# TEST METHODS
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
|
||||
@abstractmethod
|
||||
def _get_size(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _fake_keypress(self, key, fig):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _fake_click(self, point, fig, axis, xform, button, kind):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _click_ch_name(self, ch_index, button):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _resize_by_factor(self, factor):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_ticklabels(self, orientation):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _update_yaxis_labels(self):
|
||||
pass
|
||||
|
||||
|
||||
def _load_backend(backend_name):
|
||||
global backend
|
||||
if backend_name == "matplotlib":
|
||||
backend = importlib.import_module(name="._mpl_figure", package="mne.viz")
|
||||
else:
|
||||
from mne_qt_browser import _pg_figure as backend
|
||||
|
||||
logger.info(f"Using {backend_name} as 2D backend.")
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
def _get_browser(show, block, **kwargs):
|
||||
"""Instantiate a new MNE browse-style figure."""
|
||||
from .utils import _get_figsize_from_config
|
||||
|
||||
figsize = kwargs.setdefault("figsize", _get_figsize_from_config())
|
||||
if figsize is None or np.any(np.array(figsize) < 8):
|
||||
kwargs["figsize"] = (8, 8)
|
||||
kwargs["splash"] = kwargs.get("splash", True) and show
|
||||
if kwargs.get("theme", None) is None:
|
||||
kwargs["theme"] = get_config("MNE_BROWSER_THEME", "auto")
|
||||
if kwargs.get("overview_mode", None) is None:
|
||||
kwargs["overview_mode"] = get_config("MNE_BROWSER_OVERVIEW_MODE", "channels")
|
||||
|
||||
# Initialize browser backend
|
||||
backend_name = get_browser_backend()
|
||||
# Check mne-qt-browser compatibility
|
||||
if backend_name == "qt":
|
||||
import mne_qt_browser
|
||||
|
||||
from ..epochs import BaseEpochs
|
||||
|
||||
is_ica = kwargs.get("ica", False)
|
||||
is_epochs = isinstance(kwargs.get("inst", False), BaseEpochs)
|
||||
not_compat = _compare_version(mne_qt_browser.__version__, "<", "0.2.0")
|
||||
inst_str = "ICA" if is_ica else "Epochs"
|
||||
if not_compat and (is_ica or is_epochs):
|
||||
logger.info(
|
||||
f'You set the browser-backend to "qt" but your'
|
||||
f" current version {mne_qt_browser.__version__}"
|
||||
f" of mne-qt-browser is too low for {inst_str}."
|
||||
f"Update with pip or conda."
|
||||
f"Defaults to matplotlib."
|
||||
)
|
||||
with use_browser_backend("matplotlib"):
|
||||
# Initialize Browser
|
||||
fig = backend._init_browser(**kwargs)
|
||||
_show_browser(show=show, block=block, fig=fig)
|
||||
return fig
|
||||
|
||||
# Initialize Browser
|
||||
fig = backend._init_browser(**kwargs)
|
||||
_show_browser(show=show, block=block, fig=fig)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def _check_browser_backend_name(backend_name):
|
||||
_validate_type(backend_name, str, "backend_name")
|
||||
backend_name = backend_name.lower()
|
||||
backend_name = "qt" if backend_name == "pyqtgraph" else backend_name
|
||||
_check_option("backend_name", backend_name, VALID_BROWSE_BACKENDS)
|
||||
return backend_name
|
||||
|
||||
|
||||
@verbose
|
||||
def set_browser_backend(backend_name, verbose=None):
|
||||
"""Set the 2D browser backend for MNE.
|
||||
|
||||
The backend will be set as specified and operations will use
|
||||
that backend.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend_name : str
|
||||
The 2D browser backend to select. See Notes for the capabilities
|
||||
of each backend (``'qt'``, ``'matplotlib'``). The ``'qt'`` browser
|
||||
requires `mne-qt-browser
|
||||
<https://github.com/mne-tools/mne-qt-browser>`__.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
old_backend_name : str | None
|
||||
The old backend that was in use.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This table shows the capabilities of each backend ("✓" for full support,
|
||||
and "-" for partial support):
|
||||
|
||||
.. table::
|
||||
:widths: auto
|
||||
|
||||
+--------------------------------------+------------+----+
|
||||
| **2D browser function:** | matplotlib | qt |
|
||||
+======================================+============+====+
|
||||
| :func:`plot_raw` | ✓ | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
| :func:`plot_epochs` | ✓ | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
| :func:`plot_ica_sources` | ✓ | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
+--------------------------------------+------------+----+
|
||||
| **Feature:** |
|
||||
+--------------------------------------+------------+----+
|
||||
| Show Events | ✓ | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
| Add/Edit/Remove Annotations | ✓ | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
| Toggle Projections | ✓ | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
| Butterfly Mode | ✓ | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
| Selection Mode | ✓ | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
| Smooth Scrolling | | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
| Overview-Bar (with Z-Score-Mode) | | ✓ |
|
||||
+--------------------------------------+------------+----+
|
||||
|
||||
.. versionadded:: 0.24
|
||||
"""
|
||||
global MNE_BROWSER_BACKEND
|
||||
old_backend_name = MNE_BROWSER_BACKEND
|
||||
backend_name = _check_browser_backend_name(backend_name)
|
||||
if MNE_BROWSER_BACKEND != backend_name:
|
||||
_load_backend(backend_name)
|
||||
MNE_BROWSER_BACKEND = backend_name
|
||||
|
||||
return old_backend_name
|
||||
|
||||
|
||||
def _init_browser_backend():
|
||||
global MNE_BROWSER_BACKEND
|
||||
|
||||
# check if MNE_BROWSER_BACKEND is not None and valid or get it from config
|
||||
loaded_backend = MNE_BROWSER_BACKEND or get_config(
|
||||
key="MNE_BROWSER_BACKEND", default=None
|
||||
)
|
||||
if loaded_backend is not None:
|
||||
set_browser_backend(loaded_backend)
|
||||
return MNE_BROWSER_BACKEND
|
||||
else:
|
||||
errors = dict()
|
||||
# Try import of valid browser backends
|
||||
for name in VALID_BROWSE_BACKENDS:
|
||||
try:
|
||||
_load_backend(name)
|
||||
except ImportError as exc:
|
||||
errors[name] = str(exc)
|
||||
else:
|
||||
MNE_BROWSER_BACKEND = name
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Could not load any valid 2D backend:\n"
|
||||
+ "\n".join(f"{key}: {val}" for key, val in errors.items())
|
||||
)
|
||||
|
||||
return MNE_BROWSER_BACKEND
|
||||
|
||||
|
||||
def get_browser_backend():
|
||||
"""Return the 2D backend currently used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
backend_used : str | None
|
||||
The 2D browser backend currently in use. If no backend is found,
|
||||
returns ``None``.
|
||||
"""
|
||||
try:
|
||||
backend_name = _init_browser_backend()
|
||||
except RuntimeError as exc:
|
||||
backend_name = None
|
||||
logger.info(str(exc))
|
||||
return backend_name
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_browser_backend(backend_name):
|
||||
"""Create a 2D browser visualization context using the designated backend.
|
||||
|
||||
See :func:`mne.viz.set_browser_backend` for more details on the available
|
||||
2D browser backends and their capabilities.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend_name : {'qt', 'matplotlib'}
|
||||
The 2D browser backend to use in the context.
|
||||
"""
|
||||
old_backend = set_browser_backend(backend_name)
|
||||
try:
|
||||
yield backend
|
||||
finally:
|
||||
if old_backend is not None:
|
||||
try:
|
||||
set_browser_backend(old_backend)
|
||||
except Exception:
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,252 @@
|
||||
"""Functions for plotting projectors."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.pick import _picks_to_idx
|
||||
from ..defaults import DEFAULTS
|
||||
from ..utils import _pl, _validate_type, verbose, warn
|
||||
from .evoked import _plot_evoked
|
||||
from .topomap import _plot_projs_topomap
|
||||
from .utils import _check_type_projs, plt_show
|
||||
|
||||
|
||||
@verbose
|
||||
def plot_projs_joint(
|
||||
projs, evoked, picks_trace=None, *, topomap_kwargs=None, show=True, verbose=None
|
||||
):
|
||||
"""Plot projectors and evoked jointly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
projs : list of Projection
|
||||
The projectors to plot.
|
||||
evoked : instance of Evoked
|
||||
The data to plot. Typically this is the evoked instance created from
|
||||
averaging the epochs used to create the projection.
|
||||
%(picks_plot_projs_joint_trace)s
|
||||
topomap_kwargs : dict | None
|
||||
Keyword arguments to pass to :func:`mne.viz.plot_projs_topomap`.
|
||||
%(show)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
fig : instance of matplotlib Figure
|
||||
The figure.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function creates a figure with three columns:
|
||||
|
||||
1. The left shows the evoked data traces before (black) and after (green)
|
||||
projection.
|
||||
2. The center shows the topomaps associated with each of the projectors.
|
||||
3. The right again shows the data traces (black), but this time with:
|
||||
|
||||
1. The data projected onto each projector with a single normalization
|
||||
factor (solid lines). This is useful for seeing the relative power
|
||||
in each projection vector.
|
||||
2. The data projected onto each projector with individual normalization
|
||||
factors (dashed lines). This is useful for visualizing each time
|
||||
course regardless of its power.
|
||||
3. Additional data traces from ``picks_trace`` (solid yellow lines).
|
||||
This is useful for visualizing the "ground truth" of the time
|
||||
course, e.g. the measured EOG or ECG channel time courses.
|
||||
|
||||
.. versionadded:: 1.1
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ..evoked import Evoked
|
||||
|
||||
_validate_type(evoked, Evoked, "evoked")
|
||||
_validate_type(topomap_kwargs, (None, dict), "topomap_kwargs")
|
||||
projs = _check_type_projs(projs)
|
||||
topomap_kwargs = dict() if topomap_kwargs is None else topomap_kwargs
|
||||
if picks_trace is not None:
|
||||
picks_trace = _picks_to_idx(evoked.info, picks_trace, allow_empty=False)
|
||||
info = evoked.info
|
||||
ch_types = evoked.get_channel_types(unique=True, only_data_chs=True)
|
||||
proj_by_type = dict() # will be set up like an enumerate key->[pi, proj]
|
||||
ch_names_by_type = dict()
|
||||
used = np.zeros(len(projs), int)
|
||||
for ch_type in ch_types:
|
||||
these_picks = _picks_to_idx(info, ch_type, allow_empty=True)
|
||||
these_chs = [evoked.ch_names[pick] for pick in these_picks]
|
||||
ch_names_by_type[ch_type] = these_chs
|
||||
for pi, proj in enumerate(projs):
|
||||
if not set(these_chs).intersection(proj["data"]["col_names"]):
|
||||
continue
|
||||
if ch_type not in proj_by_type:
|
||||
proj_by_type[ch_type] = list()
|
||||
proj_by_type[ch_type].append([pi, deepcopy(proj)])
|
||||
used[pi] += 1
|
||||
missing = (~used.astype(bool)).sum()
|
||||
if missing:
|
||||
warn(
|
||||
f"{missing} projector{_pl(missing)} had no channel names "
|
||||
"present in epochs"
|
||||
)
|
||||
del projs
|
||||
ch_types = list(proj_by_type) # reduce to number we actually need
|
||||
# room for legend
|
||||
max_proj_per_type = max(len(x) for x in proj_by_type.values())
|
||||
cs_trace = 3
|
||||
cs_topo = 2
|
||||
n_col = max_proj_per_type * cs_topo + 2 * cs_trace
|
||||
n_row = len(ch_types)
|
||||
shape = (n_row, n_col)
|
||||
fig = plt.figure(
|
||||
figsize=(n_col * 1.1 + 0.5, n_row * 1.8 + 0.5), layout="constrained"
|
||||
)
|
||||
ri = 0
|
||||
# pick some sufficiently distinct colors (6 per proj type, e.g., ECG,
|
||||
# should be enough hopefully!)
|
||||
# https://personal.sron.nl/~pault/data/colourschemes.pdf
|
||||
# "Vibrant" color scheme
|
||||
proj_colors = [
|
||||
"#CC3311", # red
|
||||
"#009988", # teal
|
||||
"#0077BB", # blue
|
||||
"#EE3377", # magenta
|
||||
"#EE7733", # orange
|
||||
"#33BBEE", # cyan
|
||||
]
|
||||
trace_color = "#CCBB44" # yellow
|
||||
after_color, after_name = "#228833", "green"
|
||||
type_titles = DEFAULTS["titles"]
|
||||
last_ax = [None] * 2
|
||||
first_ax = dict()
|
||||
pe_kwargs = dict(show=False, draw=False)
|
||||
for ch_type, these_projs in proj_by_type.items():
|
||||
these_idxs, these_projs = zip(*these_projs)
|
||||
ch_names = ch_names_by_type[ch_type]
|
||||
idx = np.where(
|
||||
[np.isin(ch_names, proj["data"]["col_names"]).all() for proj in these_projs]
|
||||
)[0]
|
||||
used[idx] += 1
|
||||
count = len(these_projs)
|
||||
for proj in these_projs:
|
||||
sub_idx = [proj["data"]["col_names"].index(name) for name in ch_names]
|
||||
proj["data"]["data"] = proj["data"]["data"][:, sub_idx]
|
||||
proj["data"]["col_names"] = ch_names
|
||||
ba_ax = plt.subplot2grid(shape, (ri, 0), colspan=cs_trace, fig=fig)
|
||||
topo_axes = [
|
||||
plt.subplot2grid(
|
||||
shape, (ri, ci * cs_topo + cs_trace), colspan=cs_topo, fig=fig
|
||||
)
|
||||
for ci in range(count)
|
||||
]
|
||||
tr_ax = plt.subplot2grid(
|
||||
shape, (ri, n_col - cs_trace), colspan=cs_trace, fig=fig
|
||||
)
|
||||
# topomaps
|
||||
_plot_projs_topomap(these_projs, info=info, axes=topo_axes, **topomap_kwargs)
|
||||
for idx, proj, ax_ in zip(these_idxs, these_projs, topo_axes):
|
||||
ax_.set_title("") # could use proj['desc'] but it's long
|
||||
ax_.set_xlabel(f"projs[{idx}]", fontsize="small")
|
||||
unit = DEFAULTS["units"][ch_type]
|
||||
# traces
|
||||
this_evoked = evoked.copy().pick(ch_names)
|
||||
p = np.concatenate([p["data"]["data"] for p in these_projs])
|
||||
assert p.shape == (len(these_projs), len(this_evoked.data))
|
||||
traces = np.dot(p, this_evoked.data)
|
||||
traces *= np.sign(np.mean(np.dot(this_evoked.data, traces.T), 0))[:, np.newaxis]
|
||||
if picks_trace is not None:
|
||||
ch_traces = evoked.data[picks_trace]
|
||||
ch_traces -= np.mean(ch_traces, axis=1, keepdims=True)
|
||||
ch_traces /= np.abs(ch_traces).max()
|
||||
_plot_evoked(
|
||||
this_evoked, picks="all", axes=[tr_ax], **pe_kwargs, spatial_colors=False
|
||||
)
|
||||
for line in tr_ax.lines:
|
||||
line.set(lw=0.5, zorder=3)
|
||||
for t in list(tr_ax.texts):
|
||||
t.remove()
|
||||
scale = 0.8 * np.abs(tr_ax.get_ylim()).max()
|
||||
hs, labels = list(), list()
|
||||
traces /= np.abs(traces).max() # uniformly scaled
|
||||
for ti, trace in enumerate(traces):
|
||||
hs.append(
|
||||
tr_ax.plot(
|
||||
this_evoked.times,
|
||||
trace * scale,
|
||||
color=proj_colors[ti % len(proj_colors)],
|
||||
zorder=5,
|
||||
)[0]
|
||||
)
|
||||
labels.append(f"projs[{these_idxs[ti]}]")
|
||||
traces /= np.abs(traces).max(1, keepdims=True) # independently
|
||||
for ti, trace in enumerate(traces):
|
||||
tr_ax.plot(
|
||||
this_evoked.times,
|
||||
trace * scale,
|
||||
color=proj_colors[ti % len(proj_colors)],
|
||||
zorder=3.5,
|
||||
ls="--",
|
||||
lw=1.0,
|
||||
alpha=0.75,
|
||||
)
|
||||
if picks_trace is not None:
|
||||
trace_ch = [evoked.ch_names[pick] for pick in picks_trace]
|
||||
if len(picks_trace) == 1:
|
||||
trace_ch = trace_ch[0]
|
||||
hs.append(
|
||||
tr_ax.plot(
|
||||
this_evoked.times,
|
||||
ch_traces.T * scale,
|
||||
color=trace_color,
|
||||
lw=3,
|
||||
zorder=4,
|
||||
alpha=0.75,
|
||||
)[0]
|
||||
)
|
||||
labels.append(str(trace_ch))
|
||||
tr_ax.set(title="", xlabel="", ylabel="")
|
||||
# This will steal space from the subplots in a constrained layout
|
||||
# https://matplotlib.org/3.5.0/tutorials/intermediate/constrainedlayout_guide.html#legends # noqa: E501
|
||||
tr_ax.legend(
|
||||
hs,
|
||||
labels,
|
||||
loc="center left",
|
||||
borderaxespad=0.05,
|
||||
bbox_to_anchor=[1.05, 0.5],
|
||||
)
|
||||
last_ax[1] = tr_ax
|
||||
key = "Projected time course"
|
||||
if key not in first_ax:
|
||||
first_ax[key] = tr_ax
|
||||
# Before and after traces
|
||||
_plot_evoked(this_evoked, picks="all", axes=[ba_ax], **pe_kwargs)
|
||||
for line in ba_ax.lines:
|
||||
line.set(lw=0.5, zorder=3)
|
||||
loff = len(ba_ax.lines)
|
||||
this_proj_evoked = this_evoked.copy().add_proj(these_projs)
|
||||
# with meg='combined' any existing mag projectors (those already part
|
||||
# of evoked before we add_proj above) will have greatly
|
||||
# reduced power, so we ignore the warning about this issue
|
||||
this_proj_evoked.apply_proj(verbose="error")
|
||||
_plot_evoked(this_proj_evoked, picks="all", axes=[ba_ax], **pe_kwargs)
|
||||
for line in ba_ax.lines[loff:]:
|
||||
line.set(lw=0.5, zorder=4, color=after_color)
|
||||
for t in list(ba_ax.texts):
|
||||
t.remove()
|
||||
ba_ax.set(title="", xlabel="")
|
||||
ba_ax.set(ylabel=f"{type_titles[ch_type]}\n{unit}")
|
||||
last_ax[0] = ba_ax
|
||||
key = f"Before (black) and after ({after_name})"
|
||||
if key not in first_ax:
|
||||
first_ax[key] = ba_ax
|
||||
ri += 1
|
||||
for ax in last_ax:
|
||||
ax.set(xlabel="Time (s)")
|
||||
for title, ax in first_ax.items():
|
||||
ax.set_title(title, fontsize="medium")
|
||||
plt_show(show)
|
||||
return fig
|
||||
@@ -0,0 +1,81 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ..utils import _pl
|
||||
from .backends._utils import _pixmap_to_ndarray
|
||||
|
||||
|
||||
class _MNEQtBrowserScraper:
|
||||
def __repr__(self):
|
||||
return "<MNEQtBrowserScraper>"
|
||||
|
||||
def __call__(self, block, block_vars, gallery_conf):
|
||||
import mne_qt_browser
|
||||
from sphinx_gallery.scrapers import figure_rst
|
||||
|
||||
if gallery_conf["builder_name"] != "html":
|
||||
return ""
|
||||
img_fnames = list()
|
||||
inst = None
|
||||
n_plot = 0
|
||||
for gui in list(mne_qt_browser._browser_instances):
|
||||
try:
|
||||
scraped = getattr(gui, "_scraped", False)
|
||||
except Exception: # super __init__ not called, perhaps stale?
|
||||
scraped = True
|
||||
if scraped:
|
||||
continue
|
||||
gui._scraped = True # monkey-patch but it's easy enough
|
||||
n_plot += 1
|
||||
img_fnames.append(next(block_vars["image_path_iterator"]))
|
||||
pixmap, inst = _mne_qt_browser_screenshot(gui, inst)
|
||||
pixmap.save(img_fnames[-1])
|
||||
# child figures
|
||||
for fig in gui.mne.child_figs:
|
||||
# For now we only support Selection
|
||||
if not hasattr(fig, "channel_fig"):
|
||||
continue
|
||||
fig = fig.channel_fig
|
||||
img_fnames.append(next(block_vars["image_path_iterator"]))
|
||||
fig.savefig(img_fnames[-1])
|
||||
gui.close()
|
||||
del gui, pixmap
|
||||
if not len(img_fnames):
|
||||
return ""
|
||||
for _ in range(2):
|
||||
inst.processEvents()
|
||||
return figure_rst(img_fnames, gallery_conf["src_dir"], f"Raw plot{_pl(n_plot)}")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _screenshot_mode(browser):
|
||||
if need_zen := browser.mne.scrollbars_visible:
|
||||
browser._toggle_zenmode()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if need_zen:
|
||||
browser._toggle_zenmode()
|
||||
|
||||
|
||||
def _mne_qt_browser_screenshot(browser, inst=None, return_type="pixmap"):
|
||||
from mne_qt_browser._pg_figure import QApplication
|
||||
|
||||
if getattr(browser, "load_thread", None) is not None:
|
||||
if browser.load_thread.isRunning():
|
||||
browser.load_thread.wait(30000)
|
||||
if inst is None:
|
||||
inst = QApplication.instance()
|
||||
# processEvents to make sure our progressBar is updated
|
||||
with _screenshot_mode(browser):
|
||||
for _ in range(2):
|
||||
inst.processEvents()
|
||||
pixmap = browser.grab()
|
||||
assert return_type in ("pixmap", "ndarray")
|
||||
if return_type == "ndarray":
|
||||
return _pixmap_to_ndarray(pixmap)
|
||||
else:
|
||||
return pixmap, inst
|
||||
@@ -0,0 +1,7 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
"""Visualization backend."""
|
||||
|
||||
from . import renderer
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,421 @@
|
||||
#
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import collections.abc
|
||||
import functools
|
||||
import os
|
||||
import platform
|
||||
import signal
|
||||
import sys
|
||||
from colorsys import rgb_to_hls
|
||||
from contextlib import contextmanager
|
||||
from ctypes import c_char_p, c_void_p, cdll
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...fixes import _compare_version
|
||||
from ...utils import _check_qt_version, _validate_type, logger, warn
|
||||
from ..utils import _get_cmap
|
||||
|
||||
VALID_BROWSE_BACKENDS = (
|
||||
"qt",
|
||||
"matplotlib",
|
||||
)
|
||||
|
||||
VALID_3D_BACKENDS = (
|
||||
"pyvistaqt", # default 3d backend
|
||||
"notebook",
|
||||
)
|
||||
ALLOWED_QUIVER_MODES = ("2darrow", "arrow", "cone", "cylinder", "sphere", "oct")
|
||||
_ICONS_PATH = Path(__file__).parents[2] / "icons"
|
||||
|
||||
|
||||
def _get_colormap_from_array(
|
||||
colormap=None, normalized_colormap=False, default_colormap="coolwarm"
|
||||
):
|
||||
from matplotlib.colors import ListedColormap
|
||||
|
||||
if colormap is None:
|
||||
cmap = _get_cmap(default_colormap)
|
||||
elif isinstance(colormap, str):
|
||||
cmap = _get_cmap(colormap)
|
||||
elif normalized_colormap:
|
||||
cmap = ListedColormap(colormap)
|
||||
else:
|
||||
cmap = ListedColormap(np.array(colormap) / 255.0)
|
||||
return cmap
|
||||
|
||||
|
||||
def _check_color(color):
|
||||
from matplotlib.colors import colorConverter
|
||||
|
||||
if isinstance(color, str):
|
||||
color = colorConverter.to_rgb(color)
|
||||
elif isinstance(color, collections.abc.Iterable):
|
||||
np_color = np.array(color)
|
||||
if np_color.size % 3 != 0 and np_color.size % 4 != 0:
|
||||
raise ValueError("The expected valid format is RGB or RGBA.")
|
||||
if np_color.dtype in (np.int64, np.int32):
|
||||
if (np_color < 0).any() or (np_color > 255).any():
|
||||
raise ValueError("Values out of range [0, 255].")
|
||||
elif np_color.dtype == np.float64:
|
||||
if (np_color < 0.0).any() or (np_color > 1.0).any():
|
||||
raise ValueError("Values out of range [0.0, 1.0].")
|
||||
else:
|
||||
raise TypeError(
|
||||
"Expected data type is `np.int64`, `np.int32`, or `np.float64` but "
|
||||
f"{np_color.dtype} was given."
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected type is `str` or iterable but {type(color)} was given."
|
||||
)
|
||||
return color
|
||||
|
||||
|
||||
def _alpha_blend_background(ctable, background_color):
|
||||
alphas = ctable[:, -1][:, np.newaxis] / 255.0
|
||||
use_table = ctable.copy()
|
||||
use_table[:, -1] = 255.0
|
||||
return (use_table * alphas) + background_color * (1 - alphas)
|
||||
|
||||
|
||||
@functools.lru_cache(1)
|
||||
def _qt_init_icons():
|
||||
from qtpy.QtGui import QIcon
|
||||
|
||||
QIcon.setThemeSearchPaths([str(_ICONS_PATH)] + QIcon.themeSearchPaths())
|
||||
QIcon.setFallbackThemeName("light")
|
||||
return str(_ICONS_PATH)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _qt_disable_paint(widget):
|
||||
paintEvent = widget.paintEvent
|
||||
widget.paintEvent = lambda *args, **kwargs: None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
widget.paintEvent = paintEvent
|
||||
|
||||
|
||||
_QT_ICON_KEYS = dict(app=None)
|
||||
|
||||
|
||||
def _init_mne_qtapp(enable_icon=True, pg_app=False, splash=False):
|
||||
"""Get QApplication-instance for MNE-Python.
|
||||
|
||||
Parameter
|
||||
---------
|
||||
enable_icon: bool
|
||||
If to set an MNE-icon for the app.
|
||||
pg_app: bool
|
||||
If to create the QApplication with pyqtgraph. For an until know
|
||||
undiscovered reason the pyqtgraph-browser won't show without
|
||||
mkQApp from pyqtgraph.
|
||||
splash : bool | str
|
||||
If not False, display a splash screen. If str, set the message
|
||||
to the given string.
|
||||
|
||||
Returns
|
||||
-------
|
||||
app : ``qtpy.QtWidgets.QApplication``
|
||||
Instance of QApplication.
|
||||
splash : ``qtpy.QtWidgets.QSplashScreen``
|
||||
Instance of QSplashScreen. Only returned if splash is True or a
|
||||
string.
|
||||
"""
|
||||
from qtpy.QtCore import Qt
|
||||
from qtpy.QtGui import QGuiApplication, QIcon, QPixmap
|
||||
from qtpy.QtWidgets import QApplication, QSplashScreen
|
||||
|
||||
app_name = "MNE-Python"
|
||||
organization_name = "MNE"
|
||||
|
||||
# Fix from cbrnr/mnelab for app name in menu bar
|
||||
# This has to come *before* the creation of the QApplication to work.
|
||||
# It also only affects the title bar, not the application dock.
|
||||
# There seems to be no way to change the application dock from "python"
|
||||
# at runtime.
|
||||
if sys.platform.startswith("darwin"):
|
||||
try:
|
||||
# set bundle name on macOS (app name shown in the menu bar)
|
||||
from Foundation import NSBundle
|
||||
|
||||
bundle = NSBundle.mainBundle()
|
||||
info = bundle.localizedInfoDictionary() or bundle.infoDictionary()
|
||||
if "CFBundleName" not in info:
|
||||
info["CFBundleName"] = app_name
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
# First we need to check to make sure the display is valid, otherwise
|
||||
# Qt might segfault on us
|
||||
app = QApplication.instance()
|
||||
if not (app or _display_is_valid()):
|
||||
raise RuntimeError("Cannot connect to a valid display")
|
||||
|
||||
if pg_app:
|
||||
from pyqtgraph import mkQApp
|
||||
|
||||
old_argv = sys.argv
|
||||
try:
|
||||
sys.argv = []
|
||||
app = mkQApp(app_name)
|
||||
finally:
|
||||
sys.argv = old_argv
|
||||
elif not app:
|
||||
app = QApplication([app_name])
|
||||
app.setApplicationName(app_name)
|
||||
app.setOrganizationName(organization_name)
|
||||
qt_version = _check_qt_version(check_usable_display=False)
|
||||
# HiDPI is enabled by default in Qt6, requires to be explicitly set for Qt5
|
||||
if _compare_version(qt_version, "<", "6.0"):
|
||||
app.setAttribute(Qt.AA_UseHighDpiPixmaps)
|
||||
|
||||
if enable_icon or splash:
|
||||
icons_path = _qt_init_icons()
|
||||
|
||||
if (
|
||||
enable_icon
|
||||
and app.windowIcon().cacheKey() != _QT_ICON_KEYS["app"]
|
||||
and app.windowIcon().isNull() # don't overwrite existing icon (e.g. MNELAB)
|
||||
):
|
||||
# Set icon
|
||||
kind = "bigsur_" if platform.mac_ver()[0] >= "10.16" else "default_"
|
||||
icon = QIcon(f"{icons_path}/mne_{kind}icon.png")
|
||||
app.setWindowIcon(icon)
|
||||
_QT_ICON_KEYS["app"] = app.windowIcon().cacheKey()
|
||||
|
||||
out = app
|
||||
if splash:
|
||||
pixmap = QPixmap(f"{icons_path}/mne_splash.png")
|
||||
pixmap.setDevicePixelRatio(QGuiApplication.primaryScreen().devicePixelRatio())
|
||||
args = (pixmap,)
|
||||
if _should_raise_window():
|
||||
args += (Qt.WindowStaysOnTopHint,)
|
||||
qsplash = QSplashScreen(*args)
|
||||
qsplash.setAttribute(Qt.WA_ShowWithoutActivating, True)
|
||||
if isinstance(splash, str):
|
||||
alignment = int(Qt.AlignBottom | Qt.AlignHCenter)
|
||||
qsplash.showMessage(splash, alignment=alignment, color=Qt.white)
|
||||
qsplash.show()
|
||||
app.processEvents()
|
||||
out = (out, qsplash)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _display_is_valid():
|
||||
# Adapted from matplotilb _c_internal_utils.py
|
||||
if sys.platform != "linux":
|
||||
return True
|
||||
if os.getenv("DISPLAY"): # if it's not there, don't bother
|
||||
libX11 = cdll.LoadLibrary("libX11.so.6")
|
||||
libX11.XOpenDisplay.restype = c_void_p
|
||||
libX11.XOpenDisplay.argtypes = [c_char_p]
|
||||
display = libX11.XOpenDisplay(None)
|
||||
if display is not None:
|
||||
libX11.XCloseDisplay.argtypes = [c_void_p]
|
||||
libX11.XCloseDisplay(display)
|
||||
return True
|
||||
# not found, try Wayland
|
||||
if os.getenv("WAYLAND_DISPLAY"):
|
||||
libwayland = cdll.LoadLibrary("libwayland-client.so.0")
|
||||
if libwayland is not None:
|
||||
if all(
|
||||
hasattr(libwayland, f"wl_display_{kind}connect") for kind in ("", "dis")
|
||||
):
|
||||
libwayland.wl_display_connect.restype = c_void_p
|
||||
libwayland.wl_display_connect.argtypes = [c_char_p]
|
||||
display = libwayland.wl_display_connect(None)
|
||||
if display:
|
||||
libwayland.wl_display_disconnect.argtypes = [c_void_p]
|
||||
libwayland.wl_display_disconnect(display)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# https://stackoverflow.com/questions/5160577/ctrl-c-doesnt-work-with-pyqt
|
||||
def _qt_app_exec(app):
|
||||
# adapted from matplotlib
|
||||
old_signal = signal.getsignal(signal.SIGINT)
|
||||
is_python_signal_handler = old_signal is not None
|
||||
if is_python_signal_handler:
|
||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||
try:
|
||||
# Make IPython Console accessible again in Spyder
|
||||
app.lastWindowClosed.connect(app.quit)
|
||||
app.exec_()
|
||||
finally:
|
||||
# reset the SIGINT exception handler
|
||||
if is_python_signal_handler:
|
||||
signal.signal(signal.SIGINT, old_signal)
|
||||
|
||||
|
||||
def _qt_detect_theme():
|
||||
try:
|
||||
import darkdetect
|
||||
|
||||
theme = darkdetect.theme().lower()
|
||||
except ModuleNotFoundError:
|
||||
logger.info(
|
||||
'For automatic theme detection, "darkdetect" has to'
|
||||
" be installed! You can install it with "
|
||||
"`pip install darkdetect`"
|
||||
)
|
||||
theme = "light"
|
||||
except Exception:
|
||||
theme = "light"
|
||||
return theme
|
||||
|
||||
|
||||
def _qt_get_stylesheet(theme):
|
||||
_validate_type(theme, ("path-like",), "theme")
|
||||
theme = str(theme)
|
||||
stylesheet = "" # no stylesheet
|
||||
if theme in ("auto", "dark", "light"):
|
||||
if theme == "auto":
|
||||
return stylesheet
|
||||
assert theme in ("dark", "light")
|
||||
system_theme = _qt_detect_theme()
|
||||
if theme == system_theme:
|
||||
return stylesheet
|
||||
_, api = _check_qt_version(return_api=True)
|
||||
# On macOS or Qt 6, we shouldn't need to set anything when the requested
|
||||
# theme matches that of the current OS state
|
||||
try:
|
||||
import qdarkstyle
|
||||
except ModuleNotFoundError:
|
||||
logger.info(
|
||||
f'To use {theme} mode when in {system_theme} mode, "qdarkstyle" has'
|
||||
"to be installed! You can install it with:\n"
|
||||
"pip install qdarkstyle\n"
|
||||
)
|
||||
else:
|
||||
if api in ("PySide6", "PyQt6") and _compare_version(
|
||||
qdarkstyle.__version__, "<", "3.2.3"
|
||||
):
|
||||
warn(
|
||||
f"Setting theme={repr(theme)} is not supported for {api} in "
|
||||
f"qdarkstyle {qdarkstyle.__version__}, it will be ignored. "
|
||||
"Consider upgrading qdarkstyle to >=3.2.3."
|
||||
)
|
||||
else:
|
||||
stylesheet = qdarkstyle.load_stylesheet(
|
||||
getattr(
|
||||
getattr(qdarkstyle, theme).palette,
|
||||
f"{theme.capitalize()}Palette",
|
||||
)
|
||||
)
|
||||
return stylesheet
|
||||
else:
|
||||
try:
|
||||
file = open(theme)
|
||||
except OSError:
|
||||
warn(
|
||||
"Requested theme file not found, will use light instead: "
|
||||
f"{repr(theme)}"
|
||||
)
|
||||
else:
|
||||
with file as fid:
|
||||
stylesheet = fid.read()
|
||||
return stylesheet
|
||||
|
||||
|
||||
def _should_raise_window():
|
||||
from matplotlib import rcParams
|
||||
|
||||
return rcParams["figure.raise_window"]
|
||||
|
||||
|
||||
def _qt_raise_window(widget):
|
||||
# Set raise_window like matplotlib if possible
|
||||
if _should_raise_window():
|
||||
widget.activateWindow()
|
||||
widget.raise_()
|
||||
|
||||
|
||||
def _qt_is_dark(widget):
|
||||
# Ideally this would use CIELab, but this should be good enough
|
||||
win = widget.window()
|
||||
bgcolor = win.palette().color(win.backgroundRole()).getRgbF()[:3]
|
||||
return rgb_to_hls(*bgcolor)[1] < 0.5
|
||||
|
||||
|
||||
def _pixmap_to_ndarray(pixmap):
|
||||
from qtpy.QtGui import QImage
|
||||
|
||||
img = pixmap.toImage()
|
||||
img = img.convertToFormat(QImage.Format.Format_RGBA8888)
|
||||
ptr = img.bits()
|
||||
count = img.height() * img.width() * 4
|
||||
if hasattr(ptr, "setsize"): # PyQt
|
||||
ptr.setsize(count)
|
||||
data = np.frombuffer(ptr, dtype=np.uint8, count=count).copy()
|
||||
data.shape = (img.height(), img.width(), 4)
|
||||
return data / 255.0
|
||||
|
||||
|
||||
def _notebook_vtk_works():
|
||||
if sys.platform != "linux":
|
||||
return True
|
||||
# check if it's OSMesa -- if it is, continue
|
||||
try:
|
||||
from vtkmodules import vtkRenderingOpenGL2
|
||||
|
||||
vtkRenderingOpenGL2.vtkOSOpenGLRenderWindow
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
return True # has vtkOSOpenGLRenderWindow (OSMesa build)
|
||||
|
||||
# if it's not OSMesa, we need to check display validity
|
||||
if _display_is_valid():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _qt_safe_window(
|
||||
*, splash="figure.splash", window="figure.plotter.app_window", always_close=True
|
||||
):
|
||||
def dec(meth, splash=splash, always_close=always_close):
|
||||
@functools.wraps(meth)
|
||||
def func(self, *args, **kwargs):
|
||||
close_splash = always_close
|
||||
error = False
|
||||
try:
|
||||
meth(self, *args, **kwargs)
|
||||
except Exception:
|
||||
close_splash = error = True
|
||||
raise
|
||||
finally:
|
||||
for attr, do_close in ((splash, close_splash), (window, error)):
|
||||
if attr is None or not do_close:
|
||||
continue
|
||||
parent = self
|
||||
name = attr.split(".")[-1]
|
||||
try:
|
||||
for n in attr.split(".")[:-1]:
|
||||
parent = getattr(parent, n)
|
||||
if name:
|
||||
widget = getattr(parent, name, False)
|
||||
else: # empty string means "self"
|
||||
widget = parent
|
||||
if widget:
|
||||
widget.close()
|
||||
del widget
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
delattr(parent, name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return func
|
||||
|
||||
return dec
|
||||
@@ -0,0 +1,587 @@
|
||||
"""Core visualization operations."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import importlib
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...utils import (
|
||||
_auto_weakref,
|
||||
_check_option,
|
||||
_validate_type,
|
||||
fill_doc,
|
||||
get_config,
|
||||
logger,
|
||||
verbose,
|
||||
)
|
||||
from .._3d import _get_3d_option
|
||||
from ..utils import safe_event
|
||||
from ._utils import VALID_3D_BACKENDS
|
||||
|
||||
MNE_3D_BACKEND = None
|
||||
MNE_3D_BACKEND_TESTING = False
|
||||
|
||||
|
||||
_backend_name_map = dict(
|
||||
pyvistaqt="._qt",
|
||||
notebook="._notebook",
|
||||
)
|
||||
backend = None
|
||||
|
||||
|
||||
def _reload_backend(backend_name):
|
||||
global backend
|
||||
backend = importlib.import_module(
|
||||
name=_backend_name_map[backend_name], package="mne.viz.backends"
|
||||
)
|
||||
logger.info(f"Using {backend_name} 3d backend.")
|
||||
|
||||
|
||||
def _get_backend():
|
||||
_get_3d_backend()
|
||||
return backend
|
||||
|
||||
|
||||
def _get_renderer(*args, **kwargs):
|
||||
_get_3d_backend()
|
||||
return backend._Renderer(*args, **kwargs)
|
||||
|
||||
|
||||
def _check_3d_backend_name(backend_name):
|
||||
_validate_type(backend_name, str, "backend_name")
|
||||
backend_name = "pyvistaqt" if backend_name == "pyvista" else backend_name
|
||||
_check_option("backend_name", backend_name, VALID_3D_BACKENDS)
|
||||
return backend_name
|
||||
|
||||
|
||||
@verbose
|
||||
def set_3d_backend(backend_name, verbose=None):
|
||||
"""Set the 3D backend for MNE.
|
||||
|
||||
The backend will be set as specified and operations will use
|
||||
that backend.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend_name : str
|
||||
The 3d backend to select. See Notes for the capabilities of each
|
||||
backend (``'pyvistaqt'`` and ``'notebook'``).
|
||||
|
||||
.. versionchanged:: 0.24
|
||||
The ``'pyvista'`` backend was renamed ``'pyvistaqt'``.
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
old_backend_name : str | None
|
||||
The old backend that was in use.
|
||||
|
||||
Notes
|
||||
-----
|
||||
To use PyVista, set ``backend_name`` to ``pyvistaqt`` but the value
|
||||
``pyvista`` is still supported for backward compatibility.
|
||||
|
||||
This table shows the capabilities of each backend ("✓" for full support,
|
||||
and "-" for partial support):
|
||||
|
||||
.. table::
|
||||
:widths: auto
|
||||
|
||||
+--------------------------------------+-----------+----------+
|
||||
| **3D function:** | pyvistaqt | notebook |
|
||||
+======================================+===========+==========+
|
||||
| :func:`plot_vector_source_estimates` | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| :func:`plot_source_estimates` | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| :func:`plot_alignment` | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| :func:`plot_sparse_source_estimates` | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| :func:`plot_evoked_field` | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| :func:`snapshot_brain_montage` | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| :func:`link_brains` | ✓ | |
|
||||
+--------------------------------------+-----------+----------+
|
||||
+--------------------------------------+-----------+----------+
|
||||
| **Feature:** |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Large data | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Opacity/transparency | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Support geometric glyph | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Smooth shading | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Subplotting | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Inline plot in Jupyter Notebook | | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Inline plot in JupyterLab | | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Inline plot in Google Colab | | |
|
||||
+--------------------------------------+-----------+----------+
|
||||
| Toolbar | ✓ | ✓ |
|
||||
+--------------------------------------+-----------+----------+
|
||||
"""
|
||||
global MNE_3D_BACKEND
|
||||
old_backend_name = MNE_3D_BACKEND
|
||||
backend_name = _check_3d_backend_name(backend_name)
|
||||
if MNE_3D_BACKEND != backend_name:
|
||||
_reload_backend(backend_name)
|
||||
MNE_3D_BACKEND = backend_name
|
||||
return old_backend_name
|
||||
|
||||
|
||||
def get_3d_backend():
|
||||
"""Return the 3D backend currently used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
backend_used : str | None
|
||||
The 3d backend currently in use. If no backend is found,
|
||||
returns ``None``.
|
||||
|
||||
.. versionchanged:: 0.24
|
||||
The ``'pyvista'`` backend has been renamed ``'pyvistaqt'``, so
|
||||
``'pyvista'`` is no longer returned by this function.
|
||||
"""
|
||||
try:
|
||||
backend = _get_3d_backend()
|
||||
except RuntimeError as exc:
|
||||
backend = None
|
||||
logger.info(str(exc))
|
||||
return backend
|
||||
|
||||
|
||||
def _get_3d_backend():
|
||||
"""Load and return the current 3d backend."""
|
||||
global MNE_3D_BACKEND
|
||||
if MNE_3D_BACKEND is None:
|
||||
MNE_3D_BACKEND = get_config(key="MNE_3D_BACKEND", default=None)
|
||||
if MNE_3D_BACKEND is None: # try them in order
|
||||
errors = dict()
|
||||
for name in VALID_3D_BACKENDS:
|
||||
try:
|
||||
_reload_backend(name)
|
||||
except ImportError as exc:
|
||||
errors[name] = str(exc)
|
||||
else:
|
||||
MNE_3D_BACKEND = name
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Could not load any valid 3D backend\n"
|
||||
+ "\n".join(f"{key}: {val}" for key, val in errors.items())
|
||||
+ "\n".join(
|
||||
(
|
||||
"\n\n install pyvistaqt, using pip or conda:",
|
||||
"'pip install pyvistaqt'",
|
||||
"'conda install -c conda-forge pyvistaqt'",
|
||||
"\n or install ipywidgets, "
|
||||
+ "if using a notebook backend",
|
||||
"'pip install ipywidgets'",
|
||||
"'conda install -c conda-forge ipywidgets'",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
MNE_3D_BACKEND = _check_3d_backend_name(MNE_3D_BACKEND)
|
||||
_reload_backend(MNE_3D_BACKEND)
|
||||
MNE_3D_BACKEND = _check_3d_backend_name(MNE_3D_BACKEND)
|
||||
return MNE_3D_BACKEND
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_3d_backend(backend_name):
|
||||
"""Create a 3d visualization context using the designated backend.
|
||||
|
||||
See :func:`mne.viz.set_3d_backend` for more details on the available
|
||||
3d backends and their capabilities.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend_name : {'pyvistaqt', 'notebook'}
|
||||
The 3d backend to use in the context.
|
||||
"""
|
||||
old_backend = set_3d_backend(backend_name)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if old_backend is not None:
|
||||
try:
|
||||
set_3d_backend(old_backend)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _use_test_3d_backend(backend_name, interactive=False):
|
||||
"""Create a testing viz context.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend_name : str
|
||||
The 3d backend to use in the context.
|
||||
interactive : bool
|
||||
If True, ensure interactive elements are accessible.
|
||||
"""
|
||||
with _actors_invisible():
|
||||
with use_3d_backend(backend_name):
|
||||
with backend._testing_context(interactive):
|
||||
yield
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _actors_invisible():
|
||||
global MNE_3D_BACKEND_TESTING
|
||||
orig_testing = MNE_3D_BACKEND_TESTING
|
||||
MNE_3D_BACKEND_TESTING = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
MNE_3D_BACKEND_TESTING = orig_testing
|
||||
|
||||
|
||||
@fill_doc
|
||||
def set_3d_view(
|
||||
figure,
|
||||
azimuth=None,
|
||||
elevation=None,
|
||||
focalpoint=None,
|
||||
distance=None,
|
||||
roll=None,
|
||||
):
|
||||
"""Configure the view of the given scene.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
figure : object
|
||||
The scene which is modified.
|
||||
%(azimuth)s
|
||||
%(elevation)s
|
||||
%(focalpoint)s
|
||||
%(distance)s
|
||||
%(roll)s
|
||||
"""
|
||||
backend._set_3d_view(
|
||||
figure=figure,
|
||||
azimuth=azimuth,
|
||||
elevation=elevation,
|
||||
focalpoint=focalpoint,
|
||||
distance=distance,
|
||||
roll=roll,
|
||||
)
|
||||
|
||||
|
||||
@fill_doc
|
||||
def set_3d_title(figure, title, size=40, *, color="white", position="upper_left"):
|
||||
"""Configure the title of the given scene.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
figure : object
|
||||
The scene which is modified.
|
||||
title : str
|
||||
The title of the scene.
|
||||
size : int
|
||||
The size of the title.
|
||||
color : matplotlib color
|
||||
The color of the title.
|
||||
|
||||
.. versionadded:: 1.9
|
||||
position : str
|
||||
The position to use, e.g., "upper_left". See
|
||||
:meth:`pyvista.Plotter.add_text` for details.
|
||||
|
||||
.. versionadded:: 1.9
|
||||
|
||||
Returns
|
||||
-------
|
||||
text : object
|
||||
The text object returned by the given backend.
|
||||
|
||||
.. versionadded:: 1.0
|
||||
"""
|
||||
return backend._set_3d_title(
|
||||
figure=figure, title=title, size=size, color=color, position=position
|
||||
)
|
||||
|
||||
|
||||
def create_3d_figure(
|
||||
size,
|
||||
bgcolor=(0, 0, 0),
|
||||
smooth_shading=None,
|
||||
handle=None,
|
||||
*,
|
||||
scene=True,
|
||||
show=False,
|
||||
title="MNE 3D Figure",
|
||||
):
|
||||
"""Return an empty figure based on the current 3d backend.
|
||||
|
||||
.. warning:: Proceed with caution when the renderer object is
|
||||
returned (with ``scene=False``) because the _Renderer
|
||||
API is not necessarily stable enough for production,
|
||||
it's still actively in development.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
size : tuple
|
||||
The dimensions of the 3d figure (width, height).
|
||||
bgcolor : tuple
|
||||
The color of the background.
|
||||
smooth_shading : bool | None
|
||||
Whether to enable smooth shading. If ``None``, uses the config value
|
||||
``MNE_3D_OPTION_SMOOTH_SHADING``. Defaults to ``None``.
|
||||
handle : int | None
|
||||
The figure identifier.
|
||||
scene : bool
|
||||
If True (default), the returned object is the Figure3D. If False,
|
||||
an advanced, undocumented Renderer object is returned (the API is not
|
||||
stable or documented, so this is not recommended).
|
||||
show : bool
|
||||
If True, show the renderer immediately.
|
||||
|
||||
.. versionadded:: 1.0
|
||||
title : str
|
||||
The window title to use (if applicable).
|
||||
|
||||
.. versionadded:: 1.9
|
||||
|
||||
Returns
|
||||
-------
|
||||
figure : instance of Figure3D or ``Renderer``
|
||||
The requested empty figure or renderer, depending on ``scene``.
|
||||
"""
|
||||
_validate_type(smooth_shading, (bool, None), "smooth_shading")
|
||||
if smooth_shading is None:
|
||||
smooth_shading = _get_3d_option("smooth_shading")
|
||||
renderer = _get_renderer(
|
||||
fig=handle,
|
||||
size=size,
|
||||
bgcolor=bgcolor,
|
||||
smooth_shading=smooth_shading,
|
||||
show=show,
|
||||
name=title,
|
||||
)
|
||||
if scene:
|
||||
return renderer.scene()
|
||||
else:
|
||||
return renderer
|
||||
|
||||
|
||||
def close_3d_figure(figure):
|
||||
"""Close the given scene.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
figure : object
|
||||
The scene which needs to be closed.
|
||||
"""
|
||||
backend._close_3d_figure(figure)
|
||||
|
||||
|
||||
def close_all_3d_figures():
|
||||
"""Close all the scenes of the current 3d backend."""
|
||||
backend._close_all()
|
||||
|
||||
|
||||
def get_brain_class():
|
||||
"""Return the proper Brain class based on the current 3d backend.
|
||||
|
||||
Returns
|
||||
-------
|
||||
brain : object
|
||||
The Brain class corresponding to the current 3d backend.
|
||||
"""
|
||||
from ...viz._brain import Brain
|
||||
|
||||
return Brain
|
||||
|
||||
|
||||
class _TimeInteraction:
|
||||
"""Mixin enabling time interaction controls."""
|
||||
|
||||
def _enable_time_interaction(
|
||||
self,
|
||||
fig,
|
||||
current_time_func,
|
||||
times,
|
||||
init_playback_speed=0.01,
|
||||
playback_speed_range=(0.01, 0.1),
|
||||
):
|
||||
from ..ui_events import (
|
||||
PlaybackSpeed,
|
||||
TimeChange,
|
||||
publish,
|
||||
subscribe,
|
||||
)
|
||||
|
||||
self._fig = fig
|
||||
self._current_time_func = current_time_func
|
||||
self._times = times
|
||||
self._init_time = current_time_func()
|
||||
self._init_playback_speed = init_playback_speed
|
||||
|
||||
if not hasattr(self, "_dock"):
|
||||
self._dock_initialize()
|
||||
|
||||
if not hasattr(self, "_tool_bar") or self._tool_bar is None:
|
||||
self._tool_bar_initialize(name="Toolbar")
|
||||
|
||||
if not hasattr(self, "_widgets"):
|
||||
self._widgets = dict()
|
||||
|
||||
# Dock widgets
|
||||
@_auto_weakref
|
||||
def publish_time_change(time_index):
|
||||
publish(
|
||||
fig,
|
||||
TimeChange(time=np.interp(time_index, np.arange(len(times)), times)),
|
||||
)
|
||||
|
||||
layout = self._dock_add_group_box("")
|
||||
self._widgets["time_slider"] = self._dock_add_slider(
|
||||
name="Time (s)",
|
||||
value=np.interp(current_time_func(), times, np.arange(len(times))),
|
||||
rng=[0, len(times) - 1],
|
||||
double=True,
|
||||
callback=publish_time_change,
|
||||
compact=False,
|
||||
layout=layout,
|
||||
)
|
||||
hlayout = self._dock_add_layout(vertical=False)
|
||||
self._widgets["min_time"] = self._dock_add_label("-", layout=hlayout)
|
||||
self._dock_add_stretch(hlayout)
|
||||
self._widgets["current_time"] = self._dock_add_label(value="x", layout=hlayout)
|
||||
self._dock_add_stretch(hlayout)
|
||||
self._widgets["max_time"] = self._dock_add_label(value="+", layout=hlayout)
|
||||
self._layout_add_widget(layout, hlayout)
|
||||
|
||||
self._widgets["min_time"].set_value(f"{times[0]: .3f}")
|
||||
self._widgets["current_time"].set_value(f"{current_time_func(): .3f}")
|
||||
self._widgets["max_time"].set_value(f"{times[-1]: .3f}")
|
||||
|
||||
@_auto_weakref
|
||||
def publish_playback_speed(speed):
|
||||
publish(fig, PlaybackSpeed(speed=speed))
|
||||
|
||||
self._widgets["playback_speed"] = self._dock_add_spin_box(
|
||||
name="Speed",
|
||||
value=init_playback_speed,
|
||||
rng=playback_speed_range,
|
||||
callback=publish_playback_speed,
|
||||
layout=layout,
|
||||
)
|
||||
|
||||
# Tool bar buttons
|
||||
self._widgets["reset"] = self._tool_bar_add_button(
|
||||
name="reset", desc="Reset", func=self._reset_time
|
||||
)
|
||||
self._widgets["play"] = self._tool_bar_add_play_button(
|
||||
name="play",
|
||||
desc="Play/Pause",
|
||||
func=self._toggle_playback,
|
||||
shortcut=" ",
|
||||
)
|
||||
|
||||
# Configure playback
|
||||
self._playback = False
|
||||
self._playback_initialize(
|
||||
func=self._play,
|
||||
timeout=17,
|
||||
value=np.interp(current_time_func(), times, np.arange(len(times))),
|
||||
rng=[0, len(times) - 1],
|
||||
time_widget=self._widgets["time_slider"],
|
||||
play_widget=self._widgets["play"],
|
||||
)
|
||||
|
||||
# Keyboard shortcuts
|
||||
@_auto_weakref
|
||||
def shift_time(direction):
|
||||
amount = self._widgets["playback_speed"].get_value()
|
||||
publish(
|
||||
self._fig,
|
||||
TimeChange(time=self._current_time_func() + direction * amount),
|
||||
)
|
||||
|
||||
if self.plotter.iren is not None:
|
||||
self.plotter.add_key_event("n", partial(shift_time, direction=1))
|
||||
self.plotter.add_key_event("b", partial(shift_time, direction=-1))
|
||||
|
||||
# Subscribe to relevant UI events
|
||||
subscribe(fig, "time_change", self._on_time_change)
|
||||
subscribe(fig, "playback_speed", self._on_playback_speed)
|
||||
|
||||
def _on_time_change(self, event):
|
||||
"""Respond to time_change UI event."""
|
||||
from ..ui_events import disable_ui_events
|
||||
|
||||
new_time = np.clip(event.time, self._times[0], self._times[-1])
|
||||
new_time_idx = np.interp(new_time, self._times, np.arange(len(self._times)))
|
||||
|
||||
with disable_ui_events(self._fig):
|
||||
self._widgets["time_slider"].set_value(new_time_idx)
|
||||
self._widgets["current_time"].set_value(f"{new_time:.3f}")
|
||||
|
||||
def _on_playback_speed(self, event):
|
||||
"""Respond to playback_speed UI event."""
|
||||
from ..ui_events import disable_ui_events
|
||||
|
||||
with disable_ui_events(self._fig):
|
||||
self._widgets["playback_speed"].set_value(event.speed)
|
||||
|
||||
def _toggle_playback(self, value=None):
|
||||
"""Toggle time playback."""
|
||||
from ..ui_events import TimeChange, publish
|
||||
|
||||
if value is None:
|
||||
self._playback = not self._playback
|
||||
else:
|
||||
self._playback = value
|
||||
|
||||
if self._playback:
|
||||
self._tool_bar_update_button_icon(name="play", icon_name="pause")
|
||||
if self._current_time_func() == self._times[-1]: # start over
|
||||
publish(self._fig, TimeChange(time=self._times[0]))
|
||||
self._last_tick = time.time()
|
||||
else:
|
||||
self._tool_bar_update_button_icon(name="play", icon_name="play")
|
||||
|
||||
def _reset_time(self):
|
||||
"""Reset time and playback speed to initial values."""
|
||||
from ..ui_events import PlaybackSpeed, TimeChange, publish
|
||||
|
||||
publish(self._fig, TimeChange(time=self._init_time))
|
||||
publish(self._fig, PlaybackSpeed(speed=self._init_playback_speed))
|
||||
|
||||
@safe_event
|
||||
def _play(self):
|
||||
if self._playback:
|
||||
try:
|
||||
self._advance()
|
||||
except Exception:
|
||||
self._toggle_playback(value=False)
|
||||
raise
|
||||
|
||||
def _advance(self):
|
||||
from ..ui_events import TimeChange, publish
|
||||
|
||||
this_time = time.time()
|
||||
delta = this_time - self._last_tick
|
||||
self._last_tick = time.time()
|
||||
time_shift = delta * self._widgets["playback_speed"].get_value()
|
||||
new_time = min(self._current_time_func() + time_shift, self._times[-1])
|
||||
publish(self._fig, TimeChange(time=new_time))
|
||||
if new_time == self._times[-1]:
|
||||
self._toggle_playback(value=False)
|
||||
@@ -0,0 +1,469 @@
|
||||
"""Functions to plot on circle as for connectivity."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from functools import partial
|
||||
from itertools import cycle
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils import _validate_type
|
||||
from .utils import _get_cmap, plt_show
|
||||
|
||||
|
||||
def circular_layout(
|
||||
node_names,
|
||||
node_order,
|
||||
start_pos=90,
|
||||
start_between=True,
|
||||
group_boundaries=None,
|
||||
group_sep=10,
|
||||
):
|
||||
"""Create layout arranging nodes on a circle.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
node_names : list of str
|
||||
Node names.
|
||||
node_order : list of str
|
||||
List with node names defining the order in which the nodes are
|
||||
arranged. Must have the elements as node_names but the order can be
|
||||
different. The nodes are arranged clockwise starting at "start_pos"
|
||||
degrees.
|
||||
start_pos : float
|
||||
Angle in degrees that defines where the first node is plotted.
|
||||
start_between : bool
|
||||
If True, the layout starts with the position between the nodes. This is
|
||||
the same as adding "180. / len(node_names)" to start_pos.
|
||||
group_boundaries : None | array-like
|
||||
List of of boundaries between groups at which point a "group_sep" will
|
||||
be inserted. E.g. "[0, len(node_names) / 2]" will create two groups.
|
||||
group_sep : float
|
||||
Group separation angle in degrees. See "group_boundaries".
|
||||
|
||||
Returns
|
||||
-------
|
||||
node_angles : array, shape=(n_node_names,)
|
||||
Node angles in degrees.
|
||||
"""
|
||||
n_nodes = len(node_names)
|
||||
|
||||
if len(node_order) != n_nodes:
|
||||
raise ValueError("node_order has to be the same length as node_names")
|
||||
|
||||
if group_boundaries is not None:
|
||||
boundaries = np.array(group_boundaries, dtype=np.int64)
|
||||
if np.any(boundaries >= n_nodes) or np.any(boundaries < 0):
|
||||
raise ValueError('"group_boundaries" has to be between 0 and n_nodes - 1.')
|
||||
if len(boundaries) > 1 and np.any(np.diff(boundaries) <= 0):
|
||||
raise ValueError('"group_boundaries" must have non-decreasing values.')
|
||||
n_group_sep = len(group_boundaries)
|
||||
else:
|
||||
n_group_sep = 0
|
||||
boundaries = None
|
||||
|
||||
# convert it to a list with indices
|
||||
node_order = [node_order.index(name) for name in node_names]
|
||||
node_order = np.array(node_order)
|
||||
if len(np.unique(node_order)) != n_nodes:
|
||||
raise ValueError("node_order has repeated entries")
|
||||
|
||||
node_sep = (360.0 - n_group_sep * group_sep) / n_nodes
|
||||
|
||||
if start_between:
|
||||
start_pos += node_sep / 2
|
||||
|
||||
if boundaries is not None and boundaries[0] == 0:
|
||||
# special case when a group separator is at the start
|
||||
start_pos += group_sep / 2
|
||||
boundaries = boundaries[1:] if n_group_sep > 1 else None
|
||||
|
||||
node_angles = np.ones(n_nodes, dtype=np.float64) * node_sep
|
||||
node_angles[0] = start_pos
|
||||
if boundaries is not None:
|
||||
node_angles[boundaries] += group_sep
|
||||
|
||||
node_angles = np.cumsum(node_angles)[node_order]
|
||||
|
||||
return node_angles
|
||||
|
||||
|
||||
def _plot_connectivity_circle_onpick(
|
||||
event,
|
||||
fig=None,
|
||||
ax=None,
|
||||
indices=None,
|
||||
n_nodes=0,
|
||||
node_angles=None,
|
||||
ylim=(9, 10),
|
||||
):
|
||||
"""Isolate connections around a single node when user left clicks a node.
|
||||
|
||||
On right click, resets all connections.
|
||||
"""
|
||||
if event.inaxes != ax:
|
||||
return
|
||||
|
||||
if event.button == 1: # left click
|
||||
# click must be near node radius
|
||||
if not ylim[0] <= event.ydata <= ylim[1]:
|
||||
return
|
||||
|
||||
# all angles in range [0, 2*pi]
|
||||
node_angles = node_angles % (np.pi * 2)
|
||||
node = np.argmin(np.abs(event.xdata - node_angles))
|
||||
|
||||
patches = event.inaxes.patches
|
||||
for ii, (x, y) in enumerate(zip(indices[0], indices[1])):
|
||||
patches[ii].set_visible(node in [x, y])
|
||||
fig.canvas.draw()
|
||||
elif event.button == 3: # right click
|
||||
patches = event.inaxes.patches
|
||||
for ii in range(np.size(indices, axis=1)):
|
||||
patches[ii].set_visible(True)
|
||||
fig.canvas.draw()
|
||||
|
||||
|
||||
def _plot_connectivity_circle(
|
||||
con,
|
||||
node_names,
|
||||
indices=None,
|
||||
n_lines=None,
|
||||
node_angles=None,
|
||||
node_width=None,
|
||||
node_height=None,
|
||||
node_colors=None,
|
||||
facecolor="black",
|
||||
textcolor="white",
|
||||
node_edgecolor="black",
|
||||
linewidth=1.5,
|
||||
colormap="hot",
|
||||
vmin=None,
|
||||
vmax=None,
|
||||
colorbar=True,
|
||||
title=None,
|
||||
colorbar_size=None,
|
||||
colorbar_pos=None,
|
||||
fontsize_title=12,
|
||||
fontsize_names=8,
|
||||
fontsize_colorbar=8,
|
||||
padding=6.0,
|
||||
ax=None,
|
||||
interactive=True,
|
||||
node_linewidth=2.0,
|
||||
show=True,
|
||||
):
|
||||
import matplotlib.patches as m_patches
|
||||
import matplotlib.path as m_path
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.projections.polar import PolarAxes
|
||||
|
||||
_validate_type(ax, (None, PolarAxes))
|
||||
|
||||
n_nodes = len(node_names)
|
||||
|
||||
if node_angles is not None:
|
||||
if len(node_angles) != n_nodes:
|
||||
raise ValueError("node_angles has to be the same length as node_names")
|
||||
# convert it to radians
|
||||
node_angles = node_angles * np.pi / 180
|
||||
else:
|
||||
# uniform layout on unit circle
|
||||
node_angles = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False)
|
||||
|
||||
if node_width is None:
|
||||
# widths correspond to the minimum angle between two nodes
|
||||
dist_mat = node_angles[None, :] - node_angles[:, None]
|
||||
dist_mat[np.diag_indices(n_nodes)] = 1e9
|
||||
node_width = np.min(np.abs(dist_mat))
|
||||
else:
|
||||
node_width = node_width * np.pi / 180
|
||||
|
||||
if node_height is None:
|
||||
node_height = 1.0
|
||||
|
||||
if node_colors is not None:
|
||||
if len(node_colors) < n_nodes:
|
||||
node_colors = cycle(node_colors)
|
||||
else:
|
||||
# assign colors using colormap
|
||||
try:
|
||||
spectral = plt.cm.spectral
|
||||
except AttributeError:
|
||||
spectral = plt.cm.Spectral
|
||||
node_colors = [spectral(i / float(n_nodes)) for i in range(n_nodes)]
|
||||
|
||||
# handle 1D and 2D connectivity information
|
||||
if con.ndim == 1:
|
||||
if indices is None:
|
||||
raise ValueError("indices has to be provided if con.ndim == 1")
|
||||
elif con.ndim == 2:
|
||||
if con.shape[0] != n_nodes or con.shape[1] != n_nodes:
|
||||
raise ValueError("con has to be 1D or a square matrix")
|
||||
# we use the lower-triangular part
|
||||
indices = np.tril_indices(n_nodes, -1)
|
||||
con = con[indices]
|
||||
else:
|
||||
raise ValueError("con has to be 1D or a square matrix")
|
||||
|
||||
# get the colormap
|
||||
colormap = _get_cmap(colormap)
|
||||
|
||||
# Use a polar axes
|
||||
if ax is None:
|
||||
fig = plt.figure(figsize=(8, 8), facecolor=facecolor, layout="constrained")
|
||||
ax = fig.add_subplot(polar=True)
|
||||
else:
|
||||
fig = ax.figure
|
||||
ax.set_facecolor(facecolor)
|
||||
|
||||
# No ticks, we'll put our own
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
# Set y axes limit, add additional space if requested
|
||||
ax.set_ylim(0, 10 + padding)
|
||||
|
||||
# Remove the black axes border which may obscure the labels
|
||||
ax.spines["polar"].set_visible(False)
|
||||
|
||||
# Draw lines between connected nodes, only draw the strongest connections
|
||||
if n_lines is not None and len(con) > n_lines:
|
||||
con_thresh = np.sort(np.abs(con).ravel())[-n_lines]
|
||||
else:
|
||||
con_thresh = 0.0
|
||||
|
||||
# get the connections which we are drawing and sort by connection strength
|
||||
# this will allow us to draw the strongest connections first
|
||||
con_abs = np.abs(con)
|
||||
con_draw_idx = np.where(con_abs >= con_thresh)[0]
|
||||
|
||||
con = con[con_draw_idx]
|
||||
con_abs = con_abs[con_draw_idx]
|
||||
indices = [ind[con_draw_idx] for ind in indices]
|
||||
|
||||
# now sort them
|
||||
sort_idx = np.argsort(con_abs)
|
||||
del con_abs
|
||||
con = con[sort_idx]
|
||||
indices = [ind[sort_idx] for ind in indices]
|
||||
|
||||
# Get vmin vmax for color scaling
|
||||
if vmin is None:
|
||||
vmin = np.min(con[np.abs(con) >= con_thresh])
|
||||
if vmax is None:
|
||||
vmax = np.max(con)
|
||||
vrange = vmax - vmin
|
||||
|
||||
# We want to add some "noise" to the start and end position of the
|
||||
# edges: We modulate the noise with the number of connections of the
|
||||
# node and the connection strength, such that the strongest connections
|
||||
# are closer to the node center
|
||||
nodes_n_con = np.zeros((n_nodes), dtype=np.int64)
|
||||
for i, j in zip(indices[0], indices[1]):
|
||||
nodes_n_con[i] += 1
|
||||
nodes_n_con[j] += 1
|
||||
|
||||
# initialize random number generator so plot is reproducible
|
||||
rng = np.random.mtrand.RandomState(0)
|
||||
|
||||
n_con = len(indices[0])
|
||||
noise_max = 0.25 * node_width
|
||||
start_noise = rng.uniform(-noise_max, noise_max, n_con)
|
||||
end_noise = rng.uniform(-noise_max, noise_max, n_con)
|
||||
|
||||
nodes_n_con_seen = np.zeros_like(nodes_n_con)
|
||||
for i, (start, end) in enumerate(zip(indices[0], indices[1])):
|
||||
nodes_n_con_seen[start] += 1
|
||||
nodes_n_con_seen[end] += 1
|
||||
|
||||
start_noise[i] *= (nodes_n_con[start] - nodes_n_con_seen[start]) / float(
|
||||
nodes_n_con[start]
|
||||
)
|
||||
end_noise[i] *= (nodes_n_con[end] - nodes_n_con_seen[end]) / float(
|
||||
nodes_n_con[end]
|
||||
)
|
||||
|
||||
# scale connectivity for colormap (vmin<=>0, vmax<=>1)
|
||||
con_val_scaled = (con - vmin) / vrange
|
||||
|
||||
# Finally, we draw the connections
|
||||
for pos, (i, j) in enumerate(zip(indices[0], indices[1])):
|
||||
# Start point
|
||||
t0, r0 = node_angles[i], 10
|
||||
|
||||
# End point
|
||||
t1, r1 = node_angles[j], 10
|
||||
|
||||
# Some noise in start and end point
|
||||
t0 += start_noise[pos]
|
||||
t1 += end_noise[pos]
|
||||
|
||||
verts = [(t0, r0), (t0, 5), (t1, 5), (t1, r1)]
|
||||
codes = [
|
||||
m_path.Path.MOVETO,
|
||||
m_path.Path.CURVE4,
|
||||
m_path.Path.CURVE4,
|
||||
m_path.Path.LINETO,
|
||||
]
|
||||
path = m_path.Path(verts, codes)
|
||||
|
||||
color = colormap(con_val_scaled[pos])
|
||||
|
||||
# Actual line
|
||||
patch = m_patches.PathPatch(
|
||||
path, fill=False, edgecolor=color, linewidth=linewidth, alpha=1.0
|
||||
)
|
||||
ax.add_patch(patch)
|
||||
|
||||
# Draw ring with colored nodes
|
||||
height = np.ones(n_nodes) * node_height
|
||||
bars = ax.bar(
|
||||
node_angles,
|
||||
height,
|
||||
width=node_width,
|
||||
bottom=9,
|
||||
edgecolor=node_edgecolor,
|
||||
lw=node_linewidth,
|
||||
facecolor=".9",
|
||||
align="center",
|
||||
)
|
||||
|
||||
for bar, color in zip(bars, node_colors):
|
||||
bar.set_facecolor(color)
|
||||
|
||||
# Draw node labels
|
||||
angles_deg = 180 * node_angles / np.pi
|
||||
for name, angle_rad, angle_deg in zip(node_names, node_angles, angles_deg):
|
||||
if angle_deg >= 270:
|
||||
ha = "left"
|
||||
else:
|
||||
# Flip the label, so text is always upright
|
||||
angle_deg += 180
|
||||
ha = "right"
|
||||
|
||||
ax.text(
|
||||
angle_rad,
|
||||
9.4 + node_height,
|
||||
name,
|
||||
size=fontsize_names,
|
||||
rotation=angle_deg,
|
||||
rotation_mode="anchor",
|
||||
horizontalalignment=ha,
|
||||
verticalalignment="center",
|
||||
color=textcolor,
|
||||
)
|
||||
|
||||
if title is not None:
|
||||
ax.set_title(title, color=textcolor, fontsize=fontsize_title)
|
||||
|
||||
if colorbar:
|
||||
sm = plt.cm.ScalarMappable(cmap=colormap, norm=plt.Normalize(vmin, vmax))
|
||||
sm.set_array(np.linspace(vmin, vmax))
|
||||
colorbar_kwargs = dict()
|
||||
if colorbar_size is not None:
|
||||
colorbar_kwargs.update(shrink=colorbar_size)
|
||||
if colorbar_pos is not None:
|
||||
colorbar_kwargs.update(anchor=colorbar_pos)
|
||||
cb = fig.colorbar(sm, ax=ax, **colorbar_kwargs)
|
||||
cb_yticks = plt.getp(cb.ax.axes, "yticklabels")
|
||||
cb.ax.tick_params(labelsize=fontsize_colorbar)
|
||||
plt.setp(cb_yticks, color=textcolor)
|
||||
|
||||
# Add callback for interaction
|
||||
if interactive:
|
||||
callback = partial(
|
||||
_plot_connectivity_circle_onpick,
|
||||
fig=fig,
|
||||
ax=ax,
|
||||
indices=indices,
|
||||
n_nodes=n_nodes,
|
||||
node_angles=node_angles,
|
||||
)
|
||||
|
||||
fig.canvas.mpl_connect("button_press_event", callback)
|
||||
|
||||
plt_show(show)
|
||||
return fig, ax
|
||||
|
||||
|
||||
def plot_channel_labels_circle(labels, colors=None, picks=None, **kwargs):
|
||||
"""Plot labels for each channel in a circle plot.
|
||||
|
||||
.. note:: This primarily makes sense for sEEG channels where each
|
||||
channel can be assigned an anatomical label as the electrode
|
||||
passes through various brain areas.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
labels : dict
|
||||
Lists of labels (values) associated with each channel (keys).
|
||||
colors : dict
|
||||
The color (value) for each label (key).
|
||||
picks : list | tuple
|
||||
The channels to consider.
|
||||
**kwargs : kwargs
|
||||
Keyword arguments for
|
||||
:func:`mne_connectivity.viz.plot_connectivity_circle`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
fig : instance of matplotlib.figure.Figure
|
||||
The figure handle.
|
||||
axes : instance of matplotlib.projections.polar.PolarAxes
|
||||
The subplot handle.
|
||||
"""
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
|
||||
_validate_type(labels, dict, "labels")
|
||||
_validate_type(colors, (dict, None), "colors")
|
||||
_validate_type(picks, (list, tuple, None), "picks")
|
||||
if picks is not None:
|
||||
labels = {k: v for k, v in labels.items() if k in picks}
|
||||
ch_names = list(labels.keys())
|
||||
all_labels = list(set([label for val in labels.values() for label in val]))
|
||||
n_labels = len(all_labels)
|
||||
if colors is not None:
|
||||
for label in all_labels:
|
||||
if label not in colors:
|
||||
raise ValueError(f"No color provided for {label} in `colors`")
|
||||
# update all_labels, there may be unconnected labels in colors
|
||||
all_labels = list(colors.keys())
|
||||
n_labels = len(all_labels)
|
||||
# make colormap
|
||||
label_colors = [colors[label] for label in all_labels]
|
||||
node_colors = ["black"] * len(ch_names) + label_colors
|
||||
label_cmap = LinearSegmentedColormap.from_list(
|
||||
"label_cmap", label_colors, N=len(label_colors)
|
||||
)
|
||||
else:
|
||||
node_colors = None
|
||||
|
||||
node_names = ch_names + all_labels
|
||||
con = np.zeros((len(node_names), len(node_names))) * np.nan
|
||||
for idx, ch_name in enumerate(ch_names):
|
||||
for label in labels[ch_name]:
|
||||
node_idx = node_names.index(label)
|
||||
label_color = all_labels.index(label) / n_labels
|
||||
con[idx, node_idx] = con[node_idx, idx] = label_color # symmetric
|
||||
# plot
|
||||
node_order = ch_names + all_labels[::-1]
|
||||
node_angles = circular_layout(
|
||||
node_names, node_order, start_pos=90, group_boundaries=[0, len(ch_names)]
|
||||
)
|
||||
# provide defaults but don't overwrite
|
||||
if "node_angles" not in kwargs:
|
||||
kwargs.update(node_angles=node_angles)
|
||||
if "colorbar" not in kwargs:
|
||||
kwargs.update(colorbar=False)
|
||||
if "node_colors" not in kwargs:
|
||||
kwargs.update(node_colors=node_colors)
|
||||
if "vmin" not in kwargs:
|
||||
kwargs.update(vmin=0)
|
||||
if "vmax" not in kwargs:
|
||||
kwargs.update(vmax=1)
|
||||
if "colormap" not in kwargs:
|
||||
kwargs.update(colormap=label_cmap)
|
||||
return _plot_connectivity_circle(con, node_names, **kwargs)
|
||||
@@ -0,0 +1,47 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import os.path as op
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mne import Epochs, EvokedArray, create_info, events_from_annotations
|
||||
from mne.channels import make_standard_montage
|
||||
from mne.datasets.testing import _pytest_param, data_path
|
||||
from mne.io import read_raw_nirx
|
||||
from mne.preprocessing.nirs import beer_lambert_law, optical_density
|
||||
|
||||
fname_nirx = op.join(
|
||||
data_path(download=False), "NIRx", "nirscout", "nirx_15_2_recording_w_overlap"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fnirs_evoked():
|
||||
"""Create an fnirs evoked structure."""
|
||||
montage = make_standard_montage("biosemi16")
|
||||
ch_names = montage.ch_names
|
||||
ch_types = ["eeg"] * 16
|
||||
info = create_info(ch_names=ch_names, sfreq=20, ch_types=ch_types)
|
||||
evoked_data = np.random.randn(16, 30)
|
||||
evoked = EvokedArray(evoked_data, info=info, tmin=-0.2, nave=4)
|
||||
evoked.set_montage(montage)
|
||||
evoked.set_channel_types(
|
||||
{"Fp1": "hbo", "Fp2": "hbo", "F4": "hbo", "Fz": "hbo"}, verbose="error"
|
||||
)
|
||||
return evoked
|
||||
|
||||
|
||||
@pytest.fixture(params=[_pytest_param()])
|
||||
def fnirs_epochs():
|
||||
"""Create an fnirs epoch structure."""
|
||||
raw_intensity = read_raw_nirx(fname_nirx, preload=False)
|
||||
raw_od = optical_density(raw_intensity)
|
||||
raw_haemo = beer_lambert_law(raw_od, ppf=6.0)
|
||||
evts, _ = events_from_annotations(raw_haemo, event_id={"1.0": 1})
|
||||
evts_dct = {"A": 1}
|
||||
tn, tx = -1, 2
|
||||
epochs = Epochs(raw_haemo, evts, event_id=evts_dct, tmin=tn, tmax=tx)
|
||||
return epochs
|
||||
+1160
File diff suppressed because it is too large
Load Diff
+3207
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,577 @@
|
||||
"""Class to draw evoked MEG and EEG fieldlines, with a GUI to control the figure.
|
||||
|
||||
author: Marijn van Vliet <w.m.vanvliet@gmail.com>
|
||||
"""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from scipy.interpolate import interp1d
|
||||
|
||||
from .._fiff.pick import pick_types
|
||||
from ..defaults import DEFAULTS
|
||||
from ..utils import (
|
||||
_auto_weakref,
|
||||
_check_option,
|
||||
_ensure_int,
|
||||
_to_rgb,
|
||||
_validate_type,
|
||||
fill_doc,
|
||||
)
|
||||
from ._3d_overlay import _LayeredMesh
|
||||
from .ui_events import (
|
||||
ColormapRange,
|
||||
Contours,
|
||||
TimeChange,
|
||||
disable_ui_events,
|
||||
publish,
|
||||
subscribe,
|
||||
)
|
||||
from .utils import mne_analyze_colormap
|
||||
|
||||
|
||||
@fill_doc
|
||||
class EvokedField:
|
||||
"""Plot MEG/EEG fields on head surface and helmet in 3D.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
evoked : instance of mne.Evoked
|
||||
The evoked object.
|
||||
surf_maps : list
|
||||
The surface mapping information obtained with make_field_map.
|
||||
time : float | None
|
||||
The time point at which the field map shall be displayed. If None,
|
||||
the average peak latency (across sensor types) is used.
|
||||
time_label : str | None
|
||||
How to print info about the time instant visualized.
|
||||
%(n_jobs)s
|
||||
fig : instance of Figure3D | None
|
||||
If None (default), a new figure will be created, otherwise it will
|
||||
plot into the given figure.
|
||||
|
||||
.. versionadded:: 0.20
|
||||
vmax : float | dict | None
|
||||
Maximum intensity. Can be a dictionary with two entries ``"eeg"`` and ``"meg"``
|
||||
to specify separate values for EEG and MEG fields respectively. Can be
|
||||
``None`` to use the maximum value of the data.
|
||||
|
||||
.. versionadded:: 0.21
|
||||
.. versionadded:: 1.4
|
||||
``vmax`` can be a dictionary to specify separate values for EEG and
|
||||
MEG fields.
|
||||
n_contours : int
|
||||
The number of contours.
|
||||
|
||||
.. versionadded:: 0.21
|
||||
show_density : bool
|
||||
Whether to draw the field density as an overlay on top of the helmet/head
|
||||
surface. Defaults to ``True``.
|
||||
alpha : float | dict | None
|
||||
Opacity of the meshes (between 0 and 1). Can be a dictionary with two
|
||||
entries ``"eeg"`` and ``"meg"`` to specify separate values for EEG and
|
||||
MEG fields respectively. Can be ``None`` to use 1.0 when a single field
|
||||
map is shown, or ``dict(eeg=1.0, meg=0.5)`` when both field maps are shown.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
%(interpolation_brain_time)s
|
||||
|
||||
.. versionadded:: 1.6
|
||||
%(interaction_scene)s
|
||||
Defaults to ``'terrain'``.
|
||||
|
||||
.. versionadded:: 1.1
|
||||
time_viewer : bool | str
|
||||
Display time viewer GUI. Can also be ``"auto"``, which will mean
|
||||
``True`` if there is more than one time point and ``False`` otherwise.
|
||||
|
||||
.. versionadded:: 1.6
|
||||
%(verbose)s
|
||||
|
||||
Notes
|
||||
-----
|
||||
The figure will publish and subscribe to the following UI events:
|
||||
|
||||
* :class:`~mne.viz.ui_events.TimeChange`
|
||||
* :class:`~mne.viz.ui_events.Contours`, ``kind="field_strength_meg" | "field_strength_eeg"``
|
||||
* :class:`~mne.viz.ui_events.ColormapRange`, ``kind="field_strength_meg" | "field_strength_eeg"``
|
||||
""" # noqa
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
evoked,
|
||||
surf_maps,
|
||||
*,
|
||||
time=None,
|
||||
time_label="t = %0.0f ms",
|
||||
n_jobs=None,
|
||||
fig=None,
|
||||
vmax=None,
|
||||
n_contours=21,
|
||||
show_density=True,
|
||||
alpha=None,
|
||||
interpolation="nearest",
|
||||
interaction="terrain",
|
||||
time_viewer="auto",
|
||||
verbose=None,
|
||||
):
|
||||
from .backends.renderer import _get_3d_backend, _get_renderer
|
||||
|
||||
# Setup figure parameters
|
||||
self._evoked = evoked
|
||||
if time is None:
|
||||
types = [t for t in ["eeg", "grad", "mag"] if t in evoked]
|
||||
time = np.mean([evoked.get_peak(ch_type=t)[1] for t in types])
|
||||
self._current_time = time
|
||||
if not evoked.times[0] <= time <= evoked.times[-1]:
|
||||
raise ValueError(f"`time` ({time:0.3f}) must be inside `evoked.times`")
|
||||
self._time_label = time_label
|
||||
|
||||
self._vmax = _validate_type(vmax, (None, "numeric", dict), "vmax")
|
||||
self._n_contours = _ensure_int(n_contours, "n_contours")
|
||||
self._time_interpolation = _check_option(
|
||||
"interpolation",
|
||||
interpolation,
|
||||
("linear", "nearest", "zero", "slinear", "quadratic", "cubic"),
|
||||
)
|
||||
self._interaction = _check_option(
|
||||
"interaction", interaction, ["trackball", "terrain"]
|
||||
)
|
||||
|
||||
surf_map_kinds = [surf_map["kind"] for surf_map in surf_maps]
|
||||
if vmax is None:
|
||||
self._vmax = {kind: None for kind in surf_map_kinds}
|
||||
elif isinstance(vmax, dict):
|
||||
for kind in surf_map_kinds:
|
||||
if kind not in vmax:
|
||||
raise ValueError(
|
||||
f'No entry for "{kind}" found in the vmax dictionary'
|
||||
)
|
||||
self._vmax = vmax
|
||||
else: # float value
|
||||
self._vmax = {kind: vmax for kind in surf_map_kinds}
|
||||
|
||||
if alpha is None:
|
||||
self._alpha = {
|
||||
surf_map["kind"]: val for surf_map, val in zip(surf_maps, [1.0, 0.5])
|
||||
}
|
||||
elif isinstance(alpha, dict):
|
||||
for kind in surf_map_kinds:
|
||||
if kind not in alpha:
|
||||
raise ValueError(
|
||||
f'No entry for "{kind}" found in the alpha dictionary'
|
||||
)
|
||||
self._alpha = alpha
|
||||
else: # float value
|
||||
self._alpha = {kind: alpha for kind in surf_map_kinds}
|
||||
|
||||
self._colors = [(0.6, 0.6, 0.6), (1.0, 1.0, 1.0)]
|
||||
self._colormap = mne_analyze_colormap(format="vtk")
|
||||
self._colormap_lines = np.concatenate(
|
||||
[
|
||||
np.tile([0.0, 0.0, 255.0, 255.0], (127, 1)),
|
||||
np.tile([0.0, 0.0, 0.0, 255.0], (2, 1)),
|
||||
np.tile([255.0, 0.0, 0.0, 255.0], (127, 1)),
|
||||
]
|
||||
)
|
||||
self._show_density = show_density
|
||||
|
||||
from ._brain import Brain
|
||||
|
||||
if isinstance(fig, Brain):
|
||||
self._renderer = fig._renderer
|
||||
self._in_brain_figure = True
|
||||
if _get_3d_backend() == "notebook":
|
||||
raise NotImplementedError(
|
||||
"Plotting on top of an existing Brain figure "
|
||||
"is currently not supported inside a notebook."
|
||||
)
|
||||
else:
|
||||
self._renderer = _get_renderer(
|
||||
fig, bgcolor=(0.0, 0.0, 0.0), size=(600, 600)
|
||||
)
|
||||
self._in_brain_figure = False
|
||||
|
||||
self.plotter = self._renderer.plotter
|
||||
self.interaction = interaction
|
||||
|
||||
# Prepare the surface maps
|
||||
self._surf_maps = [
|
||||
self._prepare_surf_map(surf_map, color, self._alpha[surf_map["kind"]])
|
||||
for surf_map, color in zip(surf_maps, self._colors)
|
||||
]
|
||||
|
||||
# Do we want the time viewer?
|
||||
if time_viewer == "auto":
|
||||
time_viewer = len(evoked.times) > 1
|
||||
self.time_viewer = time_viewer
|
||||
|
||||
# Configure UI events
|
||||
@_auto_weakref
|
||||
def current_time_func():
|
||||
return self._current_time
|
||||
|
||||
self._widgets = dict()
|
||||
if self.time_viewer:
|
||||
# Draw widgets only if not inside a figure that already has them.
|
||||
if (
|
||||
not hasattr(self._renderer, "_widgets")
|
||||
or "time_slider" not in self._renderer._widgets
|
||||
):
|
||||
self._renderer._enable_time_interaction(
|
||||
self,
|
||||
current_time_func=current_time_func,
|
||||
times=evoked.times,
|
||||
)
|
||||
if not self._in_brain_figure or "time_slider" not in fig.widgets:
|
||||
# Draw the time label
|
||||
self._time_label = time_label
|
||||
if time_label is not None:
|
||||
if "%" in time_label:
|
||||
time_label = time_label % np.round(1e3 * time)
|
||||
self._time_label_actor = self._renderer.text2d(
|
||||
x_window=0.01, y_window=0.01, text=time_label
|
||||
)
|
||||
self._configure_dock()
|
||||
|
||||
subscribe(self, "time_change", self._on_time_change)
|
||||
subscribe(self, "colormap_range", self._on_colormap_range)
|
||||
subscribe(self, "contours", self._on_contours)
|
||||
|
||||
if not self._in_brain_figure:
|
||||
self._renderer.set_interaction(interaction)
|
||||
self._renderer.set_camera(azimuth=10, elevation=60, distance="auto")
|
||||
self._renderer.show()
|
||||
|
||||
def _prepare_surf_map(self, surf_map, color, alpha):
|
||||
"""Compute all the data required to render a fieldlines map."""
|
||||
if surf_map["kind"] == "eeg":
|
||||
pick = pick_types(self._evoked.info, meg=False, eeg=True)
|
||||
else:
|
||||
pick = pick_types(self._evoked.info, meg=True, eeg=False, ref_meg=False)
|
||||
|
||||
evoked_ch_names = set([self._evoked.ch_names[k] for k in pick])
|
||||
map_ch_names = set(surf_map["ch_names"])
|
||||
if evoked_ch_names != map_ch_names:
|
||||
message = ["Channels in map and data do not match."]
|
||||
diff = map_ch_names - evoked_ch_names
|
||||
if len(diff):
|
||||
message += [f"{list(diff)} not in data file. "]
|
||||
diff = evoked_ch_names - map_ch_names
|
||||
if len(diff):
|
||||
message += [f"{list(diff)} not in map file."]
|
||||
raise RuntimeError(" ".join(message))
|
||||
|
||||
data = surf_map["data"] @ self._evoked.data[pick]
|
||||
data_interp = interp1d(
|
||||
self._evoked.times,
|
||||
data,
|
||||
kind=self._time_interpolation,
|
||||
assume_sorted=True,
|
||||
)
|
||||
current_data = data_interp(self._current_time)
|
||||
|
||||
# Make a solid surface
|
||||
surf = surf_map["surf"]
|
||||
if self._in_brain_figure:
|
||||
surf["rr"] *= 1000
|
||||
map_vmax = self._vmax.get(surf_map["kind"])
|
||||
if map_vmax is None:
|
||||
map_vmax = float(np.max(current_data))
|
||||
mesh = _LayeredMesh(
|
||||
renderer=self._renderer,
|
||||
vertices=surf["rr"],
|
||||
triangles=surf["tris"],
|
||||
normals=surf["nn"],
|
||||
)
|
||||
mesh.map()
|
||||
color = _to_rgb(color, alpha=True)
|
||||
cmap = np.array([(0, 0, 0, 0), color])
|
||||
ctable = np.round(cmap * 255).astype(np.uint8)
|
||||
mesh.add_overlay(
|
||||
scalars=np.ones(len(current_data)),
|
||||
colormap=ctable,
|
||||
rng=[0, 1],
|
||||
opacity=alpha,
|
||||
name="surf",
|
||||
)
|
||||
|
||||
# Show the field density
|
||||
if self._show_density:
|
||||
mesh.add_overlay(
|
||||
scalars=current_data,
|
||||
colormap=self._colormap,
|
||||
rng=[-map_vmax, map_vmax],
|
||||
opacity=1.0,
|
||||
name="field",
|
||||
)
|
||||
|
||||
# And the field lines on top
|
||||
if self._n_contours > 1:
|
||||
contours = np.linspace(-map_vmax, map_vmax, self._n_contours)
|
||||
contours_actor, _ = self._renderer.contour(
|
||||
surface=surf,
|
||||
scalars=current_data,
|
||||
contours=contours,
|
||||
vmin=-map_vmax,
|
||||
vmax=map_vmax,
|
||||
colormap=self._colormap_lines,
|
||||
)
|
||||
else:
|
||||
contours = None # noqa
|
||||
contours_actor = None
|
||||
|
||||
return dict(
|
||||
pick=pick,
|
||||
data=data,
|
||||
data_interp=data_interp,
|
||||
map_kind=surf_map["kind"],
|
||||
mesh=mesh,
|
||||
contours=contours,
|
||||
contours_actor=contours_actor,
|
||||
surf=surf,
|
||||
map_vmax=map_vmax,
|
||||
)
|
||||
|
||||
def _update(self):
|
||||
"""Update the figure to reflect the current settings."""
|
||||
for surf_map in self._surf_maps:
|
||||
current_data = surf_map["data_interp"](self._current_time)
|
||||
surf_map["mesh"].update_overlay(name="field", scalars=current_data)
|
||||
|
||||
if surf_map["contours"] is not None:
|
||||
self._renderer.plotter.remove_actor(
|
||||
surf_map["contours_actor"], render=False
|
||||
)
|
||||
if self._n_contours > 1:
|
||||
surf_map["contours_actor"], _ = self._renderer.contour(
|
||||
surface=surf_map["surf"],
|
||||
scalars=current_data,
|
||||
contours=surf_map["contours"],
|
||||
vmin=-surf_map["map_vmax"],
|
||||
vmax=surf_map["map_vmax"],
|
||||
colormap=self._colormap_lines,
|
||||
)
|
||||
if self._time_label is not None:
|
||||
if hasattr(self, "_time_label_actor"):
|
||||
self._renderer.plotter.remove_actor(
|
||||
self._time_label_actor, render=False
|
||||
)
|
||||
time_label = self._time_label
|
||||
if "%" in self._time_label:
|
||||
time_label = self._time_label % np.round(1e3 * self._current_time)
|
||||
self._time_label_actor = self._renderer.text2d(
|
||||
x_window=0.01, y_window=0.01, text=time_label
|
||||
)
|
||||
|
||||
self._renderer.plotter.update()
|
||||
|
||||
def _configure_dock(self):
|
||||
"""Configure the widgets shown in the dock on the left."""
|
||||
r = self._renderer
|
||||
|
||||
if not hasattr(r, "_dock"):
|
||||
r._dock_initialize()
|
||||
|
||||
# Fieldline configuration
|
||||
layout = r._dock_add_group_box("Fieldlines")
|
||||
|
||||
if self._show_density:
|
||||
r._dock_add_label(value="max value", align=True, layout=layout)
|
||||
|
||||
@_auto_weakref
|
||||
def _callback(vmax, kind, scaling):
|
||||
self.set_vmax(vmax / scaling, kind=kind)
|
||||
|
||||
for surf_map in self._surf_maps:
|
||||
if surf_map["map_kind"] == "meg":
|
||||
scaling = DEFAULTS["scalings"]["grad"]
|
||||
else:
|
||||
scaling = DEFAULTS["scalings"]["eeg"]
|
||||
rng = [0, np.max(np.abs(surf_map["data"])) * scaling]
|
||||
hlayout = r._dock_add_layout(vertical=False)
|
||||
|
||||
self._widgets[f"vmax_slider_{surf_map['map_kind']}"] = (
|
||||
r._dock_add_slider(
|
||||
name=surf_map["map_kind"].upper(),
|
||||
value=surf_map["map_vmax"] * scaling,
|
||||
rng=rng,
|
||||
callback=partial(
|
||||
_callback, kind=surf_map["map_kind"], scaling=scaling
|
||||
),
|
||||
double=True,
|
||||
layout=hlayout,
|
||||
)
|
||||
)
|
||||
self._widgets[f"vmax_spin_{surf_map['map_kind']}"] = (
|
||||
r._dock_add_spin_box(
|
||||
name="",
|
||||
value=surf_map["map_vmax"] * scaling,
|
||||
rng=rng,
|
||||
callback=partial(
|
||||
_callback, kind=surf_map["map_kind"], scaling=scaling
|
||||
),
|
||||
layout=hlayout,
|
||||
)
|
||||
)
|
||||
r._layout_add_widget(layout, hlayout)
|
||||
|
||||
hlayout = r._dock_add_layout(vertical=False)
|
||||
r._dock_add_label(
|
||||
value="Rescale",
|
||||
align=True,
|
||||
layout=hlayout,
|
||||
)
|
||||
r._dock_add_button(
|
||||
name="↺",
|
||||
callback=self._rescale,
|
||||
layout=hlayout,
|
||||
style="toolbutton",
|
||||
)
|
||||
r._layout_add_widget(layout, hlayout)
|
||||
|
||||
self._widgets["contours"] = r._dock_add_spin_box(
|
||||
name="Contour lines",
|
||||
value=21,
|
||||
rng=[0, 99],
|
||||
step=1,
|
||||
double=False,
|
||||
callback=self.set_contours,
|
||||
layout=layout,
|
||||
)
|
||||
r._dock_finalize()
|
||||
|
||||
def _on_time_change(self, event):
|
||||
"""Respond to time_change UI event."""
|
||||
new_time = np.clip(event.time, self._evoked.times[0], self._evoked.times[-1])
|
||||
if new_time == self._current_time:
|
||||
return
|
||||
self._current_time = new_time
|
||||
self._update()
|
||||
|
||||
def _on_colormap_range(self, event):
|
||||
"""Response to the colormap_range UI event."""
|
||||
if event.kind == "field_strength_meg":
|
||||
kind = "meg"
|
||||
elif event.kind == "field_strength_eeg":
|
||||
kind = "eeg"
|
||||
else:
|
||||
return
|
||||
|
||||
for surf_map in self._surf_maps:
|
||||
if surf_map["map_kind"] == kind:
|
||||
break
|
||||
else:
|
||||
# No field map currently shown of the requested type.
|
||||
return
|
||||
|
||||
vmin = event.fmin
|
||||
vmax = event.fmax
|
||||
surf_map["contours"] = np.linspace(vmin, vmax, self._n_contours)
|
||||
|
||||
if self._show_density:
|
||||
surf_map["mesh"].update_overlay(name="field", rng=[vmin, vmax])
|
||||
# Update the GUI widgets
|
||||
if kind == "meg":
|
||||
scaling = DEFAULTS["scalings"]["grad"]
|
||||
else:
|
||||
scaling = DEFAULTS["scalings"]["eeg"]
|
||||
with disable_ui_events(self):
|
||||
widget = self._widgets.get(f"vmax_slider_{kind}", None)
|
||||
if widget is not None:
|
||||
widget.set_value(vmax * scaling)
|
||||
widget = self._widgets.get(f"vmax_spin_{kind}", None)
|
||||
if widget is not None:
|
||||
widget.set_value(vmax * scaling)
|
||||
|
||||
self._update()
|
||||
|
||||
def _on_contours(self, event):
|
||||
"""Respond to the contours UI event."""
|
||||
if event.kind == "field_strength_meg":
|
||||
kind = "meg"
|
||||
elif event.kind == "field_strength_eeg":
|
||||
kind = "eeg"
|
||||
else:
|
||||
return
|
||||
|
||||
for surf_map in self._surf_maps:
|
||||
if surf_map["map_kind"] == kind:
|
||||
break
|
||||
surf_map["contours"] = event.contours
|
||||
self._n_contours = len(event.contours)
|
||||
with disable_ui_events(self):
|
||||
if "contours" in self._widgets:
|
||||
self._widgets["contours"].set_value(len(event.contours))
|
||||
self._update()
|
||||
|
||||
def set_time(self, time):
|
||||
"""Set the time to display (in seconds).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
time : float
|
||||
The time to show, in seconds.
|
||||
"""
|
||||
if self._evoked.times[0] <= time <= self._evoked.times[-1]:
|
||||
publish(self, TimeChange(time=time))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Requested time ({time} s) is outside the range of "
|
||||
f"available times ({self._evoked.times[0]}-{self._evoked.times[-1]} s)."
|
||||
)
|
||||
|
||||
def set_contours(self, n_contours):
|
||||
"""Adjust the number of contour lines to use when drawing the fieldlines.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_contours : int
|
||||
The number of contour lines to use.
|
||||
"""
|
||||
for surf_map in self._surf_maps:
|
||||
publish(
|
||||
self,
|
||||
Contours(
|
||||
kind=f"field_strength_{surf_map['map_kind']}",
|
||||
contours=np.linspace(
|
||||
-surf_map["map_vmax"], surf_map["map_vmax"], n_contours
|
||||
).tolist(),
|
||||
),
|
||||
)
|
||||
|
||||
def set_vmax(self, vmax, kind="meg"):
|
||||
"""Change the color range of the density maps.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vmax : float
|
||||
The new maximum value of the color range.
|
||||
kind : 'meg' | 'eeg'
|
||||
Which field map to apply the new color range to.
|
||||
"""
|
||||
_check_option("type", kind, ["eeg", "meg"])
|
||||
for surf_map in self._surf_maps:
|
||||
if surf_map["map_kind"] == kind:
|
||||
publish(
|
||||
self,
|
||||
ColormapRange(
|
||||
kind=f"field_strength_{kind}",
|
||||
fmin=-vmax,
|
||||
fmax=vmax,
|
||||
),
|
||||
)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"No {type.upper()} field map currently shown.")
|
||||
|
||||
def _rescale(self):
|
||||
"""Rescale the fieldlines and density maps to the current time point."""
|
||||
for surf_map in self._surf_maps:
|
||||
current_data = surf_map["data_interp"](self._current_time)
|
||||
vmax = float(np.max(current_data))
|
||||
self.set_vmax(vmax, kind=surf_map["map_kind"])
|
||||
@@ -0,0 +1,7 @@
|
||||
"""Eye-tracking visualization routines."""
|
||||
#
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from .heatmap import plot_gaze
|
||||
@@ -0,0 +1,209 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import numpy as np
|
||||
from scipy.ndimage import gaussian_filter
|
||||
|
||||
from ..._fiff.constants import FIFF
|
||||
from ...utils import _validate_type, fill_doc, logger
|
||||
from ..utils import plt_show
|
||||
|
||||
|
||||
@fill_doc
|
||||
def plot_gaze(
|
||||
epochs,
|
||||
*,
|
||||
calibration=None,
|
||||
width=None,
|
||||
height=None,
|
||||
sigma=25,
|
||||
cmap=None,
|
||||
alpha=1.0,
|
||||
vlim=(None, None),
|
||||
axes=None,
|
||||
show=True,
|
||||
):
|
||||
"""Plot a heatmap of eyetracking gaze data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
epochs : instance of Epochs
|
||||
The :class:`~mne.Epochs` object containing eyegaze channels.
|
||||
calibration : instance of Calibration | None
|
||||
An instance of Calibration with information about the screen size, distance,
|
||||
and resolution. If ``None``, you must provide a width and height.
|
||||
width : int
|
||||
The width dimension of the plot canvas, only valid if eyegaze data are in
|
||||
pixels. For example, if the participant screen resolution was 1920x1080, then
|
||||
the width should be 1920.
|
||||
height : int
|
||||
The height dimension of the plot canvas, only valid if eyegaze data are in
|
||||
pixels. For example, if the participant screen resolution was 1920x1080, then
|
||||
the height should be 1080.
|
||||
sigma : float | None
|
||||
The amount of Gaussian smoothing applied to the heatmap data (standard
|
||||
deviation in pixels). If ``None``, no smoothing is applied. Default is 25.
|
||||
%(cmap)s
|
||||
alpha : float
|
||||
The opacity of the heatmap (default is 1).
|
||||
%(vlim_plot_topomap)s
|
||||
%(axes_plot_topomap)s
|
||||
%(show)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
fig : instance of Figure
|
||||
The resulting figure object for the heatmap plot.
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 1.6
|
||||
"""
|
||||
from mne import BaseEpochs
|
||||
from mne._fiff.pick import _picks_to_idx
|
||||
|
||||
from ...preprocessing.eyetracking.utils import (
|
||||
_check_calibration,
|
||||
get_screen_visual_angle,
|
||||
)
|
||||
|
||||
_validate_type(epochs, BaseEpochs, "epochs")
|
||||
_validate_type(alpha, "numeric", "alpha")
|
||||
_validate_type(sigma, ("numeric", None), "sigma")
|
||||
|
||||
# Get the gaze data
|
||||
pos_picks = _picks_to_idx(epochs.info, "eyegaze")
|
||||
gaze_data = epochs.get_data(picks=pos_picks)
|
||||
gaze_ch_loc = np.array([epochs.info["chs"][idx]["loc"] for idx in pos_picks])
|
||||
x_data = gaze_data[:, np.where(gaze_ch_loc[:, 4] == -1)[0], :]
|
||||
y_data = gaze_data[:, np.where(gaze_ch_loc[:, 4] == 1)[0], :]
|
||||
unit = epochs.info["chs"][pos_picks[0]]["unit"] # assumes all units are the same
|
||||
|
||||
if x_data.shape[1] > 1: # binocular recording. Average across eyes
|
||||
logger.info("Detected binocular recording. Averaging positions across eyes.")
|
||||
x_data = np.nanmean(x_data, axis=1) # shape (n_epochs, n_samples)
|
||||
y_data = np.nanmean(y_data, axis=1)
|
||||
canvas = np.vstack((x_data.flatten(), y_data.flatten())) # shape (2, n_samples)
|
||||
|
||||
# Check that we have the right inputs
|
||||
if calibration is not None:
|
||||
if width is not None or height is not None:
|
||||
raise ValueError(
|
||||
"If a calibration is provided, you cannot provide a width or height"
|
||||
" to plot heatmaps. Please provide only the calibration object."
|
||||
)
|
||||
_check_calibration(calibration)
|
||||
if unit == FIFF.FIFF_UNIT_PX:
|
||||
width, height = calibration["screen_resolution"]
|
||||
elif unit == FIFF.FIFF_UNIT_RAD:
|
||||
width, height = calibration["screen_size"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid unit type: {unit}. gaze data Must be pixels or radians."
|
||||
)
|
||||
else:
|
||||
if width is None or height is None:
|
||||
raise ValueError(
|
||||
"If no calibration is provided, you must provide a width and height"
|
||||
" to plot heatmaps."
|
||||
)
|
||||
|
||||
# Create 2D histogram
|
||||
# We need to set the histogram bins & bounds, and imshow extent, based on the units
|
||||
if unit == FIFF.FIFF_UNIT_PX: # pixel on screen
|
||||
_range = [[0, height], [0, width]]
|
||||
bins_x, bins_y = width, height
|
||||
extent = [0, width, height, 0]
|
||||
elif unit == FIFF.FIFF_UNIT_RAD: # radians of visual angle
|
||||
if not calibration:
|
||||
raise ValueError(
|
||||
"If gaze data are in Radians, you must provide a"
|
||||
" calibration instance to plot heatmaps."
|
||||
)
|
||||
width, height = get_screen_visual_angle(calibration)
|
||||
x_range = [-width / 2, width / 2]
|
||||
y_range = [-height / 2, height / 2]
|
||||
_range = [y_range, x_range]
|
||||
extent = (x_range[0], x_range[1], y_range[0], y_range[1])
|
||||
bins_x, bins_y = calibration["screen_resolution"]
|
||||
|
||||
hist, _, _ = np.histogram2d(
|
||||
canvas[1, :],
|
||||
canvas[0, :],
|
||||
bins=(bins_y, bins_x),
|
||||
range=_range,
|
||||
)
|
||||
# Convert density from samples to seconds
|
||||
hist /= epochs.info["sfreq"]
|
||||
# Smooth the heatmap
|
||||
if sigma:
|
||||
hist = gaussian_filter(hist, sigma=sigma)
|
||||
|
||||
return _plot_heatmap_array(
|
||||
hist,
|
||||
width=width,
|
||||
height=height,
|
||||
cmap=cmap,
|
||||
alpha=alpha,
|
||||
vmin=vlim[0],
|
||||
vmax=vlim[1],
|
||||
extent=extent,
|
||||
axes=axes,
|
||||
show=show,
|
||||
)
|
||||
|
||||
|
||||
def _plot_heatmap_array(
|
||||
data,
|
||||
width,
|
||||
height,
|
||||
*,
|
||||
cmap=None,
|
||||
alpha=None,
|
||||
vmin=None,
|
||||
vmax=None,
|
||||
extent=None,
|
||||
axes=None,
|
||||
show=True,
|
||||
):
|
||||
"""Plot a heatmap of eyetracking gaze data from a numpy array."""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Prepare axes
|
||||
if axes is not None:
|
||||
from matplotlib.axes import Axes
|
||||
|
||||
_validate_type(axes, Axes, "axes")
|
||||
ax = axes
|
||||
fig = ax.get_figure()
|
||||
else:
|
||||
fig, ax = plt.subplots(constrained_layout=True)
|
||||
|
||||
ax.set_title("Gaze heatmap")
|
||||
ax.set_xlabel("X position")
|
||||
ax.set_ylabel("Y position")
|
||||
|
||||
# Prepare the heatmap
|
||||
alphas = 1 if alpha is None else alpha
|
||||
vmin = np.nanmin(data) if vmin is None else vmin
|
||||
vmax = np.nanmax(data) if vmax is None else vmax
|
||||
if extent is None:
|
||||
extent = [0, width, height, 0]
|
||||
|
||||
# Plot heatmap
|
||||
im = ax.imshow(
|
||||
data,
|
||||
aspect="equal",
|
||||
cmap=cmap,
|
||||
alpha=alphas,
|
||||
extent=extent,
|
||||
origin="upper",
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
)
|
||||
|
||||
# Prepare the colorbar
|
||||
fig.colorbar(im, ax=ax, shrink=0.6, label="Dwell time (seconds)")
|
||||
plt_show(show)
|
||||
return fig
|
||||
+1460
File diff suppressed because it is too large
Load Diff
+1645
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,135 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
"""Functions to plot EEG sensor montages or digitizer montages."""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from .._fiff._digitization import _get_fid_coords
|
||||
from .._fiff.meas_info import create_info
|
||||
from ..utils import _check_option, _validate_type, logger, verbose
|
||||
from .utils import plot_sensors
|
||||
|
||||
|
||||
@verbose
|
||||
def plot_montage(
|
||||
montage,
|
||||
*,
|
||||
scale=None,
|
||||
scale_factor=None,
|
||||
show_names=True,
|
||||
kind="topomap",
|
||||
show=True,
|
||||
sphere=None,
|
||||
axes=None,
|
||||
verbose=None,
|
||||
):
|
||||
"""Plot a montage.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
montage : instance of DigMontage
|
||||
The montage to visualize.
|
||||
scale : float
|
||||
Determines the scale of the channel points and labels; values < 1 will scale
|
||||
down, whereas values > 1 will scale up. Default to None, which implies 1.
|
||||
scale_factor : float
|
||||
Determines the size of the points. Deprecated, use scale instead.
|
||||
show_names : bool | list
|
||||
Whether to display all channel names. If a list, only the channel
|
||||
names in the list are shown. Defaults to True.
|
||||
kind : str
|
||||
Whether to plot the montage as '3d' or 'topomap' (default).
|
||||
show : bool
|
||||
Show figure if True.
|
||||
%(sphere_topomap_auto)s
|
||||
%(axes_montage)s
|
||||
|
||||
.. versionadded:: 1.4
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
fig : instance of matplotlib.figure.Figure
|
||||
The figure object.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ..channels import DigMontage, make_dig_montage
|
||||
|
||||
if scale_factor is not None:
|
||||
msg = "scale_factor has been deprecated and will be removed. Use scale instead."
|
||||
if scale is not None:
|
||||
raise ValueError(
|
||||
" ".join(["scale and scale_factor cannot be used together.", msg])
|
||||
)
|
||||
logger.info(msg)
|
||||
if scale is None:
|
||||
scale = 1
|
||||
|
||||
_check_option("kind", kind, ["topomap", "3d"])
|
||||
_validate_type(montage, DigMontage, item_name="montage")
|
||||
ch_names = montage.ch_names
|
||||
title = None
|
||||
|
||||
if len(ch_names) == 0:
|
||||
raise RuntimeError("No valid channel positions found.")
|
||||
|
||||
pos = np.array(list(montage._get_ch_pos().values()))
|
||||
|
||||
dists = cdist(pos, pos)
|
||||
|
||||
# only consider upper triangular part by setting the rest to np.nan
|
||||
dists[np.tril_indices(dists.shape[0])] = np.nan
|
||||
dupes = np.argwhere(np.isclose(dists, 0))
|
||||
if dupes.any():
|
||||
montage = deepcopy(montage)
|
||||
n_chans = pos.shape[0]
|
||||
n_dupes = dupes.shape[0]
|
||||
idx = np.setdiff1d(np.arange(len(pos)), dupes[:, 1]).tolist()
|
||||
logger.info(f"{n_dupes} duplicate electrode labels found:")
|
||||
logger.info(", ".join([ch_names[d[0]] + "/" + ch_names[d[1]] for d in dupes]))
|
||||
logger.info(f"Plotting {n_chans - n_dupes} unique labels.")
|
||||
ch_names = [ch_names[i] for i in idx]
|
||||
ch_pos = dict(zip(ch_names, pos[idx, :]))
|
||||
# XXX: this might cause trouble if montage was originally in head
|
||||
fid, _ = _get_fid_coords(montage.dig)
|
||||
montage = make_dig_montage(ch_pos=ch_pos, **fid)
|
||||
|
||||
info = create_info(ch_names, sfreq=256, ch_types="eeg")
|
||||
info.set_montage(montage, on_missing="ignore")
|
||||
fig = plot_sensors(
|
||||
info,
|
||||
kind=kind,
|
||||
show_names=show_names,
|
||||
show=show,
|
||||
title=title,
|
||||
sphere=sphere,
|
||||
axes=axes,
|
||||
)
|
||||
|
||||
if scale_factor is not None:
|
||||
# scale points
|
||||
collection = fig.axes[0].collections[0]
|
||||
collection.set_sizes([scale_factor])
|
||||
elif scale is not None:
|
||||
# scale points
|
||||
collection = fig.axes[0].collections[0]
|
||||
collection.set_sizes([scale * 10])
|
||||
|
||||
# scale labels
|
||||
labels = fig.findobj(match=plt.Text)
|
||||
x_label, y_label = fig.axes[0].xaxis.label, fig.axes[0].yaxis.label
|
||||
z_label = fig.axes[0].zaxis.label if kind == "3d" else None
|
||||
tick_labels = fig.axes[0].get_xticklabels() + fig.axes[0].get_yticklabels()
|
||||
if kind == "3d":
|
||||
tick_labels += fig.axes[0].get_zticklabels()
|
||||
for label in labels:
|
||||
if label not in [x_label, y_label, z_label] + tick_labels:
|
||||
label.set_fontsize(label.get_fontsize() * scale)
|
||||
|
||||
return fig
|
||||
+635
@@ -0,0 +1,635 @@
|
||||
"""Functions to plot raw M/EEG data."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.pick import _picks_to_idx, pick_channels, pick_types
|
||||
from ..defaults import _handle_default
|
||||
from ..filter import create_filter
|
||||
from ..utils import _check_option, _get_stim_channel, _validate_type, legacy, verbose
|
||||
from ..utils.spectrum import _split_psd_kwargs
|
||||
from .utils import (
|
||||
_check_cov,
|
||||
_compute_scalings,
|
||||
_get_channel_plotting_order,
|
||||
_handle_decim,
|
||||
_handle_precompute,
|
||||
_make_event_color_dict,
|
||||
_shorten_path_from_middle,
|
||||
)
|
||||
|
||||
_RAW_CLIP_DEF = 1.5
|
||||
|
||||
|
||||
@verbose
|
||||
def plot_raw(
|
||||
raw,
|
||||
events=None,
|
||||
duration=10.0,
|
||||
start=0.0,
|
||||
n_channels=20,
|
||||
bgcolor="w",
|
||||
color=None,
|
||||
bad_color="lightgray",
|
||||
event_color="cyan",
|
||||
scalings=None,
|
||||
remove_dc=True,
|
||||
order=None,
|
||||
show_options=False,
|
||||
title=None,
|
||||
show=True,
|
||||
block=False,
|
||||
highpass=None,
|
||||
lowpass=None,
|
||||
filtorder=4,
|
||||
clipping=_RAW_CLIP_DEF,
|
||||
show_first_samp=False,
|
||||
proj=True,
|
||||
group_by="type",
|
||||
butterfly=False,
|
||||
decim="auto",
|
||||
noise_cov=None,
|
||||
event_id=None,
|
||||
show_scrollbars=True,
|
||||
show_scalebars=True,
|
||||
time_format="float",
|
||||
precompute=None,
|
||||
use_opengl=None,
|
||||
picks=None,
|
||||
*,
|
||||
theme=None,
|
||||
overview_mode=None,
|
||||
splash=True,
|
||||
verbose=None,
|
||||
):
|
||||
"""Plot raw data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The raw data to plot.
|
||||
events : array | None
|
||||
Events to show with vertical bars.
|
||||
duration : float
|
||||
Time window (s) to plot. The lesser of this value and the duration
|
||||
of the raw file will be used.
|
||||
start : float
|
||||
Initial time to show (can be changed dynamically once plotted). If
|
||||
show_first_samp is True, then it is taken relative to
|
||||
``raw.first_samp``.
|
||||
n_channels : int
|
||||
Number of channels to plot at once. Defaults to 20. The lesser of
|
||||
``n_channels`` and ``len(raw.ch_names)`` will be shown.
|
||||
Has no effect if ``order`` is 'position', 'selection' or 'butterfly'.
|
||||
bgcolor : color object
|
||||
Color of the background.
|
||||
color : dict | color object | None
|
||||
Color for the data traces. If None, defaults to::
|
||||
|
||||
dict(mag='darkblue', grad='b', eeg='k', eog='k', ecg='m',
|
||||
emg='k', ref_meg='steelblue', misc='k', stim='k',
|
||||
resp='k', chpi='k')
|
||||
|
||||
bad_color : color object
|
||||
Color to make bad channels.
|
||||
%(event_color)s
|
||||
Defaults to ``'cyan'``.
|
||||
%(scalings)s
|
||||
remove_dc : bool
|
||||
If True remove DC component when plotting data.
|
||||
order : array of int | None
|
||||
Order in which to plot data. If the array is shorter than the number of
|
||||
channels, only the given channels are plotted. If None (default), all
|
||||
channels are plotted. If ``group_by`` is ``'position'`` or
|
||||
``'selection'``, the ``order`` parameter is used only for selecting the
|
||||
channels to be plotted.
|
||||
show_options : bool
|
||||
If True, a dialog for options related to projection is shown.
|
||||
title : str | None
|
||||
The title of the window. If None, and either the filename of the
|
||||
raw object or '<unknown>' will be displayed as title.
|
||||
show : bool
|
||||
Show figure if True.
|
||||
block : bool
|
||||
Whether to halt program execution until the figure is closed.
|
||||
Useful for setting bad channels on the fly by clicking on a line.
|
||||
May not work on all systems / platforms.
|
||||
(Only Qt) If you run from a script, this needs to
|
||||
be ``True`` or a Qt-eventloop needs to be started somewhere
|
||||
else in the script (e.g. if you want to implement the browser
|
||||
inside another Qt-Application).
|
||||
highpass : float | None
|
||||
Highpass to apply when displaying data.
|
||||
lowpass : float | None
|
||||
Lowpass to apply when displaying data.
|
||||
If highpass > lowpass, a bandstop rather than bandpass filter
|
||||
will be applied.
|
||||
filtorder : int
|
||||
Filtering order. 0 will use FIR filtering with MNE defaults.
|
||||
Other values will construct an IIR filter of the given order
|
||||
and apply it with :func:`~scipy.signal.filtfilt` (making the effective
|
||||
order twice ``filtorder``). Filtering may produce some edge artifacts
|
||||
(at the left and right edges) of the signals during display.
|
||||
|
||||
.. versionchanged:: 0.18
|
||||
Support for ``filtorder=0`` to use FIR filtering.
|
||||
clipping : str | float | None
|
||||
If None, channels are allowed to exceed their designated bounds in
|
||||
the plot. If "clamp", then values are clamped to the appropriate
|
||||
range for display, creating step-like artifacts. If "transparent",
|
||||
then excessive values are not shown, creating gaps in the traces.
|
||||
If float, clipping occurs for values beyond the ``clipping`` multiple
|
||||
of their dedicated range, so ``clipping=1.`` is an alias for
|
||||
``clipping='transparent'``.
|
||||
|
||||
.. versionchanged:: 0.21
|
||||
Support for float, and default changed from None to 1.5.
|
||||
show_first_samp : bool
|
||||
If True, show time axis relative to the ``raw.first_samp``.
|
||||
proj : bool
|
||||
Whether to apply projectors prior to plotting (default is ``True``).
|
||||
Individual projectors can be enabled/disabled interactively (see
|
||||
Notes). This argument only affects the plot; use ``raw.apply_proj()``
|
||||
to modify the data stored in the Raw object.
|
||||
%(group_by_browse)s
|
||||
butterfly : bool
|
||||
Whether to start in butterfly mode. Defaults to False.
|
||||
decim : int | 'auto'
|
||||
Amount to decimate the data during display for speed purposes.
|
||||
You should only decimate if the data are sufficiently low-passed,
|
||||
otherwise aliasing can occur. The 'auto' mode (default) uses
|
||||
the decimation that results in a sampling rate least three times
|
||||
larger than ``min(info['lowpass'], lowpass)`` (e.g., a 40 Hz lowpass
|
||||
will result in at least a 120 Hz displayed sample rate).
|
||||
noise_cov : instance of Covariance | str | None
|
||||
Noise covariance used to whiten the data while plotting.
|
||||
Whitened data channels are scaled by ``scalings['whitened']``,
|
||||
and their channel names are shown in italic.
|
||||
Can be a string to load a covariance from disk.
|
||||
See also :meth:`mne.Evoked.plot_white` for additional inspection
|
||||
of noise covariance properties when whitening evoked data.
|
||||
For data processed with SSS, the effective dependence between
|
||||
magnetometers and gradiometers may introduce differences in scaling,
|
||||
consider using :meth:`mne.Evoked.plot_white`.
|
||||
|
||||
.. versionadded:: 0.16.0
|
||||
event_id : dict | None
|
||||
Event IDs used to show at event markers (default None shows
|
||||
the event numbers).
|
||||
|
||||
.. versionadded:: 0.16.0
|
||||
%(show_scrollbars)s
|
||||
%(show_scalebars)s
|
||||
|
||||
.. versionadded:: 0.20.0
|
||||
%(time_format)s
|
||||
%(precompute)s
|
||||
%(use_opengl)s
|
||||
%(picks_all)s
|
||||
%(theme_pg)s
|
||||
|
||||
.. versionadded:: 1.0
|
||||
%(overview_mode)s
|
||||
|
||||
.. versionadded:: 1.1
|
||||
%(splash)s
|
||||
|
||||
.. versionadded:: 1.6
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
%(browser)s
|
||||
|
||||
Notes
|
||||
-----
|
||||
The arrow keys (up/down/left/right) can typically be used to navigate
|
||||
between channels and time ranges, but this depends on the backend
|
||||
matplotlib is configured to use (e.g., mpl.use('TkAgg') should work). The
|
||||
left/right arrows will scroll by 25%% of ``duration``, whereas
|
||||
shift+left/shift+right will scroll by 100%% of ``duration``. The scaling
|
||||
can be adjusted with - and + (or =) keys. The viewport dimensions can be
|
||||
adjusted with page up/page down and home/end keys. Full screen mode can be
|
||||
toggled with the F11 key, and scrollbars can be hidden/shown by pressing
|
||||
'z'. Right-click a channel label to view its location. To mark or un-mark a
|
||||
channel as bad, click on a channel label or a channel trace. The changes
|
||||
will be reflected immediately in the raw object's ``raw.info['bads']``
|
||||
entry.
|
||||
|
||||
If projectors are present, a button labelled "Prj" in the lower right
|
||||
corner of the plot window opens a secondary control window, which allows
|
||||
enabling/disabling specific projectors individually. This provides a means
|
||||
of interactively observing how each projector would affect the raw data if
|
||||
it were applied.
|
||||
|
||||
Annotation mode is toggled by pressing 'a', butterfly mode by pressing
|
||||
'b', and whitening mode (when ``noise_cov is not None``) by pressing 'w'.
|
||||
By default, the channel means are removed when ``remove_dc`` is set to
|
||||
``True``. This flag can be toggled by pressing 'd'.
|
||||
|
||||
%(notes_2d_backend)s
|
||||
"""
|
||||
from ..annotations import _annotations_starts_stops
|
||||
from ..io import BaseRaw
|
||||
from ._figure import _get_browser
|
||||
|
||||
info = raw.info.copy()
|
||||
sfreq = info["sfreq"]
|
||||
projs = info["projs"]
|
||||
# this will be an attr for which projectors are currently "on" in the plot
|
||||
projs_on = np.full_like(projs, proj, dtype=bool)
|
||||
# disable projs in info if user doesn't want to see them right away
|
||||
if not proj:
|
||||
with info._unlock():
|
||||
info["projs"] = list()
|
||||
|
||||
# handle defaults / check arg validity
|
||||
color = _handle_default("color", color)
|
||||
scalings = _compute_scalings(scalings, raw, remove_dc=remove_dc, duration=duration)
|
||||
if scalings["whitened"] == "auto":
|
||||
scalings["whitened"] = 1.0
|
||||
_validate_type(raw, BaseRaw, "raw", "Raw")
|
||||
decim, picks_data = _handle_decim(info, decim, lowpass)
|
||||
noise_cov = _check_cov(noise_cov, info)
|
||||
units = _handle_default("units", None)
|
||||
unit_scalings = _handle_default("scalings", None)
|
||||
_check_option("group_by", group_by, ("selection", "position", "original", "type"))
|
||||
|
||||
# clipping
|
||||
_validate_type(clipping, (None, "numeric", str), "clipping")
|
||||
if isinstance(clipping, str):
|
||||
_check_option(
|
||||
"clipping", clipping, ("clamp", "transparent"), extra="when a string"
|
||||
)
|
||||
clipping = 1.0 if clipping == "transparent" else clipping
|
||||
elif clipping is not None:
|
||||
clipping = float(clipping)
|
||||
|
||||
# be forgiving if user asks for too much time
|
||||
duration = min(raw.times[-1], float(duration))
|
||||
|
||||
# determine IIR filtering parameters
|
||||
if highpass is not None and highpass <= 0:
|
||||
raise ValueError(f"highpass must be > 0, got {highpass}")
|
||||
if highpass is None and lowpass is None:
|
||||
ba = filt_bounds = None
|
||||
else:
|
||||
filtorder = int(filtorder)
|
||||
if filtorder == 0:
|
||||
method = "fir"
|
||||
iir_params = None
|
||||
else:
|
||||
method = "iir"
|
||||
iir_params = dict(order=filtorder, output="sos", ftype="butter")
|
||||
ba = create_filter(
|
||||
np.zeros((1, int(round(duration * sfreq)))),
|
||||
sfreq,
|
||||
highpass,
|
||||
lowpass,
|
||||
method=method,
|
||||
iir_params=iir_params,
|
||||
)
|
||||
filt_bounds = _annotations_starts_stops(
|
||||
raw, ("edge", "bad_acq_skip"), invert=True
|
||||
)
|
||||
|
||||
# compute event times in seconds
|
||||
if events is not None:
|
||||
event_times = (events[:, 0] - raw.first_samp).astype(float)
|
||||
event_times /= sfreq
|
||||
event_nums = events[:, 2]
|
||||
else:
|
||||
event_times = event_nums = None
|
||||
|
||||
# determine trace order
|
||||
ch_names = np.array(raw.ch_names)
|
||||
ch_types = np.array(raw.get_channel_types())
|
||||
|
||||
picks = _picks_to_idx(info, picks, none="all", exclude=())
|
||||
order = _get_channel_plotting_order(order, ch_types, picks=picks)
|
||||
n_channels = min(info["nchan"], n_channels, len(order))
|
||||
# adjust order based on channel selection, if needed
|
||||
selections = None
|
||||
if group_by in ("selection", "position"):
|
||||
selections = _setup_channel_selections(raw, group_by, order)
|
||||
order = np.concatenate(list(selections.values()))
|
||||
default_selection = list(selections)[0]
|
||||
n_channels = len(selections[default_selection])
|
||||
assert isinstance(order, np.ndarray)
|
||||
assert order.dtype.kind == "i"
|
||||
if order.size == 0:
|
||||
raise RuntimeError("No channels found to plot")
|
||||
|
||||
# handle event colors
|
||||
event_color_dict = _make_event_color_dict(event_color, events, event_id)
|
||||
|
||||
# handle first_samp
|
||||
first_time = raw._first_time if show_first_samp else 0
|
||||
start += first_time
|
||||
event_id_rev = {v: k for k, v in (event_id or {}).items()}
|
||||
|
||||
# generate window title; allow instances without a filename (e.g., ICA)
|
||||
if title is None:
|
||||
title = "<unknown>"
|
||||
fnames = list(tuple(raw.filenames)) # get a list of a copy of the filenames
|
||||
if len(fnames):
|
||||
title = fnames.pop(0)
|
||||
extra = f" ... (+ {len(fnames)} more)" if len(fnames) else ""
|
||||
title = f"{title}{extra}"
|
||||
if len(title) > 60:
|
||||
title = _shorten_path_from_middle(title)
|
||||
elif not isinstance(title, str):
|
||||
raise TypeError(f"title must be None or a string, got a {type(title)}")
|
||||
|
||||
# gather parameters and initialize figure
|
||||
_validate_type(use_opengl, (bool, None), "use_opengl")
|
||||
precompute = _handle_precompute(precompute)
|
||||
params = dict(
|
||||
inst=raw,
|
||||
info=info,
|
||||
# channels and channel order
|
||||
ch_names=ch_names,
|
||||
ch_types=ch_types,
|
||||
ch_order=order,
|
||||
picks=order[:n_channels],
|
||||
n_channels=n_channels,
|
||||
picks_data=picks_data,
|
||||
group_by=group_by,
|
||||
ch_selections=selections,
|
||||
# time
|
||||
t_start=start,
|
||||
duration=duration,
|
||||
n_times=raw.n_times,
|
||||
first_time=first_time,
|
||||
time_format=time_format,
|
||||
decim=decim,
|
||||
# events
|
||||
event_color_dict=event_color_dict,
|
||||
event_times=event_times,
|
||||
event_nums=event_nums,
|
||||
event_id_rev=event_id_rev,
|
||||
# preprocessing
|
||||
projs=projs,
|
||||
projs_on=projs_on,
|
||||
apply_proj=proj,
|
||||
remove_dc=remove_dc,
|
||||
filter_coefs=ba,
|
||||
filter_bounds=filt_bounds,
|
||||
noise_cov=noise_cov,
|
||||
# scalings
|
||||
scalings=scalings,
|
||||
units=units,
|
||||
unit_scalings=unit_scalings,
|
||||
# colors
|
||||
ch_color_bad=bad_color,
|
||||
ch_color_dict=color,
|
||||
# display
|
||||
butterfly=butterfly,
|
||||
clipping=clipping,
|
||||
scrollbars_visible=show_scrollbars,
|
||||
scalebars_visible=show_scalebars,
|
||||
window_title=title,
|
||||
bgcolor=bgcolor,
|
||||
# Qt-specific
|
||||
precompute=precompute,
|
||||
use_opengl=use_opengl,
|
||||
theme=theme,
|
||||
overview_mode=overview_mode,
|
||||
splash=splash,
|
||||
)
|
||||
|
||||
fig = _get_browser(show=show, block=block, **params)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
@legacy(alt="Raw.compute_psd().plot()")
|
||||
@verbose
|
||||
def plot_raw_psd(
|
||||
raw,
|
||||
fmin=0,
|
||||
fmax=np.inf,
|
||||
tmin=None,
|
||||
tmax=None,
|
||||
proj=False,
|
||||
n_fft=None,
|
||||
n_overlap=0,
|
||||
reject_by_annotation=True,
|
||||
picks=None,
|
||||
ax=None,
|
||||
color="black",
|
||||
xscale="linear",
|
||||
area_mode="std",
|
||||
area_alpha=0.33,
|
||||
dB=True,
|
||||
estimate="power",
|
||||
show=True,
|
||||
n_jobs=None,
|
||||
average=False,
|
||||
line_alpha=None,
|
||||
spatial_colors=True,
|
||||
sphere=None,
|
||||
window="hamming",
|
||||
exclude="bads",
|
||||
verbose=None,
|
||||
):
|
||||
"""%(plot_psd_doc)s.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of Raw
|
||||
The raw object.
|
||||
%(fmin_fmax_psd)s
|
||||
%(tmin_tmax_psd)s
|
||||
%(proj_psd)s
|
||||
n_fft : int | None
|
||||
Number of points to use in Welch FFT calculations. Default is ``None``,
|
||||
which uses the minimum of 2048 and the number of time points.
|
||||
n_overlap : int
|
||||
The number of points of overlap between blocks. The default value
|
||||
is 0 (no overlap).
|
||||
%(reject_by_annotation_psd)s
|
||||
%(picks_good_data_noref)s
|
||||
%(ax_plot_psd)s
|
||||
%(color_plot_psd)s
|
||||
%(xscale_plot_psd)s
|
||||
%(area_mode_plot_psd)s
|
||||
%(area_alpha_plot_psd)s
|
||||
%(dB_plot_psd)s
|
||||
%(estimate_plot_psd)s
|
||||
%(show)s
|
||||
%(n_jobs)s
|
||||
%(average_plot_psd)s
|
||||
%(line_alpha_plot_psd)s
|
||||
%(spatial_colors_psd)s
|
||||
%(sphere_topomap_auto)s
|
||||
%(window_psd)s
|
||||
|
||||
.. versionadded:: 0.22.0
|
||||
exclude : list of str | 'bads'
|
||||
Channels names to exclude from being shown. If 'bads', the bad channels
|
||||
are excluded. Pass an empty list to plot all channels (including
|
||||
channels marked "bad", if any).
|
||||
|
||||
.. versionadded:: 0.24.0
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
fig : instance of Figure
|
||||
Figure with frequency spectra of the data channels.
|
||||
|
||||
Notes
|
||||
-----
|
||||
%(notes_plot_*_psd_func)s
|
||||
"""
|
||||
from ..time_frequency import Spectrum
|
||||
|
||||
init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot)
|
||||
return raw.compute_psd(**init_kw).plot(**plot_kw)
|
||||
|
||||
|
||||
@legacy(alt="Raw.compute_psd().plot_topo()")
|
||||
@verbose
|
||||
def plot_raw_psd_topo(
|
||||
raw,
|
||||
tmin=0.0,
|
||||
tmax=None,
|
||||
fmin=0.0,
|
||||
fmax=100.0,
|
||||
proj=False,
|
||||
*,
|
||||
n_fft=2048,
|
||||
n_overlap=0,
|
||||
dB=True,
|
||||
layout=None,
|
||||
color="w",
|
||||
fig_facecolor="k",
|
||||
axis_facecolor="k",
|
||||
axes=None,
|
||||
block=False,
|
||||
show=True,
|
||||
n_jobs=None,
|
||||
verbose=None,
|
||||
):
|
||||
"""Plot power spectral density, separately for each channel.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : instance of io.Raw
|
||||
The raw instance to use.
|
||||
%(tmin_tmax_psd)s
|
||||
%(fmin_fmax_psd_topo)s
|
||||
%(proj_psd)s
|
||||
n_fft : int
|
||||
Number of points to use in Welch FFT calculations. Defaults to 2048.
|
||||
n_overlap : int
|
||||
The number of points of overlap between blocks. Defaults to 0
|
||||
(no overlap).
|
||||
%(dB_spectrum_plot_topo)s
|
||||
layout : instance of Layout | None
|
||||
Layout instance specifying sensor positions (does not need to be
|
||||
specified for Neuromag data). If ``None`` (default), the layout is
|
||||
inferred from the data.
|
||||
color : str | tuple
|
||||
A matplotlib-compatible color to use for the curves. Defaults to white.
|
||||
fig_facecolor : str | tuple
|
||||
A matplotlib-compatible color to use for the figure background.
|
||||
Defaults to black.
|
||||
axis_facecolor : str | tuple
|
||||
A matplotlib-compatible color to use for the axis background.
|
||||
Defaults to black.
|
||||
%(axes_spectrum_plot_topo)s
|
||||
block : bool
|
||||
Whether to halt program execution until the figure is closed.
|
||||
May not work on all systems / platforms. Defaults to False.
|
||||
%(show)s
|
||||
%(n_jobs)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
fig : instance of matplotlib.figure.Figure
|
||||
Figure distributing one image per channel across sensor topography.
|
||||
"""
|
||||
from ..time_frequency import Spectrum
|
||||
|
||||
init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot_topo)
|
||||
return raw.compute_psd(**init_kw).plot_topo(**plot_kw)
|
||||
|
||||
|
||||
def _setup_channel_selections(raw, kind, order):
|
||||
"""Get dictionary of channel groupings."""
|
||||
from ..channels import (
|
||||
_EEG_SELECTIONS,
|
||||
_SELECTIONS,
|
||||
_divide_to_regions,
|
||||
read_vectorview_selection,
|
||||
)
|
||||
|
||||
_check_option("group_by", kind, ("position", "selection"))
|
||||
if kind == "position":
|
||||
selections_dict = _divide_to_regions(raw.info)
|
||||
keys = _SELECTIONS[1:] # omit 'Vertex'
|
||||
else: # kind == 'selection'
|
||||
from ..channels.channels import _get_ch_info
|
||||
|
||||
(
|
||||
has_vv_mag,
|
||||
has_vv_grad,
|
||||
*_,
|
||||
has_neuromag_122_grad,
|
||||
has_csd_coils,
|
||||
) = _get_ch_info(raw.info)
|
||||
if not (has_vv_grad or has_vv_mag or has_neuromag_122_grad):
|
||||
raise ValueError(
|
||||
"order='selection' only works for Neuromag "
|
||||
"data. Use order='position' instead."
|
||||
)
|
||||
selections_dict = OrderedDict()
|
||||
# get stim channel (if any)
|
||||
stim_ch = _get_stim_channel(None, raw.info, raise_error=False)
|
||||
stim_ch = stim_ch if len(stim_ch) else [""]
|
||||
stim_ch = pick_channels(raw.ch_names, stim_ch, ordered=False)
|
||||
# loop over regions
|
||||
keys = np.concatenate([_SELECTIONS, _EEG_SELECTIONS])
|
||||
for key in keys:
|
||||
channels = read_vectorview_selection(key, info=raw.info)
|
||||
picks = pick_channels(raw.ch_names, channels, ordered=False)
|
||||
picks = np.intersect1d(picks, order)
|
||||
if not len(picks):
|
||||
continue # omit empty selections
|
||||
selections_dict[key] = np.concatenate([picks, stim_ch])
|
||||
# add misc channels
|
||||
misc = pick_types(
|
||||
raw.info,
|
||||
meg=False,
|
||||
eeg=False,
|
||||
stim=True,
|
||||
eog=True,
|
||||
ecg=True,
|
||||
emg=True,
|
||||
ref_meg=False,
|
||||
misc=True,
|
||||
resp=True,
|
||||
chpi=True,
|
||||
exci=True,
|
||||
ias=True,
|
||||
syst=True,
|
||||
seeg=False,
|
||||
bio=True,
|
||||
ecog=False,
|
||||
fnirs=False,
|
||||
dbs=False,
|
||||
temperature=True,
|
||||
gsr=True,
|
||||
exclude=(),
|
||||
)
|
||||
if len(misc) and np.isin(misc, order).any():
|
||||
selections_dict["Misc"] = misc
|
||||
return selections_dict
|
||||
+1309
File diff suppressed because it is too large
Load Diff
+4102
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,480 @@
|
||||
"""
|
||||
Event API for inter-figure communication.
|
||||
|
||||
The event API allows figures to communicate with each other, such that a change
|
||||
in one figure can trigger a change in another figure. For example, moving the
|
||||
time cursor in one plot can update the current time in another plot. Another
|
||||
scenario is two drawing routines drawing into the same window, using events to
|
||||
stay in-sync.
|
||||
|
||||
Authors: Marijn van Vliet <w.m.vanvliet@gmail.com>
|
||||
"""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from __future__ import annotations # only needed for Python ≤ 3.9
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
|
||||
from matplotlib.colors import Colormap
|
||||
|
||||
from ..utils import _validate_type, fill_doc, logger, verbose, warn
|
||||
|
||||
# Global dict {fig: channel} containing all currently active event channels.
|
||||
_event_channels = weakref.WeakKeyDictionary()
|
||||
|
||||
# The event channels of figures can be linked together. This dict keeps track
|
||||
# of these links. Links are bi-directional, so if {fig1: fig2} exists, then so
|
||||
# must {fig2: fig1}.
|
||||
_event_channel_links = weakref.WeakKeyDictionary()
|
||||
|
||||
# Event channels that are temporarily disabled by the disable_ui_events context
|
||||
# manager.
|
||||
_disabled_event_channels = weakref.WeakSet()
|
||||
|
||||
# Regex pattern used when converting CamelCase to snake_case.
|
||||
# Detects all capital letters that are not at the beginning of a word.
|
||||
_camel_to_snake = re.compile(r"(?<!^)(?=[A-Z])")
|
||||
|
||||
|
||||
# List of events
|
||||
@fill_doc
|
||||
class UIEvent:
|
||||
"""Abstract base class for all events.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
%(ui_event_name_source)s
|
||||
"""
|
||||
|
||||
source = None
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""The name of the event, which is the class name in snake case."""
|
||||
return _camel_to_snake.sub("_", self.__class__.__name__).lower()
|
||||
|
||||
|
||||
@fill_doc
|
||||
class FigureClosing(UIEvent):
|
||||
"""Indicates that the user has requested to close a figure.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
%(ui_event_name_source)s
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
@fill_doc
|
||||
class TimeChange(UIEvent):
|
||||
"""Indicates that the user has selected a time.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
time : float
|
||||
The new time in seconds.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
%(ui_event_name_source)s
|
||||
time : float
|
||||
The new time in seconds.
|
||||
"""
|
||||
|
||||
time: float
|
||||
|
||||
|
||||
@dataclass
|
||||
@fill_doc
|
||||
class PlaybackSpeed(UIEvent):
|
||||
"""Indicates that the user has selected a different playback speed for videos.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
speed : float
|
||||
The new speed in seconds per frame.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
%(ui_event_name_source)s
|
||||
speed : float
|
||||
The new speed in seconds per frame.
|
||||
"""
|
||||
|
||||
speed: float
|
||||
|
||||
|
||||
@dataclass
|
||||
@fill_doc
|
||||
class ColormapRange(UIEvent):
|
||||
"""Indicates that the user has updated the bounds of the colormap.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kind : str
|
||||
Kind of colormap being updated. The Notes section of the drawing
|
||||
routine publishing this event should mention the possible kinds.
|
||||
ch_type : str
|
||||
Type of sensor the data originates from.
|
||||
%(fmin_fmid_fmax)s
|
||||
%(alpha)s
|
||||
cmap : str
|
||||
The colormap to use. Either string or matplotlib.colors.Colormap
|
||||
instance.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
kind : str
|
||||
Kind of colormap being updated. The Notes section of the drawing
|
||||
routine publishing this event should mention the possible kinds.
|
||||
ch_type : str
|
||||
Type of sensor the data originates from.
|
||||
unit : str
|
||||
The unit of the values.
|
||||
%(ui_event_name_source)s
|
||||
%(fmin_fmid_fmax)s
|
||||
%(alpha)s
|
||||
cmap : str
|
||||
The colormap to use. Either string or matplotlib.colors.Colormap
|
||||
instance.
|
||||
"""
|
||||
|
||||
kind: str
|
||||
ch_type: str | None = None
|
||||
fmin: float | None = None
|
||||
fmid: float | None = None
|
||||
fmax: float | None = None
|
||||
alpha: bool | None = None
|
||||
cmap: Colormap | str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@fill_doc
|
||||
class VertexSelect(UIEvent):
|
||||
"""Indicates that the user has selected a vertex.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
hemi : str
|
||||
The hemisphere the vertex was selected on.
|
||||
Can be ``"lh"``, ``"rh"``, or ``"vol"``.
|
||||
vertex_id : int
|
||||
The vertex number (in the high resolution mesh) that was selected.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
%(ui_event_name_source)s
|
||||
hemi : str
|
||||
The hemisphere the vertex was selected on.
|
||||
Can be ``"lh"``, ``"rh"``, or ``"vol"``.
|
||||
vertex_id : int
|
||||
The vertex number (in the high resolution mesh) that was selected.
|
||||
"""
|
||||
|
||||
hemi: str
|
||||
vertex_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
@fill_doc
|
||||
class Contours(UIEvent):
|
||||
"""Indicates that the user has changed the contour lines.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kind : str
|
||||
The kind of contours lines being changed. The Notes section of the
|
||||
drawing routine publishing this event should mention the possible
|
||||
kinds.
|
||||
contours : list of float
|
||||
The new values at which contour lines need to be drawn.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
%(ui_event_name_source)s
|
||||
kind : str
|
||||
The kind of contours lines being changed. The Notes section of the
|
||||
drawing routine publishing this event should mention the possible
|
||||
kinds.
|
||||
contours : list of float
|
||||
The new values at which contour lines need to be drawn.
|
||||
"""
|
||||
|
||||
kind: str
|
||||
contours: list[str]
|
||||
|
||||
|
||||
def _get_event_channel(fig):
|
||||
"""Get the event channel associated with a figure.
|
||||
|
||||
If the event channel doesn't exist yet, it gets created and added to the
|
||||
global ``_event_channels`` dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fig : matplotlib.figure.Figure | Figure3D
|
||||
The figure to get the event channel for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
channel : dict[event -> list]
|
||||
The event channel. An event channel is a list mapping string event
|
||||
names to a list of callback representing all subscribers to the
|
||||
channel.
|
||||
"""
|
||||
import matplotlib
|
||||
|
||||
from ._brain import Brain
|
||||
from .evoked_field import EvokedField
|
||||
|
||||
# Create the event channel if it doesn't exist yet
|
||||
if fig not in _event_channels:
|
||||
# The channel itself is a dict mapping string event names to a list of
|
||||
# subscribers. No subscribers yet for this new event channel.
|
||||
_event_channels[fig] = dict()
|
||||
|
||||
weakfig = weakref.ref(fig)
|
||||
|
||||
# When the figure is closed, its associated event channel should be
|
||||
# deleted. This is a good time to set this up.
|
||||
def delete_event_channel(event=None, *, weakfig=weakfig):
|
||||
"""Delete the event channel (callback function)."""
|
||||
fig = weakfig()
|
||||
if fig is None:
|
||||
return
|
||||
publish(fig, event=FigureClosing()) # Notify subscribers of imminent close
|
||||
logger.debug(f"unlink(({fig})")
|
||||
unlink(fig) # Remove channel from the _event_channel_links dict
|
||||
if fig in _event_channels:
|
||||
logger.debug(f" del _event_channels[{fig}]")
|
||||
del _event_channels[fig]
|
||||
if fig in _disabled_event_channels:
|
||||
logger.debug(f" _disabled_event_channels.remove({fig})")
|
||||
_disabled_event_channels.remove(fig)
|
||||
|
||||
# Hook up the above callback function to the close event of the figure
|
||||
# window. How this is done exactly depends on the various figure types
|
||||
# MNE-Python has.
|
||||
_validate_type(fig, (matplotlib.figure.Figure, Brain, EvokedField), "fig")
|
||||
if isinstance(fig, matplotlib.figure.Figure):
|
||||
fig.canvas.mpl_connect("close_event", delete_event_channel)
|
||||
else:
|
||||
assert hasattr(fig, "_renderer") # figures like Brain, EvokedField, etc.
|
||||
fig._renderer._window_close_connect(delete_event_channel, after=False)
|
||||
|
||||
# Now the event channel exists for sure.
|
||||
return _event_channels[fig]
|
||||
|
||||
|
||||
@verbose
|
||||
def publish(fig, event, *, verbose=None):
|
||||
"""Publish an event to all subscribers of the figure's channel.
|
||||
|
||||
The figure's event channel and all linked event channels are searched for
|
||||
subscribers to the given event. Each subscriber had provided a callback
|
||||
function when subscribing, so we call that.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fig : matplotlib.figure.Figure | Figure3D
|
||||
The figure that publishes the event.
|
||||
event : UIEvent
|
||||
Event to publish.
|
||||
%(verbose)s
|
||||
"""
|
||||
if fig in _disabled_event_channels:
|
||||
return
|
||||
|
||||
# Compile a list of all event channels that the event should be published
|
||||
# on.
|
||||
channels = [_get_event_channel(fig)]
|
||||
links = _event_channel_links.get(fig, None)
|
||||
if links is not None:
|
||||
for linked_fig, (include_events, exclude_events) in links.items():
|
||||
if (include_events is None or event.name in include_events) and (
|
||||
exclude_events is None or event.name not in exclude_events
|
||||
):
|
||||
channels.append(_get_event_channel(linked_fig))
|
||||
|
||||
# Publish the event by calling the registered callback functions.
|
||||
event.source = fig
|
||||
logger.debug(f"Publishing {event} on channel {fig}")
|
||||
for channel in channels:
|
||||
if event.name not in channel:
|
||||
channel[event.name] = set()
|
||||
for callback in channel[event.name]:
|
||||
callback(event=event)
|
||||
|
||||
|
||||
@verbose
|
||||
def subscribe(fig, event_name, callback, *, verbose=None):
|
||||
"""Subscribe to an event on a figure's event channel.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fig : matplotlib.figure.Figure | Figure3D
|
||||
The figure of which event channel to subscribe.
|
||||
event_name : str
|
||||
The name of the event to listen for.
|
||||
callback : callable
|
||||
The function that should be called whenever the event is published.
|
||||
%(verbose)s
|
||||
"""
|
||||
channel = _get_event_channel(fig)
|
||||
logger.debug(f"Subscribing to channel {channel}")
|
||||
if event_name not in channel:
|
||||
channel[event_name] = set()
|
||||
channel[event_name].add(callback)
|
||||
|
||||
|
||||
@verbose
|
||||
def unsubscribe(fig, event_names, callback=None, *, verbose=None):
|
||||
"""Unsubscribe from an event on a figure's event channel.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fig : matplotlib.figure.Figure | Figure3D
|
||||
The figure of which event channel to unsubscribe from.
|
||||
event_names : str | list of str
|
||||
Select which events to stop subscribing to. Can be a single string
|
||||
event name, a list of event names or ``"all"`` which will unsubscribe
|
||||
from all events.
|
||||
callback : callable | None
|
||||
The callback function that should be unsubscribed, leaving all other
|
||||
callback functions that may be subscribed untouched. By default
|
||||
(``None``) all callback functions are unsubscribed from the event.
|
||||
%(verbose)s
|
||||
"""
|
||||
channel = _get_event_channel(fig)
|
||||
|
||||
# Determine which events to unsubscribe for.
|
||||
if event_names == "all":
|
||||
if callback is None:
|
||||
event_names = list(channel.keys())
|
||||
else:
|
||||
event_names = list(k for k, v in channel.items() if callback in v)
|
||||
elif isinstance(event_names, str):
|
||||
event_names = [event_names]
|
||||
|
||||
for event_name in event_names:
|
||||
if event_name not in channel:
|
||||
warn(
|
||||
f'Cannot unsubscribe from event "{event_name}" as we have never '
|
||||
"subscribed to it."
|
||||
)
|
||||
continue
|
||||
|
||||
if callback is None:
|
||||
del channel[event_name]
|
||||
else:
|
||||
# Unsubscribe specific callback function.
|
||||
subscribers = channel[event_name]
|
||||
if callback in subscribers:
|
||||
subscribers.remove(callback)
|
||||
else:
|
||||
warn(
|
||||
f'Cannot unsubscribe {callback} from event "{event_name}" '
|
||||
"as it was never subscribed to it."
|
||||
)
|
||||
if len(subscribers) == 0:
|
||||
del channel[event_name] # keep things tidy
|
||||
|
||||
|
||||
@verbose
|
||||
def link(*figs, include_events=None, exclude_events=None, verbose=None):
|
||||
"""Link the event channels of two figures together.
|
||||
|
||||
When event channels are linked, any events that are published on one
|
||||
channel are simultaneously published on the other channel. Links are
|
||||
bi-directional.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
*figs : tuple of matplotlib.figure.Figure | tuple of Figure3D
|
||||
The figures whose event channel will be linked.
|
||||
include_events : list of str | None
|
||||
Select which events to publish across figures. By default (``None``),
|
||||
both figures will receive all of each other's events. Passing a list of
|
||||
event names will restrict the events being shared across the figures to
|
||||
only the given ones.
|
||||
exclude_events : list of str | None
|
||||
Select which events not to publish across figures. By default (``None``),
|
||||
no events are excluded.
|
||||
%(verbose)s
|
||||
"""
|
||||
if include_events is not None:
|
||||
include_events = set(include_events)
|
||||
if exclude_events is not None:
|
||||
exclude_events = set(exclude_events)
|
||||
|
||||
# Make sure the event channels of the figures are setup properly.
|
||||
for fig in figs:
|
||||
_get_event_channel(fig)
|
||||
if fig not in _event_channel_links:
|
||||
_event_channel_links[fig] = weakref.WeakKeyDictionary()
|
||||
|
||||
# Link the event channels
|
||||
for fig1 in figs:
|
||||
for fig2 in figs:
|
||||
if fig1 is not fig2:
|
||||
_event_channel_links[fig1][fig2] = (include_events, exclude_events)
|
||||
|
||||
|
||||
@verbose
|
||||
def unlink(fig, *, verbose=None):
|
||||
"""Remove all links involving the event channel of the given figure.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fig : matplotlib.figure.Figure | Figure3D
|
||||
The figure whose event channel should be unlinked from all other event
|
||||
channels.
|
||||
%(verbose)s
|
||||
"""
|
||||
linked_figs = _event_channel_links.get(fig)
|
||||
if linked_figs is not None:
|
||||
for linked_fig in linked_figs.keys():
|
||||
del _event_channel_links[linked_fig][fig]
|
||||
if len(_event_channel_links[linked_fig]) == 0:
|
||||
del _event_channel_links[linked_fig]
|
||||
if fig in _event_channel_links: # need to check again because of weak refs
|
||||
del _event_channel_links[fig]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def disable_ui_events(fig):
|
||||
"""Temporarily disable generation of UI events. Use as context manager.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fig : matplotlib.figure.Figure | Figure3D
|
||||
The figure whose UI event generation should be temporarily disabled.
|
||||
"""
|
||||
_disabled_event_channels.add(fig)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_disabled_event_channels.remove(fig)
|
||||
|
||||
|
||||
def _cleanup_agg():
|
||||
"""Call close_event for Agg canvases to help our doc build."""
|
||||
import matplotlib.backends.backend_agg
|
||||
import matplotlib.figure
|
||||
|
||||
for key in list(_event_channels): # we might remove keys as we go
|
||||
if isinstance(key, matplotlib.figure.Figure):
|
||||
canvas = key.canvas
|
||||
if isinstance(canvas, matplotlib.backends.backend_agg.FigureCanvasAgg):
|
||||
for cb in key.canvas.callbacks.callbacks["close_event"].values():
|
||||
cb = cb() # get the true ref
|
||||
if cb is not None:
|
||||
cb()
|
||||
+2814
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user