""" Filename: fNIRS_module.py Description: Core functionality for FLARES Author: Tyler de Zeeuw License: GPL-3.0 """ # Built-in imports import os import sys import time import logging import platform import warnings import threading from io import BytesIO from copy import deepcopy from pathlib import Path from zipfile import ZipFile from datetime import datetime from itertools import compress from multiprocessing import Queue from typing import Any, Optional, cast, Literal, Iterator, Union # External library imports import pywt # type: ignore import qtpy # type: ignore import xlrd # type: ignore import psutil import scooby # type: ignore import requests import pyvistaqt # type: ignore import darkdetect # type: ignore import numpy as np import pandas as pd from PIL import Image import seaborn as sns import neurokit2 as nk # type: ignore from tqdm.auto import tqdm from pandas import DataFrame import matplotlib.pyplot as plt from matplotlib.axes import Axes from numpy.typing import NDArray #import vtkmodules.util.data_model from numpy import floating, float64 from matplotlib.lines import Line2D import matplotlib.colors as mcolors from scipy.stats import ttest_1samp # type: ignore from matplotlib.figure import Figure import statsmodels.formula.api as smf # type: ignore #import vtkmodules.util.execution_model from nilearn.plotting import plot_design_matrix # type: ignore from scipy.signal import welch, butter, filtfilt # type: ignore from matplotlib.colors import LinearSegmentedColormap from IPython.display import display, Markdown, clear_output # type: ignore from statsmodels.tools.sm_exceptions import ConvergenceWarning # type: ignore from concurrent.futures import ProcessPoolExecutor, as_completed from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas # External library imports for mne import mne from mne import EvokedArray, Info, read_source_spaces, stc_near_sensors # type: ignore from mne.source_space import SourceSpaces from mne.transforms import Transform # type: ignore from mne.io import BaseRaw, read_raw_snirf # type: ignore from mne.annotations import Annotations # type: ignore from mne_nirs.visualisation import plot_glm_group_topo # type: ignore from mne_nirs.channels import get_long_channels, get_short_channels, picks_pair_to_idx # type: ignore from mne_nirs.experimental_design import make_first_level_design_matrix # type: ignore from mne_nirs.statistics import run_glm, statsmodels_to_results # type: ignore from mne_nirs.signal_enhancement import enhance_negative_correlation, short_channel_regression # type: ignore from mne.preprocessing.nirs import beer_lambert_law, optical_density, temporal_derivative_distribution_repair, source_detector_distances, short_channels # type: ignore from mne_nirs.io.fold import fold_channel_specificity # type: ignore from mne_nirs.preprocessing import peak_power # type: ignore from mne.viz import Brain from mne_nirs.statistics._glm_level_first import RegressionResults # type: ignore from mne.filter import filter_data # type: ignore CURRENT_VERSION = "1.0.0" GUI = False PLATFORM_NAME = platform.system().lower() BASE_SNIRF_FOLDER: str SNIRF_SUBFOLDERS: list[str] STIM_DURATION: list[float] MAX_WORKERS: int SECONDS_TO_STRIP: int DOWNSAMPLE: bool DOWNSAMPLE_FREQUENCY: int FORCE_DROP_CHANNELS: list[str] SOURCE_DETECTOR_SEPARATOR: str OPTODE_FILE: bool OPTODE_FILE_PATH: str OPTODE_FILE_SEPARATOR: str TDDR: bool WAVELET: bool IQR: float HEART_RATE: bool SECONDS_TO_STRIP_HR: int MAX_LOW_HR: int MAX_HIGH_HR: int SMOOTHING_WINDOW_HR: int HEART_RATE_WINDOW: int SHORT_CHANNEL: bool SHORT_CHANNEL_THRESH: float SCI: bool SCI_TIME_WINDOW: int SCI_THRESHOLD: float PSP: bool PSP_TIME_WINDOW: int PSP_THRESHOLD: float # TODO: Implement SNR: bool SNR_TIME_WINDOW : int SNR_THRESHOLD: float EXCLUDE_CHANNELS: bool MAX_BAD_CHANNELS: int LONG_CHANNEL_THRESH: float METADATA: dict DRIFT_MODEL: str DURATION_BETWEEN_ACTIVITIES: int HRF_MODEL: str SHORT_CHANNEL_REGRESSION: bool N_JOBS: int TARGET_ACTIVITY: str TARGET_CONTROL: str ROI_GROUP_1: list[list[int]] ROI_GROUP_2: list[list[int]] ROI_GROUP_1_NAME: str ROI_GROUP_2_NAME: str P_THRESHOLD: float SEE_BAD_IMAGES: bool ABS_T_VALUE: int ABS_THETA_VALUE: int ABS_CONTRAST_T_VALUE: int ABS_CONTRAST_THETA_VALUE: int ABS_SIGNIFICANCE_T_VALUE: int ABS_SIGNIFICANCE_THETA_VALUE: int BRAIN_DISTANCE: float BRAIN_MODE: str EPOCH_REJECT_CRITERIA_THRESH: float TIME_MIN_THRESH: int TIME_MAX_THRESH: int VERBOSITY: bool REJECT_PAIRS = None FORCE_DROP_ANNOTATIONS = None FILTER_LOW_PASS = None FILTER_HIGH_PASS = None EPOCH_PAIR_TOLERANCE_WINDOW = None # FIXME: Shouldn't need each ordering - just order it before checking FIXED_CATEGORY_COLORS = { "SCI only": "skyblue", "PSP only": "salmon", "SNR only": "lightgreen", "PSP + SCI": "orange", "SCI + SNR": "violet", "PSP + SNR": "gold", "SCI + PSP": "orange", "SNR + SCI": "violet", "SNR + PSP": "gold", "PSP + SNR + SCI": "gray", "SCI + PSP + SNR": "gray", "SCI + SNR + PSP": "gray", "PSP + SCI + SNR": "gray", "PSP + SNR + SCI": "gray", "SNR + SCI + PSP": "gray", "SNR + PSP + SCI": "gray", } REQUIRED_KEYS: dict[str, Any] = { "BASE_SNIRF_FOLDER": str, "SNIRF_SUBFOLDERS": list, "STIM_DURATION": list, "MAX_WORKERS": int, "SECONDS_TO_STRIP": int, "DOWNSAMPLE": bool, "DOWNSAMPLE_FREQUENCY": int, "FORCE_DROP_CHANNELS": list, "SOURCE_DETECTOR_SEPARATOR": str, "OPTODE_FILE": bool, "OPTODE_FILE_PATH": str, "OPTODE_FILE_SEPARATOR": str, "TDDR": bool, "WAVELET": bool, "IQR": float, "HEART_RATE": bool, "SECONDS_TO_STRIP_HR": int, "MAX_LOW_HR": int, "MAX_HIGH_HR": int, "SMOOTHING_WINDOW_HR": int, "HEART_RATE_WINDOW": int, "SHORT_CHANNEL": bool, "SHORT_CHANNEL_THRESH": float, "SCI": bool, "SCI_TIME_WINDOW": int, "SCI_THRESHOLD": float, "PSP": bool, "PSP_TIME_WINDOW": int, "PSP_THRESHOLD": float, "SNR": bool, "SNR_TIME_WINDOW": int, "SNR_THRESHOLD": float, "EXCLUDE_CHANNELS": bool, "MAX_BAD_CHANNELS": int, "LONG_CHANNEL_THRESH": float, "METADATA": dict, "DRIFT_MODEL": str, "DURATION_BETWEEN_ACTIVITIES": int, "HRF_MODEL": str, "SHORT_CHANNEL_REGRESSION": bool, "N_JOBS": int, "TARGET_ACTIVITY": str, "TARGET_CONTROL": str, "ROI_GROUP_1": list, "ROI_GROUP_2": list, "ROI_GROUP_1_NAME": str, "ROI_GROUP_2_NAME": str, "P_THRESHOLD": float, "SEE_BAD_IMAGES": bool, "ABS_T_VALUE": int, "ABS_THETA_VALUE": int, "ABS_CONTRAST_T_VALUE": int, "ABS_CONTRAST_THETA_VALUE": int, "ABS_SIGNIFICANCE_T_VALUE": int, "ABS_SIGNIFICANCE_THETA_VALUE": int, "BRAIN_DISTANCE": float, "BRAIN_MODE": str, "EPOCH_REJECT_CRITERIA_THRESH": float, "TIME_MIN_THRESH": int, "TIME_MAX_THRESH": int, "VERBOSITY": bool, # "REJECT_PAIRS": bool, # "FORCE_DROP_ANNOTATIONS": list, # "FILTER_LOW_PASS": float, # "FILTER_HIGH_PASS": float, # "EPOCH_PAIR_TOLERANCE_WINDOW": int, } # Ensure that we are working in the directory of this file script_dir = os.path.dirname(os.path.abspath(__file__)) os.chdir(script_dir) # Configure logging to file with timestamps and realtime flush if PLATFORM_NAME == 'darwin': logging.basicConfig( filename=os.path.join(os.path.dirname(sys.executable), "../../../fnirs_analysis.log"), level=logging.INFO, format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', filemode='a' ) else: logging.basicConfig( filename='fnirs_analysis.log', level=logging.INFO, format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', filemode='a' ) logger = logging.getLogger() class ProcessingError(Exception): def __init__(self, message: str = "Something went wrong!"): self.message = message super().__init__(self.message) def gui_entry(config: dict[str, Any], gui_queue: Queue, progress_queue: Queue) -> None: try: print("setting config") set_config(config, True) # Start a thread to forward progress messages back to GUI def forward_progress(): while True: try: msg = progress_queue.get(timeout=1) if msg == "__done__": break gui_queue.put(msg) except: continue t = threading.Thread(target=forward_progress, daemon=True) t.start() # Run the actual processing, with progress_queue passed down print("actual call") result = run_groups(config, True, progress_queue=progress_queue) # Signal end of progress progress_queue.put("__done__") t.join() gui_queue.put({"success": True, "result": result}) except Exception as e: import traceback gui_queue.put({ "success": False, "error": str(e), "traceback": traceback.format_exc() }) def set_config(config: dict[str, Any], gui: bool = False) -> None: """ Validates and applies the given configuration dictionary. Parameters ---------- config : dict[str, Any] Dictionary containing configuration keys and their values. """ if gui: globals().update({"GUI": True}) # Ensure all keys are present for key, expected_type in REQUIRED_KEYS.items(): if key not in config: raise KeyError(f"Missing config key: {key}") value = config[key] if not isinstance(value, expected_type): # Special handling for lists to check list contents if expected_type == list and isinstance(value, list): continue # optionally: validate inner types too raise TypeError(f"Key '{key}' has incorrect type. Expected {expected_type.__name__}, got {type(value).__name__}") # Update the global variables to match the values in the config keys globals().update(config) # Ensure that passed through variables are correct or that they actually exist assert Path(BASE_SNIRF_FOLDER).is_dir(), "BASE_SNIRF_FOLDER was not found. Please check the folder location and try again." for folder in SNIRF_SUBFOLDERS: assert Path(os.path.join(BASE_SNIRF_FOLDER, folder)).is_dir(), f"The subfolder {folder} could not be found. Please check the folder location and try again." assert len(SNIRF_SUBFOLDERS) == len(STIM_DURATION), f"The amount of subfolders do not match the amount of stim durations. Subfolders: {len(SNIRF_SUBFOLDERS)} Stim durations: {len(STIM_DURATION)}" if OPTODE_FILE: path = Path(OPTODE_FILE_PATH) assert path.is_file(), "OPTODE_FILE was specified, but OPTODE_FILE_PATH is not a file." assert path.suffix == ".txt", "OPTODE_FILE_PATH does not end with a .txt extension." # Ensure that the BASE_SNIRF_FOLDER is an absolute path - helpful when logger.infoing later if 'BASE_SNIRF_FOLDER' in globals(): abs_path = str(Path(BASE_SNIRF_FOLDER).resolve()) globals()['BASE_SNIRF_FOLDER'] = abs_path # Supress MNE's warnings if not VERBOSITY: warnings.filterwarnings("ignore", category=ConvergenceWarning) warnings.filterwarnings("ignore", category=RuntimeWarning) logger.info("[Config] Configuration successfully set.") def run_groups(config, gui: bool = False, progress_queue=None) -> tuple[dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], dict[str, dict[str, BaseRaw]], dict[str, list[Figure]], dict[str, str], float]: """ Process multiple data folders and aggregate results, haemoglobin, figures, and processing details. Returns ------- tuple[dict[str, tuple[DataFrame, DataFrame, DataFrame]], dict[str, dict[str, BaseRaw]], dict[str, list[Figure]], dict[str, str]] - dict[str, tuple[DataFrame, DataFrame, DataFrame]]: Results dataframes grouped by folder. - dict[str, dict[str, BaseRaw]]: Raw haemoglobin data indexed by file ID. - dict[str, list[Figure]]: Figures generated during processing grouped by step. - dict[str, str]: Processing status messages indexed by file ID. - float: Elapsed time """ # Create dictionaries to store our results all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]] = {} all_figures: dict[str, list[Figure]] = {} all_raw_haemo: dict[str, dict[str, BaseRaw]] = {} all_processes: dict[str, str] = {} # Variables to store our total files to be processed and the remaining amount of files while the program is running total_files = 0 files_remaining = {'count': 0} start_time = time.time() # Iterate over all the folders and determine how many files are in the folder logger.info("Calculating how many files there are...") for folder in SNIRF_SUBFOLDERS: full_path = os.path.join(BASE_SNIRF_FOLDER, folder) num_items = len([ f for f in os.listdir(full_path) if os.path.isfile(os.path.join(full_path, f)) ]) total_files += num_items logger.info(f"Total of {total_files} files.") # Set the remaining count to be the total amount of files files_remaining['count'] = total_files # Iterate over all the folders for folder, stim_duration in zip(SNIRF_SUBFOLDERS, STIM_DURATION): full_path = os.path.join(BASE_SNIRF_FOLDER, folder) try: # Process all participants in the folder logger.info(f"Processing all files in {folder}...") raw_haemo, df_roi, df_cha, df_con, df_design_matrix, figures, process = process_folder(full_path, stim_duration, files_remaining, config, gui, progress_queue=progress_queue) # Store the results into the corresponding dictionaries logger.info(f"Storing the results from the {folder} folder...") # TODO: This looks yucky try: all_results[folder] = (df_roi, df_cha, df_con, df_design_matrix) logger.info(f"Applied all results.") except: pass try: for step, fig_list in figures.items(): all_figures.setdefault(step, []).extend(fig_list) logger.info(f"Applied all figures.") except: pass try: for file_id, raw in raw_haemo.items(): all_raw_haemo[file_id] = raw logger.info(f"Applied all haemo.") except: pass try: for file_id, p in process.items(): all_processes[file_id] = p logger.info(f"Applied all processes.") except: pass except ProcessingError as e: logger.info(f"Something happened! {e}") # Something really bad happened. No partial return raise Exception(e) except Exception as e: logger.info(f"Something happened! {e}") # Still return a partial analysis even if something goes wrong return all_results, all_raw_haemo, all_figures, all_processes, time.time() - start_time return all_results, all_raw_haemo, all_figures, all_processes, time.time() - start_time def create_image_montage(images: list[Image.Image], cols: int) -> Optional[Image.Image]: """ Creates a grid montage image from a list of PIL Images. Parameters ---------- images : list[Image.Image] List of images to arrange in the montage. cols : int Number of columns in the montage grid. Returns ------- Optional[Image.Image] The combined montage image, or None if the input list of images is empty. """ # Verify that we have images to process if not images: return None # Calculate the width, height, and rows logger.info("Calculating the montage parameters...") widths, heights = zip(*(i.size for i in images)) max_width = max(widths) max_height = max(heights) rows = (len(images) + cols - 1) // cols # Create the montage image logger.info("Creating the montage...") montage = Image.new('RGBA', (cols * max_width, rows * max_height), (255, 255, 255, 255)) for idx, image in enumerate(images): x = (idx % cols) * max_width y = (idx // cols) * max_height montage.paste(image, (x, y)) # type: ignore return montage def show_all_images(figures: dict[str, list[BytesIO]], inline: bool = False) -> None: """ Displays montages of figures either inline or in separate windows. Parameters ---------- figures : dict[str, list[Figure]] Dictionary containing lists of figures categorized by type. inline : bool, optional If True, display images inline (e.g., in Jupyter notebooks). Otherwise, opens them in separate windows (default is False). """ if inline: logger.info("Inline was selected.") else: logger.info("Inline was not selected.") # If we have less than 4 figures, the columns should be the exact amount of images we have. If we have more, enforce 4 columns logger.info("Calculating columns...") if len(figures.get('Raw', [])) < 4: cols = len(figures.get('Raw', [])) else: cols = 4 # Iterate over all of the types of figure, create a montage with figures of the same type, and display the resulting image logger.info("Generating images...") for _, fig_bytes_list in figures.items(): pil_images = [] for b in fig_bytes_list: try: img = Image.open(BytesIO(b)).convert("RGB") pil_images.append(img) except Exception as e: logger.warning(f"Could not open image from bytes: {e}") continue montage = create_image_montage(pil_images, cols) if montage: # Determine how to display the images to the user if inline: display(montage) else: montage.show() def save_all_images(figures: dict[str, list[Figure]]) -> None: """ Saves montages of figures as timestamped PNG files in folder called 'images'. Parameters ---------- figures : dict[str, list[Figure]] Dictionary containing lists of figures categorized by type. """ # Get the current working directory and create a folder called images if it does not exist logger.info("Getting the current directory...") if PLATFORM_NAME == 'darwin': images_folder = os.path.join(os.path.dirname(sys.executable), "../../../images") else: cwd = os.getcwd() images_folder = os.path.join(cwd, "images") logger.info("Attempting to create the images folder...") os.makedirs(images_folder, exist_ok=True) # Generate a timestamp to be appended to the end of the file name logger.info("Generating the timestamp...") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # If we have less than 4 figures, the columns should be the exact value. If we have more, enforce 4 columns logger.info("Calculating columns...") raw_fig_count = len(figures.get('Raw', [])) if raw_fig_count < 4: cols = raw_fig_count else: cols = 4 # Iterate over all of the types of figures, create a montage with figures of the same type, and save the resulting image logger.info("Generating images...") for step, fig_bytes_list in figures.items(): pil_images = [] for b in fig_bytes_list: try: img = Image.open(BytesIO(b)).convert("RGB") pil_images.append(img) except Exception as e: logger.warning(f"Could not open image from bytes: {e}") continue montage = create_image_montage(pil_images, cols) if montage: filename = f"{step}_{timestamp}.png" save_path = os.path.join(images_folder, filename) montage.save(save_path) # type: ignore logger.info(f"Saved image to {save_path}") logger.info(f"All images have been saved to '{images_folder}'.") def load_snirf(file_path: str, ID: str, drop_prefixes: list[str]) -> tuple[BaseRaw, Figure]: """ Loads a snirf file, optionally drops channels, downsamples, and creates a figure showing the results. Parameters ---------- file_path : str Path of the snirf file to load. ID : str File name of the the snirf file that was loaded. drop_prefixes : list[str] List of channel name prefixes to drop from the data. Returns ------- tuple[BaseRaw, Figure] - BaseRaw: The processed data object. - Figure: The corresponding Matplotlib figure. """ logger.info(f"Loading the snirf file ({ID})...") # Read the snirf file raw = read_raw_snirf(file_path, verbose=VERBOSITY) # type: ignore raw.load_data(verbose=VERBOSITY) # type: ignore # Strip the specified amount of seconds from the start of the file total_duration = getattr(raw, "times")[-1] if total_duration > SECONDS_TO_STRIP: raw.crop(tmin=SECONDS_TO_STRIP, tmax=total_duration, verbose=VERBOSITY) # type: ignore logger.info(f"Stripped first {SECONDS_TO_STRIP} second(s) of data.") else: logger.info(f"Data length ({total_duration:.2f}s) less than strip duration; no cropping applied.") # If the user forcibly dropped channels, remove them now before any processing occurs logger.info("Checking if there are channels to forcibly drop...") if drop_prefixes: logger.info("Force dropped channels was specified.") channels_to_drop = [ch for ch in cast(list[str], getattr(raw, "ch_names")) if any(ch.startswith(prefix) for prefix in drop_prefixes)] raw.drop_channels(channels_to_drop, "raise") # type: ignore logger.info("Force dropped channels:", channels_to_drop) # If the user wants to downsample, do it right away logger.info("Checking if we should downsample...") if DOWNSAMPLE: logger.info("Downsample was specified.") sfreq_old = getattr(raw, "info")["sfreq"] raw.resample(DOWNSAMPLE_FREQUENCY, verbose=VERBOSITY) # type: ignore sfreq_new = getattr(raw, "info")["sfreq"] logger.info(f"Finished downsampling. Old frequency: {sfreq_old}. New frequency: {sfreq_new}.") # Create a figure for the results logger.info("Creating the figure...") fig = cast(Figure, raw.plot(show=False, n_channels=len(getattr(raw, "ch_names")), duration=raw.times[-1]).figure) # type: ignore fig.suptitle(f"Raw fNIRS Data for {ID}", fontsize=16) # type: ignore fig.subplots_adjust(top=0.92) plt.close(fig) logger.info("Successfully loaded the snirf file.") return raw, fig def calculate_and_apply_updated_optode_coordinates(data: BaseRaw) -> BaseRaw: """ Update optode coordinates on the given MNE Raw data using a specified optode file. Parameters ---------- data : BaseRaw The loaded data object to process with new optode coordinates. Returns ------- BaseRaw The processed data object with the updated montage applied. """ logger.info("Updating optode coordinates...") fiducials: dict[str, NDArray[floating[Any]]] = {} ch_positions: dict[str, NDArray[floating[Any]]] = {} # Read the lines from the optode file logger.info(f"Reading optode file from {OPTODE_FILE_PATH}") with open(OPTODE_FILE_PATH, 'r') as f: for line in f: if line.strip(): # Split by the semicolon and convert to meters ch_name, coords_str = line.split(OPTODE_FILE_SEPARATOR) coords = np.array(list(map(float, coords_str.strip().split()))) * 0.001 logger.info(f"Read line: {ch_name} with coords (m): {coords}") # The key we have is a fiducial if ch_name.lower() in ['lpa', 'nz', 'rpa']: fiducials[ch_name.lower()] = coords # The key we have is a source or detector else: ch_positions[ch_name.upper()] = coords # Create montage with updated coords in head space logger.info("Creating and applying the montage...") initial_montage = mne.channels.make_dig_montage(ch_pos=ch_positions, nasion=fiducials.get('nz'), lpa=fiducials.get('lpa'), rpa=fiducials.get('rpa'), coord_frame='head') # type: ignore data.set_montage(initial_montage, verbose=VERBOSITY) # type: ignore logger.info("Successfully updated optode coordinates.") return data def calculate_and_apply_tddr(data: BaseRaw, ID: str) -> tuple[BaseRaw, Figure]: """ Applies Temporal Derivative Distribution Repair (TDDR) to the raw data and creates a figure showing the results. Parameters ---------- data : BaseRaw The loaded data object to process. ID : str File name of the the snirf file that was loaded. Returns ------- tuple[BaseRaw, Figure] - BaseRaw: The processed data object. - Figure: The corresponding Matplotlib figure. """ # Apply TDDR logger.info("Applying temporal derivative distribution repair...") raw_with_tddr = cast(BaseRaw, temporal_derivative_distribution_repair(data, verbose=VERBOSITY)) # Create a figure for the results logger.info("Creating the figure...") fig = cast(Figure, raw_with_tddr.plot(show=False, n_channels=len(getattr(data, "ch_names")), duration=data.times[-1]).figure) # type: ignore fig.suptitle(f"TDDR for {ID}", fontsize=16) # type: ignore fig.subplots_adjust(top=0.92) plt.close(fig) logger.info("Successfully applied temporal derivative distribution repair.") return raw_with_tddr, fig def iqr_threshold(coeffs: NDArray[float64], k: float = 1.5) -> floating[Any]: """ Calculate the interquartile range (IQR) threshold scaled by a factor, k. Parameters ---------- coeffs : NDArray[float64] Array of coefficients to compute the IQR from. k : float, optional Scaling factor for the IQR (default is 1.5). Returns ------- floating[Any] The scaled IQR threshold value. """ # Calculate the IQR q1 = np.percentile(coeffs, 25) q3 = np.percentile(coeffs, 75) iqr = q3 - q1 return k * iqr def wavelet_iqr_denoise(signal: NDArray[float64], wavelet: str = 'db4', level: int = 3) -> NDArray[float64]: """ Denoises a signal using wavelet decomposition and IQR-based thresholding on detail coefficients. Parameters ---------- signal : NDArray[float64] The input signal array to denoise. wavelet : str, optional The type of wavelet to use for decomposition (default is 'db4'). level : int, optional Decomposition level for wavelet transform (default is 3). Returns ------- NDArray[float64] The denoised signal array, with the same length as the input. """ # Decompose the signal using wavelet transform and initialize a list with approximation coefficients coeffs: list[NDArray[float64]] = pywt.wavedec(signal, wavelet, level=level) # type: ignore cA = coeffs[0] denoised_coeffs = [cA] # Threshold detail coefficients to reduce noise for cD in coeffs[1:]: threshold = iqr_threshold(cD, IQR) cD_thresh = np.sign(cD) * np.maximum(np.abs(cD) - threshold, 0.0) # np.where((cD < lower) | (cD > upper), 0, cD) cD_thresh = cD_thresh.astype(float64) denoised_coeffs.append(cD_thresh) # Reconstruct the denoised signal denoised_signal = cast(NDArray[float64], pywt.waverec(denoised_coeffs, wavelet)) # type: ignore return denoised_signal[:len(signal)] def calculate_and_apply_wavelet(data: BaseRaw, ID: str) -> tuple[BaseRaw, Figure]: """ Applies a wavelet IQR denoising filter to the data and generates a plot. Parameters ---------- data : BaseRaw The loaded data object to process. ID : str File name of the the snirf file that was loaded. Returns ------- tuple[BaseRaw, Figure] - BaseRaw: The processed data object. - Figure: The corresponding Matplotlib figure. """ logger.info("Applying the wavelet filter...") # Denoise the data logger.info("Denoising the data...") loaded_data: NDArray[float64] = data.get_data(verbose=VERBOSITY) # type: ignore denoised_data = np.zeros_like(loaded_data) logger.info("Calculating the IQR, decomposing the signal, and thresholding the coefficients...") for ch in range(loaded_data.shape[0]): denoised_data[ch, :] = wavelet_iqr_denoise(loaded_data[ch, :], wavelet='db4', level=3) # Reconstruct the data with the annotations logger.info("Reconstructing the data with annotations...") raw_with_tddr_and_wavelet = mne.io.RawArray(denoised_data, cast(mne.Info, data.info), verbose=VERBOSITY) raw_with_tddr_and_wavelet.set_annotations(data.annotations.copy(), verbose=VERBOSITY) # type: ignore # Create a figure for the results logger.info("Creating the figure...") fig = cast(Figure, raw_with_tddr_and_wavelet.plot(show=False, n_channels=len(getattr(data, "ch_names")), duration=data.times[-1]).figure) # type: ignore fig.suptitle(f"Wavelet for {ID}", fontsize=16) # type: ignore fig.subplots_adjust(top=0.92) plt.close(fig) logger.info("Successfully applied the wavelet filter.") return raw_with_tddr_and_wavelet, fig def short_channel_processing_for_hr(data: BaseRaw, short_chans: BaseRaw | None) -> tuple[float, NDArray[float64], NDArray[float64]]: """ Extract and trim short-channel fNIRS signal for heart rate analysis. Parameters ---------- data : BaseRaw The loaded data object to process. short_chans : BaseRaw | None Data object with only short separation channels, or None if unavailable. Returns ------- tuple[float, NDArray[float64], NDArray[float64]] - float: Sampling frequency of the signal. - NDArray[float64]: Trimmed short-channel signal. - NDArray[float64]: Corresponding time values. """ # Find the short channel (or best candidate) and extract signal data and sampling frequency logger.info("Extracting the signal and calculating the sampling frequency...") # If a short channel exists, use it for our signal. Otherwise just take the first channel in the data # TODO: Find a better way around this if short_chans is not None: signal = cast(NDArray[float64], short_chans.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore else: signal = cast(NDArray[float64], data.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore # Calculate the sampling frequency sfreq = cast(int, data.info['sfreq']) # Trim start and end of the signal to remove edge artifacts logger.info(f"Removing {SECONDS_TO_STRIP_HR} seconds from the beginning and end of the file...") strip_samples = int(sfreq * SECONDS_TO_STRIP_HR) signal_trimmed = signal[strip_samples:-strip_samples] times_trimmed = cast(NDArray[float64], getattr(data, "times"))[strip_samples:-strip_samples] return sfreq, signal_trimmed, times_trimmed def calculate_heart_rate_neurokit(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[float64], float]: """ Calculate and smooth heart rate from a trimmed signal using NeuroKit. Parameters ---------- sfreq : float Sampling frequency of the signal. signal_trimmed : NDArray[float64] Preprocessed and trimmed fNIRS signal. Returns ------- tuple[NDArray[float64], float] - NDArray[float64]: Smoothed heart rate time series (BPM). - float: Mean heart rate. """ logger.info("Calculating heart rate using NeuroKit...") # Filter signal to isolate heart rate frequencies and detect peaks logger.info("Filtering the signal and detecting peaks...") signal_filtered = cast(NDArray[float64], nk.signal_filter(signal_trimmed, sampling_rate=sfreq, lowcut=0.8, highcut=2.5)) # type: ignore peaks_dict = cast(dict[str, Any], nk.signal_findpeaks(signal_filtered)) # type: ignore peaks = peaks_dict['Peaks'] hr = cast(NDArray[float64], nk.signal_rate(peaks, sampling_rate=sfreq, desired_length=len(signal_trimmed))) # type: ignore hr_clean = np.clip(hr, MAX_LOW_HR, MAX_HIGH_HR) # Smooth heart rate time series by replacing spikes with local rolling mean and calculate the mean logger.info("Smoothing the signal and calculating the mean...") hr_series = pd.Series(hr_clean) local_median = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).median() spikes = hr_series > (local_median + 10) smoothed_values = hr_series.copy() smoothed_spikes = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).mean() smoothed_values[spikes] = smoothed_spikes[spikes] hr_smooth_nk = cast(NDArray[float64], smoothed_values.to_numpy()) # type: ignore mean_hr_nk = hr_smooth_nk.mean() logger.info("Original HR min/max: %f, %f", hr_clean.min(), hr_clean.max()) logger.info("Smoothed HR min/max:%f, %f", hr_smooth_nk.min(), hr_smooth_nk.max()) logger.info(f"Estimated mean HR nk: {mean_hr_nk:.1f} BPM") logger.info("Successfully calculated heart rate using NeuroKit.") return hr_smooth_nk, mean_hr_nk def calculate_heart_rate_scipy(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float]: """ Estimate heart rate using spectral analysis on a high-pass filtered signal. Parameters ---------- sfreq : float Sampling frequency of the input signal. signal_trimmed : NDArray[float64] Trimmed fNIRS signal to analyze. Returns ------- tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float] - NDArray[floating[Any]]: Frequencies converted to beats per minute (BPM). - NDArray[float64]: Power spectral density (PSD) of the signal. - np.ndarray[Any, np.dtype[np.bool_]]: Boolean mask indicating frequencies within heart rate range (30-300 BPM). - float: Estimated mean heart rate in BPM corresponding to the PSD peak within the range. """ logger.info("Calculating heart rate using SciPy...") # Apply a high-pass Butterworth filter to remove slow trends below 0.5 Hz from the trimmed signal (actual data) logger.info("Applying a butterworth filter...") b, a = cast(tuple[NDArray[float64], NDArray[float64]], butter(2, 0.5 / (sfreq / 2), btype='high')) signal_hp = cast(NDArray[float64],filtfilt(b, a, signal_trimmed)) # Calculate the Power Spectral Density (PSD) of the filtered signal using Welch's method logger.info("Calculating the PSD...") nperseg = min(len(signal_hp), 4096) frequencies_scipy, psd_scipy = cast(tuple[NDArray[float64], NDArray[float64]], welch(signal_hp, fs=sfreq, nperseg=nperseg, noverlap=nperseg//2)) # Convert frequency values to beats per minute (BPM) and set a heart rate range (30-300 BPM) logger.info("Converting to BPM...") freq_bpm_scipy = frequencies_scipy * 60 freq_range_scipy = (freq_bpm_scipy > 30) & (freq_bpm_scipy < 300) # Identify the peak frequency within the heart rate range and estimate the mean heart rate in BPM logger.info("Finding the mean...") peak_index = np.argmax(psd_scipy[freq_range_scipy]) mean_hr_scipy = freq_bpm_scipy[freq_range_scipy][peak_index] logger.info("Successfully calculated heart rate using SciPy.") return freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy def plot_heart_rate( freq_bpm_scipy: NDArray[floating[Any]], psd_scipy: NDArray[float64], freq_range_scipy: np.ndarray[Any, np.dtype[np.bool_]], mean_hr_scipy: float, hr_smooth_nk: NDArray[floating[Any]], mean_hr_nk: float, times_trimmed: NDArray[floating[Any]], overruled: bool ) -> tuple[Figure, Figure]: """ Generate plots comparing heart rate estimates from SciPy PSD and NeuroKit2. Parameters ---------- freq_bpm_scipy : NDArray[floating[Any]] Frequencies in beats per minute from SciPy PSD analysis. psd_scipy : NDArray[float64] Power spectral density values corresponding to freq_bpm_scipy. freq_range_scipy : np.ndarray[Any, np.dtype[np.bool_]] Boolean mask indicating the heart rate frequency range used in PSD. mean_hr_scipy : float Mean heart rate estimated from SciPy PSD peak. hr_smooth_nk : NDArray[floating[Any]] Smoothed instantaneous heart rate from NeuroKit2. mean_hr_nk : float Mean heart rate estimated from NeuroKit2 data. times_trimmed : NDArray[floating[Any]] Time points corresponding to hr_smooth_nk values. overruled: bool True if the heart rate from NeuroKit2 is overriding the results from the PSD. Returns ------- tuple[Figure, Figure] - Figure showing the PSD and SciPy heart rate estimate. - Figure showing the time series comparison of heart rates. """ # Create the first plot for the PSD. Add a yellow range to show what we will be filtering to. logger.info("Creating the figure...") fig1, ax1 = plt.subplots(figsize=(10, 5)) # type: ignore ax1.set_xlim(30, 300) ax1.plot(freq_bpm_scipy[freq_range_scipy], psd_scipy[freq_range_scipy]) # type: ignore ax1.axvline(x=mean_hr_scipy, color='red', linestyle='--', label=f'Mean HR: {mean_hr_scipy:.1f} BPM') # type: ignore ax1.axvspan(min(mean_hr_nk - HEART_RATE_WINDOW, mean_hr_scipy - HEART_RATE_WINDOW), max(mean_hr_nk + HEART_RATE_WINDOW, mean_hr_scipy + HEART_RATE_WINDOW), color='yellow', alpha=0.3, label=f'HR Range ±{HEART_RATE_WINDOW} BPM') # type: ignore ax1.set_xlabel('Heart Rate (BPM)') # type: ignore ax1.set_ylabel('Power Spectral Density') # type: ignore ax1.set_title('PSD of fNIRS signal - Peak indicates Heart Rate') # type: ignore ax1.grid(True) # type: ignore # Was the value we reported here correct for the data on the graph or was it overruled? if overruled: note = ( '\n' 'Note: Calculation was bad!\n' 'Data has been set to match\n' 'the value from NeuroKit2.' ) phantom = Line2D([0], [0], color='none', label=note) handles, _ = ax1.get_legend_handles_labels() ax1.legend(handles=handles + [phantom]) # type: ignore else: ax1.legend() # type: ignore plt.close(fig1) # Create the second plot showing the rolling heart rate, as well as the two averages that were calculated logger.info("Creating the figure...") fig2, ax2 = plt.subplots(figsize=(14, 6)) # type: ignore ax2.plot(times_trimmed, hr_smooth_nk, label='Instantaneous HR (NeuroKit2)', color='blue', alpha=0.7) # type: ignore ax2.axhline(mean_hr_nk, color='red', linestyle='--', label=f'Mean HR NeuroKit2: {mean_hr_nk:.1f} BPM') # type: ignore ax2.axhline(mean_hr_scipy, color='orange', linestyle=':', label=f'SciPy Welch PSD (HP filtered): {mean_hr_scipy:.1f} BPM') # type: ignore ax2.set_xlabel('Time (seconds)') # type: ignore ax2.set_ylabel('Heart Rate (BPM)') # type: ignore ax2.set_title('Heart Rate Estimates Comparison') # type: ignore ax2.legend() # type: ignore ax2.grid(True) # type: ignore fig2.tight_layout() plt.close(fig2) return fig1, fig2 def plot_timechannel_quality_metrics(data: BaseRaw, scores: NDArray[float64], times: list[tuple[float]], color_stops: tuple[list[float], list[float]], threshold: float, title: Optional[str] = None) -> tuple[Figure, Figure]: """ Generate two heatmaps visualizing channel quality metrics over time. Parameters ---------- data : BaseRaw The loaded data object to process. scores : NDArray[float64] A 2D array of quality scores for each channel over time. times : list[tuple[float]] List of time boundaries used to label each score column. color_stops : tuple[list[float], list[float]] Two lists of color values for custom colormaps. threshold : float, Threshold value for the color bar. title : Optional[str], optional Base title for the figures, (default is None). Returns ------- tuple[Figure, Figure] - Figure: Heatmap of all scores across channels and time. - Figure: Binary heatmap showing only scores above the threshold. """ # Get only the hbo / hbr channels once as we dont need to see the same results twice half_ch = len(getattr(data, "ch_names")) // 2 ch_names = getattr(data, "ch_names")[:half_ch] scores = scores[:half_ch, :] # Extract rounded time points to use as column headers cols = [np.round(t[0]) for t in times] n_chans = len(ch_names) vsize = 0.2 * n_chans # Create the first figure fig1, ax1 = plt.subplots(figsize=(10, vsize), layout="constrained") # type: ignore fig1.suptitle(title + " - All Scores", fontsize=16, fontweight="bold") # type: ignore # Create a DataFrame to structure data for the heatmap data_to_plot = DataFrame( data=scores, columns=pd.Index(cols, name="Time (s)"), index=pd.Index(ch_names, name="Channel"), ) # Define a custom colormap using provided color stops and base colors base_colors = ['red', 'red', 'yellow', 'green', 'green'] colors = list(zip(color_stops[0], base_colors[:len(color_stops[0])])) cmap = mcolors.LinearSegmentedColormap.from_list('gyr', colors) # Plot heatmap of scores sns.heatmap( # type: ignore data=data_to_plot, cmap=cmap, vmin=0, vmax=1, cbar_kws=dict(label="Score"), ax=ax1, ) # Add vertical dashed lines at each time boundary, sit the title, and place a black strikethrough through a bad channel for x in range(1, len(times)): ax1.axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") # type: ignore ax1.set_title("All Scores", fontweight="bold") # type: ignore markbad(data, ax1, ch_names) # Calculate average score per channel and annotate to the right of the heatmap avg_sci_subset: pd.Series[float] = data_to_plot.mean(axis=1) # type: ignore norm = mcolors.Normalize(vmin=0, vmax=1) text_x = data_to_plot.shape[1] + 0.5 for i, val in enumerate(avg_sci_subset): color = cmap(norm(val)) ax1.text( # type: ignore text_x, i + 0.5, f"{val:.3f}", va='center', ha='left', fontsize=9, color=color ) ax1.set_xlim(right=text_x + 1.5) plt.close(fig1) # Create the second figure fig2, ax2 = plt.subplots(figsize=(10, vsize), layout="constrained") # type: ignore fig2.suptitle(title + " - Scores Above Threshold", fontsize=16, fontweight="bold") # type: ignore # Create a DataFrame to structure data for the heatmap data_to_plot = DataFrame( data=scores > threshold, columns=pd.Index(cols, name="Time (s)"), index=pd.Index(ch_names, name="Channel"), ) # Define a custom colormap using provided color stops and base colors base_colors = ['red', 'red', 'white', 'white'] colors = list(zip(color_stops[1], base_colors[:len(color_stops[1])])) cmap = mcolors.LinearSegmentedColormap.from_list('gyr', colors) # Plot heatmap of scores sns.heatmap( # type: ignore data=data_to_plot, vmin=0, vmax=1, cmap=cmap, cbar_kws=dict(label="Score"), ax=ax2, ) # Add vertical dashed lines at each time boundary, sit the title, and place a black strikethrough through a bad channel for x in range(1, len(times)): ax2.axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") # type: ignore ax2.set_title("Scores > Threshold", fontweight="bold") # type: ignore markbad(data, ax2, ch_names) plt.close(fig2) return fig1, fig2 def markbad(data: BaseRaw, ax: Axes, ch_names: list[str]) -> None: """ Add a strikethrough to a plot for channels marked as bad. Parameters ---------- data : BaseRaw The loaded data object to process. ax : Axes Matplotlib Axes object where the strikethrough lines will be drawn. ch_names : list[str] List of channel names corresponding to the y-axis of the plot. """ # Iterate over all the channels for i, ch in enumerate(ch_names): # If it is marked as bad, place a strikethrough on the channel if ch in data.info["bads"]: ax.axhline(i + 0.5, ls="solid", lw=4, color="black", zorder=10) # type: ignore def calculate_scalp_coupling(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5) -> tuple[list[str], Figure, Figure]: """ Calculate the scalp coupling index (SCI) and identify bad channels based on a threshold. Parameters ---------- data : BaseRaw The loaded data object to process. l_freq : float, optional Low cutoff frequency for bandpass filtering in Hz (default is 0.7). h_freq : float, optional High cutoff frequency for bandpass filtering in Hz (default is 1.5) Returns ------- tuple[list[str], Figure, Figure] - list[str]: Channel names identified as bad based on SCI threshold. - Figure: Heatmap of all SCI scores across time and channels. - Figure: Binary heatmap of SCI scores exceeding the threshold. """ logger.info("Calculating scalp coupling index...") # Compute the SCI _, scores, times = cast(tuple[NDArray[float64], NDArray[float64], list[tuple[float]]], scalp_coupling_index_windowed_raw(data, time_window=SCI_TIME_WINDOW, l_freq=l_freq, h_freq=h_freq)) # Identify channels that don't meet the provided threshold logger.info("Identifying channels that do not meet the threshold...") sci = scores.mean(axis=1) data.info["bads"] = list(compress(cast(list[str], getattr(data, "ch_names")), sci < SCI_THRESHOLD)) # Determine the colors based on the threshold, and create the figures logger.info("Creating the figures...") color_stops = ([0.0, SCI_THRESHOLD, SCI_THRESHOLD+0.1, 0.8, 1.0], [0.0, SCI_THRESHOLD, SCI_THRESHOLD, 1.0]) fig1, fig2 = plot_timechannel_quality_metrics(data, scores, times, color_stops, SCI_THRESHOLD, "Scalp Coupling Index") logger.info("Successfully calculated scalp coupling index.") return list(compress(cast(list[str], getattr(data, "ch_names")), sci < SCI_THRESHOLD)), fig1, fig2 def scalp_coupling_index_windowed_raw(data: BaseRaw, time_window: float = 3.0, l_freq: float = 0.7, h_freq: float = 1.5, l_trans_bandwidth: float = 0.3, h_trans_bandwidth: float = 0.3) -> tuple[BaseRaw, NDArray[float64], list[tuple[float, float]]]: """ Compute windowed scalp coupling index (SCI) across fNIRS channels. Parameters ---------- data : BaseRaw The loaded data object to process. time_window : float, optional Length of each time window in seconds (default is 3.0). l_freq : float, optional Low cutoff frequency for filtering in Hz (default is 0.7). h_freq : float, optional High cutoff frequency for filtering in Hz (default is 1.5). l_trans_bandwidth : float, optional Transition bandwidth for the low cutoff in Hz (default is 0.3). h_trans_bandwidth : float, optional Transition bandwidth for the high cutoff in Hz (default is 0.3). Returns ------- tuple[BaseRaw, NDArray[float64], list[tuple[float, float]]] - BaseRaw: The original data object (unchanged). Ensures compatibility with peak_power(). - NDArray[float64]: Correlation scores for each channel and time window. - list[tuple[float, float]]: Time intervals for each window in seconds. """ # Pick only fNIRS channels and sort them by channel name picks: NDArray[np.intp] = mne.pick_types(cast(mne.Info, data.info), fnirs=True) # type: ignore picks = picks[np.argsort([getattr(data, "ch_names")[pick] for pick in picks])] # FIXME: This may happen if the heart rate calculation tries to set a value way too low if l_freq < 0.3: l_freq = 0.3 # Band-pass filter the selected fNIRS channels filtered_data = cast(NDArray[float64], filter_data( getattr(data, "_data"), getattr(data, "info")["sfreq"], l_freq, h_freq, picks=picks, verbose=False, l_trans_bandwidth=l_trans_bandwidth, # type: ignore h_trans_bandwidth=h_trans_bandwidth, # type: ignore )) # Calculate number of samples per time window, the total number of windows, and prepare output variables window_samples = int(np.ceil(time_window * getattr(data, "info")["sfreq"])) n_windows = int(np.floor(len(data) / window_samples)) scores = np.zeros((len(picks), n_windows)) times: list[tuple[float, float]] = [] # Slide through the data in windows to compute scalp coupling index (SCI) for window in range(n_windows): start_sample = int(window * window_samples) end_sample = start_sample + window_samples end_sample = np.min([end_sample, len(data) - 1]) # Track time boundaries for each window t_start = getattr(data, "times")[start_sample] t_stop = getattr(data, "times")[end_sample] times.append((t_start, t_stop)) # Iterate through channels in pairs (hbo, hbr). This requires them to be sorted by channel name for ii in range(0, len(picks), 2): c1: NDArray[float64] = filtered_data[picks[ii]][start_sample:end_sample] c2 = filtered_data[picks[ii + 1]][start_sample:end_sample] # Ensure the correlation data is valid if np.std(c1) == 0 or np.std(c2) == 0 or np.any(np.isnan(c1)) or np.any(np.isnan(c2)): c = 0 else: c = np.corrcoef(c1, c2)[0][1] # Assign the computed correlation to both channels in the pair scores[ii, window] = c scores[ii + 1, window] = c scores = scores[np.argsort(picks)] return data, scores, times def calculate_peak_power(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5) -> tuple[list[str], Figure, Figure]: """ Calculate peak spectral power (PSP) for fNIRS channels and identify bad channels. Parameters ---------- data : BaseRaw The loaded data object to process. l_freq : float, optional Low cutoff frequency for filtering in Hz (default is 0.7) h_freq : float, optional High cutoff frequency for filtering in Hz (default is 1.5) Returns ------- tuple[list[str], Figure, Figure] - list[str]: Names of channels below the PSP threshold. - Figure: Heatmap of all PSP scores. - Figure: Heatmap of scores above the PSP threshold. """ logger.info("Calculating peak spectral power...") # Compute the PSP _, scores, times = cast(tuple[NDArray[float64], NDArray[float64], list[tuple[float]]], peak_power(data, time_window=PSP_TIME_WINDOW, threshold=PSP_THRESHOLD, l_freq=l_freq, h_freq=h_freq, verbose=False)) # Identify channels that don't meet the provided threshold logger.info("Identifying channels that do not meet the threshold...") psp = scores.mean(axis=1) data.info["bads"] = list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)) # Determine the colors based on the threshold, and create the figures logger.info("Creating the figures...") color_stops = ([0.0, PSP_THRESHOLD, PSP_THRESHOLD+0.1, 0.3, 1.0], [0.0, PSP_THRESHOLD, PSP_THRESHOLD, 1.0]) psp1, psp2 = plot_timechannel_quality_metrics(data, scores, times, color_stops, PSP_THRESHOLD, "Peak Spectral Power") logger.info("Successfully calculated peak spectral power.") return list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)), psp1, psp2 def calculate_signal_noise_ratio(data: BaseRaw) -> tuple[list[str], Figure]: """ Calculates the signal-to-noise ratio (SNR) for each channel and identifies those below a defined threshold. Parameters ---------- data : BaseRaw The loaded data object to process. Returns ------- tuple[list[str], Figure] - list[str]: A list of channel names that fall below the SNR threshold and are considered bad. - Figure: A matplotlib Figure showing the channels' SNR values. """ logger.info("Calculating signal to noise ratio...") # Compute the signal-to-noise ratio values logger.info("Computing the signal to noise power...") signal_band=(0.01, 0.5) noise_band=(1.0, 10.0) data_signal = data.copy().filter(*signal_band, verbose=False) #type: ignore data_noise = data.copy().filter(*noise_band, verbose=False) #type: ignore signal_power = np.mean(data_signal.get_data()**2, axis=1) #type: ignore noise_power = np.mean(data_noise.get_data()**2, axis=1) #type: ignore # Calculate the snr using the standard formula for dB snr = 10 * np.log10(signal_power / (noise_power + np.finfo(float).eps)) # TODO: Understand what this does groups: dict[str, list[str]] = {} for ch in getattr(data, "ch_names"): # Look for the space in the channel names and remove the characters after # This is so we can get both oxy and deoxy to remove, as they will have the same source and detector base = ch.rsplit(' ', 1)[0] groups.setdefault(base, []).append(ch) # type: ignore # If any of the channels do not meet our threshold, they will get inserted into the bad_channels set bad_channels: set[str] = set() for base, ch_list in groups.items(): if any(s < SNR_THRESHOLD for s, ch in zip(snr, getattr(data, "ch_names")) if ch in ch_list): bad_channels.update(ch_list) # Design and create the figure logger.info("Creating the figure...") snr_fig, ax = plt.subplots(figsize=(12, 4), layout="constrained") # type: ignore colors = [(0/20, 'red'), (SNR_THRESHOLD/20, 'red'), ((SNR_THRESHOLD+.5)/20, 'yellow'), ((SNR_THRESHOLD+1)/20, 'green'), (20/20, 'green')] cmap = LinearSegmentedColormap.from_list('custom_snr_cmap', colors) norm = mcolors.Normalize(vmin=0, vmax=20) scatter = ax.scatter(range(len(snr)), snr, c=snr, cmap=cmap, alpha=0.8, s=100, norm=norm) # type: ignore ax.set(xlabel="Channel Number", ylabel="Signal-to-Noise Ratio (dB)", xlim=[0, len(snr)], ylim=[0, 20]) ax.axhline(SNR_THRESHOLD, color='black', linestyle='--', alpha=0.3, linewidth=1) # type: ignore cbar = snr_fig.colorbar(scatter, ax=ax, label="SNR Thresholds (dB)") # type: ignore cbar.set_ticks([0, SNR_THRESHOLD, SNR_THRESHOLD+1, 20]) # type: ignore cbar.set_ticklabels(['0', str(SNR_THRESHOLD), str(SNR_THRESHOLD+1), '20']) # type: ignore plt.close() logger.info("Successfully calculated signal to noise ratio.") return list(bad_channels), snr_fig def mark_bad_channels(data: BaseRaw, ID: str, bad_channels_sci: set[str], bad_channels_psp: set[str], bad_channels_snr: set[str]) -> tuple[BaseRaw, Figure, int]: """ Drops bad channels from the data and generates a bar plot showing which channels were removed and why. Parameters ---------- data : BaseRaw The loaded data object to process. ID : str File name of the the snirf file that was loaded. bad_channels_sci : set[str] Channels marked as bad by the SCI method. bad_channels_psp : set[str] Channels marked as bad by the PSP method. bad_channels_snr : set[str] Channels marked as bad by the SNR method. Returns ------- tuple[BaseRaw, Figure] - BaseRaw: The modified data object with bad channels removed. - Figure: A matplotlib Figure showing the dropped channels categorized by method. """ logger.info("Dropping the channels that were marked bad...") # Combine all of the bad channels into one and ensure the short channel is not present bad_channels = bad_channels_sci | bad_channels_psp | bad_channels_snr logger.info(f"Channels that were bad on SCI: {bad_channels_sci}") logger.info(f"Channels that were bad on PSP: {bad_channels_psp}") logger.info(f"Channels that were bad on SNR: {bad_channels_snr}") logger.info(f"Total bad channels: {bad_channels}") # Add the channles to the bads key and drop the bads key from the data data.info["bads"] = list(bad_channels) data = cast(BaseRaw, data.drop_channels(getattr(data, "info")["bads"])) # type: ignore # Organize channels into categories sets = [ (bad_channels_sci, "SCI"), (bad_channels_psp, "PSP"), (bad_channels_snr, "SNR"), ] # Graph what channels were dropped and why they were dropped channel_categories: dict[str, str] = {} for ch in bad_channels: present_in = [name for s, name in sets if ch in s] # Create a label for the category if len(present_in) == 1: label = f"{present_in[0]} only" else: label = " + ".join(sorted(present_in)) channel_categories[ch] = label # Sort channels alphabetically within categories for nicer visualization logger.info("Sorting the bad channels by type...") categories = sorted(set(channel_categories.values())) channel_names: list[str] = [] category_labels: list[str] = [] for cat in categories: chs_in_cat = sorted([ch for ch, c in channel_categories.items() if c == cat]) channel_names.extend(chs_in_cat) category_labels.extend([cat] * len(chs_in_cat)) colors = {cat: FIXED_CATEGORY_COLORS[cat] for cat in categories} logger.info("Creating the figure...") # Create the figure fig, ax = plt.subplots(figsize=(10, max(3, len(channel_names) * 0.3))) # type: ignore y_pos = range(len(channel_names)) ax.barh(y_pos, [1]*len(channel_names), color=[colors[cat] for cat in category_labels]) # type: ignore ax.set_yticks(y_pos) # type: ignore ax.set_yticklabels(channel_names) # type: ignore ax.set_xlabel("Marked as Bad") # type: ignore ax.set_title(f"Bad Channels by Method for {ID}") # type: ignore ax.set_xlim(0, 1) ax.set_xticks([]) # type: ignore ax.grid(axis='x', linestyle='--', alpha=0.7) # type: ignore # Add a legend denoting why the channels were bad for label, color in colors.items(): ax.bar(0, 0, color=color, label=label) # type: ignore ax.legend() # type: ignore fig.tight_layout() plt.close(fig) logger.info("Successfully dropped the channels that were marked bad.") return data, fig, len(bad_channels) def calculate_optical_density(data: BaseRaw, ID: str) -> tuple[BaseRaw, Figure]: """ Converts raw intensity data to optical density and generates a plot of the transformed signals. Parameters ---------- data : BaseRaw The loaded data object to process. ID : str File name of the the snirf file that was loaded. Returns ------- tuple[BaseRaw, Figure] - BaseRaw: The transformed data in optical density format. - Figure: A matplotlib figure displaying the optical density signals across all channels. """ logger.info("Calculating optical density...") # Calculate the optical density from the raw data optical_density_data = cast(BaseRaw, optical_density(data)) logger.info("Creating the figure...") fig = cast(Figure, optical_density_data.plot(show=False, n_channels=len(getattr(data, "ch_names")), duration=getattr(data, "times")[-1]).figure) # type: ignore fig.suptitle(f"Optical density data for {ID}", fontsize=16) # type: ignore fig.subplots_adjust(top=0.92) plt.close(fig) logger.info("Successfully calculated optical density.") return optical_density_data, fig # STEP 9: Haemoglobin concentration def calculate_haemoglobin_concentration(optical_density_data: BaseRaw, ID: str, file_path: str) -> tuple[BaseRaw, Figure]: """ Calculates haemoglobin concentration from optical density data using the Beer-Lambert law and generates a plot. Parameters ---------- optical_density_data : BaseRaw The data in optical density format. ID : str File name of the the snirf file that was loaded. file_path : str Entire file path if snirf file that was loaded. Returns ------- tuple[BaseRaw, Figure] - BaseRaw: The haemoglobin concentration data object. - Figure: A matplotlib figure displaying the haemoglobin concentration signals. """ logger.info("Calculating haemoglobin concentration data...") # Get the haemoglobin concentration using beer lambert law haemoglobin_concentration_data = beer_lambert_law(optical_density_data, ppf=calculate_dpf(file_path)) logger.info("Creating the figure...") fig = cast(Figure, optical_density_data.plot(show=False, n_channels=len(getattr(optical_density_data, "ch_names")), duration=getattr(optical_density_data, "times")[-1]).figure) # type: ignore fig.suptitle(f"Haemoglobin concentration data for {ID}", fontsize=16) # type: ignore fig.subplots_adjust(top=0.92) plt.close(fig) logger.info("Successfully calculated haemoglobin concentration data.") return haemoglobin_concentration_data, fig # -------------------------------------- HARDCODED ----------------------------------------------- def extract_normal_epochs(haemoglobin_concentration_data: BaseRaw) -> dict[str, list[Any] | mne.evoked.EvokedArray]: events, _ = mne.events_from_annotations(haemoglobin_concentration_data, event_id={"Reach": 1, "Start of Rest": 2}, verbose=VERBOSITY) # type: ignore event_dict = {"Reach": 1, "Start of Rest": 2} epochs = mne.Epochs( haemoglobin_concentration_data, events, event_id=event_dict, tmin=TIME_MIN_THRESH, tmax=TIME_MAX_THRESH, reject=dict(hbo=EPOCH_REJECT_CRITERIA_THRESH), reject_by_annotation=True, proj=True, baseline=(None, 0), preload=True, detrend=None, verbose=VERBOSITY, ) evoked_dict: dict[str, list[Any] | mne.evoked.EvokedArray] = { "Reach/HbO": epochs["Reach"].average(picks="hbo"), # type: ignore "Reach/HbR": epochs["Reach"].average(picks="hbr"), # type: ignore } # Rename channels until the encoding of frequency in ch_name is fixed for condition in evoked_dict: evoked_dict[condition].rename_channels(lambda x: x[:-4]) # type: ignore return evoked_dict def calculate_and_apply_negative_correlation_enhancement(haemoglobin_concentration_data: BaseRaw) -> dict[str, list[Any] | mne.evoked.EvokedArray]: events, _ = mne.events_from_annotations(haemoglobin_concentration_data, event_id={"Reach": 1, "Start of Rest": 2}, verbose=VERBOSITY) # type: ignore event_dict = {"Reach": 1, "Start of Rest": 2} raw_anti = enhance_negative_correlation(haemoglobin_concentration_data) epochs_anti = mne.Epochs( raw_anti, events, event_id=event_dict, tmin=TIME_MIN_THRESH, tmax=TIME_MAX_THRESH, reject=dict(hbo=EPOCH_REJECT_CRITERIA_THRESH), reject_by_annotation=True, proj=True, baseline=(None, 0), preload=True, detrend=None, verbose=VERBOSITY, ) evoked_dict_anti: dict[str, list[Any] | mne.evoked.EvokedArray] = { "Reach/HbO": epochs_anti["Reach"].average(picks="hbo"), # type: ignore "Reach/HbR": epochs_anti["Reach"].average(picks="hbr"), # type: ignore } # Rename channels until the encoding of frequency in ch_name is fixed for condition in evoked_dict_anti: evoked_dict_anti[condition].rename_channels(lambda x: x[:-4]) # type: ignore return evoked_dict_anti def calculate_and_apply_short_channel_correction(optical_density_data: BaseRaw, file_path: str) -> dict[str, list[Any] | mne.evoked.EvokedArray]: od_corrected = short_channel_regression(optical_density_data, SHORT_CHANNEL_THRESH) haemoglobin_concentration_data = beer_lambert_law(od_corrected, ppf=calculate_dpf(file_path)) events, _ = mne.events_from_annotations(haemoglobin_concentration_data, event_id={"Reach": 1, "Start of Rest": 2}, verbose=VERBOSITY) # type: ignore event_dict = {"Reach": 1, "Start of Rest": 2} epochs_corr = mne.Epochs( haemoglobin_concentration_data, events, event_id=event_dict, tmin=TIME_MIN_THRESH, tmax=TIME_MAX_THRESH, reject=dict(hbo=EPOCH_REJECT_CRITERIA_THRESH), reject_by_annotation=True, proj=True, baseline=(None, 0), preload=True, detrend=None, verbose=VERBOSITY, ) evoked_dict_corr: dict[str, list[Any] | mne.evoked.EvokedArray] = { "Reach/HbO": epochs_corr["Reach"].average(picks="hbo"), # type: ignore "Reach/HbR": epochs_corr["Reach"].average(picks="hbr"), # type: ignore } # Rename channels until the encoding of frequency in ch_name is fixed for condition in evoked_dict_corr: evoked_dict_corr[condition].rename_channels(lambda x: x[:-4]) # type: ignore return evoked_dict_corr def signal_enhancement_techniques_images(evoked_dict: dict[str, list[Any] | mne.evoked.EvokedArray], evoked_dict_anti: dict[str, list[Any] | mne.evoked.EvokedArray], evoked_dict_corr:dict[str, list[Any] | mne.evoked.EvokedArray] | None): # If we have two images, ensure we only have two columns if evoked_dict_corr is None: fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 6)) # type: ignore # If we have three images, ensure we have three columns else: fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 6)) # type: ignore color_dict = dict(HbO="#AA3377", HbR="b") # TODO: This is to prevent the warning that we are only plotting one channel. Don't we want all though? mne.set_log_level('WARNING') # type: ignore logger.info("Creating the figure...") # Plot the graph for the original data mne.viz.plot_compare_evokeds( # type: ignore evoked_dict, combine="mean", ci=0.95, # type: ignore axes=axes[0], colors=color_dict, ylim=dict(hbo=[-10, 15]), show=False, ) # Plot the graph for the enhanced anticorrelation data mne.viz.plot_compare_evokeds( # type: ignore evoked_dict_anti, combine="mean", ci=0.95, # type: ignore axes=axes[1], colors=color_dict, ylim=dict(hbo=[-10, 15]), show=False, ) # Plot the graph for short channel regression data, if it exists if evoked_dict_corr is not None: mne.viz.plot_compare_evokeds( # type: ignore evoked_dict_corr, combine="mean", ci=0.95, # type: ignore axes=axes[2], colors=color_dict, ylim=dict(hbo=[-10, 15]), show=False, ) mne.set_log_level('INFO') # type: ignore # If we have a short channel, set three titles if evoked_dict_corr is not None: for column, condition in enumerate( ["Original Data", "With Enhanced Anticorrelation", "With Short Regression"] ): axes[column].set_title(f"{condition}") # If we do not have a short channel, set two titles else: for column, condition in enumerate( ["Original Data", "With Enhanced Anticorrelation"] ): axes[column].set_title(f"{condition}") plt.close(fig) return fig def create_design_matrix(data: BaseRaw, stim_duration: float, short_chans: BaseRaw | None) -> tuple[DataFrame, Figure]: """ Creates a design matrix for first-level analysis including optional short channel regression, and generates a plot. Parameters ---------- data : BaseRaw The loaded data object to process. stim_duration : float Duration of the stimulus/event in seconds. short_chans : BaseRaw | None Data object containing only short channels for systemic component regression, or None if there is no short channels. Returns ------- tuple[DataFrame, Figure] - DataFrame: The generated design matrix. - Figure: A matplotlib figure visualizing the design matrix. """ # Create the design martix logger.info("Creating the design matrix... (This may take some time)") # If the design matrix is fir, calculate some of the extra required parameters before creating the matrix if HRF_MODEL == "fir": sfreq = getattr(data, "info")["sfreq"] fir_delays = range(int(sfreq*15)) design_matrix = make_first_level_design_matrix( data, stim_dur=0.1, hrf_model=HRF_MODEL, drift_model=DRIFT_MODEL, high_pass=1/(2*DURATION_BETWEEN_ACTIVITIES), fir_delays=fir_delays ) # Using a canonical hrf model else: design_matrix = make_first_level_design_matrix( data, stim_dur=stim_duration, hrf_model=HRF_MODEL, drift_model=DRIFT_MODEL, high_pass=1/(2*DURATION_BETWEEN_ACTIVITIES), ) # If we have a short channel, and short channel regression was specified, apply it to the design matrix if short_chans is not None: if SHORT_CHANNEL_REGRESSION: logger.info("Applying short channel regression...") for chan in range(len(short_chans.ch_names)): # type: ignore design_matrix[f"short_{chan}"] = short_chans.get_data(chan).T # type: ignore logger.info("Creating the figure...") fig, ax1 = plt.subplots(figsize=(10, 6), constrained_layout=True) # type: ignore plot_design_matrix(design_matrix, axes=ax1) plt.close(fig) logger.info("Successfully created the design matrix.") return design_matrix, fig def run_GLM_analysis(data: BaseRaw, design_matrix: DataFrame) -> RegressionResults: """ Runs a General Linear Model (GLM) analysis on the provided data using the specified design matrix. Parameters ---------- data : BaseRaw The loaded data object to process. design_matrix : DataFrame The design matrix specifying regressors for the GLM. Returns ------- RegressionResults The fitted GLM results object containing regression coefficients and statistics. """ logger.info("Running the GLM...") glm_est = run_glm(data, design_matrix, n_jobs=N_JOBS) logger.info("Successfully ran the GLM.") return glm_est def calculate_dpf(file_path): # order is hbo / hbr import h5py with h5py.File(file_path, 'r') as f: wavelengths = f['/nirs/probe/wavelengths'][:] print("Wavelengths (nm):", wavelengths) wavelengths = sorted(wavelengths, reverse=True) data = METADATA.get(file_path) if data is None: age = 25 else: age = data['Age'] logger.info(age) age = float(age) a = 223.3 b = 0.05624 c = 0.8493 d = -5.723e-7 e = 0.001245 f = -0.9025 dpf = [] for w in wavelengths: logger.info(w) dpf.append(a + b * (age**c) + d* (w**3) + e * (w**2) + f*w) logger.info(dpf) return dpf def individual_GLM_analysis(file_path: str, ID: str, stim_duration: float = 5.0, progress_callback=None) -> tuple[BaseRaw, BaseRaw, DataFrame, DataFrame, DataFrame, DataFrame, dict[str, Figure], str, bool, bool]: """ Performs individual-level General Linear Model (GLM) analysis on fNIRS data from a SNIRF file. Parameters ---------- file_path : str Path to the SNIRF file containing the participant's raw data. ID : str Unique identifier for the participant, used for labeling output. stim_duration : float, optional Duration of the stimulus in seconds for constructing the design matrix (default is 5.0) Returns ------- tuple[BaseRaw, BaseRaw, DataFrame, DataFrame, DataFrame, DataFrame, dict[str, Figure], str, bool, bool] - BaseRaw: Processed fNIRS data - BaseRaw: Full layout raw data prior to bad channel rejection - DataFrame: Region of interest statistics - DataFrame: Channel-level GLM statistics - DataFrame: Contrast results - DataFrame: Design matrix used for GLM - dict[str, Figure]: Dictionary of figures generated during processing - str: Description of processing steps applied - bool: Whether the GLM successfully ran to completion - bool: Whether the analysis result is valid based on quality checks """ # Setting up variables to be used later fig_dict: dict[str, Figure] = {} bad_channels_sci = [] bad_channels_psp = [] bad_channels_snr = [] mean_hr_nk = 70 mean_hr_scipy = 70 num_bad_channels = 0 valid = True short_chans = None roi: DataFrame = DataFrame() cha: DataFrame = DataFrame() con: DataFrame = DataFrame() design_matrix = DataFrame() # Load the file, get the sources and detectors, update their position, and calculate the short channel and any large distance channels # STEP 1 data, fig = load_snirf(file_path, ID, FORCE_DROP_CHANNELS) fig_dict['Raw'] = fig order_of_operations = "Loaded Raw File" if progress_callback: progress_callback(1) # Initalize the participants full layout to be the current data regardless if it will be updated later raw_full_layout = data logger.info(file_path) logger.info(ID) logger.info(METADATA.get(file_path)) calculate_dpf(file_path) try: # Did the user want to load new channel positions from an optode file? # STEP 2 if OPTODE_FILE: data = calculate_and_apply_updated_optode_coordinates(data) order_of_operations += " + Updated Optode Placements" if progress_callback: progress_callback(2) # STEP 2.5 # TODO: remember why i do this # I think its because i want a participants whole layout to plot without any bads # but i shouldnt need to do od and bll just check the last three numbers temp = data.copy() temp_od = cast(BaseRaw, optical_density(temp, verbose=VERBOSITY)) raw_full_layout = beer_lambert_law(temp_od, ppf=calculate_dpf(file_path)) # If specified, apply TDDR to the data # STEP 3 if TDDR: data, fig = calculate_and_apply_tddr(data, ID) order_of_operations += " + TDDR Filter" fig_dict['TDDR'] = fig if progress_callback: progress_callback(3) # If specified, apply a wavelet filter to the data # STEP 4 if WAVELET: data, fig = calculate_and_apply_wavelet(data, ID) order_of_operations += " + Wavelet Filter" fig_dict['Wavelet'] = fig if progress_callback: progress_callback(4) # If specified, attempt to get short channels from the data # STEP 4.5 if SHORT_CHANNEL: try: short_chans = get_short_channels(data, SHORT_CHANNEL_THRESH) except Exception as e: raise ProcessingError("SHORT_CHANNEL was specified, but no short channel was found. Please ensure the data has a short channel and that SHORT_CHANNEL_THRESH is set correctly.") pass else: pass # Ensure that there is no short or really long channels in the data data = get_long_channels(data, SHORT_CHANNEL_THRESH, LONG_CHANNEL_THRESH) # STEP 5 if HEART_RATE: sfreq, signal_trimmed, times_trimmed = short_channel_processing_for_hr(data, short_chans) hr_smooth_nk, mean_hr_nk = calculate_heart_rate_neurokit(sfreq, signal_trimmed) freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy = calculate_heart_rate_scipy(sfreq, signal_trimmed) # HACK: This sucks but looking at the graphs I trust neurokit2 more overruled = False if mean_hr_scipy < mean_hr_nk - 15: mean_hr_scipy = mean_hr_nk overruled = True if mean_hr_scipy > mean_hr_nk + 15: mean_hr_scipy = mean_hr_nk overruled = True hr1, hr2 = plot_heart_rate(freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy, hr_smooth_nk, mean_hr_nk, times_trimmed, overruled) order_of_operations += " + Heart Rate Calculation" fig_dict['HeartRate_PSD'] = hr1 fig_dict['HeartRate_Time'] = hr2 if progress_callback: progress_callback(5) # If specified, calculate and apply SCI # STEP 6 if SCI: bad_channels_sci, sci1, sci2 = calculate_scalp_coupling(data.copy(), min(mean_hr_nk - HEART_RATE_WINDOW, mean_hr_scipy - HEART_RATE_WINDOW) / 60, max(mean_hr_nk + HEART_RATE_WINDOW, mean_hr_scipy + HEART_RATE_WINDOW) / 60) order_of_operations += " + SCI Calculation" fig_dict['SCI1'] = sci1 fig_dict['SCI2'] = sci2 if progress_callback: progress_callback(6) # If specified, calculate and apply PSP if PSP: bad_channels_psp, psp1, psp2 = calculate_peak_power(data.copy(), min(mean_hr_nk - HEART_RATE_WINDOW, mean_hr_scipy - HEART_RATE_WINDOW) / 60, max(mean_hr_nk + HEART_RATE_WINDOW, mean_hr_scipy + HEART_RATE_WINDOW) / 60) order_of_operations += " + PSP Calculation" fig_dict['PSP1'] = psp1 fig_dict['PSP2'] = psp2 if progress_callback: progress_callback(7) # If specified, calculate and apply SNR if SNR: bad_channels_snr, fig = calculate_signal_noise_ratio(data.copy()) order_of_operations += " + SNR Calculation" fig_dict['SNR'] = fig # If specified, drop channels that were marked as bad # STEP 7 if EXCLUDE_CHANNELS: data, fig, num_bad_channels = mark_bad_channels(data, ID, set(bad_channels_sci), set(bad_channels_psp), set(bad_channels_snr)) order_of_operations += " + Excluded Bad Channels" fig_dict['Bads'] = fig if progress_callback: progress_callback(7) # Calculate the optical density # STEP 8 data, fig = calculate_optical_density(data, ID) order_of_operations += " + Optical Density" fig_dict['OpticalDensity'] = fig if progress_callback: progress_callback(8) # Mainly for visualization. Could be implemented in the future # STEP 8.5 evoked_dict_corr = None if SHORT_CHANNEL: short_chans_od = cast(BaseRaw, optical_density(short_chans)) data_recombined = cast(BaseRaw, data.copy().add_channels([short_chans_od])) # type: ignore evoked_dict_corr = calculate_and_apply_short_channel_correction(data_recombined.copy(), file_path) # Calculate the haemoglobin concentration # STEP 9 data, fig = calculate_haemoglobin_concentration(data, ID, file_path) order_of_operations += " + Haemoglobin Concentration" fig_dict['HaemoglobinConcentration'] = fig if progress_callback: progress_callback(9) # Mainly for visualization. Could be implemented in the future # STEP 9.5 evoked_dict = extract_normal_epochs(data.copy()) evoked_dict_anti = calculate_and_apply_negative_correlation_enhancement(data.copy()) fig = signal_enhancement_techniques_images(evoked_dict, evoked_dict_anti, evoked_dict_corr) fig_dict['SignalEnhancement'] = fig # Create the design matrix # STEP 10 # HACK FIXME - Downsampling to 10 is certaintly not the best way... right? if HRF_MODEL == 'fir': data.resample(10, verbose=VERBOSITY) # type: ignore if short_chans is not None: short_chans.resample(10, verbose=VERBOSITY) # type: ignore design_matrix, fig = create_design_matrix(data, stim_duration, short_chans) order_of_operations += " + Design Matrix" fig_dict['DesignMatrix'] = fig if progress_callback: progress_callback(10) # Run the glm on the design matrix # STEP 11 glm_est: RegressionResults = run_GLM_analysis(data, design_matrix) order_of_operations += " + GLM" if progress_callback: progress_callback(11) # Add the regions of interest to the groups # STEP 12 logger.info("Performing the finishing touches...") order_of_operations += " + Finishing Touches" # Extract the channel metrics logger.info("Calculating channel results...") cha = cast(DataFrame, glm_est.to_dataframe()) # type: ignore logger.info("Creating groups...") if HRF_MODEL == "fir": groups = dict(AllChannels=range(len(data.ch_names))) # type: ignore else: groups: dict[str, list[int]] = dict( group_1_picks = picks_pair_to_idx(data, ROI_GROUP_1, on_missing="ignore"), # type: ignore group_2_picks = picks_pair_to_idx(data, ROI_GROUP_2, on_missing="ignore"), # type: ignore ) # Compute region of interest results from the channel data logger.info("Calculating region of intrest results...") roi = glm_est.to_dataframe_region_of_interest(groups, design_matrix.columns, demographic_info=True) # type: ignore # Create the contrast matrix logger.info("Creating the contrast matrix...") contrast_matrix = np.eye(design_matrix.shape[1]) basic_conts = dict( [(column, contrast_matrix[i]) for i, column in enumerate(design_matrix.columns)] ) # Calculate contrast differently depending on the hrf model if HRF_MODEL == 'fir': # Find all FIR regressors for TARGET_ACTIVITY delay_cols = [col for col in design_matrix.columns if col.startswith(f"{TARGET_ACTIVITY}_delay_")] if not delay_cols: raise ValueError(f"No FIR regressors found for condition {TARGET_ACTIVITY}.") # Sum or average their contrast vectors fir_contrast = np.sum([basic_conts[col] for col in delay_cols], axis=0) fir_contrast /= len(delay_cols) # Compute contrast contrast = glm_est.compute_contrast(fir_contrast) # type: ignore con = cast(DataFrame, contrast.to_dataframe()) # type: ignore else: # Create and compute the contrast contrast_t = basic_conts[TARGET_ACTIVITY] contrast = glm_est.compute_contrast(contrast_t) # type: ignore con = cast(DataFrame, contrast.to_dataframe()) # type: ignore # Add the participant ID to the dataframes roi["ID"] = cha["ID"] = con["ID"] = design_matrix["ID"] = ID # Convert to uM for nicer plotting below. logger.info("Converting to uM...") cha["theta"] = cha["theta"].astype(float) * 1.0e6 roi["theta"] = roi["theta"].astype(float) * 1.0e6 con["effect"] = con["effect"].astype(float) * 1.0e6 # If we exceed the maximum allowed bad channels, apply an X over the figures logger.info("Checking amount of bad channels...") if num_bad_channels >= MAX_BAD_CHANNELS: valid=False logger.info("Drawing some big X's...") for _, fig in fig_dict.items(): add_x_overlay(fig, 'Too many bad channels!', 'red') logger.info("Completed individual analysis.") if progress_callback: progress_callback(12) # Clear the output for the next participant unless we are told to be verbose if not VERBOSITY: clear_output(wait=True) # Something really went wrong and we should not continue except ProcessingError as e: logger.info("An error occured!", e) raise # Something went wrong at one of the steps. Return what data we gathered, but set the validity of this run to False except Exception as e: logger.info("An error occured!", e) fig_dict_bytes = convert_fig_dict_to_png_bytes(fig_dict) return data, raw_full_layout, roi, cha, con, design_matrix, fig_dict, order_of_operations, False, False fig_dict_bytes = convert_fig_dict_to_png_bytes(fig_dict) return data, raw_full_layout, roi, cha, con, design_matrix, fig_dict_bytes, order_of_operations, True, valid def add_x_overlay(fig: Figure, reason: str, color: str) -> None: """ Adds a large 'X' across the figure if the participant met the bad channel criteria. Parameters ---------- fig : Figure Matplotlib figure to draw the X on. reason: str Why the X is being drawn. color: str What color the reason should be. """ # Draw the big X on the graph ax = fig.add_axes([0, 0, 1, 1], zorder=100) # type: ignore ax.set_axis_off() ax.plot([0, 1], [0, 1], color='red', linewidth=8, transform=fig.transFigure, clip_on=False) # type: ignore ax.plot([0, 1], [1, 0], color='red', linewidth=8, transform=fig.transFigure, clip_on=False) # type: ignore ax.text(0.5, 0.5, reason, color=color, fontsize=26, fontweight='bold', ha='center', va='center', transform=fig.transFigure, zorder=101, bbox=dict(facecolor='white', alpha=0.8, edgecolor='red', boxstyle='round,pad=0.4')) # type: ignore from io import BytesIO def convert_fig_dict_to_png_bytes(fig_dict): result = {} for key, fig in fig_dict.items(): buf = BytesIO() fig.savefig(buf, format='png') buf.seek(0) result[key] = buf.read() return result def process_file_worker(args): file_path, file_name, stim_duration, config, gui, progress_queue = args try: set_config(config, gui) def progress_callback(step_idx): print(f"[Worker] Step {step_idx} for {file_name}") if progress_queue: progress_queue.put(('progress', file_name, step_idx)) result = individual_GLM_analysis( file_path, file_name, stim_duration, progress_callback=progress_callback ) return file_name, result, None except Exception as e: return file_name, None, e def process_folder(folder_path: str, stim_duration: float, files_remaining: dict[str, int], config , gui: bool = False, progress_queue=None) -> tuple[dict[str, dict[str, BaseRaw]], DataFrame, DataFrame, DataFrame, DataFrame, dict[str, list[Figure]], dict[str, str]]: df_roi = DataFrame() df_cha = DataFrame() df_con = DataFrame() df_design_matrix = DataFrame() raw_haemo_dict: dict[str, dict[str, BaseRaw]] = {} process_dict: dict[str, str] = {} figures_by_step: dict[str, list[Figure]] = { step: [] for step in [ 'Raw', 'TDDR', 'Wavelet', 'HeartRate_PSD', 'HeartRate_Time', 'SCI1', 'SCI2', 'PSP1', 'PSP2', 'SNR', 'Bads', 'OpticalDensity', 'HaemoglobinConcentration', 'SignalEnhancement', 'DesignMatrix' ] } file_args = [ (os.path.join(folder_path, file_name), file_name, stim_duration, config, gui, progress_queue) for file_name in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, file_name)) ] print("[process_folder] File args:", file_args) available_mem = psutil.virtual_memory().available if (MAX_WORKERS >= available_mem / (1024 ** 3)): print(f"WARNING: You have set MAX_WORKERS to {MAX_WORKERS}. Each worker should have at least 1GB of system memory. Your device currently has a total of {available_mem / (1024 ** 3):.2f}GB free.\nPlease consider lowering MAX_WORKERS to prevent potential crashing due to insufficient system memory.") with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor: future_to_file = { executor.submit(process_file_worker, args): args[1] for args in file_args } with tqdm(total=len(file_args), desc="Processing files") as pbar: for future in as_completed(future_to_file): file_name = future_to_file[future] files_remaining['count'] -= 1 logger.info(f"Files remaining: {files_remaining['count']}") pbar.update(1) try: file_name, result, error = future.result() if error: logger.info(f"Error processing {file_name}: {error}") continue raw_haemo_filtered, raw_haemo_full, roi, channel, contrast, design_matrix, fig_dict, process, finished, valid = result if finished and valid: logger.info(f"Finished processing {file_name}. This participant was valid.") raw_haemo_dict[file_name] = { "filtered": raw_haemo_filtered, "full_layout": raw_haemo_full } process_dict[file_name] = process for step in figures_by_step: if step in fig_dict: figures_by_step[step].append(fig_dict[step]) df_roi = pd.concat([df_roi, roi], ignore_index=True) df_cha = pd.concat([df_cha, channel], ignore_index=True) df_con = pd.concat([df_con, contrast], ignore_index=True) df_design_matrix = pd.concat([df_design_matrix, design_matrix], ignore_index=True) else: logger.info(f"Finished processing {file_name}. This participant was NOT valid.") if SEE_BAD_IMAGES: for step in figures_by_step: if step in fig_dict: figures_by_step[step].append(fig_dict[step]) except Exception as e: logger.info(f"Unexpected error processing {file_name}: {e}") raise return raw_haemo_dict, df_roi, df_cha, df_con, df_design_matrix, figures_by_step, process_dict def verify_channel_positions(data: BaseRaw) -> None: """ Visualizes the sensor/channel positions of the raw data for verification. Parameters ---------- data : BaseRaw The loaded data object to process. """ logger.info("Creating the figure...") data.plot_sensors(show_names=True, to_sphere=True, show=False, verbose=VERBOSITY) # type: ignore plt.show() # type: ignore def plot_3d_evoked_array( inst: Union[BaseRaw, EvokedArray, Info], statsmodel_df: DataFrame, picks: Optional[Union[str, list[str]]] = "hbo", value: str = "Coef.", background: str = "w", figure: Optional[object] = None, clim: Union[str, dict[str, Union[str, list[float]]]] = "auto", mode: str = "weighted", colormap: str = "RdBu_r", surface: str = "pial", hemi: str = "both", size: int = 800, view: Optional[Union[str, dict[str, float]]] = None, colorbar: bool = True, distance: float = 0.03, subjects_dir: Optional[str] = None, src: Optional[SourceSpaces] = None, verbose: bool = False, ) -> Brain: '''Ported from MNE''' info: Info = cast(Info, deepcopy(inst if isinstance(inst, Info) else inst.info)) # type: ignore if not (getattr(info, "ch_names") == list(statsmodel_df["ch_name"].values)): # type: ignore raise RuntimeError( 'MNE data structure does not match dataframe ' f'results.\nMNE = {getattr(info, "ch_names")}.\n' f'GLM = {list(statsmodel_df["ch_name"].values)}' # type: ignore ) ea = EvokedArray(np.tile(statsmodel_df[value].values.T, (1, 1)).T, info.copy()) # type: ignore # TODO: mimic behaviour of other MNE-NIRS glm plotting options if picks is not None: ea = ea.pick(picks=picks) # type: ignore if subjects_dir is None: subjects_dir = os.environ["SUBJECTS_DIR"] if src is None: fname_src_fs = os.path.join( subjects_dir, "fsaverage", "bem", "fsaverage-ico-5-src.fif" ) src = read_source_spaces(fname_src_fs, verbose=verbose) picks = getattr(ea, "info")["ch_names"] # Set coord frame for idx in range(len(getattr(ea, "ch_names"))): getattr(ea, "info")["chs"][idx]["coord_frame"] = 4 # Generate source estimate kwargs = dict( evoked=ea, subject="fsaverage", trans=Transform('head', 'mri', np.eye(4)), distance=distance, mode=mode, surface=surface, subjects_dir=subjects_dir, src=src, project=True, ) stc = stc_near_sensors(picks=picks, **kwargs, verbose=verbose) # type: ignore from mne import SourceEstimate assert isinstance(stc, SourceEstimate) # or your specific subclass # Produce brain plot brain: Brain = stc.plot( # type: ignore src=src, subjects_dir=subjects_dir, hemi=hemi, surface=surface, initial_time=0, clim=clim, # type: ignore size=size, colormap=colormap, figure=figure, background=background, colorbar=colorbar, verbose=verbose, ) if view is not None: brain.show_view(view) # type: ignore return brain def brain_3d_visualization(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], all_haemo: dict[str, dict[str, BaseRaw]], participant_number: int, t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True) -> None: # Determine if we are visualizing t or theta to set the appropriate limit if t_or_theta == 't': clim = dict(kind="value", pos_lims=(0, ABS_T_VALUE/2, ABS_T_VALUE)) elif t_or_theta == 'theta': clim = dict(kind="value", pos_lims=(0, ABS_THETA_VALUE/2, ABS_THETA_VALUE)) # Loop over all groups for index, group_name in enumerate(all_results): # We only care for their channel results (_, df_cha, _, _) = all_results[group_name] # Get all activity conditions for cond in [TARGET_ACTIVITY]: if HRF_MODEL == 'fir': ch_summary = df_cha.query(f"Condition.str.startswith('{cond}_delay_') and Chroma == 'hbo'", engine='python') # type: ignore else: # Filter for the condition and chromophore ch_summary = df_cha.query("Condition in [@cond] and Chroma == 'hbo'") # type: ignore # Determine number of unique participants based on their ID n_participants = ch_summary["ID"].nunique() # WE JUST NEED SOMEONES OPTODE DATA TO PLOT ON THE BRAIN! # TODO: This should take the average positions of all participants # We will just take the passed through parameter participant_to_plot = ch_summary["ID"].unique()[participant_number] # type: ignore participant_raw_full: BaseRaw = all_haemo[participant_to_plot]["full_layout"] # Use ordinary least squares (OLS) if only one participant if n_participants == 1: # t values if t_or_theta == 't': ch_model = smf.ols("t ~ -1 + ch_name", ch_summary).fit() # type: ignore # theta values elif t_or_theta == 'theta': ch_model = smf.ols("theta ~ -1 + ch_name", ch_summary).fit() # type: ignore logger.info("OLS model is being used as there is only one participant!") # Use mixed effects model if there is multiple participants else: # t values if t_or_theta == 't': ch_model = smf.mixedlm("t ~ -1 + ch_name", ch_summary, groups=ch_summary["ID"]).fit(method="nm") # type: ignore # theta values elif t_or_theta == 'theta': ch_model = smf.mixedlm("theta ~ -1 + ch_name", ch_summary, groups=ch_summary["ID"]).fit(method="nm") # type: ignore # Convert model results model_df = cast(DataFrame, statsmodels_to_results(ch_model, order=ch_summary["ch_name"].unique())) # type: ignore valid_channels = ch_summary["ch_name"].unique().tolist() # type: ignore raw_for_plot = participant_raw_full.copy().pick(picks=valid_channels) # type: ignore brain = plot_3d_evoked_array(raw_for_plot.pick(picks="hbo"), model_df, view="dorsal", distance=BRAIN_DISTANCE, colorbar=True, clim=clim, mode=BRAIN_MODE, size=(800, 700)) # type: ignore if show_optodes == 'all' or show_optodes == 'sensors': brain.add_sensors(getattr(raw_for_plot, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=VERBOSITY) # type: ignore # Read and parse the file if show_optodes == 'all' or show_optodes == 'labels': positions: list[tuple[str, list[float]]] = [] with open(OPTODE_FILE_PATH, 'r') as f: for line in f: line = line.strip() if not line or ':' not in line: continue # skip empty/malformed lines name, coords = line.split(':', 1) coords = [float(x) for x in coords.strip().split()] positions.append((name.strip(), coords)) for name, (x, y, z) in positions: brain._renderer.text3d(x, y, z, name, color=('red' if name.startswith('s') else 'blue' if name.startswith('d') else 'gray'), scale=0.002) # type: ignore # Set the display text for the brain image # display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nGroup: ' + group_name + '\nCondition: '+ cond + '\nReject Criteria Threshold: ' + str(EPOCH_REJECT_CRITERIA_THRESH) + '\nMin Time Threshold: ' # + str(TIME_MIN_THRESH) + 's\nMax Time Threshold: ' + str(TIME_MAX_THRESH) + 's\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) + '\nStim Duration: ' # + str(STIM_DURATION[index]) + 's\nLooking at: ' + t_or_theta + ' values') + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE if HRF_MODEL == 'fir': display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nGroup: ' + group_name + '\nCondition: '+ cond + '\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) + '\nLooking at: ' + t_or_theta + ' values') + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE else: display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nGroup: ' + group_name + '\nCondition: '+ cond + '\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) + '\nStim Duration: ' + str(STIM_DURATION[index]) + '\nLooking at: ' + t_or_theta + ' values') + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE # Apply the text onto the brain if show_text: brain.add_text(0.12, 0.64, display_text, "title", font_size=11, color="k") # type: ignore def plot_fir_model_results(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], all_haemo: dict[str, dict[str, BaseRaw]], participant_number: int, t_or_theta: Literal['t', 'theta'] = 'theta') -> None: if HRF_MODEL != 'fir': logger.info("This method only works when HRF_MODEL is set to 'fir'.") else: for group_name in all_results: (df_roi, _, _, df_design_matrix) = all_results[group_name] first_id = df_design_matrix["ID"].unique()[participant_number] # type: ignore first_dm = df_design_matrix.query(f"ID == '{first_id}'").copy() # type: ignore first_dm.index = np.round([0.1 * i for i in range(len(first_dm))], decimals=1) # type: ignore df_design_matrix = first_dm participant = all_haemo[first_id]["full_layout"] df_roi["isActivity"] = [TARGET_ACTIVITY in n for n in df_roi["Condition"]] # type: ignore df_roi["isDelay"] = ["delay" in n for n in df_roi["Condition"]] # type: ignore df_roi = df_roi.query("isDelay in [True]") # type: ignore df_roi = df_roi.query("isActivity in [True]") # type: ignore df_roi.loc[:, "TidyCond"] = "" df_roi.loc[df_roi["isActivity"] == True, "TidyCond"] = TARGET_ACTIVITY # noqa: E712 # Finally, extract the FIR delay in to its own column in data frame df_roi.loc[:, "delay"] = [n.split("_")[-1] for n in df_roi.Condition] # type: ignore if t_or_theta == 'theta': lme = smf.mixedlm("theta ~ -1 + delay:TidyCond:Chroma", df_roi, groups=df_roi["ID"]).fit() # type: ignore elif t_or_theta == 't': lme = smf.mixedlm("t ~ -1 + delay:TidyCond:Chroma", df_roi, groups=df_roi["ID"]).fit() # type: ignore df_sum: DataFrame = statsmodels_to_results(lme) # type: ignore df_sum["delay"] = [int(n) for n in df_sum["delay"]] # type: ignore df_sum = df_sum.sort_values("delay") # type: ignore # logger.info the result for the oxyhaemoglobin data in the Reach condition df_sum.query(f"TidyCond in ['{TARGET_ACTIVITY}']").query("Chroma in ['hbo']") # type: ignore axes1: list[Axes] fig, axes1 = plt.subplots(nrows=1, ncols=3, figsize=(20, 10)) # type: ignore # Extract design matrix columns that correspond to the condition of interest dm_cond_idxs = np.where([TARGET_ACTIVITY in n for n in df_design_matrix.columns])[0] dm_cond_colnames: list[str] = [df_design_matrix.columns[i] for i in dm_cond_idxs] dm_cond = df_design_matrix[dm_cond_colnames] # 2. Extract hbo GLM estimates df_hbo = df_sum.query(f"TidyCond in ['{TARGET_ACTIVITY}']").query("Chroma in ['hbo']") # type: ignore vals_hbo = [float(v) for v in df_hbo["Coef."]] # type: ignore dm_cond_scaled_hbo = dm_cond * vals_hbo # 3. Extract hbr GLM estimates df_hbr = df_sum.query(f"TidyCond in ['{TARGET_ACTIVITY}']").query("Chroma in ['hbr']") # type: ignore vals_hbr = [float(v) for v in df_hbr["Coef."]] # type: ignore dm_cond_scaled_hbr = dm_cond * vals_hbr # Extract the time scale for plotting. # Set time zero to be the onset. index_values = cast(NDArray[float64], dm_cond_scaled_hbo.index.to_numpy(dtype=float) - participant.annotations.onset[1]) # type: ignore # Plot the result axes1[0].plot(index_values, np.asarray(dm_cond)) # type: ignore axes1[1].plot(index_values, np.asarray(dm_cond_scaled_hbo)) # type: ignore axes1[2].plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r") # type: ignore axes1[2].plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b") # type: ignore valid_mask = (index_values >= 0) & (index_values <= 15) hbo_sum_window = np.sum(dm_cond_scaled_hbo.loc[valid_mask, :], axis=1) peak_idx_in_window = np.argmax(hbo_sum_window) peak_idx = np.where(valid_mask)[0][peak_idx_in_window] peak_time = float(round(index_values[peak_idx], 2)) # type: ignore axes1[2].axvline(x=peak_time, color='k', linestyle='--', linewidth=1.5, label='Peak time') # type: ignore # Format the plot for ax in range(3): axes1[ax].set_xlim(-5, 20) axes1[ax].set_xlabel("Time (s)") # type: ignore axes1[0].set_ylim(-0.2, 1.2) axes1[1].set_ylim(-4, 8) axes1[2].set_ylim(-4, 8) axes1[0].set_title(f"FIR Model for {group_name} (Unscaled by GLM {TARGET_ACTIVITY} estimates) ({t_or_theta})") # type: ignore axes1[1].set_title(f"FIR Components for {group_name} (Scaled by GLM {TARGET_ACTIVITY} estimates) ({t_or_theta})") # type: ignore axes1[2].set_title(f"Evoked Response for {group_name} ({TARGET_ACTIVITY}) ({t_or_theta})") # type: ignore axes1[0].set_ylabel("FIR Model") # type: ignore axes1[1].set_ylabel("Oyxhaemoglobin (ΔμMol)") # type: ignore axes1[2].set_ylabel("Haemoglobin (ΔμMol)") # type: ignore axes1[2].legend(["Oyxhaemoglobin", "Deoyxhaemoglobin", f"Peak {peak_time}s"]) # type: ignore fig.tight_layout() # We can also extract the 95% confidence intervals of the estimates too l95_hbo = [float(v) for v in df_hbo["[0.025"]] # type: ignore u95_hbo = [float(v) for v in df_hbo["0.975]"]] # type: ignore dm_cond_scaled_hbo_l95 = dm_cond * l95_hbo dm_cond_scaled_hbo_u95 = dm_cond * u95_hbo l95_hbr = [float(v) for v in df_hbr["[0.025"]] # type: ignore u95_hbr = [float(v) for v in df_hbr["0.975]"]] # type: ignore dm_cond_scaled_hbr_l95 = dm_cond * l95_hbr dm_cond_scaled_hbr_u95 = dm_cond * u95_hbr axes2: Axes fig, axes2 = plt.subplots(nrows=1, ncols=1, figsize=(7, 7)) # type: ignore # Plot the result axes2.plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r") # type: ignore axes2.plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b") # type: ignore axes2.axvline(x=peak_time, color='k', linestyle='--', linewidth=1.5, label='Peak time') # type: ignore axes2.fill_between( # type: ignore index_values, np.asarray(np.sum(dm_cond_scaled_hbo_l95, axis=1)), np.asarray(np.sum(dm_cond_scaled_hbo_u95, axis=1)), facecolor="red", alpha=0.25, ) axes2.fill_between( # type: ignore index_values, np.asarray(np.sum(dm_cond_scaled_hbr_l95, axis=1)), np.asarray(np.sum(dm_cond_scaled_hbr_u95, axis=1)), facecolor="blue", alpha=0.25, ) # Format the plot axes2.set_xlim(-5, 20) axes2.set_ylim(-8, 12) axes2.set_title(f"Evoked Response with 95% confidence intervals for {group_name} ({TARGET_ACTIVITY}) ({t_or_theta})") # type: ignore axes2.set_ylabel("Haemoglobin (ΔμMol)") # type: ignore axes2.legend(["Oyxhaemoglobin", "Deoyxhaemoglobin", f"Peak {peak_time}s"]) # type: ignore axes2.set_xlabel("Time (s)") # type: ignore fig.tight_layout() plt.show() # type: ignore def plot_2d_theta_graph(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]) -> None: '''This method will create a 2d boxplot showing the theta values for each channel and group as independent ranges on the same graph.\n Inputs:\n all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n ''' # Create a list to store the channel results of all groups df_all_cha_list: list[DataFrame] = [] # Iterate over each group in all_results for group_name, (_, df_cha, _, _) in all_results.items(): df_cha["group"] = group_name # Add the group name to the data df_all_cha_list.append(df_cha) # Append the dataframe to the list # Combine all the data into a single DataFrame df_all_cha = pd.concat(df_all_cha_list, ignore_index=True) # Filter for the target activity if HRF_MODEL == 'fir': df_target = df_all_cha[df_all_cha['Condition'].str.startswith(f"{TARGET_ACTIVITY}_delay_")] # type: ignore else: df_target = df_all_cha[df_all_cha["Condition"] == TARGET_ACTIVITY] # Get the number of unique groups to know how many colors are needed for the boxplot unique_groups = df_target["group"].nunique() palette = sns.color_palette("Set2", unique_groups) # Create the boxplot fig = plt.figure(figsize=(15, 6)) # type: ignore sns.boxplot( data=df_target, x="ch_name", y="theta", hue="group", palette=palette ) # Format the boxplot plt.title("Theta Coefficients by Channel and Group") # type: ignore plt.xticks(rotation=90) # type: ignore plt.ylabel("Theta (µM)") # type: ignore plt.xlabel("Channel") # type: ignore plt.legend(title="Group") # type: ignore plt.tight_layout() plt.show() # type: ignore def plot_individual_theta_averages(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]) -> None: if HRF_MODEL == 'fir': logger.info("This method does not work when HRF_MODEL is set to 'fir'.") return else: # Iterate over all the groups for group_name in all_results: # Store the region of interest data (df_roi, _, _, _) = all_results[group_name] # Filter the results down to what we want grp_results = df_roi.query(f"Condition in ['{TARGET_ACTIVITY}', '{TARGET_CONTROL}']").copy() # type: ignore grp_results = grp_results.query("Chroma in ['hbo']").copy() # type: ignore # Rename the ROI's to be the friendly name roi_label_map = { "group_1_picks": ROI_GROUP_1_NAME, "group_2_picks": ROI_GROUP_2_NAME, } grp_results["ROI"] = grp_results["ROI"].replace(roi_label_map) # type: ignore # Create the catplot sns.catplot( x="Condition", y="theta", col="ID", hue="ROI", data=grp_results, col_wrap=5, errorbar=None, palette="muted", height=4, s=10, dodge=False, ) plt.show() # type: ignore def plot_group_theta_averages(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]) -> None: '''This method will create a stripplot showing the theta vaules for each region of interest for each group.\n Inputs:\n all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n''' if HRF_MODEL == 'fir': logger.info("This method does not work when HRF_MODEL is set to 'fir'.") return else: # Rename the ROI's to be the friendly name roi_label_map = { "group_1_picks": ROI_GROUP_1_NAME, "group_2_picks": ROI_GROUP_2_NAME, } # Setup subplot grid n = len(all_results) ncols = 2 nrows = (n + 1) // ncols # round up fig, axes = cast(tuple[Figure, np.ndarray[Any, Any]], plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 5 * nrows), squeeze=False)) # type: ignore index = -1 # Iterate over all groups for index, (group_name, ax) in enumerate(zip(all_results, axes.flatten())): # Store the region of interest data (df_roi, _, _, _) = all_results[group_name] # Filter the results down to what we want grp_results = df_roi.query(f"Condition in ['{TARGET_ACTIVITY}', '{TARGET_CONTROL}']").copy() # type: ignore # Run a mixedlm model on the data roi_model = smf.mixedlm("theta ~ -1 + ROI:Condition:Chroma", grp_results, groups=grp_results["ID"]).fit(method="nm") # type: ignore # Apply the new friendly names on to the data df = cast(DataFrame, statsmodels_to_results(roi_model)) df["ROI"] = df["ROI"].map(roi_label_map) # type: ignore # Create a stripplot: sns.stripplot( x="Condition", y="Coef.", hue="ROI", data=df.query("Chroma == 'hbo'"), # type: ignore dodge=False, jitter=False, size=5, palette="muted", ax=ax, ) # Format the stripplot ax.set_title(f"Results for {group_name}") ax.legend(title="ROI", loc="upper right") if index == -1: # No groups, so remove all axes for ax in axes.flatten(): fig.delaxes(ax) # Remove any unused axes and apply final touches else: for j in range(index + 1, len(axes.flatten())): fig.delaxes(axes.flatten()[j]) fig.tight_layout() fig.suptitle("Theta Averages Across Groups", fontsize=16, y=1.02) # type: ignore plt.show() # type: ignore def compute_p_group_stats(df_cha: DataFrame, bad_pairs: set[tuple[int, int]], t_or_theta: Literal['t', 'theta'] = 't') -> DataFrame: if HRF_MODEL == 'fir': # Filter: All delays for the target activity df_activity = df_cha[df_cha['Condition'].str.startswith(f"{TARGET_ACTIVITY}_delay_") & (df_cha['Chroma'] == 'hbo')] # type: ignore # Aggregate across FIR delays *per subject* for each channel df_agg = (df_activity.groupby(['Source', 'Detector', 'ID'])[['t', 'theta']].mean().reset_index()) # type: ignore else: # Canonical HRF case df_agg = df_cha[(df_cha['Condition'] == TARGET_ACTIVITY) & (df_cha['Chroma'] == 'hbo')].copy() # Filter the channel data down to what we want grouped = cast(Iterator[tuple[tuple[int, int], Any]], df_agg.groupby(['Source', 'Detector'])) # type: ignore # Create an empty list to store the data for our result data: list[dict[str, Any]] = [] # Iterate over the filtered channel data for (src, det), group in grouped: # If it is a bad channel pairing, do not process it if (src, det) in bad_pairs: logger.info(f"Skipping bad channel Source {src} - Detector {det}") continue # Drop any missing values that could exist t_values = group['t'].dropna().values t_values = np.array(t_values, dtype=float) theta_values = group['theta'].dropna().values theta_values = np.array(theta_values, dtype=float) # Ensure that we still have our two t values, otherwise do not process this pairing # TODO: is the t values throwing a warning good enough? if len(t_values) < 2: logger.info(f"Skipping Source {src} - Detector {det}: not enough data (n={len(t_values)})") continue # NOTE: This is still calculated with t values as it is a t-test # Perform one-sample t-test on t-values across subjects shitte, pval = ttest_1samp(t_values, popmean=0) print(shitte) # Store all of the data for this ttest using the mean t-value for visualization if t_or_theta == 't': data.append({ 'Source': src, 'Detector': det, 't_or_theta': np.mean(t_values), 'p_value': pval }) else: data.append({ 'Source': src, 'Detector': det, 't_or_theta': np.mean(theta_values), 'p_value': pval }) # Create a DataFrame with the data and ensure it is not empty result = DataFrame(data) if result.empty: logger.info("No valid channel pairs with enough data for group-level testing.") return result def get_bad_src_det_pairs(raw: BaseRaw) -> set[tuple[int, int]]: '''This method figures out the bad source and detector pairings for the 2d t+p graph to prevent them from being plotted. Inputs:\n raw (RawSNIRF) - Contains all the snirf data for the last participant processed. Only used to get the channels\n Outputs:\n bad_pairs (set) - Set containing all of the bad pairs of sources and detectors''' # Create a set to store the bad pairs bad_pairs: set[tuple[int, int]] = set() # Iterate over all the channels in bads key for ch_name in getattr(raw, "info")["bads"]: try: # Get all characters before the space parts = ch_name.split()[0] # Split with the separator src_str, det_str = parts.split(SOURCE_DETECTOR_SEPARATOR) src = int(src_str[1:]) det = int(det_str[1:]) # Add to the set bad_pairs.add((src, det)) except Exception as e: logger.info(f"Could not parse bad channel '{ch_name}': {e}") return bad_pairs def plot_avg_significant_activity(raw: BaseRaw, all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], t_or_theta: Literal['t', 'theta'] = 't') -> None: '''This method plots the average t values for the groups on a 2D graph. p values less than or equal to P_THRESHOLD are solid lines, while greater p values are dashed lines.\n Inputs:\n raw (RawSNIRF) - Contains all the snirf data for the last participant processed. Only used to get the channel locations.\n all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n''' # Iterate over all the groups for group_name in all_results: (_, df_cha, _, _) = all_results[group_name] if HRF_MODEL == 'fir': mask = df_cha['Condition'].str.startswith(f"{TARGET_ACTIVITY}_delay_") & (df_cha['Chroma'] == 'hbo') # type: ignore filtered_df = df_cha[mask] num_tests = filtered_df.groupby(['Source', 'Detector']).ngroups # type: ignore else: num_tests = len(cast(Iterator[tuple[tuple[int, int], Any]], df_cha.query(f"Condition == '{TARGET_ACTIVITY}' and Chroma == 'hbo'").groupby(['Source', 'Detector']))) # type: ignore logger.info(f"Number of tests: {num_tests}") # Compute average t-value across individuals for each channel pairing bad_pairs = get_bad_src_det_pairs(raw) avg_df = compute_p_group_stats(df_cha, bad_pairs, t_or_theta) logger.info(f"Average {t_or_theta}-values and p-values for {TARGET_ACTIVITY}:") for _, row in avg_df.iterrows(): # type: ignore logger.info(f"Source {row['Source']} <-> Detector {row['Detector']}: " f"Avg {t_or_theta}-value = {row['t_or_theta']:.3f}, Avg p-value = {row['p_value']:.3f}") # Extract the cource and detector positions from raw src_pos: dict[int, tuple[float, float]] = {} det_pos: dict[int, tuple[float, float]] = {} for ch in getattr(raw, "info")["chs"]: ch_name = ch['ch_name'] if not ch_name or not ch['loc'].any(): continue parts = ch_name.split()[0] src_str, det_str = parts.split(SOURCE_DETECTOR_SEPARATOR) src_num = int(src_str[1:]) det_num = int(det_str[1:]) src_pos[src_num] = ch['loc'][3:5] det_pos[det_num] = ch['loc'][6:8] # Set up the plot fig, ax = plt.subplots(figsize=(8, 6)) # type: ignore # Plot the sources for pos in src_pos.values(): ax.scatter(pos[0], pos[1], s=120, c='k', marker='o', edgecolors='white', linewidths=1, zorder=3) # type: ignore # Plot the detectors for pos in det_pos.values(): ax.scatter(pos[0], pos[1], s=120, c='k', marker='s', edgecolors='white', linewidths=1, zorder=3) # type: ignore # Ensure that the colors stay within the boundaries even if they are over or under the max/min values if t_or_theta == 't': norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_T_VALUE, vmax=ABS_SIGNIFICANCE_T_VALUE) elif t_or_theta == 'theta': norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_THETA_VALUE, vmax=ABS_SIGNIFICANCE_THETA_VALUE) cmap: mcolors.Colormap = plt.get_cmap('seismic') # Plot connections with avg t-values for row in avg_df.itertuples(): src: int = cast(int, row.Source) # type: ignore det: int = cast(int, row.Detector) # type: ignore tval: float = cast(float, row.t_or_theta) # type: ignore pval: float = cast(float, row.p_value) # type: ignore if src in src_pos and det in det_pos: x = [src_pos[src][0], det_pos[det][0]] y = [src_pos[src][1], det_pos[det][1]] style = '-' if pval <= P_THRESHOLD else '--' ax.plot(x, y, linestyle=style, color=cmap(norm(tval)), linewidth=4, alpha=0.9, zorder=2) # type: ignore # Format the Colorbar sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) cbar = plt.colorbar(sm, ax=ax, shrink=0.85) # type: ignore cbar.set_label(f'Average {TARGET_ACTIVITY} {t_or_theta} value (hbo)', fontsize=11) # type: ignore # Formatting the subplots ax.set_aspect('equal') ax.set_title(f"Average {t_or_theta} values for {TARGET_ACTIVITY} (HbO) for {group_name}", fontsize=14) # type: ignore ax.set_xlabel('X position (m)', fontsize=11) # type: ignore ax.set_ylabel('Y position (m)', fontsize=11) # type: ignore ax.grid(True, alpha=0.3) # type: ignore # Set axis limits to be 1cm more than the optode positions all_x = [pos[0] for pos in src_pos.values()] + [pos[0] for pos in det_pos.values()] all_y = [pos[1] for pos in src_pos.values()] + [pos[1] for pos in det_pos.values()] ax.set_xlim(min(all_x)-0.01, max(all_x)+0.01) ax.set_ylim(min(all_y)-0.01, max(all_y)+0.01) fig.tight_layout() plt.show() # type: ignore def generate_montage_locations(): """Get standard MNI montage locations in dataframe. Data is returned in the same format as the eeg_positions library. """ # standard_1020 and standard_1005 are in MNI (fsaverage) space already, # but we need to undo the scaling that head_scale will do montage = mne.channels.make_standard_montage( "standard_1005", head_size=0.09700884729534559 ) for d in montage.dig: d["coord_frame"] = 2003 montage.dig[:] = montage.dig[3:] montage.add_mni_fiducials() # now in fsaverage space coords = pd.DataFrame.from_dict(montage.get_positions()["ch_pos"]).T coords["label"] = coords.index coords = coords.rename(columns={0: "x", 1: "y", 2: "z"}) return coords.reset_index(drop=True) def _find_closest_standard_location(position, reference, *, out="label"): """Return closest montage label to coordinates. Parameters ---------- position : array, shape (3,) Coordinates. reference : dataframe As generated by _generate_montage_locations. trans_pos : str Apply a transformation to positions to specified frame. Use None for no transformation. """ from scipy.spatial.distance import cdist p0 = np.array(position) p0.shape = (-1, 3) # head_mri_t, _ = _get_trans("fsaverage", "head", "mri") # p0 = apply_trans(head_mri_t, p0) dists = cdist(p0, np.asarray(reference[["x", "y", "z"]], float)) if out == "label": min_idx = np.argmin(dists) return reference["label"][min_idx] else: assert out == "dists" return dists def _source_detector_fold_table(raw, cidx, reference, fold_tbl, interpolate): src = raw.info["chs"][cidx]["loc"][3:6] det = raw.info["chs"][cidx]["loc"][6:9] ref_lab = list(reference["label"]) dists = _find_closest_standard_location([src, det], reference, out="dists") src_min, det_min = np.argmin(dists, axis=1) src_name, det_name = ref_lab[src_min], ref_lab[det_min] tbl = fold_tbl.query("Source == @src_name and Detector == @det_name") dist = np.linalg.norm(dists[[0, 1], [src_min, det_min]]) # Try reversing source and detector if len(tbl) == 0: tbl = fold_tbl.query("Source == @det_name and Detector == @src_name") if len(tbl) == 0 and interpolate: # Try something hopefully not too terrible: pick the one with the # smallest net distance good = np.isin(fold_tbl["Source"], reference["label"]) & np.isin( fold_tbl["Detector"], reference["label"] ) assert good.any() tbl = fold_tbl[good] assert len(tbl) src_idx = [ref_lab.index(src) for src in tbl["Source"]] det_idx = [ref_lab.index(det) for det in tbl["Detector"]] # Original tot_dist = np.linalg.norm([dists[0, src_idx], dists[1, det_idx]], axis=0) assert tot_dist.shape == (len(tbl),) idx = np.argmin(tot_dist) dist_1 = tot_dist[idx] src_1, det_1 = ref_lab[src_idx[idx]], ref_lab[det_idx[idx]] # And the reverse tot_dist = np.linalg.norm([dists[0, det_idx], dists[1, src_idx]], axis=0) idx = np.argmin(tot_dist) dist_2 = tot_dist[idx] src_2, det_2 = ref_lab[det_idx[idx]], ref_lab[src_idx[idx]] if dist_1 < dist_2: new_dist, src_use, det_use = dist_1, src_1, det_1 else: new_dist, src_use, det_use = dist_2, det_2, src_2 tbl = fold_tbl.query("Source == @src_use and Detector == @det_use") tbl = tbl.copy() tbl["BestSource"] = src_name tbl["BestDetector"] = det_name tbl["BestMatchDistance"] = dist tbl["MatchDistance"] = new_dist assert len(tbl) else: tbl = tbl.copy() tbl["BestSource"] = src_name tbl["BestDetector"] = det_name tbl["BestMatchDistance"] = dist tbl["MatchDistance"] = dist tbl = tbl.copy() # don't get warnings about setting values later return tbl from mne.utils import _check_fname, _validate_type, warn def _read_fold_xls(fname, atlas="Juelich"): """Read fOLD toolbox xls file. The values are then manipulated in to a tidy dataframe. Note the xls files are not included as no license is provided. Parameters ---------- fname : str Path to xls file. atlas : str Requested atlas. """ page_reference = {"AAL2": 2, "AICHA": 5, "Brodmann": 8, "Juelich": 11, "Loni": 14} tbl = pd.read_excel(fname, sheet_name=page_reference[atlas]) # Remove the spacing between rows empty_rows = np.where(np.isnan(tbl["Specificity"]))[0] tbl = tbl.drop(empty_rows).reset_index(drop=True) # Empty values in the table mean its the same as above for row_idx in range(1, tbl.shape[0]): for col_idx, col in enumerate(tbl.columns): if not isinstance(tbl[col][row_idx], str): if np.isnan(tbl[col][row_idx]): tbl.iloc[row_idx, col_idx] = tbl.iloc[row_idx - 1, col_idx] tbl["Specificity"] = tbl["Specificity"] * 100 tbl["brainSens"] = tbl["brainSens"] * 100 return tbl import os.path as op def _check_load_fold(fold_files, atlas): # _validate_type(fold_files, (list, "path-like", None), "fold_files") if fold_files is None: fold_files = mne.get_config("MNE_NIRS_FOLD_PATH") if fold_files is None: raise ValueError( "MNE_NIRS_FOLD_PATH not set, either set it using " "mne.set_config or pass fold_files as str or list" ) if not isinstance(fold_files, list): # path-like fold_files = _check_fname( fold_files, overwrite="read", must_exist=True, name="fold_files", need_dir=True, ) fold_files = [op.join(fold_files, f"10-{x}.xls") for x in (5, 10)] fold_tbl = pd.DataFrame() for fi, fname in enumerate(fold_files): fname = _check_fname( fname, overwrite="read", must_exist=True, name=f"fold_files[{fi}]" ) fold_tbl = pd.concat( [fold_tbl, _read_fold_xls(fname, atlas=atlas)], ignore_index=True ) return fold_tbl def fold_channel_specificity_normal(raw, fold_files=None, atlas="Juelich", interpolate=False): """Return the landmarks and specificity a channel is sensitive to. Parameters """ # noqa: E501 _validate_type(raw, BaseRaw, "raw") reference_locations = generate_montage_locations() fold_tbl = _check_load_fold(fold_files, atlas) chan_spec = list() for cidx in range(len(raw.ch_names)): tbl = _source_detector_fold_table( raw, cidx, reference_locations, fold_tbl, interpolate ) chan_spec.append(tbl.reset_index(drop=True)) return chan_spec def fold_channels(raw: BaseRaw, all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], fold_path: str) -> None: # Locate the fOLD excel files mne.set_config('MNE_NIRS_FOLD_PATH', fold_path) # type: ignore # Iterate over all of the groups for group_name in all_results: output = None # List to store the results landmark_specificity_data: list[dict[str, Any]] = [] # Filter the data to only what we want hbo_channel_names = cast(list[str], getattr(raw.copy().pick(picks='hbo'), "ch_names")) # type: ignore # Format the output to make it slightly easier to read logger.info("*" * 60) logger.info(f'Landmark Specificity for {group_name}:') logger.info("*" * 60) if GUI: num_channels = len(hbo_channel_names) rows, cols = 4, 7 # 6 rows and 4 columns of pie charts fig, axes = plt.subplots(rows, cols, figsize=(16, 10), constrained_layout=True) axes = axes.flatten() # Flatten the axes array for easier indexing # If more pie charts than subplots, create extra subplots if num_channels > rows * cols: fig, axes = plt.subplots((num_channels // cols) + 1, cols, figsize=(16, 10), constrained_layout=True) axes = axes.flatten() # Create a list for consistent color mapping landmarks = [ "1 - Primary Somatosensory Cortex", "2 - Primary Somatosensory Cortex", "3 - Primary Somatosensory Cortex", "4 - Primary Motor Cortex", "5 - Somatosensory Association Cortex", "6 - Pre-Motor and Supplementary Motor Cortex", "7 - Somatosensory Association Cortex", "8 - Includes Frontal eye fields", "9 - Dorsolateral prefrontal cortex", "10 - Frontopolar area", "11 - Orbitofrontal area", "17 - Primary Visual Cortex (V1)", "18 - Visual Association Cortex (V2)", "19 - V3", "20 - Inferior Temporal gyrus", "21 - Middle Temporal gyrus", "22 - Superior Temporal Gyrus", "23 - Ventral Posterior cingulate cortex", "24 - Ventral Anterior cingulate cortex", "25 - Subgenual cortex", "32 - Dorsal anterior cingulate cortex", "37 - Fusiform gyrus", "38 - Temporopolar area", "39 - Angular gyrus, part of Wernicke's area", "40 - Supramarginal gyrus part of Wernicke's area", "41 - Primary and Auditory Association Cortex", "42 - Primary and Auditory Association Cortex", "43 - Subcentral area", "44 - pars opercularis, part of Broca's area", "45 - pars triangularis Broca's area", "46 - Dorsolateral prefrontal cortex", "47 - Inferior prefrontal gyrus", "48 - Retrosubicular area", "Brain_Outside", ] cmap1 = plt.cm.get_cmap('tab20') # First 20 colors cmap2 = plt.cm.get_cmap('tab20b') # Next 20 colors # Combine the colors from both colormaps colors = [cmap1(i) for i in range(20)] + [cmap2(i) for i in range(20)] # Total 40 colors landmarks.sort(key=lambda x: (int(x.split(" - ")[0]) if x.split(" - ")[0].isdigit() else float('inf'))) landmark_color_map = {landmark: colors[i % len(colors)] for i, landmark in enumerate(landmarks)} # Iterate over each channel for idx, channel_name in enumerate(hbo_channel_names): # Run the fOLD on the selected channel channel_data = raw.copy().pick(picks=channel_name) # type: ignore output = cast(list[DataFrame], fold_channel_specificity_normal(channel_data, interpolate=True, atlas='Brodmann')) # Process each DataFrame that fold_channel_specificity returns for df_data in output: # Extract the relevant columns useful_data = df_data[['Landmark', 'Specificity']] # Store the results landmark_specificity_data.append({ 'Channel': channel_name, 'Data': useful_data, }) # logger.info the results for data in landmark_specificity_data: logger.info(f"Channel: {data['Channel']}") logger.info(f"{data['Data']}") logger.info("-" * 60) # If PLOT_ENABLED is True, plot the results if GUI: unique_landmarks = sorted(useful_data['Landmark'].unique()) color_list = [landmark_color_map[landmark] for landmark in useful_data['Landmark']] # Plot specificity for each channel ax = axes[idx] # Use the correct axis for this channel labels = [f'{landmark.split(" - ")[0]}' if landmark != 'Brain_Outside' else 'B' for landmark in useful_data['Landmark']] wedges, texts, autotexts = ax.pie( useful_data['Specificity'], autopct='%1.1f%%', startangle=90, labels=labels, # Add the labels here labeldistance=1.05, # Adjust label position to avoid overlap with the wedges colors=color_list) # Ensure color consistency ax.set_title(f'{channel_name}') ax.axis('equal') # Equal aspect ratio ensures the pie chart is circular. # Reset the list for the next particcipant landmark_specificity_data = [] if GUI: handles = [ plt.Line2D([0], [0], marker='o', color='w', label=landmark, markersize=10, markerfacecolor=landmark_color_map[landmark]) for landmark in landmarks ] n_landmarks = len(landmarks) # Calculate the figure size based on number of rows and columns fig_width = 5 fig_height = n_landmarks / 4 # Create a new figure window for the legend legend_fig = plt.figure(figsize=(fig_width, fig_height)) legend_axes = legend_fig.add_subplot(111) legend_axes.axis('off') # Turn off axis for the legend window legend_axes.legend(handles=handles, loc='center', fontsize=10, title="Landmarks") if GUI: for ax in axes[len(hbo_channel_names):]: ax.axis('off') plt.show() def brain_landmarks_3d(raw_haemo: BaseRaw, show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all') -> None: brain = Brain("fsaverage", background="white", size=(800, 700)) # type: ignore # Add optode text labels manually if show_optodes == 'all' or show_optodes == 'sensors': brain.add_sensors(getattr(raw_haemo, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=VERBOSITY) # type: ignore # Read and parse the file if show_optodes == 'all' or show_optodes == 'labels': positions: list[tuple[str, list[float]]] = [] with open(OPTODE_FILE_PATH, 'r') as f: for line in f: line = line.strip() if not line or ':' not in line: continue # skip empty/malformed lines name, coords = line.split(':', 1) coords = [float(x) for x in coords.strip().split()] positions.append((name.strip(), coords)) for name, (x, y, z) in positions: brain._renderer.text3d(x, y, z, name, color=('red' if name.startswith('s') else 'blue' if name.startswith('d') else 'gray'), scale=0.002) # type: ignore for ch in getattr(raw_haemo, "info")['chs']: logger.info(ch['ch_name'], ch['loc'][:3]) # Add Brodmann labels labels = cast(list[mne.Label], mne.read_labels_from_annot("fsaverage", "PALS_B12_Brodmann", "rh", verbose=VERBOSITY)) # type: ignore label_colors = { "Brodmann.39-rh": "blue", "Brodmann.40-rh": "green", "Brodmann.6-rh": "pink", "Brodmann.7-rh": "orange", "Brodmann.17-rh": "red", "Brodmann.1-rh": "yellow", "Brodmann.2-rh": "yellow", "Brodmann.3-rh": "yellow", "Brodmann.18-rh": "red", "Brodmann.19-rh": "red", "Brodmann.4-rh": "purple", "Brodmann.8-rh": "white" } for label in labels: name = getattr(label, "name", None) if not isinstance(name, str): continue if name in label_colors: brain.add_label(label, borders=False, color=label_colors[name]) # type: ignore def data_to_csv(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]): logger.info("Getting the current directory...") if PLATFORM_NAME == 'darwin': csvs_folder = os.path.join(os.path.dirname(sys.executable), "../../../csvs") else: cwd = os.getcwd() csvs_folder = os.path.join(cwd, "csvs") logger.info("Attempting to create the csvs folder...") os.makedirs(csvs_folder, exist_ok=True) # Generate a timestamp to be appended to the end of the file name logger.info("Generating the timestamp...") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Iterate over all groups for group_name in all_results: # Get the channel data and generate the file name (_, df_cha, _, _) = all_results[group_name] filename = f"{group_name}_{timestamp}.csv" save_path = os.path.join(csvs_folder, filename) # Filter to just the target condition and store it in the csv if HRF_MODEL == 'fir': filtered_df = df_cha[ df_cha["Condition"].str.startswith(TARGET_ACTIVITY) & (df_cha["Chroma"] == "hbo") ] # Step 2: Define the aggregation logic agg_funcs = { 'df': 'mean', 'mse': 'mean', 'p_value': 'mean', 'se': 'mean', 't': 'mean', 'theta': 'mean', 'Source': 'mean', 'Detector': 'mean', 'Significant': lambda x: x.sum() > (len(x) / 2), 'Chroma': 'first', # assuming all are the same 'ch_name': 'first', # same ch_name in the group 'ID': 'first', # same ID in the group } # Step 3: Group and aggregate averaged_df = ( filtered_df .groupby(['ch_name', 'ID'], as_index=False) .agg(agg_funcs) ) # Step 4: Rename and add 'Condition' as TARGET_ACTIVITY averaged_df.insert(0, 'Condition', TARGET_ACTIVITY) averaged_df["df"] = averaged_df["df"].round().astype(int) averaged_df["Source"] = averaged_df["Source"].round().astype(int) averaged_df["Detector"] = averaged_df["Detector"].round().astype(int) # Step 5: Reset index and reorder columns ordered_cols = [ 'Condition', 'df', 'mse', 'p_value', 'se', 't', 'theta', 'Source', 'Detector', 'Chroma', 'Significant', 'ch_name', 'ID' ] averaged_df = averaged_df[ordered_cols].reset_index(drop=True) averaged_df = averaged_df.sort_values(by=["ID", "Detector", "Source"]).reset_index(drop=True) output_df = averaged_df else: output_df = df_cha.query(f"Condition == '{TARGET_ACTIVITY}' and Chroma == 'hbo'") # type: ignore output_df.to_csv(save_path) def all_data_to_csv(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]): logger.info("Getting the current directory...") if PLATFORM_NAME == 'darwin': csvs_folder = os.path.join(os.path.dirname(sys.executable), "../../../csvs") else: cwd = os.getcwd() csvs_folder = os.path.join(cwd, "csvs") logger.info("Attempting to create the csvs folder...") os.makedirs(csvs_folder, exist_ok=True) # Generate a timestamp to be appended to the end of the file name logger.info("Generating the timestamp...") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Iterate over all groups for group_name in all_results: # Get the channel data and generate the file name (_, df_cha, _, _) = all_results[group_name] filename = f"{group_name}_{timestamp}_all.csv" save_path = os.path.join(csvs_folder, filename) # Filter to just the target condition and store it in the csv if HRF_MODEL == 'fir': output_df = df_cha else: output_df = df_cha.query(f"Condition == '{TARGET_ACTIVITY}' and Chroma == 'hbo'") # type: ignore output_df.to_csv(save_path) def brain_3d_contrast(con_model_df: DataFrame, con_model_df_filtered: BaseRaw, common_channels: list[str], first_name: str, second_name: str, first_stim: float, second_stim: float, t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True) -> None: # Filter DataFrame to only common channels, and sort by raw order con_model = con_model_df con_model["ch_name"] = pd.Categorical( con_model["ch_name"], categories=common_channels, ordered=True ) con_model = con_model.sort_values("ch_name").reset_index(drop=True) # type: ignore if t_or_theta == 't': clim=dict(kind="value", pos_lims=(0, ABS_T_VALUE/2, ABS_T_VALUE)) elif t_or_theta == 'theta': clim=dict(kind="value", pos_lims=(0, ABS_THETA_VALUE/2, ABS_THETA_VALUE)) # Plot brain figure brain = plot_3d_evoked_array(con_model_df_filtered.copy().pick(picks="hbo"), con_model, view="dorsal", distance=BRAIN_DISTANCE, colorbar=True, mode=BRAIN_MODE, clim=clim, size=(800, 700), verbose=VERBOSITY) # type: ignore if show_optodes == 'all' or show_optodes == 'sensors': brain.add_sensors(getattr(con_model_df_filtered, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=VERBOSITY) # type: ignore # Read and parse the file if show_optodes == 'all' or show_optodes == 'labels': positions: list[tuple[str, list[float]]] = [] with open(OPTODE_FILE_PATH, 'r') as f: for line in f: line = line.strip() if not line or ':' not in line: continue # skip empty/malformed lines name, coords = line.split(':', 1) coords = [float(x) for x in coords.strip().split()] positions.append((name.strip(), coords)) for name, (x, y, z) in positions: brain._renderer.text3d(x, y, z, name, color=('red' if name.startswith('s') else 'blue' if name.startswith('d') else 'gray'), scale=0.002) # type: ignore # Set the display text for the brain image # display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nContrast: ' + first_name + ' - ' + second_name + '\nReject Criteria Threshold: ' + str(EPOCH_REJECT_CRITERIA_THRESH) + '\nMin Time Threshold: ' + # str(TIME_MIN_THRESH) + 's\nMax Time Threshold: ' + str(TIME_MAX_THRESH) + 's\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) + '\nStim Duration: ' + str(first_stim) + ', ' + # str(second_stim) + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE + '\nLooking at: ' + t_or_theta + ' values') if HRF_MODEL == 'fir': display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nContrast: ' + first_name + ' - ' + second_name + '\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE + '\nLooking at: ' + t_or_theta + ' values') else: display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nContrast: ' + first_name + ' - ' + second_name + '\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) + '\nStim Duration: ' + str(first_stim) + ', ' + str(second_stim) + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE + '\nLooking at: ' + t_or_theta + ' values') # Apply the text onto the brain if show_text: brain.add_text(0.12, 0.70, display_text, "title", font_size=11, color="k") # type: ignore def plot_2d_3d_contrasts_between_groups(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], all_raw_haemo: dict[str, dict[str, BaseRaw]], t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True) -> None: # Dictionary to store data for each group group_dfs: dict[str, DataFrame] = {} # GET RAW HAEMO OF THE FIRST PARTICIPANT raw_haemo = all_raw_haemo[list(all_raw_haemo.keys())[0]]["full_layout"] # Store all contrasts with the corresponding group name for group_name, (_, _, df_con, _) in all_results.items(): group_dfs[group_name] = df_con group_dfs[group_name]["group"] = group_name # Concatenate all groups together df_combined = pd.concat(group_dfs.values(), ignore_index=True) con_summary = df_combined.query("Chroma == 'hbo'").copy() # type: ignore valid_channels = cast(DataFrame, (pd.crosstab(con_summary['group'], con_summary['ch_name']) > 1).all()) # type: ignore valid_channels = valid_channels[valid_channels].index.tolist() # Filter data to only these channels con_summary = con_summary[con_summary['ch_name'].isin(valid_channels)] # type: ignore # # Verify your data looks as expected # logger.info(con_summary[['group', 'ch_name', 'Chroma', 'effect']].head()) # logger.info("\nUnique values:") # logger.info("Groups:", con_summary['group'].unique()) # logger.info("Channels:", con_summary['ch_name'].unique()) # logger.info("Chroma:", con_summary['Chroma'].unique()) # Should be just 'hbo' model_formula = "effect ~ -1 + group:ch_name:Chroma" con_model = smf.mixedlm(model_formula, con_summary, groups=con_summary["ID"]).fit(method="nm") # type: ignore # logger.info(con_model.summary()) # # Fit the mixed-effects model # model_formula = "effect ~ -1 + group:ch_name:Chroma" # #model_formula = "effect ~ -1 + group + ch_name" # con_model = smf.mixedlm( # model_formula, con_summary_filtered, groups=con_summary_filtered["ID"] # ).fit(method="nm") # Get the t values if we are comparing them t_values: pd.Series[float] = pd.Series(dtype=float) if t_or_theta == 't': t_values = con_model.tvalues # Get all the group names from the dictionary and how many groups we have group_names = list(group_dfs.keys()) n_groups = len(group_names) # Store DataFrames for each contrast for i in range(n_groups): for j in range(i + 1, n_groups): group1_name = group_names[i] group2_name = group_names[j] if t_or_theta == 't': # Extract the t-values for both groups group1_vals = t_values.filter(like=f"group[{group1_name}]") # type: ignore group2_vals = t_values.filter(like=f"group[{group2_name}]") # type: ignore vlim_value = ABS_CONTRAST_T_VALUE elif t_or_theta == 'theta': # Extract the coefficients for both groups group1_vals = con_model.params.filter(like=f"group[{group1_name}]") group2_vals = con_model.params.filter(like=f"group[{group2_name}]") vlim_value = ABS_CONTRAST_THETA_VALUE # TODO: Does this work for all separators? # Extract channel names group1_channels: list[str] = [ name.split(":")[1].split("[")[1].split("]")[0] for name in getattr(group1_vals, "index") ] group2_channels: list[str] = [ name.split(":")[1].split("[")[1].split("]")[0] for name in getattr(group2_vals, "index") ] # Create the DataFrames with channel indices df_group1 = DataFrame( {"Coef.": group1_vals.values}, index=group1_channels # type: ignore ) df_group2 = DataFrame( {"Coef.": group2_vals.values}, index=group2_channels # type: ignore ) # Merge the two DataFrames on the channel names df_contrast = df_group1.join(df_group2, how="inner", lsuffix=f"_{group1_name}", rsuffix=f"_{group2_name}") # type: ignore # Compute the contrasts contrast_1_2 = df_contrast[f"Coef._{group1_name}"] - df_contrast[f"Coef._{group2_name}"] contrast_2_1 = df_contrast[f"Coef._{group2_name}"] - df_contrast[f"Coef._{group1_name}"] # Add the a-b / 1-2 contrast to the DataFrame. The order and names of the keys in the DataFrame are important! df_contrast["Coef."] = contrast_1_2 con_model_df_1_2 = DataFrame({ "ch_name": df_contrast.index, "Coef.": df_contrast["Coef."], "Chroma": "hbo" }) mne_ch_names = getattr(raw_haemo.copy().pick(picks="hbo"), "ch_names") # type: ignore glm_ch_names = cast(list[DataFrame], con_model_df_1_2["ch_name"].tolist()) # Get ordered common channels common_channels = [ch for ch in mne_ch_names if ch in glm_ch_names] # Filter raw data to these channels con_model_df_filtered = raw_haemo.copy().pick(picks=common_channels) # type: ignore # Reindex GLM results to match MNE channel order con_model_df_1_2 = con_model_df_1_2.set_index("ch_name").loc[common_channels].reset_index() # type: ignore # Create the 3d visualization brain_3d_contrast(con_model_df_1_2, con_model_df_filtered, common_channels, group1_name, group2_name, STIM_DURATION[i], STIM_DURATION[j], t_or_theta, show_optodes, show_text) plot_glm_group_topo(con_model_df_filtered.copy().pick(picks="hbo"), con_model_df_1_2, names=True, res=128, vlim=(-vlim_value, vlim_value)) # type: ignore # TODO: The title currently goes on the colorbar. Low priority plt.title(f"Contrast: {group1_name} vs {group2_name}") # type: ignore plt.show() # type: ignore # Add the b-a / 2-1 contrast to the DataFrame. The order and names of the keys in the DataFrame are important! df_contrast["Coef."] = contrast_2_1 con_model_df_2_1 = DataFrame({ "ch_name": df_contrast.index, "Coef.": df_contrast["Coef."], "Chroma": "hbo" }) mne_ch_names = getattr(raw_haemo.copy().pick(picks="hbo"), "ch_names") # type: ignore glm_ch_names = cast(list[DataFrame], con_model_df_2_1["ch_name"].tolist()) # Get ordered common channels common_channels = [ch for ch in mne_ch_names if ch in glm_ch_names] # Filter raw data to these channels con_model_df_filtered = raw_haemo.copy().pick(picks=common_channels) # type: ignore # Reindex GLM results to match MNE channel order con_model_df_2_1 = con_model_df_2_1.set_index("ch_name").loc[common_channels].reset_index() # type: ignore # Create the 3d visualization brain_3d_contrast(con_model_df_2_1, con_model_df_filtered, common_channels, group2_name, group1_name, STIM_DURATION[j], STIM_DURATION[i], t_or_theta, show_optodes, show_text) plot_glm_group_topo(con_model_df_filtered.copy().pick(picks="hbo"), con_model_df_2_1, names=True, res=128, vlim=(-vlim_value, vlim_value)) # type: ignore # TODO: The title currently goes on the colorbar. Low priority plt.title(f"Contrast: {group2_name} vs {group1_name}") # type: ignore plt.show() # type: ignore # TODO: Is any of this still useful? def calculate_annotations(raw_haemo_filtered, file_name, output_folder=None, save_images=None): '''Method that extract the annotations from the data.\n Input:\n raw_haemo_filtered (RawSNIRF) - The filtered haemoglobin concentration data\n file_name (string) - The file name of the current file\n output_folder (string) - (optional) Where to save the images. Default is None\n save_images (string) - (optional) Bool to save the images or not. Default is None Output:\n events (ndarray) - Array containing row number and what index the event is\n event_dict (dict) - Contains the names of the events''' if output_folder is None: output_folder = None if save_images is None: save_images = None # Get when the events occur and what they are called, and display a figure with the result events, event_dict = mne.events_from_annotations(raw_haemo_filtered) # Do we save the image? if save_images: fig = mne.viz.plot_events(events, event_id=event_dict, sfreq=raw_haemo_filtered.info["sfreq"], show=False) save_path = output_folder + "/8. Annotations for " + file_name + ".png" fig.savefig(save_path) return events, event_dict def calculate_good_epochs(raw_haemo_filtered, events, event_dict, file_name, tmin=None, tmax=None, reject_thresh=None, target_activity=None, target_control=None, output_folder=None, save_images=None): '''Calculates what epochs are good and creates a graph showing if any are dropped.\n Input:\n raw_haemo_filtered (RawSNIRF) - The filtered haemoglobin concentration data\n events (ndarray) - Array containing row number and what index the event is\n event_dict (dict) - Contains the names of the events\n file_name (string) - The file name of the current file\n tmin (float) - (optional) Time in seconds to display before the event. Default is TIME_MIN_THRESH\n tmax (float) - (optional) Time in seconds to display after the event. Default is TIME_MAX_THRESH\n reject_thresh (float) - (optional) Value that determines the threshold for rejecting epochs. Default is EPOCH_REJECT_CRITERIA_THRESH\n target_activity (string) - (optional) The target activity. Default is TARGET_ACTIVITY\n target_control (string) - (optional) The target control. Default is TARGET_CONTROL\n output_folder (string) - (optional) Where to save the images. Default is None\n save_images (string) - (optional) Bool to save the images or not. Default is None Output:\n good_epochs (Epochs) - The remaining good epochs\n all_epochs (Epochs) - All of the epochs''' if tmin is None: tmin = TIME_MIN_THRESH if tmax is None: tmax = TIME_MAX_THRESH if reject_thresh is None: reject_thresh = EPOCH_REJECT_CRITERIA_THRESH if target_activity is None: target_activity = TARGET_ACTIVITY if target_control is None: target_control = TARGET_CONTROL if output_folder is None: output_folder = None if save_images is None: save_images = None # Get all the good epochs good_epochs = mne.Epochs( raw_haemo_filtered, events, event_id=event_dict, tmin=tmin, tmax=tmax, reject=dict(hbo=reject_thresh), reject_by_annotation=True, proj=True, baseline=(None, 0), preload=True, detrend=None, verbose=True, ) # Get all the epochs all_epochs = mne.Epochs( raw_haemo_filtered, events, event_id=event_dict, tmin=tmin, tmax=tmax, proj=True, baseline=(None, 0), preload=True, detrend=None, verbose=True, ) if REJECT_PAIRS: # Calculate which epochs were in all but not in good all_idx = all_epochs.selection good_idx = good_epochs.selection bad_idx = np.setdiff1d(all_idx, good_idx) # Split the controls and the activities event_ids = all_epochs.events[:, 2] control_id = event_dict[target_control] activity_id = event_dict[target_activity] to_reject_extra = set() for i, idx in enumerate(all_idx): if idx in bad_idx: ev = event_ids[i] # If the control was bad, drop the following activity if ev == control_id and i + 1 < len(all_idx): if event_ids[i + 1] == activity_id: to_reject_extra.add(all_idx[i + 1]) # If the activity was bad, drop the preceding activity if ev == activity_id and i - 1 >= 0: if event_ids[i - 1] == control_id: to_reject_extra.add(all_idx[i - 1]) # Create a list to store all the new drops, only adding them if they are currently classified as good drop_idx_in_good = [ np.where(good_idx == idx)[0][0] for idx in to_reject_extra if idx in good_idx ] # Drop the pairings of the bad epochs good_epochs.drop(drop_idx_in_good) # Do we save the image? if save_images: drop_log_fig = good_epochs.plot_drop_log(show=False) save_path = output_folder + "/8. Epoch drops for " + file_name + ".png" drop_log_fig.savefig(save_path) return good_epochs, all_epochs def bad_check(raw_od, max_bad_channels=None): '''Method to see if we have more bad channels than our allowed threshold.\n Inputs:\n raw_od (RawSNIRF) - The optical density data\n max_bad_channels (int) - (optional) The max amount of bad channels we want to tolerate. Default is MAX_BAD_CHANNELS\n Output\n (bool) - True it we had more bad channels than the threshold, False if we did not''' if max_bad_channels is None: max_bad_channels = MAX_BAD_CHANNELS # Check if there is more bad channels in the bads key compared to the allowed amount if len(raw_od.info.get('bads', [])) >= max_bad_channels: return True else: return False def remove_bad_epoch_pairings(raw_haemo_filtered_minus_short, good_epochs, epoch_pair_tolerance_window=None): '''Method to apply our new epochs to the loaded data in working memory. This is to ensure that the GLM does not see these epochs. Inputs:\n raw_haemo_filtered_minus_short (RawSNIRF) - The filtered haemoglobin concentration data\n good_epochs (Epochs) - The epochs we want the loaded data to take on\n epoch_pair_tolerance_window (int) - (optional) The amount of data points the paired epoch can deviate from the expected amount. Default is EPOCH_PAIR_TOLERANCE_WINDOW\n Output:\n raw_haemo_filtered_good_epochs (RawSNIRF) - The filtered haemoglobin concentration data with only the good epochs''' if epoch_pair_tolerance_window is None: epoch_pair_tolerance_window = EPOCH_PAIR_TOLERANCE_WINDOW # Copy the input haemoglobin concentration data and drop the bad channels raw_haemo_filtered_good_epochs = raw_haemo_filtered_minus_short.copy() raw_haemo_filtered_good_epochs = raw_haemo_filtered_good_epochs.drop_channels(raw_haemo_filtered_good_epochs.info['bads']) # Get the event IDs of the good events good_event_samples = set(good_epochs.events[:, 0]) logger.info(f"Total good events (epochs): {len(good_event_samples)}") # Get the current annotations raw_annots = raw_haemo_filtered_good_epochs.annotations # Create lists to use for processing clean_descriptions = [] clean_onsets = [] clean_durations = [] dropped = [] # Get the frequency of the file sfreq = raw_haemo_filtered_good_epochs.info['sfreq'] for desc, onset, dur in zip(raw_annots.description, raw_annots.onset, raw_annots.duration): # Convert annotation onset time to sample index sample = int(onset * sfreq) if FORCE_DROP_ANNOTATIONS: for i in FORCE_DROP_ANNOTATIONS: if desc == i: dropped.append((desc, onset)) continue # Check if the annotation is within the tolerance of any good event matched = any(abs(sample - event_sample) <= epoch_pair_tolerance_window for event_sample in good_event_samples) # We found a matching event if matched: clean_descriptions.append(desc) clean_onsets.append(onset) clean_durations.append(dur) else: dropped.append((desc, onset)) # Create the new filtered annotations new_annots = Annotations( onset=clean_onsets, duration=clean_durations, description=clean_descriptions, ) # Assign the new annotations raw_haemo_filtered_good_epochs.set_annotations(new_annots) # logger.info out the results logger.info(f"Original annotations: {len(raw_annots)}") logger.info(f"Kept annotations: {len(clean_descriptions)}") logger.info("Kept annotation types:", set(clean_descriptions)) if dropped: logger.info(f"Dropped annotations: {len(dropped)}") logger.info("Dropped annotations:") for desc, onset in dropped: logger.info(f" - {desc} at {onset:.2f}s") else: logger.info("No annotations were dropped!") return raw_haemo_filtered_good_epochs