3618 lines
133 KiB
Python
3618 lines
133 KiB
Python
"""
|
||
Filename: flares.py
|
||
Description: Core functionality for FLARES
|
||
|
||
Author: Tyler de Zeeuw
|
||
License: GPL-3.0
|
||
"""
|
||
|
||
# Built-in imports
|
||
import os
|
||
import sys
|
||
import platform
|
||
import threading
|
||
import logging
|
||
from io import BytesIO
|
||
from typing import Any, Optional, cast, Literal, Union
|
||
from itertools import compress
|
||
from copy import deepcopy
|
||
from multiprocessing import Queue
|
||
import os.path as op
|
||
import re
|
||
import traceback
|
||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||
|
||
# External library imports
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.colors as mcolors
|
||
from matplotlib.figure import Figure
|
||
from matplotlib.axes import Axes
|
||
from matplotlib.colors import LinearSegmentedColormap
|
||
from matplotlib.lines import Line2D
|
||
|
||
import numpy as np
|
||
from numpy.typing import NDArray
|
||
from numpy import float64, floating
|
||
|
||
import pandas as pd
|
||
from pandas import DataFrame
|
||
|
||
import seaborn as sns
|
||
import h5py
|
||
|
||
from nilearn.plotting import plot_design_matrix # type: ignore
|
||
from nilearn.glm.regression import OLSModel
|
||
|
||
import statsmodels.formula.api as smf # type: ignore
|
||
from statsmodels.stats.multitest import multipletests
|
||
|
||
from scipy import stats
|
||
from scipy.spatial.distance import cdist
|
||
from scipy.signal import welch, butter, filtfilt # type: ignore
|
||
|
||
import pywt # type: ignore
|
||
import neurokit2 as nk # type: ignore
|
||
|
||
# Backen visualization needed to be defined for pyinstaller
|
||
import pyvistaqt # type: ignore
|
||
#import vtkmodules.util.data_model
|
||
#import vtkmodules.util.execution_model
|
||
|
||
# External library imports for mne
|
||
from mne import (
|
||
EvokedArray, SourceEstimate, Info, Epochs, Label, Annotations,
|
||
events_from_annotations, read_source_spaces,
|
||
stc_near_sensors, pick_types, grand_average, get_config, set_config, read_labels_from_annot
|
||
) # type: ignore
|
||
from mne.source_space import SourceSpaces
|
||
from mne.transforms import Transform # type: ignore
|
||
from mne.io import BaseRaw, RawArray, read_raw_snirf # type: ignore
|
||
from mne.preprocessing.nirs import (
|
||
beer_lambert_law, optical_density,
|
||
temporal_derivative_distribution_repair,
|
||
source_detector_distances, short_channels
|
||
) # type: ignore
|
||
from mne.viz import Brain, plot_events, plot_evoked_topo, plot_compare_evokeds
|
||
from mne.filter import filter_data # type: ignore
|
||
from mne.utils import _check_fname, _validate_type, warn
|
||
from mne.channels import make_standard_montage
|
||
from mne.datasets.sample import data_path
|
||
|
||
from mne_nirs.visualisation import plot_glm_group_topo # type: ignore
|
||
from mne_nirs.channels import get_long_channels, get_short_channels # type: ignore
|
||
from mne_nirs.experimental_design import make_first_level_design_matrix # type: ignore
|
||
from mne_nirs.statistics import run_glm, statsmodels_to_results # type: ignore
|
||
from mne_nirs.signal_enhancement import (
|
||
enhance_negative_correlation, short_channel_regression
|
||
) # type: ignore
|
||
from mne_nirs.io.fold import fold_channel_specificity # type: ignore
|
||
from mne_nirs.preprocessing import peak_power # type: ignore
|
||
from mne_nirs.statistics._glm_level_first import RegressionResults # type: ignore
|
||
|
||
|
||
os.environ["SUBJECTS_DIR"] = str(data_path()) + "/subjects" # type: ignore
|
||
|
||
FIXED_CATEGORY_COLORS = {
|
||
"SCI only": "skyblue",
|
||
"PSP only": "salmon",
|
||
"SNR only": "lightgreen",
|
||
"PSP + SCI": "orange",
|
||
"SCI + SNR": "violet",
|
||
"PSP + SNR": "gold",
|
||
"SCI + PSP": "orange",
|
||
"SNR + SCI": "violet",
|
||
"SNR + PSP": "gold",
|
||
"PSP + SNR + SCI": "gray",
|
||
"SCI + PSP + SNR": "gray",
|
||
"SCI + SNR + PSP": "gray",
|
||
"PSP + SCI + SNR": "gray",
|
||
"PSP + SNR + SCI": "gray",
|
||
"SNR + SCI + PSP": "gray",
|
||
"SNR + PSP + SCI": "gray",
|
||
}
|
||
|
||
|
||
AGE: float
|
||
GENDER: str
|
||
|
||
# SECONDS_TO_STRIP: int
|
||
DOWNSAMPLE: bool
|
||
DOWNSAMPLE_FREQUENCY: int
|
||
|
||
TRIM: bool
|
||
SECONDS_TO_KEEP: float
|
||
|
||
OPTODE_PLACEMENT: bool
|
||
|
||
HEART_RATE: bool
|
||
|
||
SCI: bool
|
||
SCI_TIME_WINDOW: int
|
||
SCI_THRESHOLD: float
|
||
|
||
SNR: bool
|
||
# SNR_TIME_WINDOW : int
|
||
SNR_THRESHOLD: float
|
||
|
||
PSP: bool
|
||
PSP_TIME_WINDOW: int
|
||
PSP_THRESHOLD: float
|
||
|
||
TDDR: bool
|
||
|
||
WAVELET: bool
|
||
IQR: float
|
||
WAVELET_TYPE: str
|
||
WAVELET_LEVEL: int
|
||
|
||
HEART_RATE = True # True if heart rate should be calculated. This helps the SCI, PSP, and SNR methods to be more accurate.
|
||
SECONDS_TO_STRIP_HR =5 # Amount of seconds to temporarily strip from the data to calculate heart rate more effectively. Useful if participant removed cap while still recording.
|
||
MAX_LOW_HR = 40 # Any heart rate values lower than this will be set to this value.
|
||
MAX_HIGH_HR = 200 # Any heart rate values higher than this will be set to this value.
|
||
SMOOTHING_WINDOW_HR = 100 # Heart rate will be calculated as a rolling average over this many amount of samples.
|
||
HEART_RATE_WINDOW = 25 # Amount of BPM above and below the calculated average to use for a range of resting BPM.
|
||
|
||
ENHANCE_NEGATIVE_CORRELATION: bool
|
||
|
||
FILTER: bool
|
||
L_FREQ: float
|
||
H_FREQ: float
|
||
|
||
SHORT_CHANNEL: bool
|
||
SHORT_CHANNEL_THRESH: float
|
||
LONG_CHANNEL_THRESH: float
|
||
|
||
REMOVE_EVENTS: list
|
||
|
||
TIME_WINDOW_START: int
|
||
TIME_WINDOW_END: int
|
||
|
||
DRIFT_MODEL: str
|
||
|
||
VERBOSITY = True
|
||
|
||
# FIXME: Shouldn't need each ordering - just order it before checking
|
||
FIXED_CATEGORY_COLORS = {
|
||
"SCI only": "skyblue",
|
||
"PSP only": "salmon",
|
||
"SNR only": "lightgreen",
|
||
"PSP + SCI": "orange",
|
||
"SCI + SNR": "violet",
|
||
"PSP + SNR": "gold",
|
||
"SCI + PSP": "orange",
|
||
"SNR + SCI": "violet",
|
||
"SNR + PSP": "gold",
|
||
"PSP + SNR + SCI": "gray",
|
||
"SCI + PSP + SNR": "gray",
|
||
"SCI + SNR + PSP": "gray",
|
||
"PSP + SCI + SNR": "gray",
|
||
"PSP + SNR + SCI": "gray",
|
||
"SNR + SCI + PSP": "gray",
|
||
"SNR + PSP + SCI": "gray",
|
||
}
|
||
|
||
|
||
AGE = 25
|
||
GENDER = ""
|
||
GROUP = "Default"
|
||
|
||
REQUIRED_KEYS: dict[str, Any] = {
|
||
|
||
# "SECONDS_TO_STRIP": int,
|
||
"DOWNSAMPLE": bool,
|
||
"DOWNSAMPLE_FREQUENCY": int,
|
||
|
||
"TRIM": bool,
|
||
"SECONDS_TO_KEEP": float,
|
||
|
||
"OPTODE_PLACEMENT": bool,
|
||
|
||
"HEART_RATE": bool,
|
||
|
||
"SCI": bool,
|
||
"SCI_TIME_WINDOW": int,
|
||
"SCI_THRESHOLD": float,
|
||
|
||
"SNR": bool,
|
||
# SNR_TIME_WINDOW : int
|
||
"SNR_THRESHOLD": float,
|
||
|
||
"PSP": bool,
|
||
"PSP_TIME_WINDOW": int,
|
||
"PSP_THRESHOLD": float,
|
||
|
||
"SHORT_CHANNEL": bool,
|
||
"SHORT_CHANNEL_THRESH": float,
|
||
"LONG_CHANNEL_THRESH": float,
|
||
|
||
|
||
"REMOVE_EVENTS": list,
|
||
"TIME_WINDOW_START": int,
|
||
"TIME_WINDOW_END": int,
|
||
"L_FREQ": float,
|
||
"H_FREQ": float,
|
||
|
||
"TDDR": bool,
|
||
"WAVELET": bool,
|
||
"IQR": float,
|
||
"WAVELET_TYPE": str,
|
||
"WAVELET_LEVEL": int,
|
||
"FILTER": bool,
|
||
"DRIFT_MODEL": str,
|
||
# "REJECT_PAIRS": bool,
|
||
# "FORCE_DROP_ANNOTATIONS": list,
|
||
# "FILTER_LOW_PASS": float,
|
||
# "FILTER_HIGH_PASS": float,
|
||
# "EPOCH_PAIR_TOLERANCE_WINDOW": int,
|
||
}
|
||
|
||
|
||
class ProcessingError(Exception):
|
||
def __init__(self, message: str = "Something went wrong!"):
|
||
self.message = message
|
||
super().__init__(self.message)
|
||
|
||
|
||
# Ensure that we are working in the directory of this file
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
os.chdir(script_dir)
|
||
|
||
PLATFORM_NAME = platform.system().lower()
|
||
|
||
# Configure logging to file with timestamps and realtime flush
|
||
if PLATFORM_NAME == 'darwin':
|
||
logging.basicConfig(
|
||
filename=os.path.join(os.path.dirname(sys.executable), "../../../fnirs_analysis.log"),
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s',
|
||
datefmt='%Y-%m-%d %H:%M:%S',
|
||
filemode='a'
|
||
)
|
||
|
||
else:
|
||
logging.basicConfig(
|
||
filename='fnirs_analysis.log',
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(processName)s - %(levelname)s - %(message)s',
|
||
datefmt='%Y-%m-%d %H:%M:%S',
|
||
filemode='a'
|
||
)
|
||
|
||
logger = logging.getLogger()
|
||
|
||
|
||
|
||
|
||
def set_config_me(config: dict[str, Any]) -> None:
|
||
"""
|
||
Validates and applies the given configuration dictionary.
|
||
|
||
Parameters
|
||
----------
|
||
config : dict[str, Any]
|
||
Dictionary containing configuration keys and their values.
|
||
"""
|
||
logger.info(f"[DEBUG] set_config called")
|
||
|
||
globals().update(config)
|
||
|
||
|
||
def set_metadata(file_path, metadata: dict[str, Any]) -> None:
|
||
"""
|
||
Validates and applies the given configuration dictionary.
|
||
|
||
Parameters
|
||
----------
|
||
config : dict[str, Any]
|
||
Dictionary containing configuration keys and their values.
|
||
"""
|
||
logger.info(f"[DEBUG] set_metadata called")
|
||
|
||
globals()['AGE'] = 25
|
||
globals()['GENDER'] = ""
|
||
globals()['GROUP'] = "Default"
|
||
|
||
if metadata.get(file_path) is not None:
|
||
|
||
file_metadata = metadata.get(file_path, {})
|
||
|
||
for key in ("AGE", "GENDER", "GROUP"):
|
||
val = file_metadata.get(key, None)
|
||
if val not in (None, '', [], {}, ()): # check for "empty" values
|
||
globals()[key] = val
|
||
from queue import Empty # This works with multiprocessing.Manager().Queue()
|
||
|
||
|
||
def gui_entry(config: dict[str, Any], gui_queue: Queue, progress_queue: Queue) -> None:
|
||
def forward_progress():
|
||
while True:
|
||
try:
|
||
msg = progress_queue.get(timeout=1)
|
||
if msg == "__done__":
|
||
break
|
||
gui_queue.put(msg)
|
||
except Empty:
|
||
continue
|
||
except Exception as e:
|
||
gui_queue.put({
|
||
"type": "error",
|
||
"error": f"Forwarding thread crashed: {e}",
|
||
"traceback": traceback.format_exc()
|
||
})
|
||
break
|
||
|
||
t = threading.Thread(target=forward_progress, daemon=True)
|
||
t.start()
|
||
|
||
try:
|
||
file_paths = config['SNIRF_FILES']
|
||
file_params = config['PARAMS']
|
||
file_metadata = config['METADATA']
|
||
max_workers = file_params.get("MAX_WORKERS", int(os.cpu_count()/4))
|
||
|
||
results = process_multiple_participants(
|
||
file_paths, file_params, file_metadata, progress_queue, max_workers
|
||
)
|
||
|
||
gui_queue.put({"success": True, "result": results})
|
||
|
||
except Exception as e:
|
||
gui_queue.put({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
})
|
||
|
||
finally:
|
||
# Always send done to the thread and avoid hanging
|
||
try:
|
||
progress_queue.put("__done__")
|
||
except:
|
||
pass
|
||
t.join(timeout=5) # prevent permanent hang
|
||
|
||
|
||
|
||
def process_participant_worker(args):
|
||
file_path, file_params, file_metadata, progress_queue = args
|
||
|
||
set_config_me(file_params)
|
||
set_metadata(file_path, file_metadata)
|
||
logger.info(f"DEBUG: Metadata for {file_path}: AGE={globals().get('AGE')}, GENDER={globals().get('GENDER')}, GROUP={globals().get('GROUP')}")
|
||
|
||
def progress_callback(step_idx):
|
||
if progress_queue:
|
||
progress_queue.put(('progress', file_path, step_idx))
|
||
|
||
try:
|
||
result = process_participant(file_path, progress_callback=progress_callback)
|
||
return file_path, result, None
|
||
except Exception as e:
|
||
error_trace = traceback.format_exc()
|
||
return file_path, None, (str(e), error_trace)
|
||
|
||
|
||
def process_multiple_participants(file_paths, file_params, file_metadata, progress_queue=None, max_workers=None):
|
||
results_by_file = {}
|
||
|
||
file_args = [(file_path, file_params, file_metadata, progress_queue) for file_path in file_paths]
|
||
|
||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||
futures = {executor.submit(process_participant_worker, arg): arg[0] for arg in file_args}
|
||
|
||
for future in as_completed(futures):
|
||
file_path = futures[future]
|
||
try:
|
||
file_path, result, error = future.result()
|
||
if error:
|
||
error_message, error_traceback = error
|
||
if progress_queue:
|
||
progress_queue.put({
|
||
"type": "error",
|
||
"file": file_path,
|
||
"error": error_message,
|
||
"traceback": error_traceback
|
||
})
|
||
continue
|
||
|
||
results_by_file[file_path] = result
|
||
except Exception as e:
|
||
print(f"Unexpected error processing {file_path}: {e}")
|
||
|
||
return results_by_file
|
||
|
||
|
||
|
||
|
||
def markbad(data, ax, ch_names: list[str]) -> None:
|
||
"""
|
||
Add a strikethrough to a plot for channels marked as bad.
|
||
|
||
Parameters
|
||
----------
|
||
data : BaseRaw
|
||
The loaded data object to process.
|
||
ax : Axes
|
||
Matplotlib Axes object where the strikethrough lines will be drawn.
|
||
ch_names : list[str]
|
||
List of channel names corresponding to the y-axis of the plot.
|
||
"""
|
||
|
||
# Iterate over all the channels
|
||
for i, ch in enumerate(ch_names):
|
||
|
||
# If it is marked as bad, place a strikethrough on the channel
|
||
if ch in data.info["bads"]:
|
||
ax.axhline(i + 0.5, ls="solid", lw=4, color="black", zorder=10) # type: ignore
|
||
|
||
|
||
|
||
def plot_timechannel_quality_metrics(data, scores, times: list[tuple[float]], color_stops: tuple[list[float], list[float]], threshold: float, title: Optional[str] = None):
|
||
|
||
"""
|
||
Generate two heatmaps visualizing channel quality metrics over time.
|
||
|
||
Parameters
|
||
----------
|
||
data : BaseRaw
|
||
The loaded data object to process.
|
||
scores : NDArray[float64]
|
||
A 2D array of quality scores for each channel over time.
|
||
times : list[tuple[float]]
|
||
List of time boundaries used to label each score column.
|
||
color_stops : tuple[list[float], list[float]]
|
||
Two lists of color values for custom colormaps.
|
||
threshold : float,
|
||
Threshold value for the color bar.
|
||
title : Optional[str], optional
|
||
Base title for the figures, (default is None).
|
||
|
||
Returns
|
||
-------
|
||
tuple[Figure, Figure]
|
||
- Figure: Heatmap of all scores across channels and time.
|
||
- Figure: Binary heatmap showing only scores above the threshold.
|
||
"""
|
||
|
||
# Get only the hbo / hbr channels once as we dont need to see the same results twice
|
||
half_ch = len(getattr(data, "ch_names")) // 2
|
||
ch_names = getattr(data, "ch_names")[:half_ch]
|
||
scores = scores[:half_ch, :]
|
||
|
||
# Extract rounded time points to use as column headers
|
||
cols = [np.round(t[0]) for t in times]
|
||
n_chans = len(ch_names)
|
||
vsize = 0.2 * n_chans
|
||
|
||
# Create the first figure
|
||
fig1, ax1 = plt.subplots(figsize=(10, vsize), layout="constrained") # type: ignore
|
||
fig1.suptitle(title + " - All Scores", fontsize=16, fontweight="bold") # type: ignore
|
||
|
||
# Create a DataFrame to structure data for the heatmap
|
||
data_to_plot = DataFrame(
|
||
data=scores,
|
||
columns=pd.Index(cols, name="Time (s)"),
|
||
index=pd.Index(ch_names, name="Channel"),
|
||
)
|
||
|
||
# Define a custom colormap using provided color stops and base colors
|
||
base_colors = ['red', 'red', 'yellow', 'green', 'green']
|
||
colors = list(zip(color_stops[0], base_colors[:len(color_stops[0])]))
|
||
cmap = mcolors.LinearSegmentedColormap.from_list('gyr', colors)
|
||
|
||
# Plot heatmap of scores
|
||
sns.heatmap( # type: ignore
|
||
data=data_to_plot,
|
||
cmap=cmap,
|
||
vmin=0,
|
||
vmax=1,
|
||
cbar_kws=dict(label="Score"),
|
||
ax=ax1,
|
||
)
|
||
|
||
# Add vertical dashed lines at each time boundary, sit the title, and place a black strikethrough through a bad channel
|
||
for x in range(1, len(times)):
|
||
ax1.axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") # type: ignore
|
||
ax1.set_title("All Scores", fontweight="bold") # type: ignore
|
||
markbad(data, ax1, ch_names)
|
||
|
||
# Calculate average score per channel and annotate to the right of the heatmap
|
||
avg_sci_subset: pd.Series[float] = data_to_plot.mean(axis=1) # type: ignore
|
||
norm = mcolors.Normalize(vmin=0, vmax=1)
|
||
text_x = data_to_plot.shape[1] + 0.5
|
||
for i, val in enumerate(avg_sci_subset):
|
||
color = cmap(norm(val))
|
||
ax1.text( # type: ignore
|
||
text_x,
|
||
i + 0.5,
|
||
f"{val:.3f}",
|
||
va='center',
|
||
ha='left',
|
||
fontsize=9,
|
||
color=color
|
||
)
|
||
ax1.set_xlim(right=text_x + 1.5)
|
||
|
||
plt.close(fig1)
|
||
|
||
# Create the second figure
|
||
fig2, ax2 = plt.subplots(figsize=(10, vsize), layout="constrained") # type: ignore
|
||
fig2.suptitle(title + " - Scores Above Threshold", fontsize=16, fontweight="bold") # type: ignore
|
||
|
||
# Create a DataFrame to structure data for the heatmap
|
||
data_to_plot = DataFrame(
|
||
data=scores > threshold,
|
||
columns=pd.Index(cols, name="Time (s)"),
|
||
index=pd.Index(ch_names, name="Channel"),
|
||
)
|
||
|
||
# Define a custom colormap using provided color stops and base colors
|
||
base_colors = ['red', 'red', 'white', 'white']
|
||
colors = list(zip(color_stops[1], base_colors[:len(color_stops[1])]))
|
||
cmap = mcolors.LinearSegmentedColormap.from_list('gyr', colors)
|
||
|
||
# Plot heatmap of scores
|
||
sns.heatmap( # type: ignore
|
||
data=data_to_plot,
|
||
vmin=0,
|
||
vmax=1,
|
||
cmap=cmap,
|
||
cbar_kws=dict(label="Score"),
|
||
ax=ax2,
|
||
)
|
||
|
||
# Add vertical dashed lines at each time boundary, sit the title, and place a black strikethrough through a bad channel
|
||
for x in range(1, len(times)):
|
||
ax2.axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") # type: ignore
|
||
ax2.set_title("Scores > Threshold", fontweight="bold") # type: ignore
|
||
markbad(data, ax2, ch_names)
|
||
|
||
plt.close(fig2)
|
||
|
||
return fig1, fig2
|
||
|
||
|
||
def scalp_coupling_index_windowed_raw(data, time_window: float = 3.0, l_freq: float = 0.7, h_freq: float = 1.5, l_trans_bandwidth: float = 0.3, h_trans_bandwidth: float = 0.3):
|
||
"""
|
||
Compute windowed scalp coupling index (SCI) across fNIRS channels.
|
||
|
||
Parameters
|
||
----------
|
||
data : BaseRaw
|
||
The loaded data object to process.
|
||
time_window : float, optional
|
||
Length of each time window in seconds (default is 3.0).
|
||
l_freq : float, optional
|
||
Low cutoff frequency for filtering in Hz (default is 0.7).
|
||
h_freq : float, optional
|
||
High cutoff frequency for filtering in Hz (default is 1.5).
|
||
l_trans_bandwidth : float, optional
|
||
Transition bandwidth for the low cutoff in Hz (default is 0.3).
|
||
h_trans_bandwidth : float, optional
|
||
Transition bandwidth for the high cutoff in Hz (default is 0.3).
|
||
|
||
Returns
|
||
-------
|
||
tuple[BaseRaw, NDArray[float64], list[tuple[float, float]]]
|
||
- BaseRaw: The original data object (unchanged). Ensures compatibility with peak_power().
|
||
- NDArray[float64]: Correlation scores for each channel and time window.
|
||
- list[tuple[float, float]]: Time intervals for each window in seconds.
|
||
"""
|
||
|
||
# Pick only fNIRS channels and sort them by channel name
|
||
picks: NDArray[np.intp] = pick_types(cast(Info, data.info), fnirs=True) # type: ignore
|
||
picks = picks[np.argsort([getattr(data, "ch_names")[pick] for pick in picks])]
|
||
|
||
# FIXME: This may happen if the heart rate calculation tries to set a value way too low
|
||
if l_freq < 0.3:
|
||
l_freq = 0.3
|
||
|
||
# Band-pass filter the selected fNIRS channels
|
||
filtered_data = filter_data(
|
||
getattr(data, "_data"),
|
||
getattr(data, "info")["sfreq"],
|
||
l_freq,
|
||
h_freq,
|
||
picks=picks,
|
||
verbose=False,
|
||
l_trans_bandwidth=l_trans_bandwidth, # type: ignore
|
||
h_trans_bandwidth=h_trans_bandwidth, # type: ignore
|
||
)
|
||
|
||
# Calculate number of samples per time window, the total number of windows, and prepare output variables
|
||
window_samples = int(np.ceil(time_window * getattr(data, "info")["sfreq"]))
|
||
n_windows = int(np.floor(len(data) / window_samples))
|
||
scores = np.zeros((len(picks), n_windows))
|
||
times: list[tuple[float, float]] = []
|
||
|
||
# Slide through the data in windows to compute scalp coupling index (SCI)
|
||
for window in range(n_windows):
|
||
start_sample = int(window * window_samples)
|
||
end_sample = start_sample + window_samples
|
||
end_sample = np.min([end_sample, len(data) - 1])
|
||
|
||
# Track time boundaries for each window
|
||
t_start = getattr(data, "times")[start_sample]
|
||
t_stop = getattr(data, "times")[end_sample]
|
||
times.append((t_start, t_stop))
|
||
|
||
# Iterate through channels in pairs (hbo, hbr). This requires them to be sorted by channel name
|
||
for ii in range(0, len(picks), 2):
|
||
c1 = filtered_data[picks[ii]][start_sample:end_sample]
|
||
c2 = filtered_data[picks[ii + 1]][start_sample:end_sample]
|
||
|
||
# Ensure the correlation data is valid
|
||
if np.std(c1) == 0 or np.std(c2) == 0 or np.any(np.isnan(c1)) or np.any(np.isnan(c2)):
|
||
c = 0
|
||
else:
|
||
c = np.corrcoef(c1, c2)[0][1]
|
||
|
||
# Assign the computed correlation to both channels in the pair
|
||
scores[ii, window] = c
|
||
scores[ii + 1, window] = c
|
||
|
||
scores = scores[np.argsort(picks)]
|
||
|
||
return data, scores, times
|
||
|
||
def calculate_scalp_coupling(data, l_freq: float = 0.7, h_freq: float = 1.5):
|
||
"""
|
||
Calculate the scalp coupling index (SCI) and identify bad channels based on a threshold.
|
||
|
||
Parameters
|
||
----------
|
||
data : BaseRaw
|
||
The loaded data object to process.
|
||
l_freq : float, optional
|
||
Low cutoff frequency for bandpass filtering in Hz (default is 0.7).
|
||
h_freq : float, optional
|
||
High cutoff frequency for bandpass filtering in Hz (default is 1.5)
|
||
|
||
Returns
|
||
-------
|
||
tuple[list[str], Figure, Figure]
|
||
- list[str]: Channel names identified as bad based on SCI threshold.
|
||
- Figure: Heatmap of all SCI scores across time and channels.
|
||
- Figure: Binary heatmap of SCI scores exceeding the threshold.
|
||
"""
|
||
|
||
print("Calculating scalp coupling index...")
|
||
|
||
# Compute the SCI
|
||
_, scores, times = scalp_coupling_index_windowed_raw(data, time_window=SCI_TIME_WINDOW, l_freq=l_freq, h_freq=h_freq)
|
||
|
||
# Identify channels that don't meet the provided threshold
|
||
print("Identifying channels that do not meet the threshold...")
|
||
sci = scores.mean(axis=1)
|
||
data.info["bads"] = list(compress(cast(list[str], getattr(data, "ch_names")), sci < SCI_THRESHOLD))
|
||
|
||
# Determine the colors based on the threshold, and create the figures
|
||
print("Creating the figures...")
|
||
color_stops = ([0.0, SCI_THRESHOLD, SCI_THRESHOLD+0.1, 0.8, 1.0], [0.0, SCI_THRESHOLD, SCI_THRESHOLD, 1.0])
|
||
fig1, fig2 = plot_timechannel_quality_metrics(data, scores, times, color_stops, SCI_THRESHOLD, "Scalp Coupling Index")
|
||
|
||
print("Successfully calculated scalp coupling index.")
|
||
|
||
return list(compress(cast(list[str], getattr(data, "ch_names")), sci < SCI_THRESHOLD)), fig1, fig2
|
||
|
||
|
||
|
||
def calculate_signal_noise_ratio(data):
|
||
"""
|
||
Calculates the signal-to-noise ratio (SNR) for each channel and identifies those below a defined threshold.
|
||
|
||
Parameters
|
||
----------
|
||
data : BaseRaw
|
||
The loaded data object to process.
|
||
|
||
Returns
|
||
-------
|
||
tuple[list[str], Figure]
|
||
- list[str]: A list of channel names that fall below the SNR threshold and are considered bad.
|
||
- Figure: A matplotlib Figure showing the channels' SNR values.
|
||
"""
|
||
|
||
print("Calculating signal to noise ratio...")
|
||
|
||
# Compute the signal-to-noise ratio values
|
||
print("Computing the signal to noise power...")
|
||
signal_band=(0.01, 0.5)
|
||
noise_band=(1.0, 10.0)
|
||
data_signal = data.copy().filter(*signal_band, verbose=False) #type: ignore
|
||
data_noise = data.copy().filter(*noise_band, verbose=False) #type: ignore
|
||
signal_power = np.mean(data_signal.get_data()**2, axis=1) #type: ignore
|
||
noise_power = np.mean(data_noise.get_data()**2, axis=1) #type: ignore
|
||
|
||
# Calculate the snr using the standard formula for dB
|
||
snr = 10 * np.log10(signal_power / (noise_power + np.finfo(float).eps))
|
||
|
||
# TODO: Understand what this does
|
||
groups: dict[str, list[str]] = {}
|
||
for ch in getattr(data, "ch_names"):
|
||
# Look for the space in the channel names and remove the characters after
|
||
# This is so we can get both oxy and deoxy to remove, as they will have the same source and detector
|
||
base = ch.rsplit(' ', 1)[0]
|
||
groups.setdefault(base, []).append(ch) # type: ignore
|
||
|
||
# If any of the channels do not meet our threshold, they will get inserted into the bad_channels set
|
||
bad_channels: set[str] = set()
|
||
for base, ch_list in groups.items():
|
||
if any(s < SNR_THRESHOLD for s, ch in zip(snr, getattr(data, "ch_names")) if ch in ch_list):
|
||
bad_channels.update(ch_list)
|
||
|
||
# Design and create the figure
|
||
print("Creating the figure...")
|
||
snr_fig, ax = plt.subplots(figsize=(12, 4), layout="constrained") # type: ignore
|
||
colors = [(0/20, 'red'), (SNR_THRESHOLD/20, 'red'), ((SNR_THRESHOLD+.5)/20, 'yellow'), ((SNR_THRESHOLD+1)/20, 'green'), (20/20, 'green')]
|
||
cmap = LinearSegmentedColormap.from_list('custom_snr_cmap', colors)
|
||
norm = mcolors.Normalize(vmin=0, vmax=20)
|
||
scatter = ax.scatter(range(len(snr)), snr, c=snr, cmap=cmap, alpha=0.8, s=100, norm=norm) # type: ignore
|
||
ax.set(xlabel="Channel Number", ylabel="Signal-to-Noise Ratio (dB)", xlim=[0, len(snr)], ylim=[0, 20])
|
||
ax.axhline(SNR_THRESHOLD, color='black', linestyle='--', alpha=0.3, linewidth=1) # type: ignore
|
||
cbar = snr_fig.colorbar(scatter, ax=ax, label="SNR Thresholds (dB)") # type: ignore
|
||
cbar.set_ticks([0, SNR_THRESHOLD, SNR_THRESHOLD+1, 20]) # type: ignore
|
||
cbar.set_ticklabels(['0', str(SNR_THRESHOLD), str(SNR_THRESHOLD+1), '20']) # type: ignore
|
||
|
||
plt.close()
|
||
|
||
print("Successfully calculated signal to noise ratio.")
|
||
|
||
return list(bad_channels), snr_fig
|
||
|
||
|
||
|
||
def build_fnirs_adjacency(raw, threshold_meters=0.03):
|
||
"""Build an adjacency dictionary for fNIRS channels using 3D distance."""
|
||
# Extract channel positions
|
||
ch_locs = []
|
||
ch_names = []
|
||
|
||
for ch in raw.info['chs']:
|
||
loc = ch['loc'][:3] # Get x, y, z coordinates
|
||
if not np.isnan(loc).any():
|
||
ch_locs.append(loc)
|
||
ch_names.append(ch['ch_name'])
|
||
|
||
ch_locs = np.array(ch_locs)
|
||
|
||
# Compute pairwise distances
|
||
dists = cdist(ch_locs, ch_locs)
|
||
|
||
# Build adjacency dictionary
|
||
adjacency = {}
|
||
for i, ch_name in enumerate(ch_names):
|
||
neighbors = [ch_names[j] for j in range(len(ch_names))
|
||
if 0 < dists[i, j] < threshold_meters]
|
||
adjacency[ch_name] = neighbors
|
||
|
||
return adjacency
|
||
|
||
|
||
|
||
def get_hbo_hbr_picks(raw):
|
||
# Pick all fNIRS channels
|
||
fnirs_picks = pick_types(raw.info, fnirs=True, exclude=[])
|
||
|
||
# Extract wavelengths from channel names (expecting something like 'S6_D4 763' or 'S6_D4 841')
|
||
wavelengths = []
|
||
for idx in fnirs_picks:
|
||
ch_name = raw.ch_names[idx]
|
||
# Extract last 3 digits from channel name using regex
|
||
match = re.search(r'(\d{3})$', ch_name)
|
||
if match:
|
||
wavelengths.append(int(match.group(1)))
|
||
else:
|
||
raise ValueError(f"Channel name '{ch_name}' does not end with 3 digits.")
|
||
|
||
wavelengths = np.array(wavelengths)
|
||
unique_wavelengths = np.unique(wavelengths)
|
||
if len(unique_wavelengths) != 2:
|
||
raise RuntimeError(f"Expected exactly 2 distinct wavelengths, found {unique_wavelengths}")
|
||
|
||
# Determine which is HbO (larger) and which is HbR (smaller)
|
||
hbr_wl = unique_wavelengths.min()
|
||
hbo_wl = unique_wavelengths.max()
|
||
|
||
print(f"HbR wavelength: {hbr_wl}, HbO wavelength: {hbo_wl}")
|
||
|
||
# Find picks corresponding to each wavelength
|
||
hbr_picks = [fnirs_picks[i] for i, wl in enumerate(wavelengths) if wl == hbr_wl]
|
||
hbo_picks = [fnirs_picks[i] for i, wl in enumerate(wavelengths) if wl == hbo_wl]
|
||
|
||
print(f"Found {len(hbr_picks)} HbR channels and {len(hbo_picks)} HbO channels.")
|
||
|
||
return hbo_picks, hbr_picks, hbo_wl, hbr_wl
|
||
|
||
|
||
def interpolate_fNIRS_bads_weighted_average(raw, bad_channels, max_dist=0.03, min_neighbors=2):
|
||
"""
|
||
Interpolate bad fNIRS channels using a distance-weighted average of nearby good channels.
|
||
|
||
Parameters
|
||
----------
|
||
raw : mne.io.Raw
|
||
The raw fNIRS data with bads marked in raw.info['bads'].
|
||
max_dist : float
|
||
Maximum distance (in meters) to consider for neighboring good channels.
|
||
min_neighbors : int
|
||
Minimum number of neighbors required to interpolate a bad channel.
|
||
|
||
Returns
|
||
-------
|
||
raw : mne.io.Raw
|
||
Modified raw object with bads interpolated (in-place).
|
||
"""
|
||
|
||
print("Finding fNIRS channels...")
|
||
hbo_picks, hbr_picks, hbo_wl, hbr_wl = get_hbo_hbr_picks(raw)
|
||
|
||
|
||
if len(hbo_picks) != len(hbr_picks):
|
||
raise RuntimeError("Number of HbO and HbR channels must be the same.")
|
||
|
||
# Base names without wavelength for pairing
|
||
def base_name(ch_name):
|
||
# Strip last 4 chars assuming format ' <wavelength>'
|
||
# e.g. "S6_D6 841" -> "S6_D6"
|
||
return ch_name[:-4]
|
||
|
||
hbo_names = [base_name(raw.ch_names[i]) for i in hbo_picks]
|
||
hbr_names = [base_name(raw.ch_names[i]) for i in hbr_picks]
|
||
|
||
# Sanity check: pairs must match
|
||
for i in range(len(hbo_names)):
|
||
if hbo_names[i] != hbr_names[i]:
|
||
raise RuntimeError(f"Channel pairs do not match: {hbo_names[i]} vs {hbr_names[i]}")
|
||
|
||
# Identify bad pairs if either channel in pair is bad
|
||
bad_pairs = []
|
||
good_pairs = []
|
||
for i, base in enumerate(hbo_names):
|
||
hbo_ch = raw.ch_names[hbo_picks[i]]
|
||
hbr_ch = raw.ch_names[hbr_picks[i]]
|
||
if (hbo_ch in raw.info['bads']) or (hbr_ch in raw.info['bads']):
|
||
bad_pairs.append(i)
|
||
else:
|
||
good_pairs.append(i)
|
||
|
||
print(f"Total pairs: {len(hbo_names)}")
|
||
print(f"Good pairs: {len(good_pairs)}")
|
||
print(f"Bad pairs to interpolate: {len(bad_pairs)}")
|
||
|
||
if len(bad_pairs) == 0:
|
||
print("No bad pairs found. Skipping interpolation.")
|
||
return raw
|
||
|
||
# Extract locations (use HbO channel loc as pair location)
|
||
locs = np.array([raw.info['chs'][hbo_picks[i]]['loc'][:3] for i in range(len(hbo_names))])
|
||
good_locs = locs[good_pairs]
|
||
bad_locs = locs[bad_pairs]
|
||
|
||
# Compute distance matrix between bad and good pairs
|
||
dist_matrix = cdist(bad_locs, good_locs)
|
||
|
||
interpolated_pairs = []
|
||
|
||
for i, bad_idx in enumerate(bad_pairs):
|
||
bad_base = hbo_names[bad_idx]
|
||
distances = dist_matrix[i]
|
||
close_idxs = np.where(distances < max_dist)[0]
|
||
|
||
print(f"\nInterpolating pair {bad_base} (index {bad_idx})")
|
||
print(f" Nearby good pairs found: {len(close_idxs)}")
|
||
|
||
if len(close_idxs) < min_neighbors:
|
||
print(f" Skipping {bad_base}: not enough neighbors (found {len(close_idxs)} < {min_neighbors})")
|
||
continue
|
||
|
||
weights = 1 / (distances[close_idxs] + 1e-6)
|
||
weights /= weights.sum()
|
||
|
||
neighbor_hbo_indices = [hbo_picks[good_pairs[idx]] for idx in close_idxs]
|
||
neighbor_hbr_indices = [hbr_picks[good_pairs[idx]] for idx in close_idxs]
|
||
|
||
neighbor_hbo_data = raw._data[neighbor_hbo_indices, :]
|
||
neighbor_hbr_data = raw._data[neighbor_hbr_indices, :]
|
||
|
||
interpolated_hbo = np.average(neighbor_hbo_data, axis=0, weights=weights)
|
||
interpolated_hbr = np.average(neighbor_hbr_data, axis=0, weights=weights)
|
||
|
||
raw._data[hbo_picks[bad_idx]] = interpolated_hbo
|
||
raw._data[hbr_picks[bad_idx]] = interpolated_hbr
|
||
|
||
interpolated_pairs.append(bad_base)
|
||
|
||
if interpolated_pairs:
|
||
bad_ch_to_remove = []
|
||
for base_ in interpolated_pairs:
|
||
bad_ch_to_remove.append(base_ + f" {hbr_wl}") # HbR
|
||
bad_ch_to_remove.append(base_ + f" {hbo_wl}") # HbO
|
||
|
||
raw.info['bads'] = [ch for ch in raw.info['bads'] if ch not in bad_ch_to_remove]
|
||
|
||
print("\nInterpolation complete.\n")
|
||
|
||
for ch in raw.info['bads']:
|
||
print(f"Channel {ch} still marked as bad.")
|
||
|
||
print("Bads cleared:", raw.info['bads'])
|
||
fig_raw_after = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="After interpolation", show=False)
|
||
|
||
return raw, fig_raw_after
|
||
|
||
|
||
|
||
def calculate_signal_noise_ratio(data):
|
||
"""
|
||
Calculates the signal-to-noise ratio (SNR) for each channel and identifies those below a defined threshold.
|
||
|
||
Parameters
|
||
----------
|
||
data : BaseRaw
|
||
The loaded data object to process.
|
||
|
||
Returns
|
||
-------
|
||
tuple[list[str], Figure]
|
||
- list[str]: A list of channel names that fall below the SNR threshold and are considered bad.
|
||
- Figure: A matplotlib Figure showing the channels' SNR values.
|
||
"""
|
||
|
||
print("Calculating signal to noise ratio...")
|
||
|
||
# Compute the signal-to-noise ratio values
|
||
print("Computing the signal to noise power...")
|
||
signal_band=(0.01, 0.5)
|
||
noise_band=(1.0, 10.0)
|
||
data_signal = data.copy().filter(*signal_band, verbose=False) #type: ignore
|
||
data_noise = data.copy().filter(*noise_band, verbose=False) #type: ignore
|
||
signal_power = np.mean(data_signal.get_data()**2, axis=1) #type: ignore
|
||
noise_power = np.mean(data_noise.get_data()**2, axis=1) #type: ignore
|
||
|
||
# Calculate the snr using the standard formula for dB
|
||
snr = 10 * np.log10(signal_power / (noise_power + np.finfo(float).eps))
|
||
|
||
# TODO: Understand what this does
|
||
groups: dict[str, list[str]] = {}
|
||
for ch in getattr(data, "ch_names"):
|
||
# Look for the space in the channel names and remove the characters after
|
||
# This is so we can get both oxy and deoxy to remove, as they will have the same source and detector
|
||
base = ch.rsplit(' ', 1)[0]
|
||
groups.setdefault(base, []).append(ch) # type: ignore
|
||
|
||
# If any of the channels do not meet our threshold, they will get inserted into the bad_channels set
|
||
bad_channels: set[str] = set()
|
||
for base, ch_list in groups.items():
|
||
if any(s < SNR_THRESHOLD for s, ch in zip(snr, getattr(data, "ch_names")) if ch in ch_list):
|
||
bad_channels.update(ch_list)
|
||
|
||
# Design and create the figure
|
||
print("Creating the figure...")
|
||
snr_fig, ax = plt.subplots(figsize=(12, 4), layout="constrained") # type: ignore
|
||
colors = [(0/20, 'red'), (SNR_THRESHOLD/20, 'red'), ((SNR_THRESHOLD+.5)/20, 'yellow'), ((SNR_THRESHOLD+1)/20, 'green'), (20/20, 'green')]
|
||
cmap = LinearSegmentedColormap.from_list('custom_snr_cmap', colors)
|
||
norm = mcolors.Normalize(vmin=0, vmax=20)
|
||
scatter = ax.scatter(range(len(snr)), snr, c=snr, cmap=cmap, alpha=0.8, s=100, norm=norm) # type: ignore
|
||
ax.set(xlabel="Channel Number", ylabel="Signal-to-Noise Ratio (dB)", xlim=[0, len(snr)], ylim=[0, 20])
|
||
ax.axhline(SNR_THRESHOLD, color='black', linestyle='--', alpha=0.3, linewidth=1) # type: ignore
|
||
cbar = snr_fig.colorbar(scatter, ax=ax, label="SNR Thresholds (dB)") # type: ignore
|
||
cbar.set_ticks([0, SNR_THRESHOLD, SNR_THRESHOLD+1, 20]) # type: ignore
|
||
cbar.set_ticklabels(['0', str(SNR_THRESHOLD), str(SNR_THRESHOLD+1), '20']) # type: ignore
|
||
|
||
plt.close()
|
||
|
||
print("Successfully calculated signal to noise ratio.")
|
||
|
||
return list(bad_channels), snr_fig
|
||
|
||
|
||
|
||
|
||
def calculate_peak_power(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5) -> tuple[list[str], Figure, Figure]:
|
||
"""
|
||
Calculate peak spectral power (PSP) for fNIRS channels and identify bad channels.
|
||
|
||
Parameters
|
||
----------
|
||
data : BaseRaw
|
||
The loaded data object to process.
|
||
l_freq : float, optional
|
||
Low cutoff frequency for filtering in Hz (default is 0.7)
|
||
h_freq : float, optional
|
||
High cutoff frequency for filtering in Hz (default is 1.5)
|
||
|
||
Returns
|
||
-------
|
||
tuple[list[str], Figure, Figure]
|
||
- list[str]: Names of channels below the PSP threshold.
|
||
- Figure: Heatmap of all PSP scores.
|
||
- Figure: Heatmap of scores above the PSP threshold.
|
||
"""
|
||
|
||
|
||
# Compute the PSP
|
||
_, scores, times = cast(tuple[NDArray[float64], NDArray[float64], list[tuple[float]]], peak_power(data, time_window=PSP_TIME_WINDOW, threshold=PSP_THRESHOLD, l_freq=l_freq, h_freq=h_freq))
|
||
|
||
# Identify channels that don't meet the provided threshold
|
||
psp = scores.mean(axis=1)
|
||
data.info["bads"] = list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD))
|
||
|
||
# Determine the colors based on the threshold, and create the figures
|
||
color_stops = ([0.0, PSP_THRESHOLD, PSP_THRESHOLD+0.1, PSP_THRESHOLD+0.2, 1.0], [0.0, PSP_THRESHOLD, PSP_THRESHOLD, 1.0])
|
||
psp1, psp2 = plot_timechannel_quality_metrics(data, scores, times, color_stops, PSP_THRESHOLD, "Peak Spectral Power")
|
||
|
||
|
||
return list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)), psp1, psp2
|
||
|
||
|
||
def mark_bads(raw, bad_sci, bad_snr, bad_psp):
|
||
bads_combined = list(set(bad_snr) | set(bad_sci) | set(bad_psp))
|
||
print(f"Automatically marked bad channels based on SNR and SCI: {bads_combined}")
|
||
|
||
raw.info['bads'].extend(bads_combined)
|
||
|
||
# Organize channels into categories
|
||
sets = [
|
||
(bad_sci, "SCI"),
|
||
(bad_psp, "PSP"),
|
||
(bad_snr, "SNR"),
|
||
]
|
||
|
||
# Graph what channels were dropped and why they were dropped
|
||
channel_categories: dict[str, str] = {}
|
||
|
||
for ch in bads_combined:
|
||
present_in = [name for s, name in sets if ch in s]
|
||
# Create a label for the category
|
||
if len(present_in) == 1:
|
||
label = f"{present_in[0]} only"
|
||
else:
|
||
label = " + ".join(sorted(present_in))
|
||
channel_categories[ch] = label
|
||
|
||
# Sort channels alphabetically within categories for nicer visualization
|
||
categories = sorted(set(channel_categories.values()))
|
||
channel_names: list[str] = []
|
||
category_labels: list[str] = []
|
||
for cat in categories:
|
||
chs_in_cat = sorted([ch for ch, c in channel_categories.items() if c == cat])
|
||
channel_names.extend(chs_in_cat)
|
||
category_labels.extend([cat] * len(chs_in_cat))
|
||
|
||
colors = {cat: FIXED_CATEGORY_COLORS[cat] for cat in categories}
|
||
|
||
# Create the figure
|
||
fig_dropped, ax = plt.subplots(figsize=(10, max(3, len(channel_names) * 0.3))) # type: ignore
|
||
y_pos = range(len(channel_names))
|
||
ax.barh(y_pos, [1]*len(channel_names), color=[colors[cat] for cat in category_labels]) # type: ignore
|
||
ax.set_yticks(y_pos) # type: ignore
|
||
ax.set_yticklabels(channel_names) # type: ignore
|
||
ax.set_xlabel("Marked as Bad") # type: ignore
|
||
ax.set_title(f"Bad Channels by Method for") # type: ignore
|
||
ax.set_xlim(0, 1)
|
||
ax.set_xticks([]) # type: ignore
|
||
ax.grid(axis='x', linestyle='--', alpha=0.7) # type: ignore
|
||
|
||
# Add a legend denoting why the channels were bad
|
||
for label, color in colors.items():
|
||
ax.bar(0, 0, color=color, label=label) # type: ignore
|
||
ax.legend() # type: ignore
|
||
|
||
fig_dropped.tight_layout()
|
||
|
||
|
||
raw_before = deepcopy(raw)
|
||
bads_channels = [ch for ch in raw.ch_names if ch in raw.info['bads']]
|
||
print(bads_channels)
|
||
if bads_channels:
|
||
fig_raw_before = raw_before.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], picks=bads_channels, title="What they were BEFORE", show=False)
|
||
else:
|
||
fig_dropped = None
|
||
fig_raw_before = None
|
||
|
||
return raw, fig_dropped, fig_raw_before, bads_channels
|
||
|
||
|
||
|
||
|
||
def filter_the_data(raw_haemo):
|
||
# --- STEP 5: Filtering (0.01–0.2 Hz bandpass) ---
|
||
fig_filter = raw_haemo.compute_psd(fmax=3).plot(
|
||
average=True, color="r", show=False, amplitude=True
|
||
)
|
||
|
||
if L_FREQ == 0 and H_FREQ != 0:
|
||
raw_haemo = raw_haemo.filter(l_freq=None, h_freq=H_FREQ, h_trans_bandwidth=0.02)
|
||
elif L_FREQ != 0 and H_FREQ == 0:
|
||
raw_haemo = raw_haemo.filter(l_freq=L_FREQ, h_freq=None, l_trans_bandwidth=0.002)
|
||
elif L_FREQ != 0 and H_FREQ != 0:
|
||
raw_haemo = raw_haemo.filter(l_freq=L_FREQ, h_freq=H_FREQ, l_trans_bandwidth=0.002, h_trans_bandwidth=0.02)
|
||
else:
|
||
print("No filter")
|
||
#raw_haemo = raw_haemo.filter(l_freq=None, h_freq=0.4, h_trans_bandwidth=0.2)
|
||
#raw_haemo = raw_haemo.filter(l_freq=None, h_freq=0.7, h_trans_bandwidth=0.2)
|
||
#raw_haemo = raw_haemo.filter(0.005, 0.7, h_trans_bandwidth=0.02, l_trans_bandwidth=0.002)
|
||
|
||
raw_haemo.compute_psd(fmax=3).plot(
|
||
average=True, axes=fig_filter.axes, color="g", amplitude=True, show=False
|
||
)
|
||
|
||
fig_raw_haemo_filter = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="Filtered HbO and HbR", show=False)
|
||
|
||
return raw_haemo, fig_filter, fig_raw_haemo_filter
|
||
|
||
|
||
|
||
def epochs_calculations(raw_haemo, events, event_dict):
|
||
fig_epochs = [] # List to store figures
|
||
|
||
# Create epochs from raw data
|
||
epochs = Epochs(raw_haemo,
|
||
events,
|
||
event_id=event_dict,
|
||
tmin=-5,
|
||
tmax=15,
|
||
baseline=(None, 0))
|
||
|
||
# Make a copy of the epochs and drop bad ones
|
||
epochs2 = epochs.copy()
|
||
epochs2.drop_bad()
|
||
|
||
# Plot drop log
|
||
# TODO: Why show this if we never use epochs2?
|
||
fig_epochs_dropped = epochs2.plot_drop_log(show=False)
|
||
fig_epochs.append(("fig_epochs_dropped", fig_epochs_dropped))
|
||
|
||
# Plot for each condition
|
||
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 4))
|
||
for idx, condition in enumerate(epochs.event_id.keys()):
|
||
logger.info(condition)
|
||
logger.info(idx)
|
||
# Plot images for each condition
|
||
fig_epochs_data = epochs[condition].plot_image(
|
||
combine="mean",
|
||
vmin=-1,
|
||
vmax=1,
|
||
ts_args=dict(ylim=dict(hbo=[-1, 1], hbr=[-1, 1])),
|
||
show=False
|
||
)
|
||
for j, fig in enumerate(fig_epochs_data):
|
||
logger.info("------------------------------------------")
|
||
logger.info(j)
|
||
logger.info(fig)
|
||
|
||
ax = fig.axes[0]
|
||
original_title = ax.get_title()
|
||
ax.set_title(f"{condition}: {original_title}")
|
||
fig_epochs.append((f"fig_{condition}_data_{idx}_{j}", fig)) # Store with a unique name
|
||
|
||
# Evoked average figure for each condition
|
||
evoked_avg = epochs[condition].average()
|
||
clims = dict(hbo=[-1, 1], hbr=[1, -1])
|
||
condition_fig = evoked_avg.plot_image(clim=clims, show=False)
|
||
|
||
for ax in condition_fig.axes:
|
||
original_title = ax.get_title()
|
||
ax.set_title(f"{original_title} - {condition}")
|
||
fig_epochs.append((f"evoked_avg_{condition}", condition_fig)) # Store with a unique name
|
||
|
||
# Prepare evokeds and colors for topographic plot
|
||
evokeds3 = []
|
||
colors = []
|
||
conditions = list(epochs.event_id.keys())
|
||
cmap = plt.get_cmap("tab10", len(conditions))
|
||
|
||
for idx, cond in enumerate(conditions):
|
||
evoked = epochs[cond].average(picks="hbo")
|
||
evokeds3.append(evoked)
|
||
colors.append(cmap(idx))
|
||
|
||
# Create the topographic plot
|
||
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 4))
|
||
help = plot_evoked_topo(evokeds3, color=colors, axes=axes, legend=False, show=False)
|
||
|
||
# Build custom legend
|
||
lines = []
|
||
for color in colors:
|
||
line = plt.Line2D([0], [0], color=color, lw=2)
|
||
lines.append(line)
|
||
|
||
fig.legend(lines, conditions, loc="lower right")
|
||
fig_epochs.append(("evoked_topo", help)) # Store with a unique name
|
||
|
||
unique_annotations = set(raw_haemo.annotations.description)
|
||
|
||
for cond in unique_annotations:
|
||
|
||
# Evoked response for specific condition ("Activity")
|
||
evoked_stim1 = epochs[cond].average()
|
||
|
||
fig_evoked_hbo = evoked_stim1.copy().pick(picks='hbo').plot(time_unit='s', show=False)
|
||
fig_evoked_hbr = evoked_stim1.copy().pick(picks='hbr').plot(time_unit='s', show=False)
|
||
|
||
fig_epochs.append((f"fig_evoked_hbo_{cond}", fig_evoked_hbo)) # Store with a unique name
|
||
fig_epochs.append((f"fig_evoked_hbr_{cond}", fig_evoked_hbr)) # Store with a unique name
|
||
|
||
print("Evoked HbO peak amplitude:", evoked_stim1.copy().pick(picks='hbo').data.max())
|
||
|
||
evokeds = {}
|
||
for condition in epochs2.event_id:
|
||
evokeds[condition] = epochs2[condition].average()
|
||
print(f"Condition '{condition}': {len(epochs2[condition])} epochs averaged.")
|
||
|
||
all_evokeds = {}
|
||
for condition in epochs.event_id:
|
||
if condition not in all_evokeds:
|
||
all_evokeds[condition] = []
|
||
all_evokeds[condition].append(epochs[condition].average())
|
||
|
||
group_aucs = {}
|
||
# TODO: group averages with a single person?
|
||
group_averages = {cond: grand_average(evokeds) for cond, evokeds in all_evokeds.items()}
|
||
for condition, evoked in group_averages.items():
|
||
group_aucs[condition] = {}
|
||
for pick in ["hbo", "hbr"]:
|
||
picks_idx = [i for i, ch in enumerate(evoked.ch_names) if pick in ch]
|
||
if not picks_idx:
|
||
continue
|
||
|
||
data = evoked.data[picks_idx, :].mean(axis=0)
|
||
t_start, t_end = 0, 15
|
||
times_mask = (evoked.times >= t_start) & (evoked.times <= t_end)
|
||
data_segment = data[times_mask]
|
||
times_segment = evoked.times[times_mask]
|
||
|
||
auc = np.trapezoid(data_segment, times_segment)
|
||
group_aucs[condition][pick] = auc
|
||
|
||
# Final evoked comparison plot for each condition
|
||
for condition in conditions:
|
||
if condition not in evokeds:
|
||
continue
|
||
evoked = evokeds[condition]
|
||
|
||
fig, ax = plt.subplots(figsize=(6, 5))
|
||
legend_labels = ["Oxyhaemoglobin"]
|
||
|
||
for pick, color in zip(["hbo", "hbr"], ["r", "b"]):
|
||
plot_compare_evokeds(
|
||
evoked,
|
||
combine="mean",
|
||
picks=pick,
|
||
axes=ax,
|
||
show=False,
|
||
colors=[color],
|
||
legend=False,
|
||
title=f"Participant: nCondition: {condition}",
|
||
ylim=dict(hbo=[-0.5, 1], hbr=[-0.5, 1]),
|
||
show_sensors=False,
|
||
)
|
||
auc_value = group_aucs.get(condition, {}).get(pick, None)
|
||
if auc_value is not None:
|
||
label = f"{pick.upper()} AUC: {auc_value * 1e6:.4f} µM·s"
|
||
else:
|
||
label = f"{pick.upper()} AUC: N/A"
|
||
legend_labels.append(label)
|
||
if len(legend_labels) == 2:
|
||
legend_labels.append("Deoxyhaemoglobin")
|
||
|
||
ax.legend(legend_labels)
|
||
|
||
fig_epochs.append((f"fig_{condition}_compare_evokeds", fig)) # Store with a unique name
|
||
|
||
return epochs, fig_epochs
|
||
|
||
|
||
|
||
def make_design_matrix(raw_haemo, short_chans):
|
||
|
||
raw_haemo.resample(1, npad="auto")
|
||
raw_haemo._data = raw_haemo._data * 1e6
|
||
# 2) Create design matrix
|
||
if SHORT_CHANNEL:
|
||
short_chans.resample(1)
|
||
design_matrix = make_first_level_design_matrix(
|
||
raw=raw_haemo,
|
||
hrf_model='fir',
|
||
stim_dur=0.5,
|
||
fir_delays=range(15),
|
||
drift_model=DRIFT_MODEL,
|
||
high_pass=0.01,
|
||
oversampling=1,
|
||
min_onset=-125,
|
||
add_regs=short_chans.get_data().T,
|
||
add_reg_names=short_chans.ch_names
|
||
)
|
||
else:
|
||
design_matrix = make_first_level_design_matrix(
|
||
raw=raw_haemo,
|
||
hrf_model='fir',
|
||
stim_dur=0.5,
|
||
fir_delays=range(15),
|
||
drift_model=DRIFT_MODEL,
|
||
high_pass=0.01,
|
||
oversampling=1,
|
||
min_onset=-125,
|
||
)
|
||
|
||
print(design_matrix.head())
|
||
print(design_matrix.columns)
|
||
|
||
|
||
fig, ax1 = plt.subplots(figsize=(10, 6), constrained_layout=True)
|
||
_ = plot_design_matrix(design_matrix, axes=ax1)
|
||
|
||
return design_matrix, fig
|
||
|
||
|
||
|
||
def generate_montage_locations():
|
||
"""Get standard MNI montage locations in dataframe.
|
||
|
||
Data is returned in the same format as the eeg_positions library.
|
||
"""
|
||
# standard_1020 and standard_1005 are in MNI (fsaverage) space already,
|
||
# but we need to undo the scaling that head_scale will do
|
||
montage = make_standard_montage(
|
||
"standard_1005", head_size=0.09700884729534559
|
||
)
|
||
for d in montage.dig:
|
||
d["coord_frame"] = 2003
|
||
montage.dig[:] = montage.dig[3:]
|
||
montage.add_mni_fiducials() # now in fsaverage space
|
||
coords = pd.DataFrame.from_dict(montage.get_positions()["ch_pos"]).T
|
||
coords["label"] = coords.index
|
||
coords = coords.rename(columns={0: "x", 1: "y", 2: "z"})
|
||
|
||
return coords.reset_index(drop=True)
|
||
|
||
|
||
def _find_closest_standard_location(position, reference, *, out="label"):
|
||
"""Return closest montage label to coordinates.
|
||
|
||
Parameters
|
||
----------
|
||
position : array, shape (3,)
|
||
Coordinates.
|
||
reference : dataframe
|
||
As generated by _generate_montage_locations.
|
||
trans_pos : str
|
||
Apply a transformation to positions to specified frame.
|
||
Use None for no transformation.
|
||
"""
|
||
|
||
p0 = np.array(position)
|
||
p0.shape = (-1, 3)
|
||
# head_mri_t, _ = _get_trans("fsaverage", "head", "mri")
|
||
# p0 = apply_trans(head_mri_t, p0)
|
||
dists = cdist(p0, np.asarray(reference[["x", "y", "z"]], float))
|
||
|
||
if out == "label":
|
||
min_idx = np.argmin(dists)
|
||
return reference["label"][min_idx]
|
||
else:
|
||
assert out == "dists"
|
||
return dists
|
||
|
||
|
||
|
||
def _source_detector_fold_table(raw, cidx, reference, fold_tbl, interpolate):
|
||
src = raw.info["chs"][cidx]["loc"][3:6]
|
||
det = raw.info["chs"][cidx]["loc"][6:9]
|
||
|
||
ref_lab = list(reference["label"])
|
||
dists = _find_closest_standard_location([src, det], reference, out="dists")
|
||
src_min, det_min = np.argmin(dists, axis=1)
|
||
src_name, det_name = ref_lab[src_min], ref_lab[det_min]
|
||
|
||
tbl = fold_tbl.query("Source == @src_name and Detector == @det_name")
|
||
dist = np.linalg.norm(dists[[0, 1], [src_min, det_min]])
|
||
# Try reversing source and detector
|
||
if len(tbl) == 0:
|
||
tbl = fold_tbl.query("Source == @det_name and Detector == @src_name")
|
||
if len(tbl) == 0 and interpolate:
|
||
# Try something hopefully not too terrible: pick the one with the
|
||
# smallest net distance
|
||
good = np.isin(fold_tbl["Source"], reference["label"]) & np.isin(
|
||
fold_tbl["Detector"], reference["label"]
|
||
)
|
||
assert good.any()
|
||
tbl = fold_tbl[good]
|
||
assert len(tbl)
|
||
src_idx = [ref_lab.index(src) for src in tbl["Source"]]
|
||
det_idx = [ref_lab.index(det) for det in tbl["Detector"]]
|
||
# Original
|
||
tot_dist = np.linalg.norm([dists[0, src_idx], dists[1, det_idx]], axis=0)
|
||
assert tot_dist.shape == (len(tbl),)
|
||
idx = np.argmin(tot_dist)
|
||
dist_1 = tot_dist[idx]
|
||
src_1, det_1 = ref_lab[src_idx[idx]], ref_lab[det_idx[idx]]
|
||
# And the reverse
|
||
tot_dist = np.linalg.norm([dists[0, det_idx], dists[1, src_idx]], axis=0)
|
||
idx = np.argmin(tot_dist)
|
||
dist_2 = tot_dist[idx]
|
||
src_2, det_2 = ref_lab[det_idx[idx]], ref_lab[src_idx[idx]]
|
||
if dist_1 < dist_2:
|
||
new_dist, src_use, det_use = dist_1, src_1, det_1
|
||
else:
|
||
new_dist, src_use, det_use = dist_2, det_2, src_2
|
||
|
||
|
||
tbl = fold_tbl.query("Source == @src_use and Detector == @det_use")
|
||
tbl = tbl.copy()
|
||
tbl["BestSource"] = src_name
|
||
tbl["BestDetector"] = det_name
|
||
tbl["BestMatchDistance"] = dist
|
||
tbl["MatchDistance"] = new_dist
|
||
assert len(tbl)
|
||
else:
|
||
tbl = tbl.copy()
|
||
tbl["BestSource"] = src_name
|
||
tbl["BestDetector"] = det_name
|
||
tbl["BestMatchDistance"] = dist
|
||
tbl["MatchDistance"] = dist
|
||
|
||
tbl = tbl.copy() # don't get warnings about setting values later
|
||
return tbl
|
||
|
||
|
||
|
||
def _read_fold_xls(fname, atlas="Juelich"):
|
||
"""Read fOLD toolbox xls file.
|
||
|
||
The values are then manipulated in to a tidy dataframe.
|
||
|
||
Note the xls files are not included as no license is provided.
|
||
|
||
Parameters
|
||
----------
|
||
fname : str
|
||
Path to xls file.
|
||
atlas : str
|
||
Requested atlas.
|
||
"""
|
||
page_reference = {"AAL2": 2, "AICHA": 5, "Brodmann": 8, "Juelich": 11, "Loni": 14}
|
||
|
||
tbl = pd.read_excel(fname, sheet_name=page_reference[atlas])
|
||
|
||
# Remove the spacing between rows
|
||
empty_rows = np.where(np.isnan(tbl["Specificity"]))[0]
|
||
tbl = tbl.drop(empty_rows).reset_index(drop=True)
|
||
|
||
# Empty values in the table mean its the same as above
|
||
for row_idx in range(1, tbl.shape[0]):
|
||
for col_idx, col in enumerate(tbl.columns):
|
||
if not isinstance(tbl[col][row_idx], str):
|
||
if np.isnan(tbl[col][row_idx]):
|
||
tbl.iloc[row_idx, col_idx] = tbl.iloc[row_idx - 1, col_idx]
|
||
|
||
tbl["Specificity"] = tbl["Specificity"] * 100
|
||
tbl["brainSens"] = tbl["brainSens"] * 100
|
||
return tbl
|
||
|
||
|
||
def _check_load_fold(fold_files, atlas):
|
||
# _validate_type(fold_files, (list, "path-like", None), "fold_files")
|
||
if fold_files is None:
|
||
fold_files = get_config("MNE_NIRS_FOLD_PATH")
|
||
if fold_files is None:
|
||
raise ValueError(
|
||
"MNE_NIRS_FOLD_PATH not set, either set it using "
|
||
"mne.set_config or pass fold_files as str or list"
|
||
)
|
||
if not isinstance(fold_files, list): # path-like
|
||
fold_files = _check_fname(
|
||
fold_files,
|
||
overwrite="read",
|
||
must_exist=True,
|
||
name="fold_files",
|
||
need_dir=True,
|
||
)
|
||
fold_files = [op.join(fold_files, f"10-{x}.xls") for x in (5, 10)]
|
||
|
||
fold_tbl = pd.DataFrame()
|
||
for fi, fname in enumerate(fold_files):
|
||
fname = _check_fname(
|
||
fname, overwrite="read", must_exist=True, name=f"fold_files[{fi}]"
|
||
)
|
||
fold_tbl = pd.concat(
|
||
[fold_tbl, _read_fold_xls(fname, atlas=atlas)], ignore_index=True
|
||
)
|
||
return fold_tbl
|
||
|
||
|
||
|
||
def fold_channel_specificity_normal(raw, fold_files=None, atlas="Juelich", interpolate=False):
|
||
"""Return the landmarks and specificity a channel is sensitive to.
|
||
|
||
Parameters
|
||
|
||
""" # noqa: E501
|
||
_validate_type(raw, BaseRaw, "raw")
|
||
|
||
reference_locations = generate_montage_locations()
|
||
|
||
fold_tbl = _check_load_fold(fold_files, atlas)
|
||
|
||
chan_spec = list()
|
||
for cidx in range(len(raw.ch_names)):
|
||
tbl = _source_detector_fold_table(
|
||
raw, cidx, reference_locations, fold_tbl, interpolate
|
||
)
|
||
chan_spec.append(tbl.reset_index(drop=True))
|
||
|
||
return chan_spec
|
||
|
||
|
||
|
||
def resource_path(relative_path):
|
||
"""
|
||
Get absolute path to resource regardless of running directly or packaged using PyInstaller
|
||
"""
|
||
|
||
if hasattr(sys, '_MEIPASS'):
|
||
# PyInstaller bundle path
|
||
base_path = sys._MEIPASS
|
||
else:
|
||
base_path = os.path.abspath(".")
|
||
|
||
return os.path.join(base_path, relative_path)
|
||
|
||
|
||
|
||
def fold_channels(raw: BaseRaw) -> None:
|
||
|
||
# if getattr(sys, 'frozen', False):
|
||
path = os.path.expanduser("~") + "/mne_data/fOLD/fOLD-public-master/Supplementary"
|
||
logger.info(path)
|
||
set_config('MNE_NIRS_FOLD_PATH', resource_path(path)) # type: ignore
|
||
|
||
# # Locate the fOLD excel files
|
||
# else:
|
||
# logger.info("yabba")
|
||
# set_config('MNE_NIRS_FOLD_PATH', resource_path("../../mne_data/fOLD/fOLD-public-master/Supplementary")) # type: ignore
|
||
|
||
output = None
|
||
|
||
# List to store the results
|
||
landmark_specificity_data: list[dict[str, Any]] = []
|
||
|
||
# Filter the data to only what we want
|
||
hbo_channel_names = cast(list[str], getattr(raw.copy().pick(picks='hbo'), "ch_names")) # type: ignore
|
||
|
||
# Format the output to make it slightly easier to read
|
||
|
||
if True:
|
||
num_channels = len(hbo_channel_names)
|
||
rows, cols = 4, 7 # 6 rows and 4 columns of pie charts
|
||
fig, axes = plt.subplots(rows, cols, figsize=(16, 10), constrained_layout=True)
|
||
axes = axes.flatten() # Flatten the axes array for easier indexing
|
||
|
||
# If more pie charts than subplots, create extra subplots
|
||
if num_channels > rows * cols:
|
||
fig, axes = plt.subplots((num_channels // cols) + 1, cols, figsize=(16, 10), constrained_layout=True)
|
||
axes = axes.flatten()
|
||
|
||
# Create a list for consistent color mapping
|
||
landmarks = [
|
||
"1 - Primary Somatosensory Cortex",
|
||
"2 - Primary Somatosensory Cortex",
|
||
"3 - Primary Somatosensory Cortex",
|
||
"4 - Primary Motor Cortex",
|
||
"5 - Somatosensory Association Cortex",
|
||
"6 - Pre-Motor and Supplementary Motor Cortex",
|
||
"7 - Somatosensory Association Cortex",
|
||
"8 - Includes Frontal eye fields",
|
||
"9 - Dorsolateral prefrontal cortex",
|
||
"10 - Frontopolar area",
|
||
"11 - Orbitofrontal area",
|
||
"17 - Primary Visual Cortex (V1)",
|
||
"18 - Visual Association Cortex (V2)",
|
||
"19 - V3",
|
||
"20 - Inferior Temporal gyrus",
|
||
"21 - Middle Temporal gyrus",
|
||
"22 - Superior Temporal Gyrus",
|
||
"23 - Ventral Posterior cingulate cortex",
|
||
"24 - Ventral Anterior cingulate cortex",
|
||
"25 - Subgenual cortex",
|
||
"32 - Dorsal anterior cingulate cortex",
|
||
"37 - Fusiform gyrus",
|
||
"38 - Temporopolar area",
|
||
"39 - Angular gyrus, part of Wernicke's area",
|
||
"40 - Supramarginal gyrus part of Wernicke's area",
|
||
"41 - Primary and Auditory Association Cortex",
|
||
"42 - Primary and Auditory Association Cortex",
|
||
"43 - Subcentral area",
|
||
"44 - pars opercularis, part of Broca's area",
|
||
"45 - pars triangularis Broca's area",
|
||
"46 - Dorsolateral prefrontal cortex",
|
||
"47 - Inferior prefrontal gyrus",
|
||
"48 - Retrosubicular area",
|
||
"Brain_Outside",
|
||
]
|
||
|
||
cmap1 = plt.get_cmap('tab20') # First 20 colors
|
||
cmap2 = plt.get_cmap('tab20b') # Next 20 colors
|
||
|
||
# Combine the colors from both colormaps
|
||
colors = [cmap1(i) for i in range(20)] + [cmap2(i) for i in range(20)] # Total 40 colors
|
||
|
||
landmarks.sort(key=lambda x: (int(x.split(" - ")[0]) if x.split(" - ")[0].isdigit() else float('inf')))
|
||
|
||
landmark_color_map = {landmark: colors[i % len(colors)] for i, landmark in enumerate(landmarks)}
|
||
|
||
# Iterate over each channel
|
||
for idx, channel_name in enumerate(hbo_channel_names):
|
||
|
||
# Run the fOLD on the selected channel
|
||
channel_data = raw.copy().pick(picks=channel_name) # type: ignore
|
||
|
||
output = cast(list[DataFrame], fold_channel_specificity_normal(channel_data, interpolate=True, atlas='Brodmann'))
|
||
|
||
# Process each DataFrame that fold_channel_specificity returns
|
||
for df_data in output:
|
||
|
||
# Extract the relevant columns
|
||
useful_data = df_data[['Landmark', 'Specificity']]
|
||
|
||
# Store the results
|
||
landmark_specificity_data.append({
|
||
'Channel': channel_name,
|
||
'Data': useful_data,
|
||
})
|
||
|
||
|
||
# Plot the results
|
||
# TODO: Fix this
|
||
if True:
|
||
unique_landmarks = sorted(useful_data['Landmark'].unique())
|
||
color_list = [landmark_color_map[landmark] for landmark in useful_data['Landmark']]
|
||
|
||
# Plot specificity for each channel
|
||
ax = axes[idx]
|
||
|
||
labels = [f'{landmark.split(" - ")[0]}' if landmark != 'Brain_Outside' else 'B' for landmark in useful_data['Landmark']]
|
||
|
||
wedges, texts, autotexts = ax.pie(
|
||
useful_data['Specificity'],
|
||
autopct='%1.1f%%',
|
||
startangle=90,
|
||
labels=labels,
|
||
labeldistance=1.05,
|
||
colors=color_list)
|
||
|
||
ax.set_title(f'{channel_name}')
|
||
ax.axis('equal')
|
||
|
||
landmark_specificity_data = []
|
||
|
||
# TODO: Fix this
|
||
if True:
|
||
handles = [
|
||
plt.Line2D([0], [0], marker='o', color='w', label=landmark, markersize=10,
|
||
markerfacecolor=landmark_color_map[landmark])
|
||
for landmark in landmarks
|
||
]
|
||
n_landmarks = len(landmarks)
|
||
|
||
# Calculate the figure size based on number of rows and columns
|
||
fig_width = 5
|
||
fig_height = n_landmarks / 4
|
||
|
||
# Create a new figure window for the legend
|
||
legend_fig = plt.figure(figsize=(fig_width, fig_height))
|
||
legend_axes = legend_fig.add_subplot(111)
|
||
legend_axes.axis('off') # Turn off axis for the legend window
|
||
legend_axes.legend(handles=handles, loc='center', fontsize=10, title="Landmarks")
|
||
|
||
for ax in axes[len(hbo_channel_names):]:
|
||
ax.axis('off')
|
||
|
||
plt.show()
|
||
return fig, legend_fig
|
||
|
||
|
||
|
||
|
||
def individual_significance(raw_haemo, glm_est):
|
||
|
||
fig_individual_significances = [] # List to store figures
|
||
|
||
# TODO: BAD!
|
||
cha = glm_est.to_dataframe()
|
||
|
||
unique_annotations = set(raw_haemo.annotations.description)
|
||
|
||
for cond in unique_annotations:
|
||
|
||
ch_summary = cha.query(f"Condition.str.startswith('{cond}_delay_') and Chroma == 'hbo'", engine='python')
|
||
|
||
print(ch_summary.head())
|
||
|
||
channel_averages = ch_summary.groupby('ch_name')['theta'].mean().reset_index()
|
||
print(channel_averages.head())
|
||
|
||
|
||
activity_ch_summary = ch_summary.query(
|
||
f"Chroma == 'hbo' and Condition.str.startswith('{cond}_delay_')", engine='python'
|
||
)
|
||
|
||
# Function to correct p-values per channel
|
||
def fdr_correct_per_channel(df):
|
||
df = df.copy()
|
||
df['pval_fdr'] = multipletests(df['p_value'], method='fdr_bh')[1]
|
||
return df
|
||
|
||
# Apply FDR correction grouped by channel
|
||
corrected = activity_ch_summary.groupby("ch_name", group_keys=False).apply(fdr_correct_per_channel)
|
||
|
||
# Determine which channels are significant across any delay
|
||
sig_channels = (
|
||
corrected.groupby('ch_name')
|
||
.apply(lambda df: (df['pval_fdr'] < 0.05).any())
|
||
.reset_index(name='significant')
|
||
)
|
||
|
||
# Merge with mean theta (optional for plotting)
|
||
mean_theta = activity_ch_summary.groupby('ch_name')['theta'].mean().reset_index()
|
||
sig_channels = sig_channels.merge(mean_theta, on='ch_name')
|
||
print(sig_channels)
|
||
|
||
|
||
# For example, take the minimum corrected p-value per channel
|
||
summary_pvals = corrected.groupby('ch_name')['pval_fdr'].min().reset_index()
|
||
print(summary_pvals)
|
||
|
||
|
||
def parse_ch_name(ch_name):
|
||
# Extract numbers after S and D in names like 'S10_D5 hbo'
|
||
match = re.match(r'S(\d+)_D(\d+)', ch_name)
|
||
if match:
|
||
return int(match.group(1)), int(match.group(2))
|
||
else:
|
||
return None, None
|
||
|
||
|
||
min_pvals = corrected.groupby('ch_name')['pval_fdr'].min().reset_index()
|
||
|
||
# Merge the real p-values into sig_channels / avg_df
|
||
avg_df = sig_channels.merge(min_pvals, on='ch_name')
|
||
|
||
# Rename columns for consistency
|
||
avg_df = avg_df.rename(columns={'theta': 't_or_theta', 'pval_fdr': 'p_value'})
|
||
|
||
# Add Source and Detector columns again
|
||
avg_df['Source'], avg_df['Detector'] = zip(*avg_df['ch_name'].map(parse_ch_name))
|
||
|
||
# Keep relevant columns
|
||
avg_df = avg_df[['Source', 'Detector', 't_or_theta', 'p_value']].dropna()
|
||
|
||
ABS_SIGNIFICANCE_THETA_VALUE = 1
|
||
ABS_SIGNIFICANCE_T_VALUE = 1
|
||
P_THRESHOLD = 0.05
|
||
SOURCE_DETECTOR_SEPARATOR = "_"
|
||
|
||
t_or_theta = 'theta'
|
||
for _, row in avg_df.iterrows(): # type: ignore
|
||
print(f"Source {row['Source']} <-> Detector {row['Detector']}: "
|
||
f"Avg {t_or_theta}-value = {row['t_or_theta']:.3f}, Avg p-value = {row['p_value']:.3f}")
|
||
|
||
# Extract the cource and detector positions from raw
|
||
src_pos: dict[int, tuple[float, float]] = {}
|
||
det_pos: dict[int, tuple[float, float]] = {}
|
||
for ch in getattr(raw_haemo, "info")["chs"]:
|
||
ch_name = ch['ch_name']
|
||
if not ch_name or not ch['loc'].any():
|
||
continue
|
||
parts = ch_name.split()[0]
|
||
src_str, det_str = parts.split(SOURCE_DETECTOR_SEPARATOR)
|
||
src_num = int(src_str[1:])
|
||
det_num = int(det_str[1:])
|
||
src_pos[src_num] = ch['loc'][3:5]
|
||
det_pos[det_num] = ch['loc'][6:8]
|
||
|
||
# Set up the plot
|
||
fig, ax = plt.subplots(figsize=(8, 6)) # type: ignore
|
||
|
||
# Plot the sources
|
||
for pos in src_pos.values():
|
||
ax.scatter(pos[0], pos[1], s=120, c='k', marker='o', edgecolors='white', linewidths=1, zorder=3) # type: ignore
|
||
|
||
# Plot the detectors
|
||
for pos in det_pos.values():
|
||
ax.scatter(pos[0], pos[1], s=120, c='k', marker='s', edgecolors='white', linewidths=1, zorder=3) # type: ignore
|
||
|
||
# Ensure that the colors stay within the boundaries even if they are over or under the max/min values
|
||
if t_or_theta == 't':
|
||
norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_T_VALUE, vmax=ABS_SIGNIFICANCE_T_VALUE)
|
||
elif t_or_theta == 'theta':
|
||
norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_THETA_VALUE, vmax=ABS_SIGNIFICANCE_THETA_VALUE)
|
||
|
||
cmap: mcolors.Colormap = plt.get_cmap('seismic')
|
||
|
||
# Plot connections with avg t-values
|
||
for row in avg_df.itertuples():
|
||
src: int = cast(int, row.Source) # type: ignore
|
||
det: int = cast(int, row.Detector) # type: ignore
|
||
tval: float = cast(float, row.t_or_theta) # type: ignore
|
||
pval: float = cast(float, row.p_value) # type: ignore
|
||
|
||
|
||
if src in src_pos and det in det_pos:
|
||
x = [src_pos[src][0], det_pos[det][0]]
|
||
y = [src_pos[src][1], det_pos[det][1]]
|
||
style = '-' if pval <= P_THRESHOLD else '--'
|
||
ax.plot(x, y, linestyle=style, color=cmap(norm(tval)), linewidth=4, alpha=0.9, zorder=2) # type: ignore
|
||
|
||
# Format the Colorbar
|
||
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
||
sm.set_array([])
|
||
cbar = plt.colorbar(sm, ax=ax, shrink=0.85) # type: ignore
|
||
cbar.set_label(f'Average {cond} {t_or_theta} value (hbo)', fontsize=11) # type: ignore
|
||
|
||
# Formatting the subplots
|
||
ax.set_aspect('equal')
|
||
ax.set_title(f"Average {t_or_theta} values for {cond} (HbO)", fontsize=14) # type: ignore
|
||
ax.set_xlabel('X position (m)', fontsize=11) # type: ignore
|
||
ax.set_ylabel('Y position (m)', fontsize=11) # type: ignore
|
||
ax.grid(True, alpha=0.3) # type: ignore
|
||
|
||
# Set axis limits to be 1cm more than the optode positions
|
||
all_x = [pos[0] for pos in src_pos.values()] + [pos[0] for pos in det_pos.values()]
|
||
all_y = [pos[1] for pos in src_pos.values()] + [pos[1] for pos in det_pos.values()]
|
||
ax.set_xlim(min(all_x)-0.01, max(all_x)+0.01)
|
||
ax.set_ylim(min(all_y)-0.01, max(all_y)+0.01)
|
||
|
||
fig.tight_layout()
|
||
|
||
fig_individual_significances.append((f"Condition {cond}", fig))
|
||
|
||
return fig_individual_significances
|
||
|
||
# TODO: Hardcoded
|
||
def group_significance(
|
||
raw_haemo,
|
||
all_cha: pd.DataFrame,
|
||
condition: str,
|
||
correction: str = "fdr_bh"
|
||
) -> plt.Figure:
|
||
"""
|
||
Compute group-level significance using weighted Stouffer's method and plot results.
|
||
|
||
Args:
|
||
raw_haemo: Raw haemoglobin MNE object (used for optode positions)
|
||
all_cha: DataFrame with columns including 'ID', 'Condition', 'p_value', 'theta', 'df', 'ch_name', 'Chroma'
|
||
condition: condition prefix, e.g., 'Activity'
|
||
correction: p-value correction method ('fdr_bh' or 'bonferroni')
|
||
|
||
Returns:
|
||
Matplotlib Figure with group-level theta values and significance.
|
||
"""
|
||
|
||
assert "ID" in all_cha.columns, "'ID' column missing in input data"
|
||
assert len(raw_haemo) >= 1, "At least one raw haemoglobin object is required"
|
||
|
||
condition_prefix = f"{condition}_delay"
|
||
|
||
# Filter relevant data
|
||
ch_summary = all_cha.query(
|
||
"Condition.str.startswith(@condition_prefix) and Chroma == 'hbo'",
|
||
engine='python'
|
||
).copy()
|
||
|
||
|
||
logger.info("=== ch_summary head ===")
|
||
logger.info(ch_summary.head())
|
||
|
||
logger.info("\nSummary stats:")
|
||
logger.info(f"Total rows: {len(ch_summary)}")
|
||
logger.info(f"Unique subjects: {ch_summary['ID'].nunique() if 'ID' in ch_summary.columns else 'ID column missing'}")
|
||
logger.info(f"Unique conditions: {ch_summary['Condition'].unique()}")
|
||
logger.info(f"Unique channels (Source-Detector pairs): {ch_summary.groupby(['Source', 'Detector']).ngroups}")
|
||
|
||
logger.info("\nSample p_values:")
|
||
logger.info(ch_summary['p_value'].describe())
|
||
|
||
if ch_summary.empty:
|
||
raise ValueError(f"No data found for condition prefix: {condition_prefix}")
|
||
|
||
# --- For debugging
|
||
logger.info(f"Total rows after filtering for condition '{condition_prefix}': {len(ch_summary)}")
|
||
logger.info(f"Unique channels: {ch_summary['ch_name'].nunique()}")
|
||
logger.info(f"Participants: {ch_summary['ID'].nunique()}")
|
||
|
||
# Step 1: Select the peak regressor (~6s after stimulus onset)
|
||
peak_regressor = f"{condition}_delay_6"
|
||
peak_data = ch_summary[ch_summary["Condition"] == peak_regressor].copy()
|
||
|
||
logger.info(f"\n=== Logging all values for {peak_regressor} ===")
|
||
for row in peak_data.itertuples(index=False):
|
||
logger.info(
|
||
f"Subject: {row.ID}, "
|
||
f"Channel: {row.ch_name}, "
|
||
f"Source: {row.Source}, Detector: {row.Detector}, "
|
||
f"theta: {row.theta:.4f}, "
|
||
f"p_value: {row.p_value:.6f}, "
|
||
f"df: {row.df}"
|
||
)
|
||
|
||
if peak_data.empty:
|
||
raise ValueError(f"No data found for peak regressor: {peak_regressor}")
|
||
|
||
# Step 2: Combine per-channel stats across subjects
|
||
group_results = []
|
||
|
||
for (src, det), group in peak_data.groupby(["Source", "Detector"]):
|
||
pvals = group["p_value"].values
|
||
thetas = group["theta"].values
|
||
dfs = group["df"].values
|
||
|
||
# Weighted Stouffer's method
|
||
weights = np.sqrt(dfs)
|
||
z_scores = norm.isf(pvals)
|
||
combined_z = np.sum(weights * z_scores) / np.sqrt(np.sum(weights**2))
|
||
combined_p = norm.sf(combined_z)
|
||
|
||
theta_avg = np.average(thetas, weights=weights)
|
||
|
||
group_results.append({
|
||
"Source": src,
|
||
"Detector": det,
|
||
"theta_avg": theta_avg,
|
||
"combined_p": combined_p
|
||
})
|
||
|
||
# Step 3: Create combined_df
|
||
combined_df = pd.DataFrame(group_results)
|
||
|
||
# Step 4: Multiple comparisons correction
|
||
_, pvals_corr, _, significant = multipletests(
|
||
combined_df["combined_p"], alpha=0.05, method=correction
|
||
)
|
||
|
||
combined_df["pval_corr"] = pvals_corr
|
||
combined_df["significant"] = significant
|
||
|
||
logger.info(f"Used peak regressor: {peak_regressor}")
|
||
logger.info(f"Channels tested: {len(combined_df)}")
|
||
logger.info(f"Significant channels after correction: {combined_df['significant'].sum()}")
|
||
# Get optode positions from the first raw file
|
||
raw = raw_haemo
|
||
src_pos, det_pos = {}, {}
|
||
for ch in raw.info["chs"]:
|
||
ch_name = ch["ch_name"]
|
||
if not ch_name or not ch["loc"].any():
|
||
continue
|
||
parts = ch_name.split()[0]
|
||
src_str, det_str = parts.split("_")
|
||
src_num = int(src_str[1:])
|
||
det_num = int(det_str[1:])
|
||
src_pos[src_num] = ch["loc"][3:5]
|
||
det_pos[det_num] = ch["loc"][6:8]
|
||
|
||
# Plotting parameters
|
||
ABS_SIGNIFICANCE_THETA_VALUE = 1
|
||
P_THRESHOLD = 0.05
|
||
cmap = plt.get_cmap("seismic")
|
||
norm = mcolors.Normalize(vmin=-ABS_SIGNIFICANCE_THETA_VALUE, vmax=ABS_SIGNIFICANCE_THETA_VALUE)
|
||
|
||
fig, ax = plt.subplots(figsize=(8, 6))
|
||
|
||
# Plot optodes
|
||
for pos in src_pos.values():
|
||
ax.scatter(*pos, s=120, c="k", marker="o", edgecolors="white", linewidths=1, zorder=3)
|
||
for pos in det_pos.values():
|
||
ax.scatter(*pos, s=120, c="k", marker="s", edgecolors="white", linewidths=1, zorder=3)
|
||
|
||
# Plot connections colored by average theta, solid if significant
|
||
for row in combined_df.itertuples():
|
||
src, det = int(row.Source), int(row.Detector)
|
||
tval, pval = row.theta_avg, row.pval_corr
|
||
if src in src_pos and det in det_pos:
|
||
x = [src_pos[src][0], det_pos[det][0]]
|
||
y = [src_pos[src][1], det_pos[det][1]]
|
||
linestyle = "-" if pval <= P_THRESHOLD else "--"
|
||
ax.plot(x, y, linestyle=linestyle, color=cmap(norm(tval)), linewidth=4, alpha=0.9, zorder=2)
|
||
|
||
# Colorbar
|
||
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
||
sm.set_array([])
|
||
cbar = plt.colorbar(sm, ax=ax, shrink=0.85)
|
||
cbar.set_label(f"Average {condition_prefix.rstrip('_')} θ-value (HbO)", fontsize=11)
|
||
|
||
# Format axes
|
||
ax.set_aspect("equal")
|
||
ax.set_title(f"Group-level θ-values for {condition_prefix.rstrip('_')} (HbO)", fontsize=14)
|
||
ax.set_xlabel("X position (m)", fontsize=11)
|
||
ax.set_ylabel("Y position (m)", fontsize=11)
|
||
ax.grid(True, alpha=0.3)
|
||
|
||
all_x = [p[0] for p in src_pos.values()] + [p[0] for p in det_pos.values()]
|
||
all_y = [p[1] for p in src_pos.values()] + [p[1] for p in det_pos.values()]
|
||
ax.set_xlim(min(all_x) - 0.01, max(all_x) + 0.01)
|
||
ax.set_ylim(min(all_y) - 0.01, max(all_y) + 0.01)
|
||
|
||
fig.tight_layout()
|
||
fig.show()
|
||
|
||
|
||
|
||
|
||
def plot_glm_results(file_path, raw_haemo, glm_est, design_matrix):
|
||
|
||
fig_glms = [] # List to store figures
|
||
|
||
dm = design_matrix.copy()
|
||
logger.info(design_matrix.shape)
|
||
logger.info(design_matrix.columns)
|
||
logger.info(design_matrix.head())
|
||
|
||
rois = dict(AllChannels=range(len(raw_haemo.ch_names)))
|
||
conditions = design_matrix.columns
|
||
df_individual = glm_est.to_dataframe_region_of_interest(rois, conditions)
|
||
|
||
df_individual["ID"] = file_path
|
||
# df_individual["theta"] = [t * 1.0e6 for t in df_individual["theta"]]
|
||
|
||
first_onset_for_cond = {}
|
||
for onset, desc in zip(raw_haemo.annotations.onset, raw_haemo.annotations.description):
|
||
if desc not in first_onset_for_cond:
|
||
first_onset_for_cond[desc] = onset
|
||
|
||
# Get unique condition names from annotations (descriptions)
|
||
unique_annotations = set(raw_haemo.annotations.description)
|
||
|
||
for cond in unique_annotations:
|
||
logger.info(cond)
|
||
df_individual_filtered = df_individual.copy()
|
||
|
||
# Filter for the condition of interest and FIR delays
|
||
df_individual_filtered["isCondition"] = [cond in n for n in df_individual_filtered["Condition"]]
|
||
df_individual_filtered["isDelay"] = ["delay" in n for n in df_individual_filtered["Condition"]]
|
||
df_individual_filtered = df_individual_filtered.query("isDelay and isCondition")
|
||
|
||
# Remove other conditions from design matrix
|
||
dm_condition_cols = [col for col in dm.columns if cond in col]
|
||
dm_cond = dm[dm_condition_cols]
|
||
|
||
# Add a numeric delay column
|
||
def extract_delay_number(condition_str):
|
||
# Extracts the number at the end of a string like 'Activity_delay_5'
|
||
return int(condition_str.split("_")[-1])
|
||
|
||
df_individual_filtered["DelayNum"] = df_individual_filtered["Condition"].apply(extract_delay_number)
|
||
|
||
# Now separate and sort using numeric delay
|
||
df_hbo = df_individual_filtered[df_individual_filtered["Chroma"] == "hbo"].sort_values("DelayNum")
|
||
df_hbr = df_individual_filtered[df_individual_filtered["Chroma"] == "hbr"].sort_values("DelayNum")
|
||
|
||
vals_hbo = df_hbo["theta"].values
|
||
vals_hbr = df_hbr["theta"].values
|
||
|
||
# Create the plot
|
||
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(19, 10))
|
||
|
||
# Scale design matrix components using numpy arrays instead of pandas operations
|
||
dm_cond_values = dm_cond.values
|
||
dm_cond_scaled_hbo = dm_cond_values * vals_hbo.reshape(1, -1)
|
||
dm_cond_scaled_hbr = dm_cond_values * vals_hbr.reshape(1, -1)
|
||
|
||
# Create time axis relative to stimulus onset
|
||
time = dm_cond.index - np.ceil(first_onset_for_cond.get(cond, 0))
|
||
|
||
# Plot
|
||
axes[0].plot(time, dm_cond_values)
|
||
axes[1].plot(time, dm_cond_scaled_hbo)
|
||
axes[2].plot(time, np.sum(dm_cond_scaled_hbo, axis=1), 'r')
|
||
axes[2].plot(time, np.sum(dm_cond_scaled_hbr, axis=1), 'b')
|
||
|
||
# Format plots
|
||
for ax in range(3):
|
||
axes[ax].set_xlim(-5, 25)
|
||
axes[ax].set_xlabel("Time (s)")
|
||
axes[0].set_ylim(-0.2, 1.2)
|
||
axes[1].set_ylim(-0.5, 1)
|
||
axes[2].set_ylim(-0.5, 1)
|
||
axes[0].set_title(f"FIR Model (Unscaled)")
|
||
axes[1].set_title(f"FIR Components (Scaled by {cond} GLM Estimates)")
|
||
axes[2].set_title(f"Evoked Response ({cond})")
|
||
axes[0].set_ylabel("FIR Model")
|
||
axes[1].set_ylabel("Oxyhaemoglobin (ΔμMol)")
|
||
axes[2].set_ylabel("Haemoglobin (ΔμMol)")
|
||
axes[2].legend(["Oxyhaemoglobin", "Deoxyhaemoglobin"])
|
||
|
||
|
||
print(f"Number of FIR bins: {len(vals_hbo)}")
|
||
print(f"Mean theta (HbO): {np.mean(vals_hbo):.4f}")
|
||
print(f"Sum of theta (HbO): {np.sum(vals_hbo):.4f}")
|
||
print(f"Mean theta (HbR): {np.mean(vals_hbr):.4f}")
|
||
print(f"Sum of theta (HbR): {np.sum(vals_hbr):.4f}")
|
||
|
||
fig_glms.append((f"Condition {cond}", fig))
|
||
|
||
return fig_glms
|
||
|
||
|
||
def plot_3d_evoked_array(
|
||
inst: Union[BaseRaw, EvokedArray, Info],
|
||
statsmodel_df: DataFrame,
|
||
picks: Optional[Union[str, list[str]]] = "hbo",
|
||
value: str = "Coef.",
|
||
background: str = "w",
|
||
figure: Optional[object] = None,
|
||
clim: Union[str, dict[str, Union[str, list[float]]]] = "auto",
|
||
mode: str = "weighted",
|
||
colormap: str = "RdBu_r",
|
||
surface: str = "pial",
|
||
hemi: str = "both",
|
||
size: int = 800,
|
||
view: Optional[Union[str, dict[str, float]]] = None,
|
||
colorbar: bool = True,
|
||
distance: float = 0.03,
|
||
subjects_dir: Optional[str] = None,
|
||
src: Optional[SourceSpaces] = None,
|
||
verbose: bool = False,
|
||
) -> Brain:
|
||
'''Ported from MNE'''
|
||
|
||
info: Info = cast(Info, deepcopy(inst if isinstance(inst, Info) else inst.info)) # type: ignore
|
||
if not (getattr(info, "ch_names") == list(statsmodel_df["ch_name"].values)): # type: ignore
|
||
raise RuntimeError(
|
||
'MNE data structure does not match dataframe '
|
||
f'results.\nMNE = {getattr(info, "ch_names")}.\n'
|
||
f'GLM = {list(statsmodel_df["ch_name"].values)}' # type: ignore
|
||
)
|
||
|
||
ea = EvokedArray(np.tile(statsmodel_df[value].values.T, (1, 1)).T, info.copy()) # type: ignore
|
||
|
||
# TODO: mimic behaviour of other MNE-NIRS glm plotting options
|
||
if picks is not None:
|
||
ea = ea.pick(picks=picks) # type: ignore
|
||
|
||
if subjects_dir is None:
|
||
subjects_dir = os.environ["SUBJECTS_DIR"]
|
||
if src is None:
|
||
fname_src_fs = os.path.join(
|
||
subjects_dir, "fsaverage", "bem", "fsaverage-ico-5-src.fif"
|
||
)
|
||
src = read_source_spaces(fname_src_fs, verbose=verbose)
|
||
|
||
picks = getattr(ea, "info")["ch_names"]
|
||
|
||
# Set coord frame
|
||
for idx in range(len(getattr(ea, "ch_names"))):
|
||
getattr(ea, "info")["chs"][idx]["coord_frame"] = 4
|
||
|
||
# Generate source estimate
|
||
kwargs = dict(
|
||
evoked=ea,
|
||
subject="fsaverage",
|
||
trans=Transform('head', 'mri', np.eye(4)),
|
||
distance=distance,
|
||
mode=mode,
|
||
surface=surface,
|
||
subjects_dir=subjects_dir,
|
||
src=src,
|
||
project=True,
|
||
)
|
||
|
||
stc = stc_near_sensors(picks=picks, **kwargs, verbose=verbose) # type: ignore
|
||
|
||
assert isinstance(stc, SourceEstimate)
|
||
|
||
# Produce brain plot
|
||
brain: Brain = stc.plot( # type: ignore
|
||
src=src,
|
||
subjects_dir=subjects_dir,
|
||
hemi=hemi,
|
||
surface=surface,
|
||
initial_time=0,
|
||
clim=clim, # type: ignore
|
||
size=size,
|
||
colormap=colormap,
|
||
figure=figure,
|
||
background=background,
|
||
colorbar=colorbar,
|
||
verbose=verbose,
|
||
)
|
||
if view is not None:
|
||
brain.show_view(view) # type: ignore
|
||
|
||
return brain
|
||
|
||
|
||
|
||
def brain_3d_visualization(raw_haemo, df_cha, selected_event, t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True, brain_bounds: float = 1.0) -> None:
|
||
|
||
clim = dict(kind="value", pos_lims=(0, brain_bounds/2, brain_bounds))
|
||
|
||
# Get all activity conditions
|
||
for cond in [f'{selected_event}']:
|
||
|
||
if True:
|
||
ch_summary = df_cha.query(f"Condition.str.startswith('{cond}_delay_') and Chroma == 'hbo'", engine='python') # type: ignore
|
||
|
||
# Use ordinary least squares (OLS) if only one participant
|
||
# TODO: Fix.
|
||
if True:
|
||
|
||
# t values
|
||
if t_or_theta == 't':
|
||
ch_model = smf.ols("t ~ -1 + ch_name", ch_summary).fit() # type: ignore
|
||
|
||
# theta values
|
||
elif t_or_theta == 'theta':
|
||
ch_model = smf.ols("theta ~ -1 + ch_name", ch_summary).fit() # type: ignore
|
||
|
||
print("OLS model is being used as there is only one participant!")
|
||
|
||
# Convert model results
|
||
model_df = cast(DataFrame, statsmodels_to_results(ch_model, order=ch_summary["ch_name"].unique())) # type: ignore
|
||
|
||
valid_channels = ch_summary["ch_name"].unique().tolist() # type: ignore
|
||
raw_for_plot = raw_haemo.copy().pick(picks=valid_channels) # type: ignore
|
||
|
||
brain = plot_3d_evoked_array(raw_for_plot.pick(picks="hbo"), model_df, view="dorsal", distance=0.02, colorbar=True, clim=clim, mode="weighted", size=(800, 700)) # type: ignore
|
||
|
||
if show_optodes == 'all' or show_optodes == 'sensors':
|
||
brain.add_sensors(getattr(raw_for_plot, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=False) # type: ignore
|
||
|
||
if True:
|
||
display_text = ('Folder: ' + '\nGroup: ' + '\nCondition: '+ cond + '\nShort Channel Regression: '
|
||
+ '\nLooking at: ' + t_or_theta + ' values') + '\nBrain Distance: '
|
||
|
||
# Apply the text onto the brain
|
||
if show_text:
|
||
brain.add_text(0.12, 0.64, display_text, "title", font_size=11, color="k") # type: ignore
|
||
|
||
return brain
|
||
|
||
|
||
|
||
|
||
def brain_landmarks_3d(raw_haemo: BaseRaw, show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_brodmann: bool = True) -> None:
|
||
|
||
brain = Brain("fsaverage", background="white", size=(800, 700)) # type: ignore
|
||
|
||
distances = source_detector_distances(raw_haemo.info)
|
||
|
||
# Add optode text labels manually
|
||
if show_optodes == 'all' or show_optodes == 'sensors':
|
||
brain.add_sensors(getattr(raw_haemo, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=False) # type: ignore
|
||
|
||
if show_optodes == 'all' or show_optodes == 'labels':
|
||
labeled_srcs = set()
|
||
labeled_dets = set()
|
||
label_counts = {}
|
||
|
||
for idx, ch in enumerate(raw_haemo.info['chs']):
|
||
ch_name = ch['ch_name']
|
||
|
||
if not ch_name.endswith('hbo'):
|
||
continue
|
||
|
||
loc = ch['loc']
|
||
logger.info(f"Channel: {ch_name}")
|
||
logger.info(f"loc length: {len(loc)}")
|
||
logger.info("loc contents:")
|
||
for i, val in enumerate(loc):
|
||
logger.info(f" loc[{i}]: {val}")
|
||
logger.info("-" * 30)
|
||
if not ch_name or not ch['loc'].any():
|
||
continue
|
||
|
||
parts = ch_name.split()[0]
|
||
src_str, det_str = parts.split('_')
|
||
|
||
src_num = int(src_str[1:])
|
||
det_num = int(det_str[1:])
|
||
|
||
if src_num not in labeled_srcs:
|
||
src_xyz = ch['loc'][3:6] * 1000
|
||
brain._renderer.text3d(src_xyz[0], src_xyz[1], src_xyz[2], src_str,
|
||
color='red', scale=0.002)
|
||
labeled_srcs.add(src_num)
|
||
|
||
if det_num not in labeled_dets:
|
||
det_xyz = ch['loc'][6:9] * 1000
|
||
brain._renderer.text3d(det_xyz[0], det_xyz[1], det_xyz[2], det_str,
|
||
color='blue', scale=0.002)
|
||
labeled_dets.add(det_num)
|
||
|
||
# Get the source-detector distance for this channel (in meters)
|
||
dist_m = distances[idx]
|
||
dist_mm = dist_m * 1000
|
||
|
||
label_text = f"{dist_mm:.1f} mm"
|
||
label_counts[label_text] = label_counts.get(label_text, 0) + 1
|
||
if label_counts[label_text] > 1:
|
||
label_text += f" ({label_counts[label_text]})"
|
||
|
||
# Label at channel midpoint
|
||
mid_xyz = loc[0:3] * 1000
|
||
|
||
logger.info(f"Channel: {ch_name} | Midpoint (mm): x={mid_xyz[0]:.2f}, y={mid_xyz[1]:.2f}, z={mid_xyz[2]:.2f} | Distance: {dist_mm:.1f} mm")
|
||
|
||
brain._renderer.text3d(
|
||
mid_xyz[0], mid_xyz[1], mid_xyz[2],
|
||
label_text,
|
||
color='gray',
|
||
scale=0.002
|
||
)
|
||
|
||
|
||
if show_brodmann:# Add Brodmann labels
|
||
labels = cast(list[Label], read_labels_from_annot("fsaverage", "PALS_B12_Brodmann", "lh", verbose=False)) # type: ignore
|
||
|
||
label_colors = {
|
||
"Brodmann.1-lh": "red",
|
||
"Brodmann.2-lh": "red",
|
||
"Brodmann.3-lh": "red",
|
||
"Brodmann.4-lh": "orange",
|
||
"Brodmann.5-lh": "green",
|
||
"Brodmann.6-lh": "yellow",
|
||
"Brodmann.7-lh": "green",
|
||
"Brodmann.17-lh": "blue",
|
||
"Brodmann.18-lh": "blue",
|
||
"Brodmann.19-lh": "blue",
|
||
"Brodmann.39-lh": "pink",
|
||
"Brodmann.40-lh": "purple",
|
||
"Brodmann.42-lh": "white",
|
||
"Brodmann.44-lh": "white",
|
||
"Brodmann.48-lh": "white",
|
||
|
||
}
|
||
|
||
for label in labels:
|
||
name = getattr(label, "name", None)
|
||
if not isinstance(name, str):
|
||
continue
|
||
if name in label_colors:
|
||
brain.add_label(label, borders=False, color=label_colors[name]) # type: ignore
|
||
|
||
|
||
return brain
|
||
|
||
|
||
def verify_channel_positions(data: BaseRaw) -> None:
|
||
"""
|
||
Visualizes the sensor/channel positions of the raw data for verification.
|
||
|
||
Parameters
|
||
----------
|
||
data : BaseRaw
|
||
The loaded data object to process.
|
||
"""
|
||
|
||
def convert_fig_dict_to_png_bytes(fig_dict: dict[str, Figure]) -> dict[str, bytes]:
|
||
png_dict = {}
|
||
for label, fig in fig_dict.items():
|
||
buf = BytesIO()
|
||
fig.savefig(buf, format="png", bbox_inches="tight")
|
||
buf.seek(0)
|
||
png_dict[label] = buf.read()
|
||
plt.close(fig)
|
||
return png_dict
|
||
|
||
|
||
|
||
def brain_3d_contrast(con_model_df: DataFrame, con_model_df_filtered: BaseRaw, common_channels: list[str], first_name: str, second_name: str, t_or_theta: Literal['t', 'theta'] = 'theta', show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all', show_text: bool = True, brain_bounds: float = 1.0) -> None:
|
||
# Filter DataFrame to only common channels, and sort by raw order
|
||
con_model = con_model_df
|
||
|
||
con_model["ch_name"] = pd.Categorical(
|
||
con_model["ch_name"], categories=common_channels, ordered=True
|
||
)
|
||
con_model = con_model.sort_values("ch_name").reset_index(drop=True) # type: ignore
|
||
|
||
|
||
clim=dict(kind="value", pos_lims=(0, brain_bounds/2, brain_bounds))
|
||
|
||
# Plot brain figure
|
||
brain = plot_3d_evoked_array(con_model_df_filtered.copy().pick(picks="hbo"), con_model, view="dorsal", distance=0.02, colorbar=True, mode="weighted", clim=clim, size=(800, 700), verbose=False) # type: ignore
|
||
|
||
if show_optodes == 'all' or show_optodes == 'sensors':
|
||
brain.add_sensors(getattr(con_model_df_filtered, "info"), trans=Transform('head', 'mri', np.eye(4)), fnirs=["channels", "pairs", "sources", "detectors"], verbose=False) # type: ignore
|
||
|
||
display_text = ('Contrast: ' + first_name + ' - ' + second_name + '\nLooking at: ' + t_or_theta + ' values')
|
||
|
||
# Apply the text onto the brain
|
||
if show_text:
|
||
brain.add_text(0.12, 0.70, display_text, "title", font_size=11, color="k") # type: ignore
|
||
|
||
|
||
|
||
def plot_2d_3d_contrasts_between_groups(
|
||
contrast_df_a: pd.DataFrame,
|
||
contrast_df_b: pd.DataFrame,
|
||
raw_haemo: BaseRaw,
|
||
group_a_name: str,
|
||
group_b_name: str,
|
||
is_3d: bool = True,
|
||
t_or_theta: Literal['t', 'theta'] = 'theta',
|
||
show_optodes: Literal['sensors', 'labels', 'none', 'all'] = 'all',
|
||
show_text: bool = True,
|
||
brain_bounds: float = 1.0,
|
||
) -> None:
|
||
|
||
|
||
logger.info("-----")
|
||
contrast_df_a = contrast_df_a.copy()
|
||
contrast_df_a["group"] = group_a_name
|
||
contrast_df_b = contrast_df_b.copy()
|
||
contrast_df_b["group"] = group_b_name
|
||
logger.info("-----")
|
||
|
||
df_combined = pd.concat([contrast_df_a, contrast_df_b], ignore_index=True)
|
||
|
||
con_summary = df_combined.query("Chroma == 'hbo'").copy()
|
||
logger.info("-----")
|
||
|
||
valid_channels = (pd.crosstab(con_summary["group"], con_summary["ch_name"]) > 1).all()
|
||
valid_channels = valid_channels[valid_channels].index.tolist()
|
||
con_summary = con_summary[con_summary["ch_name"].isin(valid_channels)]
|
||
logger.info("-----")
|
||
|
||
model_formula = "effect ~ -1 + group:ch_name:Chroma"
|
||
con_model = smf.mixedlm(model_formula, con_summary, groups=con_summary["ID"]).fit(method="nm")
|
||
logger.info("-----")
|
||
|
||
if t_or_theta == "t":
|
||
group1_vals = con_model.tvalues.filter(like=f"group[{group_a_name}]")
|
||
group2_vals = con_model.tvalues.filter(like=f"group[{group_b_name}]")
|
||
else:
|
||
group1_vals = con_model.params.filter(like=f"group[{group_a_name}]")
|
||
group2_vals = con_model.params.filter(like=f"group[{group_b_name}]")
|
||
logger.info("-----")
|
||
|
||
group1_channels = [name.split(":")[1].split("[")[1].split("]")[0] for name in group1_vals.index]
|
||
group2_channels = [name.split(":")[1].split("[")[1].split("]")[0] for name in group2_vals.index]
|
||
|
||
df_group1 = DataFrame({"Coef.": group1_vals.values}, index=group1_channels)
|
||
df_group2 = DataFrame({"Coef.": group2_vals.values}, index=group2_channels)
|
||
|
||
df_contrast = df_group1.join(df_group2, how="inner", lsuffix=f"_{group_a_name}", rsuffix=f"_{group_b_name}")
|
||
logger.info("-----")
|
||
|
||
# A - B
|
||
df_contrast["Coef."] = df_contrast[f"Coef._{group_a_name}"] - df_contrast[f"Coef._{group_b_name}"]
|
||
con_model_df_1_2 = DataFrame({
|
||
"ch_name": df_contrast.index,
|
||
"Coef.": df_contrast["Coef."],
|
||
"Chroma": "hbo"
|
||
})
|
||
logger.info("-----")
|
||
|
||
mne_ch_names = raw_haemo.copy().pick(picks="hbo").ch_names
|
||
glm_ch_names = con_model_df_1_2["ch_name"].tolist()
|
||
common_channels = [ch for ch in mne_ch_names if ch in glm_ch_names]
|
||
|
||
con_model_df_filtered = raw_haemo.copy().pick(picks=common_channels)
|
||
con_model_df_1_2 = con_model_df_1_2.set_index("ch_name").loc[common_channels].reset_index()
|
||
logger.info("-----")
|
||
|
||
if is_3d:
|
||
brain_3d_contrast(
|
||
con_model_df_1_2,
|
||
con_model_df_filtered,
|
||
common_channels,
|
||
group_a_name,
|
||
group_b_name,
|
||
t_or_theta,
|
||
show_optodes,
|
||
show_text,
|
||
brain_bounds
|
||
)
|
||
else:
|
||
plot_glm_group_topo(con_model_df_filtered.copy().pick(picks="hbo"), con_model_df_1_2, names=True, res=128, vlim=(-brain_bounds, brain_bounds)) # type: ignore
|
||
|
||
# TODO: The title currently goes on the colorbar. Low priority
|
||
plt.title(f"Contrast: {group_a_name} vs {group_b_name}") # type: ignore
|
||
plt.show() # type: ignore
|
||
|
||
# plt.title(f"Contrast: {group_a_name} vs {group_b_name}")
|
||
# plt.show()
|
||
|
||
# B - A
|
||
df_contrast["Coef."] = df_contrast[f"Coef._{group_b_name}"] - df_contrast[f"Coef._{group_a_name}"]
|
||
con_model_df_2_1 = DataFrame({
|
||
"ch_name": df_contrast.index,
|
||
"Coef.": df_contrast["Coef."],
|
||
"Chroma": "hbo"
|
||
})
|
||
|
||
glm_ch_names = con_model_df_2_1["ch_name"].tolist()
|
||
common_channels = [ch for ch in mne_ch_names if ch in glm_ch_names]
|
||
|
||
con_model_df_filtered = raw_haemo.copy().pick(picks=common_channels)
|
||
con_model_df_2_1 = con_model_df_2_1.set_index("ch_name").loc[common_channels].reset_index()
|
||
|
||
if is_3d:
|
||
brain_3d_contrast(
|
||
con_model_df_2_1,
|
||
con_model_df_filtered,
|
||
common_channels,
|
||
group_b_name,
|
||
group_a_name,
|
||
t_or_theta,
|
||
show_optodes,
|
||
show_text,
|
||
brain_bounds
|
||
)
|
||
else:
|
||
plot_glm_group_topo(con_model_df_filtered.copy().pick(picks="hbo"), con_model_df_2_1, names=True, res=128, vlim=(-brain_bounds, brain_bounds)) # type: ignore
|
||
|
||
# TODO: The title currently goes on the colorbar. Low priority
|
||
plt.title(f"Contrast: {group_b_name} vs {group_a_name}") # type: ignore
|
||
plt.show() # type: ignore
|
||
|
||
|
||
|
||
|
||
def plot_fir_model_results(df, raw_haemo, dm, selected_event, l_bound, u_bound):
|
||
|
||
|
||
df["isActivity"] = [f"{selected_event}" in n for n in df["Condition"]]
|
||
df["isDelay"] = ["delay" in n for n in df["Condition"]]
|
||
df = df.query("isDelay in [True]")
|
||
df = df.query("isActivity in [True]")
|
||
# Make a new column that stores the condition name for tidier model below
|
||
df.loc[:, "TidyCond"] = ""
|
||
df.loc[df["isActivity"] == True, "TidyCond"] = f"{selected_event}" # noqa: E712
|
||
# Finally, extract the FIR delay in to its own column in data frame
|
||
df.loc[:, "delay"] = [n.split("_")[-1] for n in df.Condition]
|
||
|
||
# To simplify this example we will only look at the activity
|
||
# condition so we now remove the other conditions from the
|
||
# design matrix and GLM results
|
||
dm_cols_activity = np.where([f"{selected_event}" in c for c in dm.columns])[0]
|
||
dm = dm[[dm.columns[i] for i in dm_cols_activity]]
|
||
|
||
lme = smf.mixedlm("theta ~ -1 + delay:TidyCond:Chroma", df, groups=df["ID"]).fit()
|
||
|
||
df_sum = statsmodels_to_results(lme)
|
||
df_sum["delay"] = [int(n) for n in df_sum["delay"]]
|
||
df_sum = df_sum.sort_values("delay")
|
||
|
||
# Print the result for the oxyhaemoglobin data in the target condition
|
||
df_sum.query(f"TidyCond in ['{selected_event}']").query("Chroma in ['hbo']")
|
||
|
||
|
||
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(19, 10))
|
||
|
||
print("dm columns:", dm.columns.tolist())
|
||
|
||
# Extract design matrix columns that correspond to the condition of interest
|
||
dm_cond_idxs = np.where([f"{selected_event}" in n for n in dm.columns])[0]
|
||
dm_cond = dm[[dm.columns[i] for i in dm_cond_idxs]]
|
||
|
||
# Extract the corresponding estimates from the lme dataframe for hbo
|
||
df_hbo = df_sum.query(f"TidyCond in ['{selected_event}']").query("Chroma in ['hbo']")
|
||
vals_hbo = [float(v) for v in df_hbo["Coef."]]
|
||
|
||
# print("--------------------------------------")
|
||
# print(f"dm_cond shape: {dm_cond.shape}")
|
||
# print(f"dm_cond columns: {dm_cond.columns.tolist()}")
|
||
# print(f"vals_hbo length: {len(vals_hbo)}")
|
||
# print(f"vals_hbo sample: {vals_hbo[:5]}")
|
||
# print(f"vals_hbo type: {type(vals_hbo)}")
|
||
# print(f"vals_hbo element type: {type(vals_hbo[0]) if len(vals_hbo) > 0 else 'N/A'}")
|
||
|
||
dm_cond_scaled_hbo = dm_cond * vals_hbo
|
||
|
||
# Extract the corresponding estimates from the lme dataframe for hbr
|
||
df_hbr = df_sum.query(f"TidyCond in ['{selected_event}']").query("Chroma in ['hbr']")
|
||
vals_hbr = [float(v) for v in df_hbr["Coef."]]
|
||
dm_cond_scaled_hbr = dm_cond * vals_hbr
|
||
|
||
first_onset = None
|
||
for desc, onset in zip(raw_haemo.annotations.description, raw_haemo.annotations.onset):
|
||
if selected_event in desc:
|
||
first_onset = onset
|
||
break
|
||
|
||
if first_onset is None:
|
||
raise ValueError(f"Selected event '{selected_event}' not found in annotations.")
|
||
|
||
# Align index values (time axis) to the first occurrence of selected_event
|
||
index_values = dm_cond_scaled_hbo.index - np.ceil(first_onset)
|
||
index_values = np.asarray(index_values)
|
||
|
||
# Plot the result
|
||
axes[0].plot(index_values, np.asarray(dm_cond))
|
||
axes[1].plot(index_values, np.asarray(dm_cond_scaled_hbo))
|
||
axes[2].plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r")
|
||
axes[2].plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b")
|
||
|
||
valid_mask = (index_values >= 0) & (index_values <= 15)
|
||
hbo_sum_window = np.sum(dm_cond_scaled_hbo.loc[valid_mask, :], axis=1)
|
||
peak_idx_in_window = np.argmax(hbo_sum_window)
|
||
peak_idx = np.where(valid_mask)[0][peak_idx_in_window]
|
||
peak_time = float(round(index_values[peak_idx], 2)) # type: ignore
|
||
|
||
axes[2].axvline(x=peak_time, color='k', linestyle='--', linewidth=1.5, label='Peak time') # type: ignore
|
||
|
||
# Format the plot
|
||
for ax in range(3):
|
||
axes[ax].set_xlim(-5, 25)
|
||
axes[ax].set_xlabel("Time (s)")
|
||
axes[0].set_ylim(-0.1, 1.1)
|
||
axes[1].set_ylim(l_bound, u_bound)
|
||
axes[2].set_ylim(l_bound, u_bound)
|
||
axes[0].set_title("FIR Model (Unscaled by GLM estimates)")
|
||
axes[1].set_title(f"FIR Components (Scaled by {selected_event} GLM Estimates)")
|
||
axes[2].set_title(f"Evoked Response {selected_event}")
|
||
axes[0].set_ylabel("FIR Model")
|
||
axes[1].set_ylabel("Oyxhaemoglobin (ΔμMol)")
|
||
axes[2].set_ylabel("Haemoglobin (ΔμMol)")
|
||
axes[2].legend(["Oyxhaemoglobin", "Deoyxhaemoglobin"])
|
||
|
||
# We can also extract the 95% confidence intervals of the estimates too
|
||
l95_hbo = [float(v) for v in df_hbo["[0.025"]] # type: ignore
|
||
u95_hbo = [float(v) for v in df_hbo["0.975]"]] # type: ignore
|
||
dm_cond_scaled_hbo_l95 = dm_cond * l95_hbo
|
||
dm_cond_scaled_hbo_u95 = dm_cond * u95_hbo
|
||
l95_hbr = [float(v) for v in df_hbr["[0.025"]] # type: ignore
|
||
u95_hbr = [float(v) for v in df_hbr["0.975]"]] # type: ignore
|
||
dm_cond_scaled_hbr_l95 = dm_cond * l95_hbr
|
||
dm_cond_scaled_hbr_u95 = dm_cond * u95_hbr
|
||
|
||
axes2: Axes
|
||
fig2, axes2 = plt.subplots(nrows=1, ncols=1, figsize=(7, 7)) # type: ignore
|
||
|
||
# Plot the result
|
||
axes2.plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r") # type: ignore
|
||
axes2.plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b") # type: ignore
|
||
axes2.axvline(x=peak_time, color='k', linestyle='--', linewidth=1.5, label='Peak time') # type: ignore
|
||
|
||
axes2.fill_between( # type: ignore
|
||
index_values,
|
||
np.asarray(np.sum(dm_cond_scaled_hbo_l95, axis=1)),
|
||
np.asarray(np.sum(dm_cond_scaled_hbo_u95, axis=1)),
|
||
facecolor="red",
|
||
alpha=0.25,
|
||
)
|
||
axes2.fill_between( # type: ignore
|
||
index_values,
|
||
np.asarray(np.sum(dm_cond_scaled_hbr_l95, axis=1)),
|
||
np.asarray(np.sum(dm_cond_scaled_hbr_u95, axis=1)),
|
||
facecolor="blue",
|
||
alpha=0.25,
|
||
)
|
||
|
||
# Format the plot
|
||
axes2.set_xlim(-5, 20)
|
||
axes2.set_ylim(l_bound, u_bound)
|
||
axes2.set_title(f"Evoked Response with 95% confidence intervals for )") # type: ignore
|
||
axes2.set_ylabel("Haemoglobin (ΔμMol)") # type: ignore
|
||
axes2.legend(["Oyxhaemoglobin", "Deoyxhaemoglobin", f"Peak {peak_time}s"]) # type: ignore
|
||
axes2.set_xlabel("Time (s)") # type: ignore
|
||
|
||
fig2.tight_layout()
|
||
|
||
fig.show()
|
||
fig2.show()
|
||
|
||
|
||
|
||
def load_snirf(file_path: str) -> tuple[BaseRaw, Figure]:
|
||
"""
|
||
Loads a snirf file, optionally drops channels, downsamples, and creates a figure showing the results.
|
||
|
||
Parameters
|
||
----------
|
||
file_path : str
|
||
Path of the snirf file to load.
|
||
ID : str
|
||
File name of the the snirf file that was loaded.
|
||
drop_prefixes : list[str]
|
||
List of channel name prefixes to drop from the data.
|
||
|
||
Returns
|
||
-------
|
||
tuple[BaseRaw, Figure]
|
||
- BaseRaw: The processed data object.
|
||
- Figure: The corresponding Matplotlib figure.
|
||
"""
|
||
|
||
# Read the snirf file
|
||
raw = read_raw_snirf(file_path, preload=True, verbose=VERBOSITY) # type: ignore
|
||
raw.load_data(verbose=VERBOSITY) # type: ignore
|
||
|
||
# # Strip the specified amount of seconds from the start of the file
|
||
# total_duration = getattr(raw, "times")[-1]
|
||
# if total_duration > SECONDS_TO_STRIP:
|
||
# raw.crop(tmin=SECONDS_TO_STRIP, tmax=total_duration, verbose=VERBOSITY) # type: ignore
|
||
# logger.info(f"Stripped first {SECONDS_TO_STRIP} second(s) of data.")
|
||
# else:
|
||
# logger.info(f"Data length ({total_duration:.2f}s) less than strip duration; no cropping applied.")
|
||
|
||
# If the user forcibly dropped channels, remove them now before any processing occurs
|
||
# logger.info("Checking if there are channels to forcibly drop...")
|
||
# if drop_prefixes:
|
||
# logger.info("Force dropped channels was specified.")
|
||
# channels_to_drop = [ch for ch in cast(list[str], getattr(raw, "ch_names")) if any(ch.startswith(prefix) for prefix in drop_prefixes)]
|
||
# raw.drop_channels(channels_to_drop, "raise") # type: ignore
|
||
# logger.info("Force dropped channels:", channels_to_drop)
|
||
|
||
# If the user wants to downsample, do it right away
|
||
logger.info("Checking if we should downsample...")
|
||
if DOWNSAMPLE:
|
||
logger.info("Downsample was specified.")
|
||
sfreq_old = getattr(raw, "info")["sfreq"]
|
||
raw.resample(DOWNSAMPLE_FREQUENCY, verbose=VERBOSITY) # type: ignore
|
||
sfreq_new = getattr(raw, "info")["sfreq"]
|
||
logger.info(f"Finished downsampling. Old frequency: {sfreq_old}. New frequency: {sfreq_new}.")
|
||
|
||
logger.info("Successfully loaded the snirf file.")
|
||
|
||
return raw
|
||
|
||
|
||
def run_second_level_analysis(df_contrasts, raw, p, bounds):
|
||
"""
|
||
Perform second-level analysis using contrast data from multiple participants.
|
||
|
||
Parameters
|
||
----------
|
||
df_contrasts : pd.DataFrame
|
||
Combined contrast results from multiple participants.
|
||
Must include: ['ch_name', 'effect', 'ID']
|
||
|
||
Returns
|
||
-------
|
||
pd.DataFrame
|
||
Group-level t-values, p-values, and mean effect per channel.
|
||
"""
|
||
|
||
if not all(col in df_contrasts.columns for col in ['ch_name', 'effect', 'ID']):
|
||
raise ValueError("Input DataFrame must include 'ch_name', 'effect', and 'ID' columns.")
|
||
|
||
channels = df_contrasts['ch_name'].unique()
|
||
group_results = []
|
||
|
||
for ch in channels:
|
||
ch_data = df_contrasts[df_contrasts['ch_name'] == ch]
|
||
|
||
if ch_data['ID'].nunique() < 2:
|
||
logger.warning(f"Skipping channel {ch} — not enough subjects.")
|
||
continue
|
||
|
||
Y = ch_data['effect'].values
|
||
design_matrix = np.ones((len(Y), 1)) # intercept-only
|
||
model = OLSModel(design_matrix)
|
||
result = model.fit(Y)
|
||
|
||
t_val = result.t(0).item()
|
||
p_val = 2 * stats.t.sf(np.abs(t_val), df=result.df_model)
|
||
mean_beta = np.mean(Y)
|
||
|
||
group_results.append({
|
||
'ch_name': ch,
|
||
't_val': t_val,
|
||
'p_val': p_val,
|
||
'mean_beta': mean_beta,
|
||
'n_subjects': len(Y)
|
||
})
|
||
|
||
df_group = pd.DataFrame(group_results)
|
||
logger.info("Second-level results:\n%s", df_group)
|
||
|
||
|
||
# Extract the cource and detector positions from raw
|
||
src_pos: dict[int, tuple[float, float]] = {}
|
||
det_pos: dict[int, tuple[float, float]] = {}
|
||
for ch in getattr(raw, "info")["chs"]:
|
||
ch_name = ch['ch_name']
|
||
if not ch_name or not ch['loc'].any():
|
||
continue
|
||
parts = ch_name.split()[0]
|
||
src_str, det_str = parts.split('_')
|
||
src_num = int(src_str[1:])
|
||
det_num = int(det_str[1:])
|
||
src_pos[src_num] = ch['loc'][3:5]
|
||
det_pos[det_num] = ch['loc'][6:8]
|
||
|
||
# Set up the plot
|
||
fig, ax = plt.subplots(figsize=(8, 6)) # type: ignore
|
||
|
||
# Plot the sources
|
||
for pos in src_pos.values():
|
||
ax.scatter(pos[0], pos[1], s=120, c='k', marker='o', edgecolors='white', linewidths=1, zorder=3) # type: ignore
|
||
|
||
# Plot the detectors
|
||
for pos in det_pos.values():
|
||
ax.scatter(pos[0], pos[1], s=120, c='k', marker='s', edgecolors='white', linewidths=1, zorder=3) # type: ignore
|
||
|
||
# Ensure that the colors stay within the boundaries even if they are over or under the max/min values
|
||
norm = mcolors.Normalize(vmin=-bounds, vmax=bounds)
|
||
|
||
cmap: mcolors.Colormap = plt.get_cmap('seismic')
|
||
|
||
# Plot connections with avg t-values
|
||
for _, row in df_group.iterrows():
|
||
ch = row['ch_name']
|
||
pval = row['p_val']
|
||
tval = row['t_val']
|
||
|
||
if '_' not in ch:
|
||
logger.info(f"Skipping channel with unexpected format (no underscore): {ch}")
|
||
continue
|
||
|
||
src_str, det_str = ch.split('_')
|
||
det_parts = det_str.split()
|
||
detector_id = det_parts[0] # e.g. "D1"
|
||
hemo_type = det_parts[1].lower() if len(det_parts) > 1 else ''
|
||
|
||
logger.info(f"Parsing channel: {ch} -> src_str: {src_str}, det_str: {detector_id}, hemo_type: {hemo_type}")
|
||
|
||
if hemo_type != 'hbo':
|
||
logger.info(f"Skipping channel {ch} because hemo_type is not HbO: {hemo_type}")
|
||
continue
|
||
|
||
try:
|
||
src = int(src_str[1:])
|
||
det = int(detector_id[1:])
|
||
logger.info(f"Parsed src: {src}, det: {det}")
|
||
|
||
except Exception as e:
|
||
logger.info(f"Error parsing source/detector from channel '{ch}': {e}")
|
||
continue
|
||
|
||
if src in src_pos and det in det_pos:
|
||
x = [src_pos[src][0], det_pos[det][0]]
|
||
y = [src_pos[src][1], det_pos[det][1]]
|
||
style = '-' if pval <= p else '--'
|
||
color = cmap(norm(tval))
|
||
logger.info(f"Plotting {ch}: t={tval:.2f}, p={pval:.3f}, color={color}, style={style}")
|
||
ax.plot(x, y, linestyle=style, color=color, linewidth=4, alpha=0.9, zorder=2)
|
||
|
||
|
||
# Format the Colorbar
|
||
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
||
sm.set_array([])
|
||
cbar = plt.colorbar(sm, ax=ax, shrink=0.85) # type: ignore
|
||
cbar.set_label(f'Average value (hbo)', fontsize=11) # type: ignore
|
||
|
||
# Formatting the subplots
|
||
ax.set_aspect('equal')
|
||
ax.set_title(f"Average values (HbO)", fontsize=14) # type: ignore
|
||
ax.set_xlabel('X position (m)', fontsize=11) # type: ignore
|
||
ax.set_ylabel('Y position (m)', fontsize=11) # type: ignore
|
||
ax.grid(True, alpha=0.3) # type: ignore
|
||
|
||
# Set axis limits to be 1cm more than the optode positions
|
||
all_x = [pos[0] for pos in src_pos.values()] + [pos[0] for pos in det_pos.values()]
|
||
all_y = [pos[1] for pos in src_pos.values()] + [pos[1] for pos in det_pos.values()]
|
||
ax.set_xlim(min(all_x)-0.01, max(all_x)+0.01)
|
||
ax.set_ylim(min(all_y)-0.01, max(all_y)+0.01)
|
||
|
||
fig.tight_layout()
|
||
plt.show() # type: ignore
|
||
|
||
return df_group
|
||
|
||
|
||
def calculate_dpf(file_path):
|
||
# order is hbo / hbr
|
||
with h5py.File(file_path, 'r') as f:
|
||
wavelengths = f['/nirs/probe/wavelengths'][:]
|
||
logger.info(f"Wavelengths (nm): {wavelengths}")
|
||
wavelengths = sorted(wavelengths, reverse=True)
|
||
age = float(AGE)
|
||
logger.info(f"Their age was {AGE}")
|
||
a = 223.3
|
||
b = 0.05624
|
||
c = 0.8493
|
||
d = -5.723e-7
|
||
e = 0.001245
|
||
f = -0.9025
|
||
dpf = []
|
||
for w in wavelengths:
|
||
logger.info(w)
|
||
dpf.append(a + b * (age**c) + d* (w**3) + e * (w**2) + f*w)
|
||
logger.info(dpf)
|
||
return dpf
|
||
|
||
|
||
|
||
def iqr_threshold(coeffs: NDArray[float64], k: float = 1.5) -> floating[Any]:
|
||
|
||
"""
|
||
Calculate the interquartile range (IQR) threshold scaled by a factor, k.
|
||
|
||
Parameters
|
||
----------
|
||
coeffs : NDArray[float64]
|
||
Array of coefficients to compute the IQR from.
|
||
k : float, optional
|
||
Scaling factor for the IQR (default is 1.5).
|
||
|
||
Returns
|
||
-------
|
||
floating[Any]
|
||
The scaled IQR threshold value.
|
||
"""
|
||
|
||
# Calculate the IQR
|
||
q1 = np.percentile(coeffs, 25)
|
||
q3 = np.percentile(coeffs, 75)
|
||
iqr = q3 - q1
|
||
|
||
return k * iqr
|
||
|
||
|
||
|
||
def wavelet_iqr_denoise(signal: NDArray[float64], wavelet: str = 'db4', level: int = 3) -> NDArray[float64]:
|
||
"""
|
||
Denoises a signal using wavelet decomposition and IQR-based thresholding on detail coefficients.
|
||
|
||
Parameters
|
||
----------
|
||
signal : NDArray[float64]
|
||
The input signal array to denoise.
|
||
wavelet : str, optional
|
||
The type of wavelet to use for decomposition (default is 'db4').
|
||
level : int, optional
|
||
Decomposition level for wavelet transform (default is 3).
|
||
|
||
Returns
|
||
-------
|
||
NDArray[float64]
|
||
The denoised signal array, with the same length as the input.
|
||
"""
|
||
|
||
# Decompose the signal using wavelet transform and initialize a list with approximation coefficients
|
||
coeffs: list[NDArray[float64]] = pywt.wavedec(signal, wavelet, level=level) # type: ignore
|
||
cA = coeffs[0]
|
||
denoised_coeffs = [cA]
|
||
|
||
# Threshold detail coefficients to reduce noise
|
||
for cD in coeffs[1:]:
|
||
threshold = iqr_threshold(cD, IQR)
|
||
cD_thresh = np.sign(cD) * np.maximum(np.abs(cD) - threshold, 0.0) # np.where((cD < lower) | (cD > upper), 0, cD)
|
||
cD_thresh = cD_thresh.astype(float64)
|
||
denoised_coeffs.append(cD_thresh)
|
||
|
||
# Reconstruct the denoised signal
|
||
denoised_signal = cast(NDArray[float64], pywt.waverec(denoised_coeffs, wavelet)) # type: ignore
|
||
return denoised_signal[:len(signal)]
|
||
|
||
|
||
|
||
def calculate_and_apply_wavelet(data: BaseRaw) -> tuple[BaseRaw, Figure]:
|
||
"""
|
||
Applies a wavelet IQR denoising filter to the data and generates a plot.
|
||
|
||
Parameters
|
||
----------
|
||
data : BaseRaw
|
||
The loaded data object to process.
|
||
ID : str
|
||
File name of the the snirf file that was loaded.
|
||
|
||
Returns
|
||
-------
|
||
tuple[BaseRaw, Figure]
|
||
- BaseRaw: The processed data object.
|
||
- Figure: The corresponding Matplotlib figure.
|
||
"""
|
||
|
||
logger.info("Applying the wavelet filter...")
|
||
|
||
# Denoise the data
|
||
logger.info("Denoising the data...")
|
||
loaded_data: NDArray[float64] = data.get_data(verbose=VERBOSITY) # type: ignore
|
||
denoised_data = np.zeros_like(loaded_data)
|
||
|
||
logger.info("Calculating the IQR, decomposing the signal, and thresholding the coefficients...")
|
||
for ch in range(loaded_data.shape[0]):
|
||
denoised_data[ch, :] = wavelet_iqr_denoise(loaded_data[ch, :], wavelet=WAVELET_TYPE, level=WAVELET_LEVEL)
|
||
|
||
# Reconstruct the data with the annotations
|
||
logger.info("Reconstructing the data with annotations...")
|
||
raw_with_tddr_and_wavelet = RawArray(denoised_data, cast(Info, data.info), verbose=VERBOSITY)
|
||
raw_with_tddr_and_wavelet.set_annotations(data.annotations.copy(), verbose=VERBOSITY) # type: ignore
|
||
|
||
# Create a figure for the results
|
||
logger.info("Creating the figure...")
|
||
fig = cast(Figure, raw_with_tddr_and_wavelet.plot(show=False, n_channels=len(getattr(data, "ch_names")), duration=data.times[-1]).figure) # type: ignore
|
||
fig.suptitle(f"Wavelet for ", fontsize=16) # type: ignore
|
||
fig.subplots_adjust(top=0.92)
|
||
plt.close(fig)
|
||
|
||
logger.info("Successfully applied the wavelet filter.")
|
||
|
||
return raw_with_tddr_and_wavelet, fig
|
||
|
||
|
||
|
||
def short_channel_processing_for_hr(data: BaseRaw, short_chans: BaseRaw | None) -> tuple[float, NDArray[float64], NDArray[float64]]:
|
||
"""
|
||
Extract and trim short-channel fNIRS signal for heart rate analysis.
|
||
|
||
Parameters
|
||
----------
|
||
data : BaseRaw
|
||
The loaded data object to process.
|
||
short_chans : BaseRaw | None
|
||
Data object with only short separation channels, or None if unavailable.
|
||
|
||
Returns
|
||
-------
|
||
tuple[float, NDArray[float64], NDArray[float64]]
|
||
- float: Sampling frequency of the signal.
|
||
- NDArray[float64]: Trimmed short-channel signal.
|
||
- NDArray[float64]: Corresponding time values.
|
||
"""
|
||
|
||
# Find the short channel (or best candidate) and extract signal data and sampling frequency
|
||
logger.info("Extracting the signal and calculating the sampling frequency...")
|
||
|
||
# If a short channel exists, use it for our signal. Otherwise just take the first channel in the data
|
||
# TODO: Find a better way around this
|
||
if short_chans is not None:
|
||
signal = cast(NDArray[float64], short_chans.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore
|
||
else:
|
||
signal = cast(NDArray[float64], data.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore
|
||
|
||
# Calculate the sampling frequency
|
||
sfreq = cast(int, data.info['sfreq'])
|
||
|
||
# Trim start and end of the signal to remove edge artifacts
|
||
logger.info(f"Removing {SECONDS_TO_STRIP_HR} seconds from the beginning and end of the file...")
|
||
strip_samples = int(sfreq * SECONDS_TO_STRIP_HR)
|
||
signal_trimmed = signal[strip_samples:-strip_samples]
|
||
times_trimmed = cast(NDArray[float64], getattr(data, "times"))[strip_samples:-strip_samples]
|
||
|
||
return sfreq, signal_trimmed, times_trimmed
|
||
|
||
|
||
|
||
def calculate_heart_rate_neurokit(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[float64], float]:
|
||
"""
|
||
Calculate and smooth heart rate from a trimmed signal using NeuroKit.
|
||
|
||
Parameters
|
||
----------
|
||
sfreq : float
|
||
Sampling frequency of the signal.
|
||
signal_trimmed : NDArray[float64]
|
||
Preprocessed and trimmed fNIRS signal.
|
||
|
||
Returns
|
||
-------
|
||
tuple[NDArray[float64], float]
|
||
- NDArray[float64]: Smoothed heart rate time series (BPM).
|
||
- float: Mean heart rate.
|
||
"""
|
||
|
||
logger.info("Calculating heart rate using NeuroKit...")
|
||
|
||
# Filter signal to isolate heart rate frequencies and detect peaks
|
||
logger.info("Filtering the signal and detecting peaks...")
|
||
signal_filtered = cast(NDArray[float64], nk.signal_filter(signal_trimmed, sampling_rate=sfreq, lowcut=0.8, highcut=2.5)) # type: ignore
|
||
peaks_dict = cast(dict[str, Any], nk.signal_findpeaks(signal_filtered)) # type: ignore
|
||
peaks = peaks_dict['Peaks']
|
||
hr = cast(NDArray[float64], nk.signal_rate(peaks, sampling_rate=sfreq, desired_length=len(signal_trimmed))) # type: ignore
|
||
hr_clean = np.clip(hr, MAX_LOW_HR, MAX_HIGH_HR)
|
||
|
||
# Smooth heart rate time series by replacing spikes with local rolling mean and calculate the mean
|
||
logger.info("Smoothing the signal and calculating the mean...")
|
||
hr_series = pd.Series(hr_clean)
|
||
local_median = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).median()
|
||
spikes = hr_series > (local_median + 10)
|
||
smoothed_values = hr_series.copy()
|
||
smoothed_spikes = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).mean()
|
||
smoothed_values[spikes] = smoothed_spikes[spikes]
|
||
hr_smooth_nk = cast(NDArray[float64], smoothed_values.to_numpy()) # type: ignore
|
||
mean_hr_nk = hr_smooth_nk.mean()
|
||
|
||
logger.info("Original HR min/max: %f, %f", hr_clean.min(), hr_clean.max())
|
||
logger.info("Smoothed HR min/max:%f, %f", hr_smooth_nk.min(), hr_smooth_nk.max())
|
||
logger.info(f"Estimated mean HR nk: {mean_hr_nk:.1f} BPM")
|
||
|
||
logger.info("Successfully calculated heart rate using NeuroKit.")
|
||
|
||
return hr_smooth_nk, mean_hr_nk
|
||
|
||
|
||
|
||
def calculate_heart_rate_scipy(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float]:
|
||
"""
|
||
Estimate heart rate using spectral analysis on a high-pass filtered signal.
|
||
|
||
Parameters
|
||
----------
|
||
sfreq : float
|
||
Sampling frequency of the input signal.
|
||
signal_trimmed : NDArray[float64]
|
||
Trimmed fNIRS signal to analyze.
|
||
|
||
Returns
|
||
-------
|
||
tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float]
|
||
- NDArray[floating[Any]]: Frequencies converted to beats per minute (BPM).
|
||
- NDArray[float64]: Power spectral density (PSD) of the signal.
|
||
- np.ndarray[Any, np.dtype[np.bool_]]: Boolean mask indicating frequencies within heart rate range (30-300 BPM).
|
||
- float: Estimated mean heart rate in BPM corresponding to the PSD peak within the range.
|
||
"""
|
||
|
||
logger.info("Calculating heart rate using SciPy...")
|
||
|
||
# Apply a high-pass Butterworth filter to remove slow trends below 0.5 Hz from the trimmed signal (actual data)
|
||
logger.info("Applying a butterworth filter...")
|
||
b, a = cast(tuple[NDArray[float64], NDArray[float64]], butter(2, 0.5 / (sfreq / 2), btype='high'))
|
||
signal_hp = cast(NDArray[float64],filtfilt(b, a, signal_trimmed))
|
||
|
||
# Calculate the Power Spectral Density (PSD) of the filtered signal using Welch's method
|
||
logger.info("Calculating the PSD...")
|
||
nperseg = min(len(signal_hp), 4096)
|
||
frequencies_scipy, psd_scipy = cast(tuple[NDArray[float64], NDArray[float64]], welch(signal_hp, fs=sfreq, nperseg=nperseg, noverlap=nperseg//2))
|
||
|
||
# Convert frequency values to beats per minute (BPM) and set a heart rate range (30-300 BPM)
|
||
logger.info("Converting to BPM...")
|
||
freq_bpm_scipy = frequencies_scipy * 60
|
||
freq_range_scipy = (freq_bpm_scipy > 30) & (freq_bpm_scipy < 300)
|
||
|
||
# Identify the peak frequency within the heart rate range and estimate the mean heart rate in BPM
|
||
logger.info("Finding the mean...")
|
||
peak_index = np.argmax(psd_scipy[freq_range_scipy])
|
||
mean_hr_scipy = freq_bpm_scipy[freq_range_scipy][peak_index]
|
||
|
||
logger.info("Successfully calculated heart rate using SciPy.")
|
||
|
||
return freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy
|
||
|
||
|
||
def plot_heart_rate(
|
||
freq_bpm_scipy: NDArray[floating[Any]],
|
||
psd_scipy: NDArray[float64],
|
||
freq_range_scipy: np.ndarray[Any, np.dtype[np.bool_]],
|
||
mean_hr_scipy: float,
|
||
hr_smooth_nk: NDArray[floating[Any]],
|
||
mean_hr_nk: float,
|
||
times_trimmed: NDArray[floating[Any]],
|
||
overruled: bool
|
||
) -> tuple[Figure, Figure]:
|
||
"""
|
||
Generate plots comparing heart rate estimates from SciPy PSD and NeuroKit2.
|
||
|
||
Parameters
|
||
----------
|
||
freq_bpm_scipy : NDArray[floating[Any]]
|
||
Frequencies in beats per minute from SciPy PSD analysis.
|
||
psd_scipy : NDArray[float64]
|
||
Power spectral density values corresponding to freq_bpm_scipy.
|
||
freq_range_scipy : np.ndarray[Any, np.dtype[np.bool_]]
|
||
Boolean mask indicating the heart rate frequency range used in PSD.
|
||
mean_hr_scipy : float
|
||
Mean heart rate estimated from SciPy PSD peak.
|
||
hr_smooth_nk : NDArray[floating[Any]]
|
||
Smoothed instantaneous heart rate from NeuroKit2.
|
||
mean_hr_nk : float
|
||
Mean heart rate estimated from NeuroKit2 data.
|
||
times_trimmed : NDArray[floating[Any]]
|
||
Time points corresponding to hr_smooth_nk values.
|
||
overruled: bool
|
||
True if the heart rate from NeuroKit2 is overriding the results from the PSD.
|
||
|
||
Returns
|
||
-------
|
||
tuple[Figure, Figure]
|
||
- Figure showing the PSD and SciPy heart rate estimate.
|
||
- Figure showing the time series comparison of heart rates.
|
||
"""
|
||
|
||
# Create the first plot for the PSD. Add a yellow range to show what we will be filtering to.
|
||
logger.info("Creating the figure...")
|
||
fig1, ax1 = plt.subplots(figsize=(10, 5)) # type: ignore
|
||
ax1.set_xlim(30, 300)
|
||
ax1.plot(freq_bpm_scipy[freq_range_scipy], psd_scipy[freq_range_scipy]) # type: ignore
|
||
ax1.axvline(x=mean_hr_scipy, color='red', linestyle='--', label=f'Mean HR: {mean_hr_scipy:.1f} BPM') # type: ignore
|
||
ax1.axvspan(min(mean_hr_nk - HEART_RATE_WINDOW, mean_hr_scipy - HEART_RATE_WINDOW), max(mean_hr_nk + HEART_RATE_WINDOW, mean_hr_scipy + HEART_RATE_WINDOW), color='yellow', alpha=0.3, label=f'HR Range ±{HEART_RATE_WINDOW} BPM') # type: ignore
|
||
ax1.set_xlabel('Heart Rate (BPM)') # type: ignore
|
||
ax1.set_ylabel('Power Spectral Density') # type: ignore
|
||
ax1.set_title('PSD of fNIRS signal - Peak indicates Heart Rate') # type: ignore
|
||
ax1.grid(True) # type: ignore
|
||
|
||
# Was the value we reported here correct for the data on the graph or was it overruled?
|
||
if overruled:
|
||
note = (
|
||
'\n'
|
||
'Note: Calculation was bad!\n'
|
||
'Data has been set to match\n'
|
||
'the value from NeuroKit2.'
|
||
)
|
||
phantom = Line2D([0], [0], color='none', label=note)
|
||
handles, _ = ax1.get_legend_handles_labels()
|
||
ax1.legend(handles=handles + [phantom]) # type: ignore
|
||
|
||
else:
|
||
ax1.legend() # type: ignore
|
||
plt.close(fig1)
|
||
|
||
# Create the second plot showing the rolling heart rate, as well as the two averages that were calculated
|
||
logger.info("Creating the figure...")
|
||
fig2, ax2 = plt.subplots(figsize=(14, 6)) # type: ignore
|
||
ax2.plot(times_trimmed, hr_smooth_nk, label='Instantaneous HR (NeuroKit2)', color='blue', alpha=0.7) # type: ignore
|
||
ax2.axhline(mean_hr_nk, color='red', linestyle='--', label=f'Mean HR NeuroKit2: {mean_hr_nk:.1f} BPM') # type: ignore
|
||
ax2.axhline(mean_hr_scipy, color='orange', linestyle=':', label=f'SciPy Welch PSD (HP filtered): {mean_hr_scipy:.1f} BPM') # type: ignore
|
||
ax2.set_xlabel('Time (seconds)') # type: ignore
|
||
ax2.set_ylabel('Heart Rate (BPM)') # type: ignore
|
||
ax2.set_title('Heart Rate Estimates Comparison') # type: ignore
|
||
ax2.legend() # type: ignore
|
||
ax2.grid(True) # type: ignore
|
||
fig2.tight_layout()
|
||
plt.close(fig2)
|
||
|
||
return fig1, fig2
|
||
|
||
|
||
|
||
def hr_calc(raw):
|
||
if SHORT_CHANNEL:
|
||
short_chans = get_short_channels(raw, max_dist=SHORT_CHANNEL_THRESH)
|
||
else:
|
||
short_chans = None
|
||
sfreq, signal_trimmed, times_trimmed = short_channel_processing_for_hr(raw, short_chans)
|
||
hr_smooth_nk, mean_hr_nk = calculate_heart_rate_neurokit(sfreq, signal_trimmed)
|
||
freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy = calculate_heart_rate_scipy(sfreq, signal_trimmed)
|
||
|
||
# HACK: This sucks but looking at the graphs I trust neurokit2 more
|
||
overruled = False
|
||
if mean_hr_scipy < mean_hr_nk - 15:
|
||
mean_hr_scipy = mean_hr_nk
|
||
overruled = True
|
||
if mean_hr_scipy > mean_hr_nk + 15:
|
||
mean_hr_scipy = mean_hr_nk
|
||
overruled = True
|
||
|
||
hr1, hr2 = plot_heart_rate(freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy, hr_smooth_nk, mean_hr_nk, times_trimmed, overruled)
|
||
|
||
fig = raw.plot_psd(show=False)
|
||
raw_filtered = raw.copy().filter(0.5, 3, fir_design='firwin')
|
||
sfreq = raw.info['sfreq']
|
||
data = raw_filtered.get_data()
|
||
channel_names = raw.ch_names
|
||
|
||
# --- Parameters for PSD ---
|
||
desired_bin_hz = 0.1
|
||
nperseg = int(sfreq / desired_bin_hz)
|
||
hr_range = (30, 180)
|
||
|
||
# --- Function to find strongest local peak ---
|
||
def find_hr_from_psd(ch_data):
|
||
f, Pxx = welch(ch_data, sfreq, nperseg=nperseg)
|
||
mask = (f >= hr_range[0]/60) & (f <= hr_range[1]/60)
|
||
f_masked = f[mask]
|
||
Pxx_masked = Pxx[mask]
|
||
if len(Pxx_masked) < 3:
|
||
return np.nan
|
||
peaks = [i for i in range(1, len(Pxx_masked)-1)
|
||
if Pxx_masked[i] > Pxx_masked[i-1] and Pxx_masked[i] > Pxx_masked[i+1]]
|
||
if not peaks:
|
||
return np.nan
|
||
best_idx = peaks[np.argmax([Pxx_masked[i] for i in peaks])]
|
||
return f_masked[best_idx] * 60 # bpm
|
||
|
||
# --- Compute HR across all channels ---
|
||
hr_all_channels = np.array([find_hr_from_psd(data[i, :]) for i in range(len(channel_names))])
|
||
hr_all_channels = hr_all_channels[~np.isnan(hr_all_channels)]
|
||
hr_mode = np.round(np.median(hr_all_channels)) # Use median if some NaNs
|
||
|
||
print(f"Estimated Heart Rate: {hr_mode} bpm")
|
||
|
||
hr_freq = hr_mode / 60 # Hz
|
||
low = hr_freq - 0.3
|
||
high = hr_freq + 0.3
|
||
return fig, hr1, hr2, low, high
|
||
|
||
|
||
def process_participant(file_path, progress_callback=None):
|
||
|
||
fig_individual: dict[str, Figure] = {}
|
||
|
||
# Step 1: Load
|
||
raw = load_snirf(file_path)
|
||
fig_raw = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Loaded Raw", show=False)
|
||
fig_individual["Loaded Raw"] = fig_raw
|
||
if progress_callback: progress_callback(1)
|
||
logger.info("1")
|
||
|
||
|
||
if TRIM:
|
||
if hasattr(raw, 'annotations') and len(raw.annotations) > 0:
|
||
# Get time of first event
|
||
first_event_time = raw.annotations.onset[0]
|
||
trim_time = max(0, first_event_time - SECONDS_TO_KEEP) # Ensure we don't go negative
|
||
raw.crop(tmin=trim_time)
|
||
# Shift annotation onsets to match new t=0
|
||
import mne
|
||
|
||
ann = raw.annotations
|
||
ann_shifted = mne.Annotations(
|
||
onset=ann.onset - trim_time, # shift to start at zero
|
||
duration=ann.duration,
|
||
description=ann.description
|
||
)
|
||
data = raw.get_data()
|
||
info = raw.info.copy()
|
||
raw = mne.io.RawArray(data, info)
|
||
raw.set_annotations(ann_shifted)
|
||
|
||
logger.info(f"Trimmed raw data: start at {trim_time}s (5s before first event), t=0 at new start")
|
||
else:
|
||
logger.warning("No events found, skipping trim step.")
|
||
|
||
fig_trimmed = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Trimmed Raw", show=False)
|
||
fig_individual["Trimmed Raw"] = fig_trimmed
|
||
if progress_callback: progress_callback(2)
|
||
logger.info("2")
|
||
|
||
# Step 1.5: Verify optode positions
|
||
if OPTODE_PLACEMENT:
|
||
fig_optodes = raw.plot_sensors(show_names=True, to_sphere=True, show=False) # type: ignore
|
||
fig_individual["Plot Sensors"] = fig_optodes
|
||
if progress_callback: progress_callback(3)
|
||
logger.info("3")
|
||
|
||
# Step 2: Bad from SCI
|
||
if HEART_RATE:
|
||
fig, hr1, hr2, low, high = hr_calc(raw)
|
||
fig_individual["PSD"] = fig
|
||
fig_individual['HeartRate_PSD'] = hr1
|
||
fig_individual['HeartRate_Time'] = hr2
|
||
if progress_callback: progress_callback(4)
|
||
logger.info("4")
|
||
|
||
bad_sci = []
|
||
if SCI:
|
||
bad_sci, fig_sci_1, fig_sci_2 = calculate_scalp_coupling(raw, low, high)
|
||
fig_individual["SCI1"] = fig_sci_1
|
||
fig_individual["SCI2"] = fig_sci_2
|
||
if progress_callback: progress_callback(5)
|
||
logger.info("5")
|
||
|
||
# Step 2: Bad from SNR
|
||
bad_snr = []
|
||
if SNR:
|
||
bad_snr, fig_snr = calculate_signal_noise_ratio(raw)
|
||
fig_individual["SNR1"] = fig_snr
|
||
if progress_callback: progress_callback(6)
|
||
logger.info("6")
|
||
|
||
# Step 3: Bad from PSP
|
||
bad_psp = []
|
||
if PSP:
|
||
bad_psp, fig_psp1, fig_psp2 = calculate_peak_power(raw)
|
||
fig_individual["PSP1"] = fig_psp1
|
||
fig_individual["PSP2"] = fig_psp2
|
||
if progress_callback: progress_callback(7)
|
||
logger.info("7")
|
||
|
||
# Step 4: Mark the bad channels
|
||
raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp)
|
||
if fig_dropped and fig_raw_before is not None:
|
||
fig_individual["fig2"] = fig_dropped
|
||
fig_individual["fig3"] = fig_raw_before
|
||
if progress_callback: progress_callback(8)
|
||
logger.info("8")
|
||
|
||
# Step 5: Interpolate the bad channels
|
||
if bad_channels:
|
||
raw, fig_raw_after = interpolate_fNIRS_bads_weighted_average(raw, bad_channels)
|
||
fig_individual["fig4"] = fig_raw_after
|
||
if progress_callback: progress_callback(9)
|
||
logger.info("9")
|
||
|
||
# Step 6: Optical Density
|
||
raw_od = optical_density(raw)
|
||
fig_raw_od = raw_od.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Optical Density", show=False)
|
||
fig_individual["Optical Density"] = fig_raw_od
|
||
if progress_callback: progress_callback(10)
|
||
logger.info("10")
|
||
|
||
# Step 7: TDDR
|
||
if TDDR:
|
||
raw_od = temporal_derivative_distribution_repair(raw_od)
|
||
fig_raw_od_tddr = raw_od.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="After TDDR (Motion Correction)", show=False)
|
||
fig_individual["TDDR"] = fig_raw_od_tddr
|
||
if progress_callback: progress_callback(11)
|
||
logger.info("11")
|
||
|
||
|
||
if WAVELET:
|
||
raw_od, fig = calculate_and_apply_wavelet(raw_od)
|
||
fig_individual["Wavelet"] = fig
|
||
if progress_callback: progress_callback(12)
|
||
logger.info("12")
|
||
|
||
|
||
# Step 8: BLL
|
||
raw_haemo = beer_lambert_law(raw_od, ppf=calculate_dpf(file_path))
|
||
fig_raw_haemo_bll = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="HbO and HbR Signals", show=False)
|
||
fig_individual["BLL"] = fig_raw_haemo_bll
|
||
if progress_callback: progress_callback(13)
|
||
logger.info("13")
|
||
|
||
# Step 9: ENC
|
||
if ENHANCE_NEGATIVE_CORRELATION:
|
||
raw_haemo = enhance_negative_correlation(raw_haemo)
|
||
fig_raw_haemo_enc = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="HbO and HbR Signals", show=False)
|
||
fig_individual["ENC"] = fig_raw_haemo_enc
|
||
if progress_callback: progress_callback(14)
|
||
logger.info("14")
|
||
|
||
# Step 10: Filter
|
||
if FILTER:
|
||
raw_haemo, fig_filter, fig_raw_haemo_filter = filter_the_data(raw_haemo)
|
||
fig_individual["filter1"] = fig_filter
|
||
fig_individual["filter2"] = fig_raw_haemo_filter
|
||
if progress_callback: progress_callback(15)
|
||
logger.info("15")
|
||
|
||
# Step 11: Get short / long channels
|
||
if SHORT_CHANNEL:
|
||
short_chans = get_short_channels(raw_haemo, max_dist=SHORT_CHANNEL_THRESH)
|
||
fig_short_chans = short_chans.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="Short Channels Only", show=False)
|
||
fig_individual["short"] = fig_short_chans
|
||
else:
|
||
short_chans = None
|
||
raw_haemo = get_long_channels(raw_haemo, min_dist=SHORT_CHANNEL_THRESH, max_dist=LONG_CHANNEL_THRESH)
|
||
if progress_callback: progress_callback(16)
|
||
logger.info("16")
|
||
|
||
# Step 12: Events from annotations
|
||
events, event_dict = events_from_annotations(raw_haemo)
|
||
fig_events = plot_events(events, event_id=event_dict, sfreq=raw_haemo.info["sfreq"], show=False)
|
||
fig_individual["events"] = fig_events
|
||
if progress_callback: progress_callback(17)
|
||
logger.info("17")
|
||
|
||
# Step 13: Epoch calculations
|
||
epochs, fig_epochs = epochs_calculations(raw_haemo, events, event_dict)
|
||
for name, fig in fig_epochs: # Unpack the tuple here
|
||
fig_individual[f"epochs_{name}"] = fig # Store only the figure, not the name
|
||
if progress_callback: progress_callback(18)
|
||
logger.info("18")
|
||
|
||
# Step 14: Design Matrix
|
||
events_to_remove = REMOVE_EVENTS
|
||
|
||
filtered_annotations = [ann for ann in raw.annotations if ann['description'] not in events_to_remove]
|
||
|
||
new_annot = Annotations(
|
||
onset=[ann['onset'] for ann in filtered_annotations],
|
||
duration=[ann['duration'] for ann in filtered_annotations],
|
||
description=[ann['description'] for ann in filtered_annotations]
|
||
)
|
||
|
||
# Set the new annotations
|
||
raw_haemo.set_annotations(new_annot)
|
||
|
||
design_matrix, fig_design_matrix = make_design_matrix(raw_haemo, short_chans)
|
||
fig_individual["Design Matrix"] = fig_design_matrix
|
||
if progress_callback: progress_callback(19)
|
||
logger.info("19")
|
||
|
||
# Step 15: Run GLM
|
||
glm_est = run_glm(raw_haemo, design_matrix)
|
||
# Not used AppData\Local\Packages\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\LocalCache\local-packages\Python313\site-packages\nilearn\glm\contrasts.py
|
||
# Yes used AppData\Local\Packages\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\LocalCache\local-packages\Python313\site-packages\mne_nirs\utils\_io.py
|
||
|
||
# The p-value is calculated from this t-statistic using the Student’s t-distribution with appropriate degrees of freedom.
|
||
# p_value = 2 * stats.t.cdf(-abs(t_statistic), df)
|
||
# It is a two-tailed p-value.
|
||
# It says how likely it is to observe the effect you did (or something more extreme) if the true effect was zero (null hypothesis).
|
||
# A small p-value (e.g., < 0.05) suggests the effect is unlikely to be zero — it’s "statistically significant."
|
||
# A large p-value means the data do not provide strong evidence that the effect is different from zero.
|
||
|
||
|
||
if progress_callback: progress_callback(20)
|
||
logger.info("20")
|
||
|
||
# Step 16: Plot GLM results
|
||
fig_glm_result = plot_glm_results(file_path, raw_haemo, glm_est, design_matrix)
|
||
for name, fig in fig_glm_result:
|
||
fig_individual[f"GLM {name}"] = fig
|
||
if progress_callback: progress_callback(21)
|
||
logger.info("21")
|
||
|
||
# Step 17: Plot channel significance
|
||
fig_significance = individual_significance(raw_haemo, glm_est)
|
||
for name, fig in fig_significance:
|
||
fig_individual[f"Significance {name}"] = fig
|
||
if progress_callback: progress_callback(22)
|
||
logger.info("22")
|
||
|
||
# Step 18: cha, con, roi
|
||
cha = glm_est.to_dataframe()
|
||
|
||
# HACK: Comment out line 588 (self._renderer.show()) in _brain.py from MNE
|
||
# brain_thing = brain_3d_visualization(cha, raw_haemo)
|
||
# brain_individual.append(brain_thing)
|
||
# C++ objects made this get rendered on the fly
|
||
|
||
rois = dict(AllChannels=range(len(raw_haemo.ch_names)))
|
||
# Calculate ROI for all conditions
|
||
conditions = design_matrix.columns
|
||
# Compute output metrics by ROI
|
||
df_ind = glm_est.to_dataframe_region_of_interest(rois, conditions)
|
||
|
||
df_ind["ID"] = file_path
|
||
|
||
# Step 18: Fold channels
|
||
# fig_fold_data, fig_fold_legend = fold_channels(raw_haemo)
|
||
# fig_individual.append(fig_fold_data)
|
||
# fig_individual.append(fig_fold_legend)
|
||
print(design_matrix)
|
||
|
||
|
||
contrast_matrix = np.eye(design_matrix.shape[1])
|
||
basic_conts = dict(
|
||
[(column, contrast_matrix[i]) for i, column in enumerate(design_matrix.columns)]
|
||
)
|
||
|
||
all_delay_cols = [col for col in design_matrix.columns if "_delay_" in col]
|
||
all_conditions = sorted({col.split("_delay_")[0] for col in all_delay_cols})
|
||
|
||
if not all_conditions:
|
||
raise ValueError("No FIR regressors found in the design matrix.")
|
||
|
||
# Build contrast vectors for each condition
|
||
contrast_dict = {}
|
||
|
||
for condition in all_conditions:
|
||
delay_cols = [
|
||
col for col in all_delay_cols
|
||
if col.startswith(f"{condition}_delay_") and
|
||
TIME_WINDOW_START <= int(col.split("_delay_")[-1]) <= TIME_WINDOW_END
|
||
]
|
||
|
||
if not delay_cols:
|
||
continue # skip if no columns found (shouldn't happen?)
|
||
|
||
# Average across all delay regressors for this condition
|
||
contrast_vector = np.sum([basic_conts[col] for col in delay_cols], axis=0)
|
||
contrast_vector /= len(delay_cols)
|
||
|
||
contrast_dict[condition] = contrast_vector
|
||
|
||
if progress_callback: progress_callback(23)
|
||
logger.info("23")
|
||
|
||
# Compute contrast results
|
||
contrast_results = {}
|
||
|
||
for cond, contrast_vector in contrast_dict.items():
|
||
contrast = glm_est.compute_contrast(contrast_vector) # type: ignore
|
||
df = contrast.to_dataframe()
|
||
df["ID"] = file_path
|
||
contrast_results[cond] = df
|
||
|
||
cha["ID"] = file_path
|
||
|
||
if progress_callback: progress_callback(24)
|
||
logger.info("24")
|
||
|
||
|
||
fig_bytes = convert_fig_dict_to_png_bytes(fig_individual)
|
||
|
||
sanitize_paths_for_pickle(raw_haemo, epochs)
|
||
|
||
if progress_callback: progress_callback(25)
|
||
logger.info("25")
|
||
|
||
return raw_haemo, epochs, fig_bytes, cha, contrast_results, df_ind, design_matrix, AGE, GENDER, GROUP, True
|
||
|
||
|
||
def sanitize_paths_for_pickle(raw_haemo, epochs):
|
||
# Fix raw_haemo._filenames
|
||
if hasattr(raw_haemo, '_filenames'):
|
||
raw_haemo._filenames = [str(p) for p in raw_haemo._filenames]
|
||
|
||
# Fix epochs._raw._filenames
|
||
if hasattr(epochs, '_raw') and hasattr(epochs._raw, '_filenames'):
|
||
epochs._raw._filenames = [str(p) for p in epochs._raw._filenames] |