diff --git a/flares.py b/flares.py index 0d8cd04..a8d4599 100644 --- a/flares.py +++ b/flares.py @@ -100,14 +100,16 @@ from mne_connectivity import envelope_correlation, spectral_connectivity_epochs, os.environ["SUBJECTS_DIR"] = str(data_path()) + "/subjects" # type: ignore PRIMARY_COLORS = { - "SCI only": "skyblue", - "PSP only": "salmon", - "SNR only": "lightgreen", - "Master only": "plum", - "Swing only": "coral", + "SCI only": "skyblue", # Scalp Coupling Index (Standard MNE) + "SNR only": "lightgreen", # Signal-to-Noise Ratio (Your original) + "PSP only": "salmon", # Power Spectral Peak (Original Noise check) + "CV only": "yellow", # Relative Noise (The CV-only check) + "Range only": "coral", # Z-Swing (The Range Outlier check) + "Noise only": "plum", # High-Freq PSD (The Noise check) + "Disp. only": "palegreen", # Sensor Displacement (Variance Drop) + "Multiple": "orangered" # Failed 2+ categories } -# The Fallback for combinations (the "Multifail" Gray) COMBINATION_COLOR = "gray" def get_category_color(label): @@ -342,27 +344,7 @@ def set_metadata(file_path, metadata: dict[str, Any]) -> None: if val not in (None, '', [], {}, ()): # check for "empty" values globals()[key] = val -def gui_entry(config: dict[str, Any], gui_queue: Queue, progress_queue: Queue) -> None: - def forward_progress(): - while True: - try: - msg = progress_queue.get(timeout=1) - if msg == "__done__": - break - gui_queue.put(msg) - except Empty: - continue - except Exception as e: - gui_queue.put({ - "type": "error", - "error": f"Forwarding thread crashed: {e}", - "traceback": traceback.format_exc() - }) - break - - t = threading.Thread(target=forward_progress, daemon=True) - t.start() - +def gui_entry(config: dict[str, Any], gui_queue: Queue, progress_queue: Queue, ack_queue: Queue) -> None: try: file_paths = config['SNIRF_FILES'] file_params = config['PARAMS'] @@ -373,22 +355,24 @@ def gui_entry(config: dict[str, Any], gui_queue: Queue, progress_queue: Queue) - file_paths, file_params, file_metadata, progress_queue, gui_queue, max_workers ) - gui_queue.put({"success": True}) + gui_queue.put({"type": "FINISHED_SUCCESSFULLY", "success": True}) + + try: + print("CHILD: Waiting for GUI acknowledgment...") + ack_queue.get(timeout=10) + except: + print("CHILD: Ack timeout, exiting anyway.") except Exception as e: gui_queue.put({ + "type": "FINISHED_SUCCESSFULLY", "success": False, "error": str(e), "traceback": traceback.format_exc() }) finally: - # Always send done to the thread and avoid hanging - try: - progress_queue.put("__done__") - except: - pass - t.join(timeout=5) # prevent permanent hang + pass def process_participant_worker(file_path, file_params, file_metadata, result_queue, progress_queue): @@ -413,16 +397,18 @@ def process_participant_worker(file_path, file_params, file_metadata, result_que result_queue.put((file_path, result, None)) except Exception as e: - result_queue.put((file_path, None, str(e))) + try: + result_queue.put((file_path, None, str(e))) + except: pass finally: - # --- THE FIX: MANDATORY EXIT --- - # Explicitly flush the logs and force the process to terminate - #audit_log.info(f"Worker for {file_name} calling hard exit.") - sys.stdout.flush() - sys.stderr.flush() - # We use os._exit(0) as a nuclear option if sys.exit() is being caught by a try/except - os._exit(0) + try: + sys.stdout.flush() + sys.stderr.flush() + # Give the queue thread a moment to send the data before we vanish + time.sleep(0.2) + except: pass + sys.exit(0) def process_multiple_participants(file_paths, file_params, file_metadata, progress_queue=None, gui_queue=None, max_workers=6): @@ -432,16 +418,18 @@ def process_multiple_participants(file_paths, file_params, file_metadata, progre active_processes = [] results_by_file = {} - # We use a manager queue so it handles IPC serialization cleanly manager = mp.Manager() result_queue = manager.Queue() try: while pending_files or active_processes: - # 1. SPAWN + for p, f_path in active_processes[:]: + if not p.is_alive(): + p.join(timeout=0.1) + active_processes.remove((p, f_path)) + while len(active_processes) < max_workers and pending_files: file_path = pending_files.pop(0) - p = mp.Process( target=process_participant_worker, args=(file_path, file_params, file_metadata, result_queue, progress_queue) @@ -449,61 +437,42 @@ def process_multiple_participants(file_paths, file_params, file_metadata, progre p.start() active_processes.append((p, file_path)) - # 2. COLLECT RESULTS (Inner Loop) - while True: + if progress_queue: + while not progress_queue.empty(): + try: + prog_msg = progress_queue.get_nowait() + if gui_queue: + # Forward straight to GUI + gui_queue.put(prog_msg, timeout=0.1) + except: break + + while not result_queue.empty(): try: - res_path, result, error = result_queue.get(timeout=0.01) - - for i, (p, f_path) in enumerate(active_processes): - if f_path == res_path: - p.join(timeout=1) # Clean it up - active_processes.pop(i) - break - - if not pending_files and not active_processes: - break - if gui_queue: - # FIX: Added a timeout. If the GUI stops reading, we don't freeze forever. - try: - gui_queue.put({ - "type": "file_done", - "file": res_path, - "success": error is None, - "result": result if error is None else None, - "error": error if error else None - }, timeout=2) # 2-second safety valve - except Exception as e: - print(f"DEBUG: gui_queue full or closed! {e}") + gui_queue.put({ + "type": "file_done", + "file": res_path, + "success": error is None, + "result": result if error is None else None, + "error": error if error else None + }, timeout=2) else: results_by_file[res_path] = result except: - break # Break OUT of the inner loop if queue is empty - - # --- EVERYTHING BELOW IS NOW OUTSIDE THE INNER LOOP --- - - # 3. THE CIRCUIT BREAKER - if not pending_files: - alive_count = sum(1 for p, _ in active_processes if p.is_alive()) - if alive_count == 0 and len(active_processes) > 0: - print("DEBUG: All workers dead but loop persisting. Breaking.") break - - # 4. CLEANUP DEAD WORKERS - for p, f_path in active_processes[:]: - if not p.is_alive(): - p.join(timeout=0.1) - active_processes.remove((p, f_path)) - + + if not pending_files and not active_processes: + print("DEBUG: Loop finished naturally.") + break + + # should no longer hang because pipes are being drained time.sleep(0.1) - print("DEBUG: Loop finished naturally.") - except Exception as e: - print(e) + print(f"MAIN LOOP ERROR: {e}") finally: - # 1. Kill any workers + # Cleanup for p, _ in active_processes: try: if p.is_alive(): @@ -511,12 +480,6 @@ def process_multiple_participants(file_paths, file_params, file_metadata, progre p.join(timeout=0.1) except: pass - # 2. Clear Proxy - try: - del result_queue - except: pass - - # 3. Force Kill Manager if manager is not None: try: manager.shutdown() @@ -1037,9 +1000,10 @@ def interpolate_fNIRS_bads_weighted_average(raw, max_dist=0.03, min_neighbors=2) constrained_layout=True) if n_bad == 1: axes = [axes] # Handle single subplot case - axes_flat = axes.flatten() + axes_flat = np.asarray(axes).get_data().flatten() if hasattr(axes, 'get_data') else np.asarray(axes).ravel() for j in range(n_bad, len(axes_flat)): - axes_flat[j].axis('off') + if j >= n_bad: + axes_flat[j].axis('off') times = raw.times for i, bad_idx in enumerate(bad_pairs): @@ -1188,9 +1152,9 @@ def calculate_peak_power(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5 return list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)), psp1, psp2 -def mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_master, bad_swing): - print(bad_sci, bad_master, bad_psp, bad_snr, bad_swing) - bads_combined = list(set(bad_snr) | set(bad_sci) | set(bad_psp) | set(bad_master) | set(bad_swing)) +def mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_cv, bad_range, bad_noise, bad_disp): + print(bad_sci, bad_snr, bad_psp, bad_cv, bad_range, bad_noise, bad_disp) + bads_combined = list(set(bad_snr) | set(bad_sci) | set(bad_psp) | set(bad_cv) | set(bad_range) | set(bad_noise) | set(bad_disp)) print(f"Automatically marked bad channels based on SNR and SCI: {bads_combined}") raw.info['bads'].extend(bads_combined) @@ -1200,8 +1164,10 @@ def mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_master, bad_swing): (bad_sci, "SCI"), (bad_psp, "PSP"), (bad_snr, "SNR"), - (bad_master, "Master"), - (bad_swing, "Swing"), + (bad_cv, "CV"), + (bad_range, "Range"), + (bad_noise, "Noise"), + (bad_disp, "Disp.") ] # Graph what channels were dropped and why they were dropped @@ -1294,12 +1260,14 @@ def safe_create_epochs(raw, events, event_dict, tmin, tmax, baseline): sample collisions are detected. """ shift_increment = 1.0 / raw.info['sfreq'] # The duration of exactly one sample - + #TODO: User expose this + reject_criteria = dict(hbo=80e-7) for attempt in range(MAX_SHIFT): # Limit attempts to avoid infinite loops try: epochs = Epochs( raw, events, event_id=event_dict, tmin=tmin, tmax=tmax, baseline=baseline, + reject=reject_criteria, preload=True, verbose=False ) return epochs @@ -3627,74 +3595,106 @@ def find_flatline_at_end(raw, threshold_ratio=0.05): # import numpy as np -def master_clean_fnirs(raw, db_limit=-60, threshold_ratio=0.05): + + +def detect_sensor_displacement(raw, threshold_ratio=0.05): + """ + Identifies channels where signal variance drops significantly. + Returns flagged channel names and a summary figure. + """ ch_names = raw.ch_names data = raw.get_data() - sfreq = raw.info['sfreq'] - n_samples = data.shape[1] quarter = n_samples // 4 + ratios = [] dead_idx = [] + for i in range(len(ch_names)): start_var = np.var(data[i, :quarter]) end_var = np.var(data[i, -quarter:]) + # Handle zero variance (dead from the start) if start_var == 0: - dead_idx.append(i) - continue + ratio = 0.0 + else: + ratio = end_var / start_var - ratio = end_var / start_var + ratios.append(ratio) + if ratio < threshold_ratio: dead_idx.append(i) print(f"Flagged {ch_names[i]}: Variance dropped to {ratio:.2%} of original.") - - dead_channels = [ch_names[i] for i in dead_idx] + # Pair-kill logic + failed_bases = {ch_names[i].split(' ')[0] for i in dead_idx} + bad_names = [ch for ch in ch_names if ch.split(' ')[0] in failed_bases] + + # --- Visualization --- + fig_disp, ax = plt.subplots(figsize=(10, 5), constrained_layout=True) + + # Color logic: Coral for channels below the threshold + colors = ['coral' if r < threshold_ratio else 'skyblue' for r in ratios] + + ax.bar(range(len(ratios)), ratios, color=colors) + ax.axhline(threshold_ratio, color='red', linestyle='--', label=f'Threshold ({threshold_ratio:.0%})') + + ax.set_title("Sensor Displacement Check (Variance Stability)") + ax.set_ylabel("Variance Ratio (End / Start)") + ax.set_xlabel("Channel Index") + ax.set_ylim(0, max(ratios + [threshold_ratio * 2])) # Scale to see the threshold clearly + ax.legend() + + plt.close(fig_disp) + + print(f"Displacement Check: Flagged {len(failed_bases)} optode pairs.") + return bad_names, fig_disp + + + +def detect_high_freq_noise(raw, db_limit=-60): + """ + Identifies channels with excessive power at high frequencies + (sfreq/4), usually indicating electronic interference. + """ + ch_names = raw.ch_names + sfreq = raw.info['sfreq'] target_freq = sfreq / 4 + + # Compute PSD spectrum = raw.compute_psd(fmin=0.1, fmax=sfreq/2) psd_data, freqs = spectrum.get_data(return_freqs=True) + + # Find power near the target frequency f_idx = np.where((freqs >= target_freq - 0.2) & (freqs <= target_freq + 0.2))[0] power_at_target = np.mean(psd_data[:, f_idx], axis=1) abs_threshold = 10 ** (db_limit / 10) noisy_idx = np.where(power_at_target > abs_threshold)[0] - noisy_channels = [ch_names[i] for i in noisy_idx] - # --- The Pairing Logic --- - # Combine indices of specifically failed channels - failed_idx_list = np.unique(np.concatenate([dead_idx, noisy_idx])).astype(int) - - # Extract base names (e.g., "S1_D2") of failed channels - # Assuming "S1_D2 762" -> split by space -> "S1_D2" - failed_bases = {str(ch_names[i]).split(' ')[0] for i in failed_idx_list} - - # Flag EVERY channel that matches these bases - all_bads = [ch for ch in ch_names if ch.split(' ')[0] in failed_bases] - - print(f"Analysis Complete:") - print(f" - Found {len(dead_channels)} dead/low-var channels: {dead_channels}") - print(f" - Found {len(noisy_channels)} high-noise (> {db_limit}dB) channels: {noisy_channels}") + # Pair-kill logic + failed_bases = {ch_names[i].split(' ')[0] for i in noisy_idx} + bad_names = [ch for ch in ch_names if ch.split(' ')[0] in failed_bases] # --- Visualization --- - fig_master, ax = plt.subplots(figsize=(8, 4)) + fig, ax = plt.subplots(figsize=(8, 4)) ax.plot(freqs, 10 * np.log10(psd_data.T), color='gray', alpha=0.2) - - # Highlight the specifically noisy ones if len(noisy_idx) > 0: ax.plot(freqs, 10 * np.log10(psd_data[noisy_idx].T), color='plum', label='Noisy Pairs') ax.axhline(db_limit, color='red', linestyle='--', label='Threshold') - ax.set_title("Master Clean: PSD Noise & Dead Channel Detection") + ax.set_title(f"PSD Noise Analysis (Target: {target_freq}Hz)") ax.set_ylabel("Power (dB)") - ax.set_xlabel("Frequency (Hz)") ax.legend() - plt.close(fig_master) + plt.close(fig) - return list(set(all_bads)), fig_master + print(f"Noise Check: Flagged {len(failed_bases)} optode pairs.") + return bad_names, fig -def find_physiologically_impossible_channels(raw, picks, threshold=3.0): + +def find_bad_channels_range(raw, threshold=4.0): + picks = [ch for ch in raw.ch_names] data = raw.get_data(picks=picks) ranges = np.max(data, axis=1) - np.min(data, axis=1) @@ -3704,7 +3704,7 @@ def find_physiologically_impossible_channels(raw, picks, threshold=3.0): z_scores = 0.6745 * (ranges - median_range) / (mad if mad > 0 else 1e-15) # Identify failed bases - failed_indices = np.where(z_scores > threshold)[0] + failed_indices = np.where(np.abs(z_scores) > threshold)[0] failed_bases = {picks[i].split(' ')[0] for i in failed_indices} # Flag entire pairs @@ -3713,10 +3713,11 @@ def find_physiologically_impossible_channels(raw, picks, threshold=3.0): # --- Visualization --- fig_swing, ax = plt.subplots(figsize=(8, 4)) # We color bars by the specific Z-score of that individual channel - colors = ['coral' if z > threshold else 'skyblue' for z in z_scores] + colors = ['coral' if np.abs(z) > threshold else 'skyblue' for z in z_scores] ax.bar(range(len(z_scores)), z_scores, color=colors) ax.axhline(threshold, color='red', linestyle='--', label='Outlier Threshold') + ax.axhline(-threshold, color='red', linestyle='--') ax.set_title("Physiological Swing Analysis (Z-Scores)") ax.set_ylabel("Standardized Deviation") ax.set_xlabel("Channel Index") @@ -3726,6 +3727,62 @@ def find_physiologically_impossible_channels(raw, picks, threshold=3.0): return bad_names, fig_swing + +def find_bad_channels_cv(raw, cv_threshold=25.0): + """ + Identifies bad fNIRS channels using only the Coefficient of Variation (CV). + """ + print(f"\n--- Starting CV-Only Quality Check on the channels ---") + + picks = [ch for ch in raw.ch_names] + data = raw.get_data(picks=picks) + + # Calculate CV (Coefficient of Variation) + stds = np.std(data, axis=1) + means = np.mean(data, axis=1) + # Using a small epsilon (1e-15) to prevent division by zero + cv_scores = (stds / (means + 1e-15)) * 100 + + # Find indices that exceed the threshold + bad_cv_indices = np.where(cv_scores > cv_threshold)[0] + + # Pair-kill logic: If one wavelength (HbO or HbR) fails, flag the pair + failed_bases = set() + for idx in bad_cv_indices: + base = picks[idx].split(' ')[0] + failed_bases.add(base) + + bad_names = [ch for ch in picks if ch.split(' ')[0] in failed_bases] + + # Summary Prints + print(f"CV Check: Found {len(bad_cv_indices)} channels exceeding {cv_threshold}% noise threshold.") + if failed_bases: + print(f"Flagged {len(failed_bases)} optode pairs for removal:") + for base in sorted(failed_bases): + # Find the specific CV for this base (using the first channel found for it) + ch_idx = picks.index(next(p for p in picks if p.startswith(base))) + print(f" - {base}: CV = {cv_scores[ch_idx]:.2f}%") + else: + print("All channels passed the CV check.") + + # --- Visualization --- + fig_qc, ax = plt.subplots(figsize=(10, 5), constrained_layout=True) + + colors = ['coral' if c > cv_threshold else 'skyblue' for c in cv_scores] + ax.bar(range(len(cv_scores)), cv_scores, color=colors) + ax.axhline(cv_threshold, color='red', linestyle='--', label=f'Threshold ({cv_threshold}%)') + + ax.set_title("Coefficient of Variation (Relative Noise)") + ax.set_ylabel("CV %") + ax.set_xlabel("Channel Index") + ax.legend() + + plt.close(fig_qc) + + return bad_names, fig_qc + + + def hr_calc(raw): if SHORT_CHANNEL: short_chans = get_short_channels(raw, max_dist=SHORT_CHANNEL_THRESH) @@ -3862,11 +3919,6 @@ def process_participant(file_path, progress_callback=None): # late_stage_bads = find_mid_run_flatlines(raw) # _ = find_flatline_at_end(raw) # print(f"Channels that died before the end: {late_stage_bads}") - # picks = [ch for ch in raw.ch_names if 'S1_D2' in ch] - bad_master, fig_master = master_clean_fnirs(raw, db_limit=-60) - bad_swing, fig_swing = find_physiologically_impossible_channels(raw, [ch for ch in raw.ch_names], threshold=4.0) - fig_individual["fig_master"] = fig_master - fig_individual["fig_swing"] = fig_swing if progress_callback: progress_callback(5) logger.info("Step 5 Completed.") @@ -3900,9 +3952,22 @@ def process_participant(file_path, progress_callback=None): if progress_callback: progress_callback(8) logger.info("Step 8 Completed.") + # TODO: Add callbacks and user defined parameters + bad_cv, fig_cv = find_bad_channels_cv(raw, cv_threshold=20.0) + fig_individual['cv'] = fig_cv + + bad_range, fig_range = find_bad_channels_range(raw, threshold=3.0) + fig_individual['range'] = fig_range + + bad_noise, fig_noise = detect_high_freq_noise(raw, db_limit=-60) + fig_individual['psd_noise'] = fig_noise + + bad_disp, fig_disp = detect_sensor_displacement(raw, threshold_ratio=0.05) + fig_individual['displacement'] = fig_disp + # Step 9: Bad Channels Handling if BAD_CHANNELS_HANDLING != "None": - raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_master, bad_swing) + raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_cv, bad_range, bad_noise, bad_disp) if fig_dropped and fig_raw_before is not None: fig_individual["fig2"] = fig_dropped fig_individual["fig3"] = fig_raw_before diff --git a/main.py b/main.py index 50bd689..2fe89e1 100644 --- a/main.py +++ b/main.py @@ -5920,11 +5920,12 @@ class MainApplication(QMainWindow): # Start processing if current_process().name == 'MainProcess': self.result_queue = Queue() + self.ack_queue = Queue() self.progress_queue = Queue() self.result_process = Process( target=run_gui_entry_wrapper, - args=(collected_data, self.result_queue, self.progress_queue) + args=(collected_data, self.result_queue, self.progress_queue, self.ack_queue) ) self.result_process.daemon = False self.result_process.start() @@ -5986,6 +5987,29 @@ class MainApplication(QMainWindow): self.show_error_popup(f"Error: {file_path}", error_msg, msg.get("traceback", "")) self.statusbar.showMessage(f"Failed: {os.path.basename(file_path)}") + elif isinstance(msg, dict) and msg.get("type") == "FINISHED_SUCCESSFULLY": + # The child has finished its work AND its own cleanup. + # It is now safe for the GUI to stop the timer and clean up. + try: + self.ack_queue.put("ACK") + except: pass + self.result_timer.stop() + self.cleanup_after_process() + + success_count = len(self.files_results) + fail_count = self.files_total - success_count + self.statusbar.showMessage(f"Complete: {success_count} succeeded, {fail_count} failed.") + + if success_count > 0: + self.button3.setVisible(True) + + # Reset the button + try: self.button1.clicked.disconnect() + except: pass + self.button1.setText("Process") + self.button1.clicked.connect(self.on_run_task) + return # Exit the method + elif isinstance(msg, dict) and msg.get("success") is True: self.statusbar.showMessage("All files processed successfully!") @@ -6003,22 +6027,6 @@ class MainApplication(QMainWindow): _, file_path, step_index = msg self.progress_update_signal.emit(file_path, step_index) - if len(self.files_done) >= self.files_total: - self.result_timer.stop() - self.cleanup_after_process() - - success_count = len(self.files_results) - fail_count = self.files_total - success_count - self.statusbar.showMessage(f"Complete: {success_count} succeeded, {fail_count} failed.") - - if success_count > 0: - self.button3.setVisible(True) - - try: - self.button1.clicked.disconnect() - except: pass - self.button1.setText("Process") - self.button1.clicked.connect(self.on_run_task) except Exception as e: print(f"Error in timer loop: {e}") @@ -6049,7 +6057,7 @@ class MainApplication(QMainWindow): msgbox.setDetailedText(traceback_str) msgbox.setStandardButtons(QMessageBox.Ok) - msgbox.exec_() + msgbox.show() def cleanup_after_process(self): @@ -6074,10 +6082,6 @@ class MainApplication(QMainWindow): self.progress_queue.close() self.progress_queue.join_thread() - # Shutdown manager to kill its server process and clean up - if hasattr(self, 'manager'): - self.manager.shutdown() - def update_file_progress(self, file_path, step_index): key = os.path.normpath(file_path) @@ -6204,14 +6208,14 @@ def _extract_metadata_worker(file_name): print(f"Worker failed on {file_name}: {e}") return None -def run_gui_entry_wrapper(config, gui_queue, progress_queue): +def run_gui_entry_wrapper(config, gui_queue, progress_queue, ack_queue): """ Where the processing happens """ try: import flares - flares.gui_entry(config, gui_queue, progress_queue) + flares.gui_entry(config, gui_queue, progress_queue, ack_queue) gui_queue.close() # gui_queue.join_thread() progress_queue.close()