""" Filename: flares.py Description: Core functionality for FLARES Author: Tyler de Zeeuw License: GPL-3.0 """ # Built-in imports import os import sys import platform import threading import logging from io import BytesIO from typing import Any, Optional, cast, Literal, Union from itertools import compress from copy import deepcopy from multiprocessing import Queue import os.path as op import re import traceback from concurrent.futures import ProcessPoolExecutor, as_completed # External library imports import matplotlib.pyplot as plt import matplotlib.colors as mcolors from matplotlib.figure import Figure from matplotlib.axes import Axes from matplotlib.colors import LinearSegmentedColormap from matplotlib.lines import Line2D import numpy as np from numpy.typing import NDArray from numpy import float64, floating import pandas as pd from pandas import DataFrame import seaborn as sns import h5py from nilearn.plotting import plot_design_matrix # type: ignore from nilearn.glm.regression import OLSModel import statsmodels.formula.api as smf # type: ignore from statsmodels.stats.multitest import multipletests from scipy import stats from scipy.spatial.distance import cdist from scipy.signal import welch, butter, filtfilt # type: ignore import pywt # type: ignore import neurokit2 as nk # type: ignore # Backen visualization needed to be defined for pyinstaller import pyvistaqt # type: ignore #import vtkmodules.util.data_model #import vtkmodules.util.execution_model # External library imports for mne from mne import ( EvokedArray, SourceEstimate, Info, Epochs, Label, Annotations, events_from_annotations, read_source_spaces, stc_near_sensors, pick_types, grand_average, get_config, set_config, read_labels_from_annot ) # type: ignore from mne.source_space import SourceSpaces from mne.transforms import Transform # type: ignore from mne.io import BaseRaw, RawArray, read_raw_snirf # type: ignore from mne.preprocessing.nirs import ( beer_lambert_law, optical_density, temporal_derivative_distribution_repair, source_detector_distances, short_channels ) # type: ignore from mne.viz import Brain, plot_events, plot_evoked_topo, plot_compare_evokeds from mne.filter import filter_data # type: ignore from mne.utils import _check_fname, _validate_type, warn from mne.channels import make_standard_montage from mne.datasets.sample import data_path from mne_nirs.visualisation import plot_glm_group_topo # type: ignore from mne_nirs.channels import get_long_channels, get_short_channels # type: ignore from mne_nirs.experimental_design import make_first_level_design_matrix # type: ignore from mne_nirs.statistics import run_glm, statsmodels_to_results # type: ignore from mne_nirs.signal_enhancement import ( enhance_negative_correlation, short_channel_regression ) # type: ignore from mne_nirs.io.fold import fold_channel_specificity # type: ignore from mne_nirs.preprocessing import peak_power # type: ignore from mne_nirs.statistics._glm_level_first import RegressionResults # type: ignore os.environ["SUBJECTS_DIR"] = str(data_path()) + "/subjects" # type: ignore FIXED_CATEGORY_COLORS = { "SCI only": "skyblue", "PSP only": "salmon", "SNR only": "lightgreen", "PSP + SCI": "orange", "SCI + SNR": "violet", "PSP + SNR": "gold", "SCI + PSP": "orange", "SNR + SCI": "violet", "SNR + PSP": "gold", "PSP + SNR + SCI": "gray", "SCI + PSP + SNR": "gray", "SCI + SNR + PSP": "gray", "PSP + SCI + SNR": "gray", "PSP + SNR + SCI": "gray", "SNR + SCI + PSP": "gray", "SNR + PSP + SCI": "gray", } AGE: float GENDER: str # SECONDS_TO_STRIP: int DOWNSAMPLE: bool DOWNSAMPLE_FREQUENCY: int SCI: bool SCI_TIME_WINDOW: int SCI_THRESHOLD: float SNR: bool # SNR_TIME_WINDOW : int SNR_THRESHOLD: float PSP: bool PSP_TIME_WINDOW: int PSP_THRESHOLD: float TDDR: bool IQR = 1.5 HEART_RATE = True # True if heart rate should be calculated. This helps the SCI, PSP, and SNR methods to be more accurate. SECONDS_TO_STRIP_HR =5 # Amount of seconds to temporarily strip from the data to calculate heart rate more effectively. Useful if participant removed cap while still recording. MAX_LOW_HR = 40 # Any heart rate values lower than this will be set to this value. MAX_HIGH_HR = 200 # Any heart rate values higher than this will be set to this value. SMOOTHING_WINDOW_HR = 100 # Heart rate will be calculated as a rolling average over this many amount of samples. HEART_RATE_WINDOW = 25 # Amount of BPM above and below the calculated average to use for a range of resting BPM. SHORT_CHANNEL_THRESH = 0.018 ENHANCE_NEGATIVE_CORRELATION: bool L_FREQ: float H_FREQ: float SHORT_CHANNEL: bool REMOVE_EVENTS: list TIME_WINDOW_START: int TIME_WINDOW_END: int VERBOSITY = True # FIXME: Shouldn't need each ordering - just order it before checking FIXED_CATEGORY_COLORS = { "SCI only": "skyblue", "PSP only": "salmon", "SNR only": "lightgreen", "PSP + SCI": "orange", "SCI + SNR": "violet", "PSP + SNR": "gold", "SCI + PSP": "orange", "SNR + SCI": "violet", "SNR + PSP": "gold", "PSP + SNR + SCI": "gray", "SCI + PSP + SNR": "gray", "SCI + SNR + PSP": "gray", "PSP + SCI + SNR": "gray", "PSP + SNR + SCI": "gray", "SNR + SCI + PSP": "gray", "SNR + PSP + SCI": "gray", } AGE = 25 GENDER = "" GROUP = "Default" REQUIRED_KEYS: dict[str, Any] = { # "SECONDS_TO_STRIP": int, "DOWNSAMPLE": bool, "DOWNSAMPLE_FREQUENCY": int, "SCI": bool, "SCI_TIME_WINDOW": int, "SCI_THRESHOLD": float, "SNR": bool, # SNR_TIME_WINDOW : int "SNR_THRESHOLD": float, "PSP": bool, "PSP_TIME_WINDOW": int, "PSP_THRESHOLD": float, "SHORT_CHANNEL": bool, "REMOVE_EVENTS": list, "TIME_WINDOW_START": int, "TIME_WINDOW_END": int, "L_FREQ": float, "H_FREQ": float, # "REJECT_PAIRS": bool, # "FORCE_DROP_ANNOTATIONS": list, # "FILTER_LOW_PASS": float, # "FILTER_HIGH_PASS": float, # "EPOCH_PAIR_TOLERANCE_WINDOW": int, } class ProcessingError(Exception): def __init__(self, message: str = "Something went wrong!"): self.message = message super().__init__(self.message) # Ensure that we are working in the directory of this file script_dir = os.path.dirname(os.path.abspath(__file__)) os.chdir(script_dir) PLATFORM_NAME = platform.system().lower() # Configure logging to file with timestamps and realtime flush if PLATFORM_NAME == 'darwin': logging.basicConfig( filename=os.path.join(os.path.dirname(sys.executable), "../../../fnirs_analysis.log"), level=logging.INFO, format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', filemode='a' ) else: logging.basicConfig( filename='fnirs_analysis.log', level=logging.INFO, format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', filemode='a' ) logger = logging.getLogger() def set_config_me(config: dict[str, Any]) -> None: """ Validates and applies the given configuration dictionary. Parameters ---------- config : dict[str, Any] Dictionary containing configuration keys and their values. """ logger.info(f"[DEBUG] set_config called") globals().update(config) def set_metadata(file_path, metadata: dict[str, Any]) -> None: """ Validates and applies the given configuration dictionary. Parameters ---------- config : dict[str, Any] Dictionary containing configuration keys and their values. """ logger.info(f"[DEBUG] set_metadata called") globals()['AGE'] = 25 globals()['GENDER'] = "" globals()['GROUP'] = "Default" if metadata.get(file_path) is not None: file_metadata = metadata.get(file_path, {}) for key in ("AGE", "GENDER", "GROUP"): val = file_metadata.get(key, None) if val not in (None, '', [], {}, ()): # check for "empty" values globals()[key] = val from queue import Empty # This works with multiprocessing.Manager().Queue() def gui_entry(config: dict[str, Any], gui_queue: Queue, progress_queue: Queue) -> None: def forward_progress(): while True: try: msg = progress_queue.get(timeout=1) if msg == "__done__": break gui_queue.put(msg) except Empty: continue except Exception as e: gui_queue.put({ "type": "error", "error": f"Forwarding thread crashed: {e}", "traceback": traceback.format_exc() }) break t = threading.Thread(target=forward_progress, daemon=True) t.start() try: file_paths = config['SNIRF_FILES'] file_params = config['PARAMS'] file_metadata = config['METADATA'] max_workers = file_params.get("MAX_WORKERS", int(os.cpu_count()/4)) results = process_multiple_participants( file_paths, file_params, file_metadata, progress_queue, max_workers ) gui_queue.put({"success": True, "result": results}) except Exception as e: gui_queue.put({ "success": False, "error": str(e), "traceback": traceback.format_exc() }) finally: # Always send done to the thread and avoid hanging try: progress_queue.put("__done__") except: pass t.join(timeout=5) # prevent permanent hang def process_participant_worker(args): file_path, file_params, file_metadata, progress_queue = args set_config_me(file_params) set_metadata(file_path, file_metadata) logger.info(f"DEBUG: Metadata for {file_path}: AGE={globals().get('AGE')}, GENDER={globals().get('GENDER')}, GROUP={globals().get('GROUP')}") def progress_callback(step_idx): if progress_queue: progress_queue.put(('progress', file_path, step_idx)) try: result = process_participant(file_path, progress_callback=progress_callback) return file_path, result, None except Exception as e: error_trace = traceback.format_exc() return file_path, None, (str(e), error_trace) def process_multiple_participants(file_paths, file_params, file_metadata, progress_queue=None, max_workers=None): results_by_file = {} file_args = [(file_path, file_params, file_metadata, progress_queue) for file_path in file_paths] with ProcessPoolExecutor(max_workers=max_workers) as executor: futures = {executor.submit(process_participant_worker, arg): arg[0] for arg in file_args} for future in as_completed(futures): file_path = futures[future] try: file_path, result, error = future.result() if error: error_message, error_traceback = error if progress_queue: progress_queue.put({ "type": "error", "file": file_path, "error": error_message, "traceback": error_traceback }) continue results_by_file[file_path] = result except Exception as e: print(f"Unexpected error processing {file_path}: {e}") return results_by_file def markbad(data, ax, ch_names: list[str]) -> None: """ Add a strikethrough to a plot for channels marked as bad. Parameters ---------- data : BaseRaw The loaded data object to process. ax : Axes Matplotlib Axes object where the strikethrough lines will be drawn. ch_names : list[str] List of channel names corresponding to the y-axis of the plot. """ # Iterate over all the channels for i, ch in enumerate(ch_names): # If it is marked as bad, place a strikethrough on the channel if ch in data.info["bads"]: ax.axhline(i + 0.5, ls="solid", lw=4, color="black", zorder=10) # type: ignore def plot_timechannel_quality_metrics(data, scores, times: list[tuple[float]], color_stops: tuple[list[float], list[float]], threshold: float, title: Optional[str] = None): """ Generate two heatmaps visualizing channel quality metrics over time. Parameters ---------- data : BaseRaw The loaded data object to process. scores : NDArray[float64] A 2D array of quality scores for each channel over time. times : list[tuple[float]] List of time boundaries used to label each score column. color_stops : tuple[list[float], list[float]] Two lists of color values for custom colormaps. threshold : float, Threshold value for the color bar. title : Optional[str], optional Base title for the figures, (default is None). Returns ------- tuple[Figure, Figure] - Figure: Heatmap of all scores across channels and time. - Figure: Binary heatmap showing only scores above the threshold. """ # Get only the hbo / hbr channels once as we dont need to see the same results twice half_ch = len(getattr(data, "ch_names")) // 2 ch_names = getattr(data, "ch_names")[:half_ch] scores = scores[:half_ch, :] # Extract rounded time points to use as column headers cols = [np.round(t[0]) for t in times] n_chans = len(ch_names) vsize = 0.2 * n_chans # Create the first figure fig1, ax1 = plt.subplots(figsize=(10, vsize), layout="constrained") # type: ignore fig1.suptitle(title + " - All Scores", fontsize=16, fontweight="bold") # type: ignore # Create a DataFrame to structure data for the heatmap data_to_plot = DataFrame( data=scores, columns=pd.Index(cols, name="Time (s)"), index=pd.Index(ch_names, name="Channel"), ) # Define a custom colormap using provided color stops and base colors base_colors = ['red', 'red', 'yellow', 'green', 'green'] colors = list(zip(color_stops[0], base_colors[:len(color_stops[0])])) cmap = mcolors.LinearSegmentedColormap.from_list('gyr', colors) # Plot heatmap of scores sns.heatmap( # type: ignore data=data_to_plot, cmap=cmap, vmin=0, vmax=1, cbar_kws=dict(label="Score"), ax=ax1, ) # Add vertical dashed lines at each time boundary, sit the title, and place a black strikethrough through a bad channel for x in range(1, len(times)): ax1.axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") # type: ignore ax1.set_title("All Scores", fontweight="bold") # type: ignore markbad(data, ax1, ch_names) # Calculate average score per channel and annotate to the right of the heatmap avg_sci_subset: pd.Series[float] = data_to_plot.mean(axis=1) # type: ignore norm = mcolors.Normalize(vmin=0, vmax=1) text_x = data_to_plot.shape[1] + 0.5 for i, val in enumerate(avg_sci_subset): color = cmap(norm(val)) ax1.text( # type: ignore text_x, i + 0.5, f"{val:.3f}", va='center', ha='left', fontsize=9, color=color ) ax1.set_xlim(right=text_x + 1.5) plt.close(fig1) # Create the second figure fig2, ax2 = plt.subplots(figsize=(10, vsize), layout="constrained") # type: ignore fig2.suptitle(title + " - Scores Above Threshold", fontsize=16, fontweight="bold") # type: ignore # Create a DataFrame to structure data for the heatmap data_to_plot = DataFrame( data=scores > threshold, columns=pd.Index(cols, name="Time (s)"), index=pd.Index(ch_names, name="Channel"), ) # Define a custom colormap using provided color stops and base colors base_colors = ['red', 'red', 'white', 'white'] colors = list(zip(color_stops[1], base_colors[:len(color_stops[1])])) cmap = mcolors.LinearSegmentedColormap.from_list('gyr', colors) # Plot heatmap of scores sns.heatmap( # type: ignore data=data_to_plot, vmin=0, vmax=1, cmap=cmap, cbar_kws=dict(label="Score"), ax=ax2, ) # Add vertical dashed lines at each time boundary, sit the title, and place a black strikethrough through a bad channel for x in range(1, len(times)): ax2.axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") # type: ignore ax2.set_title("Scores > Threshold", fontweight="bold") # type: ignore markbad(data, ax2, ch_names) plt.close(fig2) return fig1, fig2 def scalp_coupling_index_windowed_raw(data, time_window: float = 3.0, l_freq: float = 0.7, h_freq: float = 1.5, l_trans_bandwidth: float = 0.3, h_trans_bandwidth: float = 0.3): """ Compute windowed scalp coupling index (SCI) across fNIRS channels. Parameters ---------- data : BaseRaw The loaded data object to process. time_window : float, optional Length of each time window in seconds (default is 3.0). l_freq : float, optional Low cutoff frequency for filtering in Hz (default is 0.7). h_freq : float, optional High cutoff frequency for filtering in Hz (default is 1.5). l_trans_bandwidth : float, optional Transition bandwidth for the low cutoff in Hz (default is 0.3). h_trans_bandwidth : float, optional Transition bandwidth for the high cutoff in Hz (default is 0.3). Returns ------- tuple[BaseRaw, NDArray[float64], list[tuple[float, float]]] - BaseRaw: The original data object (unchanged). Ensures compatibility with peak_power(). - NDArray[float64]: Correlation scores for each channel and time window. - list[tuple[float, float]]: Time intervals for each window in seconds. """ # Pick only fNIRS channels and sort them by channel name picks: NDArray[np.intp] = pick_types(cast(Info, data.info), fnirs=True) # type: ignore picks = picks[np.argsort([getattr(data, "ch_names")[pick] for pick in picks])] # FIXME: This may happen if the heart rate calculation tries to set a value way too low if l_freq < 0.3: l_freq = 0.3 # Band-pass filter the selected fNIRS channels filtered_data = filter_data( getattr(data, "_data"), getattr(data, "info")["sfreq"], l_freq, h_freq, picks=picks, verbose=False, l_trans_bandwidth=l_trans_bandwidth, # type: ignore h_trans_bandwidth=h_trans_bandwidth, # type: ignore ) # Calculate number of samples per time window, the total number of windows, and prepare output variables window_samples = int(np.ceil(time_window * getattr(data, "info")["sfreq"])) n_windows = int(np.floor(len(data) / window_samples)) scores = np.zeros((len(picks), n_windows)) times: list[tuple[float, float]] = [] # Slide through the data in windows to compute scalp coupling index (SCI) for window in range(n_windows): start_sample = int(window * window_samples) end_sample = start_sample + window_samples end_sample = np.min([end_sample, len(data) - 1]) # Track time boundaries for each window t_start = getattr(data, "times")[start_sample] t_stop = getattr(data, "times")[end_sample] times.append((t_start, t_stop)) # Iterate through channels in pairs (hbo, hbr). This requires them to be sorted by channel name for ii in range(0, len(picks), 2): c1 = filtered_data[picks[ii]][start_sample:end_sample] c2 = filtered_data[picks[ii + 1]][start_sample:end_sample] # Ensure the correlation data is valid if np.std(c1) == 0 or np.std(c2) == 0 or np.any(np.isnan(c1)) or np.any(np.isnan(c2)): c = 0 else: c = np.corrcoef(c1, c2)[0][1] # Assign the computed correlation to both channels in the pair scores[ii, window] = c scores[ii + 1, window] = c scores = scores[np.argsort(picks)] return data, scores, times def calculate_scalp_coupling(data, l_freq: float = 0.7, h_freq: float = 1.5): """ Calculate the scalp coupling index (SCI) and identify bad channels based on a threshold. Parameters ---------- data : BaseRaw The loaded data object to process. l_freq : float, optional Low cutoff frequency for bandpass filtering in Hz (default is 0.7). h_freq : float, optional High cutoff frequency for bandpass filtering in Hz (default is 1.5) Returns ------- tuple[list[str], Figure, Figure] - list[str]: Channel names identified as bad based on SCI threshold. - Figure: Heatmap of all SCI scores across time and channels. - Figure: Binary heatmap of SCI scores exceeding the threshold. """ print("Calculating scalp coupling index...") # Compute the SCI _, scores, times = scalp_coupling_index_windowed_raw(data, time_window=SCI_TIME_WINDOW, l_freq=l_freq, h_freq=h_freq) # Identify channels that don't meet the provided threshold print("Identifying channels that do not meet the threshold...") sci = scores.mean(axis=1) data.info["bads"] = list(compress(cast(list[str], getattr(data, "ch_names")), sci < SCI_THRESHOLD)) # Determine the colors based on the threshold, and create the figures print("Creating the figures...") color_stops = ([0.0, SCI_THRESHOLD, SCI_THRESHOLD+0.1, 0.8, 1.0], [0.0, SCI_THRESHOLD, SCI_THRESHOLD, 1.0]) fig1, fig2 = plot_timechannel_quality_metrics(data, scores, times, color_stops, SCI_THRESHOLD, "Scalp Coupling Index") print("Successfully calculated scalp coupling index.") return list(compress(cast(list[str], getattr(data, "ch_names")), sci < SCI_THRESHOLD)), fig1, fig2 def calculate_signal_noise_ratio(data): """ Calculates the signal-to-noise ratio (SNR) for each channel and identifies those below a defined threshold. Parameters ---------- data : BaseRaw The loaded data object to process. Returns ------- tuple[list[str], Figure] - list[str]: A list of channel names that fall below the SNR threshold and are considered bad. - Figure: A matplotlib Figure showing the channels' SNR values. """ print("Calculating signal to noise ratio...") # Compute the signal-to-noise ratio values print("Computing the signal to noise power...") signal_band=(0.01, 0.5) noise_band=(1.0, 10.0) data_signal = data.copy().filter(*signal_band, verbose=False) #type: ignore data_noise = data.copy().filter(*noise_band, verbose=False) #type: ignore signal_power = np.mean(data_signal.get_data()**2, axis=1) #type: ignore noise_power = np.mean(data_noise.get_data()**2, axis=1) #type: ignore # Calculate the snr using the standard formula for dB snr = 10 * np.log10(signal_power / (noise_power + np.finfo(float).eps)) # TODO: Understand what this does groups: dict[str, list[str]] = {} for ch in getattr(data, "ch_names"): # Look for the space in the channel names and remove the characters after # This is so we can get both oxy and deoxy to remove, as they will have the same source and detector base = ch.rsplit(' ', 1)[0] groups.setdefault(base, []).append(ch) # type: ignore # If any of the channels do not meet our threshold, they will get inserted into the bad_channels set bad_channels: set[str] = set() for base, ch_list in groups.items(): if any(s < SNR_THRESHOLD for s, ch in zip(snr, getattr(data, "ch_names")) if ch in ch_list): bad_channels.update(ch_list) # Design and create the figure print("Creating the figure...") snr_fig, ax = plt.subplots(figsize=(12, 4), layout="constrained") # type: ignore colors = [(0/20, 'red'), (SNR_THRESHOLD/20, 'red'), ((SNR_THRESHOLD+.5)/20, 'yellow'), ((SNR_THRESHOLD+1)/20, 'green'), (20/20, 'green')] cmap = LinearSegmentedColormap.from_list('custom_snr_cmap', colors) norm = mcolors.Normalize(vmin=0, vmax=20) scatter = ax.scatter(range(len(snr)), snr, c=snr, cmap=cmap, alpha=0.8, s=100, norm=norm) # type: ignore ax.set(xlabel="Channel Number", ylabel="Signal-to-Noise Ratio (dB)", xlim=[0, len(snr)], ylim=[0, 20]) ax.axhline(SNR_THRESHOLD, color='black', linestyle='--', alpha=0.3, linewidth=1) # type: ignore cbar = snr_fig.colorbar(scatter, ax=ax, label="SNR Thresholds (dB)") # type: ignore cbar.set_ticks([0, SNR_THRESHOLD, SNR_THRESHOLD+1, 20]) # type: ignore cbar.set_ticklabels(['0', str(SNR_THRESHOLD), str(SNR_THRESHOLD+1), '20']) # type: ignore plt.close() print("Successfully calculated signal to noise ratio.") return list(bad_channels), snr_fig def build_fnirs_adjacency(raw, threshold_meters=0.03): """Build an adjacency dictionary for fNIRS channels using 3D distance.""" # Extract channel positions ch_locs = [] ch_names = [] for ch in raw.info['chs']: loc = ch['loc'][:3] # Get x, y, z coordinates if not np.isnan(loc).any(): ch_locs.append(loc) ch_names.append(ch['ch_name']) ch_locs = np.array(ch_locs) # Compute pairwise distances dists = cdist(ch_locs, ch_locs) # Build adjacency dictionary adjacency = {} for i, ch_name in enumerate(ch_names): neighbors = [ch_names[j] for j in range(len(ch_names)) if 0 < dists[i, j] < threshold_meters] adjacency[ch_name] = neighbors return adjacency def get_hbo_hbr_picks(raw): # Pick all fNIRS channels fnirs_picks = pick_types(raw.info, fnirs=True, exclude=[]) # Extract wavelengths from channel names (expecting something like 'S6_D4 763' or 'S6_D4 841') wavelengths = [] for idx in fnirs_picks: ch_name = raw.ch_names[idx] # Extract last 3 digits from channel name using regex match = re.search(r'(\d{3})$', ch_name) if match: wavelengths.append(int(match.group(1))) else: raise ValueError(f"Channel name '{ch_name}' does not end with 3 digits.") wavelengths = np.array(wavelengths) unique_wavelengths = np.unique(wavelengths) if len(unique_wavelengths) != 2: raise RuntimeError(f"Expected exactly 2 distinct wavelengths, found {unique_wavelengths}") # Determine which is HbO (larger) and which is HbR (smaller) hbr_wl = unique_wavelengths.min() hbo_wl = unique_wavelengths.max() print(f"HbR wavelength: {hbr_wl}, HbO wavelength: {hbo_wl}") # Find picks corresponding to each wavelength hbr_picks = [fnirs_picks[i] for i, wl in enumerate(wavelengths) if wl == hbr_wl] hbo_picks = [fnirs_picks[i] for i, wl in enumerate(wavelengths) if wl == hbo_wl] print(f"Found {len(hbr_picks)} HbR channels and {len(hbo_picks)} HbO channels.") return hbo_picks, hbr_picks, hbo_wl, hbr_wl def interpolate_fNIRS_bads_weighted_average(raw, bad_channels, max_dist=0.03, min_neighbors=2): """ Interpolate bad fNIRS channels using a distance-weighted average of nearby good channels. Parameters ---------- raw : mne.io.Raw The raw fNIRS data with bads marked in raw.info['bads']. max_dist : float Maximum distance (in meters) to consider for neighboring good channels. min_neighbors : int Minimum number of neighbors required to interpolate a bad channel. Returns ------- raw : mne.io.Raw Modified raw object with bads interpolated (in-place). """ print("Finding fNIRS channels...") hbo_picks, hbr_picks, hbo_wl, hbr_wl = get_hbo_hbr_picks(raw) if len(hbo_picks) != len(hbr_picks): raise RuntimeError("Number of HbO and HbR channels must be the same.") # Base names without wavelength for pairing def base_name(ch_name): # Strip last 4 chars assuming format ' ' # e.g. "S6_D6 841" -> "S6_D6" return ch_name[:-4] hbo_names = [base_name(raw.ch_names[i]) for i in hbo_picks] hbr_names = [base_name(raw.ch_names[i]) for i in hbr_picks] # Sanity check: pairs must match for i in range(len(hbo_names)): if hbo_names[i] != hbr_names[i]: raise RuntimeError(f"Channel pairs do not match: {hbo_names[i]} vs {hbr_names[i]}") # Identify bad pairs if either channel in pair is bad bad_pairs = [] good_pairs = [] for i, base in enumerate(hbo_names): hbo_ch = raw.ch_names[hbo_picks[i]] hbr_ch = raw.ch_names[hbr_picks[i]] if (hbo_ch in raw.info['bads']) or (hbr_ch in raw.info['bads']): bad_pairs.append(i) else: good_pairs.append(i) print(f"Total pairs: {len(hbo_names)}") print(f"Good pairs: {len(good_pairs)}") print(f"Bad pairs to interpolate: {len(bad_pairs)}") if len(bad_pairs) == 0: print("No bad pairs found. Skipping interpolation.") return raw # Extract locations (use HbO channel loc as pair location) locs = np.array([raw.info['chs'][hbo_picks[i]]['loc'][:3] for i in range(len(hbo_names))]) good_locs = locs[good_pairs] bad_locs = locs[bad_pairs] # Compute distance matrix between bad and good pairs dist_matrix = cdist(bad_locs, good_locs) interpolated_pairs = [] for i, bad_idx in enumerate(bad_pairs): bad_base = hbo_names[bad_idx] distances = dist_matrix[i] close_idxs = np.where(distances < max_dist)[0] print(f"\nInterpolating pair {bad_base} (index {bad_idx})") print(f" Nearby good pairs found: {len(close_idxs)}") if len(close_idxs) < min_neighbors: print(f" Skipping {bad_base}: not enough neighbors (found {len(close_idxs)} < {min_neighbors})") continue weights = 1 / (distances[close_idxs] + 1e-6) weights /= weights.sum() neighbor_hbo_indices = [hbo_picks[good_pairs[idx]] for idx in close_idxs] neighbor_hbr_indices = [hbr_picks[good_pairs[idx]] for idx in close_idxs] neighbor_hbo_data = raw._data[neighbor_hbo_indices, :] neighbor_hbr_data = raw._data[neighbor_hbr_indices, :] interpolated_hbo = np.average(neighbor_hbo_data, axis=0, weights=weights) interpolated_hbr = np.average(neighbor_hbr_data, axis=0, weights=weights) raw._data[hbo_picks[bad_idx]] = interpolated_hbo raw._data[hbr_picks[bad_idx]] = interpolated_hbr interpolated_pairs.append(bad_base) if interpolated_pairs: bad_ch_to_remove = [] for base_ in interpolated_pairs: bad_ch_to_remove.append(base_ + f" {hbr_wl}") # HbR bad_ch_to_remove.append(base_ + f" {hbo_wl}") # HbO raw.info['bads'] = [ch for ch in raw.info['bads'] if ch not in bad_ch_to_remove] print("\nInterpolation complete.\n") for ch in raw.info['bads']: print(f"Channel {ch} still marked as bad.") print("Bads cleared:", raw.info['bads']) fig_raw_after = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="After interpolation", show=False) return raw, fig_raw_after def calculate_signal_noise_ratio(data): """ Calculates the signal-to-noise ratio (SNR) for each channel and identifies those below a defined threshold. Parameters ---------- data : BaseRaw The loaded data object to process. Returns ------- tuple[list[str], Figure] - list[str]: A list of channel names that fall below the SNR threshold and are considered bad. - Figure: A matplotlib Figure showing the channels' SNR values. """ print("Calculating signal to noise ratio...") # Compute the signal-to-noise ratio values print("Computing the signal to noise power...") signal_band=(0.01, 0.5) noise_band=(1.0, 10.0) data_signal = data.copy().filter(*signal_band, verbose=False) #type: ignore data_noise = data.copy().filter(*noise_band, verbose=False) #type: ignore signal_power = np.mean(data_signal.get_data()**2, axis=1) #type: ignore noise_power = np.mean(data_noise.get_data()**2, axis=1) #type: ignore # Calculate the snr using the standard formula for dB snr = 10 * np.log10(signal_power / (noise_power + np.finfo(float).eps)) # TODO: Understand what this does groups: dict[str, list[str]] = {} for ch in getattr(data, "ch_names"): # Look for the space in the channel names and remove the characters after # This is so we can get both oxy and deoxy to remove, as they will have the same source and detector base = ch.rsplit(' ', 1)[0] groups.setdefault(base, []).append(ch) # type: ignore # If any of the channels do not meet our threshold, they will get inserted into the bad_channels set bad_channels: set[str] = set() for base, ch_list in groups.items(): if any(s < SNR_THRESHOLD for s, ch in zip(snr, getattr(data, "ch_names")) if ch in ch_list): bad_channels.update(ch_list) # Design and create the figure print("Creating the figure...") snr_fig, ax = plt.subplots(figsize=(12, 4), layout="constrained") # type: ignore colors = [(0/20, 'red'), (SNR_THRESHOLD/20, 'red'), ((SNR_THRESHOLD+.5)/20, 'yellow'), ((SNR_THRESHOLD+1)/20, 'green'), (20/20, 'green')] cmap = LinearSegmentedColormap.from_list('custom_snr_cmap', colors) norm = mcolors.Normalize(vmin=0, vmax=20) scatter = ax.scatter(range(len(snr)), snr, c=snr, cmap=cmap, alpha=0.8, s=100, norm=norm) # type: ignore ax.set(xlabel="Channel Number", ylabel="Signal-to-Noise Ratio (dB)", xlim=[0, len(snr)], ylim=[0, 20]) ax.axhline(SNR_THRESHOLD, color='black', linestyle='--', alpha=0.3, linewidth=1) # type: ignore cbar = snr_fig.colorbar(scatter, ax=ax, label="SNR Thresholds (dB)") # type: ignore cbar.set_ticks([0, SNR_THRESHOLD, SNR_THRESHOLD+1, 20]) # type: ignore cbar.set_ticklabels(['0', str(SNR_THRESHOLD), str(SNR_THRESHOLD+1), '20']) # type: ignore plt.close() print("Successfully calculated signal to noise ratio.") return list(bad_channels), snr_fig def calculate_peak_power(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5) -> tuple[list[str], Figure, Figure]: """ Calculate peak spectral power (PSP) for fNIRS channels and identify bad channels. Parameters ---------- data : BaseRaw The loaded data object to process. l_freq : float, optional Low cutoff frequency for filtering in Hz (default is 0.7) h_freq : float, optional High cutoff frequency for filtering in Hz (default is 1.5) Returns ------- tuple[list[str], Figure, Figure] - list[str]: Names of channels below the PSP threshold. - Figure: Heatmap of all PSP scores. - Figure: Heatmap of scores above the PSP threshold. """ # Compute the PSP _, scores, times = cast(tuple[NDArray[float64], NDArray[float64], list[tuple[float]]], peak_power(data, time_window=PSP_TIME_WINDOW, threshold=PSP_THRESHOLD, l_freq=l_freq, h_freq=h_freq)) # Identify channels that don't meet the provided threshold psp = scores.mean(axis=1) data.info["bads"] = list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)) # Determine the colors based on the threshold, and create the figures color_stops = ([0.0, PSP_THRESHOLD, PSP_THRESHOLD+0.1, PSP_THRESHOLD+0.2, 1.0], [0.0, PSP_THRESHOLD, PSP_THRESHOLD, 1.0]) psp1, psp2 = plot_timechannel_quality_metrics(data, scores, times, color_stops, PSP_THRESHOLD, "Peak Spectral Power") return list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)), psp1, psp2 def mark_bads(raw, bad_sci, bad_snr, bad_psp): bads_combined = list(set(bad_snr) | set(bad_sci) | set(bad_psp)) print(f"Automatically marked bad channels based on SNR and SCI: {bads_combined}") raw.info['bads'].extend(bads_combined) # Organize channels into categories sets = [ (bad_sci, "SCI"), (bad_psp, "PSP"), (bad_snr, "SNR"), ] # Graph what channels were dropped and why they were dropped channel_categories: dict[str, str] = {} for ch in bads_combined: present_in = [name for s, name in sets if ch in s] # Create a label for the category if len(present_in) == 1: label = f"{present_in[0]} only" else: label = " + ".join(sorted(present_in)) channel_categories[ch] = label # Sort channels alphabetically within categories for nicer visualization categories = sorted(set(channel_categories.values())) channel_names: list[str] = [] category_labels: list[str] = [] for cat in categories: chs_in_cat = sorted([ch for ch, c in channel_categories.items() if c == cat]) channel_names.extend(chs_in_cat) category_labels.extend([cat] * len(chs_in_cat)) colors = {cat: FIXED_CATEGORY_COLORS[cat] for cat in categories} # Create the figure fig_dropped, ax = plt.subplots(figsize=(10, max(3, len(channel_names) * 0.3))) # type: ignore y_pos = range(len(channel_names)) ax.barh(y_pos, [1]*len(channel_names), color=[colors[cat] for cat in category_labels]) # type: ignore ax.set_yticks(y_pos) # type: ignore ax.set_yticklabels(channel_names) # type: ignore ax.set_xlabel("Marked as Bad") # type: ignore ax.set_title(f"Bad Channels by Method for") # type: ignore ax.set_xlim(0, 1) ax.set_xticks([]) # type: ignore ax.grid(axis='x', linestyle='--', alpha=0.7) # type: ignore # Add a legend denoting why the channels were bad for label, color in colors.items(): ax.bar(0, 0, color=color, label=label) # type: ignore ax.legend() # type: ignore fig_dropped.tight_layout() raw_before = deepcopy(raw) bads_channels = [ch for ch in raw.ch_names if ch in raw.info['bads']] print(bads_channels) if bads_channels: fig_raw_before = raw_before.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], picks=bads_channels, title="What they were BEFORE", show=False) else: fig_dropped = None fig_raw_before = None return raw, fig_dropped, fig_raw_before, bads_channels def filter_the_data(raw_haemo): # --- STEP 5: Filtering (0.01–0.2 Hz bandpass) --- fig_filter = raw_haemo.compute_psd(fmax=3).plot( average=True, color="r", show=False, amplitude=True ) if L_FREQ == 0 and H_FREQ != 0: raw_haemo = raw_haemo.filter(l_freq=None, h_freq=H_FREQ, h_trans_bandwidth=0.02) elif L_FREQ != 0 and H_FREQ == 0: raw_haemo = raw_haemo.filter(l_freq=L_FREQ, h_freq=None, l_trans_bandwidth=0.002) elif L_FREQ != 0 and H_FREQ != 0: raw_haemo = raw_haemo.filter(l_freq=L_FREQ, h_freq=H_FREQ, l_trans_bandwidth=0.002, h_trans_bandwidth=0.02) else: print("No filter") #raw_haemo = raw_haemo.filter(l_freq=None, h_freq=0.4, h_trans_bandwidth=0.2) #raw_haemo = raw_haemo.filter(l_freq=None, h_freq=0.7, h_trans_bandwidth=0.2) #raw_haemo = raw_haemo.filter(0.005, 0.7, h_trans_bandwidth=0.02, l_trans_bandwidth=0.002) raw_haemo.compute_psd(fmax=3).plot( average=True, axes=fig_filter.axes, color="g", amplitude=True, show=False ) fig_raw_haemo_filter = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="Filtered HbO and HbR", show=False) return fig_filter, fig_raw_haemo_filter def epochs_calculations(raw_haemo, events, event_dict): fig_epochs = [] # List to store figures # Create epochs from raw data epochs = Epochs(raw_haemo, events, event_id=event_dict, tmin=-5, tmax=15, baseline=(None, 0)) # Make a copy of the epochs and drop bad ones epochs2 = epochs.copy() epochs2.drop_bad() # Plot drop log # TODO: Why show this if we never use epochs2? fig_epochs_dropped = epochs2.plot_drop_log(show=False) fig_epochs.append(("fig_epochs_dropped", fig_epochs_dropped)) # Plot for each condition fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 4)) for idx, condition in enumerate(epochs.event_id.keys()): logger.info(condition) logger.info(idx) # Plot images for each condition fig_epochs_data = epochs[condition].plot_image( combine="mean", vmin=-1, vmax=1, ts_args=dict(ylim=dict(hbo=[-1, 1], hbr=[-1, 1])), show=False ) for j, fig in enumerate(fig_epochs_data): logger.info("------------------------------------------") logger.info(j) logger.info(fig) ax = fig.axes[0] original_title = ax.get_title() ax.set_title(f"{condition}: {original_title}") fig_epochs.append((f"fig_{condition}_data_{idx}_{j}", fig)) # Store with a unique name # Evoked average figure for each condition evoked_avg = epochs[condition].average() clims = dict(hbo=[-1, 1], hbr=[1, -1]) condition_fig = evoked_avg.plot_image(clim=clims, show=False) for ax in condition_fig.axes: original_title = ax.get_title() ax.set_title(f"{original_title} - {condition}") fig_epochs.append((f"evoked_avg_{condition}", condition_fig)) # Store with a unique name # Prepare evokeds and colors for topographic plot evokeds3 = [] colors = [] conditions = list(epochs.event_id.keys()) cmap = plt.get_cmap("tab10", len(conditions)) for idx, cond in enumerate(conditions): evoked = epochs[cond].average(picks="hbo") evokeds3.append(evoked) colors.append(cmap(idx)) # Create the topographic plot fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 4)) help = plot_evoked_topo(evokeds3, color=colors, axes=axes, legend=False, show=False) # Build custom legend lines = [] for color in colors: line = plt.Line2D([0], [0], color=color, lw=2) lines.append(line) fig.legend(lines, conditions, loc="lower right") fig_epochs.append(("evoked_topo", help)) # Store with a unique name unique_annotations = set(raw_haemo.annotations.description) for cond in unique_annotations: # Evoked response for specific condition ("Activity") evoked_stim1 = epochs[cond].average() fig_evoked_hbo = evoked_stim1.copy().pick(picks='hbo').plot(time_unit='s', show=False) fig_evoked_hbr = evoked_stim1.copy().pick(picks='hbr').plot(time_unit='s', show=False) fig_epochs.append((f"fig_evoked_hbo_{cond}", fig_evoked_hbo)) # Store with a unique name fig_epochs.append((f"fig_evoked_hbr_{cond}", fig_evoked_hbr)) # Store with a unique name print("Evoked HbO peak amplitude:", evoked_stim1.copy().pick(picks='hbo').data.max()) evokeds = {} for condition in epochs2.event_id: evokeds[condition] = epochs2[condition].average() print(f"Condition '{condition}': {len(epochs2[condition])} epochs averaged.") all_evokeds = {} for condition in epochs.event_id: if condition not in all_evokeds: all_evokeds[condition] = [] all_evokeds[condition].append(epochs[condition].average()) group_aucs = {} # TODO: group averages with a single person? group_averages = {cond: grand_average(evokeds) for cond, evokeds in all_evokeds.items()} for condition, evoked in group_averages.items(): group_aucs[condition] = {} for pick in ["hbo", "hbr"]: picks_idx = [i for i, ch in enumerate(evoked.ch_names) if pick in ch] if not picks_idx: continue data = evoked.data[picks_idx, :].mean(axis=0) t_start, t_end = 0, 15 times_mask = (evoked.times >= t_start) & (evoked.times <= t_end) data_segment = data[times_mask] times_segment = evoked.times[times_mask] auc = np.trapezoid(data_segment, times_segment) group_aucs[condition][pick] = auc # Final evoked comparison plot for each condition for condition in conditions: if condition not in evokeds: continue evoked = evokeds[condition] fig, ax = plt.subplots(figsize=(6, 5)) legend_labels = ["Oxyhaemoglobin"] for pick, color in zip(["hbo", "hbr"], ["r", "b"]): plot_compare_evokeds( evoked, combine="mean", picks=pick, axes=ax, show=False, colors=[color], legend=False, title=f"Participant: nCondition: {condition}", ylim=dict(hbo=[-0.5, 1], hbr=[-0.5, 1]), show_sensors=False, ) auc_value = group_aucs.get(condition, {}).get(pick, None) if auc_value is not None: label = f"{pick.upper()} AUC: {auc_value * 1e6:.4f} µM·s" else: label = f"{pick.upper()} AUC: N/A" legend_labels.append(label) if len(legend_labels) == 2: legend_labels.append("Deoxyhaemoglobin") ax.legend(legend_labels) fig_epochs.append((f"fig_{condition}_compare_evokeds", fig)) # Store with a unique name return epochs, fig_epochs def make_design_matrix(raw_haemo, short_chans): raw_haemo.resample(1, npad="auto") raw_haemo._data = raw_haemo._data * 1e6 # 2) Create design matrix if SHORT_CHANNEL: short_chans.resample(1) design_matrix = make_first_level_design_matrix( raw=raw_haemo, hrf_model='fir', stim_dur=0.5, fir_delays=range(15), drift_model='cosine', high_pass=0.01, oversampling=1, min_onset=-125, add_regs=short_chans.get_data().T, add_reg_names=short_chans.ch_names ) else: design_matrix = make_first_level_design_matrix( raw=raw_haemo, hrf_model='fir', stim_dur=0.5, fir_delays=range(15), drift_model='cosine', high_pass=0.01, oversampling=1, min_onset=-125, ) print(design_matrix.head()) print(design_matrix.columns) fig, ax1 = plt.subplots(figsize=(10, 6), constrained_layout=True) _ = plot_design_matrix(design_matrix, axes=ax1) return design_matrix, fig def generate_montage_locations(): """Get standard MNI montage locations in dataframe. Data is returned in the same format as the eeg_positions library. """ # standard_1020 and standard_1005 are in MNI (fsaverage) space already, # but we need to undo the scaling that head_scale will do montage = make_standard_montage( "standard_1005", head_size=0.09700884729534559 ) for d in montage.dig: d["coord_frame"] = 2003 montage.dig[:] = montage.dig[3:] montage.add_mni_fiducials() # now in fsaverage space coords = pd.DataFrame.from_dict(montage.get_positions()["ch_pos"]).T coords["label"] = coords.index coords = coords.rename(columns={0: "x", 1: "y", 2: "z"}) return coords.reset_index(drop=True) def _find_closest_standard_location(position, reference, *, out="label"): """Return closest montage label to coordinates. Parameters ---------- position : array, shape (3,) Coordinates. reference : dataframe As generated by _generate_montage_locations. trans_pos : str Apply a transformation to positions to specified frame. Use None for no transformation. """ p0 = np.array(position) p0.shape = (-1, 3) # head_mri_t, _ = _get_trans("fsaverage", "head", "mri") # p0 = apply_trans(head_mri_t, p0) dists = cdist(p0, np.asarray(reference[["x", "y", "z"]], float)) if out == "label": min_idx = np.argmin(dists) return reference["label"][min_idx] else: assert out == "dists" return dists def _source_detector_fold_table(raw, cidx, reference, fold_tbl, interpolate): src = raw.info["chs"][cidx]["loc"][3:6] det = raw.info["chs"][cidx]["loc"][6:9] ref_lab = list(reference["label"]) dists = _find_closest_standard_location([src, det], reference, out="dists") src_min, det_min = np.argmin(dists, axis=1) src_name, det_name = ref_lab[src_min], ref_lab[det_min] tbl = fold_tbl.query("Source == @src_name and Detector == @det_name") dist = np.linalg.norm(dists[[0, 1], [src_min, det_min]]) # Try reversing source and detector if len(tbl) == 0: tbl = fold_tbl.query("Source == @det_name and Detector == @src_name") if len(tbl) == 0 and interpolate: # Try something hopefully not too terrible: pick the one with the # smallest net distance good = np.isin(fold_tbl["Source"], reference["label"]) & np.isin( fold_tbl["Detector"], reference["label"] ) assert good.any() tbl = fold_tbl[good] assert len(tbl) src_idx = [ref_lab.index(src) for src in tbl["Source"]] det_idx = [ref_lab.index(det) for det in tbl["Detector"]] # Original tot_dist = np.linalg.norm([dists[0, src_idx], dists[1, det_idx]], axis=0) assert tot_dist.shape == (len(tbl),) idx = np.argmin(tot_dist) dist_1 = tot_dist[idx] src_1, det_1 = ref_lab[src_idx[idx]], ref_lab[det_idx[idx]] # And the reverse tot_dist = np.linalg.norm([dists[0, det_idx], dists[1, src_idx]], axis=0) idx = np.argmin(tot_dist) dist_2 = tot_dist[idx] src_2, det_2 = ref_lab[det_idx[idx]], ref_lab[src_idx[idx]] if dist_1 < dist_2: new_dist, src_use, det_use = dist_1, src_1, det_1 else: new_dist, src_use, det_use = dist_2, det_2, src_2 tbl = fold_tbl.query("Source == @src_use and Detector == @det_use") tbl = tbl.copy() tbl["BestSource"] = src_name tbl["BestDetector"] = det_name tbl["BestMatchDistance"] = dist tbl["MatchDistance"] = new_dist assert len(tbl) else: tbl = tbl.copy() tbl["BestSource"] = src_name tbl["BestDetector"] = det_name tbl["BestMatchDistance"] = dist tbl["MatchDistance"] = dist tbl = tbl.copy() # don't get warnings about setting values later return tbl def _read_fold_xls(fname, atlas="Juelich"): """Read fOLD toolbox xls file. The values are then manipulated in to a tidy dataframe. Note the xls files are not included as no license is provided. Parameters ---------- fname : str Path to xls file. atlas : str Requested atlas. """ page_reference = {"AAL2": 2, "AICHA": 5, "Brodmann": 8, "Juelich": 11, "Loni": 14} tbl = pd.read_excel(fname, sheet_name=page_reference[atlas]) # Remove the spacing between rows empty_rows = np.where(np.isnan(tbl["Specificity"]))[0] tbl = tbl.drop(empty_rows).reset_index(drop=True) # Empty values in the table mean its the same as above for row_idx in range(1, tbl.shape[0]): for col_idx, col in enumerate(tbl.columns): if not isinstance(tbl[col][row_idx], str): if np.isnan(tbl[col][row_idx]): tbl.iloc[row_idx, col_idx] = tbl.iloc[row_idx - 1, col_idx] tbl["Specificity"] = tbl["Specificity"] * 100 tbl["brainSens"] = tbl["brainSens"] * 100 return tbl def _check_load_fold(fold_files, atlas): # _validate_type(fold_files, (list, "path-like", None), "fold_files") if fold_files is None: fold_files = get_config("MNE_NIRS_FOLD_PATH") if fold_files is None: raise ValueError( "MNE_NIRS_FOLD_PATH not set, either set it using " "mne.set_config or pass fold_files as str or list" ) if not isinstance(fold_files, list): # path-like fold_files = _check_fname( fold_files, overwrite="read", must_exist=True, name="fold_files", need_dir=True, ) fold_files = [op.join(fold_files, f"10-{x}.xls") for x in (5, 10)] fold_tbl = pd.DataFrame() for fi, fname in enumerate(fold_files): fname = _check_fname( fname, overwrite="read", must_exist=True, name=f"fold_files[{fi}]" ) fold_tbl = pd.concat( [fold_tbl, _read_fold_xls(fname, atlas=atlas)], ignore_index=True ) return fold_tbl def fold_channel_specificity_normal(raw, fold_files=None, atlas="Juelich", interpolate=False): """Return the landmarks and specificity a channel is sensitive to. Parameters """ # noqa: E501 _validate_type(raw, BaseRaw, "raw") reference_locations = generate_montage_locations() fold_tbl = _check_load_fold(fold_files, atlas) chan_spec = list() for cidx in range(len(raw.ch_names)): tbl = _source_detector_fold_table( raw, cidx, reference_locations, fold_tbl, interpolate ) chan_spec.append(tbl.reset_index(drop=True)) return chan_spec def resource_path(relative_path): """ Get absolute path to resource regardless of running directly or packaged using PyInstaller """ if hasattr(sys, '_MEIPASS'): # PyInstaller bundle path base_path = sys._MEIPASS else: base_path = os.path.abspath(".") return os.path.join(base_path, relative_path) def fold_channels(raw: BaseRaw) -> None: # if getattr(sys, 'frozen', False): path = os.path.expanduser("~") + "/mne_data/fOLD/fOLD-public-master/Supplementary" logger.info(path) set_config('MNE_NIRS_FOLD_PATH', resource_path(path)) # type: ignore # # Locate the fOLD excel files # else: # logger.info("yabba") # set_config('MNE_NIRS_FOLD_PATH', resource_path("../../mne_data/fOLD/fOLD-public-master/Supplementary")) # type: ignore output = None # List to store the results landmark_specificity_data: list[dict[str, Any]] = [] # Filter the data to only what we want hbo_channel_names = cast(list[str], getattr(raw.copy().pick(picks='hbo'), "ch_names")) # type: ignore # Format the output to make it slightly easier to read if True: num_channels = len(hbo_channel_names) rows, cols = 4, 7 # 6 rows and 4 columns of pie charts fig, axes = plt.subplots(rows, cols, figsize=(16, 10), constrained_layout=True) axes = axes.flatten() # Flatten the axes array for easier indexing # If more pie charts than subplots, create extra subplots if num_channels > rows * cols: fig, axes = plt.subplots((num_channels // cols) + 1, cols, figsize=(16, 10), constrained_layout=True) axes = axes.flatten() # Create a list for consistent color mapping landmarks = [ "1 - Primary Somatosensory Cortex", "2 - Primary Somatosensory Cortex", "3 - Primary Somatosensory Cortex", "4 - Primary Motor Cortex", "5 - Somatosensory Association Cortex", "6 - Pre-Motor and Supplementary Motor Cortex", "7 - Somatosensory Association Cortex", "8 - Includes Frontal eye fields", "9 - Dorsolateral prefrontal cortex", "10 - Frontopolar area", "11 - Orbitofrontal area", "17 - Primary Visual Cortex (V1)", "18 - Visual Association Cortex (V2)", "19 - V3", "20 - Inferior Temporal gyrus", "21 - Middle Temporal gyrus", "22 - Superior Temporal Gyrus", "23 - Ventral Posterior cingulate cortex", "24 - Ventral Anterior cingulate cortex", "25 - Subgenual cortex", "32 - Dorsal anterior cingulate cortex", "37 - Fusiform gyrus", "38 - Temporopolar area", "39 - Angular gyrus, part of Wernicke's area", "40 - Supramarginal gyrus part of Wernicke's area", "41 - Primary and Auditory Association Cortex", "42 - Primary and Auditory Association Cortex", "43 - Subcentral area", "44 - pars opercularis, part of Broca's area", "45 - pars triangularis Broca's area", "46 - Dorsolateral prefrontal cortex", "47 - Inferior prefrontal gyrus", "48 - Retrosubicular area", "Brain_Outside", ] cmap1 = plt.get_cmap('tab20') # First 20 colors cmap2 = plt.get_cmap('tab20b') # Next 20 colors # Combine the colors from both colormaps colors = [cmap1(i) for i in range(20)] + [cmap2(i) for i in range(20)] # Total 40 colors landmarks.sort(key=lambda x: (int(x.split(" - ")[0]) if x.split(" - ")[0].isdigit() else float('inf'))) landmark_color_map = {landmark: colors[i % len(colors)] for i, landmark in enumerate(landmarks)} # Iterate over each channel for idx, channel_name in enumerate(hbo_channel_names): # Run the fOLD on the selected channel channel_data = raw.copy().pick(picks=channel_name) # type: ignore output = cast(list[DataFrame], fold_channel_specificity_normal(channel_data, interpolate=True, atlas='Brodmann')) # Process each DataFrame that fold_channel_specificity returns for df_data in output: # Extract the relevant columns useful_data = df_data[['Landmark', 'Specificity']] # Store the results landmark_specificity_data.append({ 'Channel': channel_name, 'Data': useful_data, }) # Plot the results # TODO: Fix this if True: unique_landmarks = sorted(useful_data['Landmark'].unique()) color_list = [landmark_color_map[landmark] for landmark in useful_data['Landmark']] # Plot specificity for each channel ax = axes[idx] labels = [f'{landmark.split(" - ")[0]}' if landmark != 'Brain_Outside' else 'B' for landmark in useful_data['Landmark']] wedges, texts, autotexts = ax.pie( useful_data['Specificity'], autopct='%1.1f%%', startangle=90, labels=labels, labeldistance=1.05, colors=color_list) ax.set_title(f'{channel_name}') ax.axis('equal') landmark_specificity_data = [] # TODO: Fix this if True: handles = [ plt.Line2D([0], [0], marker='o', color='w', label=landmark, markersize=10, markerfacecolor=landmark_color_map[landmark]) for landmark in landmarks ] n_landmarks = len(landmarks) # Calculate the figure size based on number of rows and columns fig_width = 5 fig_height = n_landmarks / 4 # Create a new figure window for the legend legend_fig = plt.figure(figsize=(fig_width, fig_height)) legend_axes = legend_fig.add_subplot(111) legend_axes.axis('off') # Turn off axis for the legend window legend_axes.legend(handles=handles, loc='center', fontsize=10, title="Landmarks") for ax in axes[len(hbo_channel_names):]: ax.axis('off') plt.show() return fig, legend_fig def individual_significance(raw_haemo, glm_est): fig_individual_significances = [] # List to store figures # TODO: BAD! cha = glm_est.to_dataframe() unique_annotations = set(raw_haemo.annotations.description) for cond in unique_annotations: ch_summary = cha.query(f"Condition.str.startswith('{cond}_delay_') and Chroma == 'hbo'", engine='python') print(ch_summary.head()) channel_averages = ch_summary.groupby('ch_name')['theta'].mean().reset_index() print(channel_averages.head()) activity_ch_summary = ch_summary.query( f"Chroma == 'hbo' and Condition.str.startswith('{cond}_delay_')", engine='python' ) # Function to correct p-values per channel def fdr_correct_per_channel(df): df = df.copy() df['pval_fdr'] = multipletests(df['p_value'], method='fdr_bh')[1] return df # Apply FDR correction grouped by channel corrected = activity_ch_summary.groupby("ch_name", group_keys=False).apply(fdr_correct_per_channel) # Determine which channels are significant across any delay sig_channels = ( corrected.groupby('ch_name') .apply(lambda df: (df['pval_fdr'] < 0.05).any()) .reset_index(name='significant') ) # Merge with mean theta (optional for plotting) mean_theta = activity_ch_summary.groupby('ch_name')['theta'].mean().reset_index() sig_channels = sig_channels.merge(mean_theta, on='ch_name') print(sig_channels) # For example, take the minimum corrected p-value per channel summary_pvals = corrected.groupby('ch_name')['pval_fdr'].min().reset_index() print(summary_pvals) def parse_ch_name(ch_name): # Extract numbers after S and D in names like 'S10_D5 hbo' match = re.match(r'S(\d+)_D(\d+)', ch_name) if match: return int(match.group(1)), int(match.group(2)) else: return None, None min_pvals = corrected.groupby('ch_name')['pval_fdr'].min().reset_index() # Merge the real p-values into sig_channels / avg_df avg_df = sig_channels.merge(min_pvals, on='ch_name') # Rename columns for consistency avg_df = avg_df.rename(columns={'theta': 't_or_theta', 'pval_fdr': 'p_value'}) # Add Source and Detector columns again avg_df['Source'], avg_df['Detector'] = zip(*avg_df['ch_name'].map(parse_ch_name)) # Keep relevant columns avg_df = avg_df[['Source', 'Detector', 't_or_theta', 'p_value']].dropna() ABS_SIGNIFICANCE_THETA_VALUE = 1 ABS_SIGNIFICANCE_T_VALUE = 1 P_THRESHOLD = 0.05 SOURCE_DETECTOR_SEPARATOR = "_" t_or_theta = 'theta' for _, row in avg_df.iterrows(): # type: ignore print(f"Source {row['Source']} <-> Detector {row['Detector']}: " f"Avg {t_or_theta}-value = {row['t_or_theta']:.3f}, Avg p-value = {row['p_value']:.3f}") # Extract the cource and detector positions from raw src_pos: dict[int, tuple[float, float]] = {} det_pos: dict[int, tuple[float, float]] = {} for ch in getattr(raw_haemo, "info")["chs"]: ch_name = ch['ch_name'] if not ch_name or not ch['loc'].any(): continue parts = ch_name.split()[0] src_str, det_str = parts.split(SOURCE_DETECTOR_SEPARATOR) src_num = int(src_str[1:]) det_num = int(det_str[1:]) src_pos[src_num] = ch['loc'][3:5] det_pos[det_num] = ch['loc'][6:8] # Set up the plot fig, ax = plt.subplots(figsize=(8, 6)) # type: ignore # Plot the sources for pos in src_pos.values(): ax.scatter(pos[0], pos[1], s=120, c='k', marker='o', edgecolors='white', linewidths=1, zorder=3) # type: ignore # Plot the detectors for pos in det_pos.values(): ax.scatter(pos[0], pos[1], s=120, c='k', marker='s', edgecolors='white', linewidths=1, zorder=3) # type: ignore # Ensure that the colors stay within the boundaries even if they are over or under the max/min values if t_or_theta == 't': norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_T_VALUE, vmax=ABS_SIGNIFICANCE_T_VALUE) elif t_or_theta == 'theta': norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_THETA_VALUE, vmax=ABS_SIGNIFICANCE_THETA_VALUE) cmap: mcolors.Colormap = plt.get_cmap('seismic') # Plot connections with avg t-values for row in avg_df.itertuples(): src: int = cast(int, row.Source) # type: ignore det: int = cast(int, row.Detector) # type: ignore tval: float = cast(float, row.t_or_theta) # type: ignore pval: float = cast(float, row.p_value) # type: ignore if src in src_pos and det in det_pos: x = [src_pos[src][0], det_pos[det][0]] y = [src_pos[src][1], det_pos[det][1]] style = '-' if pval <= P_THRESHOLD else '--' ax.plot(x, y, linestyle=style, color=cmap(norm(tval)), linewidth=4, alpha=0.9, zorder=2) # type: ignore # Format the Colorbar sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) cbar = plt.colorbar(sm, ax=ax, shrink=0.85) # type: ignore cbar.set_label(f'Average {cond} {t_or_theta} value (hbo)', fontsize=11) # type: ignore # Formatting the subplots ax.set_aspect('equal') ax.set_title(f"Average {t_or_theta} values for {cond} (HbO)", fontsize=14) # type: ignore ax.set_xlabel('X position (m)', fontsize=11) # type: ignore ax.set_ylabel('Y position (m)', fontsize=11) # type: ignore ax.grid(True, alpha=0.3) # type: ignore # Set axis limits to be 1cm more than the optode positions all_x = [pos[0] for pos in src_pos.values()] + [pos[0] for pos in det_pos.values()] all_y = [pos[1] for pos in src_pos.values()] + [pos[1] for pos in det_pos.values()] ax.set_xlim(min(all_x)-0.01, max(all_x)+0.01) ax.set_ylim(min(all_y)-0.01, max(all_y)+0.01) fig.tight_layout() fig_individual_significances.append((f"Condition {cond}", fig)) return fig_individual_significances # TODO: Hardcoded def group_significance( raw_haemo, all_cha: pd.DataFrame, condition: str, correction: str = "fdr_bh" ) -> plt.Figure: """ Compute group-level significance using weighted Stouffer's method and plot results. Args: raw_haemo: Raw haemoglobin MNE object (used for optode positions) all_cha: DataFrame with columns including 'ID', 'Condition', 'p_value', 'theta', 'df', 'ch_name', 'Chroma' condition: condition prefix, e.g., 'Activity' correction: p-value correction method ('fdr_bh' or 'bonferroni') Returns: Matplotlib Figure with group-level theta values and significance. """ assert "ID" in all_cha.columns, "'ID' column missing in input data" assert len(raw_haemo) >= 1, "At least one raw haemoglobin object is required" condition_prefix = f"{condition}_delay" # Filter relevant data ch_summary = all_cha.query( "Condition.str.startswith(@condition_prefix) and Chroma == 'hbo'", engine='python' ).copy() logger.info("=== ch_summary head ===") logger.info(ch_summary.head()) logger.info("\nSummary stats:") logger.info(f"Total rows: {len(ch_summary)}") logger.info(f"Unique subjects: {ch_summary['ID'].nunique() if 'ID' in ch_summary.columns else 'ID column missing'}") logger.info(f"Unique conditions: {ch_summary['Condition'].unique()}") logger.info(f"Unique channels (Source-Detector pairs): {ch_summary.groupby(['Source', 'Detector']).ngroups}") logger.info("\nSample p_values:") logger.info(ch_summary['p_value'].describe()) if ch_summary.empty: raise ValueError(f"No data found for condition prefix: {condition_prefix}") # --- For debugging logger.info(f"Total rows after filtering for condition '{condition_prefix}': {len(ch_summary)}") logger.info(f"Unique channels: {ch_summary['ch_name'].nunique()}") logger.info(f"Participants: {ch_summary['ID'].nunique()}") # Step 1: Select the peak regressor (~6s after stimulus onset) peak_regressor = f"{condition}_delay_6" peak_data = ch_summary[ch_summary["Condition"] == peak_regressor].copy() logger.info(f"\n=== Logging all values for {peak_regressor} ===") for row in peak_data.itertuples(index=False): logger.info( f"Subject: {row.ID}, " f"Channel: {row.ch_name}, " f"Source: {row.Source}, Detector: {row.Detector}, " f"theta: {row.theta:.4f}, " f"p_value: {row.p_value:.6f}, " f"df: {row.df}" ) if peak_data.empty: raise ValueError(f"No data found for peak regressor: {peak_regressor}") # Step 2: Combine per-channel stats across subjects group_results = [] for (src, det), group in peak_data.groupby(["Source", "Detector"]): pvals = group["p_value"].values thetas = group["theta"].values dfs = group["df"].values # Weighted Stouffer's method weights = np.sqrt(dfs) z_scores = norm.isf(pvals) combined_z = np.sum(weights * z_scores) / np.sqrt(np.sum(weights**2)) combined_p = norm.sf(combined_z) theta_avg = np.average(thetas, weights=weights) group_results.append({ "Source": src, "Detector": det, "theta_avg": theta_avg, "combined_p": combined_p }) # Step 3: Create combined_df combined_df = pd.DataFrame(group_results) # Step 4: Multiple comparisons correction _, pvals_corr, _, significant = multipletests( combined_df["combined_p"], alpha=0.05, method=correction ) combined_df["pval_corr"] = pvals_corr combined_df["significant"] = significant logger.info(f"Used peak regressor: {peak_regressor}") logger.info(f"Channels tested: {len(combined_df)}") logger.info(f"Significant channels after correction: {combined_df['significant'].sum()}") # Get optode positions from the first raw file raw = raw_haemo src_pos, det_pos = {}, {} for ch in raw.info["chs"]: ch_name = ch["ch_name"] if not ch_name or not ch["loc"].any(): continue parts = ch_name.split()[0] src_str, det_str = parts.split("_") src_num = int(src_str[1:]) det_num = int(det_str[1:]) src_pos[src_num] = ch["loc"][3:5] det_pos[det_num] = ch["loc"][6:8] # Plotting parameters ABS_SIGNIFICANCE_THETA_VALUE = 1 P_THRESHOLD = 0.05 cmap = plt.get_cmap("seismic") norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_THETA_VALUE, vmax=ABS_SIGNIFICANCE_THETA_VALUE) fig, ax = plt.subplots(figsize=(8, 6)) # Plot optodes for pos in src_pos.values(): ax.scatter(*pos, s=120, c="k", marker="o", edgecolors="white", linewidths=1, zorder=3) for pos in det_pos.values(): ax.scatter(*pos, s=120, c="k", marker="s", edgecolors="white", linewidths=1, zorder=3) # Plot connections colored by average theta, solid if significant for row in combined_df.itertuples(): src, det = int(row.Source), int(row.Detector) tval, pval = row.theta_avg, row.pval_corr if src in src_pos and det in det_pos: x = [src_pos[src][0], det_pos[det][0]] y = [src_pos[src][1], det_pos[det][1]] linestyle = "-" if pval <= P_THRESHOLD else "--" ax.plot(x, y, linestyle=linestyle, color=cmap(norm(tval)), linewidth=4, alpha=0.9, zorder=2) # Colorbar sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) cbar = plt.colorbar(sm, ax=ax, shrink=0.85) cbar.set_label(f"Average {condition_prefix.rstrip('_')} θ-value (HbO)", fontsize=11) # Format axes ax.set_aspect("equal") ax.set_title(f"Group-level θ-values for {condition_prefix.rstrip('_')} (HbO)", fontsize=14) ax.set_xlabel("X position (m)", fontsize=11) ax.set_ylabel("Y position (m)", fontsize=11) ax.grid(True, alpha=0.3) all_x = [p[0] for p in src_pos.values()] + [p[0] for p in det_pos.values()] all_y = [p[1] for p in src_pos.values()] + [p[1] for p in det_pos.values()] ax.set_xlim(min(all_x) - 0.01, max(all_x) + 0.01) ax.set_ylim(min(all_y) - 0.01, max(all_y) + 0.01) fig.tight_layout() fig.show() def plot_glm_results(file_path, raw_haemo, glm_est, design_matrix): fig_glms = [] # List to store figures dm = design_matrix.copy() logger.info(design_matrix.shape) logger.info(design_matrix.columns) logger.info(design_matrix.head()) rois = dict(AllChannels=range(len(raw_haemo.ch_names))) conditions = design_matrix.columns df_individual = glm_est.to_dataframe_region_of_interest(rois, conditions) df_individual["ID"] = file_path # df_individual["theta"] = [t * 1.0e6 for t in df_individual["theta"]] first_onset_for_cond = {} for onset, desc in zip(raw_haemo.annotations.onset, raw_haemo.annotations.description): if desc not in first_onset_for_cond: first_onset_for_cond[desc] = onset # Get unique condition names from annotations (descriptions) unique_annotations = set(raw_haemo.annotations.description) for cond in unique_annotations: logger.info(cond) df_individual_filtered = df_individual.copy() # Filter for the condition of interest and FIR delays df_individual_filtered["isCondition"] = [cond in n for n in df_individual_filtered["Condition"]] df_individual_filtered["isDelay"] = ["delay" in n for n in df_individual_filtered["Condition"]] df_individual_filtered = df_individual_filtered.query("isDelay and isCondition") # Remove other conditions from design matrix dm_condition_cols = [col for col in dm.columns if cond in col] dm_cond = dm[dm_condition_cols] # Add a numeric delay column def extract_delay_number(condition_str): # Extracts the number at the end of a string like 'Activity_delay_5' return int(condition_str.split("_")[-1]) df_individual_filtered["DelayNum"] = df_individual_filtered["Condition"].apply(extract_delay_number) # Now separate and sort using numeric delay df_hbo = df_individual_filtered[df_individual_filtered["Chroma"] == "hbo"].sort_values("DelayNum") df_hbr = df_individual_filtered[df_individual_filtered["Chroma"] == "hbr"].sort_values("DelayNum") vals_hbo = df_hbo["theta"].values vals_hbr = df_hbr["theta"].values # Create the plot fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(19, 10)) # Scale design matrix components using numpy arrays instead of pandas operations dm_cond_values = dm_cond.values dm_cond_scaled_hbo = dm_cond_values * vals_hbo.reshape(1, -1) dm_cond_scaled_hbr = dm_cond_values * vals_hbr.reshape(1, -1) # Create time axis relative to stimulus onset time = dm_cond.index - np.ceil(first_onset_for_cond.get(cond, 0)) # Plot axes[0].plot(time, dm_cond_values) axes[1].plot(time, dm_cond_scaled_hbo) axes[2].plot(time, np.sum(dm_cond_scaled_hbo, axis=1), 'r') axes[2].plot(time, np.sum(dm_cond_scaled_hbr, axis=1), 'b') # Format plots for ax in range(3): axes[ax].set_xlim(-5, 25) axes[ax].set_xlabel("Time (s)") axes[0].set_ylim(-0.2, 1.2) axes[1].set_ylim(-0.5, 1) axes[2].set_ylim(-0.5, 1) axes[0].set_title(f"FIR Model (Unscaled)") axes[1].set_title(f"FIR Components (Scaled by {cond} GLM Estimates)") axes[2].set_title(f"Evoked Response ({cond})") axes[0].set_ylabel("FIR Model") axes[1].set_ylabel("Oxyhaemoglobin (ΔμMol)") axes[2].set_ylabel("Haemoglobin (ΔμMol)") axes[2].legend(["Oxyhaemoglobin", "Deoxyhaemoglobin"]) print(f"Number of FIR bins: {len(vals_hbo)}") print(f"Mean theta (HbO): {np.mean(vals_hbo):.4f}") print(f"Sum of theta (HbO): {np.sum(vals_hbo):.4f}") print(f"Mean theta (HbR): {np.mean(vals_hbr):.4f}") print(f"Sum of theta (HbR): {np.sum(vals_hbr):.4f}") fig_glms.append((f"Condition {cond}", fig)) return fig_glms def plot_3d_evoked_array( inst: Union[BaseRaw, EvokedArray, Info], statsmodel_df: DataFrame, picks: Optional[Union[str, list[str]]] = "hbo", value: str = "Coef.", background: str = "w", figure: Optional[object] = None, clim: Union[str, dict[str, Union[str, list[float]]]] = "auto", mode: str = "weighted", colormap: str = "RdBu_r", surface: str = "pial", hemi: str = "both", size: int = 800, view: Optional[Union[str, dict[str, float]]] = None, colorbar: bool = True, distance: float = 0.03, subjects_dir: Optional[str] = None, src: Optional[SourceSpaces] = None, verbose: bool = False, ) -> Brain: '''Ported from MNE''' info: Info = cast(Info, deepcopy(inst if isinstance(inst, Info) else inst.info)) # type: ignore if not (getattr(info, "ch_names") == list(statsmodel_df["ch_name"].values)): # type: ignore raise RuntimeError( 'MNE data structure does not match dataframe ' f'results.\nMNE = {getattr(info, "ch_names")}.\n' f'GLM = {list(statsmodel_df["ch_name"].values)}' # type: ignore ) ea = EvokedArray(np.tile(statsmodel_df[value].values.T, (1, 1)).T, info.copy()) # type: ignore # TODO: mimic behaviour of other MNE-NIRS glm plotting options if picks is not None: ea = ea.pick(picks=picks) # type: ignore if subjects_dir is None: subjects_dir = os.environ["SUBJECTS_DIR"] if src is None: fname_src_fs = os.path.join( subjects_dir, "fsaverage", "bem", "fsaverage-ico-5-src.fif" ) src = read_source_spaces(fname_src_fs, verbose=verbose) picks = getattr(ea, "info")["ch_names"] # Set coord frame for idx in range(len(getattr(ea, "ch_names"))): getattr(ea, "info")["chs"][idx]["coord_frame"] = 4 # Generate source estimate kwargs = dict( evoked=ea, subject="fsaverage", trans=Transform('head', 'mri', np.eye(4)), distance=distance, mode=mode, surface=surface, subjects_dir=subjects_dir, src=src, project=True, ) stc = stc_near_sensors(picks=picks, **kwargs, verbose=verbose) # type: ignore assert isinstance(stc, SourceEstimate) # Produce brain plot brain: Brain = stc.plot( # type: ignore src=src, subjects_dir=subjects_dir, hemi=hemi, surface=surface, initial_time=0, clim=clim, # type: ignore size=size, colormap=colormap, figure=figure, background=background, colorbar=colorbar, verbose=verbose, ) if view is not None: brain.show_view(view) # type: ignore return brain def brain_3d_visualization(raw_haemo, df_cha, selected_event, t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True, brain_bounds: float = 1.0) -> None: clim = dict(kind="value", pos_lims=(0, brain_bounds/2, brain_bounds)) # Get all activity conditions for cond in [f'{selected_event}']: if True: ch_summary = df_cha.query(f"Condition.str.startswith('{cond}_delay_') and Chroma == 'hbo'", engine='python') # type: ignore # Use ordinary least squares (OLS) if only one participant # TODO: Fix. if True: # t values if t_or_theta == 't': ch_model = smf.ols("t ~ -1 + ch_name", ch_summary).fit() # type: ignore # theta values elif t_or_theta == 'theta': ch_model = smf.ols("theta ~ -1 + ch_name", ch_summary).fit() # type: ignore print("OLS model is being used as there is only one participant!") # Convert model results model_df = cast(DataFrame, statsmodels_to_results(ch_model, order=ch_summary["ch_name"].unique())) # type: ignore valid_channels = ch_summary["ch_name"].unique().tolist() # type: ignore raw_for_plot = raw_haemo.copy().pick(picks=valid_channels) # type: ignore brain = plot_3d_evoked_array(raw_for_plot.pick(picks="hbo"), model_df, view="dorsal", distance=0.02, colorbar=True, clim=clim, mode="weighted", size=(800, 700)) # type: ignore if show_optodes == 'all' or show_optodes == 'sensors': brain.add_sensors(getattr(raw_for_plot, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=False) # type: ignore if True: display_text = ('Folder: ' + '\nGroup: ' + '\nCondition: '+ cond + '\nShort Channel Regression: ' + '\nLooking at: ' + t_or_theta + ' values') + '\nBrain Distance: ' # Apply the text onto the brain if show_text: brain.add_text(0.12, 0.64, display_text, "title", font_size=11, color="k") # type: ignore return brain def brain_landmarks_3d(raw_haemo: BaseRaw, show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_brodmann: bool = True) -> None: brain = Brain("fsaverage", background="white", size=(800, 700)) # type: ignore distances = source_detector_distances(raw_haemo.info) # Add optode text labels manually if show_optodes == 'all' or show_optodes == 'sensors': brain.add_sensors(getattr(raw_haemo, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=False) # type: ignore if show_optodes == 'all' or show_optodes == 'labels': labeled_srcs = set() labeled_dets = set() label_counts = {} for idx, ch in enumerate(raw_haemo.info['chs']): ch_name = ch['ch_name'] if not ch_name.endswith('hbo'): continue loc = ch['loc'] logger.info(f"Channel: {ch_name}") logger.info(f"loc length: {len(loc)}") logger.info("loc contents:") for i, val in enumerate(loc): logger.info(f" loc[{i}]: {val}") logger.info("-" * 30) if not ch_name or not ch['loc'].any(): continue parts = ch_name.split()[0] src_str, det_str = parts.split('_') src_num = int(src_str[1:]) det_num = int(det_str[1:]) if src_num not in labeled_srcs: src_xyz = ch['loc'][3:6] * 1000 brain._renderer.text3d(src_xyz[0], src_xyz[1], src_xyz[2], src_str, color='red', scale=0.002) labeled_srcs.add(src_num) if det_num not in labeled_dets: det_xyz = ch['loc'][6:9] * 1000 brain._renderer.text3d(det_xyz[0], det_xyz[1], det_xyz[2], det_str, color='blue', scale=0.002) labeled_dets.add(det_num) # Get the source-detector distance for this channel (in meters) dist_m = distances[idx] dist_mm = dist_m * 1000 label_text = f"{dist_mm:.1f} mm" label_counts[label_text] = label_counts.get(label_text, 0) + 1 if label_counts[label_text] > 1: label_text += f" ({label_counts[label_text]})" # Label at channel midpoint mid_xyz = loc[0:3] * 1000 logger.info(f"Channel: {ch_name} | Midpoint (mm): x={mid_xyz[0]:.2f}, y={mid_xyz[1]:.2f}, z={mid_xyz[2]:.2f} | Distance: {dist_mm:.1f} mm") brain._renderer.text3d( mid_xyz[0], mid_xyz[1], mid_xyz[2], label_text, color='gray', scale=0.002 ) if show_brodmann:# Add Brodmann labels labels = cast(list[Label], read_labels_from_annot("fsaverage", "PALS_B12_Brodmann", "lh", verbose=False)) # type: ignore label_colors = { "Brodmann.1-lh": "red", "Brodmann.2-lh": "red", "Brodmann.3-lh": "red", "Brodmann.4-lh": "orange", "Brodmann.5-lh": "green", "Brodmann.6-lh": "yellow", "Brodmann.7-lh": "green", "Brodmann.17-lh": "blue", "Brodmann.18-lh": "blue", "Brodmann.19-lh": "blue", "Brodmann.39-lh": "pink", "Brodmann.40-lh": "purple", "Brodmann.42-lh": "white", "Brodmann.44-lh": "white", "Brodmann.48-lh": "white", } for label in labels: name = getattr(label, "name", None) if not isinstance(name, str): continue if name in label_colors: brain.add_label(label, borders=False, color=label_colors[name]) # type: ignore return brain def verify_channel_positions(data: BaseRaw) -> None: """ Visualizes the sensor/channel positions of the raw data for verification. Parameters ---------- data : BaseRaw The loaded data object to process. """ def convert_fig_dict_to_png_bytes(fig_dict: dict[str, Figure]) -> dict[str, bytes]: png_dict = {} for label, fig in fig_dict.items(): buf = BytesIO() fig.savefig(buf, format="png", bbox_inches="tight") buf.seek(0) png_dict[label] = buf.read() plt.close(fig) return png_dict def brain_3d_contrast(con_model_df: DataFrame, con_model_df_filtered: BaseRaw, common_channels: list[str], first_name: str, second_name: str, t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True, brain_bounds: float = 1.0) -> None: # Filter DataFrame to only common channels, and sort by raw order con_model = con_model_df con_model["ch_name"] = pd.Categorical( con_model["ch_name"], categories=common_channels, ordered=True ) con_model = con_model.sort_values("ch_name").reset_index(drop=True) # type: ignore clim=dict(kind="value", pos_lims=(0, brain_bounds/2, brain_bounds)) # Plot brain figure brain = plot_3d_evoked_array(con_model_df_filtered.copy().pick(picks="hbo"), con_model, view="dorsal", distance=0.02, colorbar=True, mode="weighted", clim=clim, size=(800, 700), verbose=False) # type: ignore if show_optodes == 'all' or show_optodes == 'sensors': brain.add_sensors(getattr(con_model_df_filtered, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=False) # type: ignore display_text = ('Contrast: ' + first_name + ' - ' + second_name + '\nLooking at: ' + t_or_theta + ' values') # Apply the text onto the brain if show_text: brain.add_text(0.12, 0.70, display_text, "title", font_size=11, color="k") # type: ignore def plot_2d_3d_contrasts_between_groups( contrast_df_a: pd.DataFrame, contrast_df_b: pd.DataFrame, raw_haemo: BaseRaw, group_a_name: str, group_b_name: str, is_3d: bool = True, t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True, brain_bounds: float = 1.0, ) -> None: logger.info("-----") contrast_df_a = contrast_df_a.copy() contrast_df_a["group"] = group_a_name contrast_df_b = contrast_df_b.copy() contrast_df_b["group"] = group_b_name logger.info("-----") df_combined = pd.concat([contrast_df_a, contrast_df_b], ignore_index=True) con_summary = df_combined.query("Chroma == 'hbo'").copy() logger.info("-----") valid_channels = (pd.crosstab(con_summary["group"], con_summary["ch_name"]) > 1).all() valid_channels = valid_channels[valid_channels].index.tolist() con_summary = con_summary[con_summary["ch_name"].isin(valid_channels)] logger.info("-----") model_formula = "effect ~ -1 + group:ch_name:Chroma" con_model = smf.mixedlm(model_formula, con_summary, groups=con_summary["ID"]).fit(method="nm") logger.info("-----") if t_or_theta == "t": group1_vals = con_model.tvalues.filter(like=f"group[{group_a_name}]") group2_vals = con_model.tvalues.filter(like=f"group[{group_b_name}]") else: group1_vals = con_model.params.filter(like=f"group[{group_a_name}]") group2_vals = con_model.params.filter(like=f"group[{group_b_name}]") logger.info("-----") group1_channels = [name.split(":")[1].split("[")[1].split("]")[0] for name in group1_vals.index] group2_channels = [name.split(":")[1].split("[")[1].split("]")[0] for name in group2_vals.index] df_group1 = DataFrame({"Coef.": group1_vals.values}, index=group1_channels) df_group2 = DataFrame({"Coef.": group2_vals.values}, index=group2_channels) df_contrast = df_group1.join(df_group2, how="inner", lsuffix=f"_{group_a_name}", rsuffix=f"_{group_b_name}") logger.info("-----") # A - B df_contrast["Coef."] = df_contrast[f"Coef._{group_a_name}"] - df_contrast[f"Coef._{group_b_name}"] con_model_df_1_2 = DataFrame({ "ch_name": df_contrast.index, "Coef.": df_contrast["Coef."], "Chroma": "hbo" }) logger.info("-----") mne_ch_names = raw_haemo.copy().pick(picks="hbo").ch_names glm_ch_names = con_model_df_1_2["ch_name"].tolist() common_channels = [ch for ch in mne_ch_names if ch in glm_ch_names] con_model_df_filtered = raw_haemo.copy().pick(picks=common_channels) con_model_df_1_2 = con_model_df_1_2.set_index("ch_name").loc[common_channels].reset_index() logger.info("-----") if is_3d: brain_3d_contrast( con_model_df_1_2, con_model_df_filtered, common_channels, group_a_name, group_b_name, t_or_theta, show_optodes, show_text, brain_bounds ) else: plot_glm_group_topo(con_model_df_filtered.copy().pick(picks="hbo"), con_model_df_1_2, names=True, res=128, vlim=(-brain_bounds, brain_bounds)) # type: ignore # TODO: The title currently goes on the colorbar. Low priority plt.title(f"Contrast: {group_a_name} vs {group_b_name}") # type: ignore plt.show() # type: ignore # plt.title(f"Contrast: {group_a_name} vs {group_b_name}") # plt.show() # B - A df_contrast["Coef."] = df_contrast[f"Coef._{group_b_name}"] - df_contrast[f"Coef._{group_a_name}"] con_model_df_2_1 = DataFrame({ "ch_name": df_contrast.index, "Coef.": df_contrast["Coef."], "Chroma": "hbo" }) glm_ch_names = con_model_df_2_1["ch_name"].tolist() common_channels = [ch for ch in mne_ch_names if ch in glm_ch_names] con_model_df_filtered = raw_haemo.copy().pick(picks=common_channels) con_model_df_2_1 = con_model_df_2_1.set_index("ch_name").loc[common_channels].reset_index() if is_3d: brain_3d_contrast( con_model_df_2_1, con_model_df_filtered, common_channels, group_b_name, group_a_name, t_or_theta, show_optodes, show_text, brain_bounds ) else: plot_glm_group_topo(con_model_df_filtered.copy().pick(picks="hbo"), con_model_df_2_1, names=True, res=128, vlim=(-brain_bounds, brain_bounds)) # type: ignore # TODO: The title currently goes on the colorbar. Low priority plt.title(f"Contrast: {group_b_name} vs {group_a_name}") # type: ignore plt.show() # type: ignore def plot_fir_model_results(df, raw_haemo, dm, selected_event, l_bound, u_bound): df["isActivity"] = [f"{selected_event}" in n for n in df["Condition"]] df["isDelay"] = ["delay" in n for n in df["Condition"]] df = df.query("isDelay in [True]") df = df.query("isActivity in [True]") # Make a new column that stores the condition name for tidier model below df.loc[:, "TidyCond"] = "" df.loc[df["isActivity"] == True, "TidyCond"] = f"{selected_event}" # noqa: E712 # Finally, extract the FIR delay in to its own column in data frame df.loc[:, "delay"] = [n.split("_")[-1] for n in df.Condition] # To simplify this example we will only look at the activity # condition so we now remove the other conditions from the # design matrix and GLM results dm_cols_activity = np.where([f"{selected_event}" in c for c in dm.columns])[0] dm = dm[[dm.columns[i] for i in dm_cols_activity]] lme = smf.mixedlm("theta ~ -1 + delay:TidyCond:Chroma", df, groups=df["ID"]).fit() df_sum = statsmodels_to_results(lme) df_sum["delay"] = [int(n) for n in df_sum["delay"]] df_sum = df_sum.sort_values("delay") # Print the result for the oxyhaemoglobin data in the target condition df_sum.query(f"TidyCond in ['{selected_event}']").query("Chroma in ['hbo']") fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(19, 10)) print("dm columns:", dm.columns.tolist()) # Extract design matrix columns that correspond to the condition of interest dm_cond_idxs = np.where([f"{selected_event}" in n for n in dm.columns])[0] dm_cond = dm[[dm.columns[i] for i in dm_cond_idxs]] # Extract the corresponding estimates from the lme dataframe for hbo df_hbo = df_sum.query(f"TidyCond in ['{selected_event}']").query("Chroma in ['hbo']") vals_hbo = [float(v) for v in df_hbo["Coef."]] # print("--------------------------------------") # print(f"dm_cond shape: {dm_cond.shape}") # print(f"dm_cond columns: {dm_cond.columns.tolist()}") # print(f"vals_hbo length: {len(vals_hbo)}") # print(f"vals_hbo sample: {vals_hbo[:5]}") # print(f"vals_hbo type: {type(vals_hbo)}") # print(f"vals_hbo element type: {type(vals_hbo[0]) if len(vals_hbo) > 0 else 'N/A'}") dm_cond_scaled_hbo = dm_cond * vals_hbo # Extract the corresponding estimates from the lme dataframe for hbr df_hbr = df_sum.query(f"TidyCond in ['{selected_event}']").query("Chroma in ['hbr']") vals_hbr = [float(v) for v in df_hbr["Coef."]] dm_cond_scaled_hbr = dm_cond * vals_hbr first_onset = None for desc, onset in zip(raw_haemo.annotations.description, raw_haemo.annotations.onset): if selected_event in desc: first_onset = onset break if first_onset is None: raise ValueError(f"Selected event '{selected_event}' not found in annotations.") # Align index values (time axis) to the first occurrence of selected_event index_values = dm_cond_scaled_hbo.index - np.ceil(first_onset) index_values = np.asarray(index_values) # Plot the result axes[0].plot(index_values, np.asarray(dm_cond)) axes[1].plot(index_values, np.asarray(dm_cond_scaled_hbo)) axes[2].plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r") axes[2].plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b") valid_mask = (index_values >= 0) & (index_values <= 15) hbo_sum_window = np.sum(dm_cond_scaled_hbo.loc[valid_mask, :], axis=1) peak_idx_in_window = np.argmax(hbo_sum_window) peak_idx = np.where(valid_mask)[0][peak_idx_in_window] peak_time = float(round(index_values[peak_idx], 2)) # type: ignore axes[2].axvline(x=peak_time, color='k', linestyle='--', linewidth=1.5, label='Peak time') # type: ignore # Format the plot for ax in range(3): axes[ax].set_xlim(-5, 25) axes[ax].set_xlabel("Time (s)") axes[0].set_ylim(-0.1, 1.1) axes[1].set_ylim(l_bound, u_bound) axes[2].set_ylim(l_bound, u_bound) axes[0].set_title("FIR Model (Unscaled by GLM estimates)") axes[1].set_title(f"FIR Components (Scaled by {selected_event} GLM Estimates)") axes[2].set_title(f"Evoked Response {selected_event}") axes[0].set_ylabel("FIR Model") axes[1].set_ylabel("Oyxhaemoglobin (ΔμMol)") axes[2].set_ylabel("Haemoglobin (ΔμMol)") axes[2].legend(["Oyxhaemoglobin", "Deoyxhaemoglobin"]) # We can also extract the 95% confidence intervals of the estimates too l95_hbo = [float(v) for v in df_hbo["[0.025"]] # type: ignore u95_hbo = [float(v) for v in df_hbo["0.975]"]] # type: ignore dm_cond_scaled_hbo_l95 = dm_cond * l95_hbo dm_cond_scaled_hbo_u95 = dm_cond * u95_hbo l95_hbr = [float(v) for v in df_hbr["[0.025"]] # type: ignore u95_hbr = [float(v) for v in df_hbr["0.975]"]] # type: ignore dm_cond_scaled_hbr_l95 = dm_cond * l95_hbr dm_cond_scaled_hbr_u95 = dm_cond * u95_hbr axes2: Axes fig2, axes2 = plt.subplots(nrows=1, ncols=1, figsize=(7, 7)) # type: ignore # Plot the result axes2.plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r") # type: ignore axes2.plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b") # type: ignore axes2.axvline(x=peak_time, color='k', linestyle='--', linewidth=1.5, label='Peak time') # type: ignore axes2.fill_between( # type: ignore index_values, np.asarray(np.sum(dm_cond_scaled_hbo_l95, axis=1)), np.asarray(np.sum(dm_cond_scaled_hbo_u95, axis=1)), facecolor="red", alpha=0.25, ) axes2.fill_between( # type: ignore index_values, np.asarray(np.sum(dm_cond_scaled_hbr_l95, axis=1)), np.asarray(np.sum(dm_cond_scaled_hbr_u95, axis=1)), facecolor="blue", alpha=0.25, ) # Format the plot axes2.set_xlim(-5, 20) axes2.set_ylim(l_bound, u_bound) axes2.set_title(f"Evoked Response with 95% confidence intervals for )") # type: ignore axes2.set_ylabel("Haemoglobin (ΔμMol)") # type: ignore axes2.legend(["Oyxhaemoglobin", "Deoyxhaemoglobin", f"Peak {peak_time}s"]) # type: ignore axes2.set_xlabel("Time (s)") # type: ignore fig2.tight_layout() fig.show() fig2.show() def load_snirf(file_path: str) -> tuple[BaseRaw, Figure]: """ Loads a snirf file, optionally drops channels, downsamples, and creates a figure showing the results. Parameters ---------- file_path : str Path of the snirf file to load. ID : str File name of the the snirf file that was loaded. drop_prefixes : list[str] List of channel name prefixes to drop from the data. Returns ------- tuple[BaseRaw, Figure] - BaseRaw: The processed data object. - Figure: The corresponding Matplotlib figure. """ # Read the snirf file raw = read_raw_snirf(file_path, preload=True, verbose=VERBOSITY) # type: ignore raw.load_data(verbose=VERBOSITY) # type: ignore # # Strip the specified amount of seconds from the start of the file # total_duration = getattr(raw, "times")[-1] # if total_duration > SECONDS_TO_STRIP: # raw.crop(tmin=SECONDS_TO_STRIP, tmax=total_duration, verbose=VERBOSITY) # type: ignore # logger.info(f"Stripped first {SECONDS_TO_STRIP} second(s) of data.") # else: # logger.info(f"Data length ({total_duration:.2f}s) less than strip duration; no cropping applied.") # If the user forcibly dropped channels, remove them now before any processing occurs # logger.info("Checking if there are channels to forcibly drop...") # if drop_prefixes: # logger.info("Force dropped channels was specified.") # channels_to_drop = [ch for ch in cast(list[str], getattr(raw, "ch_names")) if any(ch.startswith(prefix) for prefix in drop_prefixes)] # raw.drop_channels(channels_to_drop, "raise") # type: ignore # logger.info("Force dropped channels:", channels_to_drop) # If the user wants to downsample, do it right away logger.info("Checking if we should downsample...") if DOWNSAMPLE: logger.info("Downsample was specified.") sfreq_old = getattr(raw, "info")["sfreq"] raw.resample(DOWNSAMPLE_FREQUENCY, verbose=VERBOSITY) # type: ignore sfreq_new = getattr(raw, "info")["sfreq"] logger.info(f"Finished downsampling. Old frequency: {sfreq_old}. New frequency: {sfreq_new}.") logger.info("Successfully loaded the snirf file.") return raw def run_second_level_analysis(df_contrasts, raw, p, bounds): """ Perform second-level analysis using contrast data from multiple participants. Parameters ---------- df_contrasts : pd.DataFrame Combined contrast results from multiple participants. Must include: ['ch_name', 'effect', 'ID'] Returns ------- pd.DataFrame Group-level t-values, p-values, and mean effect per channel. """ if not all(col in df_contrasts.columns for col in ['ch_name', 'effect', 'ID']): raise ValueError("Input DataFrame must include 'ch_name', 'effect', and 'ID' columns.") channels = df_contrasts['ch_name'].unique() group_results = [] for ch in channels: ch_data = df_contrasts[df_contrasts['ch_name'] == ch] if ch_data['ID'].nunique() < 2: logger.warning(f"Skipping channel {ch} — not enough subjects.") continue Y = ch_data['effect'].values design_matrix = np.ones((len(Y), 1)) # intercept-only model = OLSModel(design_matrix) result = model.fit(Y) t_val = result.t(0).item() p_val = 2 * stats.t.sf(np.abs(t_val), df=result.df_model) mean_beta = np.mean(Y) group_results.append({ 'ch_name': ch, 't_val': t_val, 'p_val': p_val, 'mean_beta': mean_beta, 'n_subjects': len(Y) }) df_group = pd.DataFrame(group_results) logger.info("Second-level results:\n%s", df_group) # Extract the cource and detector positions from raw src_pos: dict[int, tuple[float, float]] = {} det_pos: dict[int, tuple[float, float]] = {} for ch in getattr(raw, "info")["chs"]: ch_name = ch['ch_name'] if not ch_name or not ch['loc'].any(): continue parts = ch_name.split()[0] src_str, det_str = parts.split('_') src_num = int(src_str[1:]) det_num = int(det_str[1:]) src_pos[src_num] = ch['loc'][3:5] det_pos[det_num] = ch['loc'][6:8] # Set up the plot fig, ax = plt.subplots(figsize=(8, 6)) # type: ignore # Plot the sources for pos in src_pos.values(): ax.scatter(pos[0], pos[1], s=120, c='k', marker='o', edgecolors='white', linewidths=1, zorder=3) # type: ignore # Plot the detectors for pos in det_pos.values(): ax.scatter(pos[0], pos[1], s=120, c='k', marker='s', edgecolors='white', linewidths=1, zorder=3) # type: ignore # Ensure that the colors stay within the boundaries even if they are over or under the max/min values norm = mcolors.Normalize(vmin=-bounds, vmax=bounds) cmap: mcolors.Colormap = plt.get_cmap('seismic') # Plot connections with avg t-values for _, row in df_group.iterrows(): ch = row['ch_name'] pval = row['p_val'] tval = row['t_val'] if '_' not in ch: logger.info(f"Skipping channel with unexpected format (no underscore): {ch}") continue src_str, det_str = ch.split('_') det_parts = det_str.split() detector_id = det_parts[0] # e.g. "D1" hemo_type = det_parts[1].lower() if len(det_parts) > 1 else '' logger.info(f"Parsing channel: {ch} -> src_str: {src_str}, det_str: {detector_id}, hemo_type: {hemo_type}") if hemo_type != 'hbo': logger.info(f"Skipping channel {ch} because hemo_type is not HbO: {hemo_type}") continue try: src = int(src_str[1:]) det = int(detector_id[1:]) logger.info(f"Parsed src: {src}, det: {det}") except Exception as e: logger.info(f"Error parsing source/detector from channel '{ch}': {e}") continue if src in src_pos and det in det_pos: x = [src_pos[src][0], det_pos[det][0]] y = [src_pos[src][1], det_pos[det][1]] style = '-' if pval <= p else '--' color = cmap(norm(tval)) logger.info(f"Plotting {ch}: t={tval:.2f}, p={pval:.3f}, color={color}, style={style}") ax.plot(x, y, linestyle=style, color=color, linewidth=4, alpha=0.9, zorder=2) # Format the Colorbar sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) cbar = plt.colorbar(sm, ax=ax, shrink=0.85) # type: ignore cbar.set_label(f'Average value (hbo)', fontsize=11) # type: ignore # Formatting the subplots ax.set_aspect('equal') ax.set_title(f"Average values (HbO)", fontsize=14) # type: ignore ax.set_xlabel('X position (m)', fontsize=11) # type: ignore ax.set_ylabel('Y position (m)', fontsize=11) # type: ignore ax.grid(True, alpha=0.3) # type: ignore # Set axis limits to be 1cm more than the optode positions all_x = [pos[0] for pos in src_pos.values()] + [pos[0] for pos in det_pos.values()] all_y = [pos[1] for pos in src_pos.values()] + [pos[1] for pos in det_pos.values()] ax.set_xlim(min(all_x)-0.01, max(all_x)+0.01) ax.set_ylim(min(all_y)-0.01, max(all_y)+0.01) fig.tight_layout() plt.show() # type: ignore return df_group def calculate_dpf(file_path): # order is hbo / hbr with h5py.File(file_path, 'r') as f: wavelengths = f['/nirs/probe/wavelengths'][:] logger.info(f"Wavelengths (nm): {wavelengths}") wavelengths = sorted(wavelengths, reverse=True) age = float(AGE) logger.info(f"Their age was {AGE}") a = 223.3 b = 0.05624 c = 0.8493 d = -5.723e-7 e = 0.001245 f = -0.9025 dpf = [] for w in wavelengths: logger.info(w) dpf.append(a + b * (age**c) + d* (w**3) + e * (w**2) + f*w) logger.info(dpf) return dpf def iqr_threshold(coeffs: NDArray[float64], k: float = 1.5) -> floating[Any]: """ Calculate the interquartile range (IQR) threshold scaled by a factor, k. Parameters ---------- coeffs : NDArray[float64] Array of coefficients to compute the IQR from. k : float, optional Scaling factor for the IQR (default is 1.5). Returns ------- floating[Any] The scaled IQR threshold value. """ # Calculate the IQR q1 = np.percentile(coeffs, 25) q3 = np.percentile(coeffs, 75) iqr = q3 - q1 return k * iqr def wavelet_iqr_denoise(signal: NDArray[float64], wavelet: str = 'db4', level: int = 3) -> NDArray[float64]: """ Denoises a signal using wavelet decomposition and IQR-based thresholding on detail coefficients. Parameters ---------- signal : NDArray[float64] The input signal array to denoise. wavelet : str, optional The type of wavelet to use for decomposition (default is 'db4'). level : int, optional Decomposition level for wavelet transform (default is 3). Returns ------- NDArray[float64] The denoised signal array, with the same length as the input. """ # Decompose the signal using wavelet transform and initialize a list with approximation coefficients coeffs: list[NDArray[float64]] = pywt.wavedec(signal, wavelet, level=level) # type: ignore cA = coeffs[0] denoised_coeffs = [cA] # Threshold detail coefficients to reduce noise for cD in coeffs[1:]: threshold = iqr_threshold(cD, IQR) cD_thresh = np.sign(cD) * np.maximum(np.abs(cD) - threshold, 0.0) # np.where((cD < lower) | (cD > upper), 0, cD) cD_thresh = cD_thresh.astype(float64) denoised_coeffs.append(cD_thresh) # Reconstruct the denoised signal denoised_signal = cast(NDArray[float64], pywt.waverec(denoised_coeffs, wavelet)) # type: ignore return denoised_signal[:len(signal)] def calculate_and_apply_wavelet(data: BaseRaw) -> tuple[BaseRaw, Figure]: """ Applies a wavelet IQR denoising filter to the data and generates a plot. Parameters ---------- data : BaseRaw The loaded data object to process. ID : str File name of the the snirf file that was loaded. Returns ------- tuple[BaseRaw, Figure] - BaseRaw: The processed data object. - Figure: The corresponding Matplotlib figure. """ logger.info("Applying the wavelet filter...") # Denoise the data logger.info("Denoising the data...") loaded_data: NDArray[float64] = data.get_data(verbose=VERBOSITY) # type: ignore denoised_data = np.zeros_like(loaded_data) logger.info("Calculating the IQR, decomposing the signal, and thresholding the coefficients...") for ch in range(loaded_data.shape[0]): denoised_data[ch, :] = wavelet_iqr_denoise(loaded_data[ch, :], wavelet='db4', level=3) # Reconstruct the data with the annotations logger.info("Reconstructing the data with annotations...") raw_with_tddr_and_wavelet = RawArray(denoised_data, cast(Info, data.info), verbose=VERBOSITY) raw_with_tddr_and_wavelet.set_annotations(data.annotations.copy(), verbose=VERBOSITY) # type: ignore # Create a figure for the results logger.info("Creating the figure...") fig = cast(Figure, raw_with_tddr_and_wavelet.plot(show=False, n_channels=len(getattr(data, "ch_names")), duration=data.times[-1]).figure) # type: ignore fig.suptitle(f"Wavelet for ", fontsize=16) # type: ignore fig.subplots_adjust(top=0.92) plt.close(fig) logger.info("Successfully applied the wavelet filter.") return raw_with_tddr_and_wavelet, fig def short_channel_processing_for_hr(data: BaseRaw, short_chans: BaseRaw | None) -> tuple[float, NDArray[float64], NDArray[float64]]: """ Extract and trim short-channel fNIRS signal for heart rate analysis. Parameters ---------- data : BaseRaw The loaded data object to process. short_chans : BaseRaw | None Data object with only short separation channels, or None if unavailable. Returns ------- tuple[float, NDArray[float64], NDArray[float64]] - float: Sampling frequency of the signal. - NDArray[float64]: Trimmed short-channel signal. - NDArray[float64]: Corresponding time values. """ # Find the short channel (or best candidate) and extract signal data and sampling frequency logger.info("Extracting the signal and calculating the sampling frequency...") # If a short channel exists, use it for our signal. Otherwise just take the first channel in the data # TODO: Find a better way around this if short_chans is not None: signal = cast(NDArray[float64], short_chans.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore else: signal = cast(NDArray[float64], data.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore # Calculate the sampling frequency sfreq = cast(int, data.info['sfreq']) # Trim start and end of the signal to remove edge artifacts logger.info(f"Removing {SECONDS_TO_STRIP_HR} seconds from the beginning and end of the file...") strip_samples = int(sfreq * SECONDS_TO_STRIP_HR) signal_trimmed = signal[strip_samples:-strip_samples] times_trimmed = cast(NDArray[float64], getattr(data, "times"))[strip_samples:-strip_samples] return sfreq, signal_trimmed, times_trimmed def calculate_heart_rate_neurokit(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[float64], float]: """ Calculate and smooth heart rate from a trimmed signal using NeuroKit. Parameters ---------- sfreq : float Sampling frequency of the signal. signal_trimmed : NDArray[float64] Preprocessed and trimmed fNIRS signal. Returns ------- tuple[NDArray[float64], float] - NDArray[float64]: Smoothed heart rate time series (BPM). - float: Mean heart rate. """ logger.info("Calculating heart rate using NeuroKit...") # Filter signal to isolate heart rate frequencies and detect peaks logger.info("Filtering the signal and detecting peaks...") signal_filtered = cast(NDArray[float64], nk.signal_filter(signal_trimmed, sampling_rate=sfreq, lowcut=0.8, highcut=2.5)) # type: ignore peaks_dict = cast(dict[str, Any], nk.signal_findpeaks(signal_filtered)) # type: ignore peaks = peaks_dict['Peaks'] hr = cast(NDArray[float64], nk.signal_rate(peaks, sampling_rate=sfreq, desired_length=len(signal_trimmed))) # type: ignore hr_clean = np.clip(hr, MAX_LOW_HR, MAX_HIGH_HR) # Smooth heart rate time series by replacing spikes with local rolling mean and calculate the mean logger.info("Smoothing the signal and calculating the mean...") hr_series = pd.Series(hr_clean) local_median = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).median() spikes = hr_series > (local_median + 10) smoothed_values = hr_series.copy() smoothed_spikes = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).mean() smoothed_values[spikes] = smoothed_spikes[spikes] hr_smooth_nk = cast(NDArray[float64], smoothed_values.to_numpy()) # type: ignore mean_hr_nk = hr_smooth_nk.mean() logger.info("Original HR min/max: %f, %f", hr_clean.min(), hr_clean.max()) logger.info("Smoothed HR min/max:%f, %f", hr_smooth_nk.min(), hr_smooth_nk.max()) logger.info(f"Estimated mean HR nk: {mean_hr_nk:.1f} BPM") logger.info("Successfully calculated heart rate using NeuroKit.") return hr_smooth_nk, mean_hr_nk def calculate_heart_rate_scipy(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float]: """ Estimate heart rate using spectral analysis on a high-pass filtered signal. Parameters ---------- sfreq : float Sampling frequency of the input signal. signal_trimmed : NDArray[float64] Trimmed fNIRS signal to analyze. Returns ------- tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float] - NDArray[floating[Any]]: Frequencies converted to beats per minute (BPM). - NDArray[float64]: Power spectral density (PSD) of the signal. - np.ndarray[Any, np.dtype[np.bool_]]: Boolean mask indicating frequencies within heart rate range (30-300 BPM). - float: Estimated mean heart rate in BPM corresponding to the PSD peak within the range. """ logger.info("Calculating heart rate using SciPy...") # Apply a high-pass Butterworth filter to remove slow trends below 0.5 Hz from the trimmed signal (actual data) logger.info("Applying a butterworth filter...") b, a = cast(tuple[NDArray[float64], NDArray[float64]], butter(2, 0.5 / (sfreq / 2), btype='high')) signal_hp = cast(NDArray[float64],filtfilt(b, a, signal_trimmed)) # Calculate the Power Spectral Density (PSD) of the filtered signal using Welch's method logger.info("Calculating the PSD...") nperseg = min(len(signal_hp), 4096) frequencies_scipy, psd_scipy = cast(tuple[NDArray[float64], NDArray[float64]], welch(signal_hp, fs=sfreq, nperseg=nperseg, noverlap=nperseg//2)) # Convert frequency values to beats per minute (BPM) and set a heart rate range (30-300 BPM) logger.info("Converting to BPM...") freq_bpm_scipy = frequencies_scipy * 60 freq_range_scipy = (freq_bpm_scipy > 30) & (freq_bpm_scipy < 300) # Identify the peak frequency within the heart rate range and estimate the mean heart rate in BPM logger.info("Finding the mean...") peak_index = np.argmax(psd_scipy[freq_range_scipy]) mean_hr_scipy = freq_bpm_scipy[freq_range_scipy][peak_index] logger.info("Successfully calculated heart rate using SciPy.") return freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy def plot_heart_rate( freq_bpm_scipy: NDArray[floating[Any]], psd_scipy: NDArray[float64], freq_range_scipy: np.ndarray[Any, np.dtype[np.bool_]], mean_hr_scipy: float, hr_smooth_nk: NDArray[floating[Any]], mean_hr_nk: float, times_trimmed: NDArray[floating[Any]], overruled: bool ) -> tuple[Figure, Figure]: """ Generate plots comparing heart rate estimates from SciPy PSD and NeuroKit2. Parameters ---------- freq_bpm_scipy : NDArray[floating[Any]] Frequencies in beats per minute from SciPy PSD analysis. psd_scipy : NDArray[float64] Power spectral density values corresponding to freq_bpm_scipy. freq_range_scipy : np.ndarray[Any, np.dtype[np.bool_]] Boolean mask indicating the heart rate frequency range used in PSD. mean_hr_scipy : float Mean heart rate estimated from SciPy PSD peak. hr_smooth_nk : NDArray[floating[Any]] Smoothed instantaneous heart rate from NeuroKit2. mean_hr_nk : float Mean heart rate estimated from NeuroKit2 data. times_trimmed : NDArray[floating[Any]] Time points corresponding to hr_smooth_nk values. overruled: bool True if the heart rate from NeuroKit2 is overriding the results from the PSD. Returns ------- tuple[Figure, Figure] - Figure showing the PSD and SciPy heart rate estimate. - Figure showing the time series comparison of heart rates. """ # Create the first plot for the PSD. Add a yellow range to show what we will be filtering to. logger.info("Creating the figure...") fig1, ax1 = plt.subplots(figsize=(10, 5)) # type: ignore ax1.set_xlim(30, 300) ax1.plot(freq_bpm_scipy[freq_range_scipy], psd_scipy[freq_range_scipy]) # type: ignore ax1.axvline(x=mean_hr_scipy, color='red', linestyle='--', label=f'Mean HR: {mean_hr_scipy:.1f} BPM') # type: ignore ax1.axvspan(min(mean_hr_nk - HEART_RATE_WINDOW, mean_hr_scipy - HEART_RATE_WINDOW), max(mean_hr_nk + HEART_RATE_WINDOW, mean_hr_scipy + HEART_RATE_WINDOW), color='yellow', alpha=0.3, label=f'HR Range ±{HEART_RATE_WINDOW} BPM') # type: ignore ax1.set_xlabel('Heart Rate (BPM)') # type: ignore ax1.set_ylabel('Power Spectral Density') # type: ignore ax1.set_title('PSD of fNIRS signal - Peak indicates Heart Rate') # type: ignore ax1.grid(True) # type: ignore # Was the value we reported here correct for the data on the graph or was it overruled? if overruled: note = ( '\n' 'Note: Calculation was bad!\n' 'Data has been set to match\n' 'the value from NeuroKit2.' ) phantom = Line2D([0], [0], color='none', label=note) handles, _ = ax1.get_legend_handles_labels() ax1.legend(handles=handles + [phantom]) # type: ignore else: ax1.legend() # type: ignore plt.close(fig1) # Create the second plot showing the rolling heart rate, as well as the two averages that were calculated logger.info("Creating the figure...") fig2, ax2 = plt.subplots(figsize=(14, 6)) # type: ignore ax2.plot(times_trimmed, hr_smooth_nk, label='Instantaneous HR (NeuroKit2)', color='blue', alpha=0.7) # type: ignore ax2.axhline(mean_hr_nk, color='red', linestyle='--', label=f'Mean HR NeuroKit2: {mean_hr_nk:.1f} BPM') # type: ignore ax2.axhline(mean_hr_scipy, color='orange', linestyle=':', label=f'SciPy Welch PSD (HP filtered): {mean_hr_scipy:.1f} BPM') # type: ignore ax2.set_xlabel('Time (seconds)') # type: ignore ax2.set_ylabel('Heart Rate (BPM)') # type: ignore ax2.set_title('Heart Rate Estimates Comparison') # type: ignore ax2.legend() # type: ignore ax2.grid(True) # type: ignore fig2.tight_layout() plt.close(fig2) return fig1, fig2 def hr_calc(raw): if SHORT_CHANNEL: short_chans = get_short_channels(raw, max_dist=SHORT_CHANNEL_THRESH) else: short_chans = None sfreq, signal_trimmed, times_trimmed = short_channel_processing_for_hr(raw, short_chans) hr_smooth_nk, mean_hr_nk = calculate_heart_rate_neurokit(sfreq, signal_trimmed) freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy = calculate_heart_rate_scipy(sfreq, signal_trimmed) # HACK: This sucks but looking at the graphs I trust neurokit2 more overruled = False if mean_hr_scipy < mean_hr_nk - 15: mean_hr_scipy = mean_hr_nk overruled = True if mean_hr_scipy > mean_hr_nk + 15: mean_hr_scipy = mean_hr_nk overruled = True hr1, hr2 = plot_heart_rate(freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy, hr_smooth_nk, mean_hr_nk, times_trimmed, overruled) fig = raw.plot_psd(show=False) raw_filtered = raw.copy().filter(0.5, 3, fir_design='firwin') sfreq = raw.info['sfreq'] data = raw_filtered.get_data() channel_names = raw.ch_names # --- Parameters for PSD --- desired_bin_hz = 0.1 nperseg = int(sfreq / desired_bin_hz) hr_range = (30, 180) # --- Function to find strongest local peak --- def find_hr_from_psd(ch_data): f, Pxx = welch(ch_data, sfreq, nperseg=nperseg) mask = (f >= hr_range[0]/60) & (f <= hr_range[1]/60) f_masked = f[mask] Pxx_masked = Pxx[mask] if len(Pxx_masked) < 3: return np.nan peaks = [i for i in range(1, len(Pxx_masked)-1) if Pxx_masked[i] > Pxx_masked[i-1] and Pxx_masked[i] > Pxx_masked[i+1]] if not peaks: return np.nan best_idx = peaks[np.argmax([Pxx_masked[i] for i in peaks])] return f_masked[best_idx] * 60 # bpm # --- Compute HR across all channels --- hr_all_channels = np.array([find_hr_from_psd(data[i, :]) for i in range(len(channel_names))]) hr_all_channels = hr_all_channels[~np.isnan(hr_all_channels)] hr_mode = np.round(np.median(hr_all_channels)) # Use median if some NaNs print(f"Estimated Heart Rate: {hr_mode} bpm") hr_freq = hr_mode / 60 # Hz low = hr_freq - 0.3 high = hr_freq + 0.3 return fig, hr1, hr2, low, high def process_participant(file_path, progress_callback=None): fig_individual: dict[str, Figure] = {} # Step 1: Load raw = load_snirf(file_path) fig_raw = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Loaded Raw", show=False) fig_individual["Loaded Raw"] = fig_raw if progress_callback: progress_callback(1) logger.info("1") if hasattr(raw, 'annotations') and len(raw.annotations) > 0: # Get time of first event first_event_time = raw.annotations.onset[0] trim_time = max(0, first_event_time - 5.0) # Ensure we don't go negative raw.crop(tmin=trim_time) # Shift annotation onsets to match new t=0 import mne ann = raw.annotations ann_shifted = mne.Annotations( onset=ann.onset - trim_time, # shift to start at zero duration=ann.duration, description=ann.description ) data = raw.get_data() info = raw.info.copy() raw = mne.io.RawArray(data, info) raw.set_annotations(ann_shifted) logger.info(f"Trimmed raw data: start at {trim_time}s (5s before first event), t=0 at new start") else: logger.warning("No events found, skipping trim step.") fig_trimmed = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Trimmed Raw", show=False) fig_individual["Trimmed Raw"] = fig_trimmed if progress_callback: progress_callback(2) logger.info("2") # Step 1.5: Verify optode positions fig_optodes = raw.plot_sensors(show_names=True, to_sphere=True, show=False) # type: ignore fig_individual["Plot Sensors"] = fig_optodes if progress_callback: progress_callback(2) logger.info("2") # Step 2: Downsample # raw = raw.resample(0.5) # Downsample to 0.5 Hz # Step 2: Bad from SCI if True: fig, hr1, hr2, low, high = hr_calc(raw) fig_individual["PSD"] = fig fig_individual['HeartRate_PSD'] = hr1 fig_individual['HeartRate_Time'] = hr2 if progress_callback: progress_callback(10) if progress_callback: progress_callback(2) bad_sci = [] if SCI: bad_sci, fig_sci_1, fig_sci_2 = calculate_scalp_coupling(raw, low, high) fig_individual["SCI1"] = fig_sci_1 fig_individual["SCI2"] = fig_sci_2 if progress_callback: progress_callback(3) logger.info("3") # Step 2: Bad from SNR bad_snr = [] if SNR: bad_snr, fig_snr = calculate_signal_noise_ratio(raw) fig_individual["SNR1"] = fig_snr if progress_callback: progress_callback(4) logger.info("4") # Step 3: Bad from PSP bad_psp = [] if PSP: bad_psp, fig_psp1, fig_psp2 = calculate_peak_power(raw) fig_individual["PSP1"] = fig_psp1 fig_individual["PSP2"] = fig_psp2 if progress_callback: progress_callback(5) logger.info("5") # Step 4: Mark the bad channels raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp) if fig_dropped and fig_raw_before is not None: fig_individual["fig2"] = fig_dropped fig_individual["fig3"] = fig_raw_before if progress_callback: progress_callback(6) logger.info("6") # Step 5: Interpolate the bad channels if bad_channels: raw, fig_raw_after = interpolate_fNIRS_bads_weighted_average(raw, bad_channels) fig_individual["fig4"] = fig_raw_after if progress_callback: progress_callback(7) logger.info("7") # Step 6: Optical Density raw_od = optical_density(raw) fig_raw_od = raw_od.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Optical Density", show=False) fig_individual["Optical Density"] = fig_raw_od if progress_callback: progress_callback(8) logger.info("8") # Step 7: TDDR if TDDR: raw_od = temporal_derivative_distribution_repair(raw_od) fig_raw_od_tddr = raw_od.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="After TDDR (Motion Correction)", show=False) fig_individual["TDDR"] = fig_raw_od_tddr if progress_callback: progress_callback(9) logger.info("9") raw_od, fig = calculate_and_apply_wavelet(raw_od) fig_individual["Wavelet"] = fig if progress_callback: progress_callback(9) # Step 8: BLL raw_haemo = beer_lambert_law(raw_od, ppf=calculate_dpf(file_path)) fig_raw_haemo_bll = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="HbO and HbR Signals", show=False) fig_individual["BLL"] = fig_raw_haemo_bll if progress_callback: progress_callback(10) logger.info("10") # Step 9: ENC # raw_haemo = enhance_negative_correlation(raw_haemo) # fig_raw_haemo_enc = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="HbO and HbR Signals", show=False) # fig_individual.append(fig_raw_haemo_enc) # Step 10: Filter fig_filter, fig_raw_haemo_filter = filter_the_data(raw_haemo) fig_individual["filter1"] = fig_filter fig_individual["filter2"] = fig_raw_haemo_filter if progress_callback: progress_callback(11) logger.info("11") # Step 11: Get short / long channels if SHORT_CHANNEL: short_chans = get_short_channels(raw_haemo, max_dist=0.02) fig_short_chans = short_chans.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="Short Channels Only", show=False) fig_individual["short"] = fig_short_chans else: short_chans = None raw_haemo = get_long_channels(raw_haemo) if progress_callback: progress_callback(12) logger.info("12") # Step 12: Events from annotations events, event_dict = events_from_annotations(raw_haemo) fig_events = plot_events(events, event_id=event_dict, sfreq=raw_haemo.info["sfreq"], show=False) fig_individual["events"] = fig_events if progress_callback: progress_callback(13) logger.info("13") # Step 13: Epoch calculations epochs, fig_epochs = epochs_calculations(raw_haemo, events, event_dict) for name, fig in fig_epochs: # Unpack the tuple here fig_individual[f"epochs_{name}"] = fig # Store only the figure, not the name if progress_callback: progress_callback(14) logger.info("14") # Step 14: Design Matrix events_to_remove = REMOVE_EVENTS filtered_annotations = [ann for ann in raw.annotations if ann['description'] not in events_to_remove] new_annot = Annotations( onset=[ann['onset'] for ann in filtered_annotations], duration=[ann['duration'] for ann in filtered_annotations], description=[ann['description'] for ann in filtered_annotations] ) # Set the new annotations raw_haemo.set_annotations(new_annot) design_matrix, fig_design_matrix = make_design_matrix(raw_haemo, short_chans) fig_individual["Design Matrix"] = fig_design_matrix if progress_callback: progress_callback(15) logger.info("15") # Step 15: Run GLM glm_est = run_glm(raw_haemo, design_matrix) # Not used AppData\Local\Packages\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\LocalCache\local-packages\Python313\site-packages\nilearn\glm\contrasts.py # Yes used AppData\Local\Packages\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\LocalCache\local-packages\Python313\site-packages\mne_nirs\utils\_io.py # The p-value is calculated from this t-statistic using the Student’s t-distribution with appropriate degrees of freedom. # p_value = 2 * stats.t.cdf(-abs(t_statistic), df) # It is a two-tailed p-value. # It says how likely it is to observe the effect you did (or something more extreme) if the true effect was zero (null hypothesis). # A small p-value (e.g., < 0.05) suggests the effect is unlikely to be zero — it’s "statistically significant." # A large p-value means the data do not provide strong evidence that the effect is different from zero. if progress_callback: progress_callback(16) logger.info("16") # Step 16: Plot GLM results fig_glm_result = plot_glm_results(file_path, raw_haemo, glm_est, design_matrix) for name, fig in fig_glm_result: fig_individual[f"GLM {name}"] = fig if progress_callback: progress_callback(17) logger.info("17") # Step 17: Plot channel significance fig_significance = individual_significance(raw_haemo, glm_est) for name, fig in fig_significance: fig_individual[f"Significance {name}"] = fig if progress_callback: progress_callback(18) logger.info("18") # Step 18: cha, con, roi cha = glm_est.to_dataframe() # HACK: Comment out line 588 (self._renderer.show()) in _brain.py from MNE # brain_thing = brain_3d_visualization(cha, raw_haemo) # brain_individual.append(brain_thing) # C++ objects made this get rendered on the fly rois = dict(AllChannels=range(len(raw_haemo.ch_names))) # Calculate ROI for all conditions conditions = design_matrix.columns # Compute output metrics by ROI df_ind = glm_est.to_dataframe_region_of_interest(rois, conditions) df_ind["ID"] = file_path # Step 18: Fold channels # fig_fold_data, fig_fold_legend = fold_channels(raw_haemo) # fig_individual.append(fig_fold_data) # fig_individual.append(fig_fold_legend) print(design_matrix) contrast_matrix = np.eye(design_matrix.shape[1]) basic_conts = dict( [(column, contrast_matrix[i]) for i, column in enumerate(design_matrix.columns)] ) all_delay_cols = [col for col in design_matrix.columns if "_delay_" in col] all_conditions = sorted({col.split("_delay_")[0] for col in all_delay_cols}) if not all_conditions: raise ValueError("No FIR regressors found in the design matrix.") # Build contrast vectors for each condition contrast_dict = {} for condition in all_conditions: delay_cols = [ col for col in all_delay_cols if col.startswith(f"{condition}_delay_") and TIME_WINDOW_START <= int(col.split("_delay_")[-1]) <= TIME_WINDOW_END ] if not delay_cols: continue # skip if no columns found (shouldn't happen?) # Average across all delay regressors for this condition contrast_vector = np.sum([basic_conts[col] for col in delay_cols], axis=0) contrast_vector /= len(delay_cols) contrast_dict[condition] = contrast_vector if progress_callback: progress_callback(19) logger.info("19") # Compute contrast results contrast_results = {} for cond, contrast_vector in contrast_dict.items(): contrast = glm_est.compute_contrast(contrast_vector) # type: ignore df = contrast.to_dataframe() df["ID"] = file_path contrast_results[cond] = df cha["ID"] = file_path fig_bytes = convert_fig_dict_to_png_bytes(fig_individual) if progress_callback: progress_callback(20) logger.info("20") sanitize_paths_for_pickle(raw_haemo, epochs) return raw_haemo, epochs, fig_bytes, cha, contrast_results, df_ind, design_matrix, AGE, GENDER, GROUP, True def sanitize_paths_for_pickle(raw_haemo, epochs): # Fix raw_haemo._filenames if hasattr(raw_haemo, '_filenames'): raw_haemo._filenames = [str(p) for p in raw_haemo._filenames] # Fix epochs._raw._filenames if hasattr(epochs, '_raw') and hasattr(epochs._raw, '_filenames'): epochs._raw._filenames = [str(p) for p in epochs._raw._filenames]