4107 lines
161 KiB
Python
4107 lines
161 KiB
Python
"""
|
|
Filename: fNIRS_module.py
|
|
Description: Core functionality for FLARES
|
|
|
|
Author: Tyler de Zeeuw
|
|
License: GPL-3.0
|
|
"""
|
|
|
|
# Built-in imports
|
|
import os
|
|
import sys
|
|
import time
|
|
import logging
|
|
import platform
|
|
import warnings
|
|
import threading
|
|
from io import BytesIO
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
from zipfile import ZipFile
|
|
from datetime import datetime
|
|
from itertools import compress
|
|
from multiprocessing import Queue
|
|
from typing import Any, Optional, cast, Literal, Iterator, Union
|
|
|
|
# External library imports
|
|
import h5py
|
|
import pywt # type: ignore
|
|
import qtpy # type: ignore
|
|
import xlrd # type: ignore
|
|
import psutil
|
|
import scooby # type: ignore
|
|
import requests
|
|
import pyvistaqt # type: ignore
|
|
import darkdetect # type: ignore
|
|
import numpy as np
|
|
import pandas as pd
|
|
from PIL import Image
|
|
import seaborn as sns
|
|
import neurokit2 as nk # type: ignore
|
|
from tqdm.auto import tqdm
|
|
from pandas import DataFrame
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.axes import Axes
|
|
from numpy.typing import NDArray
|
|
#import vtkmodules.util.data_model
|
|
from numpy import floating, float64
|
|
from matplotlib.lines import Line2D
|
|
import matplotlib.colors as mcolors
|
|
from scipy.stats import ttest_1samp # type: ignore
|
|
from matplotlib.figure import Figure
|
|
import statsmodels.formula.api as smf # type: ignore
|
|
#import vtkmodules.util.execution_model
|
|
from nilearn.plotting import plot_design_matrix # type: ignore
|
|
from scipy.signal import welch, butter, filtfilt # type: ignore
|
|
from matplotlib.colors import LinearSegmentedColormap
|
|
from IPython.display import display, Markdown, clear_output # type: ignore
|
|
from statsmodels.tools.sm_exceptions import ConvergenceWarning # type: ignore
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
|
|
|
# External library imports for mne
|
|
import mne
|
|
from mne import EvokedArray, Info, read_source_spaces, stc_near_sensors # type: ignore
|
|
from mne.source_space import SourceSpaces
|
|
from mne.transforms import Transform # type: ignore
|
|
from mne.io import BaseRaw, read_raw_snirf # type: ignore
|
|
from mne.annotations import Annotations # type: ignore
|
|
from mne_nirs.visualisation import plot_glm_group_topo # type: ignore
|
|
from mne_nirs.channels import get_long_channels, get_short_channels, picks_pair_to_idx # type: ignore
|
|
from mne_nirs.experimental_design import make_first_level_design_matrix # type: ignore
|
|
from mne_nirs.statistics import run_glm, statsmodels_to_results # type: ignore
|
|
from mne_nirs.signal_enhancement import enhance_negative_correlation, short_channel_regression # type: ignore
|
|
from mne.preprocessing.nirs import beer_lambert_law, optical_density, temporal_derivative_distribution_repair, source_detector_distances, short_channels # type: ignore
|
|
from mne_nirs.io.fold import fold_channel_specificity # type: ignore
|
|
from mne_nirs.preprocessing import peak_power # type: ignore
|
|
from mne.viz import Brain
|
|
from mne_nirs.statistics._glm_level_first import RegressionResults # type: ignore
|
|
from mne.filter import filter_data # type: ignore
|
|
|
|
CURRENT_VERSION = "1.0.0"
|
|
|
|
GUI = False
|
|
PLATFORM_NAME = platform.system().lower()
|
|
|
|
BASE_SNIRF_FOLDER: str
|
|
SNIRF_SUBFOLDERS: list[str]
|
|
STIM_DURATION: list[float]
|
|
MAX_WORKERS: int
|
|
|
|
SECONDS_TO_STRIP: int
|
|
DOWNSAMPLE: bool
|
|
DOWNSAMPLE_FREQUENCY: int
|
|
FORCE_DROP_CHANNELS: list[str]
|
|
SOURCE_DETECTOR_SEPARATOR: str
|
|
|
|
OPTODE_FILE: bool
|
|
OPTODE_FILE_PATH: str
|
|
OPTODE_FILE_SEPARATOR: str
|
|
|
|
TDDR: bool
|
|
|
|
WAVELET: bool
|
|
IQR: float
|
|
|
|
HEART_RATE: bool
|
|
SECONDS_TO_STRIP_HR: int
|
|
MAX_LOW_HR: int
|
|
MAX_HIGH_HR: int
|
|
SMOOTHING_WINDOW_HR: int
|
|
HEART_RATE_WINDOW: int
|
|
SHORT_CHANNEL: bool
|
|
SHORT_CHANNEL_THRESH: float
|
|
|
|
SCI: bool
|
|
SCI_TIME_WINDOW: int
|
|
SCI_THRESHOLD: float
|
|
PSP: bool
|
|
PSP_TIME_WINDOW: int
|
|
PSP_THRESHOLD: float
|
|
|
|
# TODO: Implement
|
|
SNR: bool
|
|
SNR_TIME_WINDOW : int
|
|
SNR_THRESHOLD: float
|
|
|
|
EXCLUDE_CHANNELS: bool
|
|
MAX_BAD_CHANNELS: int
|
|
LONG_CHANNEL_THRESH: float
|
|
|
|
METADATA: dict
|
|
|
|
DRIFT_MODEL: str
|
|
DURATION_BETWEEN_ACTIVITIES: int
|
|
HRF_MODEL: str
|
|
SHORT_CHANNEL_REGRESSION: bool
|
|
|
|
N_JOBS: int
|
|
|
|
TARGET_ACTIVITY: str
|
|
TARGET_CONTROL: str
|
|
ROI_GROUP_1: list[list[int]]
|
|
ROI_GROUP_2: list[list[int]]
|
|
ROI_GROUP_1_NAME: str
|
|
ROI_GROUP_2_NAME: str
|
|
P_THRESHOLD: float
|
|
|
|
SEE_BAD_IMAGES: bool
|
|
ABS_T_VALUE: int
|
|
ABS_THETA_VALUE: int
|
|
ABS_CONTRAST_T_VALUE: int
|
|
ABS_CONTRAST_THETA_VALUE: int
|
|
ABS_SIGNIFICANCE_T_VALUE: int
|
|
ABS_SIGNIFICANCE_THETA_VALUE: int
|
|
BRAIN_DISTANCE: float
|
|
BRAIN_MODE: str
|
|
|
|
EPOCH_REJECT_CRITERIA_THRESH: float
|
|
TIME_MIN_THRESH: int
|
|
TIME_MAX_THRESH: int
|
|
|
|
VERBOSITY: bool
|
|
|
|
REJECT_PAIRS = None
|
|
FORCE_DROP_ANNOTATIONS = None
|
|
FILTER_LOW_PASS = None
|
|
FILTER_HIGH_PASS = None
|
|
EPOCH_PAIR_TOLERANCE_WINDOW = None
|
|
|
|
# FIXME: Shouldn't need each ordering - just order it before checking
|
|
FIXED_CATEGORY_COLORS = {
|
|
"SCI only": "skyblue",
|
|
"PSP only": "salmon",
|
|
"SNR only": "lightgreen",
|
|
"PSP + SCI": "orange",
|
|
"SCI + SNR": "violet",
|
|
"PSP + SNR": "gold",
|
|
"SCI + PSP": "orange",
|
|
"SNR + SCI": "violet",
|
|
"SNR + PSP": "gold",
|
|
"PSP + SNR + SCI": "gray",
|
|
"SCI + PSP + SNR": "gray",
|
|
"SCI + SNR + PSP": "gray",
|
|
"PSP + SCI + SNR": "gray",
|
|
"PSP + SNR + SCI": "gray",
|
|
"SNR + SCI + PSP": "gray",
|
|
"SNR + PSP + SCI": "gray",
|
|
}
|
|
|
|
|
|
REQUIRED_KEYS: dict[str, Any] = {
|
|
"BASE_SNIRF_FOLDER": str,
|
|
"SNIRF_SUBFOLDERS": list,
|
|
"STIM_DURATION": list,
|
|
"MAX_WORKERS": int,
|
|
"SECONDS_TO_STRIP": int,
|
|
"DOWNSAMPLE": bool,
|
|
"DOWNSAMPLE_FREQUENCY": int,
|
|
"FORCE_DROP_CHANNELS": list,
|
|
"SOURCE_DETECTOR_SEPARATOR": str,
|
|
"OPTODE_FILE": bool,
|
|
"OPTODE_FILE_PATH": str,
|
|
"OPTODE_FILE_SEPARATOR": str,
|
|
"TDDR": bool,
|
|
"WAVELET": bool,
|
|
"IQR": float,
|
|
"HEART_RATE": bool,
|
|
"SECONDS_TO_STRIP_HR": int,
|
|
"MAX_LOW_HR": int,
|
|
"MAX_HIGH_HR": int,
|
|
"SMOOTHING_WINDOW_HR": int,
|
|
"HEART_RATE_WINDOW": int,
|
|
"SHORT_CHANNEL": bool,
|
|
"SHORT_CHANNEL_THRESH": float,
|
|
"SCI": bool,
|
|
"SCI_TIME_WINDOW": int,
|
|
"SCI_THRESHOLD": float,
|
|
"PSP": bool,
|
|
"PSP_TIME_WINDOW": int,
|
|
"PSP_THRESHOLD": float,
|
|
"SNR": bool,
|
|
"SNR_TIME_WINDOW": int,
|
|
"SNR_THRESHOLD": float,
|
|
"EXCLUDE_CHANNELS": bool,
|
|
"MAX_BAD_CHANNELS": int,
|
|
"LONG_CHANNEL_THRESH": float,
|
|
"METADATA": dict,
|
|
"DRIFT_MODEL": str,
|
|
"DURATION_BETWEEN_ACTIVITIES": int,
|
|
"HRF_MODEL": str,
|
|
"SHORT_CHANNEL_REGRESSION": bool,
|
|
"N_JOBS": int,
|
|
"TARGET_ACTIVITY": str,
|
|
"TARGET_CONTROL": str,
|
|
"ROI_GROUP_1": list,
|
|
"ROI_GROUP_2": list,
|
|
"ROI_GROUP_1_NAME": str,
|
|
"ROI_GROUP_2_NAME": str,
|
|
"P_THRESHOLD": float,
|
|
"SEE_BAD_IMAGES": bool,
|
|
"ABS_T_VALUE": int,
|
|
"ABS_THETA_VALUE": int,
|
|
"ABS_CONTRAST_T_VALUE": int,
|
|
"ABS_CONTRAST_THETA_VALUE": int,
|
|
"ABS_SIGNIFICANCE_T_VALUE": int,
|
|
"ABS_SIGNIFICANCE_THETA_VALUE": int,
|
|
"BRAIN_DISTANCE": float,
|
|
"BRAIN_MODE": str,
|
|
"EPOCH_REJECT_CRITERIA_THRESH": float,
|
|
"TIME_MIN_THRESH": int,
|
|
"TIME_MAX_THRESH": int,
|
|
"VERBOSITY": bool,
|
|
|
|
# "REJECT_PAIRS": bool,
|
|
# "FORCE_DROP_ANNOTATIONS": list,
|
|
# "FILTER_LOW_PASS": float,
|
|
# "FILTER_HIGH_PASS": float,
|
|
# "EPOCH_PAIR_TOLERANCE_WINDOW": int,
|
|
}
|
|
|
|
|
|
# Ensure that we are working in the directory of this file
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
os.chdir(script_dir)
|
|
|
|
|
|
# Configure logging to file with timestamps and realtime flush
|
|
if PLATFORM_NAME == 'darwin':
|
|
logging.basicConfig(
|
|
filename=os.path.join(os.path.dirname(sys.executable), "../../../fnirs_analysis.log"),
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S',
|
|
filemode='a'
|
|
)
|
|
|
|
else:
|
|
logging.basicConfig(
|
|
filename='fnirs_analysis.log',
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S',
|
|
filemode='a'
|
|
)
|
|
|
|
logger = logging.getLogger()
|
|
|
|
class ProcessingError(Exception):
|
|
def __init__(self, message: str = "Something went wrong!"):
|
|
self.message = message
|
|
super().__init__(self.message)
|
|
|
|
|
|
def gui_entry(config: dict[str, Any], gui_queue: Queue, progress_queue: Queue) -> None:
|
|
try:
|
|
print("setting config")
|
|
set_config(config, True)
|
|
|
|
# Start a thread to forward progress messages back to GUI
|
|
def forward_progress():
|
|
while True:
|
|
try:
|
|
msg = progress_queue.get(timeout=1)
|
|
if msg == "__done__":
|
|
break
|
|
gui_queue.put(msg)
|
|
except:
|
|
continue
|
|
|
|
t = threading.Thread(target=forward_progress, daemon=True)
|
|
t.start()
|
|
|
|
# Run the actual processing, with progress_queue passed down
|
|
print("actual call")
|
|
result = run_groups(config, True, progress_queue=progress_queue)
|
|
|
|
# Signal end of progress
|
|
progress_queue.put("__done__")
|
|
t.join()
|
|
|
|
gui_queue.put({"success": True, "result": result})
|
|
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
gui_queue.put({
|
|
"success": False,
|
|
"error": str(e),
|
|
"traceback": traceback.format_exc()
|
|
})
|
|
|
|
|
|
def set_config(config: dict[str, Any], gui: bool = False) -> None:
|
|
"""
|
|
Validates and applies the given configuration dictionary.
|
|
|
|
Parameters
|
|
----------
|
|
config : dict[str, Any]
|
|
Dictionary containing configuration keys and their values.
|
|
"""
|
|
|
|
if gui:
|
|
globals().update({"GUI": True})
|
|
|
|
# Ensure all keys are present
|
|
for key, expected_type in REQUIRED_KEYS.items():
|
|
if key not in config:
|
|
raise KeyError(f"Missing config key: {key}")
|
|
|
|
value = config[key]
|
|
if not isinstance(value, expected_type):
|
|
# Special handling for lists to check list contents
|
|
if expected_type == list and isinstance(value, list):
|
|
continue # optionally: validate inner types too
|
|
raise TypeError(f"Key '{key}' has incorrect type. Expected {expected_type.__name__}, got {type(value).__name__}")
|
|
|
|
# Update the global variables to match the values in the config keys
|
|
globals().update(config)
|
|
|
|
# Ensure that passed through variables are correct or that they actually exist
|
|
assert Path(BASE_SNIRF_FOLDER).is_dir(), "BASE_SNIRF_FOLDER was not found. Please check the folder location and try again."
|
|
|
|
for folder in SNIRF_SUBFOLDERS:
|
|
assert Path(os.path.join(BASE_SNIRF_FOLDER, folder)).is_dir(), f"The subfolder {folder} could not be found. Please check the folder location and try again."
|
|
|
|
assert len(SNIRF_SUBFOLDERS) == len(STIM_DURATION), f"The amount of subfolders do not match the amount of stim durations. Subfolders: {len(SNIRF_SUBFOLDERS)} Stim durations: {len(STIM_DURATION)}"
|
|
|
|
if OPTODE_FILE:
|
|
path = Path(OPTODE_FILE_PATH)
|
|
assert path.is_file(), "OPTODE_FILE was specified, but OPTODE_FILE_PATH is not a file."
|
|
assert path.suffix == ".txt", "OPTODE_FILE_PATH does not end with a .txt extension."
|
|
|
|
# Ensure that the BASE_SNIRF_FOLDER is an absolute path - helpful when logger.infoing later
|
|
if 'BASE_SNIRF_FOLDER' in globals():
|
|
abs_path = str(Path(BASE_SNIRF_FOLDER).resolve())
|
|
globals()['BASE_SNIRF_FOLDER'] = abs_path
|
|
|
|
# Supress MNE's warnings
|
|
if not VERBOSITY:
|
|
warnings.filterwarnings("ignore", category=ConvergenceWarning)
|
|
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
|
|
|
logger.info("[Config] Configuration successfully set.")
|
|
|
|
|
|
|
|
def run_groups(config, gui: bool = False, progress_queue=None) -> tuple[dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], dict[str, dict[str, BaseRaw]], dict[str, list[Figure]], dict[str, str], float]:
|
|
"""
|
|
Process multiple data folders and aggregate results, haemoglobin, figures, and processing details.
|
|
|
|
Returns
|
|
-------
|
|
tuple[dict[str, tuple[DataFrame, DataFrame, DataFrame]], dict[str, dict[str, BaseRaw]], dict[str, list[Figure]], dict[str, str]]
|
|
- dict[str, tuple[DataFrame, DataFrame, DataFrame]]: Results dataframes grouped by folder.
|
|
- dict[str, dict[str, BaseRaw]]: Raw haemoglobin data indexed by file ID.
|
|
- dict[str, list[Figure]]: Figures generated during processing grouped by step.
|
|
- dict[str, str]: Processing status messages indexed by file ID.
|
|
- float: Elapsed time
|
|
"""
|
|
|
|
# Create dictionaries to store our results
|
|
all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]] = {}
|
|
all_figures: dict[str, list[Figure]] = {}
|
|
all_raw_haemo: dict[str, dict[str, BaseRaw]] = {}
|
|
all_processes: dict[str, str] = {}
|
|
|
|
# Variables to store our total files to be processed and the remaining amount of files while the program is running
|
|
total_files = 0
|
|
files_remaining = {'count': 0}
|
|
|
|
start_time = time.time()
|
|
|
|
# Iterate over all the folders and determine how many files are in the folder
|
|
logger.info("Calculating how many files there are...")
|
|
for folder in SNIRF_SUBFOLDERS:
|
|
full_path = os.path.join(BASE_SNIRF_FOLDER, folder)
|
|
num_items = len([
|
|
f for f in os.listdir(full_path)
|
|
if os.path.isfile(os.path.join(full_path, f))
|
|
])
|
|
total_files += num_items
|
|
logger.info(f"Total of {total_files} files.")
|
|
|
|
# Set the remaining count to be the total amount of files
|
|
files_remaining['count'] = total_files
|
|
|
|
# Iterate over all the folders
|
|
for folder, stim_duration in zip(SNIRF_SUBFOLDERS, STIM_DURATION):
|
|
full_path = os.path.join(BASE_SNIRF_FOLDER, folder)
|
|
try:
|
|
# Process all participants in the folder
|
|
logger.info(f"Processing all files in {folder}...")
|
|
raw_haemo, df_roi, df_cha, df_con, df_design_matrix, figures, process = process_folder(full_path, stim_duration, files_remaining, config, gui, progress_queue=progress_queue)
|
|
|
|
# Store the results into the corresponding dictionaries
|
|
logger.info(f"Storing the results from the {folder} folder...")
|
|
|
|
# TODO: This looks yucky
|
|
try:
|
|
all_results[folder] = (df_roi, df_cha, df_con, df_design_matrix)
|
|
logger.info(f"Applied all results.")
|
|
except:
|
|
pass
|
|
try:
|
|
for step, fig_list in figures.items():
|
|
all_figures.setdefault(step, []).extend(fig_list)
|
|
logger.info(f"Applied all figures.")
|
|
except:
|
|
pass
|
|
try:
|
|
for file_id, raw in raw_haemo.items():
|
|
all_raw_haemo[file_id] = raw
|
|
logger.info(f"Applied all haemo.")
|
|
except:
|
|
pass
|
|
try:
|
|
for file_id, p in process.items():
|
|
all_processes[file_id] = p
|
|
logger.info(f"Applied all processes.")
|
|
except:
|
|
pass
|
|
|
|
except ProcessingError as e:
|
|
logger.info(f"Something happened! {e}")
|
|
# Something really bad happened. No partial return
|
|
raise Exception(e)
|
|
|
|
except Exception as e:
|
|
logger.info(f"Something happened! {e}")
|
|
# Still return a partial analysis even if something goes wrong
|
|
return all_results, all_raw_haemo, all_figures, all_processes, time.time() - start_time
|
|
|
|
return all_results, all_raw_haemo, all_figures, all_processes, time.time() - start_time
|
|
|
|
|
|
|
|
def create_image_montage(images: list[Image.Image], cols: int) -> Optional[Image.Image]:
|
|
"""
|
|
Creates a grid montage image from a list of PIL Images.
|
|
|
|
Parameters
|
|
----------
|
|
images : list[Image.Image]
|
|
List of images to arrange in the montage.
|
|
cols : int
|
|
Number of columns in the montage grid.
|
|
|
|
Returns
|
|
-------
|
|
Optional[Image.Image]
|
|
The combined montage image, or None if the input list of images is empty.
|
|
"""
|
|
|
|
# Verify that we have images to process
|
|
if not images:
|
|
return None
|
|
|
|
# Calculate the width, height, and rows
|
|
logger.info("Calculating the montage parameters...")
|
|
widths, heights = zip(*(i.size for i in images))
|
|
max_width = max(widths)
|
|
max_height = max(heights)
|
|
rows = (len(images) + cols - 1) // cols
|
|
|
|
# Create the montage image
|
|
logger.info("Creating the montage...")
|
|
montage = Image.new('RGBA', (cols * max_width, rows * max_height), (255, 255, 255, 255))
|
|
for idx, image in enumerate(images):
|
|
x = (idx % cols) * max_width
|
|
y = (idx // cols) * max_height
|
|
montage.paste(image, (x, y)) # type: ignore
|
|
|
|
return montage
|
|
|
|
|
|
|
|
def show_all_images(figures: dict[str, list[BytesIO]], inline: bool = False) -> None:
|
|
"""
|
|
Displays montages of figures either inline or in separate windows.
|
|
|
|
Parameters
|
|
----------
|
|
figures : dict[str, list[Figure]]
|
|
Dictionary containing lists of figures categorized by type.
|
|
inline : bool, optional
|
|
If True, display images inline (e.g., in Jupyter notebooks). Otherwise, opens them in separate windows (default is False).
|
|
"""
|
|
|
|
if inline:
|
|
logger.info("Inline was selected.")
|
|
else:
|
|
logger.info("Inline was not selected.")
|
|
|
|
# If we have less than 4 figures, the columns should be the exact amount of images we have. If we have more, enforce 4 columns
|
|
logger.info("Calculating columns...")
|
|
if len(figures.get('Raw', [])) < 4:
|
|
cols = len(figures.get('Raw', []))
|
|
else:
|
|
cols = 4
|
|
|
|
# Iterate over all of the types of figure, create a montage with figures of the same type, and display the resulting image
|
|
logger.info("Generating images...")
|
|
for _, fig_bytes_list in figures.items():
|
|
pil_images = []
|
|
for b in fig_bytes_list:
|
|
try:
|
|
img = Image.open(BytesIO(b)).convert("RGB")
|
|
pil_images.append(img)
|
|
except Exception as e:
|
|
logger.warning(f"Could not open image from bytes: {e}")
|
|
continue
|
|
|
|
montage = create_image_montage(pil_images, cols)
|
|
|
|
if montage:
|
|
# Determine how to display the images to the user
|
|
if inline:
|
|
display(montage)
|
|
else:
|
|
montage.show()
|
|
|
|
|
|
|
|
def save_all_images(figures: dict[str, list[Figure]]) -> None:
|
|
"""
|
|
Saves montages of figures as timestamped PNG files in folder called 'images'.
|
|
|
|
Parameters
|
|
----------
|
|
figures : dict[str, list[Figure]]
|
|
Dictionary containing lists of figures categorized by type.
|
|
"""
|
|
|
|
# Get the current working directory and create a folder called images if it does not exist
|
|
logger.info("Getting the current directory...")
|
|
if PLATFORM_NAME == 'darwin':
|
|
images_folder = os.path.join(os.path.dirname(sys.executable), "../../../images")
|
|
else:
|
|
cwd = os.getcwd()
|
|
images_folder = os.path.join(cwd, "images")
|
|
|
|
logger.info("Attempting to create the images folder...")
|
|
os.makedirs(images_folder, exist_ok=True)
|
|
|
|
# Generate a timestamp to be appended to the end of the file name
|
|
logger.info("Generating the timestamp...")
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
# If we have less than 4 figures, the columns should be the exact value. If we have more, enforce 4 columns
|
|
logger.info("Calculating columns...")
|
|
raw_fig_count = len(figures.get('Raw', []))
|
|
if raw_fig_count < 4:
|
|
cols = raw_fig_count
|
|
else:
|
|
cols = 4
|
|
|
|
# Iterate over all of the types of figures, create a montage with figures of the same type, and save the resulting image
|
|
logger.info("Generating images...")
|
|
for step, fig_bytes_list in figures.items():
|
|
pil_images = []
|
|
for b in fig_bytes_list:
|
|
try:
|
|
img = Image.open(BytesIO(b)).convert("RGB")
|
|
pil_images.append(img)
|
|
except Exception as e:
|
|
logger.warning(f"Could not open image from bytes: {e}")
|
|
continue
|
|
|
|
montage = create_image_montage(pil_images, cols)
|
|
if montage:
|
|
filename = f"{step}_{timestamp}.png"
|
|
save_path = os.path.join(images_folder, filename)
|
|
montage.save(save_path) # type: ignore
|
|
logger.info(f"Saved image to {save_path}")
|
|
|
|
logger.info(f"All images have been saved to '{images_folder}'.")
|
|
|
|
|
|
|
|
def load_snirf(file_path: str, ID: str, drop_prefixes: list[str]) -> tuple[BaseRaw, Figure]:
|
|
"""
|
|
Loads a snirf file, optionally drops channels, downsamples, and creates a figure showing the results.
|
|
|
|
Parameters
|
|
----------
|
|
file_path : str
|
|
Path of the snirf file to load.
|
|
ID : str
|
|
File name of the the snirf file that was loaded.
|
|
drop_prefixes : list[str]
|
|
List of channel name prefixes to drop from the data.
|
|
|
|
Returns
|
|
-------
|
|
tuple[BaseRaw, Figure]
|
|
- BaseRaw: The processed data object.
|
|
- Figure: The corresponding Matplotlib figure.
|
|
"""
|
|
|
|
logger.info(f"Loading the snirf file ({ID})...")
|
|
|
|
# Read the snirf file
|
|
raw = read_raw_snirf(file_path, verbose=VERBOSITY) # type: ignore
|
|
raw.load_data(verbose=VERBOSITY) # type: ignore
|
|
|
|
# Strip the specified amount of seconds from the start of the file
|
|
total_duration = getattr(raw, "times")[-1]
|
|
if total_duration > SECONDS_TO_STRIP:
|
|
raw.crop(tmin=SECONDS_TO_STRIP, tmax=total_duration, verbose=VERBOSITY) # type: ignore
|
|
logger.info(f"Stripped first {SECONDS_TO_STRIP} second(s) of data.")
|
|
else:
|
|
logger.info(f"Data length ({total_duration:.2f}s) less than strip duration; no cropping applied.")
|
|
|
|
# If the user forcibly dropped channels, remove them now before any processing occurs
|
|
logger.info("Checking if there are channels to forcibly drop...")
|
|
if drop_prefixes:
|
|
logger.info("Force dropped channels was specified.")
|
|
channels_to_drop = [ch for ch in cast(list[str], getattr(raw, "ch_names")) if any(ch.startswith(prefix) for prefix in drop_prefixes)]
|
|
raw.drop_channels(channels_to_drop, "raise") # type: ignore
|
|
logger.info("Force dropped channels:", channels_to_drop)
|
|
|
|
# If the user wants to downsample, do it right away
|
|
logger.info("Checking if we should downsample...")
|
|
if DOWNSAMPLE:
|
|
logger.info("Downsample was specified.")
|
|
sfreq_old = getattr(raw, "info")["sfreq"]
|
|
raw.resample(DOWNSAMPLE_FREQUENCY, verbose=VERBOSITY) # type: ignore
|
|
sfreq_new = getattr(raw, "info")["sfreq"]
|
|
logger.info(f"Finished downsampling. Old frequency: {sfreq_old}. New frequency: {sfreq_new}.")
|
|
|
|
# Create a figure for the results
|
|
logger.info("Creating the figure...")
|
|
fig = cast(Figure, raw.plot(show=False, n_channels=len(getattr(raw, "ch_names")), duration=raw.times[-1]).figure) # type: ignore
|
|
fig.suptitle(f"Raw fNIRS Data for {ID}", fontsize=16) # type: ignore
|
|
fig.subplots_adjust(top=0.92)
|
|
plt.close(fig)
|
|
|
|
logger.info("Successfully loaded the snirf file.")
|
|
|
|
return raw, fig
|
|
|
|
|
|
|
|
def calculate_and_apply_updated_optode_coordinates(data: BaseRaw) -> BaseRaw:
|
|
"""
|
|
Update optode coordinates on the given MNE Raw data using a specified optode file.
|
|
|
|
Parameters
|
|
----------
|
|
data : BaseRaw
|
|
The loaded data object to process with new optode coordinates.
|
|
|
|
Returns
|
|
-------
|
|
BaseRaw
|
|
The processed data object with the updated montage applied.
|
|
"""
|
|
|
|
logger.info("Updating optode coordinates...")
|
|
|
|
fiducials: dict[str, NDArray[floating[Any]]] = {}
|
|
ch_positions: dict[str, NDArray[floating[Any]]] = {}
|
|
|
|
# Read the lines from the optode file
|
|
logger.info(f"Reading optode file from {OPTODE_FILE_PATH}")
|
|
with open(OPTODE_FILE_PATH, 'r') as f:
|
|
for line in f:
|
|
if line.strip():
|
|
# Split by the semicolon and convert to meters
|
|
ch_name, coords_str = line.split(OPTODE_FILE_SEPARATOR)
|
|
coords = np.array(list(map(float, coords_str.strip().split()))) * 0.001
|
|
logger.info(f"Read line: {ch_name} with coords (m): {coords}")
|
|
|
|
# The key we have is a fiducial
|
|
if ch_name.lower() in ['lpa', 'nz', 'rpa']:
|
|
fiducials[ch_name.lower()] = coords
|
|
|
|
# The key we have is a source or detector
|
|
else:
|
|
ch_positions[ch_name.upper()] = coords
|
|
|
|
|
|
# Create montage with updated coords in head space
|
|
logger.info("Creating and applying the montage...")
|
|
initial_montage = mne.channels.make_dig_montage(ch_pos=ch_positions, nasion=fiducials.get('nz'), lpa=fiducials.get('lpa'), rpa=fiducials.get('rpa'), coord_frame='head') # type: ignore
|
|
data.set_montage(initial_montage, verbose=VERBOSITY) # type: ignore
|
|
logger.info("Successfully updated optode coordinates.")
|
|
|
|
return data
|
|
|
|
|
|
|
|
def calculate_and_apply_tddr(data: BaseRaw, ID: str) -> tuple[BaseRaw, Figure]:
|
|
"""
|
|
Applies Temporal Derivative Distribution Repair (TDDR) to the raw data and creates a figure showing the results.
|
|
|
|
Parameters
|
|
----------
|
|
data : BaseRaw
|
|
The loaded data object to process.
|
|
ID : str
|
|
File name of the the snirf file that was loaded.
|
|
|
|
Returns
|
|
-------
|
|
tuple[BaseRaw, Figure]
|
|
- BaseRaw: The processed data object.
|
|
- Figure: The corresponding Matplotlib figure.
|
|
"""
|
|
|
|
# Apply TDDR
|
|
logger.info("Applying temporal derivative distribution repair...")
|
|
raw_with_tddr = cast(BaseRaw, temporal_derivative_distribution_repair(data, verbose=VERBOSITY))
|
|
|
|
# Create a figure for the results
|
|
logger.info("Creating the figure...")
|
|
fig = cast(Figure, raw_with_tddr.plot(show=False, n_channels=len(getattr(data, "ch_names")), duration=data.times[-1]).figure) # type: ignore
|
|
fig.suptitle(f"TDDR for {ID}", fontsize=16) # type: ignore
|
|
fig.subplots_adjust(top=0.92)
|
|
plt.close(fig)
|
|
|
|
logger.info("Successfully applied temporal derivative distribution repair.")
|
|
return raw_with_tddr, fig
|
|
|
|
|
|
|
|
def iqr_threshold(coeffs: NDArray[float64], k: float = 1.5) -> floating[Any]:
|
|
|
|
"""
|
|
Calculate the interquartile range (IQR) threshold scaled by a factor, k.
|
|
|
|
Parameters
|
|
----------
|
|
coeffs : NDArray[float64]
|
|
Array of coefficients to compute the IQR from.
|
|
k : float, optional
|
|
Scaling factor for the IQR (default is 1.5).
|
|
|
|
Returns
|
|
-------
|
|
floating[Any]
|
|
The scaled IQR threshold value.
|
|
"""
|
|
|
|
# Calculate the IQR
|
|
q1 = np.percentile(coeffs, 25)
|
|
q3 = np.percentile(coeffs, 75)
|
|
iqr = q3 - q1
|
|
|
|
return k * iqr
|
|
|
|
|
|
|
|
def wavelet_iqr_denoise(signal: NDArray[float64], wavelet: str = 'db4', level: int = 3) -> NDArray[float64]:
|
|
"""
|
|
Denoises a signal using wavelet decomposition and IQR-based thresholding on detail coefficients.
|
|
|
|
Parameters
|
|
----------
|
|
signal : NDArray[float64]
|
|
The input signal array to denoise.
|
|
wavelet : str, optional
|
|
The type of wavelet to use for decomposition (default is 'db4').
|
|
level : int, optional
|
|
Decomposition level for wavelet transform (default is 3).
|
|
|
|
Returns
|
|
-------
|
|
NDArray[float64]
|
|
The denoised signal array, with the same length as the input.
|
|
"""
|
|
|
|
# Decompose the signal using wavelet transform and initialize a list with approximation coefficients
|
|
coeffs: list[NDArray[float64]] = pywt.wavedec(signal, wavelet, level=level) # type: ignore
|
|
cA = coeffs[0]
|
|
denoised_coeffs = [cA]
|
|
|
|
# Threshold detail coefficients to reduce noise
|
|
for cD in coeffs[1:]:
|
|
threshold = iqr_threshold(cD, IQR)
|
|
cD_thresh = np.sign(cD) * np.maximum(np.abs(cD) - threshold, 0.0) # np.where((cD < lower) | (cD > upper), 0, cD)
|
|
cD_thresh = cD_thresh.astype(float64)
|
|
denoised_coeffs.append(cD_thresh)
|
|
|
|
# Reconstruct the denoised signal
|
|
denoised_signal = cast(NDArray[float64], pywt.waverec(denoised_coeffs, wavelet)) # type: ignore
|
|
return denoised_signal[:len(signal)]
|
|
|
|
|
|
|
|
def calculate_and_apply_wavelet(data: BaseRaw, ID: str) -> tuple[BaseRaw, Figure]:
|
|
"""
|
|
Applies a wavelet IQR denoising filter to the data and generates a plot.
|
|
|
|
Parameters
|
|
----------
|
|
data : BaseRaw
|
|
The loaded data object to process.
|
|
ID : str
|
|
File name of the the snirf file that was loaded.
|
|
|
|
Returns
|
|
-------
|
|
tuple[BaseRaw, Figure]
|
|
- BaseRaw: The processed data object.
|
|
- Figure: The corresponding Matplotlib figure.
|
|
"""
|
|
|
|
logger.info("Applying the wavelet filter...")
|
|
|
|
# Denoise the data
|
|
logger.info("Denoising the data...")
|
|
loaded_data: NDArray[float64] = data.get_data(verbose=VERBOSITY) # type: ignore
|
|
denoised_data = np.zeros_like(loaded_data)
|
|
|
|
logger.info("Calculating the IQR, decomposing the signal, and thresholding the coefficients...")
|
|
for ch in range(loaded_data.shape[0]):
|
|
denoised_data[ch, :] = wavelet_iqr_denoise(loaded_data[ch, :], wavelet='db4', level=3)
|
|
|
|
# Reconstruct the data with the annotations
|
|
logger.info("Reconstructing the data with annotations...")
|
|
raw_with_tddr_and_wavelet = mne.io.RawArray(denoised_data, cast(mne.Info, data.info), verbose=VERBOSITY)
|
|
raw_with_tddr_and_wavelet.set_annotations(data.annotations.copy(), verbose=VERBOSITY) # type: ignore
|
|
|
|
# Create a figure for the results
|
|
logger.info("Creating the figure...")
|
|
fig = cast(Figure, raw_with_tddr_and_wavelet.plot(show=False, n_channels=len(getattr(data, "ch_names")), duration=data.times[-1]).figure) # type: ignore
|
|
fig.suptitle(f"Wavelet for {ID}", fontsize=16) # type: ignore
|
|
fig.subplots_adjust(top=0.92)
|
|
plt.close(fig)
|
|
|
|
logger.info("Successfully applied the wavelet filter.")
|
|
|
|
return raw_with_tddr_and_wavelet, fig
|
|
|
|
|
|
|
|
def short_channel_processing_for_hr(data: BaseRaw, short_chans: BaseRaw | None) -> tuple[float, NDArray[float64], NDArray[float64]]:
|
|
"""
|
|
Extract and trim short-channel fNIRS signal for heart rate analysis.
|
|
|
|
Parameters
|
|
----------
|
|
data : BaseRaw
|
|
The loaded data object to process.
|
|
short_chans : BaseRaw | None
|
|
Data object with only short separation channels, or None if unavailable.
|
|
|
|
Returns
|
|
-------
|
|
tuple[float, NDArray[float64], NDArray[float64]]
|
|
- float: Sampling frequency of the signal.
|
|
- NDArray[float64]: Trimmed short-channel signal.
|
|
- NDArray[float64]: Corresponding time values.
|
|
"""
|
|
|
|
# Find the short channel (or best candidate) and extract signal data and sampling frequency
|
|
logger.info("Extracting the signal and calculating the sampling frequency...")
|
|
|
|
# If a short channel exists, use it for our signal. Otherwise just take the first channel in the data
|
|
# TODO: Find a better way around this
|
|
if short_chans is not None:
|
|
signal = cast(NDArray[float64], short_chans.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore
|
|
else:
|
|
signal = cast(NDArray[float64], data.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore
|
|
|
|
# Calculate the sampling frequency
|
|
sfreq = cast(int, data.info['sfreq'])
|
|
|
|
# Trim start and end of the signal to remove edge artifacts
|
|
logger.info(f"Removing {SECONDS_TO_STRIP_HR} seconds from the beginning and end of the file...")
|
|
strip_samples = int(sfreq * SECONDS_TO_STRIP_HR)
|
|
signal_trimmed = signal[strip_samples:-strip_samples]
|
|
times_trimmed = cast(NDArray[float64], getattr(data, "times"))[strip_samples:-strip_samples]
|
|
|
|
return sfreq, signal_trimmed, times_trimmed
|
|
|
|
|
|
|
|
def calculate_heart_rate_neurokit(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[float64], float]:
|
|
"""
|
|
Calculate and smooth heart rate from a trimmed signal using NeuroKit.
|
|
|
|
Parameters
|
|
----------
|
|
sfreq : float
|
|
Sampling frequency of the signal.
|
|
signal_trimmed : NDArray[float64]
|
|
Preprocessed and trimmed fNIRS signal.
|
|
|
|
Returns
|
|
-------
|
|
tuple[NDArray[float64], float]
|
|
- NDArray[float64]: Smoothed heart rate time series (BPM).
|
|
- float: Mean heart rate.
|
|
"""
|
|
|
|
logger.info("Calculating heart rate using NeuroKit...")
|
|
|
|
# Filter signal to isolate heart rate frequencies and detect peaks
|
|
logger.info("Filtering the signal and detecting peaks...")
|
|
signal_filtered = cast(NDArray[float64], nk.signal_filter(signal_trimmed, sampling_rate=sfreq, lowcut=0.8, highcut=2.5)) # type: ignore
|
|
peaks_dict = cast(dict[str, Any], nk.signal_findpeaks(signal_filtered)) # type: ignore
|
|
peaks = peaks_dict['Peaks']
|
|
hr = cast(NDArray[float64], nk.signal_rate(peaks, sampling_rate=sfreq, desired_length=len(signal_trimmed))) # type: ignore
|
|
hr_clean = np.clip(hr, MAX_LOW_HR, MAX_HIGH_HR)
|
|
|
|
# Smooth heart rate time series by replacing spikes with local rolling mean and calculate the mean
|
|
logger.info("Smoothing the signal and calculating the mean...")
|
|
hr_series = pd.Series(hr_clean)
|
|
local_median = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).median()
|
|
spikes = hr_series > (local_median + 10)
|
|
smoothed_values = hr_series.copy()
|
|
smoothed_spikes = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).mean()
|
|
smoothed_values[spikes] = smoothed_spikes[spikes]
|
|
hr_smooth_nk = cast(NDArray[float64], smoothed_values.to_numpy()) # type: ignore
|
|
mean_hr_nk = hr_smooth_nk.mean()
|
|
|
|
logger.info("Original HR min/max: %f, %f", hr_clean.min(), hr_clean.max())
|
|
logger.info("Smoothed HR min/max:%f, %f", hr_smooth_nk.min(), hr_smooth_nk.max())
|
|
logger.info(f"Estimated mean HR nk: {mean_hr_nk:.1f} BPM")
|
|
|
|
logger.info("Successfully calculated heart rate using NeuroKit.")
|
|
|
|
return hr_smooth_nk, mean_hr_nk
|
|
|
|
|
|
|
|
def calculate_heart_rate_scipy(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float]:
|
|
"""
|
|
Estimate heart rate using spectral analysis on a high-pass filtered signal.
|
|
|
|
Parameters
|
|
----------
|
|
sfreq : float
|
|
Sampling frequency of the input signal.
|
|
signal_trimmed : NDArray[float64]
|
|
Trimmed fNIRS signal to analyze.
|
|
|
|
Returns
|
|
-------
|
|
tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float]
|
|
- NDArray[floating[Any]]: Frequencies converted to beats per minute (BPM).
|
|
- NDArray[float64]: Power spectral density (PSD) of the signal.
|
|
- np.ndarray[Any, np.dtype[np.bool_]]: Boolean mask indicating frequencies within heart rate range (30-300 BPM).
|
|
- float: Estimated mean heart rate in BPM corresponding to the PSD peak within the range.
|
|
"""
|
|
|
|
logger.info("Calculating heart rate using SciPy...")
|
|
|
|
# Apply a high-pass Butterworth filter to remove slow trends below 0.5 Hz from the trimmed signal (actual data)
|
|
logger.info("Applying a butterworth filter...")
|
|
b, a = cast(tuple[NDArray[float64], NDArray[float64]], butter(2, 0.5 / (sfreq / 2), btype='high'))
|
|
signal_hp = cast(NDArray[float64],filtfilt(b, a, signal_trimmed))
|
|
|
|
# Calculate the Power Spectral Density (PSD) of the filtered signal using Welch's method
|
|
logger.info("Calculating the PSD...")
|
|
nperseg = min(len(signal_hp), 4096)
|
|
frequencies_scipy, psd_scipy = cast(tuple[NDArray[float64], NDArray[float64]], welch(signal_hp, fs=sfreq, nperseg=nperseg, noverlap=nperseg//2))
|
|
|
|
# Convert frequency values to beats per minute (BPM) and set a heart rate range (30-300 BPM)
|
|
logger.info("Converting to BPM...")
|
|
freq_bpm_scipy = frequencies_scipy * 60
|
|
freq_range_scipy = (freq_bpm_scipy > 30) & (freq_bpm_scipy < 300)
|
|
|
|
# Identify the peak frequency within the heart rate range and estimate the mean heart rate in BPM
|
|
logger.info("Finding the mean...")
|
|
peak_index = np.argmax(psd_scipy[freq_range_scipy])
|
|
mean_hr_scipy = freq_bpm_scipy[freq_range_scipy][peak_index]
|
|
|
|
logger.info("Successfully calculated heart rate using SciPy.")
|
|
|
|
return freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy
|
|
|
|
|
|
def plot_heart_rate(
|
|
freq_bpm_scipy: NDArray[floating[Any]],
|
|
psd_scipy: NDArray[float64],
|
|
freq_range_scipy: np.ndarray[Any, np.dtype[np.bool_]],
|
|
mean_hr_scipy: float,
|
|
hr_smooth_nk: NDArray[floating[Any]],
|
|
mean_hr_nk: float,
|
|
times_trimmed: NDArray[floating[Any]],
|
|
overruled: bool
|
|
) -> tuple[Figure, Figure]:
|
|
"""
|
|
Generate plots comparing heart rate estimates from SciPy PSD and NeuroKit2.
|
|
|
|
Parameters
|
|
----------
|
|
freq_bpm_scipy : NDArray[floating[Any]]
|
|
Frequencies in beats per minute from SciPy PSD analysis.
|
|
psd_scipy : NDArray[float64]
|
|
Power spectral density values corresponding to freq_bpm_scipy.
|
|
freq_range_scipy : np.ndarray[Any, np.dtype[np.bool_]]
|
|
Boolean mask indicating the heart rate frequency range used in PSD.
|
|
mean_hr_scipy : float
|
|
Mean heart rate estimated from SciPy PSD peak.
|
|
hr_smooth_nk : NDArray[floating[Any]]
|
|
Smoothed instantaneous heart rate from NeuroKit2.
|
|
mean_hr_nk : float
|
|
Mean heart rate estimated from NeuroKit2 data.
|
|
times_trimmed : NDArray[floating[Any]]
|
|
Time points corresponding to hr_smooth_nk values.
|
|
overruled: bool
|
|
True if the heart rate from NeuroKit2 is overriding the results from the PSD.
|
|
|
|
Returns
|
|
-------
|
|
tuple[Figure, Figure]
|
|
- Figure showing the PSD and SciPy heart rate estimate.
|
|
- Figure showing the time series comparison of heart rates.
|
|
"""
|
|
|
|
# Create the first plot for the PSD. Add a yellow range to show what we will be filtering to.
|
|
logger.info("Creating the figure...")
|
|
fig1, ax1 = plt.subplots(figsize=(10, 5)) # type: ignore
|
|
ax1.set_xlim(30, 300)
|
|
ax1.plot(freq_bpm_scipy[freq_range_scipy], psd_scipy[freq_range_scipy]) # type: ignore
|
|
ax1.axvline(x=mean_hr_scipy, color='red', linestyle='--', label=f'Mean HR: {mean_hr_scipy:.1f} BPM') # type: ignore
|
|
ax1.axvspan(min(mean_hr_nk - HEART_RATE_WINDOW, mean_hr_scipy - HEART_RATE_WINDOW), max(mean_hr_nk + HEART_RATE_WINDOW, mean_hr_scipy + HEART_RATE_WINDOW), color='yellow', alpha=0.3, label=f'HR Range ±{HEART_RATE_WINDOW} BPM') # type: ignore
|
|
ax1.set_xlabel('Heart Rate (BPM)') # type: ignore
|
|
ax1.set_ylabel('Power Spectral Density') # type: ignore
|
|
ax1.set_title('PSD of fNIRS signal - Peak indicates Heart Rate') # type: ignore
|
|
ax1.grid(True) # type: ignore
|
|
|
|
# Was the value we reported here correct for the data on the graph or was it overruled?
|
|
if overruled:
|
|
note = (
|
|
'\n'
|
|
'Note: Calculation was bad!\n'
|
|
'Data has been set to match\n'
|
|
'the value from NeuroKit2.'
|
|
)
|
|
phantom = Line2D([0], [0], color='none', label=note)
|
|
handles, _ = ax1.get_legend_handles_labels()
|
|
ax1.legend(handles=handles + [phantom]) # type: ignore
|
|
|
|
else:
|
|
ax1.legend() # type: ignore
|
|
plt.close(fig1)
|
|
|
|
# Create the second plot showing the rolling heart rate, as well as the two averages that were calculated
|
|
logger.info("Creating the figure...")
|
|
fig2, ax2 = plt.subplots(figsize=(14, 6)) # type: ignore
|
|
ax2.plot(times_trimmed, hr_smooth_nk, label='Instantaneous HR (NeuroKit2)', color='blue', alpha=0.7) # type: ignore
|
|
ax2.axhline(mean_hr_nk, color='red', linestyle='--', label=f'Mean HR NeuroKit2: {mean_hr_nk:.1f} BPM') # type: ignore
|
|
ax2.axhline(mean_hr_scipy, color='orange', linestyle=':', label=f'SciPy Welch PSD (HP filtered): {mean_hr_scipy:.1f} BPM') # type: ignore
|
|
ax2.set_xlabel('Time (seconds)') # type: ignore
|
|
ax2.set_ylabel('Heart Rate (BPM)') # type: ignore
|
|
ax2.set_title('Heart Rate Estimates Comparison') # type: ignore
|
|
ax2.legend() # type: ignore
|
|
ax2.grid(True) # type: ignore
|
|
fig2.tight_layout()
|
|
plt.close(fig2)
|
|
|
|
return fig1, fig2
|
|
|
|
|
|
|
|
def plot_timechannel_quality_metrics(data: BaseRaw, scores: NDArray[float64], times: list[tuple[float]], color_stops: tuple[list[float], list[float]], threshold: float, title: Optional[str] = None) -> tuple[Figure, Figure]:
|
|
|
|
"""
|
|
Generate two heatmaps visualizing channel quality metrics over time.
|
|
|
|
Parameters
|
|
----------
|
|
data : BaseRaw
|
|
The loaded data object to process.
|
|
scores : NDArray[float64]
|
|
A 2D array of quality scores for each channel over time.
|
|
times : list[tuple[float]]
|
|
List of time boundaries used to label each score column.
|
|
color_stops : tuple[list[float], list[float]]
|
|
Two lists of color values for custom colormaps.
|
|
threshold : float,
|
|
Threshold value for the color bar.
|
|
title : Optional[str], optional
|
|
Base title for the figures, (default is None).
|
|
|
|
Returns
|
|
-------
|
|
tuple[Figure, Figure]
|
|
- Figure: Heatmap of all scores across channels and time.
|
|
- Figure: Binary heatmap showing only scores above the threshold.
|
|
"""
|
|
|
|
# Get only the hbo / hbr channels once as we dont need to see the same results twice
|
|
half_ch = len(getattr(data, "ch_names")) // 2
|
|
ch_names = getattr(data, "ch_names")[:half_ch]
|
|
scores = scores[:half_ch, :]
|
|
|
|
# Extract rounded time points to use as column headers
|
|
cols = [np.round(t[0]) for t in times]
|
|
n_chans = len(ch_names)
|
|
vsize = 0.2 * n_chans
|
|
|
|
# Create the first figure
|
|
fig1, ax1 = plt.subplots(figsize=(10, vsize), layout="constrained") # type: ignore
|
|
fig1.suptitle(title + " - All Scores", fontsize=16, fontweight="bold") # type: ignore
|
|
|
|
# Create a DataFrame to structure data for the heatmap
|
|
data_to_plot = DataFrame(
|
|
data=scores,
|
|
columns=pd.Index(cols, name="Time (s)"),
|
|
index=pd.Index(ch_names, name="Channel"),
|
|
)
|
|
|
|
# Define a custom colormap using provided color stops and base colors
|
|
base_colors = ['red', 'red', 'yellow', 'green', 'green']
|
|
colors = list(zip(color_stops[0], base_colors[:len(color_stops[0])]))
|
|
cmap = mcolors.LinearSegmentedColormap.from_list('gyr', colors)
|
|
|
|
# Plot heatmap of scores
|
|
sns.heatmap( # type: ignore
|
|
data=data_to_plot,
|
|
cmap=cmap,
|
|
vmin=0,
|
|
vmax=1,
|
|
cbar_kws=dict(label="Score"),
|
|
ax=ax1,
|
|
)
|
|
|
|
# Add vertical dashed lines at each time boundary, sit the title, and place a black strikethrough through a bad channel
|
|
for x in range(1, len(times)):
|
|
ax1.axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") # type: ignore
|
|
ax1.set_title("All Scores", fontweight="bold") # type: ignore
|
|
markbad(data, ax1, ch_names)
|
|
|
|
# Calculate average score per channel and annotate to the right of the heatmap
|
|
avg_sci_subset: pd.Series[float] = data_to_plot.mean(axis=1) # type: ignore
|
|
norm = mcolors.Normalize(vmin=0, vmax=1)
|
|
text_x = data_to_plot.shape[1] + 0.5
|
|
for i, val in enumerate(avg_sci_subset):
|
|
color = cmap(norm(val))
|
|
ax1.text( # type: ignore
|
|
text_x,
|
|
i + 0.5,
|
|
f"{val:.3f}",
|
|
va='center',
|
|
ha='left',
|
|
fontsize=9,
|
|
color=color
|
|
)
|
|
ax1.set_xlim(right=text_x + 1.5)
|
|
|
|
plt.close(fig1)
|
|
|
|
# Create the second figure
|
|
fig2, ax2 = plt.subplots(figsize=(10, vsize), layout="constrained") # type: ignore
|
|
fig2.suptitle(title + " - Scores Above Threshold", fontsize=16, fontweight="bold") # type: ignore
|
|
|
|
# Create a DataFrame to structure data for the heatmap
|
|
data_to_plot = DataFrame(
|
|
data=scores > threshold,
|
|
columns=pd.Index(cols, name="Time (s)"),
|
|
index=pd.Index(ch_names, name="Channel"),
|
|
)
|
|
|
|
# Define a custom colormap using provided color stops and base colors
|
|
base_colors = ['red', 'red', 'white', 'white']
|
|
colors = list(zip(color_stops[1], base_colors[:len(color_stops[1])]))
|
|
cmap = mcolors.LinearSegmentedColormap.from_list('gyr', colors)
|
|
|
|
# Plot heatmap of scores
|
|
sns.heatmap( # type: ignore
|
|
data=data_to_plot,
|
|
vmin=0,
|
|
vmax=1,
|
|
cmap=cmap,
|
|
cbar_kws=dict(label="Score"),
|
|
ax=ax2,
|
|
)
|
|
|
|
# Add vertical dashed lines at each time boundary, sit the title, and place a black strikethrough through a bad channel
|
|
for x in range(1, len(times)):
|
|
ax2.axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") # type: ignore
|
|
ax2.set_title("Scores > Threshold", fontweight="bold") # type: ignore
|
|
markbad(data, ax2, ch_names)
|
|
|
|
plt.close(fig2)
|
|
|
|
return fig1, fig2
|
|
|
|
|
|
|
|
def markbad(data: BaseRaw, ax: Axes, ch_names: list[str]) -> None:
|
|
"""
|
|
Add a strikethrough to a plot for channels marked as bad.
|
|
|
|
Parameters
|
|
----------
|
|
data : BaseRaw
|
|
The loaded data object to process.
|
|
ax : Axes
|
|
Matplotlib Axes object where the strikethrough lines will be drawn.
|
|
ch_names : list[str]
|
|
List of channel names corresponding to the y-axis of the plot.
|
|
"""
|
|
|
|
# Iterate over all the channels
|
|
for i, ch in enumerate(ch_names):
|
|
|
|
# If it is marked as bad, place a strikethrough on the channel
|
|
if ch in data.info["bads"]:
|
|
ax.axhline(i + 0.5, ls="solid", lw=4, color="black", zorder=10) # type: ignore
|
|
|
|
|
|
|
|
def calculate_scalp_coupling(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5) -> tuple[list[str], Figure, Figure]:
|
|
"""
|
|
Calculate the scalp coupling index (SCI) and identify bad channels based on a threshold.
|
|
|
|
Parameters
|
|
----------
|
|
data : BaseRaw
|
|
The loaded data object to process.
|
|
l_freq : float, optional
|
|
Low cutoff frequency for bandpass filtering in Hz (default is 0.7).
|
|
h_freq : float, optional
|
|
High cutoff frequency for bandpass filtering in Hz (default is 1.5)
|
|
|
|
Returns
|
|
-------
|
|
tuple[list[str], Figure, Figure]
|
|
- list[str]: Channel names identified as bad based on SCI threshold.
|
|
- Figure: Heatmap of all SCI scores across time and channels.
|
|
- Figure: Binary heatmap of SCI scores exceeding the threshold.
|
|
"""
|
|
|
|
logger.info("Calculating scalp coupling index...")
|
|
|
|
# Compute the SCI
|
|
_, scores, times = cast(tuple[NDArray[float64], NDArray[float64], list[tuple[float]]], scalp_coupling_index_windowed_raw(data, time_window=SCI_TIME_WINDOW, l_freq=l_freq, h_freq=h_freq))
|
|
|
|
# Identify channels that don't meet the provided threshold
|
|
logger.info("Identifying channels that do not meet the threshold...")
|
|
sci = scores.mean(axis=1)
|
|
data.info["bads"] = list(compress(cast(list[str], getattr(data, "ch_names")), sci < SCI_THRESHOLD))
|
|
|
|
# Determine the colors based on the threshold, and create the figures
|
|
logger.info("Creating the figures...")
|
|
color_stops = ([0.0, SCI_THRESHOLD, SCI_THRESHOLD+0.1, 0.8, 1.0], [0.0, SCI_THRESHOLD, SCI_THRESHOLD, 1.0])
|
|
fig1, fig2 = plot_timechannel_quality_metrics(data, scores, times, color_stops, SCI_THRESHOLD, "Scalp Coupling Index")
|
|
|
|
logger.info("Successfully calculated scalp coupling index.")
|
|
|
|
return list(compress(cast(list[str], getattr(data, "ch_names")), sci < SCI_THRESHOLD)), fig1, fig2
|
|
|
|
|
|
|
|
def scalp_coupling_index_windowed_raw(data: BaseRaw, time_window: float = 3.0, l_freq: float = 0.7, h_freq: float = 1.5, l_trans_bandwidth: float = 0.3, h_trans_bandwidth: float = 0.3) -> tuple[BaseRaw, NDArray[float64], list[tuple[float, float]]]:
|
|
"""
|
|
Compute windowed scalp coupling index (SCI) across fNIRS channels.
|
|
|
|
Parameters
|
|
----------
|
|
data : BaseRaw
|
|
The loaded data object to process.
|
|
time_window : float, optional
|
|
Length of each time window in seconds (default is 3.0).
|
|
l_freq : float, optional
|
|
Low cutoff frequency for filtering in Hz (default is 0.7).
|
|
h_freq : float, optional
|
|
High cutoff frequency for filtering in Hz (default is 1.5).
|
|
l_trans_bandwidth : float, optional
|
|
Transition bandwidth for the low cutoff in Hz (default is 0.3).
|
|
h_trans_bandwidth : float, optional
|
|
Transition bandwidth for the high cutoff in Hz (default is 0.3).
|
|
|
|
Returns
|
|
-------
|
|
tuple[BaseRaw, NDArray[float64], list[tuple[float, float]]]
|
|
- BaseRaw: The original data object (unchanged). Ensures compatibility with peak_power().
|
|
- NDArray[float64]: Correlation scores for each channel and time window.
|
|
- list[tuple[float, float]]: Time intervals for each window in seconds.
|
|
"""
|
|
|
|
# Pick only fNIRS channels and sort them by channel name
|
|
picks: NDArray[np.intp] = mne.pick_types(cast(mne.Info, data.info), fnirs=True) # type: ignore
|
|
picks = picks[np.argsort([getattr(data, "ch_names")[pick] for pick in picks])]
|
|
|
|
# FIXME: This may happen if the heart rate calculation tries to set a value way too low
|
|
if l_freq < 0.3:
|
|
l_freq = 0.3
|
|
|
|
# Band-pass filter the selected fNIRS channels
|
|
filtered_data = cast(NDArray[float64], filter_data(
|
|
getattr(data, "_data"),
|
|
getattr(data, "info")["sfreq"],
|
|
l_freq,
|
|
h_freq,
|
|
picks=picks,
|
|
verbose=False,
|
|
l_trans_bandwidth=l_trans_bandwidth, # type: ignore
|
|
h_trans_bandwidth=h_trans_bandwidth, # type: ignore
|
|
))
|
|
|
|
# Calculate number of samples per time window, the total number of windows, and prepare output variables
|
|
window_samples = int(np.ceil(time_window * getattr(data, "info")["sfreq"]))
|
|
n_windows = int(np.floor(len(data) / window_samples))
|
|
scores = np.zeros((len(picks), n_windows))
|
|
times: list[tuple[float, float]] = []
|
|
|
|
# Slide through the data in windows to compute scalp coupling index (SCI)
|
|
for window in range(n_windows):
|
|
start_sample = int(window * window_samples)
|
|
end_sample = start_sample + window_samples
|
|
end_sample = np.min([end_sample, len(data) - 1])
|
|
|
|
# Track time boundaries for each window
|
|
t_start = getattr(data, "times")[start_sample]
|
|
t_stop = getattr(data, "times")[end_sample]
|
|
times.append((t_start, t_stop))
|
|
|
|
# Iterate through channels in pairs (hbo, hbr). This requires them to be sorted by channel name
|
|
for ii in range(0, len(picks), 2):
|
|
c1: NDArray[float64] = filtered_data[picks[ii]][start_sample:end_sample]
|
|
c2 = filtered_data[picks[ii + 1]][start_sample:end_sample]
|
|
|
|
# Ensure the correlation data is valid
|
|
if np.std(c1) == 0 or np.std(c2) == 0 or np.any(np.isnan(c1)) or np.any(np.isnan(c2)):
|
|
c = 0
|
|
else:
|
|
c = np.corrcoef(c1, c2)[0][1]
|
|
|
|
# Assign the computed correlation to both channels in the pair
|
|
scores[ii, window] = c
|
|
scores[ii + 1, window] = c
|
|
|
|
scores = scores[np.argsort(picks)]
|
|
|
|
return data, scores, times
|
|
|
|
|
|
|
|
def calculate_peak_power(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5) -> tuple[list[str], Figure, Figure]:
|
|
"""
|
|
Calculate peak spectral power (PSP) for fNIRS channels and identify bad channels.
|
|
|
|
Parameters
|
|
----------
|
|
data : BaseRaw
|
|
The loaded data object to process.
|
|
l_freq : float, optional
|
|
Low cutoff frequency for filtering in Hz (default is 0.7)
|
|
h_freq : float, optional
|
|
High cutoff frequency for filtering in Hz (default is 1.5)
|
|
|
|
Returns
|
|
-------
|
|
tuple[list[str], Figure, Figure]
|
|
- list[str]: Names of channels below the PSP threshold.
|
|
- Figure: Heatmap of all PSP scores.
|
|
- Figure: Heatmap of scores above the PSP threshold.
|
|
"""
|
|
|
|
logger.info("Calculating peak spectral power...")
|
|
|
|
# Compute the PSP
|
|
_, scores, times = cast(tuple[NDArray[float64], NDArray[float64], list[tuple[float]]], peak_power(data, time_window=PSP_TIME_WINDOW, threshold=PSP_THRESHOLD, l_freq=l_freq, h_freq=h_freq, verbose=False))
|
|
|
|
# Identify channels that don't meet the provided threshold
|
|
logger.info("Identifying channels that do not meet the threshold...")
|
|
psp = scores.mean(axis=1)
|
|
data.info["bads"] = list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD))
|
|
|
|
# Determine the colors based on the threshold, and create the figures
|
|
logger.info("Creating the figures...")
|
|
color_stops = ([0.0, PSP_THRESHOLD, PSP_THRESHOLD+0.1, 0.3, 1.0], [0.0, PSP_THRESHOLD, PSP_THRESHOLD, 1.0])
|
|
psp1, psp2 = plot_timechannel_quality_metrics(data, scores, times, color_stops, PSP_THRESHOLD, "Peak Spectral Power")
|
|
|
|
logger.info("Successfully calculated peak spectral power.")
|
|
|
|
return list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)), psp1, psp2
|
|
|
|
|
|
|
|
def calculate_signal_noise_ratio(data: BaseRaw) -> tuple[list[str], Figure]:
|
|
"""
|
|
Calculates the signal-to-noise ratio (SNR) for each channel and identifies those below a defined threshold.
|
|
|
|
Parameters
|
|
----------
|
|
data : BaseRaw
|
|
The loaded data object to process.
|
|
|
|
Returns
|
|
-------
|
|
tuple[list[str], Figure]
|
|
- list[str]: A list of channel names that fall below the SNR threshold and are considered bad.
|
|
- Figure: A matplotlib Figure showing the channels' SNR values.
|
|
"""
|
|
|
|
logger.info("Calculating signal to noise ratio...")
|
|
|
|
# Compute the signal-to-noise ratio values
|
|
logger.info("Computing the signal to noise power...")
|
|
signal_band=(0.01, 0.5)
|
|
noise_band=(1.0, 10.0)
|
|
data_signal = data.copy().filter(*signal_band, verbose=False) #type: ignore
|
|
data_noise = data.copy().filter(*noise_band, verbose=False) #type: ignore
|
|
signal_power = np.mean(data_signal.get_data()**2, axis=1) #type: ignore
|
|
noise_power = np.mean(data_noise.get_data()**2, axis=1) #type: ignore
|
|
|
|
# Calculate the snr using the standard formula for dB
|
|
snr = 10 * np.log10(signal_power / (noise_power + np.finfo(float).eps))
|
|
|
|
# TODO: Understand what this does
|
|
groups: dict[str, list[str]] = {}
|
|
for ch in getattr(data, "ch_names"):
|
|
# Look for the space in the channel names and remove the characters after
|
|
# This is so we can get both oxy and deoxy to remove, as they will have the same source and detector
|
|
base = ch.rsplit(' ', 1)[0]
|
|
groups.setdefault(base, []).append(ch) # type: ignore
|
|
|
|
# If any of the channels do not meet our threshold, they will get inserted into the bad_channels set
|
|
bad_channels: set[str] = set()
|
|
for base, ch_list in groups.items():
|
|
if any(s < SNR_THRESHOLD for s, ch in zip(snr, getattr(data, "ch_names")) if ch in ch_list):
|
|
bad_channels.update(ch_list)
|
|
|
|
# Design and create the figure
|
|
logger.info("Creating the figure...")
|
|
snr_fig, ax = plt.subplots(figsize=(12, 4), layout="constrained") # type: ignore
|
|
colors = [(0/20, 'red'), (SNR_THRESHOLD/20, 'red'), ((SNR_THRESHOLD+.5)/20, 'yellow'), ((SNR_THRESHOLD+1)/20, 'green'), (20/20, 'green')]
|
|
cmap = LinearSegmentedColormap.from_list('custom_snr_cmap', colors)
|
|
norm = mcolors.Normalize(vmin=0, vmax=20)
|
|
scatter = ax.scatter(range(len(snr)), snr, c=snr, cmap=cmap, alpha=0.8, s=100, norm=norm) # type: ignore
|
|
ax.set(xlabel="Channel Number", ylabel="Signal-to-Noise Ratio (dB)", xlim=[0, len(snr)], ylim=[0, 20])
|
|
ax.axhline(SNR_THRESHOLD, color='black', linestyle='--', alpha=0.3, linewidth=1) # type: ignore
|
|
cbar = snr_fig.colorbar(scatter, ax=ax, label="SNR Thresholds (dB)") # type: ignore
|
|
cbar.set_ticks([0, SNR_THRESHOLD, SNR_THRESHOLD+1, 20]) # type: ignore
|
|
cbar.set_ticklabels(['0', str(SNR_THRESHOLD), str(SNR_THRESHOLD+1), '20']) # type: ignore
|
|
|
|
plt.close()
|
|
|
|
logger.info("Successfully calculated signal to noise ratio.")
|
|
|
|
return list(bad_channels), snr_fig
|
|
|
|
|
|
|
|
def mark_bad_channels(data: BaseRaw, ID: str, bad_channels_sci: set[str], bad_channels_psp: set[str], bad_channels_snr: set[str]) -> tuple[BaseRaw, Figure, int]:
|
|
"""
|
|
Drops bad channels from the data and generates a bar plot showing which channels were removed and why.
|
|
|
|
Parameters
|
|
----------
|
|
data : BaseRaw
|
|
The loaded data object to process.
|
|
ID : str
|
|
File name of the the snirf file that was loaded.
|
|
bad_channels_sci : set[str]
|
|
Channels marked as bad by the SCI method.
|
|
bad_channels_psp : set[str]
|
|
Channels marked as bad by the PSP method.
|
|
bad_channels_snr : set[str]
|
|
Channels marked as bad by the SNR method.
|
|
|
|
Returns
|
|
-------
|
|
tuple[BaseRaw, Figure]
|
|
- BaseRaw: The modified data object with bad channels removed.
|
|
- Figure: A matplotlib Figure showing the dropped channels categorized by method.
|
|
"""
|
|
|
|
logger.info("Dropping the channels that were marked bad...")
|
|
|
|
# Combine all of the bad channels into one and ensure the short channel is not present
|
|
bad_channels = bad_channels_sci | bad_channels_psp | bad_channels_snr
|
|
|
|
logger.info(f"Channels that were bad on SCI: {bad_channels_sci}")
|
|
logger.info(f"Channels that were bad on PSP: {bad_channels_psp}")
|
|
logger.info(f"Channels that were bad on SNR: {bad_channels_snr}")
|
|
logger.info(f"Total bad channels: {bad_channels}")
|
|
|
|
# Add the channles to the bads key and drop the bads key from the data
|
|
data.info["bads"] = list(bad_channels)
|
|
data = cast(BaseRaw, data.drop_channels(getattr(data, "info")["bads"])) # type: ignore
|
|
|
|
# Organize channels into categories
|
|
sets = [
|
|
(bad_channels_sci, "SCI"),
|
|
(bad_channels_psp, "PSP"),
|
|
(bad_channels_snr, "SNR"),
|
|
]
|
|
|
|
# Graph what channels were dropped and why they were dropped
|
|
channel_categories: dict[str, str] = {}
|
|
|
|
for ch in bad_channels:
|
|
present_in = [name for s, name in sets if ch in s]
|
|
# Create a label for the category
|
|
if len(present_in) == 1:
|
|
label = f"{present_in[0]} only"
|
|
else:
|
|
label = " + ".join(sorted(present_in))
|
|
channel_categories[ch] = label
|
|
|
|
# Sort channels alphabetically within categories for nicer visualization
|
|
logger.info("Sorting the bad channels by type...")
|
|
categories = sorted(set(channel_categories.values()))
|
|
channel_names: list[str] = []
|
|
category_labels: list[str] = []
|
|
for cat in categories:
|
|
chs_in_cat = sorted([ch for ch, c in channel_categories.items() if c == cat])
|
|
channel_names.extend(chs_in_cat)
|
|
category_labels.extend([cat] * len(chs_in_cat))
|
|
|
|
colors = {cat: FIXED_CATEGORY_COLORS[cat] for cat in categories}
|
|
|
|
logger.info("Creating the figure...")
|
|
# Create the figure
|
|
fig, ax = plt.subplots(figsize=(10, max(3, len(channel_names) * 0.3))) # type: ignore
|
|
y_pos = range(len(channel_names))
|
|
ax.barh(y_pos, [1]*len(channel_names), color=[colors[cat] for cat in category_labels]) # type: ignore
|
|
ax.set_yticks(y_pos) # type: ignore
|
|
ax.set_yticklabels(channel_names) # type: ignore
|
|
ax.set_xlabel("Marked as Bad") # type: ignore
|
|
ax.set_title(f"Bad Channels by Method for {ID}") # type: ignore
|
|
ax.set_xlim(0, 1)
|
|
ax.set_xticks([]) # type: ignore
|
|
ax.grid(axis='x', linestyle='--', alpha=0.7) # type: ignore
|
|
|
|
# Add a legend denoting why the channels were bad
|
|
for label, color in colors.items():
|
|
ax.bar(0, 0, color=color, label=label) # type: ignore
|
|
ax.legend() # type: ignore
|
|
|
|
fig.tight_layout()
|
|
plt.close(fig)
|
|
|
|
logger.info("Successfully dropped the channels that were marked bad.")
|
|
|
|
return data, fig, len(bad_channels)
|
|
|
|
|
|
|
|
def calculate_optical_density(data: BaseRaw, ID: str) -> tuple[BaseRaw, Figure]:
|
|
"""
|
|
Converts raw intensity data to optical density and generates a plot of the transformed signals.
|
|
|
|
Parameters
|
|
----------
|
|
data : BaseRaw
|
|
The loaded data object to process.
|
|
ID : str
|
|
File name of the the snirf file that was loaded.
|
|
|
|
Returns
|
|
-------
|
|
tuple[BaseRaw, Figure]
|
|
- BaseRaw: The transformed data in optical density format.
|
|
- Figure: A matplotlib figure displaying the optical density signals across all channels.
|
|
"""
|
|
|
|
logger.info("Calculating optical density...")
|
|
|
|
# Calculate the optical density from the raw data
|
|
optical_density_data = cast(BaseRaw, optical_density(data))
|
|
|
|
logger.info("Creating the figure...")
|
|
fig = cast(Figure, optical_density_data.plot(show=False, n_channels=len(getattr(data, "ch_names")), duration=getattr(data, "times")[-1]).figure) # type: ignore
|
|
fig.suptitle(f"Optical density data for {ID}", fontsize=16) # type: ignore
|
|
fig.subplots_adjust(top=0.92)
|
|
plt.close(fig)
|
|
|
|
logger.info("Successfully calculated optical density.")
|
|
|
|
return optical_density_data, fig
|
|
|
|
|
|
|
|
# STEP 9: Haemoglobin concentration
|
|
def calculate_haemoglobin_concentration(optical_density_data: BaseRaw, ID: str, file_path: str) -> tuple[BaseRaw, Figure]:
|
|
"""
|
|
Calculates haemoglobin concentration from optical density data using the Beer-Lambert law and generates a plot.
|
|
|
|
Parameters
|
|
----------
|
|
optical_density_data : BaseRaw
|
|
The data in optical density format.
|
|
ID : str
|
|
File name of the the snirf file that was loaded.
|
|
file_path : str
|
|
Entire file path if snirf file that was loaded.
|
|
|
|
Returns
|
|
-------
|
|
tuple[BaseRaw, Figure]
|
|
- BaseRaw: The haemoglobin concentration data object.
|
|
- Figure: A matplotlib figure displaying the haemoglobin concentration signals.
|
|
"""
|
|
|
|
logger.info("Calculating haemoglobin concentration data...")
|
|
|
|
# Get the haemoglobin concentration using beer lambert law
|
|
haemoglobin_concentration_data = beer_lambert_law(optical_density_data, ppf=calculate_dpf(file_path))
|
|
|
|
logger.info("Creating the figure...")
|
|
fig = cast(Figure, optical_density_data.plot(show=False, n_channels=len(getattr(optical_density_data, "ch_names")), duration=getattr(optical_density_data, "times")[-1]).figure) # type: ignore
|
|
fig.suptitle(f"Haemoglobin concentration data for {ID}", fontsize=16) # type: ignore
|
|
fig.subplots_adjust(top=0.92)
|
|
plt.close(fig)
|
|
|
|
logger.info("Successfully calculated haemoglobin concentration data.")
|
|
|
|
return haemoglobin_concentration_data, fig
|
|
|
|
|
|
|
|
# -------------------------------------- HARDCODED -----------------------------------------------
|
|
|
|
def extract_normal_epochs(haemoglobin_concentration_data: BaseRaw) -> dict[str, list[Any] | mne.evoked.EvokedArray]:
|
|
|
|
events, _ = mne.events_from_annotations(haemoglobin_concentration_data, event_id={"Reach": 1, "Start of Rest": 2}, verbose=VERBOSITY) # type: ignore
|
|
event_dict = {"Reach": 1, "Start of Rest": 2}
|
|
|
|
epochs = mne.Epochs(
|
|
haemoglobin_concentration_data,
|
|
events,
|
|
event_id=event_dict,
|
|
tmin=TIME_MIN_THRESH,
|
|
tmax=TIME_MAX_THRESH,
|
|
reject=dict(hbo=EPOCH_REJECT_CRITERIA_THRESH),
|
|
reject_by_annotation=True,
|
|
proj=True,
|
|
baseline=(None, 0),
|
|
preload=True,
|
|
detrend=None,
|
|
verbose=VERBOSITY,
|
|
)
|
|
|
|
evoked_dict: dict[str, list[Any] | mne.evoked.EvokedArray] = {
|
|
"Reach/HbO": epochs["Reach"].average(picks="hbo"), # type: ignore
|
|
"Reach/HbR": epochs["Reach"].average(picks="hbr"), # type: ignore
|
|
}
|
|
|
|
# Rename channels until the encoding of frequency in ch_name is fixed
|
|
for condition in evoked_dict:
|
|
evoked_dict[condition].rename_channels(lambda x: x[:-4]) # type: ignore
|
|
|
|
return evoked_dict
|
|
|
|
|
|
|
|
def calculate_and_apply_negative_correlation_enhancement(haemoglobin_concentration_data: BaseRaw) -> dict[str, list[Any] | mne.evoked.EvokedArray]:
|
|
|
|
events, _ = mne.events_from_annotations(haemoglobin_concentration_data, event_id={"Reach": 1, "Start of Rest": 2}, verbose=VERBOSITY) # type: ignore
|
|
event_dict = {"Reach": 1, "Start of Rest": 2}
|
|
|
|
raw_anti = enhance_negative_correlation(haemoglobin_concentration_data)
|
|
|
|
epochs_anti = mne.Epochs(
|
|
raw_anti,
|
|
events,
|
|
event_id=event_dict,
|
|
tmin=TIME_MIN_THRESH,
|
|
tmax=TIME_MAX_THRESH,
|
|
reject=dict(hbo=EPOCH_REJECT_CRITERIA_THRESH),
|
|
reject_by_annotation=True,
|
|
proj=True,
|
|
baseline=(None, 0),
|
|
preload=True,
|
|
detrend=None,
|
|
verbose=VERBOSITY,
|
|
)
|
|
|
|
evoked_dict_anti: dict[str, list[Any] | mne.evoked.EvokedArray] = {
|
|
"Reach/HbO": epochs_anti["Reach"].average(picks="hbo"), # type: ignore
|
|
"Reach/HbR": epochs_anti["Reach"].average(picks="hbr"), # type: ignore
|
|
}
|
|
|
|
# Rename channels until the encoding of frequency in ch_name is fixed
|
|
for condition in evoked_dict_anti:
|
|
evoked_dict_anti[condition].rename_channels(lambda x: x[:-4]) # type: ignore
|
|
|
|
return evoked_dict_anti
|
|
|
|
|
|
|
|
def calculate_and_apply_short_channel_correction(optical_density_data: BaseRaw, file_path: str) -> dict[str, list[Any] | mne.evoked.EvokedArray]:
|
|
|
|
od_corrected = short_channel_regression(optical_density_data, SHORT_CHANNEL_THRESH)
|
|
haemoglobin_concentration_data = beer_lambert_law(od_corrected, ppf=calculate_dpf(file_path))
|
|
|
|
events, _ = mne.events_from_annotations(haemoglobin_concentration_data, event_id={"Reach": 1, "Start of Rest": 2}, verbose=VERBOSITY) # type: ignore
|
|
event_dict = {"Reach": 1, "Start of Rest": 2}
|
|
|
|
epochs_corr = mne.Epochs(
|
|
haemoglobin_concentration_data,
|
|
events,
|
|
event_id=event_dict,
|
|
tmin=TIME_MIN_THRESH,
|
|
tmax=TIME_MAX_THRESH,
|
|
reject=dict(hbo=EPOCH_REJECT_CRITERIA_THRESH),
|
|
reject_by_annotation=True,
|
|
proj=True,
|
|
baseline=(None, 0),
|
|
preload=True,
|
|
detrend=None,
|
|
verbose=VERBOSITY,
|
|
)
|
|
|
|
evoked_dict_corr: dict[str, list[Any] | mne.evoked.EvokedArray] = {
|
|
"Reach/HbO": epochs_corr["Reach"].average(picks="hbo"), # type: ignore
|
|
"Reach/HbR": epochs_corr["Reach"].average(picks="hbr"), # type: ignore
|
|
}
|
|
|
|
# Rename channels until the encoding of frequency in ch_name is fixed
|
|
for condition in evoked_dict_corr:
|
|
evoked_dict_corr[condition].rename_channels(lambda x: x[:-4]) # type: ignore
|
|
|
|
return evoked_dict_corr
|
|
|
|
|
|
|
|
def signal_enhancement_techniques_images(evoked_dict: dict[str, list[Any] | mne.evoked.EvokedArray], evoked_dict_anti: dict[str, list[Any] | mne.evoked.EvokedArray], evoked_dict_corr:dict[str, list[Any] | mne.evoked.EvokedArray] | None):
|
|
|
|
# If we have two images, ensure we only have two columns
|
|
if evoked_dict_corr is None:
|
|
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 6)) # type: ignore
|
|
|
|
# If we have three images, ensure we have three columns
|
|
else:
|
|
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 6)) # type: ignore
|
|
|
|
color_dict = dict(HbO="#AA3377", HbR="b")
|
|
|
|
# TODO: This is to prevent the warning that we are only plotting one channel. Don't we want all though?
|
|
mne.set_log_level('WARNING') # type: ignore
|
|
|
|
logger.info("Creating the figure...")
|
|
|
|
# Plot the graph for the original data
|
|
mne.viz.plot_compare_evokeds( # type: ignore
|
|
evoked_dict,
|
|
combine="mean",
|
|
ci=0.95, # type: ignore
|
|
axes=axes[0],
|
|
colors=color_dict,
|
|
ylim=dict(hbo=[-10, 15]),
|
|
show=False,
|
|
)
|
|
|
|
# Plot the graph for the enhanced anticorrelation data
|
|
mne.viz.plot_compare_evokeds( # type: ignore
|
|
evoked_dict_anti,
|
|
combine="mean",
|
|
ci=0.95, # type: ignore
|
|
axes=axes[1],
|
|
colors=color_dict,
|
|
ylim=dict(hbo=[-10, 15]),
|
|
show=False,
|
|
)
|
|
|
|
# Plot the graph for short channel regression data, if it exists
|
|
if evoked_dict_corr is not None:
|
|
mne.viz.plot_compare_evokeds( # type: ignore
|
|
evoked_dict_corr,
|
|
combine="mean",
|
|
ci=0.95, # type: ignore
|
|
axes=axes[2],
|
|
colors=color_dict,
|
|
ylim=dict(hbo=[-10, 15]),
|
|
show=False,
|
|
)
|
|
|
|
mne.set_log_level('INFO') # type: ignore
|
|
|
|
# If we have a short channel, set three titles
|
|
if evoked_dict_corr is not None:
|
|
for column, condition in enumerate(
|
|
["Original Data", "With Enhanced Anticorrelation", "With Short Regression"]
|
|
):
|
|
axes[column].set_title(f"{condition}")
|
|
|
|
# If we do not have a short channel, set two titles
|
|
else:
|
|
for column, condition in enumerate(
|
|
["Original Data", "With Enhanced Anticorrelation"]
|
|
):
|
|
axes[column].set_title(f"{condition}")
|
|
|
|
plt.close(fig)
|
|
|
|
return fig
|
|
|
|
|
|
|
|
def create_design_matrix(data: BaseRaw, stim_duration: float, short_chans: BaseRaw | None) -> tuple[DataFrame, Figure]:
|
|
"""
|
|
Creates a design matrix for first-level analysis including optional short channel regression, and generates a plot.
|
|
|
|
Parameters
|
|
----------
|
|
data : BaseRaw
|
|
The loaded data object to process.
|
|
stim_duration : float
|
|
Duration of the stimulus/event in seconds.
|
|
short_chans : BaseRaw | None
|
|
Data object containing only short channels for systemic component regression, or None if there is no short channels.
|
|
|
|
Returns
|
|
-------
|
|
tuple[DataFrame, Figure]
|
|
- DataFrame: The generated design matrix.
|
|
- Figure: A matplotlib figure visualizing the design matrix.
|
|
"""
|
|
|
|
# Create the design martix
|
|
logger.info("Creating the design matrix... (This may take some time)")
|
|
|
|
# If the design matrix is fir, calculate some of the extra required parameters before creating the matrix
|
|
if HRF_MODEL == "fir":
|
|
sfreq = getattr(data, "info")["sfreq"]
|
|
fir_delays = range(int(sfreq*15))
|
|
|
|
design_matrix = make_first_level_design_matrix(
|
|
data,
|
|
stim_dur=0.1,
|
|
hrf_model=HRF_MODEL,
|
|
drift_model=DRIFT_MODEL,
|
|
high_pass=1/(2*DURATION_BETWEEN_ACTIVITIES),
|
|
fir_delays=fir_delays
|
|
)
|
|
|
|
# Using a canonical hrf model
|
|
else:
|
|
design_matrix = make_first_level_design_matrix(
|
|
data,
|
|
stim_dur=stim_duration,
|
|
hrf_model=HRF_MODEL,
|
|
drift_model=DRIFT_MODEL,
|
|
high_pass=1/(2*DURATION_BETWEEN_ACTIVITIES),
|
|
)
|
|
|
|
# If we have a short channel, and short channel regression was specified, apply it to the design matrix
|
|
if short_chans is not None:
|
|
if SHORT_CHANNEL_REGRESSION:
|
|
logger.info("Applying short channel regression...")
|
|
for chan in range(len(short_chans.ch_names)): # type: ignore
|
|
design_matrix[f"short_{chan}"] = short_chans.get_data(chan).T # type: ignore
|
|
|
|
logger.info("Creating the figure...")
|
|
fig, ax1 = plt.subplots(figsize=(10, 6), constrained_layout=True) # type: ignore
|
|
plot_design_matrix(design_matrix, axes=ax1)
|
|
plt.close(fig)
|
|
|
|
logger.info("Successfully created the design matrix.")
|
|
|
|
return design_matrix, fig
|
|
|
|
|
|
|
|
def run_GLM_analysis(data: BaseRaw, design_matrix: DataFrame) -> RegressionResults:
|
|
"""
|
|
Runs a General Linear Model (GLM) analysis on the provided data using the specified design matrix.
|
|
|
|
Parameters
|
|
----------
|
|
data : BaseRaw
|
|
The loaded data object to process.
|
|
design_matrix : DataFrame
|
|
The design matrix specifying regressors for the GLM.
|
|
|
|
Returns
|
|
-------
|
|
RegressionResults
|
|
The fitted GLM results object containing regression coefficients and statistics.
|
|
"""
|
|
|
|
logger.info("Running the GLM...")
|
|
glm_est = run_glm(data, design_matrix, n_jobs=N_JOBS)
|
|
logger.info("Successfully ran the GLM.")
|
|
return glm_est
|
|
|
|
|
|
def calculate_dpf(file_path):
|
|
# order is hbo / hbr
|
|
with h5py.File(file_path, 'r') as f:
|
|
wavelengths = f['/nirs/probe/wavelengths'][:]
|
|
print("Wavelengths (nm):", wavelengths)
|
|
wavelengths = sorted(wavelengths, reverse=True)
|
|
data = METADATA.get(file_path)
|
|
if data is None:
|
|
age = 25
|
|
else:
|
|
age = data['Age']
|
|
logger.info(age)
|
|
age = float(age)
|
|
a = 223.3
|
|
b = 0.05624
|
|
c = 0.8493
|
|
d = -5.723e-7
|
|
e = 0.001245
|
|
f = -0.9025
|
|
dpf = []
|
|
for w in wavelengths:
|
|
logger.info(w)
|
|
dpf.append(a + b * (age**c) + d* (w**3) + e * (w**2) + f*w)
|
|
logger.info(dpf)
|
|
return dpf
|
|
|
|
|
|
def individual_GLM_analysis(file_path: str, ID: str, stim_duration: float = 5.0, progress_callback=None) -> tuple[BaseRaw, BaseRaw, DataFrame, DataFrame, DataFrame, DataFrame, dict[str, Figure], str, bool, bool]:
|
|
"""
|
|
Performs individual-level General Linear Model (GLM) analysis on fNIRS data from a SNIRF file.
|
|
|
|
Parameters
|
|
----------
|
|
file_path : str
|
|
Path to the SNIRF file containing the participant's raw data.
|
|
ID : str
|
|
Unique identifier for the participant, used for labeling output.
|
|
stim_duration : float, optional
|
|
Duration of the stimulus in seconds for constructing the design matrix (default is 5.0)
|
|
|
|
Returns
|
|
-------
|
|
tuple[BaseRaw, BaseRaw, DataFrame, DataFrame, DataFrame, DataFrame, dict[str, Figure], str, bool, bool]
|
|
- BaseRaw: Processed fNIRS data
|
|
- BaseRaw: Full layout raw data prior to bad channel rejection
|
|
- DataFrame: Region of interest statistics
|
|
- DataFrame: Channel-level GLM statistics
|
|
- DataFrame: Contrast results
|
|
- DataFrame: Design matrix used for GLM
|
|
- dict[str, Figure]: Dictionary of figures generated during processing
|
|
- str: Description of processing steps applied
|
|
- bool: Whether the GLM successfully ran to completion
|
|
- bool: Whether the analysis result is valid based on quality checks
|
|
"""
|
|
|
|
# Setting up variables to be used later
|
|
fig_dict: dict[str, Figure] = {}
|
|
bad_channels_sci = []
|
|
bad_channels_psp = []
|
|
bad_channels_snr = []
|
|
mean_hr_nk = 70
|
|
mean_hr_scipy = 70
|
|
num_bad_channels = 0
|
|
valid = True
|
|
short_chans = None
|
|
roi: DataFrame = DataFrame()
|
|
cha: DataFrame = DataFrame()
|
|
con: DataFrame = DataFrame()
|
|
design_matrix = DataFrame()
|
|
|
|
# Load the file, get the sources and detectors, update their position, and calculate the short channel and any large distance channels
|
|
# STEP 1
|
|
data, fig = load_snirf(file_path, ID, FORCE_DROP_CHANNELS)
|
|
fig_dict['Raw'] = fig
|
|
order_of_operations = "Loaded Raw File"
|
|
if progress_callback: progress_callback(1)
|
|
|
|
# Initalize the participants full layout to be the current data regardless if it will be updated later
|
|
raw_full_layout = data
|
|
|
|
|
|
logger.info(file_path)
|
|
logger.info(ID)
|
|
logger.info(METADATA.get(file_path))
|
|
calculate_dpf(file_path)
|
|
|
|
try:
|
|
# Did the user want to load new channel positions from an optode file?
|
|
# STEP 2
|
|
if OPTODE_FILE:
|
|
data = calculate_and_apply_updated_optode_coordinates(data)
|
|
order_of_operations += " + Updated Optode Placements"
|
|
if progress_callback: progress_callback(2)
|
|
|
|
# STEP 2.5
|
|
# TODO: remember why i do this
|
|
# I think its because i want a participants whole layout to plot without any bads
|
|
# but i shouldnt need to do od and bll just check the last three numbers
|
|
temp = data.copy()
|
|
temp_od = cast(BaseRaw, optical_density(temp, verbose=VERBOSITY))
|
|
raw_full_layout = beer_lambert_law(temp_od, ppf=calculate_dpf(file_path))
|
|
|
|
# If specified, apply TDDR to the data
|
|
# STEP 3
|
|
if TDDR:
|
|
data, fig = calculate_and_apply_tddr(data, ID)
|
|
order_of_operations += " + TDDR Filter"
|
|
fig_dict['TDDR'] = fig
|
|
if progress_callback: progress_callback(3)
|
|
|
|
# If specified, apply a wavelet filter to the data
|
|
# STEP 4
|
|
if WAVELET:
|
|
data, fig = calculate_and_apply_wavelet(data, ID)
|
|
order_of_operations += " + Wavelet Filter"
|
|
fig_dict['Wavelet'] = fig
|
|
if progress_callback: progress_callback(4)
|
|
|
|
# If specified, attempt to get short channels from the data
|
|
# STEP 4.5
|
|
if SHORT_CHANNEL:
|
|
try:
|
|
short_chans = get_short_channels(data, SHORT_CHANNEL_THRESH)
|
|
except Exception as e:
|
|
raise ProcessingError("SHORT_CHANNEL was specified, but no short channel was found. Please ensure the data has a short channel and that SHORT_CHANNEL_THRESH is set correctly.")
|
|
pass
|
|
else:
|
|
pass
|
|
|
|
# Ensure that there is no short or really long channels in the data
|
|
data = get_long_channels(data, SHORT_CHANNEL_THRESH, LONG_CHANNEL_THRESH)
|
|
|
|
# STEP 5
|
|
if HEART_RATE:
|
|
sfreq, signal_trimmed, times_trimmed = short_channel_processing_for_hr(data, short_chans)
|
|
hr_smooth_nk, mean_hr_nk = calculate_heart_rate_neurokit(sfreq, signal_trimmed)
|
|
freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy = calculate_heart_rate_scipy(sfreq, signal_trimmed)
|
|
|
|
# HACK: This sucks but looking at the graphs I trust neurokit2 more
|
|
overruled = False
|
|
if mean_hr_scipy < mean_hr_nk - 15:
|
|
mean_hr_scipy = mean_hr_nk
|
|
overruled = True
|
|
if mean_hr_scipy > mean_hr_nk + 15:
|
|
mean_hr_scipy = mean_hr_nk
|
|
overruled = True
|
|
|
|
hr1, hr2 = plot_heart_rate(freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy, hr_smooth_nk, mean_hr_nk, times_trimmed, overruled)
|
|
order_of_operations += " + Heart Rate Calculation"
|
|
fig_dict['HeartRate_PSD'] = hr1
|
|
fig_dict['HeartRate_Time'] = hr2
|
|
if progress_callback: progress_callback(5)
|
|
|
|
# If specified, calculate and apply SCI
|
|
# STEP 6
|
|
if SCI:
|
|
bad_channels_sci, sci1, sci2 = calculate_scalp_coupling(data.copy(), min(mean_hr_nk - HEART_RATE_WINDOW, mean_hr_scipy - HEART_RATE_WINDOW) / 60, max(mean_hr_nk + HEART_RATE_WINDOW, mean_hr_scipy + HEART_RATE_WINDOW) / 60)
|
|
order_of_operations += " + SCI Calculation"
|
|
fig_dict['SCI1'] = sci1
|
|
fig_dict['SCI2'] = sci2
|
|
if progress_callback: progress_callback(6)
|
|
|
|
# If specified, calculate and apply PSP
|
|
if PSP:
|
|
bad_channels_psp, psp1, psp2 = calculate_peak_power(data.copy(), min(mean_hr_nk - HEART_RATE_WINDOW, mean_hr_scipy - HEART_RATE_WINDOW) / 60, max(mean_hr_nk + HEART_RATE_WINDOW, mean_hr_scipy + HEART_RATE_WINDOW) / 60)
|
|
order_of_operations += " + PSP Calculation"
|
|
fig_dict['PSP1'] = psp1
|
|
fig_dict['PSP2'] = psp2
|
|
if progress_callback: progress_callback(7)
|
|
|
|
# If specified, calculate and apply SNR
|
|
if SNR:
|
|
bad_channels_snr, fig = calculate_signal_noise_ratio(data.copy())
|
|
order_of_operations += " + SNR Calculation"
|
|
fig_dict['SNR'] = fig
|
|
|
|
# If specified, drop channels that were marked as bad
|
|
# STEP 7
|
|
if EXCLUDE_CHANNELS:
|
|
data, fig, num_bad_channels = mark_bad_channels(data, ID, set(bad_channels_sci), set(bad_channels_psp), set(bad_channels_snr))
|
|
order_of_operations += " + Excluded Bad Channels"
|
|
fig_dict['Bads'] = fig
|
|
if progress_callback: progress_callback(7)
|
|
|
|
# Calculate the optical density
|
|
# STEP 8
|
|
data, fig = calculate_optical_density(data, ID)
|
|
order_of_operations += " + Optical Density"
|
|
fig_dict['OpticalDensity'] = fig
|
|
if progress_callback: progress_callback(8)
|
|
|
|
# Mainly for visualization. Could be implemented in the future
|
|
# STEP 8.5
|
|
evoked_dict_corr = None
|
|
if SHORT_CHANNEL:
|
|
short_chans_od = cast(BaseRaw, optical_density(short_chans))
|
|
data_recombined = cast(BaseRaw, data.copy().add_channels([short_chans_od])) # type: ignore
|
|
evoked_dict_corr = calculate_and_apply_short_channel_correction(data_recombined.copy(), file_path)
|
|
|
|
# Calculate the haemoglobin concentration
|
|
# STEP 9
|
|
data, fig = calculate_haemoglobin_concentration(data, ID, file_path)
|
|
order_of_operations += " + Haemoglobin Concentration"
|
|
fig_dict['HaemoglobinConcentration'] = fig
|
|
if progress_callback: progress_callback(9)
|
|
|
|
# Mainly for visualization. Could be implemented in the future
|
|
# STEP 9.5
|
|
evoked_dict = extract_normal_epochs(data.copy())
|
|
evoked_dict_anti = calculate_and_apply_negative_correlation_enhancement(data.copy())
|
|
fig = signal_enhancement_techniques_images(evoked_dict, evoked_dict_anti, evoked_dict_corr)
|
|
fig_dict['SignalEnhancement'] = fig
|
|
|
|
# Create the design matrix
|
|
# STEP 10
|
|
# HACK FIXME - Downsampling to 10 is certaintly not the best way... right?
|
|
if HRF_MODEL == 'fir':
|
|
data.resample(10, verbose=VERBOSITY) # type: ignore
|
|
if short_chans is not None:
|
|
short_chans.resample(10, verbose=VERBOSITY) # type: ignore
|
|
|
|
design_matrix, fig = create_design_matrix(data, stim_duration, short_chans)
|
|
order_of_operations += " + Design Matrix"
|
|
fig_dict['DesignMatrix'] = fig
|
|
if progress_callback: progress_callback(10)
|
|
|
|
# Run the glm on the design matrix
|
|
# STEP 11
|
|
glm_est: RegressionResults = run_GLM_analysis(data, design_matrix)
|
|
order_of_operations += " + GLM"
|
|
if progress_callback: progress_callback(11)
|
|
|
|
# Add the regions of interest to the groups
|
|
# STEP 12
|
|
logger.info("Performing the finishing touches...")
|
|
order_of_operations += " + Finishing Touches"
|
|
|
|
# Extract the channel metrics
|
|
logger.info("Calculating channel results...")
|
|
cha = cast(DataFrame, glm_est.to_dataframe()) # type: ignore
|
|
|
|
logger.info("Creating groups...")
|
|
if HRF_MODEL == "fir":
|
|
groups = dict(AllChannels=range(len(data.ch_names))) # type: ignore
|
|
|
|
else:
|
|
groups: dict[str, list[int]] = dict(
|
|
group_1_picks = picks_pair_to_idx(data, ROI_GROUP_1, on_missing="ignore"), # type: ignore
|
|
group_2_picks = picks_pair_to_idx(data, ROI_GROUP_2, on_missing="ignore"), # type: ignore
|
|
)
|
|
|
|
# Compute region of interest results from the channel data
|
|
logger.info("Calculating region of intrest results...")
|
|
roi = glm_est.to_dataframe_region_of_interest(groups, design_matrix.columns, demographic_info=True) # type: ignore
|
|
|
|
# Create the contrast matrix
|
|
logger.info("Creating the contrast matrix...")
|
|
contrast_matrix = np.eye(design_matrix.shape[1])
|
|
basic_conts = dict(
|
|
[(column, contrast_matrix[i]) for i, column in enumerate(design_matrix.columns)]
|
|
)
|
|
|
|
# Calculate contrast differently depending on the hrf model
|
|
if HRF_MODEL == 'fir':
|
|
# Find all FIR regressors for TARGET_ACTIVITY
|
|
delay_cols = [col for col in design_matrix.columns if col.startswith(f"{TARGET_ACTIVITY}_delay_")]
|
|
|
|
if not delay_cols:
|
|
raise ValueError(f"No FIR regressors found for condition {TARGET_ACTIVITY}.")
|
|
|
|
# Sum or average their contrast vectors
|
|
fir_contrast = np.sum([basic_conts[col] for col in delay_cols], axis=0)
|
|
fir_contrast /= len(delay_cols)
|
|
|
|
# Compute contrast
|
|
contrast = glm_est.compute_contrast(fir_contrast) # type: ignore
|
|
con = cast(DataFrame, contrast.to_dataframe()) # type: ignore
|
|
|
|
else:
|
|
# Create and compute the contrast
|
|
contrast_t = basic_conts[TARGET_ACTIVITY]
|
|
contrast = glm_est.compute_contrast(contrast_t) # type: ignore
|
|
con = cast(DataFrame, contrast.to_dataframe()) # type: ignore
|
|
|
|
# Add the participant ID to the dataframes
|
|
roi["ID"] = cha["ID"] = con["ID"] = design_matrix["ID"] = ID
|
|
|
|
# Convert to uM for nicer plotting below.
|
|
logger.info("Converting to uM...")
|
|
cha["theta"] = cha["theta"].astype(float) * 1.0e6
|
|
roi["theta"] = roi["theta"].astype(float) * 1.0e6
|
|
con["effect"] = con["effect"].astype(float) * 1.0e6
|
|
|
|
# If we exceed the maximum allowed bad channels, apply an X over the figures
|
|
logger.info("Checking amount of bad channels...")
|
|
if num_bad_channels >= MAX_BAD_CHANNELS:
|
|
valid=False
|
|
logger.info("Drawing some big X's...")
|
|
for _, fig in fig_dict.items():
|
|
add_x_overlay(fig, 'Too many bad channels!', 'red')
|
|
|
|
logger.info("Completed individual analysis.")
|
|
|
|
if progress_callback: progress_callback(12)
|
|
|
|
# Clear the output for the next participant unless we are told to be verbose
|
|
if not VERBOSITY:
|
|
clear_output(wait=True)
|
|
|
|
|
|
# Something really went wrong and we should not continue
|
|
except ProcessingError as e:
|
|
logger.info("An error occured!", e)
|
|
raise
|
|
|
|
# Something went wrong at one of the steps. Return what data we gathered, but set the validity of this run to False
|
|
except Exception as e:
|
|
logger.info("An error occured!", e)
|
|
fig_dict_bytes = convert_fig_dict_to_png_bytes(fig_dict)
|
|
return data, raw_full_layout, roi, cha, con, design_matrix, fig_dict, order_of_operations, False, False
|
|
|
|
fig_dict_bytes = convert_fig_dict_to_png_bytes(fig_dict)
|
|
|
|
return data, raw_full_layout, roi, cha, con, design_matrix, fig_dict_bytes, order_of_operations, True, valid
|
|
|
|
|
|
|
|
def add_x_overlay(fig: Figure, reason: str, color: str) -> None:
|
|
"""
|
|
Adds a large 'X' across the figure if the participant met the bad channel criteria.
|
|
|
|
Parameters
|
|
----------
|
|
fig : Figure
|
|
Matplotlib figure to draw the X on.
|
|
reason: str
|
|
Why the X is being drawn.
|
|
color: str
|
|
What color the reason should be.
|
|
"""
|
|
|
|
# Draw the big X on the graph
|
|
ax = fig.add_axes([0, 0, 1, 1], zorder=100) # type: ignore
|
|
ax.set_axis_off()
|
|
ax.plot([0, 1], [0, 1], color='red', linewidth=8, transform=fig.transFigure, clip_on=False) # type: ignore
|
|
ax.plot([0, 1], [1, 0], color='red', linewidth=8, transform=fig.transFigure, clip_on=False) # type: ignore
|
|
ax.text(0.5, 0.5, reason, color=color, fontsize=26, fontweight='bold', ha='center', va='center', transform=fig.transFigure, zorder=101, bbox=dict(facecolor='white', alpha=0.8, edgecolor='red', boxstyle='round,pad=0.4')) # type: ignore
|
|
|
|
|
|
|
|
from io import BytesIO
|
|
|
|
def convert_fig_dict_to_png_bytes(fig_dict):
|
|
result = {}
|
|
for key, fig in fig_dict.items():
|
|
buf = BytesIO()
|
|
fig.savefig(buf, format='png')
|
|
buf.seek(0)
|
|
result[key] = buf.read()
|
|
return result
|
|
|
|
|
|
|
|
def process_file_worker(args):
|
|
file_path, file_name, stim_duration, config, gui, progress_queue = args
|
|
try:
|
|
set_config(config, gui)
|
|
|
|
def progress_callback(step_idx):
|
|
print(f"[Worker] Step {step_idx} for {file_name}")
|
|
if progress_queue:
|
|
progress_queue.put(('progress', file_name, step_idx))
|
|
|
|
result = individual_GLM_analysis(
|
|
file_path, file_name, stim_duration,
|
|
progress_callback=progress_callback
|
|
)
|
|
return file_name, result, None
|
|
except Exception as e:
|
|
return file_name, None, e
|
|
|
|
|
|
|
|
def process_folder(folder_path: str, stim_duration: float, files_remaining: dict[str, int], config , gui: bool = False, progress_queue=None) -> tuple[dict[str, dict[str, BaseRaw]], DataFrame, DataFrame, DataFrame, DataFrame, dict[str, list[Figure]], dict[str, str]]:
|
|
df_roi = DataFrame()
|
|
df_cha = DataFrame()
|
|
df_con = DataFrame()
|
|
df_design_matrix = DataFrame()
|
|
|
|
raw_haemo_dict: dict[str, dict[str, BaseRaw]] = {}
|
|
process_dict: dict[str, str] = {}
|
|
figures_by_step: dict[str, list[Figure]] = {
|
|
step: [] for step in [
|
|
'Raw', 'TDDR', 'Wavelet', 'HeartRate_PSD', 'HeartRate_Time',
|
|
'SCI1', 'SCI2', 'PSP1', 'PSP2', 'SNR', 'Bads',
|
|
'OpticalDensity', 'HaemoglobinConcentration', 'SignalEnhancement', 'DesignMatrix'
|
|
]
|
|
}
|
|
|
|
file_args = [
|
|
(os.path.join(folder_path, file_name), file_name, stim_duration, config, gui, progress_queue)
|
|
for file_name in os.listdir(folder_path)
|
|
if os.path.isfile(os.path.join(folder_path, file_name))
|
|
]
|
|
|
|
print("[process_folder] File args:", file_args)
|
|
|
|
available_mem = psutil.virtual_memory().available
|
|
if (MAX_WORKERS >= available_mem / (1024 ** 3)):
|
|
print(f"WARNING: You have set MAX_WORKERS to {MAX_WORKERS}. Each worker should have at least 1GB of system memory. Your device currently has a total of {available_mem / (1024 ** 3):.2f}GB free.\nPlease consider lowering MAX_WORKERS to prevent potential crashing due to insufficient system memory.")
|
|
|
|
with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
|
future_to_file = {
|
|
executor.submit(process_file_worker, args): args[1] for args in file_args
|
|
}
|
|
|
|
with tqdm(total=len(file_args), desc="Processing files") as pbar:
|
|
for future in as_completed(future_to_file):
|
|
file_name = future_to_file[future]
|
|
files_remaining['count'] -= 1
|
|
logger.info(f"Files remaining: {files_remaining['count']}")
|
|
pbar.update(1)
|
|
|
|
try:
|
|
file_name, result, error = future.result()
|
|
if error:
|
|
logger.info(f"Error processing {file_name}: {error}")
|
|
continue
|
|
|
|
raw_haemo_filtered, raw_haemo_full, roi, channel, contrast, design_matrix, fig_dict, process, finished, valid = result
|
|
|
|
if finished and valid:
|
|
logger.info(f"Finished processing {file_name}. This participant was valid.")
|
|
raw_haemo_dict[file_name] = {
|
|
"filtered": raw_haemo_filtered,
|
|
"full_layout": raw_haemo_full
|
|
}
|
|
process_dict[file_name] = process
|
|
|
|
for step in figures_by_step:
|
|
if step in fig_dict:
|
|
figures_by_step[step].append(fig_dict[step])
|
|
|
|
df_roi = pd.concat([df_roi, roi], ignore_index=True)
|
|
df_cha = pd.concat([df_cha, channel], ignore_index=True)
|
|
df_con = pd.concat([df_con, contrast], ignore_index=True)
|
|
df_design_matrix = pd.concat([df_design_matrix, design_matrix], ignore_index=True)
|
|
|
|
else:
|
|
logger.info(f"Finished processing {file_name}. This participant was NOT valid.")
|
|
if SEE_BAD_IMAGES:
|
|
for step in figures_by_step:
|
|
if step in fig_dict:
|
|
figures_by_step[step].append(fig_dict[step])
|
|
except Exception as e:
|
|
logger.info(f"Unexpected error processing {file_name}: {e}")
|
|
raise
|
|
|
|
return raw_haemo_dict, df_roi, df_cha, df_con, df_design_matrix, figures_by_step, process_dict
|
|
|
|
|
|
|
|
def verify_channel_positions(data: BaseRaw) -> None:
|
|
"""
|
|
Visualizes the sensor/channel positions of the raw data for verification.
|
|
|
|
Parameters
|
|
----------
|
|
data : BaseRaw
|
|
The loaded data object to process.
|
|
"""
|
|
logger.info("Creating the figure...")
|
|
data.plot_sensors(show_names=True, to_sphere=True, show=False, verbose=VERBOSITY) # type: ignore
|
|
plt.show() # type: ignore
|
|
|
|
|
|
|
|
def plot_3d_evoked_array(
|
|
inst: Union[BaseRaw, EvokedArray, Info],
|
|
statsmodel_df: DataFrame,
|
|
picks: Optional[Union[str, list[str]]] = "hbo",
|
|
value: str = "Coef.",
|
|
background: str = "w",
|
|
figure: Optional[object] = None,
|
|
clim: Union[str, dict[str, Union[str, list[float]]]] = "auto",
|
|
mode: str = "weighted",
|
|
colormap: str = "RdBu_r",
|
|
surface: str = "pial",
|
|
hemi: str = "both",
|
|
size: int = 800,
|
|
view: Optional[Union[str, dict[str, float]]] = None,
|
|
colorbar: bool = True,
|
|
distance: float = 0.03,
|
|
subjects_dir: Optional[str] = None,
|
|
src: Optional[SourceSpaces] = None,
|
|
verbose: bool = False,
|
|
) -> Brain:
|
|
'''Ported from MNE'''
|
|
|
|
info: Info = cast(Info, deepcopy(inst if isinstance(inst, Info) else inst.info)) # type: ignore
|
|
if not (getattr(info, "ch_names") == list(statsmodel_df["ch_name"].values)): # type: ignore
|
|
raise RuntimeError(
|
|
'MNE data structure does not match dataframe '
|
|
f'results.\nMNE = {getattr(info, "ch_names")}.\n'
|
|
f'GLM = {list(statsmodel_df["ch_name"].values)}' # type: ignore
|
|
)
|
|
|
|
ea = EvokedArray(np.tile(statsmodel_df[value].values.T, (1, 1)).T, info.copy()) # type: ignore
|
|
|
|
# TODO: mimic behaviour of other MNE-NIRS glm plotting options
|
|
if picks is not None:
|
|
ea = ea.pick(picks=picks) # type: ignore
|
|
|
|
if subjects_dir is None:
|
|
subjects_dir = os.environ["SUBJECTS_DIR"]
|
|
if src is None:
|
|
fname_src_fs = os.path.join(
|
|
subjects_dir, "fsaverage", "bem", "fsaverage-ico-5-src.fif"
|
|
)
|
|
src = read_source_spaces(fname_src_fs, verbose=verbose)
|
|
|
|
picks = getattr(ea, "info")["ch_names"]
|
|
|
|
# Set coord frame
|
|
for idx in range(len(getattr(ea, "ch_names"))):
|
|
getattr(ea, "info")["chs"][idx]["coord_frame"] = 4
|
|
|
|
# Generate source estimate
|
|
kwargs = dict(
|
|
evoked=ea,
|
|
subject="fsaverage",
|
|
trans=Transform('head', 'mri', np.eye(4)),
|
|
distance=distance,
|
|
mode=mode,
|
|
surface=surface,
|
|
subjects_dir=subjects_dir,
|
|
src=src,
|
|
project=True,
|
|
)
|
|
|
|
stc = stc_near_sensors(picks=picks, **kwargs, verbose=verbose) # type: ignore
|
|
|
|
|
|
from mne import SourceEstimate
|
|
assert isinstance(stc, SourceEstimate) # or your specific subclass
|
|
|
|
# Produce brain plot
|
|
brain: Brain = stc.plot( # type: ignore
|
|
src=src,
|
|
subjects_dir=subjects_dir,
|
|
hemi=hemi,
|
|
surface=surface,
|
|
initial_time=0,
|
|
clim=clim, # type: ignore
|
|
size=size,
|
|
colormap=colormap,
|
|
figure=figure,
|
|
background=background,
|
|
colorbar=colorbar,
|
|
verbose=verbose,
|
|
)
|
|
if view is not None:
|
|
brain.show_view(view) # type: ignore
|
|
|
|
return brain
|
|
|
|
|
|
|
|
def brain_3d_visualization(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], all_haemo: dict[str, dict[str, BaseRaw]], participant_number: int, t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True) -> None:
|
|
|
|
|
|
# Determine if we are visualizing t or theta to set the appropriate limit
|
|
if t_or_theta == 't':
|
|
clim = dict(kind="value", pos_lims=(0, ABS_T_VALUE/2, ABS_T_VALUE))
|
|
elif t_or_theta == 'theta':
|
|
clim = dict(kind="value", pos_lims=(0, ABS_THETA_VALUE/2, ABS_THETA_VALUE))
|
|
|
|
|
|
# Loop over all groups
|
|
for index, group_name in enumerate(all_results):
|
|
|
|
# We only care for their channel results
|
|
(_, df_cha, _, _) = all_results[group_name]
|
|
|
|
# Get all activity conditions
|
|
for cond in [TARGET_ACTIVITY]:
|
|
|
|
if HRF_MODEL == 'fir':
|
|
ch_summary = df_cha.query(f"Condition.str.startswith('{cond}_delay_') and Chroma == 'hbo'", engine='python') # type: ignore
|
|
|
|
else:
|
|
# Filter for the condition and chromophore
|
|
ch_summary = df_cha.query("Condition in [@cond] and Chroma == 'hbo'") # type: ignore
|
|
|
|
# Determine number of unique participants based on their ID
|
|
n_participants = ch_summary["ID"].nunique()
|
|
|
|
# WE JUST NEED SOMEONES OPTODE DATA TO PLOT ON THE BRAIN!
|
|
# TODO: This should take the average positions of all participants
|
|
# We will just take the passed through parameter
|
|
participant_to_plot = ch_summary["ID"].unique()[participant_number] # type: ignore
|
|
participant_raw_full: BaseRaw = all_haemo[participant_to_plot]["full_layout"]
|
|
|
|
# Use ordinary least squares (OLS) if only one participant
|
|
if n_participants == 1:
|
|
|
|
# t values
|
|
if t_or_theta == 't':
|
|
ch_model = smf.ols("t ~ -1 + ch_name", ch_summary).fit() # type: ignore
|
|
|
|
# theta values
|
|
elif t_or_theta == 'theta':
|
|
ch_model = smf.ols("theta ~ -1 + ch_name", ch_summary).fit() # type: ignore
|
|
|
|
logger.info("OLS model is being used as there is only one participant!")
|
|
|
|
# Use mixed effects model if there is multiple participants
|
|
else:
|
|
|
|
# t values
|
|
if t_or_theta == 't':
|
|
ch_model = smf.mixedlm("t ~ -1 + ch_name", ch_summary, groups=ch_summary["ID"]).fit(method="nm") # type: ignore
|
|
|
|
# theta values
|
|
elif t_or_theta == 'theta':
|
|
ch_model = smf.mixedlm("theta ~ -1 + ch_name", ch_summary, groups=ch_summary["ID"]).fit(method="nm") # type: ignore
|
|
|
|
# Convert model results
|
|
model_df = cast(DataFrame, statsmodels_to_results(ch_model, order=ch_summary["ch_name"].unique())) # type: ignore
|
|
|
|
valid_channels = ch_summary["ch_name"].unique().tolist() # type: ignore
|
|
raw_for_plot = participant_raw_full.copy().pick(picks=valid_channels) # type: ignore
|
|
|
|
|
|
brain = plot_3d_evoked_array(raw_for_plot.pick(picks="hbo"), model_df, view="dorsal", distance=BRAIN_DISTANCE, colorbar=True, clim=clim, mode=BRAIN_MODE, size=(800, 700)) # type: ignore
|
|
|
|
if show_optodes == 'all' or show_optodes == 'sensors':
|
|
brain.add_sensors(getattr(raw_for_plot, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=VERBOSITY) # type: ignore
|
|
|
|
|
|
# Read and parse the file
|
|
if show_optodes == 'all' or show_optodes == 'labels':
|
|
positions: list[tuple[str, list[float]]] = []
|
|
with open(OPTODE_FILE_PATH, 'r') as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line or ':' not in line:
|
|
continue # skip empty/malformed lines
|
|
name, coords = line.split(':', 1)
|
|
coords = [float(x) for x in coords.strip().split()]
|
|
positions.append((name.strip(), coords))
|
|
|
|
for name, (x, y, z) in positions:
|
|
brain._renderer.text3d(x, y, z, name, color=('red' if name.startswith('s') else 'blue' if name.startswith('d') else 'gray'), scale=0.002) # type: ignore
|
|
|
|
# Set the display text for the brain image
|
|
# display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nGroup: ' + group_name + '\nCondition: '+ cond + '\nReject Criteria Threshold: ' + str(EPOCH_REJECT_CRITERIA_THRESH) + '\nMin Time Threshold: '
|
|
# + str(TIME_MIN_THRESH) + 's\nMax Time Threshold: ' + str(TIME_MAX_THRESH) + 's\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) + '\nStim Duration: '
|
|
# + str(STIM_DURATION[index]) + 's\nLooking at: ' + t_or_theta + ' values') + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE
|
|
if HRF_MODEL == 'fir':
|
|
display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nGroup: ' + group_name + '\nCondition: '+ cond + '\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION)
|
|
+ '\nLooking at: ' + t_or_theta + ' values') + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE
|
|
else:
|
|
display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nGroup: ' + group_name + '\nCondition: '+ cond + '\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) + '\nStim Duration: '
|
|
+ str(STIM_DURATION[index]) + '\nLooking at: ' + t_or_theta + ' values') + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE
|
|
|
|
# Apply the text onto the brain
|
|
if show_text:
|
|
brain.add_text(0.12, 0.64, display_text, "title", font_size=11, color="k") # type: ignore
|
|
|
|
|
|
|
|
def plot_fir_model_results(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], all_haemo: dict[str, dict[str, BaseRaw]], participant_number: int, t_or_theta: Literal['t', 'theta'] = 'theta') -> None:
|
|
|
|
|
|
if HRF_MODEL != 'fir':
|
|
logger.info("This method only works when HRF_MODEL is set to 'fir'.")
|
|
|
|
else:
|
|
for group_name in all_results:
|
|
|
|
(df_roi, _, _, df_design_matrix) = all_results[group_name]
|
|
first_id = df_design_matrix["ID"].unique()[participant_number] # type: ignore
|
|
first_dm = df_design_matrix.query(f"ID == '{first_id}'").copy() # type: ignore
|
|
first_dm.index = np.round([0.1 * i for i in range(len(first_dm))], decimals=1) # type: ignore
|
|
df_design_matrix = first_dm
|
|
|
|
participant = all_haemo[first_id]["full_layout"]
|
|
|
|
df_roi["isActivity"] = [TARGET_ACTIVITY in n for n in df_roi["Condition"]] # type: ignore
|
|
df_roi["isDelay"] = ["delay" in n for n in df_roi["Condition"]] # type: ignore
|
|
df_roi = df_roi.query("isDelay in [True]") # type: ignore
|
|
df_roi = df_roi.query("isActivity in [True]") # type: ignore
|
|
|
|
|
|
df_roi.loc[:, "TidyCond"] = ""
|
|
df_roi.loc[df_roi["isActivity"] == True, "TidyCond"] = TARGET_ACTIVITY # noqa: E712
|
|
# Finally, extract the FIR delay in to its own column in data frame
|
|
df_roi.loc[:, "delay"] = [n.split("_")[-1] for n in df_roi.Condition] # type: ignore
|
|
|
|
if t_or_theta == 'theta':
|
|
lme = smf.mixedlm("theta ~ -1 + delay:TidyCond:Chroma", df_roi, groups=df_roi["ID"]).fit() # type: ignore
|
|
elif t_or_theta == 't':
|
|
lme = smf.mixedlm("t ~ -1 + delay:TidyCond:Chroma", df_roi, groups=df_roi["ID"]).fit() # type: ignore
|
|
|
|
df_sum: DataFrame = statsmodels_to_results(lme) # type: ignore
|
|
|
|
|
|
df_sum["delay"] = [int(n) for n in df_sum["delay"]] # type: ignore
|
|
df_sum = df_sum.sort_values("delay") # type: ignore
|
|
|
|
# logger.info the result for the oxyhaemoglobin data in the Reach condition
|
|
df_sum.query(f"TidyCond in ['{TARGET_ACTIVITY}']").query("Chroma in ['hbo']") # type: ignore
|
|
|
|
axes1: list[Axes]
|
|
fig, axes1 = plt.subplots(nrows=1, ncols=3, figsize=(20, 10)) # type: ignore
|
|
|
|
# Extract design matrix columns that correspond to the condition of interest
|
|
dm_cond_idxs = np.where([TARGET_ACTIVITY in n for n in df_design_matrix.columns])[0]
|
|
dm_cond_colnames: list[str] = [df_design_matrix.columns[i] for i in dm_cond_idxs]
|
|
dm_cond = df_design_matrix[dm_cond_colnames]
|
|
|
|
# 2. Extract hbo GLM estimates
|
|
df_hbo = df_sum.query(f"TidyCond in ['{TARGET_ACTIVITY}']").query("Chroma in ['hbo']") # type: ignore
|
|
vals_hbo = [float(v) for v in df_hbo["Coef."]] # type: ignore
|
|
|
|
dm_cond_scaled_hbo = dm_cond * vals_hbo
|
|
|
|
# 3. Extract hbr GLM estimates
|
|
df_hbr = df_sum.query(f"TidyCond in ['{TARGET_ACTIVITY}']").query("Chroma in ['hbr']") # type: ignore
|
|
vals_hbr = [float(v) for v in df_hbr["Coef."]] # type: ignore
|
|
|
|
dm_cond_scaled_hbr = dm_cond * vals_hbr
|
|
|
|
# Extract the time scale for plotting.
|
|
# Set time zero to be the onset.
|
|
index_values = cast(NDArray[float64], dm_cond_scaled_hbo.index.to_numpy(dtype=float) - participant.annotations.onset[1]) # type: ignore
|
|
|
|
# Plot the result
|
|
axes1[0].plot(index_values, np.asarray(dm_cond)) # type: ignore
|
|
axes1[1].plot(index_values, np.asarray(dm_cond_scaled_hbo)) # type: ignore
|
|
axes1[2].plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r") # type: ignore
|
|
axes1[2].plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b") # type: ignore
|
|
|
|
|
|
valid_mask = (index_values >= 0) & (index_values <= 15)
|
|
hbo_sum_window = np.sum(dm_cond_scaled_hbo.loc[valid_mask, :], axis=1)
|
|
peak_idx_in_window = np.argmax(hbo_sum_window)
|
|
peak_idx = np.where(valid_mask)[0][peak_idx_in_window]
|
|
peak_time = float(round(index_values[peak_idx], 2)) # type: ignore
|
|
|
|
axes1[2].axvline(x=peak_time, color='k', linestyle='--', linewidth=1.5, label='Peak time') # type: ignore
|
|
|
|
# Format the plot
|
|
for ax in range(3):
|
|
axes1[ax].set_xlim(-5, 20)
|
|
axes1[ax].set_xlabel("Time (s)") # type: ignore
|
|
axes1[0].set_ylim(-0.2, 1.2)
|
|
axes1[1].set_ylim(-4, 8)
|
|
axes1[2].set_ylim(-4, 8)
|
|
axes1[0].set_title(f"FIR Model for {group_name} (Unscaled by GLM {TARGET_ACTIVITY} estimates) ({t_or_theta})") # type: ignore
|
|
axes1[1].set_title(f"FIR Components for {group_name} (Scaled by GLM {TARGET_ACTIVITY} estimates) ({t_or_theta})") # type: ignore
|
|
axes1[2].set_title(f"Evoked Response for {group_name} ({TARGET_ACTIVITY}) ({t_or_theta})") # type: ignore
|
|
axes1[0].set_ylabel("FIR Model") # type: ignore
|
|
axes1[1].set_ylabel("Oyxhaemoglobin (ΔμMol)") # type: ignore
|
|
axes1[2].set_ylabel("Haemoglobin (ΔμMol)") # type: ignore
|
|
axes1[2].legend(["Oyxhaemoglobin", "Deoyxhaemoglobin", f"Peak {peak_time}s"]) # type: ignore
|
|
|
|
fig.tight_layout()
|
|
|
|
|
|
# We can also extract the 95% confidence intervals of the estimates too
|
|
l95_hbo = [float(v) for v in df_hbo["[0.025"]] # type: ignore
|
|
u95_hbo = [float(v) for v in df_hbo["0.975]"]] # type: ignore
|
|
dm_cond_scaled_hbo_l95 = dm_cond * l95_hbo
|
|
dm_cond_scaled_hbo_u95 = dm_cond * u95_hbo
|
|
l95_hbr = [float(v) for v in df_hbr["[0.025"]] # type: ignore
|
|
u95_hbr = [float(v) for v in df_hbr["0.975]"]] # type: ignore
|
|
dm_cond_scaled_hbr_l95 = dm_cond * l95_hbr
|
|
dm_cond_scaled_hbr_u95 = dm_cond * u95_hbr
|
|
|
|
axes2: Axes
|
|
fig, axes2 = plt.subplots(nrows=1, ncols=1, figsize=(7, 7)) # type: ignore
|
|
|
|
# Plot the result
|
|
axes2.plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r") # type: ignore
|
|
axes2.plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b") # type: ignore
|
|
axes2.axvline(x=peak_time, color='k', linestyle='--', linewidth=1.5, label='Peak time') # type: ignore
|
|
|
|
axes2.fill_between( # type: ignore
|
|
index_values,
|
|
np.asarray(np.sum(dm_cond_scaled_hbo_l95, axis=1)),
|
|
np.asarray(np.sum(dm_cond_scaled_hbo_u95, axis=1)),
|
|
facecolor="red",
|
|
alpha=0.25,
|
|
)
|
|
axes2.fill_between( # type: ignore
|
|
index_values,
|
|
np.asarray(np.sum(dm_cond_scaled_hbr_l95, axis=1)),
|
|
np.asarray(np.sum(dm_cond_scaled_hbr_u95, axis=1)),
|
|
facecolor="blue",
|
|
alpha=0.25,
|
|
)
|
|
|
|
# Format the plot
|
|
axes2.set_xlim(-5, 20)
|
|
axes2.set_ylim(-8, 12)
|
|
axes2.set_title(f"Evoked Response with 95% confidence intervals for {group_name} ({TARGET_ACTIVITY}) ({t_or_theta})") # type: ignore
|
|
axes2.set_ylabel("Haemoglobin (ΔμMol)") # type: ignore
|
|
axes2.legend(["Oyxhaemoglobin", "Deoyxhaemoglobin", f"Peak {peak_time}s"]) # type: ignore
|
|
axes2.set_xlabel("Time (s)") # type: ignore
|
|
|
|
fig.tight_layout()
|
|
|
|
plt.show() # type: ignore
|
|
|
|
|
|
|
|
def plot_2d_theta_graph(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]) -> None:
|
|
'''This method will create a 2d boxplot showing the theta values for each channel and group as independent ranges on the same graph.\n
|
|
Inputs:\n
|
|
all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n
|
|
'''
|
|
|
|
|
|
# Create a list to store the channel results of all groups
|
|
df_all_cha_list: list[DataFrame] = []
|
|
|
|
# Iterate over each group in all_results
|
|
for group_name, (_, df_cha, _, _) in all_results.items():
|
|
df_cha["group"] = group_name # Add the group name to the data
|
|
df_all_cha_list.append(df_cha) # Append the dataframe to the list
|
|
|
|
# Combine all the data into a single DataFrame
|
|
df_all_cha = pd.concat(df_all_cha_list, ignore_index=True)
|
|
|
|
# Filter for the target activity
|
|
if HRF_MODEL == 'fir':
|
|
df_target = df_all_cha[df_all_cha['Condition'].str.startswith(f"{TARGET_ACTIVITY}_delay_")] # type: ignore
|
|
else:
|
|
df_target = df_all_cha[df_all_cha["Condition"] == TARGET_ACTIVITY]
|
|
|
|
# Get the number of unique groups to know how many colors are needed for the boxplot
|
|
unique_groups = df_target["group"].nunique()
|
|
palette = sns.color_palette("Set2", unique_groups)
|
|
|
|
# Create the boxplot
|
|
fig = plt.figure(figsize=(15, 6)) # type: ignore
|
|
sns.boxplot(
|
|
data=df_target,
|
|
x="ch_name",
|
|
y="theta",
|
|
hue="group",
|
|
palette=palette
|
|
)
|
|
|
|
# Format the boxplot
|
|
plt.title("Theta Coefficients by Channel and Group") # type: ignore
|
|
plt.xticks(rotation=90) # type: ignore
|
|
plt.ylabel("Theta (µM)") # type: ignore
|
|
plt.xlabel("Channel") # type: ignore
|
|
plt.legend(title="Group") # type: ignore
|
|
plt.tight_layout()
|
|
plt.show() # type: ignore
|
|
|
|
|
|
|
|
def plot_individual_theta_averages(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]) -> None:
|
|
|
|
if HRF_MODEL == 'fir':
|
|
logger.info("This method does not work when HRF_MODEL is set to 'fir'.")
|
|
return
|
|
else:
|
|
# Iterate over all the groups
|
|
for group_name in all_results:
|
|
|
|
# Store the region of interest data
|
|
(df_roi, _, _, _) = all_results[group_name]
|
|
|
|
# Filter the results down to what we want
|
|
grp_results = df_roi.query(f"Condition in ['{TARGET_ACTIVITY}', '{TARGET_CONTROL}']").copy() # type: ignore
|
|
grp_results = grp_results.query("Chroma in ['hbo']").copy() # type: ignore
|
|
|
|
# Rename the ROI's to be the friendly name
|
|
roi_label_map = {
|
|
"group_1_picks": ROI_GROUP_1_NAME,
|
|
"group_2_picks": ROI_GROUP_2_NAME,
|
|
}
|
|
grp_results["ROI"] = grp_results["ROI"].replace(roi_label_map) # type: ignore
|
|
|
|
# Create the catplot
|
|
sns.catplot(
|
|
x="Condition",
|
|
y="theta",
|
|
col="ID",
|
|
hue="ROI",
|
|
data=grp_results,
|
|
col_wrap=5,
|
|
errorbar=None,
|
|
palette="muted",
|
|
height=4,
|
|
s=10,
|
|
dodge=False,
|
|
)
|
|
|
|
plt.show() # type: ignore
|
|
|
|
|
|
|
|
def plot_group_theta_averages(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]) -> None:
|
|
'''This method will create a stripplot showing the theta vaules for each region of interest for each group.\n
|
|
Inputs:\n
|
|
all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n'''
|
|
|
|
if HRF_MODEL == 'fir':
|
|
logger.info("This method does not work when HRF_MODEL is set to 'fir'.")
|
|
return
|
|
else:
|
|
# Rename the ROI's to be the friendly name
|
|
roi_label_map = {
|
|
"group_1_picks": ROI_GROUP_1_NAME,
|
|
"group_2_picks": ROI_GROUP_2_NAME,
|
|
}
|
|
|
|
# Setup subplot grid
|
|
n = len(all_results)
|
|
ncols = 2
|
|
nrows = (n + 1) // ncols # round up
|
|
fig, axes = cast(tuple[Figure, np.ndarray[Any, Any]], plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 5 * nrows), squeeze=False)) # type: ignore
|
|
|
|
index = -1
|
|
# Iterate over all groups
|
|
for index, (group_name, ax) in enumerate(zip(all_results, axes.flatten())):
|
|
|
|
# Store the region of interest data
|
|
(df_roi, _, _, _) = all_results[group_name]
|
|
|
|
# Filter the results down to what we want
|
|
grp_results = df_roi.query(f"Condition in ['{TARGET_ACTIVITY}', '{TARGET_CONTROL}']").copy() # type: ignore
|
|
|
|
# Run a mixedlm model on the data
|
|
roi_model = smf.mixedlm("theta ~ -1 + ROI:Condition:Chroma", grp_results, groups=grp_results["ID"]).fit(method="nm") # type: ignore
|
|
|
|
# Apply the new friendly names on to the data
|
|
df = cast(DataFrame, statsmodels_to_results(roi_model))
|
|
df["ROI"] = df["ROI"].map(roi_label_map) # type: ignore
|
|
|
|
# Create a stripplot:
|
|
sns.stripplot(
|
|
x="Condition",
|
|
y="Coef.",
|
|
hue="ROI",
|
|
data=df.query("Chroma == 'hbo'"), # type: ignore
|
|
dodge=False,
|
|
jitter=False,
|
|
size=5,
|
|
palette="muted",
|
|
ax=ax,
|
|
)
|
|
|
|
# Format the stripplot
|
|
ax.set_title(f"Results for {group_name}")
|
|
ax.legend(title="ROI", loc="upper right")
|
|
|
|
if index == -1:
|
|
# No groups, so remove all axes
|
|
for ax in axes.flatten():
|
|
fig.delaxes(ax)
|
|
# Remove any unused axes and apply final touches
|
|
else:
|
|
for j in range(index + 1, len(axes.flatten())):
|
|
fig.delaxes(axes.flatten()[j])
|
|
|
|
fig.tight_layout()
|
|
fig.suptitle("Theta Averages Across Groups", fontsize=16, y=1.02) # type: ignore
|
|
plt.show() # type: ignore
|
|
|
|
|
|
|
|
def compute_p_group_stats(df_cha: DataFrame, bad_pairs: set[tuple[int, int]], t_or_theta: Literal['t', 'theta'] = 't') -> DataFrame:
|
|
|
|
if HRF_MODEL == 'fir':
|
|
# Filter: All delays for the target activity
|
|
df_activity = df_cha[df_cha['Condition'].str.startswith(f"{TARGET_ACTIVITY}_delay_") & (df_cha['Chroma'] == 'hbo')] # type: ignore
|
|
|
|
# Aggregate across FIR delays *per subject* for each channel
|
|
df_agg = (df_activity.groupby(['Source', 'Detector', 'ID'])[['t', 'theta']].mean().reset_index()) # type: ignore
|
|
|
|
else:
|
|
# Canonical HRF case
|
|
df_agg = df_cha[(df_cha['Condition'] == TARGET_ACTIVITY) & (df_cha['Chroma'] == 'hbo')].copy()
|
|
|
|
# Filter the channel data down to what we want
|
|
grouped = cast(Iterator[tuple[tuple[int, int], Any]], df_agg.groupby(['Source', 'Detector'])) # type: ignore
|
|
|
|
# Create an empty list to store the data for our result
|
|
data: list[dict[str, Any]] = []
|
|
|
|
# Iterate over the filtered channel data
|
|
for (src, det), group in grouped:
|
|
|
|
# If it is a bad channel pairing, do not process it
|
|
if (src, det) in bad_pairs:
|
|
logger.info(f"Skipping bad channel Source {src} - Detector {det}")
|
|
continue
|
|
|
|
# Drop any missing values that could exist
|
|
t_values = group['t'].dropna().values
|
|
t_values = np.array(t_values, dtype=float)
|
|
theta_values = group['theta'].dropna().values
|
|
theta_values = np.array(theta_values, dtype=float)
|
|
|
|
# Ensure that we still have our two t values, otherwise do not process this pairing
|
|
# TODO: is the t values throwing a warning good enough?
|
|
if len(t_values) < 2:
|
|
logger.info(f"Skipping Source {src} - Detector {det}: not enough data (n={len(t_values)})")
|
|
continue
|
|
|
|
|
|
# NOTE: This is still calculated with t values as it is a t-test
|
|
# Perform one-sample t-test on t-values across subjects
|
|
shitte, pval = ttest_1samp(t_values, popmean=0)
|
|
|
|
print(shitte)
|
|
|
|
# Store all of the data for this ttest using the mean t-value for visualization
|
|
if t_or_theta == 't':
|
|
data.append({
|
|
'Source': src,
|
|
'Detector': det,
|
|
't_or_theta': np.mean(t_values),
|
|
'p_value': pval
|
|
})
|
|
|
|
else:
|
|
data.append({
|
|
'Source': src,
|
|
'Detector': det,
|
|
't_or_theta': np.mean(theta_values),
|
|
'p_value': pval
|
|
})
|
|
|
|
# Create a DataFrame with the data and ensure it is not empty
|
|
result = DataFrame(data)
|
|
if result.empty:
|
|
logger.info("No valid channel pairs with enough data for group-level testing.")
|
|
|
|
return result
|
|
|
|
|
|
|
|
def get_bad_src_det_pairs(raw: BaseRaw) -> set[tuple[int, int]]:
|
|
'''This method figures out the bad source and detector pairings for the 2d t+p graph to prevent them from being plotted.
|
|
Inputs:\n
|
|
raw (RawSNIRF) - Contains all the snirf data for the last participant processed. Only used to get the channels\n
|
|
Outputs:\n
|
|
bad_pairs (set) - Set containing all of the bad pairs of sources and detectors'''
|
|
|
|
# Create a set to store the bad pairs
|
|
bad_pairs: set[tuple[int, int]] = set()
|
|
|
|
# Iterate over all the channels in bads key
|
|
for ch_name in getattr(raw, "info")["bads"]:
|
|
try:
|
|
# Get all characters before the space
|
|
parts = ch_name.split()[0]
|
|
|
|
# Split with the separator
|
|
src_str, det_str = parts.split(SOURCE_DETECTOR_SEPARATOR)
|
|
src = int(src_str[1:])
|
|
det = int(det_str[1:])
|
|
|
|
# Add to the set
|
|
bad_pairs.add((src, det))
|
|
|
|
except Exception as e:
|
|
logger.info(f"Could not parse bad channel '{ch_name}': {e}")
|
|
|
|
return bad_pairs
|
|
|
|
|
|
|
|
def plot_avg_significant_activity(raw: BaseRaw, all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], t_or_theta: Literal['t', 'theta'] = 't') -> None:
|
|
'''This method plots the average t values for the groups on a 2D graph. p values less than or equal to P_THRESHOLD are solid lines, while greater p values are dashed lines.\n
|
|
Inputs:\n
|
|
raw (RawSNIRF) - Contains all the snirf data for the last participant processed. Only used to get the channel locations.\n
|
|
all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n'''
|
|
|
|
# Iterate over all the groups
|
|
for group_name in all_results:
|
|
(_, df_cha, _, _) = all_results[group_name]
|
|
|
|
if HRF_MODEL == 'fir':
|
|
mask = df_cha['Condition'].str.startswith(f"{TARGET_ACTIVITY}_delay_") & (df_cha['Chroma'] == 'hbo') # type: ignore
|
|
filtered_df = df_cha[mask]
|
|
num_tests = filtered_df.groupby(['Source', 'Detector']).ngroups # type: ignore
|
|
else:
|
|
num_tests = len(cast(Iterator[tuple[tuple[int, int], Any]], df_cha.query(f"Condition == '{TARGET_ACTIVITY}' and Chroma == 'hbo'").groupby(['Source', 'Detector']))) # type: ignore
|
|
|
|
logger.info(f"Number of tests: {num_tests}")
|
|
|
|
# Compute average t-value across individuals for each channel pairing
|
|
bad_pairs = get_bad_src_det_pairs(raw)
|
|
avg_df = compute_p_group_stats(df_cha, bad_pairs, t_or_theta)
|
|
|
|
logger.info(f"Average {t_or_theta}-values and p-values for {TARGET_ACTIVITY}:")
|
|
for _, row in avg_df.iterrows(): # type: ignore
|
|
logger.info(f"Source {row['Source']} <-> Detector {row['Detector']}: "
|
|
f"Avg {t_or_theta}-value = {row['t_or_theta']:.3f}, Avg p-value = {row['p_value']:.3f}")
|
|
|
|
# Extract the cource and detector positions from raw
|
|
src_pos: dict[int, tuple[float, float]] = {}
|
|
det_pos: dict[int, tuple[float, float]] = {}
|
|
for ch in getattr(raw, "info")["chs"]:
|
|
ch_name = ch['ch_name']
|
|
if not ch_name or not ch['loc'].any():
|
|
continue
|
|
parts = ch_name.split()[0]
|
|
src_str, det_str = parts.split(SOURCE_DETECTOR_SEPARATOR)
|
|
src_num = int(src_str[1:])
|
|
det_num = int(det_str[1:])
|
|
src_pos[src_num] = ch['loc'][3:5]
|
|
det_pos[det_num] = ch['loc'][6:8]
|
|
|
|
# Set up the plot
|
|
fig, ax = plt.subplots(figsize=(8, 6)) # type: ignore
|
|
|
|
# Plot the sources
|
|
for pos in src_pos.values():
|
|
ax.scatter(pos[0], pos[1], s=120, c='k', marker='o', edgecolors='white', linewidths=1, zorder=3) # type: ignore
|
|
|
|
# Plot the detectors
|
|
for pos in det_pos.values():
|
|
ax.scatter(pos[0], pos[1], s=120, c='k', marker='s', edgecolors='white', linewidths=1, zorder=3) # type: ignore
|
|
|
|
# Ensure that the colors stay within the boundaries even if they are over or under the max/min values
|
|
if t_or_theta == 't':
|
|
norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_T_VALUE, vmax=ABS_SIGNIFICANCE_T_VALUE)
|
|
elif t_or_theta == 'theta':
|
|
norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_THETA_VALUE, vmax=ABS_SIGNIFICANCE_THETA_VALUE)
|
|
|
|
cmap: mcolors.Colormap = plt.get_cmap('seismic')
|
|
|
|
# Plot connections with avg t-values
|
|
for row in avg_df.itertuples():
|
|
src: int = cast(int, row.Source) # type: ignore
|
|
det: int = cast(int, row.Detector) # type: ignore
|
|
tval: float = cast(float, row.t_or_theta) # type: ignore
|
|
pval: float = cast(float, row.p_value) # type: ignore
|
|
|
|
|
|
if src in src_pos and det in det_pos:
|
|
x = [src_pos[src][0], det_pos[det][0]]
|
|
y = [src_pos[src][1], det_pos[det][1]]
|
|
style = '-' if pval <= P_THRESHOLD else '--'
|
|
ax.plot(x, y, linestyle=style, color=cmap(norm(tval)), linewidth=4, alpha=0.9, zorder=2) # type: ignore
|
|
|
|
# Format the Colorbar
|
|
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
|
sm.set_array([])
|
|
cbar = plt.colorbar(sm, ax=ax, shrink=0.85) # type: ignore
|
|
cbar.set_label(f'Average {TARGET_ACTIVITY} {t_or_theta} value (hbo)', fontsize=11) # type: ignore
|
|
|
|
# Formatting the subplots
|
|
ax.set_aspect('equal')
|
|
ax.set_title(f"Average {t_or_theta} values for {TARGET_ACTIVITY} (HbO) for {group_name}", fontsize=14) # type: ignore
|
|
ax.set_xlabel('X position (m)', fontsize=11) # type: ignore
|
|
ax.set_ylabel('Y position (m)', fontsize=11) # type: ignore
|
|
ax.grid(True, alpha=0.3) # type: ignore
|
|
|
|
# Set axis limits to be 1cm more than the optode positions
|
|
all_x = [pos[0] for pos in src_pos.values()] + [pos[0] for pos in det_pos.values()]
|
|
all_y = [pos[1] for pos in src_pos.values()] + [pos[1] for pos in det_pos.values()]
|
|
ax.set_xlim(min(all_x)-0.01, max(all_x)+0.01)
|
|
ax.set_ylim(min(all_y)-0.01, max(all_y)+0.01)
|
|
|
|
fig.tight_layout()
|
|
plt.show() # type: ignore
|
|
|
|
|
|
|
|
def generate_montage_locations():
|
|
"""Get standard MNI montage locations in dataframe.
|
|
|
|
Data is returned in the same format as the eeg_positions library.
|
|
"""
|
|
# standard_1020 and standard_1005 are in MNI (fsaverage) space already,
|
|
# but we need to undo the scaling that head_scale will do
|
|
montage = mne.channels.make_standard_montage(
|
|
"standard_1005", head_size=0.09700884729534559
|
|
)
|
|
for d in montage.dig:
|
|
d["coord_frame"] = 2003
|
|
montage.dig[:] = montage.dig[3:]
|
|
montage.add_mni_fiducials() # now in fsaverage space
|
|
coords = pd.DataFrame.from_dict(montage.get_positions()["ch_pos"]).T
|
|
coords["label"] = coords.index
|
|
coords = coords.rename(columns={0: "x", 1: "y", 2: "z"})
|
|
|
|
return coords.reset_index(drop=True)
|
|
|
|
|
|
def _find_closest_standard_location(position, reference, *, out="label"):
|
|
"""Return closest montage label to coordinates.
|
|
|
|
Parameters
|
|
----------
|
|
position : array, shape (3,)
|
|
Coordinates.
|
|
reference : dataframe
|
|
As generated by _generate_montage_locations.
|
|
trans_pos : str
|
|
Apply a transformation to positions to specified frame.
|
|
Use None for no transformation.
|
|
"""
|
|
from scipy.spatial.distance import cdist
|
|
|
|
p0 = np.array(position)
|
|
p0.shape = (-1, 3)
|
|
# head_mri_t, _ = _get_trans("fsaverage", "head", "mri")
|
|
# p0 = apply_trans(head_mri_t, p0)
|
|
dists = cdist(p0, np.asarray(reference[["x", "y", "z"]], float))
|
|
|
|
if out == "label":
|
|
min_idx = np.argmin(dists)
|
|
return reference["label"][min_idx]
|
|
else:
|
|
assert out == "dists"
|
|
return dists
|
|
|
|
|
|
|
|
def _source_detector_fold_table(raw, cidx, reference, fold_tbl, interpolate):
|
|
src = raw.info["chs"][cidx]["loc"][3:6]
|
|
det = raw.info["chs"][cidx]["loc"][6:9]
|
|
|
|
ref_lab = list(reference["label"])
|
|
dists = _find_closest_standard_location([src, det], reference, out="dists")
|
|
src_min, det_min = np.argmin(dists, axis=1)
|
|
src_name, det_name = ref_lab[src_min], ref_lab[det_min]
|
|
|
|
tbl = fold_tbl.query("Source == @src_name and Detector == @det_name")
|
|
dist = np.linalg.norm(dists[[0, 1], [src_min, det_min]])
|
|
# Try reversing source and detector
|
|
if len(tbl) == 0:
|
|
tbl = fold_tbl.query("Source == @det_name and Detector == @src_name")
|
|
if len(tbl) == 0 and interpolate:
|
|
# Try something hopefully not too terrible: pick the one with the
|
|
# smallest net distance
|
|
good = np.isin(fold_tbl["Source"], reference["label"]) & np.isin(
|
|
fold_tbl["Detector"], reference["label"]
|
|
)
|
|
assert good.any()
|
|
tbl = fold_tbl[good]
|
|
assert len(tbl)
|
|
src_idx = [ref_lab.index(src) for src in tbl["Source"]]
|
|
det_idx = [ref_lab.index(det) for det in tbl["Detector"]]
|
|
# Original
|
|
tot_dist = np.linalg.norm([dists[0, src_idx], dists[1, det_idx]], axis=0)
|
|
assert tot_dist.shape == (len(tbl),)
|
|
idx = np.argmin(tot_dist)
|
|
dist_1 = tot_dist[idx]
|
|
src_1, det_1 = ref_lab[src_idx[idx]], ref_lab[det_idx[idx]]
|
|
# And the reverse
|
|
tot_dist = np.linalg.norm([dists[0, det_idx], dists[1, src_idx]], axis=0)
|
|
idx = np.argmin(tot_dist)
|
|
dist_2 = tot_dist[idx]
|
|
src_2, det_2 = ref_lab[det_idx[idx]], ref_lab[src_idx[idx]]
|
|
if dist_1 < dist_2:
|
|
new_dist, src_use, det_use = dist_1, src_1, det_1
|
|
else:
|
|
new_dist, src_use, det_use = dist_2, det_2, src_2
|
|
|
|
|
|
tbl = fold_tbl.query("Source == @src_use and Detector == @det_use")
|
|
tbl = tbl.copy()
|
|
tbl["BestSource"] = src_name
|
|
tbl["BestDetector"] = det_name
|
|
tbl["BestMatchDistance"] = dist
|
|
tbl["MatchDistance"] = new_dist
|
|
assert len(tbl)
|
|
else:
|
|
tbl = tbl.copy()
|
|
tbl["BestSource"] = src_name
|
|
tbl["BestDetector"] = det_name
|
|
tbl["BestMatchDistance"] = dist
|
|
tbl["MatchDistance"] = dist
|
|
|
|
tbl = tbl.copy() # don't get warnings about setting values later
|
|
return tbl
|
|
|
|
from mne.utils import _check_fname, _validate_type, warn
|
|
|
|
|
|
def _read_fold_xls(fname, atlas="Juelich"):
|
|
"""Read fOLD toolbox xls file.
|
|
|
|
The values are then manipulated in to a tidy dataframe.
|
|
|
|
Note the xls files are not included as no license is provided.
|
|
|
|
Parameters
|
|
----------
|
|
fname : str
|
|
Path to xls file.
|
|
atlas : str
|
|
Requested atlas.
|
|
"""
|
|
page_reference = {"AAL2": 2, "AICHA": 5, "Brodmann": 8, "Juelich": 11, "Loni": 14}
|
|
|
|
tbl = pd.read_excel(fname, sheet_name=page_reference[atlas])
|
|
|
|
# Remove the spacing between rows
|
|
empty_rows = np.where(np.isnan(tbl["Specificity"]))[0]
|
|
tbl = tbl.drop(empty_rows).reset_index(drop=True)
|
|
|
|
# Empty values in the table mean its the same as above
|
|
for row_idx in range(1, tbl.shape[0]):
|
|
for col_idx, col in enumerate(tbl.columns):
|
|
if not isinstance(tbl[col][row_idx], str):
|
|
if np.isnan(tbl[col][row_idx]):
|
|
tbl.iloc[row_idx, col_idx] = tbl.iloc[row_idx - 1, col_idx]
|
|
|
|
tbl["Specificity"] = tbl["Specificity"] * 100
|
|
tbl["brainSens"] = tbl["brainSens"] * 100
|
|
return tbl
|
|
|
|
import os.path as op
|
|
|
|
def _check_load_fold(fold_files, atlas):
|
|
# _validate_type(fold_files, (list, "path-like", None), "fold_files")
|
|
if fold_files is None:
|
|
fold_files = mne.get_config("MNE_NIRS_FOLD_PATH")
|
|
if fold_files is None:
|
|
raise ValueError(
|
|
"MNE_NIRS_FOLD_PATH not set, either set it using "
|
|
"mne.set_config or pass fold_files as str or list"
|
|
)
|
|
if not isinstance(fold_files, list): # path-like
|
|
fold_files = _check_fname(
|
|
fold_files,
|
|
overwrite="read",
|
|
must_exist=True,
|
|
name="fold_files",
|
|
need_dir=True,
|
|
)
|
|
fold_files = [op.join(fold_files, f"10-{x}.xls") for x in (5, 10)]
|
|
|
|
fold_tbl = pd.DataFrame()
|
|
for fi, fname in enumerate(fold_files):
|
|
fname = _check_fname(
|
|
fname, overwrite="read", must_exist=True, name=f"fold_files[{fi}]"
|
|
)
|
|
fold_tbl = pd.concat(
|
|
[fold_tbl, _read_fold_xls(fname, atlas=atlas)], ignore_index=True
|
|
)
|
|
return fold_tbl
|
|
|
|
|
|
|
|
def fold_channel_specificity_normal(raw, fold_files=None, atlas="Juelich", interpolate=False):
|
|
"""Return the landmarks and specificity a channel is sensitive to.
|
|
|
|
Parameters
|
|
|
|
""" # noqa: E501
|
|
_validate_type(raw, BaseRaw, "raw")
|
|
|
|
reference_locations = generate_montage_locations()
|
|
|
|
fold_tbl = _check_load_fold(fold_files, atlas)
|
|
|
|
chan_spec = list()
|
|
for cidx in range(len(raw.ch_names)):
|
|
tbl = _source_detector_fold_table(
|
|
raw, cidx, reference_locations, fold_tbl, interpolate
|
|
)
|
|
chan_spec.append(tbl.reset_index(drop=True))
|
|
|
|
return chan_spec
|
|
|
|
|
|
|
|
def fold_channels(raw: BaseRaw, all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], fold_path: str) -> None:
|
|
|
|
|
|
# Locate the fOLD excel files
|
|
mne.set_config('MNE_NIRS_FOLD_PATH', fold_path) # type: ignore
|
|
|
|
# Iterate over all of the groups
|
|
for group_name in all_results:
|
|
|
|
output = None
|
|
|
|
# List to store the results
|
|
landmark_specificity_data: list[dict[str, Any]] = []
|
|
|
|
# Filter the data to only what we want
|
|
hbo_channel_names = cast(list[str], getattr(raw.copy().pick(picks='hbo'), "ch_names")) # type: ignore
|
|
|
|
# Format the output to make it slightly easier to read
|
|
logger.info("*" * 60)
|
|
logger.info(f'Landmark Specificity for {group_name}:')
|
|
logger.info("*" * 60)
|
|
|
|
if GUI:
|
|
|
|
num_channels = len(hbo_channel_names)
|
|
rows, cols = 4, 7 # 6 rows and 4 columns of pie charts
|
|
fig, axes = plt.subplots(rows, cols, figsize=(16, 10), constrained_layout=True)
|
|
axes = axes.flatten() # Flatten the axes array for easier indexing
|
|
|
|
# If more pie charts than subplots, create extra subplots
|
|
if num_channels > rows * cols:
|
|
fig, axes = plt.subplots((num_channels // cols) + 1, cols, figsize=(16, 10), constrained_layout=True)
|
|
axes = axes.flatten()
|
|
|
|
# Create a list for consistent color mapping
|
|
landmarks = [
|
|
"1 - Primary Somatosensory Cortex",
|
|
"2 - Primary Somatosensory Cortex",
|
|
"3 - Primary Somatosensory Cortex",
|
|
"4 - Primary Motor Cortex",
|
|
"5 - Somatosensory Association Cortex",
|
|
"6 - Pre-Motor and Supplementary Motor Cortex",
|
|
"7 - Somatosensory Association Cortex",
|
|
"8 - Includes Frontal eye fields",
|
|
"9 - Dorsolateral prefrontal cortex",
|
|
"10 - Frontopolar area",
|
|
"11 - Orbitofrontal area",
|
|
"17 - Primary Visual Cortex (V1)",
|
|
"18 - Visual Association Cortex (V2)",
|
|
"19 - V3",
|
|
"20 - Inferior Temporal gyrus",
|
|
"21 - Middle Temporal gyrus",
|
|
"22 - Superior Temporal Gyrus",
|
|
"23 - Ventral Posterior cingulate cortex",
|
|
"24 - Ventral Anterior cingulate cortex",
|
|
"25 - Subgenual cortex",
|
|
"32 - Dorsal anterior cingulate cortex",
|
|
"37 - Fusiform gyrus",
|
|
"38 - Temporopolar area",
|
|
"39 - Angular gyrus, part of Wernicke's area",
|
|
"40 - Supramarginal gyrus part of Wernicke's area",
|
|
"41 - Primary and Auditory Association Cortex",
|
|
"42 - Primary and Auditory Association Cortex",
|
|
"43 - Subcentral area",
|
|
"44 - pars opercularis, part of Broca's area",
|
|
"45 - pars triangularis Broca's area",
|
|
"46 - Dorsolateral prefrontal cortex",
|
|
"47 - Inferior prefrontal gyrus",
|
|
"48 - Retrosubicular area",
|
|
"Brain_Outside",
|
|
]
|
|
|
|
cmap1 = plt.cm.get_cmap('tab20') # First 20 colors
|
|
cmap2 = plt.cm.get_cmap('tab20b') # Next 20 colors
|
|
|
|
# Combine the colors from both colormaps
|
|
colors = [cmap1(i) for i in range(20)] + [cmap2(i) for i in range(20)] # Total 40 colors
|
|
|
|
landmarks.sort(key=lambda x: (int(x.split(" - ")[0]) if x.split(" - ")[0].isdigit() else float('inf')))
|
|
|
|
landmark_color_map = {landmark: colors[i % len(colors)] for i, landmark in enumerate(landmarks)}
|
|
|
|
# Iterate over each channel
|
|
for idx, channel_name in enumerate(hbo_channel_names):
|
|
|
|
# Run the fOLD on the selected channel
|
|
channel_data = raw.copy().pick(picks=channel_name) # type: ignore
|
|
|
|
output = cast(list[DataFrame], fold_channel_specificity_normal(channel_data, interpolate=True, atlas='Brodmann'))
|
|
|
|
# Process each DataFrame that fold_channel_specificity returns
|
|
for df_data in output:
|
|
|
|
# Extract the relevant columns
|
|
useful_data = df_data[['Landmark', 'Specificity']]
|
|
|
|
# Store the results
|
|
landmark_specificity_data.append({
|
|
'Channel': channel_name,
|
|
'Data': useful_data,
|
|
})
|
|
|
|
# logger.info the results
|
|
for data in landmark_specificity_data:
|
|
logger.info(f"Channel: {data['Channel']}")
|
|
logger.info(f"{data['Data']}")
|
|
logger.info("-" * 60)
|
|
|
|
|
|
# If PLOT_ENABLED is True, plot the results
|
|
if GUI:
|
|
unique_landmarks = sorted(useful_data['Landmark'].unique())
|
|
color_list = [landmark_color_map[landmark] for landmark in useful_data['Landmark']]
|
|
|
|
# Plot specificity for each channel
|
|
ax = axes[idx] # Use the correct axis for this channel
|
|
|
|
labels = [f'{landmark.split(" - ")[0]}' if landmark != 'Brain_Outside' else 'B' for landmark in useful_data['Landmark']]
|
|
|
|
wedges, texts, autotexts = ax.pie(
|
|
useful_data['Specificity'],
|
|
autopct='%1.1f%%',
|
|
startangle=90,
|
|
labels=labels, # Add the labels here
|
|
labeldistance=1.05, # Adjust label position to avoid overlap with the wedges
|
|
colors=color_list) # Ensure color consistency
|
|
|
|
ax.set_title(f'{channel_name}')
|
|
ax.axis('equal') # Equal aspect ratio ensures the pie chart is circular.
|
|
|
|
|
|
# Reset the list for the next particcipant
|
|
landmark_specificity_data = []
|
|
|
|
if GUI:
|
|
handles = [
|
|
plt.Line2D([0], [0], marker='o', color='w', label=landmark, markersize=10,
|
|
markerfacecolor=landmark_color_map[landmark])
|
|
for landmark in landmarks
|
|
]
|
|
n_landmarks = len(landmarks)
|
|
|
|
# Calculate the figure size based on number of rows and columns
|
|
fig_width = 5
|
|
fig_height = n_landmarks / 4
|
|
|
|
# Create a new figure window for the legend
|
|
legend_fig = plt.figure(figsize=(fig_width, fig_height))
|
|
legend_axes = legend_fig.add_subplot(111)
|
|
legend_axes.axis('off') # Turn off axis for the legend window
|
|
legend_axes.legend(handles=handles, loc='center', fontsize=10, title="Landmarks")
|
|
|
|
if GUI:
|
|
for ax in axes[len(hbo_channel_names):]:
|
|
ax.axis('off')
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
def brain_landmarks_3d(raw_haemo: BaseRaw, show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all') -> None:
|
|
|
|
brain = Brain("fsaverage", background="white", size=(800, 700)) # type: ignore
|
|
|
|
# Add optode text labels manually
|
|
if show_optodes == 'all' or show_optodes == 'sensors':
|
|
brain.add_sensors(getattr(raw_haemo, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=VERBOSITY) # type: ignore
|
|
|
|
# Read and parse the file
|
|
if show_optodes == 'all' or show_optodes == 'labels':
|
|
positions: list[tuple[str, list[float]]] = []
|
|
with open(OPTODE_FILE_PATH, 'r') as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line or ':' not in line:
|
|
continue # skip empty/malformed lines
|
|
name, coords = line.split(':', 1)
|
|
coords = [float(x) for x in coords.strip().split()]
|
|
positions.append((name.strip(), coords))
|
|
|
|
for name, (x, y, z) in positions:
|
|
brain._renderer.text3d(x, y, z, name, color=('red' if name.startswith('s') else 'blue' if name.startswith('d') else 'gray'), scale=0.002) # type: ignore
|
|
|
|
for ch in getattr(raw_haemo, "info")['chs']:
|
|
logger.info(ch['ch_name'], ch['loc'][:3])
|
|
|
|
# Add Brodmann labels
|
|
labels = cast(list[mne.Label], mne.read_labels_from_annot("fsaverage", "PALS_B12_Brodmann", "rh", verbose=VERBOSITY)) # type: ignore
|
|
|
|
label_colors = {
|
|
"Brodmann.39-rh": "blue",
|
|
"Brodmann.40-rh": "green",
|
|
"Brodmann.6-rh": "pink",
|
|
"Brodmann.7-rh": "orange",
|
|
"Brodmann.17-rh": "red",
|
|
"Brodmann.1-rh": "yellow",
|
|
"Brodmann.2-rh": "yellow",
|
|
"Brodmann.3-rh": "yellow",
|
|
"Brodmann.18-rh": "red",
|
|
"Brodmann.19-rh": "red",
|
|
"Brodmann.4-rh": "purple",
|
|
"Brodmann.8-rh": "white"
|
|
}
|
|
|
|
for label in labels:
|
|
name = getattr(label, "name", None)
|
|
if not isinstance(name, str):
|
|
continue
|
|
if name in label_colors:
|
|
brain.add_label(label, borders=False, color=label_colors[name]) # type: ignore
|
|
|
|
|
|
|
|
def data_to_csv(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]):
|
|
|
|
logger.info("Getting the current directory...")
|
|
if PLATFORM_NAME == 'darwin':
|
|
csvs_folder = os.path.join(os.path.dirname(sys.executable), "../../../csvs")
|
|
else:
|
|
cwd = os.getcwd()
|
|
csvs_folder = os.path.join(cwd, "csvs")
|
|
|
|
logger.info("Attempting to create the csvs folder...")
|
|
os.makedirs(csvs_folder, exist_ok=True)
|
|
|
|
# Generate a timestamp to be appended to the end of the file name
|
|
logger.info("Generating the timestamp...")
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
# Iterate over all groups
|
|
for group_name in all_results:
|
|
|
|
# Get the channel data and generate the file name
|
|
(_, df_cha, _, _) = all_results[group_name]
|
|
|
|
filename = f"{group_name}_{timestamp}.csv"
|
|
save_path = os.path.join(csvs_folder, filename)
|
|
|
|
# Filter to just the target condition and store it in the csv
|
|
if HRF_MODEL == 'fir':
|
|
filtered_df = df_cha[
|
|
df_cha["Condition"].str.startswith(TARGET_ACTIVITY) &
|
|
(df_cha["Chroma"] == "hbo")
|
|
]
|
|
|
|
# Step 2: Define the aggregation logic
|
|
agg_funcs = {
|
|
'df': 'mean',
|
|
'mse': 'mean',
|
|
'p_value': 'mean',
|
|
'se': 'mean',
|
|
't': 'mean',
|
|
'theta': 'mean',
|
|
'Source': 'mean',
|
|
'Detector': 'mean',
|
|
'Significant': lambda x: x.sum() > (len(x) / 2),
|
|
'Chroma': 'first', # assuming all are the same
|
|
'ch_name': 'first', # same ch_name in the group
|
|
'ID': 'first', # same ID in the group
|
|
}
|
|
|
|
# Step 3: Group and aggregate
|
|
averaged_df = (
|
|
filtered_df
|
|
.groupby(['ch_name', 'ID'], as_index=False)
|
|
.agg(agg_funcs)
|
|
)
|
|
|
|
# Step 4: Rename and add 'Condition' as TARGET_ACTIVITY
|
|
averaged_df.insert(0, 'Condition', TARGET_ACTIVITY)
|
|
|
|
averaged_df["df"] = averaged_df["df"].round().astype(int)
|
|
averaged_df["Source"] = averaged_df["Source"].round().astype(int)
|
|
averaged_df["Detector"] = averaged_df["Detector"].round().astype(int)
|
|
|
|
# Step 5: Reset index and reorder columns
|
|
ordered_cols = [
|
|
'Condition', 'df', 'mse', 'p_value', 'se', 't', 'theta',
|
|
'Source', 'Detector', 'Chroma', 'Significant', 'ch_name', 'ID'
|
|
]
|
|
averaged_df = averaged_df[ordered_cols].reset_index(drop=True)
|
|
averaged_df = averaged_df.sort_values(by=["ID", "Detector", "Source"]).reset_index(drop=True)
|
|
|
|
output_df = averaged_df
|
|
else:
|
|
output_df = df_cha.query(f"Condition == '{TARGET_ACTIVITY}' and Chroma == 'hbo'") # type: ignore
|
|
output_df.to_csv(save_path)
|
|
|
|
|
|
|
|
def all_data_to_csv(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]]):
|
|
|
|
logger.info("Getting the current directory...")
|
|
if PLATFORM_NAME == 'darwin':
|
|
csvs_folder = os.path.join(os.path.dirname(sys.executable), "../../../csvs")
|
|
else:
|
|
cwd = os.getcwd()
|
|
csvs_folder = os.path.join(cwd, "csvs")
|
|
|
|
logger.info("Attempting to create the csvs folder...")
|
|
os.makedirs(csvs_folder, exist_ok=True)
|
|
|
|
# Generate a timestamp to be appended to the end of the file name
|
|
logger.info("Generating the timestamp...")
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
# Iterate over all groups
|
|
for group_name in all_results:
|
|
|
|
# Get the channel data and generate the file name
|
|
(_, df_cha, _, _) = all_results[group_name]
|
|
|
|
filename = f"{group_name}_{timestamp}_all.csv"
|
|
save_path = os.path.join(csvs_folder, filename)
|
|
|
|
# Filter to just the target condition and store it in the csv
|
|
if HRF_MODEL == 'fir':
|
|
output_df = df_cha
|
|
else:
|
|
output_df = df_cha.query(f"Condition == '{TARGET_ACTIVITY}' and Chroma == 'hbo'") # type: ignore
|
|
output_df.to_csv(save_path)
|
|
|
|
|
|
|
|
def brain_3d_contrast(con_model_df: DataFrame, con_model_df_filtered: BaseRaw, common_channels: list[str], first_name: str, second_name: str, first_stim: float, second_stim: float, t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True) -> None:
|
|
# Filter DataFrame to only common channels, and sort by raw order
|
|
con_model = con_model_df
|
|
|
|
con_model["ch_name"] = pd.Categorical(
|
|
con_model["ch_name"], categories=common_channels, ordered=True
|
|
)
|
|
con_model = con_model.sort_values("ch_name").reset_index(drop=True) # type: ignore
|
|
|
|
|
|
if t_or_theta == 't':
|
|
clim=dict(kind="value", pos_lims=(0, ABS_T_VALUE/2, ABS_T_VALUE))
|
|
elif t_or_theta == 'theta':
|
|
clim=dict(kind="value", pos_lims=(0, ABS_THETA_VALUE/2, ABS_THETA_VALUE))
|
|
|
|
# Plot brain figure
|
|
brain = plot_3d_evoked_array(con_model_df_filtered.copy().pick(picks="hbo"), con_model, view="dorsal", distance=BRAIN_DISTANCE, colorbar=True, mode=BRAIN_MODE, clim=clim, size=(800, 700), verbose=VERBOSITY) # type: ignore
|
|
|
|
if show_optodes == 'all' or show_optodes == 'sensors':
|
|
brain.add_sensors(getattr(con_model_df_filtered, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=VERBOSITY) # type: ignore
|
|
|
|
# Read and parse the file
|
|
if show_optodes == 'all' or show_optodes == 'labels':
|
|
positions: list[tuple[str, list[float]]] = []
|
|
with open(OPTODE_FILE_PATH, 'r') as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line or ':' not in line:
|
|
continue # skip empty/malformed lines
|
|
name, coords = line.split(':', 1)
|
|
coords = [float(x) for x in coords.strip().split()]
|
|
positions.append((name.strip(), coords))
|
|
|
|
for name, (x, y, z) in positions:
|
|
brain._renderer.text3d(x, y, z, name, color=('red' if name.startswith('s') else 'blue' if name.startswith('d') else 'gray'), scale=0.002) # type: ignore
|
|
|
|
# Set the display text for the brain image
|
|
# display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nContrast: ' + first_name + ' - ' + second_name + '\nReject Criteria Threshold: ' + str(EPOCH_REJECT_CRITERIA_THRESH) + '\nMin Time Threshold: ' +
|
|
# str(TIME_MIN_THRESH) + 's\nMax Time Threshold: ' + str(TIME_MAX_THRESH) + 's\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) + '\nStim Duration: ' + str(first_stim) + ', ' +
|
|
# str(second_stim) + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE + '\nLooking at: ' + t_or_theta + ' values')
|
|
|
|
if HRF_MODEL == 'fir':
|
|
display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nContrast: ' + first_name + ' - ' + second_name + '\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION)
|
|
+ '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE + '\nLooking at: ' + t_or_theta + ' values')
|
|
else:
|
|
display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nContrast: ' + first_name + ' - ' + second_name + '\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION)
|
|
+ '\nStim Duration: ' + str(first_stim) + ', ' + str(second_stim) + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE + '\nLooking at: ' + t_or_theta + ' values')
|
|
|
|
# Apply the text onto the brain
|
|
if show_text:
|
|
brain.add_text(0.12, 0.70, display_text, "title", font_size=11, color="k") # type: ignore
|
|
|
|
|
|
|
|
def plot_2d_3d_contrasts_between_groups(all_results: dict[str, tuple[DataFrame, DataFrame, DataFrame, DataFrame]], all_raw_haemo: dict[str, dict[str, BaseRaw]], t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True) -> None:
|
|
|
|
|
|
# Dictionary to store data for each group
|
|
group_dfs: dict[str, DataFrame] = {}
|
|
|
|
# GET RAW HAEMO OF THE FIRST PARTICIPANT
|
|
raw_haemo = all_raw_haemo[list(all_raw_haemo.keys())[0]]["full_layout"]
|
|
|
|
# Store all contrasts with the corresponding group name
|
|
for group_name, (_, _, df_con, _) in all_results.items():
|
|
group_dfs[group_name] = df_con
|
|
group_dfs[group_name]["group"] = group_name
|
|
|
|
# Concatenate all groups together
|
|
df_combined = pd.concat(group_dfs.values(), ignore_index=True)
|
|
|
|
con_summary = df_combined.query("Chroma == 'hbo'").copy() # type: ignore
|
|
|
|
valid_channels = cast(DataFrame, (pd.crosstab(con_summary['group'], con_summary['ch_name']) > 1).all()) # type: ignore
|
|
valid_channels = valid_channels[valid_channels].index.tolist()
|
|
|
|
# Filter data to only these channels
|
|
con_summary = con_summary[con_summary['ch_name'].isin(valid_channels)] # type: ignore
|
|
|
|
|
|
# # Verify your data looks as expected
|
|
# logger.info(con_summary[['group', 'ch_name', 'Chroma', 'effect']].head())
|
|
# logger.info("\nUnique values:")
|
|
# logger.info("Groups:", con_summary['group'].unique())
|
|
# logger.info("Channels:", con_summary['ch_name'].unique())
|
|
# logger.info("Chroma:", con_summary['Chroma'].unique()) # Should be just 'hbo'
|
|
|
|
model_formula = "effect ~ -1 + group:ch_name:Chroma"
|
|
con_model = smf.mixedlm(model_formula, con_summary, groups=con_summary["ID"]).fit(method="nm") # type: ignore
|
|
|
|
# logger.info(con_model.summary())
|
|
|
|
|
|
# # Fit the mixed-effects model
|
|
# model_formula = "effect ~ -1 + group:ch_name:Chroma"
|
|
|
|
# #model_formula = "effect ~ -1 + group + ch_name"
|
|
|
|
# con_model = smf.mixedlm(
|
|
# model_formula, con_summary_filtered, groups=con_summary_filtered["ID"]
|
|
# ).fit(method="nm")
|
|
|
|
# Get the t values if we are comparing them
|
|
|
|
t_values: pd.Series[float] = pd.Series(dtype=float)
|
|
if t_or_theta == 't':
|
|
t_values = con_model.tvalues
|
|
|
|
# Get all the group names from the dictionary and how many groups we have
|
|
group_names = list(group_dfs.keys())
|
|
n_groups = len(group_names)
|
|
|
|
# Store DataFrames for each contrast
|
|
for i in range(n_groups):
|
|
for j in range(i + 1, n_groups):
|
|
group1_name = group_names[i]
|
|
group2_name = group_names[j]
|
|
|
|
if t_or_theta == 't':
|
|
# Extract the t-values for both groups
|
|
group1_vals = t_values.filter(like=f"group[{group1_name}]") # type: ignore
|
|
group2_vals = t_values.filter(like=f"group[{group2_name}]") # type: ignore
|
|
vlim_value = ABS_CONTRAST_T_VALUE
|
|
|
|
elif t_or_theta == 'theta':
|
|
# Extract the coefficients for both groups
|
|
group1_vals = con_model.params.filter(like=f"group[{group1_name}]")
|
|
group2_vals = con_model.params.filter(like=f"group[{group2_name}]")
|
|
vlim_value = ABS_CONTRAST_THETA_VALUE
|
|
|
|
# TODO: Does this work for all separators?
|
|
# Extract channel names
|
|
|
|
group1_channels: list[str] = [
|
|
name.split(":")[1].split("[")[1].split("]")[0]
|
|
for name in getattr(group1_vals, "index")
|
|
]
|
|
group2_channels: list[str] = [
|
|
name.split(":")[1].split("[")[1].split("]")[0]
|
|
for name in getattr(group2_vals, "index")
|
|
]
|
|
|
|
# Create the DataFrames with channel indices
|
|
df_group1 = DataFrame(
|
|
{"Coef.": group1_vals.values}, index=group1_channels # type: ignore
|
|
)
|
|
df_group2 = DataFrame(
|
|
{"Coef.": group2_vals.values}, index=group2_channels # type: ignore
|
|
)
|
|
|
|
# Merge the two DataFrames on the channel names
|
|
df_contrast = df_group1.join(df_group2, how="inner", lsuffix=f"_{group1_name}", rsuffix=f"_{group2_name}") # type: ignore
|
|
|
|
# Compute the contrasts
|
|
contrast_1_2 = df_contrast[f"Coef._{group1_name}"] - df_contrast[f"Coef._{group2_name}"]
|
|
contrast_2_1 = df_contrast[f"Coef._{group2_name}"] - df_contrast[f"Coef._{group1_name}"]
|
|
|
|
# Add the a-b / 1-2 contrast to the DataFrame. The order and names of the keys in the DataFrame are important!
|
|
df_contrast["Coef."] = contrast_1_2
|
|
con_model_df_1_2 = DataFrame({
|
|
"ch_name": df_contrast.index,
|
|
"Coef.": df_contrast["Coef."],
|
|
"Chroma": "hbo"
|
|
})
|
|
|
|
|
|
mne_ch_names = getattr(raw_haemo.copy().pick(picks="hbo"), "ch_names") # type: ignore
|
|
glm_ch_names = cast(list[DataFrame], con_model_df_1_2["ch_name"].tolist())
|
|
|
|
# Get ordered common channels
|
|
common_channels = [ch for ch in mne_ch_names if ch in glm_ch_names]
|
|
|
|
# Filter raw data to these channels
|
|
con_model_df_filtered = raw_haemo.copy().pick(picks=common_channels) # type: ignore
|
|
|
|
# Reindex GLM results to match MNE channel order
|
|
con_model_df_1_2 = con_model_df_1_2.set_index("ch_name").loc[common_channels].reset_index() # type: ignore
|
|
|
|
# Create the 3d visualization
|
|
brain_3d_contrast(con_model_df_1_2, con_model_df_filtered, common_channels, group1_name, group2_name, STIM_DURATION[i], STIM_DURATION[j], t_or_theta, show_optodes, show_text)
|
|
|
|
plot_glm_group_topo(con_model_df_filtered.copy().pick(picks="hbo"), con_model_df_1_2, names=True, res=128, vlim=(-vlim_value, vlim_value)) # type: ignore
|
|
|
|
|
|
# TODO: The title currently goes on the colorbar. Low priority
|
|
plt.title(f"Contrast: {group1_name} vs {group2_name}") # type: ignore
|
|
plt.show() # type: ignore
|
|
|
|
# Add the b-a / 2-1 contrast to the DataFrame. The order and names of the keys in the DataFrame are important!
|
|
df_contrast["Coef."] = contrast_2_1
|
|
con_model_df_2_1 = DataFrame({
|
|
"ch_name": df_contrast.index,
|
|
"Coef.": df_contrast["Coef."],
|
|
"Chroma": "hbo"
|
|
})
|
|
|
|
|
|
mne_ch_names = getattr(raw_haemo.copy().pick(picks="hbo"), "ch_names") # type: ignore
|
|
glm_ch_names = cast(list[DataFrame], con_model_df_2_1["ch_name"].tolist())
|
|
|
|
# Get ordered common channels
|
|
common_channels = [ch for ch in mne_ch_names if ch in glm_ch_names]
|
|
|
|
# Filter raw data to these channels
|
|
con_model_df_filtered = raw_haemo.copy().pick(picks=common_channels) # type: ignore
|
|
|
|
# Reindex GLM results to match MNE channel order
|
|
con_model_df_2_1 = con_model_df_2_1.set_index("ch_name").loc[common_channels].reset_index() # type: ignore
|
|
|
|
# Create the 3d visualization
|
|
brain_3d_contrast(con_model_df_2_1, con_model_df_filtered, common_channels, group2_name, group1_name, STIM_DURATION[j], STIM_DURATION[i], t_or_theta, show_optodes, show_text)
|
|
|
|
plot_glm_group_topo(con_model_df_filtered.copy().pick(picks="hbo"), con_model_df_2_1, names=True, res=128, vlim=(-vlim_value, vlim_value)) # type: ignore
|
|
|
|
# TODO: The title currently goes on the colorbar. Low priority
|
|
plt.title(f"Contrast: {group2_name} vs {group1_name}") # type: ignore
|
|
plt.show() # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: Is any of this still useful?
|
|
|
|
def calculate_annotations(raw_haemo_filtered, file_name, output_folder=None, save_images=None):
|
|
'''Method that extract the annotations from the data.\n
|
|
Input:\n
|
|
raw_haemo_filtered (RawSNIRF) - The filtered haemoglobin concentration data\n
|
|
file_name (string) - The file name of the current file\n
|
|
output_folder (string) - (optional) Where to save the images. Default is None\n
|
|
save_images (string) - (optional) Bool to save the images or not. Default is None
|
|
Output:\n
|
|
events (ndarray) - Array containing row number and what index the event is\n
|
|
event_dict (dict) - Contains the names of the events'''
|
|
|
|
if output_folder is None:
|
|
output_folder = None
|
|
if save_images is None:
|
|
save_images = None
|
|
|
|
# Get when the events occur and what they are called, and display a figure with the result
|
|
events, event_dict = mne.events_from_annotations(raw_haemo_filtered)
|
|
|
|
# Do we save the image?
|
|
if save_images:
|
|
fig = mne.viz.plot_events(events, event_id=event_dict, sfreq=raw_haemo_filtered.info["sfreq"], show=False)
|
|
save_path = output_folder + "/8. Annotations for " + file_name + ".png"
|
|
fig.savefig(save_path)
|
|
|
|
return events, event_dict
|
|
|
|
|
|
|
|
def calculate_good_epochs(raw_haemo_filtered, events, event_dict, file_name, tmin=None, tmax=None, reject_thresh=None, target_activity=None, target_control=None, output_folder=None, save_images=None):
|
|
'''Calculates what epochs are good and creates a graph showing if any are dropped.\n
|
|
Input:\n
|
|
raw_haemo_filtered (RawSNIRF) - The filtered haemoglobin concentration data\n
|
|
events (ndarray) - Array containing row number and what index the event is\n
|
|
event_dict (dict) - Contains the names of the events\n
|
|
file_name (string) - The file name of the current file\n
|
|
tmin (float) - (optional) Time in seconds to display before the event. Default is TIME_MIN_THRESH\n
|
|
tmax (float) - (optional) Time in seconds to display after the event. Default is TIME_MAX_THRESH\n
|
|
reject_thresh (float) - (optional) Value that determines the threshold for rejecting epochs. Default is EPOCH_REJECT_CRITERIA_THRESH\n
|
|
target_activity (string) - (optional) The target activity. Default is TARGET_ACTIVITY\n
|
|
target_control (string) - (optional) The target control. Default is TARGET_CONTROL\n
|
|
output_folder (string) - (optional) Where to save the images. Default is None\n
|
|
save_images (string) - (optional) Bool to save the images or not. Default is None
|
|
Output:\n
|
|
good_epochs (Epochs) - The remaining good epochs\n
|
|
all_epochs (Epochs) - All of the epochs'''
|
|
|
|
if tmin is None:
|
|
tmin = TIME_MIN_THRESH
|
|
if tmax is None:
|
|
tmax = TIME_MAX_THRESH
|
|
if reject_thresh is None:
|
|
reject_thresh = EPOCH_REJECT_CRITERIA_THRESH
|
|
if target_activity is None:
|
|
target_activity = TARGET_ACTIVITY
|
|
if target_control is None:
|
|
target_control = TARGET_CONTROL
|
|
if output_folder is None:
|
|
output_folder = None
|
|
if save_images is None:
|
|
save_images = None
|
|
|
|
# Get all the good epochs
|
|
good_epochs = mne.Epochs(
|
|
raw_haemo_filtered,
|
|
events,
|
|
event_id=event_dict,
|
|
tmin=tmin,
|
|
tmax=tmax,
|
|
reject=dict(hbo=reject_thresh),
|
|
reject_by_annotation=True,
|
|
proj=True,
|
|
baseline=(None, 0),
|
|
preload=True,
|
|
detrend=None,
|
|
verbose=True,
|
|
)
|
|
|
|
# Get all the epochs
|
|
all_epochs = mne.Epochs(
|
|
raw_haemo_filtered,
|
|
events,
|
|
event_id=event_dict,
|
|
tmin=tmin,
|
|
tmax=tmax,
|
|
proj=True,
|
|
baseline=(None, 0),
|
|
preload=True,
|
|
detrend=None,
|
|
verbose=True,
|
|
)
|
|
|
|
if REJECT_PAIRS:
|
|
# Calculate which epochs were in all but not in good
|
|
all_idx = all_epochs.selection
|
|
good_idx = good_epochs.selection
|
|
bad_idx = np.setdiff1d(all_idx, good_idx)
|
|
|
|
# Split the controls and the activities
|
|
event_ids = all_epochs.events[:, 2]
|
|
control_id = event_dict[target_control]
|
|
activity_id = event_dict[target_activity]
|
|
|
|
to_reject_extra = set()
|
|
|
|
for i, idx in enumerate(all_idx):
|
|
if idx in bad_idx:
|
|
ev = event_ids[i]
|
|
# If the control was bad, drop the following activity
|
|
if ev == control_id and i + 1 < len(all_idx):
|
|
if event_ids[i + 1] == activity_id:
|
|
to_reject_extra.add(all_idx[i + 1])
|
|
# If the activity was bad, drop the preceding activity
|
|
if ev == activity_id and i - 1 >= 0:
|
|
if event_ids[i - 1] == control_id:
|
|
to_reject_extra.add(all_idx[i - 1])
|
|
|
|
# Create a list to store all the new drops, only adding them if they are currently classified as good
|
|
drop_idx_in_good = [
|
|
np.where(good_idx == idx)[0][0] for idx in to_reject_extra if idx in good_idx
|
|
]
|
|
|
|
# Drop the pairings of the bad epochs
|
|
good_epochs.drop(drop_idx_in_good)
|
|
|
|
# Do we save the image?
|
|
if save_images:
|
|
drop_log_fig = good_epochs.plot_drop_log(show=False)
|
|
save_path = output_folder + "/8. Epoch drops for " + file_name + ".png"
|
|
drop_log_fig.savefig(save_path)
|
|
|
|
return good_epochs, all_epochs
|
|
|
|
|
|
|
|
def bad_check(raw_od, max_bad_channels=None):
|
|
'''Method to see if we have more bad channels than our allowed threshold.\n
|
|
Inputs:\n
|
|
raw_od (RawSNIRF) - The optical density data\n
|
|
max_bad_channels (int) - (optional) The max amount of bad channels we want to tolerate. Default is MAX_BAD_CHANNELS\n
|
|
Output\n
|
|
(bool) - True it we had more bad channels than the threshold, False if we did not'''
|
|
|
|
if max_bad_channels is None:
|
|
max_bad_channels = MAX_BAD_CHANNELS
|
|
|
|
# Check if there is more bad channels in the bads key compared to the allowed amount
|
|
if len(raw_od.info.get('bads', [])) >= max_bad_channels:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
|
|
def remove_bad_epoch_pairings(raw_haemo_filtered_minus_short, good_epochs, epoch_pair_tolerance_window=None):
|
|
'''Method to apply our new epochs to the loaded data in working memory. This is to ensure that the GLM does not see these epochs.
|
|
Inputs:\n
|
|
raw_haemo_filtered_minus_short (RawSNIRF) - The filtered haemoglobin concentration data\n
|
|
good_epochs (Epochs) - The epochs we want the loaded data to take on\n
|
|
epoch_pair_tolerance_window (int) - (optional) The amount of data points the paired epoch can deviate from the expected amount. Default is EPOCH_PAIR_TOLERANCE_WINDOW\n
|
|
Output:\n
|
|
raw_haemo_filtered_good_epochs (RawSNIRF) - The filtered haemoglobin concentration data with only the good epochs'''
|
|
|
|
if epoch_pair_tolerance_window is None:
|
|
epoch_pair_tolerance_window = EPOCH_PAIR_TOLERANCE_WINDOW
|
|
# Copy the input haemoglobin concentration data and drop the bad channels
|
|
raw_haemo_filtered_good_epochs = raw_haemo_filtered_minus_short.copy()
|
|
raw_haemo_filtered_good_epochs = raw_haemo_filtered_good_epochs.drop_channels(raw_haemo_filtered_good_epochs.info['bads'])
|
|
|
|
# Get the event IDs of the good events
|
|
good_event_samples = set(good_epochs.events[:, 0])
|
|
logger.info(f"Total good events (epochs): {len(good_event_samples)}")
|
|
|
|
# Get the current annotations
|
|
raw_annots = raw_haemo_filtered_good_epochs.annotations
|
|
|
|
# Create lists to use for processing
|
|
clean_descriptions = []
|
|
clean_onsets = []
|
|
clean_durations = []
|
|
dropped = []
|
|
|
|
# Get the frequency of the file
|
|
sfreq = raw_haemo_filtered_good_epochs.info['sfreq']
|
|
|
|
for desc, onset, dur in zip(raw_annots.description, raw_annots.onset, raw_annots.duration):
|
|
# Convert annotation onset time to sample index
|
|
sample = int(onset * sfreq)
|
|
|
|
if FORCE_DROP_ANNOTATIONS:
|
|
for i in FORCE_DROP_ANNOTATIONS:
|
|
if desc == i:
|
|
dropped.append((desc, onset))
|
|
continue
|
|
|
|
# Check if the annotation is within the tolerance of any good event
|
|
matched = any(abs(sample - event_sample) <= epoch_pair_tolerance_window for event_sample in good_event_samples)
|
|
|
|
# We found a matching event
|
|
if matched:
|
|
clean_descriptions.append(desc)
|
|
clean_onsets.append(onset)
|
|
clean_durations.append(dur)
|
|
else:
|
|
dropped.append((desc, onset))
|
|
|
|
# Create the new filtered annotations
|
|
new_annots = Annotations(
|
|
onset=clean_onsets,
|
|
duration=clean_durations,
|
|
description=clean_descriptions,
|
|
)
|
|
|
|
# Assign the new annotations
|
|
raw_haemo_filtered_good_epochs.set_annotations(new_annots)
|
|
|
|
# logger.info out the results
|
|
logger.info(f"Original annotations: {len(raw_annots)}")
|
|
logger.info(f"Kept annotations: {len(clean_descriptions)}")
|
|
logger.info("Kept annotation types:", set(clean_descriptions))
|
|
if dropped:
|
|
logger.info(f"Dropped annotations: {len(dropped)}")
|
|
logger.info("Dropped annotations:")
|
|
for desc, onset in dropped:
|
|
logger.info(f" - {desc} at {onset:.2f}s")
|
|
else:
|
|
logger.info("No annotations were dropped!")
|
|
|
|
return raw_haemo_filtered_good_epochs |