initial commit

This commit is contained in:
2025-08-19 09:13:22 -07:00
parent 28464811d6
commit 0977a3e14d
820 changed files with 1003358 additions and 2 deletions

View File

@@ -0,0 +1,11 @@
"""Plot Cortex Surface."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from ._brain import Brain, _LayeredMesh
from ._scraper import _BrainScraper
from ._linkviewer import _LinkViewer
__all__ = ["Brain"]

4141
mne/viz/_brain/_brain.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,98 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from ...utils import warn
from ..ui_events import link
class _LinkViewer:
"""Class to link multiple Brain objects."""
def __init__(self, brains, time=True, camera=False, colorbar=True, picking=False):
self.brains = brains
self.leader = self.brains[0] # select a brain as leader
# check time infos
times = [brain._times for brain in brains]
if time and not all(np.allclose(x, times[0]) for x in times):
warn("stc.times do not match, not linking time")
time = False
if camera:
self.link_cameras()
events_to_link = []
if time:
events_to_link.append("time_change")
if colorbar:
events_to_link.append("colormap_range")
for brain in brains[1:]:
link(self.leader, brain, include_events=events_to_link)
if picking:
def _func_add(*args, **kwargs):
for brain in self.brains:
brain._add_vertex_glyph2(*args, **kwargs)
brain.plotter.update()
def _func_remove(*args, **kwargs):
for brain in self.brains:
brain._remove_vertex_glyph2(*args, **kwargs)
# save initial picked points
initial_points = dict()
for hemi in ("lh", "rh"):
initial_points[hemi] = set()
for brain in self.brains:
initial_points[hemi] |= set(brain.picked_points[hemi])
# link the viewers
for brain in self.brains:
brain.clear_glyphs()
brain._add_vertex_glyph2 = brain._add_vertex_glyph
brain._add_vertex_glyph = _func_add
brain._remove_vertex_glyph2 = brain._remove_vertex_glyph
brain._remove_vertex_glyph = _func_remove
# link the initial points
for hemi in initial_points.keys():
if hemi in brain._layered_meshes:
mesh = brain._layered_meshes[hemi]._polydata
for vertex_id in initial_points[hemi]:
self.leader._add_vertex_glyph(hemi, mesh, vertex_id)
def set_fmin(self, value):
self.leader.update_lut(fmin=value)
def set_fmid(self, value):
self.leader.update_lut(fmid=value)
def set_fmax(self, value):
self.leader.update_lut(fmax=value)
def set_time_point(self, value):
self.leader.set_time_point(value)
def set_playback_speed(self, value):
self.leader.set_playback_speed(value)
def toggle_playback(self):
self.leader.toggle_playback()
def link_cameras(self):
from ..backends._pyvista import _add_camera_callback
def _update_camera(vtk_picker, event):
for brain in self.brains:
brain.plotter.update()
camera = self.leader.plotter.camera
_add_camera_callback(camera, _update_camera)
for brain in self.brains:
for renderer in brain.plotter.renderers:
renderer.camera = camera

102
mne/viz/_brain/_scraper.py Normal file
View File

@@ -0,0 +1,102 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import os
import os.path as op
from ._brain import Brain
class _BrainScraper:
"""Scrape Brain objects."""
def __repr__(self):
return "<BrainScraper>"
def __call__(self, block, block_vars, gallery_conf):
rst = ""
for brain in list(block_vars["example_globals"].values()):
# Only need to process if it's a brain with a time_viewer
# with traces on and shown in the same window, otherwise
# PyVista and matplotlib scrapers can just do the work
if (not isinstance(brain, Brain)) or brain._closed:
continue
from matplotlib import animation
from matplotlib import pyplot as plt
from sphinx_gallery.scrapers import matplotlib_scraper
img = brain.screenshot(time_viewer=True)
dpi = 100.0
figsize = (img.shape[1] / dpi, img.shape[0] / dpi)
fig = plt.figure(figsize=figsize, dpi=dpi)
ax = plt.Axes(fig, [0, 0, 1, 1])
fig.add_axes(ax)
img = ax.imshow(img)
movie_key = "# brain.save_movie"
if movie_key in block[1]:
kwargs = dict()
# Parse our parameters
lines = block[1].splitlines()
for li, line in enumerate(block[1].splitlines()):
if line.startswith(movie_key):
line = line[len(movie_key) :].replace("..., ", "")
for ni in range(1, 5): # should be enough
if len(lines) > li + ni and lines[li + ni].startswith(
"# "
):
line = line + lines[li + ni][1:].strip()
else:
break
assert line.startswith("(") and line.endswith(")")
kwargs.update(eval(f"dict{line}")) # nosec B307
for key, default in [
("time_dilation", 4),
("framerate", 24),
("tmin", None),
("tmax", None),
("interpolation", None),
("time_viewer", False),
]:
if key not in kwargs:
kwargs[key] = default
kwargs.pop("filename", None) # always omit this one
if brain.time_viewer:
assert kwargs["time_viewer"], "Must use time_viewer=True"
frames = brain._make_movie_frames(callback=None, **kwargs)
# Turn them into an animation
def func(frame):
img.set_data(frame)
return [img]
anim = animation.FuncAnimation(
fig,
func=func,
frames=frames,
blit=True,
interval=1000.0 / kwargs["framerate"],
)
# Out to sphinx-gallery:
#
# 1. A static image but hide it (useful for carousel)
if animation.FFMpegWriter.isAvailable():
writer = "ffmpeg"
elif animation.ImageMagickWriter.isAvailable():
writer = "imagemagick"
else:
writer = None
static_fname = next(block_vars["image_path_iterator"])
static_fname = static_fname[:-4] + ".gif"
anim.save(static_fname, writer=writer, dpi=dpi)
rel_fname = op.relpath(static_fname, gallery_conf["src_dir"])
rel_fname = rel_fname.replace(os.sep, "/").lstrip("/")
rst += f"\n.. image:: /{rel_fname}\n :class: hidden\n"
# 2. An animation that will be embedded and visible
block_vars["example_globals"]["_brain_anim_"] = anim
brain.close()
rst += matplotlib_scraper(block, block_vars, gallery_conf)
return rst

185
mne/viz/_brain/colormap.py Normal file
View File

@@ -0,0 +1,185 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
def create_lut(cmap, n_colors=256, center=None):
"""Return a colormap suitable for setting as a LUT."""
from .._3d import _get_cmap
assert not (isinstance(cmap, str) and cmap == "auto")
cmap = _get_cmap(cmap)
lut = np.round(cmap(np.linspace(0, 1, n_colors)) * 255.0).astype(np.int64)
return lut
def scale_sequential_lut(lut_table, fmin, fmid, fmax):
"""Scale a sequential colormap."""
assert fmin <= fmid <= fmax # guaranteed by calculate_lut
lut_table_new = lut_table.copy()
n_colors = lut_table.shape[0]
n_colors2 = n_colors // 2
if fmax == fmin:
fmid_idx = 0
else:
fmid_idx = np.clip(
int(np.round(n_colors * ((fmid - fmin) / (fmax - fmin))) - 1),
0,
n_colors - 2,
)
n_left = fmid_idx + 1
n_right = n_colors - n_left
for i in range(4):
lut_table_new[: fmid_idx + 1, i] = np.interp(
np.linspace(0, n_colors2 - 1, n_left), np.arange(n_colors), lut_table[:, i]
)
lut_table_new[fmid_idx + 1 :, i] = np.interp(
np.linspace(n_colors - 1, n_colors2, n_right)[::-1],
np.arange(n_colors),
lut_table[:, i],
)
return lut_table_new
def get_fill_colors(cols, n_fill):
"""Get the fill colors for the middle of divergent colormaps."""
steps = np.linalg.norm(np.diff(cols[:, :3].astype(float), axis=0), axis=1)
ind = np.flatnonzero(steps[1:-1] > steps[[0, -1]].mean() * 3)
if ind.size > 0:
# choose the two colors between which there is the large step
ind = ind[0] + 1
fillcols = np.r_[
np.tile(cols[ind, :], (n_fill // 2, 1)),
np.tile(cols[ind + 1, :], (n_fill - n_fill // 2, 1)),
]
else:
# choose a color from the middle of the colormap
fillcols = np.tile(cols[int(cols.shape[0] / 2), :], (n_fill, 1))
return fillcols
def calculate_lut(lut_table, alpha, fmin, fmid, fmax, center=None, transparent=True):
"""Transparent color map calculation.
A colormap may be sequential or divergent. When the colormap is
divergent indicate this by providing a value for 'center'. The
meanings of fmin, fmid and fmax are different for sequential and
divergent colormaps. A sequential colormap is characterised by::
[fmin, fmid, fmax]
where fmin and fmax define the edges of the colormap and fmid
will be the value mapped to the center of the originally chosen colormap.
A divergent colormap is characterised by::
[center-fmax, center-fmid, center-fmin, center,
center+fmin, center+fmid, center+fmax]
i.e., values between center-fmin and center+fmin will not be shown
while center-fmid will map to the fmid of the first half of the
original colormap and center-fmid to the fmid of the second half.
Parameters
----------
lim_cmap : Colormap
Color map obtained from _process_mapdata.
alpha : float
Alpha value to apply globally to the overlay. Has no effect with mpl
backend.
fmin : float
Min value in colormap.
fmid : float
Intermediate value in colormap.
fmax : float
Max value in colormap.
center : float or None
If not None, center of a divergent colormap, changes the meaning of
fmin, fmax and fmid.
transparent : boolean
if True: use a linear transparency between fmin and fmid and make
values below fmin fully transparent (symmetrically for divergent
colormaps)
Returns
-------
cmap : matplotlib.ListedColormap
Color map with transparency channel.
"""
if not fmin <= fmid <= fmax:
raise ValueError(f"Must have fmin ({fmin}) <= fmid ({fmid}) <= fmax ({fmax})")
lut_table = create_lut(lut_table)
assert lut_table.dtype.kind == "i"
divergent = center is not None
n_colors = lut_table.shape[0]
# Add transparency if needed
n_colors2 = n_colors // 2
if transparent:
if divergent:
N4 = np.full(4, n_colors // 4)
N4[[0, 3, 1, 2][: np.mod(n_colors, 4)]] += 1
assert N4.sum() == n_colors
lut_table[:, -1] = np.round(
np.hstack(
[
np.full(N4[0], 255.0),
np.linspace(0, 255, N4[1])[::-1],
np.linspace(0, 255, N4[2]),
np.full(N4[3], 255.0),
]
)
)
else:
lut_table[:n_colors2, -1] = np.round(np.linspace(0, 255, n_colors2))
lut_table[n_colors2:, -1] = 255
alpha = float(alpha)
if alpha < 1.0:
lut_table[:, -1] = np.round(lut_table[:, -1] * alpha)
if divergent:
if np.isclose(fmax, fmin, rtol=1e-6, atol=0):
lut_table = np.r_[
lut_table[:1],
get_fill_colors(
lut_table[n_colors2 - 3 : n_colors2 + 3, :], n_colors - 2
),
lut_table[-1:],
]
else:
n_fill = int(round(fmin * n_colors2 / (fmax - fmin))) * 2
lut_table = np.r_[
scale_sequential_lut(
lut_table[:n_colors2, :],
center - fmax,
center - fmid,
center - fmin,
),
get_fill_colors(lut_table[n_colors2 - 3 : n_colors2 + 3, :], n_fill),
scale_sequential_lut(
lut_table[n_colors2:, :][::-1],
center - fmax,
center - fmid,
center - fmin,
)[::-1],
]
else:
lut_table = scale_sequential_lut(lut_table, fmin, fmid, fmax)
n_colors = lut_table.shape[0]
if n_colors != 256:
lut = np.zeros((256, 4))
x = np.linspace(1, n_colors, 256)
for chan in range(4):
lut[:, chan] = np.interp(x, np.arange(1, n_colors + 1), lut_table[:, chan])
lut_table = lut
lut_table = lut_table.astype(np.float64) / 255.0
return lut_table

177
mne/viz/_brain/surface.py Normal file
View File

@@ -0,0 +1,177 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
from os import path as path
import numpy as np
from ...surface import _read_patch, complete_surface_info, read_curvature, read_surface
from ...utils import _check_fname, _check_option, _validate_type, get_subjects_dir
class _Surface:
"""Container for a brain surface.
It is used for storing vertices, faces and morphometric data
(curvature) of a hemisphere mesh.
Parameters
----------
subject : string
Name of subject
hemi : {'lh', 'rh'}
Which hemisphere to load
surf : string
Name of the surface to load (eg. inflated, orig ...).
subjects_dir : str | None
If not None, this directory will be used as the subjects directory
instead of the value set using the SUBJECTS_DIR environment variable.
offset : float | None
If 0.0, the surface will be offset such that the medial
wall is aligned with the origin. If None, no offset will
be applied. If != 0.0, an additional offset will be used.
units : str
Can be 'm' or 'mm' (default).
x_dir : ndarray | None
The x direction to use for offset alignment.
Attributes
----------
bin_curv : numpy.ndarray
Curvature values stored as non-negative integers.
coords : numpy.ndarray
nvtx x 3 array of vertex (x, y, z) coordinates.
curv : numpy.ndarray
Vector representation of surface morpometry (curvature) values as
loaded from a file.
grey_curv : numpy.ndarray
Normalized morphometry (curvature) data, used in order to get
a gray cortex.
faces : numpy.ndarray
nfaces x 3 array of defining mesh triangles.
hemi : {'lh', 'rh'}
Which hemisphere to load.
nn : numpy.ndarray
Vertex normals for a triangulated surface.
offset : float | None
If float, align inside edge of each hemisphere to center + offset.
If None, do not change coordinates (default).
subject : string
Name of subject.
surf : string
Name of the surface to load (eg. inflated, orig ...).
units : str
Can be 'm' or 'mm' (default).
"""
def __init__(
self,
subject,
hemi,
surf,
subjects_dir=None,
offset=None,
units="mm",
x_dir=None,
):
x_dir = np.array([1.0, 0, 0]) if x_dir is None else x_dir
assert isinstance(x_dir, np.ndarray)
assert np.isclose(np.linalg.norm(x_dir), 1.0, atol=1e-6)
assert hemi in ("lh", "rh")
_validate_type(offset, (None, "numeric"), "offset")
self.units = _check_option("units", units, ("mm", "m"))
self.subject = subject
self.hemi = hemi
self.surf = surf
self.offset = offset
self.bin_curv = None
self.coords = None
self.curv = None
self.faces = None
self.nn = None
self.labels = dict()
self.x_dir = x_dir
subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=True))
self.data_path = path.join(subjects_dir, subject)
if surf == "seghead":
raise ValueError(
"`surf` cannot be seghead, use "
"`mne.viz.Brain.add_head` to plot the seghead"
)
def load_geometry(self):
"""Load geometry of the surface.
Parameters
----------
None
Returns
-------
None
"""
if self.surf == "flat": # special case
fname = path.join(self.data_path, "surf", f"{self.hemi}.cortex.patch.flat")
_check_fname(
fname, overwrite="read", must_exist=True, name="flatmap surface file"
)
coords, faces, orig_faces = _read_patch(fname)
# rotate 90 degrees to get to a more standard orientation
# where X determines the distance between the hemis
coords = coords[:, [1, 0, 2]]
coords[:, 1] *= -1
else:
# allow ?h.pial.T1 if ?h.pial doesn't exist for instance
# end with '' for better file not found error
for img in ("", ".T1", ".T2", ""):
surf_fname = path.join(
self.data_path, "surf", f"{self.hemi}.{self.surf}{img}"
)
if path.isfile(surf_fname):
break
coords, faces = read_surface(surf_fname)
orig_faces = faces
if self.units == "m":
coords /= 1000.0
if self.offset is not None:
x_ = coords @ self.x_dir
if self.hemi == "lh":
coords -= (np.max(x_) + self.offset) * self.x_dir
else:
coords -= (np.min(x_) + self.offset) * self.x_dir
surf = dict(rr=coords, tris=faces)
complete_surface_info(surf, copy=False, verbose=False, do_neighbor_tri=False)
nn = surf["nn"]
self.coords = coords
self.faces = faces
self.orig_faces = orig_faces
self.nn = nn
def __len__(self):
"""Return number of vertices."""
return len(self.coords)
@property
def x(self):
return self.coords[:, 0]
@property
def y(self):
return self.coords[:, 1]
@property
def z(self):
return self.coords[:, 2]
def load_curvature(self):
"""Load in curvature values from the ?h.curv file."""
curv_path = path.join(self.data_path, "surf", f"{self.hemi}.curv")
if path.isfile(curv_path):
self.curv = read_curvature(curv_path, binary=False)
self.bin_curv = np.array(self.curv > 0, np.int64)
else:
self.curv = None
self.bin_curv = None

54
mne/viz/_brain/view.py Normal file
View File

@@ -0,0 +1,54 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
ORIGIN = "auto"
DIST = "auto"
_lh_views_dict = {
"lateral": dict(azimuth=180.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
"medial": dict(azimuth=0.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
"rostral": dict(azimuth=90.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
"caudal": dict(azimuth=270.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
"dorsal": dict(azimuth=180.0, elevation=0.0, focalpoint=ORIGIN, distance=DIST),
"ventral": dict(azimuth=180.0, elevation=180.0, focalpoint=ORIGIN, distance=DIST),
"frontal": dict(azimuth=120.0, elevation=80.0, focalpoint=ORIGIN, distance=DIST),
"parietal": dict(azimuth=-120.0, elevation=60.0, focalpoint=ORIGIN, distance=DIST),
"sagittal": dict(azimuth=180.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
"coronal": dict(azimuth=90.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
"axial": dict(
azimuth=180.0, elevation=0.0, focalpoint=ORIGIN, roll=0, distance=DIST
), # noqa: E501
}
_rh_views_dict = {
"lateral": dict(azimuth=180.0, elevation=-90.0, focalpoint=ORIGIN, distance=DIST),
"medial": dict(azimuth=0.0, elevation=-90.0, focalpoint=ORIGIN, distance=DIST),
"rostral": dict(azimuth=-90.0, elevation=-90.0, focalpoint=ORIGIN, distance=DIST),
"caudal": dict(azimuth=90.0, elevation=-90.0, focalpoint=ORIGIN, distance=DIST),
"dorsal": dict(azimuth=180.0, elevation=0.0, focalpoint=ORIGIN, distance=DIST),
"ventral": dict(azimuth=180.0, elevation=180.0, focalpoint=ORIGIN, distance=DIST),
"frontal": dict(azimuth=60.0, elevation=80.0, focalpoint=ORIGIN, distance=DIST),
"parietal": dict(azimuth=-60.0, elevation=60.0, focalpoint=ORIGIN, distance=DIST),
"sagittal": dict(azimuth=180.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
"coronal": dict(azimuth=90.0, elevation=90.0, focalpoint=ORIGIN, distance=DIST),
"axial": dict(
azimuth=180.0, elevation=0.0, focalpoint=ORIGIN, roll=0, distance=DIST
),
}
# add short-size version entries into the dict
lh_views_dict = _lh_views_dict.copy()
for k, v in _lh_views_dict.items():
lh_views_dict[k[:3]] = v
lh_views_dict["flat"] = dict(
azimuth=0, elevation=0, focalpoint=ORIGIN, roll=0, distance=DIST
)
rh_views_dict = _rh_views_dict.copy()
for k, v in _rh_views_dict.items():
rh_views_dict[k[:3]] = v
rh_views_dict["flat"] = dict(
azimuth=0, elevation=0, focalpoint=ORIGIN, roll=0, distance=DIST
)
views_dicts = dict(
lh=lh_views_dict, vol=lh_views_dict, both=lh_views_dict, rh=rh_views_dict
)