diff --git a/changelog.md b/changelog.md index 4dad5e5..aff804e 100644 --- a/changelog.md +++ b/changelog.md @@ -1,7 +1,19 @@ +# Version 1.1.7 + +- Fixed a bug where having both a L_FREQ and H_FREQ would cause only the L_FREQ to be used +- Changed the default H_FREQ from 0.7 to 0.3 +- Removed SECONDS_TO_STRIP from the preprocessing options +- Instead all files are trimmed up until 5 seconds before the first annotation/event in the file +- Added a PSD graph, along with 2 heart rate images to the individual participant viewer +- The PSD graph is used to help calculate the heart rate, whereas the other 2 are just for show currently +- SCI is now done using a .6hz window around the calculated heart rate compared to a window around an average heart rate +- Fixed an issue with some epochs figures not showing + + # Version 1.1.6 - Fixed Process button from appearing when no files are selected -- Fix for instand child process crash on Windows +- Fixed a bug that would cause an instant child process crash on Windows - Added L_FREQ and H_FREQ parameters for more user control over low and high pass filtering diff --git a/flares.py b/flares.py index a4265a0..5346994 100644 --- a/flares.py +++ b/flares.py @@ -28,10 +28,11 @@ import matplotlib.colors as mcolors from matplotlib.figure import Figure from matplotlib.axes import Axes from matplotlib.colors import LinearSegmentedColormap +from matplotlib.lines import Line2D import numpy as np from numpy.typing import NDArray -from numpy import float64 +from numpy import float64, floating import pandas as pd from pandas import DataFrame @@ -47,11 +48,15 @@ from statsmodels.stats.multitest import multipletests from scipy import stats from scipy.spatial.distance import cdist +from scipy.signal import welch, butter, filtfilt # type: ignore + +import pywt # type: ignore +import neurokit2 as nk # type: ignore # Backen visualization needed to be defined for pyinstaller import pyvistaqt # type: ignore -import vtkmodules.util.data_model -import vtkmodules.util.execution_model +#import vtkmodules.util.data_model +#import vtkmodules.util.execution_model # External library imports for mne from mne import ( @@ -61,7 +66,7 @@ from mne import ( ) # type: ignore from mne.source_space import SourceSpaces from mne.transforms import Transform # type: ignore -from mne.io import BaseRaw, read_raw_snirf # type: ignore +from mne.io import BaseRaw, RawArray, read_raw_snirf # type: ignore from mne.preprocessing.nirs import ( beer_lambert_law, optical_density, temporal_derivative_distribution_repair, @@ -110,7 +115,7 @@ FIXED_CATEGORY_COLORS = { AGE: float GENDER: str -SECONDS_TO_STRIP: int +# SECONDS_TO_STRIP: int DOWNSAMPLE: bool DOWNSAMPLE_FREQUENCY: int @@ -128,6 +133,15 @@ PSP_THRESHOLD: float TDDR: bool +IQR = 1.5 +HEART_RATE = True # True if heart rate should be calculated. This helps the SCI, PSP, and SNR methods to be more accurate. +SECONDS_TO_STRIP_HR =5 # Amount of seconds to temporarily strip from the data to calculate heart rate more effectively. Useful if participant removed cap while still recording. +MAX_LOW_HR = 40 # Any heart rate values lower than this will be set to this value. +MAX_HIGH_HR = 200 # Any heart rate values higher than this will be set to this value. +SMOOTHING_WINDOW_HR = 100 # Heart rate will be calculated as a rolling average over this many amount of samples. +HEART_RATE_WINDOW = 25 # Amount of BPM above and below the calculated average to use for a range of resting BPM. +SHORT_CHANNEL_THRESH = 0.018 + ENHANCE_NEGATIVE_CORRELATION: bool L_FREQ: float @@ -170,7 +184,7 @@ GROUP = "Default" REQUIRED_KEYS: dict[str, Any] = { - "SECONDS_TO_STRIP": int, + # "SECONDS_TO_STRIP": int, "DOWNSAMPLE": bool, "DOWNSAMPLE_FREQUENCY": int, @@ -1071,23 +1085,24 @@ def mark_bads(raw, bad_sci, bad_snr, bad_psp): def filter_the_data(raw_haemo): # --- STEP 5: Filtering (0.01–0.2 Hz bandpass) --- - fig_filter = raw_haemo.compute_psd(fmax=2).plot( - average=True, xscale="log", color="r", show=False, amplitude=False + fig_filter = raw_haemo.compute_psd(fmax=3).plot( + average=True, color="r", show=False, amplitude=True ) if L_FREQ == 0 and H_FREQ != 0: raw_haemo = raw_haemo.filter(l_freq=None, h_freq=H_FREQ, h_trans_bandwidth=0.02) elif L_FREQ != 0 and H_FREQ == 0: raw_haemo = raw_haemo.filter(l_freq=L_FREQ, h_freq=None, l_trans_bandwidth=0.002) - elif L_FREQ != 0 and H_FREQ == 0: + elif L_FREQ != 0 and H_FREQ != 0: raw_haemo = raw_haemo.filter(l_freq=L_FREQ, h_freq=H_FREQ, l_trans_bandwidth=0.002, h_trans_bandwidth=0.02) - + else: + print("No filter") #raw_haemo = raw_haemo.filter(l_freq=None, h_freq=0.4, h_trans_bandwidth=0.2) #raw_haemo = raw_haemo.filter(l_freq=None, h_freq=0.7, h_trans_bandwidth=0.2) #raw_haemo = raw_haemo.filter(0.005, 0.7, h_trans_bandwidth=0.02, l_trans_bandwidth=0.002) - raw_haemo.compute_psd(fmax=2).plot( - average=True, xscale="log", axes=fig_filter.axes, color="g", amplitude=False, show=False + raw_haemo.compute_psd(fmax=3).plot( + average=True, axes=fig_filter.axes, color="g", amplitude=True, show=False ) fig_raw_haemo_filter = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="Filtered HbO and HbR", show=False) @@ -1119,6 +1134,8 @@ def epochs_calculations(raw_haemo, events, event_dict): # Plot for each condition fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 4)) for idx, condition in enumerate(epochs.event_id.keys()): + logger.info(condition) + logger.info(idx) # Plot images for each condition fig_epochs_data = epochs[condition].plot_image( combine="mean", @@ -1127,11 +1144,15 @@ def epochs_calculations(raw_haemo, events, event_dict): ts_args=dict(ylim=dict(hbo=[-1, 1], hbr=[-1, 1])), show=False ) - for i in fig_epochs_data: + for j, fig in enumerate(fig_epochs_data): + logger.info("------------------------------------------") + logger.info(j) + logger.info(fig) + ax = fig.axes[0] original_title = ax.get_title() ax.set_title(f"{condition}: {original_title}") - fig_epochs.append((f"fig_{condition}_data_{idx}", i)) # Store with a unique name + fig_epochs.append((f"fig_{condition}_data_{idx}_{j}", fig)) # Store with a unique name # Evoked average figure for each condition evoked_avg = epochs[condition].average() @@ -2302,8 +2323,8 @@ def brain_landmarks_3d(raw_haemo: BaseRaw, show_optodes: Literal['sensors', 'lab "Brodmann.17-lh": "blue", "Brodmann.18-lh": "blue", "Brodmann.19-lh": "blue", - "Brodmann.39-lh": "purple", - "Brodmann.40-lh": "pink", + "Brodmann.39-lh": "pink", + "Brodmann.40-lh": "purple", "Brodmann.42-lh": "white", "Brodmann.44-lh": "white", "Brodmann.48-lh": "white", @@ -2666,13 +2687,13 @@ def load_snirf(file_path: str) -> tuple[BaseRaw, Figure]: raw = read_raw_snirf(file_path, preload=True, verbose=VERBOSITY) # type: ignore raw.load_data(verbose=VERBOSITY) # type: ignore - # Strip the specified amount of seconds from the start of the file - total_duration = getattr(raw, "times")[-1] - if total_duration > SECONDS_TO_STRIP: - raw.crop(tmin=SECONDS_TO_STRIP, tmax=total_duration, verbose=VERBOSITY) # type: ignore - logger.info(f"Stripped first {SECONDS_TO_STRIP} second(s) of data.") - else: - logger.info(f"Data length ({total_duration:.2f}s) less than strip duration; no cropping applied.") + # # Strip the specified amount of seconds from the start of the file + # total_duration = getattr(raw, "times")[-1] + # if total_duration > SECONDS_TO_STRIP: + # raw.crop(tmin=SECONDS_TO_STRIP, tmax=total_duration, verbose=VERBOSITY) # type: ignore + # logger.info(f"Stripped first {SECONDS_TO_STRIP} second(s) of data.") + # else: + # logger.info(f"Data length ({total_duration:.2f}s) less than strip duration; no cropping applied.") # If the user forcibly dropped channels, remove them now before any processing occurs # logger.info("Checking if there are channels to forcibly drop...") @@ -2863,6 +2884,399 @@ def calculate_dpf(file_path): +def iqr_threshold(coeffs: NDArray[float64], k: float = 1.5) -> floating[Any]: + + """ + Calculate the interquartile range (IQR) threshold scaled by a factor, k. + + Parameters + ---------- + coeffs : NDArray[float64] + Array of coefficients to compute the IQR from. + k : float, optional + Scaling factor for the IQR (default is 1.5). + + Returns + ------- + floating[Any] + The scaled IQR threshold value. + """ + + # Calculate the IQR + q1 = np.percentile(coeffs, 25) + q3 = np.percentile(coeffs, 75) + iqr = q3 - q1 + + return k * iqr + + + +def wavelet_iqr_denoise(signal: NDArray[float64], wavelet: str = 'db4', level: int = 3) -> NDArray[float64]: + """ + Denoises a signal using wavelet decomposition and IQR-based thresholding on detail coefficients. + + Parameters + ---------- + signal : NDArray[float64] + The input signal array to denoise. + wavelet : str, optional + The type of wavelet to use for decomposition (default is 'db4'). + level : int, optional + Decomposition level for wavelet transform (default is 3). + + Returns + ------- + NDArray[float64] + The denoised signal array, with the same length as the input. + """ + + # Decompose the signal using wavelet transform and initialize a list with approximation coefficients + coeffs: list[NDArray[float64]] = pywt.wavedec(signal, wavelet, level=level) # type: ignore + cA = coeffs[0] + denoised_coeffs = [cA] + + # Threshold detail coefficients to reduce noise + for cD in coeffs[1:]: + threshold = iqr_threshold(cD, IQR) + cD_thresh = np.sign(cD) * np.maximum(np.abs(cD) - threshold, 0.0) # np.where((cD < lower) | (cD > upper), 0, cD) + cD_thresh = cD_thresh.astype(float64) + denoised_coeffs.append(cD_thresh) + + # Reconstruct the denoised signal + denoised_signal = cast(NDArray[float64], pywt.waverec(denoised_coeffs, wavelet)) # type: ignore + return denoised_signal[:len(signal)] + + + +def calculate_and_apply_wavelet(data: BaseRaw) -> tuple[BaseRaw, Figure]: + """ + Applies a wavelet IQR denoising filter to the data and generates a plot. + + Parameters + ---------- + data : BaseRaw + The loaded data object to process. + ID : str + File name of the the snirf file that was loaded. + + Returns + ------- + tuple[BaseRaw, Figure] + - BaseRaw: The processed data object. + - Figure: The corresponding Matplotlib figure. + """ + + logger.info("Applying the wavelet filter...") + + # Denoise the data + logger.info("Denoising the data...") + loaded_data: NDArray[float64] = data.get_data(verbose=VERBOSITY) # type: ignore + denoised_data = np.zeros_like(loaded_data) + + logger.info("Calculating the IQR, decomposing the signal, and thresholding the coefficients...") + for ch in range(loaded_data.shape[0]): + denoised_data[ch, :] = wavelet_iqr_denoise(loaded_data[ch, :], wavelet='db4', level=3) + + # Reconstruct the data with the annotations + logger.info("Reconstructing the data with annotations...") + raw_with_tddr_and_wavelet = RawArray(denoised_data, cast(Info, data.info), verbose=VERBOSITY) + raw_with_tddr_and_wavelet.set_annotations(data.annotations.copy(), verbose=VERBOSITY) # type: ignore + + # Create a figure for the results + logger.info("Creating the figure...") + fig = cast(Figure, raw_with_tddr_and_wavelet.plot(show=False, n_channels=len(getattr(data, "ch_names")), duration=data.times[-1]).figure) # type: ignore + fig.suptitle(f"Wavelet for ", fontsize=16) # type: ignore + fig.subplots_adjust(top=0.92) + plt.close(fig) + + logger.info("Successfully applied the wavelet filter.") + + return raw_with_tddr_and_wavelet, fig + + + +def short_channel_processing_for_hr(data: BaseRaw, short_chans: BaseRaw | None) -> tuple[float, NDArray[float64], NDArray[float64]]: + """ + Extract and trim short-channel fNIRS signal for heart rate analysis. + + Parameters + ---------- + data : BaseRaw + The loaded data object to process. + short_chans : BaseRaw | None + Data object with only short separation channels, or None if unavailable. + + Returns + ------- + tuple[float, NDArray[float64], NDArray[float64]] + - float: Sampling frequency of the signal. + - NDArray[float64]: Trimmed short-channel signal. + - NDArray[float64]: Corresponding time values. + """ + + # Find the short channel (or best candidate) and extract signal data and sampling frequency + logger.info("Extracting the signal and calculating the sampling frequency...") + + # If a short channel exists, use it for our signal. Otherwise just take the first channel in the data + # TODO: Find a better way around this + if short_chans is not None: + signal = cast(NDArray[float64], short_chans.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore + else: + signal = cast(NDArray[float64], data.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore + + # Calculate the sampling frequency + sfreq = cast(int, data.info['sfreq']) + + # Trim start and end of the signal to remove edge artifacts + logger.info(f"Removing {SECONDS_TO_STRIP_HR} seconds from the beginning and end of the file...") + strip_samples = int(sfreq * SECONDS_TO_STRIP_HR) + signal_trimmed = signal[strip_samples:-strip_samples] + times_trimmed = cast(NDArray[float64], getattr(data, "times"))[strip_samples:-strip_samples] + + return sfreq, signal_trimmed, times_trimmed + + + +def calculate_heart_rate_neurokit(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[float64], float]: + """ + Calculate and smooth heart rate from a trimmed signal using NeuroKit. + + Parameters + ---------- + sfreq : float + Sampling frequency of the signal. + signal_trimmed : NDArray[float64] + Preprocessed and trimmed fNIRS signal. + + Returns + ------- + tuple[NDArray[float64], float] + - NDArray[float64]: Smoothed heart rate time series (BPM). + - float: Mean heart rate. + """ + + logger.info("Calculating heart rate using NeuroKit...") + + # Filter signal to isolate heart rate frequencies and detect peaks + logger.info("Filtering the signal and detecting peaks...") + signal_filtered = cast(NDArray[float64], nk.signal_filter(signal_trimmed, sampling_rate=sfreq, lowcut=0.8, highcut=2.5)) # type: ignore + peaks_dict = cast(dict[str, Any], nk.signal_findpeaks(signal_filtered)) # type: ignore + peaks = peaks_dict['Peaks'] + hr = cast(NDArray[float64], nk.signal_rate(peaks, sampling_rate=sfreq, desired_length=len(signal_trimmed))) # type: ignore + hr_clean = np.clip(hr, MAX_LOW_HR, MAX_HIGH_HR) + + # Smooth heart rate time series by replacing spikes with local rolling mean and calculate the mean + logger.info("Smoothing the signal and calculating the mean...") + hr_series = pd.Series(hr_clean) + local_median = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).median() + spikes = hr_series > (local_median + 10) + smoothed_values = hr_series.copy() + smoothed_spikes = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).mean() + smoothed_values[spikes] = smoothed_spikes[spikes] + hr_smooth_nk = cast(NDArray[float64], smoothed_values.to_numpy()) # type: ignore + mean_hr_nk = hr_smooth_nk.mean() + + logger.info("Original HR min/max: %f, %f", hr_clean.min(), hr_clean.max()) + logger.info("Smoothed HR min/max:%f, %f", hr_smooth_nk.min(), hr_smooth_nk.max()) + logger.info(f"Estimated mean HR nk: {mean_hr_nk:.1f} BPM") + + logger.info("Successfully calculated heart rate using NeuroKit.") + + return hr_smooth_nk, mean_hr_nk + + + +def calculate_heart_rate_scipy(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float]: + """ + Estimate heart rate using spectral analysis on a high-pass filtered signal. + + Parameters + ---------- + sfreq : float + Sampling frequency of the input signal. + signal_trimmed : NDArray[float64] + Trimmed fNIRS signal to analyze. + + Returns + ------- + tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float] + - NDArray[floating[Any]]: Frequencies converted to beats per minute (BPM). + - NDArray[float64]: Power spectral density (PSD) of the signal. + - np.ndarray[Any, np.dtype[np.bool_]]: Boolean mask indicating frequencies within heart rate range (30-300 BPM). + - float: Estimated mean heart rate in BPM corresponding to the PSD peak within the range. + """ + + logger.info("Calculating heart rate using SciPy...") + + # Apply a high-pass Butterworth filter to remove slow trends below 0.5 Hz from the trimmed signal (actual data) + logger.info("Applying a butterworth filter...") + b, a = cast(tuple[NDArray[float64], NDArray[float64]], butter(2, 0.5 / (sfreq / 2), btype='high')) + signal_hp = cast(NDArray[float64],filtfilt(b, a, signal_trimmed)) + + # Calculate the Power Spectral Density (PSD) of the filtered signal using Welch's method + logger.info("Calculating the PSD...") + nperseg = min(len(signal_hp), 4096) + frequencies_scipy, psd_scipy = cast(tuple[NDArray[float64], NDArray[float64]], welch(signal_hp, fs=sfreq, nperseg=nperseg, noverlap=nperseg//2)) + + # Convert frequency values to beats per minute (BPM) and set a heart rate range (30-300 BPM) + logger.info("Converting to BPM...") + freq_bpm_scipy = frequencies_scipy * 60 + freq_range_scipy = (freq_bpm_scipy > 30) & (freq_bpm_scipy < 300) + + # Identify the peak frequency within the heart rate range and estimate the mean heart rate in BPM + logger.info("Finding the mean...") + peak_index = np.argmax(psd_scipy[freq_range_scipy]) + mean_hr_scipy = freq_bpm_scipy[freq_range_scipy][peak_index] + + logger.info("Successfully calculated heart rate using SciPy.") + + return freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy + + +def plot_heart_rate( + freq_bpm_scipy: NDArray[floating[Any]], + psd_scipy: NDArray[float64], + freq_range_scipy: np.ndarray[Any, np.dtype[np.bool_]], + mean_hr_scipy: float, + hr_smooth_nk: NDArray[floating[Any]], + mean_hr_nk: float, + times_trimmed: NDArray[floating[Any]], + overruled: bool +) -> tuple[Figure, Figure]: + """ + Generate plots comparing heart rate estimates from SciPy PSD and NeuroKit2. + + Parameters + ---------- + freq_bpm_scipy : NDArray[floating[Any]] + Frequencies in beats per minute from SciPy PSD analysis. + psd_scipy : NDArray[float64] + Power spectral density values corresponding to freq_bpm_scipy. + freq_range_scipy : np.ndarray[Any, np.dtype[np.bool_]] + Boolean mask indicating the heart rate frequency range used in PSD. + mean_hr_scipy : float + Mean heart rate estimated from SciPy PSD peak. + hr_smooth_nk : NDArray[floating[Any]] + Smoothed instantaneous heart rate from NeuroKit2. + mean_hr_nk : float + Mean heart rate estimated from NeuroKit2 data. + times_trimmed : NDArray[floating[Any]] + Time points corresponding to hr_smooth_nk values. + overruled: bool + True if the heart rate from NeuroKit2 is overriding the results from the PSD. + + Returns + ------- + tuple[Figure, Figure] + - Figure showing the PSD and SciPy heart rate estimate. + - Figure showing the time series comparison of heart rates. + """ + + # Create the first plot for the PSD. Add a yellow range to show what we will be filtering to. + logger.info("Creating the figure...") + fig1, ax1 = plt.subplots(figsize=(10, 5)) # type: ignore + ax1.set_xlim(30, 300) + ax1.plot(freq_bpm_scipy[freq_range_scipy], psd_scipy[freq_range_scipy]) # type: ignore + ax1.axvline(x=mean_hr_scipy, color='red', linestyle='--', label=f'Mean HR: {mean_hr_scipy:.1f} BPM') # type: ignore + ax1.axvspan(min(mean_hr_nk - HEART_RATE_WINDOW, mean_hr_scipy - HEART_RATE_WINDOW), max(mean_hr_nk + HEART_RATE_WINDOW, mean_hr_scipy + HEART_RATE_WINDOW), color='yellow', alpha=0.3, label=f'HR Range ±{HEART_RATE_WINDOW} BPM') # type: ignore + ax1.set_xlabel('Heart Rate (BPM)') # type: ignore + ax1.set_ylabel('Power Spectral Density') # type: ignore + ax1.set_title('PSD of fNIRS signal - Peak indicates Heart Rate') # type: ignore + ax1.grid(True) # type: ignore + + # Was the value we reported here correct for the data on the graph or was it overruled? + if overruled: + note = ( + '\n' + 'Note: Calculation was bad!\n' + 'Data has been set to match\n' + 'the value from NeuroKit2.' + ) + phantom = Line2D([0], [0], color='none', label=note) + handles, _ = ax1.get_legend_handles_labels() + ax1.legend(handles=handles + [phantom]) # type: ignore + + else: + ax1.legend() # type: ignore + plt.close(fig1) + + # Create the second plot showing the rolling heart rate, as well as the two averages that were calculated + logger.info("Creating the figure...") + fig2, ax2 = plt.subplots(figsize=(14, 6)) # type: ignore + ax2.plot(times_trimmed, hr_smooth_nk, label='Instantaneous HR (NeuroKit2)', color='blue', alpha=0.7) # type: ignore + ax2.axhline(mean_hr_nk, color='red', linestyle='--', label=f'Mean HR NeuroKit2: {mean_hr_nk:.1f} BPM') # type: ignore + ax2.axhline(mean_hr_scipy, color='orange', linestyle=':', label=f'SciPy Welch PSD (HP filtered): {mean_hr_scipy:.1f} BPM') # type: ignore + ax2.set_xlabel('Time (seconds)') # type: ignore + ax2.set_ylabel('Heart Rate (BPM)') # type: ignore + ax2.set_title('Heart Rate Estimates Comparison') # type: ignore + ax2.legend() # type: ignore + ax2.grid(True) # type: ignore + fig2.tight_layout() + plt.close(fig2) + + return fig1, fig2 + + + +def hr_calc(raw): + if SHORT_CHANNEL: + short_chans = get_short_channels(raw, max_dist=SHORT_CHANNEL_THRESH) + else: + short_chans = None + sfreq, signal_trimmed, times_trimmed = short_channel_processing_for_hr(raw, short_chans) + hr_smooth_nk, mean_hr_nk = calculate_heart_rate_neurokit(sfreq, signal_trimmed) + freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy = calculate_heart_rate_scipy(sfreq, signal_trimmed) + + # HACK: This sucks but looking at the graphs I trust neurokit2 more + overruled = False + if mean_hr_scipy < mean_hr_nk - 15: + mean_hr_scipy = mean_hr_nk + overruled = True + if mean_hr_scipy > mean_hr_nk + 15: + mean_hr_scipy = mean_hr_nk + overruled = True + + hr1, hr2 = plot_heart_rate(freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy, hr_smooth_nk, mean_hr_nk, times_trimmed, overruled) + + fig = raw.plot_psd(show=False) + raw_filtered = raw.copy().filter(0.5, 3, fir_design='firwin') + sfreq = raw.info['sfreq'] + data = raw_filtered.get_data() + channel_names = raw.ch_names + + # --- Parameters for PSD --- + desired_bin_hz = 0.1 + nperseg = int(sfreq / desired_bin_hz) + hr_range = (30, 180) + + # --- Function to find strongest local peak --- + def find_hr_from_psd(ch_data): + f, Pxx = welch(ch_data, sfreq, nperseg=nperseg) + mask = (f >= hr_range[0]/60) & (f <= hr_range[1]/60) + f_masked = f[mask] + Pxx_masked = Pxx[mask] + if len(Pxx_masked) < 3: + return np.nan + peaks = [i for i in range(1, len(Pxx_masked)-1) + if Pxx_masked[i] > Pxx_masked[i-1] and Pxx_masked[i] > Pxx_masked[i+1]] + if not peaks: + return np.nan + best_idx = peaks[np.argmax([Pxx_masked[i] for i in peaks])] + return f_masked[best_idx] * 60 # bpm + + # --- Compute HR across all channels --- + hr_all_channels = np.array([find_hr_from_psd(data[i, :]) for i in range(len(channel_names))]) + hr_all_channels = hr_all_channels[~np.isnan(hr_all_channels)] + hr_mode = np.round(np.median(hr_all_channels)) # Use median if some NaNs + + print(f"Estimated Heart Rate: {hr_mode} bpm") + + hr_freq = hr_mode / 60 # Hz + low = hr_freq - 0.3 + high = hr_freq + 0.3 + return fig, hr1, hr2, low, high + + def process_participant(file_path, progress_callback=None): fig_individual: dict[str, Figure] = {} @@ -2873,6 +3287,35 @@ def process_participant(file_path, progress_callback=None): fig_individual["Loaded Raw"] = fig_raw if progress_callback: progress_callback(1) logger.info("1") + + + if hasattr(raw, 'annotations') and len(raw.annotations) > 0: + # Get time of first event + first_event_time = raw.annotations.onset[0] + trim_time = max(0, first_event_time - 5.0) # Ensure we don't go negative + raw.crop(tmin=trim_time) + # Shift annotation onsets to match new t=0 + import mne + + ann = raw.annotations + ann_shifted = mne.Annotations( + onset=ann.onset - trim_time, # shift to start at zero + duration=ann.duration, + description=ann.description + ) + data = raw.get_data() + info = raw.info.copy() + raw = mne.io.RawArray(data, info) + raw.set_annotations(ann_shifted) + + logger.info(f"Trimmed raw data: start at {trim_time}s (5s before first event), t=0 at new start") + else: + logger.warning("No events found, skipping trim step.") + + fig_trimmed = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Trimmed Raw", show=False) + fig_individual["Trimmed Raw"] = fig_trimmed + if progress_callback: progress_callback(2) + logger.info("2") # Step 1.5: Verify optode positions fig_optodes = raw.plot_sensors(show_names=True, to_sphere=True, show=False) # type: ignore @@ -2884,9 +3327,18 @@ def process_participant(file_path, progress_callback=None): # raw = raw.resample(0.5) # Downsample to 0.5 Hz # Step 2: Bad from SCI + if True: + fig, hr1, hr2, low, high = hr_calc(raw) + fig_individual["PSD"] = fig + fig_individual['HeartRate_PSD'] = hr1 + fig_individual['HeartRate_Time'] = hr2 + if progress_callback: progress_callback(10) + + if progress_callback: progress_callback(2) + bad_sci = [] if SCI: - bad_sci, fig_sci_1, fig_sci_2 = calculate_scalp_coupling(raw) + bad_sci, fig_sci_1, fig_sci_2 = calculate_scalp_coupling(raw, low, high) fig_individual["SCI1"] = fig_sci_1 fig_individual["SCI2"] = fig_sci_2 if progress_callback: progress_callback(3) @@ -2938,6 +3390,12 @@ def process_participant(file_path, progress_callback=None): fig_individual["TDDR"] = fig_raw_od_tddr if progress_callback: progress_callback(9) logger.info("9") + + + raw_od, fig = calculate_and_apply_wavelet(raw_od) + fig_individual["Wavelet"] = fig + if progress_callback: progress_callback(9) + # Step 8: BLL raw_haemo = beer_lambert_law(raw_od, ppf=calculate_dpf(file_path)) diff --git a/main.py b/main.py index bf305c1..8b6e9ec 100644 --- a/main.py +++ b/main.py @@ -58,7 +58,7 @@ SECTIONS = [ { "title": "Preprocessing", "params": [ - {"name": "SECONDS_TO_STRIP", "default": 0, "type": int, "help": "Seconds to remove from beginning of all loaded snirf files. Setting this to 0 will remove nothing from the files."}, + # {"name": "SECONDS_TO_STRIP", "default": 0, "type": int, "help": "Seconds to remove from beginning of all loaded snirf files. Setting this to 0 will remove nothing from the files."}, {"name": "DOWNSAMPLE", "default": True, "type": bool, "help": "Should the snirf files be downsampled? If this is set to True, DOWNSAMPLE_FREQUENCY will be used as the target frequency to downsample to."}, {"name": "DOWNSAMPLE_FREQUENCY", "default": 25, "type": int, "help": "Frequency (Hz) to downsample to. If this is set higher than the input data, new data will be interpolated. Only used if DOWNSAMPLE is set to True"}, ] @@ -124,7 +124,7 @@ SECTIONS = [ "title": "Filtering", "params": [ {"name": "L_FREQ", "default": 0.005, "type": float, "help": "Any frequencies lower than this value will be removed."}, - {"name": "H_FREQ", "default": 0.7, "type": float, "help": "Any frequencies higher than this value will be removed."}, + {"name": "H_FREQ", "default": 0.3, "type": float, "help": "Any frequencies higher than this value will be removed."}, #{"name": "FILTER", "default": True, "type": bool, "help": "Calculate Peak Spectral Power."}, ]