initial commit
This commit is contained in:
252
mne/viz/_proj.py
Normal file
252
mne/viz/_proj.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""Functions for plotting projectors."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._fiff.pick import _picks_to_idx
|
||||
from ..defaults import DEFAULTS
|
||||
from ..utils import _pl, _validate_type, verbose, warn
|
||||
from .evoked import _plot_evoked
|
||||
from .topomap import _plot_projs_topomap
|
||||
from .utils import _check_type_projs, plt_show
|
||||
|
||||
|
||||
@verbose
|
||||
def plot_projs_joint(
|
||||
projs, evoked, picks_trace=None, *, topomap_kwargs=None, show=True, verbose=None
|
||||
):
|
||||
"""Plot projectors and evoked jointly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
projs : list of Projection
|
||||
The projectors to plot.
|
||||
evoked : instance of Evoked
|
||||
The data to plot. Typically this is the evoked instance created from
|
||||
averaging the epochs used to create the projection.
|
||||
%(picks_plot_projs_joint_trace)s
|
||||
topomap_kwargs : dict | None
|
||||
Keyword arguments to pass to :func:`mne.viz.plot_projs_topomap`.
|
||||
%(show)s
|
||||
%(verbose)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
fig : instance of matplotlib Figure
|
||||
The figure.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function creates a figure with three columns:
|
||||
|
||||
1. The left shows the evoked data traces before (black) and after (green)
|
||||
projection.
|
||||
2. The center shows the topomaps associated with each of the projectors.
|
||||
3. The right again shows the data traces (black), but this time with:
|
||||
|
||||
1. The data projected onto each projector with a single normalization
|
||||
factor (solid lines). This is useful for seeing the relative power
|
||||
in each projection vector.
|
||||
2. The data projected onto each projector with individual normalization
|
||||
factors (dashed lines). This is useful for visualizing each time
|
||||
course regardless of its power.
|
||||
3. Additional data traces from ``picks_trace`` (solid yellow lines).
|
||||
This is useful for visualizing the "ground truth" of the time
|
||||
course, e.g. the measured EOG or ECG channel time courses.
|
||||
|
||||
.. versionadded:: 1.1
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ..evoked import Evoked
|
||||
|
||||
_validate_type(evoked, Evoked, "evoked")
|
||||
_validate_type(topomap_kwargs, (None, dict), "topomap_kwargs")
|
||||
projs = _check_type_projs(projs)
|
||||
topomap_kwargs = dict() if topomap_kwargs is None else topomap_kwargs
|
||||
if picks_trace is not None:
|
||||
picks_trace = _picks_to_idx(evoked.info, picks_trace, allow_empty=False)
|
||||
info = evoked.info
|
||||
ch_types = evoked.get_channel_types(unique=True, only_data_chs=True)
|
||||
proj_by_type = dict() # will be set up like an enumerate key->[pi, proj]
|
||||
ch_names_by_type = dict()
|
||||
used = np.zeros(len(projs), int)
|
||||
for ch_type in ch_types:
|
||||
these_picks = _picks_to_idx(info, ch_type, allow_empty=True)
|
||||
these_chs = [evoked.ch_names[pick] for pick in these_picks]
|
||||
ch_names_by_type[ch_type] = these_chs
|
||||
for pi, proj in enumerate(projs):
|
||||
if not set(these_chs).intersection(proj["data"]["col_names"]):
|
||||
continue
|
||||
if ch_type not in proj_by_type:
|
||||
proj_by_type[ch_type] = list()
|
||||
proj_by_type[ch_type].append([pi, deepcopy(proj)])
|
||||
used[pi] += 1
|
||||
missing = (~used.astype(bool)).sum()
|
||||
if missing:
|
||||
warn(
|
||||
f"{missing} projector{_pl(missing)} had no channel names "
|
||||
"present in epochs"
|
||||
)
|
||||
del projs
|
||||
ch_types = list(proj_by_type) # reduce to number we actually need
|
||||
# room for legend
|
||||
max_proj_per_type = max(len(x) for x in proj_by_type.values())
|
||||
cs_trace = 3
|
||||
cs_topo = 2
|
||||
n_col = max_proj_per_type * cs_topo + 2 * cs_trace
|
||||
n_row = len(ch_types)
|
||||
shape = (n_row, n_col)
|
||||
fig = plt.figure(
|
||||
figsize=(n_col * 1.1 + 0.5, n_row * 1.8 + 0.5), layout="constrained"
|
||||
)
|
||||
ri = 0
|
||||
# pick some sufficiently distinct colors (6 per proj type, e.g., ECG,
|
||||
# should be enough hopefully!)
|
||||
# https://personal.sron.nl/~pault/data/colourschemes.pdf
|
||||
# "Vibrant" color scheme
|
||||
proj_colors = [
|
||||
"#CC3311", # red
|
||||
"#009988", # teal
|
||||
"#0077BB", # blue
|
||||
"#EE3377", # magenta
|
||||
"#EE7733", # orange
|
||||
"#33BBEE", # cyan
|
||||
]
|
||||
trace_color = "#CCBB44" # yellow
|
||||
after_color, after_name = "#228833", "green"
|
||||
type_titles = DEFAULTS["titles"]
|
||||
last_ax = [None] * 2
|
||||
first_ax = dict()
|
||||
pe_kwargs = dict(show=False, draw=False)
|
||||
for ch_type, these_projs in proj_by_type.items():
|
||||
these_idxs, these_projs = zip(*these_projs)
|
||||
ch_names = ch_names_by_type[ch_type]
|
||||
idx = np.where(
|
||||
[np.isin(ch_names, proj["data"]["col_names"]).all() for proj in these_projs]
|
||||
)[0]
|
||||
used[idx] += 1
|
||||
count = len(these_projs)
|
||||
for proj in these_projs:
|
||||
sub_idx = [proj["data"]["col_names"].index(name) for name in ch_names]
|
||||
proj["data"]["data"] = proj["data"]["data"][:, sub_idx]
|
||||
proj["data"]["col_names"] = ch_names
|
||||
ba_ax = plt.subplot2grid(shape, (ri, 0), colspan=cs_trace, fig=fig)
|
||||
topo_axes = [
|
||||
plt.subplot2grid(
|
||||
shape, (ri, ci * cs_topo + cs_trace), colspan=cs_topo, fig=fig
|
||||
)
|
||||
for ci in range(count)
|
||||
]
|
||||
tr_ax = plt.subplot2grid(
|
||||
shape, (ri, n_col - cs_trace), colspan=cs_trace, fig=fig
|
||||
)
|
||||
# topomaps
|
||||
_plot_projs_topomap(these_projs, info=info, axes=topo_axes, **topomap_kwargs)
|
||||
for idx, proj, ax_ in zip(these_idxs, these_projs, topo_axes):
|
||||
ax_.set_title("") # could use proj['desc'] but it's long
|
||||
ax_.set_xlabel(f"projs[{idx}]", fontsize="small")
|
||||
unit = DEFAULTS["units"][ch_type]
|
||||
# traces
|
||||
this_evoked = evoked.copy().pick(ch_names)
|
||||
p = np.concatenate([p["data"]["data"] for p in these_projs])
|
||||
assert p.shape == (len(these_projs), len(this_evoked.data))
|
||||
traces = np.dot(p, this_evoked.data)
|
||||
traces *= np.sign(np.mean(np.dot(this_evoked.data, traces.T), 0))[:, np.newaxis]
|
||||
if picks_trace is not None:
|
||||
ch_traces = evoked.data[picks_trace]
|
||||
ch_traces -= np.mean(ch_traces, axis=1, keepdims=True)
|
||||
ch_traces /= np.abs(ch_traces).max()
|
||||
_plot_evoked(
|
||||
this_evoked, picks="all", axes=[tr_ax], **pe_kwargs, spatial_colors=False
|
||||
)
|
||||
for line in tr_ax.lines:
|
||||
line.set(lw=0.5, zorder=3)
|
||||
for t in list(tr_ax.texts):
|
||||
t.remove()
|
||||
scale = 0.8 * np.abs(tr_ax.get_ylim()).max()
|
||||
hs, labels = list(), list()
|
||||
traces /= np.abs(traces).max() # uniformly scaled
|
||||
for ti, trace in enumerate(traces):
|
||||
hs.append(
|
||||
tr_ax.plot(
|
||||
this_evoked.times,
|
||||
trace * scale,
|
||||
color=proj_colors[ti % len(proj_colors)],
|
||||
zorder=5,
|
||||
)[0]
|
||||
)
|
||||
labels.append(f"projs[{these_idxs[ti]}]")
|
||||
traces /= np.abs(traces).max(1, keepdims=True) # independently
|
||||
for ti, trace in enumerate(traces):
|
||||
tr_ax.plot(
|
||||
this_evoked.times,
|
||||
trace * scale,
|
||||
color=proj_colors[ti % len(proj_colors)],
|
||||
zorder=3.5,
|
||||
ls="--",
|
||||
lw=1.0,
|
||||
alpha=0.75,
|
||||
)
|
||||
if picks_trace is not None:
|
||||
trace_ch = [evoked.ch_names[pick] for pick in picks_trace]
|
||||
if len(picks_trace) == 1:
|
||||
trace_ch = trace_ch[0]
|
||||
hs.append(
|
||||
tr_ax.plot(
|
||||
this_evoked.times,
|
||||
ch_traces.T * scale,
|
||||
color=trace_color,
|
||||
lw=3,
|
||||
zorder=4,
|
||||
alpha=0.75,
|
||||
)[0]
|
||||
)
|
||||
labels.append(str(trace_ch))
|
||||
tr_ax.set(title="", xlabel="", ylabel="")
|
||||
# This will steal space from the subplots in a constrained layout
|
||||
# https://matplotlib.org/3.5.0/tutorials/intermediate/constrainedlayout_guide.html#legends # noqa: E501
|
||||
tr_ax.legend(
|
||||
hs,
|
||||
labels,
|
||||
loc="center left",
|
||||
borderaxespad=0.05,
|
||||
bbox_to_anchor=[1.05, 0.5],
|
||||
)
|
||||
last_ax[1] = tr_ax
|
||||
key = "Projected time course"
|
||||
if key not in first_ax:
|
||||
first_ax[key] = tr_ax
|
||||
# Before and after traces
|
||||
_plot_evoked(this_evoked, picks="all", axes=[ba_ax], **pe_kwargs)
|
||||
for line in ba_ax.lines:
|
||||
line.set(lw=0.5, zorder=3)
|
||||
loff = len(ba_ax.lines)
|
||||
this_proj_evoked = this_evoked.copy().add_proj(these_projs)
|
||||
# with meg='combined' any existing mag projectors (those already part
|
||||
# of evoked before we add_proj above) will have greatly
|
||||
# reduced power, so we ignore the warning about this issue
|
||||
this_proj_evoked.apply_proj(verbose="error")
|
||||
_plot_evoked(this_proj_evoked, picks="all", axes=[ba_ax], **pe_kwargs)
|
||||
for line in ba_ax.lines[loff:]:
|
||||
line.set(lw=0.5, zorder=4, color=after_color)
|
||||
for t in list(ba_ax.texts):
|
||||
t.remove()
|
||||
ba_ax.set(title="", xlabel="")
|
||||
ba_ax.set(ylabel=f"{type_titles[ch_type]}\n{unit}")
|
||||
last_ax[0] = ba_ax
|
||||
key = f"Before (black) and after ({after_name})"
|
||||
if key not in first_ax:
|
||||
first_ax[key] = ba_ax
|
||||
ri += 1
|
||||
for ax in last_ax:
|
||||
ax.set(xlabel="Time (s)")
|
||||
for title, ax in first_ax.items():
|
||||
ax.set_title(title, fontsize="medium")
|
||||
plt_show(show)
|
||||
return fig
|
||||
Reference in New Issue
Block a user