fix zombie process & added more quality metrics

This commit is contained in:
2026-03-27 08:10:34 -07:00
parent 94be2fb921
commit 70c4c1e779
2 changed files with 235 additions and 166 deletions
+204 -139
View File
@@ -100,14 +100,16 @@ from mne_connectivity import envelope_correlation, spectral_connectivity_epochs,
os.environ["SUBJECTS_DIR"] = str(data_path()) + "/subjects" # type: ignore
PRIMARY_COLORS = {
"SCI only": "skyblue",
"PSP only": "salmon",
"SNR only": "lightgreen",
"Master only": "plum",
"Swing only": "coral",
"SCI only": "skyblue", # Scalp Coupling Index (Standard MNE)
"SNR only": "lightgreen", # Signal-to-Noise Ratio (Your original)
"PSP only": "salmon", # Power Spectral Peak (Original Noise check)
"CV only": "yellow", # Relative Noise (The CV-only check)
"Range only": "coral", # Z-Swing (The Range Outlier check)
"Noise only": "plum", # High-Freq PSD (The Noise check)
"Disp. only": "palegreen", # Sensor Displacement (Variance Drop)
"Multiple": "orangered" # Failed 2+ categories
}
# The Fallback for combinations (the "Multifail" Gray)
COMBINATION_COLOR = "gray"
def get_category_color(label):
@@ -342,27 +344,7 @@ def set_metadata(file_path, metadata: dict[str, Any]) -> None:
if val not in (None, '', [], {}, ()): # check for "empty" values
globals()[key] = val
def gui_entry(config: dict[str, Any], gui_queue: Queue, progress_queue: Queue) -> None:
def forward_progress():
while True:
try:
msg = progress_queue.get(timeout=1)
if msg == "__done__":
break
gui_queue.put(msg)
except Empty:
continue
except Exception as e:
gui_queue.put({
"type": "error",
"error": f"Forwarding thread crashed: {e}",
"traceback": traceback.format_exc()
})
break
t = threading.Thread(target=forward_progress, daemon=True)
t.start()
def gui_entry(config: dict[str, Any], gui_queue: Queue, progress_queue: Queue, ack_queue: Queue) -> None:
try:
file_paths = config['SNIRF_FILES']
file_params = config['PARAMS']
@@ -373,22 +355,24 @@ def gui_entry(config: dict[str, Any], gui_queue: Queue, progress_queue: Queue) -
file_paths, file_params, file_metadata, progress_queue, gui_queue, max_workers
)
gui_queue.put({"success": True})
gui_queue.put({"type": "FINISHED_SUCCESSFULLY", "success": True})
try:
print("CHILD: Waiting for GUI acknowledgment...")
ack_queue.get(timeout=10)
except:
print("CHILD: Ack timeout, exiting anyway.")
except Exception as e:
gui_queue.put({
"type": "FINISHED_SUCCESSFULLY",
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
})
finally:
# Always send done to the thread and avoid hanging
try:
progress_queue.put("__done__")
except:
pass
t.join(timeout=5) # prevent permanent hang
pass
def process_participant_worker(file_path, file_params, file_metadata, result_queue, progress_queue):
@@ -413,16 +397,18 @@ def process_participant_worker(file_path, file_params, file_metadata, result_que
result_queue.put((file_path, result, None))
except Exception as e:
result_queue.put((file_path, None, str(e)))
try:
result_queue.put((file_path, None, str(e)))
except: pass
finally:
# --- THE FIX: MANDATORY EXIT ---
# Explicitly flush the logs and force the process to terminate
#audit_log.info(f"Worker for {file_name} calling hard exit.")
sys.stdout.flush()
sys.stderr.flush()
# We use os._exit(0) as a nuclear option if sys.exit() is being caught by a try/except
os._exit(0)
try:
sys.stdout.flush()
sys.stderr.flush()
# Give the queue thread a moment to send the data before we vanish
time.sleep(0.2)
except: pass
sys.exit(0)
def process_multiple_participants(file_paths, file_params, file_metadata, progress_queue=None, gui_queue=None, max_workers=6):
@@ -432,16 +418,18 @@ def process_multiple_participants(file_paths, file_params, file_metadata, progre
active_processes = []
results_by_file = {}
# We use a manager queue so it handles IPC serialization cleanly
manager = mp.Manager()
result_queue = manager.Queue()
try:
while pending_files or active_processes:
# 1. SPAWN
for p, f_path in active_processes[:]:
if not p.is_alive():
p.join(timeout=0.1)
active_processes.remove((p, f_path))
while len(active_processes) < max_workers and pending_files:
file_path = pending_files.pop(0)
p = mp.Process(
target=process_participant_worker,
args=(file_path, file_params, file_metadata, result_queue, progress_queue)
@@ -449,61 +437,42 @@ def process_multiple_participants(file_paths, file_params, file_metadata, progre
p.start()
active_processes.append((p, file_path))
# 2. COLLECT RESULTS (Inner Loop)
while True:
if progress_queue:
while not progress_queue.empty():
try:
prog_msg = progress_queue.get_nowait()
if gui_queue:
# Forward straight to GUI
gui_queue.put(prog_msg, timeout=0.1)
except: break
while not result_queue.empty():
try:
res_path, result, error = result_queue.get(timeout=0.01)
for i, (p, f_path) in enumerate(active_processes):
if f_path == res_path:
p.join(timeout=1) # Clean it up
active_processes.pop(i)
break
if not pending_files and not active_processes:
break
if gui_queue:
# FIX: Added a timeout. If the GUI stops reading, we don't freeze forever.
try:
gui_queue.put({
"type": "file_done",
"file": res_path,
"success": error is None,
"result": result if error is None else None,
"error": error if error else None
}, timeout=2) # 2-second safety valve
except Exception as e:
print(f"DEBUG: gui_queue full or closed! {e}")
gui_queue.put({
"type": "file_done",
"file": res_path,
"success": error is None,
"result": result if error is None else None,
"error": error if error else None
}, timeout=2)
else:
results_by_file[res_path] = result
except:
break # Break OUT of the inner loop if queue is empty
# --- EVERYTHING BELOW IS NOW OUTSIDE THE INNER LOOP ---
# 3. THE CIRCUIT BREAKER
if not pending_files:
alive_count = sum(1 for p, _ in active_processes if p.is_alive())
if alive_count == 0 and len(active_processes) > 0:
print("DEBUG: All workers dead but loop persisting. Breaking.")
break
# 4. CLEANUP DEAD WORKERS
for p, f_path in active_processes[:]:
if not p.is_alive():
p.join(timeout=0.1)
active_processes.remove((p, f_path))
if not pending_files and not active_processes:
print("DEBUG: Loop finished naturally.")
break
# should no longer hang because pipes are being drained
time.sleep(0.1)
print("DEBUG: Loop finished naturally.")
except Exception as e:
print(e)
print(f"MAIN LOOP ERROR: {e}")
finally:
# 1. Kill any workers
# Cleanup
for p, _ in active_processes:
try:
if p.is_alive():
@@ -511,12 +480,6 @@ def process_multiple_participants(file_paths, file_params, file_metadata, progre
p.join(timeout=0.1)
except: pass
# 2. Clear Proxy
try:
del result_queue
except: pass
# 3. Force Kill Manager
if manager is not None:
try:
manager.shutdown()
@@ -1037,9 +1000,10 @@ def interpolate_fNIRS_bads_weighted_average(raw, max_dist=0.03, min_neighbors=2)
constrained_layout=True)
if n_bad == 1: axes = [axes] # Handle single subplot case
axes_flat = axes.flatten()
axes_flat = np.asarray(axes).get_data().flatten() if hasattr(axes, 'get_data') else np.asarray(axes).ravel()
for j in range(n_bad, len(axes_flat)):
axes_flat[j].axis('off')
if j >= n_bad:
axes_flat[j].axis('off')
times = raw.times
for i, bad_idx in enumerate(bad_pairs):
@@ -1188,9 +1152,9 @@ def calculate_peak_power(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5
return list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)), psp1, psp2
def mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_master, bad_swing):
print(bad_sci, bad_master, bad_psp, bad_snr, bad_swing)
bads_combined = list(set(bad_snr) | set(bad_sci) | set(bad_psp) | set(bad_master) | set(bad_swing))
def mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_cv, bad_range, bad_noise, bad_disp):
print(bad_sci, bad_snr, bad_psp, bad_cv, bad_range, bad_noise, bad_disp)
bads_combined = list(set(bad_snr) | set(bad_sci) | set(bad_psp) | set(bad_cv) | set(bad_range) | set(bad_noise) | set(bad_disp))
print(f"Automatically marked bad channels based on SNR and SCI: {bads_combined}")
raw.info['bads'].extend(bads_combined)
@@ -1200,8 +1164,10 @@ def mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_master, bad_swing):
(bad_sci, "SCI"),
(bad_psp, "PSP"),
(bad_snr, "SNR"),
(bad_master, "Master"),
(bad_swing, "Swing"),
(bad_cv, "CV"),
(bad_range, "Range"),
(bad_noise, "Noise"),
(bad_disp, "Disp.")
]
# Graph what channels were dropped and why they were dropped
@@ -1294,12 +1260,14 @@ def safe_create_epochs(raw, events, event_dict, tmin, tmax, baseline):
sample collisions are detected.
"""
shift_increment = 1.0 / raw.info['sfreq'] # The duration of exactly one sample
#TODO: User expose this
reject_criteria = dict(hbo=80e-7)
for attempt in range(MAX_SHIFT): # Limit attempts to avoid infinite loops
try:
epochs = Epochs(
raw, events, event_id=event_dict,
tmin=tmin, tmax=tmax, baseline=baseline,
reject=reject_criteria,
preload=True, verbose=False
)
return epochs
@@ -3627,74 +3595,106 @@ def find_flatline_at_end(raw, threshold_ratio=0.05):
# import numpy as np
def master_clean_fnirs(raw, db_limit=-60, threshold_ratio=0.05):
def detect_sensor_displacement(raw, threshold_ratio=0.05):
"""
Identifies channels where signal variance drops significantly.
Returns flagged channel names and a summary figure.
"""
ch_names = raw.ch_names
data = raw.get_data()
sfreq = raw.info['sfreq']
n_samples = data.shape[1]
quarter = n_samples // 4
ratios = []
dead_idx = []
for i in range(len(ch_names)):
start_var = np.var(data[i, :quarter])
end_var = np.var(data[i, -quarter:])
# Handle zero variance (dead from the start)
if start_var == 0:
dead_idx.append(i)
continue
ratio = 0.0
else:
ratio = end_var / start_var
ratios.append(ratio)
ratio = end_var / start_var
if ratio < threshold_ratio:
dead_idx.append(i)
print(f"Flagged {ch_names[i]}: Variance dropped to {ratio:.2%} of original.")
dead_channels = [ch_names[i] for i in dead_idx]
# Pair-kill logic
failed_bases = {ch_names[i].split(' ')[0] for i in dead_idx}
bad_names = [ch for ch in ch_names if ch.split(' ')[0] in failed_bases]
# --- Visualization ---
fig_disp, ax = plt.subplots(figsize=(10, 5), constrained_layout=True)
# Color logic: Coral for channels below the threshold
colors = ['coral' if r < threshold_ratio else 'skyblue' for r in ratios]
ax.bar(range(len(ratios)), ratios, color=colors)
ax.axhline(threshold_ratio, color='red', linestyle='--', label=f'Threshold ({threshold_ratio:.0%})')
ax.set_title("Sensor Displacement Check (Variance Stability)")
ax.set_ylabel("Variance Ratio (End / Start)")
ax.set_xlabel("Channel Index")
ax.set_ylim(0, max(ratios + [threshold_ratio * 2])) # Scale to see the threshold clearly
ax.legend()
plt.close(fig_disp)
print(f"Displacement Check: Flagged {len(failed_bases)} optode pairs.")
return bad_names, fig_disp
def detect_high_freq_noise(raw, db_limit=-60):
"""
Identifies channels with excessive power at high frequencies
(sfreq/4), usually indicating electronic interference.
"""
ch_names = raw.ch_names
sfreq = raw.info['sfreq']
target_freq = sfreq / 4
# Compute PSD
spectrum = raw.compute_psd(fmin=0.1, fmax=sfreq/2)
psd_data, freqs = spectrum.get_data(return_freqs=True)
# Find power near the target frequency
f_idx = np.where((freqs >= target_freq - 0.2) & (freqs <= target_freq + 0.2))[0]
power_at_target = np.mean(psd_data[:, f_idx], axis=1)
abs_threshold = 10 ** (db_limit / 10)
noisy_idx = np.where(power_at_target > abs_threshold)[0]
noisy_channels = [ch_names[i] for i in noisy_idx]
# --- The Pairing Logic ---
# Combine indices of specifically failed channels
failed_idx_list = np.unique(np.concatenate([dead_idx, noisy_idx])).astype(int)
# Extract base names (e.g., "S1_D2") of failed channels
# Assuming "S1_D2 762" -> split by space -> "S1_D2"
failed_bases = {str(ch_names[i]).split(' ')[0] for i in failed_idx_list}
# Flag EVERY channel that matches these bases
all_bads = [ch for ch in ch_names if ch.split(' ')[0] in failed_bases]
print(f"Analysis Complete:")
print(f" - Found {len(dead_channels)} dead/low-var channels: {dead_channels}")
print(f" - Found {len(noisy_channels)} high-noise (> {db_limit}dB) channels: {noisy_channels}")
# Pair-kill logic
failed_bases = {ch_names[i].split(' ')[0] for i in noisy_idx}
bad_names = [ch for ch in ch_names if ch.split(' ')[0] in failed_bases]
# --- Visualization ---
fig_master, ax = plt.subplots(figsize=(8, 4))
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(freqs, 10 * np.log10(psd_data.T), color='gray', alpha=0.2)
# Highlight the specifically noisy ones
if len(noisy_idx) > 0:
ax.plot(freqs, 10 * np.log10(psd_data[noisy_idx].T), color='plum', label='Noisy Pairs')
ax.axhline(db_limit, color='red', linestyle='--', label='Threshold')
ax.set_title("Master Clean: PSD Noise & Dead Channel Detection")
ax.set_title(f"PSD Noise Analysis (Target: {target_freq}Hz)")
ax.set_ylabel("Power (dB)")
ax.set_xlabel("Frequency (Hz)")
ax.legend()
plt.close(fig_master)
plt.close(fig)
return list(set(all_bads)), fig_master
print(f"Noise Check: Flagged {len(failed_bases)} optode pairs.")
return bad_names, fig
def find_physiologically_impossible_channels(raw, picks, threshold=3.0):
def find_bad_channels_range(raw, threshold=4.0):
picks = [ch for ch in raw.ch_names]
data = raw.get_data(picks=picks)
ranges = np.max(data, axis=1) - np.min(data, axis=1)
@@ -3704,7 +3704,7 @@ def find_physiologically_impossible_channels(raw, picks, threshold=3.0):
z_scores = 0.6745 * (ranges - median_range) / (mad if mad > 0 else 1e-15)
# Identify failed bases
failed_indices = np.where(z_scores > threshold)[0]
failed_indices = np.where(np.abs(z_scores) > threshold)[0]
failed_bases = {picks[i].split(' ')[0] for i in failed_indices}
# Flag entire pairs
@@ -3713,10 +3713,11 @@ def find_physiologically_impossible_channels(raw, picks, threshold=3.0):
# --- Visualization ---
fig_swing, ax = plt.subplots(figsize=(8, 4))
# We color bars by the specific Z-score of that individual channel
colors = ['coral' if z > threshold else 'skyblue' for z in z_scores]
colors = ['coral' if np.abs(z) > threshold else 'skyblue' for z in z_scores]
ax.bar(range(len(z_scores)), z_scores, color=colors)
ax.axhline(threshold, color='red', linestyle='--', label='Outlier Threshold')
ax.axhline(-threshold, color='red', linestyle='--')
ax.set_title("Physiological Swing Analysis (Z-Scores)")
ax.set_ylabel("Standardized Deviation")
ax.set_xlabel("Channel Index")
@@ -3726,6 +3727,62 @@ def find_physiologically_impossible_channels(raw, picks, threshold=3.0):
return bad_names, fig_swing
def find_bad_channels_cv(raw, cv_threshold=25.0):
"""
Identifies bad fNIRS channels using only the Coefficient of Variation (CV).
"""
print(f"\n--- Starting CV-Only Quality Check on the channels ---")
picks = [ch for ch in raw.ch_names]
data = raw.get_data(picks=picks)
# Calculate CV (Coefficient of Variation)
stds = np.std(data, axis=1)
means = np.mean(data, axis=1)
# Using a small epsilon (1e-15) to prevent division by zero
cv_scores = (stds / (means + 1e-15)) * 100
# Find indices that exceed the threshold
bad_cv_indices = np.where(cv_scores > cv_threshold)[0]
# Pair-kill logic: If one wavelength (HbO or HbR) fails, flag the pair
failed_bases = set()
for idx in bad_cv_indices:
base = picks[idx].split(' ')[0]
failed_bases.add(base)
bad_names = [ch for ch in picks if ch.split(' ')[0] in failed_bases]
# Summary Prints
print(f"CV Check: Found {len(bad_cv_indices)} channels exceeding {cv_threshold}% noise threshold.")
if failed_bases:
print(f"Flagged {len(failed_bases)} optode pairs for removal:")
for base in sorted(failed_bases):
# Find the specific CV for this base (using the first channel found for it)
ch_idx = picks.index(next(p for p in picks if p.startswith(base)))
print(f" - {base}: CV = {cv_scores[ch_idx]:.2f}%")
else:
print("All channels passed the CV check.")
# --- Visualization ---
fig_qc, ax = plt.subplots(figsize=(10, 5), constrained_layout=True)
colors = ['coral' if c > cv_threshold else 'skyblue' for c in cv_scores]
ax.bar(range(len(cv_scores)), cv_scores, color=colors)
ax.axhline(cv_threshold, color='red', linestyle='--', label=f'Threshold ({cv_threshold}%)')
ax.set_title("Coefficient of Variation (Relative Noise)")
ax.set_ylabel("CV %")
ax.set_xlabel("Channel Index")
ax.legend()
plt.close(fig_qc)
return bad_names, fig_qc
def hr_calc(raw):
if SHORT_CHANNEL:
short_chans = get_short_channels(raw, max_dist=SHORT_CHANNEL_THRESH)
@@ -3862,11 +3919,6 @@ def process_participant(file_path, progress_callback=None):
# late_stage_bads = find_mid_run_flatlines(raw)
# _ = find_flatline_at_end(raw)
# print(f"Channels that died before the end: {late_stage_bads}")
# picks = [ch for ch in raw.ch_names if 'S1_D2' in ch]
bad_master, fig_master = master_clean_fnirs(raw, db_limit=-60)
bad_swing, fig_swing = find_physiologically_impossible_channels(raw, [ch for ch in raw.ch_names], threshold=4.0)
fig_individual["fig_master"] = fig_master
fig_individual["fig_swing"] = fig_swing
if progress_callback: progress_callback(5)
logger.info("Step 5 Completed.")
@@ -3900,9 +3952,22 @@ def process_participant(file_path, progress_callback=None):
if progress_callback: progress_callback(8)
logger.info("Step 8 Completed.")
# TODO: Add callbacks and user defined parameters
bad_cv, fig_cv = find_bad_channels_cv(raw, cv_threshold=20.0)
fig_individual['cv'] = fig_cv
bad_range, fig_range = find_bad_channels_range(raw, threshold=3.0)
fig_individual['range'] = fig_range
bad_noise, fig_noise = detect_high_freq_noise(raw, db_limit=-60)
fig_individual['psd_noise'] = fig_noise
bad_disp, fig_disp = detect_sensor_displacement(raw, threshold_ratio=0.05)
fig_individual['displacement'] = fig_disp
# Step 9: Bad Channels Handling
if BAD_CHANNELS_HANDLING != "None":
raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_master, bad_swing)
raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_cv, bad_range, bad_noise, bad_disp)
if fig_dropped and fig_raw_before is not None:
fig_individual["fig2"] = fig_dropped
fig_individual["fig3"] = fig_raw_before
+28 -24
View File
@@ -5920,11 +5920,12 @@ class MainApplication(QMainWindow):
# Start processing
if current_process().name == 'MainProcess':
self.result_queue = Queue()
self.ack_queue = Queue()
self.progress_queue = Queue()
self.result_process = Process(
target=run_gui_entry_wrapper,
args=(collected_data, self.result_queue, self.progress_queue)
args=(collected_data, self.result_queue, self.progress_queue, self.ack_queue)
)
self.result_process.daemon = False
self.result_process.start()
@@ -5986,6 +5987,29 @@ class MainApplication(QMainWindow):
self.show_error_popup(f"Error: {file_path}", error_msg, msg.get("traceback", ""))
self.statusbar.showMessage(f"Failed: {os.path.basename(file_path)}")
elif isinstance(msg, dict) and msg.get("type") == "FINISHED_SUCCESSFULLY":
# The child has finished its work AND its own cleanup.
# It is now safe for the GUI to stop the timer and clean up.
try:
self.ack_queue.put("ACK")
except: pass
self.result_timer.stop()
self.cleanup_after_process()
success_count = len(self.files_results)
fail_count = self.files_total - success_count
self.statusbar.showMessage(f"Complete: {success_count} succeeded, {fail_count} failed.")
if success_count > 0:
self.button3.setVisible(True)
# Reset the button
try: self.button1.clicked.disconnect()
except: pass
self.button1.setText("Process")
self.button1.clicked.connect(self.on_run_task)
return # Exit the method
elif isinstance(msg, dict) and msg.get("success") is True:
self.statusbar.showMessage("All files processed successfully!")
@@ -6003,22 +6027,6 @@ class MainApplication(QMainWindow):
_, file_path, step_index = msg
self.progress_update_signal.emit(file_path, step_index)
if len(self.files_done) >= self.files_total:
self.result_timer.stop()
self.cleanup_after_process()
success_count = len(self.files_results)
fail_count = self.files_total - success_count
self.statusbar.showMessage(f"Complete: {success_count} succeeded, {fail_count} failed.")
if success_count > 0:
self.button3.setVisible(True)
try:
self.button1.clicked.disconnect()
except: pass
self.button1.setText("Process")
self.button1.clicked.connect(self.on_run_task)
except Exception as e:
print(f"Error in timer loop: {e}")
@@ -6049,7 +6057,7 @@ class MainApplication(QMainWindow):
msgbox.setDetailedText(traceback_str)
msgbox.setStandardButtons(QMessageBox.Ok)
msgbox.exec_()
msgbox.show()
def cleanup_after_process(self):
@@ -6074,10 +6082,6 @@ class MainApplication(QMainWindow):
self.progress_queue.close()
self.progress_queue.join_thread()
# Shutdown manager to kill its server process and clean up
if hasattr(self, 'manager'):
self.manager.shutdown()
def update_file_progress(self, file_path, step_index):
key = os.path.normpath(file_path)
@@ -6204,14 +6208,14 @@ def _extract_metadata_worker(file_name):
print(f"Worker failed on {file_name}: {e}")
return None
def run_gui_entry_wrapper(config, gui_queue, progress_queue):
def run_gui_entry_wrapper(config, gui_queue, progress_queue, ack_queue):
"""
Where the processing happens
"""
try:
import flares
flares.gui_entry(config, gui_queue, progress_queue)
flares.gui_entry(config, gui_queue, progress_queue, ack_queue)
gui_queue.close()
# gui_queue.join_thread()
progress_queue.close()