changes and improvements

This commit is contained in:
2026-01-29 17:23:52 -08:00
parent 7007478c3b
commit f82978e2e8
4 changed files with 371 additions and 149 deletions

103
flares.py
View File

@@ -58,6 +58,7 @@ import neurokit2 as nk # type: ignore
import pyvistaqt # type: ignore
import vtkmodules.util.data_model
import vtkmodules.util.execution_model
import xlrd
# External library imports for mne
from mne import (
@@ -123,8 +124,6 @@ SECONDS_TO_KEEP: float
OPTODE_PLACEMENT: bool
SHOW_OPTODE_NAMES: bool
HEART_RATE: bool
SHORT_CHANNEL: bool
SHORT_CHANNEL_THRESH: float
LONG_CHANNEL_THRESH: float
@@ -928,11 +927,12 @@ def interpolate_fNIRS_bads_weighted_average(raw, max_dist=0.03, min_neighbors=2)
raw.info['bads'] = [ch for ch in raw.info['bads'] if ch not in bad_ch_to_remove]
print("\nInterpolation complete.\n")
print("Bads cleared:", raw.info['bads'])
raw.info['bads'] = []
for ch in raw.info['bads']:
print(f"Channel {ch} still marked as bad.")
print("Bads cleared:", raw.info['bads'])
fig_raw_after = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="After interpolation", show=False)
return raw, fig_raw_after
@@ -1333,7 +1333,7 @@ def make_design_matrix(raw_haemo, short_chans):
drift_model=DRIFT_MODEL,
high_pass=HIGH_PASS,
drift_order=DRIFT_ORDER,
fir_delays=range(15),
fir_delays=FIR_DELAYS,
add_regs=short_chans.get_data().T,
add_reg_names=short_chans.ch_names,
min_onset=MIN_ONSET,
@@ -1347,7 +1347,7 @@ def make_design_matrix(raw_haemo, short_chans):
drift_model=DRIFT_MODEL,
high_pass=HIGH_PASS,
drift_order=DRIFT_ORDER,
fir_delays=range(15),
fir_delays=FIR_DELAYS,
min_onset=MIN_ONSET,
oversampling=OVERSAMPLING
)
@@ -1577,7 +1577,7 @@ def resource_path(relative_path):
def fold_channels(raw: BaseRaw) -> None:
def fold_channels(raw: BaseRaw, progress_callback=None) -> None:
# if getattr(sys, 'frozen', False):
path = os.path.expanduser("~") + "/mne_data/fOLD/fOLD-public-master/Supplementary"
@@ -1659,8 +1659,11 @@ def fold_channels(raw: BaseRaw) -> None:
landmark_color_map = {landmark: colors[i % len(colors)] for i, landmark in enumerate(landmarks)}
# Iterate over each channel
print(len(hbo_channel_names))
for idx, channel_name in enumerate(hbo_channel_names):
print(idx, channel_name)
# Run the fOLD on the selected channel
channel_data = raw.copy().pick(picks=channel_name) # type: ignore
@@ -1703,6 +1706,9 @@ def fold_channels(raw: BaseRaw) -> None:
landmark_specificity_data = []
if progress_callback:
progress_callback(idx + 1)
# TODO: Fix this
if True:
handles = [
@@ -1725,8 +1731,9 @@ def fold_channels(raw: BaseRaw) -> None:
for ax in axes[len(hbo_channel_names):]:
ax.axis('off')
plt.show()
return fig, legend_fig
#plt.show()
fig_dict = {"main": fig, "legend": legend_fig}
return convert_fig_dict_to_png_bytes(fig_dict)
@@ -2246,8 +2253,14 @@ def brain_3d_visualization(raw_haemo, df_cha, selected_event, t_or_theta: Litera
# Get all activity conditions
for cond in [f'{selected_event}']:
if True:
ch_summary = df_cha.query(f"Condition.str.startswith('{cond}_delay_') and Chroma == 'hbo'", engine='python') # type: ignore
ch_summary = df_cha.query(f"Condition.str.startswith('{cond}_delay_') and Chroma == 'hbo'", engine='python') # type: ignore
print(ch_summary)
if ch_summary.empty:
#not fir model
print("No data found for this condition.")
ch_summary = df_cha.query(f"Condition in [@cond] and Chroma == 'hbo'", engine='python')
# Use ordinary least squares (OLS) if only one participant
# TODO: Fix.
@@ -2269,6 +2282,9 @@ def brain_3d_visualization(raw_haemo, df_cha, selected_event, t_or_theta: Litera
valid_channels = ch_summary["ch_name"].unique().tolist() # type: ignore
raw_for_plot = raw_haemo.copy().pick(picks=valid_channels) # type: ignore
print(f"DEBUG: Model DF rows: {len(model_df)}")
print(f"DEBUG: Raw channels: {len(raw_for_plot.ch_names)}")
brain = plot_3d_evoked_array(raw_for_plot.pick(picks="hbo"), model_df, view="dorsal", distance=0.02, colorbar=True, clim=clim, mode="weighted", size=(800, 700)) # type: ignore
if show_optodes == 'all' or show_optodes == 'sensors':
@@ -3299,7 +3315,7 @@ def hr_calc(raw):
# --- Parameters for PSD ---
desired_bin_hz = 0.1
nperseg = int(sfreq / desired_bin_hz)
hr_range = (30, 180)
hr_range = (30, 180) # TODO: SHould this not use the user defined values?
# --- Function to find strongest local peak ---
def find_hr_from_psd(ch_data):
@@ -3527,16 +3543,18 @@ def process_participant(file_path, progress_callback=None):
logger.info("19")
# Step 20: Generate GLM Results
fig_glm_result = plot_glm_results(file_path, raw_haemo, glm_est, design_matrix)
for name, fig in fig_glm_result:
fig_individual[f"GLM {name}"] = fig
if "derivative" not in HRF_MODEL.lower():
fig_glm_result = plot_glm_results(file_path, raw_haemo, glm_est, design_matrix)
for name, fig in fig_glm_result:
fig_individual[f"GLM {name}"] = fig
if progress_callback: progress_callback(20)
logger.info("20")
# Step 21: Generate Channel Significance
fig_significance = individual_significance(raw_haemo, glm_est)
for name, fig in fig_significance:
fig_individual[f"Significance {name}"] = fig
if HRF_MODEL == "fir":
fig_significance = individual_significance(raw_haemo, glm_est)
for name, fig in fig_significance:
fig_individual[f"Significance {name}"] = fig
if progress_callback: progress_callback(21)
logger.info("21")
@@ -3568,30 +3586,31 @@ def process_participant(file_path, progress_callback=None):
[(column, contrast_matrix[i]) for i, column in enumerate(design_matrix.columns)]
)
all_delay_cols = [col for col in design_matrix.columns if "_delay_" in col]
all_conditions = sorted({col.split("_delay_")[0] for col in all_delay_cols})
if HRF_MODEL == "fir":
all_delay_cols = [col for col in design_matrix.columns if "_delay_" in col]
all_conditions = sorted({col.split("_delay_")[0] for col in all_delay_cols})
if not all_conditions:
raise ValueError("No FIR regressors found in the design matrix.")
if not all_conditions:
raise ValueError("No FIR regressors found in the design matrix.")
# Build contrast vectors for each condition
contrast_dict = {}
# Build contrast vectors for each condition
contrast_dict = {}
for condition in all_conditions:
delay_cols = [
col for col in all_delay_cols
if col.startswith(f"{condition}_delay_") and
TIME_WINDOW_START <= int(col.split("_delay_")[-1]) <= TIME_WINDOW_END
]
for condition in all_conditions:
delay_cols = [
col for col in all_delay_cols
if col.startswith(f"{condition}_delay_") and
TIME_WINDOW_START <= int(col.split("_delay_")[-1]) <= TIME_WINDOW_END
]
if not delay_cols:
continue # skip if no columns found (shouldn't happen?)
if not delay_cols:
continue # skip if no columns found (shouldn't happen?)
# Average across all delay regressors for this condition
contrast_vector = np.sum([basic_conts[col] for col in delay_cols], axis=0)
contrast_vector /= len(delay_cols)
# Average across all delay regressors for this condition
contrast_vector = np.sum([basic_conts[col] for col in delay_cols], axis=0)
contrast_vector /= len(delay_cols)
contrast_dict[condition] = contrast_vector
contrast_dict[condition] = contrast_vector
if progress_callback: progress_callback(22)
logger.info("22")
@@ -3599,11 +3618,13 @@ def process_participant(file_path, progress_callback=None):
# Step 23: Compute Contrast Results
contrast_results = {}
for cond, contrast_vector in contrast_dict.items():
contrast = glm_est.compute_contrast(contrast_vector) # type: ignore
df = contrast.to_dataframe()
df["ID"] = file_path
contrast_results[cond] = df
if HRF_MODEL == "fir":
for cond, contrast_vector in contrast_dict.items():
contrast = glm_est.compute_contrast(contrast_vector) # type: ignore
df = contrast.to_dataframe()
df["ID"] = file_path
contrast_results[cond] = df
cha["ID"] = file_path