Files
flares/flares.py
2025-10-20 16:07:18 -07:00

3102 lines
112 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Filename: flares.py
Description: Core functionality for FLARES
Author: Tyler de Zeeuw
License: GPL-3.0
"""
# Built-in imports
import os
import sys
import platform
import threading
import logging
from io import BytesIO
from typing import Any, Optional, cast, Literal, Union
from itertools import compress
from copy import deepcopy
from multiprocessing import Queue
import os.path as op
import re
import traceback
from concurrent.futures import ProcessPoolExecutor, as_completed
# External library imports
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from matplotlib.colors import LinearSegmentedColormap
import numpy as np
from numpy.typing import NDArray
from numpy import float64
import pandas as pd
from pandas import DataFrame
import seaborn as sns
import h5py
from nilearn.plotting import plot_design_matrix # type: ignore
from nilearn.glm.regression import OLSModel
import statsmodels.formula.api as smf # type: ignore
from statsmodels.stats.multitest import multipletests
from scipy import stats
from scipy.spatial.distance import cdist
# Backen visualization needed to be defined for pyinstaller
import pyvistaqt # type: ignore
# import vtkmodules.util.data_model
# import vtkmodules.util.execution_model
# External library imports for mne
from mne import (
EvokedArray, SourceEstimate, Info, Epochs, Label, Annotations,
events_from_annotations, read_source_spaces,
stc_near_sensors, pick_types, grand_average, get_config, set_config, read_labels_from_annot
) # type: ignore
from mne.source_space import SourceSpaces
from mne.transforms import Transform # type: ignore
from mne.io import BaseRaw, read_raw_snirf # type: ignore
from mne.preprocessing.nirs import (
beer_lambert_law, optical_density,
temporal_derivative_distribution_repair,
source_detector_distances, short_channels
) # type: ignore
from mne.viz import Brain, plot_events, plot_evoked_topo, plot_compare_evokeds
from mne.filter import filter_data # type: ignore
from mne.utils import _check_fname, _validate_type, warn
from mne.channels import make_standard_montage
from mne.datasets.sample import data_path
from mne_nirs.visualisation import plot_glm_group_topo # type: ignore
from mne_nirs.channels import get_long_channels, get_short_channels # type: ignore
from mne_nirs.experimental_design import make_first_level_design_matrix # type: ignore
from mne_nirs.statistics import run_glm, statsmodels_to_results # type: ignore
from mne_nirs.signal_enhancement import (
enhance_negative_correlation, short_channel_regression
) # type: ignore
from mne_nirs.io.fold import fold_channel_specificity # type: ignore
from mne_nirs.preprocessing import peak_power # type: ignore
from mne_nirs.statistics._glm_level_first import RegressionResults # type: ignore
os.environ["SUBJECTS_DIR"] = str(data_path()) + "/subjects" # type: ignore
FIXED_CATEGORY_COLORS = {
"SCI only": "skyblue",
"PSP only": "salmon",
"SNR only": "lightgreen",
"PSP + SCI": "orange",
"SCI + SNR": "violet",
"PSP + SNR": "gold",
"SCI + PSP": "orange",
"SNR + SCI": "violet",
"SNR + PSP": "gold",
"PSP + SNR + SCI": "gray",
"SCI + PSP + SNR": "gray",
"SCI + SNR + PSP": "gray",
"PSP + SCI + SNR": "gray",
"PSP + SNR + SCI": "gray",
"SNR + SCI + PSP": "gray",
"SNR + PSP + SCI": "gray",
}
AGE: float
GENDER: str
SECONDS_TO_STRIP: int
DOWNSAMPLE: bool
DOWNSAMPLE_FREQUENCY: int
SCI: bool
SCI_TIME_WINDOW: int
SCI_THRESHOLD: float
SNR: bool
# SNR_TIME_WINDOW : int
SNR_THRESHOLD: float
PSP: bool
PSP_TIME_WINDOW: int
PSP_THRESHOLD: float
TDDR: bool
ENHANCE_NEGATIVE_CORRELATION: bool
SHORT_CHANNEL: bool
REMOVE_EVENTS: list
TIME_WINDOW_START: int
TIME_WINDOW_END: int
VERBOSITY = True
# FIXME: Shouldn't need each ordering - just order it before checking
FIXED_CATEGORY_COLORS = {
"SCI only": "skyblue",
"PSP only": "salmon",
"SNR only": "lightgreen",
"PSP + SCI": "orange",
"SCI + SNR": "violet",
"PSP + SNR": "gold",
"SCI + PSP": "orange",
"SNR + SCI": "violet",
"SNR + PSP": "gold",
"PSP + SNR + SCI": "gray",
"SCI + PSP + SNR": "gray",
"SCI + SNR + PSP": "gray",
"PSP + SCI + SNR": "gray",
"PSP + SNR + SCI": "gray",
"SNR + SCI + PSP": "gray",
"SNR + PSP + SCI": "gray",
}
AGE = 25
GENDER = ""
GROUP = "Default"
REQUIRED_KEYS: dict[str, Any] = {
"SECONDS_TO_STRIP": int,
"DOWNSAMPLE": bool,
"DOWNSAMPLE_FREQUENCY": int,
"SCI": bool,
"SCI_TIME_WINDOW": int,
"SCI_THRESHOLD": float,
"SNR": bool,
# SNR_TIME_WINDOW : int
"SNR_THRESHOLD": float,
"PSP": bool,
"PSP_TIME_WINDOW": int,
"PSP_THRESHOLD": float,
"SHORT_CHANNEL": bool,
"REMOVE_EVENTS": list,
"TIME_WINDOW_START": int,
"TIME_WINDOW_END": int
# "REJECT_PAIRS": bool,
# "FORCE_DROP_ANNOTATIONS": list,
# "FILTER_LOW_PASS": float,
# "FILTER_HIGH_PASS": float,
# "EPOCH_PAIR_TOLERANCE_WINDOW": int,
}
class ProcessingError(Exception):
def __init__(self, message: str = "Something went wrong!"):
self.message = message
super().__init__(self.message)
# Ensure that we are working in the directory of this file
script_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(script_dir)
PLATFORM_NAME = platform.system().lower()
# Configure logging to file with timestamps and realtime flush
if PLATFORM_NAME == 'darwin':
logging.basicConfig(
filename=os.path.join(os.path.dirname(sys.executable), "../../../fnirs_analysis.log"),
level=logging.INFO,
format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
filemode='a'
)
else:
logging.basicConfig(
filename='fnirs_analysis.log',
level=logging.INFO,
format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
filemode='a'
)
logger = logging.getLogger()
def set_config_me(config: dict[str, Any]) -> None:
"""
Validates and applies the given configuration dictionary.
Parameters
----------
config : dict[str, Any]
Dictionary containing configuration keys and their values.
"""
logger.info(f"[DEBUG] set_config called")
globals().update(config)
def set_metadata(file_path, metadata: dict[str, Any]) -> None:
"""
Validates and applies the given configuration dictionary.
Parameters
----------
config : dict[str, Any]
Dictionary containing configuration keys and their values.
"""
logger.info(f"[DEBUG] set_metadata called")
globals()['AGE'] = 25
globals()['GENDER'] = ""
globals()['GROUP'] = "Default"
if metadata.get(file_path) is not None:
file_metadata = metadata.get(file_path, {})
for key in ("AGE", "GENDER", "GROUP"):
val = file_metadata.get(key, None)
if val not in (None, '', [], {}, ()): # check for "empty" values
globals()[key] = val
from queue import Empty # This works with multiprocessing.Manager().Queue()
def gui_entry(config: dict[str, Any], gui_queue: Queue, progress_queue: Queue) -> None:
def forward_progress():
while True:
try:
msg = progress_queue.get(timeout=1)
if msg == "__done__":
break
gui_queue.put(msg)
except Empty:
continue
except Exception as e:
gui_queue.put({
"type": "error",
"error": f"Forwarding thread crashed: {e}",
"traceback": traceback.format_exc()
})
break
t = threading.Thread(target=forward_progress, daemon=True)
t.start()
try:
file_paths = config['SNIRF_FILES']
file_params = config['PARAMS']
file_metadata = config['METADATA']
max_workers = file_params.get("MAX_WORKERS", int(os.cpu_count()/4))
results = process_multiple_participants(
file_paths, file_params, file_metadata, progress_queue, max_workers
)
gui_queue.put({"success": True, "result": results})
except Exception as e:
gui_queue.put({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
})
finally:
# Always send done to the thread and avoid hanging
try:
progress_queue.put("__done__")
except:
pass
t.join(timeout=5) # prevent permanent hang
def process_participant_worker(args):
file_path, file_params, file_metadata, progress_queue = args
set_config_me(file_params)
set_metadata(file_path, file_metadata)
logger.info(f"DEBUG: Metadata for {file_path}: AGE={globals().get('AGE')}, GENDER={globals().get('GENDER')}, GROUP={globals().get('GROUP')}")
def progress_callback(step_idx):
if progress_queue:
progress_queue.put(('progress', file_path, step_idx))
try:
result = process_participant(file_path, progress_callback=progress_callback)
return file_path, result, None
except Exception as e:
error_trace = traceback.format_exc()
return file_path, None, (str(e), error_trace)
def process_multiple_participants(file_paths, file_params, file_metadata, progress_queue=None, max_workers=None):
results_by_file = {}
file_args = [(file_path, file_params, file_metadata, progress_queue) for file_path in file_paths]
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(process_participant_worker, arg): arg[0] for arg in file_args}
for future in as_completed(futures):
file_path = futures[future]
try:
file_path, result, error = future.result()
if error:
error_message, error_traceback = error
if progress_queue:
progress_queue.put({
"type": "error",
"file": file_path,
"error": error_message,
"traceback": error_traceback
})
continue
results_by_file[file_path] = result
except Exception as e:
print(f"Unexpected error processing {file_path}: {e}")
return results_by_file
def markbad(data, ax, ch_names: list[str]) -> None:
"""
Add a strikethrough to a plot for channels marked as bad.
Parameters
----------
data : BaseRaw
The loaded data object to process.
ax : Axes
Matplotlib Axes object where the strikethrough lines will be drawn.
ch_names : list[str]
List of channel names corresponding to the y-axis of the plot.
"""
# Iterate over all the channels
for i, ch in enumerate(ch_names):
# If it is marked as bad, place a strikethrough on the channel
if ch in data.info["bads"]:
ax.axhline(i + 0.5, ls="solid", lw=4, color="black", zorder=10) # type: ignore
def plot_timechannel_quality_metrics(data, scores, times: list[tuple[float]], color_stops: tuple[list[float], list[float]], threshold: float, title: Optional[str] = None):
"""
Generate two heatmaps visualizing channel quality metrics over time.
Parameters
----------
data : BaseRaw
The loaded data object to process.
scores : NDArray[float64]
A 2D array of quality scores for each channel over time.
times : list[tuple[float]]
List of time boundaries used to label each score column.
color_stops : tuple[list[float], list[float]]
Two lists of color values for custom colormaps.
threshold : float,
Threshold value for the color bar.
title : Optional[str], optional
Base title for the figures, (default is None).
Returns
-------
tuple[Figure, Figure]
- Figure: Heatmap of all scores across channels and time.
- Figure: Binary heatmap showing only scores above the threshold.
"""
# Get only the hbo / hbr channels once as we dont need to see the same results twice
half_ch = len(getattr(data, "ch_names")) // 2
ch_names = getattr(data, "ch_names")[:half_ch]
scores = scores[:half_ch, :]
# Extract rounded time points to use as column headers
cols = [np.round(t[0]) for t in times]
n_chans = len(ch_names)
vsize = 0.2 * n_chans
# Create the first figure
fig1, ax1 = plt.subplots(figsize=(10, vsize), layout="constrained") # type: ignore
fig1.suptitle(title + " - All Scores", fontsize=16, fontweight="bold") # type: ignore
# Create a DataFrame to structure data for the heatmap
data_to_plot = DataFrame(
data=scores,
columns=pd.Index(cols, name="Time (s)"),
index=pd.Index(ch_names, name="Channel"),
)
# Define a custom colormap using provided color stops and base colors
base_colors = ['red', 'red', 'yellow', 'green', 'green']
colors = list(zip(color_stops[0], base_colors[:len(color_stops[0])]))
cmap = mcolors.LinearSegmentedColormap.from_list('gyr', colors)
# Plot heatmap of scores
sns.heatmap( # type: ignore
data=data_to_plot,
cmap=cmap,
vmin=0,
vmax=1,
cbar_kws=dict(label="Score"),
ax=ax1,
)
# Add vertical dashed lines at each time boundary, sit the title, and place a black strikethrough through a bad channel
for x in range(1, len(times)):
ax1.axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") # type: ignore
ax1.set_title("All Scores", fontweight="bold") # type: ignore
markbad(data, ax1, ch_names)
# Calculate average score per channel and annotate to the right of the heatmap
avg_sci_subset: pd.Series[float] = data_to_plot.mean(axis=1) # type: ignore
norm = mcolors.Normalize(vmin=0, vmax=1)
text_x = data_to_plot.shape[1] + 0.5
for i, val in enumerate(avg_sci_subset):
color = cmap(norm(val))
ax1.text( # type: ignore
text_x,
i + 0.5,
f"{val:.3f}",
va='center',
ha='left',
fontsize=9,
color=color
)
ax1.set_xlim(right=text_x + 1.5)
plt.close(fig1)
# Create the second figure
fig2, ax2 = plt.subplots(figsize=(10, vsize), layout="constrained") # type: ignore
fig2.suptitle(title + " - Scores Above Threshold", fontsize=16, fontweight="bold") # type: ignore
# Create a DataFrame to structure data for the heatmap
data_to_plot = DataFrame(
data=scores > threshold,
columns=pd.Index(cols, name="Time (s)"),
index=pd.Index(ch_names, name="Channel"),
)
# Define a custom colormap using provided color stops and base colors
base_colors = ['red', 'red', 'white', 'white']
colors = list(zip(color_stops[1], base_colors[:len(color_stops[1])]))
cmap = mcolors.LinearSegmentedColormap.from_list('gyr', colors)
# Plot heatmap of scores
sns.heatmap( # type: ignore
data=data_to_plot,
vmin=0,
vmax=1,
cmap=cmap,
cbar_kws=dict(label="Score"),
ax=ax2,
)
# Add vertical dashed lines at each time boundary, sit the title, and place a black strikethrough through a bad channel
for x in range(1, len(times)):
ax2.axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") # type: ignore
ax2.set_title("Scores > Threshold", fontweight="bold") # type: ignore
markbad(data, ax2, ch_names)
plt.close(fig2)
return fig1, fig2
def scalp_coupling_index_windowed_raw(data, time_window: float = 3.0, l_freq: float = 0.7, h_freq: float = 1.5, l_trans_bandwidth: float = 0.3, h_trans_bandwidth: float = 0.3):
"""
Compute windowed scalp coupling index (SCI) across fNIRS channels.
Parameters
----------
data : BaseRaw
The loaded data object to process.
time_window : float, optional
Length of each time window in seconds (default is 3.0).
l_freq : float, optional
Low cutoff frequency for filtering in Hz (default is 0.7).
h_freq : float, optional
High cutoff frequency for filtering in Hz (default is 1.5).
l_trans_bandwidth : float, optional
Transition bandwidth for the low cutoff in Hz (default is 0.3).
h_trans_bandwidth : float, optional
Transition bandwidth for the high cutoff in Hz (default is 0.3).
Returns
-------
tuple[BaseRaw, NDArray[float64], list[tuple[float, float]]]
- BaseRaw: The original data object (unchanged). Ensures compatibility with peak_power().
- NDArray[float64]: Correlation scores for each channel and time window.
- list[tuple[float, float]]: Time intervals for each window in seconds.
"""
# Pick only fNIRS channels and sort them by channel name
picks: NDArray[np.intp] = pick_types(cast(Info, data.info), fnirs=True) # type: ignore
picks = picks[np.argsort([getattr(data, "ch_names")[pick] for pick in picks])]
# FIXME: This may happen if the heart rate calculation tries to set a value way too low
if l_freq < 0.3:
l_freq = 0.3
# Band-pass filter the selected fNIRS channels
filtered_data = filter_data(
getattr(data, "_data"),
getattr(data, "info")["sfreq"],
l_freq,
h_freq,
picks=picks,
verbose=False,
l_trans_bandwidth=l_trans_bandwidth, # type: ignore
h_trans_bandwidth=h_trans_bandwidth, # type: ignore
)
# Calculate number of samples per time window, the total number of windows, and prepare output variables
window_samples = int(np.ceil(time_window * getattr(data, "info")["sfreq"]))
n_windows = int(np.floor(len(data) / window_samples))
scores = np.zeros((len(picks), n_windows))
times: list[tuple[float, float]] = []
# Slide through the data in windows to compute scalp coupling index (SCI)
for window in range(n_windows):
start_sample = int(window * window_samples)
end_sample = start_sample + window_samples
end_sample = np.min([end_sample, len(data) - 1])
# Track time boundaries for each window
t_start = getattr(data, "times")[start_sample]
t_stop = getattr(data, "times")[end_sample]
times.append((t_start, t_stop))
# Iterate through channels in pairs (hbo, hbr). This requires them to be sorted by channel name
for ii in range(0, len(picks), 2):
c1 = filtered_data[picks[ii]][start_sample:end_sample]
c2 = filtered_data[picks[ii + 1]][start_sample:end_sample]
# Ensure the correlation data is valid
if np.std(c1) == 0 or np.std(c2) == 0 or np.any(np.isnan(c1)) or np.any(np.isnan(c2)):
c = 0
else:
c = np.corrcoef(c1, c2)[0][1]
# Assign the computed correlation to both channels in the pair
scores[ii, window] = c
scores[ii + 1, window] = c
scores = scores[np.argsort(picks)]
return data, scores, times
def calculate_scalp_coupling(data, l_freq: float = 0.7, h_freq: float = 1.5):
"""
Calculate the scalp coupling index (SCI) and identify bad channels based on a threshold.
Parameters
----------
data : BaseRaw
The loaded data object to process.
l_freq : float, optional
Low cutoff frequency for bandpass filtering in Hz (default is 0.7).
h_freq : float, optional
High cutoff frequency for bandpass filtering in Hz (default is 1.5)
Returns
-------
tuple[list[str], Figure, Figure]
- list[str]: Channel names identified as bad based on SCI threshold.
- Figure: Heatmap of all SCI scores across time and channels.
- Figure: Binary heatmap of SCI scores exceeding the threshold.
"""
print("Calculating scalp coupling index...")
# Compute the SCI
_, scores, times = scalp_coupling_index_windowed_raw(data, time_window=SCI_TIME_WINDOW, l_freq=l_freq, h_freq=h_freq)
# Identify channels that don't meet the provided threshold
print("Identifying channels that do not meet the threshold...")
sci = scores.mean(axis=1)
data.info["bads"] = list(compress(cast(list[str], getattr(data, "ch_names")), sci < SCI_THRESHOLD))
# Determine the colors based on the threshold, and create the figures
print("Creating the figures...")
color_stops = ([0.0, SCI_THRESHOLD, SCI_THRESHOLD+0.1, 0.8, 1.0], [0.0, SCI_THRESHOLD, SCI_THRESHOLD, 1.0])
fig1, fig2 = plot_timechannel_quality_metrics(data, scores, times, color_stops, SCI_THRESHOLD, "Scalp Coupling Index")
print("Successfully calculated scalp coupling index.")
return list(compress(cast(list[str], getattr(data, "ch_names")), sci < SCI_THRESHOLD)), fig1, fig2
def calculate_signal_noise_ratio(data):
"""
Calculates the signal-to-noise ratio (SNR) for each channel and identifies those below a defined threshold.
Parameters
----------
data : BaseRaw
The loaded data object to process.
Returns
-------
tuple[list[str], Figure]
- list[str]: A list of channel names that fall below the SNR threshold and are considered bad.
- Figure: A matplotlib Figure showing the channels' SNR values.
"""
print("Calculating signal to noise ratio...")
# Compute the signal-to-noise ratio values
print("Computing the signal to noise power...")
signal_band=(0.01, 0.5)
noise_band=(1.0, 10.0)
data_signal = data.copy().filter(*signal_band, verbose=False) #type: ignore
data_noise = data.copy().filter(*noise_band, verbose=False) #type: ignore
signal_power = np.mean(data_signal.get_data()**2, axis=1) #type: ignore
noise_power = np.mean(data_noise.get_data()**2, axis=1) #type: ignore
# Calculate the snr using the standard formula for dB
snr = 10 * np.log10(signal_power / (noise_power + np.finfo(float).eps))
# TODO: Understand what this does
groups: dict[str, list[str]] = {}
for ch in getattr(data, "ch_names"):
# Look for the space in the channel names and remove the characters after
# This is so we can get both oxy and deoxy to remove, as they will have the same source and detector
base = ch.rsplit(' ', 1)[0]
groups.setdefault(base, []).append(ch) # type: ignore
# If any of the channels do not meet our threshold, they will get inserted into the bad_channels set
bad_channels: set[str] = set()
for base, ch_list in groups.items():
if any(s < SNR_THRESHOLD for s, ch in zip(snr, getattr(data, "ch_names")) if ch in ch_list):
bad_channels.update(ch_list)
# Design and create the figure
print("Creating the figure...")
snr_fig, ax = plt.subplots(figsize=(12, 4), layout="constrained") # type: ignore
colors = [(0/20, 'red'), (SNR_THRESHOLD/20, 'red'), ((SNR_THRESHOLD+.5)/20, 'yellow'), ((SNR_THRESHOLD+1)/20, 'green'), (20/20, 'green')]
cmap = LinearSegmentedColormap.from_list('custom_snr_cmap', colors)
norm = mcolors.Normalize(vmin=0, vmax=20)
scatter = ax.scatter(range(len(snr)), snr, c=snr, cmap=cmap, alpha=0.8, s=100, norm=norm) # type: ignore
ax.set(xlabel="Channel Number", ylabel="Signal-to-Noise Ratio (dB)", xlim=[0, len(snr)], ylim=[0, 20])
ax.axhline(SNR_THRESHOLD, color='black', linestyle='--', alpha=0.3, linewidth=1) # type: ignore
cbar = snr_fig.colorbar(scatter, ax=ax, label="SNR Thresholds (dB)") # type: ignore
cbar.set_ticks([0, SNR_THRESHOLD, SNR_THRESHOLD+1, 20]) # type: ignore
cbar.set_ticklabels(['0', str(SNR_THRESHOLD), str(SNR_THRESHOLD+1), '20']) # type: ignore
plt.close()
print("Successfully calculated signal to noise ratio.")
return list(bad_channels), snr_fig
def build_fnirs_adjacency(raw, threshold_meters=0.03):
"""Build an adjacency dictionary for fNIRS channels using 3D distance."""
# Extract channel positions
ch_locs = []
ch_names = []
for ch in raw.info['chs']:
loc = ch['loc'][:3] # Get x, y, z coordinates
if not np.isnan(loc).any():
ch_locs.append(loc)
ch_names.append(ch['ch_name'])
ch_locs = np.array(ch_locs)
# Compute pairwise distances
dists = cdist(ch_locs, ch_locs)
# Build adjacency dictionary
adjacency = {}
for i, ch_name in enumerate(ch_names):
neighbors = [ch_names[j] for j in range(len(ch_names))
if 0 < dists[i, j] < threshold_meters]
adjacency[ch_name] = neighbors
return adjacency
def get_hbo_hbr_picks(raw):
# Pick all fNIRS channels
fnirs_picks = pick_types(raw.info, fnirs=True, exclude=[])
# Extract wavelengths from channel names (expecting something like 'S6_D4 763' or 'S6_D4 841')
wavelengths = []
for idx in fnirs_picks:
ch_name = raw.ch_names[idx]
# Extract last 3 digits from channel name using regex
match = re.search(r'(\d{3})$', ch_name)
if match:
wavelengths.append(int(match.group(1)))
else:
raise ValueError(f"Channel name '{ch_name}' does not end with 3 digits.")
wavelengths = np.array(wavelengths)
unique_wavelengths = np.unique(wavelengths)
if len(unique_wavelengths) != 2:
raise RuntimeError(f"Expected exactly 2 distinct wavelengths, found {unique_wavelengths}")
# Determine which is HbO (larger) and which is HbR (smaller)
hbr_wl = unique_wavelengths.min()
hbo_wl = unique_wavelengths.max()
print(f"HbR wavelength: {hbr_wl}, HbO wavelength: {hbo_wl}")
# Find picks corresponding to each wavelength
hbr_picks = [fnirs_picks[i] for i, wl in enumerate(wavelengths) if wl == hbr_wl]
hbo_picks = [fnirs_picks[i] for i, wl in enumerate(wavelengths) if wl == hbo_wl]
print(f"Found {len(hbr_picks)} HbR channels and {len(hbo_picks)} HbO channels.")
return hbo_picks, hbr_picks, hbo_wl, hbr_wl
def interpolate_fNIRS_bads_weighted_average(raw, bad_channels, max_dist=0.03, min_neighbors=2):
"""
Interpolate bad fNIRS channels using a distance-weighted average of nearby good channels.
Parameters
----------
raw : mne.io.Raw
The raw fNIRS data with bads marked in raw.info['bads'].
max_dist : float
Maximum distance (in meters) to consider for neighboring good channels.
min_neighbors : int
Minimum number of neighbors required to interpolate a bad channel.
Returns
-------
raw : mne.io.Raw
Modified raw object with bads interpolated (in-place).
"""
print("Finding fNIRS channels...")
hbo_picks, hbr_picks, hbo_wl, hbr_wl = get_hbo_hbr_picks(raw)
if len(hbo_picks) != len(hbr_picks):
raise RuntimeError("Number of HbO and HbR channels must be the same.")
# Base names without wavelength for pairing
def base_name(ch_name):
# Strip last 4 chars assuming format ' <wavelength>'
# e.g. "S6_D6 841" -> "S6_D6"
return ch_name[:-4]
hbo_names = [base_name(raw.ch_names[i]) for i in hbo_picks]
hbr_names = [base_name(raw.ch_names[i]) for i in hbr_picks]
# Sanity check: pairs must match
for i in range(len(hbo_names)):
if hbo_names[i] != hbr_names[i]:
raise RuntimeError(f"Channel pairs do not match: {hbo_names[i]} vs {hbr_names[i]}")
# Identify bad pairs if either channel in pair is bad
bad_pairs = []
good_pairs = []
for i, base in enumerate(hbo_names):
hbo_ch = raw.ch_names[hbo_picks[i]]
hbr_ch = raw.ch_names[hbr_picks[i]]
if (hbo_ch in raw.info['bads']) or (hbr_ch in raw.info['bads']):
bad_pairs.append(i)
else:
good_pairs.append(i)
print(f"Total pairs: {len(hbo_names)}")
print(f"Good pairs: {len(good_pairs)}")
print(f"Bad pairs to interpolate: {len(bad_pairs)}")
if len(bad_pairs) == 0:
print("No bad pairs found. Skipping interpolation.")
return raw
# Extract locations (use HbO channel loc as pair location)
locs = np.array([raw.info['chs'][hbo_picks[i]]['loc'][:3] for i in range(len(hbo_names))])
good_locs = locs[good_pairs]
bad_locs = locs[bad_pairs]
# Compute distance matrix between bad and good pairs
dist_matrix = cdist(bad_locs, good_locs)
interpolated_pairs = []
for i, bad_idx in enumerate(bad_pairs):
bad_base = hbo_names[bad_idx]
distances = dist_matrix[i]
close_idxs = np.where(distances < max_dist)[0]
print(f"\nInterpolating pair {bad_base} (index {bad_idx})")
print(f" Nearby good pairs found: {len(close_idxs)}")
if len(close_idxs) < min_neighbors:
print(f" Skipping {bad_base}: not enough neighbors (found {len(close_idxs)} < {min_neighbors})")
continue
weights = 1 / (distances[close_idxs] + 1e-6)
weights /= weights.sum()
neighbor_hbo_indices = [hbo_picks[good_pairs[idx]] for idx in close_idxs]
neighbor_hbr_indices = [hbr_picks[good_pairs[idx]] for idx in close_idxs]
neighbor_hbo_data = raw._data[neighbor_hbo_indices, :]
neighbor_hbr_data = raw._data[neighbor_hbr_indices, :]
interpolated_hbo = np.average(neighbor_hbo_data, axis=0, weights=weights)
interpolated_hbr = np.average(neighbor_hbr_data, axis=0, weights=weights)
raw._data[hbo_picks[bad_idx]] = interpolated_hbo
raw._data[hbr_picks[bad_idx]] = interpolated_hbr
interpolated_pairs.append(bad_base)
if interpolated_pairs:
bad_ch_to_remove = []
for base_ in interpolated_pairs:
bad_ch_to_remove.append(base_ + f" {hbr_wl}") # HbR
bad_ch_to_remove.append(base_ + f" {hbo_wl}") # HbO
raw.info['bads'] = [ch for ch in raw.info['bads'] if ch not in bad_ch_to_remove]
print("\nInterpolation complete.\n")
for ch in raw.info['bads']:
print(f"Channel {ch} still marked as bad.")
print("Bads cleared:", raw.info['bads'])
fig_raw_after = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="After interpolation", show=False)
return raw, fig_raw_after
def calculate_signal_noise_ratio(data):
"""
Calculates the signal-to-noise ratio (SNR) for each channel and identifies those below a defined threshold.
Parameters
----------
data : BaseRaw
The loaded data object to process.
Returns
-------
tuple[list[str], Figure]
- list[str]: A list of channel names that fall below the SNR threshold and are considered bad.
- Figure: A matplotlib Figure showing the channels' SNR values.
"""
print("Calculating signal to noise ratio...")
# Compute the signal-to-noise ratio values
print("Computing the signal to noise power...")
signal_band=(0.01, 0.5)
noise_band=(1.0, 10.0)
data_signal = data.copy().filter(*signal_band, verbose=False) #type: ignore
data_noise = data.copy().filter(*noise_band, verbose=False) #type: ignore
signal_power = np.mean(data_signal.get_data()**2, axis=1) #type: ignore
noise_power = np.mean(data_noise.get_data()**2, axis=1) #type: ignore
# Calculate the snr using the standard formula for dB
snr = 10 * np.log10(signal_power / (noise_power + np.finfo(float).eps))
# TODO: Understand what this does
groups: dict[str, list[str]] = {}
for ch in getattr(data, "ch_names"):
# Look for the space in the channel names and remove the characters after
# This is so we can get both oxy and deoxy to remove, as they will have the same source and detector
base = ch.rsplit(' ', 1)[0]
groups.setdefault(base, []).append(ch) # type: ignore
# If any of the channels do not meet our threshold, they will get inserted into the bad_channels set
bad_channels: set[str] = set()
for base, ch_list in groups.items():
if any(s < SNR_THRESHOLD for s, ch in zip(snr, getattr(data, "ch_names")) if ch in ch_list):
bad_channels.update(ch_list)
# Design and create the figure
print("Creating the figure...")
snr_fig, ax = plt.subplots(figsize=(12, 4), layout="constrained") # type: ignore
colors = [(0/20, 'red'), (SNR_THRESHOLD/20, 'red'), ((SNR_THRESHOLD+.5)/20, 'yellow'), ((SNR_THRESHOLD+1)/20, 'green'), (20/20, 'green')]
cmap = LinearSegmentedColormap.from_list('custom_snr_cmap', colors)
norm = mcolors.Normalize(vmin=0, vmax=20)
scatter = ax.scatter(range(len(snr)), snr, c=snr, cmap=cmap, alpha=0.8, s=100, norm=norm) # type: ignore
ax.set(xlabel="Channel Number", ylabel="Signal-to-Noise Ratio (dB)", xlim=[0, len(snr)], ylim=[0, 20])
ax.axhline(SNR_THRESHOLD, color='black', linestyle='--', alpha=0.3, linewidth=1) # type: ignore
cbar = snr_fig.colorbar(scatter, ax=ax, label="SNR Thresholds (dB)") # type: ignore
cbar.set_ticks([0, SNR_THRESHOLD, SNR_THRESHOLD+1, 20]) # type: ignore
cbar.set_ticklabels(['0', str(SNR_THRESHOLD), str(SNR_THRESHOLD+1), '20']) # type: ignore
plt.close()
print("Successfully calculated signal to noise ratio.")
return list(bad_channels), snr_fig
def calculate_peak_power(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5) -> tuple[list[str], Figure, Figure]:
"""
Calculate peak spectral power (PSP) for fNIRS channels and identify bad channels.
Parameters
----------
data : BaseRaw
The loaded data object to process.
l_freq : float, optional
Low cutoff frequency for filtering in Hz (default is 0.7)
h_freq : float, optional
High cutoff frequency for filtering in Hz (default is 1.5)
Returns
-------
tuple[list[str], Figure, Figure]
- list[str]: Names of channels below the PSP threshold.
- Figure: Heatmap of all PSP scores.
- Figure: Heatmap of scores above the PSP threshold.
"""
# Compute the PSP
_, scores, times = cast(tuple[NDArray[float64], NDArray[float64], list[tuple[float]]], peak_power(data, time_window=PSP_TIME_WINDOW, threshold=PSP_THRESHOLD, l_freq=l_freq, h_freq=h_freq))
# Identify channels that don't meet the provided threshold
psp = scores.mean(axis=1)
data.info["bads"] = list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD))
# Determine the colors based on the threshold, and create the figures
color_stops = ([0.0, PSP_THRESHOLD, PSP_THRESHOLD+0.1, PSP_THRESHOLD+0.2, 1.0], [0.0, PSP_THRESHOLD, PSP_THRESHOLD, 1.0])
psp1, psp2 = plot_timechannel_quality_metrics(data, scores, times, color_stops, PSP_THRESHOLD, "Peak Spectral Power")
return list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)), psp1, psp2
def mark_bads(raw, bad_sci, bad_snr, bad_psp):
bads_combined = list(set(bad_snr) | set(bad_sci) | set(bad_psp))
print(f"Automatically marked bad channels based on SNR and SCI: {bads_combined}")
raw.info['bads'].extend(bads_combined)
# Organize channels into categories
sets = [
(bad_sci, "SCI"),
(bad_psp, "PSP"),
(bad_snr, "SNR"),
]
# Graph what channels were dropped and why they were dropped
channel_categories: dict[str, str] = {}
for ch in bads_combined:
present_in = [name for s, name in sets if ch in s]
# Create a label for the category
if len(present_in) == 1:
label = f"{present_in[0]} only"
else:
label = " + ".join(sorted(present_in))
channel_categories[ch] = label
# Sort channels alphabetically within categories for nicer visualization
categories = sorted(set(channel_categories.values()))
channel_names: list[str] = []
category_labels: list[str] = []
for cat in categories:
chs_in_cat = sorted([ch for ch, c in channel_categories.items() if c == cat])
channel_names.extend(chs_in_cat)
category_labels.extend([cat] * len(chs_in_cat))
colors = {cat: FIXED_CATEGORY_COLORS[cat] for cat in categories}
# Create the figure
fig_dropped, ax = plt.subplots(figsize=(10, max(3, len(channel_names) * 0.3))) # type: ignore
y_pos = range(len(channel_names))
ax.barh(y_pos, [1]*len(channel_names), color=[colors[cat] for cat in category_labels]) # type: ignore
ax.set_yticks(y_pos) # type: ignore
ax.set_yticklabels(channel_names) # type: ignore
ax.set_xlabel("Marked as Bad") # type: ignore
ax.set_title(f"Bad Channels by Method for") # type: ignore
ax.set_xlim(0, 1)
ax.set_xticks([]) # type: ignore
ax.grid(axis='x', linestyle='--', alpha=0.7) # type: ignore
# Add a legend denoting why the channels were bad
for label, color in colors.items():
ax.bar(0, 0, color=color, label=label) # type: ignore
ax.legend() # type: ignore
fig_dropped.tight_layout()
raw_before = deepcopy(raw)
bads_channels = [ch for ch in raw.ch_names if ch in raw.info['bads']]
print(bads_channels)
if bads_channels:
fig_raw_before = raw_before.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], picks=bads_channels, title="What they were BEFORE", show=False)
else:
fig_dropped = None
fig_raw_before = None
return raw, fig_dropped, fig_raw_before, bads_channels
def filter_the_data(raw_haemo):
# --- STEP 5: Filtering (0.010.2 Hz bandpass) ---
fig_filter = raw_haemo.compute_psd(fmax=2).plot(
average=True, xscale="log", color="r", show=False, amplitude=False
)
#raw_haemo = raw_haemo.filter(l_freq=None, h_freq=0.4, h_trans_bandwidth=0.2)
raw_haemo = raw_haemo.filter(0.05, 0.7, h_trans_bandwidth=0.2, l_trans_bandwidth=0.02)
raw_haemo.compute_psd(fmax=2).plot(
average=True, xscale="log", axes=fig_filter.axes, color="g", amplitude=False, show=False
)
fig_raw_haemo_filter = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="Filtered HbO and HbR", show=False)
return fig_filter, fig_raw_haemo_filter
def epochs_calculations(raw_haemo, events, event_dict):
fig_epochs = [] # List to store figures
# Create epochs from raw data
epochs = Epochs(raw_haemo,
events,
event_id=event_dict,
tmin=-5,
tmax=15,
baseline=(None, 0))
# Make a copy of the epochs and drop bad ones
epochs2 = epochs.copy()
epochs2.drop_bad()
# Plot drop log
# TODO: Why show this if we never use epochs2?
fig_epochs_dropped = epochs2.plot_drop_log(show=False)
fig_epochs.append(("fig_epochs_dropped", fig_epochs_dropped))
# Plot for each condition
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 4))
for idx, condition in enumerate(epochs.event_id.keys()):
# Plot images for each condition
fig_epochs_data = epochs[condition].plot_image(
combine="mean",
vmin=-1,
vmax=1,
ts_args=dict(ylim=dict(hbo=[-1, 1], hbr=[-1, 1])),
show=False
)
for i in fig_epochs_data:
ax = fig.axes[0]
original_title = ax.get_title()
ax.set_title(f"{condition}: {original_title}")
fig_epochs.append((f"fig_{condition}_data_{idx}", i)) # Store with a unique name
# Evoked average figure for each condition
evoked_avg = epochs[condition].average()
clims = dict(hbo=[-1, 1], hbr=[1, -1])
condition_fig = evoked_avg.plot_image(clim=clims, show=False)
for ax in condition_fig.axes:
original_title = ax.get_title()
ax.set_title(f"{original_title} - {condition}")
fig_epochs.append((f"evoked_avg_{condition}", condition_fig)) # Store with a unique name
# Prepare evokeds and colors for topographic plot
evokeds3 = []
colors = []
conditions = list(epochs.event_id.keys())
cmap = plt.get_cmap("tab10", len(conditions))
for idx, cond in enumerate(conditions):
evoked = epochs[cond].average(picks="hbo")
evokeds3.append(evoked)
colors.append(cmap(idx))
# Create the topographic plot
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 4))
help = plot_evoked_topo(evokeds3, color=colors, axes=axes, legend=False, show=False)
# Build custom legend
lines = []
for color in colors:
line = plt.Line2D([0], [0], color=color, lw=2)
lines.append(line)
fig.legend(lines, conditions, loc="lower right")
fig_epochs.append(("evoked_topo", help)) # Store with a unique name
unique_annotations = set(raw_haemo.annotations.description)
for cond in unique_annotations:
# Evoked response for specific condition ("Activity")
evoked_stim1 = epochs[cond].average()
fig_evoked_hbo = evoked_stim1.copy().pick(picks='hbo').plot(time_unit='s', show=False)
fig_evoked_hbr = evoked_stim1.copy().pick(picks='hbr').plot(time_unit='s', show=False)
fig_epochs.append((f"fig_evoked_hbo_{cond}", fig_evoked_hbo)) # Store with a unique name
fig_epochs.append((f"fig_evoked_hbr_{cond}", fig_evoked_hbr)) # Store with a unique name
print("Evoked HbO peak amplitude:", evoked_stim1.copy().pick(picks='hbo').data.max())
evokeds = {}
for condition in epochs2.event_id:
evokeds[condition] = epochs2[condition].average()
print(f"Condition '{condition}': {len(epochs2[condition])} epochs averaged.")
all_evokeds = {}
for condition in epochs.event_id:
if condition not in all_evokeds:
all_evokeds[condition] = []
all_evokeds[condition].append(epochs[condition].average())
group_aucs = {}
# TODO: group averages with a single person?
group_averages = {cond: grand_average(evokeds) for cond, evokeds in all_evokeds.items()}
for condition, evoked in group_averages.items():
group_aucs[condition] = {}
for pick in ["hbo", "hbr"]:
picks_idx = [i for i, ch in enumerate(evoked.ch_names) if pick in ch]
if not picks_idx:
continue
data = evoked.data[picks_idx, :].mean(axis=0)
t_start, t_end = 0, 15
times_mask = (evoked.times >= t_start) & (evoked.times <= t_end)
data_segment = data[times_mask]
times_segment = evoked.times[times_mask]
auc = np.trapezoid(data_segment, times_segment)
group_aucs[condition][pick] = auc
# Final evoked comparison plot for each condition
for condition in conditions:
if condition not in evokeds:
continue
evoked = evokeds[condition]
fig, ax = plt.subplots(figsize=(6, 5))
legend_labels = ["Oxyhaemoglobin"]
for pick, color in zip(["hbo", "hbr"], ["r", "b"]):
plot_compare_evokeds(
evoked,
combine="mean",
picks=pick,
axes=ax,
show=False,
colors=[color],
legend=False,
title=f"Participant: nCondition: {condition}",
ylim=dict(hbo=[-0.5, 1], hbr=[-0.5, 1]),
show_sensors=False,
)
auc_value = group_aucs.get(condition, {}).get(pick, None)
if auc_value is not None:
label = f"{pick.upper()} AUC: {auc_value * 1e6:.4f} µM·s"
else:
label = f"{pick.upper()} AUC: N/A"
legend_labels.append(label)
if len(legend_labels) == 2:
legend_labels.append("Deoxyhaemoglobin")
ax.legend(legend_labels)
fig_epochs.append((f"fig_{condition}_compare_evokeds", fig)) # Store with a unique name
return epochs, fig_epochs
def make_design_matrix(raw_haemo, short_chans):
raw_haemo.resample(1, npad="auto")
raw_haemo._data = raw_haemo._data * 1e6
# 2) Create design matrix
if SHORT_CHANNEL:
short_chans.resample(1)
design_matrix = make_first_level_design_matrix(
raw=raw_haemo,
hrf_model='fir',
stim_dur=0.5,
fir_delays=range(15),
drift_model='cosine',
high_pass=0.01,
oversampling=1,
min_onset=-125,
add_regs=short_chans.get_data().T,
add_reg_names=short_chans.ch_names
)
else:
design_matrix = make_first_level_design_matrix(
raw=raw_haemo,
hrf_model='fir',
stim_dur=0.5,
fir_delays=range(15),
drift_model='cosine',
high_pass=0.01,
oversampling=1,
min_onset=-125,
)
print(design_matrix.head())
print(design_matrix.columns)
fig, ax1 = plt.subplots(figsize=(10, 6), constrained_layout=True)
_ = plot_design_matrix(design_matrix, axes=ax1)
return design_matrix, fig
def generate_montage_locations():
"""Get standard MNI montage locations in dataframe.
Data is returned in the same format as the eeg_positions library.
"""
# standard_1020 and standard_1005 are in MNI (fsaverage) space already,
# but we need to undo the scaling that head_scale will do
montage = make_standard_montage(
"standard_1005", head_size=0.09700884729534559
)
for d in montage.dig:
d["coord_frame"] = 2003
montage.dig[:] = montage.dig[3:]
montage.add_mni_fiducials() # now in fsaverage space
coords = pd.DataFrame.from_dict(montage.get_positions()["ch_pos"]).T
coords["label"] = coords.index
coords = coords.rename(columns={0: "x", 1: "y", 2: "z"})
return coords.reset_index(drop=True)
def _find_closest_standard_location(position, reference, *, out="label"):
"""Return closest montage label to coordinates.
Parameters
----------
position : array, shape (3,)
Coordinates.
reference : dataframe
As generated by _generate_montage_locations.
trans_pos : str
Apply a transformation to positions to specified frame.
Use None for no transformation.
"""
p0 = np.array(position)
p0.shape = (-1, 3)
# head_mri_t, _ = _get_trans("fsaverage", "head", "mri")
# p0 = apply_trans(head_mri_t, p0)
dists = cdist(p0, np.asarray(reference[["x", "y", "z"]], float))
if out == "label":
min_idx = np.argmin(dists)
return reference["label"][min_idx]
else:
assert out == "dists"
return dists
def _source_detector_fold_table(raw, cidx, reference, fold_tbl, interpolate):
src = raw.info["chs"][cidx]["loc"][3:6]
det = raw.info["chs"][cidx]["loc"][6:9]
ref_lab = list(reference["label"])
dists = _find_closest_standard_location([src, det], reference, out="dists")
src_min, det_min = np.argmin(dists, axis=1)
src_name, det_name = ref_lab[src_min], ref_lab[det_min]
tbl = fold_tbl.query("Source == @src_name and Detector == @det_name")
dist = np.linalg.norm(dists[[0, 1], [src_min, det_min]])
# Try reversing source and detector
if len(tbl) == 0:
tbl = fold_tbl.query("Source == @det_name and Detector == @src_name")
if len(tbl) == 0 and interpolate:
# Try something hopefully not too terrible: pick the one with the
# smallest net distance
good = np.isin(fold_tbl["Source"], reference["label"]) & np.isin(
fold_tbl["Detector"], reference["label"]
)
assert good.any()
tbl = fold_tbl[good]
assert len(tbl)
src_idx = [ref_lab.index(src) for src in tbl["Source"]]
det_idx = [ref_lab.index(det) for det in tbl["Detector"]]
# Original
tot_dist = np.linalg.norm([dists[0, src_idx], dists[1, det_idx]], axis=0)
assert tot_dist.shape == (len(tbl),)
idx = np.argmin(tot_dist)
dist_1 = tot_dist[idx]
src_1, det_1 = ref_lab[src_idx[idx]], ref_lab[det_idx[idx]]
# And the reverse
tot_dist = np.linalg.norm([dists[0, det_idx], dists[1, src_idx]], axis=0)
idx = np.argmin(tot_dist)
dist_2 = tot_dist[idx]
src_2, det_2 = ref_lab[det_idx[idx]], ref_lab[src_idx[idx]]
if dist_1 < dist_2:
new_dist, src_use, det_use = dist_1, src_1, det_1
else:
new_dist, src_use, det_use = dist_2, det_2, src_2
tbl = fold_tbl.query("Source == @src_use and Detector == @det_use")
tbl = tbl.copy()
tbl["BestSource"] = src_name
tbl["BestDetector"] = det_name
tbl["BestMatchDistance"] = dist
tbl["MatchDistance"] = new_dist
assert len(tbl)
else:
tbl = tbl.copy()
tbl["BestSource"] = src_name
tbl["BestDetector"] = det_name
tbl["BestMatchDistance"] = dist
tbl["MatchDistance"] = dist
tbl = tbl.copy() # don't get warnings about setting values later
return tbl
def _read_fold_xls(fname, atlas="Juelich"):
"""Read fOLD toolbox xls file.
The values are then manipulated in to a tidy dataframe.
Note the xls files are not included as no license is provided.
Parameters
----------
fname : str
Path to xls file.
atlas : str
Requested atlas.
"""
page_reference = {"AAL2": 2, "AICHA": 5, "Brodmann": 8, "Juelich": 11, "Loni": 14}
tbl = pd.read_excel(fname, sheet_name=page_reference[atlas])
# Remove the spacing between rows
empty_rows = np.where(np.isnan(tbl["Specificity"]))[0]
tbl = tbl.drop(empty_rows).reset_index(drop=True)
# Empty values in the table mean its the same as above
for row_idx in range(1, tbl.shape[0]):
for col_idx, col in enumerate(tbl.columns):
if not isinstance(tbl[col][row_idx], str):
if np.isnan(tbl[col][row_idx]):
tbl.iloc[row_idx, col_idx] = tbl.iloc[row_idx - 1, col_idx]
tbl["Specificity"] = tbl["Specificity"] * 100
tbl["brainSens"] = tbl["brainSens"] * 100
return tbl
def _check_load_fold(fold_files, atlas):
# _validate_type(fold_files, (list, "path-like", None), "fold_files")
if fold_files is None:
fold_files = get_config("MNE_NIRS_FOLD_PATH")
if fold_files is None:
raise ValueError(
"MNE_NIRS_FOLD_PATH not set, either set it using "
"mne.set_config or pass fold_files as str or list"
)
if not isinstance(fold_files, list): # path-like
fold_files = _check_fname(
fold_files,
overwrite="read",
must_exist=True,
name="fold_files",
need_dir=True,
)
fold_files = [op.join(fold_files, f"10-{x}.xls") for x in (5, 10)]
fold_tbl = pd.DataFrame()
for fi, fname in enumerate(fold_files):
fname = _check_fname(
fname, overwrite="read", must_exist=True, name=f"fold_files[{fi}]"
)
fold_tbl = pd.concat(
[fold_tbl, _read_fold_xls(fname, atlas=atlas)], ignore_index=True
)
return fold_tbl
def fold_channel_specificity_normal(raw, fold_files=None, atlas="Juelich", interpolate=False):
"""Return the landmarks and specificity a channel is sensitive to.
Parameters
""" # noqa: E501
_validate_type(raw, BaseRaw, "raw")
reference_locations = generate_montage_locations()
fold_tbl = _check_load_fold(fold_files, atlas)
chan_spec = list()
for cidx in range(len(raw.ch_names)):
tbl = _source_detector_fold_table(
raw, cidx, reference_locations, fold_tbl, interpolate
)
chan_spec.append(tbl.reset_index(drop=True))
return chan_spec
def resource_path(relative_path):
"""
Get absolute path to resource regardless of running directly or packaged using PyInstaller
"""
if hasattr(sys, '_MEIPASS'):
# PyInstaller bundle path
base_path = sys._MEIPASS
else:
base_path = os.path.abspath(".")
return os.path.join(base_path, relative_path)
def fold_channels(raw: BaseRaw) -> None:
# if getattr(sys, 'frozen', False):
path = os.path.expanduser("~") + "/mne_data/fOLD/fOLD-public-master/Supplementary"
logger.info(path)
set_config('MNE_NIRS_FOLD_PATH', resource_path(path)) # type: ignore
# # Locate the fOLD excel files
# else:
# logger.info("yabba")
# set_config('MNE_NIRS_FOLD_PATH', resource_path("../../mne_data/fOLD/fOLD-public-master/Supplementary")) # type: ignore
output = None
# List to store the results
landmark_specificity_data: list[dict[str, Any]] = []
# Filter the data to only what we want
hbo_channel_names = cast(list[str], getattr(raw.copy().pick(picks='hbo'), "ch_names")) # type: ignore
# Format the output to make it slightly easier to read
if True:
num_channels = len(hbo_channel_names)
rows, cols = 4, 7 # 6 rows and 4 columns of pie charts
fig, axes = plt.subplots(rows, cols, figsize=(16, 10), constrained_layout=True)
axes = axes.flatten() # Flatten the axes array for easier indexing
# If more pie charts than subplots, create extra subplots
if num_channels > rows * cols:
fig, axes = plt.subplots((num_channels // cols) + 1, cols, figsize=(16, 10), constrained_layout=True)
axes = axes.flatten()
# Create a list for consistent color mapping
landmarks = [
"1 - Primary Somatosensory Cortex",
"2 - Primary Somatosensory Cortex",
"3 - Primary Somatosensory Cortex",
"4 - Primary Motor Cortex",
"5 - Somatosensory Association Cortex",
"6 - Pre-Motor and Supplementary Motor Cortex",
"7 - Somatosensory Association Cortex",
"8 - Includes Frontal eye fields",
"9 - Dorsolateral prefrontal cortex",
"10 - Frontopolar area",
"11 - Orbitofrontal area",
"17 - Primary Visual Cortex (V1)",
"18 - Visual Association Cortex (V2)",
"19 - V3",
"20 - Inferior Temporal gyrus",
"21 - Middle Temporal gyrus",
"22 - Superior Temporal Gyrus",
"23 - Ventral Posterior cingulate cortex",
"24 - Ventral Anterior cingulate cortex",
"25 - Subgenual cortex",
"32 - Dorsal anterior cingulate cortex",
"37 - Fusiform gyrus",
"38 - Temporopolar area",
"39 - Angular gyrus, part of Wernicke's area",
"40 - Supramarginal gyrus part of Wernicke's area",
"41 - Primary and Auditory Association Cortex",
"42 - Primary and Auditory Association Cortex",
"43 - Subcentral area",
"44 - pars opercularis, part of Broca's area",
"45 - pars triangularis Broca's area",
"46 - Dorsolateral prefrontal cortex",
"47 - Inferior prefrontal gyrus",
"48 - Retrosubicular area",
"Brain_Outside",
]
cmap1 = plt.get_cmap('tab20') # First 20 colors
cmap2 = plt.get_cmap('tab20b') # Next 20 colors
# Combine the colors from both colormaps
colors = [cmap1(i) for i in range(20)] + [cmap2(i) for i in range(20)] # Total 40 colors
landmarks.sort(key=lambda x: (int(x.split(" - ")[0]) if x.split(" - ")[0].isdigit() else float('inf')))
landmark_color_map = {landmark: colors[i % len(colors)] for i, landmark in enumerate(landmarks)}
# Iterate over each channel
for idx, channel_name in enumerate(hbo_channel_names):
# Run the fOLD on the selected channel
channel_data = raw.copy().pick(picks=channel_name) # type: ignore
output = cast(list[DataFrame], fold_channel_specificity_normal(channel_data, interpolate=True, atlas='Brodmann'))
# Process each DataFrame that fold_channel_specificity returns
for df_data in output:
# Extract the relevant columns
useful_data = df_data[['Landmark', 'Specificity']]
# Store the results
landmark_specificity_data.append({
'Channel': channel_name,
'Data': useful_data,
})
# Plot the results
# TODO: Fix this
if True:
unique_landmarks = sorted(useful_data['Landmark'].unique())
color_list = [landmark_color_map[landmark] for landmark in useful_data['Landmark']]
# Plot specificity for each channel
ax = axes[idx]
labels = [f'{landmark.split(" - ")[0]}' if landmark != 'Brain_Outside' else 'B' for landmark in useful_data['Landmark']]
wedges, texts, autotexts = ax.pie(
useful_data['Specificity'],
autopct='%1.1f%%',
startangle=90,
labels=labels,
labeldistance=1.05,
colors=color_list)
ax.set_title(f'{channel_name}')
ax.axis('equal')
landmark_specificity_data = []
# TODO: Fix this
if True:
handles = [
plt.Line2D([0], [0], marker='o', color='w', label=landmark, markersize=10,
markerfacecolor=landmark_color_map[landmark])
for landmark in landmarks
]
n_landmarks = len(landmarks)
# Calculate the figure size based on number of rows and columns
fig_width = 5
fig_height = n_landmarks / 4
# Create a new figure window for the legend
legend_fig = plt.figure(figsize=(fig_width, fig_height))
legend_axes = legend_fig.add_subplot(111)
legend_axes.axis('off') # Turn off axis for the legend window
legend_axes.legend(handles=handles, loc='center', fontsize=10, title="Landmarks")
for ax in axes[len(hbo_channel_names):]:
ax.axis('off')
plt.show()
return fig, legend_fig
def individual_significance(raw_haemo, glm_est):
fig_individual_significances = [] # List to store figures
# TODO: BAD!
cha = glm_est.to_dataframe()
unique_annotations = set(raw_haemo.annotations.description)
for cond in unique_annotations:
ch_summary = cha.query(f"Condition.str.startswith('{cond}_delay_') and Chroma == 'hbo'", engine='python')
print(ch_summary.head())
channel_averages = ch_summary.groupby('ch_name')['theta'].mean().reset_index()
print(channel_averages.head())
activity_ch_summary = ch_summary.query(
f"Chroma == 'hbo' and Condition.str.startswith('{cond}_delay_')", engine='python'
)
# Function to correct p-values per channel
def fdr_correct_per_channel(df):
df = df.copy()
df['pval_fdr'] = multipletests(df['p_value'], method='fdr_bh')[1]
return df
# Apply FDR correction grouped by channel
corrected = activity_ch_summary.groupby("ch_name", group_keys=False).apply(fdr_correct_per_channel)
# Determine which channels are significant across any delay
sig_channels = (
corrected.groupby('ch_name')
.apply(lambda df: (df['pval_fdr'] < 0.05).any())
.reset_index(name='significant')
)
# Merge with mean theta (optional for plotting)
mean_theta = activity_ch_summary.groupby('ch_name')['theta'].mean().reset_index()
sig_channels = sig_channels.merge(mean_theta, on='ch_name')
print(sig_channels)
# For example, take the minimum corrected p-value per channel
summary_pvals = corrected.groupby('ch_name')['pval_fdr'].min().reset_index()
print(summary_pvals)
def parse_ch_name(ch_name):
# Extract numbers after S and D in names like 'S10_D5 hbo'
match = re.match(r'S(\d+)_D(\d+)', ch_name)
if match:
return int(match.group(1)), int(match.group(2))
else:
return None, None
min_pvals = corrected.groupby('ch_name')['pval_fdr'].min().reset_index()
# Merge the real p-values into sig_channels / avg_df
avg_df = sig_channels.merge(min_pvals, on='ch_name')
# Rename columns for consistency
avg_df = avg_df.rename(columns={'theta': 't_or_theta', 'pval_fdr': 'p_value'})
# Add Source and Detector columns again
avg_df['Source'], avg_df['Detector'] = zip(*avg_df['ch_name'].map(parse_ch_name))
# Keep relevant columns
avg_df = avg_df[['Source', 'Detector', 't_or_theta', 'p_value']].dropna()
ABS_SIGNIFICANCE_THETA_VALUE = 1
ABS_SIGNIFICANCE_T_VALUE = 1
P_THRESHOLD = 0.05
SOURCE_DETECTOR_SEPARATOR = "_"
t_or_theta = 'theta'
for _, row in avg_df.iterrows(): # type: ignore
print(f"Source {row['Source']} <-> Detector {row['Detector']}: "
f"Avg {t_or_theta}-value = {row['t_or_theta']:.3f}, Avg p-value = {row['p_value']:.3f}")
# Extract the cource and detector positions from raw
src_pos: dict[int, tuple[float, float]] = {}
det_pos: dict[int, tuple[float, float]] = {}
for ch in getattr(raw_haemo, "info")["chs"]:
ch_name = ch['ch_name']
if not ch_name or not ch['loc'].any():
continue
parts = ch_name.split()[0]
src_str, det_str = parts.split(SOURCE_DETECTOR_SEPARATOR)
src_num = int(src_str[1:])
det_num = int(det_str[1:])
src_pos[src_num] = ch['loc'][3:5]
det_pos[det_num] = ch['loc'][6:8]
# Set up the plot
fig, ax = plt.subplots(figsize=(8, 6)) # type: ignore
# Plot the sources
for pos in src_pos.values():
ax.scatter(pos[0], pos[1], s=120, c='k', marker='o', edgecolors='white', linewidths=1, zorder=3) # type: ignore
# Plot the detectors
for pos in det_pos.values():
ax.scatter(pos[0], pos[1], s=120, c='k', marker='s', edgecolors='white', linewidths=1, zorder=3) # type: ignore
# Ensure that the colors stay within the boundaries even if they are over or under the max/min values
if t_or_theta == 't':
norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_T_VALUE, vmax=ABS_SIGNIFICANCE_T_VALUE)
elif t_or_theta == 'theta':
norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_THETA_VALUE, vmax=ABS_SIGNIFICANCE_THETA_VALUE)
cmap: mcolors.Colormap = plt.get_cmap('seismic')
# Plot connections with avg t-values
for row in avg_df.itertuples():
src: int = cast(int, row.Source) # type: ignore
det: int = cast(int, row.Detector) # type: ignore
tval: float = cast(float, row.t_or_theta) # type: ignore
pval: float = cast(float, row.p_value) # type: ignore
if src in src_pos and det in det_pos:
x = [src_pos[src][0], det_pos[det][0]]
y = [src_pos[src][1], det_pos[det][1]]
style = '-' if pval <= P_THRESHOLD else '--'
ax.plot(x, y, linestyle=style, color=cmap(norm(tval)), linewidth=4, alpha=0.9, zorder=2) # type: ignore
# Format the Colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, shrink=0.85) # type: ignore
cbar.set_label(f'Average {cond} {t_or_theta} value (hbo)', fontsize=11) # type: ignore
# Formatting the subplots
ax.set_aspect('equal')
ax.set_title(f"Average {t_or_theta} values for {cond} (HbO)", fontsize=14) # type: ignore
ax.set_xlabel('X position (m)', fontsize=11) # type: ignore
ax.set_ylabel('Y position (m)', fontsize=11) # type: ignore
ax.grid(True, alpha=0.3) # type: ignore
# Set axis limits to be 1cm more than the optode positions
all_x = [pos[0] for pos in src_pos.values()] + [pos[0] for pos in det_pos.values()]
all_y = [pos[1] for pos in src_pos.values()] + [pos[1] for pos in det_pos.values()]
ax.set_xlim(min(all_x)-0.01, max(all_x)+0.01)
ax.set_ylim(min(all_y)-0.01, max(all_y)+0.01)
fig.tight_layout()
fig_individual_significances.append((f"Condition {cond}", fig))
return fig_individual_significances
# TODO: Hardcoded
def group_significance(
raw_haemo,
all_cha: pd.DataFrame,
condition: str,
correction: str = "fdr_bh"
) -> plt.Figure:
"""
Compute group-level significance using weighted Stouffer's method and plot results.
Args:
raw_haemo: Raw haemoglobin MNE object (used for optode positions)
all_cha: DataFrame with columns including 'ID', 'Condition', 'p_value', 'theta', 'df', 'ch_name', 'Chroma'
condition: condition prefix, e.g., 'Activity'
correction: p-value correction method ('fdr_bh' or 'bonferroni')
Returns:
Matplotlib Figure with group-level theta values and significance.
"""
assert "ID" in all_cha.columns, "'ID' column missing in input data"
assert len(raw_haemo) >= 1, "At least one raw haemoglobin object is required"
condition_prefix = f"{condition}_delay"
# Filter relevant data
ch_summary = all_cha.query(
"Condition.str.startswith(@condition_prefix) and Chroma == 'hbo'",
engine='python'
).copy()
logger.info("=== ch_summary head ===")
logger.info(ch_summary.head())
logger.info("\nSummary stats:")
logger.info(f"Total rows: {len(ch_summary)}")
logger.info(f"Unique subjects: {ch_summary['ID'].nunique() if 'ID' in ch_summary.columns else 'ID column missing'}")
logger.info(f"Unique conditions: {ch_summary['Condition'].unique()}")
logger.info(f"Unique channels (Source-Detector pairs): {ch_summary.groupby(['Source', 'Detector']).ngroups}")
logger.info("\nSample p_values:")
logger.info(ch_summary['p_value'].describe())
if ch_summary.empty:
raise ValueError(f"No data found for condition prefix: {condition_prefix}")
# --- For debugging
logger.info(f"Total rows after filtering for condition '{condition_prefix}': {len(ch_summary)}")
logger.info(f"Unique channels: {ch_summary['ch_name'].nunique()}")
logger.info(f"Participants: {ch_summary['ID'].nunique()}")
# Step 1: Select the peak regressor (~6s after stimulus onset)
peak_regressor = f"{condition}_delay_6"
peak_data = ch_summary[ch_summary["Condition"] == peak_regressor].copy()
logger.info(f"\n=== Logging all values for {peak_regressor} ===")
for row in peak_data.itertuples(index=False):
logger.info(
f"Subject: {row.ID}, "
f"Channel: {row.ch_name}, "
f"Source: {row.Source}, Detector: {row.Detector}, "
f"theta: {row.theta:.4f}, "
f"p_value: {row.p_value:.6f}, "
f"df: {row.df}"
)
if peak_data.empty:
raise ValueError(f"No data found for peak regressor: {peak_regressor}")
# Step 2: Combine per-channel stats across subjects
group_results = []
for (src, det), group in peak_data.groupby(["Source", "Detector"]):
pvals = group["p_value"].values
thetas = group["theta"].values
dfs = group["df"].values
# Weighted Stouffer's method
weights = np.sqrt(dfs)
z_scores = norm.isf(pvals)
combined_z = np.sum(weights * z_scores) / np.sqrt(np.sum(weights**2))
combined_p = norm.sf(combined_z)
theta_avg = np.average(thetas, weights=weights)
group_results.append({
"Source": src,
"Detector": det,
"theta_avg": theta_avg,
"combined_p": combined_p
})
# Step 3: Create combined_df
combined_df = pd.DataFrame(group_results)
# Step 4: Multiple comparisons correction
_, pvals_corr, _, significant = multipletests(
combined_df["combined_p"], alpha=0.05, method=correction
)
combined_df["pval_corr"] = pvals_corr
combined_df["significant"] = significant
logger.info(f"Used peak regressor: {peak_regressor}")
logger.info(f"Channels tested: {len(combined_df)}")
logger.info(f"Significant channels after correction: {combined_df['significant'].sum()}")
# Get optode positions from the first raw file
raw = raw_haemo
src_pos, det_pos = {}, {}
for ch in raw.info["chs"]:
ch_name = ch["ch_name"]
if not ch_name or not ch["loc"].any():
continue
parts = ch_name.split()[0]
src_str, det_str = parts.split("_")
src_num = int(src_str[1:])
det_num = int(det_str[1:])
src_pos[src_num] = ch["loc"][3:5]
det_pos[det_num] = ch["loc"][6:8]
# Plotting parameters
ABS_SIGNIFICANCE_THETA_VALUE = 1
P_THRESHOLD = 0.05
cmap = plt.get_cmap("seismic")
norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_THETA_VALUE, vmax=ABS_SIGNIFICANCE_THETA_VALUE)
fig, ax = plt.subplots(figsize=(8, 6))
# Plot optodes
for pos in src_pos.values():
ax.scatter(*pos, s=120, c="k", marker="o", edgecolors="white", linewidths=1, zorder=3)
for pos in det_pos.values():
ax.scatter(*pos, s=120, c="k", marker="s", edgecolors="white", linewidths=1, zorder=3)
# Plot connections colored by average theta, solid if significant
for row in combined_df.itertuples():
src, det = int(row.Source), int(row.Detector)
tval, pval = row.theta_avg, row.pval_corr
if src in src_pos and det in det_pos:
x = [src_pos[src][0], det_pos[det][0]]
y = [src_pos[src][1], det_pos[det][1]]
linestyle = "-" if pval <= P_THRESHOLD else "--"
ax.plot(x, y, linestyle=linestyle, color=cmap(norm(tval)), linewidth=4, alpha=0.9, zorder=2)
# Colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, shrink=0.85)
cbar.set_label(f"Average {condition_prefix.rstrip('_')} θ-value (HbO)", fontsize=11)
# Format axes
ax.set_aspect("equal")
ax.set_title(f"Group-level θ-values for {condition_prefix.rstrip('_')} (HbO)", fontsize=14)
ax.set_xlabel("X position (m)", fontsize=11)
ax.set_ylabel("Y position (m)", fontsize=11)
ax.grid(True, alpha=0.3)
all_x = [p[0] for p in src_pos.values()] + [p[0] for p in det_pos.values()]
all_y = [p[1] for p in src_pos.values()] + [p[1] for p in det_pos.values()]
ax.set_xlim(min(all_x) - 0.01, max(all_x) + 0.01)
ax.set_ylim(min(all_y) - 0.01, max(all_y) + 0.01)
fig.tight_layout()
fig.show()
def plot_glm_results(file_path, raw_haemo, glm_est, design_matrix):
fig_glms = [] # List to store figures
dm = design_matrix.copy()
logger.info(design_matrix.shape)
logger.info(design_matrix.columns)
logger.info(design_matrix.head())
rois = dict(AllChannels=range(len(raw_haemo.ch_names)))
conditions = design_matrix.columns
df_individual = glm_est.to_dataframe_region_of_interest(rois, conditions)
df_individual["ID"] = file_path
# df_individual["theta"] = [t * 1.0e6 for t in df_individual["theta"]]
first_onset_for_cond = {}
for onset, desc in zip(raw_haemo.annotations.onset, raw_haemo.annotations.description):
if desc not in first_onset_for_cond:
first_onset_for_cond[desc] = onset
# Get unique condition names from annotations (descriptions)
unique_annotations = set(raw_haemo.annotations.description)
for cond in unique_annotations:
logger.info(cond)
df_individual_filtered = df_individual.copy()
# Filter for the condition of interest and FIR delays
df_individual_filtered["isCondition"] = [cond in n for n in df_individual_filtered["Condition"]]
df_individual_filtered["isDelay"] = ["delay" in n for n in df_individual_filtered["Condition"]]
df_individual_filtered = df_individual_filtered.query("isDelay and isCondition")
# Remove other conditions from design matrix
dm_condition_cols = [col for col in dm.columns if cond in col]
dm_cond = dm[dm_condition_cols]
# Add a numeric delay column
def extract_delay_number(condition_str):
# Extracts the number at the end of a string like 'Activity_delay_5'
return int(condition_str.split("_")[-1])
df_individual_filtered["DelayNum"] = df_individual_filtered["Condition"].apply(extract_delay_number)
# Now separate and sort using numeric delay
df_hbo = df_individual_filtered[df_individual_filtered["Chroma"] == "hbo"].sort_values("DelayNum")
df_hbr = df_individual_filtered[df_individual_filtered["Chroma"] == "hbr"].sort_values("DelayNum")
vals_hbo = df_hbo["theta"].values
vals_hbr = df_hbr["theta"].values
# Create the plot
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(19, 10))
# Scale design matrix components using numpy arrays instead of pandas operations
dm_cond_values = dm_cond.values
dm_cond_scaled_hbo = dm_cond_values * vals_hbo.reshape(1, -1)
dm_cond_scaled_hbr = dm_cond_values * vals_hbr.reshape(1, -1)
# Create time axis relative to stimulus onset
time = dm_cond.index - np.ceil(first_onset_for_cond.get(cond, 0))
# Plot
axes[0].plot(time, dm_cond_values)
axes[1].plot(time, dm_cond_scaled_hbo)
axes[2].plot(time, np.sum(dm_cond_scaled_hbo, axis=1), 'r')
axes[2].plot(time, np.sum(dm_cond_scaled_hbr, axis=1), 'b')
# Format plots
for ax in range(3):
axes[ax].set_xlim(-5, 25)
axes[ax].set_xlabel("Time (s)")
axes[0].set_ylim(-0.2, 1.2)
axes[1].set_ylim(-0.5, 1)
axes[2].set_ylim(-0.5, 1)
axes[0].set_title(f"FIR Model (Unscaled)")
axes[1].set_title(f"FIR Components (Scaled by {cond} GLM Estimates)")
axes[2].set_title(f"Evoked Response ({cond})")
axes[0].set_ylabel("FIR Model")
axes[1].set_ylabel("Oxyhaemoglobin (ΔμMol)")
axes[2].set_ylabel("Haemoglobin (ΔμMol)")
axes[2].legend(["Oxyhaemoglobin", "Deoxyhaemoglobin"])
print(f"Number of FIR bins: {len(vals_hbo)}")
print(f"Mean theta (HbO): {np.mean(vals_hbo):.4f}")
print(f"Sum of theta (HbO): {np.sum(vals_hbo):.4f}")
print(f"Mean theta (HbR): {np.mean(vals_hbr):.4f}")
print(f"Sum of theta (HbR): {np.sum(vals_hbr):.4f}")
fig_glms.append((f"Condition {cond}", fig))
return fig_glms
def plot_3d_evoked_array(
inst: Union[BaseRaw, EvokedArray, Info],
statsmodel_df: DataFrame,
picks: Optional[Union[str, list[str]]] = "hbo",
value: str = "Coef.",
background: str = "w",
figure: Optional[object] = None,
clim: Union[str, dict[str, Union[str, list[float]]]] = "auto",
mode: str = "weighted",
colormap: str = "RdBu_r",
surface: str = "pial",
hemi: str = "both",
size: int = 800,
view: Optional[Union[str, dict[str, float]]] = None,
colorbar: bool = True,
distance: float = 0.03,
subjects_dir: Optional[str] = None,
src: Optional[SourceSpaces] = None,
verbose: bool = False,
) -> Brain:
'''Ported from MNE'''
info: Info = cast(Info, deepcopy(inst if isinstance(inst, Info) else inst.info)) # type: ignore
if not (getattr(info, "ch_names") == list(statsmodel_df["ch_name"].values)): # type: ignore
raise RuntimeError(
'MNE data structure does not match dataframe '
f'results.\nMNE = {getattr(info, "ch_names")}.\n'
f'GLM = {list(statsmodel_df["ch_name"].values)}' # type: ignore
)
ea = EvokedArray(np.tile(statsmodel_df[value].values.T, (1, 1)).T, info.copy()) # type: ignore
# TODO: mimic behaviour of other MNE-NIRS glm plotting options
if picks is not None:
ea = ea.pick(picks=picks) # type: ignore
if subjects_dir is None:
subjects_dir = os.environ["SUBJECTS_DIR"]
if src is None:
fname_src_fs = os.path.join(
subjects_dir, "fsaverage", "bem", "fsaverage-ico-5-src.fif"
)
src = read_source_spaces(fname_src_fs, verbose=verbose)
picks = getattr(ea, "info")["ch_names"]
# Set coord frame
for idx in range(len(getattr(ea, "ch_names"))):
getattr(ea, "info")["chs"][idx]["coord_frame"] = 4
# Generate source estimate
kwargs = dict(
evoked=ea,
subject="fsaverage",
trans=Transform('head', 'mri', np.eye(4)),
distance=distance,
mode=mode,
surface=surface,
subjects_dir=subjects_dir,
src=src,
project=True,
)
stc = stc_near_sensors(picks=picks, **kwargs, verbose=verbose) # type: ignore
assert isinstance(stc, SourceEstimate)
# Produce brain plot
brain: Brain = stc.plot( # type: ignore
src=src,
subjects_dir=subjects_dir,
hemi=hemi,
surface=surface,
initial_time=0,
clim=clim, # type: ignore
size=size,
colormap=colormap,
figure=figure,
background=background,
colorbar=colorbar,
verbose=verbose,
)
if view is not None:
brain.show_view(view) # type: ignore
return brain
def brain_3d_visualization(raw_haemo, df_cha, selected_event, t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True, brain_bounds: float = 1.0) -> None:
clim = dict(kind="value", pos_lims=(0, brain_bounds/2, brain_bounds))
# Get all activity conditions
for cond in [f'{selected_event}']:
if True:
ch_summary = df_cha.query(f"Condition.str.startswith('{cond}_delay_') and Chroma == 'hbo'", engine='python') # type: ignore
# Use ordinary least squares (OLS) if only one participant
# TODO: Fix.
if True:
# t values
if t_or_theta == 't':
ch_model = smf.ols("t ~ -1 + ch_name", ch_summary).fit() # type: ignore
# theta values
elif t_or_theta == 'theta':
ch_model = smf.ols("theta ~ -1 + ch_name", ch_summary).fit() # type: ignore
print("OLS model is being used as there is only one participant!")
# Convert model results
model_df = cast(DataFrame, statsmodels_to_results(ch_model, order=ch_summary["ch_name"].unique())) # type: ignore
valid_channels = ch_summary["ch_name"].unique().tolist() # type: ignore
raw_for_plot = raw_haemo.copy().pick(picks=valid_channels) # type: ignore
brain = plot_3d_evoked_array(raw_for_plot.pick(picks="hbo"), model_df, view="dorsal", distance=0.02, colorbar=True, clim=clim, mode="weighted", size=(800, 700)) # type: ignore
if show_optodes == 'all' or show_optodes == 'sensors':
brain.add_sensors(getattr(raw_for_plot, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=False) # type: ignore
if True:
display_text = ('Folder: ' + '\nGroup: ' + '\nCondition: '+ cond + '\nShort Channel Regression: '
+ '\nLooking at: ' + t_or_theta + ' values') + '\nBrain Distance: '
# Apply the text onto the brain
if show_text:
brain.add_text(0.12, 0.64, display_text, "title", font_size=11, color="k") # type: ignore
return brain
def brain_landmarks_3d(raw_haemo: BaseRaw, show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_brodmann: bool = True) -> None:
brain = Brain("fsaverage", background="white", size=(800, 700)) # type: ignore
distances = source_detector_distances(raw_haemo.info)
# Add optode text labels manually
if show_optodes == 'all' or show_optodes == 'sensors':
brain.add_sensors(getattr(raw_haemo, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=False) # type: ignore
if show_optodes == 'all' or show_optodes == 'labels':
labeled_srcs = set()
labeled_dets = set()
label_counts = {}
for idx, ch in enumerate(raw_haemo.info['chs']):
ch_name = ch['ch_name']
if not ch_name.endswith('hbo'):
continue
loc = ch['loc']
logger.info(f"Channel: {ch_name}")
logger.info(f"loc length: {len(loc)}")
logger.info("loc contents:")
for i, val in enumerate(loc):
logger.info(f" loc[{i}]: {val}")
logger.info("-" * 30)
if not ch_name or not ch['loc'].any():
continue
parts = ch_name.split()[0]
src_str, det_str = parts.split('_')
src_num = int(src_str[1:])
det_num = int(det_str[1:])
if src_num not in labeled_srcs:
src_xyz = ch['loc'][3:6] * 1000
brain._renderer.text3d(src_xyz[0], src_xyz[1], src_xyz[2], src_str,
color='red', scale=0.002)
labeled_srcs.add(src_num)
if det_num not in labeled_dets:
det_xyz = ch['loc'][6:9] * 1000
brain._renderer.text3d(det_xyz[0], det_xyz[1], det_xyz[2], det_str,
color='blue', scale=0.002)
labeled_dets.add(det_num)
# Get the source-detector distance for this channel (in meters)
dist_m = distances[idx]
dist_mm = dist_m * 1000
label_text = f"{dist_mm:.1f} mm"
label_counts[label_text] = label_counts.get(label_text, 0) + 1
if label_counts[label_text] > 1:
label_text += f" ({label_counts[label_text]})"
# Label at channel midpoint
mid_xyz = loc[0:3] * 1000
logger.info(f"Channel: {ch_name} | Midpoint (mm): x={mid_xyz[0]:.2f}, y={mid_xyz[1]:.2f}, z={mid_xyz[2]:.2f} | Distance: {dist_mm:.1f} mm")
brain._renderer.text3d(
mid_xyz[0], mid_xyz[1], mid_xyz[2],
label_text,
color='gray',
scale=0.002
)
if show_brodmann:# Add Brodmann labels
labels = cast(list[Label], read_labels_from_annot("fsaverage", "PALS_B12_Brodmann", "rh", verbose=False)) # type: ignore
label_colors = {
"Brodmann.39-rh": "blue",
"Brodmann.40-rh": "green",
"Brodmann.6-rh": "pink",
"Brodmann.7-rh": "orange",
"Brodmann.17-rh": "red",
"Brodmann.1-rh": "yellow",
"Brodmann.2-rh": "yellow",
"Brodmann.3-rh": "yellow",
"Brodmann.18-rh": "red",
"Brodmann.19-rh": "red",
"Brodmann.4-rh": "purple",
"Brodmann.8-rh": "white"
}
for label in labels:
name = getattr(label, "name", None)
if not isinstance(name, str):
continue
if name in label_colors:
brain.add_label(label, borders=False, color=label_colors[name]) # type: ignore
return brain
def verify_channel_positions(data: BaseRaw) -> None:
"""
Visualizes the sensor/channel positions of the raw data for verification.
Parameters
----------
data : BaseRaw
The loaded data object to process.
"""
def convert_fig_dict_to_png_bytes(fig_dict: dict[str, Figure]) -> dict[str, bytes]:
png_dict = {}
for label, fig in fig_dict.items():
buf = BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
png_dict[label] = buf.read()
plt.close(fig)
return png_dict
def brain_3d_contrast(con_model_df: DataFrame, con_model_df_filtered: BaseRaw, common_channels: list[str], first_name: str, second_name: str, t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True, brain_bounds: float = 1.0) -> None:
# Filter DataFrame to only common channels, and sort by raw order
con_model = con_model_df
con_model["ch_name"] = pd.Categorical(
con_model["ch_name"], categories=common_channels, ordered=True
)
con_model = con_model.sort_values("ch_name").reset_index(drop=True) # type: ignore
clim=dict(kind="value", pos_lims=(0, brain_bounds/2, brain_bounds))
# Plot brain figure
brain = plot_3d_evoked_array(con_model_df_filtered.copy().pick(picks="hbo"), con_model, view="dorsal", distance=0.02, colorbar=True, mode="weighted", clim=clim, size=(800, 700), verbose=False) # type: ignore
if show_optodes == 'all' or show_optodes == 'sensors':
brain.add_sensors(getattr(con_model_df_filtered, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=False) # type: ignore
display_text = ('Contrast: ' + first_name + ' - ' + second_name + '\nLooking at: ' + t_or_theta + ' values')
# Apply the text onto the brain
if show_text:
brain.add_text(0.12, 0.70, display_text, "title", font_size=11, color="k") # type: ignore
def plot_2d_3d_contrasts_between_groups(
contrast_df_a: pd.DataFrame,
contrast_df_b: pd.DataFrame,
raw_haemo: BaseRaw,
group_a_name: str,
group_b_name: str,
is_3d: bool = True,
t_or_theta: Literal['t', 'theta'] = 'theta',
show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all',
show_text: bool = True,
brain_bounds: float = 1.0,
) -> None:
logger.info("-----")
contrast_df_a = contrast_df_a.copy()
contrast_df_a["group"] = group_a_name
contrast_df_b = contrast_df_b.copy()
contrast_df_b["group"] = group_b_name
logger.info("-----")
df_combined = pd.concat([contrast_df_a, contrast_df_b], ignore_index=True)
con_summary = df_combined.query("Chroma == 'hbo'").copy()
logger.info("-----")
valid_channels = (pd.crosstab(con_summary["group"], con_summary["ch_name"]) > 1).all()
valid_channels = valid_channels[valid_channels].index.tolist()
con_summary = con_summary[con_summary["ch_name"].isin(valid_channels)]
logger.info("-----")
model_formula = "effect ~ -1 + group:ch_name:Chroma"
con_model = smf.mixedlm(model_formula, con_summary, groups=con_summary["ID"]).fit(method="nm")
logger.info("-----")
if t_or_theta == "t":
group1_vals = con_model.tvalues.filter(like=f"group[{group_a_name}]")
group2_vals = con_model.tvalues.filter(like=f"group[{group_b_name}]")
else:
group1_vals = con_model.params.filter(like=f"group[{group_a_name}]")
group2_vals = con_model.params.filter(like=f"group[{group_b_name}]")
logger.info("-----")
group1_channels = [name.split(":")[1].split("[")[1].split("]")[0] for name in group1_vals.index]
group2_channels = [name.split(":")[1].split("[")[1].split("]")[0] for name in group2_vals.index]
df_group1 = DataFrame({"Coef.": group1_vals.values}, index=group1_channels)
df_group2 = DataFrame({"Coef.": group2_vals.values}, index=group2_channels)
df_contrast = df_group1.join(df_group2, how="inner", lsuffix=f"_{group_a_name}", rsuffix=f"_{group_b_name}")
logger.info("-----")
# A - B
df_contrast["Coef."] = df_contrast[f"Coef._{group_a_name}"] - df_contrast[f"Coef._{group_b_name}"]
con_model_df_1_2 = DataFrame({
"ch_name": df_contrast.index,
"Coef.": df_contrast["Coef."],
"Chroma": "hbo"
})
logger.info("-----")
mne_ch_names = raw_haemo.copy().pick(picks="hbo").ch_names
glm_ch_names = con_model_df_1_2["ch_name"].tolist()
common_channels = [ch for ch in mne_ch_names if ch in glm_ch_names]
con_model_df_filtered = raw_haemo.copy().pick(picks=common_channels)
con_model_df_1_2 = con_model_df_1_2.set_index("ch_name").loc[common_channels].reset_index()
logger.info("-----")
if is_3d:
brain_3d_contrast(
con_model_df_1_2,
con_model_df_filtered,
common_channels,
group_a_name,
group_b_name,
t_or_theta,
show_optodes,
show_text,
brain_bounds
)
else:
plot_glm_group_topo(con_model_df_filtered.copy().pick(picks="hbo"), con_model_df_1_2, names=True, res=128, vlim=(-brain_bounds, brain_bounds)) # type: ignore
# TODO: The title currently goes on the colorbar. Low priority
plt.title(f"Contrast: {group_a_name} vs {group_b_name}") # type: ignore
plt.show() # type: ignore
# plt.title(f"Contrast: {group_a_name} vs {group_b_name}")
# plt.show()
# B - A
df_contrast["Coef."] = df_contrast[f"Coef._{group_b_name}"] - df_contrast[f"Coef._{group_a_name}"]
con_model_df_2_1 = DataFrame({
"ch_name": df_contrast.index,
"Coef.": df_contrast["Coef."],
"Chroma": "hbo"
})
glm_ch_names = con_model_df_2_1["ch_name"].tolist()
common_channels = [ch for ch in mne_ch_names if ch in glm_ch_names]
con_model_df_filtered = raw_haemo.copy().pick(picks=common_channels)
con_model_df_2_1 = con_model_df_2_1.set_index("ch_name").loc[common_channels].reset_index()
if is_3d:
brain_3d_contrast(
con_model_df_2_1,
con_model_df_filtered,
common_channels,
group_b_name,
group_a_name,
t_or_theta,
show_optodes,
show_text,
brain_bounds
)
else:
plot_glm_group_topo(con_model_df_filtered.copy().pick(picks="hbo"), con_model_df_2_1, names=True, res=128, vlim=(-brain_bounds, brain_bounds)) # type: ignore
# TODO: The title currently goes on the colorbar. Low priority
plt.title(f"Contrast: {group_b_name} vs {group_a_name}") # type: ignore
plt.show() # type: ignore
def plot_fir_model_results(df, raw_haemo, dm, selected_event, l_bound, u_bound):
df["isActivity"] = [f"{selected_event}" in n for n in df["Condition"]]
df["isDelay"] = ["delay" in n for n in df["Condition"]]
df = df.query("isDelay in [True]")
df = df.query("isActivity in [True]")
# Make a new column that stores the condition name for tidier model below
df.loc[:, "TidyCond"] = ""
df.loc[df["isActivity"] == True, "TidyCond"] = f"{selected_event}" # noqa: E712
# Finally, extract the FIR delay in to its own column in data frame
df.loc[:, "delay"] = [n.split("_")[-1] for n in df.Condition]
# To simplify this example we will only look at the activity
# condition so we now remove the other conditions from the
# design matrix and GLM results
dm_cols_activity = np.where([f"{selected_event}" in c for c in dm.columns])[0]
dm = dm[[dm.columns[i] for i in dm_cols_activity]]
lme = smf.mixedlm("theta ~ -1 + delay:TidyCond:Chroma", df, groups=df["ID"]).fit()
df_sum = statsmodels_to_results(lme)
df_sum["delay"] = [int(n) for n in df_sum["delay"]]
df_sum = df_sum.sort_values("delay")
# Print the result for the oxyhaemoglobin data in the target condition
df_sum.query(f"TidyCond in ['{selected_event}']").query("Chroma in ['hbo']")
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(19, 10))
print("dm columns:", dm.columns.tolist())
# Extract design matrix columns that correspond to the condition of interest
dm_cond_idxs = np.where([f"{selected_event}" in n for n in dm.columns])[0]
dm_cond = dm[[dm.columns[i] for i in dm_cond_idxs]]
# Extract the corresponding estimates from the lme dataframe for hbo
df_hbo = df_sum.query(f"TidyCond in ['{selected_event}']").query("Chroma in ['hbo']")
vals_hbo = [float(v) for v in df_hbo["Coef."]]
# print("--------------------------------------")
# print(f"dm_cond shape: {dm_cond.shape}")
# print(f"dm_cond columns: {dm_cond.columns.tolist()}")
# print(f"vals_hbo length: {len(vals_hbo)}")
# print(f"vals_hbo sample: {vals_hbo[:5]}")
# print(f"vals_hbo type: {type(vals_hbo)}")
# print(f"vals_hbo element type: {type(vals_hbo[0]) if len(vals_hbo) > 0 else 'N/A'}")
dm_cond_scaled_hbo = dm_cond * vals_hbo
# Extract the corresponding estimates from the lme dataframe for hbr
df_hbr = df_sum.query(f"TidyCond in ['{selected_event}']").query("Chroma in ['hbr']")
vals_hbr = [float(v) for v in df_hbr["Coef."]]
dm_cond_scaled_hbr = dm_cond * vals_hbr
first_onset = None
for desc, onset in zip(raw_haemo.annotations.description, raw_haemo.annotations.onset):
if selected_event in desc:
first_onset = onset
break
if first_onset is None:
raise ValueError(f"Selected event '{selected_event}' not found in annotations.")
# Align index values (time axis) to the first occurrence of selected_event
index_values = dm_cond_scaled_hbo.index - np.ceil(first_onset)
index_values = np.asarray(index_values)
# Plot the result
axes[0].plot(index_values, np.asarray(dm_cond))
axes[1].plot(index_values, np.asarray(dm_cond_scaled_hbo))
axes[2].plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r")
axes[2].plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b")
valid_mask = (index_values >= 0) & (index_values <= 15)
hbo_sum_window = np.sum(dm_cond_scaled_hbo.loc[valid_mask, :], axis=1)
peak_idx_in_window = np.argmax(hbo_sum_window)
peak_idx = np.where(valid_mask)[0][peak_idx_in_window]
peak_time = float(round(index_values[peak_idx], 2)) # type: ignore
axes[2].axvline(x=peak_time, color='k', linestyle='--', linewidth=1.5, label='Peak time') # type: ignore
# Format the plot
for ax in range(3):
axes[ax].set_xlim(-5, 25)
axes[ax].set_xlabel("Time (s)")
axes[0].set_ylim(-0.1, 1.1)
axes[1].set_ylim(l_bound, u_bound)
axes[2].set_ylim(l_bound, u_bound)
axes[0].set_title("FIR Model (Unscaled by GLM estimates)")
axes[1].set_title(f"FIR Components (Scaled by {selected_event} GLM Estimates)")
axes[2].set_title(f"Evoked Response {selected_event}")
axes[0].set_ylabel("FIR Model")
axes[1].set_ylabel("Oyxhaemoglobin (ΔμMol)")
axes[2].set_ylabel("Haemoglobin (ΔμMol)")
axes[2].legend(["Oyxhaemoglobin", "Deoyxhaemoglobin"])
# We can also extract the 95% confidence intervals of the estimates too
l95_hbo = [float(v) for v in df_hbo["[0.025"]] # type: ignore
u95_hbo = [float(v) for v in df_hbo["0.975]"]] # type: ignore
dm_cond_scaled_hbo_l95 = dm_cond * l95_hbo
dm_cond_scaled_hbo_u95 = dm_cond * u95_hbo
l95_hbr = [float(v) for v in df_hbr["[0.025"]] # type: ignore
u95_hbr = [float(v) for v in df_hbr["0.975]"]] # type: ignore
dm_cond_scaled_hbr_l95 = dm_cond * l95_hbr
dm_cond_scaled_hbr_u95 = dm_cond * u95_hbr
axes2: Axes
fig2, axes2 = plt.subplots(nrows=1, ncols=1, figsize=(7, 7)) # type: ignore
# Plot the result
axes2.plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r") # type: ignore
axes2.plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b") # type: ignore
axes2.axvline(x=peak_time, color='k', linestyle='--', linewidth=1.5, label='Peak time') # type: ignore
axes2.fill_between( # type: ignore
index_values,
np.asarray(np.sum(dm_cond_scaled_hbo_l95, axis=1)),
np.asarray(np.sum(dm_cond_scaled_hbo_u95, axis=1)),
facecolor="red",
alpha=0.25,
)
axes2.fill_between( # type: ignore
index_values,
np.asarray(np.sum(dm_cond_scaled_hbr_l95, axis=1)),
np.asarray(np.sum(dm_cond_scaled_hbr_u95, axis=1)),
facecolor="blue",
alpha=0.25,
)
# Format the plot
axes2.set_xlim(-5, 20)
axes2.set_ylim(l_bound, u_bound)
axes2.set_title(f"Evoked Response with 95% confidence intervals for )") # type: ignore
axes2.set_ylabel("Haemoglobin (ΔμMol)") # type: ignore
axes2.legend(["Oyxhaemoglobin", "Deoyxhaemoglobin", f"Peak {peak_time}s"]) # type: ignore
axes2.set_xlabel("Time (s)") # type: ignore
fig2.tight_layout()
fig.show()
fig2.show()
def load_snirf(file_path: str) -> tuple[BaseRaw, Figure]:
"""
Loads a snirf file, optionally drops channels, downsamples, and creates a figure showing the results.
Parameters
----------
file_path : str
Path of the snirf file to load.
ID : str
File name of the the snirf file that was loaded.
drop_prefixes : list[str]
List of channel name prefixes to drop from the data.
Returns
-------
tuple[BaseRaw, Figure]
- BaseRaw: The processed data object.
- Figure: The corresponding Matplotlib figure.
"""
# Read the snirf file
raw = read_raw_snirf(file_path, preload=True, verbose=VERBOSITY) # type: ignore
raw.load_data(verbose=VERBOSITY) # type: ignore
# Strip the specified amount of seconds from the start of the file
total_duration = getattr(raw, "times")[-1]
if total_duration > SECONDS_TO_STRIP:
raw.crop(tmin=SECONDS_TO_STRIP, tmax=total_duration, verbose=VERBOSITY) # type: ignore
logger.info(f"Stripped first {SECONDS_TO_STRIP} second(s) of data.")
else:
logger.info(f"Data length ({total_duration:.2f}s) less than strip duration; no cropping applied.")
# If the user forcibly dropped channels, remove them now before any processing occurs
# logger.info("Checking if there are channels to forcibly drop...")
# if drop_prefixes:
# logger.info("Force dropped channels was specified.")
# channels_to_drop = [ch for ch in cast(list[str], getattr(raw, "ch_names")) if any(ch.startswith(prefix) for prefix in drop_prefixes)]
# raw.drop_channels(channels_to_drop, "raise") # type: ignore
# logger.info("Force dropped channels:", channels_to_drop)
# If the user wants to downsample, do it right away
logger.info("Checking if we should downsample...")
if DOWNSAMPLE:
logger.info("Downsample was specified.")
sfreq_old = getattr(raw, "info")["sfreq"]
raw.resample(DOWNSAMPLE_FREQUENCY, verbose=VERBOSITY) # type: ignore
sfreq_new = getattr(raw, "info")["sfreq"]
logger.info(f"Finished downsampling. Old frequency: {sfreq_old}. New frequency: {sfreq_new}.")
logger.info("Successfully loaded the snirf file.")
return raw
def run_second_level_analysis(df_contrasts, raw, p, bounds):
"""
Perform second-level analysis using contrast data from multiple participants.
Parameters
----------
df_contrasts : pd.DataFrame
Combined contrast results from multiple participants.
Must include: ['ch_name', 'effect', 'ID']
Returns
-------
pd.DataFrame
Group-level t-values, p-values, and mean effect per channel.
"""
if not all(col in df_contrasts.columns for col in ['ch_name', 'effect', 'ID']):
raise ValueError("Input DataFrame must include 'ch_name', 'effect', and 'ID' columns.")
channels = df_contrasts['ch_name'].unique()
group_results = []
for ch in channels:
ch_data = df_contrasts[df_contrasts['ch_name'] == ch]
if ch_data['ID'].nunique() < 2:
logger.warning(f"Skipping channel {ch} — not enough subjects.")
continue
Y = ch_data['effect'].values
design_matrix = np.ones((len(Y), 1)) # intercept-only
model = OLSModel(design_matrix)
result = model.fit(Y)
t_val = result.t(0).item()
p_val = 2 * stats.t.sf(np.abs(t_val), df=result.df_model)
mean_beta = np.mean(Y)
group_results.append({
'ch_name': ch,
't_val': t_val,
'p_val': p_val,
'mean_beta': mean_beta,
'n_subjects': len(Y)
})
df_group = pd.DataFrame(group_results)
logger.info("Second-level results:\n%s", df_group)
# Extract the cource and detector positions from raw
src_pos: dict[int, tuple[float, float]] = {}
det_pos: dict[int, tuple[float, float]] = {}
for ch in getattr(raw, "info")["chs"]:
ch_name = ch['ch_name']
if not ch_name or not ch['loc'].any():
continue
parts = ch_name.split()[0]
src_str, det_str = parts.split('_')
src_num = int(src_str[1:])
det_num = int(det_str[1:])
src_pos[src_num] = ch['loc'][3:5]
det_pos[det_num] = ch['loc'][6:8]
# Set up the plot
fig, ax = plt.subplots(figsize=(8, 6)) # type: ignore
# Plot the sources
for pos in src_pos.values():
ax.scatter(pos[0], pos[1], s=120, c='k', marker='o', edgecolors='white', linewidths=1, zorder=3) # type: ignore
# Plot the detectors
for pos in det_pos.values():
ax.scatter(pos[0], pos[1], s=120, c='k', marker='s', edgecolors='white', linewidths=1, zorder=3) # type: ignore
# Ensure that the colors stay within the boundaries even if they are over or under the max/min values
norm = mcolors.Normalize(vmin=-bounds, vmax=bounds)
cmap: mcolors.Colormap = plt.get_cmap('seismic')
# Plot connections with avg t-values
for _, row in df_group.iterrows():
ch = row['ch_name']
pval = row['p_val']
tval = row['t_val']
if '_' not in ch:
logger.info(f"Skipping channel with unexpected format (no underscore): {ch}")
continue
src_str, det_str = ch.split('_')
det_parts = det_str.split()
detector_id = det_parts[0] # e.g. "D1"
hemo_type = det_parts[1].lower() if len(det_parts) > 1 else ''
logger.info(f"Parsing channel: {ch} -> src_str: {src_str}, det_str: {detector_id}, hemo_type: {hemo_type}")
if hemo_type != 'hbo':
logger.info(f"Skipping channel {ch} because hemo_type is not HbO: {hemo_type}")
continue
try:
src = int(src_str[1:])
det = int(detector_id[1:])
logger.info(f"Parsed src: {src}, det: {det}")
except Exception as e:
logger.info(f"Error parsing source/detector from channel '{ch}': {e}")
continue
if src in src_pos and det in det_pos:
x = [src_pos[src][0], det_pos[det][0]]
y = [src_pos[src][1], det_pos[det][1]]
style = '-' if pval <= p else '--'
color = cmap(norm(tval))
logger.info(f"Plotting {ch}: t={tval:.2f}, p={pval:.3f}, color={color}, style={style}")
ax.plot(x, y, linestyle=style, color=color, linewidth=4, alpha=0.9, zorder=2)
# Format the Colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, shrink=0.85) # type: ignore
cbar.set_label(f'Average value (hbo)', fontsize=11) # type: ignore
# Formatting the subplots
ax.set_aspect('equal')
ax.set_title(f"Average values (HbO)", fontsize=14) # type: ignore
ax.set_xlabel('X position (m)', fontsize=11) # type: ignore
ax.set_ylabel('Y position (m)', fontsize=11) # type: ignore
ax.grid(True, alpha=0.3) # type: ignore
# Set axis limits to be 1cm more than the optode positions
all_x = [pos[0] for pos in src_pos.values()] + [pos[0] for pos in det_pos.values()]
all_y = [pos[1] for pos in src_pos.values()] + [pos[1] for pos in det_pos.values()]
ax.set_xlim(min(all_x)-0.01, max(all_x)+0.01)
ax.set_ylim(min(all_y)-0.01, max(all_y)+0.01)
fig.tight_layout()
plt.show() # type: ignore
return df_group
def calculate_dpf(file_path):
# order is hbo / hbr
with h5py.File(file_path, 'r') as f:
wavelengths = f['/nirs/probe/wavelengths'][:]
logger.info(f"Wavelengths (nm): {wavelengths}")
wavelengths = sorted(wavelengths, reverse=True)
age = float(AGE)
logger.info(f"Their age was {AGE}")
a = 223.3
b = 0.05624
c = 0.8493
d = -5.723e-7
e = 0.001245
f = -0.9025
dpf = []
for w in wavelengths:
logger.info(w)
dpf.append(a + b * (age**c) + d* (w**3) + e * (w**2) + f*w)
logger.info(dpf)
return dpf
def process_participant(file_path, progress_callback=None):
fig_individual: dict[str, Figure] = {}
# Step 1: Load
raw = load_snirf(file_path)
fig_raw = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Loaded Raw", show=False)
fig_individual["Loaded Raw"] = fig_raw
if progress_callback: progress_callback(1)
logger.info("1")
# Step 1.5: Verify optode positions
fig_optodes = raw.plot_sensors(show_names=True, to_sphere=True, show=False) # type: ignore
fig_individual["Plot Sensors"] = fig_optodes
if progress_callback: progress_callback(2)
logger.info("2")
# Step 2: Downsample
# raw = raw.resample(0.5) # Downsample to 0.5 Hz
# Step 2: Bad from SCI
bad_sci = []
if SCI:
bad_sci, fig_sci_1, fig_sci_2 = calculate_scalp_coupling(raw)
fig_individual["SCI1"] = fig_sci_1
fig_individual["SCI2"] = fig_sci_2
if progress_callback: progress_callback(3)
logger.info("3")
# Step 2: Bad from SNR
bad_snr = []
if SNR:
bad_snr, fig_snr = calculate_signal_noise_ratio(raw)
fig_individual["SNR1"] = fig_snr
if progress_callback: progress_callback(4)
logger.info("4")
# Step 3: Bad from PSP
bad_psp = []
if PSP:
bad_psp, fig_psp1, fig_psp2 = calculate_peak_power(raw)
fig_individual["PSP1"] = fig_psp1
fig_individual["PSP2"] = fig_psp2
if progress_callback: progress_callback(5)
logger.info("5")
# Step 4: Mark the bad channels
raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp)
if fig_dropped and fig_raw_before is not None:
fig_individual["fig2"] = fig_dropped
fig_individual["fig3"] = fig_raw_before
if progress_callback: progress_callback(6)
logger.info("6")
# Step 5: Interpolate the bad channels
if bad_channels:
raw, fig_raw_after = interpolate_fNIRS_bads_weighted_average(raw, bad_channels)
fig_individual["fig4"] = fig_raw_after
if progress_callback: progress_callback(7)
logger.info("7")
# Step 6: Optical Density
raw_od = optical_density(raw)
fig_raw_od = raw_od.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Optical Density", show=False)
fig_individual["Optical Density"] = fig_raw_od
if progress_callback: progress_callback(8)
logger.info("8")
# Step 7: TDDR
if TDDR:
raw_od = temporal_derivative_distribution_repair(raw_od)
fig_raw_od_tddr = raw_od.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="After TDDR (Motion Correction)", show=False)
fig_individual["TDDR"] = fig_raw_od_tddr
if progress_callback: progress_callback(9)
logger.info("9")
# Step 8: BLL
raw_haemo = beer_lambert_law(raw_od, ppf=calculate_dpf(file_path))
fig_raw_haemo_bll = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="HbO and HbR Signals", show=False)
fig_individual["BLL"] = fig_raw_haemo_bll
if progress_callback: progress_callback(10)
logger.info("10")
# Step 9: ENC
# raw_haemo = enhance_negative_correlation(raw_haemo)
# fig_raw_haemo_enc = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="HbO and HbR Signals", show=False)
# fig_individual.append(fig_raw_haemo_enc)
# Step 10: Filter
fig_filter, fig_raw_haemo_filter = filter_the_data(raw_haemo)
fig_individual["filter1"] = fig_filter
fig_individual["filter2"] = fig_raw_haemo_filter
if progress_callback: progress_callback(11)
logger.info("11")
# Step 11: Get short / long channels
if SHORT_CHANNEL:
short_chans = get_short_channels(raw_haemo, max_dist=0.015)
fig_short_chans = short_chans.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="Short Channels Only", show=False)
fig_individual["short"] = fig_short_chans
else:
short_chans = None
raw_haemo = get_long_channels(raw_haemo)
if progress_callback: progress_callback(12)
logger.info("12")
# Step 12: Events from annotations
events, event_dict = events_from_annotations(raw_haemo)
fig_events = plot_events(events, event_id=event_dict, sfreq=raw_haemo.info["sfreq"], show=False)
fig_individual["events"] = fig_events
if progress_callback: progress_callback(13)
logger.info("13")
# Step 13: Epoch calculations
epochs, fig_epochs = epochs_calculations(raw_haemo, events, event_dict)
for name, fig in fig_epochs: # Unpack the tuple here
fig_individual[f"epochs_{name}"] = fig # Store only the figure, not the name
if progress_callback: progress_callback(14)
logger.info("14")
# Step 14: Design Matrix
events_to_remove = REMOVE_EVENTS
filtered_annotations = [ann for ann in raw.annotations if ann['description'] not in events_to_remove]
new_annot = Annotations(
onset=[ann['onset'] for ann in filtered_annotations],
duration=[ann['duration'] for ann in filtered_annotations],
description=[ann['description'] for ann in filtered_annotations]
)
# Set the new annotations
raw_haemo.set_annotations(new_annot)
design_matrix, fig_design_matrix = make_design_matrix(raw_haemo, short_chans)
fig_individual["Design Matrix"] = fig_design_matrix
if progress_callback: progress_callback(15)
logger.info("15")
# Step 15: Run GLM
glm_est = run_glm(raw_haemo, design_matrix)
# Not used AppData\Local\Packages\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\LocalCache\local-packages\Python313\site-packages\nilearn\glm\contrasts.py
# Yes used AppData\Local\Packages\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\LocalCache\local-packages\Python313\site-packages\mne_nirs\utils\_io.py
# The p-value is calculated from this t-statistic using the Students t-distribution with appropriate degrees of freedom.
# p_value = 2 * stats.t.cdf(-abs(t_statistic), df)
# It is a two-tailed p-value.
# It says how likely it is to observe the effect you did (or something more extreme) if the true effect was zero (null hypothesis).
# A small p-value (e.g., < 0.05) suggests the effect is unlikely to be zero — its "statistically significant."
# A large p-value means the data do not provide strong evidence that the effect is different from zero.
if progress_callback: progress_callback(16)
logger.info("16")
# Step 16: Plot GLM results
fig_glm_result = plot_glm_results(file_path, raw_haemo, glm_est, design_matrix)
for name, fig in fig_glm_result:
fig_individual[f"GLM {name}"] = fig
if progress_callback: progress_callback(17)
logger.info("17")
# Step 17: Plot channel significance
fig_significance = individual_significance(raw_haemo, glm_est)
for name, fig in fig_significance:
fig_individual[f"Significance {name}"] = fig
if progress_callback: progress_callback(18)
logger.info("18")
# Step 18: cha, con, roi
cha = glm_est.to_dataframe()
# HACK: Comment out line 588 (self._renderer.show()) in _brain.py from MNE
# brain_thing = brain_3d_visualization(cha, raw_haemo)
# brain_individual.append(brain_thing)
# C++ objects made this get rendered on the fly
rois = dict(AllChannels=range(len(raw_haemo.ch_names)))
# Calculate ROI for all conditions
conditions = design_matrix.columns
# Compute output metrics by ROI
df_ind = glm_est.to_dataframe_region_of_interest(rois, conditions)
df_ind["ID"] = file_path
# Step 18: Fold channels
# fig_fold_data, fig_fold_legend = fold_channels(raw_haemo)
# fig_individual.append(fig_fold_data)
# fig_individual.append(fig_fold_legend)
print(design_matrix)
contrast_matrix = np.eye(design_matrix.shape[1])
basic_conts = dict(
[(column, contrast_matrix[i]) for i, column in enumerate(design_matrix.columns)]
)
all_delay_cols = [col for col in design_matrix.columns if "_delay_" in col]
all_conditions = sorted({col.split("_delay_")[0] for col in all_delay_cols})
if not all_conditions:
raise ValueError("No FIR regressors found in the design matrix.")
# Build contrast vectors for each condition
contrast_dict = {}
for condition in all_conditions:
delay_cols = [
col for col in all_delay_cols
if col.startswith(f"{condition}_delay_") and
TIME_WINDOW_START <= int(col.split("_delay_")[-1]) <= TIME_WINDOW_END
]
if not delay_cols:
continue # skip if no columns found (shouldn't happen?)
# Average across all delay regressors for this condition
contrast_vector = np.sum([basic_conts[col] for col in delay_cols], axis=0)
contrast_vector /= len(delay_cols)
contrast_dict[condition] = contrast_vector
if progress_callback: progress_callback(19)
logger.info("19")
# Compute contrast results
contrast_results = {}
for cond, contrast_vector in contrast_dict.items():
contrast = glm_est.compute_contrast(contrast_vector) # type: ignore
df = contrast.to_dataframe()
df["ID"] = file_path
contrast_results[cond] = df
cha["ID"] = file_path
fig_bytes = convert_fig_dict_to_png_bytes(fig_individual)
if progress_callback: progress_callback(20)
logger.info("20")
sanitize_paths_for_pickle(raw_haemo, epochs)
return raw_haemo, epochs, fig_bytes, cha, contrast_results, df_ind, design_matrix, AGE, GENDER, GROUP, True
def sanitize_paths_for_pickle(raw_haemo, epochs):
# Fix raw_haemo._filenames
if hasattr(raw_haemo, '_filenames'):
raw_haemo._filenames = [str(p) for p in raw_haemo._filenames]
# Fix epochs._raw._filenames
if hasattr(epochs, '_raw') and hasattr(epochs._raw, '_filenames'):
epochs._raw._filenames = [str(p) for p in epochs._raw._filenames]