updates to hr

This commit is contained in:
2025-10-31 21:21:10 -07:00
parent 45c6176dba
commit 1aa2402d09
3 changed files with 497 additions and 27 deletions

View File

@@ -1,7 +1,19 @@
# Version 1.1.7
- Fixed a bug where having both a L_FREQ and H_FREQ would cause only the L_FREQ to be used
- Changed the default H_FREQ from 0.7 to 0.3
- Removed SECONDS_TO_STRIP from the preprocessing options
- Instead all files are trimmed up until 5 seconds before the first annotation/event in the file
- Added a PSD graph, along with 2 heart rate images to the individual participant viewer
- The PSD graph is used to help calculate the heart rate, whereas the other 2 are just for show currently
- SCI is now done using a .6hz window around the calculated heart rate compared to a window around an average heart rate
- Fixed an issue with some epochs figures not showing
# Version 1.1.6
- Fixed Process button from appearing when no files are selected
- Fix for instand child process crash on Windows
- Fixed a bug that would cause an instant child process crash on Windows
- Added L_FREQ and H_FREQ parameters for more user control over low and high pass filtering

506
flares.py
View File

@@ -28,10 +28,11 @@ import matplotlib.colors as mcolors
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.lines import Line2D
import numpy as np
from numpy.typing import NDArray
from numpy import float64
from numpy import float64, floating
import pandas as pd
from pandas import DataFrame
@@ -47,11 +48,15 @@ from statsmodels.stats.multitest import multipletests
from scipy import stats
from scipy.spatial.distance import cdist
from scipy.signal import welch, butter, filtfilt # type: ignore
import pywt # type: ignore
import neurokit2 as nk # type: ignore
# Backen visualization needed to be defined for pyinstaller
import pyvistaqt # type: ignore
import vtkmodules.util.data_model
import vtkmodules.util.execution_model
#import vtkmodules.util.data_model
#import vtkmodules.util.execution_model
# External library imports for mne
from mne import (
@@ -61,7 +66,7 @@ from mne import (
) # type: ignore
from mne.source_space import SourceSpaces
from mne.transforms import Transform # type: ignore
from mne.io import BaseRaw, read_raw_snirf # type: ignore
from mne.io import BaseRaw, RawArray, read_raw_snirf # type: ignore
from mne.preprocessing.nirs import (
beer_lambert_law, optical_density,
temporal_derivative_distribution_repair,
@@ -110,7 +115,7 @@ FIXED_CATEGORY_COLORS = {
AGE: float
GENDER: str
SECONDS_TO_STRIP: int
# SECONDS_TO_STRIP: int
DOWNSAMPLE: bool
DOWNSAMPLE_FREQUENCY: int
@@ -128,6 +133,15 @@ PSP_THRESHOLD: float
TDDR: bool
IQR = 1.5
HEART_RATE = True # True if heart rate should be calculated. This helps the SCI, PSP, and SNR methods to be more accurate.
SECONDS_TO_STRIP_HR =5 # Amount of seconds to temporarily strip from the data to calculate heart rate more effectively. Useful if participant removed cap while still recording.
MAX_LOW_HR = 40 # Any heart rate values lower than this will be set to this value.
MAX_HIGH_HR = 200 # Any heart rate values higher than this will be set to this value.
SMOOTHING_WINDOW_HR = 100 # Heart rate will be calculated as a rolling average over this many amount of samples.
HEART_RATE_WINDOW = 25 # Amount of BPM above and below the calculated average to use for a range of resting BPM.
SHORT_CHANNEL_THRESH = 0.018
ENHANCE_NEGATIVE_CORRELATION: bool
L_FREQ: float
@@ -170,7 +184,7 @@ GROUP = "Default"
REQUIRED_KEYS: dict[str, Any] = {
"SECONDS_TO_STRIP": int,
# "SECONDS_TO_STRIP": int,
"DOWNSAMPLE": bool,
"DOWNSAMPLE_FREQUENCY": int,
@@ -1071,23 +1085,24 @@ def mark_bads(raw, bad_sci, bad_snr, bad_psp):
def filter_the_data(raw_haemo):
# --- STEP 5: Filtering (0.010.2 Hz bandpass) ---
fig_filter = raw_haemo.compute_psd(fmax=2).plot(
average=True, xscale="log", color="r", show=False, amplitude=False
fig_filter = raw_haemo.compute_psd(fmax=3).plot(
average=True, color="r", show=False, amplitude=True
)
if L_FREQ == 0 and H_FREQ != 0:
raw_haemo = raw_haemo.filter(l_freq=None, h_freq=H_FREQ, h_trans_bandwidth=0.02)
elif L_FREQ != 0 and H_FREQ == 0:
raw_haemo = raw_haemo.filter(l_freq=L_FREQ, h_freq=None, l_trans_bandwidth=0.002)
elif L_FREQ != 0 and H_FREQ == 0:
elif L_FREQ != 0 and H_FREQ != 0:
raw_haemo = raw_haemo.filter(l_freq=L_FREQ, h_freq=H_FREQ, l_trans_bandwidth=0.002, h_trans_bandwidth=0.02)
else:
print("No filter")
#raw_haemo = raw_haemo.filter(l_freq=None, h_freq=0.4, h_trans_bandwidth=0.2)
#raw_haemo = raw_haemo.filter(l_freq=None, h_freq=0.7, h_trans_bandwidth=0.2)
#raw_haemo = raw_haemo.filter(0.005, 0.7, h_trans_bandwidth=0.02, l_trans_bandwidth=0.002)
raw_haemo.compute_psd(fmax=2).plot(
average=True, xscale="log", axes=fig_filter.axes, color="g", amplitude=False, show=False
raw_haemo.compute_psd(fmax=3).plot(
average=True, axes=fig_filter.axes, color="g", amplitude=True, show=False
)
fig_raw_haemo_filter = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="Filtered HbO and HbR", show=False)
@@ -1119,6 +1134,8 @@ def epochs_calculations(raw_haemo, events, event_dict):
# Plot for each condition
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 4))
for idx, condition in enumerate(epochs.event_id.keys()):
logger.info(condition)
logger.info(idx)
# Plot images for each condition
fig_epochs_data = epochs[condition].plot_image(
combine="mean",
@@ -1127,11 +1144,15 @@ def epochs_calculations(raw_haemo, events, event_dict):
ts_args=dict(ylim=dict(hbo=[-1, 1], hbr=[-1, 1])),
show=False
)
for i in fig_epochs_data:
for j, fig in enumerate(fig_epochs_data):
logger.info("------------------------------------------")
logger.info(j)
logger.info(fig)
ax = fig.axes[0]
original_title = ax.get_title()
ax.set_title(f"{condition}: {original_title}")
fig_epochs.append((f"fig_{condition}_data_{idx}", i)) # Store with a unique name
fig_epochs.append((f"fig_{condition}_data_{idx}_{j}", fig)) # Store with a unique name
# Evoked average figure for each condition
evoked_avg = epochs[condition].average()
@@ -2302,8 +2323,8 @@ def brain_landmarks_3d(raw_haemo: BaseRaw, show_optodes: Literal['sensors', 'lab
"Brodmann.17-lh": "blue",
"Brodmann.18-lh": "blue",
"Brodmann.19-lh": "blue",
"Brodmann.39-lh": "purple",
"Brodmann.40-lh": "pink",
"Brodmann.39-lh": "pink",
"Brodmann.40-lh": "purple",
"Brodmann.42-lh": "white",
"Brodmann.44-lh": "white",
"Brodmann.48-lh": "white",
@@ -2666,13 +2687,13 @@ def load_snirf(file_path: str) -> tuple[BaseRaw, Figure]:
raw = read_raw_snirf(file_path, preload=True, verbose=VERBOSITY) # type: ignore
raw.load_data(verbose=VERBOSITY) # type: ignore
# Strip the specified amount of seconds from the start of the file
total_duration = getattr(raw, "times")[-1]
if total_duration > SECONDS_TO_STRIP:
raw.crop(tmin=SECONDS_TO_STRIP, tmax=total_duration, verbose=VERBOSITY) # type: ignore
logger.info(f"Stripped first {SECONDS_TO_STRIP} second(s) of data.")
else:
logger.info(f"Data length ({total_duration:.2f}s) less than strip duration; no cropping applied.")
# # Strip the specified amount of seconds from the start of the file
# total_duration = getattr(raw, "times")[-1]
# if total_duration > SECONDS_TO_STRIP:
# raw.crop(tmin=SECONDS_TO_STRIP, tmax=total_duration, verbose=VERBOSITY) # type: ignore
# logger.info(f"Stripped first {SECONDS_TO_STRIP} second(s) of data.")
# else:
# logger.info(f"Data length ({total_duration:.2f}s) less than strip duration; no cropping applied.")
# If the user forcibly dropped channels, remove them now before any processing occurs
# logger.info("Checking if there are channels to forcibly drop...")
@@ -2863,6 +2884,399 @@ def calculate_dpf(file_path):
def iqr_threshold(coeffs: NDArray[float64], k: float = 1.5) -> floating[Any]:
"""
Calculate the interquartile range (IQR) threshold scaled by a factor, k.
Parameters
----------
coeffs : NDArray[float64]
Array of coefficients to compute the IQR from.
k : float, optional
Scaling factor for the IQR (default is 1.5).
Returns
-------
floating[Any]
The scaled IQR threshold value.
"""
# Calculate the IQR
q1 = np.percentile(coeffs, 25)
q3 = np.percentile(coeffs, 75)
iqr = q3 - q1
return k * iqr
def wavelet_iqr_denoise(signal: NDArray[float64], wavelet: str = 'db4', level: int = 3) -> NDArray[float64]:
"""
Denoises a signal using wavelet decomposition and IQR-based thresholding on detail coefficients.
Parameters
----------
signal : NDArray[float64]
The input signal array to denoise.
wavelet : str, optional
The type of wavelet to use for decomposition (default is 'db4').
level : int, optional
Decomposition level for wavelet transform (default is 3).
Returns
-------
NDArray[float64]
The denoised signal array, with the same length as the input.
"""
# Decompose the signal using wavelet transform and initialize a list with approximation coefficients
coeffs: list[NDArray[float64]] = pywt.wavedec(signal, wavelet, level=level) # type: ignore
cA = coeffs[0]
denoised_coeffs = [cA]
# Threshold detail coefficients to reduce noise
for cD in coeffs[1:]:
threshold = iqr_threshold(cD, IQR)
cD_thresh = np.sign(cD) * np.maximum(np.abs(cD) - threshold, 0.0) # np.where((cD < lower) | (cD > upper), 0, cD)
cD_thresh = cD_thresh.astype(float64)
denoised_coeffs.append(cD_thresh)
# Reconstruct the denoised signal
denoised_signal = cast(NDArray[float64], pywt.waverec(denoised_coeffs, wavelet)) # type: ignore
return denoised_signal[:len(signal)]
def calculate_and_apply_wavelet(data: BaseRaw) -> tuple[BaseRaw, Figure]:
"""
Applies a wavelet IQR denoising filter to the data and generates a plot.
Parameters
----------
data : BaseRaw
The loaded data object to process.
ID : str
File name of the the snirf file that was loaded.
Returns
-------
tuple[BaseRaw, Figure]
- BaseRaw: The processed data object.
- Figure: The corresponding Matplotlib figure.
"""
logger.info("Applying the wavelet filter...")
# Denoise the data
logger.info("Denoising the data...")
loaded_data: NDArray[float64] = data.get_data(verbose=VERBOSITY) # type: ignore
denoised_data = np.zeros_like(loaded_data)
logger.info("Calculating the IQR, decomposing the signal, and thresholding the coefficients...")
for ch in range(loaded_data.shape[0]):
denoised_data[ch, :] = wavelet_iqr_denoise(loaded_data[ch, :], wavelet='db4', level=3)
# Reconstruct the data with the annotations
logger.info("Reconstructing the data with annotations...")
raw_with_tddr_and_wavelet = RawArray(denoised_data, cast(Info, data.info), verbose=VERBOSITY)
raw_with_tddr_and_wavelet.set_annotations(data.annotations.copy(), verbose=VERBOSITY) # type: ignore
# Create a figure for the results
logger.info("Creating the figure...")
fig = cast(Figure, raw_with_tddr_and_wavelet.plot(show=False, n_channels=len(getattr(data, "ch_names")), duration=data.times[-1]).figure) # type: ignore
fig.suptitle(f"Wavelet for ", fontsize=16) # type: ignore
fig.subplots_adjust(top=0.92)
plt.close(fig)
logger.info("Successfully applied the wavelet filter.")
return raw_with_tddr_and_wavelet, fig
def short_channel_processing_for_hr(data: BaseRaw, short_chans: BaseRaw | None) -> tuple[float, NDArray[float64], NDArray[float64]]:
"""
Extract and trim short-channel fNIRS signal for heart rate analysis.
Parameters
----------
data : BaseRaw
The loaded data object to process.
short_chans : BaseRaw | None
Data object with only short separation channels, or None if unavailable.
Returns
-------
tuple[float, NDArray[float64], NDArray[float64]]
- float: Sampling frequency of the signal.
- NDArray[float64]: Trimmed short-channel signal.
- NDArray[float64]: Corresponding time values.
"""
# Find the short channel (or best candidate) and extract signal data and sampling frequency
logger.info("Extracting the signal and calculating the sampling frequency...")
# If a short channel exists, use it for our signal. Otherwise just take the first channel in the data
# TODO: Find a better way around this
if short_chans is not None:
signal = cast(NDArray[float64], short_chans.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore
else:
signal = cast(NDArray[float64], data.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore
# Calculate the sampling frequency
sfreq = cast(int, data.info['sfreq'])
# Trim start and end of the signal to remove edge artifacts
logger.info(f"Removing {SECONDS_TO_STRIP_HR} seconds from the beginning and end of the file...")
strip_samples = int(sfreq * SECONDS_TO_STRIP_HR)
signal_trimmed = signal[strip_samples:-strip_samples]
times_trimmed = cast(NDArray[float64], getattr(data, "times"))[strip_samples:-strip_samples]
return sfreq, signal_trimmed, times_trimmed
def calculate_heart_rate_neurokit(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[float64], float]:
"""
Calculate and smooth heart rate from a trimmed signal using NeuroKit.
Parameters
----------
sfreq : float
Sampling frequency of the signal.
signal_trimmed : NDArray[float64]
Preprocessed and trimmed fNIRS signal.
Returns
-------
tuple[NDArray[float64], float]
- NDArray[float64]: Smoothed heart rate time series (BPM).
- float: Mean heart rate.
"""
logger.info("Calculating heart rate using NeuroKit...")
# Filter signal to isolate heart rate frequencies and detect peaks
logger.info("Filtering the signal and detecting peaks...")
signal_filtered = cast(NDArray[float64], nk.signal_filter(signal_trimmed, sampling_rate=sfreq, lowcut=0.8, highcut=2.5)) # type: ignore
peaks_dict = cast(dict[str, Any], nk.signal_findpeaks(signal_filtered)) # type: ignore
peaks = peaks_dict['Peaks']
hr = cast(NDArray[float64], nk.signal_rate(peaks, sampling_rate=sfreq, desired_length=len(signal_trimmed))) # type: ignore
hr_clean = np.clip(hr, MAX_LOW_HR, MAX_HIGH_HR)
# Smooth heart rate time series by replacing spikes with local rolling mean and calculate the mean
logger.info("Smoothing the signal and calculating the mean...")
hr_series = pd.Series(hr_clean)
local_median = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).median()
spikes = hr_series > (local_median + 10)
smoothed_values = hr_series.copy()
smoothed_spikes = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).mean()
smoothed_values[spikes] = smoothed_spikes[spikes]
hr_smooth_nk = cast(NDArray[float64], smoothed_values.to_numpy()) # type: ignore
mean_hr_nk = hr_smooth_nk.mean()
logger.info("Original HR min/max: %f, %f", hr_clean.min(), hr_clean.max())
logger.info("Smoothed HR min/max:%f, %f", hr_smooth_nk.min(), hr_smooth_nk.max())
logger.info(f"Estimated mean HR nk: {mean_hr_nk:.1f} BPM")
logger.info("Successfully calculated heart rate using NeuroKit.")
return hr_smooth_nk, mean_hr_nk
def calculate_heart_rate_scipy(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float]:
"""
Estimate heart rate using spectral analysis on a high-pass filtered signal.
Parameters
----------
sfreq : float
Sampling frequency of the input signal.
signal_trimmed : NDArray[float64]
Trimmed fNIRS signal to analyze.
Returns
-------
tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float]
- NDArray[floating[Any]]: Frequencies converted to beats per minute (BPM).
- NDArray[float64]: Power spectral density (PSD) of the signal.
- np.ndarray[Any, np.dtype[np.bool_]]: Boolean mask indicating frequencies within heart rate range (30-300 BPM).
- float: Estimated mean heart rate in BPM corresponding to the PSD peak within the range.
"""
logger.info("Calculating heart rate using SciPy...")
# Apply a high-pass Butterworth filter to remove slow trends below 0.5 Hz from the trimmed signal (actual data)
logger.info("Applying a butterworth filter...")
b, a = cast(tuple[NDArray[float64], NDArray[float64]], butter(2, 0.5 / (sfreq / 2), btype='high'))
signal_hp = cast(NDArray[float64],filtfilt(b, a, signal_trimmed))
# Calculate the Power Spectral Density (PSD) of the filtered signal using Welch's method
logger.info("Calculating the PSD...")
nperseg = min(len(signal_hp), 4096)
frequencies_scipy, psd_scipy = cast(tuple[NDArray[float64], NDArray[float64]], welch(signal_hp, fs=sfreq, nperseg=nperseg, noverlap=nperseg//2))
# Convert frequency values to beats per minute (BPM) and set a heart rate range (30-300 BPM)
logger.info("Converting to BPM...")
freq_bpm_scipy = frequencies_scipy * 60
freq_range_scipy = (freq_bpm_scipy > 30) & (freq_bpm_scipy < 300)
# Identify the peak frequency within the heart rate range and estimate the mean heart rate in BPM
logger.info("Finding the mean...")
peak_index = np.argmax(psd_scipy[freq_range_scipy])
mean_hr_scipy = freq_bpm_scipy[freq_range_scipy][peak_index]
logger.info("Successfully calculated heart rate using SciPy.")
return freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy
def plot_heart_rate(
freq_bpm_scipy: NDArray[floating[Any]],
psd_scipy: NDArray[float64],
freq_range_scipy: np.ndarray[Any, np.dtype[np.bool_]],
mean_hr_scipy: float,
hr_smooth_nk: NDArray[floating[Any]],
mean_hr_nk: float,
times_trimmed: NDArray[floating[Any]],
overruled: bool
) -> tuple[Figure, Figure]:
"""
Generate plots comparing heart rate estimates from SciPy PSD and NeuroKit2.
Parameters
----------
freq_bpm_scipy : NDArray[floating[Any]]
Frequencies in beats per minute from SciPy PSD analysis.
psd_scipy : NDArray[float64]
Power spectral density values corresponding to freq_bpm_scipy.
freq_range_scipy : np.ndarray[Any, np.dtype[np.bool_]]
Boolean mask indicating the heart rate frequency range used in PSD.
mean_hr_scipy : float
Mean heart rate estimated from SciPy PSD peak.
hr_smooth_nk : NDArray[floating[Any]]
Smoothed instantaneous heart rate from NeuroKit2.
mean_hr_nk : float
Mean heart rate estimated from NeuroKit2 data.
times_trimmed : NDArray[floating[Any]]
Time points corresponding to hr_smooth_nk values.
overruled: bool
True if the heart rate from NeuroKit2 is overriding the results from the PSD.
Returns
-------
tuple[Figure, Figure]
- Figure showing the PSD and SciPy heart rate estimate.
- Figure showing the time series comparison of heart rates.
"""
# Create the first plot for the PSD. Add a yellow range to show what we will be filtering to.
logger.info("Creating the figure...")
fig1, ax1 = plt.subplots(figsize=(10, 5)) # type: ignore
ax1.set_xlim(30, 300)
ax1.plot(freq_bpm_scipy[freq_range_scipy], psd_scipy[freq_range_scipy]) # type: ignore
ax1.axvline(x=mean_hr_scipy, color='red', linestyle='--', label=f'Mean HR: {mean_hr_scipy:.1f} BPM') # type: ignore
ax1.axvspan(min(mean_hr_nk - HEART_RATE_WINDOW, mean_hr_scipy - HEART_RATE_WINDOW), max(mean_hr_nk + HEART_RATE_WINDOW, mean_hr_scipy + HEART_RATE_WINDOW), color='yellow', alpha=0.3, label=f'HR Range ±{HEART_RATE_WINDOW} BPM') # type: ignore
ax1.set_xlabel('Heart Rate (BPM)') # type: ignore
ax1.set_ylabel('Power Spectral Density') # type: ignore
ax1.set_title('PSD of fNIRS signal - Peak indicates Heart Rate') # type: ignore
ax1.grid(True) # type: ignore
# Was the value we reported here correct for the data on the graph or was it overruled?
if overruled:
note = (
'\n'
'Note: Calculation was bad!\n'
'Data has been set to match\n'
'the value from NeuroKit2.'
)
phantom = Line2D([0], [0], color='none', label=note)
handles, _ = ax1.get_legend_handles_labels()
ax1.legend(handles=handles + [phantom]) # type: ignore
else:
ax1.legend() # type: ignore
plt.close(fig1)
# Create the second plot showing the rolling heart rate, as well as the two averages that were calculated
logger.info("Creating the figure...")
fig2, ax2 = plt.subplots(figsize=(14, 6)) # type: ignore
ax2.plot(times_trimmed, hr_smooth_nk, label='Instantaneous HR (NeuroKit2)', color='blue', alpha=0.7) # type: ignore
ax2.axhline(mean_hr_nk, color='red', linestyle='--', label=f'Mean HR NeuroKit2: {mean_hr_nk:.1f} BPM') # type: ignore
ax2.axhline(mean_hr_scipy, color='orange', linestyle=':', label=f'SciPy Welch PSD (HP filtered): {mean_hr_scipy:.1f} BPM') # type: ignore
ax2.set_xlabel('Time (seconds)') # type: ignore
ax2.set_ylabel('Heart Rate (BPM)') # type: ignore
ax2.set_title('Heart Rate Estimates Comparison') # type: ignore
ax2.legend() # type: ignore
ax2.grid(True) # type: ignore
fig2.tight_layout()
plt.close(fig2)
return fig1, fig2
def hr_calc(raw):
if SHORT_CHANNEL:
short_chans = get_short_channels(raw, max_dist=SHORT_CHANNEL_THRESH)
else:
short_chans = None
sfreq, signal_trimmed, times_trimmed = short_channel_processing_for_hr(raw, short_chans)
hr_smooth_nk, mean_hr_nk = calculate_heart_rate_neurokit(sfreq, signal_trimmed)
freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy = calculate_heart_rate_scipy(sfreq, signal_trimmed)
# HACK: This sucks but looking at the graphs I trust neurokit2 more
overruled = False
if mean_hr_scipy < mean_hr_nk - 15:
mean_hr_scipy = mean_hr_nk
overruled = True
if mean_hr_scipy > mean_hr_nk + 15:
mean_hr_scipy = mean_hr_nk
overruled = True
hr1, hr2 = plot_heart_rate(freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy, hr_smooth_nk, mean_hr_nk, times_trimmed, overruled)
fig = raw.plot_psd(show=False)
raw_filtered = raw.copy().filter(0.5, 3, fir_design='firwin')
sfreq = raw.info['sfreq']
data = raw_filtered.get_data()
channel_names = raw.ch_names
# --- Parameters for PSD ---
desired_bin_hz = 0.1
nperseg = int(sfreq / desired_bin_hz)
hr_range = (30, 180)
# --- Function to find strongest local peak ---
def find_hr_from_psd(ch_data):
f, Pxx = welch(ch_data, sfreq, nperseg=nperseg)
mask = (f >= hr_range[0]/60) & (f <= hr_range[1]/60)
f_masked = f[mask]
Pxx_masked = Pxx[mask]
if len(Pxx_masked) < 3:
return np.nan
peaks = [i for i in range(1, len(Pxx_masked)-1)
if Pxx_masked[i] > Pxx_masked[i-1] and Pxx_masked[i] > Pxx_masked[i+1]]
if not peaks:
return np.nan
best_idx = peaks[np.argmax([Pxx_masked[i] for i in peaks])]
return f_masked[best_idx] * 60 # bpm
# --- Compute HR across all channels ---
hr_all_channels = np.array([find_hr_from_psd(data[i, :]) for i in range(len(channel_names))])
hr_all_channels = hr_all_channels[~np.isnan(hr_all_channels)]
hr_mode = np.round(np.median(hr_all_channels)) # Use median if some NaNs
print(f"Estimated Heart Rate: {hr_mode} bpm")
hr_freq = hr_mode / 60 # Hz
low = hr_freq - 0.3
high = hr_freq + 0.3
return fig, hr1, hr2, low, high
def process_participant(file_path, progress_callback=None):
fig_individual: dict[str, Figure] = {}
@@ -2873,6 +3287,35 @@ def process_participant(file_path, progress_callback=None):
fig_individual["Loaded Raw"] = fig_raw
if progress_callback: progress_callback(1)
logger.info("1")
if hasattr(raw, 'annotations') and len(raw.annotations) > 0:
# Get time of first event
first_event_time = raw.annotations.onset[0]
trim_time = max(0, first_event_time - 5.0) # Ensure we don't go negative
raw.crop(tmin=trim_time)
# Shift annotation onsets to match new t=0
import mne
ann = raw.annotations
ann_shifted = mne.Annotations(
onset=ann.onset - trim_time, # shift to start at zero
duration=ann.duration,
description=ann.description
)
data = raw.get_data()
info = raw.info.copy()
raw = mne.io.RawArray(data, info)
raw.set_annotations(ann_shifted)
logger.info(f"Trimmed raw data: start at {trim_time}s (5s before first event), t=0 at new start")
else:
logger.warning("No events found, skipping trim step.")
fig_trimmed = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Trimmed Raw", show=False)
fig_individual["Trimmed Raw"] = fig_trimmed
if progress_callback: progress_callback(2)
logger.info("2")
# Step 1.5: Verify optode positions
fig_optodes = raw.plot_sensors(show_names=True, to_sphere=True, show=False) # type: ignore
@@ -2884,9 +3327,18 @@ def process_participant(file_path, progress_callback=None):
# raw = raw.resample(0.5) # Downsample to 0.5 Hz
# Step 2: Bad from SCI
if True:
fig, hr1, hr2, low, high = hr_calc(raw)
fig_individual["PSD"] = fig
fig_individual['HeartRate_PSD'] = hr1
fig_individual['HeartRate_Time'] = hr2
if progress_callback: progress_callback(10)
if progress_callback: progress_callback(2)
bad_sci = []
if SCI:
bad_sci, fig_sci_1, fig_sci_2 = calculate_scalp_coupling(raw)
bad_sci, fig_sci_1, fig_sci_2 = calculate_scalp_coupling(raw, low, high)
fig_individual["SCI1"] = fig_sci_1
fig_individual["SCI2"] = fig_sci_2
if progress_callback: progress_callback(3)
@@ -2938,6 +3390,12 @@ def process_participant(file_path, progress_callback=None):
fig_individual["TDDR"] = fig_raw_od_tddr
if progress_callback: progress_callback(9)
logger.info("9")
raw_od, fig = calculate_and_apply_wavelet(raw_od)
fig_individual["Wavelet"] = fig
if progress_callback: progress_callback(9)
# Step 8: BLL
raw_haemo = beer_lambert_law(raw_od, ppf=calculate_dpf(file_path))

View File

@@ -58,7 +58,7 @@ SECTIONS = [
{
"title": "Preprocessing",
"params": [
{"name": "SECONDS_TO_STRIP", "default": 0, "type": int, "help": "Seconds to remove from beginning of all loaded snirf files. Setting this to 0 will remove nothing from the files."},
# {"name": "SECONDS_TO_STRIP", "default": 0, "type": int, "help": "Seconds to remove from beginning of all loaded snirf files. Setting this to 0 will remove nothing from the files."},
{"name": "DOWNSAMPLE", "default": True, "type": bool, "help": "Should the snirf files be downsampled? If this is set to True, DOWNSAMPLE_FREQUENCY will be used as the target frequency to downsample to."},
{"name": "DOWNSAMPLE_FREQUENCY", "default": 25, "type": int, "help": "Frequency (Hz) to downsample to. If this is set higher than the input data, new data will be interpolated. Only used if DOWNSAMPLE is set to True"},
]
@@ -124,7 +124,7 @@ SECTIONS = [
"title": "Filtering",
"params": [
{"name": "L_FREQ", "default": 0.005, "type": float, "help": "Any frequencies lower than this value will be removed."},
{"name": "H_FREQ", "default": 0.7, "type": float, "help": "Any frequencies higher than this value will be removed."},
{"name": "H_FREQ", "default": 0.3, "type": float, "help": "Any frequencies higher than this value will be removed."},
#{"name": "FILTER", "default": True, "type": bool, "help": "Calculate Peak Spectral Power."},
]