no more zombie + additional bad channel detection underneath

This commit is contained in:
2026-03-24 16:50:30 -07:00
parent f2c236bec5
commit 7ef07678ae
2 changed files with 478 additions and 88 deletions
+410 -51
View File
@@ -99,26 +99,21 @@ from mne_connectivity import envelope_correlation, spectral_connectivity_epochs,
# Needs to be set for mne # Needs to be set for mne
os.environ["SUBJECTS_DIR"] = str(data_path()) + "/subjects" # type: ignore os.environ["SUBJECTS_DIR"] = str(data_path()) + "/subjects" # type: ignore
# TODO: Tidy this up PRIMARY_COLORS = {
FIXED_CATEGORY_COLORS = {
"SCI only": "skyblue", "SCI only": "skyblue",
"PSP only": "salmon", "PSP only": "salmon",
"SNR only": "lightgreen", "SNR only": "lightgreen",
"PSP + SCI": "orange", "Master only": "plum",
"SCI + SNR": "violet", "Swing only": "coral",
"PSP + SNR": "gold",
"SCI + PSP": "orange",
"SNR + SCI": "violet",
"SNR + PSP": "gold",
"PSP + SNR + SCI": "gray",
"SCI + PSP + SNR": "gray",
"SCI + SNR + PSP": "gray",
"PSP + SCI + SNR": "gray",
"PSP + SNR + SCI": "gray",
"SNR + SCI + PSP": "gray",
"SNR + PSP + SCI": "gray",
} }
# The Fallback for combinations (the "Multifail" Gray)
COMBINATION_COLOR = "gray"
def get_category_color(label):
"""Returns the primary color if it's a single failure, otherwise gray."""
return PRIMARY_COLORS.get(label, COMBINATION_COLOR)
DOWNSAMPLE: bool DOWNSAMPLE: bool
DOWNSAMPLE_FREQUENCY: int DOWNSAMPLE_FREQUENCY: int
@@ -434,17 +429,16 @@ def process_multiple_participants(file_paths, file_params, file_metadata, progre
#audit_log.info(f"--- SESSION START: {len(file_paths)} files ---") #audit_log.info(f"--- SESSION START: {len(file_paths)} files ---")
pending_files = list(file_paths) pending_files = list(file_paths)
active_processes = [] # List of tuples: (Process object, file_path) active_processes = []
results_by_file = {} results_by_file = {}
# We use a manager queue so it handles IPC serialization cleanly # We use a manager queue so it handles IPC serialization cleanly
manager = mp.Manager() manager = mp.Manager()
result_queue = manager.Queue() result_queue = manager.Queue()
# Loop continues as long as there are files to process OR workers still running try:
while pending_files or active_processes: while pending_files or active_processes:
# 1. SPAWN
# 1. SPWAN WORKERS: Only spawn if we are under the limit AND have files left
while len(active_processes) < max_workers and pending_files: while len(active_processes) < max_workers and pending_files:
file_path = pending_files.pop(0) file_path = pending_files.pop(0)
@@ -454,39 +448,80 @@ def process_multiple_participants(file_paths, file_params, file_metadata, progre
) )
p.start() p.start()
active_processes.append((p, file_path)) active_processes.append((p, file_path))
#audit_log.info(f"Spawned worker. Active processes: {len(active_processes)}")
# 2. COLLECT RESULTS: Drain the queue continuously so workers don't deadlock # 2. COLLECT RESULTS (Inner Loop)
while not result_queue.empty(): while True:
try: try:
res_path, result, error = result_queue.get_nowait()
# SEND IMMEDIATELY TO THE MAIN GUI res_path, result, error = result_queue.get(timeout=0.01)
for i, (p, f_path) in enumerate(active_processes):
if f_path == res_path:
p.join(timeout=1) # Clean it up
active_processes.pop(i)
break
if not pending_files and not active_processes:
break
if gui_queue: if gui_queue:
# FIX: Added a timeout. If the GUI stops reading, we don't freeze forever.
try:
gui_queue.put({ gui_queue.put({
"type": "file_done", "type": "file_done",
"file": res_path, "file": res_path,
"success": error is None, "success": error is None,
"result": result if error is None else None, "result": result if error is None else None,
"error": error if error else None "error": error if error else None
}) }, timeout=2) # 2-second safety valve
except Exception as e:
print(f"DEBUG: gui_queue full or closed! {e}")
else: else:
# Fallback if no GUI queue (e.g., CLI mode)
results_by_file[res_path] = result results_by_file[res_path] = result
except:
break # Break OUT of the inner loop if queue is empty
except Exception: # --- EVERYTHING BELOW IS NOW OUTSIDE THE INNER LOOP ---
break # Queue is empty or busy
# 3. CLEANUP: Check for finished processes and remove them # 3. THE CIRCUIT BREAKER
for p, f_path in active_processes[:]: # Iterate over a slice copy if not pending_files:
alive_count = sum(1 for p, _ in active_processes if p.is_alive())
if alive_count == 0 and len(active_processes) > 0:
print("DEBUG: All workers dead but loop persisting. Breaking.")
break
# 4. CLEANUP DEAD WORKERS
for p, f_path in active_processes[:]:
if not p.is_alive(): if not p.is_alive():
p.join() # Formally close the process to free OS resources p.join(timeout=0.1)
active_processes.remove((p, f_path)) active_processes.remove((p, f_path))
#audit_log.info(f"Worker finished. Active processes dropping to: {len(active_processes)}")
# Brief pause to prevent this while loop from pegging your CPU to 100% time.sleep(0.1)
time.sleep(0.5)
print("DEBUG: Loop finished naturally.")
except Exception as e:
print(e)
finally:
# 1. Kill any workers
for p, _ in active_processes:
try:
if p.is_alive():
p.terminate()
p.join(timeout=0.1)
except: pass
# 2. Clear Proxy
try:
del result_queue
except: pass
# 3. Force Kill Manager
if manager is not None:
try:
manager.shutdown()
except: pass
#audit_log.info("--- SESSION COMPLETE ---")
return results_by_file return results_by_file
@@ -949,7 +984,9 @@ def interpolate_fNIRS_bads_weighted_average(raw, max_dist=0.03, min_neighbors=2)
if len(bad_pairs) == 0: if len(bad_pairs) == 0:
print("No bad pairs found. Skipping interpolation.") print("No bad pairs found. Skipping interpolation.")
return raw return raw, None, None
raw_before_data = raw.get_data().copy()
# Extract locations (use HbO channel loc as pair location) # Extract locations (use HbO channel loc as pair location)
locs = np.array([raw.info['chs'][hbo_picks[i]]['loc'][:3] for i in range(len(hbo_names))]) locs = np.array([raw.info['chs'][hbo_picks[i]]['loc'][:3] for i in range(len(hbo_names))])
@@ -990,6 +1027,43 @@ def interpolate_fNIRS_bads_weighted_average(raw, max_dist=0.03, min_neighbors=2)
interpolated_pairs.append(bad_base) interpolated_pairs.append(bad_base)
n_bad = len(bad_pairs)
n_cols = 4 # Fixed width for horizontal scaling
n_rows = int(np.ceil(n_bad / n_cols))
# Calculate height: 2.5 inches per row is usually enough for readability
fig_height = max(4, n_rows * 2.5)
fig_compare, axes = plt.subplots(n_rows, n_cols, figsize=(15, fig_height),
constrained_layout=True)
if n_bad == 1: axes = [axes] # Handle single subplot case
axes_flat = axes.flatten()
for j in range(n_bad, len(axes_flat)):
axes_flat[j].axis('off')
times = raw.times
for i, bad_idx in enumerate(bad_pairs):
ax = axes_flat[i]
base = hbo_names[bad_idx]
# Plot "Before" (Dirty data) in light gray/red
ax.plot(times, raw_before_data[hbo_picks[bad_idx]], color='red', alpha=0.3, label='Original HbO')
# Plot "After" (Interpolated data) in solid blue/green
if base in interpolated_pairs:
ax.plot(times, raw._data[hbo_picks[bad_idx]], color='blue', label='Interpolated HbO')
status = "SUCCESS"
color = "green"
else:
status = "FAILED (Isolated)"
color = "red"
ax.set_title(f"Channel {base} | Status: {status}", color=color, fontweight='bold')
ax.set_ylabel("Amplitude")
if i == 0: ax.legend(loc='upper right')
plt.close(fig_compare)
if interpolated_pairs: if interpolated_pairs:
bad_ch_to_remove = [] bad_ch_to_remove = []
for base_ in interpolated_pairs: for base_ in interpolated_pairs:
@@ -1007,7 +1081,7 @@ def interpolate_fNIRS_bads_weighted_average(raw, max_dist=0.03, min_neighbors=2)
fig_raw_after = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="After interpolation", show=False) fig_raw_after = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="After interpolation", show=False)
return raw, fig_raw_after return raw, fig_raw_after, fig_compare
@@ -1114,8 +1188,9 @@ def calculate_peak_power(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5
return list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)), psp1, psp2 return list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)), psp1, psp2
def mark_bads(raw, bad_sci, bad_snr, bad_psp): def mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_master, bad_swing):
bads_combined = list(set(bad_snr) | set(bad_sci) | set(bad_psp)) print(bad_sci, bad_master, bad_psp, bad_snr, bad_swing)
bads_combined = list(set(bad_snr) | set(bad_sci) | set(bad_psp) | set(bad_master) | set(bad_swing))
print(f"Automatically marked bad channels based on SNR and SCI: {bads_combined}") print(f"Automatically marked bad channels based on SNR and SCI: {bads_combined}")
raw.info['bads'].extend(bads_combined) raw.info['bads'].extend(bads_combined)
@@ -1125,6 +1200,8 @@ def mark_bads(raw, bad_sci, bad_snr, bad_psp):
(bad_sci, "SCI"), (bad_sci, "SCI"),
(bad_psp, "PSP"), (bad_psp, "PSP"),
(bad_snr, "SNR"), (bad_snr, "SNR"),
(bad_master, "Master"),
(bad_swing, "Swing"),
] ]
# Graph what channels were dropped and why they were dropped # Graph what channels were dropped and why they were dropped
@@ -1148,8 +1225,7 @@ def mark_bads(raw, bad_sci, bad_snr, bad_psp):
channel_names.extend(chs_in_cat) channel_names.extend(chs_in_cat)
category_labels.extend([cat] * len(chs_in_cat)) category_labels.extend([cat] * len(chs_in_cat))
colors = {cat: FIXED_CATEGORY_COLORS[cat] for cat in categories} colors = {cat: get_category_color(cat) for cat in categories}
# Create the figure # Create the figure
fig_dropped, ax = plt.subplots(figsize=(10, max(3, len(channel_names) * 0.3))) # type: ignore fig_dropped, ax = plt.subplots(figsize=(10, max(3, len(channel_names) * 0.3))) # type: ignore
y_pos = range(len(channel_names)) y_pos = range(len(channel_names))
@@ -1879,12 +1955,12 @@ def individual_significance(raw_haemo, glm_est):
# Merge with mean theta (optional for plotting) # Merge with mean theta (optional for plotting)
mean_theta = activity_ch_summary.groupby('ch_name')['theta'].mean().reset_index() mean_theta = activity_ch_summary.groupby('ch_name')['theta'].mean().reset_index()
sig_channels = sig_channels.merge(mean_theta, on='ch_name') sig_channels = sig_channels.merge(mean_theta, on='ch_name')
print(sig_channels) # print(sig_channels)
# For example, take the minimum corrected p-value per channel # For example, take the minimum corrected p-value per channel
summary_pvals = corrected.groupby('ch_name')['pval_fdr'].min().reset_index() summary_pvals = corrected.groupby('ch_name')['pval_fdr'].min().reset_index()
print(summary_pvals) # print(summary_pvals)
def parse_ch_name(ch_name): def parse_ch_name(ch_name):
@@ -1916,9 +1992,10 @@ def individual_significance(raw_haemo, glm_est):
SOURCE_DETECTOR_SEPARATOR = "_" SOURCE_DETECTOR_SEPARATOR = "_"
t_or_theta = 'theta' t_or_theta = 'theta'
for _, row in avg_df.iterrows(): # type: ignore #holy log noise
print(f"Source {row['Source']} <-> Detector {row['Detector']}: " # for _, row in avg_df.iterrows(): # type: ignore
f"Avg {t_or_theta}-value = {row['t_or_theta']:.3f}, Avg p-value = {row['p_value']:.3f}") # print(f"Source {row['Source']} <-> Detector {row['Detector']}: "
# f"Avg {t_or_theta}-value = {row['t_or_theta']:.3f}, Avg p-value = {row['p_value']:.3f}")
# Extract the cource and detector positions from raw # Extract the cource and detector positions from raw
src_pos: dict[int, tuple[float, float]] = {} src_pos: dict[int, tuple[float, float]] = {}
@@ -3384,6 +3461,270 @@ def plot_heart_rate(
return fig1, fig2 return fig1, fig2
# import numpy as np
# def mark_bads_by_db_threshold(raw, db_limit=-60):
# """
# Converts a dB threshold to absolute power and marks channels
# exceeding it at 6.25 Hz (1/2 Nyquist).
# """
# # 1. Convert your "Magic Number" from dB to absolute units
# abs_threshold = 10 ** (db_limit / 10)
# # 2. Compute PSD specifically around the frequency of interest
# sfreq = raw.info['sfreq']
# target_freq = (sfreq / 2) - 1 # 6.25 Hz
# # We use a small window around 6.25 Hz to get a stable average
# spectrum = raw.compute_psd(fmin=target_freq - 0.5, fmax=target_freq + 0.5)
# psd_data, freqs = spectrum.get_data(return_freqs=True)
# # 3. Get the power at that frequency (averaging across the small window)
# # psd_data shape: (n_channels, n_freqs)
# power_at_6Hz = np.mean(psd_data, axis=1)
# # 4. Identify the "loud" channels
# bad_indices = np.where(power_at_6Hz > abs_threshold)[0]
# new_bads = [raw.ch_names[i] for i in bad_indices]
# # 5. Update raw.info['bads'] without duplicates
# # raw.info['bads'] = list(set(raw.info['bads'] + new_bads))
# print(f"Threshold: {db_limit} dB -> {abs_threshold:.2e} V^2/Hz")
# print(f"Newly potentially bad channels: {new_bads}")
# return new_bads
# def find_flat_channels(raw, threshold=1e-15):
# """
# Identifies channels that are essentially flat lines (zero variance).
# """
# data = raw.get_data()
# # Calculate standard deviation of each channel
# std_devs = np.std(data, axis=1)
# # Identify channels with almost zero movement
# flat_idx = np.where(std_devs < threshold)[0]
# flat_names = [raw.ch_names[i] for i in flat_idx]
# return flat_names
# def find_truly_dead_channels(raw, threshold=1e-20):
# """
# Looks for channels where the signal literally doesn't move
# between samples (successive differences are zero).
# """
# data = raw.get_data()
# # Calculate the absolute difference between every sample (t and t+1)
# # diffs shape: (n_channels, n_times - 1)
# diffs = np.abs(np.diff(data, axis=1))
# # Calculate the mean 'step' size for each channel
# mean_diffs = np.mean(diffs, axis=1)
# # Identify channels where the signal is 'stuck'
# stuck_idx = np.where(mean_diffs < threshold)[0]
# stuck_names = [raw.ch_names[i] for i in stuck_idx]
# return stuck_names, mean_diffs
# def find_mid_run_flatlines(raw, window_size=10.0, threshold=1e-12):
# """
# Checks for channels that go flat partway through the recording.
# window_size: Duration in seconds to check for flatness.
# """
# sfreq = raw.info['sfreq']
# data = raw.get_data()
# n_samples_win = int(window_size * sfreq)
# bad_channels = []
# for i, ch_name in enumerate(raw.ch_names):
# ch_data = data[i]
# # Create sliding windows (non-overlapping for speed)
# windows = np.array_split(ch_data, len(ch_data) // n_samples_win)
# # Calculate variance for each window
# win_vars = [np.var(w) for w in windows]
# # If the LAST window (or any significant portion at the end) is flat
# if win_vars[-1] < threshold:
# bad_channels.append(ch_name)
# return bad_channels
def find_flatline_at_end(raw, threshold_ratio=0.05):
"""
Compares the variance of the first 25% of the data to the last 25%.
If the end variance is less than 5% of the start variance, mark it bad.
"""
data = raw.get_data()
n_samples = data.shape[1]
quarter = n_samples // 4
bad_channels = []
for i, ch_name in enumerate(raw.ch_names):
start_var = np.var(data[i, :quarter])
end_var = np.var(data[i, -quarter:])
# Avoid division by zero for truly dead channels
if start_var == 0:
bad_channels.append(ch_name)
continue
ratio = end_var / start_var
if ratio < threshold_ratio:
bad_channels.append(ch_name)
print(f"Flagged {ch_name}: Variance dropped to {ratio:.2%} of original.")
return bad_channels
# def plot_with_death_lines(raw, picks, window_sec=5.0, var_threshold=0.05):
# sfreq = raw.info['sfreq']
# data, times = raw.get_data(picks=picks, return_times=True)
# fig, ax = plt.subplots(figsize=(12, 6))
# for i, ch_name in enumerate(picks):
# ch_data = data[i]
# ax.plot(times, ch_data, label=ch_name, alpha=0.8)
# # --- Find the 'Death Point' ---
# # Calculate rolling variance in 5-second chunks
# win_samples = int(window_sec * sfreq)
# initial_var = np.var(ch_data[:win_samples])
# # Check from the end backwards to find where it "died"
# death_time = None
# n_samples = len(ch_data)
# for start_idx in range(n_samples - win_samples - 1, 0, -win_samples):
# current_var = np.var(ch_data[start_idx : start_idx + win_samples])
# if current_var > (initial_var * var_threshold):
# # The point right after this is where it stays dead
# death_time = times[start_idx + win_samples]
# break
# if death_time and death_time < (times[-1] - 10): # Only plot if it died early
# ax.axvline(x=death_time, color='red', linestyle='--', alpha=0.5)
# ax.text(death_time, ax.get_ylim()[1], 'Signal Loss',
# color='red', rotation=90, verticalalignment='top')
# ax.set_title("HbO/HbR with Automated Signal Loss Detection")
# ax.set_ylabel("Concentration (Δ μmol)")
# ax.set_xlabel("Time (s)")
# ax.legend(loc='lower right')
# plt.grid(True, alpha=0.3)
# plt.show(block=True)
# import numpy as np
def master_clean_fnirs(raw, db_limit=-60, threshold_ratio=0.05):
ch_names = raw.ch_names
data = raw.get_data()
sfreq = raw.info['sfreq']
n_samples = data.shape[1]
quarter = n_samples // 4
dead_idx = []
for i in range(len(ch_names)):
start_var = np.var(data[i, :quarter])
end_var = np.var(data[i, -quarter:])
if start_var == 0:
dead_idx.append(i)
continue
ratio = end_var / start_var
if ratio < threshold_ratio:
dead_idx.append(i)
print(f"Flagged {ch_names[i]}: Variance dropped to {ratio:.2%} of original.")
dead_channels = [ch_names[i] for i in dead_idx]
target_freq = sfreq / 4
spectrum = raw.compute_psd(fmin=0.1, fmax=sfreq/2)
psd_data, freqs = spectrum.get_data(return_freqs=True)
f_idx = np.where((freqs >= target_freq - 0.2) & (freqs <= target_freq + 0.2))[0]
power_at_target = np.mean(psd_data[:, f_idx], axis=1)
abs_threshold = 10 ** (db_limit / 10)
noisy_idx = np.where(power_at_target > abs_threshold)[0]
noisy_channels = [ch_names[i] for i in noisy_idx]
# --- The Pairing Logic ---
# Combine indices of specifically failed channels
failed_idx_list = np.unique(np.concatenate([dead_idx, noisy_idx])).astype(int)
# Extract base names (e.g., "S1_D2") of failed channels
# Assuming "S1_D2 762" -> split by space -> "S1_D2"
failed_bases = {str(ch_names[i]).split(' ')[0] for i in failed_idx_list}
# Flag EVERY channel that matches these bases
all_bads = [ch for ch in ch_names if ch.split(' ')[0] in failed_bases]
print(f"Analysis Complete:")
print(f" - Found {len(dead_channels)} dead/low-var channels: {dead_channels}")
print(f" - Found {len(noisy_channels)} high-noise (> {db_limit}dB) channels: {noisy_channels}")
# --- Visualization ---
fig_master, ax = plt.subplots(figsize=(8, 4))
ax.plot(freqs, 10 * np.log10(psd_data.T), color='gray', alpha=0.2)
# Highlight the specifically noisy ones
if len(noisy_idx) > 0:
ax.plot(freqs, 10 * np.log10(psd_data[noisy_idx].T), color='plum', label='Noisy Pairs')
ax.axhline(db_limit, color='red', linestyle='--', label='Threshold')
ax.set_title("Master Clean: PSD Noise & Dead Channel Detection")
ax.set_ylabel("Power (dB)")
ax.set_xlabel("Frequency (Hz)")
ax.legend()
plt.close(fig_master)
return list(set(all_bads)), fig_master
def find_physiologically_impossible_channels(raw, picks, threshold=3.0):
data = raw.get_data(picks=picks)
ranges = np.max(data, axis=1) - np.min(data, axis=1)
# Calculate Z-Scores
median_range = np.median(ranges)
mad = np.median(np.abs(ranges - median_range))
z_scores = 0.6745 * (ranges - median_range) / (mad if mad > 0 else 1e-15)
# Identify failed bases
failed_indices = np.where(z_scores > threshold)[0]
failed_bases = {picks[i].split(' ')[0] for i in failed_indices}
# Flag entire pairs
bad_names = [ch for ch in picks if ch.split(' ')[0] in failed_bases]
# --- Visualization ---
fig_swing, ax = plt.subplots(figsize=(8, 4))
# We color bars by the specific Z-score of that individual channel
colors = ['coral' if z > threshold else 'skyblue' for z in z_scores]
ax.bar(range(len(z_scores)), z_scores, color=colors)
ax.axhline(threshold, color='red', linestyle='--', label='Outlier Threshold')
ax.set_title("Physiological Swing Analysis (Z-Scores)")
ax.set_ylabel("Standardized Deviation")
ax.set_xlabel("Channel Index")
ax.legend()
plt.close(fig_swing)
return bad_names, fig_swing
def hr_calc(raw): def hr_calc(raw):
if SHORT_CHANNEL: if SHORT_CHANNEL:
@@ -3414,7 +3755,7 @@ def hr_calc(raw):
# --- Parameters for PSD --- # --- Parameters for PSD ---
desired_bin_hz = 0.1 desired_bin_hz = 0.1
nperseg = int(sfreq / desired_bin_hz) nperseg = int(sfreq / desired_bin_hz)
hr_range = (30, 180) # TODO: SHould this not use the user defined values? hr_range = (30, 180) # TODO: Should this not use the user defined values?
# --- Function to find strongest local peak --- # --- Function to find strongest local peak ---
def find_hr_from_psd(ch_data): def find_hr_from_psd(ch_data):
@@ -3511,6 +3852,22 @@ def process_participant(file_path, progress_callback=None):
fig_individual["PSD"] = fig fig_individual["PSD"] = fig
fig_individual['HeartRate_PSD'] = hr1 fig_individual['HeartRate_PSD'] = hr1
fig_individual['HeartRate_Time'] = hr2 fig_individual['HeartRate_Time'] = hr2
# --- Run it ---
# mark_bads_by_db_threshold(raw, db_limit=-60)
# dead_channels = find_flat_channels(raw)
# print(f"Dead/Flat channels removed: {dead_channels}")
# stuck_channels, movement_scores = find_truly_dead_channels(raw)
# print(f"Stuck Channels: {stuck_channels}")
# late_stage_bads = find_mid_run_flatlines(raw)
# _ = find_flatline_at_end(raw)
# print(f"Channels that died before the end: {late_stage_bads}")
# picks = [ch for ch in raw.ch_names if 'S1_D2' in ch]
bad_master, fig_master = master_clean_fnirs(raw, db_limit=-60)
bad_swing, fig_swing = find_physiologically_impossible_channels(raw, [ch for ch in raw.ch_names], threshold=4.0)
fig_individual["fig_master"] = fig_master
fig_individual["fig_swing"] = fig_swing
if progress_callback: progress_callback(5) if progress_callback: progress_callback(5)
logger.info("Step 5 Completed.") logger.info("Step 5 Completed.")
@@ -3545,17 +3902,19 @@ def process_participant(file_path, progress_callback=None):
# Step 9: Bad Channels Handling # Step 9: Bad Channels Handling
if BAD_CHANNELS_HANDLING != "None": if BAD_CHANNELS_HANDLING != "None":
raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp) raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_master, bad_swing)
if fig_dropped and fig_raw_before is not None: if fig_dropped and fig_raw_before is not None:
fig_individual["fig2"] = fig_dropped fig_individual["fig2"] = fig_dropped
fig_individual["fig3"] = fig_raw_before fig_individual["fig3"] = fig_raw_before
if bad_channels: if bad_channels:
if BAD_CHANNELS_HANDLING == "Interpolate": if BAD_CHANNELS_HANDLING == "Interpolate":
raw, fig_raw_after = interpolate_fNIRS_bads_weighted_average(raw, max_dist=MAX_DIST, min_neighbors=MIN_NEIGHBORS) raw, fig_raw_after, fig_compare = interpolate_fNIRS_bads_weighted_average(raw, max_dist=MAX_DIST, min_neighbors=MIN_NEIGHBORS)
fig_individual["fig4"] = fig_raw_after fig_individual["fig4"] = fig_raw_after
fig_individual["Compare"] = fig_compare
elif BAD_CHANNELS_HANDLING == "Remove": elif BAD_CHANNELS_HANDLING == "Remove":
pass #NOTE: testing this
#TODO: Is there more needed here? raw.pick_types(fnirs=True, exclude='bads')
logger.info(f"Physically removed {len(bad_channels)} channels from the dataset.")
if progress_callback: progress_callback(9) if progress_callback: progress_callback(9)
logger.info("Step 9 Completed.") logger.info("Step 9 Completed.")
+45 -14
View File
@@ -181,6 +181,9 @@ SECTIONS = [
# TODO: implement drop # TODO: implement drop
{"name": "EPOCH_HANDLING", "default": ["shift"], "type": list, "options": ["shift", "strict"], "help": "What to do if two unique events occur at the same time. Shift will automatically move one event to the first valid free index. Strict will raise an error processing the file. Drop will remove one of the events."}, {"name": "EPOCH_HANDLING", "default": ["shift"], "type": list, "options": ["shift", "strict"], "help": "What to do if two unique events occur at the same time. Shift will automatically move one event to the first valid free index. Strict will raise an error processing the file. Drop will remove one of the events."},
{"name": "MAX_SHIFT", "default": 5, "type": int, "depends_on": "EPOCH_HANDLING", "depends_value": "shift", "help": "Amount of indexes to look ahead and see if there is a valid one to shift to. If none were found, will fall back to 'strict' behaviour."}, {"name": "MAX_SHIFT", "default": 5, "type": int, "depends_on": "EPOCH_HANDLING", "depends_value": "shift", "help": "Amount of indexes to look ahead and see if there is a valid one to shift to. If none were found, will fall back to 'strict' behaviour."},
{"name": "REJECT_BY_ANNOTATIONS", "default": True, "type": bool, "help": "Help."},
{"name": "MAX_SHIFT", "default": 5, "type": int, "depends_on": "EPOCH_HANDLING", "depends_value": "shift", "help": "Amount of indexes to look ahead and see if there is a valid one to shift to. If none were found, will fall back to 'strict' behaviour."},
{"name": "MAX_SHIFT", "default": 5, "type": int, "depends_on": "EPOCH_HANDLING", "depends_value": "shift", "help": "Amount of indexes to look ahead and see if there is a valid one to shift to. If none were found, will fall back to 'strict' behaviour."},
{"name": "T_MIN", "default": -5, "type": int, "help": "Seconds before the epoch to be used."}, {"name": "T_MIN", "default": -5, "type": int, "help": "Seconds before the epoch to be used."},
{"name": "T_MAX", "default": 15, "type": int, "help": "Seconds after the epoch to be used."}, {"name": "T_MAX", "default": 15, "type": int, "help": "Seconds after the epoch to be used."},
] ]
@@ -6201,26 +6204,54 @@ def _extract_metadata_worker(file_name):
print(f"Worker failed on {file_name}: {e}") print(f"Worker failed on {file_name}: {e}")
return None return None
def run_gui_entry_wrapper(config, gui_queue, progress_queue): # def run_gui_entry_wrapper(config, gui_queue, progress_queue):
""" # """
Where the processing happens # Where the processing happens
""" # """
# try:
# import flares
# flares.gui_entry(config, gui_queue, progress_queue)
# gui_queue.close()
# # gui_queue.join_thread()
# progress_queue.close()
# # progress_queue.join_thread()
# os._exit(0)
# except Exception as e:
# tb_str = traceback.format_exc()
# gui_queue.put({
# "success": False,
# "error": f"Child process crashed: {str(e)}\nTraceback:\n{tb_str}"
# })
# os._exit(1)
def log_to_file(msg):
with open("FLARES_CRITICAL_LOG.txt", "a") as f:
f.write(f"DEBUG: {msg}\n")
def run_gui_entry_wrapper(config, gui_queue, progress_queue):
log_to_file("Wrapper Started")
try: try:
import flares import flares
log_to_file("Flares Imported")
# This is the line where it likely dies
flares.gui_entry(config, gui_queue, progress_queue) flares.gui_entry(config, gui_queue, progress_queue)
gui_queue.close()
#gui_queue.join_thread() log_to_file("gui_entry finished successfully")
sys.exit(0)
gui_queue.cancel_join_thread()
progress_queue.cancel_join_thread()
log_to_file("Attempting OS EXIT")
os._exit(0)
except Exception as e: except Exception as e:
tb_str = traceback.format_exc() err = f"CRASH: {str(e)}\n{traceback.format_exc()}"
gui_queue.put({ log_to_file(err)
"success": False, os._exit(1)
"error": f"Child process crashed: {str(e)}\nTraceback:\n{tb_str}"
})
sys.exit(1)
def resource_path(relative_path): def resource_path(relative_path):
""" """