no more zombie + additional bad channel detection underneath

This commit is contained in:
2026-03-24 16:50:30 -07:00
parent f2c236bec5
commit 7ef07678ae
2 changed files with 478 additions and 88 deletions
+433 -74
View File
@@ -99,26 +99,21 @@ from mne_connectivity import envelope_correlation, spectral_connectivity_epochs,
# Needs to be set for mne
os.environ["SUBJECTS_DIR"] = str(data_path()) + "/subjects" # type: ignore
# TODO: Tidy this up
FIXED_CATEGORY_COLORS = {
PRIMARY_COLORS = {
"SCI only": "skyblue",
"PSP only": "salmon",
"SNR only": "lightgreen",
"PSP + SCI": "orange",
"SCI + SNR": "violet",
"PSP + SNR": "gold",
"SCI + PSP": "orange",
"SNR + SCI": "violet",
"SNR + PSP": "gold",
"PSP + SNR + SCI": "gray",
"SCI + PSP + SNR": "gray",
"SCI + SNR + PSP": "gray",
"PSP + SCI + SNR": "gray",
"PSP + SNR + SCI": "gray",
"SNR + SCI + PSP": "gray",
"SNR + PSP + SCI": "gray",
"Master only": "plum",
"Swing only": "coral",
}
# The Fallback for combinations (the "Multifail" Gray)
COMBINATION_COLOR = "gray"
def get_category_color(label):
"""Returns the primary color if it's a single failure, otherwise gray."""
return PRIMARY_COLORS.get(label, COMBINATION_COLOR)
DOWNSAMPLE: bool
DOWNSAMPLE_FREQUENCY: int
@@ -434,59 +429,99 @@ def process_multiple_participants(file_paths, file_params, file_metadata, progre
#audit_log.info(f"--- SESSION START: {len(file_paths)} files ---")
pending_files = list(file_paths)
active_processes = [] # List of tuples: (Process object, file_path)
active_processes = []
results_by_file = {}
# We use a manager queue so it handles IPC serialization cleanly
manager = mp.Manager()
result_queue = manager.Queue()
# Loop continues as long as there are files to process OR workers still running
while pending_files or active_processes:
try:
while pending_files or active_processes:
# 1. SPAWN
while len(active_processes) < max_workers and pending_files:
file_path = pending_files.pop(0)
p = mp.Process(
target=process_participant_worker,
args=(file_path, file_params, file_metadata, result_queue, progress_queue)
)
p.start()
active_processes.append((p, file_path))
# 2. COLLECT RESULTS (Inner Loop)
while True:
try:
res_path, result, error = result_queue.get(timeout=0.01)
for i, (p, f_path) in enumerate(active_processes):
if f_path == res_path:
p.join(timeout=1) # Clean it up
active_processes.pop(i)
break
if not pending_files and not active_processes:
break
if gui_queue:
# FIX: Added a timeout. If the GUI stops reading, we don't freeze forever.
try:
gui_queue.put({
"type": "file_done",
"file": res_path,
"success": error is None,
"result": result if error is None else None,
"error": error if error else None
}, timeout=2) # 2-second safety valve
except Exception as e:
print(f"DEBUG: gui_queue full or closed! {e}")
else:
results_by_file[res_path] = result
except:
break # Break OUT of the inner loop if queue is empty
# --- EVERYTHING BELOW IS NOW OUTSIDE THE INNER LOOP ---
# 3. THE CIRCUIT BREAKER
if not pending_files:
alive_count = sum(1 for p, _ in active_processes if p.is_alive())
if alive_count == 0 and len(active_processes) > 0:
print("DEBUG: All workers dead but loop persisting. Breaking.")
break
# 4. CLEANUP DEAD WORKERS
for p, f_path in active_processes[:]:
if not p.is_alive():
p.join(timeout=0.1)
active_processes.remove((p, f_path))
time.sleep(0.1)
print("DEBUG: Loop finished naturally.")
# 1. SPWAN WORKERS: Only spawn if we are under the limit AND have files left
while len(active_processes) < max_workers and pending_files:
file_path = pending_files.pop(0)
p = mp.Process(
target=process_participant_worker,
args=(file_path, file_params, file_metadata, result_queue, progress_queue)
)
p.start()
active_processes.append((p, file_path))
#audit_log.info(f"Spawned worker. Active processes: {len(active_processes)}")
# 2. COLLECT RESULTS: Drain the queue continuously so workers don't deadlock
while not result_queue.empty():
except Exception as e:
print(e)
finally:
# 1. Kill any workers
for p, _ in active_processes:
try:
res_path, result, error = result_queue.get_nowait()
# SEND IMMEDIATELY TO THE MAIN GUI
if gui_queue:
gui_queue.put({
"type": "file_done",
"file": res_path,
"success": error is None,
"result": result if error is None else None,
"error": error if error else None
})
else:
# Fallback if no GUI queue (e.g., CLI mode)
results_by_file[res_path] = result
except Exception:
break # Queue is empty or busy
if p.is_alive():
p.terminate()
p.join(timeout=0.1)
except: pass
# 3. CLEANUP: Check for finished processes and remove them
for p, f_path in active_processes[:]: # Iterate over a slice copy
if not p.is_alive():
p.join() # Formally close the process to free OS resources
active_processes.remove((p, f_path))
#audit_log.info(f"Worker finished. Active processes dropping to: {len(active_processes)}")
# 2. Clear Proxy
try:
del result_queue
except: pass
# 3. Force Kill Manager
if manager is not None:
try:
manager.shutdown()
except: pass
# Brief pause to prevent this while loop from pegging your CPU to 100%
time.sleep(0.5)
#audit_log.info("--- SESSION COMPLETE ---")
return results_by_file
@@ -949,7 +984,9 @@ def interpolate_fNIRS_bads_weighted_average(raw, max_dist=0.03, min_neighbors=2)
if len(bad_pairs) == 0:
print("No bad pairs found. Skipping interpolation.")
return raw
return raw, None, None
raw_before_data = raw.get_data().copy()
# Extract locations (use HbO channel loc as pair location)
locs = np.array([raw.info['chs'][hbo_picks[i]]['loc'][:3] for i in range(len(hbo_names))])
@@ -990,6 +1027,43 @@ def interpolate_fNIRS_bads_weighted_average(raw, max_dist=0.03, min_neighbors=2)
interpolated_pairs.append(bad_base)
n_bad = len(bad_pairs)
n_cols = 4 # Fixed width for horizontal scaling
n_rows = int(np.ceil(n_bad / n_cols))
# Calculate height: 2.5 inches per row is usually enough for readability
fig_height = max(4, n_rows * 2.5)
fig_compare, axes = plt.subplots(n_rows, n_cols, figsize=(15, fig_height),
constrained_layout=True)
if n_bad == 1: axes = [axes] # Handle single subplot case
axes_flat = axes.flatten()
for j in range(n_bad, len(axes_flat)):
axes_flat[j].axis('off')
times = raw.times
for i, bad_idx in enumerate(bad_pairs):
ax = axes_flat[i]
base = hbo_names[bad_idx]
# Plot "Before" (Dirty data) in light gray/red
ax.plot(times, raw_before_data[hbo_picks[bad_idx]], color='red', alpha=0.3, label='Original HbO')
# Plot "After" (Interpolated data) in solid blue/green
if base in interpolated_pairs:
ax.plot(times, raw._data[hbo_picks[bad_idx]], color='blue', label='Interpolated HbO')
status = "SUCCESS"
color = "green"
else:
status = "FAILED (Isolated)"
color = "red"
ax.set_title(f"Channel {base} | Status: {status}", color=color, fontweight='bold')
ax.set_ylabel("Amplitude")
if i == 0: ax.legend(loc='upper right')
plt.close(fig_compare)
if interpolated_pairs:
bad_ch_to_remove = []
for base_ in interpolated_pairs:
@@ -1007,7 +1081,7 @@ def interpolate_fNIRS_bads_weighted_average(raw, max_dist=0.03, min_neighbors=2)
fig_raw_after = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="After interpolation", show=False)
return raw, fig_raw_after
return raw, fig_raw_after, fig_compare
@@ -1114,8 +1188,9 @@ def calculate_peak_power(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5
return list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)), psp1, psp2
def mark_bads(raw, bad_sci, bad_snr, bad_psp):
bads_combined = list(set(bad_snr) | set(bad_sci) | set(bad_psp))
def mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_master, bad_swing):
print(bad_sci, bad_master, bad_psp, bad_snr, bad_swing)
bads_combined = list(set(bad_snr) | set(bad_sci) | set(bad_psp) | set(bad_master) | set(bad_swing))
print(f"Automatically marked bad channels based on SNR and SCI: {bads_combined}")
raw.info['bads'].extend(bads_combined)
@@ -1125,6 +1200,8 @@ def mark_bads(raw, bad_sci, bad_snr, bad_psp):
(bad_sci, "SCI"),
(bad_psp, "PSP"),
(bad_snr, "SNR"),
(bad_master, "Master"),
(bad_swing, "Swing"),
]
# Graph what channels were dropped and why they were dropped
@@ -1148,8 +1225,7 @@ def mark_bads(raw, bad_sci, bad_snr, bad_psp):
channel_names.extend(chs_in_cat)
category_labels.extend([cat] * len(chs_in_cat))
colors = {cat: FIXED_CATEGORY_COLORS[cat] for cat in categories}
colors = {cat: get_category_color(cat) for cat in categories}
# Create the figure
fig_dropped, ax = plt.subplots(figsize=(10, max(3, len(channel_names) * 0.3))) # type: ignore
y_pos = range(len(channel_names))
@@ -1879,12 +1955,12 @@ def individual_significance(raw_haemo, glm_est):
# Merge with mean theta (optional for plotting)
mean_theta = activity_ch_summary.groupby('ch_name')['theta'].mean().reset_index()
sig_channels = sig_channels.merge(mean_theta, on='ch_name')
print(sig_channels)
# print(sig_channels)
# For example, take the minimum corrected p-value per channel
summary_pvals = corrected.groupby('ch_name')['pval_fdr'].min().reset_index()
print(summary_pvals)
# print(summary_pvals)
def parse_ch_name(ch_name):
@@ -1916,9 +1992,10 @@ def individual_significance(raw_haemo, glm_est):
SOURCE_DETECTOR_SEPARATOR = "_"
t_or_theta = 'theta'
for _, row in avg_df.iterrows(): # type: ignore
print(f"Source {row['Source']} <-> Detector {row['Detector']}: "
f"Avg {t_or_theta}-value = {row['t_or_theta']:.3f}, Avg p-value = {row['p_value']:.3f}")
#holy log noise
# for _, row in avg_df.iterrows(): # type: ignore
# print(f"Source {row['Source']} <-> Detector {row['Detector']}: "
# f"Avg {t_or_theta}-value = {row['t_or_theta']:.3f}, Avg p-value = {row['p_value']:.3f}")
# Extract the cource and detector positions from raw
src_pos: dict[int, tuple[float, float]] = {}
@@ -3384,6 +3461,270 @@ def plot_heart_rate(
return fig1, fig2
# import numpy as np
# def mark_bads_by_db_threshold(raw, db_limit=-60):
# """
# Converts a dB threshold to absolute power and marks channels
# exceeding it at 6.25 Hz (1/2 Nyquist).
# """
# # 1. Convert your "Magic Number" from dB to absolute units
# abs_threshold = 10 ** (db_limit / 10)
# # 2. Compute PSD specifically around the frequency of interest
# sfreq = raw.info['sfreq']
# target_freq = (sfreq / 2) - 1 # 6.25 Hz
# # We use a small window around 6.25 Hz to get a stable average
# spectrum = raw.compute_psd(fmin=target_freq - 0.5, fmax=target_freq + 0.5)
# psd_data, freqs = spectrum.get_data(return_freqs=True)
# # 3. Get the power at that frequency (averaging across the small window)
# # psd_data shape: (n_channels, n_freqs)
# power_at_6Hz = np.mean(psd_data, axis=1)
# # 4. Identify the "loud" channels
# bad_indices = np.where(power_at_6Hz > abs_threshold)[0]
# new_bads = [raw.ch_names[i] for i in bad_indices]
# # 5. Update raw.info['bads'] without duplicates
# # raw.info['bads'] = list(set(raw.info['bads'] + new_bads))
# print(f"Threshold: {db_limit} dB -> {abs_threshold:.2e} V^2/Hz")
# print(f"Newly potentially bad channels: {new_bads}")
# return new_bads
# def find_flat_channels(raw, threshold=1e-15):
# """
# Identifies channels that are essentially flat lines (zero variance).
# """
# data = raw.get_data()
# # Calculate standard deviation of each channel
# std_devs = np.std(data, axis=1)
# # Identify channels with almost zero movement
# flat_idx = np.where(std_devs < threshold)[0]
# flat_names = [raw.ch_names[i] for i in flat_idx]
# return flat_names
# def find_truly_dead_channels(raw, threshold=1e-20):
# """
# Looks for channels where the signal literally doesn't move
# between samples (successive differences are zero).
# """
# data = raw.get_data()
# # Calculate the absolute difference between every sample (t and t+1)
# # diffs shape: (n_channels, n_times - 1)
# diffs = np.abs(np.diff(data, axis=1))
# # Calculate the mean 'step' size for each channel
# mean_diffs = np.mean(diffs, axis=1)
# # Identify channels where the signal is 'stuck'
# stuck_idx = np.where(mean_diffs < threshold)[0]
# stuck_names = [raw.ch_names[i] for i in stuck_idx]
# return stuck_names, mean_diffs
# def find_mid_run_flatlines(raw, window_size=10.0, threshold=1e-12):
# """
# Checks for channels that go flat partway through the recording.
# window_size: Duration in seconds to check for flatness.
# """
# sfreq = raw.info['sfreq']
# data = raw.get_data()
# n_samples_win = int(window_size * sfreq)
# bad_channels = []
# for i, ch_name in enumerate(raw.ch_names):
# ch_data = data[i]
# # Create sliding windows (non-overlapping for speed)
# windows = np.array_split(ch_data, len(ch_data) // n_samples_win)
# # Calculate variance for each window
# win_vars = [np.var(w) for w in windows]
# # If the LAST window (or any significant portion at the end) is flat
# if win_vars[-1] < threshold:
# bad_channels.append(ch_name)
# return bad_channels
def find_flatline_at_end(raw, threshold_ratio=0.05):
"""
Compares the variance of the first 25% of the data to the last 25%.
If the end variance is less than 5% of the start variance, mark it bad.
"""
data = raw.get_data()
n_samples = data.shape[1]
quarter = n_samples // 4
bad_channels = []
for i, ch_name in enumerate(raw.ch_names):
start_var = np.var(data[i, :quarter])
end_var = np.var(data[i, -quarter:])
# Avoid division by zero for truly dead channels
if start_var == 0:
bad_channels.append(ch_name)
continue
ratio = end_var / start_var
if ratio < threshold_ratio:
bad_channels.append(ch_name)
print(f"Flagged {ch_name}: Variance dropped to {ratio:.2%} of original.")
return bad_channels
# def plot_with_death_lines(raw, picks, window_sec=5.0, var_threshold=0.05):
# sfreq = raw.info['sfreq']
# data, times = raw.get_data(picks=picks, return_times=True)
# fig, ax = plt.subplots(figsize=(12, 6))
# for i, ch_name in enumerate(picks):
# ch_data = data[i]
# ax.plot(times, ch_data, label=ch_name, alpha=0.8)
# # --- Find the 'Death Point' ---
# # Calculate rolling variance in 5-second chunks
# win_samples = int(window_sec * sfreq)
# initial_var = np.var(ch_data[:win_samples])
# # Check from the end backwards to find where it "died"
# death_time = None
# n_samples = len(ch_data)
# for start_idx in range(n_samples - win_samples - 1, 0, -win_samples):
# current_var = np.var(ch_data[start_idx : start_idx + win_samples])
# if current_var > (initial_var * var_threshold):
# # The point right after this is where it stays dead
# death_time = times[start_idx + win_samples]
# break
# if death_time and death_time < (times[-1] - 10): # Only plot if it died early
# ax.axvline(x=death_time, color='red', linestyle='--', alpha=0.5)
# ax.text(death_time, ax.get_ylim()[1], 'Signal Loss',
# color='red', rotation=90, verticalalignment='top')
# ax.set_title("HbO/HbR with Automated Signal Loss Detection")
# ax.set_ylabel("Concentration (Δ μmol)")
# ax.set_xlabel("Time (s)")
# ax.legend(loc='lower right')
# plt.grid(True, alpha=0.3)
# plt.show(block=True)
# import numpy as np
def master_clean_fnirs(raw, db_limit=-60, threshold_ratio=0.05):
ch_names = raw.ch_names
data = raw.get_data()
sfreq = raw.info['sfreq']
n_samples = data.shape[1]
quarter = n_samples // 4
dead_idx = []
for i in range(len(ch_names)):
start_var = np.var(data[i, :quarter])
end_var = np.var(data[i, -quarter:])
if start_var == 0:
dead_idx.append(i)
continue
ratio = end_var / start_var
if ratio < threshold_ratio:
dead_idx.append(i)
print(f"Flagged {ch_names[i]}: Variance dropped to {ratio:.2%} of original.")
dead_channels = [ch_names[i] for i in dead_idx]
target_freq = sfreq / 4
spectrum = raw.compute_psd(fmin=0.1, fmax=sfreq/2)
psd_data, freqs = spectrum.get_data(return_freqs=True)
f_idx = np.where((freqs >= target_freq - 0.2) & (freqs <= target_freq + 0.2))[0]
power_at_target = np.mean(psd_data[:, f_idx], axis=1)
abs_threshold = 10 ** (db_limit / 10)
noisy_idx = np.where(power_at_target > abs_threshold)[0]
noisy_channels = [ch_names[i] for i in noisy_idx]
# --- The Pairing Logic ---
# Combine indices of specifically failed channels
failed_idx_list = np.unique(np.concatenate([dead_idx, noisy_idx])).astype(int)
# Extract base names (e.g., "S1_D2") of failed channels
# Assuming "S1_D2 762" -> split by space -> "S1_D2"
failed_bases = {str(ch_names[i]).split(' ')[0] for i in failed_idx_list}
# Flag EVERY channel that matches these bases
all_bads = [ch for ch in ch_names if ch.split(' ')[0] in failed_bases]
print(f"Analysis Complete:")
print(f" - Found {len(dead_channels)} dead/low-var channels: {dead_channels}")
print(f" - Found {len(noisy_channels)} high-noise (> {db_limit}dB) channels: {noisy_channels}")
# --- Visualization ---
fig_master, ax = plt.subplots(figsize=(8, 4))
ax.plot(freqs, 10 * np.log10(psd_data.T), color='gray', alpha=0.2)
# Highlight the specifically noisy ones
if len(noisy_idx) > 0:
ax.plot(freqs, 10 * np.log10(psd_data[noisy_idx].T), color='plum', label='Noisy Pairs')
ax.axhline(db_limit, color='red', linestyle='--', label='Threshold')
ax.set_title("Master Clean: PSD Noise & Dead Channel Detection")
ax.set_ylabel("Power (dB)")
ax.set_xlabel("Frequency (Hz)")
ax.legend()
plt.close(fig_master)
return list(set(all_bads)), fig_master
def find_physiologically_impossible_channels(raw, picks, threshold=3.0):
data = raw.get_data(picks=picks)
ranges = np.max(data, axis=1) - np.min(data, axis=1)
# Calculate Z-Scores
median_range = np.median(ranges)
mad = np.median(np.abs(ranges - median_range))
z_scores = 0.6745 * (ranges - median_range) / (mad if mad > 0 else 1e-15)
# Identify failed bases
failed_indices = np.where(z_scores > threshold)[0]
failed_bases = {picks[i].split(' ')[0] for i in failed_indices}
# Flag entire pairs
bad_names = [ch for ch in picks if ch.split(' ')[0] in failed_bases]
# --- Visualization ---
fig_swing, ax = plt.subplots(figsize=(8, 4))
# We color bars by the specific Z-score of that individual channel
colors = ['coral' if z > threshold else 'skyblue' for z in z_scores]
ax.bar(range(len(z_scores)), z_scores, color=colors)
ax.axhline(threshold, color='red', linestyle='--', label='Outlier Threshold')
ax.set_title("Physiological Swing Analysis (Z-Scores)")
ax.set_ylabel("Standardized Deviation")
ax.set_xlabel("Channel Index")
ax.legend()
plt.close(fig_swing)
return bad_names, fig_swing
def hr_calc(raw):
if SHORT_CHANNEL:
@@ -3414,7 +3755,7 @@ def hr_calc(raw):
# --- Parameters for PSD ---
desired_bin_hz = 0.1
nperseg = int(sfreq / desired_bin_hz)
hr_range = (30, 180) # TODO: SHould this not use the user defined values?
hr_range = (30, 180) # TODO: Should this not use the user defined values?
# --- Function to find strongest local peak ---
def find_hr_from_psd(ch_data):
@@ -3511,6 +3852,22 @@ def process_participant(file_path, progress_callback=None):
fig_individual["PSD"] = fig
fig_individual['HeartRate_PSD'] = hr1
fig_individual['HeartRate_Time'] = hr2
# --- Run it ---
# mark_bads_by_db_threshold(raw, db_limit=-60)
# dead_channels = find_flat_channels(raw)
# print(f"Dead/Flat channels removed: {dead_channels}")
# stuck_channels, movement_scores = find_truly_dead_channels(raw)
# print(f"Stuck Channels: {stuck_channels}")
# late_stage_bads = find_mid_run_flatlines(raw)
# _ = find_flatline_at_end(raw)
# print(f"Channels that died before the end: {late_stage_bads}")
# picks = [ch for ch in raw.ch_names if 'S1_D2' in ch]
bad_master, fig_master = master_clean_fnirs(raw, db_limit=-60)
bad_swing, fig_swing = find_physiologically_impossible_channels(raw, [ch for ch in raw.ch_names], threshold=4.0)
fig_individual["fig_master"] = fig_master
fig_individual["fig_swing"] = fig_swing
if progress_callback: progress_callback(5)
logger.info("Step 5 Completed.")
@@ -3545,17 +3902,19 @@ def process_participant(file_path, progress_callback=None):
# Step 9: Bad Channels Handling
if BAD_CHANNELS_HANDLING != "None":
raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp)
raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_master, bad_swing)
if fig_dropped and fig_raw_before is not None:
fig_individual["fig2"] = fig_dropped
fig_individual["fig3"] = fig_raw_before
if bad_channels:
if BAD_CHANNELS_HANDLING == "Interpolate":
raw, fig_raw_after = interpolate_fNIRS_bads_weighted_average(raw, max_dist=MAX_DIST, min_neighbors=MIN_NEIGHBORS)
raw, fig_raw_after, fig_compare = interpolate_fNIRS_bads_weighted_average(raw, max_dist=MAX_DIST, min_neighbors=MIN_NEIGHBORS)
fig_individual["fig4"] = fig_raw_after
fig_individual["Compare"] = fig_compare
elif BAD_CHANNELS_HANDLING == "Remove":
pass
#TODO: Is there more needed here?
#NOTE: testing this
raw.pick_types(fnirs=True, exclude='bads')
logger.info(f"Physically removed {len(bad_channels)} channels from the dataset.")
if progress_callback: progress_callback(9)
logger.info("Step 9 Completed.")