Files
flares/fNIRS_module.py
2025-08-19 13:57:14 -07:00

4107 lines
161 KiB
Python

"""
Filename: fNIRS_module.py
Description: Core functionality for FLARES
Author: Tyler de Zeeuw
License: GPL-3.0
"""
# Built-in imports
import os
import sys
import time
import logging
import platform
import warnings
import threading
from io import BytesIO
from copy import deepcopy
from pathlib import Path
from zipfile import ZipFile
from datetime import datetime
from itertools import compress
from multiprocessing import Queue
from typing import Any, Optional, cast, Literal, Iterator, Union
# External library imports
import pywt # type: ignore
import qtpy # type: ignore
import xlrd # type: ignore
import psutil
import scooby # type: ignore
import requests
import pyvistaqt # type: ignore
import darkdetect # type: ignore
import numpy as np
import pandas as pd
from PIL import Image
import seaborn as sns
import neurokit2 as nk # type: ignore
from tqdm.auto import tqdm
from pandas import DataFrame
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from numpy.typing import NDArray
#import vtkmodules.util.data_model
from numpy import floating, float64
from matplotlib.lines import Line2D
import matplotlib.colors as mcolors
from scipy.stats import ttest_1samp # type: ignore
from matplotlib.figure import Figure
import statsmodels.formula.api as smf # type: ignore
#import vtkmodules.util.execution_model
from nilearn.plotting import plot_design_matrix # type: ignore
from scipy.signal import welch, butter, filtfilt # type: ignore
from matplotlib.colors import LinearSegmentedColormap
from IPython.display import display, Markdown, clear_output # type: ignore
from statsmodels.tools.sm_exceptions import ConvergenceWarning # type: ignore
from concurrent.futures import ProcessPoolExecutor, as_completed
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
# External library imports for mne
import mne
from mne import EvokedArray, Info, read_source_spaces, stc_near_sensors # type: ignore
from mne.source_space import SourceSpaces
from mne.transforms import Transform # type: ignore
from mne.io import BaseRaw, read_raw_snirf # type: ignore
from mne.annotations import Annotations # type: ignore
from mne_nirs.visualisation import plot_glm_group_topo # type: ignore
from mne_nirs.channels import get_long_channels, get_short_channels, picks_pair_to_idx # type: ignore
from mne_nirs.experimental_design import make_first_level_design_matrix # type: ignore
from mne_nirs.statistics import run_glm, statsmodels_to_results # type: ignore
from mne_nirs.signal_enhancement import enhance_negative_correlation, short_channel_regression # type: ignore
from mne.preprocessing.nirs import beer_lambert_law, optical_density, temporal_derivative_distribution_repair, source_detector_distances, short_channels # type: ignore
from mne_nirs.io.fold import fold_channel_specificity # type: ignore
from mne_nirs.preprocessing import peak_power # type: ignore
from mne.viz import Brain
from mne_nirs.statistics._glm_level_first import RegressionResults # type: ignore
from mne.filter import filter_data # type: ignore
CURRENT_VERSION = "1.0.0"
GUI = False
PLATFORM_NAME = platform.system().lower()
BASE_SNIRF_FOLDER: str
SNIRF_SUBFOLDERS: list[str]
STIM_DURATION: list[float]
MAX_WORKERS: int
SECONDS_TO_STRIP: int
DOWNSAMPLE: bool
DOWNSAMPLE_FREQUENCY: int
FORCE_DROP_CHANNELS: list[str]
SOURCE_DETECTOR_SEPARATOR: str
OPTODE_FILE: bool
OPTODE_FILE_PATH: str
OPTODE_FILE_SEPARATOR: str
TDDR: bool
WAVELET: bool
IQR: float
HEART_RATE: bool
SECONDS_TO_STRIP_HR: int
MAX_LOW_HR: int
MAX_HIGH_HR: int
SMOOTHING_WINDOW_HR: int
HEART_RATE_WINDOW: int
SHORT_CHANNEL: bool
SHORT_CHANNEL_THRESH: float
SCI: bool
SCI_TIME_WINDOW: int
SCI_THRESHOLD: float
PSP: bool
PSP_TIME_WINDOW: int
PSP_THRESHOLD: float
# TODO: Implement
SNR: bool
SNR_TIME_WINDOW : int
SNR_THRESHOLD: float
EXCLUDE_CHANNELS: bool
MAX_BAD_CHANNELS: int
LONG_CHANNEL_THRESH: float
METADATA: dict
DRIFT_MODEL: str
DURATION_BETWEEN_ACTIVITIES: int
HRF_MODEL: str
SHORT_CHANNEL_REGRESSION: bool
N_JOBS: int
TARGET_ACTIVITY: str
TARGET_CONTROL: str
ROI_GROUP_1: list[list[int]]
ROI_GROUP_2: list[list[int]]
ROI_GROUP_1_NAME: str
ROI_GROUP_2_NAME: str
P_THRESHOLD: float
SEE_BAD_IMAGES: bool
ABS_T_VALUE: int
ABS_THETA_VALUE: int
ABS_CONTRAST_T_VALUE: int
ABS_CONTRAST_THETA_VALUE: int
ABS_SIGNIFICANCE_T_VALUE: int
ABS_SIGNIFICANCE_THETA_VALUE: int
BRAIN_DISTANCE: float
BRAIN_MODE: str
EPOCH_REJECT_CRITERIA_THRESH: float
TIME_MIN_THRESH: int
TIME_MAX_THRESH: int
VERBOSITY: bool
REJECT_PAIRS = None
FORCE_DROP_ANNOTATIONS = None
FILTER_LOW_PASS = None
FILTER_HIGH_PASS = None
EPOCH_PAIR_TOLERANCE_WINDOW = None
# FIXME: Shouldn't need each ordering - just order it before checking
FIXED_CATEGORY_COLORS = {
"SCI only": "skyblue",
"PSP only": "salmon",
"SNR only": "lightgreen",
"PSP + SCI": "orange",
"SCI + SNR": "violet",
"PSP + SNR": "gold",
"SCI + PSP": "orange",
"SNR + SCI": "violet",
"SNR + PSP": "gold",
"PSP + SNR + SCI": "gray",
"SCI + PSP + SNR": "gray",
"SCI + SNR + PSP": "gray",
"PSP + SCI + SNR": "gray",
"PSP + SNR + SCI": "gray",
"SNR + SCI + PSP": "gray",
"SNR + PSP + SCI": "gray",
}
REQUIRED_KEYS: dict[str, Any] = {
"BASE_SNIRF_FOLDER": str,
"SNIRF_SUBFOLDERS": list,
"STIM_DURATION": list,
"MAX_WORKERS": int,
"SECONDS_TO_STRIP": int,
"DOWNSAMPLE": bool,
"DOWNSAMPLE_FREQUENCY": int,
"FORCE_DROP_CHANNELS": list,
"SOURCE_DETECTOR_SEPARATOR": str,
"OPTODE_FILE": bool,
"OPTODE_FILE_PATH": str,
"OPTODE_FILE_SEPARATOR": str,
"TDDR": bool,
"WAVELET": bool,
"IQR": float,
"HEART_RATE": bool,
"SECONDS_TO_STRIP_HR": int,
"MAX_LOW_HR": int,
"MAX_HIGH_HR": int,
"SMOOTHING_WINDOW_HR": int,
"HEART_RATE_WINDOW": int,
"SHORT_CHANNEL": bool,
"SHORT_CHANNEL_THRESH": float,
"SCI": bool,
"SCI_TIME_WINDOW": int,
"SCI_THRESHOLD": float,
"PSP": bool,
"PSP_TIME_WINDOW": int,
"PSP_THRESHOLD": float,
"SNR": bool,
"SNR_TIME_WINDOW": int,
"SNR_THRESHOLD": float,
"EXCLUDE_CHANNELS": bool,
"MAX_BAD_CHANNELS": int,
"LONG_CHANNEL_THRESH": float,
"METADATA": dict,
"DRIFT_MODEL": str,
"DURATION_BETWEEN_ACTIVITIES": int,
"HRF_MODEL": str,
"SHORT_CHANNEL_REGRESSION": bool,
"N_JOBS": int,
"TARGET_ACTIVITY": str,
"TARGET_CONTROL": str,
"ROI_GROUP_1": list,
"ROI_GROUP_2": list,
"ROI_GROUP_1_NAME": str,
"ROI_GROUP_2_NAME": str,
"P_THRESHOLD": float,
"SEE_BAD_IMAGES": bool,
"ABS_T_VALUE": int,
"ABS_THETA_VALUE": int,
"ABS_CONTRAST_T_VALUE": int,
"ABS_CONTRAST_THETA_VALUE": int,
"ABS_SIGNIFICANCE_T_VALUE": int,
"ABS_SIGNIFICANCE_THETA_VALUE": int,
"BRAIN_DISTANCE": float,
"BRAIN_MODE": str,
"EPOCH_REJECT_CRITERIA_THRESH": float,
"TIME_MIN_THRESH": int,
"TIME_MAX_THRESH": int,
"VERBOSITY": bool,
# "REJECT_PAIRS": bool,
# "FORCE_DROP_ANNOTATIONS": list,
# "FILTER_LOW_PASS": float,
# "FILTER_HIGH_PASS": float,
# "EPOCH_PAIR_TOLERANCE_WINDOW": int,
}
# Ensure that we are working in the directory of this file
script_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(script_dir)
# Configure logging to file with timestamps and realtime flush
if PLATFORM_NAME == 'darwin':
logging.basicConfig(
filename=os.path.join(os.path.dirname(sys.executable), "../../../fnirs_analysis.log"),
level=logging.INFO,
format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
filemode='a'
)
else:
logging.basicConfig(
filename='fnirs_analysis.log',
level=logging.INFO,
format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
filemode='a'
)
logger = logging.getLogger()
class ProcessingError(Exception):
def __init__(self, message: str = "Something went wrong!"):
self.message = message
super().__init__(self.message)
def gui_entry(config: dict[str, Any], gui_queue: Queue, progress_queue: Queue) -> None:
try:
print("setting config")
set_config(config, True)
# Start a thread to forward progress messages back to GUI
def forward_progress():
while True:
try:
msg = progress_queue.get(timeout=1)
if msg == "__done__":
break
gui_queue.put(msg)
except:
continue
t = threading.Thread(target=forward_progress, daemon=True)
t.start()
# Run the actual processing, with progress_queue passed down
print("actual call")
result = run_groups(config, True, progress_queue=progress_queue)
# Signal end of progress
progress_queue.put("__done__")
t.join()
gui_queue.put({"success": True, "result": result})
except Exception as e:
import traceback
gui_queue.put({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
})
def set_config(config: dict[str, Any], gui: bool = False) -> None:
"""
Validates and applies the given configuration dictionary.
Parameters
----------
config : dict[str, Any]
Dictionary containing configuration keys and their values.
"""
if gui:
globals().update({"GUI": True})
# Ensure all keys are present
for key, expected_type in REQUIRED_KEYS.items():
if key not in config:
raise KeyError(f"Missing config key: {key}")
value = config[key]
if not isinstance(value, expected_type):
# Special handling for lists to check list contents
if expected_type == list and isinstance(value, list):
continue # optionally: validate inner types too
raise TypeError(f"Key '{key}' has incorrect type. Expected {expected_type.__name__}, got {type(value).__name__}")
# Update the global variables to match the values in the config keys
globals().update(config)
# Ensure that passed through variables are correct or that they actually exist
assert Path(BASE_SNIRF_FOLDER).is_dir(), "BASE_SNIRF_FOLDER was not found. Please check the folder location and try again."
for folder in SNIRF_SUBFOLDERS:
assert Path(os.path.join(BASE_SNIRF_FOLDER, folder)).is_dir(), f"The subfolder {folder} could not be found. Please check the folder location and try again."
assert len(SNIRF_SUBFOLDERS) == len(STIM_DURATION), f"The amount of subfolders do not match the amount of stim durations. Subfolders: {len(SNIRF_SUBFOLDERS)} Stim durations: {len(STIM_DURATION)}"
if OPTODE_FILE:
path = Path(OPTODE_FILE_PATH)
assert path.is_file(), "OPTODE_FILE was specified, but OPTODE_FILE_PATH is not a file."
assert path.suffix == ".txt", "OPTODE_FILE_PATH does not end with a .txt extension."
# Ensure that the BASE_SNIRF_FOLDER is an absolute path - helpful when logger.infoing later
if 'BASE_SNIRF_FOLDER' in globals():
abs_path = str(Path(BASE_SNIRF_FOLDER).resolve())
globals()['BASE_SNIRF_FOLDER'] = abs_path
# Supress MNE's warnings
if not VERBOSITY:
warnings.filterwarnings("ignore", category=ConvergenceWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger.info("[Config] Configuration successfully set.")
def run_groups(config, gui: bool = False, progress_queue=None) -> tuple[dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], dict[str, dict[str, BaseRaw]], dict[str, list[Figure]], dict[str, str], float]:
"""
Process multiple data folders and aggregate results, haemoglobin, figures, and processing details.
Returns
-------
tuple[dict[str, tuple[DataFrame, DataFrame, DataFrame]], dict[str, dict[str, BaseRaw]], dict[str, list[Figure]], dict[str, str]]
- dict[str, tuple[DataFrame, DataFrame, DataFrame]]: Results dataframes grouped by folder.
- dict[str, dict[str, BaseRaw]]: Raw haemoglobin data indexed by file ID.
- dict[str, list[Figure]]: Figures generated during processing grouped by step.
- dict[str, str]: Processing status messages indexed by file ID.
- float: Elapsed time
"""
# Create dictionaries to store our results
all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]] = {}
all_figures: dict[str, list[Figure]] = {}
all_raw_haemo: dict[str, dict[str, BaseRaw]] = {}
all_processes: dict[str, str] = {}
# Variables to store our total files to be processed and the remaining amount of files while the program is running
total_files = 0
files_remaining = {'count': 0}
start_time = time.time()
# Iterate over all the folders and determine how many files are in the folder
logger.info("Calculating how many files there are...")
for folder in SNIRF_SUBFOLDERS:
full_path = os.path.join(BASE_SNIRF_FOLDER, folder)
num_items = len([
f for f in os.listdir(full_path)
if os.path.isfile(os.path.join(full_path, f))
])
total_files += num_items
logger.info(f"Total of {total_files} files.")
# Set the remaining count to be the total amount of files
files_remaining['count'] = total_files
# Iterate over all the folders
for folder, stim_duration in zip(SNIRF_SUBFOLDERS, STIM_DURATION):
full_path = os.path.join(BASE_SNIRF_FOLDER, folder)
try:
# Process all participants in the folder
logger.info(f"Processing all files in {folder}...")
raw_haemo, df_roi, df_cha, df_con, df_design_matrix, figures, process = process_folder(full_path, stim_duration, files_remaining, config, gui, progress_queue=progress_queue)
# Store the results into the corresponding dictionaries
logger.info(f"Storing the results from the {folder} folder...")
# TODO: This looks yucky
try:
all_results[folder] = (df_roi, df_cha, df_con, df_design_matrix)
logger.info(f"Applied all results.")
except:
pass
try:
for step, fig_list in figures.items():
all_figures.setdefault(step, []).extend(fig_list)
logger.info(f"Applied all figures.")
except:
pass
try:
for file_id, raw in raw_haemo.items():
all_raw_haemo[file_id] = raw
logger.info(f"Applied all haemo.")
except:
pass
try:
for file_id, p in process.items():
all_processes[file_id] = p
logger.info(f"Applied all processes.")
except:
pass
except ProcessingError as e:
logger.info(f"Something happened! {e}")
# Something really bad happened. No partial return
raise Exception(e)
except Exception as e:
logger.info(f"Something happened! {e}")
# Still return a partial analysis even if something goes wrong
return all_results, all_raw_haemo, all_figures, all_processes, time.time() - start_time
return all_results, all_raw_haemo, all_figures, all_processes, time.time() - start_time
def create_image_montage(images: list[Image.Image], cols: int) -> Optional[Image.Image]:
"""
Creates a grid montage image from a list of PIL Images.
Parameters
----------
images : list[Image.Image]
List of images to arrange in the montage.
cols : int
Number of columns in the montage grid.
Returns
-------
Optional[Image.Image]
The combined montage image, or None if the input list of images is empty.
"""
# Verify that we have images to process
if not images:
return None
# Calculate the width, height, and rows
logger.info("Calculating the montage parameters...")
widths, heights = zip(*(i.size for i in images))
max_width = max(widths)
max_height = max(heights)
rows = (len(images) + cols - 1) // cols
# Create the montage image
logger.info("Creating the montage...")
montage = Image.new('RGBA', (cols * max_width, rows * max_height), (255, 255, 255, 255))
for idx, image in enumerate(images):
x = (idx % cols) * max_width
y = (idx // cols) * max_height
montage.paste(image, (x, y)) # type: ignore
return montage
def show_all_images(figures: dict[str, list[BytesIO]], inline: bool = False) -> None:
"""
Displays montages of figures either inline or in separate windows.
Parameters
----------
figures : dict[str, list[Figure]]
Dictionary containing lists of figures categorized by type.
inline : bool, optional
If True, display images inline (e.g., in Jupyter notebooks). Otherwise, opens them in separate windows (default is False).
"""
if inline:
logger.info("Inline was selected.")
else:
logger.info("Inline was not selected.")
# If we have less than 4 figures, the columns should be the exact amount of images we have. If we have more, enforce 4 columns
logger.info("Calculating columns...")
if len(figures.get('Raw', [])) < 4:
cols = len(figures.get('Raw', []))
else:
cols = 4
# Iterate over all of the types of figure, create a montage with figures of the same type, and display the resulting image
logger.info("Generating images...")
for _, fig_bytes_list in figures.items():
pil_images = []
for b in fig_bytes_list:
try:
img = Image.open(BytesIO(b)).convert("RGB")
pil_images.append(img)
except Exception as e:
logger.warning(f"Could not open image from bytes: {e}")
continue
montage = create_image_montage(pil_images, cols)
if montage:
# Determine how to display the images to the user
if inline:
display(montage)
else:
montage.show()
def save_all_images(figures: dict[str, list[Figure]]) -> None:
"""
Saves montages of figures as timestamped PNG files in folder called 'images'.
Parameters
----------
figures : dict[str, list[Figure]]
Dictionary containing lists of figures categorized by type.
"""
# Get the current working directory and create a folder called images if it does not exist
logger.info("Getting the current directory...")
if PLATFORM_NAME == 'darwin':
images_folder = os.path.join(os.path.dirname(sys.executable), "../../../images")
else:
cwd = os.getcwd()
images_folder = os.path.join(cwd, "images")
logger.info("Attempting to create the images folder...")
os.makedirs(images_folder, exist_ok=True)
# Generate a timestamp to be appended to the end of the file name
logger.info("Generating the timestamp...")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# If we have less than 4 figures, the columns should be the exact value. If we have more, enforce 4 columns
logger.info("Calculating columns...")
raw_fig_count = len(figures.get('Raw', []))
if raw_fig_count < 4:
cols = raw_fig_count
else:
cols = 4
# Iterate over all of the types of figures, create a montage with figures of the same type, and save the resulting image
logger.info("Generating images...")
for step, fig_bytes_list in figures.items():
pil_images = []
for b in fig_bytes_list:
try:
img = Image.open(BytesIO(b)).convert("RGB")
pil_images.append(img)
except Exception as e:
logger.warning(f"Could not open image from bytes: {e}")
continue
montage = create_image_montage(pil_images, cols)
if montage:
filename = f"{step}_{timestamp}.png"
save_path = os.path.join(images_folder, filename)
montage.save(save_path) # type: ignore
logger.info(f"Saved image to {save_path}")
logger.info(f"All images have been saved to '{images_folder}'.")
def load_snirf(file_path: str, ID: str, drop_prefixes: list[str]) -> tuple[BaseRaw, Figure]:
"""
Loads a snirf file, optionally drops channels, downsamples, and creates a figure showing the results.
Parameters
----------
file_path : str
Path of the snirf file to load.
ID : str
File name of the the snirf file that was loaded.
drop_prefixes : list[str]
List of channel name prefixes to drop from the data.
Returns
-------
tuple[BaseRaw, Figure]
- BaseRaw: The processed data object.
- Figure: The corresponding Matplotlib figure.
"""
logger.info(f"Loading the snirf file ({ID})...")
# Read the snirf file
raw = read_raw_snirf(file_path, verbose=VERBOSITY) # type: ignore
raw.load_data(verbose=VERBOSITY) # type: ignore
# Strip the specified amount of seconds from the start of the file
total_duration = getattr(raw, "times")[-1]
if total_duration > SECONDS_TO_STRIP:
raw.crop(tmin=SECONDS_TO_STRIP, tmax=total_duration, verbose=VERBOSITY) # type: ignore
logger.info(f"Stripped first {SECONDS_TO_STRIP} second(s) of data.")
else:
logger.info(f"Data length ({total_duration:.2f}s) less than strip duration; no cropping applied.")
# If the user forcibly dropped channels, remove them now before any processing occurs
logger.info("Checking if there are channels to forcibly drop...")
if drop_prefixes:
logger.info("Force dropped channels was specified.")
channels_to_drop = [ch for ch in cast(list[str], getattr(raw, "ch_names")) if any(ch.startswith(prefix) for prefix in drop_prefixes)]
raw.drop_channels(channels_to_drop, "raise") # type: ignore
logger.info("Force dropped channels:", channels_to_drop)
# If the user wants to downsample, do it right away
logger.info("Checking if we should downsample...")
if DOWNSAMPLE:
logger.info("Downsample was specified.")
sfreq_old = getattr(raw, "info")["sfreq"]
raw.resample(DOWNSAMPLE_FREQUENCY, verbose=VERBOSITY) # type: ignore
sfreq_new = getattr(raw, "info")["sfreq"]
logger.info(f"Finished downsampling. Old frequency: {sfreq_old}. New frequency: {sfreq_new}.")
# Create a figure for the results
logger.info("Creating the figure...")
fig = cast(Figure, raw.plot(show=False, n_channels=len(getattr(raw, "ch_names")), duration=raw.times[-1]).figure) # type: ignore
fig.suptitle(f"Raw fNIRS Data for {ID}", fontsize=16) # type: ignore
fig.subplots_adjust(top=0.92)
plt.close(fig)
logger.info("Successfully loaded the snirf file.")
return raw, fig
def calculate_and_apply_updated_optode_coordinates(data: BaseRaw) -> BaseRaw:
"""
Update optode coordinates on the given MNE Raw data using a specified optode file.
Parameters
----------
data : BaseRaw
The loaded data object to process with new optode coordinates.
Returns
-------
BaseRaw
The processed data object with the updated montage applied.
"""
logger.info("Updating optode coordinates...")
fiducials: dict[str, NDArray[floating[Any]]] = {}
ch_positions: dict[str, NDArray[floating[Any]]] = {}
# Read the lines from the optode file
logger.info(f"Reading optode file from {OPTODE_FILE_PATH}")
with open(OPTODE_FILE_PATH, 'r') as f:
for line in f:
if line.strip():
# Split by the semicolon and convert to meters
ch_name, coords_str = line.split(OPTODE_FILE_SEPARATOR)
coords = np.array(list(map(float, coords_str.strip().split()))) * 0.001
logger.info(f"Read line: {ch_name} with coords (m): {coords}")
# The key we have is a fiducial
if ch_name.lower() in ['lpa', 'nz', 'rpa']:
fiducials[ch_name.lower()] = coords
# The key we have is a source or detector
else:
ch_positions[ch_name.upper()] = coords
# Create montage with updated coords in head space
logger.info("Creating and applying the montage...")
initial_montage = mne.channels.make_dig_montage(ch_pos=ch_positions, nasion=fiducials.get('nz'), lpa=fiducials.get('lpa'), rpa=fiducials.get('rpa'), coord_frame='head') # type: ignore
data.set_montage(initial_montage, verbose=VERBOSITY) # type: ignore
logger.info("Successfully updated optode coordinates.")
return data
def calculate_and_apply_tddr(data: BaseRaw, ID: str) -> tuple[BaseRaw, Figure]:
"""
Applies Temporal Derivative Distribution Repair (TDDR) to the raw data and creates a figure showing the results.
Parameters
----------
data : BaseRaw
The loaded data object to process.
ID : str
File name of the the snirf file that was loaded.
Returns
-------
tuple[BaseRaw, Figure]
- BaseRaw: The processed data object.
- Figure: The corresponding Matplotlib figure.
"""
# Apply TDDR
logger.info("Applying temporal derivative distribution repair...")
raw_with_tddr = cast(BaseRaw, temporal_derivative_distribution_repair(data, verbose=VERBOSITY))
# Create a figure for the results
logger.info("Creating the figure...")
fig = cast(Figure, raw_with_tddr.plot(show=False, n_channels=len(getattr(data, "ch_names")), duration=data.times[-1]).figure) # type: ignore
fig.suptitle(f"TDDR for {ID}", fontsize=16) # type: ignore
fig.subplots_adjust(top=0.92)
plt.close(fig)
logger.info("Successfully applied temporal derivative distribution repair.")
return raw_with_tddr, fig
def iqr_threshold(coeffs: NDArray[float64], k: float = 1.5) -> floating[Any]:
"""
Calculate the interquartile range (IQR) threshold scaled by a factor, k.
Parameters
----------
coeffs : NDArray[float64]
Array of coefficients to compute the IQR from.
k : float, optional
Scaling factor for the IQR (default is 1.5).
Returns
-------
floating[Any]
The scaled IQR threshold value.
"""
# Calculate the IQR
q1 = np.percentile(coeffs, 25)
q3 = np.percentile(coeffs, 75)
iqr = q3 - q1
return k * iqr
def wavelet_iqr_denoise(signal: NDArray[float64], wavelet: str = 'db4', level: int = 3) -> NDArray[float64]:
"""
Denoises a signal using wavelet decomposition and IQR-based thresholding on detail coefficients.
Parameters
----------
signal : NDArray[float64]
The input signal array to denoise.
wavelet : str, optional
The type of wavelet to use for decomposition (default is 'db4').
level : int, optional
Decomposition level for wavelet transform (default is 3).
Returns
-------
NDArray[float64]
The denoised signal array, with the same length as the input.
"""
# Decompose the signal using wavelet transform and initialize a list with approximation coefficients
coeffs: list[NDArray[float64]] = pywt.wavedec(signal, wavelet, level=level) # type: ignore
cA = coeffs[0]
denoised_coeffs = [cA]
# Threshold detail coefficients to reduce noise
for cD in coeffs[1:]:
threshold = iqr_threshold(cD, IQR)
cD_thresh = np.sign(cD) * np.maximum(np.abs(cD) - threshold, 0.0) # np.where((cD < lower) | (cD > upper), 0, cD)
cD_thresh = cD_thresh.astype(float64)
denoised_coeffs.append(cD_thresh)
# Reconstruct the denoised signal
denoised_signal = cast(NDArray[float64], pywt.waverec(denoised_coeffs, wavelet)) # type: ignore
return denoised_signal[:len(signal)]
def calculate_and_apply_wavelet(data: BaseRaw, ID: str) -> tuple[BaseRaw, Figure]:
"""
Applies a wavelet IQR denoising filter to the data and generates a plot.
Parameters
----------
data : BaseRaw
The loaded data object to process.
ID : str
File name of the the snirf file that was loaded.
Returns
-------
tuple[BaseRaw, Figure]
- BaseRaw: The processed data object.
- Figure: The corresponding Matplotlib figure.
"""
logger.info("Applying the wavelet filter...")
# Denoise the data
logger.info("Denoising the data...")
loaded_data: NDArray[float64] = data.get_data(verbose=VERBOSITY) # type: ignore
denoised_data = np.zeros_like(loaded_data)
logger.info("Calculating the IQR, decomposing the signal, and thresholding the coefficients...")
for ch in range(loaded_data.shape[0]):
denoised_data[ch, :] = wavelet_iqr_denoise(loaded_data[ch, :], wavelet='db4', level=3)
# Reconstruct the data with the annotations
logger.info("Reconstructing the data with annotations...")
raw_with_tddr_and_wavelet = mne.io.RawArray(denoised_data, cast(mne.Info, data.info), verbose=VERBOSITY)
raw_with_tddr_and_wavelet.set_annotations(data.annotations.copy(), verbose=VERBOSITY) # type: ignore
# Create a figure for the results
logger.info("Creating the figure...")
fig = cast(Figure, raw_with_tddr_and_wavelet.plot(show=False, n_channels=len(getattr(data, "ch_names")), duration=data.times[-1]).figure) # type: ignore
fig.suptitle(f"Wavelet for {ID}", fontsize=16) # type: ignore
fig.subplots_adjust(top=0.92)
plt.close(fig)
logger.info("Successfully applied the wavelet filter.")
return raw_with_tddr_and_wavelet, fig
def short_channel_processing_for_hr(data: BaseRaw, short_chans: BaseRaw | None) -> tuple[float, NDArray[float64], NDArray[float64]]:
"""
Extract and trim short-channel fNIRS signal for heart rate analysis.
Parameters
----------
data : BaseRaw
The loaded data object to process.
short_chans : BaseRaw | None
Data object with only short separation channels, or None if unavailable.
Returns
-------
tuple[float, NDArray[float64], NDArray[float64]]
- float: Sampling frequency of the signal.
- NDArray[float64]: Trimmed short-channel signal.
- NDArray[float64]: Corresponding time values.
"""
# Find the short channel (or best candidate) and extract signal data and sampling frequency
logger.info("Extracting the signal and calculating the sampling frequency...")
# If a short channel exists, use it for our signal. Otherwise just take the first channel in the data
# TODO: Find a better way around this
if short_chans is not None:
signal = cast(NDArray[float64], short_chans.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore
else:
signal = cast(NDArray[float64], data.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore
# Calculate the sampling frequency
sfreq = cast(int, data.info['sfreq'])
# Trim start and end of the signal to remove edge artifacts
logger.info(f"Removing {SECONDS_TO_STRIP_HR} seconds from the beginning and end of the file...")
strip_samples = int(sfreq * SECONDS_TO_STRIP_HR)
signal_trimmed = signal[strip_samples:-strip_samples]
times_trimmed = cast(NDArray[float64], getattr(data, "times"))[strip_samples:-strip_samples]
return sfreq, signal_trimmed, times_trimmed
def calculate_heart_rate_neurokit(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[float64], float]:
"""
Calculate and smooth heart rate from a trimmed signal using NeuroKit.
Parameters
----------
sfreq : float
Sampling frequency of the signal.
signal_trimmed : NDArray[float64]
Preprocessed and trimmed fNIRS signal.
Returns
-------
tuple[NDArray[float64], float]
- NDArray[float64]: Smoothed heart rate time series (BPM).
- float: Mean heart rate.
"""
logger.info("Calculating heart rate using NeuroKit...")
# Filter signal to isolate heart rate frequencies and detect peaks
logger.info("Filtering the signal and detecting peaks...")
signal_filtered = cast(NDArray[float64], nk.signal_filter(signal_trimmed, sampling_rate=sfreq, lowcut=0.8, highcut=2.5)) # type: ignore
peaks_dict = cast(dict[str, Any], nk.signal_findpeaks(signal_filtered)) # type: ignore
peaks = peaks_dict['Peaks']
hr = cast(NDArray[float64], nk.signal_rate(peaks, sampling_rate=sfreq, desired_length=len(signal_trimmed))) # type: ignore
hr_clean = np.clip(hr, MAX_LOW_HR, MAX_HIGH_HR)
# Smooth heart rate time series by replacing spikes with local rolling mean and calculate the mean
logger.info("Smoothing the signal and calculating the mean...")
hr_series = pd.Series(hr_clean)
local_median = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).median()
spikes = hr_series > (local_median + 10)
smoothed_values = hr_series.copy()
smoothed_spikes = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).mean()
smoothed_values[spikes] = smoothed_spikes[spikes]
hr_smooth_nk = cast(NDArray[float64], smoothed_values.to_numpy()) # type: ignore
mean_hr_nk = hr_smooth_nk.mean()
logger.info("Original HR min/max: %f, %f", hr_clean.min(), hr_clean.max())
logger.info("Smoothed HR min/max:%f, %f", hr_smooth_nk.min(), hr_smooth_nk.max())
logger.info(f"Estimated mean HR nk: {mean_hr_nk:.1f} BPM")
logger.info("Successfully calculated heart rate using NeuroKit.")
return hr_smooth_nk, mean_hr_nk
def calculate_heart_rate_scipy(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float]:
"""
Estimate heart rate using spectral analysis on a high-pass filtered signal.
Parameters
----------
sfreq : float
Sampling frequency of the input signal.
signal_trimmed : NDArray[float64]
Trimmed fNIRS signal to analyze.
Returns
-------
tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float]
- NDArray[floating[Any]]: Frequencies converted to beats per minute (BPM).
- NDArray[float64]: Power spectral density (PSD) of the signal.
- np.ndarray[Any, np.dtype[np.bool_]]: Boolean mask indicating frequencies within heart rate range (30-300 BPM).
- float: Estimated mean heart rate in BPM corresponding to the PSD peak within the range.
"""
logger.info("Calculating heart rate using SciPy...")
# Apply a high-pass Butterworth filter to remove slow trends below 0.5 Hz from the trimmed signal (actual data)
logger.info("Applying a butterworth filter...")
b, a = cast(tuple[NDArray[float64], NDArray[float64]], butter(2, 0.5 / (sfreq / 2), btype='high'))
signal_hp = cast(NDArray[float64],filtfilt(b, a, signal_trimmed))
# Calculate the Power Spectral Density (PSD) of the filtered signal using Welch's method
logger.info("Calculating the PSD...")
nperseg = min(len(signal_hp), 4096)
frequencies_scipy, psd_scipy = cast(tuple[NDArray[float64], NDArray[float64]], welch(signal_hp, fs=sfreq, nperseg=nperseg, noverlap=nperseg//2))
# Convert frequency values to beats per minute (BPM) and set a heart rate range (30-300 BPM)
logger.info("Converting to BPM...")
freq_bpm_scipy = frequencies_scipy * 60
freq_range_scipy = (freq_bpm_scipy > 30) & (freq_bpm_scipy < 300)
# Identify the peak frequency within the heart rate range and estimate the mean heart rate in BPM
logger.info("Finding the mean...")
peak_index = np.argmax(psd_scipy[freq_range_scipy])
mean_hr_scipy = freq_bpm_scipy[freq_range_scipy][peak_index]
logger.info("Successfully calculated heart rate using SciPy.")
return freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy
def plot_heart_rate(
freq_bpm_scipy: NDArray[floating[Any]],
psd_scipy: NDArray[float64],
freq_range_scipy: np.ndarray[Any, np.dtype[np.bool_]],
mean_hr_scipy: float,
hr_smooth_nk: NDArray[floating[Any]],
mean_hr_nk: float,
times_trimmed: NDArray[floating[Any]],
overruled: bool
) -> tuple[Figure, Figure]:
"""
Generate plots comparing heart rate estimates from SciPy PSD and NeuroKit2.
Parameters
----------
freq_bpm_scipy : NDArray[floating[Any]]
Frequencies in beats per minute from SciPy PSD analysis.
psd_scipy : NDArray[float64]
Power spectral density values corresponding to freq_bpm_scipy.
freq_range_scipy : np.ndarray[Any, np.dtype[np.bool_]]
Boolean mask indicating the heart rate frequency range used in PSD.
mean_hr_scipy : float
Mean heart rate estimated from SciPy PSD peak.
hr_smooth_nk : NDArray[floating[Any]]
Smoothed instantaneous heart rate from NeuroKit2.
mean_hr_nk : float
Mean heart rate estimated from NeuroKit2 data.
times_trimmed : NDArray[floating[Any]]
Time points corresponding to hr_smooth_nk values.
overruled: bool
True if the heart rate from NeuroKit2 is overriding the results from the PSD.
Returns
-------
tuple[Figure, Figure]
- Figure showing the PSD and SciPy heart rate estimate.
- Figure showing the time series comparison of heart rates.
"""
# Create the first plot for the PSD. Add a yellow range to show what we will be filtering to.
logger.info("Creating the figure...")
fig1, ax1 = plt.subplots(figsize=(10, 5)) # type: ignore
ax1.set_xlim(30, 300)
ax1.plot(freq_bpm_scipy[freq_range_scipy], psd_scipy[freq_range_scipy]) # type: ignore
ax1.axvline(x=mean_hr_scipy, color='red', linestyle='--', label=f'Mean HR: {mean_hr_scipy:.1f} BPM') # type: ignore
ax1.axvspan(min(mean_hr_nk - HEART_RATE_WINDOW, mean_hr_scipy - HEART_RATE_WINDOW), max(mean_hr_nk + HEART_RATE_WINDOW, mean_hr_scipy + HEART_RATE_WINDOW), color='yellow', alpha=0.3, label=f'HR Range ±{HEART_RATE_WINDOW} BPM') # type: ignore
ax1.set_xlabel('Heart Rate (BPM)') # type: ignore
ax1.set_ylabel('Power Spectral Density') # type: ignore
ax1.set_title('PSD of fNIRS signal - Peak indicates Heart Rate') # type: ignore
ax1.grid(True) # type: ignore
# Was the value we reported here correct for the data on the graph or was it overruled?
if overruled:
note = (
'\n'
'Note: Calculation was bad!\n'
'Data has been set to match\n'
'the value from NeuroKit2.'
)
phantom = Line2D([0], [0], color='none', label=note)
handles, _ = ax1.get_legend_handles_labels()
ax1.legend(handles=handles + [phantom]) # type: ignore
else:
ax1.legend() # type: ignore
plt.close(fig1)
# Create the second plot showing the rolling heart rate, as well as the two averages that were calculated
logger.info("Creating the figure...")
fig2, ax2 = plt.subplots(figsize=(14, 6)) # type: ignore
ax2.plot(times_trimmed, hr_smooth_nk, label='Instantaneous HR (NeuroKit2)', color='blue', alpha=0.7) # type: ignore
ax2.axhline(mean_hr_nk, color='red', linestyle='--', label=f'Mean HR NeuroKit2: {mean_hr_nk:.1f} BPM') # type: ignore
ax2.axhline(mean_hr_scipy, color='orange', linestyle=':', label=f'SciPy Welch PSD (HP filtered): {mean_hr_scipy:.1f} BPM') # type: ignore
ax2.set_xlabel('Time (seconds)') # type: ignore
ax2.set_ylabel('Heart Rate (BPM)') # type: ignore
ax2.set_title('Heart Rate Estimates Comparison') # type: ignore
ax2.legend() # type: ignore
ax2.grid(True) # type: ignore
fig2.tight_layout()
plt.close(fig2)
return fig1, fig2
def plot_timechannel_quality_metrics(data: BaseRaw, scores: NDArray[float64], times: list[tuple[float]], color_stops: tuple[list[float], list[float]], threshold: float, title: Optional[str] = None) -> tuple[Figure, Figure]:
"""
Generate two heatmaps visualizing channel quality metrics over time.
Parameters
----------
data : BaseRaw
The loaded data object to process.
scores : NDArray[float64]
A 2D array of quality scores for each channel over time.
times : list[tuple[float]]
List of time boundaries used to label each score column.
color_stops : tuple[list[float], list[float]]
Two lists of color values for custom colormaps.
threshold : float,
Threshold value for the color bar.
title : Optional[str], optional
Base title for the figures, (default is None).
Returns
-------
tuple[Figure, Figure]
- Figure: Heatmap of all scores across channels and time.
- Figure: Binary heatmap showing only scores above the threshold.
"""
# Get only the hbo / hbr channels once as we dont need to see the same results twice
half_ch = len(getattr(data, "ch_names")) // 2
ch_names = getattr(data, "ch_names")[:half_ch]
scores = scores[:half_ch, :]
# Extract rounded time points to use as column headers
cols = [np.round(t[0]) for t in times]
n_chans = len(ch_names)
vsize = 0.2 * n_chans
# Create the first figure
fig1, ax1 = plt.subplots(figsize=(10, vsize), layout="constrained") # type: ignore
fig1.suptitle(title + " - All Scores", fontsize=16, fontweight="bold") # type: ignore
# Create a DataFrame to structure data for the heatmap
data_to_plot = DataFrame(
data=scores,
columns=pd.Index(cols, name="Time (s)"),
index=pd.Index(ch_names, name="Channel"),
)
# Define a custom colormap using provided color stops and base colors
base_colors = ['red', 'red', 'yellow', 'green', 'green']
colors = list(zip(color_stops[0], base_colors[:len(color_stops[0])]))
cmap = mcolors.LinearSegmentedColormap.from_list('gyr', colors)
# Plot heatmap of scores
sns.heatmap( # type: ignore
data=data_to_plot,
cmap=cmap,
vmin=0,
vmax=1,
cbar_kws=dict(label="Score"),
ax=ax1,
)
# Add vertical dashed lines at each time boundary, sit the title, and place a black strikethrough through a bad channel
for x in range(1, len(times)):
ax1.axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") # type: ignore
ax1.set_title("All Scores", fontweight="bold") # type: ignore
markbad(data, ax1, ch_names)
# Calculate average score per channel and annotate to the right of the heatmap
avg_sci_subset: pd.Series[float] = data_to_plot.mean(axis=1) # type: ignore
norm = mcolors.Normalize(vmin=0, vmax=1)
text_x = data_to_plot.shape[1] + 0.5
for i, val in enumerate(avg_sci_subset):
color = cmap(norm(val))
ax1.text( # type: ignore
text_x,
i + 0.5,
f"{val:.3f}",
va='center',
ha='left',
fontsize=9,
color=color
)
ax1.set_xlim(right=text_x + 1.5)
plt.close(fig1)
# Create the second figure
fig2, ax2 = plt.subplots(figsize=(10, vsize), layout="constrained") # type: ignore
fig2.suptitle(title + " - Scores Above Threshold", fontsize=16, fontweight="bold") # type: ignore
# Create a DataFrame to structure data for the heatmap
data_to_plot = DataFrame(
data=scores > threshold,
columns=pd.Index(cols, name="Time (s)"),
index=pd.Index(ch_names, name="Channel"),
)
# Define a custom colormap using provided color stops and base colors
base_colors = ['red', 'red', 'white', 'white']
colors = list(zip(color_stops[1], base_colors[:len(color_stops[1])]))
cmap = mcolors.LinearSegmentedColormap.from_list('gyr', colors)
# Plot heatmap of scores
sns.heatmap( # type: ignore
data=data_to_plot,
vmin=0,
vmax=1,
cmap=cmap,
cbar_kws=dict(label="Score"),
ax=ax2,
)
# Add vertical dashed lines at each time boundary, sit the title, and place a black strikethrough through a bad channel
for x in range(1, len(times)):
ax2.axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") # type: ignore
ax2.set_title("Scores > Threshold", fontweight="bold") # type: ignore
markbad(data, ax2, ch_names)
plt.close(fig2)
return fig1, fig2
def markbad(data: BaseRaw, ax: Axes, ch_names: list[str]) -> None:
"""
Add a strikethrough to a plot for channels marked as bad.
Parameters
----------
data : BaseRaw
The loaded data object to process.
ax : Axes
Matplotlib Axes object where the strikethrough lines will be drawn.
ch_names : list[str]
List of channel names corresponding to the y-axis of the plot.
"""
# Iterate over all the channels
for i, ch in enumerate(ch_names):
# If it is marked as bad, place a strikethrough on the channel
if ch in data.info["bads"]:
ax.axhline(i + 0.5, ls="solid", lw=4, color="black", zorder=10) # type: ignore
def calculate_scalp_coupling(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5) -> tuple[list[str], Figure, Figure]:
"""
Calculate the scalp coupling index (SCI) and identify bad channels based on a threshold.
Parameters
----------
data : BaseRaw
The loaded data object to process.
l_freq : float, optional
Low cutoff frequency for bandpass filtering in Hz (default is 0.7).
h_freq : float, optional
High cutoff frequency for bandpass filtering in Hz (default is 1.5)
Returns
-------
tuple[list[str], Figure, Figure]
- list[str]: Channel names identified as bad based on SCI threshold.
- Figure: Heatmap of all SCI scores across time and channels.
- Figure: Binary heatmap of SCI scores exceeding the threshold.
"""
logger.info("Calculating scalp coupling index...")
# Compute the SCI
_, scores, times = cast(tuple[NDArray[float64], NDArray[float64], list[tuple[float]]], scalp_coupling_index_windowed_raw(data, time_window=SCI_TIME_WINDOW, l_freq=l_freq, h_freq=h_freq))
# Identify channels that don't meet the provided threshold
logger.info("Identifying channels that do not meet the threshold...")
sci = scores.mean(axis=1)
data.info["bads"] = list(compress(cast(list[str], getattr(data, "ch_names")), sci < SCI_THRESHOLD))
# Determine the colors based on the threshold, and create the figures
logger.info("Creating the figures...")
color_stops = ([0.0, SCI_THRESHOLD, SCI_THRESHOLD+0.1, 0.8, 1.0], [0.0, SCI_THRESHOLD, SCI_THRESHOLD, 1.0])
fig1, fig2 = plot_timechannel_quality_metrics(data, scores, times, color_stops, SCI_THRESHOLD, "Scalp Coupling Index")
logger.info("Successfully calculated scalp coupling index.")
return list(compress(cast(list[str], getattr(data, "ch_names")), sci < SCI_THRESHOLD)), fig1, fig2
def scalp_coupling_index_windowed_raw(data: BaseRaw, time_window: float = 3.0, l_freq: float = 0.7, h_freq: float = 1.5, l_trans_bandwidth: float = 0.3, h_trans_bandwidth: float = 0.3) -> tuple[BaseRaw, NDArray[float64], list[tuple[float, float]]]:
"""
Compute windowed scalp coupling index (SCI) across fNIRS channels.
Parameters
----------
data : BaseRaw
The loaded data object to process.
time_window : float, optional
Length of each time window in seconds (default is 3.0).
l_freq : float, optional
Low cutoff frequency for filtering in Hz (default is 0.7).
h_freq : float, optional
High cutoff frequency for filtering in Hz (default is 1.5).
l_trans_bandwidth : float, optional
Transition bandwidth for the low cutoff in Hz (default is 0.3).
h_trans_bandwidth : float, optional
Transition bandwidth for the high cutoff in Hz (default is 0.3).
Returns
-------
tuple[BaseRaw, NDArray[float64], list[tuple[float, float]]]
- BaseRaw: The original data object (unchanged). Ensures compatibility with peak_power().
- NDArray[float64]: Correlation scores for each channel and time window.
- list[tuple[float, float]]: Time intervals for each window in seconds.
"""
# Pick only fNIRS channels and sort them by channel name
picks: NDArray[np.intp] = mne.pick_types(cast(mne.Info, data.info), fnirs=True) # type: ignore
picks = picks[np.argsort([getattr(data, "ch_names")[pick] for pick in picks])]
# FIXME: This may happen if the heart rate calculation tries to set a value way too low
if l_freq < 0.3:
l_freq = 0.3
# Band-pass filter the selected fNIRS channels
filtered_data = cast(NDArray[float64], filter_data(
getattr(data, "_data"),
getattr(data, "info")["sfreq"],
l_freq,
h_freq,
picks=picks,
verbose=False,
l_trans_bandwidth=l_trans_bandwidth, # type: ignore
h_trans_bandwidth=h_trans_bandwidth, # type: ignore
))
# Calculate number of samples per time window, the total number of windows, and prepare output variables
window_samples = int(np.ceil(time_window * getattr(data, "info")["sfreq"]))
n_windows = int(np.floor(len(data) / window_samples))
scores = np.zeros((len(picks), n_windows))
times: list[tuple[float, float]] = []
# Slide through the data in windows to compute scalp coupling index (SCI)
for window in range(n_windows):
start_sample = int(window * window_samples)
end_sample = start_sample + window_samples
end_sample = np.min([end_sample, len(data) - 1])
# Track time boundaries for each window
t_start = getattr(data, "times")[start_sample]
t_stop = getattr(data, "times")[end_sample]
times.append((t_start, t_stop))
# Iterate through channels in pairs (hbo, hbr). This requires them to be sorted by channel name
for ii in range(0, len(picks), 2):
c1: NDArray[float64] = filtered_data[picks[ii]][start_sample:end_sample]
c2 = filtered_data[picks[ii + 1]][start_sample:end_sample]
# Ensure the correlation data is valid
if np.std(c1) == 0 or np.std(c2) == 0 or np.any(np.isnan(c1)) or np.any(np.isnan(c2)):
c = 0
else:
c = np.corrcoef(c1, c2)[0][1]
# Assign the computed correlation to both channels in the pair
scores[ii, window] = c
scores[ii + 1, window] = c
scores = scores[np.argsort(picks)]
return data, scores, times
def calculate_peak_power(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5) -> tuple[list[str], Figure, Figure]:
"""
Calculate peak spectral power (PSP) for fNIRS channels and identify bad channels.
Parameters
----------
data : BaseRaw
The loaded data object to process.
l_freq : float, optional
Low cutoff frequency for filtering in Hz (default is 0.7)
h_freq : float, optional
High cutoff frequency for filtering in Hz (default is 1.5)
Returns
-------
tuple[list[str], Figure, Figure]
- list[str]: Names of channels below the PSP threshold.
- Figure: Heatmap of all PSP scores.
- Figure: Heatmap of scores above the PSP threshold.
"""
logger.info("Calculating peak spectral power...")
# Compute the PSP
_, scores, times = cast(tuple[NDArray[float64], NDArray[float64], list[tuple[float]]], peak_power(data, time_window=PSP_TIME_WINDOW, threshold=PSP_THRESHOLD, l_freq=l_freq, h_freq=h_freq, verbose=False))
# Identify channels that don't meet the provided threshold
logger.info("Identifying channels that do not meet the threshold...")
psp = scores.mean(axis=1)
data.info["bads"] = list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD))
# Determine the colors based on the threshold, and create the figures
logger.info("Creating the figures...")
color_stops = ([0.0, PSP_THRESHOLD, PSP_THRESHOLD+0.1, 0.3, 1.0], [0.0, PSP_THRESHOLD, PSP_THRESHOLD, 1.0])
psp1, psp2 = plot_timechannel_quality_metrics(data, scores, times, color_stops, PSP_THRESHOLD, "Peak Spectral Power")
logger.info("Successfully calculated peak spectral power.")
return list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)), psp1, psp2
def calculate_signal_noise_ratio(data: BaseRaw) -> tuple[list[str], Figure]:
"""
Calculates the signal-to-noise ratio (SNR) for each channel and identifies those below a defined threshold.
Parameters
----------
data : BaseRaw
The loaded data object to process.
Returns
-------
tuple[list[str], Figure]
- list[str]: A list of channel names that fall below the SNR threshold and are considered bad.
- Figure: A matplotlib Figure showing the channels' SNR values.
"""
logger.info("Calculating signal to noise ratio...")
# Compute the signal-to-noise ratio values
logger.info("Computing the signal to noise power...")
signal_band=(0.01, 0.5)
noise_band=(1.0, 10.0)
data_signal = data.copy().filter(*signal_band, verbose=False) #type: ignore
data_noise = data.copy().filter(*noise_band, verbose=False) #type: ignore
signal_power = np.mean(data_signal.get_data()**2, axis=1) #type: ignore
noise_power = np.mean(data_noise.get_data()**2, axis=1) #type: ignore
# Calculate the snr using the standard formula for dB
snr = 10 * np.log10(signal_power / (noise_power + np.finfo(float).eps))
# TODO: Understand what this does
groups: dict[str, list[str]] = {}
for ch in getattr(data, "ch_names"):
# Look for the space in the channel names and remove the characters after
# This is so we can get both oxy and deoxy to remove, as they will have the same source and detector
base = ch.rsplit(' ', 1)[0]
groups.setdefault(base, []).append(ch) # type: ignore
# If any of the channels do not meet our threshold, they will get inserted into the bad_channels set
bad_channels: set[str] = set()
for base, ch_list in groups.items():
if any(s < SNR_THRESHOLD for s, ch in zip(snr, getattr(data, "ch_names")) if ch in ch_list):
bad_channels.update(ch_list)
# Design and create the figure
logger.info("Creating the figure...")
snr_fig, ax = plt.subplots(figsize=(12, 4), layout="constrained") # type: ignore
colors = [(0/20, 'red'), (SNR_THRESHOLD/20, 'red'), ((SNR_THRESHOLD+.5)/20, 'yellow'), ((SNR_THRESHOLD+1)/20, 'green'), (20/20, 'green')]
cmap = LinearSegmentedColormap.from_list('custom_snr_cmap', colors)
norm = mcolors.Normalize(vmin=0, vmax=20)
scatter = ax.scatter(range(len(snr)), snr, c=snr, cmap=cmap, alpha=0.8, s=100, norm=norm) # type: ignore
ax.set(xlabel="Channel Number", ylabel="Signal-to-Noise Ratio (dB)", xlim=[0, len(snr)], ylim=[0, 20])
ax.axhline(SNR_THRESHOLD, color='black', linestyle='--', alpha=0.3, linewidth=1) # type: ignore
cbar = snr_fig.colorbar(scatter, ax=ax, label="SNR Thresholds (dB)") # type: ignore
cbar.set_ticks([0, SNR_THRESHOLD, SNR_THRESHOLD+1, 20]) # type: ignore
cbar.set_ticklabels(['0', str(SNR_THRESHOLD), str(SNR_THRESHOLD+1), '20']) # type: ignore
plt.close()
logger.info("Successfully calculated signal to noise ratio.")
return list(bad_channels), snr_fig
def mark_bad_channels(data: BaseRaw, ID: str, bad_channels_sci: set[str], bad_channels_psp: set[str], bad_channels_snr: set[str]) -> tuple[BaseRaw, Figure, int]:
"""
Drops bad channels from the data and generates a bar plot showing which channels were removed and why.
Parameters
----------
data : BaseRaw
The loaded data object to process.
ID : str
File name of the the snirf file that was loaded.
bad_channels_sci : set[str]
Channels marked as bad by the SCI method.
bad_channels_psp : set[str]
Channels marked as bad by the PSP method.
bad_channels_snr : set[str]
Channels marked as bad by the SNR method.
Returns
-------
tuple[BaseRaw, Figure]
- BaseRaw: The modified data object with bad channels removed.
- Figure: A matplotlib Figure showing the dropped channels categorized by method.
"""
logger.info("Dropping the channels that were marked bad...")
# Combine all of the bad channels into one and ensure the short channel is not present
bad_channels = bad_channels_sci | bad_channels_psp | bad_channels_snr
logger.info(f"Channels that were bad on SCI: {bad_channels_sci}")
logger.info(f"Channels that were bad on PSP: {bad_channels_psp}")
logger.info(f"Channels that were bad on SNR: {bad_channels_snr}")
logger.info(f"Total bad channels: {bad_channels}")
# Add the channles to the bads key and drop the bads key from the data
data.info["bads"] = list(bad_channels)
data = cast(BaseRaw, data.drop_channels(getattr(data, "info")["bads"])) # type: ignore
# Organize channels into categories
sets = [
(bad_channels_sci, "SCI"),
(bad_channels_psp, "PSP"),
(bad_channels_snr, "SNR"),
]
# Graph what channels were dropped and why they were dropped
channel_categories: dict[str, str] = {}
for ch in bad_channels:
present_in = [name for s, name in sets if ch in s]
# Create a label for the category
if len(present_in) == 1:
label = f"{present_in[0]} only"
else:
label = " + ".join(sorted(present_in))
channel_categories[ch] = label
# Sort channels alphabetically within categories for nicer visualization
logger.info("Sorting the bad channels by type...")
categories = sorted(set(channel_categories.values()))
channel_names: list[str] = []
category_labels: list[str] = []
for cat in categories:
chs_in_cat = sorted([ch for ch, c in channel_categories.items() if c == cat])
channel_names.extend(chs_in_cat)
category_labels.extend([cat] * len(chs_in_cat))
colors = {cat: FIXED_CATEGORY_COLORS[cat] for cat in categories}
logger.info("Creating the figure...")
# Create the figure
fig, ax = plt.subplots(figsize=(10, max(3, len(channel_names) * 0.3))) # type: ignore
y_pos = range(len(channel_names))
ax.barh(y_pos, [1]*len(channel_names), color=[colors[cat] for cat in category_labels]) # type: ignore
ax.set_yticks(y_pos) # type: ignore
ax.set_yticklabels(channel_names) # type: ignore
ax.set_xlabel("Marked as Bad") # type: ignore
ax.set_title(f"Bad Channels by Method for {ID}") # type: ignore
ax.set_xlim(0, 1)
ax.set_xticks([]) # type: ignore
ax.grid(axis='x', linestyle='--', alpha=0.7) # type: ignore
# Add a legend denoting why the channels were bad
for label, color in colors.items():
ax.bar(0, 0, color=color, label=label) # type: ignore
ax.legend() # type: ignore
fig.tight_layout()
plt.close(fig)
logger.info("Successfully dropped the channels that were marked bad.")
return data, fig, len(bad_channels)
def calculate_optical_density(data: BaseRaw, ID: str) -> tuple[BaseRaw, Figure]:
"""
Converts raw intensity data to optical density and generates a plot of the transformed signals.
Parameters
----------
data : BaseRaw
The loaded data object to process.
ID : str
File name of the the snirf file that was loaded.
Returns
-------
tuple[BaseRaw, Figure]
- BaseRaw: The transformed data in optical density format.
- Figure: A matplotlib figure displaying the optical density signals across all channels.
"""
logger.info("Calculating optical density...")
# Calculate the optical density from the raw data
optical_density_data = cast(BaseRaw, optical_density(data))
logger.info("Creating the figure...")
fig = cast(Figure, optical_density_data.plot(show=False, n_channels=len(getattr(data, "ch_names")), duration=getattr(data, "times")[-1]).figure) # type: ignore
fig.suptitle(f"Optical density data for {ID}", fontsize=16) # type: ignore
fig.subplots_adjust(top=0.92)
plt.close(fig)
logger.info("Successfully calculated optical density.")
return optical_density_data, fig
# STEP 9: Haemoglobin concentration
def calculate_haemoglobin_concentration(optical_density_data: BaseRaw, ID: str, file_path: str) -> tuple[BaseRaw, Figure]:
"""
Calculates haemoglobin concentration from optical density data using the Beer-Lambert law and generates a plot.
Parameters
----------
optical_density_data : BaseRaw
The data in optical density format.
ID : str
File name of the the snirf file that was loaded.
file_path : str
Entire file path if snirf file that was loaded.
Returns
-------
tuple[BaseRaw, Figure]
- BaseRaw: The haemoglobin concentration data object.
- Figure: A matplotlib figure displaying the haemoglobin concentration signals.
"""
logger.info("Calculating haemoglobin concentration data...")
# Get the haemoglobin concentration using beer lambert law
haemoglobin_concentration_data = beer_lambert_law(optical_density_data, ppf=calculate_dpf(file_path))
logger.info("Creating the figure...")
fig = cast(Figure, optical_density_data.plot(show=False, n_channels=len(getattr(optical_density_data, "ch_names")), duration=getattr(optical_density_data, "times")[-1]).figure) # type: ignore
fig.suptitle(f"Haemoglobin concentration data for {ID}", fontsize=16) # type: ignore
fig.subplots_adjust(top=0.92)
plt.close(fig)
logger.info("Successfully calculated haemoglobin concentration data.")
return haemoglobin_concentration_data, fig
# -------------------------------------- HARDCODED -----------------------------------------------
def extract_normal_epochs(haemoglobin_concentration_data: BaseRaw) -> dict[str, list[Any] | mne.evoked.EvokedArray]:
events, _ = mne.events_from_annotations(haemoglobin_concentration_data, event_id={"Reach": 1, "Start of Rest": 2}, verbose=VERBOSITY) # type: ignore
event_dict = {"Reach": 1, "Start of Rest": 2}
epochs = mne.Epochs(
haemoglobin_concentration_data,
events,
event_id=event_dict,
tmin=TIME_MIN_THRESH,
tmax=TIME_MAX_THRESH,
reject=dict(hbo=EPOCH_REJECT_CRITERIA_THRESH),
reject_by_annotation=True,
proj=True,
baseline=(None, 0),
preload=True,
detrend=None,
verbose=VERBOSITY,
)
evoked_dict: dict[str, list[Any] | mne.evoked.EvokedArray] = {
"Reach/HbO": epochs["Reach"].average(picks="hbo"), # type: ignore
"Reach/HbR": epochs["Reach"].average(picks="hbr"), # type: ignore
}
# Rename channels until the encoding of frequency in ch_name is fixed
for condition in evoked_dict:
evoked_dict[condition].rename_channels(lambda x: x[:-4]) # type: ignore
return evoked_dict
def calculate_and_apply_negative_correlation_enhancement(haemoglobin_concentration_data: BaseRaw) -> dict[str, list[Any] | mne.evoked.EvokedArray]:
events, _ = mne.events_from_annotations(haemoglobin_concentration_data, event_id={"Reach": 1, "Start of Rest": 2}, verbose=VERBOSITY) # type: ignore
event_dict = {"Reach": 1, "Start of Rest": 2}
raw_anti = enhance_negative_correlation(haemoglobin_concentration_data)
epochs_anti = mne.Epochs(
raw_anti,
events,
event_id=event_dict,
tmin=TIME_MIN_THRESH,
tmax=TIME_MAX_THRESH,
reject=dict(hbo=EPOCH_REJECT_CRITERIA_THRESH),
reject_by_annotation=True,
proj=True,
baseline=(None, 0),
preload=True,
detrend=None,
verbose=VERBOSITY,
)
evoked_dict_anti: dict[str, list[Any] | mne.evoked.EvokedArray] = {
"Reach/HbO": epochs_anti["Reach"].average(picks="hbo"), # type: ignore
"Reach/HbR": epochs_anti["Reach"].average(picks="hbr"), # type: ignore
}
# Rename channels until the encoding of frequency in ch_name is fixed
for condition in evoked_dict_anti:
evoked_dict_anti[condition].rename_channels(lambda x: x[:-4]) # type: ignore
return evoked_dict_anti
def calculate_and_apply_short_channel_correction(optical_density_data: BaseRaw, file_path: str) -> dict[str, list[Any] | mne.evoked.EvokedArray]:
od_corrected = short_channel_regression(optical_density_data, SHORT_CHANNEL_THRESH)
haemoglobin_concentration_data = beer_lambert_law(od_corrected, ppf=calculate_dpf(file_path))
events, _ = mne.events_from_annotations(haemoglobin_concentration_data, event_id={"Reach": 1, "Start of Rest": 2}, verbose=VERBOSITY) # type: ignore
event_dict = {"Reach": 1, "Start of Rest": 2}
epochs_corr = mne.Epochs(
haemoglobin_concentration_data,
events,
event_id=event_dict,
tmin=TIME_MIN_THRESH,
tmax=TIME_MAX_THRESH,
reject=dict(hbo=EPOCH_REJECT_CRITERIA_THRESH),
reject_by_annotation=True,
proj=True,
baseline=(None, 0),
preload=True,
detrend=None,
verbose=VERBOSITY,
)
evoked_dict_corr: dict[str, list[Any] | mne.evoked.EvokedArray] = {
"Reach/HbO": epochs_corr["Reach"].average(picks="hbo"), # type: ignore
"Reach/HbR": epochs_corr["Reach"].average(picks="hbr"), # type: ignore
}
# Rename channels until the encoding of frequency in ch_name is fixed
for condition in evoked_dict_corr:
evoked_dict_corr[condition].rename_channels(lambda x: x[:-4]) # type: ignore
return evoked_dict_corr
def signal_enhancement_techniques_images(evoked_dict: dict[str, list[Any] | mne.evoked.EvokedArray], evoked_dict_anti: dict[str, list[Any] | mne.evoked.EvokedArray], evoked_dict_corr:dict[str, list[Any] | mne.evoked.EvokedArray] | None):
# If we have two images, ensure we only have two columns
if evoked_dict_corr is None:
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 6)) # type: ignore
# If we have three images, ensure we have three columns
else:
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 6)) # type: ignore
color_dict = dict(HbO="#AA3377", HbR="b")
# TODO: This is to prevent the warning that we are only plotting one channel. Don't we want all though?
mne.set_log_level('WARNING') # type: ignore
logger.info("Creating the figure...")
# Plot the graph for the original data
mne.viz.plot_compare_evokeds( # type: ignore
evoked_dict,
combine="mean",
ci=0.95, # type: ignore
axes=axes[0],
colors=color_dict,
ylim=dict(hbo=[-10, 15]),
show=False,
)
# Plot the graph for the enhanced anticorrelation data
mne.viz.plot_compare_evokeds( # type: ignore
evoked_dict_anti,
combine="mean",
ci=0.95, # type: ignore
axes=axes[1],
colors=color_dict,
ylim=dict(hbo=[-10, 15]),
show=False,
)
# Plot the graph for short channel regression data, if it exists
if evoked_dict_corr is not None:
mne.viz.plot_compare_evokeds( # type: ignore
evoked_dict_corr,
combine="mean",
ci=0.95, # type: ignore
axes=axes[2],
colors=color_dict,
ylim=dict(hbo=[-10, 15]),
show=False,
)
mne.set_log_level('INFO') # type: ignore
# If we have a short channel, set three titles
if evoked_dict_corr is not None:
for column, condition in enumerate(
["Original Data", "With Enhanced Anticorrelation", "With Short Regression"]
):
axes[column].set_title(f"{condition}")
# If we do not have a short channel, set two titles
else:
for column, condition in enumerate(
["Original Data", "With Enhanced Anticorrelation"]
):
axes[column].set_title(f"{condition}")
plt.close(fig)
return fig
def create_design_matrix(data: BaseRaw, stim_duration: float, short_chans: BaseRaw | None) -> tuple[DataFrame, Figure]:
"""
Creates a design matrix for first-level analysis including optional short channel regression, and generates a plot.
Parameters
----------
data : BaseRaw
The loaded data object to process.
stim_duration : float
Duration of the stimulus/event in seconds.
short_chans : BaseRaw | None
Data object containing only short channels for systemic component regression, or None if there is no short channels.
Returns
-------
tuple[DataFrame, Figure]
- DataFrame: The generated design matrix.
- Figure: A matplotlib figure visualizing the design matrix.
"""
# Create the design martix
logger.info("Creating the design matrix... (This may take some time)")
# If the design matrix is fir, calculate some of the extra required parameters before creating the matrix
if HRF_MODEL == "fir":
sfreq = getattr(data, "info")["sfreq"]
fir_delays = range(int(sfreq*15))
design_matrix = make_first_level_design_matrix(
data,
stim_dur=0.1,
hrf_model=HRF_MODEL,
drift_model=DRIFT_MODEL,
high_pass=1/(2*DURATION_BETWEEN_ACTIVITIES),
fir_delays=fir_delays
)
# Using a canonical hrf model
else:
design_matrix = make_first_level_design_matrix(
data,
stim_dur=stim_duration,
hrf_model=HRF_MODEL,
drift_model=DRIFT_MODEL,
high_pass=1/(2*DURATION_BETWEEN_ACTIVITIES),
)
# If we have a short channel, and short channel regression was specified, apply it to the design matrix
if short_chans is not None:
if SHORT_CHANNEL_REGRESSION:
logger.info("Applying short channel regression...")
for chan in range(len(short_chans.ch_names)): # type: ignore
design_matrix[f"short_{chan}"] = short_chans.get_data(chan).T # type: ignore
logger.info("Creating the figure...")
fig, ax1 = plt.subplots(figsize=(10, 6), constrained_layout=True) # type: ignore
plot_design_matrix(design_matrix, axes=ax1)
plt.close(fig)
logger.info("Successfully created the design matrix.")
return design_matrix, fig
def run_GLM_analysis(data: BaseRaw, design_matrix: DataFrame) -> RegressionResults:
"""
Runs a General Linear Model (GLM) analysis on the provided data using the specified design matrix.
Parameters
----------
data : BaseRaw
The loaded data object to process.
design_matrix : DataFrame
The design matrix specifying regressors for the GLM.
Returns
-------
RegressionResults
The fitted GLM results object containing regression coefficients and statistics.
"""
logger.info("Running the GLM...")
glm_est = run_glm(data, design_matrix, n_jobs=N_JOBS)
logger.info("Successfully ran the GLM.")
return glm_est
def calculate_dpf(file_path):
# order is hbo / hbr
import h5py
with h5py.File(file_path, 'r') as f:
wavelengths = f['/nirs/probe/wavelengths'][:]
print("Wavelengths (nm):", wavelengths)
wavelengths = sorted(wavelengths, reverse=True)
data = METADATA.get(file_path)
if data is None:
age = 25
else:
age = data['Age']
logger.info(age)
age = float(age)
a = 223.3
b = 0.05624
c = 0.8493
d = -5.723e-7
e = 0.001245
f = -0.9025
dpf = []
for w in wavelengths:
logger.info(w)
dpf.append(a + b * (age**c) + d* (w**3) + e * (w**2) + f*w)
logger.info(dpf)
return dpf
def individual_GLM_analysis(file_path: str, ID: str, stim_duration: float = 5.0, progress_callback=None) -> tuple[BaseRaw, BaseRaw, DataFrame, DataFrame, DataFrame, DataFrame, dict[str, Figure], str, bool, bool]:
"""
Performs individual-level General Linear Model (GLM) analysis on fNIRS data from a SNIRF file.
Parameters
----------
file_path : str
Path to the SNIRF file containing the participant's raw data.
ID : str
Unique identifier for the participant, used for labeling output.
stim_duration : float, optional
Duration of the stimulus in seconds for constructing the design matrix (default is 5.0)
Returns
-------
tuple[BaseRaw, BaseRaw, DataFrame, DataFrame, DataFrame, DataFrame, dict[str, Figure], str, bool, bool]
- BaseRaw: Processed fNIRS data
- BaseRaw: Full layout raw data prior to bad channel rejection
- DataFrame: Region of interest statistics
- DataFrame: Channel-level GLM statistics
- DataFrame: Contrast results
- DataFrame: Design matrix used for GLM
- dict[str, Figure]: Dictionary of figures generated during processing
- str: Description of processing steps applied
- bool: Whether the GLM successfully ran to completion
- bool: Whether the analysis result is valid based on quality checks
"""
# Setting up variables to be used later
fig_dict: dict[str, Figure] = {}
bad_channels_sci = []
bad_channels_psp = []
bad_channels_snr = []
mean_hr_nk = 70
mean_hr_scipy = 70
num_bad_channels = 0
valid = True
short_chans = None
roi: DataFrame = DataFrame()
cha: DataFrame = DataFrame()
con: DataFrame = DataFrame()
design_matrix = DataFrame()
# Load the file, get the sources and detectors, update their position, and calculate the short channel and any large distance channels
# STEP 1
data, fig = load_snirf(file_path, ID, FORCE_DROP_CHANNELS)
fig_dict['Raw'] = fig
order_of_operations = "Loaded Raw File"
if progress_callback: progress_callback(1)
# Initalize the participants full layout to be the current data regardless if it will be updated later
raw_full_layout = data
logger.info(file_path)
logger.info(ID)
logger.info(METADATA.get(file_path))
calculate_dpf(file_path)
try:
# Did the user want to load new channel positions from an optode file?
# STEP 2
if OPTODE_FILE:
data = calculate_and_apply_updated_optode_coordinates(data)
order_of_operations += " + Updated Optode Placements"
if progress_callback: progress_callback(2)
# STEP 2.5
# TODO: remember why i do this
# I think its because i want a participants whole layout to plot without any bads
# but i shouldnt need to do od and bll just check the last three numbers
temp = data.copy()
temp_od = cast(BaseRaw, optical_density(temp, verbose=VERBOSITY))
raw_full_layout = beer_lambert_law(temp_od, ppf=calculate_dpf(file_path))
# If specified, apply TDDR to the data
# STEP 3
if TDDR:
data, fig = calculate_and_apply_tddr(data, ID)
order_of_operations += " + TDDR Filter"
fig_dict['TDDR'] = fig
if progress_callback: progress_callback(3)
# If specified, apply a wavelet filter to the data
# STEP 4
if WAVELET:
data, fig = calculate_and_apply_wavelet(data, ID)
order_of_operations += " + Wavelet Filter"
fig_dict['Wavelet'] = fig
if progress_callback: progress_callback(4)
# If specified, attempt to get short channels from the data
# STEP 4.5
if SHORT_CHANNEL:
try:
short_chans = get_short_channels(data, SHORT_CHANNEL_THRESH)
except Exception as e:
raise ProcessingError("SHORT_CHANNEL was specified, but no short channel was found. Please ensure the data has a short channel and that SHORT_CHANNEL_THRESH is set correctly.")
pass
else:
pass
# Ensure that there is no short or really long channels in the data
data = get_long_channels(data, SHORT_CHANNEL_THRESH, LONG_CHANNEL_THRESH)
# STEP 5
if HEART_RATE:
sfreq, signal_trimmed, times_trimmed = short_channel_processing_for_hr(data, short_chans)
hr_smooth_nk, mean_hr_nk = calculate_heart_rate_neurokit(sfreq, signal_trimmed)
freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy = calculate_heart_rate_scipy(sfreq, signal_trimmed)
# HACK: This sucks but looking at the graphs I trust neurokit2 more
overruled = False
if mean_hr_scipy < mean_hr_nk - 15:
mean_hr_scipy = mean_hr_nk
overruled = True
if mean_hr_scipy > mean_hr_nk + 15:
mean_hr_scipy = mean_hr_nk
overruled = True
hr1, hr2 = plot_heart_rate(freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy, hr_smooth_nk, mean_hr_nk, times_trimmed, overruled)
order_of_operations += " + Heart Rate Calculation"
fig_dict['HeartRate_PSD'] = hr1
fig_dict['HeartRate_Time'] = hr2
if progress_callback: progress_callback(5)
# If specified, calculate and apply SCI
# STEP 6
if SCI:
bad_channels_sci, sci1, sci2 = calculate_scalp_coupling(data.copy(), min(mean_hr_nk - HEART_RATE_WINDOW, mean_hr_scipy - HEART_RATE_WINDOW) / 60, max(mean_hr_nk + HEART_RATE_WINDOW, mean_hr_scipy + HEART_RATE_WINDOW) / 60)
order_of_operations += " + SCI Calculation"
fig_dict['SCI1'] = sci1
fig_dict['SCI2'] = sci2
if progress_callback: progress_callback(6)
# If specified, calculate and apply PSP
if PSP:
bad_channels_psp, psp1, psp2 = calculate_peak_power(data.copy(), min(mean_hr_nk - HEART_RATE_WINDOW, mean_hr_scipy - HEART_RATE_WINDOW) / 60, max(mean_hr_nk + HEART_RATE_WINDOW, mean_hr_scipy + HEART_RATE_WINDOW) / 60)
order_of_operations += " + PSP Calculation"
fig_dict['PSP1'] = psp1
fig_dict['PSP2'] = psp2
if progress_callback: progress_callback(7)
# If specified, calculate and apply SNR
if SNR:
bad_channels_snr, fig = calculate_signal_noise_ratio(data.copy())
order_of_operations += " + SNR Calculation"
fig_dict['SNR'] = fig
# If specified, drop channels that were marked as bad
# STEP 7
if EXCLUDE_CHANNELS:
data, fig, num_bad_channels = mark_bad_channels(data, ID, set(bad_channels_sci), set(bad_channels_psp), set(bad_channels_snr))
order_of_operations += " + Excluded Bad Channels"
fig_dict['Bads'] = fig
if progress_callback: progress_callback(7)
# Calculate the optical density
# STEP 8
data, fig = calculate_optical_density(data, ID)
order_of_operations += " + Optical Density"
fig_dict['OpticalDensity'] = fig
if progress_callback: progress_callback(8)
# Mainly for visualization. Could be implemented in the future
# STEP 8.5
evoked_dict_corr = None
if SHORT_CHANNEL:
short_chans_od = cast(BaseRaw, optical_density(short_chans))
data_recombined = cast(BaseRaw, data.copy().add_channels([short_chans_od])) # type: ignore
evoked_dict_corr = calculate_and_apply_short_channel_correction(data_recombined.copy(), file_path)
# Calculate the haemoglobin concentration
# STEP 9
data, fig = calculate_haemoglobin_concentration(data, ID, file_path)
order_of_operations += " + Haemoglobin Concentration"
fig_dict['HaemoglobinConcentration'] = fig
if progress_callback: progress_callback(9)
# Mainly for visualization. Could be implemented in the future
# STEP 9.5
evoked_dict = extract_normal_epochs(data.copy())
evoked_dict_anti = calculate_and_apply_negative_correlation_enhancement(data.copy())
fig = signal_enhancement_techniques_images(evoked_dict, evoked_dict_anti, evoked_dict_corr)
fig_dict['SignalEnhancement'] = fig
# Create the design matrix
# STEP 10
# HACK FIXME - Downsampling to 10 is certaintly not the best way... right?
if HRF_MODEL == 'fir':
data.resample(10, verbose=VERBOSITY) # type: ignore
if short_chans is not None:
short_chans.resample(10, verbose=VERBOSITY) # type: ignore
design_matrix, fig = create_design_matrix(data, stim_duration, short_chans)
order_of_operations += " + Design Matrix"
fig_dict['DesignMatrix'] = fig
if progress_callback: progress_callback(10)
# Run the glm on the design matrix
# STEP 11
glm_est: RegressionResults = run_GLM_analysis(data, design_matrix)
order_of_operations += " + GLM"
if progress_callback: progress_callback(11)
# Add the regions of interest to the groups
# STEP 12
logger.info("Performing the finishing touches...")
order_of_operations += " + Finishing Touches"
# Extract the channel metrics
logger.info("Calculating channel results...")
cha = cast(DataFrame, glm_est.to_dataframe()) # type: ignore
logger.info("Creating groups...")
if HRF_MODEL == "fir":
groups = dict(AllChannels=range(len(data.ch_names))) # type: ignore
else:
groups: dict[str, list[int]] = dict(
group_1_picks = picks_pair_to_idx(data, ROI_GROUP_1, on_missing="ignore"), # type: ignore
group_2_picks = picks_pair_to_idx(data, ROI_GROUP_2, on_missing="ignore"), # type: ignore
)
# Compute region of interest results from the channel data
logger.info("Calculating region of intrest results...")
roi = glm_est.to_dataframe_region_of_interest(groups, design_matrix.columns, demographic_info=True) # type: ignore
# Create the contrast matrix
logger.info("Creating the contrast matrix...")
contrast_matrix = np.eye(design_matrix.shape[1])
basic_conts = dict(
[(column, contrast_matrix[i]) for i, column in enumerate(design_matrix.columns)]
)
# Calculate contrast differently depending on the hrf model
if HRF_MODEL == 'fir':
# Find all FIR regressors for TARGET_ACTIVITY
delay_cols = [col for col in design_matrix.columns if col.startswith(f"{TARGET_ACTIVITY}_delay_")]
if not delay_cols:
raise ValueError(f"No FIR regressors found for condition {TARGET_ACTIVITY}.")
# Sum or average their contrast vectors
fir_contrast = np.sum([basic_conts[col] for col in delay_cols], axis=0)
fir_contrast /= len(delay_cols)
# Compute contrast
contrast = glm_est.compute_contrast(fir_contrast) # type: ignore
con = cast(DataFrame, contrast.to_dataframe()) # type: ignore
else:
# Create and compute the contrast
contrast_t = basic_conts[TARGET_ACTIVITY]
contrast = glm_est.compute_contrast(contrast_t) # type: ignore
con = cast(DataFrame, contrast.to_dataframe()) # type: ignore
# Add the participant ID to the dataframes
roi["ID"] = cha["ID"] = con["ID"] = design_matrix["ID"] = ID
# Convert to uM for nicer plotting below.
logger.info("Converting to uM...")
cha["theta"] = cha["theta"].astype(float) * 1.0e6
roi["theta"] = roi["theta"].astype(float) * 1.0e6
con["effect"] = con["effect"].astype(float) * 1.0e6
# If we exceed the maximum allowed bad channels, apply an X over the figures
logger.info("Checking amount of bad channels...")
if num_bad_channels >= MAX_BAD_CHANNELS:
valid=False
logger.info("Drawing some big X's...")
for _, fig in fig_dict.items():
add_x_overlay(fig, 'Too many bad channels!', 'red')
logger.info("Completed individual analysis.")
if progress_callback: progress_callback(12)
# Clear the output for the next participant unless we are told to be verbose
if not VERBOSITY:
clear_output(wait=True)
# Something really went wrong and we should not continue
except ProcessingError as e:
logger.info("An error occured!", e)
raise
# Something went wrong at one of the steps. Return what data we gathered, but set the validity of this run to False
except Exception as e:
logger.info("An error occured!", e)
fig_dict_bytes = convert_fig_dict_to_png_bytes(fig_dict)
return data, raw_full_layout, roi, cha, con, design_matrix, fig_dict, order_of_operations, False, False
fig_dict_bytes = convert_fig_dict_to_png_bytes(fig_dict)
return data, raw_full_layout, roi, cha, con, design_matrix, fig_dict_bytes, order_of_operations, True, valid
def add_x_overlay(fig: Figure, reason: str, color: str) -> None:
"""
Adds a large 'X' across the figure if the participant met the bad channel criteria.
Parameters
----------
fig : Figure
Matplotlib figure to draw the X on.
reason: str
Why the X is being drawn.
color: str
What color the reason should be.
"""
# Draw the big X on the graph
ax = fig.add_axes([0, 0, 1, 1], zorder=100) # type: ignore
ax.set_axis_off()
ax.plot([0, 1], [0, 1], color='red', linewidth=8, transform=fig.transFigure, clip_on=False) # type: ignore
ax.plot([0, 1], [1, 0], color='red', linewidth=8, transform=fig.transFigure, clip_on=False) # type: ignore
ax.text(0.5, 0.5, reason, color=color, fontsize=26, fontweight='bold', ha='center', va='center', transform=fig.transFigure, zorder=101, bbox=dict(facecolor='white', alpha=0.8, edgecolor='red', boxstyle='round,pad=0.4')) # type: ignore
from io import BytesIO
def convert_fig_dict_to_png_bytes(fig_dict):
result = {}
for key, fig in fig_dict.items():
buf = BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
result[key] = buf.read()
return result
def process_file_worker(args):
file_path, file_name, stim_duration, config, gui, progress_queue = args
try:
set_config(config, gui)
def progress_callback(step_idx):
print(f"[Worker] Step {step_idx} for {file_name}")
if progress_queue:
progress_queue.put(('progress', file_name, step_idx))
result = individual_GLM_analysis(
file_path, file_name, stim_duration,
progress_callback=progress_callback
)
return file_name, result, None
except Exception as e:
return file_name, None, e
def process_folder(folder_path: str, stim_duration: float, files_remaining: dict[str, int], config , gui: bool = False, progress_queue=None) -> tuple[dict[str, dict[str, BaseRaw]], DataFrame, DataFrame, DataFrame, DataFrame, dict[str, list[Figure]], dict[str, str]]:
df_roi = DataFrame()
df_cha = DataFrame()
df_con = DataFrame()
df_design_matrix = DataFrame()
raw_haemo_dict: dict[str, dict[str, BaseRaw]] = {}
process_dict: dict[str, str] = {}
figures_by_step: dict[str, list[Figure]] = {
step: [] for step in [
'Raw', 'TDDR', 'Wavelet', 'HeartRate_PSD', 'HeartRate_Time',
'SCI1', 'SCI2', 'PSP1', 'PSP2', 'SNR', 'Bads',
'OpticalDensity', 'HaemoglobinConcentration', 'SignalEnhancement', 'DesignMatrix'
]
}
file_args = [
(os.path.join(folder_path, file_name), file_name, stim_duration, config, gui, progress_queue)
for file_name in os.listdir(folder_path)
if os.path.isfile(os.path.join(folder_path, file_name))
]
print("[process_folder] File args:", file_args)
available_mem = psutil.virtual_memory().available
if (MAX_WORKERS >= available_mem / (1024 ** 3)):
print(f"WARNING: You have set MAX_WORKERS to {MAX_WORKERS}. Each worker should have at least 1GB of system memory. Your device currently has a total of {available_mem / (1024 ** 3):.2f}GB free.\nPlease consider lowering MAX_WORKERS to prevent potential crashing due to insufficient system memory.")
with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor:
future_to_file = {
executor.submit(process_file_worker, args): args[1] for args in file_args
}
with tqdm(total=len(file_args), desc="Processing files") as pbar:
for future in as_completed(future_to_file):
file_name = future_to_file[future]
files_remaining['count'] -= 1
logger.info(f"Files remaining: {files_remaining['count']}")
pbar.update(1)
try:
file_name, result, error = future.result()
if error:
logger.info(f"Error processing {file_name}: {error}")
continue
raw_haemo_filtered, raw_haemo_full, roi, channel, contrast, design_matrix, fig_dict, process, finished, valid = result
if finished and valid:
logger.info(f"Finished processing {file_name}. This participant was valid.")
raw_haemo_dict[file_name] = {
"filtered": raw_haemo_filtered,
"full_layout": raw_haemo_full
}
process_dict[file_name] = process
for step in figures_by_step:
if step in fig_dict:
figures_by_step[step].append(fig_dict[step])
df_roi = pd.concat([df_roi, roi], ignore_index=True)
df_cha = pd.concat([df_cha, channel], ignore_index=True)
df_con = pd.concat([df_con, contrast], ignore_index=True)
df_design_matrix = pd.concat([df_design_matrix, design_matrix], ignore_index=True)
else:
logger.info(f"Finished processing {file_name}. This participant was NOT valid.")
if SEE_BAD_IMAGES:
for step in figures_by_step:
if step in fig_dict:
figures_by_step[step].append(fig_dict[step])
except Exception as e:
logger.info(f"Unexpected error processing {file_name}: {e}")
raise
return raw_haemo_dict, df_roi, df_cha, df_con, df_design_matrix, figures_by_step, process_dict
def verify_channel_positions(data: BaseRaw) -> None:
"""
Visualizes the sensor/channel positions of the raw data for verification.
Parameters
----------
data : BaseRaw
The loaded data object to process.
"""
logger.info("Creating the figure...")
data.plot_sensors(show_names=True, to_sphere=True, show=False, verbose=VERBOSITY) # type: ignore
plt.show() # type: ignore
def plot_3d_evoked_array(
inst: Union[BaseRaw, EvokedArray, Info],
statsmodel_df: DataFrame,
picks: Optional[Union[str, list[str]]] = "hbo",
value: str = "Coef.",
background: str = "w",
figure: Optional[object] = None,
clim: Union[str, dict[str, Union[str, list[float]]]] = "auto",
mode: str = "weighted",
colormap: str = "RdBu_r",
surface: str = "pial",
hemi: str = "both",
size: int = 800,
view: Optional[Union[str, dict[str, float]]] = None,
colorbar: bool = True,
distance: float = 0.03,
subjects_dir: Optional[str] = None,
src: Optional[SourceSpaces] = None,
verbose: bool = False,
) -> Brain:
'''Ported from MNE'''
info: Info = cast(Info, deepcopy(inst if isinstance(inst, Info) else inst.info)) # type: ignore
if not (getattr(info, "ch_names") == list(statsmodel_df["ch_name"].values)): # type: ignore
raise RuntimeError(
'MNE data structure does not match dataframe '
f'results.\nMNE = {getattr(info, "ch_names")}.\n'
f'GLM = {list(statsmodel_df["ch_name"].values)}' # type: ignore
)
ea = EvokedArray(np.tile(statsmodel_df[value].values.T, (1, 1)).T, info.copy()) # type: ignore
# TODO: mimic behaviour of other MNE-NIRS glm plotting options
if picks is not None:
ea = ea.pick(picks=picks) # type: ignore
if subjects_dir is None:
subjects_dir = os.environ["SUBJECTS_DIR"]
if src is None:
fname_src_fs = os.path.join(
subjects_dir, "fsaverage", "bem", "fsaverage-ico-5-src.fif"
)
src = read_source_spaces(fname_src_fs, verbose=verbose)
picks = getattr(ea, "info")["ch_names"]
# Set coord frame
for idx in range(len(getattr(ea, "ch_names"))):
getattr(ea, "info")["chs"][idx]["coord_frame"] = 4
# Generate source estimate
kwargs = dict(
evoked=ea,
subject="fsaverage",
trans=Transform('head', 'mri', np.eye(4)),
distance=distance,
mode=mode,
surface=surface,
subjects_dir=subjects_dir,
src=src,
project=True,
)
stc = stc_near_sensors(picks=picks, **kwargs, verbose=verbose) # type: ignore
from mne import SourceEstimate
assert isinstance(stc, SourceEstimate) # or your specific subclass
# Produce brain plot
brain: Brain = stc.plot( # type: ignore
src=src,
subjects_dir=subjects_dir,
hemi=hemi,
surface=surface,
initial_time=0,
clim=clim, # type: ignore
size=size,
colormap=colormap,
figure=figure,
background=background,
colorbar=colorbar,
verbose=verbose,
)
if view is not None:
brain.show_view(view) # type: ignore
return brain
def brain_3d_visualization(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], all_haemo: dict[str, dict[str, BaseRaw]], participant_number: int, t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True) -> None:
# Determine if we are visualizing t or theta to set the appropriate limit
if t_or_theta == 't':
clim = dict(kind="value", pos_lims=(0, ABS_T_VALUE/2, ABS_T_VALUE))
elif t_or_theta == 'theta':
clim = dict(kind="value", pos_lims=(0, ABS_THETA_VALUE/2, ABS_THETA_VALUE))
# Loop over all groups
for index, group_name in enumerate(all_results):
# We only care for their channel results
(_, df_cha, _, _) = all_results[group_name]
# Get all activity conditions
for cond in [TARGET_ACTIVITY]:
if HRF_MODEL == 'fir':
ch_summary = df_cha.query(f"Condition.str.startswith('{cond}_delay_') and Chroma == 'hbo'", engine='python') # type: ignore
else:
# Filter for the condition and chromophore
ch_summary = df_cha.query("Condition in [@cond] and Chroma == 'hbo'") # type: ignore
# Determine number of unique participants based on their ID
n_participants = ch_summary["ID"].nunique()
# WE JUST NEED SOMEONES OPTODE DATA TO PLOT ON THE BRAIN!
# TODO: This should take the average positions of all participants
# We will just take the passed through parameter
participant_to_plot = ch_summary["ID"].unique()[participant_number] # type: ignore
participant_raw_full: BaseRaw = all_haemo[participant_to_plot]["full_layout"]
# Use ordinary least squares (OLS) if only one participant
if n_participants == 1:
# t values
if t_or_theta == 't':
ch_model = smf.ols("t ~ -1 + ch_name", ch_summary).fit() # type: ignore
# theta values
elif t_or_theta == 'theta':
ch_model = smf.ols("theta ~ -1 + ch_name", ch_summary).fit() # type: ignore
logger.info("OLS model is being used as there is only one participant!")
# Use mixed effects model if there is multiple participants
else:
# t values
if t_or_theta == 't':
ch_model = smf.mixedlm("t ~ -1 + ch_name", ch_summary, groups=ch_summary["ID"]).fit(method="nm") # type: ignore
# theta values
elif t_or_theta == 'theta':
ch_model = smf.mixedlm("theta ~ -1 + ch_name", ch_summary, groups=ch_summary["ID"]).fit(method="nm") # type: ignore
# Convert model results
model_df = cast(DataFrame, statsmodels_to_results(ch_model, order=ch_summary["ch_name"].unique())) # type: ignore
valid_channels = ch_summary["ch_name"].unique().tolist() # type: ignore
raw_for_plot = participant_raw_full.copy().pick(picks=valid_channels) # type: ignore
brain = plot_3d_evoked_array(raw_for_plot.pick(picks="hbo"), model_df, view="dorsal", distance=BRAIN_DISTANCE, colorbar=True, clim=clim, mode=BRAIN_MODE, size=(800, 700)) # type: ignore
if show_optodes == 'all' or show_optodes == 'sensors':
brain.add_sensors(getattr(raw_for_plot, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=VERBOSITY) # type: ignore
# Read and parse the file
if show_optodes == 'all' or show_optodes == 'labels':
positions: list[tuple[str, list[float]]] = []
with open(OPTODE_FILE_PATH, 'r') as f:
for line in f:
line = line.strip()
if not line or ':' not in line:
continue # skip empty/malformed lines
name, coords = line.split(':', 1)
coords = [float(x) for x in coords.strip().split()]
positions.append((name.strip(), coords))
for name, (x, y, z) in positions:
brain._renderer.text3d(x, y, z, name, color=('red' if name.startswith('s') else 'blue' if name.startswith('d') else 'gray'), scale=0.002) # type: ignore
# Set the display text for the brain image
# display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nGroup: ' + group_name + '\nCondition: '+ cond + '\nReject Criteria Threshold: ' + str(EPOCH_REJECT_CRITERIA_THRESH) + '\nMin Time Threshold: '
# + str(TIME_MIN_THRESH) + 's\nMax Time Threshold: ' + str(TIME_MAX_THRESH) + 's\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) + '\nStim Duration: '
# + str(STIM_DURATION[index]) + 's\nLooking at: ' + t_or_theta + ' values') + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE
if HRF_MODEL == 'fir':
display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nGroup: ' + group_name + '\nCondition: '+ cond + '\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION)
+ '\nLooking at: ' + t_or_theta + ' values') + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE
else:
display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nGroup: ' + group_name + '\nCondition: '+ cond + '\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) + '\nStim Duration: '
+ str(STIM_DURATION[index]) + '\nLooking at: ' + t_or_theta + ' values') + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE
# Apply the text onto the brain
if show_text:
brain.add_text(0.12, 0.64, display_text, "title", font_size=11, color="k") # type: ignore
def plot_fir_model_results(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], all_haemo: dict[str, dict[str, BaseRaw]], participant_number: int, t_or_theta: Literal['t', 'theta'] = 'theta') -> None:
if HRF_MODEL != 'fir':
logger.info("This method only works when HRF_MODEL is set to 'fir'.")
else:
for group_name in all_results:
(df_roi, _, _, df_design_matrix) = all_results[group_name]
first_id = df_design_matrix["ID"].unique()[participant_number] # type: ignore
first_dm = df_design_matrix.query(f"ID == '{first_id}'").copy() # type: ignore
first_dm.index = np.round([0.1 * i for i in range(len(first_dm))], decimals=1) # type: ignore
df_design_matrix = first_dm
participant = all_haemo[first_id]["full_layout"]
df_roi["isActivity"] = [TARGET_ACTIVITY in n for n in df_roi["Condition"]] # type: ignore
df_roi["isDelay"] = ["delay" in n for n in df_roi["Condition"]] # type: ignore
df_roi = df_roi.query("isDelay in [True]") # type: ignore
df_roi = df_roi.query("isActivity in [True]") # type: ignore
df_roi.loc[:, "TidyCond"] = ""
df_roi.loc[df_roi["isActivity"] == True, "TidyCond"] = TARGET_ACTIVITY # noqa: E712
# Finally, extract the FIR delay in to its own column in data frame
df_roi.loc[:, "delay"] = [n.split("_")[-1] for n in df_roi.Condition] # type: ignore
if t_or_theta == 'theta':
lme = smf.mixedlm("theta ~ -1 + delay:TidyCond:Chroma", df_roi, groups=df_roi["ID"]).fit() # type: ignore
elif t_or_theta == 't':
lme = smf.mixedlm("t ~ -1 + delay:TidyCond:Chroma", df_roi, groups=df_roi["ID"]).fit() # type: ignore
df_sum: DataFrame = statsmodels_to_results(lme) # type: ignore
df_sum["delay"] = [int(n) for n in df_sum["delay"]] # type: ignore
df_sum = df_sum.sort_values("delay") # type: ignore
# logger.info the result for the oxyhaemoglobin data in the Reach condition
df_sum.query(f"TidyCond in ['{TARGET_ACTIVITY}']").query("Chroma in ['hbo']") # type: ignore
axes1: list[Axes]
fig, axes1 = plt.subplots(nrows=1, ncols=3, figsize=(20, 10)) # type: ignore
# Extract design matrix columns that correspond to the condition of interest
dm_cond_idxs = np.where([TARGET_ACTIVITY in n for n in df_design_matrix.columns])[0]
dm_cond_colnames: list[str] = [df_design_matrix.columns[i] for i in dm_cond_idxs]
dm_cond = df_design_matrix[dm_cond_colnames]
# 2. Extract hbo GLM estimates
df_hbo = df_sum.query(f"TidyCond in ['{TARGET_ACTIVITY}']").query("Chroma in ['hbo']") # type: ignore
vals_hbo = [float(v) for v in df_hbo["Coef."]] # type: ignore
dm_cond_scaled_hbo = dm_cond * vals_hbo
# 3. Extract hbr GLM estimates
df_hbr = df_sum.query(f"TidyCond in ['{TARGET_ACTIVITY}']").query("Chroma in ['hbr']") # type: ignore
vals_hbr = [float(v) for v in df_hbr["Coef."]] # type: ignore
dm_cond_scaled_hbr = dm_cond * vals_hbr
# Extract the time scale for plotting.
# Set time zero to be the onset.
index_values = cast(NDArray[float64], dm_cond_scaled_hbo.index.to_numpy(dtype=float) - participant.annotations.onset[1]) # type: ignore
# Plot the result
axes1[0].plot(index_values, np.asarray(dm_cond)) # type: ignore
axes1[1].plot(index_values, np.asarray(dm_cond_scaled_hbo)) # type: ignore
axes1[2].plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r") # type: ignore
axes1[2].plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b") # type: ignore
valid_mask = (index_values >= 0) & (index_values <= 15)
hbo_sum_window = np.sum(dm_cond_scaled_hbo.loc[valid_mask, :], axis=1)
peak_idx_in_window = np.argmax(hbo_sum_window)
peak_idx = np.where(valid_mask)[0][peak_idx_in_window]
peak_time = float(round(index_values[peak_idx], 2)) # type: ignore
axes1[2].axvline(x=peak_time, color='k', linestyle='--', linewidth=1.5, label='Peak time') # type: ignore
# Format the plot
for ax in range(3):
axes1[ax].set_xlim(-5, 20)
axes1[ax].set_xlabel("Time (s)") # type: ignore
axes1[0].set_ylim(-0.2, 1.2)
axes1[1].set_ylim(-4, 8)
axes1[2].set_ylim(-4, 8)
axes1[0].set_title(f"FIR Model for {group_name} (Unscaled by GLM {TARGET_ACTIVITY} estimates) ({t_or_theta})") # type: ignore
axes1[1].set_title(f"FIR Components for {group_name} (Scaled by GLM {TARGET_ACTIVITY} estimates) ({t_or_theta})") # type: ignore
axes1[2].set_title(f"Evoked Response for {group_name} ({TARGET_ACTIVITY}) ({t_or_theta})") # type: ignore
axes1[0].set_ylabel("FIR Model") # type: ignore
axes1[1].set_ylabel("Oyxhaemoglobin (ΔμMol)") # type: ignore
axes1[2].set_ylabel("Haemoglobin (ΔμMol)") # type: ignore
axes1[2].legend(["Oyxhaemoglobin", "Deoyxhaemoglobin", f"Peak {peak_time}s"]) # type: ignore
fig.tight_layout()
# We can also extract the 95% confidence intervals of the estimates too
l95_hbo = [float(v) for v in df_hbo["[0.025"]] # type: ignore
u95_hbo = [float(v) for v in df_hbo["0.975]"]] # type: ignore
dm_cond_scaled_hbo_l95 = dm_cond * l95_hbo
dm_cond_scaled_hbo_u95 = dm_cond * u95_hbo
l95_hbr = [float(v) for v in df_hbr["[0.025"]] # type: ignore
u95_hbr = [float(v) for v in df_hbr["0.975]"]] # type: ignore
dm_cond_scaled_hbr_l95 = dm_cond * l95_hbr
dm_cond_scaled_hbr_u95 = dm_cond * u95_hbr
axes2: Axes
fig, axes2 = plt.subplots(nrows=1, ncols=1, figsize=(7, 7)) # type: ignore
# Plot the result
axes2.plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r") # type: ignore
axes2.plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b") # type: ignore
axes2.axvline(x=peak_time, color='k', linestyle='--', linewidth=1.5, label='Peak time') # type: ignore
axes2.fill_between( # type: ignore
index_values,
np.asarray(np.sum(dm_cond_scaled_hbo_l95, axis=1)),
np.asarray(np.sum(dm_cond_scaled_hbo_u95, axis=1)),
facecolor="red",
alpha=0.25,
)
axes2.fill_between( # type: ignore
index_values,
np.asarray(np.sum(dm_cond_scaled_hbr_l95, axis=1)),
np.asarray(np.sum(dm_cond_scaled_hbr_u95, axis=1)),
facecolor="blue",
alpha=0.25,
)
# Format the plot
axes2.set_xlim(-5, 20)
axes2.set_ylim(-8, 12)
axes2.set_title(f"Evoked Response with 95% confidence intervals for {group_name} ({TARGET_ACTIVITY}) ({t_or_theta})") # type: ignore
axes2.set_ylabel("Haemoglobin (ΔμMol)") # type: ignore
axes2.legend(["Oyxhaemoglobin", "Deoyxhaemoglobin", f"Peak {peak_time}s"]) # type: ignore
axes2.set_xlabel("Time (s)") # type: ignore
fig.tight_layout()
plt.show() # type: ignore
def plot_2d_theta_graph(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]) -> None:
'''This method will create a 2d boxplot showing the theta values for each channel and group as independent ranges on the same graph.\n
Inputs:\n
all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n
'''
# Create a list to store the channel results of all groups
df_all_cha_list: list[DataFrame] = []
# Iterate over each group in all_results
for group_name, (_, df_cha, _, _) in all_results.items():
df_cha["group"] = group_name # Add the group name to the data
df_all_cha_list.append(df_cha) # Append the dataframe to the list
# Combine all the data into a single DataFrame
df_all_cha = pd.concat(df_all_cha_list, ignore_index=True)
# Filter for the target activity
if HRF_MODEL == 'fir':
df_target = df_all_cha[df_all_cha['Condition'].str.startswith(f"{TARGET_ACTIVITY}_delay_")] # type: ignore
else:
df_target = df_all_cha[df_all_cha["Condition"] == TARGET_ACTIVITY]
# Get the number of unique groups to know how many colors are needed for the boxplot
unique_groups = df_target["group"].nunique()
palette = sns.color_palette("Set2", unique_groups)
# Create the boxplot
fig = plt.figure(figsize=(15, 6)) # type: ignore
sns.boxplot(
data=df_target,
x="ch_name",
y="theta",
hue="group",
palette=palette
)
# Format the boxplot
plt.title("Theta Coefficients by Channel and Group") # type: ignore
plt.xticks(rotation=90) # type: ignore
plt.ylabel("Theta (µM)") # type: ignore
plt.xlabel("Channel") # type: ignore
plt.legend(title="Group") # type: ignore
plt.tight_layout()
plt.show() # type: ignore
def plot_individual_theta_averages(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]) -> None:
if HRF_MODEL == 'fir':
logger.info("This method does not work when HRF_MODEL is set to 'fir'.")
return
else:
# Iterate over all the groups
for group_name in all_results:
# Store the region of interest data
(df_roi, _, _, _) = all_results[group_name]
# Filter the results down to what we want
grp_results = df_roi.query(f"Condition in ['{TARGET_ACTIVITY}', '{TARGET_CONTROL}']").copy() # type: ignore
grp_results = grp_results.query("Chroma in ['hbo']").copy() # type: ignore
# Rename the ROI's to be the friendly name
roi_label_map = {
"group_1_picks": ROI_GROUP_1_NAME,
"group_2_picks": ROI_GROUP_2_NAME,
}
grp_results["ROI"] = grp_results["ROI"].replace(roi_label_map) # type: ignore
# Create the catplot
sns.catplot(
x="Condition",
y="theta",
col="ID",
hue="ROI",
data=grp_results,
col_wrap=5,
errorbar=None,
palette="muted",
height=4,
s=10,
dodge=False,
)
plt.show() # type: ignore
def plot_group_theta_averages(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]) -> None:
'''This method will create a stripplot showing the theta vaules for each region of interest for each group.\n
Inputs:\n
all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n'''
if HRF_MODEL == 'fir':
logger.info("This method does not work when HRF_MODEL is set to 'fir'.")
return
else:
# Rename the ROI's to be the friendly name
roi_label_map = {
"group_1_picks": ROI_GROUP_1_NAME,
"group_2_picks": ROI_GROUP_2_NAME,
}
# Setup subplot grid
n = len(all_results)
ncols = 2
nrows = (n + 1) // ncols # round up
fig, axes = cast(tuple[Figure, np.ndarray[Any, Any]], plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 5 * nrows), squeeze=False)) # type: ignore
index = -1
# Iterate over all groups
for index, (group_name, ax) in enumerate(zip(all_results, axes.flatten())):
# Store the region of interest data
(df_roi, _, _, _) = all_results[group_name]
# Filter the results down to what we want
grp_results = df_roi.query(f"Condition in ['{TARGET_ACTIVITY}', '{TARGET_CONTROL}']").copy() # type: ignore
# Run a mixedlm model on the data
roi_model = smf.mixedlm("theta ~ -1 + ROI:Condition:Chroma", grp_results, groups=grp_results["ID"]).fit(method="nm") # type: ignore
# Apply the new friendly names on to the data
df = cast(DataFrame, statsmodels_to_results(roi_model))
df["ROI"] = df["ROI"].map(roi_label_map) # type: ignore
# Create a stripplot:
sns.stripplot(
x="Condition",
y="Coef.",
hue="ROI",
data=df.query("Chroma == 'hbo'"), # type: ignore
dodge=False,
jitter=False,
size=5,
palette="muted",
ax=ax,
)
# Format the stripplot
ax.set_title(f"Results for {group_name}")
ax.legend(title="ROI", loc="upper right")
if index == -1:
# No groups, so remove all axes
for ax in axes.flatten():
fig.delaxes(ax)
# Remove any unused axes and apply final touches
else:
for j in range(index + 1, len(axes.flatten())):
fig.delaxes(axes.flatten()[j])
fig.tight_layout()
fig.suptitle("Theta Averages Across Groups", fontsize=16, y=1.02) # type: ignore
plt.show() # type: ignore
def compute_p_group_stats(df_cha: DataFrame, bad_pairs: set[tuple[int, int]], t_or_theta: Literal['t', 'theta'] = 't') -> DataFrame:
if HRF_MODEL == 'fir':
# Filter: All delays for the target activity
df_activity = df_cha[df_cha['Condition'].str.startswith(f"{TARGET_ACTIVITY}_delay_") & (df_cha['Chroma'] == 'hbo')] # type: ignore
# Aggregate across FIR delays *per subject* for each channel
df_agg = (df_activity.groupby(['Source', 'Detector', 'ID'])[['t', 'theta']].mean().reset_index()) # type: ignore
else:
# Canonical HRF case
df_agg = df_cha[(df_cha['Condition'] == TARGET_ACTIVITY) & (df_cha['Chroma'] == 'hbo')].copy()
# Filter the channel data down to what we want
grouped = cast(Iterator[tuple[tuple[int, int], Any]], df_agg.groupby(['Source', 'Detector'])) # type: ignore
# Create an empty list to store the data for our result
data: list[dict[str, Any]] = []
# Iterate over the filtered channel data
for (src, det), group in grouped:
# If it is a bad channel pairing, do not process it
if (src, det) in bad_pairs:
logger.info(f"Skipping bad channel Source {src} - Detector {det}")
continue
# Drop any missing values that could exist
t_values = group['t'].dropna().values
t_values = np.array(t_values, dtype=float)
theta_values = group['theta'].dropna().values
theta_values = np.array(theta_values, dtype=float)
# Ensure that we still have our two t values, otherwise do not process this pairing
# TODO: is the t values throwing a warning good enough?
if len(t_values) < 2:
logger.info(f"Skipping Source {src} - Detector {det}: not enough data (n={len(t_values)})")
continue
# NOTE: This is still calculated with t values as it is a t-test
# Perform one-sample t-test on t-values across subjects
shitte, pval = ttest_1samp(t_values, popmean=0)
print(shitte)
# Store all of the data for this ttest using the mean t-value for visualization
if t_or_theta == 't':
data.append({
'Source': src,
'Detector': det,
't_or_theta': np.mean(t_values),
'p_value': pval
})
else:
data.append({
'Source': src,
'Detector': det,
't_or_theta': np.mean(theta_values),
'p_value': pval
})
# Create a DataFrame with the data and ensure it is not empty
result = DataFrame(data)
if result.empty:
logger.info("No valid channel pairs with enough data for group-level testing.")
return result
def get_bad_src_det_pairs(raw: BaseRaw) -> set[tuple[int, int]]:
'''This method figures out the bad source and detector pairings for the 2d t+p graph to prevent them from being plotted.
Inputs:\n
raw (RawSNIRF) - Contains all the snirf data for the last participant processed. Only used to get the channels\n
Outputs:\n
bad_pairs (set) - Set containing all of the bad pairs of sources and detectors'''
# Create a set to store the bad pairs
bad_pairs: set[tuple[int, int]] = set()
# Iterate over all the channels in bads key
for ch_name in getattr(raw, "info")["bads"]:
try:
# Get all characters before the space
parts = ch_name.split()[0]
# Split with the separator
src_str, det_str = parts.split(SOURCE_DETECTOR_SEPARATOR)
src = int(src_str[1:])
det = int(det_str[1:])
# Add to the set
bad_pairs.add((src, det))
except Exception as e:
logger.info(f"Could not parse bad channel '{ch_name}': {e}")
return bad_pairs
def plot_avg_significant_activity(raw: BaseRaw, all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], t_or_theta: Literal['t', 'theta'] = 't') -> None:
'''This method plots the average t values for the groups on a 2D graph. p values less than or equal to P_THRESHOLD are solid lines, while greater p values are dashed lines.\n
Inputs:\n
raw (RawSNIRF) - Contains all the snirf data for the last participant processed. Only used to get the channel locations.\n
all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n'''
# Iterate over all the groups
for group_name in all_results:
(_, df_cha, _, _) = all_results[group_name]
if HRF_MODEL == 'fir':
mask = df_cha['Condition'].str.startswith(f"{TARGET_ACTIVITY}_delay_") & (df_cha['Chroma'] == 'hbo') # type: ignore
filtered_df = df_cha[mask]
num_tests = filtered_df.groupby(['Source', 'Detector']).ngroups # type: ignore
else:
num_tests = len(cast(Iterator[tuple[tuple[int, int], Any]], df_cha.query(f"Condition == '{TARGET_ACTIVITY}' and Chroma == 'hbo'").groupby(['Source', 'Detector']))) # type: ignore
logger.info(f"Number of tests: {num_tests}")
# Compute average t-value across individuals for each channel pairing
bad_pairs = get_bad_src_det_pairs(raw)
avg_df = compute_p_group_stats(df_cha, bad_pairs, t_or_theta)
logger.info(f"Average {t_or_theta}-values and p-values for {TARGET_ACTIVITY}:")
for _, row in avg_df.iterrows(): # type: ignore
logger.info(f"Source {row['Source']} <-> Detector {row['Detector']}: "
f"Avg {t_or_theta}-value = {row['t_or_theta']:.3f}, Avg p-value = {row['p_value']:.3f}")
# Extract the cource and detector positions from raw
src_pos: dict[int, tuple[float, float]] = {}
det_pos: dict[int, tuple[float, float]] = {}
for ch in getattr(raw, "info")["chs"]:
ch_name = ch['ch_name']
if not ch_name or not ch['loc'].any():
continue
parts = ch_name.split()[0]
src_str, det_str = parts.split(SOURCE_DETECTOR_SEPARATOR)
src_num = int(src_str[1:])
det_num = int(det_str[1:])
src_pos[src_num] = ch['loc'][3:5]
det_pos[det_num] = ch['loc'][6:8]
# Set up the plot
fig, ax = plt.subplots(figsize=(8, 6)) # type: ignore
# Plot the sources
for pos in src_pos.values():
ax.scatter(pos[0], pos[1], s=120, c='k', marker='o', edgecolors='white', linewidths=1, zorder=3) # type: ignore
# Plot the detectors
for pos in det_pos.values():
ax.scatter(pos[0], pos[1], s=120, c='k', marker='s', edgecolors='white', linewidths=1, zorder=3) # type: ignore
# Ensure that the colors stay within the boundaries even if they are over or under the max/min values
if t_or_theta == 't':
norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_T_VALUE, vmax=ABS_SIGNIFICANCE_T_VALUE)
elif t_or_theta == 'theta':
norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_THETA_VALUE, vmax=ABS_SIGNIFICANCE_THETA_VALUE)
cmap: mcolors.Colormap = plt.get_cmap('seismic')
# Plot connections with avg t-values
for row in avg_df.itertuples():
src: int = cast(int, row.Source) # type: ignore
det: int = cast(int, row.Detector) # type: ignore
tval: float = cast(float, row.t_or_theta) # type: ignore
pval: float = cast(float, row.p_value) # type: ignore
if src in src_pos and det in det_pos:
x = [src_pos[src][0], det_pos[det][0]]
y = [src_pos[src][1], det_pos[det][1]]
style = '-' if pval <= P_THRESHOLD else '--'
ax.plot(x, y, linestyle=style, color=cmap(norm(tval)), linewidth=4, alpha=0.9, zorder=2) # type: ignore
# Format the Colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, shrink=0.85) # type: ignore
cbar.set_label(f'Average {TARGET_ACTIVITY} {t_or_theta} value (hbo)', fontsize=11) # type: ignore
# Formatting the subplots
ax.set_aspect('equal')
ax.set_title(f"Average {t_or_theta} values for {TARGET_ACTIVITY} (HbO) for {group_name}", fontsize=14) # type: ignore
ax.set_xlabel('X position (m)', fontsize=11) # type: ignore
ax.set_ylabel('Y position (m)', fontsize=11) # type: ignore
ax.grid(True, alpha=0.3) # type: ignore
# Set axis limits to be 1cm more than the optode positions
all_x = [pos[0] for pos in src_pos.values()] + [pos[0] for pos in det_pos.values()]
all_y = [pos[1] for pos in src_pos.values()] + [pos[1] for pos in det_pos.values()]
ax.set_xlim(min(all_x)-0.01, max(all_x)+0.01)
ax.set_ylim(min(all_y)-0.01, max(all_y)+0.01)
fig.tight_layout()
plt.show() # type: ignore
def generate_montage_locations():
"""Get standard MNI montage locations in dataframe.
Data is returned in the same format as the eeg_positions library.
"""
# standard_1020 and standard_1005 are in MNI (fsaverage) space already,
# but we need to undo the scaling that head_scale will do
montage = mne.channels.make_standard_montage(
"standard_1005", head_size=0.09700884729534559
)
for d in montage.dig:
d["coord_frame"] = 2003
montage.dig[:] = montage.dig[3:]
montage.add_mni_fiducials() # now in fsaverage space
coords = pd.DataFrame.from_dict(montage.get_positions()["ch_pos"]).T
coords["label"] = coords.index
coords = coords.rename(columns={0: "x", 1: "y", 2: "z"})
return coords.reset_index(drop=True)
def _find_closest_standard_location(position, reference, *, out="label"):
"""Return closest montage label to coordinates.
Parameters
----------
position : array, shape (3,)
Coordinates.
reference : dataframe
As generated by _generate_montage_locations.
trans_pos : str
Apply a transformation to positions to specified frame.
Use None for no transformation.
"""
from scipy.spatial.distance import cdist
p0 = np.array(position)
p0.shape = (-1, 3)
# head_mri_t, _ = _get_trans("fsaverage", "head", "mri")
# p0 = apply_trans(head_mri_t, p0)
dists = cdist(p0, np.asarray(reference[["x", "y", "z"]], float))
if out == "label":
min_idx = np.argmin(dists)
return reference["label"][min_idx]
else:
assert out == "dists"
return dists
def _source_detector_fold_table(raw, cidx, reference, fold_tbl, interpolate):
src = raw.info["chs"][cidx]["loc"][3:6]
det = raw.info["chs"][cidx]["loc"][6:9]
ref_lab = list(reference["label"])
dists = _find_closest_standard_location([src, det], reference, out="dists")
src_min, det_min = np.argmin(dists, axis=1)
src_name, det_name = ref_lab[src_min], ref_lab[det_min]
tbl = fold_tbl.query("Source == @src_name and Detector == @det_name")
dist = np.linalg.norm(dists[[0, 1], [src_min, det_min]])
# Try reversing source and detector
if len(tbl) == 0:
tbl = fold_tbl.query("Source == @det_name and Detector == @src_name")
if len(tbl) == 0 and interpolate:
# Try something hopefully not too terrible: pick the one with the
# smallest net distance
good = np.isin(fold_tbl["Source"], reference["label"]) & np.isin(
fold_tbl["Detector"], reference["label"]
)
assert good.any()
tbl = fold_tbl[good]
assert len(tbl)
src_idx = [ref_lab.index(src) for src in tbl["Source"]]
det_idx = [ref_lab.index(det) for det in tbl["Detector"]]
# Original
tot_dist = np.linalg.norm([dists[0, src_idx], dists[1, det_idx]], axis=0)
assert tot_dist.shape == (len(tbl),)
idx = np.argmin(tot_dist)
dist_1 = tot_dist[idx]
src_1, det_1 = ref_lab[src_idx[idx]], ref_lab[det_idx[idx]]
# And the reverse
tot_dist = np.linalg.norm([dists[0, det_idx], dists[1, src_idx]], axis=0)
idx = np.argmin(tot_dist)
dist_2 = tot_dist[idx]
src_2, det_2 = ref_lab[det_idx[idx]], ref_lab[src_idx[idx]]
if dist_1 < dist_2:
new_dist, src_use, det_use = dist_1, src_1, det_1
else:
new_dist, src_use, det_use = dist_2, det_2, src_2
tbl = fold_tbl.query("Source == @src_use and Detector == @det_use")
tbl = tbl.copy()
tbl["BestSource"] = src_name
tbl["BestDetector"] = det_name
tbl["BestMatchDistance"] = dist
tbl["MatchDistance"] = new_dist
assert len(tbl)
else:
tbl = tbl.copy()
tbl["BestSource"] = src_name
tbl["BestDetector"] = det_name
tbl["BestMatchDistance"] = dist
tbl["MatchDistance"] = dist
tbl = tbl.copy() # don't get warnings about setting values later
return tbl
from mne.utils import _check_fname, _validate_type, warn
def _read_fold_xls(fname, atlas="Juelich"):
"""Read fOLD toolbox xls file.
The values are then manipulated in to a tidy dataframe.
Note the xls files are not included as no license is provided.
Parameters
----------
fname : str
Path to xls file.
atlas : str
Requested atlas.
"""
page_reference = {"AAL2": 2, "AICHA": 5, "Brodmann": 8, "Juelich": 11, "Loni": 14}
tbl = pd.read_excel(fname, sheet_name=page_reference[atlas])
# Remove the spacing between rows
empty_rows = np.where(np.isnan(tbl["Specificity"]))[0]
tbl = tbl.drop(empty_rows).reset_index(drop=True)
# Empty values in the table mean its the same as above
for row_idx in range(1, tbl.shape[0]):
for col_idx, col in enumerate(tbl.columns):
if not isinstance(tbl[col][row_idx], str):
if np.isnan(tbl[col][row_idx]):
tbl.iloc[row_idx, col_idx] = tbl.iloc[row_idx - 1, col_idx]
tbl["Specificity"] = tbl["Specificity"] * 100
tbl["brainSens"] = tbl["brainSens"] * 100
return tbl
import os.path as op
def _check_load_fold(fold_files, atlas):
# _validate_type(fold_files, (list, "path-like", None), "fold_files")
if fold_files is None:
fold_files = mne.get_config("MNE_NIRS_FOLD_PATH")
if fold_files is None:
raise ValueError(
"MNE_NIRS_FOLD_PATH not set, either set it using "
"mne.set_config or pass fold_files as str or list"
)
if not isinstance(fold_files, list): # path-like
fold_files = _check_fname(
fold_files,
overwrite="read",
must_exist=True,
name="fold_files",
need_dir=True,
)
fold_files = [op.join(fold_files, f"10-{x}.xls") for x in (5, 10)]
fold_tbl = pd.DataFrame()
for fi, fname in enumerate(fold_files):
fname = _check_fname(
fname, overwrite="read", must_exist=True, name=f"fold_files[{fi}]"
)
fold_tbl = pd.concat(
[fold_tbl, _read_fold_xls(fname, atlas=atlas)], ignore_index=True
)
return fold_tbl
def fold_channel_specificity_normal(raw, fold_files=None, atlas="Juelich", interpolate=False):
"""Return the landmarks and specificity a channel is sensitive to.
Parameters
""" # noqa: E501
_validate_type(raw, BaseRaw, "raw")
reference_locations = generate_montage_locations()
fold_tbl = _check_load_fold(fold_files, atlas)
chan_spec = list()
for cidx in range(len(raw.ch_names)):
tbl = _source_detector_fold_table(
raw, cidx, reference_locations, fold_tbl, interpolate
)
chan_spec.append(tbl.reset_index(drop=True))
return chan_spec
def fold_channels(raw: BaseRaw, all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], fold_path: str) -> None:
# Locate the fOLD excel files
mne.set_config('MNE_NIRS_FOLD_PATH', fold_path) # type: ignore
# Iterate over all of the groups
for group_name in all_results:
output = None
# List to store the results
landmark_specificity_data: list[dict[str, Any]] = []
# Filter the data to only what we want
hbo_channel_names = cast(list[str], getattr(raw.copy().pick(picks='hbo'), "ch_names")) # type: ignore
# Format the output to make it slightly easier to read
logger.info("*" * 60)
logger.info(f'Landmark Specificity for {group_name}:')
logger.info("*" * 60)
if GUI:
num_channels = len(hbo_channel_names)
rows, cols = 4, 7 # 6 rows and 4 columns of pie charts
fig, axes = plt.subplots(rows, cols, figsize=(16, 10), constrained_layout=True)
axes = axes.flatten() # Flatten the axes array for easier indexing
# If more pie charts than subplots, create extra subplots
if num_channels > rows * cols:
fig, axes = plt.subplots((num_channels // cols) + 1, cols, figsize=(16, 10), constrained_layout=True)
axes = axes.flatten()
# Create a list for consistent color mapping
landmarks = [
"1 - Primary Somatosensory Cortex",
"2 - Primary Somatosensory Cortex",
"3 - Primary Somatosensory Cortex",
"4 - Primary Motor Cortex",
"5 - Somatosensory Association Cortex",
"6 - Pre-Motor and Supplementary Motor Cortex",
"7 - Somatosensory Association Cortex",
"8 - Includes Frontal eye fields",
"9 - Dorsolateral prefrontal cortex",
"10 - Frontopolar area",
"11 - Orbitofrontal area",
"17 - Primary Visual Cortex (V1)",
"18 - Visual Association Cortex (V2)",
"19 - V3",
"20 - Inferior Temporal gyrus",
"21 - Middle Temporal gyrus",
"22 - Superior Temporal Gyrus",
"23 - Ventral Posterior cingulate cortex",
"24 - Ventral Anterior cingulate cortex",
"25 - Subgenual cortex",
"32 - Dorsal anterior cingulate cortex",
"37 - Fusiform gyrus",
"38 - Temporopolar area",
"39 - Angular gyrus, part of Wernicke's area",
"40 - Supramarginal gyrus part of Wernicke's area",
"41 - Primary and Auditory Association Cortex",
"42 - Primary and Auditory Association Cortex",
"43 - Subcentral area",
"44 - pars opercularis, part of Broca's area",
"45 - pars triangularis Broca's area",
"46 - Dorsolateral prefrontal cortex",
"47 - Inferior prefrontal gyrus",
"48 - Retrosubicular area",
"Brain_Outside",
]
cmap1 = plt.cm.get_cmap('tab20') # First 20 colors
cmap2 = plt.cm.get_cmap('tab20b') # Next 20 colors
# Combine the colors from both colormaps
colors = [cmap1(i) for i in range(20)] + [cmap2(i) for i in range(20)] # Total 40 colors
landmarks.sort(key=lambda x: (int(x.split(" - ")[0]) if x.split(" - ")[0].isdigit() else float('inf')))
landmark_color_map = {landmark: colors[i % len(colors)] for i, landmark in enumerate(landmarks)}
# Iterate over each channel
for idx, channel_name in enumerate(hbo_channel_names):
# Run the fOLD on the selected channel
channel_data = raw.copy().pick(picks=channel_name) # type: ignore
output = cast(list[DataFrame], fold_channel_specificity_normal(channel_data, interpolate=True, atlas='Brodmann'))
# Process each DataFrame that fold_channel_specificity returns
for df_data in output:
# Extract the relevant columns
useful_data = df_data[['Landmark', 'Specificity']]
# Store the results
landmark_specificity_data.append({
'Channel': channel_name,
'Data': useful_data,
})
# logger.info the results
for data in landmark_specificity_data:
logger.info(f"Channel: {data['Channel']}")
logger.info(f"{data['Data']}")
logger.info("-" * 60)
# If PLOT_ENABLED is True, plot the results
if GUI:
unique_landmarks = sorted(useful_data['Landmark'].unique())
color_list = [landmark_color_map[landmark] for landmark in useful_data['Landmark']]
# Plot specificity for each channel
ax = axes[idx] # Use the correct axis for this channel
labels = [f'{landmark.split(" - ")[0]}' if landmark != 'Brain_Outside' else 'B' for landmark in useful_data['Landmark']]
wedges, texts, autotexts = ax.pie(
useful_data['Specificity'],
autopct='%1.1f%%',
startangle=90,
labels=labels, # Add the labels here
labeldistance=1.05, # Adjust label position to avoid overlap with the wedges
colors=color_list) # Ensure color consistency
ax.set_title(f'{channel_name}')
ax.axis('equal') # Equal aspect ratio ensures the pie chart is circular.
# Reset the list for the next particcipant
landmark_specificity_data = []
if GUI:
handles = [
plt.Line2D([0], [0], marker='o', color='w', label=landmark, markersize=10,
markerfacecolor=landmark_color_map[landmark])
for landmark in landmarks
]
n_landmarks = len(landmarks)
# Calculate the figure size based on number of rows and columns
fig_width = 5
fig_height = n_landmarks / 4
# Create a new figure window for the legend
legend_fig = plt.figure(figsize=(fig_width, fig_height))
legend_axes = legend_fig.add_subplot(111)
legend_axes.axis('off') # Turn off axis for the legend window
legend_axes.legend(handles=handles, loc='center', fontsize=10, title="Landmarks")
if GUI:
for ax in axes[len(hbo_channel_names):]:
ax.axis('off')
plt.show()
def brain_landmarks_3d(raw_haemo: BaseRaw, show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all') -> None:
brain = Brain("fsaverage", background="white", size=(800, 700)) # type: ignore
# Add optode text labels manually
if show_optodes == 'all' or show_optodes == 'sensors':
brain.add_sensors(getattr(raw_haemo, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=VERBOSITY) # type: ignore
# Read and parse the file
if show_optodes == 'all' or show_optodes == 'labels':
positions: list[tuple[str, list[float]]] = []
with open(OPTODE_FILE_PATH, 'r') as f:
for line in f:
line = line.strip()
if not line or ':' not in line:
continue # skip empty/malformed lines
name, coords = line.split(':', 1)
coords = [float(x) for x in coords.strip().split()]
positions.append((name.strip(), coords))
for name, (x, y, z) in positions:
brain._renderer.text3d(x, y, z, name, color=('red' if name.startswith('s') else 'blue' if name.startswith('d') else 'gray'), scale=0.002) # type: ignore
for ch in getattr(raw_haemo, "info")['chs']:
logger.info(ch['ch_name'], ch['loc'][:3])
# Add Brodmann labels
labels = cast(list[mne.Label], mne.read_labels_from_annot("fsaverage", "PALS_B12_Brodmann", "rh", verbose=VERBOSITY)) # type: ignore
label_colors = {
"Brodmann.39-rh": "blue",
"Brodmann.40-rh": "green",
"Brodmann.6-rh": "pink",
"Brodmann.7-rh": "orange",
"Brodmann.17-rh": "red",
"Brodmann.1-rh": "yellow",
"Brodmann.2-rh": "yellow",
"Brodmann.3-rh": "yellow",
"Brodmann.18-rh": "red",
"Brodmann.19-rh": "red",
"Brodmann.4-rh": "purple",
"Brodmann.8-rh": "white"
}
for label in labels:
name = getattr(label, "name", None)
if not isinstance(name, str):
continue
if name in label_colors:
brain.add_label(label, borders=False, color=label_colors[name]) # type: ignore
def data_to_csv(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]):
logger.info("Getting the current directory...")
if PLATFORM_NAME == 'darwin':
csvs_folder = os.path.join(os.path.dirname(sys.executable), "../../../csvs")
else:
cwd = os.getcwd()
csvs_folder = os.path.join(cwd, "csvs")
logger.info("Attempting to create the csvs folder...")
os.makedirs(csvs_folder, exist_ok=True)
# Generate a timestamp to be appended to the end of the file name
logger.info("Generating the timestamp...")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Iterate over all groups
for group_name in all_results:
# Get the channel data and generate the file name
(_, df_cha, _, _) = all_results[group_name]
filename = f"{group_name}_{timestamp}.csv"
save_path = os.path.join(csvs_folder, filename)
# Filter to just the target condition and store it in the csv
if HRF_MODEL == 'fir':
filtered_df = df_cha[
df_cha["Condition"].str.startswith(TARGET_ACTIVITY) &
(df_cha["Chroma"] == "hbo")
]
# Step 2: Define the aggregation logic
agg_funcs = {
'df': 'mean',
'mse': 'mean',
'p_value': 'mean',
'se': 'mean',
't': 'mean',
'theta': 'mean',
'Source': 'mean',
'Detector': 'mean',
'Significant': lambda x: x.sum() > (len(x) / 2),
'Chroma': 'first', # assuming all are the same
'ch_name': 'first', # same ch_name in the group
'ID': 'first', # same ID in the group
}
# Step 3: Group and aggregate
averaged_df = (
filtered_df
.groupby(['ch_name', 'ID'], as_index=False)
.agg(agg_funcs)
)
# Step 4: Rename and add 'Condition' as TARGET_ACTIVITY
averaged_df.insert(0, 'Condition', TARGET_ACTIVITY)
averaged_df["df"] = averaged_df["df"].round().astype(int)
averaged_df["Source"] = averaged_df["Source"].round().astype(int)
averaged_df["Detector"] = averaged_df["Detector"].round().astype(int)
# Step 5: Reset index and reorder columns
ordered_cols = [
'Condition', 'df', 'mse', 'p_value', 'se', 't', 'theta',
'Source', 'Detector', 'Chroma', 'Significant', 'ch_name', 'ID'
]
averaged_df = averaged_df[ordered_cols].reset_index(drop=True)
averaged_df = averaged_df.sort_values(by=["ID", "Detector", "Source"]).reset_index(drop=True)
output_df = averaged_df
else:
output_df = df_cha.query(f"Condition == '{TARGET_ACTIVITY}' and Chroma == 'hbo'") # type: ignore
output_df.to_csv(save_path)
def all_data_to_csv(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]):
logger.info("Getting the current directory...")
if PLATFORM_NAME == 'darwin':
csvs_folder = os.path.join(os.path.dirname(sys.executable), "../../../csvs")
else:
cwd = os.getcwd()
csvs_folder = os.path.join(cwd, "csvs")
logger.info("Attempting to create the csvs folder...")
os.makedirs(csvs_folder, exist_ok=True)
# Generate a timestamp to be appended to the end of the file name
logger.info("Generating the timestamp...")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Iterate over all groups
for group_name in all_results:
# Get the channel data and generate the file name
(_, df_cha, _, _) = all_results[group_name]
filename = f"{group_name}_{timestamp}_all.csv"
save_path = os.path.join(csvs_folder, filename)
# Filter to just the target condition and store it in the csv
if HRF_MODEL == 'fir':
output_df = df_cha
else:
output_df = df_cha.query(f"Condition == '{TARGET_ACTIVITY}' and Chroma == 'hbo'") # type: ignore
output_df.to_csv(save_path)
def brain_3d_contrast(con_model_df: DataFrame, con_model_df_filtered: BaseRaw, common_channels: list[str], first_name: str, second_name: str, first_stim: float, second_stim: float, t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True) -> None:
# Filter DataFrame to only common channels, and sort by raw order
con_model = con_model_df
con_model["ch_name"] = pd.Categorical(
con_model["ch_name"], categories=common_channels, ordered=True
)
con_model = con_model.sort_values("ch_name").reset_index(drop=True) # type: ignore
if t_or_theta == 't':
clim=dict(kind="value", pos_lims=(0, ABS_T_VALUE/2, ABS_T_VALUE))
elif t_or_theta == 'theta':
clim=dict(kind="value", pos_lims=(0, ABS_THETA_VALUE/2, ABS_THETA_VALUE))
# Plot brain figure
brain = plot_3d_evoked_array(con_model_df_filtered.copy().pick(picks="hbo"), con_model, view="dorsal", distance=BRAIN_DISTANCE, colorbar=True, mode=BRAIN_MODE, clim=clim, size=(800, 700), verbose=VERBOSITY) # type: ignore
if show_optodes == 'all' or show_optodes == 'sensors':
brain.add_sensors(getattr(con_model_df_filtered, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=VERBOSITY) # type: ignore
# Read and parse the file
if show_optodes == 'all' or show_optodes == 'labels':
positions: list[tuple[str, list[float]]] = []
with open(OPTODE_FILE_PATH, 'r') as f:
for line in f:
line = line.strip()
if not line or ':' not in line:
continue # skip empty/malformed lines
name, coords = line.split(':', 1)
coords = [float(x) for x in coords.strip().split()]
positions.append((name.strip(), coords))
for name, (x, y, z) in positions:
brain._renderer.text3d(x, y, z, name, color=('red' if name.startswith('s') else 'blue' if name.startswith('d') else 'gray'), scale=0.002) # type: ignore
# Set the display text for the brain image
# display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nContrast: ' + first_name + ' - ' + second_name + '\nReject Criteria Threshold: ' + str(EPOCH_REJECT_CRITERIA_THRESH) + '\nMin Time Threshold: ' +
# str(TIME_MIN_THRESH) + 's\nMax Time Threshold: ' + str(TIME_MAX_THRESH) + 's\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) + '\nStim Duration: ' + str(first_stim) + ', ' +
# str(second_stim) + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE + '\nLooking at: ' + t_or_theta + ' values')
if HRF_MODEL == 'fir':
display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nContrast: ' + first_name + ' - ' + second_name + '\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION)
+ '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE + '\nLooking at: ' + t_or_theta + ' values')
else:
display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nContrast: ' + first_name + ' - ' + second_name + '\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION)
+ '\nStim Duration: ' + str(first_stim) + ', ' + str(second_stim) + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE + '\nLooking at: ' + t_or_theta + ' values')
# Apply the text onto the brain
if show_text:
brain.add_text(0.12, 0.70, display_text, "title", font_size=11, color="k") # type: ignore
def plot_2d_3d_contrasts_between_groups(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], all_raw_haemo: dict[str, dict[str, BaseRaw]], t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True) -> None:
# Dictionary to store data for each group
group_dfs: dict[str, DataFrame] = {}
# GET RAW HAEMO OF THE FIRST PARTICIPANT
raw_haemo = all_raw_haemo[list(all_raw_haemo.keys())[0]]["full_layout"]
# Store all contrasts with the corresponding group name
for group_name, (_, _, df_con, _) in all_results.items():
group_dfs[group_name] = df_con
group_dfs[group_name]["group"] = group_name
# Concatenate all groups together
df_combined = pd.concat(group_dfs.values(), ignore_index=True)
con_summary = df_combined.query("Chroma == 'hbo'").copy() # type: ignore
valid_channels = cast(DataFrame, (pd.crosstab(con_summary['group'], con_summary['ch_name']) > 1).all()) # type: ignore
valid_channels = valid_channels[valid_channels].index.tolist()
# Filter data to only these channels
con_summary = con_summary[con_summary['ch_name'].isin(valid_channels)] # type: ignore
# # Verify your data looks as expected
# logger.info(con_summary[['group', 'ch_name', 'Chroma', 'effect']].head())
# logger.info("\nUnique values:")
# logger.info("Groups:", con_summary['group'].unique())
# logger.info("Channels:", con_summary['ch_name'].unique())
# logger.info("Chroma:", con_summary['Chroma'].unique()) # Should be just 'hbo'
model_formula = "effect ~ -1 + group:ch_name:Chroma"
con_model = smf.mixedlm(model_formula, con_summary, groups=con_summary["ID"]).fit(method="nm") # type: ignore
# logger.info(con_model.summary())
# # Fit the mixed-effects model
# model_formula = "effect ~ -1 + group:ch_name:Chroma"
# #model_formula = "effect ~ -1 + group + ch_name"
# con_model = smf.mixedlm(
# model_formula, con_summary_filtered, groups=con_summary_filtered["ID"]
# ).fit(method="nm")
# Get the t values if we are comparing them
t_values: pd.Series[float] = pd.Series(dtype=float)
if t_or_theta == 't':
t_values = con_model.tvalues
# Get all the group names from the dictionary and how many groups we have
group_names = list(group_dfs.keys())
n_groups = len(group_names)
# Store DataFrames for each contrast
for i in range(n_groups):
for j in range(i + 1, n_groups):
group1_name = group_names[i]
group2_name = group_names[j]
if t_or_theta == 't':
# Extract the t-values for both groups
group1_vals = t_values.filter(like=f"group[{group1_name}]") # type: ignore
group2_vals = t_values.filter(like=f"group[{group2_name}]") # type: ignore
vlim_value = ABS_CONTRAST_T_VALUE
elif t_or_theta == 'theta':
# Extract the coefficients for both groups
group1_vals = con_model.params.filter(like=f"group[{group1_name}]")
group2_vals = con_model.params.filter(like=f"group[{group2_name}]")
vlim_value = ABS_CONTRAST_THETA_VALUE
# TODO: Does this work for all separators?
# Extract channel names
group1_channels: list[str] = [
name.split(":")[1].split("[")[1].split("]")[0]
for name in getattr(group1_vals, "index")
]
group2_channels: list[str] = [
name.split(":")[1].split("[")[1].split("]")[0]
for name in getattr(group2_vals, "index")
]
# Create the DataFrames with channel indices
df_group1 = DataFrame(
{"Coef.": group1_vals.values}, index=group1_channels # type: ignore
)
df_group2 = DataFrame(
{"Coef.": group2_vals.values}, index=group2_channels # type: ignore
)
# Merge the two DataFrames on the channel names
df_contrast = df_group1.join(df_group2, how="inner", lsuffix=f"_{group1_name}", rsuffix=f"_{group2_name}") # type: ignore
# Compute the contrasts
contrast_1_2 = df_contrast[f"Coef._{group1_name}"] - df_contrast[f"Coef._{group2_name}"]
contrast_2_1 = df_contrast[f"Coef._{group2_name}"] - df_contrast[f"Coef._{group1_name}"]
# Add the a-b / 1-2 contrast to the DataFrame. The order and names of the keys in the DataFrame are important!
df_contrast["Coef."] = contrast_1_2
con_model_df_1_2 = DataFrame({
"ch_name": df_contrast.index,
"Coef.": df_contrast["Coef."],
"Chroma": "hbo"
})
mne_ch_names = getattr(raw_haemo.copy().pick(picks="hbo"), "ch_names") # type: ignore
glm_ch_names = cast(list[DataFrame], con_model_df_1_2["ch_name"].tolist())
# Get ordered common channels
common_channels = [ch for ch in mne_ch_names if ch in glm_ch_names]
# Filter raw data to these channels
con_model_df_filtered = raw_haemo.copy().pick(picks=common_channels) # type: ignore
# Reindex GLM results to match MNE channel order
con_model_df_1_2 = con_model_df_1_2.set_index("ch_name").loc[common_channels].reset_index() # type: ignore
# Create the 3d visualization
brain_3d_contrast(con_model_df_1_2, con_model_df_filtered, common_channels, group1_name, group2_name, STIM_DURATION[i], STIM_DURATION[j], t_or_theta, show_optodes, show_text)
plot_glm_group_topo(con_model_df_filtered.copy().pick(picks="hbo"), con_model_df_1_2, names=True, res=128, vlim=(-vlim_value, vlim_value)) # type: ignore
# TODO: The title currently goes on the colorbar. Low priority
plt.title(f"Contrast: {group1_name} vs {group2_name}") # type: ignore
plt.show() # type: ignore
# Add the b-a / 2-1 contrast to the DataFrame. The order and names of the keys in the DataFrame are important!
df_contrast["Coef."] = contrast_2_1
con_model_df_2_1 = DataFrame({
"ch_name": df_contrast.index,
"Coef.": df_contrast["Coef."],
"Chroma": "hbo"
})
mne_ch_names = getattr(raw_haemo.copy().pick(picks="hbo"), "ch_names") # type: ignore
glm_ch_names = cast(list[DataFrame], con_model_df_2_1["ch_name"].tolist())
# Get ordered common channels
common_channels = [ch for ch in mne_ch_names if ch in glm_ch_names]
# Filter raw data to these channels
con_model_df_filtered = raw_haemo.copy().pick(picks=common_channels) # type: ignore
# Reindex GLM results to match MNE channel order
con_model_df_2_1 = con_model_df_2_1.set_index("ch_name").loc[common_channels].reset_index() # type: ignore
# Create the 3d visualization
brain_3d_contrast(con_model_df_2_1, con_model_df_filtered, common_channels, group2_name, group1_name, STIM_DURATION[j], STIM_DURATION[i], t_or_theta, show_optodes, show_text)
plot_glm_group_topo(con_model_df_filtered.copy().pick(picks="hbo"), con_model_df_2_1, names=True, res=128, vlim=(-vlim_value, vlim_value)) # type: ignore
# TODO: The title currently goes on the colorbar. Low priority
plt.title(f"Contrast: {group2_name} vs {group1_name}") # type: ignore
plt.show() # type: ignore
# TODO: Is any of this still useful?
def calculate_annotations(raw_haemo_filtered, file_name, output_folder=None, save_images=None):
'''Method that extract the annotations from the data.\n
Input:\n
raw_haemo_filtered (RawSNIRF) - The filtered haemoglobin concentration data\n
file_name (string) - The file name of the current file\n
output_folder (string) - (optional) Where to save the images. Default is None\n
save_images (string) - (optional) Bool to save the images or not. Default is None
Output:\n
events (ndarray) - Array containing row number and what index the event is\n
event_dict (dict) - Contains the names of the events'''
if output_folder is None:
output_folder = None
if save_images is None:
save_images = None
# Get when the events occur and what they are called, and display a figure with the result
events, event_dict = mne.events_from_annotations(raw_haemo_filtered)
# Do we save the image?
if save_images:
fig = mne.viz.plot_events(events, event_id=event_dict, sfreq=raw_haemo_filtered.info["sfreq"], show=False)
save_path = output_folder + "/8. Annotations for " + file_name + ".png"
fig.savefig(save_path)
return events, event_dict
def calculate_good_epochs(raw_haemo_filtered, events, event_dict, file_name, tmin=None, tmax=None, reject_thresh=None, target_activity=None, target_control=None, output_folder=None, save_images=None):
'''Calculates what epochs are good and creates a graph showing if any are dropped.\n
Input:\n
raw_haemo_filtered (RawSNIRF) - The filtered haemoglobin concentration data\n
events (ndarray) - Array containing row number and what index the event is\n
event_dict (dict) - Contains the names of the events\n
file_name (string) - The file name of the current file\n
tmin (float) - (optional) Time in seconds to display before the event. Default is TIME_MIN_THRESH\n
tmax (float) - (optional) Time in seconds to display after the event. Default is TIME_MAX_THRESH\n
reject_thresh (float) - (optional) Value that determines the threshold for rejecting epochs. Default is EPOCH_REJECT_CRITERIA_THRESH\n
target_activity (string) - (optional) The target activity. Default is TARGET_ACTIVITY\n
target_control (string) - (optional) The target control. Default is TARGET_CONTROL\n
output_folder (string) - (optional) Where to save the images. Default is None\n
save_images (string) - (optional) Bool to save the images or not. Default is None
Output:\n
good_epochs (Epochs) - The remaining good epochs\n
all_epochs (Epochs) - All of the epochs'''
if tmin is None:
tmin = TIME_MIN_THRESH
if tmax is None:
tmax = TIME_MAX_THRESH
if reject_thresh is None:
reject_thresh = EPOCH_REJECT_CRITERIA_THRESH
if target_activity is None:
target_activity = TARGET_ACTIVITY
if target_control is None:
target_control = TARGET_CONTROL
if output_folder is None:
output_folder = None
if save_images is None:
save_images = None
# Get all the good epochs
good_epochs = mne.Epochs(
raw_haemo_filtered,
events,
event_id=event_dict,
tmin=tmin,
tmax=tmax,
reject=dict(hbo=reject_thresh),
reject_by_annotation=True,
proj=True,
baseline=(None, 0),
preload=True,
detrend=None,
verbose=True,
)
# Get all the epochs
all_epochs = mne.Epochs(
raw_haemo_filtered,
events,
event_id=event_dict,
tmin=tmin,
tmax=tmax,
proj=True,
baseline=(None, 0),
preload=True,
detrend=None,
verbose=True,
)
if REJECT_PAIRS:
# Calculate which epochs were in all but not in good
all_idx = all_epochs.selection
good_idx = good_epochs.selection
bad_idx = np.setdiff1d(all_idx, good_idx)
# Split the controls and the activities
event_ids = all_epochs.events[:, 2]
control_id = event_dict[target_control]
activity_id = event_dict[target_activity]
to_reject_extra = set()
for i, idx in enumerate(all_idx):
if idx in bad_idx:
ev = event_ids[i]
# If the control was bad, drop the following activity
if ev == control_id and i + 1 < len(all_idx):
if event_ids[i + 1] == activity_id:
to_reject_extra.add(all_idx[i + 1])
# If the activity was bad, drop the preceding activity
if ev == activity_id and i - 1 >= 0:
if event_ids[i - 1] == control_id:
to_reject_extra.add(all_idx[i - 1])
# Create a list to store all the new drops, only adding them if they are currently classified as good
drop_idx_in_good = [
np.where(good_idx == idx)[0][0] for idx in to_reject_extra if idx in good_idx
]
# Drop the pairings of the bad epochs
good_epochs.drop(drop_idx_in_good)
# Do we save the image?
if save_images:
drop_log_fig = good_epochs.plot_drop_log(show=False)
save_path = output_folder + "/8. Epoch drops for " + file_name + ".png"
drop_log_fig.savefig(save_path)
return good_epochs, all_epochs
def bad_check(raw_od, max_bad_channels=None):
'''Method to see if we have more bad channels than our allowed threshold.\n
Inputs:\n
raw_od (RawSNIRF) - The optical density data\n
max_bad_channels (int) - (optional) The max amount of bad channels we want to tolerate. Default is MAX_BAD_CHANNELS\n
Output\n
(bool) - True it we had more bad channels than the threshold, False if we did not'''
if max_bad_channels is None:
max_bad_channels = MAX_BAD_CHANNELS
# Check if there is more bad channels in the bads key compared to the allowed amount
if len(raw_od.info.get('bads', [])) >= max_bad_channels:
return True
else:
return False
def remove_bad_epoch_pairings(raw_haemo_filtered_minus_short, good_epochs, epoch_pair_tolerance_window=None):
'''Method to apply our new epochs to the loaded data in working memory. This is to ensure that the GLM does not see these epochs.
Inputs:\n
raw_haemo_filtered_minus_short (RawSNIRF) - The filtered haemoglobin concentration data\n
good_epochs (Epochs) - The epochs we want the loaded data to take on\n
epoch_pair_tolerance_window (int) - (optional) The amount of data points the paired epoch can deviate from the expected amount. Default is EPOCH_PAIR_TOLERANCE_WINDOW\n
Output:\n
raw_haemo_filtered_good_epochs (RawSNIRF) - The filtered haemoglobin concentration data with only the good epochs'''
if epoch_pair_tolerance_window is None:
epoch_pair_tolerance_window = EPOCH_PAIR_TOLERANCE_WINDOW
# Copy the input haemoglobin concentration data and drop the bad channels
raw_haemo_filtered_good_epochs = raw_haemo_filtered_minus_short.copy()
raw_haemo_filtered_good_epochs = raw_haemo_filtered_good_epochs.drop_channels(raw_haemo_filtered_good_epochs.info['bads'])
# Get the event IDs of the good events
good_event_samples = set(good_epochs.events[:, 0])
logger.info(f"Total good events (epochs): {len(good_event_samples)}")
# Get the current annotations
raw_annots = raw_haemo_filtered_good_epochs.annotations
# Create lists to use for processing
clean_descriptions = []
clean_onsets = []
clean_durations = []
dropped = []
# Get the frequency of the file
sfreq = raw_haemo_filtered_good_epochs.info['sfreq']
for desc, onset, dur in zip(raw_annots.description, raw_annots.onset, raw_annots.duration):
# Convert annotation onset time to sample index
sample = int(onset * sfreq)
if FORCE_DROP_ANNOTATIONS:
for i in FORCE_DROP_ANNOTATIONS:
if desc == i:
dropped.append((desc, onset))
continue
# Check if the annotation is within the tolerance of any good event
matched = any(abs(sample - event_sample) <= epoch_pair_tolerance_window for event_sample in good_event_samples)
# We found a matching event
if matched:
clean_descriptions.append(desc)
clean_onsets.append(onset)
clean_durations.append(dur)
else:
dropped.append((desc, onset))
# Create the new filtered annotations
new_annots = Annotations(
onset=clean_onsets,
duration=clean_durations,
description=clean_descriptions,
)
# Assign the new annotations
raw_haemo_filtered_good_epochs.set_annotations(new_annots)
# logger.info out the results
logger.info(f"Original annotations: {len(raw_annots)}")
logger.info(f"Kept annotations: {len(clean_descriptions)}")
logger.info("Kept annotation types:", set(clean_descriptions))
if dropped:
logger.info(f"Dropped annotations: {len(dropped)}")
logger.info("Dropped annotations:")
for desc, onset in dropped:
logger.info(f" - {desc} at {onset:.2f}s")
else:
logger.info("No annotations were dropped!")
return raw_haemo_filtered_good_epochs