diff --git a/changelog.md b/changelog.md index 528c6d1..8a7220e 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,7 @@ +# Version 1.1.0 + +- Changelog details coming soon + # Version 1.0.1 - Two new options have been added when clicking on a participant's file. diff --git a/fNIRS_module.py b/fNIRS_module.py deleted file mode 100644 index c3bcdf1..0000000 --- a/fNIRS_module.py +++ /dev/null @@ -1,4107 +0,0 @@ -""" -Filename: fNIRS_module.py -Description: Core functionality for FLARES - -Author: Tyler de Zeeuw -License: GPL-3.0 -""" - -# Built-in imports -import os -import sys -import time -import logging -import platform -import warnings -import threading -from io import BytesIO -from copy import deepcopy -from pathlib import Path -from zipfile import ZipFile -from datetime import datetime -from itertools import compress -from multiprocessing import Queue -from typing import Any, Optional, cast, Literal, Iterator, Union - -# External library imports -import h5py -import pywt # type: ignore -import qtpy # type: ignore -import xlrd # type: ignore -import psutil -import scooby # type: ignore -import requests -import pyvistaqt # type: ignore -import darkdetect # type: ignore -import numpy as np -import pandas as pd -from PIL import Image -import seaborn as sns -import neurokit2 as nk # type: ignore -from tqdm.auto import tqdm -from pandas import DataFrame -import matplotlib.pyplot as plt -from matplotlib.axes import Axes -from numpy.typing import NDArray -#import vtkmodules.util.data_model -from numpy import floating, float64 -from matplotlib.lines import Line2D -import matplotlib.colors as mcolors -from scipy.stats import ttest_1samp # type: ignore -from matplotlib.figure import Figure -import statsmodels.formula.api as smf # type: ignore -#import vtkmodules.util.execution_model -from nilearn.plotting import plot_design_matrix # type: ignore -from scipy.signal import welch, butter, filtfilt # type: ignore -from matplotlib.colors import LinearSegmentedColormap -from IPython.display import display, Markdown, clear_output # type: ignore -from statsmodels.tools.sm_exceptions import ConvergenceWarning # type: ignore -from concurrent.futures import ProcessPoolExecutor, as_completed -from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas - -# External library imports for mne -import mne -from mne import EvokedArray, Info, read_source_spaces, stc_near_sensors # type: ignore -from mne.source_space import SourceSpaces -from mne.transforms import Transform # type: ignore -from mne.io import BaseRaw, read_raw_snirf # type: ignore -from mne.annotations import Annotations # type: ignore -from mne_nirs.visualisation import plot_glm_group_topo # type: ignore -from mne_nirs.channels import get_long_channels, get_short_channels, picks_pair_to_idx # type: ignore -from mne_nirs.experimental_design import make_first_level_design_matrix # type: ignore -from mne_nirs.statistics import run_glm, statsmodels_to_results # type: ignore -from mne_nirs.signal_enhancement import enhance_negative_correlation, short_channel_regression # type: ignore -from mne.preprocessing.nirs import beer_lambert_law, optical_density, temporal_derivative_distribution_repair, source_detector_distances, short_channels # type: ignore -from mne_nirs.io.fold import fold_channel_specificity # type: ignore -from mne_nirs.preprocessing import peak_power # type: ignore -from mne.viz import Brain -from mne_nirs.statistics._glm_level_first import RegressionResults # type: ignore -from mne.filter import filter_data # type: ignore - -CURRENT_VERSION = "1.0.0" - -GUI = False -PLATFORM_NAME = platform.system().lower() - -BASE_SNIRF_FOLDER: str -SNIRF_SUBFOLDERS: list[str] -STIM_DURATION: list[float] -MAX_WORKERS: int - -SECONDS_TO_STRIP: int -DOWNSAMPLE: bool -DOWNSAMPLE_FREQUENCY: int -FORCE_DROP_CHANNELS: list[str] -SOURCE_DETECTOR_SEPARATOR: str - -OPTODE_FILE: bool -OPTODE_FILE_PATH: str -OPTODE_FILE_SEPARATOR: str - -TDDR: bool - -WAVELET: bool -IQR: float - -HEART_RATE: bool -SECONDS_TO_STRIP_HR: int -MAX_LOW_HR: int -MAX_HIGH_HR: int -SMOOTHING_WINDOW_HR: int -HEART_RATE_WINDOW: int -SHORT_CHANNEL: bool -SHORT_CHANNEL_THRESH: float - -SCI: bool -SCI_TIME_WINDOW: int -SCI_THRESHOLD: float -PSP: bool -PSP_TIME_WINDOW: int -PSP_THRESHOLD: float - -# TODO: Implement -SNR: bool -SNR_TIME_WINDOW : int -SNR_THRESHOLD: float - -EXCLUDE_CHANNELS: bool -MAX_BAD_CHANNELS: int -LONG_CHANNEL_THRESH: float - -METADATA: dict - -DRIFT_MODEL: str -DURATION_BETWEEN_ACTIVITIES: int -HRF_MODEL: str -SHORT_CHANNEL_REGRESSION: bool - -N_JOBS: int - -TARGET_ACTIVITY: str -TARGET_CONTROL: str -ROI_GROUP_1: list[list[int]] -ROI_GROUP_2: list[list[int]] -ROI_GROUP_1_NAME: str -ROI_GROUP_2_NAME: str -P_THRESHOLD: float - -SEE_BAD_IMAGES: bool -ABS_T_VALUE: int -ABS_THETA_VALUE: int -ABS_CONTRAST_T_VALUE: int -ABS_CONTRAST_THETA_VALUE: int -ABS_SIGNIFICANCE_T_VALUE: int -ABS_SIGNIFICANCE_THETA_VALUE: int -BRAIN_DISTANCE: float -BRAIN_MODE: str - -EPOCH_REJECT_CRITERIA_THRESH: float -TIME_MIN_THRESH: int -TIME_MAX_THRESH: int - -VERBOSITY: bool - -REJECT_PAIRS = None -FORCE_DROP_ANNOTATIONS = None -FILTER_LOW_PASS = None -FILTER_HIGH_PASS = None -EPOCH_PAIR_TOLERANCE_WINDOW = None - -# FIXME: Shouldn't need each ordering - just order it before checking -FIXED_CATEGORY_COLORS = { - "SCI only": "skyblue", - "PSP only": "salmon", - "SNR only": "lightgreen", - "PSP + SCI": "orange", - "SCI + SNR": "violet", - "PSP + SNR": "gold", - "SCI + PSP": "orange", - "SNR + SCI": "violet", - "SNR + PSP": "gold", - "PSP + SNR + SCI": "gray", - "SCI + PSP + SNR": "gray", - "SCI + SNR + PSP": "gray", - "PSP + SCI + SNR": "gray", - "PSP + SNR + SCI": "gray", - "SNR + SCI + PSP": "gray", - "SNR + PSP + SCI": "gray", -} - - -REQUIRED_KEYS: dict[str, Any] = { - "BASE_SNIRF_FOLDER": str, - "SNIRF_SUBFOLDERS": list, - "STIM_DURATION": list, - "MAX_WORKERS": int, - "SECONDS_TO_STRIP": int, - "DOWNSAMPLE": bool, - "DOWNSAMPLE_FREQUENCY": int, - "FORCE_DROP_CHANNELS": list, - "SOURCE_DETECTOR_SEPARATOR": str, - "OPTODE_FILE": bool, - "OPTODE_FILE_PATH": str, - "OPTODE_FILE_SEPARATOR": str, - "TDDR": bool, - "WAVELET": bool, - "IQR": float, - "HEART_RATE": bool, - "SECONDS_TO_STRIP_HR": int, - "MAX_LOW_HR": int, - "MAX_HIGH_HR": int, - "SMOOTHING_WINDOW_HR": int, - "HEART_RATE_WINDOW": int, - "SHORT_CHANNEL": bool, - "SHORT_CHANNEL_THRESH": float, - "SCI": bool, - "SCI_TIME_WINDOW": int, - "SCI_THRESHOLD": float, - "PSP": bool, - "PSP_TIME_WINDOW": int, - "PSP_THRESHOLD": float, - "SNR": bool, - "SNR_TIME_WINDOW": int, - "SNR_THRESHOLD": float, - "EXCLUDE_CHANNELS": bool, - "MAX_BAD_CHANNELS": int, - "LONG_CHANNEL_THRESH": float, - "METADATA": dict, - "DRIFT_MODEL": str, - "DURATION_BETWEEN_ACTIVITIES": int, - "HRF_MODEL": str, - "SHORT_CHANNEL_REGRESSION": bool, - "N_JOBS": int, - "TARGET_ACTIVITY": str, - "TARGET_CONTROL": str, - "ROI_GROUP_1": list, - "ROI_GROUP_2": list, - "ROI_GROUP_1_NAME": str, - "ROI_GROUP_2_NAME": str, - "P_THRESHOLD": float, - "SEE_BAD_IMAGES": bool, - "ABS_T_VALUE": int, - "ABS_THETA_VALUE": int, - "ABS_CONTRAST_T_VALUE": int, - "ABS_CONTRAST_THETA_VALUE": int, - "ABS_SIGNIFICANCE_T_VALUE": int, - "ABS_SIGNIFICANCE_THETA_VALUE": int, - "BRAIN_DISTANCE": float, - "BRAIN_MODE": str, - "EPOCH_REJECT_CRITERIA_THRESH": float, - "TIME_MIN_THRESH": int, - "TIME_MAX_THRESH": int, - "VERBOSITY": bool, - - # "REJECT_PAIRS": bool, - # "FORCE_DROP_ANNOTATIONS": list, - # "FILTER_LOW_PASS": float, - # "FILTER_HIGH_PASS": float, - # "EPOCH_PAIR_TOLERANCE_WINDOW": int, -} - - -# Ensure that we are working in the directory of this file -script_dir = os.path.dirname(os.path.abspath(__file__)) -os.chdir(script_dir) - - -# Configure logging to file with timestamps and realtime flush -if PLATFORM_NAME == 'darwin': - logging.basicConfig( - filename=os.path.join(os.path.dirname(sys.executable), "../../../fnirs_analysis.log"), - level=logging.INFO, - format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - filemode='a' - ) - -else: - logging.basicConfig( - filename='fnirs_analysis.log', - level=logging.INFO, - format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - filemode='a' - ) - -logger = logging.getLogger() - -class ProcessingError(Exception): - def __init__(self, message: str = "Something went wrong!"): - self.message = message - super().__init__(self.message) - - -def gui_entry(config: dict[str, Any], gui_queue: Queue, progress_queue: Queue) -> None: - try: - print("setting config") - set_config(config, True) - - # Start a thread to forward progress messages back to GUI - def forward_progress(): - while True: - try: - msg = progress_queue.get(timeout=1) - if msg == "__done__": - break - gui_queue.put(msg) - except: - continue - - t = threading.Thread(target=forward_progress, daemon=True) - t.start() - - # Run the actual processing, with progress_queue passed down - print("actual call") - result = run_groups(config, True, progress_queue=progress_queue) - - # Signal end of progress - progress_queue.put("__done__") - t.join() - - gui_queue.put({"success": True, "result": result}) - - - except Exception as e: - import traceback - gui_queue.put({ - "success": False, - "error": str(e), - "traceback": traceback.format_exc() - }) - - -def set_config(config: dict[str, Any], gui: bool = False) -> None: - """ - Validates and applies the given configuration dictionary. - - Parameters - ---------- - config : dict[str, Any] - Dictionary containing configuration keys and their values. - """ - - if gui: - globals().update({"GUI": True}) - - # Ensure all keys are present - for key, expected_type in REQUIRED_KEYS.items(): - if key not in config: - raise KeyError(f"Missing config key: {key}") - - value = config[key] - if not isinstance(value, expected_type): - # Special handling for lists to check list contents - if expected_type == list and isinstance(value, list): - continue # optionally: validate inner types too - raise TypeError(f"Key '{key}' has incorrect type. Expected {expected_type.__name__}, got {type(value).__name__}") - - # Update the global variables to match the values in the config keys - globals().update(config) - - # Ensure that passed through variables are correct or that they actually exist - assert Path(BASE_SNIRF_FOLDER).is_dir(), "BASE_SNIRF_FOLDER was not found. Please check the folder location and try again." - - for folder in SNIRF_SUBFOLDERS: - assert Path(os.path.join(BASE_SNIRF_FOLDER, folder)).is_dir(), f"The subfolder {folder} could not be found. Please check the folder location and try again." - - assert len(SNIRF_SUBFOLDERS) == len(STIM_DURATION), f"The amount of subfolders do not match the amount of stim durations. Subfolders: {len(SNIRF_SUBFOLDERS)} Stim durations: {len(STIM_DURATION)}" - - if OPTODE_FILE: - path = Path(OPTODE_FILE_PATH) - assert path.is_file(), "OPTODE_FILE was specified, but OPTODE_FILE_PATH is not a file." - assert path.suffix == ".txt", "OPTODE_FILE_PATH does not end with a .txt extension." - - # Ensure that the BASE_SNIRF_FOLDER is an absolute path - helpful when logger.infoing later - if 'BASE_SNIRF_FOLDER' in globals(): - abs_path = str(Path(BASE_SNIRF_FOLDER).resolve()) - globals()['BASE_SNIRF_FOLDER'] = abs_path - - # Supress MNE's warnings - if not VERBOSITY: - warnings.filterwarnings("ignore", category=ConvergenceWarning) - warnings.filterwarnings("ignore", category=RuntimeWarning) - - logger.info("[Config] Configuration successfully set.") - - - -def run_groups(config, gui: bool = False, progress_queue=None) -> tuple[dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], dict[str, dict[str, BaseRaw]], dict[str, list[Figure]], dict[str, str], float]: - """ - Process multiple data folders and aggregate results, haemoglobin, figures, and processing details. - - Returns - ------- - tuple[dict[str, tuple[DataFrame, DataFrame, DataFrame]], dict[str, dict[str, BaseRaw]], dict[str, list[Figure]], dict[str, str]] - - dict[str, tuple[DataFrame, DataFrame, DataFrame]]: Results dataframes grouped by folder. - - dict[str, dict[str, BaseRaw]]: Raw haemoglobin data indexed by file ID. - - dict[str, list[Figure]]: Figures generated during processing grouped by step. - - dict[str, str]: Processing status messages indexed by file ID. - - float: Elapsed time - """ - - # Create dictionaries to store our results - all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]] = {} - all_figures: dict[str, list[Figure]] = {} - all_raw_haemo: dict[str, dict[str, BaseRaw]] = {} - all_processes: dict[str, str] = {} - - # Variables to store our total files to be processed and the remaining amount of files while the program is running - total_files = 0 - files_remaining = {'count': 0} - - start_time = time.time() - - # Iterate over all the folders and determine how many files are in the folder - logger.info("Calculating how many files there are...") - for folder in SNIRF_SUBFOLDERS: - full_path = os.path.join(BASE_SNIRF_FOLDER, folder) - num_items = len([ - f for f in os.listdir(full_path) - if os.path.isfile(os.path.join(full_path, f)) - ]) - total_files += num_items - logger.info(f"Total of {total_files} files.") - - # Set the remaining count to be the total amount of files - files_remaining['count'] = total_files - - # Iterate over all the folders - for folder, stim_duration in zip(SNIRF_SUBFOLDERS, STIM_DURATION): - full_path = os.path.join(BASE_SNIRF_FOLDER, folder) - try: - # Process all participants in the folder - logger.info(f"Processing all files in {folder}...") - raw_haemo, df_roi, df_cha, df_con, df_design_matrix, figures, process = process_folder(full_path, stim_duration, files_remaining, config, gui, progress_queue=progress_queue) - - # Store the results into the corresponding dictionaries - logger.info(f"Storing the results from the {folder} folder...") - - # TODO: This looks yucky - try: - all_results[folder] = (df_roi, df_cha, df_con, df_design_matrix) - logger.info(f"Applied all results.") - except: - pass - try: - for step, fig_list in figures.items(): - all_figures.setdefault(step, []).extend(fig_list) - logger.info(f"Applied all figures.") - except: - pass - try: - for file_id, raw in raw_haemo.items(): - all_raw_haemo[file_id] = raw - logger.info(f"Applied all haemo.") - except: - pass - try: - for file_id, p in process.items(): - all_processes[file_id] = p - logger.info(f"Applied all processes.") - except: - pass - - except ProcessingError as e: - logger.info(f"Something happened! {e}") - # Something really bad happened. No partial return - raise Exception(e) - - except Exception as e: - logger.info(f"Something happened! {e}") - # Still return a partial analysis even if something goes wrong - return all_results, all_raw_haemo, all_figures, all_processes, time.time() - start_time - - return all_results, all_raw_haemo, all_figures, all_processes, time.time() - start_time - - - -def create_image_montage(images: list[Image.Image], cols: int) -> Optional[Image.Image]: - """ - Creates a grid montage image from a list of PIL Images. - - Parameters - ---------- - images : list[Image.Image] - List of images to arrange in the montage. - cols : int - Number of columns in the montage grid. - - Returns - ------- - Optional[Image.Image] - The combined montage image, or None if the input list of images is empty. - """ - - # Verify that we have images to process - if not images: - return None - - # Calculate the width, height, and rows - logger.info("Calculating the montage parameters...") - widths, heights = zip(*(i.size for i in images)) - max_width = max(widths) - max_height = max(heights) - rows = (len(images) + cols - 1) // cols - - # Create the montage image - logger.info("Creating the montage...") - montage = Image.new('RGBA', (cols * max_width, rows * max_height), (255, 255, 255, 255)) - for idx, image in enumerate(images): - x = (idx % cols) * max_width - y = (idx // cols) * max_height - montage.paste(image, (x, y)) # type: ignore - - return montage - - - -def show_all_images(figures: dict[str, list[BytesIO]], inline: bool = False) -> None: - """ - Displays montages of figures either inline or in separate windows. - - Parameters - ---------- - figures : dict[str, list[Figure]] - Dictionary containing lists of figures categorized by type. - inline : bool, optional - If True, display images inline (e.g., in Jupyter notebooks). Otherwise, opens them in separate windows (default is False). - """ - - if inline: - logger.info("Inline was selected.") - else: - logger.info("Inline was not selected.") - - # If we have less than 4 figures, the columns should be the exact amount of images we have. If we have more, enforce 4 columns - logger.info("Calculating columns...") - if len(figures.get('Raw', [])) < 4: - cols = len(figures.get('Raw', [])) - else: - cols = 4 - - # Iterate over all of the types of figure, create a montage with figures of the same type, and display the resulting image - logger.info("Generating images...") - for _, fig_bytes_list in figures.items(): - pil_images = [] - for b in fig_bytes_list: - try: - img = Image.open(BytesIO(b)).convert("RGB") - pil_images.append(img) - except Exception as e: - logger.warning(f"Could not open image from bytes: {e}") - continue - - montage = create_image_montage(pil_images, cols) - - if montage: - # Determine how to display the images to the user - if inline: - display(montage) - else: - montage.show() - - - -def save_all_images(figures: dict[str, list[Figure]]) -> None: - """ - Saves montages of figures as timestamped PNG files in folder called 'images'. - - Parameters - ---------- - figures : dict[str, list[Figure]] - Dictionary containing lists of figures categorized by type. - """ - - # Get the current working directory and create a folder called images if it does not exist - logger.info("Getting the current directory...") - if PLATFORM_NAME == 'darwin': - images_folder = os.path.join(os.path.dirname(sys.executable), "../../../images") - else: - cwd = os.getcwd() - images_folder = os.path.join(cwd, "images") - - logger.info("Attempting to create the images folder...") - os.makedirs(images_folder, exist_ok=True) - - # Generate a timestamp to be appended to the end of the file name - logger.info("Generating the timestamp...") - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - - # If we have less than 4 figures, the columns should be the exact value. If we have more, enforce 4 columns - logger.info("Calculating columns...") - raw_fig_count = len(figures.get('Raw', [])) - if raw_fig_count < 4: - cols = raw_fig_count - else: - cols = 4 - - # Iterate over all of the types of figures, create a montage with figures of the same type, and save the resulting image - logger.info("Generating images...") - for step, fig_bytes_list in figures.items(): - pil_images = [] - for b in fig_bytes_list: - try: - img = Image.open(BytesIO(b)).convert("RGB") - pil_images.append(img) - except Exception as e: - logger.warning(f"Could not open image from bytes: {e}") - continue - - montage = create_image_montage(pil_images, cols) - if montage: - filename = f"{step}_{timestamp}.png" - save_path = os.path.join(images_folder, filename) - montage.save(save_path) # type: ignore - logger.info(f"Saved image to {save_path}") - - logger.info(f"All images have been saved to '{images_folder}'.") - - - -def load_snirf(file_path: str, ID: str, drop_prefixes: list[str]) -> tuple[BaseRaw, Figure]: - """ - Loads a snirf file, optionally drops channels, downsamples, and creates a figure showing the results. - - Parameters - ---------- - file_path : str - Path of the snirf file to load. - ID : str - File name of the the snirf file that was loaded. - drop_prefixes : list[str] - List of channel name prefixes to drop from the data. - - Returns - ------- - tuple[BaseRaw, Figure] - - BaseRaw: The processed data object. - - Figure: The corresponding Matplotlib figure. - """ - - logger.info(f"Loading the snirf file ({ID})...") - - # Read the snirf file - raw = read_raw_snirf(file_path, verbose=VERBOSITY) # type: ignore - raw.load_data(verbose=VERBOSITY) # type: ignore - - # Strip the specified amount of seconds from the start of the file - total_duration = getattr(raw, "times")[-1] - if total_duration > SECONDS_TO_STRIP: - raw.crop(tmin=SECONDS_TO_STRIP, tmax=total_duration, verbose=VERBOSITY) # type: ignore - logger.info(f"Stripped first {SECONDS_TO_STRIP} second(s) of data.") - else: - logger.info(f"Data length ({total_duration:.2f}s) less than strip duration; no cropping applied.") - - # If the user forcibly dropped channels, remove them now before any processing occurs - logger.info("Checking if there are channels to forcibly drop...") - if drop_prefixes: - logger.info("Force dropped channels was specified.") - channels_to_drop = [ch for ch in cast(list[str], getattr(raw, "ch_names")) if any(ch.startswith(prefix) for prefix in drop_prefixes)] - raw.drop_channels(channels_to_drop, "raise") # type: ignore - logger.info("Force dropped channels:", channels_to_drop) - - # If the user wants to downsample, do it right away - logger.info("Checking if we should downsample...") - if DOWNSAMPLE: - logger.info("Downsample was specified.") - sfreq_old = getattr(raw, "info")["sfreq"] - raw.resample(DOWNSAMPLE_FREQUENCY, verbose=VERBOSITY) # type: ignore - sfreq_new = getattr(raw, "info")["sfreq"] - logger.info(f"Finished downsampling. Old frequency: {sfreq_old}. New frequency: {sfreq_new}.") - - # Create a figure for the results - logger.info("Creating the figure...") - fig = cast(Figure, raw.plot(show=False, n_channels=len(getattr(raw, "ch_names")), duration=raw.times[-1]).figure) # type: ignore - fig.suptitle(f"Raw fNIRS Data for {ID}", fontsize=16) # type: ignore - fig.subplots_adjust(top=0.92) - plt.close(fig) - - logger.info("Successfully loaded the snirf file.") - - return raw, fig - - - -def calculate_and_apply_updated_optode_coordinates(data: BaseRaw) -> BaseRaw: - """ - Update optode coordinates on the given MNE Raw data using a specified optode file. - - Parameters - ---------- - data : BaseRaw - The loaded data object to process with new optode coordinates. - - Returns - ------- - BaseRaw - The processed data object with the updated montage applied. - """ - - logger.info("Updating optode coordinates...") - - fiducials: dict[str, NDArray[floating[Any]]] = {} - ch_positions: dict[str, NDArray[floating[Any]]] = {} - - # Read the lines from the optode file - logger.info(f"Reading optode file from {OPTODE_FILE_PATH}") - with open(OPTODE_FILE_PATH, 'r') as f: - for line in f: - if line.strip(): - # Split by the semicolon and convert to meters - ch_name, coords_str = line.split(OPTODE_FILE_SEPARATOR) - coords = np.array(list(map(float, coords_str.strip().split()))) * 0.001 - logger.info(f"Read line: {ch_name} with coords (m): {coords}") - - # The key we have is a fiducial - if ch_name.lower() in ['lpa', 'nz', 'rpa']: - fiducials[ch_name.lower()] = coords - - # The key we have is a source or detector - else: - ch_positions[ch_name.upper()] = coords - - - # Create montage with updated coords in head space - logger.info("Creating and applying the montage...") - initial_montage = mne.channels.make_dig_montage(ch_pos=ch_positions, nasion=fiducials.get('nz'), lpa=fiducials.get('lpa'), rpa=fiducials.get('rpa'), coord_frame='head') # type: ignore - data.set_montage(initial_montage, verbose=VERBOSITY) # type: ignore - logger.info("Successfully updated optode coordinates.") - - return data - - - -def calculate_and_apply_tddr(data: BaseRaw, ID: str) -> tuple[BaseRaw, Figure]: - """ - Applies Temporal Derivative Distribution Repair (TDDR) to the raw data and creates a figure showing the results. - - Parameters - ---------- - data : BaseRaw - The loaded data object to process. - ID : str - File name of the the snirf file that was loaded. - - Returns - ------- - tuple[BaseRaw, Figure] - - BaseRaw: The processed data object. - - Figure: The corresponding Matplotlib figure. - """ - - # Apply TDDR - logger.info("Applying temporal derivative distribution repair...") - raw_with_tddr = cast(BaseRaw, temporal_derivative_distribution_repair(data, verbose=VERBOSITY)) - - # Create a figure for the results - logger.info("Creating the figure...") - fig = cast(Figure, raw_with_tddr.plot(show=False, n_channels=len(getattr(data, "ch_names")), duration=data.times[-1]).figure) # type: ignore - fig.suptitle(f"TDDR for {ID}", fontsize=16) # type: ignore - fig.subplots_adjust(top=0.92) - plt.close(fig) - - logger.info("Successfully applied temporal derivative distribution repair.") - return raw_with_tddr, fig - - - -def iqr_threshold(coeffs: NDArray[float64], k: float = 1.5) -> floating[Any]: - - """ - Calculate the interquartile range (IQR) threshold scaled by a factor, k. - - Parameters - ---------- - coeffs : NDArray[float64] - Array of coefficients to compute the IQR from. - k : float, optional - Scaling factor for the IQR (default is 1.5). - - Returns - ------- - floating[Any] - The scaled IQR threshold value. - """ - - # Calculate the IQR - q1 = np.percentile(coeffs, 25) - q3 = np.percentile(coeffs, 75) - iqr = q3 - q1 - - return k * iqr - - - -def wavelet_iqr_denoise(signal: NDArray[float64], wavelet: str = 'db4', level: int = 3) -> NDArray[float64]: - """ - Denoises a signal using wavelet decomposition and IQR-based thresholding on detail coefficients. - - Parameters - ---------- - signal : NDArray[float64] - The input signal array to denoise. - wavelet : str, optional - The type of wavelet to use for decomposition (default is 'db4'). - level : int, optional - Decomposition level for wavelet transform (default is 3). - - Returns - ------- - NDArray[float64] - The denoised signal array, with the same length as the input. - """ - - # Decompose the signal using wavelet transform and initialize a list with approximation coefficients - coeffs: list[NDArray[float64]] = pywt.wavedec(signal, wavelet, level=level) # type: ignore - cA = coeffs[0] - denoised_coeffs = [cA] - - # Threshold detail coefficients to reduce noise - for cD in coeffs[1:]: - threshold = iqr_threshold(cD, IQR) - cD_thresh = np.sign(cD) * np.maximum(np.abs(cD) - threshold, 0.0) # np.where((cD < lower) | (cD > upper), 0, cD) - cD_thresh = cD_thresh.astype(float64) - denoised_coeffs.append(cD_thresh) - - # Reconstruct the denoised signal - denoised_signal = cast(NDArray[float64], pywt.waverec(denoised_coeffs, wavelet)) # type: ignore - return denoised_signal[:len(signal)] - - - -def calculate_and_apply_wavelet(data: BaseRaw, ID: str) -> tuple[BaseRaw, Figure]: - """ - Applies a wavelet IQR denoising filter to the data and generates a plot. - - Parameters - ---------- - data : BaseRaw - The loaded data object to process. - ID : str - File name of the the snirf file that was loaded. - - Returns - ------- - tuple[BaseRaw, Figure] - - BaseRaw: The processed data object. - - Figure: The corresponding Matplotlib figure. - """ - - logger.info("Applying the wavelet filter...") - - # Denoise the data - logger.info("Denoising the data...") - loaded_data: NDArray[float64] = data.get_data(verbose=VERBOSITY) # type: ignore - denoised_data = np.zeros_like(loaded_data) - - logger.info("Calculating the IQR, decomposing the signal, and thresholding the coefficients...") - for ch in range(loaded_data.shape[0]): - denoised_data[ch, :] = wavelet_iqr_denoise(loaded_data[ch, :], wavelet='db4', level=3) - - # Reconstruct the data with the annotations - logger.info("Reconstructing the data with annotations...") - raw_with_tddr_and_wavelet = mne.io.RawArray(denoised_data, cast(mne.Info, data.info), verbose=VERBOSITY) - raw_with_tddr_and_wavelet.set_annotations(data.annotations.copy(), verbose=VERBOSITY) # type: ignore - - # Create a figure for the results - logger.info("Creating the figure...") - fig = cast(Figure, raw_with_tddr_and_wavelet.plot(show=False, n_channels=len(getattr(data, "ch_names")), duration=data.times[-1]).figure) # type: ignore - fig.suptitle(f"Wavelet for {ID}", fontsize=16) # type: ignore - fig.subplots_adjust(top=0.92) - plt.close(fig) - - logger.info("Successfully applied the wavelet filter.") - - return raw_with_tddr_and_wavelet, fig - - - -def short_channel_processing_for_hr(data: BaseRaw, short_chans: BaseRaw | None) -> tuple[float, NDArray[float64], NDArray[float64]]: - """ - Extract and trim short-channel fNIRS signal for heart rate analysis. - - Parameters - ---------- - data : BaseRaw - The loaded data object to process. - short_chans : BaseRaw | None - Data object with only short separation channels, or None if unavailable. - - Returns - ------- - tuple[float, NDArray[float64], NDArray[float64]] - - float: Sampling frequency of the signal. - - NDArray[float64]: Trimmed short-channel signal. - - NDArray[float64]: Corresponding time values. - """ - - # Find the short channel (or best candidate) and extract signal data and sampling frequency - logger.info("Extracting the signal and calculating the sampling frequency...") - - # If a short channel exists, use it for our signal. Otherwise just take the first channel in the data - # TODO: Find a better way around this - if short_chans is not None: - signal = cast(NDArray[float64], short_chans.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore - else: - signal = cast(NDArray[float64], data.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore - - # Calculate the sampling frequency - sfreq = cast(int, data.info['sfreq']) - - # Trim start and end of the signal to remove edge artifacts - logger.info(f"Removing {SECONDS_TO_STRIP_HR} seconds from the beginning and end of the file...") - strip_samples = int(sfreq * SECONDS_TO_STRIP_HR) - signal_trimmed = signal[strip_samples:-strip_samples] - times_trimmed = cast(NDArray[float64], getattr(data, "times"))[strip_samples:-strip_samples] - - return sfreq, signal_trimmed, times_trimmed - - - -def calculate_heart_rate_neurokit(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[float64], float]: - """ - Calculate and smooth heart rate from a trimmed signal using NeuroKit. - - Parameters - ---------- - sfreq : float - Sampling frequency of the signal. - signal_trimmed : NDArray[float64] - Preprocessed and trimmed fNIRS signal. - - Returns - ------- - tuple[NDArray[float64], float] - - NDArray[float64]: Smoothed heart rate time series (BPM). - - float: Mean heart rate. - """ - - logger.info("Calculating heart rate using NeuroKit...") - - # Filter signal to isolate heart rate frequencies and detect peaks - logger.info("Filtering the signal and detecting peaks...") - signal_filtered = cast(NDArray[float64], nk.signal_filter(signal_trimmed, sampling_rate=sfreq, lowcut=0.8, highcut=2.5)) # type: ignore - peaks_dict = cast(dict[str, Any], nk.signal_findpeaks(signal_filtered)) # type: ignore - peaks = peaks_dict['Peaks'] - hr = cast(NDArray[float64], nk.signal_rate(peaks, sampling_rate=sfreq, desired_length=len(signal_trimmed))) # type: ignore - hr_clean = np.clip(hr, MAX_LOW_HR, MAX_HIGH_HR) - - # Smooth heart rate time series by replacing spikes with local rolling mean and calculate the mean - logger.info("Smoothing the signal and calculating the mean...") - hr_series = pd.Series(hr_clean) - local_median = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).median() - spikes = hr_series > (local_median + 10) - smoothed_values = hr_series.copy() - smoothed_spikes = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).mean() - smoothed_values[spikes] = smoothed_spikes[spikes] - hr_smooth_nk = cast(NDArray[float64], smoothed_values.to_numpy()) # type: ignore - mean_hr_nk = hr_smooth_nk.mean() - - logger.info("Original HR min/max: %f, %f", hr_clean.min(), hr_clean.max()) - logger.info("Smoothed HR min/max:%f, %f", hr_smooth_nk.min(), hr_smooth_nk.max()) - logger.info(f"Estimated mean HR nk: {mean_hr_nk:.1f} BPM") - - logger.info("Successfully calculated heart rate using NeuroKit.") - - return hr_smooth_nk, mean_hr_nk - - - -def calculate_heart_rate_scipy(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float]: - """ - Estimate heart rate using spectral analysis on a high-pass filtered signal. - - Parameters - ---------- - sfreq : float - Sampling frequency of the input signal. - signal_trimmed : NDArray[float64] - Trimmed fNIRS signal to analyze. - - Returns - ------- - tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float] - - NDArray[floating[Any]]: Frequencies converted to beats per minute (BPM). - - NDArray[float64]: Power spectral density (PSD) of the signal. - - np.ndarray[Any, np.dtype[np.bool_]]: Boolean mask indicating frequencies within heart rate range (30-300 BPM). - - float: Estimated mean heart rate in BPM corresponding to the PSD peak within the range. - """ - - logger.info("Calculating heart rate using SciPy...") - - # Apply a high-pass Butterworth filter to remove slow trends below 0.5 Hz from the trimmed signal (actual data) - logger.info("Applying a butterworth filter...") - b, a = cast(tuple[NDArray[float64], NDArray[float64]], butter(2, 0.5 / (sfreq / 2), btype='high')) - signal_hp = cast(NDArray[float64],filtfilt(b, a, signal_trimmed)) - - # Calculate the Power Spectral Density (PSD) of the filtered signal using Welch's method - logger.info("Calculating the PSD...") - nperseg = min(len(signal_hp), 4096) - frequencies_scipy, psd_scipy = cast(tuple[NDArray[float64], NDArray[float64]], welch(signal_hp, fs=sfreq, nperseg=nperseg, noverlap=nperseg//2)) - - # Convert frequency values to beats per minute (BPM) and set a heart rate range (30-300 BPM) - logger.info("Converting to BPM...") - freq_bpm_scipy = frequencies_scipy * 60 - freq_range_scipy = (freq_bpm_scipy > 30) & (freq_bpm_scipy < 300) - - # Identify the peak frequency within the heart rate range and estimate the mean heart rate in BPM - logger.info("Finding the mean...") - peak_index = np.argmax(psd_scipy[freq_range_scipy]) - mean_hr_scipy = freq_bpm_scipy[freq_range_scipy][peak_index] - - logger.info("Successfully calculated heart rate using SciPy.") - - return freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy - - -def plot_heart_rate( - freq_bpm_scipy: NDArray[floating[Any]], - psd_scipy: NDArray[float64], - freq_range_scipy: np.ndarray[Any, np.dtype[np.bool_]], - mean_hr_scipy: float, - hr_smooth_nk: NDArray[floating[Any]], - mean_hr_nk: float, - times_trimmed: NDArray[floating[Any]], - overruled: bool -) -> tuple[Figure, Figure]: - """ - Generate plots comparing heart rate estimates from SciPy PSD and NeuroKit2. - - Parameters - ---------- - freq_bpm_scipy : NDArray[floating[Any]] - Frequencies in beats per minute from SciPy PSD analysis. - psd_scipy : NDArray[float64] - Power spectral density values corresponding to freq_bpm_scipy. - freq_range_scipy : np.ndarray[Any, np.dtype[np.bool_]] - Boolean mask indicating the heart rate frequency range used in PSD. - mean_hr_scipy : float - Mean heart rate estimated from SciPy PSD peak. - hr_smooth_nk : NDArray[floating[Any]] - Smoothed instantaneous heart rate from NeuroKit2. - mean_hr_nk : float - Mean heart rate estimated from NeuroKit2 data. - times_trimmed : NDArray[floating[Any]] - Time points corresponding to hr_smooth_nk values. - overruled: bool - True if the heart rate from NeuroKit2 is overriding the results from the PSD. - - Returns - ------- - tuple[Figure, Figure] - - Figure showing the PSD and SciPy heart rate estimate. - - Figure showing the time series comparison of heart rates. - """ - - # Create the first plot for the PSD. Add a yellow range to show what we will be filtering to. - logger.info("Creating the figure...") - fig1, ax1 = plt.subplots(figsize=(10, 5)) # type: ignore - ax1.set_xlim(30, 300) - ax1.plot(freq_bpm_scipy[freq_range_scipy], psd_scipy[freq_range_scipy]) # type: ignore - ax1.axvline(x=mean_hr_scipy, color='red', linestyle='--', label=f'Mean HR: {mean_hr_scipy:.1f} BPM') # type: ignore - ax1.axvspan(min(mean_hr_nk - HEART_RATE_WINDOW, mean_hr_scipy - HEART_RATE_WINDOW), max(mean_hr_nk + HEART_RATE_WINDOW, mean_hr_scipy + HEART_RATE_WINDOW), color='yellow', alpha=0.3, label=f'HR Range ±{HEART_RATE_WINDOW} BPM') # type: ignore - ax1.set_xlabel('Heart Rate (BPM)') # type: ignore - ax1.set_ylabel('Power Spectral Density') # type: ignore - ax1.set_title('PSD of fNIRS signal - Peak indicates Heart Rate') # type: ignore - ax1.grid(True) # type: ignore - - # Was the value we reported here correct for the data on the graph or was it overruled? - if overruled: - note = ( - '\n' - 'Note: Calculation was bad!\n' - 'Data has been set to match\n' - 'the value from NeuroKit2.' - ) - phantom = Line2D([0], [0], color='none', label=note) - handles, _ = ax1.get_legend_handles_labels() - ax1.legend(handles=handles + [phantom]) # type: ignore - - else: - ax1.legend() # type: ignore - plt.close(fig1) - - # Create the second plot showing the rolling heart rate, as well as the two averages that were calculated - logger.info("Creating the figure...") - fig2, ax2 = plt.subplots(figsize=(14, 6)) # type: ignore - ax2.plot(times_trimmed, hr_smooth_nk, label='Instantaneous HR (NeuroKit2)', color='blue', alpha=0.7) # type: ignore - ax2.axhline(mean_hr_nk, color='red', linestyle='--', label=f'Mean HR NeuroKit2: {mean_hr_nk:.1f} BPM') # type: ignore - ax2.axhline(mean_hr_scipy, color='orange', linestyle=':', label=f'SciPy Welch PSD (HP filtered): {mean_hr_scipy:.1f} BPM') # type: ignore - ax2.set_xlabel('Time (seconds)') # type: ignore - ax2.set_ylabel('Heart Rate (BPM)') # type: ignore - ax2.set_title('Heart Rate Estimates Comparison') # type: ignore - ax2.legend() # type: ignore - ax2.grid(True) # type: ignore - fig2.tight_layout() - plt.close(fig2) - - return fig1, fig2 - - - -def plot_timechannel_quality_metrics(data: BaseRaw, scores: NDArray[float64], times: list[tuple[float]], color_stops: tuple[list[float], list[float]], threshold: float, title: Optional[str] = None) -> tuple[Figure, Figure]: - - """ - Generate two heatmaps visualizing channel quality metrics over time. - - Parameters - ---------- - data : BaseRaw - The loaded data object to process. - scores : NDArray[float64] - A 2D array of quality scores for each channel over time. - times : list[tuple[float]] - List of time boundaries used to label each score column. - color_stops : tuple[list[float], list[float]] - Two lists of color values for custom colormaps. - threshold : float, - Threshold value for the color bar. - title : Optional[str], optional - Base title for the figures, (default is None). - - Returns - ------- - tuple[Figure, Figure] - - Figure: Heatmap of all scores across channels and time. - - Figure: Binary heatmap showing only scores above the threshold. - """ - - # Get only the hbo / hbr channels once as we dont need to see the same results twice - half_ch = len(getattr(data, "ch_names")) // 2 - ch_names = getattr(data, "ch_names")[:half_ch] - scores = scores[:half_ch, :] - - # Extract rounded time points to use as column headers - cols = [np.round(t[0]) for t in times] - n_chans = len(ch_names) - vsize = 0.2 * n_chans - - # Create the first figure - fig1, ax1 = plt.subplots(figsize=(10, vsize), layout="constrained") # type: ignore - fig1.suptitle(title + " - All Scores", fontsize=16, fontweight="bold") # type: ignore - - # Create a DataFrame to structure data for the heatmap - data_to_plot = DataFrame( - data=scores, - columns=pd.Index(cols, name="Time (s)"), - index=pd.Index(ch_names, name="Channel"), - ) - - # Define a custom colormap using provided color stops and base colors - base_colors = ['red', 'red', 'yellow', 'green', 'green'] - colors = list(zip(color_stops[0], base_colors[:len(color_stops[0])])) - cmap = mcolors.LinearSegmentedColormap.from_list('gyr', colors) - - # Plot heatmap of scores - sns.heatmap( # type: ignore - data=data_to_plot, - cmap=cmap, - vmin=0, - vmax=1, - cbar_kws=dict(label="Score"), - ax=ax1, - ) - - # Add vertical dashed lines at each time boundary, sit the title, and place a black strikethrough through a bad channel - for x in range(1, len(times)): - ax1.axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") # type: ignore - ax1.set_title("All Scores", fontweight="bold") # type: ignore - markbad(data, ax1, ch_names) - - # Calculate average score per channel and annotate to the right of the heatmap - avg_sci_subset: pd.Series[float] = data_to_plot.mean(axis=1) # type: ignore - norm = mcolors.Normalize(vmin=0, vmax=1) - text_x = data_to_plot.shape[1] + 0.5 - for i, val in enumerate(avg_sci_subset): - color = cmap(norm(val)) - ax1.text( # type: ignore - text_x, - i + 0.5, - f"{val:.3f}", - va='center', - ha='left', - fontsize=9, - color=color - ) - ax1.set_xlim(right=text_x + 1.5) - - plt.close(fig1) - - # Create the second figure - fig2, ax2 = plt.subplots(figsize=(10, vsize), layout="constrained") # type: ignore - fig2.suptitle(title + " - Scores Above Threshold", fontsize=16, fontweight="bold") # type: ignore - - # Create a DataFrame to structure data for the heatmap - data_to_plot = DataFrame( - data=scores > threshold, - columns=pd.Index(cols, name="Time (s)"), - index=pd.Index(ch_names, name="Channel"), - ) - - # Define a custom colormap using provided color stops and base colors - base_colors = ['red', 'red', 'white', 'white'] - colors = list(zip(color_stops[1], base_colors[:len(color_stops[1])])) - cmap = mcolors.LinearSegmentedColormap.from_list('gyr', colors) - - # Plot heatmap of scores - sns.heatmap( # type: ignore - data=data_to_plot, - vmin=0, - vmax=1, - cmap=cmap, - cbar_kws=dict(label="Score"), - ax=ax2, - ) - - # Add vertical dashed lines at each time boundary, sit the title, and place a black strikethrough through a bad channel - for x in range(1, len(times)): - ax2.axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") # type: ignore - ax2.set_title("Scores > Threshold", fontweight="bold") # type: ignore - markbad(data, ax2, ch_names) - - plt.close(fig2) - - return fig1, fig2 - - - -def markbad(data: BaseRaw, ax: Axes, ch_names: list[str]) -> None: - """ - Add a strikethrough to a plot for channels marked as bad. - - Parameters - ---------- - data : BaseRaw - The loaded data object to process. - ax : Axes - Matplotlib Axes object where the strikethrough lines will be drawn. - ch_names : list[str] - List of channel names corresponding to the y-axis of the plot. - """ - - # Iterate over all the channels - for i, ch in enumerate(ch_names): - - # If it is marked as bad, place a strikethrough on the channel - if ch in data.info["bads"]: - ax.axhline(i + 0.5, ls="solid", lw=4, color="black", zorder=10) # type: ignore - - - -def calculate_scalp_coupling(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5) -> tuple[list[str], Figure, Figure]: - """ - Calculate the scalp coupling index (SCI) and identify bad channels based on a threshold. - - Parameters - ---------- - data : BaseRaw - The loaded data object to process. - l_freq : float, optional - Low cutoff frequency for bandpass filtering in Hz (default is 0.7). - h_freq : float, optional - High cutoff frequency for bandpass filtering in Hz (default is 1.5) - - Returns - ------- - tuple[list[str], Figure, Figure] - - list[str]: Channel names identified as bad based on SCI threshold. - - Figure: Heatmap of all SCI scores across time and channels. - - Figure: Binary heatmap of SCI scores exceeding the threshold. - """ - - logger.info("Calculating scalp coupling index...") - - # Compute the SCI - _, scores, times = cast(tuple[NDArray[float64], NDArray[float64], list[tuple[float]]], scalp_coupling_index_windowed_raw(data, time_window=SCI_TIME_WINDOW, l_freq=l_freq, h_freq=h_freq)) - - # Identify channels that don't meet the provided threshold - logger.info("Identifying channels that do not meet the threshold...") - sci = scores.mean(axis=1) - data.info["bads"] = list(compress(cast(list[str], getattr(data, "ch_names")), sci < SCI_THRESHOLD)) - - # Determine the colors based on the threshold, and create the figures - logger.info("Creating the figures...") - color_stops = ([0.0, SCI_THRESHOLD, SCI_THRESHOLD+0.1, 0.8, 1.0], [0.0, SCI_THRESHOLD, SCI_THRESHOLD, 1.0]) - fig1, fig2 = plot_timechannel_quality_metrics(data, scores, times, color_stops, SCI_THRESHOLD, "Scalp Coupling Index") - - logger.info("Successfully calculated scalp coupling index.") - - return list(compress(cast(list[str], getattr(data, "ch_names")), sci < SCI_THRESHOLD)), fig1, fig2 - - - -def scalp_coupling_index_windowed_raw(data: BaseRaw, time_window: float = 3.0, l_freq: float = 0.7, h_freq: float = 1.5, l_trans_bandwidth: float = 0.3, h_trans_bandwidth: float = 0.3) -> tuple[BaseRaw, NDArray[float64], list[tuple[float, float]]]: - """ - Compute windowed scalp coupling index (SCI) across fNIRS channels. - - Parameters - ---------- - data : BaseRaw - The loaded data object to process. - time_window : float, optional - Length of each time window in seconds (default is 3.0). - l_freq : float, optional - Low cutoff frequency for filtering in Hz (default is 0.7). - h_freq : float, optional - High cutoff frequency for filtering in Hz (default is 1.5). - l_trans_bandwidth : float, optional - Transition bandwidth for the low cutoff in Hz (default is 0.3). - h_trans_bandwidth : float, optional - Transition bandwidth for the high cutoff in Hz (default is 0.3). - - Returns - ------- - tuple[BaseRaw, NDArray[float64], list[tuple[float, float]]] - - BaseRaw: The original data object (unchanged). Ensures compatibility with peak_power(). - - NDArray[float64]: Correlation scores for each channel and time window. - - list[tuple[float, float]]: Time intervals for each window in seconds. - """ - - # Pick only fNIRS channels and sort them by channel name - picks: NDArray[np.intp] = mne.pick_types(cast(mne.Info, data.info), fnirs=True) # type: ignore - picks = picks[np.argsort([getattr(data, "ch_names")[pick] for pick in picks])] - - # FIXME: This may happen if the heart rate calculation tries to set a value way too low - if l_freq < 0.3: - l_freq = 0.3 - - # Band-pass filter the selected fNIRS channels - filtered_data = cast(NDArray[float64], filter_data( - getattr(data, "_data"), - getattr(data, "info")["sfreq"], - l_freq, - h_freq, - picks=picks, - verbose=False, - l_trans_bandwidth=l_trans_bandwidth, # type: ignore - h_trans_bandwidth=h_trans_bandwidth, # type: ignore - )) - - # Calculate number of samples per time window, the total number of windows, and prepare output variables - window_samples = int(np.ceil(time_window * getattr(data, "info")["sfreq"])) - n_windows = int(np.floor(len(data) / window_samples)) - scores = np.zeros((len(picks), n_windows)) - times: list[tuple[float, float]] = [] - - # Slide through the data in windows to compute scalp coupling index (SCI) - for window in range(n_windows): - start_sample = int(window * window_samples) - end_sample = start_sample + window_samples - end_sample = np.min([end_sample, len(data) - 1]) - - # Track time boundaries for each window - t_start = getattr(data, "times")[start_sample] - t_stop = getattr(data, "times")[end_sample] - times.append((t_start, t_stop)) - - # Iterate through channels in pairs (hbo, hbr). This requires them to be sorted by channel name - for ii in range(0, len(picks), 2): - c1: NDArray[float64] = filtered_data[picks[ii]][start_sample:end_sample] - c2 = filtered_data[picks[ii + 1]][start_sample:end_sample] - - # Ensure the correlation data is valid - if np.std(c1) == 0 or np.std(c2) == 0 or np.any(np.isnan(c1)) or np.any(np.isnan(c2)): - c = 0 - else: - c = np.corrcoef(c1, c2)[0][1] - - # Assign the computed correlation to both channels in the pair - scores[ii, window] = c - scores[ii + 1, window] = c - - scores = scores[np.argsort(picks)] - - return data, scores, times - - - -def calculate_peak_power(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5) -> tuple[list[str], Figure, Figure]: - """ - Calculate peak spectral power (PSP) for fNIRS channels and identify bad channels. - - Parameters - ---------- - data : BaseRaw - The loaded data object to process. - l_freq : float, optional - Low cutoff frequency for filtering in Hz (default is 0.7) - h_freq : float, optional - High cutoff frequency for filtering in Hz (default is 1.5) - - Returns - ------- - tuple[list[str], Figure, Figure] - - list[str]: Names of channels below the PSP threshold. - - Figure: Heatmap of all PSP scores. - - Figure: Heatmap of scores above the PSP threshold. - """ - - logger.info("Calculating peak spectral power...") - - # Compute the PSP - _, scores, times = cast(tuple[NDArray[float64], NDArray[float64], list[tuple[float]]], peak_power(data, time_window=PSP_TIME_WINDOW, threshold=PSP_THRESHOLD, l_freq=l_freq, h_freq=h_freq, verbose=False)) - - # Identify channels that don't meet the provided threshold - logger.info("Identifying channels that do not meet the threshold...") - psp = scores.mean(axis=1) - data.info["bads"] = list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)) - - # Determine the colors based on the threshold, and create the figures - logger.info("Creating the figures...") - color_stops = ([0.0, PSP_THRESHOLD, PSP_THRESHOLD+0.1, 0.3, 1.0], [0.0, PSP_THRESHOLD, PSP_THRESHOLD, 1.0]) - psp1, psp2 = plot_timechannel_quality_metrics(data, scores, times, color_stops, PSP_THRESHOLD, "Peak Spectral Power") - - logger.info("Successfully calculated peak spectral power.") - - return list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)), psp1, psp2 - - - -def calculate_signal_noise_ratio(data: BaseRaw) -> tuple[list[str], Figure]: - """ - Calculates the signal-to-noise ratio (SNR) for each channel and identifies those below a defined threshold. - - Parameters - ---------- - data : BaseRaw - The loaded data object to process. - - Returns - ------- - tuple[list[str], Figure] - - list[str]: A list of channel names that fall below the SNR threshold and are considered bad. - - Figure: A matplotlib Figure showing the channels' SNR values. - """ - - logger.info("Calculating signal to noise ratio...") - - # Compute the signal-to-noise ratio values - logger.info("Computing the signal to noise power...") - signal_band=(0.01, 0.5) - noise_band=(1.0, 10.0) - data_signal = data.copy().filter(*signal_band, verbose=False) #type: ignore - data_noise = data.copy().filter(*noise_band, verbose=False) #type: ignore - signal_power = np.mean(data_signal.get_data()**2, axis=1) #type: ignore - noise_power = np.mean(data_noise.get_data()**2, axis=1) #type: ignore - - # Calculate the snr using the standard formula for dB - snr = 10 * np.log10(signal_power / (noise_power + np.finfo(float).eps)) - - # TODO: Understand what this does - groups: dict[str, list[str]] = {} - for ch in getattr(data, "ch_names"): - # Look for the space in the channel names and remove the characters after - # This is so we can get both oxy and deoxy to remove, as they will have the same source and detector - base = ch.rsplit(' ', 1)[0] - groups.setdefault(base, []).append(ch) # type: ignore - - # If any of the channels do not meet our threshold, they will get inserted into the bad_channels set - bad_channels: set[str] = set() - for base, ch_list in groups.items(): - if any(s < SNR_THRESHOLD for s, ch in zip(snr, getattr(data, "ch_names")) if ch in ch_list): - bad_channels.update(ch_list) - - # Design and create the figure - logger.info("Creating the figure...") - snr_fig, ax = plt.subplots(figsize=(12, 4), layout="constrained") # type: ignore - colors = [(0/20, 'red'), (SNR_THRESHOLD/20, 'red'), ((SNR_THRESHOLD+.5)/20, 'yellow'), ((SNR_THRESHOLD+1)/20, 'green'), (20/20, 'green')] - cmap = LinearSegmentedColormap.from_list('custom_snr_cmap', colors) - norm = mcolors.Normalize(vmin=0, vmax=20) - scatter = ax.scatter(range(len(snr)), snr, c=snr, cmap=cmap, alpha=0.8, s=100, norm=norm) # type: ignore - ax.set(xlabel="Channel Number", ylabel="Signal-to-Noise Ratio (dB)", xlim=[0, len(snr)], ylim=[0, 20]) - ax.axhline(SNR_THRESHOLD, color='black', linestyle='--', alpha=0.3, linewidth=1) # type: ignore - cbar = snr_fig.colorbar(scatter, ax=ax, label="SNR Thresholds (dB)") # type: ignore - cbar.set_ticks([0, SNR_THRESHOLD, SNR_THRESHOLD+1, 20]) # type: ignore - cbar.set_ticklabels(['0', str(SNR_THRESHOLD), str(SNR_THRESHOLD+1), '20']) # type: ignore - - plt.close() - - logger.info("Successfully calculated signal to noise ratio.") - - return list(bad_channels), snr_fig - - - -def mark_bad_channels(data: BaseRaw, ID: str, bad_channels_sci: set[str], bad_channels_psp: set[str], bad_channels_snr: set[str]) -> tuple[BaseRaw, Figure, int]: - """ - Drops bad channels from the data and generates a bar plot showing which channels were removed and why. - - Parameters - ---------- - data : BaseRaw - The loaded data object to process. - ID : str - File name of the the snirf file that was loaded. - bad_channels_sci : set[str] - Channels marked as bad by the SCI method. - bad_channels_psp : set[str] - Channels marked as bad by the PSP method. - bad_channels_snr : set[str] - Channels marked as bad by the SNR method. - - Returns - ------- - tuple[BaseRaw, Figure] - - BaseRaw: The modified data object with bad channels removed. - - Figure: A matplotlib Figure showing the dropped channels categorized by method. - """ - - logger.info("Dropping the channels that were marked bad...") - - # Combine all of the bad channels into one and ensure the short channel is not present - bad_channels = bad_channels_sci | bad_channels_psp | bad_channels_snr - - logger.info(f"Channels that were bad on SCI: {bad_channels_sci}") - logger.info(f"Channels that were bad on PSP: {bad_channels_psp}") - logger.info(f"Channels that were bad on SNR: {bad_channels_snr}") - logger.info(f"Total bad channels: {bad_channels}") - - # Add the channles to the bads key and drop the bads key from the data - data.info["bads"] = list(bad_channels) - data = cast(BaseRaw, data.drop_channels(getattr(data, "info")["bads"])) # type: ignore - - # Organize channels into categories - sets = [ - (bad_channels_sci, "SCI"), - (bad_channels_psp, "PSP"), - (bad_channels_snr, "SNR"), - ] - - # Graph what channels were dropped and why they were dropped - channel_categories: dict[str, str] = {} - - for ch in bad_channels: - present_in = [name for s, name in sets if ch in s] - # Create a label for the category - if len(present_in) == 1: - label = f"{present_in[0]} only" - else: - label = " + ".join(sorted(present_in)) - channel_categories[ch] = label - - # Sort channels alphabetically within categories for nicer visualization - logger.info("Sorting the bad channels by type...") - categories = sorted(set(channel_categories.values())) - channel_names: list[str] = [] - category_labels: list[str] = [] - for cat in categories: - chs_in_cat = sorted([ch for ch, c in channel_categories.items() if c == cat]) - channel_names.extend(chs_in_cat) - category_labels.extend([cat] * len(chs_in_cat)) - - colors = {cat: FIXED_CATEGORY_COLORS[cat] for cat in categories} - - logger.info("Creating the figure...") - # Create the figure - fig, ax = plt.subplots(figsize=(10, max(3, len(channel_names) * 0.3))) # type: ignore - y_pos = range(len(channel_names)) - ax.barh(y_pos, [1]*len(channel_names), color=[colors[cat] for cat in category_labels]) # type: ignore - ax.set_yticks(y_pos) # type: ignore - ax.set_yticklabels(channel_names) # type: ignore - ax.set_xlabel("Marked as Bad") # type: ignore - ax.set_title(f"Bad Channels by Method for {ID}") # type: ignore - ax.set_xlim(0, 1) - ax.set_xticks([]) # type: ignore - ax.grid(axis='x', linestyle='--', alpha=0.7) # type: ignore - - # Add a legend denoting why the channels were bad - for label, color in colors.items(): - ax.bar(0, 0, color=color, label=label) # type: ignore - ax.legend() # type: ignore - - fig.tight_layout() - plt.close(fig) - - logger.info("Successfully dropped the channels that were marked bad.") - - return data, fig, len(bad_channels) - - - -def calculate_optical_density(data: BaseRaw, ID: str) -> tuple[BaseRaw, Figure]: - """ - Converts raw intensity data to optical density and generates a plot of the transformed signals. - - Parameters - ---------- - data : BaseRaw - The loaded data object to process. - ID : str - File name of the the snirf file that was loaded. - - Returns - ------- - tuple[BaseRaw, Figure] - - BaseRaw: The transformed data in optical density format. - - Figure: A matplotlib figure displaying the optical density signals across all channels. - """ - - logger.info("Calculating optical density...") - - # Calculate the optical density from the raw data - optical_density_data = cast(BaseRaw, optical_density(data)) - - logger.info("Creating the figure...") - fig = cast(Figure, optical_density_data.plot(show=False, n_channels=len(getattr(data, "ch_names")), duration=getattr(data, "times")[-1]).figure) # type: ignore - fig.suptitle(f"Optical density data for {ID}", fontsize=16) # type: ignore - fig.subplots_adjust(top=0.92) - plt.close(fig) - - logger.info("Successfully calculated optical density.") - - return optical_density_data, fig - - - -# STEP 9: Haemoglobin concentration -def calculate_haemoglobin_concentration(optical_density_data: BaseRaw, ID: str, file_path: str) -> tuple[BaseRaw, Figure]: - """ - Calculates haemoglobin concentration from optical density data using the Beer-Lambert law and generates a plot. - - Parameters - ---------- - optical_density_data : BaseRaw - The data in optical density format. - ID : str - File name of the the snirf file that was loaded. - file_path : str - Entire file path if snirf file that was loaded. - - Returns - ------- - tuple[BaseRaw, Figure] - - BaseRaw: The haemoglobin concentration data object. - - Figure: A matplotlib figure displaying the haemoglobin concentration signals. - """ - - logger.info("Calculating haemoglobin concentration data...") - - # Get the haemoglobin concentration using beer lambert law - haemoglobin_concentration_data = beer_lambert_law(optical_density_data, ppf=calculate_dpf(file_path)) - - logger.info("Creating the figure...") - fig = cast(Figure, optical_density_data.plot(show=False, n_channels=len(getattr(optical_density_data, "ch_names")), duration=getattr(optical_density_data, "times")[-1]).figure) # type: ignore - fig.suptitle(f"Haemoglobin concentration data for {ID}", fontsize=16) # type: ignore - fig.subplots_adjust(top=0.92) - plt.close(fig) - - logger.info("Successfully calculated haemoglobin concentration data.") - - return haemoglobin_concentration_data, fig - - - -# -------------------------------------- HARDCODED ----------------------------------------------- - -def extract_normal_epochs(haemoglobin_concentration_data: BaseRaw) -> dict[str, list[Any] | mne.evoked.EvokedArray]: - - events, _ = mne.events_from_annotations(haemoglobin_concentration_data, event_id={"Reach": 1, "Start of Rest": 2}, verbose=VERBOSITY) # type: ignore - event_dict = {"Reach": 1, "Start of Rest": 2} - - epochs = mne.Epochs( - haemoglobin_concentration_data, - events, - event_id=event_dict, - tmin=TIME_MIN_THRESH, - tmax=TIME_MAX_THRESH, - reject=dict(hbo=EPOCH_REJECT_CRITERIA_THRESH), - reject_by_annotation=True, - proj=True, - baseline=(None, 0), - preload=True, - detrend=None, - verbose=VERBOSITY, - ) - - evoked_dict: dict[str, list[Any] | mne.evoked.EvokedArray] = { - "Reach/HbO": epochs["Reach"].average(picks="hbo"), # type: ignore - "Reach/HbR": epochs["Reach"].average(picks="hbr"), # type: ignore - } - - # Rename channels until the encoding of frequency in ch_name is fixed - for condition in evoked_dict: - evoked_dict[condition].rename_channels(lambda x: x[:-4]) # type: ignore - - return evoked_dict - - - -def calculate_and_apply_negative_correlation_enhancement(haemoglobin_concentration_data: BaseRaw) -> dict[str, list[Any] | mne.evoked.EvokedArray]: - - events, _ = mne.events_from_annotations(haemoglobin_concentration_data, event_id={"Reach": 1, "Start of Rest": 2}, verbose=VERBOSITY) # type: ignore - event_dict = {"Reach": 1, "Start of Rest": 2} - - raw_anti = enhance_negative_correlation(haemoglobin_concentration_data) - - epochs_anti = mne.Epochs( - raw_anti, - events, - event_id=event_dict, - tmin=TIME_MIN_THRESH, - tmax=TIME_MAX_THRESH, - reject=dict(hbo=EPOCH_REJECT_CRITERIA_THRESH), - reject_by_annotation=True, - proj=True, - baseline=(None, 0), - preload=True, - detrend=None, - verbose=VERBOSITY, - ) - - evoked_dict_anti: dict[str, list[Any] | mne.evoked.EvokedArray] = { - "Reach/HbO": epochs_anti["Reach"].average(picks="hbo"), # type: ignore - "Reach/HbR": epochs_anti["Reach"].average(picks="hbr"), # type: ignore - } - - # Rename channels until the encoding of frequency in ch_name is fixed - for condition in evoked_dict_anti: - evoked_dict_anti[condition].rename_channels(lambda x: x[:-4]) # type: ignore - - return evoked_dict_anti - - - -def calculate_and_apply_short_channel_correction(optical_density_data: BaseRaw, file_path: str) -> dict[str, list[Any] | mne.evoked.EvokedArray]: - - od_corrected = short_channel_regression(optical_density_data, SHORT_CHANNEL_THRESH) - haemoglobin_concentration_data = beer_lambert_law(od_corrected, ppf=calculate_dpf(file_path)) - - events, _ = mne.events_from_annotations(haemoglobin_concentration_data, event_id={"Reach": 1, "Start of Rest": 2}, verbose=VERBOSITY) # type: ignore - event_dict = {"Reach": 1, "Start of Rest": 2} - - epochs_corr = mne.Epochs( - haemoglobin_concentration_data, - events, - event_id=event_dict, - tmin=TIME_MIN_THRESH, - tmax=TIME_MAX_THRESH, - reject=dict(hbo=EPOCH_REJECT_CRITERIA_THRESH), - reject_by_annotation=True, - proj=True, - baseline=(None, 0), - preload=True, - detrend=None, - verbose=VERBOSITY, - ) - - evoked_dict_corr: dict[str, list[Any] | mne.evoked.EvokedArray] = { - "Reach/HbO": epochs_corr["Reach"].average(picks="hbo"), # type: ignore - "Reach/HbR": epochs_corr["Reach"].average(picks="hbr"), # type: ignore - } - - # Rename channels until the encoding of frequency in ch_name is fixed - for condition in evoked_dict_corr: - evoked_dict_corr[condition].rename_channels(lambda x: x[:-4]) # type: ignore - - return evoked_dict_corr - - - -def signal_enhancement_techniques_images(evoked_dict: dict[str, list[Any] | mne.evoked.EvokedArray], evoked_dict_anti: dict[str, list[Any] | mne.evoked.EvokedArray], evoked_dict_corr:dict[str, list[Any] | mne.evoked.EvokedArray] | None): - - # If we have two images, ensure we only have two columns - if evoked_dict_corr is None: - fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 6)) # type: ignore - - # If we have three images, ensure we have three columns - else: - fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 6)) # type: ignore - - color_dict = dict(HbO="#AA3377", HbR="b") - - # TODO: This is to prevent the warning that we are only plotting one channel. Don't we want all though? - mne.set_log_level('WARNING') # type: ignore - - logger.info("Creating the figure...") - - # Plot the graph for the original data - mne.viz.plot_compare_evokeds( # type: ignore - evoked_dict, - combine="mean", - ci=0.95, # type: ignore - axes=axes[0], - colors=color_dict, - ylim=dict(hbo=[-10, 15]), - show=False, - ) - - # Plot the graph for the enhanced anticorrelation data - mne.viz.plot_compare_evokeds( # type: ignore - evoked_dict_anti, - combine="mean", - ci=0.95, # type: ignore - axes=axes[1], - colors=color_dict, - ylim=dict(hbo=[-10, 15]), - show=False, - ) - - # Plot the graph for short channel regression data, if it exists - if evoked_dict_corr is not None: - mne.viz.plot_compare_evokeds( # type: ignore - evoked_dict_corr, - combine="mean", - ci=0.95, # type: ignore - axes=axes[2], - colors=color_dict, - ylim=dict(hbo=[-10, 15]), - show=False, - ) - - mne.set_log_level('INFO') # type: ignore - - # If we have a short channel, set three titles - if evoked_dict_corr is not None: - for column, condition in enumerate( - ["Original Data", "With Enhanced Anticorrelation", "With Short Regression"] - ): - axes[column].set_title(f"{condition}") - - # If we do not have a short channel, set two titles - else: - for column, condition in enumerate( - ["Original Data", "With Enhanced Anticorrelation"] - ): - axes[column].set_title(f"{condition}") - - plt.close(fig) - - return fig - - - -def create_design_matrix(data: BaseRaw, stim_duration: float, short_chans: BaseRaw | None) -> tuple[DataFrame, Figure]: - """ - Creates a design matrix for first-level analysis including optional short channel regression, and generates a plot. - - Parameters - ---------- - data : BaseRaw - The loaded data object to process. - stim_duration : float - Duration of the stimulus/event in seconds. - short_chans : BaseRaw | None - Data object containing only short channels for systemic component regression, or None if there is no short channels. - - Returns - ------- - tuple[DataFrame, Figure] - - DataFrame: The generated design matrix. - - Figure: A matplotlib figure visualizing the design matrix. - """ - - # Create the design martix - logger.info("Creating the design matrix... (This may take some time)") - - # If the design matrix is fir, calculate some of the extra required parameters before creating the matrix - if HRF_MODEL == "fir": - sfreq = getattr(data, "info")["sfreq"] - fir_delays = range(int(sfreq*15)) - - design_matrix = make_first_level_design_matrix( - data, - stim_dur=0.1, - hrf_model=HRF_MODEL, - drift_model=DRIFT_MODEL, - high_pass=1/(2*DURATION_BETWEEN_ACTIVITIES), - fir_delays=fir_delays - ) - - # Using a canonical hrf model - else: - design_matrix = make_first_level_design_matrix( - data, - stim_dur=stim_duration, - hrf_model=HRF_MODEL, - drift_model=DRIFT_MODEL, - high_pass=1/(2*DURATION_BETWEEN_ACTIVITIES), - ) - - # If we have a short channel, and short channel regression was specified, apply it to the design matrix - if short_chans is not None: - if SHORT_CHANNEL_REGRESSION: - logger.info("Applying short channel regression...") - for chan in range(len(short_chans.ch_names)): # type: ignore - design_matrix[f"short_{chan}"] = short_chans.get_data(chan).T # type: ignore - - logger.info("Creating the figure...") - fig, ax1 = plt.subplots(figsize=(10, 6), constrained_layout=True) # type: ignore - plot_design_matrix(design_matrix, axes=ax1) - plt.close(fig) - - logger.info("Successfully created the design matrix.") - - return design_matrix, fig - - - -def run_GLM_analysis(data: BaseRaw, design_matrix: DataFrame) -> RegressionResults: - """ - Runs a General Linear Model (GLM) analysis on the provided data using the specified design matrix. - - Parameters - ---------- - data : BaseRaw - The loaded data object to process. - design_matrix : DataFrame - The design matrix specifying regressors for the GLM. - - Returns - ------- - RegressionResults - The fitted GLM results object containing regression coefficients and statistics. - """ - - logger.info("Running the GLM...") - glm_est = run_glm(data, design_matrix, n_jobs=N_JOBS) - logger.info("Successfully ran the GLM.") - return glm_est - - -def calculate_dpf(file_path): - # order is hbo / hbr - with h5py.File(file_path, 'r') as f: - wavelengths = f['/nirs/probe/wavelengths'][:] - print("Wavelengths (nm):", wavelengths) - wavelengths = sorted(wavelengths, reverse=True) - data = METADATA.get(file_path) - if data is None: - age = 25 - else: - age = data['Age'] - logger.info(age) - age = float(age) - a = 223.3 - b = 0.05624 - c = 0.8493 - d = -5.723e-7 - e = 0.001245 - f = -0.9025 - dpf = [] - for w in wavelengths: - logger.info(w) - dpf.append(a + b * (age**c) + d* (w**3) + e * (w**2) + f*w) - logger.info(dpf) - return dpf - - -def individual_GLM_analysis(file_path: str, ID: str, stim_duration: float = 5.0, progress_callback=None) -> tuple[BaseRaw, BaseRaw, DataFrame, DataFrame, DataFrame, DataFrame, dict[str, Figure], str, bool, bool]: - """ - Performs individual-level General Linear Model (GLM) analysis on fNIRS data from a SNIRF file. - - Parameters - ---------- - file_path : str - Path to the SNIRF file containing the participant's raw data. - ID : str - Unique identifier for the participant, used for labeling output. - stim_duration : float, optional - Duration of the stimulus in seconds for constructing the design matrix (default is 5.0) - - Returns - ------- - tuple[BaseRaw, BaseRaw, DataFrame, DataFrame, DataFrame, DataFrame, dict[str, Figure], str, bool, bool] - - BaseRaw: Processed fNIRS data - - BaseRaw: Full layout raw data prior to bad channel rejection - - DataFrame: Region of interest statistics - - DataFrame: Channel-level GLM statistics - - DataFrame: Contrast results - - DataFrame: Design matrix used for GLM - - dict[str, Figure]: Dictionary of figures generated during processing - - str: Description of processing steps applied - - bool: Whether the GLM successfully ran to completion - - bool: Whether the analysis result is valid based on quality checks - """ - - # Setting up variables to be used later - fig_dict: dict[str, Figure] = {} - bad_channels_sci = [] - bad_channels_psp = [] - bad_channels_snr = [] - mean_hr_nk = 70 - mean_hr_scipy = 70 - num_bad_channels = 0 - valid = True - short_chans = None - roi: DataFrame = DataFrame() - cha: DataFrame = DataFrame() - con: DataFrame = DataFrame() - design_matrix = DataFrame() - - # Load the file, get the sources and detectors, update their position, and calculate the short channel and any large distance channels - # STEP 1 - data, fig = load_snirf(file_path, ID, FORCE_DROP_CHANNELS) - fig_dict['Raw'] = fig - order_of_operations = "Loaded Raw File" - if progress_callback: progress_callback(1) - - # Initalize the participants full layout to be the current data regardless if it will be updated later - raw_full_layout = data - - - logger.info(file_path) - logger.info(ID) - logger.info(METADATA.get(file_path)) - calculate_dpf(file_path) - - try: - # Did the user want to load new channel positions from an optode file? - # STEP 2 - if OPTODE_FILE: - data = calculate_and_apply_updated_optode_coordinates(data) - order_of_operations += " + Updated Optode Placements" - if progress_callback: progress_callback(2) - - # STEP 2.5 - # TODO: remember why i do this - # I think its because i want a participants whole layout to plot without any bads - # but i shouldnt need to do od and bll just check the last three numbers - temp = data.copy() - temp_od = cast(BaseRaw, optical_density(temp, verbose=VERBOSITY)) - raw_full_layout = beer_lambert_law(temp_od, ppf=calculate_dpf(file_path)) - - # If specified, apply TDDR to the data - # STEP 3 - if TDDR: - data, fig = calculate_and_apply_tddr(data, ID) - order_of_operations += " + TDDR Filter" - fig_dict['TDDR'] = fig - if progress_callback: progress_callback(3) - - # If specified, apply a wavelet filter to the data - # STEP 4 - if WAVELET: - data, fig = calculate_and_apply_wavelet(data, ID) - order_of_operations += " + Wavelet Filter" - fig_dict['Wavelet'] = fig - if progress_callback: progress_callback(4) - - # If specified, attempt to get short channels from the data - # STEP 4.5 - if SHORT_CHANNEL: - try: - short_chans = get_short_channels(data, SHORT_CHANNEL_THRESH) - except Exception as e: - raise ProcessingError("SHORT_CHANNEL was specified, but no short channel was found. Please ensure the data has a short channel and that SHORT_CHANNEL_THRESH is set correctly.") - pass - else: - pass - - # Ensure that there is no short or really long channels in the data - data = get_long_channels(data, SHORT_CHANNEL_THRESH, LONG_CHANNEL_THRESH) - - # STEP 5 - if HEART_RATE: - sfreq, signal_trimmed, times_trimmed = short_channel_processing_for_hr(data, short_chans) - hr_smooth_nk, mean_hr_nk = calculate_heart_rate_neurokit(sfreq, signal_trimmed) - freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy = calculate_heart_rate_scipy(sfreq, signal_trimmed) - - # HACK: This sucks but looking at the graphs I trust neurokit2 more - overruled = False - if mean_hr_scipy < mean_hr_nk - 15: - mean_hr_scipy = mean_hr_nk - overruled = True - if mean_hr_scipy > mean_hr_nk + 15: - mean_hr_scipy = mean_hr_nk - overruled = True - - hr1, hr2 = plot_heart_rate(freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy, hr_smooth_nk, mean_hr_nk, times_trimmed, overruled) - order_of_operations += " + Heart Rate Calculation" - fig_dict['HeartRate_PSD'] = hr1 - fig_dict['HeartRate_Time'] = hr2 - if progress_callback: progress_callback(5) - - # If specified, calculate and apply SCI - # STEP 6 - if SCI: - bad_channels_sci, sci1, sci2 = calculate_scalp_coupling(data.copy(), min(mean_hr_nk - HEART_RATE_WINDOW, mean_hr_scipy - HEART_RATE_WINDOW) / 60, max(mean_hr_nk + HEART_RATE_WINDOW, mean_hr_scipy + HEART_RATE_WINDOW) / 60) - order_of_operations += " + SCI Calculation" - fig_dict['SCI1'] = sci1 - fig_dict['SCI2'] = sci2 - if progress_callback: progress_callback(6) - - # If specified, calculate and apply PSP - if PSP: - bad_channels_psp, psp1, psp2 = calculate_peak_power(data.copy(), min(mean_hr_nk - HEART_RATE_WINDOW, mean_hr_scipy - HEART_RATE_WINDOW) / 60, max(mean_hr_nk + HEART_RATE_WINDOW, mean_hr_scipy + HEART_RATE_WINDOW) / 60) - order_of_operations += " + PSP Calculation" - fig_dict['PSP1'] = psp1 - fig_dict['PSP2'] = psp2 - if progress_callback: progress_callback(7) - - # If specified, calculate and apply SNR - if SNR: - bad_channels_snr, fig = calculate_signal_noise_ratio(data.copy()) - order_of_operations += " + SNR Calculation" - fig_dict['SNR'] = fig - - # If specified, drop channels that were marked as bad - # STEP 7 - if EXCLUDE_CHANNELS: - data, fig, num_bad_channels = mark_bad_channels(data, ID, set(bad_channels_sci), set(bad_channels_psp), set(bad_channels_snr)) - order_of_operations += " + Excluded Bad Channels" - fig_dict['Bads'] = fig - if progress_callback: progress_callback(7) - - # Calculate the optical density - # STEP 8 - data, fig = calculate_optical_density(data, ID) - order_of_operations += " + Optical Density" - fig_dict['OpticalDensity'] = fig - if progress_callback: progress_callback(8) - - # Mainly for visualization. Could be implemented in the future - # STEP 8.5 - evoked_dict_corr = None - if SHORT_CHANNEL: - short_chans_od = cast(BaseRaw, optical_density(short_chans)) - data_recombined = cast(BaseRaw, data.copy().add_channels([short_chans_od])) # type: ignore - evoked_dict_corr = calculate_and_apply_short_channel_correction(data_recombined.copy(), file_path) - - # Calculate the haemoglobin concentration - # STEP 9 - data, fig = calculate_haemoglobin_concentration(data, ID, file_path) - order_of_operations += " + Haemoglobin Concentration" - fig_dict['HaemoglobinConcentration'] = fig - if progress_callback: progress_callback(9) - - # Mainly for visualization. Could be implemented in the future - # STEP 9.5 - evoked_dict = extract_normal_epochs(data.copy()) - evoked_dict_anti = calculate_and_apply_negative_correlation_enhancement(data.copy()) - fig = signal_enhancement_techniques_images(evoked_dict, evoked_dict_anti, evoked_dict_corr) - fig_dict['SignalEnhancement'] = fig - - # Create the design matrix - # STEP 10 - # HACK FIXME - Downsampling to 10 is certaintly not the best way... right? - if HRF_MODEL == 'fir': - data.resample(10, verbose=VERBOSITY) # type: ignore - if short_chans is not None: - short_chans.resample(10, verbose=VERBOSITY) # type: ignore - - design_matrix, fig = create_design_matrix(data, stim_duration, short_chans) - order_of_operations += " + Design Matrix" - fig_dict['DesignMatrix'] = fig - if progress_callback: progress_callback(10) - - # Run the glm on the design matrix - # STEP 11 - glm_est: RegressionResults = run_GLM_analysis(data, design_matrix) - order_of_operations += " + GLM" - if progress_callback: progress_callback(11) - - # Add the regions of interest to the groups - # STEP 12 - logger.info("Performing the finishing touches...") - order_of_operations += " + Finishing Touches" - - # Extract the channel metrics - logger.info("Calculating channel results...") - cha = cast(DataFrame, glm_est.to_dataframe()) # type: ignore - - logger.info("Creating groups...") - if HRF_MODEL == "fir": - groups = dict(AllChannels=range(len(data.ch_names))) # type: ignore - - else: - groups: dict[str, list[int]] = dict( - group_1_picks = picks_pair_to_idx(data, ROI_GROUP_1, on_missing="ignore"), # type: ignore - group_2_picks = picks_pair_to_idx(data, ROI_GROUP_2, on_missing="ignore"), # type: ignore - ) - - # Compute region of interest results from the channel data - logger.info("Calculating region of intrest results...") - roi = glm_est.to_dataframe_region_of_interest(groups, design_matrix.columns, demographic_info=True) # type: ignore - - # Create the contrast matrix - logger.info("Creating the contrast matrix...") - contrast_matrix = np.eye(design_matrix.shape[1]) - basic_conts = dict( - [(column, contrast_matrix[i]) for i, column in enumerate(design_matrix.columns)] - ) - - # Calculate contrast differently depending on the hrf model - if HRF_MODEL == 'fir': - # Find all FIR regressors for TARGET_ACTIVITY - delay_cols = [col for col in design_matrix.columns if col.startswith(f"{TARGET_ACTIVITY}_delay_")] - - if not delay_cols: - raise ValueError(f"No FIR regressors found for condition {TARGET_ACTIVITY}.") - - # Sum or average their contrast vectors - fir_contrast = np.sum([basic_conts[col] for col in delay_cols], axis=0) - fir_contrast /= len(delay_cols) - - # Compute contrast - contrast = glm_est.compute_contrast(fir_contrast) # type: ignore - con = cast(DataFrame, contrast.to_dataframe()) # type: ignore - - else: - # Create and compute the contrast - contrast_t = basic_conts[TARGET_ACTIVITY] - contrast = glm_est.compute_contrast(contrast_t) # type: ignore - con = cast(DataFrame, contrast.to_dataframe()) # type: ignore - - # Add the participant ID to the dataframes - roi["ID"] = cha["ID"] = con["ID"] = design_matrix["ID"] = ID - - # Convert to uM for nicer plotting below. - logger.info("Converting to uM...") - cha["theta"] = cha["theta"].astype(float) * 1.0e6 - roi["theta"] = roi["theta"].astype(float) * 1.0e6 - con["effect"] = con["effect"].astype(float) * 1.0e6 - - # If we exceed the maximum allowed bad channels, apply an X over the figures - logger.info("Checking amount of bad channels...") - if num_bad_channels >= MAX_BAD_CHANNELS: - valid=False - logger.info("Drawing some big X's...") - for _, fig in fig_dict.items(): - add_x_overlay(fig, 'Too many bad channels!', 'red') - - logger.info("Completed individual analysis.") - - if progress_callback: progress_callback(12) - - # Clear the output for the next participant unless we are told to be verbose - if not VERBOSITY: - clear_output(wait=True) - - - # Something really went wrong and we should not continue - except ProcessingError as e: - logger.info("An error occured!", e) - raise - - # Something went wrong at one of the steps. Return what data we gathered, but set the validity of this run to False - except Exception as e: - logger.info("An error occured!", e) - fig_dict_bytes = convert_fig_dict_to_png_bytes(fig_dict) - return data, raw_full_layout, roi, cha, con, design_matrix, fig_dict, order_of_operations, False, False - - fig_dict_bytes = convert_fig_dict_to_png_bytes(fig_dict) - - return data, raw_full_layout, roi, cha, con, design_matrix, fig_dict_bytes, order_of_operations, True, valid - - - -def add_x_overlay(fig: Figure, reason: str, color: str) -> None: - """ - Adds a large 'X' across the figure if the participant met the bad channel criteria. - - Parameters - ---------- - fig : Figure - Matplotlib figure to draw the X on. - reason: str - Why the X is being drawn. - color: str - What color the reason should be. - """ - - # Draw the big X on the graph - ax = fig.add_axes([0, 0, 1, 1], zorder=100) # type: ignore - ax.set_axis_off() - ax.plot([0, 1], [0, 1], color='red', linewidth=8, transform=fig.transFigure, clip_on=False) # type: ignore - ax.plot([0, 1], [1, 0], color='red', linewidth=8, transform=fig.transFigure, clip_on=False) # type: ignore - ax.text(0.5, 0.5, reason, color=color, fontsize=26, fontweight='bold', ha='center', va='center', transform=fig.transFigure, zorder=101, bbox=dict(facecolor='white', alpha=0.8, edgecolor='red', boxstyle='round,pad=0.4')) # type: ignore - - - -from io import BytesIO - -def convert_fig_dict_to_png_bytes(fig_dict): - result = {} - for key, fig in fig_dict.items(): - buf = BytesIO() - fig.savefig(buf, format='png') - buf.seek(0) - result[key] = buf.read() - return result - - - -def process_file_worker(args): - file_path, file_name, stim_duration, config, gui, progress_queue = args - try: - set_config(config, gui) - - def progress_callback(step_idx): - print(f"[Worker] Step {step_idx} for {file_name}") - if progress_queue: - progress_queue.put(('progress', file_name, step_idx)) - - result = individual_GLM_analysis( - file_path, file_name, stim_duration, - progress_callback=progress_callback - ) - return file_name, result, None - except Exception as e: - return file_name, None, e - - - -def process_folder(folder_path: str, stim_duration: float, files_remaining: dict[str, int], config , gui: bool = False, progress_queue=None) -> tuple[dict[str, dict[str, BaseRaw]], DataFrame, DataFrame, DataFrame, DataFrame, dict[str, list[Figure]], dict[str, str]]: - df_roi = DataFrame() - df_cha = DataFrame() - df_con = DataFrame() - df_design_matrix = DataFrame() - - raw_haemo_dict: dict[str, dict[str, BaseRaw]] = {} - process_dict: dict[str, str] = {} - figures_by_step: dict[str, list[Figure]] = { - step: [] for step in [ - 'Raw', 'TDDR', 'Wavelet', 'HeartRate_PSD', 'HeartRate_Time', - 'SCI1', 'SCI2', 'PSP1', 'PSP2', 'SNR', 'Bads', - 'OpticalDensity', 'HaemoglobinConcentration', 'SignalEnhancement', 'DesignMatrix' - ] - } - - file_args = [ - (os.path.join(folder_path, file_name), file_name, stim_duration, config, gui, progress_queue) - for file_name in os.listdir(folder_path) - if os.path.isfile(os.path.join(folder_path, file_name)) - ] - - print("[process_folder] File args:", file_args) - - available_mem = psutil.virtual_memory().available - if (MAX_WORKERS >= available_mem / (1024 ** 3)): - print(f"WARNING: You have set MAX_WORKERS to {MAX_WORKERS}. Each worker should have at least 1GB of system memory. Your device currently has a total of {available_mem / (1024 ** 3):.2f}GB free.\nPlease consider lowering MAX_WORKERS to prevent potential crashing due to insufficient system memory.") - - with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor: - future_to_file = { - executor.submit(process_file_worker, args): args[1] for args in file_args - } - - with tqdm(total=len(file_args), desc="Processing files") as pbar: - for future in as_completed(future_to_file): - file_name = future_to_file[future] - files_remaining['count'] -= 1 - logger.info(f"Files remaining: {files_remaining['count']}") - pbar.update(1) - - try: - file_name, result, error = future.result() - if error: - logger.info(f"Error processing {file_name}: {error}") - continue - - raw_haemo_filtered, raw_haemo_full, roi, channel, contrast, design_matrix, fig_dict, process, finished, valid = result - - if finished and valid: - logger.info(f"Finished processing {file_name}. This participant was valid.") - raw_haemo_dict[file_name] = { - "filtered": raw_haemo_filtered, - "full_layout": raw_haemo_full - } - process_dict[file_name] = process - - for step in figures_by_step: - if step in fig_dict: - figures_by_step[step].append(fig_dict[step]) - - df_roi = pd.concat([df_roi, roi], ignore_index=True) - df_cha = pd.concat([df_cha, channel], ignore_index=True) - df_con = pd.concat([df_con, contrast], ignore_index=True) - df_design_matrix = pd.concat([df_design_matrix, design_matrix], ignore_index=True) - - else: - logger.info(f"Finished processing {file_name}. This participant was NOT valid.") - if SEE_BAD_IMAGES: - for step in figures_by_step: - if step in fig_dict: - figures_by_step[step].append(fig_dict[step]) - except Exception as e: - logger.info(f"Unexpected error processing {file_name}: {e}") - raise - - return raw_haemo_dict, df_roi, df_cha, df_con, df_design_matrix, figures_by_step, process_dict - - - -def verify_channel_positions(data: BaseRaw) -> None: - """ - Visualizes the sensor/channel positions of the raw data for verification. - - Parameters - ---------- - data : BaseRaw - The loaded data object to process. - """ - logger.info("Creating the figure...") - data.plot_sensors(show_names=True, to_sphere=True, show=False, verbose=VERBOSITY) # type: ignore - plt.show() # type: ignore - - - -def plot_3d_evoked_array( - inst: Union[BaseRaw, EvokedArray, Info], - statsmodel_df: DataFrame, - picks: Optional[Union[str, list[str]]] = "hbo", - value: str = "Coef.", - background: str = "w", - figure: Optional[object] = None, - clim: Union[str, dict[str, Union[str, list[float]]]] = "auto", - mode: str = "weighted", - colormap: str = "RdBu_r", - surface: str = "pial", - hemi: str = "both", - size: int = 800, - view: Optional[Union[str, dict[str, float]]] = None, - colorbar: bool = True, - distance: float = 0.03, - subjects_dir: Optional[str] = None, - src: Optional[SourceSpaces] = None, - verbose: bool = False, -) -> Brain: - '''Ported from MNE''' - - info: Info = cast(Info, deepcopy(inst if isinstance(inst, Info) else inst.info)) # type: ignore - if not (getattr(info, "ch_names") == list(statsmodel_df["ch_name"].values)): # type: ignore - raise RuntimeError( - 'MNE data structure does not match dataframe ' - f'results.\nMNE = {getattr(info, "ch_names")}.\n' - f'GLM = {list(statsmodel_df["ch_name"].values)}' # type: ignore - ) - - ea = EvokedArray(np.tile(statsmodel_df[value].values.T, (1, 1)).T, info.copy()) # type: ignore - - # TODO: mimic behaviour of other MNE-NIRS glm plotting options - if picks is not None: - ea = ea.pick(picks=picks) # type: ignore - - if subjects_dir is None: - subjects_dir = os.environ["SUBJECTS_DIR"] - if src is None: - fname_src_fs = os.path.join( - subjects_dir, "fsaverage", "bem", "fsaverage-ico-5-src.fif" - ) - src = read_source_spaces(fname_src_fs, verbose=verbose) - - picks = getattr(ea, "info")["ch_names"] - - # Set coord frame - for idx in range(len(getattr(ea, "ch_names"))): - getattr(ea, "info")["chs"][idx]["coord_frame"] = 4 - - # Generate source estimate - kwargs = dict( - evoked=ea, - subject="fsaverage", - trans=Transform('head', 'mri', np.eye(4)), - distance=distance, - mode=mode, - surface=surface, - subjects_dir=subjects_dir, - src=src, - project=True, - ) - - stc = stc_near_sensors(picks=picks, **kwargs, verbose=verbose) # type: ignore - - - from mne import SourceEstimate - assert isinstance(stc, SourceEstimate) # or your specific subclass - - # Produce brain plot - brain: Brain = stc.plot( # type: ignore - src=src, - subjects_dir=subjects_dir, - hemi=hemi, - surface=surface, - initial_time=0, - clim=clim, # type: ignore - size=size, - colormap=colormap, - figure=figure, - background=background, - colorbar=colorbar, - verbose=verbose, - ) - if view is not None: - brain.show_view(view) # type: ignore - - return brain - - - -def brain_3d_visualization(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], all_haemo: dict[str, dict[str, BaseRaw]], participant_number: int, t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True) -> None: - - - # Determine if we are visualizing t or theta to set the appropriate limit - if t_or_theta == 't': - clim = dict(kind="value", pos_lims=(0, ABS_T_VALUE/2, ABS_T_VALUE)) - elif t_or_theta == 'theta': - clim = dict(kind="value", pos_lims=(0, ABS_THETA_VALUE/2, ABS_THETA_VALUE)) - - - # Loop over all groups - for index, group_name in enumerate(all_results): - - # We only care for their channel results - (_, df_cha, _, _) = all_results[group_name] - - # Get all activity conditions - for cond in [TARGET_ACTIVITY]: - - if HRF_MODEL == 'fir': - ch_summary = df_cha.query(f"Condition.str.startswith('{cond}_delay_') and Chroma == 'hbo'", engine='python') # type: ignore - - else: - # Filter for the condition and chromophore - ch_summary = df_cha.query("Condition in [@cond] and Chroma == 'hbo'") # type: ignore - - # Determine number of unique participants based on their ID - n_participants = ch_summary["ID"].nunique() - - # WE JUST NEED SOMEONES OPTODE DATA TO PLOT ON THE BRAIN! - # TODO: This should take the average positions of all participants - # We will just take the passed through parameter - participant_to_plot = ch_summary["ID"].unique()[participant_number] # type: ignore - participant_raw_full: BaseRaw = all_haemo[participant_to_plot]["full_layout"] - - # Use ordinary least squares (OLS) if only one participant - if n_participants == 1: - - # t values - if t_or_theta == 't': - ch_model = smf.ols("t ~ -1 + ch_name", ch_summary).fit() # type: ignore - - # theta values - elif t_or_theta == 'theta': - ch_model = smf.ols("theta ~ -1 + ch_name", ch_summary).fit() # type: ignore - - logger.info("OLS model is being used as there is only one participant!") - - # Use mixed effects model if there is multiple participants - else: - - # t values - if t_or_theta == 't': - ch_model = smf.mixedlm("t ~ -1 + ch_name", ch_summary, groups=ch_summary["ID"]).fit(method="nm") # type: ignore - - # theta values - elif t_or_theta == 'theta': - ch_model = smf.mixedlm("theta ~ -1 + ch_name", ch_summary, groups=ch_summary["ID"]).fit(method="nm") # type: ignore - - # Convert model results - model_df = cast(DataFrame, statsmodels_to_results(ch_model, order=ch_summary["ch_name"].unique())) # type: ignore - - valid_channels = ch_summary["ch_name"].unique().tolist() # type: ignore - raw_for_plot = participant_raw_full.copy().pick(picks=valid_channels) # type: ignore - - - brain = plot_3d_evoked_array(raw_for_plot.pick(picks="hbo"), model_df, view="dorsal", distance=BRAIN_DISTANCE, colorbar=True, clim=clim, mode=BRAIN_MODE, size=(800, 700)) # type: ignore - - if show_optodes == 'all' or show_optodes == 'sensors': - brain.add_sensors(getattr(raw_for_plot, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=VERBOSITY) # type: ignore - - - # Read and parse the file - if show_optodes == 'all' or show_optodes == 'labels': - positions: list[tuple[str, list[float]]] = [] - with open(OPTODE_FILE_PATH, 'r') as f: - for line in f: - line = line.strip() - if not line or ':' not in line: - continue # skip empty/malformed lines - name, coords = line.split(':', 1) - coords = [float(x) for x in coords.strip().split()] - positions.append((name.strip(), coords)) - - for name, (x, y, z) in positions: - brain._renderer.text3d(x, y, z, name, color=('red' if name.startswith('s') else 'blue' if name.startswith('d') else 'gray'), scale=0.002) # type: ignore - - # Set the display text for the brain image - # display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nGroup: ' + group_name + '\nCondition: '+ cond + '\nReject Criteria Threshold: ' + str(EPOCH_REJECT_CRITERIA_THRESH) + '\nMin Time Threshold: ' - # + str(TIME_MIN_THRESH) + 's\nMax Time Threshold: ' + str(TIME_MAX_THRESH) + 's\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) + '\nStim Duration: ' - # + str(STIM_DURATION[index]) + 's\nLooking at: ' + t_or_theta + ' values') + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE - if HRF_MODEL == 'fir': - display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nGroup: ' + group_name + '\nCondition: '+ cond + '\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) - + '\nLooking at: ' + t_or_theta + ' values') + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE - else: - display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nGroup: ' + group_name + '\nCondition: '+ cond + '\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) + '\nStim Duration: ' - + str(STIM_DURATION[index]) + '\nLooking at: ' + t_or_theta + ' values') + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE - - # Apply the text onto the brain - if show_text: - brain.add_text(0.12, 0.64, display_text, "title", font_size=11, color="k") # type: ignore - - - -def plot_fir_model_results(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], all_haemo: dict[str, dict[str, BaseRaw]], participant_number: int, t_or_theta: Literal['t', 'theta'] = 'theta') -> None: - - - if HRF_MODEL != 'fir': - logger.info("This method only works when HRF_MODEL is set to 'fir'.") - - else: - for group_name in all_results: - - (df_roi, _, _, df_design_matrix) = all_results[group_name] - first_id = df_design_matrix["ID"].unique()[participant_number] # type: ignore - first_dm = df_design_matrix.query(f"ID == '{first_id}'").copy() # type: ignore - first_dm.index = np.round([0.1 * i for i in range(len(first_dm))], decimals=1) # type: ignore - df_design_matrix = first_dm - - participant = all_haemo[first_id]["full_layout"] - - df_roi["isActivity"] = [TARGET_ACTIVITY in n for n in df_roi["Condition"]] # type: ignore - df_roi["isDelay"] = ["delay" in n for n in df_roi["Condition"]] # type: ignore - df_roi = df_roi.query("isDelay in [True]") # type: ignore - df_roi = df_roi.query("isActivity in [True]") # type: ignore - - - df_roi.loc[:, "TidyCond"] = "" - df_roi.loc[df_roi["isActivity"] == True, "TidyCond"] = TARGET_ACTIVITY # noqa: E712 - # Finally, extract the FIR delay in to its own column in data frame - df_roi.loc[:, "delay"] = [n.split("_")[-1] for n in df_roi.Condition] # type: ignore - - if t_or_theta == 'theta': - lme = smf.mixedlm("theta ~ -1 + delay:TidyCond:Chroma", df_roi, groups=df_roi["ID"]).fit() # type: ignore - elif t_or_theta == 't': - lme = smf.mixedlm("t ~ -1 + delay:TidyCond:Chroma", df_roi, groups=df_roi["ID"]).fit() # type: ignore - - df_sum: DataFrame = statsmodels_to_results(lme) # type: ignore - - - df_sum["delay"] = [int(n) for n in df_sum["delay"]] # type: ignore - df_sum = df_sum.sort_values("delay") # type: ignore - - # logger.info the result for the oxyhaemoglobin data in the Reach condition - df_sum.query(f"TidyCond in ['{TARGET_ACTIVITY}']").query("Chroma in ['hbo']") # type: ignore - - axes1: list[Axes] - fig, axes1 = plt.subplots(nrows=1, ncols=3, figsize=(20, 10)) # type: ignore - - # Extract design matrix columns that correspond to the condition of interest - dm_cond_idxs = np.where([TARGET_ACTIVITY in n for n in df_design_matrix.columns])[0] - dm_cond_colnames: list[str] = [df_design_matrix.columns[i] for i in dm_cond_idxs] - dm_cond = df_design_matrix[dm_cond_colnames] - - # 2. Extract hbo GLM estimates - df_hbo = df_sum.query(f"TidyCond in ['{TARGET_ACTIVITY}']").query("Chroma in ['hbo']") # type: ignore - vals_hbo = [float(v) for v in df_hbo["Coef."]] # type: ignore - - dm_cond_scaled_hbo = dm_cond * vals_hbo - - # 3. Extract hbr GLM estimates - df_hbr = df_sum.query(f"TidyCond in ['{TARGET_ACTIVITY}']").query("Chroma in ['hbr']") # type: ignore - vals_hbr = [float(v) for v in df_hbr["Coef."]] # type: ignore - - dm_cond_scaled_hbr = dm_cond * vals_hbr - - # Extract the time scale for plotting. - # Set time zero to be the onset. - index_values = cast(NDArray[float64], dm_cond_scaled_hbo.index.to_numpy(dtype=float) - participant.annotations.onset[1]) # type: ignore - - # Plot the result - axes1[0].plot(index_values, np.asarray(dm_cond)) # type: ignore - axes1[1].plot(index_values, np.asarray(dm_cond_scaled_hbo)) # type: ignore - axes1[2].plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r") # type: ignore - axes1[2].plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b") # type: ignore - - - valid_mask = (index_values >= 0) & (index_values <= 15) - hbo_sum_window = np.sum(dm_cond_scaled_hbo.loc[valid_mask, :], axis=1) - peak_idx_in_window = np.argmax(hbo_sum_window) - peak_idx = np.where(valid_mask)[0][peak_idx_in_window] - peak_time = float(round(index_values[peak_idx], 2)) # type: ignore - - axes1[2].axvline(x=peak_time, color='k', linestyle='--', linewidth=1.5, label='Peak time') # type: ignore - - # Format the plot - for ax in range(3): - axes1[ax].set_xlim(-5, 20) - axes1[ax].set_xlabel("Time (s)") # type: ignore - axes1[0].set_ylim(-0.2, 1.2) - axes1[1].set_ylim(-4, 8) - axes1[2].set_ylim(-4, 8) - axes1[0].set_title(f"FIR Model for {group_name} (Unscaled by GLM {TARGET_ACTIVITY} estimates) ({t_or_theta})") # type: ignore - axes1[1].set_title(f"FIR Components for {group_name} (Scaled by GLM {TARGET_ACTIVITY} estimates) ({t_or_theta})") # type: ignore - axes1[2].set_title(f"Evoked Response for {group_name} ({TARGET_ACTIVITY}) ({t_or_theta})") # type: ignore - axes1[0].set_ylabel("FIR Model") # type: ignore - axes1[1].set_ylabel("Oyxhaemoglobin (ΔμMol)") # type: ignore - axes1[2].set_ylabel("Haemoglobin (ΔμMol)") # type: ignore - axes1[2].legend(["Oyxhaemoglobin", "Deoyxhaemoglobin", f"Peak {peak_time}s"]) # type: ignore - - fig.tight_layout() - - - # We can also extract the 95% confidence intervals of the estimates too - l95_hbo = [float(v) for v in df_hbo["[0.025"]] # type: ignore - u95_hbo = [float(v) for v in df_hbo["0.975]"]] # type: ignore - dm_cond_scaled_hbo_l95 = dm_cond * l95_hbo - dm_cond_scaled_hbo_u95 = dm_cond * u95_hbo - l95_hbr = [float(v) for v in df_hbr["[0.025"]] # type: ignore - u95_hbr = [float(v) for v in df_hbr["0.975]"]] # type: ignore - dm_cond_scaled_hbr_l95 = dm_cond * l95_hbr - dm_cond_scaled_hbr_u95 = dm_cond * u95_hbr - - axes2: Axes - fig, axes2 = plt.subplots(nrows=1, ncols=1, figsize=(7, 7)) # type: ignore - - # Plot the result - axes2.plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r") # type: ignore - axes2.plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b") # type: ignore - axes2.axvline(x=peak_time, color='k', linestyle='--', linewidth=1.5, label='Peak time') # type: ignore - - axes2.fill_between( # type: ignore - index_values, - np.asarray(np.sum(dm_cond_scaled_hbo_l95, axis=1)), - np.asarray(np.sum(dm_cond_scaled_hbo_u95, axis=1)), - facecolor="red", - alpha=0.25, - ) - axes2.fill_between( # type: ignore - index_values, - np.asarray(np.sum(dm_cond_scaled_hbr_l95, axis=1)), - np.asarray(np.sum(dm_cond_scaled_hbr_u95, axis=1)), - facecolor="blue", - alpha=0.25, - ) - - # Format the plot - axes2.set_xlim(-5, 20) - axes2.set_ylim(-8, 12) - axes2.set_title(f"Evoked Response with 95% confidence intervals for {group_name} ({TARGET_ACTIVITY}) ({t_or_theta})") # type: ignore - axes2.set_ylabel("Haemoglobin (ΔμMol)") # type: ignore - axes2.legend(["Oyxhaemoglobin", "Deoyxhaemoglobin", f"Peak {peak_time}s"]) # type: ignore - axes2.set_xlabel("Time (s)") # type: ignore - - fig.tight_layout() - - plt.show() # type: ignore - - - -def plot_2d_theta_graph(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]) -> None: - '''This method will create a 2d boxplot showing the theta values for each channel and group as independent ranges on the same graph.\n - Inputs:\n - all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n - ''' - - - # Create a list to store the channel results of all groups - df_all_cha_list: list[DataFrame] = [] - - # Iterate over each group in all_results - for group_name, (_, df_cha, _, _) in all_results.items(): - df_cha["group"] = group_name # Add the group name to the data - df_all_cha_list.append(df_cha) # Append the dataframe to the list - - # Combine all the data into a single DataFrame - df_all_cha = pd.concat(df_all_cha_list, ignore_index=True) - - # Filter for the target activity - if HRF_MODEL == 'fir': - df_target = df_all_cha[df_all_cha['Condition'].str.startswith(f"{TARGET_ACTIVITY}_delay_")] # type: ignore - else: - df_target = df_all_cha[df_all_cha["Condition"] == TARGET_ACTIVITY] - - # Get the number of unique groups to know how many colors are needed for the boxplot - unique_groups = df_target["group"].nunique() - palette = sns.color_palette("Set2", unique_groups) - - # Create the boxplot - fig = plt.figure(figsize=(15, 6)) # type: ignore - sns.boxplot( - data=df_target, - x="ch_name", - y="theta", - hue="group", - palette=palette - ) - - # Format the boxplot - plt.title("Theta Coefficients by Channel and Group") # type: ignore - plt.xticks(rotation=90) # type: ignore - plt.ylabel("Theta (µM)") # type: ignore - plt.xlabel("Channel") # type: ignore - plt.legend(title="Group") # type: ignore - plt.tight_layout() - plt.show() # type: ignore - - - -def plot_individual_theta_averages(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]) -> None: - - if HRF_MODEL == 'fir': - logger.info("This method does not work when HRF_MODEL is set to 'fir'.") - return - else: - # Iterate over all the groups - for group_name in all_results: - - # Store the region of interest data - (df_roi, _, _, _) = all_results[group_name] - - # Filter the results down to what we want - grp_results = df_roi.query(f"Condition in ['{TARGET_ACTIVITY}', '{TARGET_CONTROL}']").copy() # type: ignore - grp_results = grp_results.query("Chroma in ['hbo']").copy() # type: ignore - - # Rename the ROI's to be the friendly name - roi_label_map = { - "group_1_picks": ROI_GROUP_1_NAME, - "group_2_picks": ROI_GROUP_2_NAME, - } - grp_results["ROI"] = grp_results["ROI"].replace(roi_label_map) # type: ignore - - # Create the catplot - sns.catplot( - x="Condition", - y="theta", - col="ID", - hue="ROI", - data=grp_results, - col_wrap=5, - errorbar=None, - palette="muted", - height=4, - s=10, - dodge=False, - ) - - plt.show() # type: ignore - - - -def plot_group_theta_averages(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]) -> None: - '''This method will create a stripplot showing the theta vaules for each region of interest for each group.\n - Inputs:\n - all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n''' - - if HRF_MODEL == 'fir': - logger.info("This method does not work when HRF_MODEL is set to 'fir'.") - return - else: - # Rename the ROI's to be the friendly name - roi_label_map = { - "group_1_picks": ROI_GROUP_1_NAME, - "group_2_picks": ROI_GROUP_2_NAME, - } - - # Setup subplot grid - n = len(all_results) - ncols = 2 - nrows = (n + 1) // ncols # round up - fig, axes = cast(tuple[Figure, np.ndarray[Any, Any]], plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 5 * nrows), squeeze=False)) # type: ignore - - index = -1 - # Iterate over all groups - for index, (group_name, ax) in enumerate(zip(all_results, axes.flatten())): - - # Store the region of interest data - (df_roi, _, _, _) = all_results[group_name] - - # Filter the results down to what we want - grp_results = df_roi.query(f"Condition in ['{TARGET_ACTIVITY}', '{TARGET_CONTROL}']").copy() # type: ignore - - # Run a mixedlm model on the data - roi_model = smf.mixedlm("theta ~ -1 + ROI:Condition:Chroma", grp_results, groups=grp_results["ID"]).fit(method="nm") # type: ignore - - # Apply the new friendly names on to the data - df = cast(DataFrame, statsmodels_to_results(roi_model)) - df["ROI"] = df["ROI"].map(roi_label_map) # type: ignore - - # Create a stripplot: - sns.stripplot( - x="Condition", - y="Coef.", - hue="ROI", - data=df.query("Chroma == 'hbo'"), # type: ignore - dodge=False, - jitter=False, - size=5, - palette="muted", - ax=ax, - ) - - # Format the stripplot - ax.set_title(f"Results for {group_name}") - ax.legend(title="ROI", loc="upper right") - - if index == -1: - # No groups, so remove all axes - for ax in axes.flatten(): - fig.delaxes(ax) - # Remove any unused axes and apply final touches - else: - for j in range(index + 1, len(axes.flatten())): - fig.delaxes(axes.flatten()[j]) - - fig.tight_layout() - fig.suptitle("Theta Averages Across Groups", fontsize=16, y=1.02) # type: ignore - plt.show() # type: ignore - - - -def compute_p_group_stats(df_cha: DataFrame, bad_pairs: set[tuple[int, int]], t_or_theta: Literal['t', 'theta'] = 't') -> DataFrame: - - if HRF_MODEL == 'fir': - # Filter: All delays for the target activity - df_activity = df_cha[df_cha['Condition'].str.startswith(f"{TARGET_ACTIVITY}_delay_") & (df_cha['Chroma'] == 'hbo')] # type: ignore - - # Aggregate across FIR delays *per subject* for each channel - df_agg = (df_activity.groupby(['Source', 'Detector', 'ID'])[['t', 'theta']].mean().reset_index()) # type: ignore - - else: - # Canonical HRF case - df_agg = df_cha[(df_cha['Condition'] == TARGET_ACTIVITY) & (df_cha['Chroma'] == 'hbo')].copy() - - # Filter the channel data down to what we want - grouped = cast(Iterator[tuple[tuple[int, int], Any]], df_agg.groupby(['Source', 'Detector'])) # type: ignore - - # Create an empty list to store the data for our result - data: list[dict[str, Any]] = [] - - # Iterate over the filtered channel data - for (src, det), group in grouped: - - # If it is a bad channel pairing, do not process it - if (src, det) in bad_pairs: - logger.info(f"Skipping bad channel Source {src} - Detector {det}") - continue - - # Drop any missing values that could exist - t_values = group['t'].dropna().values - t_values = np.array(t_values, dtype=float) - theta_values = group['theta'].dropna().values - theta_values = np.array(theta_values, dtype=float) - - # Ensure that we still have our two t values, otherwise do not process this pairing - # TODO: is the t values throwing a warning good enough? - if len(t_values) < 2: - logger.info(f"Skipping Source {src} - Detector {det}: not enough data (n={len(t_values)})") - continue - - - # NOTE: This is still calculated with t values as it is a t-test - # Perform one-sample t-test on t-values across subjects - shitte, pval = ttest_1samp(t_values, popmean=0) - - print(shitte) - - # Store all of the data for this ttest using the mean t-value for visualization - if t_or_theta == 't': - data.append({ - 'Source': src, - 'Detector': det, - 't_or_theta': np.mean(t_values), - 'p_value': pval - }) - - else: - data.append({ - 'Source': src, - 'Detector': det, - 't_or_theta': np.mean(theta_values), - 'p_value': pval - }) - - # Create a DataFrame with the data and ensure it is not empty - result = DataFrame(data) - if result.empty: - logger.info("No valid channel pairs with enough data for group-level testing.") - - return result - - - -def get_bad_src_det_pairs(raw: BaseRaw) -> set[tuple[int, int]]: - '''This method figures out the bad source and detector pairings for the 2d t+p graph to prevent them from being plotted. - Inputs:\n - raw (RawSNIRF) - Contains all the snirf data for the last participant processed. Only used to get the channels\n - Outputs:\n - bad_pairs (set) - Set containing all of the bad pairs of sources and detectors''' - - # Create a set to store the bad pairs - bad_pairs: set[tuple[int, int]] = set() - - # Iterate over all the channels in bads key - for ch_name in getattr(raw, "info")["bads"]: - try: - # Get all characters before the space - parts = ch_name.split()[0] - - # Split with the separator - src_str, det_str = parts.split(SOURCE_DETECTOR_SEPARATOR) - src = int(src_str[1:]) - det = int(det_str[1:]) - - # Add to the set - bad_pairs.add((src, det)) - - except Exception as e: - logger.info(f"Could not parse bad channel '{ch_name}': {e}") - - return bad_pairs - - - -def plot_avg_significant_activity(raw: BaseRaw, all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], t_or_theta: Literal['t', 'theta'] = 't') -> None: - '''This method plots the average t values for the groups on a 2D graph. p values less than or equal to P_THRESHOLD are solid lines, while greater p values are dashed lines.\n - Inputs:\n - raw (RawSNIRF) - Contains all the snirf data for the last participant processed. Only used to get the channel locations.\n - all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n''' - - # Iterate over all the groups - for group_name in all_results: - (_, df_cha, _, _) = all_results[group_name] - - if HRF_MODEL == 'fir': - mask = df_cha['Condition'].str.startswith(f"{TARGET_ACTIVITY}_delay_") & (df_cha['Chroma'] == 'hbo') # type: ignore - filtered_df = df_cha[mask] - num_tests = filtered_df.groupby(['Source', 'Detector']).ngroups # type: ignore - else: - num_tests = len(cast(Iterator[tuple[tuple[int, int], Any]], df_cha.query(f"Condition == '{TARGET_ACTIVITY}' and Chroma == 'hbo'").groupby(['Source', 'Detector']))) # type: ignore - - logger.info(f"Number of tests: {num_tests}") - - # Compute average t-value across individuals for each channel pairing - bad_pairs = get_bad_src_det_pairs(raw) - avg_df = compute_p_group_stats(df_cha, bad_pairs, t_or_theta) - - logger.info(f"Average {t_or_theta}-values and p-values for {TARGET_ACTIVITY}:") - for _, row in avg_df.iterrows(): # type: ignore - logger.info(f"Source {row['Source']} <-> Detector {row['Detector']}: " - f"Avg {t_or_theta}-value = {row['t_or_theta']:.3f}, Avg p-value = {row['p_value']:.3f}") - - # Extract the cource and detector positions from raw - src_pos: dict[int, tuple[float, float]] = {} - det_pos: dict[int, tuple[float, float]] = {} - for ch in getattr(raw, "info")["chs"]: - ch_name = ch['ch_name'] - if not ch_name or not ch['loc'].any(): - continue - parts = ch_name.split()[0] - src_str, det_str = parts.split(SOURCE_DETECTOR_SEPARATOR) - src_num = int(src_str[1:]) - det_num = int(det_str[1:]) - src_pos[src_num] = ch['loc'][3:5] - det_pos[det_num] = ch['loc'][6:8] - - # Set up the plot - fig, ax = plt.subplots(figsize=(8, 6)) # type: ignore - - # Plot the sources - for pos in src_pos.values(): - ax.scatter(pos[0], pos[1], s=120, c='k', marker='o', edgecolors='white', linewidths=1, zorder=3) # type: ignore - - # Plot the detectors - for pos in det_pos.values(): - ax.scatter(pos[0], pos[1], s=120, c='k', marker='s', edgecolors='white', linewidths=1, zorder=3) # type: ignore - - # Ensure that the colors stay within the boundaries even if they are over or under the max/min values - if t_or_theta == 't': - norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_T_VALUE, vmax=ABS_SIGNIFICANCE_T_VALUE) - elif t_or_theta == 'theta': - norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_THETA_VALUE, vmax=ABS_SIGNIFICANCE_THETA_VALUE) - - cmap: mcolors.Colormap = plt.get_cmap('seismic') - - # Plot connections with avg t-values - for row in avg_df.itertuples(): - src: int = cast(int, row.Source) # type: ignore - det: int = cast(int, row.Detector) # type: ignore - tval: float = cast(float, row.t_or_theta) # type: ignore - pval: float = cast(float, row.p_value) # type: ignore - - - if src in src_pos and det in det_pos: - x = [src_pos[src][0], det_pos[det][0]] - y = [src_pos[src][1], det_pos[det][1]] - style = '-' if pval <= P_THRESHOLD else '--' - ax.plot(x, y, linestyle=style, color=cmap(norm(tval)), linewidth=4, alpha=0.9, zorder=2) # type: ignore - - # Format the Colorbar - sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) - sm.set_array([]) - cbar = plt.colorbar(sm, ax=ax, shrink=0.85) # type: ignore - cbar.set_label(f'Average {TARGET_ACTIVITY} {t_or_theta} value (hbo)', fontsize=11) # type: ignore - - # Formatting the subplots - ax.set_aspect('equal') - ax.set_title(f"Average {t_or_theta} values for {TARGET_ACTIVITY} (HbO) for {group_name}", fontsize=14) # type: ignore - ax.set_xlabel('X position (m)', fontsize=11) # type: ignore - ax.set_ylabel('Y position (m)', fontsize=11) # type: ignore - ax.grid(True, alpha=0.3) # type: ignore - - # Set axis limits to be 1cm more than the optode positions - all_x = [pos[0] for pos in src_pos.values()] + [pos[0] for pos in det_pos.values()] - all_y = [pos[1] for pos in src_pos.values()] + [pos[1] for pos in det_pos.values()] - ax.set_xlim(min(all_x)-0.01, max(all_x)+0.01) - ax.set_ylim(min(all_y)-0.01, max(all_y)+0.01) - - fig.tight_layout() - plt.show() # type: ignore - - - -def generate_montage_locations(): - """Get standard MNI montage locations in dataframe. - - Data is returned in the same format as the eeg_positions library. - """ - # standard_1020 and standard_1005 are in MNI (fsaverage) space already, - # but we need to undo the scaling that head_scale will do - montage = mne.channels.make_standard_montage( - "standard_1005", head_size=0.09700884729534559 - ) - for d in montage.dig: - d["coord_frame"] = 2003 - montage.dig[:] = montage.dig[3:] - montage.add_mni_fiducials() # now in fsaverage space - coords = pd.DataFrame.from_dict(montage.get_positions()["ch_pos"]).T - coords["label"] = coords.index - coords = coords.rename(columns={0: "x", 1: "y", 2: "z"}) - - return coords.reset_index(drop=True) - - -def _find_closest_standard_location(position, reference, *, out="label"): - """Return closest montage label to coordinates. - - Parameters - ---------- - position : array, shape (3,) - Coordinates. - reference : dataframe - As generated by _generate_montage_locations. - trans_pos : str - Apply a transformation to positions to specified frame. - Use None for no transformation. - """ - from scipy.spatial.distance import cdist - - p0 = np.array(position) - p0.shape = (-1, 3) - # head_mri_t, _ = _get_trans("fsaverage", "head", "mri") - # p0 = apply_trans(head_mri_t, p0) - dists = cdist(p0, np.asarray(reference[["x", "y", "z"]], float)) - - if out == "label": - min_idx = np.argmin(dists) - return reference["label"][min_idx] - else: - assert out == "dists" - return dists - - - -def _source_detector_fold_table(raw, cidx, reference, fold_tbl, interpolate): - src = raw.info["chs"][cidx]["loc"][3:6] - det = raw.info["chs"][cidx]["loc"][6:9] - - ref_lab = list(reference["label"]) - dists = _find_closest_standard_location([src, det], reference, out="dists") - src_min, det_min = np.argmin(dists, axis=1) - src_name, det_name = ref_lab[src_min], ref_lab[det_min] - - tbl = fold_tbl.query("Source == @src_name and Detector == @det_name") - dist = np.linalg.norm(dists[[0, 1], [src_min, det_min]]) - # Try reversing source and detector - if len(tbl) == 0: - tbl = fold_tbl.query("Source == @det_name and Detector == @src_name") - if len(tbl) == 0 and interpolate: - # Try something hopefully not too terrible: pick the one with the - # smallest net distance - good = np.isin(fold_tbl["Source"], reference["label"]) & np.isin( - fold_tbl["Detector"], reference["label"] - ) - assert good.any() - tbl = fold_tbl[good] - assert len(tbl) - src_idx = [ref_lab.index(src) for src in tbl["Source"]] - det_idx = [ref_lab.index(det) for det in tbl["Detector"]] - # Original - tot_dist = np.linalg.norm([dists[0, src_idx], dists[1, det_idx]], axis=0) - assert tot_dist.shape == (len(tbl),) - idx = np.argmin(tot_dist) - dist_1 = tot_dist[idx] - src_1, det_1 = ref_lab[src_idx[idx]], ref_lab[det_idx[idx]] - # And the reverse - tot_dist = np.linalg.norm([dists[0, det_idx], dists[1, src_idx]], axis=0) - idx = np.argmin(tot_dist) - dist_2 = tot_dist[idx] - src_2, det_2 = ref_lab[det_idx[idx]], ref_lab[src_idx[idx]] - if dist_1 < dist_2: - new_dist, src_use, det_use = dist_1, src_1, det_1 - else: - new_dist, src_use, det_use = dist_2, det_2, src_2 - - - tbl = fold_tbl.query("Source == @src_use and Detector == @det_use") - tbl = tbl.copy() - tbl["BestSource"] = src_name - tbl["BestDetector"] = det_name - tbl["BestMatchDistance"] = dist - tbl["MatchDistance"] = new_dist - assert len(tbl) - else: - tbl = tbl.copy() - tbl["BestSource"] = src_name - tbl["BestDetector"] = det_name - tbl["BestMatchDistance"] = dist - tbl["MatchDistance"] = dist - - tbl = tbl.copy() # don't get warnings about setting values later - return tbl - -from mne.utils import _check_fname, _validate_type, warn - - -def _read_fold_xls(fname, atlas="Juelich"): - """Read fOLD toolbox xls file. - - The values are then manipulated in to a tidy dataframe. - - Note the xls files are not included as no license is provided. - - Parameters - ---------- - fname : str - Path to xls file. - atlas : str - Requested atlas. - """ - page_reference = {"AAL2": 2, "AICHA": 5, "Brodmann": 8, "Juelich": 11, "Loni": 14} - - tbl = pd.read_excel(fname, sheet_name=page_reference[atlas]) - - # Remove the spacing between rows - empty_rows = np.where(np.isnan(tbl["Specificity"]))[0] - tbl = tbl.drop(empty_rows).reset_index(drop=True) - - # Empty values in the table mean its the same as above - for row_idx in range(1, tbl.shape[0]): - for col_idx, col in enumerate(tbl.columns): - if not isinstance(tbl[col][row_idx], str): - if np.isnan(tbl[col][row_idx]): - tbl.iloc[row_idx, col_idx] = tbl.iloc[row_idx - 1, col_idx] - - tbl["Specificity"] = tbl["Specificity"] * 100 - tbl["brainSens"] = tbl["brainSens"] * 100 - return tbl - -import os.path as op - -def _check_load_fold(fold_files, atlas): - # _validate_type(fold_files, (list, "path-like", None), "fold_files") - if fold_files is None: - fold_files = mne.get_config("MNE_NIRS_FOLD_PATH") - if fold_files is None: - raise ValueError( - "MNE_NIRS_FOLD_PATH not set, either set it using " - "mne.set_config or pass fold_files as str or list" - ) - if not isinstance(fold_files, list): # path-like - fold_files = _check_fname( - fold_files, - overwrite="read", - must_exist=True, - name="fold_files", - need_dir=True, - ) - fold_files = [op.join(fold_files, f"10-{x}.xls") for x in (5, 10)] - - fold_tbl = pd.DataFrame() - for fi, fname in enumerate(fold_files): - fname = _check_fname( - fname, overwrite="read", must_exist=True, name=f"fold_files[{fi}]" - ) - fold_tbl = pd.concat( - [fold_tbl, _read_fold_xls(fname, atlas=atlas)], ignore_index=True - ) - return fold_tbl - - - -def fold_channel_specificity_normal(raw, fold_files=None, atlas="Juelich", interpolate=False): - """Return the landmarks and specificity a channel is sensitive to. - - Parameters - - """ # noqa: E501 - _validate_type(raw, BaseRaw, "raw") - - reference_locations = generate_montage_locations() - - fold_tbl = _check_load_fold(fold_files, atlas) - - chan_spec = list() - for cidx in range(len(raw.ch_names)): - tbl = _source_detector_fold_table( - raw, cidx, reference_locations, fold_tbl, interpolate - ) - chan_spec.append(tbl.reset_index(drop=True)) - - return chan_spec - - - -def fold_channels(raw: BaseRaw, all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], fold_path: str) -> None: - - - # Locate the fOLD excel files - mne.set_config('MNE_NIRS_FOLD_PATH', fold_path) # type: ignore - - # Iterate over all of the groups - for group_name in all_results: - - output = None - - # List to store the results - landmark_specificity_data: list[dict[str, Any]] = [] - - # Filter the data to only what we want - hbo_channel_names = cast(list[str], getattr(raw.copy().pick(picks='hbo'), "ch_names")) # type: ignore - - # Format the output to make it slightly easier to read - logger.info("*" * 60) - logger.info(f'Landmark Specificity for {group_name}:') - logger.info("*" * 60) - - if GUI: - - num_channels = len(hbo_channel_names) - rows, cols = 4, 7 # 6 rows and 4 columns of pie charts - fig, axes = plt.subplots(rows, cols, figsize=(16, 10), constrained_layout=True) - axes = axes.flatten() # Flatten the axes array for easier indexing - - # If more pie charts than subplots, create extra subplots - if num_channels > rows * cols: - fig, axes = plt.subplots((num_channels // cols) + 1, cols, figsize=(16, 10), constrained_layout=True) - axes = axes.flatten() - - # Create a list for consistent color mapping - landmarks = [ - "1 - Primary Somatosensory Cortex", - "2 - Primary Somatosensory Cortex", - "3 - Primary Somatosensory Cortex", - "4 - Primary Motor Cortex", - "5 - Somatosensory Association Cortex", - "6 - Pre-Motor and Supplementary Motor Cortex", - "7 - Somatosensory Association Cortex", - "8 - Includes Frontal eye fields", - "9 - Dorsolateral prefrontal cortex", - "10 - Frontopolar area", - "11 - Orbitofrontal area", - "17 - Primary Visual Cortex (V1)", - "18 - Visual Association Cortex (V2)", - "19 - V3", - "20 - Inferior Temporal gyrus", - "21 - Middle Temporal gyrus", - "22 - Superior Temporal Gyrus", - "23 - Ventral Posterior cingulate cortex", - "24 - Ventral Anterior cingulate cortex", - "25 - Subgenual cortex", - "32 - Dorsal anterior cingulate cortex", - "37 - Fusiform gyrus", - "38 - Temporopolar area", - "39 - Angular gyrus, part of Wernicke's area", - "40 - Supramarginal gyrus part of Wernicke's area", - "41 - Primary and Auditory Association Cortex", - "42 - Primary and Auditory Association Cortex", - "43 - Subcentral area", - "44 - pars opercularis, part of Broca's area", - "45 - pars triangularis Broca's area", - "46 - Dorsolateral prefrontal cortex", - "47 - Inferior prefrontal gyrus", - "48 - Retrosubicular area", - "Brain_Outside", - ] - - cmap1 = plt.cm.get_cmap('tab20') # First 20 colors - cmap2 = plt.cm.get_cmap('tab20b') # Next 20 colors - - # Combine the colors from both colormaps - colors = [cmap1(i) for i in range(20)] + [cmap2(i) for i in range(20)] # Total 40 colors - - landmarks.sort(key=lambda x: (int(x.split(" - ")[0]) if x.split(" - ")[0].isdigit() else float('inf'))) - - landmark_color_map = {landmark: colors[i % len(colors)] for i, landmark in enumerate(landmarks)} - - # Iterate over each channel - for idx, channel_name in enumerate(hbo_channel_names): - - # Run the fOLD on the selected channel - channel_data = raw.copy().pick(picks=channel_name) # type: ignore - - output = cast(list[DataFrame], fold_channel_specificity_normal(channel_data, interpolate=True, atlas='Brodmann')) - - # Process each DataFrame that fold_channel_specificity returns - for df_data in output: - - # Extract the relevant columns - useful_data = df_data[['Landmark', 'Specificity']] - - # Store the results - landmark_specificity_data.append({ - 'Channel': channel_name, - 'Data': useful_data, - }) - - # logger.info the results - for data in landmark_specificity_data: - logger.info(f"Channel: {data['Channel']}") - logger.info(f"{data['Data']}") - logger.info("-" * 60) - - - # If PLOT_ENABLED is True, plot the results - if GUI: - unique_landmarks = sorted(useful_data['Landmark'].unique()) - color_list = [landmark_color_map[landmark] for landmark in useful_data['Landmark']] - - # Plot specificity for each channel - ax = axes[idx] # Use the correct axis for this channel - - labels = [f'{landmark.split(" - ")[0]}' if landmark != 'Brain_Outside' else 'B' for landmark in useful_data['Landmark']] - - wedges, texts, autotexts = ax.pie( - useful_data['Specificity'], - autopct='%1.1f%%', - startangle=90, - labels=labels, # Add the labels here - labeldistance=1.05, # Adjust label position to avoid overlap with the wedges - colors=color_list) # Ensure color consistency - - ax.set_title(f'{channel_name}') - ax.axis('equal') # Equal aspect ratio ensures the pie chart is circular. - - - # Reset the list for the next particcipant - landmark_specificity_data = [] - - if GUI: - handles = [ - plt.Line2D([0], [0], marker='o', color='w', label=landmark, markersize=10, - markerfacecolor=landmark_color_map[landmark]) - for landmark in landmarks - ] - n_landmarks = len(landmarks) - - # Calculate the figure size based on number of rows and columns - fig_width = 5 - fig_height = n_landmarks / 4 - - # Create a new figure window for the legend - legend_fig = plt.figure(figsize=(fig_width, fig_height)) - legend_axes = legend_fig.add_subplot(111) - legend_axes.axis('off') # Turn off axis for the legend window - legend_axes.legend(handles=handles, loc='center', fontsize=10, title="Landmarks") - - if GUI: - for ax in axes[len(hbo_channel_names):]: - ax.axis('off') - plt.show() - - - - -def brain_landmarks_3d(raw_haemo: BaseRaw, show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all') -> None: - - brain = Brain("fsaverage", background="white", size=(800, 700)) # type: ignore - - # Add optode text labels manually - if show_optodes == 'all' or show_optodes == 'sensors': - brain.add_sensors(getattr(raw_haemo, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=VERBOSITY) # type: ignore - - # Read and parse the file - if show_optodes == 'all' or show_optodes == 'labels': - positions: list[tuple[str, list[float]]] = [] - with open(OPTODE_FILE_PATH, 'r') as f: - for line in f: - line = line.strip() - if not line or ':' not in line: - continue # skip empty/malformed lines - name, coords = line.split(':', 1) - coords = [float(x) for x in coords.strip().split()] - positions.append((name.strip(), coords)) - - for name, (x, y, z) in positions: - brain._renderer.text3d(x, y, z, name, color=('red' if name.startswith('s') else 'blue' if name.startswith('d') else 'gray'), scale=0.002) # type: ignore - - for ch in getattr(raw_haemo, "info")['chs']: - logger.info(ch['ch_name'], ch['loc'][:3]) - - # Add Brodmann labels - labels = cast(list[mne.Label], mne.read_labels_from_annot("fsaverage", "PALS_B12_Brodmann", "rh", verbose=VERBOSITY)) # type: ignore - - label_colors = { - "Brodmann.39-rh": "blue", - "Brodmann.40-rh": "green", - "Brodmann.6-rh": "pink", - "Brodmann.7-rh": "orange", - "Brodmann.17-rh": "red", - "Brodmann.1-rh": "yellow", - "Brodmann.2-rh": "yellow", - "Brodmann.3-rh": "yellow", - "Brodmann.18-rh": "red", - "Brodmann.19-rh": "red", - "Brodmann.4-rh": "purple", - "Brodmann.8-rh": "white" - } - - for label in labels: - name = getattr(label, "name", None) - if not isinstance(name, str): - continue - if name in label_colors: - brain.add_label(label, borders=False, color=label_colors[name]) # type: ignore - - - -def data_to_csv(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]): - - logger.info("Getting the current directory...") - if PLATFORM_NAME == 'darwin': - csvs_folder = os.path.join(os.path.dirname(sys.executable), "../../../csvs") - else: - cwd = os.getcwd() - csvs_folder = os.path.join(cwd, "csvs") - - logger.info("Attempting to create the csvs folder...") - os.makedirs(csvs_folder, exist_ok=True) - - # Generate a timestamp to be appended to the end of the file name - logger.info("Generating the timestamp...") - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - - # Iterate over all groups - for group_name in all_results: - - # Get the channel data and generate the file name - (_, df_cha, _, _) = all_results[group_name] - - filename = f"{group_name}_{timestamp}.csv" - save_path = os.path.join(csvs_folder, filename) - - # Filter to just the target condition and store it in the csv - if HRF_MODEL == 'fir': - filtered_df = df_cha[ - df_cha["Condition"].str.startswith(TARGET_ACTIVITY) & - (df_cha["Chroma"] == "hbo") - ] - - # Step 2: Define the aggregation logic - agg_funcs = { - 'df': 'mean', - 'mse': 'mean', - 'p_value': 'mean', - 'se': 'mean', - 't': 'mean', - 'theta': 'mean', - 'Source': 'mean', - 'Detector': 'mean', - 'Significant': lambda x: x.sum() > (len(x) / 2), - 'Chroma': 'first', # assuming all are the same - 'ch_name': 'first', # same ch_name in the group - 'ID': 'first', # same ID in the group - } - - # Step 3: Group and aggregate - averaged_df = ( - filtered_df - .groupby(['ch_name', 'ID'], as_index=False) - .agg(agg_funcs) - ) - - # Step 4: Rename and add 'Condition' as TARGET_ACTIVITY - averaged_df.insert(0, 'Condition', TARGET_ACTIVITY) - - averaged_df["df"] = averaged_df["df"].round().astype(int) - averaged_df["Source"] = averaged_df["Source"].round().astype(int) - averaged_df["Detector"] = averaged_df["Detector"].round().astype(int) - - # Step 5: Reset index and reorder columns - ordered_cols = [ - 'Condition', 'df', 'mse', 'p_value', 'se', 't', 'theta', - 'Source', 'Detector', 'Chroma', 'Significant', 'ch_name', 'ID' - ] - averaged_df = averaged_df[ordered_cols].reset_index(drop=True) - averaged_df = averaged_df.sort_values(by=["ID", "Detector", "Source"]).reset_index(drop=True) - - output_df = averaged_df - else: - output_df = df_cha.query(f"Condition == '{TARGET_ACTIVITY}' and Chroma == 'hbo'") # type: ignore - output_df.to_csv(save_path) - - - -def all_data_to_csv(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]): - - logger.info("Getting the current directory...") - if PLATFORM_NAME == 'darwin': - csvs_folder = os.path.join(os.path.dirname(sys.executable), "../../../csvs") - else: - cwd = os.getcwd() - csvs_folder = os.path.join(cwd, "csvs") - - logger.info("Attempting to create the csvs folder...") - os.makedirs(csvs_folder, exist_ok=True) - - # Generate a timestamp to be appended to the end of the file name - logger.info("Generating the timestamp...") - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - - # Iterate over all groups - for group_name in all_results: - - # Get the channel data and generate the file name - (_, df_cha, _, _) = all_results[group_name] - - filename = f"{group_name}_{timestamp}_all.csv" - save_path = os.path.join(csvs_folder, filename) - - # Filter to just the target condition and store it in the csv - if HRF_MODEL == 'fir': - output_df = df_cha - else: - output_df = df_cha.query(f"Condition == '{TARGET_ACTIVITY}' and Chroma == 'hbo'") # type: ignore - output_df.to_csv(save_path) - - - -def brain_3d_contrast(con_model_df: DataFrame, con_model_df_filtered: BaseRaw, common_channels: list[str], first_name: str, second_name: str, first_stim: float, second_stim: float, t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True) -> None: - # Filter DataFrame to only common channels, and sort by raw order - con_model = con_model_df - - con_model["ch_name"] = pd.Categorical( - con_model["ch_name"], categories=common_channels, ordered=True - ) - con_model = con_model.sort_values("ch_name").reset_index(drop=True) # type: ignore - - - if t_or_theta == 't': - clim=dict(kind="value", pos_lims=(0, ABS_T_VALUE/2, ABS_T_VALUE)) - elif t_or_theta == 'theta': - clim=dict(kind="value", pos_lims=(0, ABS_THETA_VALUE/2, ABS_THETA_VALUE)) - - # Plot brain figure - brain = plot_3d_evoked_array(con_model_df_filtered.copy().pick(picks="hbo"), con_model, view="dorsal", distance=BRAIN_DISTANCE, colorbar=True, mode=BRAIN_MODE, clim=clim, size=(800, 700), verbose=VERBOSITY) # type: ignore - - if show_optodes == 'all' or show_optodes == 'sensors': - brain.add_sensors(getattr(con_model_df_filtered, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=VERBOSITY) # type: ignore - - # Read and parse the file - if show_optodes == 'all' or show_optodes == 'labels': - positions: list[tuple[str, list[float]]] = [] - with open(OPTODE_FILE_PATH, 'r') as f: - for line in f: - line = line.strip() - if not line or ':' not in line: - continue # skip empty/malformed lines - name, coords = line.split(':', 1) - coords = [float(x) for x in coords.strip().split()] - positions.append((name.strip(), coords)) - - for name, (x, y, z) in positions: - brain._renderer.text3d(x, y, z, name, color=('red' if name.startswith('s') else 'blue' if name.startswith('d') else 'gray'), scale=0.002) # type: ignore - - # Set the display text for the brain image - # display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nContrast: ' + first_name + ' - ' + second_name + '\nReject Criteria Threshold: ' + str(EPOCH_REJECT_CRITERIA_THRESH) + '\nMin Time Threshold: ' + - # str(TIME_MIN_THRESH) + 's\nMax Time Threshold: ' + str(TIME_MAX_THRESH) + 's\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) + '\nStim Duration: ' + str(first_stim) + ', ' + - # str(second_stim) + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE + '\nLooking at: ' + t_or_theta + ' values') - - if HRF_MODEL == 'fir': - display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nContrast: ' + first_name + ' - ' + second_name + '\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) - + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE + '\nLooking at: ' + t_or_theta + ' values') - else: - display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nContrast: ' + first_name + ' - ' + second_name + '\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) - + '\nStim Duration: ' + str(first_stim) + ', ' + str(second_stim) + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE + '\nLooking at: ' + t_or_theta + ' values') - - # Apply the text onto the brain - if show_text: - brain.add_text(0.12, 0.70, display_text, "title", font_size=11, color="k") # type: ignore - - - -def plot_2d_3d_contrasts_between_groups(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], all_raw_haemo: dict[str, dict[str, BaseRaw]], t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True) -> None: - - - # Dictionary to store data for each group - group_dfs: dict[str, DataFrame] = {} - - # GET RAW HAEMO OF THE FIRST PARTICIPANT - raw_haemo = all_raw_haemo[list(all_raw_haemo.keys())[0]]["full_layout"] - - # Store all contrasts with the corresponding group name - for group_name, (_, _, df_con, _) in all_results.items(): - group_dfs[group_name] = df_con - group_dfs[group_name]["group"] = group_name - - # Concatenate all groups together - df_combined = pd.concat(group_dfs.values(), ignore_index=True) - - con_summary = df_combined.query("Chroma == 'hbo'").copy() # type: ignore - - valid_channels = cast(DataFrame, (pd.crosstab(con_summary['group'], con_summary['ch_name']) > 1).all()) # type: ignore - valid_channels = valid_channels[valid_channels].index.tolist() - - # Filter data to only these channels - con_summary = con_summary[con_summary['ch_name'].isin(valid_channels)] # type: ignore - - - # # Verify your data looks as expected - # logger.info(con_summary[['group', 'ch_name', 'Chroma', 'effect']].head()) - # logger.info("\nUnique values:") - # logger.info("Groups:", con_summary['group'].unique()) - # logger.info("Channels:", con_summary['ch_name'].unique()) - # logger.info("Chroma:", con_summary['Chroma'].unique()) # Should be just 'hbo' - - model_formula = "effect ~ -1 + group:ch_name:Chroma" - con_model = smf.mixedlm(model_formula, con_summary, groups=con_summary["ID"]).fit(method="nm") # type: ignore - - # logger.info(con_model.summary()) - - - # # Fit the mixed-effects model - # model_formula = "effect ~ -1 + group:ch_name:Chroma" - - # #model_formula = "effect ~ -1 + group + ch_name" - - # con_model = smf.mixedlm( - # model_formula, con_summary_filtered, groups=con_summary_filtered["ID"] - # ).fit(method="nm") - - # Get the t values if we are comparing them - - t_values: pd.Series[float] = pd.Series(dtype=float) - if t_or_theta == 't': - t_values = con_model.tvalues - - # Get all the group names from the dictionary and how many groups we have - group_names = list(group_dfs.keys()) - n_groups = len(group_names) - - # Store DataFrames for each contrast - for i in range(n_groups): - for j in range(i + 1, n_groups): - group1_name = group_names[i] - group2_name = group_names[j] - - if t_or_theta == 't': - # Extract the t-values for both groups - group1_vals = t_values.filter(like=f"group[{group1_name}]") # type: ignore - group2_vals = t_values.filter(like=f"group[{group2_name}]") # type: ignore - vlim_value = ABS_CONTRAST_T_VALUE - - elif t_or_theta == 'theta': - # Extract the coefficients for both groups - group1_vals = con_model.params.filter(like=f"group[{group1_name}]") - group2_vals = con_model.params.filter(like=f"group[{group2_name}]") - vlim_value = ABS_CONTRAST_THETA_VALUE - - # TODO: Does this work for all separators? - # Extract channel names - - group1_channels: list[str] = [ - name.split(":")[1].split("[")[1].split("]")[0] - for name in getattr(group1_vals, "index") - ] - group2_channels: list[str] = [ - name.split(":")[1].split("[")[1].split("]")[0] - for name in getattr(group2_vals, "index") - ] - - # Create the DataFrames with channel indices - df_group1 = DataFrame( - {"Coef.": group1_vals.values}, index=group1_channels # type: ignore - ) - df_group2 = DataFrame( - {"Coef.": group2_vals.values}, index=group2_channels # type: ignore - ) - - # Merge the two DataFrames on the channel names - df_contrast = df_group1.join(df_group2, how="inner", lsuffix=f"_{group1_name}", rsuffix=f"_{group2_name}") # type: ignore - - # Compute the contrasts - contrast_1_2 = df_contrast[f"Coef._{group1_name}"] - df_contrast[f"Coef._{group2_name}"] - contrast_2_1 = df_contrast[f"Coef._{group2_name}"] - df_contrast[f"Coef._{group1_name}"] - - # Add the a-b / 1-2 contrast to the DataFrame. The order and names of the keys in the DataFrame are important! - df_contrast["Coef."] = contrast_1_2 - con_model_df_1_2 = DataFrame({ - "ch_name": df_contrast.index, - "Coef.": df_contrast["Coef."], - "Chroma": "hbo" - }) - - - mne_ch_names = getattr(raw_haemo.copy().pick(picks="hbo"), "ch_names") # type: ignore - glm_ch_names = cast(list[DataFrame], con_model_df_1_2["ch_name"].tolist()) - - # Get ordered common channels - common_channels = [ch for ch in mne_ch_names if ch in glm_ch_names] - - # Filter raw data to these channels - con_model_df_filtered = raw_haemo.copy().pick(picks=common_channels) # type: ignore - - # Reindex GLM results to match MNE channel order - con_model_df_1_2 = con_model_df_1_2.set_index("ch_name").loc[common_channels].reset_index() # type: ignore - - # Create the 3d visualization - brain_3d_contrast(con_model_df_1_2, con_model_df_filtered, common_channels, group1_name, group2_name, STIM_DURATION[i], STIM_DURATION[j], t_or_theta, show_optodes, show_text) - - plot_glm_group_topo(con_model_df_filtered.copy().pick(picks="hbo"), con_model_df_1_2, names=True, res=128, vlim=(-vlim_value, vlim_value)) # type: ignore - - - # TODO: The title currently goes on the colorbar. Low priority - plt.title(f"Contrast: {group1_name} vs {group2_name}") # type: ignore - plt.show() # type: ignore - - # Add the b-a / 2-1 contrast to the DataFrame. The order and names of the keys in the DataFrame are important! - df_contrast["Coef."] = contrast_2_1 - con_model_df_2_1 = DataFrame({ - "ch_name": df_contrast.index, - "Coef.": df_contrast["Coef."], - "Chroma": "hbo" - }) - - - mne_ch_names = getattr(raw_haemo.copy().pick(picks="hbo"), "ch_names") # type: ignore - glm_ch_names = cast(list[DataFrame], con_model_df_2_1["ch_name"].tolist()) - - # Get ordered common channels - common_channels = [ch for ch in mne_ch_names if ch in glm_ch_names] - - # Filter raw data to these channels - con_model_df_filtered = raw_haemo.copy().pick(picks=common_channels) # type: ignore - - # Reindex GLM results to match MNE channel order - con_model_df_2_1 = con_model_df_2_1.set_index("ch_name").loc[common_channels].reset_index() # type: ignore - - # Create the 3d visualization - brain_3d_contrast(con_model_df_2_1, con_model_df_filtered, common_channels, group2_name, group1_name, STIM_DURATION[j], STIM_DURATION[i], t_or_theta, show_optodes, show_text) - - plot_glm_group_topo(con_model_df_filtered.copy().pick(picks="hbo"), con_model_df_2_1, names=True, res=128, vlim=(-vlim_value, vlim_value)) # type: ignore - - # TODO: The title currently goes on the colorbar. Low priority - plt.title(f"Contrast: {group2_name} vs {group1_name}") # type: ignore - plt.show() # type: ignore - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -# TODO: Is any of this still useful? - -def calculate_annotations(raw_haemo_filtered, file_name, output_folder=None, save_images=None): - '''Method that extract the annotations from the data.\n - Input:\n - raw_haemo_filtered (RawSNIRF) - The filtered haemoglobin concentration data\n - file_name (string) - The file name of the current file\n - output_folder (string) - (optional) Where to save the images. Default is None\n - save_images (string) - (optional) Bool to save the images or not. Default is None - Output:\n - events (ndarray) - Array containing row number and what index the event is\n - event_dict (dict) - Contains the names of the events''' - - if output_folder is None: - output_folder = None - if save_images is None: - save_images = None - - # Get when the events occur and what they are called, and display a figure with the result - events, event_dict = mne.events_from_annotations(raw_haemo_filtered) - - # Do we save the image? - if save_images: - fig = mne.viz.plot_events(events, event_id=event_dict, sfreq=raw_haemo_filtered.info["sfreq"], show=False) - save_path = output_folder + "/8. Annotations for " + file_name + ".png" - fig.savefig(save_path) - - return events, event_dict - - - -def calculate_good_epochs(raw_haemo_filtered, events, event_dict, file_name, tmin=None, tmax=None, reject_thresh=None, target_activity=None, target_control=None, output_folder=None, save_images=None): - '''Calculates what epochs are good and creates a graph showing if any are dropped.\n - Input:\n - raw_haemo_filtered (RawSNIRF) - The filtered haemoglobin concentration data\n - events (ndarray) - Array containing row number and what index the event is\n - event_dict (dict) - Contains the names of the events\n - file_name (string) - The file name of the current file\n - tmin (float) - (optional) Time in seconds to display before the event. Default is TIME_MIN_THRESH\n - tmax (float) - (optional) Time in seconds to display after the event. Default is TIME_MAX_THRESH\n - reject_thresh (float) - (optional) Value that determines the threshold for rejecting epochs. Default is EPOCH_REJECT_CRITERIA_THRESH\n - target_activity (string) - (optional) The target activity. Default is TARGET_ACTIVITY\n - target_control (string) - (optional) The target control. Default is TARGET_CONTROL\n - output_folder (string) - (optional) Where to save the images. Default is None\n - save_images (string) - (optional) Bool to save the images or not. Default is None - Output:\n - good_epochs (Epochs) - The remaining good epochs\n - all_epochs (Epochs) - All of the epochs''' - - if tmin is None: - tmin = TIME_MIN_THRESH - if tmax is None: - tmax = TIME_MAX_THRESH - if reject_thresh is None: - reject_thresh = EPOCH_REJECT_CRITERIA_THRESH - if target_activity is None: - target_activity = TARGET_ACTIVITY - if target_control is None: - target_control = TARGET_CONTROL - if output_folder is None: - output_folder = None - if save_images is None: - save_images = None - - # Get all the good epochs - good_epochs = mne.Epochs( - raw_haemo_filtered, - events, - event_id=event_dict, - tmin=tmin, - tmax=tmax, - reject=dict(hbo=reject_thresh), - reject_by_annotation=True, - proj=True, - baseline=(None, 0), - preload=True, - detrend=None, - verbose=True, - ) - - # Get all the epochs - all_epochs = mne.Epochs( - raw_haemo_filtered, - events, - event_id=event_dict, - tmin=tmin, - tmax=tmax, - proj=True, - baseline=(None, 0), - preload=True, - detrend=None, - verbose=True, - ) - - if REJECT_PAIRS: - # Calculate which epochs were in all but not in good - all_idx = all_epochs.selection - good_idx = good_epochs.selection - bad_idx = np.setdiff1d(all_idx, good_idx) - - # Split the controls and the activities - event_ids = all_epochs.events[:, 2] - control_id = event_dict[target_control] - activity_id = event_dict[target_activity] - - to_reject_extra = set() - - for i, idx in enumerate(all_idx): - if idx in bad_idx: - ev = event_ids[i] - # If the control was bad, drop the following activity - if ev == control_id and i + 1 < len(all_idx): - if event_ids[i + 1] == activity_id: - to_reject_extra.add(all_idx[i + 1]) - # If the activity was bad, drop the preceding activity - if ev == activity_id and i - 1 >= 0: - if event_ids[i - 1] == control_id: - to_reject_extra.add(all_idx[i - 1]) - - # Create a list to store all the new drops, only adding them if they are currently classified as good - drop_idx_in_good = [ - np.where(good_idx == idx)[0][0] for idx in to_reject_extra if idx in good_idx - ] - - # Drop the pairings of the bad epochs - good_epochs.drop(drop_idx_in_good) - - # Do we save the image? - if save_images: - drop_log_fig = good_epochs.plot_drop_log(show=False) - save_path = output_folder + "/8. Epoch drops for " + file_name + ".png" - drop_log_fig.savefig(save_path) - - return good_epochs, all_epochs - - - -def bad_check(raw_od, max_bad_channels=None): - '''Method to see if we have more bad channels than our allowed threshold.\n - Inputs:\n - raw_od (RawSNIRF) - The optical density data\n - max_bad_channels (int) - (optional) The max amount of bad channels we want to tolerate. Default is MAX_BAD_CHANNELS\n - Output\n - (bool) - True it we had more bad channels than the threshold, False if we did not''' - - if max_bad_channels is None: - max_bad_channels = MAX_BAD_CHANNELS - - # Check if there is more bad channels in the bads key compared to the allowed amount - if len(raw_od.info.get('bads', [])) >= max_bad_channels: - return True - else: - return False - - - -def remove_bad_epoch_pairings(raw_haemo_filtered_minus_short, good_epochs, epoch_pair_tolerance_window=None): - '''Method to apply our new epochs to the loaded data in working memory. This is to ensure that the GLM does not see these epochs. - Inputs:\n - raw_haemo_filtered_minus_short (RawSNIRF) - The filtered haemoglobin concentration data\n - good_epochs (Epochs) - The epochs we want the loaded data to take on\n - epoch_pair_tolerance_window (int) - (optional) The amount of data points the paired epoch can deviate from the expected amount. Default is EPOCH_PAIR_TOLERANCE_WINDOW\n - Output:\n - raw_haemo_filtered_good_epochs (RawSNIRF) - The filtered haemoglobin concentration data with only the good epochs''' - - if epoch_pair_tolerance_window is None: - epoch_pair_tolerance_window = EPOCH_PAIR_TOLERANCE_WINDOW - # Copy the input haemoglobin concentration data and drop the bad channels - raw_haemo_filtered_good_epochs = raw_haemo_filtered_minus_short.copy() - raw_haemo_filtered_good_epochs = raw_haemo_filtered_good_epochs.drop_channels(raw_haemo_filtered_good_epochs.info['bads']) - - # Get the event IDs of the good events - good_event_samples = set(good_epochs.events[:, 0]) - logger.info(f"Total good events (epochs): {len(good_event_samples)}") - - # Get the current annotations - raw_annots = raw_haemo_filtered_good_epochs.annotations - - # Create lists to use for processing - clean_descriptions = [] - clean_onsets = [] - clean_durations = [] - dropped = [] - - # Get the frequency of the file - sfreq = raw_haemo_filtered_good_epochs.info['sfreq'] - - for desc, onset, dur in zip(raw_annots.description, raw_annots.onset, raw_annots.duration): - # Convert annotation onset time to sample index - sample = int(onset * sfreq) - - if FORCE_DROP_ANNOTATIONS: - for i in FORCE_DROP_ANNOTATIONS: - if desc == i: - dropped.append((desc, onset)) - continue - - # Check if the annotation is within the tolerance of any good event - matched = any(abs(sample - event_sample) <= epoch_pair_tolerance_window for event_sample in good_event_samples) - - # We found a matching event - if matched: - clean_descriptions.append(desc) - clean_onsets.append(onset) - clean_durations.append(dur) - else: - dropped.append((desc, onset)) - - # Create the new filtered annotations - new_annots = Annotations( - onset=clean_onsets, - duration=clean_durations, - description=clean_descriptions, - ) - - # Assign the new annotations - raw_haemo_filtered_good_epochs.set_annotations(new_annots) - - # logger.info out the results - logger.info(f"Original annotations: {len(raw_annots)}") - logger.info(f"Kept annotations: {len(clean_descriptions)}") - logger.info("Kept annotation types:", set(clean_descriptions)) - if dropped: - logger.info(f"Dropped annotations: {len(dropped)}") - logger.info("Dropped annotations:") - for desc, onset in dropped: - logger.info(f" - {desc} at {onset:.2f}s") - else: - logger.info("No annotations were dropped!") - - return raw_haemo_filtered_good_epochs \ No newline at end of file diff --git a/flares.py b/flares.py new file mode 100644 index 0000000..5eedd78 --- /dev/null +++ b/flares.py @@ -0,0 +1,2994 @@ +""" +Filename: flares.py +Description: Core functionality for FLARES + +Author: Tyler de Zeeuw +License: GPL-3.0 +""" + +# Built-in imports +import os +import sys +import platform +import threading +import logging +from io import BytesIO +from typing import Any, Optional, cast, Literal, Union +from itertools import compress +from copy import deepcopy +from multiprocessing import Queue +import os.path as op +import re +import traceback +from concurrent.futures import ProcessPoolExecutor, as_completed + +# External library imports +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +from matplotlib.figure import Figure +from matplotlib.axes import Axes +from matplotlib.colors import LinearSegmentedColormap + +import numpy as np +from numpy.typing import NDArray +from numpy import float64 + +import pandas as pd +from pandas import DataFrame + +import seaborn as sns +import h5py + +from nilearn.plotting import plot_design_matrix # type: ignore +from nilearn.glm.regression import OLSModel + +import statsmodels.formula.api as smf # type: ignore +from statsmodels.stats.multitest import multipletests + +from scipy import stats +from scipy.spatial.distance import cdist + +# External library imports for mne +from mne import ( + EvokedArray, SourceEstimate, Info, Epochs, Label, + events_from_annotations, read_source_spaces, + stc_near_sensors, pick_types, grand_average, get_config, set_config, read_labels_from_annot +) # type: ignore +from mne.source_space import SourceSpaces +from mne.transforms import Transform # type: ignore +from mne.io import BaseRaw, read_raw_snirf # type: ignore +from mne.preprocessing.nirs import ( + beer_lambert_law, optical_density, + temporal_derivative_distribution_repair, + source_detector_distances, short_channels +) # type: ignore +from mne.viz import Brain, plot_events, plot_evoked_topo, plot_compare_evokeds +from mne.filter import filter_data # type: ignore +from mne.utils import _check_fname, _validate_type, warn +from mne.channels import make_standard_montage +from mne.datasets.sample import data_path + +from mne_nirs.visualisation import plot_glm_group_topo # type: ignore +from mne_nirs.channels import get_long_channels, get_short_channels # type: ignore +from mne_nirs.experimental_design import make_first_level_design_matrix # type: ignore +from mne_nirs.statistics import run_glm, statsmodels_to_results # type: ignore +from mne_nirs.signal_enhancement import ( + enhance_negative_correlation, short_channel_regression +) # type: ignore +from mne_nirs.io.fold import fold_channel_specificity # type: ignore +from mne_nirs.preprocessing import peak_power # type: ignore +from mne_nirs.statistics._glm_level_first import RegressionResults # type: ignore + + +os.environ["SUBJECTS_DIR"] = str(data_path()) + "/subjects" # type: ignore + +FIXED_CATEGORY_COLORS = { + "SCI only": "skyblue", + "PSP only": "salmon", + "SNR only": "lightgreen", + "PSP + SCI": "orange", + "SCI + SNR": "violet", + "PSP + SNR": "gold", + "SCI + PSP": "orange", + "SNR + SCI": "violet", + "SNR + PSP": "gold", + "PSP + SNR + SCI": "gray", + "SCI + PSP + SNR": "gray", + "SCI + SNR + PSP": "gray", + "PSP + SCI + SNR": "gray", + "PSP + SNR + SCI": "gray", + "SNR + SCI + PSP": "gray", + "SNR + PSP + SCI": "gray", +} + + +AGE: float +GENDER: str + +SECONDS_TO_STRIP: int +DOWNSAMPLE: bool +DOWNSAMPLE_FREQUENCY: int + +SCI: bool +SCI_TIME_WINDOW: int +SCI_THRESHOLD: float + +SNR: bool +# SNR_TIME_WINDOW : int +SNR_THRESHOLD: float + +PSP: bool +PSP_TIME_WINDOW: int +PSP_THRESHOLD: float + +TDDR: bool + +ENHANCE_NEGATIVE_CORRELATION: bool + +VERBOSITY = True + +# FIXME: Shouldn't need each ordering - just order it before checking +FIXED_CATEGORY_COLORS = { + "SCI only": "skyblue", + "PSP only": "salmon", + "SNR only": "lightgreen", + "PSP + SCI": "orange", + "SCI + SNR": "violet", + "PSP + SNR": "gold", + "SCI + PSP": "orange", + "SNR + SCI": "violet", + "SNR + PSP": "gold", + "PSP + SNR + SCI": "gray", + "SCI + PSP + SNR": "gray", + "SCI + SNR + PSP": "gray", + "PSP + SCI + SNR": "gray", + "PSP + SNR + SCI": "gray", + "SNR + SCI + PSP": "gray", + "SNR + PSP + SCI": "gray", +} + + +AGE = 25 +GENDER = "" +GROUP = "Default" + +REQUIRED_KEYS: dict[str, Any] = { + + + "SECONDS_TO_STRIP": int, + "DOWNSAMPLE": bool, + "DOWNSAMPLE_FREQUENCY": int, + + "SCI": bool, + "SCI_TIME_WINDOW": int, + "SCI_THRESHOLD": float, + + "SNR": bool, + # SNR_TIME_WINDOW : int + "SNR_THRESHOLD": float, + + "PSP": bool, + "PSP_TIME_WINDOW": int, + "PSP_THRESHOLD": float, + + # "REJECT_PAIRS": bool, + # "FORCE_DROP_ANNOTATIONS": list, + # "FILTER_LOW_PASS": float, + # "FILTER_HIGH_PASS": float, + # "EPOCH_PAIR_TOLERANCE_WINDOW": int, +} + + +class ProcessingError(Exception): + def __init__(self, message: str = "Something went wrong!"): + self.message = message + super().__init__(self.message) + + +# Ensure that we are working in the directory of this file +script_dir = os.path.dirname(os.path.abspath(__file__)) +os.chdir(script_dir) + +PLATFORM_NAME = platform.system().lower() + +# Configure logging to file with timestamps and realtime flush +if PLATFORM_NAME == 'darwin': + logging.basicConfig( + filename=os.path.join(os.path.dirname(sys.executable), "../../../fnirs_analysis.log"), + level=logging.INFO, + format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + filemode='a' + ) + +else: + logging.basicConfig( + filename='fnirs_analysis.log', + level=logging.INFO, + format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + filemode='a' + ) + +logger = logging.getLogger() + + + + +def set_config_me(config: dict[str, Any]) -> None: + """ + Validates and applies the given configuration dictionary. + + Parameters + ---------- + config : dict[str, Any] + Dictionary containing configuration keys and their values. + """ + logger.info(f"[DEBUG] set_config called") + + globals().update(config) + + +def set_metadata(file_path, metadata: dict[str, Any]) -> None: + """ + Validates and applies the given configuration dictionary. + + Parameters + ---------- + config : dict[str, Any] + Dictionary containing configuration keys and their values. + """ + logger.info(f"[DEBUG] set_metadata called") + + globals()['AGE'] = 25 + globals()['GENDER'] = "" + globals()['GROUP'] = "Default" + + if metadata.get(file_path) is not None: + + file_metadata = metadata.get(file_path, {}) + + for key in ("AGE", "GENDER", "GROUP"): + val = file_metadata.get(key, None) + if val not in (None, '', [], {}, ()): # check for "empty" values + globals()[key] = val + + +def gui_entry(config: dict[str, Any], gui_queue: Queue, progress_queue: Queue) -> None: + try: + # Start a thread to forward progress messages back to GUI + def forward_progress(): + while True: + try: + msg = progress_queue.get(timeout=1) + if msg == "__done__": + break + gui_queue.put(msg) + except: + continue + + t = threading.Thread(target=forward_progress, daemon=True) + t.start() + file_paths = config['SNIRF_FILES'] + file_params = config['PARAMS'] + file_metadata = config['METADATA'] + + max_workers = file_params.get("MAX_WORKERS", int(os.cpu_count()/4)) + + # Run the actual processing, with progress_queue passed down + print("actual call") + results = process_multiple_participants(file_paths, file_params, file_metadata, progress_queue, max_workers) + + # Signal end of progress + progress_queue.put("__done__") + t.join() + + gui_queue.put({"success": True, "result": results}) + + + except Exception as e: + gui_queue.put({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }) + + + +def process_participant_worker(args): + file_path, file_params, file_metadata, progress_queue = args + + set_config_me(file_params) + set_metadata(file_path, file_metadata) + logger.info(f"DEBUG: Metadata for {file_path}: AGE={globals().get('AGE')}, GENDER={globals().get('GENDER')}, GROUP={globals().get('GROUP')}") + + def progress_callback(step_idx): + if progress_queue: + progress_queue.put(('progress', file_path, step_idx)) + + try: + result = process_participant(file_path, progress_callback=progress_callback) + return file_path, result, None + except Exception as e: + error_trace = traceback.format_exc() + return file_path, None, (str(e), error_trace) + + +def process_multiple_participants(file_paths, file_params, file_metadata, progress_queue=None, max_workers=None): + results_by_file = {} + + file_args = [(file_path, file_params, file_metadata, progress_queue) for file_path in file_paths] + + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(process_participant_worker, arg): arg[0] for arg in file_args} + + for future in as_completed(futures): + file_path = futures[future] + try: + file_path, result, error = future.result() + if error: + print(f"Error processing {file_path}: {error[0]}") + print(error[1]) + continue + results_by_file[file_path] = result + except Exception as e: + print(f"Unexpected error processing {file_path}: {e}") + + return results_by_file + + + + +def markbad(data, ax, ch_names: list[str]) -> None: + """ + Add a strikethrough to a plot for channels marked as bad. + + Parameters + ---------- + data : BaseRaw + The loaded data object to process. + ax : Axes + Matplotlib Axes object where the strikethrough lines will be drawn. + ch_names : list[str] + List of channel names corresponding to the y-axis of the plot. + """ + + # Iterate over all the channels + for i, ch in enumerate(ch_names): + + # If it is marked as bad, place a strikethrough on the channel + if ch in data.info["bads"]: + ax.axhline(i + 0.5, ls="solid", lw=4, color="black", zorder=10) # type: ignore + + + +def plot_timechannel_quality_metrics(data, scores, times: list[tuple[float]], color_stops: tuple[list[float], list[float]], threshold: float, title: Optional[str] = None): + + """ + Generate two heatmaps visualizing channel quality metrics over time. + + Parameters + ---------- + data : BaseRaw + The loaded data object to process. + scores : NDArray[float64] + A 2D array of quality scores for each channel over time. + times : list[tuple[float]] + List of time boundaries used to label each score column. + color_stops : tuple[list[float], list[float]] + Two lists of color values for custom colormaps. + threshold : float, + Threshold value for the color bar. + title : Optional[str], optional + Base title for the figures, (default is None). + + Returns + ------- + tuple[Figure, Figure] + - Figure: Heatmap of all scores across channels and time. + - Figure: Binary heatmap showing only scores above the threshold. + """ + + # Get only the hbo / hbr channels once as we dont need to see the same results twice + half_ch = len(getattr(data, "ch_names")) // 2 + ch_names = getattr(data, "ch_names")[:half_ch] + scores = scores[:half_ch, :] + + # Extract rounded time points to use as column headers + cols = [np.round(t[0]) for t in times] + n_chans = len(ch_names) + vsize = 0.2 * n_chans + + # Create the first figure + fig1, ax1 = plt.subplots(figsize=(10, vsize), layout="constrained") # type: ignore + fig1.suptitle(title + " - All Scores", fontsize=16, fontweight="bold") # type: ignore + + # Create a DataFrame to structure data for the heatmap + data_to_plot = DataFrame( + data=scores, + columns=pd.Index(cols, name="Time (s)"), + index=pd.Index(ch_names, name="Channel"), + ) + + # Define a custom colormap using provided color stops and base colors + base_colors = ['red', 'red', 'yellow', 'green', 'green'] + colors = list(zip(color_stops[0], base_colors[:len(color_stops[0])])) + cmap = mcolors.LinearSegmentedColormap.from_list('gyr', colors) + + # Plot heatmap of scores + sns.heatmap( # type: ignore + data=data_to_plot, + cmap=cmap, + vmin=0, + vmax=1, + cbar_kws=dict(label="Score"), + ax=ax1, + ) + + # Add vertical dashed lines at each time boundary, sit the title, and place a black strikethrough through a bad channel + for x in range(1, len(times)): + ax1.axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") # type: ignore + ax1.set_title("All Scores", fontweight="bold") # type: ignore + markbad(data, ax1, ch_names) + + # Calculate average score per channel and annotate to the right of the heatmap + avg_sci_subset: pd.Series[float] = data_to_plot.mean(axis=1) # type: ignore + norm = mcolors.Normalize(vmin=0, vmax=1) + text_x = data_to_plot.shape[1] + 0.5 + for i, val in enumerate(avg_sci_subset): + color = cmap(norm(val)) + ax1.text( # type: ignore + text_x, + i + 0.5, + f"{val:.3f}", + va='center', + ha='left', + fontsize=9, + color=color + ) + ax1.set_xlim(right=text_x + 1.5) + + plt.close(fig1) + + # Create the second figure + fig2, ax2 = plt.subplots(figsize=(10, vsize), layout="constrained") # type: ignore + fig2.suptitle(title + " - Scores Above Threshold", fontsize=16, fontweight="bold") # type: ignore + + # Create a DataFrame to structure data for the heatmap + data_to_plot = DataFrame( + data=scores > threshold, + columns=pd.Index(cols, name="Time (s)"), + index=pd.Index(ch_names, name="Channel"), + ) + + # Define a custom colormap using provided color stops and base colors + base_colors = ['red', 'red', 'white', 'white'] + colors = list(zip(color_stops[1], base_colors[:len(color_stops[1])])) + cmap = mcolors.LinearSegmentedColormap.from_list('gyr', colors) + + # Plot heatmap of scores + sns.heatmap( # type: ignore + data=data_to_plot, + vmin=0, + vmax=1, + cmap=cmap, + cbar_kws=dict(label="Score"), + ax=ax2, + ) + + # Add vertical dashed lines at each time boundary, sit the title, and place a black strikethrough through a bad channel + for x in range(1, len(times)): + ax2.axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") # type: ignore + ax2.set_title("Scores > Threshold", fontweight="bold") # type: ignore + markbad(data, ax2, ch_names) + + plt.close(fig2) + + return fig1, fig2 + + +def scalp_coupling_index_windowed_raw(data, time_window: float = 3.0, l_freq: float = 0.7, h_freq: float = 1.5, l_trans_bandwidth: float = 0.3, h_trans_bandwidth: float = 0.3): + """ + Compute windowed scalp coupling index (SCI) across fNIRS channels. + + Parameters + ---------- + data : BaseRaw + The loaded data object to process. + time_window : float, optional + Length of each time window in seconds (default is 3.0). + l_freq : float, optional + Low cutoff frequency for filtering in Hz (default is 0.7). + h_freq : float, optional + High cutoff frequency for filtering in Hz (default is 1.5). + l_trans_bandwidth : float, optional + Transition bandwidth for the low cutoff in Hz (default is 0.3). + h_trans_bandwidth : float, optional + Transition bandwidth for the high cutoff in Hz (default is 0.3). + + Returns + ------- + tuple[BaseRaw, NDArray[float64], list[tuple[float, float]]] + - BaseRaw: The original data object (unchanged). Ensures compatibility with peak_power(). + - NDArray[float64]: Correlation scores for each channel and time window. + - list[tuple[float, float]]: Time intervals for each window in seconds. + """ + + # Pick only fNIRS channels and sort them by channel name + picks: NDArray[np.intp] = pick_types(cast(Info, data.info), fnirs=True) # type: ignore + picks = picks[np.argsort([getattr(data, "ch_names")[pick] for pick in picks])] + + # FIXME: This may happen if the heart rate calculation tries to set a value way too low + if l_freq < 0.3: + l_freq = 0.3 + + # Band-pass filter the selected fNIRS channels + filtered_data = filter_data( + getattr(data, "_data"), + getattr(data, "info")["sfreq"], + l_freq, + h_freq, + picks=picks, + verbose=False, + l_trans_bandwidth=l_trans_bandwidth, # type: ignore + h_trans_bandwidth=h_trans_bandwidth, # type: ignore + ) + + # Calculate number of samples per time window, the total number of windows, and prepare output variables + window_samples = int(np.ceil(time_window * getattr(data, "info")["sfreq"])) + n_windows = int(np.floor(len(data) / window_samples)) + scores = np.zeros((len(picks), n_windows)) + times: list[tuple[float, float]] = [] + + # Slide through the data in windows to compute scalp coupling index (SCI) + for window in range(n_windows): + start_sample = int(window * window_samples) + end_sample = start_sample + window_samples + end_sample = np.min([end_sample, len(data) - 1]) + + # Track time boundaries for each window + t_start = getattr(data, "times")[start_sample] + t_stop = getattr(data, "times")[end_sample] + times.append((t_start, t_stop)) + + # Iterate through channels in pairs (hbo, hbr). This requires them to be sorted by channel name + for ii in range(0, len(picks), 2): + c1 = filtered_data[picks[ii]][start_sample:end_sample] + c2 = filtered_data[picks[ii + 1]][start_sample:end_sample] + + # Ensure the correlation data is valid + if np.std(c1) == 0 or np.std(c2) == 0 or np.any(np.isnan(c1)) or np.any(np.isnan(c2)): + c = 0 + else: + c = np.corrcoef(c1, c2)[0][1] + + # Assign the computed correlation to both channels in the pair + scores[ii, window] = c + scores[ii + 1, window] = c + + scores = scores[np.argsort(picks)] + + return data, scores, times + +def calculate_scalp_coupling(data, l_freq: float = 0.7, h_freq: float = 1.5): + """ + Calculate the scalp coupling index (SCI) and identify bad channels based on a threshold. + + Parameters + ---------- + data : BaseRaw + The loaded data object to process. + l_freq : float, optional + Low cutoff frequency for bandpass filtering in Hz (default is 0.7). + h_freq : float, optional + High cutoff frequency for bandpass filtering in Hz (default is 1.5) + + Returns + ------- + tuple[list[str], Figure, Figure] + - list[str]: Channel names identified as bad based on SCI threshold. + - Figure: Heatmap of all SCI scores across time and channels. + - Figure: Binary heatmap of SCI scores exceeding the threshold. + """ + + print("Calculating scalp coupling index...") + + # Compute the SCI + _, scores, times = scalp_coupling_index_windowed_raw(data, time_window=SCI_TIME_WINDOW, l_freq=l_freq, h_freq=h_freq) + + # Identify channels that don't meet the provided threshold + print("Identifying channels that do not meet the threshold...") + sci = scores.mean(axis=1) + data.info["bads"] = list(compress(cast(list[str], getattr(data, "ch_names")), sci < SCI_THRESHOLD)) + + # Determine the colors based on the threshold, and create the figures + print("Creating the figures...") + color_stops = ([0.0, SCI_THRESHOLD, SCI_THRESHOLD+0.1, 0.8, 1.0], [0.0, SCI_THRESHOLD, SCI_THRESHOLD, 1.0]) + fig1, fig2 = plot_timechannel_quality_metrics(data, scores, times, color_stops, SCI_THRESHOLD, "Scalp Coupling Index") + + print("Successfully calculated scalp coupling index.") + + return list(compress(cast(list[str], getattr(data, "ch_names")), sci < SCI_THRESHOLD)), fig1, fig2 + + + +def calculate_signal_noise_ratio(data): + """ + Calculates the signal-to-noise ratio (SNR) for each channel and identifies those below a defined threshold. + + Parameters + ---------- + data : BaseRaw + The loaded data object to process. + + Returns + ------- + tuple[list[str], Figure] + - list[str]: A list of channel names that fall below the SNR threshold and are considered bad. + - Figure: A matplotlib Figure showing the channels' SNR values. + """ + + print("Calculating signal to noise ratio...") + + # Compute the signal-to-noise ratio values + print("Computing the signal to noise power...") + signal_band=(0.01, 0.5) + noise_band=(1.0, 10.0) + data_signal = data.copy().filter(*signal_band, verbose=False) #type: ignore + data_noise = data.copy().filter(*noise_band, verbose=False) #type: ignore + signal_power = np.mean(data_signal.get_data()**2, axis=1) #type: ignore + noise_power = np.mean(data_noise.get_data()**2, axis=1) #type: ignore + + # Calculate the snr using the standard formula for dB + snr = 10 * np.log10(signal_power / (noise_power + np.finfo(float).eps)) + + # TODO: Understand what this does + groups: dict[str, list[str]] = {} + for ch in getattr(data, "ch_names"): + # Look for the space in the channel names and remove the characters after + # This is so we can get both oxy and deoxy to remove, as they will have the same source and detector + base = ch.rsplit(' ', 1)[0] + groups.setdefault(base, []).append(ch) # type: ignore + + # If any of the channels do not meet our threshold, they will get inserted into the bad_channels set + bad_channels: set[str] = set() + for base, ch_list in groups.items(): + if any(s < SNR_THRESHOLD for s, ch in zip(snr, getattr(data, "ch_names")) if ch in ch_list): + bad_channels.update(ch_list) + + # Design and create the figure + print("Creating the figure...") + snr_fig, ax = plt.subplots(figsize=(12, 4), layout="constrained") # type: ignore + colors = [(0/20, 'red'), (SNR_THRESHOLD/20, 'red'), ((SNR_THRESHOLD+.5)/20, 'yellow'), ((SNR_THRESHOLD+1)/20, 'green'), (20/20, 'green')] + cmap = LinearSegmentedColormap.from_list('custom_snr_cmap', colors) + norm = mcolors.Normalize(vmin=0, vmax=20) + scatter = ax.scatter(range(len(snr)), snr, c=snr, cmap=cmap, alpha=0.8, s=100, norm=norm) # type: ignore + ax.set(xlabel="Channel Number", ylabel="Signal-to-Noise Ratio (dB)", xlim=[0, len(snr)], ylim=[0, 20]) + ax.axhline(SNR_THRESHOLD, color='black', linestyle='--', alpha=0.3, linewidth=1) # type: ignore + cbar = snr_fig.colorbar(scatter, ax=ax, label="SNR Thresholds (dB)") # type: ignore + cbar.set_ticks([0, SNR_THRESHOLD, SNR_THRESHOLD+1, 20]) # type: ignore + cbar.set_ticklabels(['0', str(SNR_THRESHOLD), str(SNR_THRESHOLD+1), '20']) # type: ignore + + plt.close() + + print("Successfully calculated signal to noise ratio.") + + return list(bad_channels), snr_fig + + + +def build_fnirs_adjacency(raw, threshold_meters=0.03): + """Build an adjacency dictionary for fNIRS channels using 3D distance.""" + # Extract channel positions + ch_locs = [] + ch_names = [] + + for ch in raw.info['chs']: + loc = ch['loc'][:3] # Get x, y, z coordinates + if not np.isnan(loc).any(): + ch_locs.append(loc) + ch_names.append(ch['ch_name']) + + ch_locs = np.array(ch_locs) + + # Compute pairwise distances + dists = cdist(ch_locs, ch_locs) + + # Build adjacency dictionary + adjacency = {} + for i, ch_name in enumerate(ch_names): + neighbors = [ch_names[j] for j in range(len(ch_names)) + if 0 < dists[i, j] < threshold_meters] + adjacency[ch_name] = neighbors + + return adjacency + + + +def get_hbo_hbr_picks(raw): + # Pick all fNIRS channels + fnirs_picks = pick_types(raw.info, fnirs=True, exclude=[]) + + # Extract wavelengths from channel names (expecting something like 'S6_D4 763' or 'S6_D4 841') + wavelengths = [] + for idx in fnirs_picks: + ch_name = raw.ch_names[idx] + # Extract last 3 digits from channel name using regex + match = re.search(r'(\d{3})$', ch_name) + if match: + wavelengths.append(int(match.group(1))) + else: + raise ValueError(f"Channel name '{ch_name}' does not end with 3 digits.") + + wavelengths = np.array(wavelengths) + unique_wavelengths = np.unique(wavelengths) + if len(unique_wavelengths) != 2: + raise RuntimeError(f"Expected exactly 2 distinct wavelengths, found {unique_wavelengths}") + + # Determine which is HbO (larger) and which is HbR (smaller) + hbr_wl = unique_wavelengths.min() + hbo_wl = unique_wavelengths.max() + + print(f"HbR wavelength: {hbr_wl}, HbO wavelength: {hbo_wl}") + + # Find picks corresponding to each wavelength + hbr_picks = [fnirs_picks[i] for i, wl in enumerate(wavelengths) if wl == hbr_wl] + hbo_picks = [fnirs_picks[i] for i, wl in enumerate(wavelengths) if wl == hbo_wl] + + print(f"Found {len(hbr_picks)} HbR channels and {len(hbo_picks)} HbO channels.") + + return hbo_picks, hbr_picks, hbo_wl, hbr_wl + + +def interpolate_fNIRS_bads_weighted_average(raw, bad_channels, max_dist=0.03, min_neighbors=2): + """ + Interpolate bad fNIRS channels using a distance-weighted average of nearby good channels. + + Parameters + ---------- + raw : mne.io.Raw + The raw fNIRS data with bads marked in raw.info['bads']. + max_dist : float + Maximum distance (in meters) to consider for neighboring good channels. + min_neighbors : int + Minimum number of neighbors required to interpolate a bad channel. + + Returns + ------- + raw : mne.io.Raw + Modified raw object with bads interpolated (in-place). + """ + + print("Finding fNIRS channels...") + hbo_picks, hbr_picks, hbo_wl, hbr_wl = get_hbo_hbr_picks(raw) + + + if len(hbo_picks) != len(hbr_picks): + raise RuntimeError("Number of HbO and HbR channels must be the same.") + + # Base names without wavelength for pairing + def base_name(ch_name): + # Strip last 4 chars assuming format ' ' + # e.g. "S6_D6 841" -> "S6_D6" + return ch_name[:-4] + + hbo_names = [base_name(raw.ch_names[i]) for i in hbo_picks] + hbr_names = [base_name(raw.ch_names[i]) for i in hbr_picks] + + # Sanity check: pairs must match + for i in range(len(hbo_names)): + if hbo_names[i] != hbr_names[i]: + raise RuntimeError(f"Channel pairs do not match: {hbo_names[i]} vs {hbr_names[i]}") + + # Identify bad pairs if either channel in pair is bad + bad_pairs = [] + good_pairs = [] + for i, base in enumerate(hbo_names): + hbo_ch = raw.ch_names[hbo_picks[i]] + hbr_ch = raw.ch_names[hbr_picks[i]] + if (hbo_ch in raw.info['bads']) or (hbr_ch in raw.info['bads']): + bad_pairs.append(i) + else: + good_pairs.append(i) + + print(f"Total pairs: {len(hbo_names)}") + print(f"Good pairs: {len(good_pairs)}") + print(f"Bad pairs to interpolate: {len(bad_pairs)}") + + if len(bad_pairs) == 0: + print("No bad pairs found. Skipping interpolation.") + return raw + + # Extract locations (use HbO channel loc as pair location) + locs = np.array([raw.info['chs'][hbo_picks[i]]['loc'][:3] for i in range(len(hbo_names))]) + good_locs = locs[good_pairs] + bad_locs = locs[bad_pairs] + + # Compute distance matrix between bad and good pairs + dist_matrix = cdist(bad_locs, good_locs) + + interpolated_pairs = [] + + for i, bad_idx in enumerate(bad_pairs): + bad_base = hbo_names[bad_idx] + distances = dist_matrix[i] + close_idxs = np.where(distances < max_dist)[0] + + print(f"\nInterpolating pair {bad_base} (index {bad_idx})") + print(f" Nearby good pairs found: {len(close_idxs)}") + + if len(close_idxs) < min_neighbors: + print(f" Skipping {bad_base}: not enough neighbors (found {len(close_idxs)} < {min_neighbors})") + continue + + weights = 1 / (distances[close_idxs] + 1e-6) + weights /= weights.sum() + + neighbor_hbo_indices = [hbo_picks[good_pairs[idx]] for idx in close_idxs] + neighbor_hbr_indices = [hbr_picks[good_pairs[idx]] for idx in close_idxs] + + neighbor_hbo_data = raw._data[neighbor_hbo_indices, :] + neighbor_hbr_data = raw._data[neighbor_hbr_indices, :] + + interpolated_hbo = np.average(neighbor_hbo_data, axis=0, weights=weights) + interpolated_hbr = np.average(neighbor_hbr_data, axis=0, weights=weights) + + raw._data[hbo_picks[bad_idx]] = interpolated_hbo + raw._data[hbr_picks[bad_idx]] = interpolated_hbr + + interpolated_pairs.append(bad_base) + + if interpolated_pairs: + bad_ch_to_remove = [] + for base_ in interpolated_pairs: + bad_ch_to_remove.append(base_ + f" {hbr_wl}") # HbR + bad_ch_to_remove.append(base_ + f" {hbo_wl}") # HbO + + raw.info['bads'] = [ch for ch in raw.info['bads'] if ch not in bad_ch_to_remove] + + print("\nInterpolation complete.\n") + + for ch in raw.info['bads']: + print(f"Channel {ch} still marked as bad.") + + print("Bads cleared:", raw.info['bads']) + fig_raw_after = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="After interpolation", show=False) + + return raw, fig_raw_after + + + +def calculate_signal_noise_ratio(data): + """ + Calculates the signal-to-noise ratio (SNR) for each channel and identifies those below a defined threshold. + + Parameters + ---------- + data : BaseRaw + The loaded data object to process. + + Returns + ------- + tuple[list[str], Figure] + - list[str]: A list of channel names that fall below the SNR threshold and are considered bad. + - Figure: A matplotlib Figure showing the channels' SNR values. + """ + + print("Calculating signal to noise ratio...") + + # Compute the signal-to-noise ratio values + print("Computing the signal to noise power...") + signal_band=(0.01, 0.5) + noise_band=(1.0, 10.0) + data_signal = data.copy().filter(*signal_band, verbose=False) #type: ignore + data_noise = data.copy().filter(*noise_band, verbose=False) #type: ignore + signal_power = np.mean(data_signal.get_data()**2, axis=1) #type: ignore + noise_power = np.mean(data_noise.get_data()**2, axis=1) #type: ignore + + # Calculate the snr using the standard formula for dB + snr = 10 * np.log10(signal_power / (noise_power + np.finfo(float).eps)) + + # TODO: Understand what this does + groups: dict[str, list[str]] = {} + for ch in getattr(data, "ch_names"): + # Look for the space in the channel names and remove the characters after + # This is so we can get both oxy and deoxy to remove, as they will have the same source and detector + base = ch.rsplit(' ', 1)[0] + groups.setdefault(base, []).append(ch) # type: ignore + + # If any of the channels do not meet our threshold, they will get inserted into the bad_channels set + bad_channels: set[str] = set() + for base, ch_list in groups.items(): + if any(s < SNR_THRESHOLD for s, ch in zip(snr, getattr(data, "ch_names")) if ch in ch_list): + bad_channels.update(ch_list) + + # Design and create the figure + print("Creating the figure...") + snr_fig, ax = plt.subplots(figsize=(12, 4), layout="constrained") # type: ignore + colors = [(0/20, 'red'), (SNR_THRESHOLD/20, 'red'), ((SNR_THRESHOLD+.5)/20, 'yellow'), ((SNR_THRESHOLD+1)/20, 'green'), (20/20, 'green')] + cmap = LinearSegmentedColormap.from_list('custom_snr_cmap', colors) + norm = mcolors.Normalize(vmin=0, vmax=20) + scatter = ax.scatter(range(len(snr)), snr, c=snr, cmap=cmap, alpha=0.8, s=100, norm=norm) # type: ignore + ax.set(xlabel="Channel Number", ylabel="Signal-to-Noise Ratio (dB)", xlim=[0, len(snr)], ylim=[0, 20]) + ax.axhline(SNR_THRESHOLD, color='black', linestyle='--', alpha=0.3, linewidth=1) # type: ignore + cbar = snr_fig.colorbar(scatter, ax=ax, label="SNR Thresholds (dB)") # type: ignore + cbar.set_ticks([0, SNR_THRESHOLD, SNR_THRESHOLD+1, 20]) # type: ignore + cbar.set_ticklabels(['0', str(SNR_THRESHOLD), str(SNR_THRESHOLD+1), '20']) # type: ignore + + plt.close() + + print("Successfully calculated signal to noise ratio.") + + return list(bad_channels), snr_fig + + + + +def calculate_peak_power(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5) -> tuple[list[str], Figure, Figure]: + """ + Calculate peak spectral power (PSP) for fNIRS channels and identify bad channels. + + Parameters + ---------- + data : BaseRaw + The loaded data object to process. + l_freq : float, optional + Low cutoff frequency for filtering in Hz (default is 0.7) + h_freq : float, optional + High cutoff frequency for filtering in Hz (default is 1.5) + + Returns + ------- + tuple[list[str], Figure, Figure] + - list[str]: Names of channels below the PSP threshold. + - Figure: Heatmap of all PSP scores. + - Figure: Heatmap of scores above the PSP threshold. + """ + + + # Compute the PSP + _, scores, times = cast(tuple[NDArray[float64], NDArray[float64], list[tuple[float]]], peak_power(data, time_window=PSP_TIME_WINDOW, threshold=PSP_THRESHOLD, l_freq=l_freq, h_freq=h_freq)) + + # Identify channels that don't meet the provided threshold + psp = scores.mean(axis=1) + data.info["bads"] = list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)) + + # Determine the colors based on the threshold, and create the figures + color_stops = ([0.0, PSP_THRESHOLD, PSP_THRESHOLD+0.1, PSP_THRESHOLD+0.2, 1.0], [0.0, PSP_THRESHOLD, PSP_THRESHOLD, 1.0]) + psp1, psp2 = plot_timechannel_quality_metrics(data, scores, times, color_stops, PSP_THRESHOLD, "Peak Spectral Power") + + + return list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)), psp1, psp2 + + +def mark_bads(raw, bad_sci, bad_snr, bad_psp): + bads_combined = list(set(bad_snr) | set(bad_sci) | set(bad_psp)) + print(f"Automatically marked bad channels based on SNR and SCI: {bads_combined}") + + raw.info['bads'].extend(bads_combined) + + # Organize channels into categories + sets = [ + (bad_sci, "SCI"), + (bad_psp, "PSP"), + (bad_snr, "SNR"), + ] + + # Graph what channels were dropped and why they were dropped + channel_categories: dict[str, str] = {} + + for ch in bads_combined: + present_in = [name for s, name in sets if ch in s] + # Create a label for the category + if len(present_in) == 1: + label = f"{present_in[0]} only" + else: + label = " + ".join(sorted(present_in)) + channel_categories[ch] = label + + # Sort channels alphabetically within categories for nicer visualization + categories = sorted(set(channel_categories.values())) + channel_names: list[str] = [] + category_labels: list[str] = [] + for cat in categories: + chs_in_cat = sorted([ch for ch, c in channel_categories.items() if c == cat]) + channel_names.extend(chs_in_cat) + category_labels.extend([cat] * len(chs_in_cat)) + + colors = {cat: FIXED_CATEGORY_COLORS[cat] for cat in categories} + + # Create the figure + fig_dropped, ax = plt.subplots(figsize=(10, max(3, len(channel_names) * 0.3))) # type: ignore + y_pos = range(len(channel_names)) + ax.barh(y_pos, [1]*len(channel_names), color=[colors[cat] for cat in category_labels]) # type: ignore + ax.set_yticks(y_pos) # type: ignore + ax.set_yticklabels(channel_names) # type: ignore + ax.set_xlabel("Marked as Bad") # type: ignore + ax.set_title(f"Bad Channels by Method for") # type: ignore + ax.set_xlim(0, 1) + ax.set_xticks([]) # type: ignore + ax.grid(axis='x', linestyle='--', alpha=0.7) # type: ignore + + # Add a legend denoting why the channels were bad + for label, color in colors.items(): + ax.bar(0, 0, color=color, label=label) # type: ignore + ax.legend() # type: ignore + + fig_dropped.tight_layout() + + + raw_before = deepcopy(raw) + bads_channels = [ch for ch in raw.ch_names if ch in raw.info['bads']] + print(bads_channels) + if bads_channels: + fig_raw_before = raw_before.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], picks=bads_channels, title="What they were BEFORE", show=False) + else: + fig_dropped = None + fig_raw_before = None + + return raw, fig_dropped, fig_raw_before, bads_channels + + + + +def filter_the_data(raw_haemo): + # --- STEP 5: Filtering (0.01–0.2 Hz bandpass) --- + fig_filter = raw_haemo.compute_psd(fmax=2).plot( + average=True, xscale="log", color="r", show=False, amplitude=False + ) + + raw_haemo = raw_haemo.filter(l_freq=None, h_freq=0.4, h_trans_bandwidth=0.2) + + raw_haemo.compute_psd(fmax=2).plot( + average=True, xscale="log", axes=fig_filter.axes, color="g", amplitude=False, show=False + ) + + fig_raw_haemo_filter = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="Filtered HbO and HbR", show=False) + + return fig_filter, fig_raw_haemo_filter + + + +def epochs_calculations(raw_haemo, events, event_dict): + fig_epochs = [] # List to store figures + + # Create epochs from raw data + epochs = Epochs(raw_haemo, + events, + event_id=event_dict, + tmin=-5, + tmax=15, + baseline=(None, 0)) + + # Make a copy of the epochs and drop bad ones + epochs2 = epochs.copy() + epochs2.drop_bad() + + # Plot drop log + # TODO: Why show this if we never use epochs2? + fig_epochs_dropped = epochs2.plot_drop_log() + fig_epochs.append(("fig_epochs_dropped", fig_epochs_dropped)) + + # Plot for each condition + fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 4)) + for idx, condition in enumerate(epochs.event_id.keys()): + # Plot images for each condition + fig_epochs_data = epochs[condition].plot_image( + combine="mean", + vmin=-1, + vmax=1, + ts_args=dict(ylim=dict(hbo=[-1, 1], hbr=[-1, 1])), + show=False + ) + for i in fig_epochs_data: + ax = fig.axes[0] + original_title = ax.get_title() + ax.set_title(f"{condition}: {original_title}") + fig_epochs.append((f"fig_{condition}_data_{idx}", i)) # Store with a unique name + + # Evoked average figure for each condition + evoked_avg = epochs[condition].average() + clims = dict(hbo=[-1, 1], hbr=[1, -1]) + condition_fig = evoked_avg.plot_image(clim=clims, show=False) + + for ax in condition_fig.axes: + original_title = ax.get_title() + ax.set_title(f"{original_title} - {condition}") + fig_epochs.append((f"evoked_avg_{condition}", condition_fig)) # Store with a unique name + + # Prepare evokeds and colors for topographic plot + evokeds3 = [] + colors = [] + conditions = list(epochs.event_id.keys()) + cmap = plt.cm.get_cmap("tab10", len(conditions)) + + for idx, cond in enumerate(conditions): + evoked = epochs[cond].average(picks="hbo") + evokeds3.append(evoked) + colors.append(cmap(idx)) + + # Create the topographic plot + fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 4)) + help = plot_evoked_topo(evokeds3, color=colors, axes=axes, legend=False, show=False) + + # Build custom legend + lines = [] + for color in colors: + line = plt.Line2D([0], [0], color=color, lw=2) + lines.append(line) + + fig.legend(lines, conditions, loc="lower right") + fig_epochs.append(("evoked_topo", help)) # Store with a unique name + + # Evoked response for specific condition ("Reach") + evoked_stim1 = epochs['Reach'].average() + + fig_evoked_hbo = evoked_stim1.copy().pick(picks='hbo').plot(time_unit='s', show=False) + fig_evoked_hbr = evoked_stim1.copy().pick(picks='hbr').plot(time_unit='s', show=False) + + fig_epochs.append(("fig_evoked_hbo", fig_evoked_hbo)) # Store with a unique name + fig_epochs.append(("fig_evoked_hbr", fig_evoked_hbr)) # Store with a unique name + + print("Evoked HbO peak amplitude:", evoked_stim1.copy().pick(picks='hbo').data.max()) + + evokeds = {} + for condition in epochs2.event_id: + evokeds[condition] = epochs2[condition].average() + print(f"Condition '{condition}': {len(epochs2[condition])} epochs averaged.") + + all_evokeds = {} + for condition in epochs.event_id: + if condition not in all_evokeds: + all_evokeds[condition] = [] + all_evokeds[condition].append(epochs[condition].average()) + + group_aucs = {} + # TODO: group averages with a single person? + group_averages = {cond: grand_average(evokeds) for cond, evokeds in all_evokeds.items()} + for condition, evoked in group_averages.items(): + group_aucs[condition] = {} + for pick in ["hbo", "hbr"]: + picks_idx = [i for i, ch in enumerate(evoked.ch_names) if pick in ch] + if not picks_idx: + continue + + data = evoked.data[picks_idx, :].mean(axis=0) + t_start, t_end = 0, 15 + times_mask = (evoked.times >= t_start) & (evoked.times <= t_end) + data_segment = data[times_mask] + times_segment = evoked.times[times_mask] + + auc = np.trapezoid(data_segment, times_segment) + group_aucs[condition][pick] = auc + + # Final evoked comparison plot for each condition + for condition in conditions: + if condition not in evokeds: + continue + evoked = evokeds[condition] + + fig, ax = plt.subplots(figsize=(6, 5)) + legend_labels = ["Oxyhaemoglobin"] + + for pick, color in zip(["hbo", "hbr"], ["r", "b"]): + plot_compare_evokeds( + evoked, + combine="mean", + picks=pick, + axes=ax, + show=False, + colors=[color], + legend=False, + title=f"Participant: nCondition: {condition}", + ylim=dict(hbo=[-0.5, 1], hbr=[-0.5, 1]), + show_sensors=False, + ) + auc_value = group_aucs.get(condition, {}).get(pick, None) + if auc_value is not None: + label = f"{pick.upper()} AUC: {auc_value * 1e6:.4f} µM·s" + else: + label = f"{pick.upper()} AUC: N/A" + legend_labels.append(label) + if len(legend_labels) == 2: + legend_labels.append("Deoxyhaemoglobin") + + ax.legend(legend_labels) + + fig_epochs.append((f"fig_{condition}_compare_evokeds", fig)) # Store with a unique name + + return epochs, fig_epochs + + + + + +def make_design_matrix(raw_haemo, short_chans): + + raw_haemo.resample(1, npad="auto") + short_chans.resample(1) + raw_haemo._data = raw_haemo._data * 1e6 + # 2) Create design matrix + design_matrix = make_first_level_design_matrix( + raw=raw_haemo, + hrf_model='fir', + stim_dur=0.5, + fir_delays=range(15), + drift_model='cosine', + high_pass=0.01, + oversampling=1, + min_onset=-125, + add_regs=short_chans.get_data().T, + add_reg_names=short_chans.ch_names + ) + + print(design_matrix.head()) + print(design_matrix.columns) + + + fig, ax1 = plt.subplots(figsize=(10, 6), constrained_layout=True) + _ = plot_design_matrix(design_matrix, axes=ax1) + + return design_matrix, fig + + + + + + + +def generate_montage_locations(): + """Get standard MNI montage locations in dataframe. + + Data is returned in the same format as the eeg_positions library. + """ + # standard_1020 and standard_1005 are in MNI (fsaverage) space already, + # but we need to undo the scaling that head_scale will do + montage = make_standard_montage( + "standard_1005", head_size=0.09700884729534559 + ) + for d in montage.dig: + d["coord_frame"] = 2003 + montage.dig[:] = montage.dig[3:] + montage.add_mni_fiducials() # now in fsaverage space + coords = pd.DataFrame.from_dict(montage.get_positions()["ch_pos"]).T + coords["label"] = coords.index + coords = coords.rename(columns={0: "x", 1: "y", 2: "z"}) + + return coords.reset_index(drop=True) + + +def _find_closest_standard_location(position, reference, *, out="label"): + """Return closest montage label to coordinates. + + Parameters + ---------- + position : array, shape (3,) + Coordinates. + reference : dataframe + As generated by _generate_montage_locations. + trans_pos : str + Apply a transformation to positions to specified frame. + Use None for no transformation. + """ + + p0 = np.array(position) + p0.shape = (-1, 3) + # head_mri_t, _ = _get_trans("fsaverage", "head", "mri") + # p0 = apply_trans(head_mri_t, p0) + dists = cdist(p0, np.asarray(reference[["x", "y", "z"]], float)) + + if out == "label": + min_idx = np.argmin(dists) + return reference["label"][min_idx] + else: + assert out == "dists" + return dists + + + +def _source_detector_fold_table(raw, cidx, reference, fold_tbl, interpolate): + src = raw.info["chs"][cidx]["loc"][3:6] + det = raw.info["chs"][cidx]["loc"][6:9] + + ref_lab = list(reference["label"]) + dists = _find_closest_standard_location([src, det], reference, out="dists") + src_min, det_min = np.argmin(dists, axis=1) + src_name, det_name = ref_lab[src_min], ref_lab[det_min] + + tbl = fold_tbl.query("Source == @src_name and Detector == @det_name") + dist = np.linalg.norm(dists[[0, 1], [src_min, det_min]]) + # Try reversing source and detector + if len(tbl) == 0: + tbl = fold_tbl.query("Source == @det_name and Detector == @src_name") + if len(tbl) == 0 and interpolate: + # Try something hopefully not too terrible: pick the one with the + # smallest net distance + good = np.isin(fold_tbl["Source"], reference["label"]) & np.isin( + fold_tbl["Detector"], reference["label"] + ) + assert good.any() + tbl = fold_tbl[good] + assert len(tbl) + src_idx = [ref_lab.index(src) for src in tbl["Source"]] + det_idx = [ref_lab.index(det) for det in tbl["Detector"]] + # Original + tot_dist = np.linalg.norm([dists[0, src_idx], dists[1, det_idx]], axis=0) + assert tot_dist.shape == (len(tbl),) + idx = np.argmin(tot_dist) + dist_1 = tot_dist[idx] + src_1, det_1 = ref_lab[src_idx[idx]], ref_lab[det_idx[idx]] + # And the reverse + tot_dist = np.linalg.norm([dists[0, det_idx], dists[1, src_idx]], axis=0) + idx = np.argmin(tot_dist) + dist_2 = tot_dist[idx] + src_2, det_2 = ref_lab[det_idx[idx]], ref_lab[src_idx[idx]] + if dist_1 < dist_2: + new_dist, src_use, det_use = dist_1, src_1, det_1 + else: + new_dist, src_use, det_use = dist_2, det_2, src_2 + + + tbl = fold_tbl.query("Source == @src_use and Detector == @det_use") + tbl = tbl.copy() + tbl["BestSource"] = src_name + tbl["BestDetector"] = det_name + tbl["BestMatchDistance"] = dist + tbl["MatchDistance"] = new_dist + assert len(tbl) + else: + tbl = tbl.copy() + tbl["BestSource"] = src_name + tbl["BestDetector"] = det_name + tbl["BestMatchDistance"] = dist + tbl["MatchDistance"] = dist + + tbl = tbl.copy() # don't get warnings about setting values later + return tbl + + + +def _read_fold_xls(fname, atlas="Juelich"): + """Read fOLD toolbox xls file. + + The values are then manipulated in to a tidy dataframe. + + Note the xls files are not included as no license is provided. + + Parameters + ---------- + fname : str + Path to xls file. + atlas : str + Requested atlas. + """ + page_reference = {"AAL2": 2, "AICHA": 5, "Brodmann": 8, "Juelich": 11, "Loni": 14} + + tbl = pd.read_excel(fname, sheet_name=page_reference[atlas]) + + # Remove the spacing between rows + empty_rows = np.where(np.isnan(tbl["Specificity"]))[0] + tbl = tbl.drop(empty_rows).reset_index(drop=True) + + # Empty values in the table mean its the same as above + for row_idx in range(1, tbl.shape[0]): + for col_idx, col in enumerate(tbl.columns): + if not isinstance(tbl[col][row_idx], str): + if np.isnan(tbl[col][row_idx]): + tbl.iloc[row_idx, col_idx] = tbl.iloc[row_idx - 1, col_idx] + + tbl["Specificity"] = tbl["Specificity"] * 100 + tbl["brainSens"] = tbl["brainSens"] * 100 + return tbl + + +def _check_load_fold(fold_files, atlas): + # _validate_type(fold_files, (list, "path-like", None), "fold_files") + if fold_files is None: + fold_files = get_config("MNE_NIRS_FOLD_PATH") + if fold_files is None: + raise ValueError( + "MNE_NIRS_FOLD_PATH not set, either set it using " + "mne.set_config or pass fold_files as str or list" + ) + if not isinstance(fold_files, list): # path-like + fold_files = _check_fname( + fold_files, + overwrite="read", + must_exist=True, + name="fold_files", + need_dir=True, + ) + fold_files = [op.join(fold_files, f"10-{x}.xls") for x in (5, 10)] + + fold_tbl = pd.DataFrame() + for fi, fname in enumerate(fold_files): + fname = _check_fname( + fname, overwrite="read", must_exist=True, name=f"fold_files[{fi}]" + ) + fold_tbl = pd.concat( + [fold_tbl, _read_fold_xls(fname, atlas=atlas)], ignore_index=True + ) + return fold_tbl + + + +def fold_channel_specificity_normal(raw, fold_files=None, atlas="Juelich", interpolate=False): + """Return the landmarks and specificity a channel is sensitive to. + + Parameters + + """ # noqa: E501 + _validate_type(raw, BaseRaw, "raw") + + reference_locations = generate_montage_locations() + + fold_tbl = _check_load_fold(fold_files, atlas) + + chan_spec = list() + for cidx in range(len(raw.ch_names)): + tbl = _source_detector_fold_table( + raw, cidx, reference_locations, fold_tbl, interpolate + ) + chan_spec.append(tbl.reset_index(drop=True)) + + return chan_spec + + + +def resource_path(relative_path): + """ + Get absolute path to resource regardless of running directly or packaged using PyInstaller + """ + + if hasattr(sys, '_MEIPASS'): + # PyInstaller bundle path + base_path = sys._MEIPASS + else: + base_path = os.path.abspath(".") + + return os.path.join(base_path, relative_path) + + + +def fold_channels(raw: BaseRaw) -> None: + + + # Locate the fOLD excel files + set_config('MNE_NIRS_FOLD_PATH', resource_path("../../mne_data/fOLD/fOLD-public-master/Supplementary")) # type: ignore + + output = None + + # List to store the results + landmark_specificity_data: list[dict[str, Any]] = [] + + # Filter the data to only what we want + hbo_channel_names = cast(list[str], getattr(raw.copy().pick(picks='hbo'), "ch_names")) # type: ignore + + # Format the output to make it slightly easier to read + + if True: + + num_channels = len(hbo_channel_names) + rows, cols = 4, 7 # 6 rows and 4 columns of pie charts + fig, axes = plt.subplots(rows, cols, figsize=(16, 10), constrained_layout=True) + axes = axes.flatten() # Flatten the axes array for easier indexing + + # If more pie charts than subplots, create extra subplots + if num_channels > rows * cols: + fig, axes = plt.subplots((num_channels // cols) + 1, cols, figsize=(16, 10), constrained_layout=True) + axes = axes.flatten() + + # Create a list for consistent color mapping + landmarks = [ + "1 - Primary Somatosensory Cortex", + "2 - Primary Somatosensory Cortex", + "3 - Primary Somatosensory Cortex", + "4 - Primary Motor Cortex", + "5 - Somatosensory Association Cortex", + "6 - Pre-Motor and Supplementary Motor Cortex", + "7 - Somatosensory Association Cortex", + "8 - Includes Frontal eye fields", + "9 - Dorsolateral prefrontal cortex", + "10 - Frontopolar area", + "11 - Orbitofrontal area", + "17 - Primary Visual Cortex (V1)", + "18 - Visual Association Cortex (V2)", + "19 - V3", + "20 - Inferior Temporal gyrus", + "21 - Middle Temporal gyrus", + "22 - Superior Temporal Gyrus", + "23 - Ventral Posterior cingulate cortex", + "24 - Ventral Anterior cingulate cortex", + "25 - Subgenual cortex", + "32 - Dorsal anterior cingulate cortex", + "37 - Fusiform gyrus", + "38 - Temporopolar area", + "39 - Angular gyrus, part of Wernicke's area", + "40 - Supramarginal gyrus part of Wernicke's area", + "41 - Primary and Auditory Association Cortex", + "42 - Primary and Auditory Association Cortex", + "43 - Subcentral area", + "44 - pars opercularis, part of Broca's area", + "45 - pars triangularis Broca's area", + "46 - Dorsolateral prefrontal cortex", + "47 - Inferior prefrontal gyrus", + "48 - Retrosubicular area", + "Brain_Outside", + ] + + cmap1 = plt.cm.get_cmap('tab20') # First 20 colors + cmap2 = plt.cm.get_cmap('tab20b') # Next 20 colors + + # Combine the colors from both colormaps + colors = [cmap1(i) for i in range(20)] + [cmap2(i) for i in range(20)] # Total 40 colors + + landmarks.sort(key=lambda x: (int(x.split(" - ")[0]) if x.split(" - ")[0].isdigit() else float('inf'))) + + landmark_color_map = {landmark: colors[i % len(colors)] for i, landmark in enumerate(landmarks)} + + # Iterate over each channel + for idx, channel_name in enumerate(hbo_channel_names): + + # Run the fOLD on the selected channel + channel_data = raw.copy().pick(picks=channel_name) # type: ignore + + output = cast(list[DataFrame], fold_channel_specificity_normal(channel_data, interpolate=True, atlas='Brodmann')) + + # Process each DataFrame that fold_channel_specificity returns + for df_data in output: + + # Extract the relevant columns + useful_data = df_data[['Landmark', 'Specificity']] + + # Store the results + landmark_specificity_data.append({ + 'Channel': channel_name, + 'Data': useful_data, + }) + + + # Plot the results + # TODO: Fix this + if True: + unique_landmarks = sorted(useful_data['Landmark'].unique()) + color_list = [landmark_color_map[landmark] for landmark in useful_data['Landmark']] + + # Plot specificity for each channel + ax = axes[idx] + + labels = [f'{landmark.split(" - ")[0]}' if landmark != 'Brain_Outside' else 'B' for landmark in useful_data['Landmark']] + + wedges, texts, autotexts = ax.pie( + useful_data['Specificity'], + autopct='%1.1f%%', + startangle=90, + labels=labels, + labeldistance=1.05, + colors=color_list) + + ax.set_title(f'{channel_name}') + ax.axis('equal') + + landmark_specificity_data = [] + + # TODO: Fix this + if True: + handles = [ + plt.Line2D([0], [0], marker='o', color='w', label=landmark, markersize=10, + markerfacecolor=landmark_color_map[landmark]) + for landmark in landmarks + ] + n_landmarks = len(landmarks) + + # Calculate the figure size based on number of rows and columns + fig_width = 5 + fig_height = n_landmarks / 4 + + # Create a new figure window for the legend + legend_fig = plt.figure(figsize=(fig_width, fig_height)) + legend_axes = legend_fig.add_subplot(111) + legend_axes.axis('off') # Turn off axis for the legend window + legend_axes.legend(handles=handles, loc='center', fontsize=10, title="Landmarks") + + for ax in axes[len(hbo_channel_names):]: + ax.axis('off') + + return fig, legend_fig + + + + +def individual_significance(raw_haemo, glm_est): + + # TODO: BAD! + cha = glm_est.to_dataframe() + + ch_summary = cha.query("Condition.str.startswith('Reach_delay_') and Chroma == 'hbo'", engine='python') + + print(ch_summary.head()) + + channel_averages = ch_summary.groupby('ch_name')['theta'].mean().reset_index() + print(channel_averages.head()) + + + reach_ch_summary = ch_summary.query( + "Chroma == 'hbo' and Condition.str.startswith('Reach_delay_')", engine='python' + ) + + # Function to correct p-values per channel + def fdr_correct_per_channel(df): + df = df.copy() + df['pval_fdr'] = multipletests(df['p_value'], method='fdr_bh')[1] + return df + + # Apply FDR correction grouped by channel + corrected = reach_ch_summary.groupby("ch_name", group_keys=False).apply(fdr_correct_per_channel) + + # Determine which channels are significant across any delay + sig_channels = ( + corrected.groupby('ch_name') + .apply(lambda df: (df['pval_fdr'] < 0.05).any()) + .reset_index(name='significant') + ) + + # Merge with mean theta (optional for plotting) + mean_theta = reach_ch_summary.groupby('ch_name')['theta'].mean().reset_index() + sig_channels = sig_channels.merge(mean_theta, on='ch_name') + print(sig_channels) + + + # For example, take the minimum corrected p-value per channel + summary_pvals = corrected.groupby('ch_name')['pval_fdr'].min().reset_index() + print(summary_pvals) + + + def parse_ch_name(ch_name): + # Extract numbers after S and D in names like 'S10_D5 hbo' + match = re.match(r'S(\d+)_D(\d+)', ch_name) + if match: + return int(match.group(1)), int(match.group(2)) + else: + return None, None + + + min_pvals = corrected.groupby('ch_name')['pval_fdr'].min().reset_index() + + # Merge the real p-values into sig_channels / avg_df + avg_df = sig_channels.merge(min_pvals, on='ch_name') + + # Rename columns for consistency + avg_df = avg_df.rename(columns={'theta': 't_or_theta', 'pval_fdr': 'p_value'}) + + # Add Source and Detector columns again + avg_df['Source'], avg_df['Detector'] = zip(*avg_df['ch_name'].map(parse_ch_name)) + + # Keep relevant columns + avg_df = avg_df[['Source', 'Detector', 't_or_theta', 'p_value']].dropna() + + ABS_SIGNIFICANCE_THETA_VALUE = 1 + ABS_SIGNIFICANCE_T_VALUE = 1 + P_THRESHOLD = 0.05 + SOURCE_DETECTOR_SEPARATOR = "_" + Reach = "Reach" + + + t_or_theta = 'theta' + for _, row in avg_df.iterrows(): # type: ignore + print(f"Source {row['Source']} <-> Detector {row['Detector']}: " + f"Avg {t_or_theta}-value = {row['t_or_theta']:.3f}, Avg p-value = {row['p_value']:.3f}") + + # Extract the cource and detector positions from raw + src_pos: dict[int, tuple[float, float]] = {} + det_pos: dict[int, tuple[float, float]] = {} + for ch in getattr(raw_haemo, "info")["chs"]: + ch_name = ch['ch_name'] + if not ch_name or not ch['loc'].any(): + continue + parts = ch_name.split()[0] + src_str, det_str = parts.split(SOURCE_DETECTOR_SEPARATOR) + src_num = int(src_str[1:]) + det_num = int(det_str[1:]) + src_pos[src_num] = ch['loc'][3:5] + det_pos[det_num] = ch['loc'][6:8] + + # Set up the plot + fig, ax = plt.subplots(figsize=(8, 6)) # type: ignore + + # Plot the sources + for pos in src_pos.values(): + ax.scatter(pos[0], pos[1], s=120, c='k', marker='o', edgecolors='white', linewidths=1, zorder=3) # type: ignore + + # Plot the detectors + for pos in det_pos.values(): + ax.scatter(pos[0], pos[1], s=120, c='k', marker='s', edgecolors='white', linewidths=1, zorder=3) # type: ignore + + # Ensure that the colors stay within the boundaries even if they are over or under the max/min values + if t_or_theta == 't': + norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_T_VALUE, vmax=ABS_SIGNIFICANCE_T_VALUE) + elif t_or_theta == 'theta': + norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_THETA_VALUE, vmax=ABS_SIGNIFICANCE_THETA_VALUE) + + cmap: mcolors.Colormap = plt.get_cmap('seismic') + + # Plot connections with avg t-values + for row in avg_df.itertuples(): + src: int = cast(int, row.Source) # type: ignore + det: int = cast(int, row.Detector) # type: ignore + tval: float = cast(float, row.t_or_theta) # type: ignore + pval: float = cast(float, row.p_value) # type: ignore + + + if src in src_pos and det in det_pos: + x = [src_pos[src][0], det_pos[det][0]] + y = [src_pos[src][1], det_pos[det][1]] + style = '-' if pval <= P_THRESHOLD else '--' + ax.plot(x, y, linestyle=style, color=cmap(norm(tval)), linewidth=4, alpha=0.9, zorder=2) # type: ignore + + # Format the Colorbar + sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) + sm.set_array([]) + cbar = plt.colorbar(sm, ax=ax, shrink=0.85) # type: ignore + cbar.set_label(f'Average {Reach} {t_or_theta} value (hbo)', fontsize=11) # type: ignore + + # Formatting the subplots + ax.set_aspect('equal') + ax.set_title(f"Average {t_or_theta} values for {Reach} (HbO)", fontsize=14) # type: ignore + ax.set_xlabel('X position (m)', fontsize=11) # type: ignore + ax.set_ylabel('Y position (m)', fontsize=11) # type: ignore + ax.grid(True, alpha=0.3) # type: ignore + + # Set axis limits to be 1cm more than the optode positions + all_x = [pos[0] for pos in src_pos.values()] + [pos[0] for pos in det_pos.values()] + all_y = [pos[1] for pos in src_pos.values()] + [pos[1] for pos in det_pos.values()] + ax.set_xlim(min(all_x)-0.01, max(all_x)+0.01) + ax.set_ylim(min(all_y)-0.01, max(all_y)+0.01) + + fig.tight_layout() + + + return fig + +# TODO: Hardcoded +def group_significance( + raw_haemo, + all_cha: pd.DataFrame, + condition: str, + correction: str = "fdr_bh" +) -> plt.Figure: + """ + Compute group-level significance using weighted Stouffer's method and plot results. + + Args: + raw_haemo: Raw haemoglobin MNE object (used for optode positions) + all_cha: DataFrame with columns including 'ID', 'Condition', 'p_value', 'theta', 'df', 'ch_name', 'Chroma' + condition: condition prefix, e.g., 'Reach' + correction: p-value correction method ('fdr_bh' or 'bonferroni') + + Returns: + Matplotlib Figure with group-level theta values and significance. + """ + + assert "ID" in all_cha.columns, "'ID' column missing in input data" + assert len(raw_haemo) >= 1, "At least one raw haemoglobin object is required" + + condition_prefix = f"{condition}_delay" + + # Filter relevant data + ch_summary = all_cha.query( + "Condition.str.startswith(@condition_prefix) and Chroma == 'hbo'", + engine='python' + ).copy() + + + logger.info("=== ch_summary head ===") + logger.info(ch_summary.head()) + + logger.info("\nSummary stats:") + logger.info(f"Total rows: {len(ch_summary)}") + logger.info(f"Unique subjects: {ch_summary['ID'].nunique() if 'ID' in ch_summary.columns else 'ID column missing'}") + logger.info(f"Unique conditions: {ch_summary['Condition'].unique()}") + logger.info(f"Unique channels (Source-Detector pairs): {ch_summary.groupby(['Source', 'Detector']).ngroups}") + + logger.info("\nSample p_values:") + logger.info(ch_summary['p_value'].describe()) + + if ch_summary.empty: + raise ValueError(f"No data found for condition prefix: {condition_prefix}") + + # --- For debugging + logger.info(f"Total rows after filtering for condition '{condition_prefix}': {len(ch_summary)}") + logger.info(f"Unique channels: {ch_summary['ch_name'].nunique()}") + logger.info(f"Participants: {ch_summary['ID'].nunique()}") + + # Step 1: Select the peak regressor (~6s after stimulus onset) + peak_regressor = f"{condition}_delay_6" + peak_data = ch_summary[ch_summary["Condition"] == peak_regressor].copy() + + logger.info(f"\n=== Logging all values for {peak_regressor} ===") + for row in peak_data.itertuples(index=False): + logger.info( + f"Subject: {row.ID}, " + f"Channel: {row.ch_name}, " + f"Source: {row.Source}, Detector: {row.Detector}, " + f"theta: {row.theta:.4f}, " + f"p_value: {row.p_value:.6f}, " + f"df: {row.df}" + ) + + if peak_data.empty: + raise ValueError(f"No data found for peak regressor: {peak_regressor}") + + # Step 2: Combine per-channel stats across subjects + group_results = [] + + for (src, det), group in peak_data.groupby(["Source", "Detector"]): + pvals = group["p_value"].values + thetas = group["theta"].values + dfs = group["df"].values + + # Weighted Stouffer's method + weights = np.sqrt(dfs) + z_scores = norm.isf(pvals) + combined_z = np.sum(weights * z_scores) / np.sqrt(np.sum(weights**2)) + combined_p = norm.sf(combined_z) + + theta_avg = np.average(thetas, weights=weights) + + group_results.append({ + "Source": src, + "Detector": det, + "theta_avg": theta_avg, + "combined_p": combined_p + }) + + # Step 3: Create combined_df + combined_df = pd.DataFrame(group_results) + + # Step 4: Multiple comparisons correction + _, pvals_corr, _, significant = multipletests( + combined_df["combined_p"], alpha=0.05, method=correction + ) + + combined_df["pval_corr"] = pvals_corr + combined_df["significant"] = significant + + logger.info(f"Used peak regressor: {peak_regressor}") + logger.info(f"Channels tested: {len(combined_df)}") + logger.info(f"Significant channels after correction: {combined_df['significant'].sum()}") + # Get optode positions from the first raw file + raw = raw_haemo + src_pos, det_pos = {}, {} + for ch in raw.info["chs"]: + ch_name = ch["ch_name"] + if not ch_name or not ch["loc"].any(): + continue + parts = ch_name.split()[0] + src_str, det_str = parts.split("_") + src_num = int(src_str[1:]) + det_num = int(det_str[1:]) + src_pos[src_num] = ch["loc"][3:5] + det_pos[det_num] = ch["loc"][6:8] + + # Plotting parameters + ABS_SIGNIFICANCE_THETA_VALUE = 1 + P_THRESHOLD = 0.05 + cmap = plt.get_cmap("seismic") + norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_THETA_VALUE, vmax=ABS_SIGNIFICANCE_THETA_VALUE) + + fig, ax = plt.subplots(figsize=(8, 6)) + + # Plot optodes + for pos in src_pos.values(): + ax.scatter(*pos, s=120, c="k", marker="o", edgecolors="white", linewidths=1, zorder=3) + for pos in det_pos.values(): + ax.scatter(*pos, s=120, c="k", marker="s", edgecolors="white", linewidths=1, zorder=3) + + # Plot connections colored by average theta, solid if significant + for row in combined_df.itertuples(): + src, det = int(row.Source), int(row.Detector) + tval, pval = row.theta_avg, row.pval_corr + if src in src_pos and det in det_pos: + x = [src_pos[src][0], det_pos[det][0]] + y = [src_pos[src][1], det_pos[det][1]] + linestyle = "-" if pval <= P_THRESHOLD else "--" + ax.plot(x, y, linestyle=linestyle, color=cmap(norm(tval)), linewidth=4, alpha=0.9, zorder=2) + + # Colorbar + sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) + sm.set_array([]) + cbar = plt.colorbar(sm, ax=ax, shrink=0.85) + cbar.set_label(f"Average {condition_prefix.rstrip('_')} θ-value (HbO)", fontsize=11) + + # Format axes + ax.set_aspect("equal") + ax.set_title(f"Group-level θ-values for {condition_prefix.rstrip('_')} (HbO)", fontsize=14) + ax.set_xlabel("X position (m)", fontsize=11) + ax.set_ylabel("Y position (m)", fontsize=11) + ax.grid(True, alpha=0.3) + + all_x = [p[0] for p in src_pos.values()] + [p[0] for p in det_pos.values()] + all_y = [p[1] for p in src_pos.values()] + [p[1] for p in det_pos.values()] + ax.set_xlim(min(all_x) - 0.01, max(all_x) + 0.01) + ax.set_ylim(min(all_y) - 0.01, max(all_y) + 0.01) + + fig.tight_layout() + fig.show() + + + + +def plot_glm_results(file_path, raw_haemo, glm_est, design_matrix): + + dm = design_matrix.copy() + + rois = dict(AllChannels=range(len(raw_haemo.ch_names))) + conditions = design_matrix.columns + df_individual = glm_est.to_dataframe_region_of_interest(rois, conditions) + + df_individual["ID"] = file_path + # df_individual["theta"] = [t * 1.0e6 for t in df_individual["theta"]] + + condition_of_interest="Reach" + + # Filter for the condition of interest and FIR delays + df_individual["isCondition"] = [condition_of_interest in n for n in df_individual["Condition"]] + df_individual["isDelay"] = ["delay" in n for n in df_individual["Condition"]] + df_individual = df_individual.query("isDelay and isCondition") + + # Remove other conditions from design matrix + dm_condition_cols = [col for col in dm.columns if condition_of_interest in col] + dm_cond = dm[dm_condition_cols] + + # Add a numeric delay column + def extract_delay_number(condition_str): + # Extracts the number at the end of a string like 'Reach_delay_5' + return int(condition_str.split("_")[-1]) + + df_individual["DelayNum"] = df_individual["Condition"].apply(extract_delay_number) + + # Now separate and sort using numeric delay + df_hbo = df_individual[df_individual["Chroma"] == "hbo"].sort_values("DelayNum") + df_hbr = df_individual[df_individual["Chroma"] == "hbr"].sort_values("DelayNum") + + vals_hbo = df_hbo["theta"].values + vals_hbr = df_hbr["theta"].values + + # Create the plot + fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(19, 10)) + + # Scale design matrix components using numpy arrays instead of pandas operations + dm_cond_values = dm_cond.values + dm_cond_scaled_hbo = dm_cond_values * vals_hbo.reshape(1, -1) + dm_cond_scaled_hbr = dm_cond_values * vals_hbr.reshape(1, -1) + + # Create time axis relative to stimulus onset + time = dm_cond.index - np.ceil(raw_haemo.annotations.onset[1]) + + # Plot + axes[0].plot(time, dm_cond_values) + axes[1].plot(time, dm_cond_scaled_hbo) + axes[2].plot(time, np.sum(dm_cond_scaled_hbo, axis=1), 'r') + axes[2].plot(time, np.sum(dm_cond_scaled_hbr, axis=1), 'b') + + # Format plots + for ax in range(3): + axes[ax].set_xlim(-5, 25) + axes[ax].set_xlabel("Time (s)") + axes[0].set_ylim(-0.2, 1.2) + axes[1].set_ylim(-0.5, 1) + axes[2].set_ylim(-0.5, 1) + axes[0].set_title(f"FIR Model (Unscaled)") + axes[1].set_title(f"FIR Components (Scaled by {condition_of_interest} GLM Estimates)") + axes[2].set_title(f"Evoked Response ({condition_of_interest})") + axes[0].set_ylabel("FIR Model") + axes[1].set_ylabel("Oxyhaemoglobin (ΔμMol)") + axes[2].set_ylabel("Haemoglobin (ΔμMol)") + axes[2].legend(["Oxyhaemoglobin", "Deoxyhaemoglobin"]) + + + print(f"Number of FIR bins: {len(vals_hbo)}") + print(f"Mean theta (HbO): {np.mean(vals_hbo):.4f}") + print(f"Sum of theta (HbO): {np.sum(vals_hbo):.4f}") + print(f"Mean theta (HbR): {np.mean(vals_hbr):.4f}") + print(f"Sum of theta (HbR): {np.sum(vals_hbr):.4f}") + + return fig + + + +def plot_3d_evoked_array( + inst: Union[BaseRaw, EvokedArray, Info], + statsmodel_df: DataFrame, + picks: Optional[Union[str, list[str]]] = "hbo", + value: str = "Coef.", + background: str = "w", + figure: Optional[object] = None, + clim: Union[str, dict[str, Union[str, list[float]]]] = "auto", + mode: str = "weighted", + colormap: str = "RdBu_r", + surface: str = "pial", + hemi: str = "both", + size: int = 800, + view: Optional[Union[str, dict[str, float]]] = None, + colorbar: bool = True, + distance: float = 0.03, + subjects_dir: Optional[str] = None, + src: Optional[SourceSpaces] = None, + verbose: bool = False, +) -> Brain: + '''Ported from MNE''' + + info: Info = cast(Info, deepcopy(inst if isinstance(inst, Info) else inst.info)) # type: ignore + if not (getattr(info, "ch_names") == list(statsmodel_df["ch_name"].values)): # type: ignore + raise RuntimeError( + 'MNE data structure does not match dataframe ' + f'results.\nMNE = {getattr(info, "ch_names")}.\n' + f'GLM = {list(statsmodel_df["ch_name"].values)}' # type: ignore + ) + + ea = EvokedArray(np.tile(statsmodel_df[value].values.T, (1, 1)).T, info.copy()) # type: ignore + + # TODO: mimic behaviour of other MNE-NIRS glm plotting options + if picks is not None: + ea = ea.pick(picks=picks) # type: ignore + + if subjects_dir is None: + subjects_dir = os.environ["SUBJECTS_DIR"] + if src is None: + fname_src_fs = os.path.join( + subjects_dir, "fsaverage", "bem", "fsaverage-ico-5-src.fif" + ) + src = read_source_spaces(fname_src_fs, verbose=verbose) + + picks = getattr(ea, "info")["ch_names"] + + # Set coord frame + for idx in range(len(getattr(ea, "ch_names"))): + getattr(ea, "info")["chs"][idx]["coord_frame"] = 4 + + # Generate source estimate + kwargs = dict( + evoked=ea, + subject="fsaverage", + trans=Transform('head', 'mri', np.eye(4)), + distance=distance, + mode=mode, + surface=surface, + subjects_dir=subjects_dir, + src=src, + project=True, + ) + + stc = stc_near_sensors(picks=picks, **kwargs, verbose=verbose) # type: ignore + + assert isinstance(stc, SourceEstimate) + + # Produce brain plot + brain: Brain = stc.plot( # type: ignore + src=src, + subjects_dir=subjects_dir, + hemi=hemi, + surface=surface, + initial_time=0, + clim=clim, # type: ignore + size=size, + colormap=colormap, + figure=figure, + background=background, + colorbar=colorbar, + verbose=verbose, + ) + if view is not None: + brain.show_view(view) # type: ignore + + return brain + + + +def brain_3d_visualization(raw_haemo, df_cha, selected_event, t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True, brain_bounds: float = 1.0) -> None: + + clim = dict(kind="value", pos_lims=(0, brain_bounds/2, brain_bounds)) + + # Get all activity conditions + for cond in [f'{selected_event}']: + + if True: + ch_summary = df_cha.query(f"Condition.str.startswith('{cond}_delay_') and Chroma == 'hbo'", engine='python') # type: ignore + + # Use ordinary least squares (OLS) if only one participant + # TODO: Fix. + if True: + + # t values + if t_or_theta == 't': + ch_model = smf.ols("t ~ -1 + ch_name", ch_summary).fit() # type: ignore + + # theta values + elif t_or_theta == 'theta': + ch_model = smf.ols("theta ~ -1 + ch_name", ch_summary).fit() # type: ignore + + print("OLS model is being used as there is only one participant!") + + # Convert model results + model_df = cast(DataFrame, statsmodels_to_results(ch_model, order=ch_summary["ch_name"].unique())) # type: ignore + + valid_channels = ch_summary["ch_name"].unique().tolist() # type: ignore + raw_for_plot = raw_haemo.copy().pick(picks=valid_channels) # type: ignore + + brain = plot_3d_evoked_array(raw_for_plot.pick(picks="hbo"), model_df, view="dorsal", distance=0.02, colorbar=True, clim=clim, mode="weighted", size=(800, 700)) # type: ignore + + if show_optodes == 'all' or show_optodes == 'sensors': + brain.add_sensors(getattr(raw_for_plot, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=False) # type: ignore + + if True: + display_text = ('Folder: ' + '\nGroup: ' + '\nCondition: '+ cond + '\nShort Channel Regression: ' + + '\nLooking at: ' + t_or_theta + ' values') + '\nBrain Distance: ' + + # Apply the text onto the brain + if show_text: + brain.add_text(0.12, 0.64, display_text, "title", font_size=11, color="k") # type: ignore + + return brain + + + + +def brain_landmarks_3d(raw_haemo: BaseRaw, show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_brodmann: bool = True) -> None: + + brain = Brain("fsaverage", background="white", size=(800, 700)) # type: ignore + + distances = source_detector_distances(raw_haemo.info) + + # Add optode text labels manually + if show_optodes == 'all' or show_optodes == 'sensors': + brain.add_sensors(getattr(raw_haemo, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=False) # type: ignore + + if show_optodes == 'all' or show_optodes == 'labels': + labeled_srcs = set() + labeled_dets = set() + label_counts = {} + + for idx, ch in enumerate(raw_haemo.info['chs']): + ch_name = ch['ch_name'] + + if not ch_name.endswith('hbo'): + continue + + loc = ch['loc'] + logger.info(f"Channel: {ch_name}") + logger.info(f"loc length: {len(loc)}") + logger.info("loc contents:") + for i, val in enumerate(loc): + logger.info(f" loc[{i}]: {val}") + logger.info("-" * 30) + if not ch_name or not ch['loc'].any(): + continue + + parts = ch_name.split()[0] + src_str, det_str = parts.split('_') + + src_num = int(src_str[1:]) + det_num = int(det_str[1:]) + + if src_num not in labeled_srcs: + src_xyz = ch['loc'][3:6] * 1000 + brain._renderer.text3d(src_xyz[0], src_xyz[1], src_xyz[2], src_str, + color='red', scale=0.002) + labeled_srcs.add(src_num) + + if det_num not in labeled_dets: + det_xyz = ch['loc'][6:9] * 1000 + brain._renderer.text3d(det_xyz[0], det_xyz[1], det_xyz[2], det_str, + color='blue', scale=0.002) + labeled_dets.add(det_num) + + # Get the source-detector distance for this channel (in meters) + dist_m = distances[idx] + dist_mm = dist_m * 1000 + + label_text = f"{dist_mm:.1f} mm" + label_counts[label_text] = label_counts.get(label_text, 0) + 1 + if label_counts[label_text] > 1: + label_text += f" ({label_counts[label_text]})" + + # Label at channel midpoint + mid_xyz = loc[0:3] * 1000 + + logger.info(f"Channel: {ch_name} | Midpoint (mm): x={mid_xyz[0]:.2f}, y={mid_xyz[1]:.2f}, z={mid_xyz[2]:.2f} | Distance: {dist_mm:.1f} mm") + + brain._renderer.text3d( + mid_xyz[0], mid_xyz[1], mid_xyz[2], + label_text, + color='gray', + scale=0.002 + ) + + + if show_brodmann:# Add Brodmann labels + labels = cast(list[Label], read_labels_from_annot("fsaverage", "PALS_B12_Brodmann", "rh", verbose=False)) # type: ignore + + label_colors = { + "Brodmann.39-rh": "blue", + "Brodmann.40-rh": "green", + "Brodmann.6-rh": "pink", + "Brodmann.7-rh": "orange", + "Brodmann.17-rh": "red", + "Brodmann.1-rh": "yellow", + "Brodmann.2-rh": "yellow", + "Brodmann.3-rh": "yellow", + "Brodmann.18-rh": "red", + "Brodmann.19-rh": "red", + "Brodmann.4-rh": "purple", + "Brodmann.8-rh": "white" + } + + for label in labels: + name = getattr(label, "name", None) + if not isinstance(name, str): + continue + if name in label_colors: + brain.add_label(label, borders=False, color=label_colors[name]) # type: ignore + + + return brain + + +def verify_channel_positions(data: BaseRaw) -> None: + """ + Visualizes the sensor/channel positions of the raw data for verification. + + Parameters + ---------- + data : BaseRaw + The loaded data object to process. + """ + +def convert_fig_dict_to_png_bytes(fig_dict: dict[str, Figure]) -> dict[str, bytes]: + png_dict = {} + for label, fig in fig_dict.items(): + buf = BytesIO() + fig.savefig(buf, format="png", bbox_inches="tight") + buf.seek(0) + png_dict[label] = buf.read() + plt.close(fig) + return png_dict + + + +def brain_3d_contrast(con_model_df: DataFrame, con_model_df_filtered: BaseRaw, common_channels: list[str], first_name: str, second_name: str, t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True, brain_bounds: float = 1.0) -> None: + # Filter DataFrame to only common channels, and sort by raw order + con_model = con_model_df + + con_model["ch_name"] = pd.Categorical( + con_model["ch_name"], categories=common_channels, ordered=True + ) + con_model = con_model.sort_values("ch_name").reset_index(drop=True) # type: ignore + + + clim=dict(kind="value", pos_lims=(0, brain_bounds/2, brain_bounds)) + + # Plot brain figure + brain = plot_3d_evoked_array(con_model_df_filtered.copy().pick(picks="hbo"), con_model, view="dorsal", distance=0.02, colorbar=True, mode="weighted", clim=clim, size=(800, 700), verbose=False) # type: ignore + + if show_optodes == 'all' or show_optodes == 'sensors': + brain.add_sensors(getattr(con_model_df_filtered, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=False) # type: ignore + + display_text = ('Contrast: ' + first_name + ' - ' + second_name + '\nLooking at: ' + t_or_theta + ' values') + + # Apply the text onto the brain + if show_text: + brain.add_text(0.12, 0.70, display_text, "title", font_size=11, color="k") # type: ignore + + + +def plot_2d_3d_contrasts_between_groups( + contrast_df_a: pd.DataFrame, + contrast_df_b: pd.DataFrame, + raw_haemo: BaseRaw, + group_a_name: str, + group_b_name: str, + is_3d: bool = True, + t_or_theta: Literal['t', 'theta'] = 'theta', + show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', + show_text: bool = True, + brain_bounds: float = 1.0, +) -> None: + + + logger.info("-----") + contrast_df_a = contrast_df_a.copy() + contrast_df_a["group"] = group_a_name + contrast_df_b = contrast_df_b.copy() + contrast_df_b["group"] = group_b_name + logger.info("-----") + + df_combined = pd.concat([contrast_df_a, contrast_df_b], ignore_index=True) + + con_summary = df_combined.query("Chroma == 'hbo'").copy() + logger.info("-----") + + valid_channels = (pd.crosstab(con_summary["group"], con_summary["ch_name"]) > 1).all() + valid_channels = valid_channels[valid_channels].index.tolist() + con_summary = con_summary[con_summary["ch_name"].isin(valid_channels)] + logger.info("-----") + + model_formula = "effect ~ -1 + group:ch_name:Chroma" + con_model = smf.mixedlm(model_formula, con_summary, groups=con_summary["ID"]).fit(method="nm") + logger.info("-----") + + if t_or_theta == "t": + group1_vals = con_model.tvalues.filter(like=f"group[{group_a_name}]") + group2_vals = con_model.tvalues.filter(like=f"group[{group_b_name}]") + else: + group1_vals = con_model.params.filter(like=f"group[{group_a_name}]") + group2_vals = con_model.params.filter(like=f"group[{group_b_name}]") + logger.info("-----") + + group1_channels = [name.split(":")[1].split("[")[1].split("]")[0] for name in group1_vals.index] + group2_channels = [name.split(":")[1].split("[")[1].split("]")[0] for name in group2_vals.index] + + df_group1 = DataFrame({"Coef.": group1_vals.values}, index=group1_channels) + df_group2 = DataFrame({"Coef.": group2_vals.values}, index=group2_channels) + + df_contrast = df_group1.join(df_group2, how="inner", lsuffix=f"_{group_a_name}", rsuffix=f"_{group_b_name}") + logger.info("-----") + + # A - B + df_contrast["Coef."] = df_contrast[f"Coef._{group_a_name}"] - df_contrast[f"Coef._{group_b_name}"] + con_model_df_1_2 = DataFrame({ + "ch_name": df_contrast.index, + "Coef.": df_contrast["Coef."], + "Chroma": "hbo" + }) + logger.info("-----") + + mne_ch_names = raw_haemo.copy().pick(picks="hbo").ch_names + glm_ch_names = con_model_df_1_2["ch_name"].tolist() + common_channels = [ch for ch in mne_ch_names if ch in glm_ch_names] + + con_model_df_filtered = raw_haemo.copy().pick(picks=common_channels) + con_model_df_1_2 = con_model_df_1_2.set_index("ch_name").loc[common_channels].reset_index() + logger.info("-----") + + if is_3d: + brain_3d_contrast( + con_model_df_1_2, + con_model_df_filtered, + common_channels, + group_a_name, + group_b_name, + t_or_theta, + show_optodes, + show_text, + brain_bounds + ) + else: + plot_glm_group_topo(con_model_df_filtered.copy().pick(picks="hbo"), con_model_df_1_2, names=True, res=128, vlim=(-brain_bounds, brain_bounds)) # type: ignore + + # TODO: The title currently goes on the colorbar. Low priority + plt.title(f"Contrast: {group_a_name} vs {group_b_name}") # type: ignore + plt.show() # type: ignore + + # plt.title(f"Contrast: {group_a_name} vs {group_b_name}") + # plt.show() + + # B - A + df_contrast["Coef."] = df_contrast[f"Coef._{group_b_name}"] - df_contrast[f"Coef._{group_a_name}"] + con_model_df_2_1 = DataFrame({ + "ch_name": df_contrast.index, + "Coef.": df_contrast["Coef."], + "Chroma": "hbo" + }) + + glm_ch_names = con_model_df_2_1["ch_name"].tolist() + common_channels = [ch for ch in mne_ch_names if ch in glm_ch_names] + + con_model_df_filtered = raw_haemo.copy().pick(picks=common_channels) + con_model_df_2_1 = con_model_df_2_1.set_index("ch_name").loc[common_channels].reset_index() + + if is_3d: + brain_3d_contrast( + con_model_df_2_1, + con_model_df_filtered, + common_channels, + group_b_name, + group_a_name, + t_or_theta, + show_optodes, + show_text, + brain_bounds + ) + else: + plot_glm_group_topo(con_model_df_filtered.copy().pick(picks="hbo"), con_model_df_2_1, names=True, res=128, vlim=(-brain_bounds, brain_bounds)) # type: ignore + + # TODO: The title currently goes on the colorbar. Low priority + plt.title(f"Contrast: {group_b_name} vs {group_a_name}") # type: ignore + plt.show() # type: ignore + + + + +def plot_fir_model_results(df, raw_haemo, dm, selected_event, l_bound, u_bound): + + + df["isActivity"] = [f"{selected_event}" in n for n in df["Condition"]] + df["isDelay"] = ["delay" in n for n in df["Condition"]] + df = df.query("isDelay in [True]") + df = df.query("isActivity in [True]") + # Make a new column that stores the condition name for tidier model below + df.loc[:, "TidyCond"] = "" + df.loc[df["isActivity"] == True, "TidyCond"] = f"{selected_event}" # noqa: E712 + # Finally, extract the FIR delay in to its own column in data frame + df.loc[:, "delay"] = [n.split("_")[-1] for n in df.Condition] + + # To simplify this example we will only look at the activity + # condition so we now remove the other conditions from the + # design matrix and GLM results + dm_cols_activity = np.where([f"{selected_event}" in c for c in dm.columns])[0] + dm = dm[[dm.columns[i] for i in dm_cols_activity]] + + lme = smf.mixedlm("theta ~ -1 + delay:TidyCond:Chroma", df, groups=df["ID"]).fit() + + df_sum = statsmodels_to_results(lme) + df_sum["delay"] = [int(n) for n in df_sum["delay"]] + df_sum = df_sum.sort_values("delay") + + # Print the result for the oxyhaemoglobin data in the target condition + df_sum.query(f"TidyCond in ['{selected_event}']").query("Chroma in ['hbo']") + + + fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(19, 10)) + + print("dm columns:", dm.columns.tolist()) + + # Extract design matrix columns that correspond to the condition of interest + dm_cond_idxs = np.where([f"{selected_event}" in n for n in dm.columns])[0] + dm_cond = dm[[dm.columns[i] for i in dm_cond_idxs]] + + # Extract the corresponding estimates from the lme dataframe for hbo + df_hbo = df_sum.query(f"TidyCond in ['{selected_event}']").query("Chroma in ['hbo']") + vals_hbo = [float(v) for v in df_hbo["Coef."]] + + # print("--------------------------------------") + # print(f"dm_cond shape: {dm_cond.shape}") + # print(f"dm_cond columns: {dm_cond.columns.tolist()}") + # print(f"vals_hbo length: {len(vals_hbo)}") + # print(f"vals_hbo sample: {vals_hbo[:5]}") + # print(f"vals_hbo type: {type(vals_hbo)}") + # print(f"vals_hbo element type: {type(vals_hbo[0]) if len(vals_hbo) > 0 else 'N/A'}") + + dm_cond_scaled_hbo = dm_cond * vals_hbo + + # Extract the corresponding estimates from the lme dataframe for hbr + df_hbr = df_sum.query(f"TidyCond in ['{selected_event}']").query("Chroma in ['hbr']") + vals_hbr = [float(v) for v in df_hbr["Coef."]] + dm_cond_scaled_hbr = dm_cond * vals_hbr + + first_onset = None + for desc, onset in zip(raw_haemo.annotations.description, raw_haemo.annotations.onset): + if selected_event in desc: + first_onset = onset + break + + if first_onset is None: + raise ValueError(f"Selected event '{selected_event}' not found in annotations.") + + # Align index values (time axis) to the first occurrence of selected_event + index_values = dm_cond_scaled_hbo.index - np.ceil(first_onset) + index_values = np.asarray(index_values) + + # Plot the result + axes[0].plot(index_values, np.asarray(dm_cond)) + axes[1].plot(index_values, np.asarray(dm_cond_scaled_hbo)) + axes[2].plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r") + axes[2].plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b") + + valid_mask = (index_values >= 0) & (index_values <= 15) + hbo_sum_window = np.sum(dm_cond_scaled_hbo.loc[valid_mask, :], axis=1) + peak_idx_in_window = np.argmax(hbo_sum_window) + peak_idx = np.where(valid_mask)[0][peak_idx_in_window] + peak_time = float(round(index_values[peak_idx], 2)) # type: ignore + + axes[2].axvline(x=peak_time, color='k', linestyle='--', linewidth=1.5, label='Peak time') # type: ignore + + # Format the plot + for ax in range(3): + axes[ax].set_xlim(-5, 25) + axes[ax].set_xlabel("Time (s)") + axes[0].set_ylim(-0.1, 1.1) + axes[1].set_ylim(l_bound, u_bound) + axes[2].set_ylim(l_bound, u_bound) + axes[0].set_title("FIR Model (Unscaled by GLM estimates)") + axes[1].set_title(f"FIR Components (Scaled by {selected_event} GLM Estimates)") + axes[2].set_title(f"Evoked Response {selected_event}") + axes[0].set_ylabel("FIR Model") + axes[1].set_ylabel("Oyxhaemoglobin (ΔμMol)") + axes[2].set_ylabel("Haemoglobin (ΔμMol)") + axes[2].legend(["Oyxhaemoglobin", "Deoyxhaemoglobin"]) + + # We can also extract the 95% confidence intervals of the estimates too + l95_hbo = [float(v) for v in df_hbo["[0.025"]] # type: ignore + u95_hbo = [float(v) for v in df_hbo["0.975]"]] # type: ignore + dm_cond_scaled_hbo_l95 = dm_cond * l95_hbo + dm_cond_scaled_hbo_u95 = dm_cond * u95_hbo + l95_hbr = [float(v) for v in df_hbr["[0.025"]] # type: ignore + u95_hbr = [float(v) for v in df_hbr["0.975]"]] # type: ignore + dm_cond_scaled_hbr_l95 = dm_cond * l95_hbr + dm_cond_scaled_hbr_u95 = dm_cond * u95_hbr + + axes2: Axes + fig2, axes2 = plt.subplots(nrows=1, ncols=1, figsize=(7, 7)) # type: ignore + + # Plot the result + axes2.plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r") # type: ignore + axes2.plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b") # type: ignore + axes2.axvline(x=peak_time, color='k', linestyle='--', linewidth=1.5, label='Peak time') # type: ignore + + axes2.fill_between( # type: ignore + index_values, + np.asarray(np.sum(dm_cond_scaled_hbo_l95, axis=1)), + np.asarray(np.sum(dm_cond_scaled_hbo_u95, axis=1)), + facecolor="red", + alpha=0.25, + ) + axes2.fill_between( # type: ignore + index_values, + np.asarray(np.sum(dm_cond_scaled_hbr_l95, axis=1)), + np.asarray(np.sum(dm_cond_scaled_hbr_u95, axis=1)), + facecolor="blue", + alpha=0.25, + ) + + # Format the plot + axes2.set_xlim(-5, 20) + axes2.set_ylim(l_bound, u_bound) + axes2.set_title(f"Evoked Response with 95% confidence intervals for )") # type: ignore + axes2.set_ylabel("Haemoglobin (ΔμMol)") # type: ignore + axes2.legend(["Oyxhaemoglobin", "Deoyxhaemoglobin", f"Peak {peak_time}s"]) # type: ignore + axes2.set_xlabel("Time (s)") # type: ignore + + fig2.tight_layout() + + fig.show() + fig2.show() + + + +def load_snirf(file_path: str) -> tuple[BaseRaw, Figure]: + """ + Loads a snirf file, optionally drops channels, downsamples, and creates a figure showing the results. + + Parameters + ---------- + file_path : str + Path of the snirf file to load. + ID : str + File name of the the snirf file that was loaded. + drop_prefixes : list[str] + List of channel name prefixes to drop from the data. + + Returns + ------- + tuple[BaseRaw, Figure] + - BaseRaw: The processed data object. + - Figure: The corresponding Matplotlib figure. + """ + + # Read the snirf file + raw = read_raw_snirf(file_path, preload=True, verbose=VERBOSITY) # type: ignore + raw.load_data(verbose=VERBOSITY) # type: ignore + + # Strip the specified amount of seconds from the start of the file + total_duration = getattr(raw, "times")[-1] + if total_duration > SECONDS_TO_STRIP: + raw.crop(tmin=SECONDS_TO_STRIP, tmax=total_duration, verbose=VERBOSITY) # type: ignore + logger.info(f"Stripped first {SECONDS_TO_STRIP} second(s) of data.") + else: + logger.info(f"Data length ({total_duration:.2f}s) less than strip duration; no cropping applied.") + + # If the user forcibly dropped channels, remove them now before any processing occurs + # logger.info("Checking if there are channels to forcibly drop...") + # if drop_prefixes: + # logger.info("Force dropped channels was specified.") + # channels_to_drop = [ch for ch in cast(list[str], getattr(raw, "ch_names")) if any(ch.startswith(prefix) for prefix in drop_prefixes)] + # raw.drop_channels(channels_to_drop, "raise") # type: ignore + # logger.info("Force dropped channels:", channels_to_drop) + + # If the user wants to downsample, do it right away + logger.info("Checking if we should downsample...") + if DOWNSAMPLE: + logger.info("Downsample was specified.") + sfreq_old = getattr(raw, "info")["sfreq"] + raw.resample(DOWNSAMPLE_FREQUENCY, verbose=VERBOSITY) # type: ignore + sfreq_new = getattr(raw, "info")["sfreq"] + logger.info(f"Finished downsampling. Old frequency: {sfreq_old}. New frequency: {sfreq_new}.") + + logger.info("Successfully loaded the snirf file.") + + return raw + + +def run_second_level_analysis(df_contrasts, raw, p, bounds): + """ + Perform second-level analysis using contrast data from multiple participants. + + Parameters + ---------- + df_contrasts : pd.DataFrame + Combined contrast results from multiple participants. + Must include: ['ch_name', 'effect', 'ID'] + + Returns + ------- + pd.DataFrame + Group-level t-values, p-values, and mean effect per channel. + """ + + if not all(col in df_contrasts.columns for col in ['ch_name', 'effect', 'ID']): + raise ValueError("Input DataFrame must include 'ch_name', 'effect', and 'ID' columns.") + + channels = df_contrasts['ch_name'].unique() + group_results = [] + + for ch in channels: + ch_data = df_contrasts[df_contrasts['ch_name'] == ch] + + if ch_data['ID'].nunique() < 2: + logger.warning(f"Skipping channel {ch} — not enough subjects.") + continue + + Y = ch_data['effect'].values + design_matrix = np.ones((len(Y), 1)) # intercept-only + model = OLSModel(design_matrix) + result = model.fit(Y) + + t_val = result.t(0).item() + p_val = 2 * stats.t.sf(np.abs(t_val), df=result.df_model) + mean_beta = np.mean(Y) + + group_results.append({ + 'ch_name': ch, + 't_val': t_val, + 'p_val': p_val, + 'mean_beta': mean_beta, + 'n_subjects': len(Y) + }) + + df_group = pd.DataFrame(group_results) + logger.info("Second-level results:\n%s", df_group) + + + # Extract the cource and detector positions from raw + src_pos: dict[int, tuple[float, float]] = {} + det_pos: dict[int, tuple[float, float]] = {} + for ch in getattr(raw, "info")["chs"]: + ch_name = ch['ch_name'] + if not ch_name or not ch['loc'].any(): + continue + parts = ch_name.split()[0] + src_str, det_str = parts.split('_') + src_num = int(src_str[1:]) + det_num = int(det_str[1:]) + src_pos[src_num] = ch['loc'][3:5] + det_pos[det_num] = ch['loc'][6:8] + + # Set up the plot + fig, ax = plt.subplots(figsize=(8, 6)) # type: ignore + + # Plot the sources + for pos in src_pos.values(): + ax.scatter(pos[0], pos[1], s=120, c='k', marker='o', edgecolors='white', linewidths=1, zorder=3) # type: ignore + + # Plot the detectors + for pos in det_pos.values(): + ax.scatter(pos[0], pos[1], s=120, c='k', marker='s', edgecolors='white', linewidths=1, zorder=3) # type: ignore + + # Ensure that the colors stay within the boundaries even if they are over or under the max/min values + norm = mcolors.Normalize(vmin=-bounds, vmax=bounds) + + cmap: mcolors.Colormap = plt.get_cmap('seismic') + + # Plot connections with avg t-values + for _, row in df_group.iterrows(): + ch = row['ch_name'] + pval = row['p_val'] + tval = row['t_val'] + + if '_' not in ch: + logger.info(f"Skipping channel with unexpected format (no underscore): {ch}") + continue + + src_str, det_str = ch.split('_') + det_parts = det_str.split() + detector_id = det_parts[0] # e.g. "D1" + hemo_type = det_parts[1].lower() if len(det_parts) > 1 else '' + + logger.info(f"Parsing channel: {ch} -> src_str: {src_str}, det_str: {detector_id}, hemo_type: {hemo_type}") + + if hemo_type != 'hbo': + logger.info(f"Skipping channel {ch} because hemo_type is not HbO: {hemo_type}") + continue + + try: + src = int(src_str[1:]) + det = int(detector_id[1:]) + logger.info(f"Parsed src: {src}, det: {det}") + + except Exception as e: + logger.info(f"Error parsing source/detector from channel '{ch}': {e}") + continue + + if src in src_pos and det in det_pos: + x = [src_pos[src][0], det_pos[det][0]] + y = [src_pos[src][1], det_pos[det][1]] + style = '-' if pval <= p else '--' + color = cmap(norm(tval)) + logger.info(f"Plotting {ch}: t={tval:.2f}, p={pval:.3f}, color={color}, style={style}") + ax.plot(x, y, linestyle=style, color=color, linewidth=4, alpha=0.9, zorder=2) + + + # Format the Colorbar + sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) + sm.set_array([]) + cbar = plt.colorbar(sm, ax=ax, shrink=0.85) # type: ignore + cbar.set_label(f'Average value (hbo)', fontsize=11) # type: ignore + + # Formatting the subplots + ax.set_aspect('equal') + ax.set_title(f"Average values (HbO)", fontsize=14) # type: ignore + ax.set_xlabel('X position (m)', fontsize=11) # type: ignore + ax.set_ylabel('Y position (m)', fontsize=11) # type: ignore + ax.grid(True, alpha=0.3) # type: ignore + + # Set axis limits to be 1cm more than the optode positions + all_x = [pos[0] for pos in src_pos.values()] + [pos[0] for pos in det_pos.values()] + all_y = [pos[1] for pos in src_pos.values()] + [pos[1] for pos in det_pos.values()] + ax.set_xlim(min(all_x)-0.01, max(all_x)+0.01) + ax.set_ylim(min(all_y)-0.01, max(all_y)+0.01) + + fig.tight_layout() + plt.show() # type: ignore + + return df_group + + +def calculate_dpf(file_path): + # order is hbo / hbr + with h5py.File(file_path, 'r') as f: + wavelengths = f['/nirs/probe/wavelengths'][:] + logger.info("Wavelengths (nm):", wavelengths) + wavelengths = sorted(wavelengths, reverse=True) + age = float(AGE) + logger.info(f"Their age was {AGE}") + a = 223.3 + b = 0.05624 + c = 0.8493 + d = -5.723e-7 + e = 0.001245 + f = -0.9025 + dpf = [] + for w in wavelengths: + logger.info(w) + dpf.append(a + b * (age**c) + d* (w**3) + e * (w**2) + f*w) + logger.info(dpf) + return dpf + + + +def process_participant(file_path, progress_callback=None): + + fig_individual: dict[str, Figure] = {} + + # Step 1: Load + raw = load_snirf(file_path) + fig_raw = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Loaded Raw", show=False) + fig_individual["Loaded Raw"] = fig_raw + if progress_callback: progress_callback(1) + logger.info("1") + + # Step 1.5: Verify optode positions + fig_optodes = raw.plot_sensors(show_names=True, to_sphere=True, show=False) # type: ignore + fig_individual["Plot Sensors"] = fig_optodes + if progress_callback: progress_callback(2) + logger.info("2") + + # Step 2: Downsample + # raw = raw.resample(0.5) # Downsample to 0.5 Hz + + # Step 2: Bad from SCI + bad_sci = [] + if SCI: + bad_sci, fig_sci_1, fig_sci_2 = calculate_scalp_coupling(raw) + fig_individual["SCI1"] = fig_sci_1 + fig_individual["SCI2"] = fig_sci_2 + if progress_callback: progress_callback(3) + logger.info("3") + + # Step 2: Bad from SNR + bad_snr = [] + if SNR: + bad_snr, fig_snr = calculate_signal_noise_ratio(raw) + fig_individual["SNR1"] = fig_snr + if progress_callback: progress_callback(4) + logger.info("4") + + # Step 3: Bad from PSP + bad_psp = [] + if PSP: + bad_psp, fig_psp1, fig_psp2 = calculate_peak_power(raw) + fig_individual["PSP1"] = fig_psp1 + fig_individual["PSP2"] = fig_psp2 + if progress_callback: progress_callback(5) + logger.info("5") + + # Step 4: Mark the bad channels + raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp) + if fig_dropped and fig_raw_before is not None: + fig_individual["fig2"] = fig_dropped + fig_individual["fig3"] = fig_raw_before + if progress_callback: progress_callback(6) + logger.info("6") + + # Step 5: Interpolate the bad channels + if bad_channels: + raw, fig_raw_after = interpolate_fNIRS_bads_weighted_average(raw, bad_channels) + fig_individual["fig4"] = fig_raw_after + if progress_callback: progress_callback(7) + logger.info("7") + + # Step 6: Optical Density + raw_od = optical_density(raw) + fig_raw_od = raw_od.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Optical Density", show=False) + fig_individual["Optical Density"] = fig_raw_od + if progress_callback: progress_callback(8) + logger.info("8") + + # Step 7: TDDR + if TDDR: + raw_od = temporal_derivative_distribution_repair(raw_od) + fig_raw_od_tddr = raw_od.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="After TDDR (Motion Correction)", show=False) + fig_individual["TDDR"] = fig_raw_od_tddr + if progress_callback: progress_callback(9) + logger.info("9") + + # Step 8: BLL + raw_haemo = beer_lambert_law(raw_od, ppf=calculate_dpf(file_path)) + fig_raw_haemo_bll = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="HbO and HbR Signals", show=False) + fig_individual["BLL"] = fig_raw_haemo_bll + if progress_callback: progress_callback(10) + logger.info("10") + + # Step 9: ENC + # raw_haemo = enhance_negative_correlation(raw_haemo) + # fig_raw_haemo_enc = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="HbO and HbR Signals", show=False) + # fig_individual.append(fig_raw_haemo_enc) + + # Step 10: Filter + fig_filter, fig_raw_haemo_filter = filter_the_data(raw_haemo) + fig_individual["filter1"] = fig_filter + fig_individual["filter2"] = fig_raw_haemo_filter + if progress_callback: progress_callback(11) + logger.info("11") + + # Step 11: Get short / long channels + short_chans = get_short_channels(raw_haemo, max_dist=0.015) + fig_short_chans = short_chans.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="Short Channels Only", show=False) + fig_individual["short"] = fig_short_chans + raw_haemo = get_long_channels(raw_haemo) + if progress_callback: progress_callback(12) + logger.info("12") + + # Step 12: Events from annotations + events, event_dict = events_from_annotations(raw_haemo) + fig_events = plot_events(events, event_id=event_dict, sfreq=raw_haemo.info["sfreq"], show=False) + fig_individual["events"] = fig_events + if progress_callback: progress_callback(13) + logger.info("13") + + # Step 13: Epoch calculations + epochs, fig_epochs = epochs_calculations(raw_haemo, events, event_dict) + for name, fig in fig_epochs: # Unpack the tuple here + fig_individual[f"epochs_{name}"] = fig # Store only the figure, not the name + if progress_callback: progress_callback(14) + logger.info("14") + + # Step 14: Design Matrix + design_matrix, fig_design_matrix = make_design_matrix(raw_haemo, short_chans) + fig_individual["Design Matrix"] = fig_design_matrix + if progress_callback: progress_callback(15) + logger.info("15") + + # Step 15: Run GLM + glm_est = run_glm(raw_haemo, design_matrix) + # Not used AppData\Local\Packages\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\LocalCache\local-packages\Python313\site-packages\nilearn\glm\contrasts.py + # Yes used AppData\Local\Packages\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\LocalCache\local-packages\Python313\site-packages\mne_nirs\utils\_io.py + + # The p-value is calculated from this t-statistic using the Student’s t-distribution with appropriate degrees of freedom. + # p_value = 2 * stats.t.cdf(-abs(t_statistic), df) + # It is a two-tailed p-value. + # It says how likely it is to observe the effect you did (or something more extreme) if the true effect was zero (null hypothesis). + # A small p-value (e.g., < 0.05) suggests the effect is unlikely to be zero — it’s "statistically significant." + # A large p-value means the data do not provide strong evidence that the effect is different from zero. + + + if progress_callback: progress_callback(16) + logger.info("16") + + # Step 16: Plot GLM results + fig_glm_result = plot_glm_results(file_path, raw_haemo, glm_est, design_matrix) + fig_individual["GLM"] = fig_glm_result + if progress_callback: progress_callback(17) + logger.info("17") + + # Step 17: Plot channel significance + fig_significance = individual_significance(raw_haemo, glm_est) + fig_individual["Significance"] = fig_significance + if progress_callback: progress_callback(18) + logger.info("18") + + # Step 18: cha, con, roi + cha = glm_est.to_dataframe() + + # HACK: Comment out line 588 (self._renderer.show()) in _brain.py from MNE + # brain_thing = brain_3d_visualization(cha, raw_haemo) + # brain_individual.append(brain_thing) + # C++ objects made this get rendered on the fly + + rois = dict(AllChannels=range(len(raw_haemo.ch_names))) + # Calculate ROI for all conditions + conditions = design_matrix.columns + # Compute output metrics by ROI + df_ind = glm_est.to_dataframe_region_of_interest(rois, conditions) + + df_ind["ID"] = file_path + + # Step 18: Fold channels + # fig_fold_data, fig_fold_legend = fold_channels(raw_haemo) + # fig_individual.append(fig_fold_data) + # fig_individual.append(fig_fold_legend) + print(design_matrix) + + + contrast_matrix = np.eye(design_matrix.shape[1]) + basic_conts = dict( + [(column, contrast_matrix[i]) for i, column in enumerate(design_matrix.columns)] + ) + + all_delay_cols = [col for col in design_matrix.columns if "_delay_" in col] + all_conditions = sorted({col.split("_delay_")[0] for col in all_delay_cols}) + + if not all_conditions: + raise ValueError("No FIR regressors found in the design matrix.") + + # Build contrast vectors for each condition + contrast_dict = {} + + for condition in all_conditions: + delay_cols = [col for col in all_delay_cols if col.startswith(f"{condition}_delay_")] + + if not delay_cols: + continue # skip if no columns found (shouldn't happen?) + + # Average across all delay regressors for this condition + contrast_vector = np.sum([basic_conts[col] for col in delay_cols], axis=0) + contrast_vector /= len(delay_cols) + + contrast_dict[condition] = contrast_vector + + # Compute contrast results + contrast_results = {} + + for cond, contrast_vector in contrast_dict.items(): + contrast = glm_est.compute_contrast(contrast_vector) # type: ignore + df = contrast.to_dataframe() + df["ID"] = file_path + contrast_results[cond] = df + + cha["ID"] = file_path + + fig_bytes = convert_fig_dict_to_png_bytes(fig_individual) + + + return raw_haemo, epochs, fig_bytes, cha, contrast_results, df_ind, design_matrix, AGE, GENDER, GROUP, True + +# Not 3000 lines yay! \ No newline at end of file diff --git a/main.py b/main.py index f0d6ed2..c4a6750 100644 --- a/main.py +++ b/main.py @@ -22,19 +22,27 @@ from pathlib import Path from datetime import datetime from multiprocessing import Process, current_process, freeze_support, Manager +import numpy as np +import pandas as pd + # External library imports -import mne import psutil import requests + +from mne.io import read_raw_snirf from mne.preprocessing.nirs import source_detector_distances +from mne_nirs.io import write_raw_snirf +from mne.channels import make_dig_montage + from PySide6.QtWidgets import ( QApplication, QWidget, QMessageBox, QVBoxLayout, QHBoxLayout, QTextEdit, QScrollArea, QComboBox, QGridLayout, - QPushButton, QMainWindow, QFileDialog, QLabel, QLineEdit, QFrame, QSizePolicy, QGroupBox + QPushButton, QMainWindow, QFileDialog, QLabel, QLineEdit, QFrame, QSizePolicy, QGroupBox, QDialog, QListView, ) -from PySide6.QtCore import QThread, Signal, Qt, QTimer -from PySide6.QtGui import QAction, QKeySequence, QIcon, QIntValidator, QDoubleValidator +from PySide6.QtCore import QThread, Signal, Qt, QTimer, QEvent, QSize +from PySide6.QtGui import QAction, QKeySequence, QIcon, QIntValidator, QDoubleValidator, QPixmap, QStandardItemModel, QStandardItem -CURRENT_VERSION = "1.0.0" + +CURRENT_VERSION = "1.0.1" API_URL = "https://git.research.dezeeuw.ca/api/v1/repos/tyler/flares/releases" PLATFORM_NAME = platform.system().lower() @@ -47,64 +55,39 @@ SECTIONS = [ {"name": "SECONDS_TO_STRIP", "default": 0, "type": int, "help": "Seconds to remove from beginning of file. Setting this to 0 will remove nothing from the file."}, {"name": "DOWNSAMPLE", "default": True, "type": bool, "help": "Downsample snirf files."}, {"name": "DOWNSAMPLE_FREQUENCY", "default": 25, "type": int, "help": "Frequency (Hz) to downsample to. If this is set higher than the input data, new data will be interpolated."}, - {"name": "FORCE_DROP_CHANNELS", "default": "", "type": str, "help": "Channels to forcibly drop (comma separated)."}, - {"name": "SOURCE_DETECTOR_SEPARATOR", "default": "_", "type": str, "help": "Separator between source and detector."}, ] }, { - "title": "Update Optode Positions", - "params": [ - {"name": "OPTODE_FILE", "default": True, "type": bool, "help": "Use optode file to update positions."}, - {"name": "OPTODE_FILE_PATH", "default": "", "type": str, "help": "Optode file location."}, - {"name": "OPTODE_FILE_SEPARATOR", "default": ":", "type": str, "help": "Separator in optode file."}, - ] - }, - { - "title": "Temporal Derivative Distribution Repair filtering", - "params": [ - {"name": "TDDR", "default": True, "type": bool, "help": "Apply TDDR filtering."}, - ] - }, - { - "title": "Wavelet filtering", - "params": [ - {"name": "WAVELET", "default": True, "type": bool, "help": "Apply Wavelet filtering."}, - {"name": "IQR", "default": 1.5, "type": float, "help": "Interquartile Range for Wavelet filter."}, - ] - }, - { - "title": "Heart rate", - "params": [ - {"name": "HEART_RATE", "default": True, "type": bool, "help": "Calculate heart rate."}, - {"name": "SECONDS_TO_STRIP_HR", "default": 5, "type": int, "help": "Seconds to strip for HR calculation."}, - {"name": "MAX_LOW_HR", "default": 40, "type": int, "help": "Minimum heart rate value."}, - {"name": "MAX_HIGH_HR", "default": 200, "type": int, "help": "Maximum heart rate value."}, - {"name": "SMOOTHING_WINDOW_HR", "default": 100, "type": int, "help": "Rolling average window for HR."}, - {"name": "HEART_RATE_WINDOW", "default": 25, "type": int, "help": "Range of BPM around average."}, - {"name": "SHORT_CHANNEL", "default": True, "type": bool, "help": "Indicates if data has short channel."}, - {"name": "SHORT_CHANNEL_THRESH", "default": 0.013, "type": float, "help": "Threshold for short channel (m)."}, - ] - }, - { - "title": "Scalp Coupling Index / Peak Spectral Power / Signal to Noise Ratio", + "title": "Scalp Coupling Index", "params": [ {"name": "SCI", "default": True, "type": bool, "help": "Calculate Scalp Coupling Index."}, {"name": "SCI_TIME_WINDOW", "default": 3, "type": int, "help": "SCI time window."}, {"name": "SCI_THRESHOLD", "default": 0.6, "type": float, "help": "SCI threshold (0-1)."}, - {"name": "PSP", "default": True, "type": bool, "help": "Calculate Peak Spectral Power."}, - {"name": "PSP_TIME_WINDOW", "default": 3, "type": int, "help": "PSP time window."}, - {"name": "PSP_THRESHOLD", "default": 0.1, "type": float, "help": "PSP threshold."}, - {"name": "SNR", "default": True, "type": bool, "help": "Calculate Signal to Noise Ratio."}, - {"name": "SNR_TIME_WINDOW", "default": -1, "type": int, "help": "SNR time window."}, - {"name": "SNR_THRESHOLD", "default": 2.0, "type": float, "help": "SNR threshold (dB)."}, ] }, { - "title": "Drop bad channels", + "title": "Signal to Noise Ratio", "params": [ - {"name": "EXCLUDE_CHANNELS", "default": True, "type": bool, "help": "Drop channels failing metrics."}, - {"name": "MAX_BAD_CHANNELS", "default": 15, "type": int, "help": "Max bad channels allowed."}, - {"name": "LONG_CHANNEL_THRESH", "default": 0.045, "type": float, "help": "Max distance (m) for channel."}, + {"name": "SNR", "default": True, "type": bool, "help": "Calculate Signal to Noise Ratio."}, + # {"name": "SNR_TIME_WINDOW", "default": -1, "type": int, "help": "SNR time window."}, + {"name": "SNR_THRESHOLD", "default": 5.0, "type": float, "help": "SNR threshold (dB)."}, + ] + }, + { + "title": "Peak Spectral Power", + "params": [ + + {"name": "PSP", "default": True, "type": bool, "help": "Calculate Peak Spectral Power."}, + {"name": "PSP_TIME_WINDOW", "default": 3, "type": int, "help": "PSP time window."}, + {"name": "PSP_THRESHOLD", "default": 0.1, "type": float, "help": "PSP threshold."}, + ] + }, + { + "title": "Bad Channels Handling", + "params": [ + # {"name": "NOT_IMPLEMENTED", "default": True, "type": bool, "help": "Calculate Peak Spectral Power."}, + # {"name": "NOT_IMPLEMENTED", "default": 3, "type": int, "help": "PSP time window."}, + # {"name": "NOT_IMPLEMENTED", "default": 0.1, "type": float, "help": "PSP threshold."}, ] }, { @@ -113,24 +96,54 @@ SECTIONS = [ # Intentionally empty (TODO) ] }, + { + "title": "Temporal Derivative Distribution Repair filtering", + "params": [ + {"name": "TDDR", "default": True, "type": bool, "help": "Apply TDDR filtering."}, + ] + }, { "title": "Haemoglobin Concentration", "params": [ # Intentionally empty (TODO) ] }, + { + "title": "Enhance Negative Correlation", + "params": [ + #{"name": "ENHANCE_NEGATIVE_CORRELATION", "default": False, "type": bool, "help": "Calculate Peak Spectral Power."}, + ] + }, + { + "title": "Filtering", + "params": [ + #{"name": "FILTER", "default": True, "type": bool, "help": "Calculate Peak Spectral Power."}, + ] + }, + { + "title": "Extracting Events", + "params": [ + #{"name": "EVENTS", "default": True, "type": bool, "help": "Calculate Peak Spectral Power."}, + ] + }, + { + "title": "Epoch Calculations", + "params": [ + #{"name": "EVENTS", "default": True, "type": bool, "help": "Calculate Peak Spectral Power."}, + ] + }, { "title": "Design Matrix", "params": [ - {"name": "DRIFT_MODEL", "default": "cosine", "type": str, "help": "Drift model for GLM."}, - {"name": "DURATION_BETWEEN_ACTIVITIES", "default": 35, "type": int, "help": "Time between activities (s)."}, - {"name": "SHORT_CHANNEL_REGRESSION", "default": True, "type": bool, "help": "Use short channel regression."}, + # {"name": "DRIFT_MODEL", "default": "cosine", "type": str, "help": "Drift model for GLM."}, + # {"name": "DURATION_BETWEEN_ACTIVITIES", "default": 35, "type": int, "help": "Time between activities (s)."}, + # {"name": "SHORT_CHANNEL_REGRESSION", "default": True, "type": bool, "help": "Use short channel regression."}, ] }, { "title": "General Linear Model", "params": [ - {"name": "N_JOBS", "default": 1, "type": int, "help": "Number of jobs for GLM processing."}, + #{"name": "N_JOBS", "default": 1, "type": int, "help": "Number of jobs for GLM processing."}, ] }, { @@ -369,6 +382,239 @@ class UserGuideWindow(QWidget): +class UpdateOptodesWindow(QWidget): + + def __init__(self, parent=None): + super().__init__(parent, Qt.WindowType.Window) + self.setWindowTitle("Update optode positions") + self.resize(760, 200) + + self.label_file_a = QLabel("SNIRF file:") + self.line_edit_file_a = QLineEdit() + self.line_edit_file_a.setReadOnly(True) + self.btn_browse_a = QPushButton("Browse .snirf") + self.btn_browse_a.clicked.connect(self.browse_file_a) + + self.label_file_b = QLabel("TXT file:") + self.line_edit_file_b = QLineEdit() + self.line_edit_file_b.setReadOnly(True) + self.btn_browse_b = QPushButton("Browse .txt") + self.btn_browse_b.clicked.connect(self.browse_file_b) + + self.label_suffix = QLabel("Suffix to append to filename:") + self.line_edit_suffix = QLineEdit() + self.line_edit_suffix.setText("flare") + + self.btn_clear = QPushButton("Clear") + self.btn_go = QPushButton("Go") + self.btn_clear.clicked.connect(self.clear_files) + self.btn_go.clicked.connect(self.go_action) + + # --- + layout = QVBoxLayout() + self.description = QLabel() + self.description.setTextFormat(Qt.TextFormat.RichText) + self.description.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction) + self.description.setOpenExternalLinks(False) # Handle the click internally + + self.description.setText("Some software when creating snirf files will insert a template of optode positions as the correct position of the optodes for the participant.
" + "This is rarely correct as each head differs slightly in shape or size, and a lot of calculations require the optodes to be in the correct location.
" + "Using a .txt file, we can update the positions in the snirf file to match those of a digitization system such as one from Polhemus or elsewhere.
" + "The .txt file should have the fiducials, detectors, and sources clearly labeled, followed by the x, y, and z coordinates seperated by a space.
" + "An example format of what a digitization text file should look like can be found by clicking here.") + + self.description.linkActivated.connect(self.handle_link_click) + layout.addWidget(self.description) + + help_text_a = "Select the SNIRF (.snirf) file to update with new optode positions." + + file_a_layout = QHBoxLayout() + + # Help button on the left + help_btn_a = QPushButton("?") + help_btn_a.setFixedWidth(25) + help_btn_a.setToolTip(help_text_a) + help_btn_a.clicked.connect(lambda _, text=help_text_a: self.show_help_popup(text)) + file_a_layout.addWidget(help_btn_a) + + # Container for label + line_edit + browse button with tooltip + file_a_container = QWidget() + file_a_container_layout = QHBoxLayout() + file_a_container_layout.setContentsMargins(0, 0, 0, 0) + file_a_container_layout.addWidget(self.label_file_a) + file_a_container_layout.addWidget(self.line_edit_file_a) + file_a_container_layout.addWidget(self.btn_browse_a) + file_a_container.setLayout(file_a_container_layout) + file_a_container.setToolTip(help_text_a) + + file_a_layout.addWidget(file_a_container) + layout.addLayout(file_a_layout) + + + help_text_b = "Provide a .txt file with labeled optodes (e.g., nz, rpa, lpa, d1, s1) and their x, y, z coordinates." + + file_b_layout = QHBoxLayout() + + help_btn_b = QPushButton("?") + help_btn_b.setFixedWidth(25) + help_btn_b.setToolTip(help_text_b) + help_btn_b.clicked.connect(lambda _, text=help_text_b: self.show_help_popup(text)) + file_b_layout.addWidget(help_btn_b) + + file_b_container = QWidget() + file_b_container_layout = QHBoxLayout() + file_b_container_layout.setContentsMargins(0, 0, 0, 0) + file_b_container_layout.addWidget(self.label_file_b) + file_b_container_layout.addWidget(self.line_edit_file_b) + file_b_container_layout.addWidget(self.btn_browse_b) + file_b_container.setLayout(file_b_container_layout) + file_b_container.setToolTip(help_text_b) + + file_b_layout.addWidget(file_b_container) + layout.addLayout(file_b_layout) + + + help_text_suffix = "This text will be appended to the original filename when saving. Default is 'flare'." + + suffix_layout = QHBoxLayout() + + help_btn_suffix = QPushButton("?") + help_btn_suffix.setFixedWidth(25) + help_btn_suffix.setToolTip(help_text_suffix) + help_btn_suffix.clicked.connect(lambda _, text=help_text_suffix: self.show_help_popup(text)) + suffix_layout.addWidget(help_btn_suffix) + + suffix_container = QWidget() + suffix_container_layout = QHBoxLayout() + suffix_container_layout.setContentsMargins(0, 0, 0, 0) + suffix_container_layout.addWidget(self.label_suffix) + suffix_container_layout.addWidget(self.line_edit_suffix) + suffix_container.setLayout(suffix_container_layout) + suffix_container.setToolTip(help_text_suffix) + + suffix_layout.addWidget(suffix_container) + layout.addLayout(suffix_layout) + + + buttons_layout = QHBoxLayout() + buttons_layout.addStretch() + buttons_layout.addWidget(self.btn_clear) + buttons_layout.addWidget(self.btn_go) + layout.addLayout(buttons_layout) + + self.setLayout(layout) + + + def show_help_popup(self, text): + msg = QMessageBox(self) + msg.setWindowTitle("Parameter Info") + msg.setText(text) + msg.exec() + + def handle_link_click(self, link): + if link == "custom_link": + msg = QMessageBox(self) + msg.setWindowTitle("Example Digitization File") + + text = "nz: -1.91 85.175 -31.1525\n" \ + "rpa: 80.3825 -17.1925 -57.2775\n" \ + "lpa: -81.815 -17.1925 -57.965\n" \ + "d1: 0.01 -97.5175 62.5875\n" \ + "d2: 25.125 -103.415 45.045\n" \ + "d3: 49.095 -97.9025 30.2075\n" \ + "s1: 0.01 -112.43 32.595\n" \ + "s2: 30.325 -84.3125 71.8975\n" \ + "s3: 0.01 -70.6875 89.0925\n" + msg.setText(text) + msg.exec() + + def browse_file_a(self): + file_path, _ = QFileDialog.getOpenFileName(self, "Select SNIRF File", "", "SNIRF Files (*.snirf)") + if file_path: + self.line_edit_file_a.setText(file_path) + + def browse_file_b(self): + file_path, _ = QFileDialog.getOpenFileName(self, "Select TXT File", "", "Text Files (*.txt)") + if file_path: + self.line_edit_file_b.setText(file_path) + + def clear_files(self): + self.line_edit_file_a.clear() + self.line_edit_file_b.clear() + + def go_action(self): + file_a = self.line_edit_file_a.text() + file_b = self.line_edit_file_b.text() + suffix = self.line_edit_suffix.text().strip() or "flare" + + if not file_a: + QMessageBox.critical(self, "Missing File", "Please select a SNIRF file.") + return + if not file_b: + QMessageBox.critical(self, "Missing File", "Please select a TXT file.") + return + + # Get original filename without extension + base_name = os.path.splitext(os.path.basename(file_a))[0] + suggested_name = f"{base_name}_{suffix}.snirf" + + # Open save dialog with default name + save_path, _ = QFileDialog.getSaveFileName( + self, + "Save SNIRF File As", + suggested_name, + "SNIRF Files (*.snirf)" + ) + + if not save_path: + print("Save cancelled.") + return + + # Ensure .snirf extension + if not save_path.lower().endswith(".snirf"): + save_path += ".snirf" + + try: + self.update_optode_positions(file_a=file_a, file_b=file_b, save_path=save_path) + except Exception as e: + QMessageBox.critical(self, "Error", f"Failed to write file:\n{e}") + return + + QMessageBox.information(self, "File Saved", f"File was saved to:\n{save_path}") + + + def update_optode_positions(self, file_a, file_b, save_path): + + fiducials = {} + ch_positions = {} + + # Read the lines from the optode file + with open(file_b, 'r') as f: + for line in f: + if line.strip(): + # Split by the semicolon and convert to meters + ch_name, coords_str = line.split(":") + coords = np.array(list(map(float, coords_str.strip().split()))) * 0.001 + + # The key we have is a fiducial + if ch_name.lower() in ['lpa', 'nz', 'rpa']: + fiducials[ch_name.lower()] = coords + + # The key we have is a source or detector + else: + ch_positions[ch_name.upper()] = coords + + # Create montage with updated coords in head space + initial_montage = make_dig_montage(ch_pos=ch_positions, nasion=fiducials.get('nz'), lpa=fiducials.get('lpa'), rpa=fiducials.get('rpa'), coord_frame='head') # type: ignore + + # Read the SNIRF file, set the montage, and write it back + raw = read_raw_snirf(file_a, preload=True) + raw.set_montage(initial_montage) + write_raw_snirf(raw, save_path) + + + + class ProgressBubble(QWidget): """ A clickable widget displaying a progress bar made of colored rectangles and a label. @@ -386,6 +632,7 @@ class ProgressBubble(QWidget): self.layout = QVBoxLayout() self.label = QLabel(display_name) + self.base_text = display_name self.label.setAlignment(Qt.AlignmentFlag.AlignCenter) self.label.setStyleSheet(""" QLabel { @@ -418,7 +665,7 @@ class ProgressBubble(QWidget): self.setCursor(Qt.CursorShape.PointingHandCursor) # Resize policy to make bubbles responsive - # TODO: Not only do this once but when window is resized too + # TODO: Not only do this once but when window is resized too. Also just doesnt work self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum) def update_progress(self, step_index): @@ -435,6 +682,12 @@ class ProgressBubble(QWidget): self.clicked.emit(self) super().mousePressEvent(event) + def setSuffixText(self, suffix): + if suffix: + self.label.setText(f"{self.base_text} {suffix}") + else: + self.label.setText(self.base_text) + class ParamSection(QWidget): @@ -513,7 +766,7 @@ class ParamSection(QWidget): widget = QLineEdit() widget.setValidator(QDoubleValidator()) widget.setText(str(param["default"])) - else: # str or list, treat as text for now + else: widget = QLineEdit() widget.setText(str(param["default"])) @@ -523,7 +776,10 @@ class ParamSection(QWidget): h_layout.setStretch(2, 5) # Set stretch factor for input field (50%) layout.addLayout(h_layout) - self.widgets[param["name"]] = widget + self.widgets[param["name"]] = { + "widget": widget, + "type": param["type"] + } def show_help_popup(self, text): @@ -533,74 +789,1827 @@ class ParamSection(QWidget): msg.exec() + def get_param_values(self): + values = {} + for name, info in self.widgets.items(): + widget = info["widget"] + expected_type = info["type"] -class ViewerWindow(QWidget): - """ - Window displaying various fNIRS data visualization and analysis options via buttons. + if expected_type == bool: + values[name] = widget.currentText() == "True" + else: + raw_text = widget.text() + try: + if expected_type == int: + values[name] = int(raw_text) + elif expected_type == float: + values[name] = float(raw_text) + elif expected_type == str: + values[name] = raw_text + elif expected_type == list: + # Very basic CSV parsing - fix? + values[name] = [x.strip() for x in raw_text.split(",") if x.strip()] + else: + values[name] = raw_text # Fallback + except Exception as e: + raise ValueError(f"Invalid value for {name}: {raw_text}") from e - Args: - all_results (dict): Analysis results data. - all_haemo (dict): Haemodynamic data per subject. - all_figures (dict): Figures generated from the data. - config_snapshot (dict): Configuration snapshot used for analysis. - parent (QWidget, optional): Parent widget. Defaults to None. - """ - - def __init__(self, all_results, all_haemo, all_figures, config_snapshot, parent=None): - try: - super().__init__(parent, Qt.WindowType.Window) + return values + - self.all_results = all_results - self.all_haemo = all_haemo - self.all_figures = all_figures - self.config_snapshot = config_snapshot +class FullClickLineEdit(QLineEdit): + def mousePressEvent(self, event): + combo = self.parent() + if isinstance(combo, QComboBox): + combo.showPopup() + super().mousePressEvent(event) - if not self.all_haemo: - QMessageBox.critical(self, "Data Error", "No haemodynamic data available!") + +class FullClickComboBox(QComboBox): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setLineEdit(FullClickLineEdit(self)) + self.lineEdit().setReadOnly(True) + + + +class ParticipantViewerWidget(QWidget): + def __init__(self, haemo_dict, fig_bytes_dict): + super().__init__() + self.setWindowTitle("FLARES Participant Viewer") + self.haemo_dict = haemo_dict + self.fig_bytes_dict = fig_bytes_dict + + # Create mappings: file_path -> participant label and dropdown display text + self.participant_map = {} # file_path -> "Participant 1" + self.participant_dropdown_items = [] # "Participant 1 (filename)" + + for i, file_path in enumerate(self.haemo_dict.keys(), start=1): + short_label = f"Participant {i}" + display_label = f"{short_label} ({os.path.basename(file_path)})" + self.participant_map[file_path] = short_label + self.participant_dropdown_items.append(display_label) + + self.layout = QVBoxLayout(self) + self.top_bar = QHBoxLayout() + self.layout.addLayout(self.top_bar) + + self.participant_dropdown = self._create_multiselect_dropdown(self.participant_dropdown_items) + self.participant_dropdown.currentIndexChanged.connect(self.update_participant_dropdown_label) + + first_fig_dict = next(iter(self.fig_bytes_dict.values())) + image_label_items = list(first_fig_dict.keys()) + + self.image_index_dropdown = self._create_multiselect_dropdown(image_label_items) + self.image_index_dropdown.currentIndexChanged.connect(self.update_image_index_dropdown_label) + + self.submit_button = QPushButton("Submit") + self.submit_button.clicked.connect(self.show_selected_images) + + self.top_bar.addWidget(QLabel("Participants:")) + self.top_bar.addWidget(self.participant_dropdown) + self.top_bar.addWidget(QLabel("Image Indexes:")) + self.top_bar.addWidget(self.image_index_dropdown) + self.top_bar.addWidget(self.submit_button) + + self.scroll = QScrollArea() + self.scroll.setWidgetResizable(True) + self.scroll_content = QWidget() + self.grid_layout = QGridLayout(self.scroll_content) + self.scroll.setWidget(self.scroll_content) + self.layout.addWidget(self.scroll) + + self.thumb_size = QSize(280, 180) + + self.save_button = QPushButton("Save Displayed Images") + self.save_button.clicked.connect(self.save_displayed_images) + self.top_bar.addWidget(self.save_button) + + self.showMaximized() + + def _create_multiselect_dropdown(self, items): + combo = FullClickComboBox() + combo.setView(QListView()) + model = QStandardItemModel() + combo.setModel(model) + combo.setEditable(True) + combo.lineEdit().setReadOnly(True) + combo.lineEdit().setPlaceholderText("Select...") + + dummy_item = QStandardItem("") + dummy_item.setFlags(Qt.ItemIsEnabled) + model.appendRow(dummy_item) + + toggle_item = QStandardItem("Toggle Select All") + toggle_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) + toggle_item.setData(Qt.Unchecked, Qt.CheckStateRole) + model.appendRow(toggle_item) + + for item in items: + standard_item = QStandardItem(item) + standard_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) + standard_item.setData(Qt.Unchecked, Qt.CheckStateRole) + model.appendRow(standard_item) + + combo.setInsertPolicy(QComboBox.NoInsert) + + + def on_view_clicked(index): + item = model.itemFromIndex(index) + if item.isCheckable(): + new_state = Qt.Checked if item.checkState() == Qt.Unchecked else Qt.Unchecked + item.setCheckState(new_state) + + combo.view().pressed.connect(on_view_clicked) + + self._updating_checkstates = False + + def on_item_changed(item): + if self._updating_checkstates: return - - subjects_dir = resource_path("mne_data/MNE-sample-data/subjects") - os.environ["SUBJECTS_DIR"] = subjects_dir + self._updating_checkstates = True - # TODO: Thread all of this to not freeze main window - import fNIRS_module - fNIRS_module.set_config(self.config_snapshot, True) # Set globals in this process + normal_items = [model.item(i) for i in range(2, model.rowCount())] # skip dummy and toggle - layout = QVBoxLayout() - button_actions = [ - ("Show all Images", lambda: fNIRS_module.show_all_images(self.all_figures)), - ("save_all_images", lambda: fNIRS_module.save_all_images(self.all_figures)), - ("data_to_csv", lambda: fNIRS_module.data_to_csv(self.all_results)), - ("plot_2d_theta_graph", lambda: fNIRS_module.plot_2d_theta_graph(self.all_results)), - ("verify_channel_positions", lambda: fNIRS_module.verify_channel_positions(self.all_haemo[list(self.all_haemo.keys())[0]]["full_layout"])), - ("brain_landmarks_3d", lambda: fNIRS_module.brain_landmarks_3d(self.all_haemo[list(self.all_haemo.keys())[0]]["full_layout"], 'all')), - ("plot_2d_3d_contrasts_between_groups", lambda: fNIRS_module.plot_2d_3d_contrasts_between_groups(self.all_results, self.all_haemo, 'theta', 'all', True)), - ("brain_3d_visualization", lambda: fNIRS_module.brain_3d_visualization(self.all_results, self.all_haemo, 0, 't', 'all', True)), - ("plot_fir_model_results", lambda: fNIRS_module.plot_fir_model_results(self.all_results, self.all_haemo, 0, 'theta')), - ("plot_individual_theta_averages", lambda: fNIRS_module.plot_individual_theta_averages(self.all_results)), - ("plot_group_theta_averages", lambda: fNIRS_module.plot_group_theta_averages(self.all_results)), - ("plot_avg_significant_activity", lambda: fNIRS_module.plot_avg_significant_activity(self.all_haemo[list(self.all_haemo.keys())[0]]["full_layout"], self.all_results, 'theta')), - ("fold_channels", lambda: fNIRS_module.fold_channels(self.all_haemo[list(self.all_haemo.keys())[0]]["full_layout"].copy(), self.all_results, resource_path("mne_data/fOLD/fOLD-public-master/Supplementary"))), + if item == toggle_item: + all_checked = all(i.checkState() == Qt.Checked for i in normal_items) + if all_checked: + for i in normal_items: + i.setCheckState(Qt.Unchecked) + toggle_item.setCheckState(Qt.Unchecked) + else: + for i in normal_items: + i.setCheckState(Qt.Checked) + toggle_item.setCheckState(Qt.Checked) + + elif item == dummy_item: + pass + + else: + # When normal items change, update toggle item + all_checked = all(i.checkState() == Qt.Checked for i in normal_items) + toggle_item.setCheckState(Qt.Checked if all_checked else Qt.Unchecked) + + # Update label text immediately after change + if combo == self.participant_dropdown: + self.update_participant_dropdown_label() + elif combo == self.image_index_dropdown: + self.update_image_index_dropdown_label() + + self._updating_checkstates = False + + model.itemChanged.connect(on_item_changed) + + combo.setInsertPolicy(QComboBox.NoInsert) + return combo + + def _get_checked_items(self, combo): + checked = [] + model = combo.model() + for i in range(model.rowCount()): + item = model.item(i) + # Skip dummy and toggle items: + if item.text() in ("", "Toggle Select All"): + continue + if item.checkState() == Qt.Checked: + checked.append(item.text()) + return checked + + def update_participant_dropdown_label(self): + selected = self._get_checked_items(self.participant_dropdown) + if not selected: + self.participant_dropdown.lineEdit().setText("") + else: + # Extract just "Participant N" from "Participant N (filename)" + selected_short = [s.split(" ")[0] + " " + s.split(" ")[1] for s in selected] + self.participant_dropdown.lineEdit().setText(", ".join(selected_short)) + + + def update_image_index_dropdown_label(self): + selected = self._get_checked_items(self.image_index_dropdown) + if not selected: + self.image_index_dropdown.lineEdit().setText("") + else: + # Only show the index part + self.image_index_dropdown.lineEdit().setText(", ".join(selected)) + + + def show_selected_images(self): + # Clear previous images + for i in reversed(range(self.grid_layout.count())): + widget = self.grid_layout.itemAt(i).widget() + if widget: + widget.setParent(None) + + selected_display_names = self._get_checked_items(self.participant_dropdown) + # Map from display names back to file paths + selected_file_paths = [] + for display_name in selected_display_names: + # Find file_path by matching display name + for fp, short_label in self.participant_map.items(): + expected_display = f"{short_label} ({os.path.basename(fp)})" + if display_name == expected_display: + selected_file_paths.append(fp) + break + + selected_labels = self._get_checked_items(self.image_index_dropdown) + + row, col = 0, 0 + for file_path in selected_file_paths: + fig_list = self.fig_bytes_dict.get(file_path, []) + participant_label = self.participant_map[file_path] + for label in selected_labels: + fig_bytes = fig_list.get(label) + if not fig_bytes: + continue + + full_pixmap = QPixmap() + full_pixmap.loadFromData(fig_bytes) + + thumbnail_pixmap = full_pixmap.scaled( + self.thumb_size, + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation + ) + + container = QWidget() + hlayout = QHBoxLayout(container) + hlayout.setContentsMargins(0, 0, 0, 0) + hlayout.setSpacing(0) + hlayout.setAlignment(Qt.AlignmentFlag.AlignCenter) + + image_label = ClickableLabel(full_pixmap, thumbnail_pixmap) + image_label.setToolTip(f"{participant_label}\n{label}") + hlayout.addWidget(image_label) + + self.grid_layout.addWidget(container, row, col) + + col += 1 + if col >= 6: + col = 0 + row += 1 + + # Update dropdown labels after display + self.update_participant_dropdown_label() + self.update_image_index_dropdown_label() + + + def save_displayed_images(self): + # Ensure the folder exists + save_dir = Path("individual_images") + save_dir.mkdir(exist_ok=True) + + selected_display_names = self._get_checked_items(self.participant_dropdown) + selected_image_labels = self._get_checked_items(self.image_index_dropdown) + + for display_name in selected_display_names: + # Match display name to file path + for file_path, short_label in self.participant_map.items(): + expected_display = f"{short_label} ({os.path.basename(file_path)})" + if display_name == expected_display: + fig_dict = self.fig_bytes_dict.get(file_path, {}) + for label in selected_image_labels: + if label not in fig_dict: + continue + fig_bytes = fig_dict[label] + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"{os.path.basename(file_path)}_{label}_{timestamp}.png" + output_path = save_dir / filename + with open(output_path, "wb") as f: + f.write(fig_bytes) + break # file_path matched; stop loop + + QMessageBox.information(self, "Save Complete", f"Images saved to {save_dir.resolve()}") + + +class ClickableLabel(QLabel): + def __init__(self, full_pixmap: QPixmap, thumbnail_pixmap: QPixmap): + super().__init__() + self._pixmap_full = full_pixmap + self.setPixmap(thumbnail_pixmap) + self.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.setFixedSize(thumbnail_pixmap.size()) + self.setStyleSheet("border: 1px solid gray; margin: 2px;") + + def mousePressEvent(self, event): + viewer = QWidget() + viewer.setWindowTitle("Expanded View") + layout = QVBoxLayout(viewer) + label = QLabel() + label.setPixmap(self._pixmap_full) + label.setAlignment(Qt.AlignmentFlag.AlignCenter) + layout.addWidget(label) + viewer.resize(1000, 800) + viewer.show() + self._expanded_viewer = viewer # keep reference alive + + + +class ParticipantBrainViewerWidget(QWidget): + def __init__(self, haemo_dict, cha_dict): + super().__init__() + self.setWindowTitle("FLARES Participant Brain Viewer") + self.haemo_dict = haemo_dict + self.cha_dict = cha_dict + + # Create mappings: file_path -> participant label and dropdown display text + self.participant_map = {} # file_path -> "Participant 1" + self.participant_dropdown_items = [] # "Participant 1 (filename)" + + for i, file_path in enumerate(self.haemo_dict.keys(), start=1): + short_label = f"Participant {i}" + display_label = f"{short_label} ({os.path.basename(file_path)})" + self.participant_map[file_path] = short_label + self.participant_dropdown_items.append(display_label) + + self.layout = QVBoxLayout(self) + self.top_bar = QHBoxLayout() + self.layout.addLayout(self.top_bar) + + self.participant_dropdown = self._create_multiselect_dropdown(self.participant_dropdown_items) + self.participant_dropdown.currentIndexChanged.connect(self.update_participant_dropdown_label) + + self.event_dropdown = QComboBox() + self.event_dropdown.addItem("") + + + self.index_texts = [ + "0 (Brain Landmarks)", + "1 (Brain Activity Visualization)", + # "2 (third image)", + # "3 (fourth image)", ] - for text, func in button_actions: - btn = QPushButton(text) - btn.clicked.connect(self.make_safe_callback(func)) - layout.addWidget(btn) + self.image_index_dropdown = self._create_multiselect_dropdown(self.index_texts) + self.image_index_dropdown.currentIndexChanged.connect(self.update_image_index_dropdown_label) - self.setLayout(layout) + self.submit_button = QPushButton("Submit") + self.submit_button.clicked.connect(self.show_brain_images) - except Exception as e: - QMessageBox.critical(None, "Startup Error", f"ViewerWindow failed:\n{str(e)}") + self.top_bar.addWidget(QLabel("Participants:")) + self.top_bar.addWidget(self.participant_dropdown) + self.top_bar.addWidget(QLabel("Event:")) + self.top_bar.addWidget(self.event_dropdown) + self.top_bar.addWidget(QLabel("Image Indexes:")) + self.top_bar.addWidget(self.image_index_dropdown) + self.top_bar.addWidget(self.submit_button) + + self.scroll = QScrollArea() + self.scroll.setWidgetResizable(True) + self.scroll_content = QWidget() + self.grid_layout = QGridLayout(self.scroll_content) + self.scroll.setWidget(self.scroll_content) + self.layout.addWidget(self.scroll) + + self.thumb_size = QSize(280, 180) + self.showMaximized() + + def _create_multiselect_dropdown(self, items): + combo = FullClickComboBox() + combo.setView(QListView()) + model = QStandardItemModel() + combo.setModel(model) + combo.setEditable(True) + combo.lineEdit().setReadOnly(True) + combo.lineEdit().setPlaceholderText("Select...") + + + dummy_item = QStandardItem("") + dummy_item.setFlags(Qt.ItemIsEnabled) + model.appendRow(dummy_item) + + toggle_item = QStandardItem("Toggle Select All") + toggle_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) + toggle_item.setData(Qt.Unchecked, Qt.CheckStateRole) + model.appendRow(toggle_item) + + for item in items: + standard_item = QStandardItem(item) + standard_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) + standard_item.setData(Qt.Unchecked, Qt.CheckStateRole) + model.appendRow(standard_item) + + combo.setInsertPolicy(QComboBox.NoInsert) + + + def on_view_clicked(index): + item = model.itemFromIndex(index) + if item.isCheckable(): + new_state = Qt.Checked if item.checkState() == Qt.Unchecked else Qt.Unchecked + item.setCheckState(new_state) + + combo.view().pressed.connect(on_view_clicked) + + self._updating_checkstates = False + + def on_item_changed(item): + if self._updating_checkstates: + return + self._updating_checkstates = True + + normal_items = [model.item(i) for i in range(2, model.rowCount())] # skip dummy and toggle + + if item == toggle_item: + all_checked = all(i.checkState() == Qt.Checked for i in normal_items) + if all_checked: + for i in normal_items: + i.setCheckState(Qt.Unchecked) + toggle_item.setCheckState(Qt.Unchecked) + else: + for i in normal_items: + i.setCheckState(Qt.Checked) + toggle_item.setCheckState(Qt.Checked) + + elif item == dummy_item: + pass + + else: + # When normal items change, update toggle item + all_checked = all(i.checkState() == Qt.Checked for i in normal_items) + toggle_item.setCheckState(Qt.Checked if all_checked else Qt.Unchecked) + + # Update label text immediately after change + if combo == self.participant_dropdown: + self.update_participant_dropdown_label() + elif combo == self.image_index_dropdown: + self.update_image_index_dropdown_label() + + self._updating_checkstates = False + + model.itemChanged.connect(on_item_changed) + + combo.setInsertPolicy(QComboBox.NoInsert) + return combo + + def _get_checked_items(self, combo): + checked = [] + model = combo.model() + for i in range(model.rowCount()): + item = model.item(i) + # Skip dummy and toggle items: + if item.text() in ("", "Toggle Select All"): + continue + if item.checkState() == Qt.Checked: + checked.append(item.text()) + return checked + + def update_participant_dropdown_label(self): + selected = self._get_checked_items(self.participant_dropdown) + if not selected: + self.participant_dropdown.lineEdit().setText("") + else: + # Extract just "Participant N" from "Participant N (filename)" + selected_short = [s.split(" ")[0] + " " + s.split(" ")[1] for s in selected] + self.participant_dropdown.lineEdit().setText(", ".join(selected_short)) + self._update_event_dropdown() + + def update_image_index_dropdown_label(self): + selected = self._get_checked_items(self.image_index_dropdown) + if not selected: + self.image_index_dropdown.lineEdit().setText("") + else: + # Only show the index part + index_labels = [s.split(" ")[0] for s in selected] + self.image_index_dropdown.lineEdit().setText(", ".join(index_labels)) + + def _update_event_dropdown(self): + selected_display_names = self._get_checked_items(self.participant_dropdown) + selected_file_paths = [] + for display_name in selected_display_names: + for fp, short_label in self.participant_map.items(): + expected_display = f"{short_label} ({os.path.basename(fp)})" + if display_name == expected_display: + selected_file_paths.append(fp) + break + + if not selected_file_paths: + self.event_dropdown.clear() + self.event_dropdown.addItem("") + return + + annotation_sets = [] + + for file_path in selected_file_paths: + raw = self.haemo_dict.get(file_path) + if raw is None or not hasattr(raw, "annotations"): + continue + annotations = set(raw.annotations.description) + annotation_sets.append(annotations) + + if not annotation_sets: + self.event_dropdown.clear() + self.event_dropdown.addItem("") + return + + shared_annotations = set.intersection(*annotation_sets) + self.event_dropdown.clear() + self.event_dropdown.addItem("") + for ann in sorted(shared_annotations): + self.event_dropdown.addItem(ann) + + def show_brain_images(self): + import flares + + selected_event = self.event_dropdown.currentText() + if selected_event == "": + selected_event = None + + selected_display_names = self._get_checked_items(self.participant_dropdown) + selected_file_paths = [] + for display_name in selected_display_names: + for fp, short_label in self.participant_map.items(): + expected_display = f"{short_label} ({os.path.basename(fp)})" + if display_name == expected_display: + selected_file_paths.append(fp) + break + + selected_indexes = [ + int(s.split(" ")[0]) for s in self._get_checked_items(self.image_index_dropdown) + ] + + + parameterized_indexes = { + 0: [ + { + "key": "show_optodes", + "label": "Determine what is rendered above the brain. Valid values are 'sensors', 'labels', 'none', 'all'.", + "default": "all", + "type": str, + }, + { + "key": "show_brodmann", + "label": "Show common brodmann areas on the brain.", + "default": "True", + "type": bool, + } + ], + 1: [ + { + "key": "show_optodes", + "label": "Determine what is rendered above the brain. Valid values are 'sensors', 'labels', 'none', 'all'.", + "default": "all", + "type": str, + }, + { + "key": "t_or_theta", + "label": "Specify if t values or theta values should be plotted. Valid values are 't', 'theta'", + "default": "theta", + "type": str, + }, + { + "key": "show_text", + "label": "Display informative text on the top left corner. THIS DOES NOT WORK AND SHOULD BE LEFT AT FALSE", + "default": "False", + "type": bool, + }, + { + "key": "brain_bounds", + "label": "Graph Upper/Lower Limit", + "default": "1.0", + "type": float, + } + ], + } + + # Inject full_text from index_texts + for idx, params_list in parameterized_indexes.items(): + full_text = self.index_texts[idx] if idx < len(self.index_texts) else f"{idx} (No label found)" + for param_info in params_list: + param_info["full_text"] = full_text + + indexes_needing_params = {idx: parameterized_indexes[idx] for idx in selected_indexes if idx in parameterized_indexes} + + param_values = {} + if indexes_needing_params: + dialog = ParameterInputDialog(indexes_needing_params, parent=self) + if dialog.exec_() == QDialog.Accepted: + param_values = dialog.get_values() + if param_values is None: + return + else: + return + + # Pass the necessary arguments to each method + for file_path in selected_file_paths: + haemo_obj = self.haemo_dict.get(file_path) + + if haemo_obj is None: + continue + + cha = self.cha_dict.get(file_path) + + for idx in selected_indexes: + if idx == 0: + + params = param_values.get(idx, {}) + show_optodes = params.get("show_optodes", None) + show_brodmann = params.get("show_brodmann", None) + + if show_optodes is None or show_brodmann is None: + print(f"Missing parameters for index {idx}, skipping.") + continue + + flares.brain_landmarks_3d(haemo_obj, show_optodes, show_brodmann) + + elif idx == 1: + params = param_values.get(idx, {}) + show_optodes = params.get("show_optodes", None) + t_or_theta = params.get("t_or_theta", None) + show_text = params.get("show_text", None) + brain_bounds = params.get("brain_bounds", None) + + if show_optodes is None or t_or_theta is None or show_text is None or brain_bounds is None: + print(f"Missing parameters for index {idx}, skipping.") + continue + + flares.brain_3d_visualization(haemo_obj, cha, selected_event, t_or_theta=t_or_theta, show_optodes=show_optodes, show_text=show_text, brain_bounds=brain_bounds) + + else: + print(f"No method defined for index {idx}") + + + + +class ClickableLabel(QLabel): + def __init__(self, full_pixmap: QPixmap, thumbnail_pixmap: QPixmap): + super().__init__() + self._pixmap_full = full_pixmap + self.setPixmap(thumbnail_pixmap) + self.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.setFixedSize(thumbnail_pixmap.size()) + self.setStyleSheet("border: 1px solid gray; margin: 2px;") + + def mousePressEvent(self, event): + viewer = QWidget() + viewer.setWindowTitle("Expanded View") + layout = QVBoxLayout(viewer) + label = QLabel() + label.setPixmap(self._pixmap_full) + label.setAlignment(Qt.AlignmentFlag.AlignCenter) + layout.addWidget(label) + viewer.resize(1000, 800) + viewer.show() + self._expanded_viewer = viewer # keep reference alive + + + +class ParameterInputDialog(QDialog): + def __init__(self, params_dict, parent=None): + """ + params_dict format: + { + idx: [ + { + "key": "p_val", + "label": "Significance threshold P-value (e.g. 0.05)", + "default": "0.05", + "type": float, + }, + { + "key": "graph_scale", + "label": "Graph scale factor", + "default": "1", + "type": int, + } + ], + ... + } + """ + super().__init__(parent) + self.setWindowTitle("Input Parameters") + self.params_dict = params_dict + self.inputs = {} # {(idx, param_key): QLineEdit} + + layout = QVBoxLayout(self) + intro_label = QLabel( + "Some methods require parameters to continue:\n" + "Clicking OK will simply use default values if input is left empty." + ) + layout.addWidget(intro_label) + + for idx, param_list in params_dict.items(): + full_text = param_list[0].get('full_text', f"Index [{idx}]") + group_label = QLabel(f"{full_text} requires parameters:") + group_label.setStyleSheet("font-weight: bold; margin-top: 10px;") + layout.addWidget(group_label) + + for param_info in param_list: + label = QLabel(param_info["label"]) + layout.addWidget(label) + + line_edit = QLineEdit(self) + line_edit.setPlaceholderText(str(param_info.get("default", ""))) + layout.addWidget(line_edit) + + self.inputs[(idx, param_info["key"])] = line_edit + + # Buttons + btn_layout = QHBoxLayout() + ok_btn = QPushButton("OK", self) + cancel_btn = QPushButton("Cancel", self) + btn_layout.addWidget(ok_btn) + btn_layout.addWidget(cancel_btn) + layout.addLayout(btn_layout) + + ok_btn.clicked.connect(self.accept) + cancel_btn.clicked.connect(self.reject) + + def get_values(self): + """ + Validate and return values dict in form: + { + idx: { + param_key: value, + ... + }, + ... + } + Returns None if validation fails (error dialog shown). + """ + values = {} + for (idx, param_key), line_edit in self.inputs.items(): + text = line_edit.text().strip() + + # Find param info dict + param_info = None + for p in self.params_dict[idx]: + if p['key'] == param_key: + param_info = p + break + if param_info is None: + # This shouldn't happen, but just in case: + self._show_error(f"Internal error: No param info for index {idx} key '{param_key}'") + return None + + if not text: + text = str(param_info.get('default', '')) + + param_type = param_info.get('type', str) - def make_safe_callback(self, func): - def safe_func(): try: - func() - except Exception as e: - QMessageBox.critical(self, "Error", f"An error occurred:\n{str(e)}") - return safe_func + if param_type == int: + val = int(text) + elif param_type == float: + val = float(text) + elif param_type == bool: + # Convert common bool strings to bool + val_lower = text.lower() + if val_lower in ('true', '1', 'yes', 'y'): + val = True + elif val_lower in ('false', '0', 'no', 'n'): + val = False + else: + raise ValueError(f"Invalid bool value: {text}") + elif param_type == str: + val = text + else: + val = text # fallback + except (ValueError, TypeError): + self._show_error( + f"Invalid input for index {idx} parameter '{param_key}': '{text}'\n" + f"Expected type: {param_type.__name__}" + ) + return None + if idx not in values: + values[idx] = {} + values[idx][param_key] = val + + return values + + def _show_error(self, message): + error_box = QMessageBox(self) + error_box.setIcon(QMessageBox.Critical) + error_box.setWindowTitle("Input Error") + error_box.setText(message) + error_box.exec_() + + + + +class GroupViewerWidget(QWidget): + def __init__(self, haemo_dict, cha, df_ind, design_matrix, contrast_results, group): + super().__init__() + self.setWindowTitle("FLARES Group Viewer") + self.haemo_dict = haemo_dict + self.cha = cha + self.df_ind = df_ind + self.design_matrix = design_matrix + self.contrast_results = contrast_results + self.group = group + + + # Create mappings: file_path -> participant label and dropdown display text + self.participant_map = {} # file_path -> "Participant 1" + self.participant_dropdown_items = [] # "Participant 1 (filename)" + + for i, file_path in enumerate(self.haemo_dict.keys(), start=1): + short_label = f"Participant {i}" + display_label = f"{short_label} ({os.path.basename(file_path)})" + self.participant_map[file_path] = short_label + self.participant_dropdown_items.append(display_label) + + self.layout = QVBoxLayout(self) + self.top_bar = QHBoxLayout() + self.layout.addLayout(self.top_bar) + + self.group_to_paths = {} + for file_path, group_name in self.group.items(): + self.group_to_paths.setdefault(group_name, []).append(file_path) + + self.group_names = sorted(self.group_to_paths.keys()) + + self.group_dropdown = QComboBox() + self.group_dropdown.addItem("") + self.group_dropdown.addItems(self.group_names) + self.group_dropdown.setCurrentIndex(0) + self.group_dropdown.currentIndexChanged.connect(self.update_participant_list_for_group) + + self.participant_dropdown = self._create_multiselect_dropdown(self.participant_dropdown_items) + self.participant_dropdown.currentIndexChanged.connect(self.update_participant_dropdown_label) + self.participant_dropdown.setEnabled(False) + + self.event_dropdown = QComboBox() + self.event_dropdown.addItem("") + + self.index_texts = [ + "0 (GLM Results)", + "1 (Significance)", + # "2 (third_image)", + # "3 (fourth image)", + ] + + self.image_index_dropdown = self._create_multiselect_dropdown(self.index_texts) + self.image_index_dropdown.currentIndexChanged.connect(self.update_image_index_dropdown_label) + + self.submit_button = QPushButton("Submit") + self.submit_button.clicked.connect(self.show_brain_images) + + self.top_bar.addWidget(QLabel("Group:")) + self.top_bar.addWidget(self.group_dropdown) + self.top_bar.addWidget(QLabel("Participants:")) + self.top_bar.addWidget(self.participant_dropdown) + self.top_bar.addWidget(QLabel("Event:")) + self.top_bar.addWidget(self.event_dropdown) + self.top_bar.addWidget(QLabel("Image Indexes:")) + self.top_bar.addWidget(self.image_index_dropdown) + self.top_bar.addWidget(self.submit_button) + + self.scroll = QScrollArea() + self.scroll.setWidgetResizable(True) + self.scroll_content = QWidget() + self.grid_layout = QGridLayout(self.scroll_content) + self.scroll.setWidget(self.scroll_content) + self.layout.addWidget(self.scroll) + + self.thumb_size = QSize(280, 180) + self.showMaximized() + + def _create_multiselect_dropdown(self, items): + combo = FullClickComboBox() + combo.setView(QListView()) + model = QStandardItemModel() + combo.setModel(model) + combo.setEditable(True) + combo.lineEdit().setReadOnly(True) + combo.lineEdit().setPlaceholderText("Select...") + + + dummy_item = QStandardItem("") + dummy_item.setFlags(Qt.ItemIsEnabled) + model.appendRow(dummy_item) + + toggle_item = QStandardItem("Toggle Select All") + toggle_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) + toggle_item.setData(Qt.Unchecked, Qt.CheckStateRole) + model.appendRow(toggle_item) + + for item in items: + standard_item = QStandardItem(item) + standard_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) + standard_item.setData(Qt.Unchecked, Qt.CheckStateRole) + model.appendRow(standard_item) + + combo.setInsertPolicy(QComboBox.NoInsert) + + + def on_view_clicked(index): + item = model.itemFromIndex(index) + if item.isCheckable(): + new_state = Qt.Checked if item.checkState() == Qt.Unchecked else Qt.Unchecked + item.setCheckState(new_state) + + combo.view().pressed.connect(on_view_clicked) + + self._updating_checkstates = False + + def on_item_changed(item): + if self._updating_checkstates: + return + self._updating_checkstates = True + + normal_items = [model.item(i) for i in range(2, model.rowCount())] # skip dummy and toggle + + if item == toggle_item: + all_checked = all(i.checkState() == Qt.Checked for i in normal_items) + if all_checked: + for i in normal_items: + i.setCheckState(Qt.Unchecked) + toggle_item.setCheckState(Qt.Unchecked) + else: + for i in normal_items: + i.setCheckState(Qt.Checked) + toggle_item.setCheckState(Qt.Checked) + + elif item == dummy_item: + pass + + else: + # When normal items change, update toggle item + all_checked = all(i.checkState() == Qt.Checked for i in normal_items) + toggle_item.setCheckState(Qt.Checked if all_checked else Qt.Unchecked) + + # Update label text immediately after change + if combo == self.participant_dropdown: + self.update_participant_dropdown_label() + elif combo == self.image_index_dropdown: + self.update_image_index_dropdown_label() + + self._updating_checkstates = False + + model.itemChanged.connect(on_item_changed) + + combo.setInsertPolicy(QComboBox.NoInsert) + return combo + + def _get_checked_items(self, combo): + checked = [] + model = combo.model() + for i in range(model.rowCount()): + item = model.item(i) + # Skip dummy and toggle items: + if item.text() in ("", "Toggle Select All"): + continue + if item.checkState() == Qt.Checked: + checked.append(item.text()) + return checked + + + def update_participant_list_for_group(self): + selected_group = self.group_dropdown.currentText() + model = self.participant_dropdown.model() + model.clear() + self.participant_map.clear() + + # Add dummy and toggle select all items again + dummy_item = QStandardItem("") + dummy_item.setFlags(Qt.ItemIsEnabled) + model.appendRow(dummy_item) + + toggle_item = QStandardItem("Toggle Select All") + toggle_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) + toggle_item.setData(Qt.Unchecked, Qt.CheckStateRole) + model.appendRow(toggle_item) + + if selected_group == "": + # Disable participant dropdown when no group selected + self.participant_dropdown.setEnabled(False) + self.update_participant_dropdown_label() + return + + # Enable participant dropdown since a valid group is selected + self.participant_dropdown.setEnabled(True) + + group_file_paths = self.group_to_paths.get(selected_group, []) + for i, file_path in enumerate(group_file_paths, start=1): + short_label = f"Participant {i}" + display_label = f"{short_label} ({os.path.basename(file_path)})" + self.participant_map[file_path] = short_label + + item = QStandardItem(display_label) + item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) + item.setData(Qt.Unchecked, Qt.CheckStateRole) + model.appendRow(item) + + self._connect_select_all_toggle(toggle_item, model) + self.update_participant_dropdown_label() + + + def _connect_select_all_toggle(self, toggle_item, model): + """Helper function to connect the Select All functionality.""" + normal_items = [model.item(i) for i in range(2, model.rowCount())] # skip dummy and toggle + + def on_item_changed(item): + if self._updating_checkstates: + return + self._updating_checkstates = True + + if item == toggle_item: + all_checked = all(i.checkState() == Qt.Checked for i in normal_items) + if all_checked: + for i in normal_items: + i.setCheckState(Qt.Unchecked) + toggle_item.setCheckState(Qt.Unchecked) + else: + for i in normal_items: + i.setCheckState(Qt.Checked) + toggle_item.setCheckState(Qt.Checked) + + else: + # When normal items change, update toggle item + all_checked = all(i.checkState() == Qt.Checked for i in normal_items) + toggle_item.setCheckState(Qt.Checked if all_checked else Qt.Unchecked) + + # Update label text immediately after change + if self.participant_dropdown: + self.update_participant_dropdown_label() + + self._updating_checkstates = False + + model.itemChanged.connect(on_item_changed) + + def update_participant_dropdown_label(self): + selected = self._get_checked_items(self.participant_dropdown) + if not selected: + self.participant_dropdown.lineEdit().setText("") + else: + # Extract just "Participant N" from "Participant N (filename)" + selected_short = [s.split(" ")[0] + " " + s.split(" ")[1] for s in selected] + self.participant_dropdown.lineEdit().setText(", ".join(selected_short)) + self._update_event_dropdown() + + + def update_image_index_dropdown_label(self): + selected = self._get_checked_items(self.image_index_dropdown) + if not selected: + self.image_index_dropdown.lineEdit().setText("") + else: + # Only show the index part + index_labels = [s.split(" ")[0] for s in selected] + self.image_index_dropdown.lineEdit().setText(", ".join(index_labels)) + + + def _update_event_dropdown(self): + selected_display_names = self._get_checked_items(self.participant_dropdown) + selected_file_paths = [] + for display_name in selected_display_names: + for fp, short_label in self.participant_map.items(): + expected_display = f"{short_label} ({os.path.basename(fp)})" + if display_name == expected_display: + selected_file_paths.append(fp) + break + + if not selected_file_paths: + self.event_dropdown.clear() + self.event_dropdown.addItem("") + return + + annotation_sets = [] + + for file_path in selected_file_paths: + raw = self.haemo_dict.get(file_path) + if raw is None or not hasattr(raw, "annotations"): + continue + annotations = set(raw.annotations.description) + annotation_sets.append(annotations) + + if not annotation_sets: + self.event_dropdown.clear() + self.event_dropdown.addItem("") + return + + shared_annotations = set.intersection(*annotation_sets) + self.event_dropdown.clear() + self.event_dropdown.addItem("") + for ann in sorted(shared_annotations): + self.event_dropdown.addItem(ann) + + def show_brain_images(self): + import flares + + selected_event = self.event_dropdown.currentText() + if selected_event == "": + selected_event = None + + selected_display_names = self._get_checked_items(self.participant_dropdown) + selected_file_paths = [] + for display_name in selected_display_names: + for fp, short_label in self.participant_map.items(): + expected_display = f"{short_label} ({os.path.basename(fp)})" + if display_name == expected_display: + selected_file_paths.append(fp) + break + + selected_indexes = [ + int(s.split(" ")[0]) for s in self._get_checked_items(self.image_index_dropdown) + ] + + if not selected_file_paths: + print("No participants selected.") + return + + # Only keep indexes 0 and 1 that need parameters + parameterized_indexes = { + 0: [ + { + "key": "lower_bound", + "label": "Lower bound + ", + "default": "-0.3", + "type": float, # specify int here + }, + { + "key": "upper_bound", + "label": "Upper bound + ", + "default": "0.8", + "type": float, # specify int here + } + ], + 1: [ + { + "key": "p_value", + "label": "Significance threshold P-value (e.g. 0.05)", + "default": "0.05", + "type": float, + }, + { + "key": "graph_bounds", + "label": "Graph Upper/Lower Limit", + "default": "3.0", + "type": float, + } + ], + } + + # Inject full_text from index_texts + for idx, params_list in parameterized_indexes.items(): + full_text = self.index_texts[idx] if idx < len(self.index_texts) else f"{idx} (No label found)" + for param_info in params_list: + param_info["full_text"] = full_text + + indexes_needing_params = {idx: parameterized_indexes[idx] for idx in selected_indexes if idx in parameterized_indexes} + + param_values = {} + if indexes_needing_params: + dialog = ParameterInputDialog(indexes_needing_params, parent=self) + if dialog.exec_() == QDialog.Accepted: + param_values = dialog.get_values() + if param_values is None: + return + else: + return + + + all_cha = pd.DataFrame() + for fp in selected_file_paths: + cha_df = self.cha.get(fp) + if cha_df is not None: + all_cha = pd.concat([all_cha, cha_df], ignore_index=True) + + # Pass the necessary arguments to each method + file_path = selected_file_paths[0] + p_haemo = self.haemo_dict.get(file_path) + p_design_matrix = self.design_matrix.get(file_path) + + df_group = pd.DataFrame() + + if selected_file_paths: + for file_path in selected_file_paths: + df = self.df_ind.get(file_path) + if df is not None: + df_group = pd.concat([df_group, df], ignore_index=True) + + + for idx in selected_indexes: + if idx == 0: + params = param_values.get(idx, {}) + lower_bound = params.get("lower_bound", None) + upper_bound = params.get("upper_bound", None) + + if lower_bound is None or upper_bound is None: + print(f"Missing parameters for index {idx}, skipping.") + continue + + + flares.plot_fir_model_results(df_group, p_haemo, p_design_matrix, selected_event, lower_bound, upper_bound) + + elif idx == 1: + params = param_values.get(idx, {}) + p_val = params.get("p_value", None) + graph_bounds = params.get("graph_bounds", None) + + if p_val is None or graph_bounds is None: + print(f"Missing parameters for index {idx}, skipping.") + continue + + all_contrasts = [] + for fp in selected_file_paths: + condition_dfs = self.contrast_results.get(fp, {}) + if selected_event in condition_dfs: + df = condition_dfs[selected_event].copy() + df["ID"] = fp + all_contrasts.append(df) + + if not all_contrasts: + print("No contrast data found for selected participants and event.") + return + + df_contrasts = pd.concat(all_contrasts, ignore_index=True) + flares.run_second_level_analysis(df_contrasts, p_haemo, p_val, graph_bounds) + + elif idx == 3: + pass + + else: + print(f"No method defined for index {idx}") + + + + + +class GroupBrainViewerWidget(QWidget): + def __init__(self, haemo_dict, df_ind, design_matrix, group, contrast_results_dict): + super().__init__() + self.setWindowTitle("Group Brain Viewer") + self.haemo_dict = haemo_dict + self.df_ind = df_ind + self.design_matrix = design_matrix + self.group = group + self.contrast_results_dict = contrast_results_dict + + self.group_to_paths = {} + for file_path, group_name in self.group.items(): + self.group_to_paths.setdefault(group_name, []).append(file_path) + + self.group_names = sorted(self.group_to_paths.keys()) + + self.layout = QVBoxLayout(self) + self.top_bar = QHBoxLayout() + self.layout.addLayout(self.top_bar) + + + self.group_a_dropdown = QComboBox() + self.group_a_dropdown.addItem("") + self.group_a_dropdown.addItems(self.group_names) + self.group_a_dropdown.currentIndexChanged.connect(self._update_group_a_options) + + + self.group_b_dropdown = QComboBox() + self.group_b_dropdown.addItem("") + self.group_b_dropdown.addItems(self.group_names) + self.group_b_dropdown.currentIndexChanged.connect(self._update_group_b_options) + + + self.event_dropdown = QComboBox() + self.event_dropdown.addItem("") + + self.participant_dropdown_a = self._create_multiselect_dropdown([]) + self.participant_dropdown_a.lineEdit().setPlaceholderText("Select participants (Group A)") + self.participant_dropdown_a.model().itemChanged.connect(self._on_participants_changed) + + + self.participant_dropdown_b = self._create_multiselect_dropdown([]) + self.participant_dropdown_b.lineEdit().setPlaceholderText("Select participants (Group B)") + self.participant_dropdown_b.model().itemChanged.connect(self._on_participants_changed) + + + self.index_texts = [ + "0 (Contrast Image)", + # "1 (3D Brain Contrast)", + # "2 (third image)", + # "3 (fourth image)", + ] + self.image_index_dropdown = self._create_multiselect_dropdown(self.index_texts) + self.image_index_dropdown.currentIndexChanged.connect(self.update_image_index_dropdown_label) + + + self.submit_button = QPushButton("Submit") + self.submit_button.clicked.connect(self.show_brain_images) + + + self.top_bar.addWidget(QLabel("Group A:")) + self.top_bar.addWidget(self.group_a_dropdown) + self.top_bar.addWidget(QLabel("Participants (Group A):")) + self.top_bar.addWidget(self.participant_dropdown_a) + self.top_bar.addWidget(QLabel("Group B:")) + self.top_bar.addWidget(self.group_b_dropdown) + self.top_bar.addWidget(QLabel("Participants (Group B):")) + self.top_bar.addWidget(self.participant_dropdown_b) + self.top_bar.addWidget(QLabel("Event:")) + self.top_bar.addWidget(self.event_dropdown) + self.top_bar.addWidget(QLabel("Image Indexes:")) + self.top_bar.addWidget(self.image_index_dropdown) + self.top_bar.addWidget(self.submit_button) + + self.scroll = QScrollArea() + self.scroll.setWidgetResizable(True) + self.scroll_content = QWidget() + self.grid_layout = QGridLayout(self.scroll_content) + self.scroll.setWidget(self.scroll_content) + self.layout.addWidget(self.scroll) + + self.thumb_size = QSize(280, 180) + self.showMaximized() + + def _update_group_b_options(self): + selected = self.group_a_dropdown.currentText() + self._refresh_group_dropdown(self.group_b_dropdown, exclude=selected) + self._update_event_dropdown() + group_b = self.group_b_dropdown.currentText() + self.update_participant_list_for_group(group_b, self.participant_dropdown_b) + + def _update_group_a_options(self): + selected = self.group_b_dropdown.currentText() + self._refresh_group_dropdown(self.group_a_dropdown, exclude=selected) + self._update_event_dropdown() + group_a = self.group_a_dropdown.currentText() + self.update_participant_list_for_group(group_a, self.participant_dropdown_a) + + + + def update_participant_list_for_group(self, group_name: str, participant_dropdown: FullClickComboBox): + model = participant_dropdown.model() + model.clear() + + # Maintain separate participant maps for A and B to avoid conflicts + if participant_dropdown == self.participant_dropdown_a: + participant_map = self.participant_map_a = {} + else: + participant_map = self.participant_map_b = {} + + # Add dummy and toggle select all items again + dummy_item = QStandardItem("") + dummy_item.setFlags(Qt.ItemIsEnabled) + model.appendRow(dummy_item) + + toggle_item = QStandardItem("Toggle Select All") + toggle_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) + toggle_item.setData(Qt.Unchecked, Qt.CheckStateRole) + model.appendRow(toggle_item) + + if group_name == "": + participant_dropdown.setEnabled(False) + self._update_participant_dropdown_label(participant_dropdown) + return + + participant_dropdown.setEnabled(True) + + group_file_paths = self.group_to_paths.get(group_name, []) + for i, file_path in enumerate(group_file_paths, start=1): + short_label = f"Participant {i}" + display_label = f"{short_label} ({os.path.basename(file_path)})" + participant_map[file_path] = short_label + + item = QStandardItem(display_label) + item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) + item.setData(Qt.Unchecked, Qt.CheckStateRole) + model.appendRow(item) + + self._connect_select_all_toggle(toggle_item, model) + self._update_participant_dropdown_label(participant_dropdown) + + def _update_participant_dropdown_label(self, participant_dropdown): + selected = self._get_checked_items(participant_dropdown) + if not selected: + participant_dropdown.lineEdit().setText("") + else: + # Extract just "Participant N" from "Participant N (filename)" + selected_short = [s.split(" ")[0] + " " + s.split(" ")[1] for s in selected] + participant_dropdown.lineEdit().setText(", ".join(selected_short)) + self._update_event_dropdown() + + def _connect_select_all_toggle(self, toggle_item, model): + normal_items = [model.item(i) for i in range(2, model.rowCount())] # skip dummy and toggle + + def on_item_changed(item): + if getattr(self, "_updating_checkstates", False): + return + self._updating_checkstates = True + + if item == toggle_item: + all_checked = all(i.checkState() == Qt.Checked for i in normal_items) + if all_checked: + for i in normal_items: + i.setCheckState(Qt.Unchecked) + toggle_item.setCheckState(Qt.Unchecked) + else: + for i in normal_items: + i.setCheckState(Qt.Checked) + toggle_item.setCheckState(Qt.Checked) + else: + all_checked = all(i.checkState() == Qt.Checked for i in normal_items) + toggle_item.setCheckState(Qt.Checked if all_checked else Qt.Unchecked) + + # Update label text for participant dropdowns + if hasattr(self, 'participant_dropdown_a') and model == self.participant_dropdown_a.model(): + self._update_participant_dropdown_label(self.participant_dropdown_a) + elif hasattr(self, 'participant_dropdown_b') and model == self.participant_dropdown_b.model(): + self._update_participant_dropdown_label(self.participant_dropdown_b) + + self._updating_checkstates = False + + model.itemChanged.connect(on_item_changed) + + def _on_participants_changed(self, item=None): + self._update_event_dropdown() + + + + def _update_event_dropdown(self): + participants_a = self._get_checked_items(self.participant_dropdown_a) + participants_b = self._get_checked_items(self.participant_dropdown_b) + + if not participants_a or not participants_b: + self.event_dropdown.clear() + self.event_dropdown.addItem("") + return + + selected_file_paths_a = [ + fp for display_name in participants_a + for fp, short_label in self.participant_map_a.items() + if display_name == f"{short_label} ({os.path.basename(fp)})" + ] + + selected_file_paths_b = [ + fp for display_name in participants_b + for fp, short_label in self.participant_map_b.items() + if display_name == f"{short_label} ({os.path.basename(fp)})" + ] + + all_selected_file_paths = set(selected_file_paths_a + selected_file_paths_b) + + if not all_selected_file_paths: + self.event_dropdown.clear() + self.event_dropdown.addItem("") + return + + annotation_sets = [] + for file_path in all_selected_file_paths: + raw = self.haemo_dict.get(file_path) + if raw is None or not hasattr(raw, "annotations"): + continue + annotation_sets.append(set(raw.annotations.description)) + + if not annotation_sets: + self.event_dropdown.clear() + self.event_dropdown.addItem("") + return + + shared_annotations = set.intersection(*annotation_sets) + self.event_dropdown.clear() + self.event_dropdown.addItem("") + for ann in sorted(shared_annotations): + self.event_dropdown.addItem(ann) + + def _refresh_group_dropdown(self, dropdown, exclude): + current = dropdown.currentText() + dropdown.blockSignals(True) + dropdown.clear() + dropdown.addItem("") + for group in self.group_names: + if group != exclude: + dropdown.addItem(group) + # Restore previous selection if still valid + if current != "" and current != exclude and dropdown.findText(current) != -1: + dropdown.setCurrentText(current) + else: + dropdown.setCurrentIndex(0) # Reset to "" + dropdown.blockSignals(False) + + def _create_multiselect_dropdown(self, items): + combo = FullClickComboBox() + combo.setView(QListView()) + model = QStandardItemModel() + combo.setModel(model) + combo.setEditable(True) + combo.lineEdit().setReadOnly(True) + combo.lineEdit().setPlaceholderText("Select...") + + dummy_item = QStandardItem("") + dummy_item.setFlags(Qt.ItemIsEnabled) + model.appendRow(dummy_item) + + toggle_item = QStandardItem("Toggle Select All") + toggle_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) + toggle_item.setData(Qt.Unchecked, Qt.CheckStateRole) + model.appendRow(toggle_item) + + for item in items: + standard_item = QStandardItem(item) + standard_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) + standard_item.setData(Qt.Unchecked, Qt.CheckStateRole) + model.appendRow(standard_item) + + def on_view_clicked(index): + item = model.itemFromIndex(index) + if item.isCheckable(): + new_state = Qt.Checked if item.checkState() == Qt.Unchecked else Qt.Unchecked + item.setCheckState(new_state) + + combo.view().pressed.connect(on_view_clicked) + + self._updating_checkstates = False + + def on_item_changed(item): + if self._updating_checkstates: + return + self._updating_checkstates = True + + normal_items = [model.item(i) for i in range(2, model.rowCount())] + + if item == toggle_item: + all_checked = all(i.checkState() == Qt.Checked for i in normal_items) + for i in normal_items: + i.setCheckState(Qt.Unchecked if all_checked else Qt.Checked) + toggle_item.setCheckState(Qt.Unchecked if all_checked else Qt.Checked) + + toggle_item.setCheckState(Qt.Checked if all(i.checkState() == Qt.Checked for i in normal_items) else Qt.Unchecked) + + self.update_image_index_dropdown_label() + self._updating_checkstates = False + + model.itemChanged.connect(on_item_changed) + return combo + + def _get_checked_items(self, combo): + checked = [] + model = combo.model() + for i in range(model.rowCount()): + item = model.item(i) + if item.text() in ("", "Toggle Select All"): + continue + if item.checkState() == Qt.Checked: + checked.append(item.text()) + return checked + + def update_image_index_dropdown_label(self): + selected = self._get_checked_items(self.image_index_dropdown) + if not selected: + self.image_index_dropdown.lineEdit().setText("") + else: + index_labels = [s.split(" ")[0] for s in selected] + self.image_index_dropdown.lineEdit().setText(", ".join(index_labels)) + + + + def _connect_select_all_toggle(self, toggle_item, model): + """Helper function to connect the Select All functionality.""" + normal_items = [model.item(i) for i in range(2, model.rowCount())] # skip dummy and toggle + + def on_item_changed(item): + if self._updating_checkstates: + return + self._updating_checkstates = True + + if item == toggle_item: + all_checked = all(i.checkState() == Qt.Checked for i in normal_items) + if all_checked: + for i in normal_items: + i.setCheckState(Qt.Unchecked) + toggle_item.setCheckState(Qt.Unchecked) + else: + for i in normal_items: + i.setCheckState(Qt.Checked) + toggle_item.setCheckState(Qt.Checked) + + else: + # When normal items change, update toggle item + all_checked = all(i.checkState() == Qt.Checked for i in normal_items) + toggle_item.setCheckState(Qt.Checked if all_checked else Qt.Unchecked) + + # Update label text immediately after change + if hasattr(self, 'participant_dropdown_a') and model == self.participant_dropdown_a.model(): + self._update_participant_dropdown_label(self.participant_dropdown_a) + elif hasattr(self, 'participant_dropdown_b') and model == self.participant_dropdown_b.model(): + self._update_participant_dropdown_label(self.participant_dropdown_b) + + self._updating_checkstates = False + + model.itemChanged.connect(on_item_changed) + + def update_participant_dropdown_label(self): + selected = self._get_checked_items(self.participant_dropdown) + if not selected: + self.participant_dropdown.lineEdit().setText("") + else: + # Extract just "Participant N" from "Participant N (filename)" + selected_short = [s.split(" ")[0] + " " + s.split(" ")[1] for s in selected] + self.participant_dropdown.lineEdit().setText(", ".join(selected_short)) + + + def update_image_index_dropdown_label(self): + selected = self._get_checked_items(self.image_index_dropdown) + if not selected: + self.image_index_dropdown.lineEdit().setText("") + else: + # Only show the index part + index_labels = [s.split(" ")[0] for s in selected] + self.image_index_dropdown.lineEdit().setText(", ".join(index_labels)) + + def _get_file_paths_from_labels(self, labels, group_name): + file_paths = [] + + if group_name == self.group_a_dropdown.currentText(): + participant_map = self.participant_map_a + elif group_name == self.group_b_dropdown.currentText(): + participant_map = self.participant_map_b + else: + return [] + + # Reverse map: display label -> file path + reverse_map = { + f"{label} ({os.path.basename(fp)})": fp + for fp, label in participant_map.items() + } + + for label in labels: + file_path = reverse_map.get(label) + if file_path: + file_paths.append(file_path) + + return file_paths + + def show_brain_images(self): + import flares + + selected_event = self.event_dropdown.currentText() + if selected_event == "": + selected_event = None + + # Group A + participants_a = self._get_checked_items(self.participant_dropdown_a) + file_paths_a = self._get_file_paths_from_labels(participants_a, self.group_a_dropdown.currentText()) + + # Group B + participants_b = self._get_checked_items(self.participant_dropdown_b) + file_paths_b = self._get_file_paths_from_labels(participants_b, self.group_b_dropdown.currentText()) + + selected_indexes = [ + int(s.split(" ")[0]) for s in self._get_checked_items(self.image_index_dropdown) + ] + + + parameterized_indexes = { + 0: [ + { + "key": "show_optodes", + "label": "Determine what is rendered above the brain. Valid values are 'sensors', 'labels', 'none', 'all'.", + "default": "all", + "type": str, + }, + { + "key": "t_or_theta", + "label": "Specify if t values or theta values should be plotted. Valid values are 't', 'theta'", + "default": "theta", + "type": str, + }, + { + "key": "show_text", + "label": "Display informative text on the top left corner about the contrast.", + "default": "True", + "type": bool, + }, + { + "key": "brain_bounds", + "label": "Graph Upper/Lower Limit", + "default": "1.0", + "type": float, + }, + { + "key": "is_3d", + "label": "Should we display the results in a 3D interactive window?", + "default": "True", + "type": bool, + } + ], + } + + + # Inject full_text from index_texts + for idx, params_list in parameterized_indexes.items(): + full_text = self.index_texts[idx] if idx < len(self.index_texts) else f"{idx} (No label found)" + for param_info in params_list: + param_info["full_text"] = full_text + + indexes_needing_params = {idx: parameterized_indexes[idx] for idx in selected_indexes if idx in parameterized_indexes} + + param_values = {} + if indexes_needing_params: + dialog = ParameterInputDialog(indexes_needing_params, parent=self) + if dialog.exec_() == QDialog.Accepted: + param_values = dialog.get_values() + if param_values is None: + return + else: + return + + # Build group-level contrast DataFrames + def concat_group_contrasts(file_paths: list[str], event: str | None) -> pd.DataFrame: + group_df = pd.DataFrame() + for fp in file_paths: + print(f"Looking up contrast for: {fp}") + event_con_dict = self.contrast_results_dict.get(fp, {}) + print("Available events for this file:", list(event_con_dict.keys())) + if event and event in event_con_dict: + df = event_con_dict[event] + print(f"Appending contrast df for event: {event}") + group_df = pd.concat([group_df, df], ignore_index=True) + else: + print(f"Event '{event}' not found for {fp}") + return group_df + + print("Selected event:", selected_event) + print("File paths A:", file_paths_a) + print("File paths B:", file_paths_b) + + contrast_df_a = concat_group_contrasts(file_paths_a, selected_event) + contrast_df_b = concat_group_contrasts(file_paths_b, selected_event) + + print("contrast_df_a empty?", contrast_df_a.empty) + print("contrast_df_b empty?", contrast_df_b.empty) + + # Get one person for their layout + rep_raw = None + for fp in file_paths_a + file_paths_b: + rep_raw = self.haemo_dict.get(fp) + if rep_raw: + break + + print(rep_raw) + + # Visualizations + for idx in selected_indexes: + if idx == 0: + params = param_values.get(idx, {}) + show_optodes = params.get("show_optodes", None) + t_or_theta = params.get("t_or_theta", None) + show_text = params.get("show_text", None) + brain_bounds = params.get("brain_bounds", None) + is_3d = params.get("is_3d", None) + + if show_optodes is None or t_or_theta is None or show_text is None or brain_bounds is None or is_3d is None: + print(f"Missing parameters for index {idx}, skipping.") + continue + + if not contrast_df_a.empty and not contrast_df_b.empty and rep_raw: + + flares.plot_2d_3d_contrasts_between_groups( + contrast_df_a, + contrast_df_b, + raw_haemo=rep_raw, + group_a_name=self.group_a_dropdown.currentText(), + group_b_name=self.group_b_dropdown.currentText(), + is_3d=is_3d, + t_or_theta=t_or_theta, + show_optodes=show_optodes, + show_text=show_text, + brain_bounds=brain_bounds + ) + else: + print("no") + + + + +class ViewerLauncherWidget(QWidget): + def __init__(self, haemo_dict, fig_bytes_dict, cha_dict, contrast_results_dict, df_ind, design_matrix, group): + super().__init__() + self.setWindowTitle("Viewer Launcher") + + layout = QVBoxLayout(self) + + btn1 = QPushButton("Open Participant Viewer") + btn1.clicked.connect(lambda: self.open_participant_viewer(haemo_dict, fig_bytes_dict)) + + btn2 = QPushButton("Open Participant Brain Viewer") + btn2.clicked.connect(lambda: self.open_participant_brain_viewer(haemo_dict, cha_dict)) + + btn3 = QPushButton("Open Inter-Group Viewer") + btn3.clicked.connect(lambda: self.open_group_viewer(haemo_dict, cha_dict, df_ind, design_matrix, contrast_results_dict, group)) + + btn4 = QPushButton("Open Cross Group Brain Viewer") + btn4.clicked.connect(lambda: self.open_group_brain_viewer(haemo_dict, df_ind, design_matrix, group, contrast_results_dict)) + + + layout.addWidget(btn1) + layout.addWidget(btn2) + layout.addWidget(btn3) + layout.addWidget(btn4) + + def open_participant_viewer(self, haemo_dict, fig_bytes_dict): + self.participant_viewer = ParticipantViewerWidget(haemo_dict, fig_bytes_dict) + self.participant_viewer.show() + + def open_participant_brain_viewer(self, haemo_dict, cha_dict): + self.participant_brain_viewer = ParticipantBrainViewerWidget(haemo_dict, cha_dict) + self.participant_brain_viewer.show() + + def open_group_viewer(self, haemo_dict, cha_dict, df_ind, design_matrix, contrast_results_dict, group): + self.participant_brain_viewer = GroupViewerWidget(haemo_dict, cha_dict, df_ind, design_matrix, contrast_results_dict, group) + self.participant_brain_viewer.show() + + def open_group_brain_viewer(self, haemo_dict, df_ind, design_matrix, group, contrast_results_dict): + self.participant_brain_viewer = GroupBrainViewerWidget(haemo_dict, df_ind, design_matrix, group, contrast_results_dict) + self.participant_brain_viewer.show() class MainApplication(QMainWindow): @@ -617,6 +2626,7 @@ class MainApplication(QMainWindow): self.about = None self.help = None + self.optodes = None self.bubble_widgets = {} self.param_sections = [] self.folder_paths = [] @@ -627,11 +2637,12 @@ class MainApplication(QMainWindow): self.platform_suffix = "-" + PLATFORM_NAME self.pending_update_version = None self.pending_update_path = None - + self.last_clicked_bubble = None + self.installEventFilter(self) + self.file_metadata = {} self.current_file = None - # Start local pending update check thread self.local_check_thread = LocalPendingUpdateCheckThread(CURRENT_VERSION, self.platform_suffix) self.local_check_thread.pending_update_found.connect(self.on_pending_update_found) @@ -654,8 +2665,8 @@ class MainApplication(QMainWindow): left_container.setLayout(left_layout) left_container.setMinimumWidth(300) - top_left_container = QGroupBox() # Optional: add a title inside parentheses - top_left_container.setTitle("File information") # Optional visible title + top_left_container = QGroupBox() + top_left_container.setTitle("File information") top_left_container.setStyleSheet("QGroupBox { font-weight: bold; }") # Style if needed top_left_container.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding) @@ -676,14 +2687,12 @@ class MainApplication(QMainWindow): right_column_layout = QVBoxLayout() self.right_column_widget.setLayout(right_column_layout) - self.meta_fields = { - "Age": QLineEdit(), - "Gender": QLineEdit(), + "AGE": QLineEdit(), + "GENDER": QLineEdit(), + "GROUP": QLineEdit(), } - - # Inside your top-left container's right column layout: for key, field in self.meta_fields.items(): label = QLabel(key.capitalize()) field.setPlaceholderText(f"Enter {key}") @@ -700,7 +2709,7 @@ class MainApplication(QMainWindow): # Add top_left_container to the main left_layout left_layout.addWidget(top_left_container) - # Bottom left: the bubbles inside a scroll area + # Bottom left: the bubbles inside the scroll area self.bubble_container = QWidget() self.bubble_layout = QGridLayout() self.bubble_layout.setAlignment(Qt.AlignmentFlag.AlignTop) @@ -716,7 +2725,7 @@ class MainApplication(QMainWindow): self.progress_update_signal.connect(self.update_file_progress) - # Right widget (full height on right side) — example QTextEdit + # Right widget (full height on right side) self.right_container = QWidget() right_container_layout = QVBoxLayout() self.right_container.setLayout(right_container_layout) @@ -769,7 +2778,7 @@ class MainApplication(QMainWindow): self.button1.clicked.connect(self.on_run_task) self.button2.clicked.connect(self.clear_all) - self.button3.clicked.connect(self.open_viewer_window) + self.button3.clicked.connect(self.open_launcher_window) # Add scroll area and buttons widget to right container layout right_container_layout.addWidget(self.right_scroll_area) @@ -817,7 +2826,7 @@ class MainApplication(QMainWindow): ("Open Folders...", "Ctrl+Shift+O", self.open_multiple_folders_dialog, resource_path("icons/folder_copy_24dp_1F1F1F.svg")), ("Load Project...", "Ctrl+L", self.load_project, resource_path("icons/article_24dp_1F1F1F.svg")), ("Save Project...", "Ctrl+S", self.save_project, resource_path("icons/save_24dp_1F1F1F.svg")), - ("Save Project As...", "Ctrl+Shift+S", self.save_project, resource_path("icons/save_as_24dp_1F1F1F.svg")), # Maybe connect a separate method if different + ("Save Project As...", "Ctrl+Shift+S", self.save_project, resource_path("icons/save_as_24dp_1F1F1F.svg")), ] for i, (name, shortcut, slot, icon) in enumerate(file_actions): @@ -849,15 +2858,15 @@ class MainApplication(QMainWindow): options_actions = [ ("User Guide", "F1", self.user_guide, resource_path("icons/help_24dp_1F1F1F.svg")), ("Check for Updates", "F5", self.manual_check_for_updates, resource_path("icons/update_24dp_1F1F1F.svg")), + ("Update optodes in snirf file...", "F6", self.update_optode_positions, resource_path("icons/update_24dp_1F1F1F.svg")), ("About", "F12", self.about_window, resource_path("icons/info_24dp_1F1F1F.svg")) ] for i, (name, shortcut, slot, icon) in enumerate(options_actions): options_menu.addAction(make_action(name, shortcut, slot, icon=icon)) - if i == 1: # after the first 2 actions (0,1) + if i == 1 or i == 2: # after the first 2 actions (0,1) options_menu.addSeparator() - # Optional: status bar self.statusbar = self.statusBar() self.statusbar.showMessage("Ready") @@ -893,24 +2902,28 @@ class MainApplication(QMainWindow): self.bubble_widgets.clear() self.statusBar().clearMessage() - # Reset any other data variables - self.collected_data_snapshot = None - self.all_results = None - self.all_haemo = None - self.all_figures = None - + self.raw_haemo_dict = None + self.epochs_dict = None + self.fig_bytes_dict = None + self.cha_dict = None + self.contrast_results_dict = None + self.df_ind_dict = None + self.design_matrix_dict = None + self.age_dict = None + self.gender_dict = None + self.group_dict = None + self.valid_dict = None + # Reset any visible UI elements self.button1.setVisible(False) self.button3.setVisible(False) self.top_left_widget.clear() - def open_viewer_window(self): - if not hasattr(self, "all_figures") or self.all_figures is None: - QMessageBox.warning(self, "No Data", "No figures to show yet!") - return - self.viewer_window = ViewerWindow(self.all_results, self.all_haemo, self.all_figures, self.collected_data_snapshot, self) - self.viewer_window.show() + def open_launcher_window(self): + + self.launcher_window = ViewerLauncherWidget(self.raw_haemo_dict, self.fig_bytes_dict, self.cha_dict, self.contrast_results_dict, self.df_ind_dict, self.design_matrix_dict, self.group_dict) + self.launcher_window.show() def copy_text(self): @@ -936,6 +2949,11 @@ class MainApplication(QMainWindow): self.help = UserGuideWindow(self) self.help.show() + def update_optode_positions(self): + if self.optodes is None or not self.optodes.isVisible(): + self.optodes = UpdateOptodesWindow(self) + self.optodes.show() + def open_file_dialog(self): file_path, _ = QFileDialog.getOpenFileName( @@ -947,6 +2965,7 @@ class MainApplication(QMainWindow): self.button1.setVisible(True) + def open_folder_dialog(self): folder_path = QFileDialog.getExistingDirectory( self, "Select Folder", "" @@ -997,10 +3016,18 @@ class MainApplication(QMainWindow): "progress_states": { bubble.file_path: bubble.current_step for bubble in self.bubble_widgets.values() }, - "all_results": self.all_results, - "all_haemo": self.all_haemo, - "all_figures": self.all_figures, - "config_snapshot": self.collected_data_snapshot, + + "raw_haemo_dict": self.raw_haemo_dict, + "epochs_dict": self.epochs_dict, + "fig_bytes_dict": self.fig_bytes_dict, + "cha_dict": self.cha_dict, + "contrast_results_dict": self.contrast_results_dict, + "df_ind_dict": self.df_ind_dict, + "design_matrix_dict": self.design_matrix_dict, + "age_dict": self.age_dict, + "gender_dict": self.gender_dict, + "group_dict": self.group_dict, + "valid_dict": self.valid_dict, } with open(filename, "wb") as f: @@ -1014,7 +3041,6 @@ class MainApplication(QMainWindow): def load_project(self): - filename, _ = QFileDialog.getOpenFileName( self, "Load Project", "", "FLARE Project (*.flare)" ) @@ -1025,34 +3051,23 @@ class MainApplication(QMainWindow): with open(filename, "rb") as f: data = pickle.load(f) - self.collected_data_snapshot = data["config_snapshot"] - self.all_results = data["all_results"] - self.all_haemo = data["all_haemo"] - self.all_figures = data["all_figures"] - - for section_widget in self.param_sections: - for name, widget in section_widget.widgets.items(): - if name not in self.collected_data_snapshot: - continue - value = self.collected_data_snapshot[name] - - if isinstance(widget, QComboBox): - widget.setCurrentText("True" if value else "False") - - elif isinstance(widget, QLineEdit): - validator = widget.validator() - - if isinstance(validator, QIntValidator): - widget.setText(str(int(value))) - elif isinstance(validator, QDoubleValidator): - widget.setText(str(float(value))) - else: - widget.setText(str(value)) + self.raw_haemo_dict = data.get("raw_haemo_dict", {}) + self.epochs_dict = data.get("epochs_dict", {}) + self.fig_bytes_dict = data.get("fig_bytes_dict", {}) + self.cha_dict = data.get("cha_dict", {}) + self.contrast_results_dict = data.get("contrast_results_dict", {}) + self.df_ind_dict = data.get("df_ind_dict", {}) + self.design_matrix_dict = data.get("design_matrix_dict", {}) + self.age_dict = data.get("age_dict", {}) + self.gender_dict = data.get("gender_dict", {}) + self.group_dict = data.get("group_dict", {}) + self.valid_dict = data.get("valid_dict", {}) + # Restore bubbles and progress self.show_files_as_bubbles_from_list(data["file_list"], data.get("progress_states", {}), filename) - # Re-enable the buttons - self.button1.setVisible(True) + # Re-enable buttons + # self.button1.setVisible(True) self.button3.setVisible(True) QMessageBox.information(self, "Loaded", f"Project loaded from:\n{filename}") @@ -1120,7 +3135,7 @@ class MainApplication(QMainWindow): self.bubble_widgets = {} - temp_bubble = ProgressBubble("Test Bubble", "") # A dummy bubble for measurement + temp_bubble = ProgressBubble("Test Bubble Test Bubble", "") # A dummy bubble for measurement temp_bubble.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) # temp_bubble.setAttribute(Qt.WA_OpaquePaintEvent) # Improves rendering? temp_bubble.adjustSize() # Adjust size after the widget is created @@ -1150,9 +3165,26 @@ class MainApplication(QMainWindow): self.statusBar().showMessage(f"{len(file_list)} files loaded from from {os.path.abspath(filenames)}.") + def get_suffix_from_meta_fields(self): + parts = [] + for key, line_edit in self.meta_fields.items(): + val = line_edit.text().strip() + if val: + parts.append(f"{key}: {val}") + return ", ".join(parts) + def on_bubble_clicked(self, bubble): + + if self.current_file: + self.save_metadata(self.current_file) + + if self.last_clicked_bubble and self.last_clicked_bubble != bubble: + suffix = self.get_suffix_from_meta_fields() + self.last_clicked_bubble.setSuffixText(suffix) + + self.last_clicked_bubble = bubble - # show age / gender + # show age / gender / group self.right_column_widget.show() file_path = bubble.file_path @@ -1198,6 +3230,9 @@ class MainApplication(QMainWindow): # Update current file self.current_file = file_path + if file_path not in self.file_metadata: + self.file_metadata[file_path] = {key: "" for key in self.meta_fields} + # Load new file's metadata into the fields metadata = self.file_metadata.get(file_path, {}) for key, field in self.meta_fields.items(): @@ -1205,6 +3240,23 @@ class MainApplication(QMainWindow): field.setText(metadata.get(key, "")) field.blockSignals(False) + + def eventFilter(self, watched, event): + if event.type() == QEvent.Type.MouseButtonPress: + widget = self.childAt(event.pos()) + if isinstance(widget, ProgressBubble): + pass + else: + if self.last_clicked_bubble: + if not self.last_clicked_bubble.isAncestorOf(widget): + if self.current_file: + self.save_metadata(self.current_file) + suffix = self.get_suffix_from_meta_fields() + self.last_clicked_bubble.setSuffixText(suffix) + self.last_clicked_bubble = None + + return super().eventFilter(watched, event) + def placeholder(self): QMessageBox.information(self, "Placeholder", "This feature is not implemented yet.") @@ -1219,6 +3271,11 @@ class MainApplication(QMainWindow): def get_all_metadata(self): # First, make sure current file's edits are saved + + for field in self.meta_fields.values(): + field.clearFocus() + + # Save current file's metadata if self.current_file: self.save_metadata(self.current_file) @@ -1227,108 +3284,57 @@ class MainApplication(QMainWindow): '''MODULE FILE''' def on_run_task(self): - - all_metadata = self.get_all_metadata() - - for file_path, data in all_metadata.items(): - print(f"{file_path}: {data}") - - collected_data = {} - - # Add all parameter key-value pairs - for section_widget in self.param_sections: - for name, widget in section_widget.widgets.items(): - if isinstance(widget, QComboBox): - val = widget.currentText() == "True" - elif isinstance(widget, QLineEdit): - text = widget.text() - validator = widget.validator() - if isinstance(validator, QIntValidator): - val = int(text or 0) - elif isinstance(validator, QDoubleValidator): - val = float(text or 0.0) - else: - val = text - else: - val = None - collected_data[name] = val # Flattened! - + # Collect all selected snirf files in a flat list + snirf_files = [] if hasattr(self, "selected_paths") and self.selected_paths: - # Handle multiple folders - parents = [Path(p).parent for p in self.selected_paths] - base_parents = set(str(p) for p in parents) - if len(base_parents) > 1: - raise ValueError("Selected folders must have the same parent directory") - - - collected_data["BASE_SNIRF_FOLDER"] = base_parents.pop() - collected_data["SNIRF_SUBFOLDERS"] = [Path(p).name for p in self.selected_paths] - collected_data["STIM_DURATION"] = [0 for _ in self.selected_paths] + for path in self.selected_paths: + p = Path(path) + if p.is_dir(): + snirf_files += [str(f) for f in p.glob("*.snirf")] + elif p.is_file() and p.suffix == ".snirf": + snirf_files.append(str(p)) elif hasattr(self, "selected_path") and self.selected_path: - # Handle single folder - selected_path = Path(self.selected_path) - collected_data["BASE_SNIRF_FOLDER"] = str(selected_path.parent) - collected_data["SNIRF_SUBFOLDERS"] = [selected_path.name] - collected_data["STIM_DURATION"] = [0] + p = Path(self.selected_path) + if p.is_dir(): + snirf_files += [str(f) for f in p.glob("*.snirf")] + elif p.is_file() and p.suffix == ".snirf": + snirf_files.append(str(p)) else: - # No folder selected - handle gracefully or raise error - raise ValueError("No folder(s) selected") + raise ValueError("No file(s) selected") - collected_data["METADATA"] = all_metadata - collected_data["HRF_MODEL"] = 'fir' + if not snirf_files: + raise ValueError("No .snirf files found in selection") - collected_data["FORCE_DROP_CHANNELS"] = [] - collected_data["TARGET_ACTIVITY"] = "Reach" - collected_data["TARGET_CONTROL"] = "Start of Rest" - collected_data["ROI_GROUP_1"] = [[1, 1], [1, 2], [2, 1], [2, 4], [3, 1], # Channel pairings for a region of interest. - [2, 2], [4, 3], [4, 4], [5, 5], [6, 4]] - collected_data["ROI_GROUP_2"] = [[6, 5], [6, 8], [7, 7], [7, 8], [8, 5], # Channel pairings for another region of interest. - [8, 6], [9, 6], [9, 7], [10, 7], [10, 8]] - collected_data["ROI_GROUP_1_NAME"] = "Parieto-Ocipital" # Friendly name for the first region of interest group. - collected_data["ROI_GROUP_2_NAME"] = "Fronto-Parietal" - collected_data["P_THRESHOLD"] = 0.05 - - collected_data["SEE_BAD_IMAGES"] = True - collected_data["ABS_T_VALUE"] = 6 - collected_data["ABS_THETA_VALUE"] = 10 - collected_data["ABS_CONTRAST_T_VALUE"] = 6 - collected_data["ABS_CONTRAST_THETA_VALUE"] = 10 - collected_data["ABS_SIGNIFICANCE_T_VALUE"] = 6 - collected_data["ABS_SIGNIFICANCE_THETA_VALUE"] = 10 - collected_data["BRAIN_DISTANCE"] = 0.02 - collected_data["BRAIN_MODE"] = "weighted" - - collected_data["EPOCH_REJECT_CRITERIA_THRESH"] = 20e-2 - collected_data["TIME_MIN_THRESH"] = -5 - collected_data["TIME_MAX_THRESH"] = 15 - collected_data["VERBOSITY"] = True - - self.collected_data_snapshot = collected_data.copy() + all_params = {} + for section_widget in self.param_sections: + section_params = section_widget.get_param_values() + all_params.update(section_params) + + collected_data = { + "SNIRF_FILES": snirf_files, + "PARAMS": all_params, # add this line + "METADATA": self.get_all_metadata(), # optionally add metadata if needed + } + # Start processing if current_process().name == 'MainProcess': - - self.manager = Manager() self.result_queue = self.manager.Queue() self.progress_queue = self.manager.Queue() + self.result_process = Process( target=run_gui_entry_wrapper, args=(collected_data, self.result_queue, self.progress_queue) ) - self.result_process.daemon = False - - self.result_process.start() - print("start was called") self.statusbar.showMessage("Running processing in background...") - # Poll the queue periodically self.result_timer = QTimer() self.result_timer.timeout.connect(self.check_for_pipeline_results) self.result_timer.start() @@ -1344,17 +3350,41 @@ class MainApplication(QMainWindow): if isinstance(msg, dict): if msg.get("success"): - all_results, all_haemo, all_figures, all_processes, elapsed_time = msg["result"] + results = msg["result"] # from flares.py - self.all_results = all_results - self.all_haemo = all_haemo - self.all_figures = all_figures - self.all_processes = all_processes - self.elapsed_time = elapsed_time + # Initialize storage + # TODO: Is this check needed? + if not hasattr(self, 'raw_haemo_dict'): + self.raw_haemo_dict = {} + self.epochs_dict = {} + self.fig_bytes_dict = {} + self.cha_dict = {} + self.contrast_results_dict = {} + self.df_ind_dict = {} + self.design_matrix_dict = {} + self.age_dict = {} + self.gender_dict = {} + self.group_dict = {} + self.valid_dict = {} - self.statusbar.showMessage(f"Processing complete! Time elapsed: {elapsed_time:.2f} seconds") + # Combine all results into the dicts + for file_path, (raw_haemo, epochs, fig_bytes, cha, contrast_results, df_ind, design_matrix, age, gender, group, valid) in results.items(): + self.raw_haemo_dict[file_path] = raw_haemo + self.epochs_dict[file_path] = epochs + self.fig_bytes_dict[file_path] = fig_bytes + self.cha_dict[file_path] = cha + self.contrast_results_dict[file_path] = contrast_results + self.df_ind_dict[file_path] = df_ind + self.design_matrix_dict[file_path] = design_matrix + self.age_dict[file_path] = age + self.gender_dict[file_path] = gender + self.group_dict[file_path] = group + self.valid_dict[file_path] = valid - self.button3.setVisible(True) + # self.statusbar.showMessage(f"Processing complete! Time elapsed: {elapsed_time:.2f} seconds") + self.statusbar.showMessage(f"Processing complete!") + + self.button3.setVisible(True) else: error_msg = msg.get("error", "Unknown error") @@ -1367,7 +3397,8 @@ class MainApplication(QMainWindow): return elif isinstance(msg, tuple) and msg[0] == 'progress': - _, file_name, step_index = msg + _, file_path, step_index = msg + file_name = os.path.basename(file_path) # extract file name self.progress_update_signal.emit(file_name, step_index) @@ -1547,14 +3578,8 @@ class MainApplication(QMainWindow): def get_snirf_metadata_mne(self, file_name): - try: - import h5py - print("h5py version:", h5py.__version__) - except Exception as e: - print("Failed to import h5py:", e) - try: - raw = mne.io.read_raw_snirf(file_name, preload=True) + raw = read_raw_snirf(file_name, preload=True) snirf_info = {} @@ -1764,9 +3789,8 @@ def run_gui_entry_wrapper(config, gui_queue, progress_queue): """ try: - - import fNIRS_module - fNIRS_module.gui_entry(config, gui_queue, progress_queue) + import flares + flares.gui_entry(config, gui_queue, progress_queue) sys.exit(0) @@ -1866,4 +3890,6 @@ if __name__ == "__main__": app.setWindowIcon(QIcon(resource_path("icons/main.ico"))) window.setWindowIcon(QIcon(resource_path("icons/main.ico"))) window.show() - sys.exit(app.exec()) \ No newline at end of file + sys.exit(app.exec()) + +# Not 4000 lines yay! \ No newline at end of file