no more zombie + additional bad channel detection underneath
This commit is contained in:
@@ -99,26 +99,21 @@ from mne_connectivity import envelope_correlation, spectral_connectivity_epochs,
|
||||
# Needs to be set for mne
|
||||
os.environ["SUBJECTS_DIR"] = str(data_path()) + "/subjects" # type: ignore
|
||||
|
||||
# TODO: Tidy this up
|
||||
FIXED_CATEGORY_COLORS = {
|
||||
PRIMARY_COLORS = {
|
||||
"SCI only": "skyblue",
|
||||
"PSP only": "salmon",
|
||||
"SNR only": "lightgreen",
|
||||
"PSP + SCI": "orange",
|
||||
"SCI + SNR": "violet",
|
||||
"PSP + SNR": "gold",
|
||||
"SCI + PSP": "orange",
|
||||
"SNR + SCI": "violet",
|
||||
"SNR + PSP": "gold",
|
||||
"PSP + SNR + SCI": "gray",
|
||||
"SCI + PSP + SNR": "gray",
|
||||
"SCI + SNR + PSP": "gray",
|
||||
"PSP + SCI + SNR": "gray",
|
||||
"PSP + SNR + SCI": "gray",
|
||||
"SNR + SCI + PSP": "gray",
|
||||
"SNR + PSP + SCI": "gray",
|
||||
"Master only": "plum",
|
||||
"Swing only": "coral",
|
||||
}
|
||||
|
||||
# The Fallback for combinations (the "Multifail" Gray)
|
||||
COMBINATION_COLOR = "gray"
|
||||
|
||||
def get_category_color(label):
|
||||
"""Returns the primary color if it's a single failure, otherwise gray."""
|
||||
return PRIMARY_COLORS.get(label, COMBINATION_COLOR)
|
||||
|
||||
|
||||
DOWNSAMPLE: bool
|
||||
DOWNSAMPLE_FREQUENCY: int
|
||||
@@ -434,17 +429,16 @@ def process_multiple_participants(file_paths, file_params, file_metadata, progre
|
||||
#audit_log.info(f"--- SESSION START: {len(file_paths)} files ---")
|
||||
|
||||
pending_files = list(file_paths)
|
||||
active_processes = [] # List of tuples: (Process object, file_path)
|
||||
active_processes = []
|
||||
results_by_file = {}
|
||||
|
||||
# We use a manager queue so it handles IPC serialization cleanly
|
||||
manager = mp.Manager()
|
||||
result_queue = manager.Queue()
|
||||
|
||||
# Loop continues as long as there are files to process OR workers still running
|
||||
try:
|
||||
while pending_files or active_processes:
|
||||
|
||||
# 1. SPWAN WORKERS: Only spawn if we are under the limit AND have files left
|
||||
# 1. SPAWN
|
||||
while len(active_processes) < max_workers and pending_files:
|
||||
file_path = pending_files.pop(0)
|
||||
|
||||
@@ -454,39 +448,80 @@ def process_multiple_participants(file_paths, file_params, file_metadata, progre
|
||||
)
|
||||
p.start()
|
||||
active_processes.append((p, file_path))
|
||||
#audit_log.info(f"Spawned worker. Active processes: {len(active_processes)}")
|
||||
|
||||
# 2. COLLECT RESULTS: Drain the queue continuously so workers don't deadlock
|
||||
while not result_queue.empty():
|
||||
# 2. COLLECT RESULTS (Inner Loop)
|
||||
while True:
|
||||
try:
|
||||
res_path, result, error = result_queue.get_nowait()
|
||||
# SEND IMMEDIATELY TO THE MAIN GUI
|
||||
|
||||
res_path, result, error = result_queue.get(timeout=0.01)
|
||||
|
||||
for i, (p, f_path) in enumerate(active_processes):
|
||||
if f_path == res_path:
|
||||
p.join(timeout=1) # Clean it up
|
||||
active_processes.pop(i)
|
||||
break
|
||||
|
||||
if not pending_files and not active_processes:
|
||||
break
|
||||
|
||||
if gui_queue:
|
||||
# FIX: Added a timeout. If the GUI stops reading, we don't freeze forever.
|
||||
try:
|
||||
gui_queue.put({
|
||||
"type": "file_done",
|
||||
"file": res_path,
|
||||
"success": error is None,
|
||||
"result": result if error is None else None,
|
||||
"error": error if error else None
|
||||
})
|
||||
}, timeout=2) # 2-second safety valve
|
||||
except Exception as e:
|
||||
print(f"DEBUG: gui_queue full or closed! {e}")
|
||||
else:
|
||||
# Fallback if no GUI queue (e.g., CLI mode)
|
||||
results_by_file[res_path] = result
|
||||
except:
|
||||
break # Break OUT of the inner loop if queue is empty
|
||||
|
||||
except Exception:
|
||||
break # Queue is empty or busy
|
||||
# --- EVERYTHING BELOW IS NOW OUTSIDE THE INNER LOOP ---
|
||||
|
||||
# 3. CLEANUP: Check for finished processes and remove them
|
||||
for p, f_path in active_processes[:]: # Iterate over a slice copy
|
||||
# 3. THE CIRCUIT BREAKER
|
||||
if not pending_files:
|
||||
alive_count = sum(1 for p, _ in active_processes if p.is_alive())
|
||||
if alive_count == 0 and len(active_processes) > 0:
|
||||
print("DEBUG: All workers dead but loop persisting. Breaking.")
|
||||
break
|
||||
|
||||
# 4. CLEANUP DEAD WORKERS
|
||||
for p, f_path in active_processes[:]:
|
||||
if not p.is_alive():
|
||||
p.join() # Formally close the process to free OS resources
|
||||
p.join(timeout=0.1)
|
||||
active_processes.remove((p, f_path))
|
||||
#audit_log.info(f"Worker finished. Active processes dropping to: {len(active_processes)}")
|
||||
|
||||
# Brief pause to prevent this while loop from pegging your CPU to 100%
|
||||
time.sleep(0.5)
|
||||
time.sleep(0.1)
|
||||
|
||||
print("DEBUG: Loop finished naturally.")
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
finally:
|
||||
# 1. Kill any workers
|
||||
for p, _ in active_processes:
|
||||
try:
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
p.join(timeout=0.1)
|
||||
except: pass
|
||||
|
||||
# 2. Clear Proxy
|
||||
try:
|
||||
del result_queue
|
||||
except: pass
|
||||
|
||||
# 3. Force Kill Manager
|
||||
if manager is not None:
|
||||
try:
|
||||
manager.shutdown()
|
||||
except: pass
|
||||
|
||||
#audit_log.info("--- SESSION COMPLETE ---")
|
||||
return results_by_file
|
||||
|
||||
|
||||
@@ -949,7 +984,9 @@ def interpolate_fNIRS_bads_weighted_average(raw, max_dist=0.03, min_neighbors=2)
|
||||
|
||||
if len(bad_pairs) == 0:
|
||||
print("No bad pairs found. Skipping interpolation.")
|
||||
return raw
|
||||
return raw, None, None
|
||||
|
||||
raw_before_data = raw.get_data().copy()
|
||||
|
||||
# Extract locations (use HbO channel loc as pair location)
|
||||
locs = np.array([raw.info['chs'][hbo_picks[i]]['loc'][:3] for i in range(len(hbo_names))])
|
||||
@@ -990,6 +1027,43 @@ def interpolate_fNIRS_bads_weighted_average(raw, max_dist=0.03, min_neighbors=2)
|
||||
|
||||
interpolated_pairs.append(bad_base)
|
||||
|
||||
n_bad = len(bad_pairs)
|
||||
n_cols = 4 # Fixed width for horizontal scaling
|
||||
n_rows = int(np.ceil(n_bad / n_cols))
|
||||
|
||||
# Calculate height: 2.5 inches per row is usually enough for readability
|
||||
fig_height = max(4, n_rows * 2.5)
|
||||
fig_compare, axes = plt.subplots(n_rows, n_cols, figsize=(15, fig_height),
|
||||
constrained_layout=True)
|
||||
if n_bad == 1: axes = [axes] # Handle single subplot case
|
||||
|
||||
axes_flat = axes.flatten()
|
||||
for j in range(n_bad, len(axes_flat)):
|
||||
axes_flat[j].axis('off')
|
||||
|
||||
times = raw.times
|
||||
for i, bad_idx in enumerate(bad_pairs):
|
||||
ax = axes_flat[i]
|
||||
base = hbo_names[bad_idx]
|
||||
|
||||
# Plot "Before" (Dirty data) in light gray/red
|
||||
ax.plot(times, raw_before_data[hbo_picks[bad_idx]], color='red', alpha=0.3, label='Original HbO')
|
||||
|
||||
# Plot "After" (Interpolated data) in solid blue/green
|
||||
if base in interpolated_pairs:
|
||||
ax.plot(times, raw._data[hbo_picks[bad_idx]], color='blue', label='Interpolated HbO')
|
||||
status = "SUCCESS"
|
||||
color = "green"
|
||||
else:
|
||||
status = "FAILED (Isolated)"
|
||||
color = "red"
|
||||
|
||||
ax.set_title(f"Channel {base} | Status: {status}", color=color, fontweight='bold')
|
||||
ax.set_ylabel("Amplitude")
|
||||
if i == 0: ax.legend(loc='upper right')
|
||||
|
||||
plt.close(fig_compare)
|
||||
|
||||
if interpolated_pairs:
|
||||
bad_ch_to_remove = []
|
||||
for base_ in interpolated_pairs:
|
||||
@@ -1007,7 +1081,7 @@ def interpolate_fNIRS_bads_weighted_average(raw, max_dist=0.03, min_neighbors=2)
|
||||
|
||||
fig_raw_after = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="After interpolation", show=False)
|
||||
|
||||
return raw, fig_raw_after
|
||||
return raw, fig_raw_after, fig_compare
|
||||
|
||||
|
||||
|
||||
@@ -1114,8 +1188,9 @@ def calculate_peak_power(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5
|
||||
return list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)), psp1, psp2
|
||||
|
||||
|
||||
def mark_bads(raw, bad_sci, bad_snr, bad_psp):
|
||||
bads_combined = list(set(bad_snr) | set(bad_sci) | set(bad_psp))
|
||||
def mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_master, bad_swing):
|
||||
print(bad_sci, bad_master, bad_psp, bad_snr, bad_swing)
|
||||
bads_combined = list(set(bad_snr) | set(bad_sci) | set(bad_psp) | set(bad_master) | set(bad_swing))
|
||||
print(f"Automatically marked bad channels based on SNR and SCI: {bads_combined}")
|
||||
|
||||
raw.info['bads'].extend(bads_combined)
|
||||
@@ -1125,6 +1200,8 @@ def mark_bads(raw, bad_sci, bad_snr, bad_psp):
|
||||
(bad_sci, "SCI"),
|
||||
(bad_psp, "PSP"),
|
||||
(bad_snr, "SNR"),
|
||||
(bad_master, "Master"),
|
||||
(bad_swing, "Swing"),
|
||||
]
|
||||
|
||||
# Graph what channels were dropped and why they were dropped
|
||||
@@ -1148,8 +1225,7 @@ def mark_bads(raw, bad_sci, bad_snr, bad_psp):
|
||||
channel_names.extend(chs_in_cat)
|
||||
category_labels.extend([cat] * len(chs_in_cat))
|
||||
|
||||
colors = {cat: FIXED_CATEGORY_COLORS[cat] for cat in categories}
|
||||
|
||||
colors = {cat: get_category_color(cat) for cat in categories}
|
||||
# Create the figure
|
||||
fig_dropped, ax = plt.subplots(figsize=(10, max(3, len(channel_names) * 0.3))) # type: ignore
|
||||
y_pos = range(len(channel_names))
|
||||
@@ -1879,12 +1955,12 @@ def individual_significance(raw_haemo, glm_est):
|
||||
# Merge with mean theta (optional for plotting)
|
||||
mean_theta = activity_ch_summary.groupby('ch_name')['theta'].mean().reset_index()
|
||||
sig_channels = sig_channels.merge(mean_theta, on='ch_name')
|
||||
print(sig_channels)
|
||||
# print(sig_channels)
|
||||
|
||||
|
||||
# For example, take the minimum corrected p-value per channel
|
||||
summary_pvals = corrected.groupby('ch_name')['pval_fdr'].min().reset_index()
|
||||
print(summary_pvals)
|
||||
# print(summary_pvals)
|
||||
|
||||
|
||||
def parse_ch_name(ch_name):
|
||||
@@ -1916,9 +1992,10 @@ def individual_significance(raw_haemo, glm_est):
|
||||
SOURCE_DETECTOR_SEPARATOR = "_"
|
||||
|
||||
t_or_theta = 'theta'
|
||||
for _, row in avg_df.iterrows(): # type: ignore
|
||||
print(f"Source {row['Source']} <-> Detector {row['Detector']}: "
|
||||
f"Avg {t_or_theta}-value = {row['t_or_theta']:.3f}, Avg p-value = {row['p_value']:.3f}")
|
||||
#holy log noise
|
||||
# for _, row in avg_df.iterrows(): # type: ignore
|
||||
# print(f"Source {row['Source']} <-> Detector {row['Detector']}: "
|
||||
# f"Avg {t_or_theta}-value = {row['t_or_theta']:.3f}, Avg p-value = {row['p_value']:.3f}")
|
||||
|
||||
# Extract the cource and detector positions from raw
|
||||
src_pos: dict[int, tuple[float, float]] = {}
|
||||
@@ -3384,6 +3461,270 @@ def plot_heart_rate(
|
||||
return fig1, fig2
|
||||
|
||||
|
||||
# import numpy as np
|
||||
|
||||
# def mark_bads_by_db_threshold(raw, db_limit=-60):
|
||||
# """
|
||||
# Converts a dB threshold to absolute power and marks channels
|
||||
# exceeding it at 6.25 Hz (1/2 Nyquist).
|
||||
# """
|
||||
# # 1. Convert your "Magic Number" from dB to absolute units
|
||||
# abs_threshold = 10 ** (db_limit / 10)
|
||||
|
||||
# # 2. Compute PSD specifically around the frequency of interest
|
||||
# sfreq = raw.info['sfreq']
|
||||
# target_freq = (sfreq / 2) - 1 # 6.25 Hz
|
||||
|
||||
# # We use a small window around 6.25 Hz to get a stable average
|
||||
# spectrum = raw.compute_psd(fmin=target_freq - 0.5, fmax=target_freq + 0.5)
|
||||
# psd_data, freqs = spectrum.get_data(return_freqs=True)
|
||||
|
||||
# # 3. Get the power at that frequency (averaging across the small window)
|
||||
# # psd_data shape: (n_channels, n_freqs)
|
||||
# power_at_6Hz = np.mean(psd_data, axis=1)
|
||||
|
||||
# # 4. Identify the "loud" channels
|
||||
# bad_indices = np.where(power_at_6Hz > abs_threshold)[0]
|
||||
# new_bads = [raw.ch_names[i] for i in bad_indices]
|
||||
|
||||
# # 5. Update raw.info['bads'] without duplicates
|
||||
# # raw.info['bads'] = list(set(raw.info['bads'] + new_bads))
|
||||
|
||||
# print(f"Threshold: {db_limit} dB -> {abs_threshold:.2e} V^2/Hz")
|
||||
# print(f"Newly potentially bad channels: {new_bads}")
|
||||
|
||||
# return new_bads
|
||||
|
||||
|
||||
|
||||
# def find_flat_channels(raw, threshold=1e-15):
|
||||
# """
|
||||
# Identifies channels that are essentially flat lines (zero variance).
|
||||
# """
|
||||
# data = raw.get_data()
|
||||
# # Calculate standard deviation of each channel
|
||||
# std_devs = np.std(data, axis=1)
|
||||
|
||||
# # Identify channels with almost zero movement
|
||||
# flat_idx = np.where(std_devs < threshold)[0]
|
||||
# flat_names = [raw.ch_names[i] for i in flat_idx]
|
||||
|
||||
# return flat_names
|
||||
|
||||
|
||||
# def find_truly_dead_channels(raw, threshold=1e-20):
|
||||
# """
|
||||
# Looks for channels where the signal literally doesn't move
|
||||
# between samples (successive differences are zero).
|
||||
# """
|
||||
# data = raw.get_data()
|
||||
# # Calculate the absolute difference between every sample (t and t+1)
|
||||
# # diffs shape: (n_channels, n_times - 1)
|
||||
# diffs = np.abs(np.diff(data, axis=1))
|
||||
|
||||
# # Calculate the mean 'step' size for each channel
|
||||
# mean_diffs = np.mean(diffs, axis=1)
|
||||
|
||||
# # Identify channels where the signal is 'stuck'
|
||||
# stuck_idx = np.where(mean_diffs < threshold)[0]
|
||||
# stuck_names = [raw.ch_names[i] for i in stuck_idx]
|
||||
|
||||
# return stuck_names, mean_diffs
|
||||
|
||||
|
||||
# def find_mid_run_flatlines(raw, window_size=10.0, threshold=1e-12):
|
||||
# """
|
||||
# Checks for channels that go flat partway through the recording.
|
||||
# window_size: Duration in seconds to check for flatness.
|
||||
# """
|
||||
# sfreq = raw.info['sfreq']
|
||||
# data = raw.get_data()
|
||||
# n_samples_win = int(window_size * sfreq)
|
||||
|
||||
# bad_channels = []
|
||||
|
||||
# for i, ch_name in enumerate(raw.ch_names):
|
||||
# ch_data = data[i]
|
||||
# # Create sliding windows (non-overlapping for speed)
|
||||
# windows = np.array_split(ch_data, len(ch_data) // n_samples_win)
|
||||
|
||||
# # Calculate variance for each window
|
||||
# win_vars = [np.var(w) for w in windows]
|
||||
|
||||
# # If the LAST window (or any significant portion at the end) is flat
|
||||
# if win_vars[-1] < threshold:
|
||||
# bad_channels.append(ch_name)
|
||||
|
||||
# return bad_channels
|
||||
|
||||
|
||||
def find_flatline_at_end(raw, threshold_ratio=0.05):
|
||||
"""
|
||||
Compares the variance of the first 25% of the data to the last 25%.
|
||||
If the end variance is less than 5% of the start variance, mark it bad.
|
||||
"""
|
||||
data = raw.get_data()
|
||||
n_samples = data.shape[1]
|
||||
quarter = n_samples // 4
|
||||
|
||||
bad_channels = []
|
||||
|
||||
for i, ch_name in enumerate(raw.ch_names):
|
||||
start_var = np.var(data[i, :quarter])
|
||||
end_var = np.var(data[i, -quarter:])
|
||||
|
||||
# Avoid division by zero for truly dead channels
|
||||
if start_var == 0:
|
||||
bad_channels.append(ch_name)
|
||||
continue
|
||||
|
||||
ratio = end_var / start_var
|
||||
|
||||
if ratio < threshold_ratio:
|
||||
bad_channels.append(ch_name)
|
||||
print(f"Flagged {ch_name}: Variance dropped to {ratio:.2%} of original.")
|
||||
|
||||
return bad_channels
|
||||
|
||||
|
||||
# def plot_with_death_lines(raw, picks, window_sec=5.0, var_threshold=0.05):
|
||||
# sfreq = raw.info['sfreq']
|
||||
# data, times = raw.get_data(picks=picks, return_times=True)
|
||||
|
||||
# fig, ax = plt.subplots(figsize=(12, 6))
|
||||
|
||||
# for i, ch_name in enumerate(picks):
|
||||
# ch_data = data[i]
|
||||
# ax.plot(times, ch_data, label=ch_name, alpha=0.8)
|
||||
|
||||
# # --- Find the 'Death Point' ---
|
||||
# # Calculate rolling variance in 5-second chunks
|
||||
# win_samples = int(window_sec * sfreq)
|
||||
# initial_var = np.var(ch_data[:win_samples])
|
||||
|
||||
# # Check from the end backwards to find where it "died"
|
||||
# death_time = None
|
||||
# n_samples = len(ch_data)
|
||||
# for start_idx in range(n_samples - win_samples - 1, 0, -win_samples):
|
||||
# current_var = np.var(ch_data[start_idx : start_idx + win_samples])
|
||||
# if current_var > (initial_var * var_threshold):
|
||||
# # The point right after this is where it stays dead
|
||||
# death_time = times[start_idx + win_samples]
|
||||
# break
|
||||
|
||||
# if death_time and death_time < (times[-1] - 10): # Only plot if it died early
|
||||
# ax.axvline(x=death_time, color='red', linestyle='--', alpha=0.5)
|
||||
# ax.text(death_time, ax.get_ylim()[1], 'Signal Loss',
|
||||
# color='red', rotation=90, verticalalignment='top')
|
||||
|
||||
# ax.set_title("HbO/HbR with Automated Signal Loss Detection")
|
||||
# ax.set_ylabel("Concentration (Δ μmol)")
|
||||
# ax.set_xlabel("Time (s)")
|
||||
# ax.legend(loc='lower right')
|
||||
# plt.grid(True, alpha=0.3)
|
||||
# plt.show(block=True)
|
||||
|
||||
|
||||
# import numpy as np
|
||||
|
||||
def master_clean_fnirs(raw, db_limit=-60, threshold_ratio=0.05):
|
||||
ch_names = raw.ch_names
|
||||
data = raw.get_data()
|
||||
sfreq = raw.info['sfreq']
|
||||
|
||||
n_samples = data.shape[1]
|
||||
quarter = n_samples // 4
|
||||
|
||||
dead_idx = []
|
||||
for i in range(len(ch_names)):
|
||||
start_var = np.var(data[i, :quarter])
|
||||
end_var = np.var(data[i, -quarter:])
|
||||
|
||||
if start_var == 0:
|
||||
dead_idx.append(i)
|
||||
continue
|
||||
|
||||
ratio = end_var / start_var
|
||||
if ratio < threshold_ratio:
|
||||
dead_idx.append(i)
|
||||
print(f"Flagged {ch_names[i]}: Variance dropped to {ratio:.2%} of original.")
|
||||
|
||||
dead_channels = [ch_names[i] for i in dead_idx]
|
||||
|
||||
target_freq = sfreq / 4
|
||||
spectrum = raw.compute_psd(fmin=0.1, fmax=sfreq/2)
|
||||
psd_data, freqs = spectrum.get_data(return_freqs=True)
|
||||
f_idx = np.where((freqs >= target_freq - 0.2) & (freqs <= target_freq + 0.2))[0]
|
||||
power_at_target = np.mean(psd_data[:, f_idx], axis=1)
|
||||
|
||||
abs_threshold = 10 ** (db_limit / 10)
|
||||
noisy_idx = np.where(power_at_target > abs_threshold)[0]
|
||||
noisy_channels = [ch_names[i] for i in noisy_idx]
|
||||
|
||||
# --- The Pairing Logic ---
|
||||
# Combine indices of specifically failed channels
|
||||
failed_idx_list = np.unique(np.concatenate([dead_idx, noisy_idx])).astype(int)
|
||||
|
||||
# Extract base names (e.g., "S1_D2") of failed channels
|
||||
# Assuming "S1_D2 762" -> split by space -> "S1_D2"
|
||||
failed_bases = {str(ch_names[i]).split(' ')[0] for i in failed_idx_list}
|
||||
|
||||
# Flag EVERY channel that matches these bases
|
||||
all_bads = [ch for ch in ch_names if ch.split(' ')[0] in failed_bases]
|
||||
|
||||
print(f"Analysis Complete:")
|
||||
print(f" - Found {len(dead_channels)} dead/low-var channels: {dead_channels}")
|
||||
print(f" - Found {len(noisy_channels)} high-noise (> {db_limit}dB) channels: {noisy_channels}")
|
||||
|
||||
# --- Visualization ---
|
||||
fig_master, ax = plt.subplots(figsize=(8, 4))
|
||||
ax.plot(freqs, 10 * np.log10(psd_data.T), color='gray', alpha=0.2)
|
||||
|
||||
# Highlight the specifically noisy ones
|
||||
if len(noisy_idx) > 0:
|
||||
ax.plot(freqs, 10 * np.log10(psd_data[noisy_idx].T), color='plum', label='Noisy Pairs')
|
||||
|
||||
ax.axhline(db_limit, color='red', linestyle='--', label='Threshold')
|
||||
ax.set_title("Master Clean: PSD Noise & Dead Channel Detection")
|
||||
ax.set_ylabel("Power (dB)")
|
||||
ax.set_xlabel("Frequency (Hz)")
|
||||
ax.legend()
|
||||
plt.close(fig_master)
|
||||
|
||||
return list(set(all_bads)), fig_master
|
||||
|
||||
|
||||
def find_physiologically_impossible_channels(raw, picks, threshold=3.0):
|
||||
data = raw.get_data(picks=picks)
|
||||
ranges = np.max(data, axis=1) - np.min(data, axis=1)
|
||||
|
||||
# Calculate Z-Scores
|
||||
median_range = np.median(ranges)
|
||||
mad = np.median(np.abs(ranges - median_range))
|
||||
z_scores = 0.6745 * (ranges - median_range) / (mad if mad > 0 else 1e-15)
|
||||
|
||||
# Identify failed bases
|
||||
failed_indices = np.where(z_scores > threshold)[0]
|
||||
failed_bases = {picks[i].split(' ')[0] for i in failed_indices}
|
||||
|
||||
# Flag entire pairs
|
||||
bad_names = [ch for ch in picks if ch.split(' ')[0] in failed_bases]
|
||||
|
||||
# --- Visualization ---
|
||||
fig_swing, ax = plt.subplots(figsize=(8, 4))
|
||||
# We color bars by the specific Z-score of that individual channel
|
||||
colors = ['coral' if z > threshold else 'skyblue' for z in z_scores]
|
||||
|
||||
ax.bar(range(len(z_scores)), z_scores, color=colors)
|
||||
ax.axhline(threshold, color='red', linestyle='--', label='Outlier Threshold')
|
||||
ax.set_title("Physiological Swing Analysis (Z-Scores)")
|
||||
ax.set_ylabel("Standardized Deviation")
|
||||
ax.set_xlabel("Channel Index")
|
||||
ax.legend()
|
||||
plt.close(fig_swing)
|
||||
|
||||
return bad_names, fig_swing
|
||||
|
||||
|
||||
def hr_calc(raw):
|
||||
if SHORT_CHANNEL:
|
||||
@@ -3414,7 +3755,7 @@ def hr_calc(raw):
|
||||
# --- Parameters for PSD ---
|
||||
desired_bin_hz = 0.1
|
||||
nperseg = int(sfreq / desired_bin_hz)
|
||||
hr_range = (30, 180) # TODO: SHould this not use the user defined values?
|
||||
hr_range = (30, 180) # TODO: Should this not use the user defined values?
|
||||
|
||||
# --- Function to find strongest local peak ---
|
||||
def find_hr_from_psd(ch_data):
|
||||
@@ -3511,6 +3852,22 @@ def process_participant(file_path, progress_callback=None):
|
||||
fig_individual["PSD"] = fig
|
||||
fig_individual['HeartRate_PSD'] = hr1
|
||||
fig_individual['HeartRate_Time'] = hr2
|
||||
|
||||
# --- Run it ---
|
||||
# mark_bads_by_db_threshold(raw, db_limit=-60)
|
||||
# dead_channels = find_flat_channels(raw)
|
||||
# print(f"Dead/Flat channels removed: {dead_channels}")
|
||||
# stuck_channels, movement_scores = find_truly_dead_channels(raw)
|
||||
# print(f"Stuck Channels: {stuck_channels}")
|
||||
# late_stage_bads = find_mid_run_flatlines(raw)
|
||||
# _ = find_flatline_at_end(raw)
|
||||
# print(f"Channels that died before the end: {late_stage_bads}")
|
||||
# picks = [ch for ch in raw.ch_names if 'S1_D2' in ch]
|
||||
bad_master, fig_master = master_clean_fnirs(raw, db_limit=-60)
|
||||
bad_swing, fig_swing = find_physiologically_impossible_channels(raw, [ch for ch in raw.ch_names], threshold=4.0)
|
||||
fig_individual["fig_master"] = fig_master
|
||||
fig_individual["fig_swing"] = fig_swing
|
||||
|
||||
if progress_callback: progress_callback(5)
|
||||
logger.info("Step 5 Completed.")
|
||||
|
||||
@@ -3545,17 +3902,19 @@ def process_participant(file_path, progress_callback=None):
|
||||
|
||||
# Step 9: Bad Channels Handling
|
||||
if BAD_CHANNELS_HANDLING != "None":
|
||||
raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp)
|
||||
raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_master, bad_swing)
|
||||
if fig_dropped and fig_raw_before is not None:
|
||||
fig_individual["fig2"] = fig_dropped
|
||||
fig_individual["fig3"] = fig_raw_before
|
||||
if bad_channels:
|
||||
if BAD_CHANNELS_HANDLING == "Interpolate":
|
||||
raw, fig_raw_after = interpolate_fNIRS_bads_weighted_average(raw, max_dist=MAX_DIST, min_neighbors=MIN_NEIGHBORS)
|
||||
raw, fig_raw_after, fig_compare = interpolate_fNIRS_bads_weighted_average(raw, max_dist=MAX_DIST, min_neighbors=MIN_NEIGHBORS)
|
||||
fig_individual["fig4"] = fig_raw_after
|
||||
fig_individual["Compare"] = fig_compare
|
||||
elif BAD_CHANNELS_HANDLING == "Remove":
|
||||
pass
|
||||
#TODO: Is there more needed here?
|
||||
#NOTE: testing this
|
||||
raw.pick_types(fnirs=True, exclude='bads')
|
||||
logger.info(f"Physically removed {len(bad_channels)} channels from the dataset.")
|
||||
|
||||
if progress_callback: progress_callback(9)
|
||||
logger.info("Step 9 Completed.")
|
||||
|
||||
@@ -181,6 +181,9 @@ SECTIONS = [
|
||||
# TODO: implement drop
|
||||
{"name": "EPOCH_HANDLING", "default": ["shift"], "type": list, "options": ["shift", "strict"], "help": "What to do if two unique events occur at the same time. Shift will automatically move one event to the first valid free index. Strict will raise an error processing the file. Drop will remove one of the events."},
|
||||
{"name": "MAX_SHIFT", "default": 5, "type": int, "depends_on": "EPOCH_HANDLING", "depends_value": "shift", "help": "Amount of indexes to look ahead and see if there is a valid one to shift to. If none were found, will fall back to 'strict' behaviour."},
|
||||
{"name": "REJECT_BY_ANNOTATIONS", "default": True, "type": bool, "help": "Help."},
|
||||
{"name": "MAX_SHIFT", "default": 5, "type": int, "depends_on": "EPOCH_HANDLING", "depends_value": "shift", "help": "Amount of indexes to look ahead and see if there is a valid one to shift to. If none were found, will fall back to 'strict' behaviour."},
|
||||
{"name": "MAX_SHIFT", "default": 5, "type": int, "depends_on": "EPOCH_HANDLING", "depends_value": "shift", "help": "Amount of indexes to look ahead and see if there is a valid one to shift to. If none were found, will fall back to 'strict' behaviour."},
|
||||
{"name": "T_MIN", "default": -5, "type": int, "help": "Seconds before the epoch to be used."},
|
||||
{"name": "T_MAX", "default": 15, "type": int, "help": "Seconds after the epoch to be used."},
|
||||
]
|
||||
@@ -6201,26 +6204,54 @@ def _extract_metadata_worker(file_name):
|
||||
print(f"Worker failed on {file_name}: {e}")
|
||||
return None
|
||||
|
||||
def run_gui_entry_wrapper(config, gui_queue, progress_queue):
|
||||
"""
|
||||
Where the processing happens
|
||||
"""
|
||||
# def run_gui_entry_wrapper(config, gui_queue, progress_queue):
|
||||
# """
|
||||
# Where the processing happens
|
||||
# """
|
||||
|
||||
# try:
|
||||
# import flares
|
||||
# flares.gui_entry(config, gui_queue, progress_queue)
|
||||
# gui_queue.close()
|
||||
# # gui_queue.join_thread()
|
||||
# progress_queue.close()
|
||||
# # progress_queue.join_thread()
|
||||
# os._exit(0)
|
||||
|
||||
# except Exception as e:
|
||||
# tb_str = traceback.format_exc()
|
||||
# gui_queue.put({
|
||||
# "success": False,
|
||||
# "error": f"Child process crashed: {str(e)}\nTraceback:\n{tb_str}"
|
||||
# })
|
||||
# os._exit(1)
|
||||
|
||||
|
||||
def log_to_file(msg):
|
||||
with open("FLARES_CRITICAL_LOG.txt", "a") as f:
|
||||
f.write(f"DEBUG: {msg}\n")
|
||||
|
||||
def run_gui_entry_wrapper(config, gui_queue, progress_queue):
|
||||
log_to_file("Wrapper Started")
|
||||
try:
|
||||
import flares
|
||||
log_to_file("Flares Imported")
|
||||
|
||||
# This is the line where it likely dies
|
||||
flares.gui_entry(config, gui_queue, progress_queue)
|
||||
gui_queue.close()
|
||||
#gui_queue.join_thread()
|
||||
sys.exit(0)
|
||||
|
||||
log_to_file("gui_entry finished successfully")
|
||||
|
||||
gui_queue.cancel_join_thread()
|
||||
progress_queue.cancel_join_thread()
|
||||
|
||||
log_to_file("Attempting OS EXIT")
|
||||
os._exit(0)
|
||||
|
||||
except Exception as e:
|
||||
tb_str = traceback.format_exc()
|
||||
gui_queue.put({
|
||||
"success": False,
|
||||
"error": f"Child process crashed: {str(e)}\nTraceback:\n{tb_str}"
|
||||
})
|
||||
sys.exit(1)
|
||||
|
||||
err = f"CRASH: {str(e)}\n{traceback.format_exc()}"
|
||||
log_to_file(err)
|
||||
os._exit(1)
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user