more parameters

This commit is contained in:
2025-11-03 16:56:05 -08:00
parent 1aa2402d09
commit 64ed6d2e87
3 changed files with 173 additions and 100 deletions

222
flares.py
View File

@@ -119,6 +119,13 @@ GENDER: str
DOWNSAMPLE: bool
DOWNSAMPLE_FREQUENCY: int
TRIM: bool
SECONDS_TO_KEEP: float
OPTODE_PLACEMENT: bool
HEART_RATE: bool
SCI: bool
SCI_TIME_WINDOW: int
SCI_THRESHOLD: float
@@ -133,27 +140,35 @@ PSP_THRESHOLD: float
TDDR: bool
IQR = 1.5
WAVELET: bool
IQR: float
WAVELET_TYPE: str
WAVELET_LEVEL: int
HEART_RATE = True # True if heart rate should be calculated. This helps the SCI, PSP, and SNR methods to be more accurate.
SECONDS_TO_STRIP_HR =5 # Amount of seconds to temporarily strip from the data to calculate heart rate more effectively. Useful if participant removed cap while still recording.
MAX_LOW_HR = 40 # Any heart rate values lower than this will be set to this value.
MAX_HIGH_HR = 200 # Any heart rate values higher than this will be set to this value.
SMOOTHING_WINDOW_HR = 100 # Heart rate will be calculated as a rolling average over this many amount of samples.
HEART_RATE_WINDOW = 25 # Amount of BPM above and below the calculated average to use for a range of resting BPM.
SHORT_CHANNEL_THRESH = 0.018
ENHANCE_NEGATIVE_CORRELATION: bool
FILTER: bool
L_FREQ: float
H_FREQ: float
SHORT_CHANNEL: bool
SHORT_CHANNEL_THRESH: float
LONG_CHANNEL_THRESH: float
REMOVE_EVENTS: list
TIME_WINDOW_START: int
TIME_WINDOW_END: int
DRIFT_MODEL: str
VERBOSITY = True
# FIXME: Shouldn't need each ordering - just order it before checking
@@ -183,11 +198,17 @@ GROUP = "Default"
REQUIRED_KEYS: dict[str, Any] = {
# "SECONDS_TO_STRIP": int,
"DOWNSAMPLE": bool,
"DOWNSAMPLE_FREQUENCY": int,
"TRIM": bool,
"SECONDS_TO_KEEP": float,
"OPTODE_PLACEMENT": bool,
"HEART_RATE": bool,
"SCI": bool,
"SCI_TIME_WINDOW": int,
"SCI_THRESHOLD": float,
@@ -201,11 +222,23 @@ REQUIRED_KEYS: dict[str, Any] = {
"PSP_THRESHOLD": float,
"SHORT_CHANNEL": bool,
"SHORT_CHANNEL_THRESH": float,
"LONG_CHANNEL_THRESH": float,
"REMOVE_EVENTS": list,
"TIME_WINDOW_START": int,
"TIME_WINDOW_END": int,
"L_FREQ": float,
"H_FREQ": float,
"TDDR": bool,
"WAVELET": bool,
"IQR": float,
"WAVELET_TYPE": str,
"WAVELET_LEVEL": int,
"FILTER": bool,
"DRIFT_MODEL": str,
# "REJECT_PAIRS": bool,
# "FORCE_DROP_ANNOTATIONS": list,
# "FILTER_LOW_PASS": float,
@@ -1107,7 +1140,7 @@ def filter_the_data(raw_haemo):
fig_raw_haemo_filter = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="Filtered HbO and HbR", show=False)
return fig_filter, fig_raw_haemo_filter
return raw_haemo, fig_filter, fig_raw_haemo_filter
@@ -1284,7 +1317,7 @@ def make_design_matrix(raw_haemo, short_chans):
hrf_model='fir',
stim_dur=0.5,
fir_delays=range(15),
drift_model='cosine',
drift_model=DRIFT_MODEL,
high_pass=0.01,
oversampling=1,
min_onset=-125,
@@ -1297,7 +1330,7 @@ def make_design_matrix(raw_haemo, short_chans):
hrf_model='fir',
stim_dur=0.5,
fir_delays=range(15),
drift_model='cosine',
drift_model=DRIFT_MODEL,
high_pass=0.01,
oversampling=1,
min_onset=-125,
@@ -2975,7 +3008,7 @@ def calculate_and_apply_wavelet(data: BaseRaw) -> tuple[BaseRaw, Figure]:
logger.info("Calculating the IQR, decomposing the signal, and thresholding the coefficients...")
for ch in range(loaded_data.shape[0]):
denoised_data[ch, :] = wavelet_iqr_denoise(loaded_data[ch, :], wavelet='db4', level=3)
denoised_data[ch, :] = wavelet_iqr_denoise(loaded_data[ch, :], wavelet=WAVELET_TYPE, level=WAVELET_LEVEL)
# Reconstruct the data with the annotations
logger.info("Reconstructing the data with annotations...")
@@ -3289,68 +3322,66 @@ def process_participant(file_path, progress_callback=None):
logger.info("1")
if hasattr(raw, 'annotations') and len(raw.annotations) > 0:
# Get time of first event
first_event_time = raw.annotations.onset[0]
trim_time = max(0, first_event_time - 5.0) # Ensure we don't go negative
raw.crop(tmin=trim_time)
# Shift annotation onsets to match new t=0
import mne
if TRIM:
if hasattr(raw, 'annotations') and len(raw.annotations) > 0:
# Get time of first event
first_event_time = raw.annotations.onset[0]
trim_time = max(0, first_event_time - SECONDS_TO_KEEP) # Ensure we don't go negative
raw.crop(tmin=trim_time)
# Shift annotation onsets to match new t=0
import mne
ann = raw.annotations
ann_shifted = mne.Annotations(
onset=ann.onset - trim_time, # shift to start at zero
duration=ann.duration,
description=ann.description
)
data = raw.get_data()
info = raw.info.copy()
raw = mne.io.RawArray(data, info)
raw.set_annotations(ann_shifted)
ann = raw.annotations
ann_shifted = mne.Annotations(
onset=ann.onset - trim_time, # shift to start at zero
duration=ann.duration,
description=ann.description
)
data = raw.get_data()
info = raw.info.copy()
raw = mne.io.RawArray(data, info)
raw.set_annotations(ann_shifted)
logger.info(f"Trimmed raw data: start at {trim_time}s (5s before first event), t=0 at new start")
else:
logger.warning("No events found, skipping trim step.")
logger.info(f"Trimmed raw data: start at {trim_time}s (5s before first event), t=0 at new start")
else:
logger.warning("No events found, skipping trim step.")
fig_trimmed = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Trimmed Raw", show=False)
fig_individual["Trimmed Raw"] = fig_trimmed
fig_trimmed = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Trimmed Raw", show=False)
fig_individual["Trimmed Raw"] = fig_trimmed
if progress_callback: progress_callback(2)
logger.info("2")
# Step 1.5: Verify optode positions
fig_optodes = raw.plot_sensors(show_names=True, to_sphere=True, show=False) # type: ignore
fig_individual["Plot Sensors"] = fig_optodes
if progress_callback: progress_callback(2)
logger.info("2")
# Step 2: Downsample
# raw = raw.resample(0.5) # Downsample to 0.5 Hz
if OPTODE_PLACEMENT:
fig_optodes = raw.plot_sensors(show_names=True, to_sphere=True, show=False) # type: ignore
fig_individual["Plot Sensors"] = fig_optodes
if progress_callback: progress_callback(3)
logger.info("3")
# Step 2: Bad from SCI
if True:
if HEART_RATE:
fig, hr1, hr2, low, high = hr_calc(raw)
fig_individual["PSD"] = fig
fig_individual['HeartRate_PSD'] = hr1
fig_individual['HeartRate_Time'] = hr2
if progress_callback: progress_callback(10)
if progress_callback: progress_callback(2)
if progress_callback: progress_callback(4)
logger.info("4")
bad_sci = []
if SCI:
bad_sci, fig_sci_1, fig_sci_2 = calculate_scalp_coupling(raw, low, high)
fig_individual["SCI1"] = fig_sci_1
fig_individual["SCI2"] = fig_sci_2
if progress_callback: progress_callback(3)
logger.info("3")
if progress_callback: progress_callback(5)
logger.info("5")
# Step 2: Bad from SNR
bad_snr = []
if SNR:
bad_snr, fig_snr = calculate_signal_noise_ratio(raw)
fig_individual["SNR1"] = fig_snr
if progress_callback: progress_callback(4)
logger.info("4")
if progress_callback: progress_callback(6)
logger.info("6")
# Step 3: Bad from PSP
bad_psp = []
@@ -3358,88 +3389,94 @@ def process_participant(file_path, progress_callback=None):
bad_psp, fig_psp1, fig_psp2 = calculate_peak_power(raw)
fig_individual["PSP1"] = fig_psp1
fig_individual["PSP2"] = fig_psp2
if progress_callback: progress_callback(5)
logger.info("5")
if progress_callback: progress_callback(7)
logger.info("7")
# Step 4: Mark the bad channels
raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp)
if fig_dropped and fig_raw_before is not None:
fig_individual["fig2"] = fig_dropped
fig_individual["fig3"] = fig_raw_before
if progress_callback: progress_callback(6)
logger.info("6")
if progress_callback: progress_callback(8)
logger.info("8")
# Step 5: Interpolate the bad channels
if bad_channels:
raw, fig_raw_after = interpolate_fNIRS_bads_weighted_average(raw, bad_channels)
fig_individual["fig4"] = fig_raw_after
if progress_callback: progress_callback(7)
logger.info("7")
if progress_callback: progress_callback(9)
logger.info("9")
# Step 6: Optical Density
raw_od = optical_density(raw)
fig_raw_od = raw_od.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Optical Density", show=False)
fig_individual["Optical Density"] = fig_raw_od
if progress_callback: progress_callback(8)
logger.info("8")
if progress_callback: progress_callback(10)
logger.info("10")
# Step 7: TDDR
if TDDR:
raw_od = temporal_derivative_distribution_repair(raw_od)
fig_raw_od_tddr = raw_od.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="After TDDR (Motion Correction)", show=False)
fig_individual["TDDR"] = fig_raw_od_tddr
if progress_callback: progress_callback(9)
logger.info("9")
if progress_callback: progress_callback(11)
logger.info("11")
raw_od, fig = calculate_and_apply_wavelet(raw_od)
fig_individual["Wavelet"] = fig
if progress_callback: progress_callback(9)
if WAVELET:
raw_od, fig = calculate_and_apply_wavelet(raw_od)
fig_individual["Wavelet"] = fig
if progress_callback: progress_callback(12)
logger.info("12")
# Step 8: BLL
raw_haemo = beer_lambert_law(raw_od, ppf=calculate_dpf(file_path))
fig_raw_haemo_bll = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="HbO and HbR Signals", show=False)
fig_individual["BLL"] = fig_raw_haemo_bll
if progress_callback: progress_callback(10)
logger.info("10")
if progress_callback: progress_callback(13)
logger.info("13")
# Step 9: ENC
# raw_haemo = enhance_negative_correlation(raw_haemo)
# fig_raw_haemo_enc = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="HbO and HbR Signals", show=False)
# fig_individual.append(fig_raw_haemo_enc)
if ENHANCE_NEGATIVE_CORRELATION:
raw_haemo = enhance_negative_correlation(raw_haemo)
fig_raw_haemo_enc = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="HbO and HbR Signals", show=False)
fig_individual["ENC"] = fig_raw_haemo_enc
if progress_callback: progress_callback(14)
logger.info("14")
# Step 10: Filter
fig_filter, fig_raw_haemo_filter = filter_the_data(raw_haemo)
fig_individual["filter1"] = fig_filter
fig_individual["filter2"] = fig_raw_haemo_filter
if progress_callback: progress_callback(11)
logger.info("11")
if FILTER:
raw_haemo, fig_filter, fig_raw_haemo_filter = filter_the_data(raw_haemo)
fig_individual["filter1"] = fig_filter
fig_individual["filter2"] = fig_raw_haemo_filter
if progress_callback: progress_callback(15)
logger.info("15")
# Step 11: Get short / long channels
if SHORT_CHANNEL:
short_chans = get_short_channels(raw_haemo, max_dist=0.02)
short_chans = get_short_channels(raw_haemo, max_dist=SHORT_CHANNEL_THRESH)
fig_short_chans = short_chans.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="Short Channels Only", show=False)
fig_individual["short"] = fig_short_chans
else:
short_chans = None
raw_haemo = get_long_channels(raw_haemo)
if progress_callback: progress_callback(12)
logger.info("12")
raw_haemo = get_long_channels(raw_haemo, min_dist=SHORT_CHANNEL_THRESH, max_dist=LONG_CHANNEL_THRESH)
if progress_callback: progress_callback(16)
logger.info("16")
# Step 12: Events from annotations
events, event_dict = events_from_annotations(raw_haemo)
fig_events = plot_events(events, event_id=event_dict, sfreq=raw_haemo.info["sfreq"], show=False)
fig_individual["events"] = fig_events
if progress_callback: progress_callback(13)
logger.info("13")
if progress_callback: progress_callback(17)
logger.info("17")
# Step 13: Epoch calculations
epochs, fig_epochs = epochs_calculations(raw_haemo, events, event_dict)
for name, fig in fig_epochs: # Unpack the tuple here
fig_individual[f"epochs_{name}"] = fig # Store only the figure, not the name
if progress_callback: progress_callback(14)
logger.info("14")
if progress_callback: progress_callback(18)
logger.info("18")
# Step 14: Design Matrix
events_to_remove = REMOVE_EVENTS
@@ -3457,8 +3494,8 @@ def process_participant(file_path, progress_callback=None):
design_matrix, fig_design_matrix = make_design_matrix(raw_haemo, short_chans)
fig_individual["Design Matrix"] = fig_design_matrix
if progress_callback: progress_callback(15)
logger.info("15")
if progress_callback: progress_callback(19)
logger.info("19")
# Step 15: Run GLM
glm_est = run_glm(raw_haemo, design_matrix)
@@ -3473,22 +3510,22 @@ def process_participant(file_path, progress_callback=None):
# A large p-value means the data do not provide strong evidence that the effect is different from zero.
if progress_callback: progress_callback(16)
logger.info("16")
if progress_callback: progress_callback(20)
logger.info("20")
# Step 16: Plot GLM results
fig_glm_result = plot_glm_results(file_path, raw_haemo, glm_est, design_matrix)
for name, fig in fig_glm_result:
fig_individual[f"GLM {name}"] = fig
if progress_callback: progress_callback(17)
logger.info("17")
if progress_callback: progress_callback(21)
logger.info("21")
# Step 17: Plot channel significance
fig_significance = individual_significance(raw_haemo, glm_est)
for name, fig in fig_significance:
fig_individual[f"Significance {name}"] = fig
if progress_callback: progress_callback(18)
logger.info("18")
if progress_callback: progress_callback(22)
logger.info("22")
# Step 18: cha, con, roi
cha = glm_est.to_dataframe()
@@ -3543,8 +3580,8 @@ def process_participant(file_path, progress_callback=None):
contrast_dict[condition] = contrast_vector
if progress_callback: progress_callback(19)
logger.info("19")
if progress_callback: progress_callback(23)
logger.info("23")
# Compute contrast results
contrast_results = {}
@@ -3557,15 +3594,20 @@ def process_participant(file_path, progress_callback=None):
cha["ID"] = file_path
if progress_callback: progress_callback(24)
logger.info("24")
fig_bytes = convert_fig_dict_to_png_bytes(fig_individual)
if progress_callback: progress_callback(20)
logger.info("20")
sanitize_paths_for_pickle(raw_haemo, epochs)
if progress_callback: progress_callback(25)
logger.info("25")
return raw_haemo, epochs, fig_bytes, cha, contrast_results, df_ind, design_matrix, AGE, GENDER, GROUP, True
def sanitize_paths_for_pickle(raw_haemo, epochs):
# Fix raw_haemo._filenames
if hasattr(raw_haemo, '_filenames'):