no more zombie + additional bad channel detection underneath
This commit is contained in:
@@ -99,26 +99,21 @@ from mne_connectivity import envelope_correlation, spectral_connectivity_epochs,
|
||||
# Needs to be set for mne
|
||||
os.environ["SUBJECTS_DIR"] = str(data_path()) + "/subjects" # type: ignore
|
||||
|
||||
# TODO: Tidy this up
|
||||
FIXED_CATEGORY_COLORS = {
|
||||
PRIMARY_COLORS = {
|
||||
"SCI only": "skyblue",
|
||||
"PSP only": "salmon",
|
||||
"SNR only": "lightgreen",
|
||||
"PSP + SCI": "orange",
|
||||
"SCI + SNR": "violet",
|
||||
"PSP + SNR": "gold",
|
||||
"SCI + PSP": "orange",
|
||||
"SNR + SCI": "violet",
|
||||
"SNR + PSP": "gold",
|
||||
"PSP + SNR + SCI": "gray",
|
||||
"SCI + PSP + SNR": "gray",
|
||||
"SCI + SNR + PSP": "gray",
|
||||
"PSP + SCI + SNR": "gray",
|
||||
"PSP + SNR + SCI": "gray",
|
||||
"SNR + SCI + PSP": "gray",
|
||||
"SNR + PSP + SCI": "gray",
|
||||
"Master only": "plum",
|
||||
"Swing only": "coral",
|
||||
}
|
||||
|
||||
# The Fallback for combinations (the "Multifail" Gray)
|
||||
COMBINATION_COLOR = "gray"
|
||||
|
||||
def get_category_color(label):
|
||||
"""Returns the primary color if it's a single failure, otherwise gray."""
|
||||
return PRIMARY_COLORS.get(label, COMBINATION_COLOR)
|
||||
|
||||
|
||||
DOWNSAMPLE: bool
|
||||
DOWNSAMPLE_FREQUENCY: int
|
||||
@@ -434,59 +429,99 @@ def process_multiple_participants(file_paths, file_params, file_metadata, progre
|
||||
#audit_log.info(f"--- SESSION START: {len(file_paths)} files ---")
|
||||
|
||||
pending_files = list(file_paths)
|
||||
active_processes = [] # List of tuples: (Process object, file_path)
|
||||
active_processes = []
|
||||
results_by_file = {}
|
||||
|
||||
# We use a manager queue so it handles IPC serialization cleanly
|
||||
manager = mp.Manager()
|
||||
result_queue = manager.Queue()
|
||||
|
||||
# Loop continues as long as there are files to process OR workers still running
|
||||
while pending_files or active_processes:
|
||||
try:
|
||||
while pending_files or active_processes:
|
||||
# 1. SPAWN
|
||||
while len(active_processes) < max_workers and pending_files:
|
||||
file_path = pending_files.pop(0)
|
||||
|
||||
p = mp.Process(
|
||||
target=process_participant_worker,
|
||||
args=(file_path, file_params, file_metadata, result_queue, progress_queue)
|
||||
)
|
||||
p.start()
|
||||
active_processes.append((p, file_path))
|
||||
|
||||
# 2. COLLECT RESULTS (Inner Loop)
|
||||
while True:
|
||||
try:
|
||||
|
||||
res_path, result, error = result_queue.get(timeout=0.01)
|
||||
|
||||
for i, (p, f_path) in enumerate(active_processes):
|
||||
if f_path == res_path:
|
||||
p.join(timeout=1) # Clean it up
|
||||
active_processes.pop(i)
|
||||
break
|
||||
|
||||
if not pending_files and not active_processes:
|
||||
break
|
||||
|
||||
if gui_queue:
|
||||
# FIX: Added a timeout. If the GUI stops reading, we don't freeze forever.
|
||||
try:
|
||||
gui_queue.put({
|
||||
"type": "file_done",
|
||||
"file": res_path,
|
||||
"success": error is None,
|
||||
"result": result if error is None else None,
|
||||
"error": error if error else None
|
||||
}, timeout=2) # 2-second safety valve
|
||||
except Exception as e:
|
||||
print(f"DEBUG: gui_queue full or closed! {e}")
|
||||
else:
|
||||
results_by_file[res_path] = result
|
||||
except:
|
||||
break # Break OUT of the inner loop if queue is empty
|
||||
|
||||
# --- EVERYTHING BELOW IS NOW OUTSIDE THE INNER LOOP ---
|
||||
|
||||
# 3. THE CIRCUIT BREAKER
|
||||
if not pending_files:
|
||||
alive_count = sum(1 for p, _ in active_processes if p.is_alive())
|
||||
if alive_count == 0 and len(active_processes) > 0:
|
||||
print("DEBUG: All workers dead but loop persisting. Breaking.")
|
||||
break
|
||||
|
||||
# 4. CLEANUP DEAD WORKERS
|
||||
for p, f_path in active_processes[:]:
|
||||
if not p.is_alive():
|
||||
p.join(timeout=0.1)
|
||||
active_processes.remove((p, f_path))
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
print("DEBUG: Loop finished naturally.")
|
||||
|
||||
# 1. SPWAN WORKERS: Only spawn if we are under the limit AND have files left
|
||||
while len(active_processes) < max_workers and pending_files:
|
||||
file_path = pending_files.pop(0)
|
||||
|
||||
p = mp.Process(
|
||||
target=process_participant_worker,
|
||||
args=(file_path, file_params, file_metadata, result_queue, progress_queue)
|
||||
)
|
||||
p.start()
|
||||
active_processes.append((p, file_path))
|
||||
#audit_log.info(f"Spawned worker. Active processes: {len(active_processes)}")
|
||||
|
||||
# 2. COLLECT RESULTS: Drain the queue continuously so workers don't deadlock
|
||||
while not result_queue.empty():
|
||||
except Exception as e:
|
||||
print(e)
|
||||
finally:
|
||||
# 1. Kill any workers
|
||||
for p, _ in active_processes:
|
||||
try:
|
||||
res_path, result, error = result_queue.get_nowait()
|
||||
# SEND IMMEDIATELY TO THE MAIN GUI
|
||||
if gui_queue:
|
||||
gui_queue.put({
|
||||
"type": "file_done",
|
||||
"file": res_path,
|
||||
"success": error is None,
|
||||
"result": result if error is None else None,
|
||||
"error": error if error else None
|
||||
})
|
||||
else:
|
||||
# Fallback if no GUI queue (e.g., CLI mode)
|
||||
results_by_file[res_path] = result
|
||||
|
||||
except Exception:
|
||||
break # Queue is empty or busy
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
p.join(timeout=0.1)
|
||||
except: pass
|
||||
|
||||
# 3. CLEANUP: Check for finished processes and remove them
|
||||
for p, f_path in active_processes[:]: # Iterate over a slice copy
|
||||
if not p.is_alive():
|
||||
p.join() # Formally close the process to free OS resources
|
||||
active_processes.remove((p, f_path))
|
||||
#audit_log.info(f"Worker finished. Active processes dropping to: {len(active_processes)}")
|
||||
# 2. Clear Proxy
|
||||
try:
|
||||
del result_queue
|
||||
except: pass
|
||||
|
||||
# 3. Force Kill Manager
|
||||
if manager is not None:
|
||||
try:
|
||||
manager.shutdown()
|
||||
except: pass
|
||||
|
||||
# Brief pause to prevent this while loop from pegging your CPU to 100%
|
||||
time.sleep(0.5)
|
||||
|
||||
#audit_log.info("--- SESSION COMPLETE ---")
|
||||
return results_by_file
|
||||
|
||||
|
||||
@@ -949,7 +984,9 @@ def interpolate_fNIRS_bads_weighted_average(raw, max_dist=0.03, min_neighbors=2)
|
||||
|
||||
if len(bad_pairs) == 0:
|
||||
print("No bad pairs found. Skipping interpolation.")
|
||||
return raw
|
||||
return raw, None, None
|
||||
|
||||
raw_before_data = raw.get_data().copy()
|
||||
|
||||
# Extract locations (use HbO channel loc as pair location)
|
||||
locs = np.array([raw.info['chs'][hbo_picks[i]]['loc'][:3] for i in range(len(hbo_names))])
|
||||
@@ -990,6 +1027,43 @@ def interpolate_fNIRS_bads_weighted_average(raw, max_dist=0.03, min_neighbors=2)
|
||||
|
||||
interpolated_pairs.append(bad_base)
|
||||
|
||||
n_bad = len(bad_pairs)
|
||||
n_cols = 4 # Fixed width for horizontal scaling
|
||||
n_rows = int(np.ceil(n_bad / n_cols))
|
||||
|
||||
# Calculate height: 2.5 inches per row is usually enough for readability
|
||||
fig_height = max(4, n_rows * 2.5)
|
||||
fig_compare, axes = plt.subplots(n_rows, n_cols, figsize=(15, fig_height),
|
||||
constrained_layout=True)
|
||||
if n_bad == 1: axes = [axes] # Handle single subplot case
|
||||
|
||||
axes_flat = axes.flatten()
|
||||
for j in range(n_bad, len(axes_flat)):
|
||||
axes_flat[j].axis('off')
|
||||
|
||||
times = raw.times
|
||||
for i, bad_idx in enumerate(bad_pairs):
|
||||
ax = axes_flat[i]
|
||||
base = hbo_names[bad_idx]
|
||||
|
||||
# Plot "Before" (Dirty data) in light gray/red
|
||||
ax.plot(times, raw_before_data[hbo_picks[bad_idx]], color='red', alpha=0.3, label='Original HbO')
|
||||
|
||||
# Plot "After" (Interpolated data) in solid blue/green
|
||||
if base in interpolated_pairs:
|
||||
ax.plot(times, raw._data[hbo_picks[bad_idx]], color='blue', label='Interpolated HbO')
|
||||
status = "SUCCESS"
|
||||
color = "green"
|
||||
else:
|
||||
status = "FAILED (Isolated)"
|
||||
color = "red"
|
||||
|
||||
ax.set_title(f"Channel {base} | Status: {status}", color=color, fontweight='bold')
|
||||
ax.set_ylabel("Amplitude")
|
||||
if i == 0: ax.legend(loc='upper right')
|
||||
|
||||
plt.close(fig_compare)
|
||||
|
||||
if interpolated_pairs:
|
||||
bad_ch_to_remove = []
|
||||
for base_ in interpolated_pairs:
|
||||
@@ -1007,7 +1081,7 @@ def interpolate_fNIRS_bads_weighted_average(raw, max_dist=0.03, min_neighbors=2)
|
||||
|
||||
fig_raw_after = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="After interpolation", show=False)
|
||||
|
||||
return raw, fig_raw_after
|
||||
return raw, fig_raw_after, fig_compare
|
||||
|
||||
|
||||
|
||||
@@ -1114,8 +1188,9 @@ def calculate_peak_power(data: BaseRaw, l_freq: float = 0.7, h_freq: float = 1.5
|
||||
return list(compress(cast(list[str], getattr(data, "ch_names")), psp < PSP_THRESHOLD)), psp1, psp2
|
||||
|
||||
|
||||
def mark_bads(raw, bad_sci, bad_snr, bad_psp):
|
||||
bads_combined = list(set(bad_snr) | set(bad_sci) | set(bad_psp))
|
||||
def mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_master, bad_swing):
|
||||
print(bad_sci, bad_master, bad_psp, bad_snr, bad_swing)
|
||||
bads_combined = list(set(bad_snr) | set(bad_sci) | set(bad_psp) | set(bad_master) | set(bad_swing))
|
||||
print(f"Automatically marked bad channels based on SNR and SCI: {bads_combined}")
|
||||
|
||||
raw.info['bads'].extend(bads_combined)
|
||||
@@ -1125,6 +1200,8 @@ def mark_bads(raw, bad_sci, bad_snr, bad_psp):
|
||||
(bad_sci, "SCI"),
|
||||
(bad_psp, "PSP"),
|
||||
(bad_snr, "SNR"),
|
||||
(bad_master, "Master"),
|
||||
(bad_swing, "Swing"),
|
||||
]
|
||||
|
||||
# Graph what channels were dropped and why they were dropped
|
||||
@@ -1148,8 +1225,7 @@ def mark_bads(raw, bad_sci, bad_snr, bad_psp):
|
||||
channel_names.extend(chs_in_cat)
|
||||
category_labels.extend([cat] * len(chs_in_cat))
|
||||
|
||||
colors = {cat: FIXED_CATEGORY_COLORS[cat] for cat in categories}
|
||||
|
||||
colors = {cat: get_category_color(cat) for cat in categories}
|
||||
# Create the figure
|
||||
fig_dropped, ax = plt.subplots(figsize=(10, max(3, len(channel_names) * 0.3))) # type: ignore
|
||||
y_pos = range(len(channel_names))
|
||||
@@ -1879,12 +1955,12 @@ def individual_significance(raw_haemo, glm_est):
|
||||
# Merge with mean theta (optional for plotting)
|
||||
mean_theta = activity_ch_summary.groupby('ch_name')['theta'].mean().reset_index()
|
||||
sig_channels = sig_channels.merge(mean_theta, on='ch_name')
|
||||
print(sig_channels)
|
||||
# print(sig_channels)
|
||||
|
||||
|
||||
# For example, take the minimum corrected p-value per channel
|
||||
summary_pvals = corrected.groupby('ch_name')['pval_fdr'].min().reset_index()
|
||||
print(summary_pvals)
|
||||
# print(summary_pvals)
|
||||
|
||||
|
||||
def parse_ch_name(ch_name):
|
||||
@@ -1916,9 +1992,10 @@ def individual_significance(raw_haemo, glm_est):
|
||||
SOURCE_DETECTOR_SEPARATOR = "_"
|
||||
|
||||
t_or_theta = 'theta'
|
||||
for _, row in avg_df.iterrows(): # type: ignore
|
||||
print(f"Source {row['Source']} <-> Detector {row['Detector']}: "
|
||||
f"Avg {t_or_theta}-value = {row['t_or_theta']:.3f}, Avg p-value = {row['p_value']:.3f}")
|
||||
#holy log noise
|
||||
# for _, row in avg_df.iterrows(): # type: ignore
|
||||
# print(f"Source {row['Source']} <-> Detector {row['Detector']}: "
|
||||
# f"Avg {t_or_theta}-value = {row['t_or_theta']:.3f}, Avg p-value = {row['p_value']:.3f}")
|
||||
|
||||
# Extract the cource and detector positions from raw
|
||||
src_pos: dict[int, tuple[float, float]] = {}
|
||||
@@ -3384,6 +3461,270 @@ def plot_heart_rate(
|
||||
return fig1, fig2
|
||||
|
||||
|
||||
# import numpy as np
|
||||
|
||||
# def mark_bads_by_db_threshold(raw, db_limit=-60):
|
||||
# """
|
||||
# Converts a dB threshold to absolute power and marks channels
|
||||
# exceeding it at 6.25 Hz (1/2 Nyquist).
|
||||
# """
|
||||
# # 1. Convert your "Magic Number" from dB to absolute units
|
||||
# abs_threshold = 10 ** (db_limit / 10)
|
||||
|
||||
# # 2. Compute PSD specifically around the frequency of interest
|
||||
# sfreq = raw.info['sfreq']
|
||||
# target_freq = (sfreq / 2) - 1 # 6.25 Hz
|
||||
|
||||
# # We use a small window around 6.25 Hz to get a stable average
|
||||
# spectrum = raw.compute_psd(fmin=target_freq - 0.5, fmax=target_freq + 0.5)
|
||||
# psd_data, freqs = spectrum.get_data(return_freqs=True)
|
||||
|
||||
# # 3. Get the power at that frequency (averaging across the small window)
|
||||
# # psd_data shape: (n_channels, n_freqs)
|
||||
# power_at_6Hz = np.mean(psd_data, axis=1)
|
||||
|
||||
# # 4. Identify the "loud" channels
|
||||
# bad_indices = np.where(power_at_6Hz > abs_threshold)[0]
|
||||
# new_bads = [raw.ch_names[i] for i in bad_indices]
|
||||
|
||||
# # 5. Update raw.info['bads'] without duplicates
|
||||
# # raw.info['bads'] = list(set(raw.info['bads'] + new_bads))
|
||||
|
||||
# print(f"Threshold: {db_limit} dB -> {abs_threshold:.2e} V^2/Hz")
|
||||
# print(f"Newly potentially bad channels: {new_bads}")
|
||||
|
||||
# return new_bads
|
||||
|
||||
|
||||
|
||||
# def find_flat_channels(raw, threshold=1e-15):
|
||||
# """
|
||||
# Identifies channels that are essentially flat lines (zero variance).
|
||||
# """
|
||||
# data = raw.get_data()
|
||||
# # Calculate standard deviation of each channel
|
||||
# std_devs = np.std(data, axis=1)
|
||||
|
||||
# # Identify channels with almost zero movement
|
||||
# flat_idx = np.where(std_devs < threshold)[0]
|
||||
# flat_names = [raw.ch_names[i] for i in flat_idx]
|
||||
|
||||
# return flat_names
|
||||
|
||||
|
||||
# def find_truly_dead_channels(raw, threshold=1e-20):
|
||||
# """
|
||||
# Looks for channels where the signal literally doesn't move
|
||||
# between samples (successive differences are zero).
|
||||
# """
|
||||
# data = raw.get_data()
|
||||
# # Calculate the absolute difference between every sample (t and t+1)
|
||||
# # diffs shape: (n_channels, n_times - 1)
|
||||
# diffs = np.abs(np.diff(data, axis=1))
|
||||
|
||||
# # Calculate the mean 'step' size for each channel
|
||||
# mean_diffs = np.mean(diffs, axis=1)
|
||||
|
||||
# # Identify channels where the signal is 'stuck'
|
||||
# stuck_idx = np.where(mean_diffs < threshold)[0]
|
||||
# stuck_names = [raw.ch_names[i] for i in stuck_idx]
|
||||
|
||||
# return stuck_names, mean_diffs
|
||||
|
||||
|
||||
# def find_mid_run_flatlines(raw, window_size=10.0, threshold=1e-12):
|
||||
# """
|
||||
# Checks for channels that go flat partway through the recording.
|
||||
# window_size: Duration in seconds to check for flatness.
|
||||
# """
|
||||
# sfreq = raw.info['sfreq']
|
||||
# data = raw.get_data()
|
||||
# n_samples_win = int(window_size * sfreq)
|
||||
|
||||
# bad_channels = []
|
||||
|
||||
# for i, ch_name in enumerate(raw.ch_names):
|
||||
# ch_data = data[i]
|
||||
# # Create sliding windows (non-overlapping for speed)
|
||||
# windows = np.array_split(ch_data, len(ch_data) // n_samples_win)
|
||||
|
||||
# # Calculate variance for each window
|
||||
# win_vars = [np.var(w) for w in windows]
|
||||
|
||||
# # If the LAST window (or any significant portion at the end) is flat
|
||||
# if win_vars[-1] < threshold:
|
||||
# bad_channels.append(ch_name)
|
||||
|
||||
# return bad_channels
|
||||
|
||||
|
||||
def find_flatline_at_end(raw, threshold_ratio=0.05):
|
||||
"""
|
||||
Compares the variance of the first 25% of the data to the last 25%.
|
||||
If the end variance is less than 5% of the start variance, mark it bad.
|
||||
"""
|
||||
data = raw.get_data()
|
||||
n_samples = data.shape[1]
|
||||
quarter = n_samples // 4
|
||||
|
||||
bad_channels = []
|
||||
|
||||
for i, ch_name in enumerate(raw.ch_names):
|
||||
start_var = np.var(data[i, :quarter])
|
||||
end_var = np.var(data[i, -quarter:])
|
||||
|
||||
# Avoid division by zero for truly dead channels
|
||||
if start_var == 0:
|
||||
bad_channels.append(ch_name)
|
||||
continue
|
||||
|
||||
ratio = end_var / start_var
|
||||
|
||||
if ratio < threshold_ratio:
|
||||
bad_channels.append(ch_name)
|
||||
print(f"Flagged {ch_name}: Variance dropped to {ratio:.2%} of original.")
|
||||
|
||||
return bad_channels
|
||||
|
||||
|
||||
# def plot_with_death_lines(raw, picks, window_sec=5.0, var_threshold=0.05):
|
||||
# sfreq = raw.info['sfreq']
|
||||
# data, times = raw.get_data(picks=picks, return_times=True)
|
||||
|
||||
# fig, ax = plt.subplots(figsize=(12, 6))
|
||||
|
||||
# for i, ch_name in enumerate(picks):
|
||||
# ch_data = data[i]
|
||||
# ax.plot(times, ch_data, label=ch_name, alpha=0.8)
|
||||
|
||||
# # --- Find the 'Death Point' ---
|
||||
# # Calculate rolling variance in 5-second chunks
|
||||
# win_samples = int(window_sec * sfreq)
|
||||
# initial_var = np.var(ch_data[:win_samples])
|
||||
|
||||
# # Check from the end backwards to find where it "died"
|
||||
# death_time = None
|
||||
# n_samples = len(ch_data)
|
||||
# for start_idx in range(n_samples - win_samples - 1, 0, -win_samples):
|
||||
# current_var = np.var(ch_data[start_idx : start_idx + win_samples])
|
||||
# if current_var > (initial_var * var_threshold):
|
||||
# # The point right after this is where it stays dead
|
||||
# death_time = times[start_idx + win_samples]
|
||||
# break
|
||||
|
||||
# if death_time and death_time < (times[-1] - 10): # Only plot if it died early
|
||||
# ax.axvline(x=death_time, color='red', linestyle='--', alpha=0.5)
|
||||
# ax.text(death_time, ax.get_ylim()[1], 'Signal Loss',
|
||||
# color='red', rotation=90, verticalalignment='top')
|
||||
|
||||
# ax.set_title("HbO/HbR with Automated Signal Loss Detection")
|
||||
# ax.set_ylabel("Concentration (Δ μmol)")
|
||||
# ax.set_xlabel("Time (s)")
|
||||
# ax.legend(loc='lower right')
|
||||
# plt.grid(True, alpha=0.3)
|
||||
# plt.show(block=True)
|
||||
|
||||
|
||||
# import numpy as np
|
||||
|
||||
def master_clean_fnirs(raw, db_limit=-60, threshold_ratio=0.05):
|
||||
ch_names = raw.ch_names
|
||||
data = raw.get_data()
|
||||
sfreq = raw.info['sfreq']
|
||||
|
||||
n_samples = data.shape[1]
|
||||
quarter = n_samples // 4
|
||||
|
||||
dead_idx = []
|
||||
for i in range(len(ch_names)):
|
||||
start_var = np.var(data[i, :quarter])
|
||||
end_var = np.var(data[i, -quarter:])
|
||||
|
||||
if start_var == 0:
|
||||
dead_idx.append(i)
|
||||
continue
|
||||
|
||||
ratio = end_var / start_var
|
||||
if ratio < threshold_ratio:
|
||||
dead_idx.append(i)
|
||||
print(f"Flagged {ch_names[i]}: Variance dropped to {ratio:.2%} of original.")
|
||||
|
||||
dead_channels = [ch_names[i] for i in dead_idx]
|
||||
|
||||
target_freq = sfreq / 4
|
||||
spectrum = raw.compute_psd(fmin=0.1, fmax=sfreq/2)
|
||||
psd_data, freqs = spectrum.get_data(return_freqs=True)
|
||||
f_idx = np.where((freqs >= target_freq - 0.2) & (freqs <= target_freq + 0.2))[0]
|
||||
power_at_target = np.mean(psd_data[:, f_idx], axis=1)
|
||||
|
||||
abs_threshold = 10 ** (db_limit / 10)
|
||||
noisy_idx = np.where(power_at_target > abs_threshold)[0]
|
||||
noisy_channels = [ch_names[i] for i in noisy_idx]
|
||||
|
||||
# --- The Pairing Logic ---
|
||||
# Combine indices of specifically failed channels
|
||||
failed_idx_list = np.unique(np.concatenate([dead_idx, noisy_idx])).astype(int)
|
||||
|
||||
# Extract base names (e.g., "S1_D2") of failed channels
|
||||
# Assuming "S1_D2 762" -> split by space -> "S1_D2"
|
||||
failed_bases = {str(ch_names[i]).split(' ')[0] for i in failed_idx_list}
|
||||
|
||||
# Flag EVERY channel that matches these bases
|
||||
all_bads = [ch for ch in ch_names if ch.split(' ')[0] in failed_bases]
|
||||
|
||||
print(f"Analysis Complete:")
|
||||
print(f" - Found {len(dead_channels)} dead/low-var channels: {dead_channels}")
|
||||
print(f" - Found {len(noisy_channels)} high-noise (> {db_limit}dB) channels: {noisy_channels}")
|
||||
|
||||
# --- Visualization ---
|
||||
fig_master, ax = plt.subplots(figsize=(8, 4))
|
||||
ax.plot(freqs, 10 * np.log10(psd_data.T), color='gray', alpha=0.2)
|
||||
|
||||
# Highlight the specifically noisy ones
|
||||
if len(noisy_idx) > 0:
|
||||
ax.plot(freqs, 10 * np.log10(psd_data[noisy_idx].T), color='plum', label='Noisy Pairs')
|
||||
|
||||
ax.axhline(db_limit, color='red', linestyle='--', label='Threshold')
|
||||
ax.set_title("Master Clean: PSD Noise & Dead Channel Detection")
|
||||
ax.set_ylabel("Power (dB)")
|
||||
ax.set_xlabel("Frequency (Hz)")
|
||||
ax.legend()
|
||||
plt.close(fig_master)
|
||||
|
||||
return list(set(all_bads)), fig_master
|
||||
|
||||
|
||||
def find_physiologically_impossible_channels(raw, picks, threshold=3.0):
|
||||
data = raw.get_data(picks=picks)
|
||||
ranges = np.max(data, axis=1) - np.min(data, axis=1)
|
||||
|
||||
# Calculate Z-Scores
|
||||
median_range = np.median(ranges)
|
||||
mad = np.median(np.abs(ranges - median_range))
|
||||
z_scores = 0.6745 * (ranges - median_range) / (mad if mad > 0 else 1e-15)
|
||||
|
||||
# Identify failed bases
|
||||
failed_indices = np.where(z_scores > threshold)[0]
|
||||
failed_bases = {picks[i].split(' ')[0] for i in failed_indices}
|
||||
|
||||
# Flag entire pairs
|
||||
bad_names = [ch for ch in picks if ch.split(' ')[0] in failed_bases]
|
||||
|
||||
# --- Visualization ---
|
||||
fig_swing, ax = plt.subplots(figsize=(8, 4))
|
||||
# We color bars by the specific Z-score of that individual channel
|
||||
colors = ['coral' if z > threshold else 'skyblue' for z in z_scores]
|
||||
|
||||
ax.bar(range(len(z_scores)), z_scores, color=colors)
|
||||
ax.axhline(threshold, color='red', linestyle='--', label='Outlier Threshold')
|
||||
ax.set_title("Physiological Swing Analysis (Z-Scores)")
|
||||
ax.set_ylabel("Standardized Deviation")
|
||||
ax.set_xlabel("Channel Index")
|
||||
ax.legend()
|
||||
plt.close(fig_swing)
|
||||
|
||||
return bad_names, fig_swing
|
||||
|
||||
|
||||
def hr_calc(raw):
|
||||
if SHORT_CHANNEL:
|
||||
@@ -3414,7 +3755,7 @@ def hr_calc(raw):
|
||||
# --- Parameters for PSD ---
|
||||
desired_bin_hz = 0.1
|
||||
nperseg = int(sfreq / desired_bin_hz)
|
||||
hr_range = (30, 180) # TODO: SHould this not use the user defined values?
|
||||
hr_range = (30, 180) # TODO: Should this not use the user defined values?
|
||||
|
||||
# --- Function to find strongest local peak ---
|
||||
def find_hr_from_psd(ch_data):
|
||||
@@ -3511,6 +3852,22 @@ def process_participant(file_path, progress_callback=None):
|
||||
fig_individual["PSD"] = fig
|
||||
fig_individual['HeartRate_PSD'] = hr1
|
||||
fig_individual['HeartRate_Time'] = hr2
|
||||
|
||||
# --- Run it ---
|
||||
# mark_bads_by_db_threshold(raw, db_limit=-60)
|
||||
# dead_channels = find_flat_channels(raw)
|
||||
# print(f"Dead/Flat channels removed: {dead_channels}")
|
||||
# stuck_channels, movement_scores = find_truly_dead_channels(raw)
|
||||
# print(f"Stuck Channels: {stuck_channels}")
|
||||
# late_stage_bads = find_mid_run_flatlines(raw)
|
||||
# _ = find_flatline_at_end(raw)
|
||||
# print(f"Channels that died before the end: {late_stage_bads}")
|
||||
# picks = [ch for ch in raw.ch_names if 'S1_D2' in ch]
|
||||
bad_master, fig_master = master_clean_fnirs(raw, db_limit=-60)
|
||||
bad_swing, fig_swing = find_physiologically_impossible_channels(raw, [ch for ch in raw.ch_names], threshold=4.0)
|
||||
fig_individual["fig_master"] = fig_master
|
||||
fig_individual["fig_swing"] = fig_swing
|
||||
|
||||
if progress_callback: progress_callback(5)
|
||||
logger.info("Step 5 Completed.")
|
||||
|
||||
@@ -3545,17 +3902,19 @@ def process_participant(file_path, progress_callback=None):
|
||||
|
||||
# Step 9: Bad Channels Handling
|
||||
if BAD_CHANNELS_HANDLING != "None":
|
||||
raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp)
|
||||
raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp, bad_master, bad_swing)
|
||||
if fig_dropped and fig_raw_before is not None:
|
||||
fig_individual["fig2"] = fig_dropped
|
||||
fig_individual["fig3"] = fig_raw_before
|
||||
if bad_channels:
|
||||
if BAD_CHANNELS_HANDLING == "Interpolate":
|
||||
raw, fig_raw_after = interpolate_fNIRS_bads_weighted_average(raw, max_dist=MAX_DIST, min_neighbors=MIN_NEIGHBORS)
|
||||
raw, fig_raw_after, fig_compare = interpolate_fNIRS_bads_weighted_average(raw, max_dist=MAX_DIST, min_neighbors=MIN_NEIGHBORS)
|
||||
fig_individual["fig4"] = fig_raw_after
|
||||
fig_individual["Compare"] = fig_compare
|
||||
elif BAD_CHANNELS_HANDLING == "Remove":
|
||||
pass
|
||||
#TODO: Is there more needed here?
|
||||
#NOTE: testing this
|
||||
raw.pick_types(fnirs=True, exclude='bads')
|
||||
logger.info(f"Physically removed {len(bad_channels)} channels from the dataset.")
|
||||
|
||||
if progress_callback: progress_callback(9)
|
||||
logger.info("Step 9 Completed.")
|
||||
|
||||
Reference in New Issue
Block a user