release worthy?

This commit is contained in:
2026-01-31 23:42:49 -08:00
parent dd2ac058af
commit f1dd9bd184
4 changed files with 805 additions and 251 deletions

260
flares.py
View File

@@ -47,9 +47,9 @@ from nilearn.glm.regression import OLSModel
import statsmodels.formula.api as smf # type: ignore
from statsmodels.stats.multitest import multipletests
from scipy import stats
from scipy.spatial.distance import cdist
from scipy.signal import welch, butter, filtfilt # type: ignore
from scipy.stats import pearsonr, zscore, t
import pywt # type: ignore
import neurokit2 as nk # type: ignore
@@ -91,6 +91,9 @@ from mne_nirs.io.fold import fold_channel_specificity # type: ignore
from mne_nirs.preprocessing import peak_power # type: ignore
from mne_nirs.statistics._glm_level_first import RegressionResults # type: ignore
from mne_connectivity.viz import plot_connectivity_circle
from mne_connectivity import envelope_correlation, spectral_connectivity_epochs, spectral_connectivity_time
# Needs to be set for mne
os.environ["SUBJECTS_DIR"] = str(data_path()) + "/subjects" # type: ignore
@@ -188,9 +191,9 @@ TIME_WINDOW_END: int
MAX_WORKERS: int
VERBOSITY: bool
AGE = 25 # Assume 25 if not set from the GUI. This will result in a reasonable PPF
GENDER = ""
GROUP = "Default"
AGE: int = 25 # Assume 25 if not set from the GUI. This will result in a reasonable PPF
GENDER: str = ""
GROUP: str = "Default"
# These are parameters that are required for the analysis
REQUIRED_KEYS: dict[str, Any] = {
@@ -2818,7 +2821,7 @@ def run_second_level_analysis(df_contrasts, raw, p, bounds):
result = model.fit(Y)
t_val = result.t(0).item()
p_val = 2 * stats.t.sf(np.abs(t_val), df=result.df_model)
p_val = 2 * t.sf(np.abs(t_val), df=result.df_model)
mean_beta = np.mean(Y)
group_results.append({
@@ -3357,6 +3360,7 @@ def process_participant(file_path, progress_callback=None):
logger.info("Step 1 Completed.")
# Step 2: Trimming
# TODO: Clean this into a method
if TRIM:
if hasattr(raw, 'annotations') and len(raw.annotations) > 0:
# Get time of first event
@@ -3636,8 +3640,18 @@ def process_participant(file_path, progress_callback=None):
if progress_callback: progress_callback(25)
logger.info("25")
return raw_haemo, epochs, fig_bytes, cha, contrast_results, df_ind, design_matrix, AGE, GENDER, GROUP, True
# TODO: Tidy up
# Extract the parameters this file was ran with. No need to return age, gender, group?
config = {
k: globals()[k]
for k in __annotations__
if k in globals() and k != "REQUIRED_KEYS"
}
print(config)
return raw_haemo, config, epochs, fig_bytes, cha, contrast_results, df_ind, design_matrix, True
def sanitize_paths_for_pickle(raw_haemo, epochs):
@@ -3647,4 +3661,234 @@ def sanitize_paths_for_pickle(raw_haemo, epochs):
# Fix epochs._raw._filenames
if hasattr(epochs, '_raw') and hasattr(epochs._raw, '_filenames'):
epochs._raw._filenames = [str(p) for p in epochs._raw._filenames]
epochs._raw._filenames = [str(p) for p in epochs._raw._filenames]
def functional_connectivity_spectral_epochs(epochs, n_lines, vmin):
# will crash without this load
epochs.load_data()
hbo_epochs = epochs.copy().pick(picks="hbo")
data = hbo_epochs.get_data()
names = hbo_epochs.ch_names
sfreq = hbo_epochs.info["sfreq"]
con = spectral_connectivity_epochs(
data,
method=["coh", "plv"],
mode="multitaper",
sfreq=sfreq,
fmin=0.04,
fmax=0.2,
faverage=True,
verbose=True
)
con_coh, con_plv = con
coh = con_coh.get_data(output="dense").squeeze()
plv = con_plv.get_data(output="dense").squeeze()
np.fill_diagonal(coh, 0)
np.fill_diagonal(plv, 0)
plot_connectivity_circle(
coh,
names,
title="fNIRS Functional Connectivity (HbO - Coherence)",
n_lines=n_lines,
vmin=vmin
)
def functional_connectivity_spectral_time(epochs, n_lines, vmin):
# will crash without this load
epochs.load_data()
hbo_epochs = epochs.copy().pick(picks="hbo")
data = hbo_epochs.get_data()
names = hbo_epochs.ch_names
sfreq = hbo_epochs.info["sfreq"]
freqs = np.linspace(0.04, 0.2, 10)
n_cycles = freqs * 2
con = spectral_connectivity_time(
data,
freqs=freqs,
method=["coh", "plv"],
mode="multitaper",
sfreq=sfreq,
fmin=0.04,
fmax=0.2,
n_cycles=n_cycles,
faverage=True,
verbose=True
)
con_coh, con_plv = con
coh = con_coh.get_data(output="dense").squeeze()
plv = con_plv.get_data(output="dense").squeeze()
np.fill_diagonal(coh, 0)
np.fill_diagonal(plv, 0)
plot_connectivity_circle(
coh,
names,
title="fNIRS Functional Connectivity (HbO - Coherence)",
n_lines=n_lines,
vmin=vmin
)
def functional_connectivity_envelope(epochs, n_lines, vmin):
# will crash without this load
epochs.load_data()
hbo_epochs = epochs.copy().pick(picks="hbo")
data = hbo_epochs.get_data()
env = envelope_correlation(
data,
orthogonalize=False,
absolute=True
)
env_data = env.get_data(output="dense")
env_corr = env_data.mean(axis=0)
env_corr = np.squeeze(env_corr)
np.fill_diagonal(env_corr, 0)
plot_connectivity_circle(
env_corr,
hbo_epochs.ch_names,
title="fNIRS HbO Envelope Correlation (Task Connectivity)",
n_lines=n_lines,
vmin=vmin
)
def functional_connectivity_betas(raw_hbo, n_lines, vmin, event_name=None):
raw_hbo = raw_hbo.copy().pick(picks="hbo")
onsets = raw_hbo.annotations.onset
# CRITICAL: Update the Raw object's annotations so the GLM sees unique events
ann = raw_hbo.annotations
new_desc = []
for i, desc in enumerate(ann.description):
new_desc.append(f"{desc}__trial_{i:03d}")
ann.description = np.array(new_desc)
# shoudl use user defiuned!!!!
design_matrix = make_first_level_design_matrix(
raw=raw_hbo,
hrf_model='fir',
fir_delays=np.arange(0, 12, 1),
drift_model='cosine',
drift_order=1
)
# 3. Run GLM & Extract Betas
glm_results = run_glm(raw_hbo, design_matrix)
betas = np.array(glm_results.theta())
reg_names = list(design_matrix.columns)
n_channels = betas.shape[0]
# ------------------------------------------------------------------
# 5. Find unique trial tags (optionally filtered by event)
# ------------------------------------------------------------------
trial_tags = sorted({
col.split("_delay")[0]
for col in reg_names
if (
("__trial_" in col)
and (event_name is None or col.startswith(event_name + "__"))
)
})
if len(trial_tags) == 0:
raise ValueError(f"No trials found for event_name={event_name}")
# ------------------------------------------------------------------
# 6. Build beta series (average across FIR delays per trial)
# ------------------------------------------------------------------
beta_series = np.zeros((n_channels, len(trial_tags)))
for t, tag in enumerate(trial_tags):
idx = [
i for i, col in enumerate(reg_names)
if col.startswith(f"{tag}_delay")
]
beta_series[:, t] = np.mean(betas[:, idx], axis=1).flatten()
# n_channels, n_trials = betas.shape[0], len(onsets)
# beta_series = np.zeros((n_channels, n_trials))
# for t in range(n_trials):
# trial_indices = [i for i, col in enumerate(reg_names) if col.startswith(f"trial_{t:03d}_delay")]
# if trial_indices:
# beta_series[:, t] = np.mean(betas[:, trial_indices], axis=1).flatten()
# Normalize each channel so they are on the same scale
# Without this, everything is connected to everything. Apparently this is a big issue in fNIRS?
beta_series = zscore(beta_series, axis=1)
global_signal = np.mean(beta_series, axis=0)
beta_series_clean = np.zeros_like(beta_series)
for i in range(n_channels):
slope, _ = np.polyfit(global_signal, beta_series[i, :], 1)
beta_series_clean[i, :] = beta_series[i, :] - (slope * global_signal)
# 4. Correlation & Strict Filtering
corr_matrix = np.zeros((n_channels, n_channels))
p_matrix = np.ones((n_channels, n_channels))
for i in range(n_channels):
for j in range(i + 1, n_channels):
r, p = pearsonr(beta_series_clean[i, :], beta_series_clean[j, :])
corr_matrix[i, j] = corr_matrix[j, i] = r
p_matrix[i, j] = p_matrix[j, i] = p
# 5. High-Bar Thresholding
reject, _ = multipletests(p_matrix[np.triu_indices(n_channels, k=1)], method='fdr_bh', alpha=0.05)[:2]
sig_corr_matrix = np.zeros_like(corr_matrix)
triu = np.triu_indices(n_channels, k=1)
for idx, is_sig in enumerate(reject):
r_val = corr_matrix[triu[0][idx], triu[1][idx]]
# Only keep the absolute strongest connections
if is_sig and abs(r_val) > 0.7:
sig_corr_matrix[triu[0][idx], triu[1][idx]] = r_val
sig_corr_matrix[triu[1][idx], triu[0][idx]] = r_val
# 6. Plot
plot_connectivity_circle(
sig_corr_matrix,
raw_hbo.ch_names,
title="Strictly Filtered Connectivity (TDDR + GSR + Z-Score)",
n_lines=None,
vmin=0.7,
vmax=1.0,
colormap='hot' # Use 'hot' to make positive connections pop
)