@@ -21,6 +21,7 @@ import os.path as op
import re
import re
import traceback
import traceback
from concurrent . futures import ProcessPoolExecutor , as_completed
from concurrent . futures import ProcessPoolExecutor , as_completed
from queue import Empty
# External library imports
# External library imports
import matplotlib . pyplot as plt
import matplotlib . pyplot as plt
@@ -46,17 +47,18 @@ from nilearn.glm.regression import OLSModel
import statsmodels . formula . api as smf # type: ignore
import statsmodels . formula . api as smf # type: ignore
from statsmodels . stats . multitest import multipletests
from statsmodels . stats . multitest import multipletests
from scipy import stats
from scipy . spatial . distance import cdist
from scipy . spatial . distance import cdist
from scipy . signal import welch , butter , filtfilt # type: ignore
from scipy . signal import welch , butter , filtfilt # type: ignore
from scipy . stats import pearsonr , zscore , t
import pywt # type: ignore
import pywt # type: ignore
import neurokit2 as nk # type: ignore
import neurokit2 as nk # type: ignore
# Backen visualization needed to be defined for pyinstaller
# Backend visualization needed to be defined for pyinstaller
import pyvistaqt # type: ignore
import pyvistaqt # type: ignore
# import vtkmodules.util. data_model
import vtkmodules. util . data_model
# import vtkmodules.util. execution_model
import vtkmodules. util . execution_model
import xlrd
# External library imports for mne
# External library imports for mne
from mne import (
from mne import (
@@ -89,9 +91,13 @@ from mne_nirs.io.fold import fold_channel_specificity # type: ignore
from mne_nirs . preprocessing import peak_power # type: ignore
from mne_nirs . preprocessing import peak_power # type: ignore
from mne_nirs . statistics . _glm_level_first import RegressionResults # type: ignore
from mne_nirs . statistics . _glm_level_first import RegressionResults # type: ignore
from mne_connectivity . viz import plot_connectivity_circle
from mne_connectivity import envelope_correlation , spectral_connectivity_epochs , spectral_connectivity_time
# Needs to be set for mne
os . environ [ " SUBJECTS_DIR " ] = str ( data_path ( ) ) + " /subjects " # type: ignore
os . environ [ " SUBJECTS_DIR " ] = str ( data_path ( ) ) + " /subjects " # type: ignore
# TODO: Tidy this up
FIXED_CATEGORY_COLORS = {
FIXED_CATEGORY_COLORS = {
" SCI only " : " skyblue " ,
" SCI only " : " skyblue " ,
" PSP only " : " salmon " ,
" PSP only " : " salmon " ,
@@ -112,10 +118,6 @@ FIXED_CATEGORY_COLORS = {
}
}
AGE : float
GENDER : str
# SECONDS_TO_STRIP: int
DOWNSAMPLE : bool
DOWNSAMPLE : bool
DOWNSAMPLE_FREQUENCY : int
DOWNSAMPLE_FREQUENCY : int
@@ -123,21 +125,35 @@ TRIM: bool
SECONDS_TO_KEEP : float
SECONDS_TO_KEEP : float
OPTODE_PLACEMENT : bool
OPTODE_PLACEMENT : bool
SHOW_OPTODE_NAMES : bool
SHORT_CHANNEL : bool
SHORT_CHANNEL_THRESH : float
LONG_CHANNEL_THRESH : float
HEART_RATE : bool
HEART_RATE : bool
SECONDS_TO_STRIP_HR : int
MAX_LOW_HR : int
MAX_HIGH_HR : int
SMOOTHING_WINDOW_HR : int
HEART_RATE_WINDOW : int
SCI : bool
SCI : bool
SCI_TIME_WINDOW : int
SCI_TIME_WINDOW : int
SCI_THRESHOLD : float
SCI_THRESHOLD : float
SNR : bool
SNR : bool
# SNR_TIME_WINDOW : int
# SNR_TIME_WINDOW : int #TODO: is this needed?
SNR_THRESHOLD : float
SNR_THRESHOLD : float
PSP : bool
PSP : bool
PSP_TIME_WINDOW : int
PSP_TIME_WINDOW : int
PSP_THRESHOLD : float
PSP_THRESHOLD : float
BAD_CHANNELS_HANDLING : str
MAX_DIST : float
MIN_NEIGHBORS : int
TDDR : bool
TDDR : bool
WAVELET : bool
WAVELET : bool
@@ -145,57 +161,41 @@ IQR: float
WAVELET_TYPE : str
WAVELET_TYPE : str
WAVELET_LEVEL : int
WAVELET_LEVEL : int
HEART_RATE = True # True if heart rate should be calculated. This helps the SCI, PSP, and SNR methods to be more accurate.
SECONDS_TO_STRIP_HR = 5 # Amount of seconds to temporarily strip from the data to calculate heart rate more effectively. Useful if participant removed cap while still recording.
MAX_LOW_HR = 40 # Any heart rate values lower than this will be set to this value.
MAX_HIGH_HR = 200 # Any heart rate values higher than this will be set to this value.
SMOOTHING_WINDOW_HR = 100 # Heart rate will be calculated as a rolling average over this many amount of samples.
HEART_RATE_WINDOW = 25 # Amount of BPM above and below the calculated average to use for a range of resting BPM.
ENHANCE_NEGATIVE_CORRELATION : bool
ENHANCE_NEGATIVE_CORRELATION : bool
FILTER : bool
FILTER : bool
L_FREQ : float
L_FREQ : float
H_FREQ : float
H_FREQ : float
L_TRANS_BANDWIDTH : float
H_TRANS_BANDWIDTH : float
SHORT_CHANNEL : bool
RESAMPLE : bool
SHORT_CHANNEL_THRESH : floa t
RESAMPLE_FREQ : in t
LONG_CHANNEL_THRESH : float
STIM_DUR : float
HRF_MODEL : str
DRIFT_MODEL : str
HIGH_PASS : float
DRIFT_ORDER : int
FIR_DELAYS : range
MIN_ONSET : int
OVERSAMPLING : int
REMOVE_EVENTS : list
REMOVE_EVENTS : list
SHORT_CHANNEL_REGRESSION : bool
NOISE_MODEL : str
BINS : int
N_JOBS : int
TIME_WINDOW_START : int
TIME_WINDOW_START : int
TIME_WINDOW_END : int
TIME_WINDOW_END : int
MAX_WORKERS : int
VERBOSITY : bool
DRIFT_MODEL : str
AGE : int = 25 # Assume 25 if not set from the GUI. This will result in a reasonable PPF
GENDER : str = " "
VERBOSITY = True
GROUP : str = " Default "
# FIXME: Shouldn't need each ordering - just order it before checking
FIXED_CATEGORY_COLORS = {
" SCI only " : " skyblue " ,
" PSP only " : " salmon " ,
" SNR only " : " lightgreen " ,
" PSP + SCI " : " orange " ,
" SCI + SNR " : " violet " ,
" PSP + SNR " : " gold " ,
" SCI + PSP " : " orange " ,
" SNR + SCI " : " violet " ,
" SNR + PSP " : " gold " ,
" PSP + SNR + SCI " : " gray " ,
" SCI + PSP + SNR " : " gray " ,
" SCI + SNR + PSP " : " gray " ,
" PSP + SCI + SNR " : " gray " ,
" PSP + SNR + SCI " : " gray " ,
" SNR + SCI + PSP " : " gray " ,
" SNR + PSP + SCI " : " gray " ,
}
AGE = 25
GENDER = " "
GROUP = " Default "
# These are parameters that are required for the analysis
REQUIRED_KEYS : dict [ str , Any ] = {
REQUIRED_KEYS : dict [ str , Any ] = {
# "SECONDS_TO_STRIP": int,
# "SECONDS_TO_STRIP": int,
@@ -262,7 +262,7 @@ PLATFORM_NAME = platform.system().lower()
# Configure logging to file with timestamps and realtime flush
# Configure logging to file with timestamps and realtime flush
if PLATFORM_NAME == ' darwin ' :
if PLATFORM_NAME == ' darwin ' :
logging . basicConfig (
logging . basicConfig (
filename = os . path . join ( os . path . dirname ( sys . executable ) , " ../../../fnirs_analysis.log " ) ,
filename = os . path . join ( os . path . dirname ( sys . executable ) , " ../../../fnirs_analysis.log " ) , # Needed to get out of the bundled application
level = logging . INFO ,
level = logging . INFO ,
format = ' %(asctime)s - %(processName)s - %(levelname)s - %(message)s ' ,
format = ' %(asctime)s - %(processName)s - %(levelname)s - %(message)s ' ,
datefmt = ' % Y- % m- %d % H: % M: % S ' ,
datefmt = ' % Y- % m- %d % H: % M: % S ' ,
@@ -320,8 +320,6 @@ def set_metadata(file_path, metadata: dict[str, Any]) -> None:
val = file_metadata . get ( key , None )
val = file_metadata . get ( key , None )
if val not in ( None , ' ' , [ ] , { } , ( ) ) : # check for "empty" values
if val not in ( None , ' ' , [ ] , { } , ( ) ) : # check for "empty" values
globals ( ) [ key ] = val
globals ( ) [ key ] = val
from queue import Empty # This works with multiprocessing.Manager().Queue()
def gui_entry ( config : dict [ str , Any ] , gui_queue : Queue , progress_queue : Queue ) - > None :
def gui_entry ( config : dict [ str , Any ] , gui_queue : Queue , progress_queue : Queue ) - > None :
def forward_progress ( ) :
def forward_progress ( ) :
@@ -825,7 +823,7 @@ def get_hbo_hbr_picks(raw):
return hbo_picks , hbr_picks , hbo_wl , hbr_wl
return hbo_picks , hbr_picks , hbo_wl , hbr_wl
def interpolate_fNIRS_bads_weighted_average ( raw , bad_channels , max_dist = 0.03 , min_neighbors = 2 ) :
def interpolate_fNIRS_bads_weighted_average ( raw , max_dist = 0.03 , min_neighbors = 2 ) :
"""
"""
Interpolate bad fNIRS channels using a distance-weighted average of nearby good channels.
Interpolate bad fNIRS channels using a distance-weighted average of nearby good channels.
@@ -932,11 +930,12 @@ def interpolate_fNIRS_bads_weighted_average(raw, bad_channels, max_dist=0.03, mi
raw . info [ ' bads ' ] = [ ch for ch in raw . info [ ' bads ' ] if ch not in bad_ch_to_remove ]
raw . info [ ' bads ' ] = [ ch for ch in raw . info [ ' bads ' ] if ch not in bad_ch_to_remove ]
print ( " \n Interpolation complete. \n " )
print ( " \n Interpolation complete. \n " )
print ( " Bads cleared: " , raw . info [ ' bads ' ] )
raw . info [ ' bads ' ] = [ ]
for ch in raw . info [ ' bads ' ] :
for ch in raw . info [ ' bads ' ] :
print ( f " Channel { ch } still marked as bad. " )
print ( f " Channel { ch } still marked as bad. " )
print ( " Bads cleared: " , raw . info [ ' bads ' ] )
fig_raw_after = raw . plot ( duration = raw . times [ - 1 ] , n_channels = raw . info [ ' nchan ' ] , title = " After interpolation " , show = False )
fig_raw_after = raw . plot ( duration = raw . times [ - 1 ] , n_channels = raw . info [ ' nchan ' ] , title = " After interpolation " , show = False )
return raw , fig_raw_after
return raw , fig_raw_after
@@ -1117,17 +1116,17 @@ def mark_bads(raw, bad_sci, bad_snr, bad_psp):
def filter_the_data ( raw_haemo ) :
def filter_the_data ( raw_haemo ) :
# --- STEP 5: Filtering (0.01– 0.2 Hz bandpass) ---
# --- STEP 5: Filtering (0.01- 0.2 Hz bandpass) ---
fig_filter = raw_haemo . compute_psd ( fmax = 3 ) . plot (
fig_filter = raw_haemo . compute_psd ( fmax = 3 ) . plot (
average = True , color = " r " , show = False , amplitude = True
average = True , color = " r " , show = False , amplitude = True
)
)
if L_FREQ == 0 and H_FREQ != 0 :
if L_FREQ == 0 and H_FREQ != 0 :
raw_haemo = raw_haemo . filter ( l_freq = None , h_freq = H_FREQ , h_trans_bandwidth = 0.02 )
raw_haemo = raw_haemo . filter ( l_freq = None , h_freq = H_FREQ , h_trans_bandwidth = H_TRANS_BANDWIDTH )
elif L_FREQ != 0 and H_FREQ == 0 :
elif L_FREQ != 0 and H_FREQ == 0 :
raw_haemo = raw_haemo . filter ( l_freq = L_FREQ , h_freq = None , l_trans_bandwidth = 0.002 )
raw_haemo = raw_haemo . filter ( l_freq = L_FREQ , h_freq = None , l_trans_bandwidth = L_TRANS_BANDWIDTH )
elif L_FREQ != 0 and H_FREQ != 0 :
elif L_FREQ != 0 and H_FREQ != 0 :
raw_haemo = raw_haemo . filter ( l_freq = L_FREQ , h_freq = H_FREQ , l_trans_bandwidth = 0.002 , h_trans_bandwidth = 0.02 )
raw_haemo = raw_haemo . filter ( l_freq = L_FREQ , h_freq = H_FREQ , l_trans_bandwidth = L_TRANS_BANDWIDTH , h_trans_bandwidth = H_TRANS_BANDWIDTH )
else :
else :
print ( " No filter " )
print ( " No filter " )
#raw_haemo = raw_haemo.filter(l_freq=None, h_freq=0.4, h_trans_bandwidth=0.2)
#raw_haemo = raw_haemo.filter(l_freq=None, h_freq=0.4, h_trans_bandwidth=0.2)
@@ -1258,7 +1257,7 @@ def epochs_calculations(raw_haemo, events, event_dict):
continue
continue
data = evoked . data [ picks_idx , : ] . mean ( axis = 0 )
data = evoked . data [ picks_idx , : ] . mean ( axis = 0 )
t_start , t_end = 0 , 15
t_start , t_end = 0 , 15 #TODO: Is this in seconds? or is it 1hz input that makes it 15s?
times_mask = ( evoked . times > = t_start ) & ( evoked . times < = t_end )
times_mask = ( evoked . times > = t_start ) & ( evoked . times < = t_end )
data_segment = data [ times_mask ]
data_segment = data [ times_mask ]
times_segment = evoked . times [ times_mask ]
times_segment = evoked . times [ times_mask ]
@@ -1307,33 +1306,53 @@ def epochs_calculations(raw_haemo, events, event_dict):
def make_design_matrix ( raw_haemo , short_chans ) :
def make_design_matrix ( raw_haemo , short_chans ) :
raw_haemo . resample ( 1 , npad = " auto " )
events_to_remove = REMOVE_EVENTS
filtered_annotations = [ ann for ann in raw_haemo . annotations if ann [ ' description ' ] not in events_to_remove ]
new_annot = Annotations (
onset = [ ann [ ' onset ' ] for ann in filtered_annotations ] ,
duration = [ ann [ ' duration ' ] for ann in filtered_annotations ] ,
description = [ ann [ ' description ' ] for ann in filtered_annotations ]
)
# Set the new annotations
raw_haemo . set_annotations ( new_annot )
if RESAMPLE :
raw_haemo . resample ( RESAMPLE_FREQ , npad = " auto " )
raw_haemo . _data = raw_haemo . _data * 1e6
raw_haemo . _data = raw_haemo . _data * 1e6
try :
short_chans . resample ( RESAMPLE_FREQ )
except :
pass
# 2) Create design matrix
# 2) Create design matrix
if SHORT_CHANNEL :
if SHORT_CHANNEL_REGRESSION :
short_chans . resample ( 1 )
design_matrix = make_first_level_design_matrix (
design_matrix = make_first_level_design_matrix (
raw = raw_haemo ,
raw = raw_haemo ,
hrf_model = ' fir ' ,
stim_dur = STIM_DUR ,
stim_dur = 0.5 ,
hrf_model = HRF_MODEL ,
fir_delays = range ( 15 ) ,
drift_model = DRIFT_MODEL ,
drift_model = DRIFT_MODEL ,
high_pass = 0.01 ,
high_pass = HIGH_PASS ,
oversampling = 1 ,
drift_order = DRIFT_ORDER ,
min_onset = - 125 ,
fir_delays = FIR_DELAYS ,
add_regs = short_chans . get_data ( ) . T ,
add_regs = short_chans . get_data ( ) . T ,
add_reg_names = short_chans . ch_names
add_reg_names = short_chans . ch_names ,
min_onset = MIN_ONSET ,
oversampling = OVERSAMPLING
)
)
else :
else :
design_matrix = make_first_level_design_matrix (
design_matrix = make_first_level_design_matrix (
raw = raw_haemo ,
raw = raw_haemo ,
hrf_model = ' fir ' ,
stim_dur = STIM_DUR ,
stim_dur = 0.5 ,
hrf_model = HRF_MODEL ,
fir_delays = range ( 15 ) ,
drift_model = DRIFT_MODEL ,
drift_model = DRIFT_MODEL ,
high_pass = 0.01 ,
high_pass = HIGH_PASS ,
oversampling = 1 ,
drift_order = DRIFT_ORDER ,
min_onset = - 125 ,
fir_delays = FIR_DELAYS ,
min_onset = MIN_ONSET ,
oversampling = OVERSAMPLING
)
)
print ( design_matrix . head ( ) )
print ( design_matrix . head ( ) )
@@ -1643,8 +1662,11 @@ def fold_channels(raw: BaseRaw) -> None:
landmark_color_map = { landmark : colors [ i % len ( colors ) ] for i , landmark in enumerate ( landmarks ) }
landmark_color_map = { landmark : colors [ i % len ( colors ) ] for i , landmark in enumerate ( landmarks ) }
# Iterate over each channel
# Iterate over each channel
print ( len ( hbo_channel_names ) )
for idx , channel_name in enumerate ( hbo_channel_names ) :
for idx , channel_name in enumerate ( hbo_channel_names ) :
print ( idx , channel_name )
# Run the fOLD on the selected channel
# Run the fOLD on the selected channel
channel_data = raw . copy ( ) . pick ( picks = channel_name ) # type: ignore
channel_data = raw . copy ( ) . pick ( picks = channel_name ) # type: ignore
@@ -1687,6 +1709,7 @@ def fold_channels(raw: BaseRaw) -> None:
landmark_specificity_data = [ ]
landmark_specificity_data = [ ]
# TODO: Fix this
# TODO: Fix this
if True :
if True :
handles = [
handles = [
@@ -1709,8 +1732,9 @@ def fold_channels(raw: BaseRaw) -> None:
for ax in axes [ len ( hbo_channel_names ) : ] :
for ax in axes [ len ( hbo_channel_names ) : ] :
ax . axis ( ' off ' )
ax . axis ( ' off ' )
plt. show ( )
# plt.show( )
return fig , legend_fig
fig_dict = { " main " : fig , " legend " : legend_fig }
return convert_fig_dict_to_png_bytes ( fig_dict )
@@ -2230,9 +2254,15 @@ def brain_3d_visualization(raw_haemo, df_cha, selected_event, t_or_theta: Litera
# Get all activity conditions
# Get all activity conditions
for cond in [ f ' { selected_event } ' ] :
for cond in [ f ' { selected_event } ' ] :
if True :
ch_summary = df_cha . query ( f " Condition.str.startswith( ' { cond } _delay_ ' ) and Chroma == ' hbo ' " , engine = ' python ' ) # type: ignore
ch_summary = df_cha . query ( f " Condition.str.startswith( ' { cond } _delay_ ' ) and Chroma == ' hbo ' " , engine = ' python ' ) # type: ignore
print ( ch_summary )
if ch_summary . empty :
#not fir model
print ( " No data found for this condition. " )
ch_summary = df_cha . query ( f " Condition in [@cond] and Chroma == ' hbo ' " , engine = ' python ' )
# Use ordinary least squares (OLS) if only one participant
# Use ordinary least squares (OLS) if only one participant
# TODO: Fix.
# TODO: Fix.
if True :
if True :
@@ -2253,6 +2283,9 @@ def brain_3d_visualization(raw_haemo, df_cha, selected_event, t_or_theta: Litera
valid_channels = ch_summary [ " ch_name " ] . unique ( ) . tolist ( ) # type: ignore
valid_channels = ch_summary [ " ch_name " ] . unique ( ) . tolist ( ) # type: ignore
raw_for_plot = raw_haemo . copy ( ) . pick ( picks = valid_channels ) # type: ignore
raw_for_plot = raw_haemo . copy ( ) . pick ( picks = valid_channels ) # type: ignore
print ( f " DEBUG: Model DF rows: { len ( model_df ) } " )
print ( f " DEBUG: Raw channels: { len ( raw_for_plot . ch_names ) } " )
brain = plot_3d_evoked_array ( raw_for_plot . pick ( picks = " hbo " ) , model_df , view = " dorsal " , distance = 0.02 , colorbar = True , clim = clim , mode = " weighted " , size = ( 800 , 700 ) ) # type: ignore
brain = plot_3d_evoked_array ( raw_for_plot . pick ( picks = " hbo " ) , model_df , view = " dorsal " , distance = 0.02 , colorbar = True , clim = clim , mode = " weighted " , size = ( 800 , 700 ) ) # type: ignore
if show_optodes == ' all ' or show_optodes == ' sensors ' :
if show_optodes == ' all ' or show_optodes == ' sensors ' :
@@ -2569,7 +2602,10 @@ def plot_fir_model_results(df, raw_haemo, dm, selected_event, l_bound, u_bound):
dm_cols_activity = np . where ( [ f " { selected_event } " in c for c in dm . columns ] ) [ 0 ]
dm_cols_activity = np . where ( [ f " { selected_event } " in c for c in dm . columns ] ) [ 0 ]
dm = dm [ [ dm . columns [ i ] for i in dm_cols_activity ] ]
dm = dm [ [ dm . columns [ i ] for i in dm_cols_activity ] ]
try :
lme = smf . mixedlm ( " theta ~ -1 + delay:TidyCond:Chroma " , df , groups = df [ " ID " ] ) . fit ( )
lme = smf . mixedlm ( " theta ~ -1 + delay:TidyCond:Chroma " , df , groups = df [ " ID " ] ) . fit ( )
except :
lme = smf . ols ( " theta ~ -1 + delay:TidyCond:Chroma " , df , groups = df [ " ID " ] ) . fit ( ) # type: ignore
df_sum = statsmodels_to_results ( lme )
df_sum = statsmodels_to_results ( lme )
df_sum [ " delay " ] = [ int ( n ) for n in df_sum [ " delay " ] ]
df_sum [ " delay " ] = [ int ( n ) for n in df_sum [ " delay " ] ]
@@ -2785,7 +2821,7 @@ def run_second_level_analysis(df_contrasts, raw, p, bounds):
result = model . fit ( Y )
result = model . fit ( Y )
t_val = result . t ( 0 ) . item ( )
t_val = result . t ( 0 ) . item ( )
p_val = 2 * stats . t. sf ( np . abs ( t_val ) , df = result . df_model )
p_val = 2 * t . sf ( np . abs ( t_val ) , df = result . df_model )
mean_beta = np . mean ( Y )
mean_beta = np . mean ( Y )
group_results . append ( {
group_results . append ( {
@@ -3280,7 +3316,7 @@ def hr_calc(raw):
# --- Parameters for PSD ---
# --- Parameters for PSD ---
desired_bin_hz = 0.1
desired_bin_hz = 0.1
nperseg = int ( sfreq / desired_bin_hz )
nperseg = int ( sfreq / desired_bin_hz )
hr_range = ( 30 , 180 )
hr_range = ( 30 , 180 ) # TODO: SHould this not use the user defined values?
# --- Function to find strongest local peak ---
# --- Function to find strongest local peak ---
def find_hr_from_psd ( ch_data ) :
def find_hr_from_psd ( ch_data ) :
@@ -3310,18 +3346,21 @@ def hr_calc(raw):
return fig , hr1 , hr2 , low , high
return fig , hr1 , hr2 , low , high
def process_participant ( file_path , progress_callback = None ) :
def process_participant ( file_path , progress_callback = None ) :
fig_individual : dict [ str , Figure ] = { }
fig_individual : dict [ str , Figure ] = { }
# Step 1: Load
# Step 1: Preprocessing
raw = load_snirf ( file_path )
raw = load_snirf ( file_path )
fig_raw = raw . plot ( duration = raw . times [ - 1 ] , n_channels = raw . info [ ' nchan ' ] , title = " Loaded Raw " , show = False )
fig_raw = raw . plot ( duration = raw . times [ - 1 ] , n_channels = raw . info [ ' nchan ' ] , title = " Loaded Raw " , show = False )
fig_individual [ " Loaded Raw " ] = fig_raw
fig_individual [ " Loaded Raw " ] = fig_raw
if progress_callback : progress_callback ( 1 )
if progress_callback : progress_callback ( 1 )
logger . info ( " 1 " )
logger . info ( " Step 1 Completed. " )
# Step 2: Trimming
# TODO: Clean this into a method
if TRIM :
if TRIM :
if hasattr ( raw , ' annotations ' ) and len ( raw . annotations ) > 0 :
if hasattr ( raw , ' annotations ' ) and len ( raw . annotations ) > 0 :
# Get time of first event
# Get time of first event
@@ -3329,17 +3368,16 @@ def process_participant(file_path, progress_callback=None):
trim_time = max ( 0 , first_event_time - SECONDS_TO_KEEP ) # Ensure we don't go negative
trim_time = max ( 0 , first_event_time - SECONDS_TO_KEEP ) # Ensure we don't go negative
raw . crop ( tmin = trim_time )
raw . crop ( tmin = trim_time )
# Shift annotation onsets to match new t=0
# Shift annotation onsets to match new t=0
import mne
ann = raw . annotations
ann = raw . annotations
ann_shifted = mne . Annotations(
ann_shifted = Annotations (
onset = ann . onset - trim_time , # shift to start at zero
onset = ann . onset - trim_time , # shift to start at zero
duration = ann . duration ,
duration = ann . duration ,
description = ann . description
description = ann . description
)
)
data = raw . get_data ( )
data = raw . get_data ( )
info = raw . info . copy ( )
info = raw . info . copy ( )
raw = mne . io . RawArray( data , info )
raw = RawArray ( data , info )
raw . set_annotations ( ann_shifted )
raw . set_annotations ( ann_shifted )
logger . info ( f " Trimmed raw data: start at { trim_time } s (5s before first event), t=0 at new start " )
logger . info ( f " Trimmed raw data: start at { trim_time } s (5s before first event), t=0 at new start " )
@@ -3349,185 +3387,180 @@ def process_participant(file_path, progress_callback=None):
fig_trimmed = raw . plot ( duration = raw . times [ - 1 ] , n_channels = raw . info [ ' nchan ' ] , title = " Trimmed Raw " , show = False )
fig_trimmed = raw . plot ( duration = raw . times [ - 1 ] , n_channels = raw . info [ ' nchan ' ] , title = " Trimmed Raw " , show = False )
fig_individual [ " Trimmed Raw " ] = fig_trimmed
fig_individual [ " Trimmed Raw " ] = fig_trimmed
if progress_callback : progress_callback ( 2 )
if progress_callback : progress_callback ( 2 )
logger . info ( " 2 " )
logger . info ( " Step 2 Completed. " )
# Step 1.5 : Verify o ptode positions
# Step 3 : Verify O ptode Placement
if OPTODE_PLACEMENT :
if OPTODE_PLACEMENT :
fig_optodes = raw . plot_sensors ( show_names = True , to_sphere = True , show = False ) # type: ignore
fig_optodes = raw . plot_sensors ( show_names = SHOW_OPTODE_NAMES , to_sphere = True , show = False ) # type: ignore
fig_individual [ " Plot Sensors " ] = fig_optodes
fig_individual [ " Plot Sensors " ] = fig_optodes
if progress_callback : progress_callback ( 3 )
if progress_callback : progress_callback ( 3 )
logger . info ( " 3 " )
logger . info ( " Step 3 Completed. " )
# Step 2 : Bad from SCI
# Step 4 : Short/Long Channels
if SHORT_CHANNEL :
short_chans = get_short_channels ( raw , max_dist = SHORT_CHANNEL_THRESH )
fig_short_chans = short_chans . plot ( duration = raw . times [ - 1 ] , n_channels = raw . info [ ' nchan ' ] , title = " Short Channels Only " , show = False )
fig_individual [ " short " ] = fig_short_chans
else :
short_chans = None
get_long_channels ( raw , min_dist = SHORT_CHANNEL_THRESH , max_dist = LONG_CHANNEL_THRESH ) # Don't update the existing raw
if progress_callback : progress_callback ( 4 )
logger . info ( " Step 4 Completed. " )
# Step 5: Heart Rate
if HEART_RATE :
if HEART_RATE :
fig , hr1 , hr2 , low , high = hr_calc ( raw )
fig , hr1 , hr2 , low , high = hr_calc ( raw )
fig_individual [ " PSD " ] = fig
fig_individual [ " PSD " ] = fig
fig_individual [ ' HeartRate_PSD ' ] = hr1
fig_individual [ ' HeartRate_PSD ' ] = hr1
fig_individual [ ' HeartRate_Time ' ] = hr2
fig_individual [ ' HeartRate_Time ' ] = hr2
if progress_callback : progress_callback ( 4 )
if progress_callback : progress_callback ( 5 )
logger . info ( " 4 " )
logger . info ( " Step 5 Completed. " )
# Step 6: Scalp Coupling Index
bad_sci = [ ]
bad_sci = [ ]
if SCI :
if SCI :
if HEART_RATE :
bad_sci , fig_sci_1 , fig_sci_2 = calculate_scalp_coupling ( raw , low , high )
bad_sci , fig_sci_1 , fig_sci_2 = calculate_scalp_coupling ( raw , low , high )
else :
bad_sci , fig_sci_1 , fig_sci_2 = calculate_scalp_coupling ( raw )
fig_individual [ " SCI1 " ] = fig_sci_1
fig_individual [ " SCI1 " ] = fig_sci_1
fig_individual [ " SCI2 " ] = fig_sci_2
fig_individual [ " SCI2 " ] = fig_sci_2
if progress_callback : progress_callback ( 5 )
if progress_callback : progress_callback ( 6 )
logger . info ( " 5 " )
logger . info ( " Step 6 Completed. " )
# Step 2 : Bad from SNR
# Step 7 : Signal to Noise Ratio
bad_snr = [ ]
bad_snr = [ ]
if SNR :
if SNR :
bad_snr , fig_snr = calculate_signal_noise_ratio ( raw )
bad_snr , fig_snr = calculate_signal_noise_ratio ( raw )
fig_individual [ " SNR1 " ] = fig_snr
fig_individual [ " SNR1 " ] = fig_snr
if progress_callback : progress_callback ( 6 )
if progress_callback : progress_callback ( 7 )
logger . info ( " 6 " )
logger . info ( " Step 7 Completed. " )
# Step 3 : Bad from PSP
# Step 8 : Peak Spectral Power
bad_psp = [ ]
bad_psp = [ ]
if PSP :
if PSP :
bad_psp , fig_psp1 , fig_psp2 = calculate_peak_power ( raw )
bad_psp , fig_psp1 , fig_psp2 = calculate_peak_power ( raw )
fig_individual [ " PSP1 " ] = fig_psp1
fig_individual [ " PSP1 " ] = fig_psp1
fig_individual [ " PSP2 " ] = fig_psp2
fig_individual [ " PSP2 " ] = fig_psp2
if progress_callback : progress_callback ( 7 )
if progress_callback : progress_callback ( 8 )
logger . info ( " 7 " )
logger . info ( " Step 8 Completed. " )
# Step 4 : Mark the b ad c hannels
# Step 9 : B ad C hannels Handling
if BAD_CHANNELS_HANDLING != " None " :
raw , fig_dropped , fig_raw_before , bad_channels = mark_bads ( raw , bad_sci , bad_snr , bad_psp )
raw , fig_dropped , fig_raw_before , bad_channels = mark_bads ( raw , bad_sci , bad_snr , bad_psp )
if fig_dropped and fig_raw_before is not None :
if fig_dropped and fig_raw_before is not None :
fig_individual [ " fig2 " ] = fig_dropped
fig_individual [ " fig2 " ] = fig_dropped
fig_individual [ " fig3 " ] = fig_raw_before
fig_individual [ " fig3 " ] = fig_raw_before
if progress_callback : progress_callback ( 8 )
logger . info ( " 8 " )
# Step 5: Interpolate the bad channels
if bad_channels :
if bad_channels :
raw , fig_raw_after = i nterpolate_fNIRS_bads_weighted_average ( raw , bad_channels )
if BAD_CHANNELS_HANDLING == " I nterpolate" :
raw , fig_raw_after = interpolate_fNIRS_bads_weighted_average ( raw , max_dist = MAX_DIST , min_neighbors = MIN_NEIGHBORS )
fig_individual [ " fig4 " ] = fig_raw_after
fig_individual [ " fig4 " ] = fig_raw_after
if progress_callback : progress_callback ( 9 )
el if BAD_CHANNELS_HANDLING == " Remove " :
logger . info ( " 9 " )
pass
#TODO: Is there more needed here?
# Step 6: Optical Density
if progress_callback : progress_callback ( 9 )
logger . info ( " Step 9 Completed. " )
# Step 10: Optical Density
raw_od = optical_density ( raw )
raw_od = optical_density ( raw )
fig_raw_od = raw_od . plot ( duration = raw . times [ - 1 ] , n_channels = raw . info [ ' nchan ' ] , title = " Optical Density " , show = False )
fig_raw_od = raw_od . plot ( duration = raw . times [ - 1 ] , n_channels = raw . info [ ' nchan ' ] , title = " Optical Density " , show = False )
fig_individual [ " Optical Density " ] = fig_raw_od
fig_individual [ " Optical Density " ] = fig_raw_od
if progress_callback : progress_callback ( 10 )
if progress_callback : progress_callback ( 10 )
logger . info ( " 10 " )
logger . info ( " Step 10 Completed. " )
# Step 7 : TDDR
# Step 11 : Temporal Derivative Distribution Repair Filtering
if TDDR :
if TDDR :
raw_od = temporal_derivative_distribution_repair ( raw_od )
raw_od = temporal_derivative_distribution_repair ( raw_od )
fig_raw_od_tddr = raw_od . plot ( duration = raw . times [ - 1 ] , n_channels = raw . info [ ' nchan ' ] , title = " After TDDR (Motion Correction) " , show = False )
fig_raw_od_tddr = raw_od . plot ( duration = raw . times [ - 1 ] , n_channels = raw . info [ ' nchan ' ] , title = " After TDDR (Motion Correction) " , show = False )
fig_individual [ " TDDR " ] = fig_raw_od_tddr
fig_individual [ " TDDR " ] = fig_raw_od_tddr
if progress_callback : progress_callback ( 11 )
if progress_callback : progress_callback ( 11 )
logger . info ( " 11 " )
logger . info ( " Step 11 Completed. " )
# Step 12: Wavelet Filtering
if WAVELET :
if WAVELET :
raw_od , fig = calculate_and_apply_wavelet ( raw_od )
raw_od , fig = calculate_and_apply_wavelet ( raw_od )
fig_individual [ " Wavelet " ] = fig
fig_individual [ " Wavelet " ] = fig
if progress_callback : progress_callback ( 12 )
if progress_callback : progress_callback ( 12 )
logger . info ( " 12 " )
logger . info ( " Step 12 Completed. " )
# Step 13: Haemoglobin Concentration
# Step 8: BLL
raw_haemo = beer_lambert_law ( raw_od , ppf = calculate_dpf ( file_path ) )
raw_haemo = beer_lambert_law ( raw_od , ppf = calculate_dpf ( file_path ) )
fig_raw_haemo_bll = raw_haemo . plot ( duration = raw_haemo . times [ - 1 ] , n_channels = raw_haemo . info [ ' nchan ' ] , title = " HbO and HbR Signals " , show = False )
fig_raw_haemo_bll = raw_haemo . plot ( duration = raw_haemo . times [ - 1 ] , n_channels = raw_haemo . info [ ' nchan ' ] , title = " HbO and HbR Signals " , show = False )
fig_individual [ " BLL " ] = fig_raw_haemo_bll
fig_individual [ " BLL " ] = fig_raw_haemo_bll
if progress_callback : progress_callback ( 13 )
if progress_callback : progress_callback ( 13 )
logger . info ( " 13 " )
logger . info ( " Step 13 Completed. " )
# Step 9 : ENC
# Step 14 : Enhance Negative Correlation
if ENHANCE_NEGATIVE_CORRELATION :
if ENHANCE_NEGATIVE_CORRELATION :
raw_haemo = enhance_negative_correlation ( raw_haemo )
raw_haemo = enhance_negative_correlation ( raw_haemo )
fig_raw_haemo_enc = raw_haemo . plot ( duration = raw_haemo . times [ - 1 ] , n_channels = raw_haemo . info [ ' nchan ' ] , title = " HbO and HbR Signals " , show = False )
fig_raw_haemo_enc = raw_haemo . plot ( duration = raw_haemo . times [ - 1 ] , n_channels = raw_haemo . info [ ' nchan ' ] , title = " Enhance Negative Correlation " , show = False )
fig_individual [ " ENC " ] = fig_raw_haemo_enc
fig_individual [ " ENC " ] = fig_raw_haemo_enc
if progress_callback : progress_callback ( 14 )
if progress_callback : progress_callback ( 14 )
logger . info ( " 14 " )
logger . info ( " Step 14 Completed. " )
# Step 10 : Filter
# Step 15 : Filter
if FILTER :
if FILTER :
raw_haemo , fig_filter , fig_raw_haemo_filter = filter_the_data ( raw_haemo )
raw_haemo , fig_filter , fig_raw_haemo_filter = filter_the_data ( raw_haemo )
fig_individual [ " filter1 " ] = fig_filter
fig_individual [ " filter1 " ] = fig_filter
fig_individual [ " filter2 " ] = fig_raw_haemo_filter
fig_individual [ " filter2 " ] = fig_raw_haemo_filter
if progress_callback : progress_callback ( 15 )
if progress_callback : progress_callback ( 15 )
logger . info ( " 15 " )
logger . info ( " Step 15 Completed. " )
# Step 11 : Get short / long channel s
# Step 16 : Extracting Event s
if SHORT_CHANNEL :
short_chans = get_short_channels ( raw_haemo , max_dist = SHORT_CHANNEL_THRESH )
fig_short_chans = short_chans . plot ( duration = raw_haemo . times [ - 1 ] , n_channels = raw_haemo . info [ ' nchan ' ] , title = " Short Channels Only " , show = False )
fig_individual [ " short " ] = fig_short_chans
else :
short_chans = None
raw_haemo = get_long_channels ( raw_haemo , min_dist = SHORT_CHANNEL_THRESH , max_dist = LONG_CHANNEL_THRESH )
if progress_callback : progress_callback ( 16 )
logger . info ( " 16 " )
# Step 12: Events from annotations
events , event_dict = events_from_annotations ( raw_haemo )
events , event_dict = events_from_annotations ( raw_haemo )
fig_events = plot_events ( events , event_id = event_dict , sfreq = raw_haemo . info [ " sfreq " ] , show = False )
fig_events = plot_events ( events , event_id = event_dict , sfreq = raw_haemo . info [ " sfreq " ] , show = False )
fig_individual [ " events " ] = fig_events
fig_individual [ " events " ] = fig_events
if progress_callback : progress_callback ( 17 )
if progress_callback : progress_callback ( 16 )
logger . info ( " 17 " )
logger . info ( " Step 16 Completed. " )
# Step 13 : Epoch c alculations
# Step 17 : Epoch C alculations
epochs , fig_epochs = epochs_calculations ( raw_haemo , events , event_dict )
epochs , fig_epochs = epochs_calculations ( raw_haemo , events , event_dict )
for name , fig in fig_epochs : # Unpack the tuple here
for name , fig in fig_epochs :
fig_individual [ f " epochs_ { name } " ] = fig # Store only the figure, not the name
fig_individual [ f " epochs_ { name } " ] = fig
if progress_callback : progress_callback ( 18 )
if progress_callback : progress_callback ( 17 )
logger . info ( " 18 " )
logger . info ( " Step 17 Completed. " )
# Step 14: Design Matrix
events_to_remove = REMOVE_EVENTS
filtered_annotations = [ ann for ann in raw . annotations if ann [ ' description ' ] not in events_to_remove ]
new_annot = Annotations (
onset = [ ann [ ' onset ' ] for ann in filtered_annotations ] ,
duration = [ ann [ ' duration ' ] for ann in filtered_annotations ] ,
description = [ ann [ ' description ' ] for ann in filtered_annotations ]
)
# Set the new annotations
raw_haemo . set_annotations ( new_annot )
# Step 18: Design Matrix
design_matrix , fig_design_matrix = make_design_matrix ( raw_haemo , short_chans )
design_matrix , fig_design_matrix = make_design_matrix ( raw_haemo , short_chans )
fig_individual [ " Design Matrix " ] = fig_design_matrix
fig_individual [ " Design Matrix " ] = fig_design_matrix
if progress_callback : progress_callback ( 19 )
if progress_callback : progress_callback ( 18 )
logger . info ( " 19 " )
logger . info ( " Step 18 Completed. " )
# Step 15: Run GLM
glm_est = run_glm ( raw_haemo , design_matrix )
# Step 19: Run GLM
glm_est = run_glm ( raw_haemo , design_matrix , noise_model = NOISE_MODEL , bins = BINS , n_jobs = N_JOBS , verbose = VERBOSITY )
# Not used AppData\Local\Packages\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\LocalCache\local-packages\Python313\site-packages\nilearn\glm\contrasts.py
# Not used AppData\Local\Packages\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\LocalCache\local-packages\Python313\site-packages\nilearn\glm\contrasts.py
# Yes used AppData\Local\Packages\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\LocalCache\local-packages\Python313\site-packages\mne_nirs\utils\_io.py
# Yes used AppData\Local\Packages\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\LocalCache\local-packages\Python313\site-packages\mne_nirs\utils\_io.py
# The p-value is calculated from this t-statistic using the Student’ s t-distribution with appropriate degrees of freedom.
# The p-value is calculated from this t-statistic using the Student' s t-distribution with appropriate degrees of freedom.
# p_value = 2 * stats.t.cdf(-abs(t_statistic), df)
# p_value = 2 * stats.t.cdf(-abs(t_statistic), df)
# It is a two-tailed p-value.
# It is a two-tailed p-value.
# It says how likely it is to observe the effect you did (or something more extreme) if the true effect was zero (null hypothesis).
# It says how likely it is to observe the effect you did (or something more extreme) if the true effect was zero (null hypothesis).
# A small p-value (e.g., < 0.05) suggests the effect is unlikely to be zero — it’ s "statistically significant."
# A small p-value (e.g., < 0.05) suggests the effect is unlikely to be zero — it' s "statistically significant."
# A large p-value means the data do not provide strong evidence that the effect is different from zero.
# A large p-value means the data do not provide strong evidence that the effect is different from zero.
if progress_callback : progress_callback ( 20 )
if progress_callback : progress_callback ( 19 )
logger . info ( " 20 " )
logger . info ( " 19 " )
# Step 16: Plot GLM r esults
# Step 20: Generate GLM R esults
if " derivative " not in HRF_MODEL . lower ( ) :
fig_glm_result = plot_glm_results ( file_path , raw_haemo , glm_est , design_matrix )
fig_glm_result = plot_glm_results ( file_path , raw_haemo , glm_est , design_matrix )
for name , fig in fig_glm_result :
for name , fig in fig_glm_result :
fig_individual [ f " GLM { name } " ] = fig
fig_individual [ f " GLM { name } " ] = fig
if progress_callback : progress_callback ( 21 )
if progress_callback : progress_callback ( 20 )
logger . info ( " 21 " )
logger . info ( " 20 " )
# Step 17: Plot c hannel s ignificance
# Step 21: Generate C hannel S ignificance
if HRF_MODEL == " fir " :
fig_significance = individual_significance ( raw_haemo , glm_est )
fig_significance = individual_significance ( raw_haemo , glm_est )
for name , fig in fig_significance :
for name , fig in fig_significance :
fig_individual [ f " Significance { name } " ] = fig
fig_individual [ f " Significance { name } " ] = fig
if progress_callback : progress_callback ( 22 )
if progress_callback : progress_callback ( 21 )
logger . info ( " 22 " )
logger . info ( " 21 " )
# Step 18: cha, con, roi
# Step 22: Generate Channel, Region of Interest, and Contrast Results
cha = glm_est . to_dataframe ( )
cha = glm_est . to_dataframe ( )
# HACK: Comment out line 588 (self._renderer.show()) in _brain.py from MNE
# HACK: Comment out line 588 (self._renderer.show()) in _brain.py from MNE
@@ -3555,6 +3588,7 @@ def process_participant(file_path, progress_callback=None):
[ ( column , contrast_matrix [ i ] ) for i , column in enumerate ( design_matrix . columns ) ]
[ ( column , contrast_matrix [ i ] ) for i , column in enumerate ( design_matrix . columns ) ]
)
)
if HRF_MODEL == " fir " :
all_delay_cols = [ col for col in design_matrix . columns if " _delay_ " in col ]
all_delay_cols = [ col for col in design_matrix . columns if " _delay_ " in col ]
all_conditions = sorted ( { col . split ( " _delay_ " ) [ 0 ] for col in all_delay_cols } )
all_conditions = sorted ( { col . split ( " _delay_ " ) [ 0 ] for col in all_delay_cols } )
@@ -3580,12 +3614,14 @@ def process_participant(file_path, progress_callback=None):
contrast_dict [ condition ] = contrast_vector
contrast_dict [ condition ] = contrast_vector
if progress_callback : progress_callback ( 23 )
if progress_callback : progress_callback ( 22 )
logger . info ( " 23 " )
logger . info ( " 22 " )
# Compute c ontrast r esults
# Step 23: Compute C ontrast R esults
contrast_results = { }
contrast_results = { }
if HRF_MODEL == " fir " :
for cond , contrast_vector in contrast_dict . items ( ) :
for cond , contrast_vector in contrast_dict . items ( ) :
contrast = glm_est . compute_contrast ( contrast_vector ) # type: ignore
contrast = glm_est . compute_contrast ( contrast_vector ) # type: ignore
df = contrast . to_dataframe ( )
df = contrast . to_dataframe ( )
@@ -3594,10 +3630,10 @@ def process_participant(file_path, progress_callback=None):
cha [ " ID " ] = file_path
cha [ " ID " ] = file_path
if progress_callback : progress_callback ( 24 )
if progress_callback : progress_callback ( 23 )
logger . info ( " 24 " )
logger . info ( " 23 " )
# Step 24: Finishing Up
fig_bytes = convert_fig_dict_to_png_bytes ( fig_individual )
fig_bytes = convert_fig_dict_to_png_bytes ( fig_individual )
sanitize_paths_for_pickle ( raw_haemo , epochs )
sanitize_paths_for_pickle ( raw_haemo , epochs )
@@ -3605,7 +3641,17 @@ def process_participant(file_path, progress_callback=None):
if progress_callback : progress_callback ( 25 )
if progress_callback : progress_callback ( 25 )
logger . info ( " 25 " )
logger . info ( " 25 " )
return raw_haemo , epochs , fig_bytes , cha , contrast_results , df_ind , design_matrix , AGE , GENDER , GROUP , True
# TODO: Tidy up
# Extract the parameters this file was ran with. No need to return age, gender, group?
config = {
k : globals ( ) [ k ]
for k in __annotations__
if k in globals ( ) and k != " REQUIRED_KEYS "
}
print ( config )
return raw_haemo , config , epochs , fig_bytes , cha , contrast_results , df_ind , design_matrix , True
def sanitize_paths_for_pickle ( raw_haemo , epochs ) :
def sanitize_paths_for_pickle ( raw_haemo , epochs ) :
@@ -3616,3 +3662,233 @@ def sanitize_paths_for_pickle(raw_haemo, epochs):
# Fix epochs._raw._filenames
# Fix epochs._raw._filenames
if hasattr ( epochs , ' _raw ' ) and hasattr ( epochs . _raw , ' _filenames ' ) :
if hasattr ( epochs , ' _raw ' ) and hasattr ( epochs . _raw , ' _filenames ' ) :
epochs . _raw . _filenames = [ str ( p ) for p in epochs . _raw . _filenames ]
epochs . _raw . _filenames = [ str ( p ) for p in epochs . _raw . _filenames ]
def functional_connectivity_spectral_epochs ( epochs , n_lines , vmin ) :
# will crash without this load
epochs . load_data ( )
hbo_epochs = epochs . copy ( ) . pick ( picks = " hbo " )
data = hbo_epochs . get_data ( )
names = hbo_epochs . ch_names
sfreq = hbo_epochs . info [ " sfreq " ]
con = spectral_connectivity_epochs (
data ,
method = [ " coh " , " plv " ] ,
mode = " multitaper " ,
sfreq = sfreq ,
fmin = 0.04 ,
fmax = 0.2 ,
faverage = True ,
verbose = True
)
con_coh , con_plv = con
coh = con_coh . get_data ( output = " dense " ) . squeeze ( )
plv = con_plv . get_data ( output = " dense " ) . squeeze ( )
np . fill_diagonal ( coh , 0 )
np . fill_diagonal ( plv , 0 )
plot_connectivity_circle (
coh ,
names ,
title = " fNIRS Functional Connectivity (HbO - Coherence) " ,
n_lines = n_lines ,
vmin = vmin
)
def functional_connectivity_spectral_time ( epochs , n_lines , vmin ) :
# will crash without this load
epochs . load_data ( )
hbo_epochs = epochs . copy ( ) . pick ( picks = " hbo " )
data = hbo_epochs . get_data ( )
names = hbo_epochs . ch_names
sfreq = hbo_epochs . info [ " sfreq " ]
freqs = np . linspace ( 0.04 , 0.2 , 10 )
n_cycles = freqs * 2
con = spectral_connectivity_time (
data ,
freqs = freqs ,
method = [ " coh " , " plv " ] ,
mode = " multitaper " ,
sfreq = sfreq ,
fmin = 0.04 ,
fmax = 0.2 ,
n_cycles = n_cycles ,
faverage = True ,
verbose = True
)
con_coh , con_plv = con
coh = con_coh . get_data ( output = " dense " ) . squeeze ( )
plv = con_plv . get_data ( output = " dense " ) . squeeze ( )
np . fill_diagonal ( coh , 0 )
np . fill_diagonal ( plv , 0 )
plot_connectivity_circle (
coh ,
names ,
title = " fNIRS Functional Connectivity (HbO - Coherence) " ,
n_lines = n_lines ,
vmin = vmin
)
def functional_connectivity_envelope ( epochs , n_lines , vmin ) :
# will crash without this load
epochs . load_data ( )
hbo_epochs = epochs . copy ( ) . pick ( picks = " hbo " )
data = hbo_epochs . get_data ( )
env = envelope_correlation (
data ,
orthogonalize = False ,
absolute = True
)
env_data = env . get_data ( output = " dense " )
env_corr = env_data . mean ( axis = 0 )
env_corr = np . squeeze ( env_corr )
np . fill_diagonal ( env_corr , 0 )
plot_connectivity_circle (
env_corr ,
hbo_epochs . ch_names ,
title = " fNIRS HbO Envelope Correlation (Task Connectivity) " ,
n_lines = n_lines ,
vmin = vmin
)
def functional_connectivity_betas ( raw_hbo , n_lines , vmin , event_name = None ) :
raw_hbo = raw_hbo . copy ( ) . pick ( picks = " hbo " )
onsets = raw_hbo . annotations . onset
# CRITICAL: Update the Raw object's annotations so the GLM sees unique events
ann = raw_hbo . annotations
new_desc = [ ]
for i , desc in enumerate ( ann . description ) :
new_desc . append ( f " { desc } __trial_ { i : 03d } " )
ann . description = np . array ( new_desc )
# shoudl use user defiuned!!!!
design_matrix = make_first_level_design_matrix (
raw = raw_hbo ,
hrf_model = ' fir ' ,
fir_delays = np . arange ( 0 , 12 , 1 ) ,
drift_model = ' cosine ' ,
drift_order = 1
)
# 3. Run GLM & Extract Betas
glm_results = run_glm ( raw_hbo , design_matrix )
betas = np . array ( glm_results . theta ( ) )
reg_names = list ( design_matrix . columns )
n_channels = betas . shape [ 0 ]
# ------------------------------------------------------------------
# 5. Find unique trial tags (optionally filtered by event)
# ------------------------------------------------------------------
trial_tags = sorted ( {
col . split ( " _delay " ) [ 0 ]
for col in reg_names
if (
( " __trial_ " in col )
and ( event_name is None or col . startswith ( event_name + " __ " ) )
)
} )
if len ( trial_tags ) == 0 :
raise ValueError ( f " No trials found for event_name= { event_name } " )
# ------------------------------------------------------------------
# 6. Build beta series (average across FIR delays per trial)
# ------------------------------------------------------------------
beta_series = np . zeros ( ( n_channels , len ( trial_tags ) ) )
for t , tag in enumerate ( trial_tags ) :
idx = [
i for i , col in enumerate ( reg_names )
if col . startswith ( f " { tag } _delay " )
]
beta_series [ : , t ] = np . mean ( betas [ : , idx ] , axis = 1 ) . flatten ( )
# n_channels, n_trials = betas.shape[0], len(onsets)
# beta_series = np.zeros((n_channels, n_trials))
# for t in range(n_trials):
# trial_indices = [i for i, col in enumerate(reg_names) if col.startswith(f"trial_{t:03d}_delay")]
# if trial_indices:
# beta_series[:, t] = np.mean(betas[:, trial_indices], axis=1).flatten()
# Normalize each channel so they are on the same scale
# Without this, everything is connected to everything. Apparently this is a big issue in fNIRS?
beta_series = zscore ( beta_series , axis = 1 )
global_signal = np . mean ( beta_series , axis = 0 )
beta_series_clean = np . zeros_like ( beta_series )
for i in range ( n_channels ) :
slope , _ = np . polyfit ( global_signal , beta_series [ i , : ] , 1 )
beta_series_clean [ i , : ] = beta_series [ i , : ] - ( slope * global_signal )
# 4. Correlation & Strict Filtering
corr_matrix = np . zeros ( ( n_channels , n_channels ) )
p_matrix = np . ones ( ( n_channels , n_channels ) )
for i in range ( n_channels ) :
for j in range ( i + 1 , n_channels ) :
r , p = pearsonr ( beta_series_clean [ i , : ] , beta_series_clean [ j , : ] )
corr_matrix [ i , j ] = corr_matrix [ j , i ] = r
p_matrix [ i , j ] = p_matrix [ j , i ] = p
# 5. High-Bar Thresholding
reject , _ = multipletests ( p_matrix [ np . triu_indices ( n_channels , k = 1 ) ] , method = ' fdr_bh ' , alpha = 0.05 ) [ : 2 ]
sig_corr_matrix = np . zeros_like ( corr_matrix )
triu = np . triu_indices ( n_channels , k = 1 )
for idx , is_sig in enumerate ( reject ) :
r_val = corr_matrix [ triu [ 0 ] [ idx ] , triu [ 1 ] [ idx ] ]
# Only keep the absolute strongest connections
if is_sig and abs ( r_val ) > 0.7 :
sig_corr_matrix [ triu [ 0 ] [ idx ] , triu [ 1 ] [ idx ] ] = r_val
sig_corr_matrix [ triu [ 1 ] [ idx ] , triu [ 0 ] [ idx ] ] = r_val
# 6. Plot
plot_connectivity_circle (
sig_corr_matrix ,
raw_hbo . ch_names ,
title = " Strictly Filtered Connectivity (TDDR + GSR + Z-Score) " ,
n_lines = None ,
vmin = 0.7 ,
vmax = 1.0 ,
colormap = ' hot ' # Use 'hot' to make positive connections pop
)