diff --git a/changelog.md b/changelog.md index b7bc2f0..971c035 100644 --- a/changelog.md +++ b/changelog.md @@ -1,7 +1,9 @@ # Version 1.2.0 +- This is a save-breaking release due to a new save file format. Please update your project files to ensure compatibility. Fixes [Issue 30](https://git.research.dezeeuw.ca/tyler/flares/issues/30) - Added new parameters to the right side of the screen - These parameters include SHOW_OPTODE_NAMES, SECONDS_TO_STRIP_HR, MAX_LOW_HR, MAX_HIGH_HR, SMOOTHING_WINDOW_HR, HEART_RATE_WINDOW, BAD_CHANNELS_HANDLING, MAX_DIST, MIN_NEIGHBORS, L_TRANS_BANDWIDTH, H_TRANS_BANDWIDTH, RESAMPLE, RESAMPLE_FREQ, STIM_DUR, HRF_MODEL, HIGH_PASS, DRIFT_ORDER, FIR_DELAYS, MIN_ONSET, OVERSAMPLING, SHORT_CHANNEL_REGRESSION, NOISE_MODEL, BINS, and VERBOSITY. +- Certain parameters now have dependencies on other parameters and will now grey out if they are not used - All the new parameters have default values matching the underlying values in version 1.1.7 - The order of the parameters have changed to match the order that the code runs when the Process button is clicked - Moved TIME_WINDOW_START and TIME_WINDOW_END to the 'Other' category @@ -17,11 +19,18 @@ - Added a new CSV export option to be used by other applications - Added support for updating optode positions directly from an .xlsx file from a Polhemius system - Fixed an issue where the dropdowns in the Viewer windows would immediately open and close when using a trackpad -- glover and spm hrf models now function as intended without crashing. Currently, group analysis is still only supported by fir -- Revamped the fold channels viewer to not hang the application and to better process multiple participants at once +- glover and spm hrf models now function as intended without crashing. Currently, group analysis is still only supported by fir. Fixes [Issue 8](https://git.research.dezeeuw.ca/tyler/flares/issues/8) +- Clicking 'Clear' should now properly clear all data. Fixes [Issue 9](https://git.research.dezeeuw.ca/tyler/flares/issues/9) +- Revamped the fold channels viewer to not hang the application and to better process multiple participants at once. Fixes [Issue 34](https://git.research.dezeeuw.ca/tyler/flares/issues/34), [Issue 31](https://git.research.dezeeuw.ca/tyler/flares/issues/31) - Added a Preferences menu to the navigation bar -- Currently, there is only one preference allowing to bypass the warning of 2D data -- Fixed [Issue 8](https://git.research.dezeeuw.ca/tyler/flares/issues/8), [Issue 9](https://git.research.dezeeuw.ca/tyler/flares/issues/9), [Issue 30](https://git.research.dezeeuw.ca/tyler/flares/issues/30), [Issue 31](https://git.research.dezeeuw.ca/tyler/flares/issues/31), [Issue 34](https://git.research.dezeeuw.ca/tyler/flares/issues/34), [Issue 36](https://git.research.dezeeuw.ca/tyler/flares/issues/36) +- Two preferences have been added allowing to bypass the warning of 2D data detected and save files being from previous, potentially breaking versions +- Fixed a typo when saving a CSV that stated a SNIRF was being saved +- Loading a save file now properly restores AGE, GENDER, and GROUP. Fixes [Issue 40](https://git.research.dezeeuw.ca/tyler/flares/issues/40) +- Saving a project now no longer makes the main window go not responding. Fixes [Issue 43](https://git.research.dezeeuw.ca/tyler/flares/issues/43) +- Memory usage should no longer grow when generating lots of images multiple times. Fixes [Issue 36](https://git.research.dezeeuw.ca/tyler/flares/issues/36) +- Added a new option in the Analysis window for Functional Connectivity +- Functional connectivity is still in development and the results should currently be taken with a grain of salt +- A warning is displayed when entering the Functional Connectivity Viewer disclosing this # Version 1.1.7 diff --git a/fc.py b/fc.py deleted file mode 100644 index fc730fb..0000000 --- a/fc.py +++ /dev/null @@ -1,207 +0,0 @@ -import mne -import numpy as np - -from mne.preprocessing.nirs import optical_density, beer_lambert_law -from mne_connectivity import spectral_connectivity_epochs -from mne_connectivity.viz import plot_connectivity_circle - -raw = mne.io.read_raw_snirf("E:/CVI_V_Adults_Cor/P21_CVI_V_updated.snirf", preload=True) - -raw.info["bads"] = [] # mark bad channels here if needed - -raw_od = optical_density(raw) -raw_hb = beer_lambert_law(raw_od) - -raw_hbo = raw_hb.copy().pick(picks="hbo") - -raw_hbo.filter( - l_freq=0.01, - h_freq=0.2, - picks="hbo", - verbose=False -) - -events = mne.make_fixed_length_events( - raw_hbo, - duration=30.0 -) - -epochs = mne.Epochs( - raw_hbo, - events, - tmin=0, - tmax=30.0, - baseline=None, - preload=True, - verbose=False -) - -data = epochs.get_data() # (n_epochs, n_channels, n_times) -names = epochs.ch_names -sfreq = epochs.info["sfreq"] - - -con = spectral_connectivity_epochs( - data, - method=["coh", "plv"], - mode="multitaper", - sfreq=sfreq, - fmin=0.04, - fmax=0.2, - faverage=True, - verbose=True -) - - -con_coh, con_plv = con - - -coh = con_coh.get_data(output="dense").squeeze() -plv = con_plv.get_data(output="dense").squeeze() - -np.fill_diagonal(coh, 0) -np.fill_diagonal(plv, 0) - -plot_connectivity_circle( - coh, - names, - title="fNIRS Functional Connectivity (HbO - Coherence)", - n_lines=40 -) - - - -from mne_connectivity import envelope_correlation -env = envelope_correlation( - data, - orthogonalize=False, - absolute=True -) -env_data = env.get_data(output="dense") - -env_corr = env_data.mean(axis=0) - -env_corr = np.squeeze(env_corr) - -np.fill_diagonal(env_corr, 0) - -plot_connectivity_circle( - env_corr, - epochs.ch_names, - title="fNIRS HbO Envelope Correlation (Task Connectivity)", - n_lines=40 -) - - -from mne_nirs.statistics import run_glm -from mne_nirs.experimental_design import make_first_level_design_matrix - -raw_hb.annotations.description = [ - f"Reach_{i}" if d == "Reach" else d - for i, d in enumerate(raw_hb.annotations.description) -] - -design_matrix = make_first_level_design_matrix( - raw_hb, - stim_dur=1.0, # We assume a short burst since duration is unknown - hrf_model='fir', # Finite Impulse Response - fir_delays=np.arange(0, 12) # Look at 0-20 seconds after onset -) - -# 2. Run the GLM -# This calculates the brain's response for every channel -glm_est = run_glm(raw_hb, design_matrix) - -import pandas as pd - -# 3. Extract Beta Weights -beta_df = glm_est.to_dataframe() - -print("\n--- DEBUG: Dataframe Info ---") -print(f"Total rows in beta_df: {len(beta_df)}") -print(f"Columns available: {list(beta_df.columns)}") -print(f"Unique Chroma values: {beta_df['Chroma'].unique()}") -print(f"First 5 unique Conditions: {beta_df['Condition'].unique()[:5]}") - -# FIX: Use .str.contains() because FIR conditions are named like 'Reach[5.0]' -# We filter for HbO AND any condition that starts with 'Reach' -hbo_betas = beta_df[(beta_df['Chroma'] == 'hbo') & - (beta_df['Condition'].str.contains('Reach'))] - -hbo_betas = hbo_betas.copy() -hbo_betas[['Trial', 'Delay']] = hbo_betas['Condition'].str.extract(r'(Reach_\d+)_delay_(\d+)') - -# 2. Find which DELAY (time point) is best across ALL trials -# We convert 'Delay' to numeric so we can sort them properly later if needed -hbo_betas['Delay'] = pd.to_numeric(hbo_betas['Delay']) - -# IMPORTANT: We ignore delay 0 and 1 because they are usually 0.0000 (stimulus onset) -# Brain responses in fNIRS usually peak between delays 4 and 8 (4-8 seconds) - -mask = (hbo_betas['Delay'] >= 4) & (hbo_betas['Delay'] <= 8) -hbo_window = hbo_betas[mask] - -if hbo_window.empty: - print("Warning: No data found in the 4-8s window. Check your 'fir_delays' range.") - # Fallback to whatever is available if 4-8 is missing - mean_by_delay = hbo_betas.groupby('Delay')['theta'].mean() -else: - mean_by_delay = hbo_window.groupby('Delay')['theta'].mean() - -peak_delay_num = mean_by_delay.idxmax() - -print(f"\n--- DEBUG: FIR Timing ---") -print(f"Delays analyzed: {list(mean_by_delay.index)}") -print(f"Peak brain response found at delay: {peak_delay_num}") - -# 3. Filter the data to ONLY include that peak delay across ALL trials -peak_df = hbo_betas[hbo_betas['Delay'] == peak_delay_num] - -mne_order = raw_hbo.ch_names - -# 4. Pivot: Rows = Trials (Reach_1, Reach_2...), Columns = Channels -beta_pivot = peak_df.pivot(index='Trial', columns='ch_name', values='theta') - -beta_pivot = beta_pivot.reindex(columns=mne_order) - -print(f"Pivot table shape: {beta_pivot.shape} (Should be something like 30 trials x 26 channels)") - -# 5. Correlation (Now it has a series of data to correlate!) -beta_corr_matrix = beta_pivot.corr().values -np.fill_diagonal(beta_corr_matrix, 0) - -# Replace any NaNs with 0 (occurs if a channel has 0 variance) -beta_corr_matrix = np.nan_to_num(beta_corr_matrix) - -import matplotlib.pyplot as plt -channel_names = beta_pivot.columns.tolist() -# Create the plot -plot_connectivity_circle( - beta_corr_matrix, - channel_names, - n_lines=40, # Show only the top 40 strongest connections - title=f"FIR Beta Series Connectivity)", -) - -# 1. Aggregate the mean response for each delay across all trials and channels -# We want to see the general 'shape' of the Reach response -time_points = np.arange(0, 12) # Matches your fir_delays -average_response = hbo_betas.groupby('Delay')['theta'].mean() - -# 2. Plotting -plt.figure(figsize=(10, 6)) -for ch in hbo_betas['ch_name'].unique(): - ch_data = hbo_betas[hbo_betas['ch_name'] == ch].groupby('Delay')['theta'].mean() - plt.plot(time_points, ch_data, color='gray', alpha=0.3) # Individual channels - -# Plot the 'Grand Average' in bold red -plt.plot(time_points, average_response, color='red', linewidth=3, label='Grand Average') - -plt.axvline(x=4, color='green', linestyle='--', label='Window Start (4s)') -plt.axvline(x=8, color='green', linestyle='--', label='Window End (8s)') -plt.title("FIR Hemodynamic Response to 'Reach' (HbO)") -plt.xlabel("Seconds after Stimulus") -plt.ylabel("HbO Concentration (Beta Weight)") -plt.legend() -plt.grid(True, alpha=0.3) -plt.show() \ No newline at end of file diff --git a/flares.py b/flares.py index abc79f3..3c54519 100644 --- a/flares.py +++ b/flares.py @@ -47,9 +47,9 @@ from nilearn.glm.regression import OLSModel import statsmodels.formula.api as smf # type: ignore from statsmodels.stats.multitest import multipletests -from scipy import stats from scipy.spatial.distance import cdist from scipy.signal import welch, butter, filtfilt # type: ignore +from scipy.stats import pearsonr, zscore, t import pywt # type: ignore import neurokit2 as nk # type: ignore @@ -91,6 +91,9 @@ from mne_nirs.io.fold import fold_channel_specificity # type: ignore from mne_nirs.preprocessing import peak_power # type: ignore from mne_nirs.statistics._glm_level_first import RegressionResults # type: ignore +from mne_connectivity.viz import plot_connectivity_circle +from mne_connectivity import envelope_correlation, spectral_connectivity_epochs, spectral_connectivity_time + # Needs to be set for mne os.environ["SUBJECTS_DIR"] = str(data_path()) + "/subjects" # type: ignore @@ -188,9 +191,9 @@ TIME_WINDOW_END: int MAX_WORKERS: int VERBOSITY: bool -AGE = 25 # Assume 25 if not set from the GUI. This will result in a reasonable PPF -GENDER = "" -GROUP = "Default" +AGE: int = 25 # Assume 25 if not set from the GUI. This will result in a reasonable PPF +GENDER: str = "" +GROUP: str = "Default" # These are parameters that are required for the analysis REQUIRED_KEYS: dict[str, Any] = { @@ -2818,7 +2821,7 @@ def run_second_level_analysis(df_contrasts, raw, p, bounds): result = model.fit(Y) t_val = result.t(0).item() - p_val = 2 * stats.t.sf(np.abs(t_val), df=result.df_model) + p_val = 2 * t.sf(np.abs(t_val), df=result.df_model) mean_beta = np.mean(Y) group_results.append({ @@ -3357,6 +3360,7 @@ def process_participant(file_path, progress_callback=None): logger.info("Step 1 Completed.") # Step 2: Trimming + # TODO: Clean this into a method if TRIM: if hasattr(raw, 'annotations') and len(raw.annotations) > 0: # Get time of first event @@ -3636,8 +3640,18 @@ def process_participant(file_path, progress_callback=None): if progress_callback: progress_callback(25) logger.info("25") - - return raw_haemo, epochs, fig_bytes, cha, contrast_results, df_ind, design_matrix, AGE, GENDER, GROUP, True + + # TODO: Tidy up + # Extract the parameters this file was ran with. No need to return age, gender, group? + config = { + k: globals()[k] + for k in __annotations__ + if k in globals() and k != "REQUIRED_KEYS" + } + + print(config) + + return raw_haemo, config, epochs, fig_bytes, cha, contrast_results, df_ind, design_matrix, True def sanitize_paths_for_pickle(raw_haemo, epochs): @@ -3647,4 +3661,234 @@ def sanitize_paths_for_pickle(raw_haemo, epochs): # Fix epochs._raw._filenames if hasattr(epochs, '_raw') and hasattr(epochs._raw, '_filenames'): - epochs._raw._filenames = [str(p) for p in epochs._raw._filenames] \ No newline at end of file + epochs._raw._filenames = [str(p) for p in epochs._raw._filenames] + + +def functional_connectivity_spectral_epochs(epochs, n_lines, vmin): + + # will crash without this load + epochs.load_data() + hbo_epochs = epochs.copy().pick(picks="hbo") + data = hbo_epochs.get_data() + names = hbo_epochs.ch_names + sfreq = hbo_epochs.info["sfreq"] + con = spectral_connectivity_epochs( + data, + method=["coh", "plv"], + mode="multitaper", + sfreq=sfreq, + fmin=0.04, + fmax=0.2, + faverage=True, + verbose=True + ) + + con_coh, con_plv = con + + coh = con_coh.get_data(output="dense").squeeze() + plv = con_plv.get_data(output="dense").squeeze() + + np.fill_diagonal(coh, 0) + np.fill_diagonal(plv, 0) + + plot_connectivity_circle( + coh, + names, + title="fNIRS Functional Connectivity (HbO - Coherence)", + n_lines=n_lines, + vmin=vmin + ) + + + + + +def functional_connectivity_spectral_time(epochs, n_lines, vmin): + + # will crash without this load + epochs.load_data() + hbo_epochs = epochs.copy().pick(picks="hbo") + data = hbo_epochs.get_data() + names = hbo_epochs.ch_names + sfreq = hbo_epochs.info["sfreq"] + + freqs = np.linspace(0.04, 0.2, 10) + n_cycles = freqs * 2 + + con = spectral_connectivity_time( + data, + freqs=freqs, + method=["coh", "plv"], + mode="multitaper", + sfreq=sfreq, + fmin=0.04, + fmax=0.2, + n_cycles=n_cycles, + faverage=True, + verbose=True + ) + + con_coh, con_plv = con + + coh = con_coh.get_data(output="dense").squeeze() + plv = con_plv.get_data(output="dense").squeeze() + + np.fill_diagonal(coh, 0) + np.fill_diagonal(plv, 0) + + plot_connectivity_circle( + coh, + names, + title="fNIRS Functional Connectivity (HbO - Coherence)", + n_lines=n_lines, + vmin=vmin + ) + + + + +def functional_connectivity_envelope(epochs, n_lines, vmin): + # will crash without this load + + epochs.load_data() + hbo_epochs = epochs.copy().pick(picks="hbo") + data = hbo_epochs.get_data() + + + env = envelope_correlation( + data, + orthogonalize=False, + absolute=True + ) + env_data = env.get_data(output="dense") + + env_corr = env_data.mean(axis=0) + + env_corr = np.squeeze(env_corr) + + np.fill_diagonal(env_corr, 0) + + plot_connectivity_circle( + env_corr, + hbo_epochs.ch_names, + title="fNIRS HbO Envelope Correlation (Task Connectivity)", + n_lines=n_lines, + vmin=vmin + ) + + +def functional_connectivity_betas(raw_hbo, n_lines, vmin, event_name=None): + + raw_hbo = raw_hbo.copy().pick(picks="hbo") + onsets = raw_hbo.annotations.onset + + # CRITICAL: Update the Raw object's annotations so the GLM sees unique events + ann = raw_hbo.annotations + new_desc = [] + + for i, desc in enumerate(ann.description): + new_desc.append(f"{desc}__trial_{i:03d}") + + ann.description = np.array(new_desc) + + + # shoudl use user defiuned!!!! + design_matrix = make_first_level_design_matrix( + raw=raw_hbo, + hrf_model='fir', + fir_delays=np.arange(0, 12, 1), + drift_model='cosine', + drift_order=1 + ) + + + # 3. Run GLM & Extract Betas + glm_results = run_glm(raw_hbo, design_matrix) + betas = np.array(glm_results.theta()) + reg_names = list(design_matrix.columns) + + + + + + n_channels = betas.shape[0] + + # ------------------------------------------------------------------ + # 5. Find unique trial tags (optionally filtered by event) + # ------------------------------------------------------------------ + trial_tags = sorted({ + col.split("_delay")[0] + for col in reg_names + if ( + ("__trial_" in col) + and (event_name is None or col.startswith(event_name + "__")) + ) + }) + + if len(trial_tags) == 0: + raise ValueError(f"No trials found for event_name={event_name}") + + # ------------------------------------------------------------------ + # 6. Build beta series (average across FIR delays per trial) + # ------------------------------------------------------------------ + beta_series = np.zeros((n_channels, len(trial_tags))) + + for t, tag in enumerate(trial_tags): + idx = [ + i for i, col in enumerate(reg_names) + if col.startswith(f"{tag}_delay") + ] + beta_series[:, t] = np.mean(betas[:, idx], axis=1).flatten() + + + + # n_channels, n_trials = betas.shape[0], len(onsets) + # beta_series = np.zeros((n_channels, n_trials)) + + # for t in range(n_trials): + # trial_indices = [i for i, col in enumerate(reg_names) if col.startswith(f"trial_{t:03d}_delay")] + # if trial_indices: + # beta_series[:, t] = np.mean(betas[:, trial_indices], axis=1).flatten() + + # Normalize each channel so they are on the same scale + # Without this, everything is connected to everything. Apparently this is a big issue in fNIRS? + beta_series = zscore(beta_series, axis=1) + + global_signal = np.mean(beta_series, axis=0) + beta_series_clean = np.zeros_like(beta_series) + for i in range(n_channels): + slope, _ = np.polyfit(global_signal, beta_series[i, :], 1) + beta_series_clean[i, :] = beta_series[i, :] - (slope * global_signal) + + # 4. Correlation & Strict Filtering + corr_matrix = np.zeros((n_channels, n_channels)) + p_matrix = np.ones((n_channels, n_channels)) + + for i in range(n_channels): + for j in range(i + 1, n_channels): + r, p = pearsonr(beta_series_clean[i, :], beta_series_clean[j, :]) + corr_matrix[i, j] = corr_matrix[j, i] = r + p_matrix[i, j] = p_matrix[j, i] = p + + # 5. High-Bar Thresholding + reject, _ = multipletests(p_matrix[np.triu_indices(n_channels, k=1)], method='fdr_bh', alpha=0.05)[:2] + sig_corr_matrix = np.zeros_like(corr_matrix) + triu = np.triu_indices(n_channels, k=1) + + for idx, is_sig in enumerate(reject): + r_val = corr_matrix[triu[0][idx], triu[1][idx]] + # Only keep the absolute strongest connections + if is_sig and abs(r_val) > 0.7: + sig_corr_matrix[triu[0][idx], triu[1][idx]] = r_val + sig_corr_matrix[triu[1][idx], triu[0][idx]] = r_val + + # 6. Plot + plot_connectivity_circle( + sig_corr_matrix, + raw_hbo.ch_names, + title="Strictly Filtered Connectivity (TDDR + GSR + Z-Score)", + n_lines=None, + vmin=0.7, + vmax=1.0, + colormap='hot' # Use 'hot' to make positive connections pop + ) diff --git a/main.py b/main.py index 21dbad6..ef1e5a2 100644 --- a/main.py +++ b/main.py @@ -226,6 +226,44 @@ SECTIONS = [ + + +class SaveProjectThread(QThread): + finished_signal = Signal(str) + error_signal = Signal(str) + + def __init__(self, filename, project_data): + super().__init__() + self.filename = filename + self.project_data = project_data + + def run(self): + try: + import pickle + with open(self.filename, "wb") as f: + pickle.dump(self.project_data, f) + self.finished_signal.emit(self.filename) + except Exception as e: + self.error_signal.emit(str(e)) + + +class SavingOverlay(QDialog): + def __init__(self, parent=None): + super().__init__(parent) + self.setWindowFlags(Qt.WindowType.Dialog | Qt.WindowType.FramelessWindowHint) + self.setModal(True) + self.setWindowModality(Qt.WindowModality.ApplicationModal) + self.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground) + + layout = QVBoxLayout() + layout.setAlignment(Qt.AlignmentFlag.AlignCenter) + + label = QLabel("Saving Project…") + label.setStyleSheet("font-size: 18px; color: white; background-color: rgba(0,0,0,150); padding: 20px; border-radius: 10px;") + layout.addWidget(label) + self.setLayout(layout) + + class TerminalWindow(QWidget): def __init__(self, parent=None): super().__init__(parent, Qt.WindowType.Window) @@ -2439,6 +2477,371 @@ class ParticipantBrainViewerWidget(QWidget): + + + +class ParticipantFunctionalConnectivityWidget(QWidget): + def __init__(self, haemo_dict, epochs_dict): + super().__init__() + self.setWindowTitle("FLARES Functional Connectivity Viewer [BETA]") + self.haemo_dict = haemo_dict + self.epochs_dict = epochs_dict + + QMessageBox.warning(self, "Warning - FLARES", f"Functional Connectivity is still in development and the results should currently be taken with a grain of salt. " + "By clicking OK, you accept that the images generated may not be factual.") + + # Create mappings: file_path -> participant label and dropdown display text + self.participant_map = {} # file_path -> "Participant 1" + self.participant_dropdown_items = [] # "Participant 1 (filename)" + + for i, file_path in enumerate(self.haemo_dict.keys(), start=1): + short_label = f"Participant {i}" + display_label = f"{short_label} ({os.path.basename(file_path)})" + self.participant_map[file_path] = short_label + self.participant_dropdown_items.append(display_label) + + self.layout = QVBoxLayout(self) + self.top_bar = QHBoxLayout() + self.layout.addLayout(self.top_bar) + + self.participant_dropdown = self._create_multiselect_dropdown(self.participant_dropdown_items) + self.participant_dropdown.currentIndexChanged.connect(self.update_participant_dropdown_label) + + self.event_dropdown = QComboBox() + self.event_dropdown.addItem("") + + + self.index_texts = [ + "0 (Spectral Connectivity Epochs)", + "1 (Envelope Correlation)", + "2 (Betas)", + "3 (Spectral Connectivity Epochs)", + ] + + self.image_index_dropdown = self._create_multiselect_dropdown(self.index_texts) + self.image_index_dropdown.currentIndexChanged.connect(self.update_image_index_dropdown_label) + + self.submit_button = QPushButton("Submit") + self.submit_button.clicked.connect(self.show_brain_images) + + self.top_bar.addWidget(QLabel("Participants:")) + self.top_bar.addWidget(self.participant_dropdown) + self.top_bar.addWidget(QLabel("Event:")) + self.top_bar.addWidget(self.event_dropdown) + self.top_bar.addWidget(QLabel("Image Indexes:")) + self.top_bar.addWidget(self.image_index_dropdown) + self.top_bar.addWidget(self.submit_button) + + self.scroll = QScrollArea() + self.scroll.setWidgetResizable(True) + self.scroll_content = QWidget() + self.grid_layout = QGridLayout(self.scroll_content) + self.scroll.setWidget(self.scroll_content) + self.layout.addWidget(self.scroll) + + self.thumb_size = QSize(280, 180) + self.showMaximized() + + def _create_multiselect_dropdown(self, items): + combo = FullClickComboBox() + combo.setView(QListView()) + model = QStandardItemModel() + combo.setModel(model) + combo.setEditable(True) + combo.lineEdit().setReadOnly(True) + combo.lineEdit().setPlaceholderText("Select...") + + + dummy_item = QStandardItem("") + dummy_item.setFlags(Qt.ItemIsEnabled) + model.appendRow(dummy_item) + + toggle_item = QStandardItem("Toggle Select All") + toggle_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) + toggle_item.setData(Qt.Unchecked, Qt.CheckStateRole) + model.appendRow(toggle_item) + + for item in items: + standard_item = QStandardItem(item) + standard_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) + standard_item.setData(Qt.Unchecked, Qt.CheckStateRole) + model.appendRow(standard_item) + + combo.setInsertPolicy(QComboBox.NoInsert) + + + def on_view_clicked(index): + item = model.itemFromIndex(index) + if item.isCheckable(): + new_state = Qt.Checked if item.checkState() == Qt.Unchecked else Qt.Unchecked + item.setCheckState(new_state) + + combo.view().pressed.connect(on_view_clicked) + + self._updating_checkstates = False + + def on_item_changed(item): + if self._updating_checkstates: + return + self._updating_checkstates = True + + normal_items = [model.item(i) for i in range(2, model.rowCount())] # skip dummy and toggle + + if item == toggle_item: + all_checked = all(i.checkState() == Qt.Checked for i in normal_items) + if all_checked: + for i in normal_items: + i.setCheckState(Qt.Unchecked) + toggle_item.setCheckState(Qt.Unchecked) + else: + for i in normal_items: + i.setCheckState(Qt.Checked) + toggle_item.setCheckState(Qt.Checked) + + elif item == dummy_item: + pass + + else: + # When normal items change, update toggle item + all_checked = all(i.checkState() == Qt.Checked for i in normal_items) + toggle_item.setCheckState(Qt.Checked if all_checked else Qt.Unchecked) + + # Update label text immediately after change + if combo == self.participant_dropdown: + self.update_participant_dropdown_label() + elif combo == self.image_index_dropdown: + self.update_image_index_dropdown_label() + + self._updating_checkstates = False + + model.itemChanged.connect(on_item_changed) + + combo.setInsertPolicy(QComboBox.NoInsert) + return combo + + def _get_checked_items(self, combo): + checked = [] + model = combo.model() + for i in range(model.rowCount()): + item = model.item(i) + # Skip dummy and toggle items: + if item.text() in ("", "Toggle Select All"): + continue + if item.checkState() == Qt.Checked: + checked.append(item.text()) + return checked + + def update_participant_dropdown_label(self): + selected = self._get_checked_items(self.participant_dropdown) + if not selected: + self.participant_dropdown.lineEdit().setText("") + else: + # Extract just "Participant N" from "Participant N (filename)" + selected_short = [s.split(" ")[0] + " " + s.split(" ")[1] for s in selected] + self.participant_dropdown.lineEdit().setText(", ".join(selected_short)) + self._update_event_dropdown() + + def update_image_index_dropdown_label(self): + selected = self._get_checked_items(self.image_index_dropdown) + if not selected: + self.image_index_dropdown.lineEdit().setText("") + else: + # Only show the index part + index_labels = [s.split(" ")[0] for s in selected] + self.image_index_dropdown.lineEdit().setText(", ".join(index_labels)) + + def _update_event_dropdown(self): + selected_display_names = self._get_checked_items(self.participant_dropdown) + selected_file_paths = [] + for display_name in selected_display_names: + for fp, short_label in self.participant_map.items(): + expected_display = f"{short_label} ({os.path.basename(fp)})" + if display_name == expected_display: + selected_file_paths.append(fp) + break + + if not selected_file_paths: + self.event_dropdown.clear() + self.event_dropdown.addItem("") + return + + annotation_sets = [] + + for file_path in selected_file_paths: + raw = self.haemo_dict.get(file_path) + if raw is None or not hasattr(raw, "annotations"): + continue + annotations = set(raw.annotations.description) + annotation_sets.append(annotations) + + if not annotation_sets: + self.event_dropdown.clear() + self.event_dropdown.addItem("") + return + + shared_annotations = set.intersection(*annotation_sets) + self.event_dropdown.clear() + self.event_dropdown.addItem("") + for ann in sorted(shared_annotations): + self.event_dropdown.addItem(ann) + + def show_brain_images(self): + import flares + + selected_event = self.event_dropdown.currentText() + if selected_event == "": + selected_event = None + + selected_display_names = self._get_checked_items(self.participant_dropdown) + selected_file_paths = [] + for display_name in selected_display_names: + for fp, short_label in self.participant_map.items(): + expected_display = f"{short_label} ({os.path.basename(fp)})" + if display_name == expected_display: + selected_file_paths.append(fp) + break + + selected_indexes = [ + int(s.split(" ")[0]) for s in self._get_checked_items(self.image_index_dropdown) + ] + + + parameterized_indexes = { + 0: [ + { + "key": "n_lines", + "label": "", + "default": "20", + "type": int, + }, + { + "key": "vmin", + "label": "", + "default": "0.9", + "type": float, + }, + ], + 1: [ + { + "key": "n_lines", + "label": "", + "default": "20", + "type": int, + }, + { + "key": "vmin", + "label": "", + "default": "0.9", + "type": float, + }, + + ], + 2: [ + { + "key": "n_lines", + "label": "", + "default": "20", + "type": int, + }, + { + "key": "vmin", + "label": "", + "default": "0.9", + "type": float, + }, + + ], + 3: [ + { + "key": "n_lines", + "label": "", + "default": "20", + "type": int, + }, + { + "key": "vmin", + "label": "", + "default": "0.9", + "type": float, + }, + + ], + } + + # Inject full_text from index_texts + for idx, params_list in parameterized_indexes.items(): + full_text = self.index_texts[idx] if idx < len(self.index_texts) else f"{idx} (No label found)" + for param_info in params_list: + param_info["full_text"] = full_text + + indexes_needing_params = {idx: parameterized_indexes[idx] for idx in selected_indexes if idx in parameterized_indexes} + + param_values = {} + if indexes_needing_params: + dialog = ParameterInputDialog(indexes_needing_params, parent=self) + if dialog.exec_() == QDialog.Accepted: + param_values = dialog.get_values() + if param_values is None: + return + else: + return + + # Pass the necessary arguments to each method + for file_path in selected_file_paths: + haemo_obj = self.haemo_dict.get(file_path) + epochs_obj = self.epochs_dict.get(file_path) + + if haemo_obj is None: + raise Exception("How did we get here?") + + + for idx in selected_indexes: + if idx == 0: + + params = param_values.get(idx, {}) + n_lines = params.get("n_lines", None) + vmin = params.get("vmin", None) + + if n_lines is None or vmin is None: + print(f"Missing parameters for index {idx}, skipping.") + continue + flares.functional_connectivity_spectral_epochs(epochs_obj, n_lines, vmin) + + elif idx == 1: + params = param_values.get(idx, {}) + n_lines = params.get("n_lines", None) + vmin = params.get("vmin", None) + + if n_lines is None or vmin is None: + print(f"Missing parameters for index {idx}, skipping.") + continue + flares.functional_connectivity_envelope(epochs_obj, n_lines, vmin) + + elif idx == 2: + params = param_values.get(idx, {}) + n_lines = params.get("n_lines", None) + vmin = params.get("vmin", None) + + if n_lines is None or vmin is None: + print(f"Missing parameters for index {idx}, skipping.") + continue + flares.functional_connectivity_betas(haemo_obj, n_lines, vmin, selected_event) + + elif idx == 3: + params = param_values.get(idx, {}) + n_lines = params.get("n_lines", None) + vmin = params.get("vmin", None) + + if n_lines is None or vmin is None: + print(f"Missing parameters for index {idx}, skipping.") + continue + flares.functional_connectivity_spectral_time(epochs_obj, n_lines, vmin) + + else: + print(f"No method defined for index {idx}") + + + class MultiProgressDialog(QDialog): def __init__(self, parent=None): super().__init__(parent) @@ -3000,7 +3403,7 @@ class ExportDataAsCSVViewerWidget(QWidget): # Open save dialog save_path, _ = QFileDialog.getSaveFileName( self, - "Save SNIRF File As", + "Save CSV File As", suggested_name, "CSV Files (*.csv)" ) @@ -3017,7 +3420,7 @@ class ExportDataAsCSVViewerWidget(QWidget): QMessageBox.information(self, "Success", "CSV file has been saved.") except Exception as e: - QMessageBox.critical(self, "Error", f"Failed to update SNIRF file:\n{e}") + QMessageBox.critical(self, "Error", f"Failed to update CSV file:\n{e}") elif idx == 1: @@ -3027,7 +3430,7 @@ class ExportDataAsCSVViewerWidget(QWidget): # Open save dialog save_path, _ = QFileDialog.getSaveFileName( self, - "Save SNIRF File As", + "Save CSV File As", suggested_name, "CSV Files (*.csv)" ) @@ -3071,7 +3474,7 @@ class ExportDataAsCSVViewerWidget(QWidget): win.show() except Exception as e: - QMessageBox.critical(self, "Error", f"Failed to update SNIRF file:\n{e}") + QMessageBox.critical(self, "Error", f"Failed to update CSV file:\n{e}") else: @@ -4263,10 +4666,15 @@ class GroupBrainViewerWidget(QWidget): class ViewerLauncherWidget(QWidget): - def __init__(self, haemo_dict, fig_bytes_dict, cha_dict, contrast_results_dict, df_ind, design_matrix, group): + def __init__(self, haemo_dict, config_dict, fig_bytes_dict, cha_dict, contrast_results_dict, df_ind, design_matrix, epochs_dict): super().__init__() self.setWindowTitle("Viewer Launcher") + group_dict = { + file_path: config.get("GROUP", "Unknown") # default if GROUP missing + for file_path, config in config_dict.items() + } + layout = QVBoxLayout(self) btn1 = QPushButton("Open Participant Viewer") @@ -4278,19 +4686,23 @@ class ViewerLauncherWidget(QWidget): btn3 = QPushButton("Open Participant Fold Channels Viewer") btn3.clicked.connect(lambda: self.open_participant_fold_channels_viewer(haemo_dict, cha_dict)) + btn7 = QPushButton("Open Functional Connectivity Viewer [BETA]") + btn7.clicked.connect(lambda: self.open_participant_functional_connectivity_viewer(haemo_dict, epochs_dict)) + btn4 = QPushButton("Open Inter-Group Viewer") - btn4.clicked.connect(lambda: self.open_group_viewer(haemo_dict, cha_dict, df_ind, design_matrix, contrast_results_dict, group)) + btn4.clicked.connect(lambda: self.open_group_viewer(haemo_dict, cha_dict, df_ind, design_matrix, contrast_results_dict, group_dict)) btn5 = QPushButton("Open Cross Group Brain Viewer") - btn5.clicked.connect(lambda: self.open_group_brain_viewer(haemo_dict, df_ind, design_matrix, group, contrast_results_dict)) + btn5.clicked.connect(lambda: self.open_group_brain_viewer(haemo_dict, df_ind, design_matrix, group_dict, contrast_results_dict)) btn6 = QPushButton("Open Export Data As CSV Viewer") - btn6.clicked.connect(lambda: self.open_export_data_as_csv_viewer(haemo_dict, cha_dict, df_ind, design_matrix, group, contrast_results_dict)) + btn6.clicked.connect(lambda: self.open_export_data_as_csv_viewer(haemo_dict, cha_dict, df_ind, design_matrix, group_dict, contrast_results_dict)) layout.addWidget(btn1) layout.addWidget(btn2) layout.addWidget(btn3) + layout.addWidget(btn7) layout.addWidget(btn4) layout.addWidget(btn5) layout.addWidget(btn6) @@ -4307,6 +4719,10 @@ class ViewerLauncherWidget(QWidget): self.participant_fold_channels_viewer = ParticipantFoldChannelsWidget(haemo_dict, cha_dict) self.participant_fold_channels_viewer.show() + def open_participant_functional_connectivity_viewer(self, haemo_dict, epochs_dict): + self.participant_brain_viewer = ParticipantFunctionalConnectivityWidget(haemo_dict, epochs_dict) + self.participant_brain_viewer.show() + def open_group_viewer(self, haemo_dict, cha_dict, df_ind, design_matrix, contrast_results_dict, group): self.participant_brain_viewer = GroupViewerWidget(haemo_dict, cha_dict, df_ind, design_matrix, contrast_results_dict, group) self.participant_brain_viewer.show() @@ -4343,7 +4759,8 @@ class MainApplication(QMainWindow): self.section_widget = None self.first_run = True self.is_2d_bypass = False - + self.incompatible_save_bypass = False + self.files_total = 0 # total number of files to process self.files_done = set() # set of file paths done (success or fail) self.files_failed = set() # set of failed file paths @@ -4593,7 +5010,8 @@ class MainApplication(QMainWindow): preferences_menu = menu_bar.addMenu("Preferences") preferences_actions = [ - ("2D Data Bypass", "Ctrl+B", self.is_2d_bypass_func, resource_path("icons/info_24dp_1F1F1F.svg")) + ("2D Data Bypass", "", self.is_2d_bypass_func, resource_path("icons/info_24dp_1F1F1F.svg")), + ("Incompatible Save Bypass", "", self.incompatable_save_bypass_func, resource_path("icons/info_24dp_1F1F1F.svg")) ] for name, shortcut, slot, icon in preferences_actions: preferences_menu.addAction(make_action(name, shortcut, slot, icon=icon, checkable=True, checked=False)) @@ -4643,15 +5061,13 @@ class MainApplication(QMainWindow): self.statusBar().clearMessage() self.raw_haemo_dict = None + self.config_dict = None self.epochs_dict = None self.fig_bytes_dict = None self.cha_dict = None self.contrast_results_dict = None self.df_ind_dict = None self.design_matrix_dict = None - self.age_dict = None - self.gender_dict = None - self.group_dict = None self.valid_dict = None # Reset any visible UI elements @@ -4662,7 +5078,7 @@ class MainApplication(QMainWindow): def open_launcher_window(self): - self.launcher_window = ViewerLauncherWidget(self.raw_haemo_dict, self.fig_bytes_dict, self.cha_dict, self.contrast_results_dict, self.df_ind_dict, self.design_matrix_dict, self.group_dict) + self.launcher_window = ViewerLauncherWidget(self.raw_haemo_dict, self.config_dict, self.fig_bytes_dict, self.cha_dict, self.contrast_results_dict, self.df_ind_dict, self.design_matrix_dict, self.epochs_dict) self.launcher_window.show() @@ -4681,6 +5097,9 @@ class MainApplication(QMainWindow): def is_2d_bypass_func(self, checked): self.is_2d_bypass = checked + def incompatable_save_bypass_func(self, checked): + self.incompatible_save_bypass = checked + def about_window(self): if self.about is None or not self.about.isVisible(): self.about = AboutWindow(self) @@ -4839,19 +5258,19 @@ class MainApplication(QMainWindow): for bubble in self.bubble_widgets.values() } + version = CURRENT_VERSION project_data = { + "version": version, "file_list": file_list, "progress_states": progress_states, "raw_haemo_dict": self.raw_haemo_dict, + "config_dict": self.config_dict, "epochs_dict": self.epochs_dict, "fig_bytes_dict": self.fig_bytes_dict, "cha_dict": self.cha_dict, "contrast_results_dict": self.contrast_results_dict, "df_ind_dict": self.df_ind_dict, "design_matrix_dict": self.design_matrix_dict, - "age_dict": self.age_dict, - "gender_dict": self.gender_dict, - "group_dict": self.group_dict, "valid_dict": self.valid_dict, } @@ -4866,10 +5285,24 @@ class MainApplication(QMainWindow): project_data = sanitize(project_data) - with open(filename, "wb") as f: - pickle.dump(project_data, f) - - QMessageBox.information(self, "Success", f"Project saved to:\n{filename}") + self.saving_overlay = SavingOverlay(self) + self.saving_overlay.resize(self.size()) # Cover the main window + self.saving_overlay.show() + + # Start the background save thread + self.save_thread = SaveProjectThread(filename, project_data) + + # When finished, close overlay and show success + self.save_thread.finished_signal.connect(lambda f: ( + self.saving_overlay.close(), + QMessageBox.information(self, "Success", f"Project saved to:\n{f}") + )) + self.save_thread.error_signal.connect(lambda e: ( + self.saving_overlay.close(), + QMessageBox.critical(self, "Error", f"Failed to save project:\n{e}") + )) + + self.save_thread.start() except Exception as e: if not onCrash: @@ -4888,16 +5321,27 @@ class MainApplication(QMainWindow): with open(filename, "rb") as f: data = pickle.load(f) + # Check for saves prior to 1.2.0 + if "version" not in data: + print(self.incompatible_save_bypass) + if self.incompatible_save_bypass: + QMessageBox.warning(self, "Warning - FLARES", f"This project was saved in an earlier version of FLARES (<=1.1.7) and is potentially not compatible with this version. " + "You are receiving this warning because you have 'Incompatible Save Bypass' turned on. FLARES will now attempt to load the project. It is strongly " + "recommended to recreate the project file.") + else: + QMessageBox.critical(self, "Error - FLARES", f"This project was saved in an earlier version of FLARES (<=1.1.7) and is potentially not compatible with this version. " + "The file can attempt to be loaded if 'Incompatible Save Bypass' is selected in the 'Preferences' menu.") + return + + self.raw_haemo_dict = data.get("raw_haemo_dict", {}) + self.config_dict = data.get("config_dict", {}) self.epochs_dict = data.get("epochs_dict", {}) self.fig_bytes_dict = data.get("fig_bytes_dict", {}) self.cha_dict = data.get("cha_dict", {}) self.contrast_results_dict = data.get("contrast_results_dict", {}) self.df_ind_dict = data.get("df_ind_dict", {}) self.design_matrix_dict = data.get("design_matrix_dict", {}) - self.age_dict = data.get("age_dict", {}) - self.gender_dict = data.get("gender_dict", {}) - self.group_dict = data.get("group_dict", {}) self.valid_dict = data.get("valid_dict", {}) project_dir = Path(filename).parent @@ -4913,7 +5357,27 @@ class MainApplication(QMainWindow): } self.show_files_as_bubbles_from_list(file_list, progress_states, filename) - + + for file_path, config in self.config_dict.items(): + # Only store AGE, GENDER, GROUP + self.file_metadata[file_path] = { + key: str(config.get(key, "")) # convert to str for QLineEdit + for key in ["AGE", "GENDER", "GROUP"] + } + + if self.config_dict: + first_file = next(iter(self.config_dict.keys())) + self.current_file = first_file + + # Update meta fields (AGE/GENDER/GROUP) + for key, field in self.meta_fields.items(): + field.setText(self.file_metadata[first_file][key]) + self.right_column_widget.show() + + # Restore all other constants to the parameter sections + first_config = self.config_dict[first_file] + self.restore_sections_from_config(first_config) + # Re-enable buttons # self.button1.setVisible(True) self.button3.setVisible(True) @@ -4924,6 +5388,54 @@ class MainApplication(QMainWindow): QMessageBox.critical(self, "Error", f"Failed to load project:\n{e}") + def restore_sections_from_config(self, config): + """ + Fill all ParamSection widgets with values from a participant's config. + """ + for section_widget in self.param_sections: + widgets_dict = getattr(section_widget, 'widgets', None) + if widgets_dict is None: + continue + + for name, widget_info in widgets_dict.items(): + if name not in config: + continue + + value = config[name] + print(f"Restoring {name} = {value}") + + widget = widget_info["widget"] + w_type = widget_info.get("type") + + # QLineEdit (int, float, str) + if isinstance(widget, QLineEdit): + widget.blockSignals(True) + widget.setText(str(value)) + widget.blockSignals(False) + widget.update() + + # QComboBox (bool, list) + elif isinstance(widget, QComboBox): + widget.blockSignals(True) + widget.setCurrentText(str(value)) + widget.blockSignals(False) + widget.update() + + # QSpinBox (range) + elif isinstance(widget, QSpinBox): + widget.blockSignals(True) + try: + widget.setValue(int(value)) + except Exception: + pass + widget.blockSignals(False) + widget.update() + + # After restoring, make sure dependencies are updated + if hasattr(section_widget, 'update_dependencies'): + section_widget.update_dependencies() + + def show_files_as_bubbles(self, folder_paths): @@ -5366,29 +5878,25 @@ class MainApplication(QMainWindow): # TODO: Is this check needed? Edit: yes very much so if getattr(self, 'raw_haemo_dict', None) is None: self.raw_haemo_dict = {} + self.config_dict = {} self.epochs_dict = {} self.fig_bytes_dict = {} self.cha_dict = {} self.contrast_results_dict = {} self.df_ind_dict = {} self.design_matrix_dict = {} - self.age_dict = {} - self.gender_dict = {} - self.group_dict = {} self.valid_dict = {} # Combine all results into the dicts - for file_path, (raw_haemo, epochs, fig_bytes, cha, contrast_results, df_ind, design_matrix, age, gender, group, valid) in results.items(): + for file_path, (raw_haemo, config, epochs, fig_bytes, cha, contrast_results, df_ind, design_matrix, valid) in results.items(): self.raw_haemo_dict[file_path] = raw_haemo + self.config_dict[file_path] = config self.epochs_dict[file_path] = epochs self.fig_bytes_dict[file_path] = fig_bytes self.cha_dict[file_path] = cha self.contrast_results_dict[file_path] = contrast_results self.df_ind_dict[file_path] = df_ind self.design_matrix_dict[file_path] = design_matrix - self.age_dict[file_path] = age - self.gender_dict[file_path] = gender - self.group_dict[file_path] = group self.valid_dict[file_path] = valid # self.statusbar.showMessage(f"Processing complete! Time elapsed: {elapsed_time:.2f} seconds")