Compare commits
2 Commits
45c6176dba
...
v1.1.7
| Author | SHA1 | Date | |
|---|---|---|---|
| 64ed6d2e87 | |||
| 1aa2402d09 |
16
changelog.md
16
changelog.md
@@ -1,7 +1,21 @@
|
||||
# Version 1.1.7
|
||||
|
||||
- Fixed a bug where having both a L_FREQ and H_FREQ would cause only the L_FREQ to be used
|
||||
- Changed the default H_FREQ from 0.7 to 0.3
|
||||
- Added a PSD graph, along with 2 heart rate images to the individual participant viewer
|
||||
- The PSD graph is used to help calculate the heart rate, whereas the other 2 are currently just for show
|
||||
- SCI is now done using a .6hz window around the calculated heart rate compared to a window around an average heart rate
|
||||
- Fixed an issue with some epochs figures not showing under the participant analysis
|
||||
- Removed SECONDS_TO_STRIP from the preprocessing options
|
||||
- Added new parameters to the right side of the screen
|
||||
- These parameters include TRIM, SECONDS_TO_KEEP, OPTODE_PLACEMENT, HEART_RATE, WAVELET, IQR, WAVELET_TYPE, WAVELET_LEVEL, ENHANCE_NEGATIVE_CORRELATION, SHORT_CHANNEL_THRESH, LONG_CHANNEL_THRESH, and DRIFT_MODEL
|
||||
- Changed number of rectangles in the progress bar to 25 to account for the new options
|
||||
|
||||
|
||||
# Version 1.1.6
|
||||
|
||||
- Fixed Process button from appearing when no files are selected
|
||||
- Fix for instand child process crash on Windows
|
||||
- Fixed a bug that would cause an instant child process crash on Windows
|
||||
- Added L_FREQ and H_FREQ parameters for more user control over low and high pass filtering
|
||||
|
||||
|
||||
|
||||
664
flares.py
664
flares.py
@@ -28,10 +28,11 @@ import matplotlib.colors as mcolors
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
from matplotlib.lines import Line2D
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from numpy import float64
|
||||
from numpy import float64, floating
|
||||
|
||||
import pandas as pd
|
||||
from pandas import DataFrame
|
||||
@@ -47,11 +48,15 @@ from statsmodels.stats.multitest import multipletests
|
||||
|
||||
from scipy import stats
|
||||
from scipy.spatial.distance import cdist
|
||||
from scipy.signal import welch, butter, filtfilt # type: ignore
|
||||
|
||||
import pywt # type: ignore
|
||||
import neurokit2 as nk # type: ignore
|
||||
|
||||
# Backen visualization needed to be defined for pyinstaller
|
||||
import pyvistaqt # type: ignore
|
||||
import vtkmodules.util.data_model
|
||||
import vtkmodules.util.execution_model
|
||||
#import vtkmodules.util.data_model
|
||||
#import vtkmodules.util.execution_model
|
||||
|
||||
# External library imports for mne
|
||||
from mne import (
|
||||
@@ -61,7 +66,7 @@ from mne import (
|
||||
) # type: ignore
|
||||
from mne.source_space import SourceSpaces
|
||||
from mne.transforms import Transform # type: ignore
|
||||
from mne.io import BaseRaw, read_raw_snirf # type: ignore
|
||||
from mne.io import BaseRaw, RawArray, read_raw_snirf # type: ignore
|
||||
from mne.preprocessing.nirs import (
|
||||
beer_lambert_law, optical_density,
|
||||
temporal_derivative_distribution_repair,
|
||||
@@ -110,10 +115,17 @@ FIXED_CATEGORY_COLORS = {
|
||||
AGE: float
|
||||
GENDER: str
|
||||
|
||||
SECONDS_TO_STRIP: int
|
||||
# SECONDS_TO_STRIP: int
|
||||
DOWNSAMPLE: bool
|
||||
DOWNSAMPLE_FREQUENCY: int
|
||||
|
||||
TRIM: bool
|
||||
SECONDS_TO_KEEP: float
|
||||
|
||||
OPTODE_PLACEMENT: bool
|
||||
|
||||
HEART_RATE: bool
|
||||
|
||||
SCI: bool
|
||||
SCI_TIME_WINDOW: int
|
||||
SCI_THRESHOLD: float
|
||||
@@ -128,18 +140,35 @@ PSP_THRESHOLD: float
|
||||
|
||||
TDDR: bool
|
||||
|
||||
WAVELET: bool
|
||||
IQR: float
|
||||
WAVELET_TYPE: str
|
||||
WAVELET_LEVEL: int
|
||||
|
||||
HEART_RATE = True # True if heart rate should be calculated. This helps the SCI, PSP, and SNR methods to be more accurate.
|
||||
SECONDS_TO_STRIP_HR =5 # Amount of seconds to temporarily strip from the data to calculate heart rate more effectively. Useful if participant removed cap while still recording.
|
||||
MAX_LOW_HR = 40 # Any heart rate values lower than this will be set to this value.
|
||||
MAX_HIGH_HR = 200 # Any heart rate values higher than this will be set to this value.
|
||||
SMOOTHING_WINDOW_HR = 100 # Heart rate will be calculated as a rolling average over this many amount of samples.
|
||||
HEART_RATE_WINDOW = 25 # Amount of BPM above and below the calculated average to use for a range of resting BPM.
|
||||
|
||||
ENHANCE_NEGATIVE_CORRELATION: bool
|
||||
|
||||
FILTER: bool
|
||||
L_FREQ: float
|
||||
H_FREQ: float
|
||||
|
||||
SHORT_CHANNEL: bool
|
||||
SHORT_CHANNEL_THRESH: float
|
||||
LONG_CHANNEL_THRESH: float
|
||||
|
||||
REMOVE_EVENTS: list
|
||||
|
||||
TIME_WINDOW_START: int
|
||||
TIME_WINDOW_END: int
|
||||
|
||||
DRIFT_MODEL: str
|
||||
|
||||
VERBOSITY = True
|
||||
|
||||
# FIXME: Shouldn't need each ordering - just order it before checking
|
||||
@@ -169,11 +198,17 @@ GROUP = "Default"
|
||||
|
||||
REQUIRED_KEYS: dict[str, Any] = {
|
||||
|
||||
|
||||
"SECONDS_TO_STRIP": int,
|
||||
# "SECONDS_TO_STRIP": int,
|
||||
"DOWNSAMPLE": bool,
|
||||
"DOWNSAMPLE_FREQUENCY": int,
|
||||
|
||||
"TRIM": bool,
|
||||
"SECONDS_TO_KEEP": float,
|
||||
|
||||
"OPTODE_PLACEMENT": bool,
|
||||
|
||||
"HEART_RATE": bool,
|
||||
|
||||
"SCI": bool,
|
||||
"SCI_TIME_WINDOW": int,
|
||||
"SCI_THRESHOLD": float,
|
||||
@@ -187,11 +222,23 @@ REQUIRED_KEYS: dict[str, Any] = {
|
||||
"PSP_THRESHOLD": float,
|
||||
|
||||
"SHORT_CHANNEL": bool,
|
||||
"SHORT_CHANNEL_THRESH": float,
|
||||
"LONG_CHANNEL_THRESH": float,
|
||||
|
||||
|
||||
"REMOVE_EVENTS": list,
|
||||
"TIME_WINDOW_START": int,
|
||||
"TIME_WINDOW_END": int,
|
||||
"L_FREQ": float,
|
||||
"H_FREQ": float,
|
||||
|
||||
"TDDR": bool,
|
||||
"WAVELET": bool,
|
||||
"IQR": float,
|
||||
"WAVELET_TYPE": str,
|
||||
"WAVELET_LEVEL": int,
|
||||
"FILTER": bool,
|
||||
"DRIFT_MODEL": str,
|
||||
# "REJECT_PAIRS": bool,
|
||||
# "FORCE_DROP_ANNOTATIONS": list,
|
||||
# "FILTER_LOW_PASS": float,
|
||||
@@ -1071,28 +1118,29 @@ def mark_bads(raw, bad_sci, bad_snr, bad_psp):
|
||||
|
||||
def filter_the_data(raw_haemo):
|
||||
# --- STEP 5: Filtering (0.01–0.2 Hz bandpass) ---
|
||||
fig_filter = raw_haemo.compute_psd(fmax=2).plot(
|
||||
average=True, xscale="log", color="r", show=False, amplitude=False
|
||||
fig_filter = raw_haemo.compute_psd(fmax=3).plot(
|
||||
average=True, color="r", show=False, amplitude=True
|
||||
)
|
||||
|
||||
if L_FREQ == 0 and H_FREQ != 0:
|
||||
raw_haemo = raw_haemo.filter(l_freq=None, h_freq=H_FREQ, h_trans_bandwidth=0.02)
|
||||
elif L_FREQ != 0 and H_FREQ == 0:
|
||||
raw_haemo = raw_haemo.filter(l_freq=L_FREQ, h_freq=None, l_trans_bandwidth=0.002)
|
||||
elif L_FREQ != 0 and H_FREQ == 0:
|
||||
elif L_FREQ != 0 and H_FREQ != 0:
|
||||
raw_haemo = raw_haemo.filter(l_freq=L_FREQ, h_freq=H_FREQ, l_trans_bandwidth=0.002, h_trans_bandwidth=0.02)
|
||||
|
||||
else:
|
||||
print("No filter")
|
||||
#raw_haemo = raw_haemo.filter(l_freq=None, h_freq=0.4, h_trans_bandwidth=0.2)
|
||||
#raw_haemo = raw_haemo.filter(l_freq=None, h_freq=0.7, h_trans_bandwidth=0.2)
|
||||
#raw_haemo = raw_haemo.filter(0.005, 0.7, h_trans_bandwidth=0.02, l_trans_bandwidth=0.002)
|
||||
|
||||
raw_haemo.compute_psd(fmax=2).plot(
|
||||
average=True, xscale="log", axes=fig_filter.axes, color="g", amplitude=False, show=False
|
||||
raw_haemo.compute_psd(fmax=3).plot(
|
||||
average=True, axes=fig_filter.axes, color="g", amplitude=True, show=False
|
||||
)
|
||||
|
||||
fig_raw_haemo_filter = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="Filtered HbO and HbR", show=False)
|
||||
|
||||
return fig_filter, fig_raw_haemo_filter
|
||||
return raw_haemo, fig_filter, fig_raw_haemo_filter
|
||||
|
||||
|
||||
|
||||
@@ -1119,6 +1167,8 @@ def epochs_calculations(raw_haemo, events, event_dict):
|
||||
# Plot for each condition
|
||||
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 4))
|
||||
for idx, condition in enumerate(epochs.event_id.keys()):
|
||||
logger.info(condition)
|
||||
logger.info(idx)
|
||||
# Plot images for each condition
|
||||
fig_epochs_data = epochs[condition].plot_image(
|
||||
combine="mean",
|
||||
@@ -1127,11 +1177,15 @@ def epochs_calculations(raw_haemo, events, event_dict):
|
||||
ts_args=dict(ylim=dict(hbo=[-1, 1], hbr=[-1, 1])),
|
||||
show=False
|
||||
)
|
||||
for i in fig_epochs_data:
|
||||
for j, fig in enumerate(fig_epochs_data):
|
||||
logger.info("------------------------------------------")
|
||||
logger.info(j)
|
||||
logger.info(fig)
|
||||
|
||||
ax = fig.axes[0]
|
||||
original_title = ax.get_title()
|
||||
ax.set_title(f"{condition}: {original_title}")
|
||||
fig_epochs.append((f"fig_{condition}_data_{idx}", i)) # Store with a unique name
|
||||
fig_epochs.append((f"fig_{condition}_data_{idx}_{j}", fig)) # Store with a unique name
|
||||
|
||||
# Evoked average figure for each condition
|
||||
evoked_avg = epochs[condition].average()
|
||||
@@ -1263,7 +1317,7 @@ def make_design_matrix(raw_haemo, short_chans):
|
||||
hrf_model='fir',
|
||||
stim_dur=0.5,
|
||||
fir_delays=range(15),
|
||||
drift_model='cosine',
|
||||
drift_model=DRIFT_MODEL,
|
||||
high_pass=0.01,
|
||||
oversampling=1,
|
||||
min_onset=-125,
|
||||
@@ -1276,7 +1330,7 @@ def make_design_matrix(raw_haemo, short_chans):
|
||||
hrf_model='fir',
|
||||
stim_dur=0.5,
|
||||
fir_delays=range(15),
|
||||
drift_model='cosine',
|
||||
drift_model=DRIFT_MODEL,
|
||||
high_pass=0.01,
|
||||
oversampling=1,
|
||||
min_onset=-125,
|
||||
@@ -2302,8 +2356,8 @@ def brain_landmarks_3d(raw_haemo: BaseRaw, show_optodes: Literal['sensors', 'lab
|
||||
"Brodmann.17-lh": "blue",
|
||||
"Brodmann.18-lh": "blue",
|
||||
"Brodmann.19-lh": "blue",
|
||||
"Brodmann.39-lh": "purple",
|
||||
"Brodmann.40-lh": "pink",
|
||||
"Brodmann.39-lh": "pink",
|
||||
"Brodmann.40-lh": "purple",
|
||||
"Brodmann.42-lh": "white",
|
||||
"Brodmann.44-lh": "white",
|
||||
"Brodmann.48-lh": "white",
|
||||
@@ -2666,13 +2720,13 @@ def load_snirf(file_path: str) -> tuple[BaseRaw, Figure]:
|
||||
raw = read_raw_snirf(file_path, preload=True, verbose=VERBOSITY) # type: ignore
|
||||
raw.load_data(verbose=VERBOSITY) # type: ignore
|
||||
|
||||
# Strip the specified amount of seconds from the start of the file
|
||||
total_duration = getattr(raw, "times")[-1]
|
||||
if total_duration > SECONDS_TO_STRIP:
|
||||
raw.crop(tmin=SECONDS_TO_STRIP, tmax=total_duration, verbose=VERBOSITY) # type: ignore
|
||||
logger.info(f"Stripped first {SECONDS_TO_STRIP} second(s) of data.")
|
||||
else:
|
||||
logger.info(f"Data length ({total_duration:.2f}s) less than strip duration; no cropping applied.")
|
||||
# # Strip the specified amount of seconds from the start of the file
|
||||
# total_duration = getattr(raw, "times")[-1]
|
||||
# if total_duration > SECONDS_TO_STRIP:
|
||||
# raw.crop(tmin=SECONDS_TO_STRIP, tmax=total_duration, verbose=VERBOSITY) # type: ignore
|
||||
# logger.info(f"Stripped first {SECONDS_TO_STRIP} second(s) of data.")
|
||||
# else:
|
||||
# logger.info(f"Data length ({total_duration:.2f}s) less than strip duration; no cropping applied.")
|
||||
|
||||
# If the user forcibly dropped channels, remove them now before any processing occurs
|
||||
# logger.info("Checking if there are channels to forcibly drop...")
|
||||
@@ -2863,6 +2917,399 @@ def calculate_dpf(file_path):
|
||||
|
||||
|
||||
|
||||
def iqr_threshold(coeffs: NDArray[float64], k: float = 1.5) -> floating[Any]:
|
||||
|
||||
"""
|
||||
Calculate the interquartile range (IQR) threshold scaled by a factor, k.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coeffs : NDArray[float64]
|
||||
Array of coefficients to compute the IQR from.
|
||||
k : float, optional
|
||||
Scaling factor for the IQR (default is 1.5).
|
||||
|
||||
Returns
|
||||
-------
|
||||
floating[Any]
|
||||
The scaled IQR threshold value.
|
||||
"""
|
||||
|
||||
# Calculate the IQR
|
||||
q1 = np.percentile(coeffs, 25)
|
||||
q3 = np.percentile(coeffs, 75)
|
||||
iqr = q3 - q1
|
||||
|
||||
return k * iqr
|
||||
|
||||
|
||||
|
||||
def wavelet_iqr_denoise(signal: NDArray[float64], wavelet: str = 'db4', level: int = 3) -> NDArray[float64]:
|
||||
"""
|
||||
Denoises a signal using wavelet decomposition and IQR-based thresholding on detail coefficients.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : NDArray[float64]
|
||||
The input signal array to denoise.
|
||||
wavelet : str, optional
|
||||
The type of wavelet to use for decomposition (default is 'db4').
|
||||
level : int, optional
|
||||
Decomposition level for wavelet transform (default is 3).
|
||||
|
||||
Returns
|
||||
-------
|
||||
NDArray[float64]
|
||||
The denoised signal array, with the same length as the input.
|
||||
"""
|
||||
|
||||
# Decompose the signal using wavelet transform and initialize a list with approximation coefficients
|
||||
coeffs: list[NDArray[float64]] = pywt.wavedec(signal, wavelet, level=level) # type: ignore
|
||||
cA = coeffs[0]
|
||||
denoised_coeffs = [cA]
|
||||
|
||||
# Threshold detail coefficients to reduce noise
|
||||
for cD in coeffs[1:]:
|
||||
threshold = iqr_threshold(cD, IQR)
|
||||
cD_thresh = np.sign(cD) * np.maximum(np.abs(cD) - threshold, 0.0) # np.where((cD < lower) | (cD > upper), 0, cD)
|
||||
cD_thresh = cD_thresh.astype(float64)
|
||||
denoised_coeffs.append(cD_thresh)
|
||||
|
||||
# Reconstruct the denoised signal
|
||||
denoised_signal = cast(NDArray[float64], pywt.waverec(denoised_coeffs, wavelet)) # type: ignore
|
||||
return denoised_signal[:len(signal)]
|
||||
|
||||
|
||||
|
||||
def calculate_and_apply_wavelet(data: BaseRaw) -> tuple[BaseRaw, Figure]:
|
||||
"""
|
||||
Applies a wavelet IQR denoising filter to the data and generates a plot.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : BaseRaw
|
||||
The loaded data object to process.
|
||||
ID : str
|
||||
File name of the the snirf file that was loaded.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[BaseRaw, Figure]
|
||||
- BaseRaw: The processed data object.
|
||||
- Figure: The corresponding Matplotlib figure.
|
||||
"""
|
||||
|
||||
logger.info("Applying the wavelet filter...")
|
||||
|
||||
# Denoise the data
|
||||
logger.info("Denoising the data...")
|
||||
loaded_data: NDArray[float64] = data.get_data(verbose=VERBOSITY) # type: ignore
|
||||
denoised_data = np.zeros_like(loaded_data)
|
||||
|
||||
logger.info("Calculating the IQR, decomposing the signal, and thresholding the coefficients...")
|
||||
for ch in range(loaded_data.shape[0]):
|
||||
denoised_data[ch, :] = wavelet_iqr_denoise(loaded_data[ch, :], wavelet=WAVELET_TYPE, level=WAVELET_LEVEL)
|
||||
|
||||
# Reconstruct the data with the annotations
|
||||
logger.info("Reconstructing the data with annotations...")
|
||||
raw_with_tddr_and_wavelet = RawArray(denoised_data, cast(Info, data.info), verbose=VERBOSITY)
|
||||
raw_with_tddr_and_wavelet.set_annotations(data.annotations.copy(), verbose=VERBOSITY) # type: ignore
|
||||
|
||||
# Create a figure for the results
|
||||
logger.info("Creating the figure...")
|
||||
fig = cast(Figure, raw_with_tddr_and_wavelet.plot(show=False, n_channels=len(getattr(data, "ch_names")), duration=data.times[-1]).figure) # type: ignore
|
||||
fig.suptitle(f"Wavelet for ", fontsize=16) # type: ignore
|
||||
fig.subplots_adjust(top=0.92)
|
||||
plt.close(fig)
|
||||
|
||||
logger.info("Successfully applied the wavelet filter.")
|
||||
|
||||
return raw_with_tddr_and_wavelet, fig
|
||||
|
||||
|
||||
|
||||
def short_channel_processing_for_hr(data: BaseRaw, short_chans: BaseRaw | None) -> tuple[float, NDArray[float64], NDArray[float64]]:
|
||||
"""
|
||||
Extract and trim short-channel fNIRS signal for heart rate analysis.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : BaseRaw
|
||||
The loaded data object to process.
|
||||
short_chans : BaseRaw | None
|
||||
Data object with only short separation channels, or None if unavailable.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[float, NDArray[float64], NDArray[float64]]
|
||||
- float: Sampling frequency of the signal.
|
||||
- NDArray[float64]: Trimmed short-channel signal.
|
||||
- NDArray[float64]: Corresponding time values.
|
||||
"""
|
||||
|
||||
# Find the short channel (or best candidate) and extract signal data and sampling frequency
|
||||
logger.info("Extracting the signal and calculating the sampling frequency...")
|
||||
|
||||
# If a short channel exists, use it for our signal. Otherwise just take the first channel in the data
|
||||
# TODO: Find a better way around this
|
||||
if short_chans is not None:
|
||||
signal = cast(NDArray[float64], short_chans.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore
|
||||
else:
|
||||
signal = cast(NDArray[float64], data.get_data(picks=[0], verbose=VERBOSITY))[0] # type: ignore
|
||||
|
||||
# Calculate the sampling frequency
|
||||
sfreq = cast(int, data.info['sfreq'])
|
||||
|
||||
# Trim start and end of the signal to remove edge artifacts
|
||||
logger.info(f"Removing {SECONDS_TO_STRIP_HR} seconds from the beginning and end of the file...")
|
||||
strip_samples = int(sfreq * SECONDS_TO_STRIP_HR)
|
||||
signal_trimmed = signal[strip_samples:-strip_samples]
|
||||
times_trimmed = cast(NDArray[float64], getattr(data, "times"))[strip_samples:-strip_samples]
|
||||
|
||||
return sfreq, signal_trimmed, times_trimmed
|
||||
|
||||
|
||||
|
||||
def calculate_heart_rate_neurokit(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[float64], float]:
|
||||
"""
|
||||
Calculate and smooth heart rate from a trimmed signal using NeuroKit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sfreq : float
|
||||
Sampling frequency of the signal.
|
||||
signal_trimmed : NDArray[float64]
|
||||
Preprocessed and trimmed fNIRS signal.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[NDArray[float64], float]
|
||||
- NDArray[float64]: Smoothed heart rate time series (BPM).
|
||||
- float: Mean heart rate.
|
||||
"""
|
||||
|
||||
logger.info("Calculating heart rate using NeuroKit...")
|
||||
|
||||
# Filter signal to isolate heart rate frequencies and detect peaks
|
||||
logger.info("Filtering the signal and detecting peaks...")
|
||||
signal_filtered = cast(NDArray[float64], nk.signal_filter(signal_trimmed, sampling_rate=sfreq, lowcut=0.8, highcut=2.5)) # type: ignore
|
||||
peaks_dict = cast(dict[str, Any], nk.signal_findpeaks(signal_filtered)) # type: ignore
|
||||
peaks = peaks_dict['Peaks']
|
||||
hr = cast(NDArray[float64], nk.signal_rate(peaks, sampling_rate=sfreq, desired_length=len(signal_trimmed))) # type: ignore
|
||||
hr_clean = np.clip(hr, MAX_LOW_HR, MAX_HIGH_HR)
|
||||
|
||||
# Smooth heart rate time series by replacing spikes with local rolling mean and calculate the mean
|
||||
logger.info("Smoothing the signal and calculating the mean...")
|
||||
hr_series = pd.Series(hr_clean)
|
||||
local_median = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).median()
|
||||
spikes = hr_series > (local_median + 10)
|
||||
smoothed_values = hr_series.copy()
|
||||
smoothed_spikes = hr_series.rolling(window=SMOOTHING_WINDOW_HR, center=True, min_periods=1).mean()
|
||||
smoothed_values[spikes] = smoothed_spikes[spikes]
|
||||
hr_smooth_nk = cast(NDArray[float64], smoothed_values.to_numpy()) # type: ignore
|
||||
mean_hr_nk = hr_smooth_nk.mean()
|
||||
|
||||
logger.info("Original HR min/max: %f, %f", hr_clean.min(), hr_clean.max())
|
||||
logger.info("Smoothed HR min/max:%f, %f", hr_smooth_nk.min(), hr_smooth_nk.max())
|
||||
logger.info(f"Estimated mean HR nk: {mean_hr_nk:.1f} BPM")
|
||||
|
||||
logger.info("Successfully calculated heart rate using NeuroKit.")
|
||||
|
||||
return hr_smooth_nk, mean_hr_nk
|
||||
|
||||
|
||||
|
||||
def calculate_heart_rate_scipy(sfreq: float, signal_trimmed: NDArray[float64]) -> tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float]:
|
||||
"""
|
||||
Estimate heart rate using spectral analysis on a high-pass filtered signal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sfreq : float
|
||||
Sampling frequency of the input signal.
|
||||
signal_trimmed : NDArray[float64]
|
||||
Trimmed fNIRS signal to analyze.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[NDArray[floating[Any]], NDArray[float64], np.ndarray[Any, np.dtype[np.bool_]], float]
|
||||
- NDArray[floating[Any]]: Frequencies converted to beats per minute (BPM).
|
||||
- NDArray[float64]: Power spectral density (PSD) of the signal.
|
||||
- np.ndarray[Any, np.dtype[np.bool_]]: Boolean mask indicating frequencies within heart rate range (30-300 BPM).
|
||||
- float: Estimated mean heart rate in BPM corresponding to the PSD peak within the range.
|
||||
"""
|
||||
|
||||
logger.info("Calculating heart rate using SciPy...")
|
||||
|
||||
# Apply a high-pass Butterworth filter to remove slow trends below 0.5 Hz from the trimmed signal (actual data)
|
||||
logger.info("Applying a butterworth filter...")
|
||||
b, a = cast(tuple[NDArray[float64], NDArray[float64]], butter(2, 0.5 / (sfreq / 2), btype='high'))
|
||||
signal_hp = cast(NDArray[float64],filtfilt(b, a, signal_trimmed))
|
||||
|
||||
# Calculate the Power Spectral Density (PSD) of the filtered signal using Welch's method
|
||||
logger.info("Calculating the PSD...")
|
||||
nperseg = min(len(signal_hp), 4096)
|
||||
frequencies_scipy, psd_scipy = cast(tuple[NDArray[float64], NDArray[float64]], welch(signal_hp, fs=sfreq, nperseg=nperseg, noverlap=nperseg//2))
|
||||
|
||||
# Convert frequency values to beats per minute (BPM) and set a heart rate range (30-300 BPM)
|
||||
logger.info("Converting to BPM...")
|
||||
freq_bpm_scipy = frequencies_scipy * 60
|
||||
freq_range_scipy = (freq_bpm_scipy > 30) & (freq_bpm_scipy < 300)
|
||||
|
||||
# Identify the peak frequency within the heart rate range and estimate the mean heart rate in BPM
|
||||
logger.info("Finding the mean...")
|
||||
peak_index = np.argmax(psd_scipy[freq_range_scipy])
|
||||
mean_hr_scipy = freq_bpm_scipy[freq_range_scipy][peak_index]
|
||||
|
||||
logger.info("Successfully calculated heart rate using SciPy.")
|
||||
|
||||
return freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy
|
||||
|
||||
|
||||
def plot_heart_rate(
|
||||
freq_bpm_scipy: NDArray[floating[Any]],
|
||||
psd_scipy: NDArray[float64],
|
||||
freq_range_scipy: np.ndarray[Any, np.dtype[np.bool_]],
|
||||
mean_hr_scipy: float,
|
||||
hr_smooth_nk: NDArray[floating[Any]],
|
||||
mean_hr_nk: float,
|
||||
times_trimmed: NDArray[floating[Any]],
|
||||
overruled: bool
|
||||
) -> tuple[Figure, Figure]:
|
||||
"""
|
||||
Generate plots comparing heart rate estimates from SciPy PSD and NeuroKit2.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
freq_bpm_scipy : NDArray[floating[Any]]
|
||||
Frequencies in beats per minute from SciPy PSD analysis.
|
||||
psd_scipy : NDArray[float64]
|
||||
Power spectral density values corresponding to freq_bpm_scipy.
|
||||
freq_range_scipy : np.ndarray[Any, np.dtype[np.bool_]]
|
||||
Boolean mask indicating the heart rate frequency range used in PSD.
|
||||
mean_hr_scipy : float
|
||||
Mean heart rate estimated from SciPy PSD peak.
|
||||
hr_smooth_nk : NDArray[floating[Any]]
|
||||
Smoothed instantaneous heart rate from NeuroKit2.
|
||||
mean_hr_nk : float
|
||||
Mean heart rate estimated from NeuroKit2 data.
|
||||
times_trimmed : NDArray[floating[Any]]
|
||||
Time points corresponding to hr_smooth_nk values.
|
||||
overruled: bool
|
||||
True if the heart rate from NeuroKit2 is overriding the results from the PSD.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[Figure, Figure]
|
||||
- Figure showing the PSD and SciPy heart rate estimate.
|
||||
- Figure showing the time series comparison of heart rates.
|
||||
"""
|
||||
|
||||
# Create the first plot for the PSD. Add a yellow range to show what we will be filtering to.
|
||||
logger.info("Creating the figure...")
|
||||
fig1, ax1 = plt.subplots(figsize=(10, 5)) # type: ignore
|
||||
ax1.set_xlim(30, 300)
|
||||
ax1.plot(freq_bpm_scipy[freq_range_scipy], psd_scipy[freq_range_scipy]) # type: ignore
|
||||
ax1.axvline(x=mean_hr_scipy, color='red', linestyle='--', label=f'Mean HR: {mean_hr_scipy:.1f} BPM') # type: ignore
|
||||
ax1.axvspan(min(mean_hr_nk - HEART_RATE_WINDOW, mean_hr_scipy - HEART_RATE_WINDOW), max(mean_hr_nk + HEART_RATE_WINDOW, mean_hr_scipy + HEART_RATE_WINDOW), color='yellow', alpha=0.3, label=f'HR Range ±{HEART_RATE_WINDOW} BPM') # type: ignore
|
||||
ax1.set_xlabel('Heart Rate (BPM)') # type: ignore
|
||||
ax1.set_ylabel('Power Spectral Density') # type: ignore
|
||||
ax1.set_title('PSD of fNIRS signal - Peak indicates Heart Rate') # type: ignore
|
||||
ax1.grid(True) # type: ignore
|
||||
|
||||
# Was the value we reported here correct for the data on the graph or was it overruled?
|
||||
if overruled:
|
||||
note = (
|
||||
'\n'
|
||||
'Note: Calculation was bad!\n'
|
||||
'Data has been set to match\n'
|
||||
'the value from NeuroKit2.'
|
||||
)
|
||||
phantom = Line2D([0], [0], color='none', label=note)
|
||||
handles, _ = ax1.get_legend_handles_labels()
|
||||
ax1.legend(handles=handles + [phantom]) # type: ignore
|
||||
|
||||
else:
|
||||
ax1.legend() # type: ignore
|
||||
plt.close(fig1)
|
||||
|
||||
# Create the second plot showing the rolling heart rate, as well as the two averages that were calculated
|
||||
logger.info("Creating the figure...")
|
||||
fig2, ax2 = plt.subplots(figsize=(14, 6)) # type: ignore
|
||||
ax2.plot(times_trimmed, hr_smooth_nk, label='Instantaneous HR (NeuroKit2)', color='blue', alpha=0.7) # type: ignore
|
||||
ax2.axhline(mean_hr_nk, color='red', linestyle='--', label=f'Mean HR NeuroKit2: {mean_hr_nk:.1f} BPM') # type: ignore
|
||||
ax2.axhline(mean_hr_scipy, color='orange', linestyle=':', label=f'SciPy Welch PSD (HP filtered): {mean_hr_scipy:.1f} BPM') # type: ignore
|
||||
ax2.set_xlabel('Time (seconds)') # type: ignore
|
||||
ax2.set_ylabel('Heart Rate (BPM)') # type: ignore
|
||||
ax2.set_title('Heart Rate Estimates Comparison') # type: ignore
|
||||
ax2.legend() # type: ignore
|
||||
ax2.grid(True) # type: ignore
|
||||
fig2.tight_layout()
|
||||
plt.close(fig2)
|
||||
|
||||
return fig1, fig2
|
||||
|
||||
|
||||
|
||||
def hr_calc(raw):
|
||||
if SHORT_CHANNEL:
|
||||
short_chans = get_short_channels(raw, max_dist=SHORT_CHANNEL_THRESH)
|
||||
else:
|
||||
short_chans = None
|
||||
sfreq, signal_trimmed, times_trimmed = short_channel_processing_for_hr(raw, short_chans)
|
||||
hr_smooth_nk, mean_hr_nk = calculate_heart_rate_neurokit(sfreq, signal_trimmed)
|
||||
freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy = calculate_heart_rate_scipy(sfreq, signal_trimmed)
|
||||
|
||||
# HACK: This sucks but looking at the graphs I trust neurokit2 more
|
||||
overruled = False
|
||||
if mean_hr_scipy < mean_hr_nk - 15:
|
||||
mean_hr_scipy = mean_hr_nk
|
||||
overruled = True
|
||||
if mean_hr_scipy > mean_hr_nk + 15:
|
||||
mean_hr_scipy = mean_hr_nk
|
||||
overruled = True
|
||||
|
||||
hr1, hr2 = plot_heart_rate(freq_bpm_scipy, psd_scipy, freq_range_scipy, mean_hr_scipy, hr_smooth_nk, mean_hr_nk, times_trimmed, overruled)
|
||||
|
||||
fig = raw.plot_psd(show=False)
|
||||
raw_filtered = raw.copy().filter(0.5, 3, fir_design='firwin')
|
||||
sfreq = raw.info['sfreq']
|
||||
data = raw_filtered.get_data()
|
||||
channel_names = raw.ch_names
|
||||
|
||||
# --- Parameters for PSD ---
|
||||
desired_bin_hz = 0.1
|
||||
nperseg = int(sfreq / desired_bin_hz)
|
||||
hr_range = (30, 180)
|
||||
|
||||
# --- Function to find strongest local peak ---
|
||||
def find_hr_from_psd(ch_data):
|
||||
f, Pxx = welch(ch_data, sfreq, nperseg=nperseg)
|
||||
mask = (f >= hr_range[0]/60) & (f <= hr_range[1]/60)
|
||||
f_masked = f[mask]
|
||||
Pxx_masked = Pxx[mask]
|
||||
if len(Pxx_masked) < 3:
|
||||
return np.nan
|
||||
peaks = [i for i in range(1, len(Pxx_masked)-1)
|
||||
if Pxx_masked[i] > Pxx_masked[i-1] and Pxx_masked[i] > Pxx_masked[i+1]]
|
||||
if not peaks:
|
||||
return np.nan
|
||||
best_idx = peaks[np.argmax([Pxx_masked[i] for i in peaks])]
|
||||
return f_masked[best_idx] * 60 # bpm
|
||||
|
||||
# --- Compute HR across all channels ---
|
||||
hr_all_channels = np.array([find_hr_from_psd(data[i, :]) for i in range(len(channel_names))])
|
||||
hr_all_channels = hr_all_channels[~np.isnan(hr_all_channels)]
|
||||
hr_mode = np.round(np.median(hr_all_channels)) # Use median if some NaNs
|
||||
|
||||
print(f"Estimated Heart Rate: {hr_mode} bpm")
|
||||
|
||||
hr_freq = hr_mode / 60 # Hz
|
||||
low = hr_freq - 0.3
|
||||
high = hr_freq + 0.3
|
||||
return fig, hr1, hr2, low, high
|
||||
|
||||
|
||||
def process_participant(file_path, progress_callback=None):
|
||||
|
||||
fig_individual: dict[str, Figure] = {}
|
||||
@@ -2874,31 +3321,67 @@ def process_participant(file_path, progress_callback=None):
|
||||
if progress_callback: progress_callback(1)
|
||||
logger.info("1")
|
||||
|
||||
# Step 1.5: Verify optode positions
|
||||
fig_optodes = raw.plot_sensors(show_names=True, to_sphere=True, show=False) # type: ignore
|
||||
fig_individual["Plot Sensors"] = fig_optodes
|
||||
|
||||
if TRIM:
|
||||
if hasattr(raw, 'annotations') and len(raw.annotations) > 0:
|
||||
# Get time of first event
|
||||
first_event_time = raw.annotations.onset[0]
|
||||
trim_time = max(0, first_event_time - SECONDS_TO_KEEP) # Ensure we don't go negative
|
||||
raw.crop(tmin=trim_time)
|
||||
# Shift annotation onsets to match new t=0
|
||||
import mne
|
||||
|
||||
ann = raw.annotations
|
||||
ann_shifted = mne.Annotations(
|
||||
onset=ann.onset - trim_time, # shift to start at zero
|
||||
duration=ann.duration,
|
||||
description=ann.description
|
||||
)
|
||||
data = raw.get_data()
|
||||
info = raw.info.copy()
|
||||
raw = mne.io.RawArray(data, info)
|
||||
raw.set_annotations(ann_shifted)
|
||||
|
||||
logger.info(f"Trimmed raw data: start at {trim_time}s (5s before first event), t=0 at new start")
|
||||
else:
|
||||
logger.warning("No events found, skipping trim step.")
|
||||
|
||||
fig_trimmed = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Trimmed Raw", show=False)
|
||||
fig_individual["Trimmed Raw"] = fig_trimmed
|
||||
if progress_callback: progress_callback(2)
|
||||
logger.info("2")
|
||||
|
||||
# Step 2: Downsample
|
||||
# raw = raw.resample(0.5) # Downsample to 0.5 Hz
|
||||
|
||||
# Step 2: Bad from SCI
|
||||
bad_sci = []
|
||||
if SCI:
|
||||
bad_sci, fig_sci_1, fig_sci_2 = calculate_scalp_coupling(raw)
|
||||
fig_individual["SCI1"] = fig_sci_1
|
||||
fig_individual["SCI2"] = fig_sci_2
|
||||
# Step 1.5: Verify optode positions
|
||||
if OPTODE_PLACEMENT:
|
||||
fig_optodes = raw.plot_sensors(show_names=True, to_sphere=True, show=False) # type: ignore
|
||||
fig_individual["Plot Sensors"] = fig_optodes
|
||||
if progress_callback: progress_callback(3)
|
||||
logger.info("3")
|
||||
|
||||
# Step 2: Bad from SCI
|
||||
if HEART_RATE:
|
||||
fig, hr1, hr2, low, high = hr_calc(raw)
|
||||
fig_individual["PSD"] = fig
|
||||
fig_individual['HeartRate_PSD'] = hr1
|
||||
fig_individual['HeartRate_Time'] = hr2
|
||||
if progress_callback: progress_callback(4)
|
||||
logger.info("4")
|
||||
|
||||
bad_sci = []
|
||||
if SCI:
|
||||
bad_sci, fig_sci_1, fig_sci_2 = calculate_scalp_coupling(raw, low, high)
|
||||
fig_individual["SCI1"] = fig_sci_1
|
||||
fig_individual["SCI2"] = fig_sci_2
|
||||
if progress_callback: progress_callback(5)
|
||||
logger.info("5")
|
||||
|
||||
# Step 2: Bad from SNR
|
||||
bad_snr = []
|
||||
if SNR:
|
||||
bad_snr, fig_snr = calculate_signal_noise_ratio(raw)
|
||||
fig_individual["SNR1"] = fig_snr
|
||||
if progress_callback: progress_callback(4)
|
||||
logger.info("4")
|
||||
if progress_callback: progress_callback(6)
|
||||
logger.info("6")
|
||||
|
||||
# Step 3: Bad from PSP
|
||||
bad_psp = []
|
||||
@@ -2906,82 +3389,94 @@ def process_participant(file_path, progress_callback=None):
|
||||
bad_psp, fig_psp1, fig_psp2 = calculate_peak_power(raw)
|
||||
fig_individual["PSP1"] = fig_psp1
|
||||
fig_individual["PSP2"] = fig_psp2
|
||||
if progress_callback: progress_callback(5)
|
||||
logger.info("5")
|
||||
if progress_callback: progress_callback(7)
|
||||
logger.info("7")
|
||||
|
||||
# Step 4: Mark the bad channels
|
||||
raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp)
|
||||
if fig_dropped and fig_raw_before is not None:
|
||||
fig_individual["fig2"] = fig_dropped
|
||||
fig_individual["fig3"] = fig_raw_before
|
||||
if progress_callback: progress_callback(6)
|
||||
logger.info("6")
|
||||
if progress_callback: progress_callback(8)
|
||||
logger.info("8")
|
||||
|
||||
# Step 5: Interpolate the bad channels
|
||||
if bad_channels:
|
||||
raw, fig_raw_after = interpolate_fNIRS_bads_weighted_average(raw, bad_channels)
|
||||
fig_individual["fig4"] = fig_raw_after
|
||||
if progress_callback: progress_callback(7)
|
||||
logger.info("7")
|
||||
if progress_callback: progress_callback(9)
|
||||
logger.info("9")
|
||||
|
||||
# Step 6: Optical Density
|
||||
raw_od = optical_density(raw)
|
||||
fig_raw_od = raw_od.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Optical Density", show=False)
|
||||
fig_individual["Optical Density"] = fig_raw_od
|
||||
if progress_callback: progress_callback(8)
|
||||
logger.info("8")
|
||||
if progress_callback: progress_callback(10)
|
||||
logger.info("10")
|
||||
|
||||
# Step 7: TDDR
|
||||
if TDDR:
|
||||
raw_od = temporal_derivative_distribution_repair(raw_od)
|
||||
fig_raw_od_tddr = raw_od.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="After TDDR (Motion Correction)", show=False)
|
||||
fig_individual["TDDR"] = fig_raw_od_tddr
|
||||
if progress_callback: progress_callback(9)
|
||||
logger.info("9")
|
||||
if progress_callback: progress_callback(11)
|
||||
logger.info("11")
|
||||
|
||||
|
||||
if WAVELET:
|
||||
raw_od, fig = calculate_and_apply_wavelet(raw_od)
|
||||
fig_individual["Wavelet"] = fig
|
||||
if progress_callback: progress_callback(12)
|
||||
logger.info("12")
|
||||
|
||||
|
||||
# Step 8: BLL
|
||||
raw_haemo = beer_lambert_law(raw_od, ppf=calculate_dpf(file_path))
|
||||
fig_raw_haemo_bll = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="HbO and HbR Signals", show=False)
|
||||
fig_individual["BLL"] = fig_raw_haemo_bll
|
||||
if progress_callback: progress_callback(10)
|
||||
logger.info("10")
|
||||
if progress_callback: progress_callback(13)
|
||||
logger.info("13")
|
||||
|
||||
# Step 9: ENC
|
||||
# raw_haemo = enhance_negative_correlation(raw_haemo)
|
||||
# fig_raw_haemo_enc = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="HbO and HbR Signals", show=False)
|
||||
# fig_individual.append(fig_raw_haemo_enc)
|
||||
if ENHANCE_NEGATIVE_CORRELATION:
|
||||
raw_haemo = enhance_negative_correlation(raw_haemo)
|
||||
fig_raw_haemo_enc = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="HbO and HbR Signals", show=False)
|
||||
fig_individual["ENC"] = fig_raw_haemo_enc
|
||||
if progress_callback: progress_callback(14)
|
||||
logger.info("14")
|
||||
|
||||
# Step 10: Filter
|
||||
fig_filter, fig_raw_haemo_filter = filter_the_data(raw_haemo)
|
||||
fig_individual["filter1"] = fig_filter
|
||||
fig_individual["filter2"] = fig_raw_haemo_filter
|
||||
if progress_callback: progress_callback(11)
|
||||
logger.info("11")
|
||||
if FILTER:
|
||||
raw_haemo, fig_filter, fig_raw_haemo_filter = filter_the_data(raw_haemo)
|
||||
fig_individual["filter1"] = fig_filter
|
||||
fig_individual["filter2"] = fig_raw_haemo_filter
|
||||
if progress_callback: progress_callback(15)
|
||||
logger.info("15")
|
||||
|
||||
# Step 11: Get short / long channels
|
||||
if SHORT_CHANNEL:
|
||||
short_chans = get_short_channels(raw_haemo, max_dist=0.02)
|
||||
short_chans = get_short_channels(raw_haemo, max_dist=SHORT_CHANNEL_THRESH)
|
||||
fig_short_chans = short_chans.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="Short Channels Only", show=False)
|
||||
fig_individual["short"] = fig_short_chans
|
||||
else:
|
||||
short_chans = None
|
||||
raw_haemo = get_long_channels(raw_haemo)
|
||||
if progress_callback: progress_callback(12)
|
||||
logger.info("12")
|
||||
raw_haemo = get_long_channels(raw_haemo, min_dist=SHORT_CHANNEL_THRESH, max_dist=LONG_CHANNEL_THRESH)
|
||||
if progress_callback: progress_callback(16)
|
||||
logger.info("16")
|
||||
|
||||
# Step 12: Events from annotations
|
||||
events, event_dict = events_from_annotations(raw_haemo)
|
||||
fig_events = plot_events(events, event_id=event_dict, sfreq=raw_haemo.info["sfreq"], show=False)
|
||||
fig_individual["events"] = fig_events
|
||||
if progress_callback: progress_callback(13)
|
||||
logger.info("13")
|
||||
if progress_callback: progress_callback(17)
|
||||
logger.info("17")
|
||||
|
||||
# Step 13: Epoch calculations
|
||||
epochs, fig_epochs = epochs_calculations(raw_haemo, events, event_dict)
|
||||
for name, fig in fig_epochs: # Unpack the tuple here
|
||||
fig_individual[f"epochs_{name}"] = fig # Store only the figure, not the name
|
||||
if progress_callback: progress_callback(14)
|
||||
logger.info("14")
|
||||
if progress_callback: progress_callback(18)
|
||||
logger.info("18")
|
||||
|
||||
# Step 14: Design Matrix
|
||||
events_to_remove = REMOVE_EVENTS
|
||||
@@ -2999,8 +3494,8 @@ def process_participant(file_path, progress_callback=None):
|
||||
|
||||
design_matrix, fig_design_matrix = make_design_matrix(raw_haemo, short_chans)
|
||||
fig_individual["Design Matrix"] = fig_design_matrix
|
||||
if progress_callback: progress_callback(15)
|
||||
logger.info("15")
|
||||
if progress_callback: progress_callback(19)
|
||||
logger.info("19")
|
||||
|
||||
# Step 15: Run GLM
|
||||
glm_est = run_glm(raw_haemo, design_matrix)
|
||||
@@ -3015,22 +3510,22 @@ def process_participant(file_path, progress_callback=None):
|
||||
# A large p-value means the data do not provide strong evidence that the effect is different from zero.
|
||||
|
||||
|
||||
if progress_callback: progress_callback(16)
|
||||
logger.info("16")
|
||||
if progress_callback: progress_callback(20)
|
||||
logger.info("20")
|
||||
|
||||
# Step 16: Plot GLM results
|
||||
fig_glm_result = plot_glm_results(file_path, raw_haemo, glm_est, design_matrix)
|
||||
for name, fig in fig_glm_result:
|
||||
fig_individual[f"GLM {name}"] = fig
|
||||
if progress_callback: progress_callback(17)
|
||||
logger.info("17")
|
||||
if progress_callback: progress_callback(21)
|
||||
logger.info("21")
|
||||
|
||||
# Step 17: Plot channel significance
|
||||
fig_significance = individual_significance(raw_haemo, glm_est)
|
||||
for name, fig in fig_significance:
|
||||
fig_individual[f"Significance {name}"] = fig
|
||||
if progress_callback: progress_callback(18)
|
||||
logger.info("18")
|
||||
if progress_callback: progress_callback(22)
|
||||
logger.info("22")
|
||||
|
||||
# Step 18: cha, con, roi
|
||||
cha = glm_est.to_dataframe()
|
||||
@@ -3085,8 +3580,8 @@ def process_participant(file_path, progress_callback=None):
|
||||
|
||||
contrast_dict[condition] = contrast_vector
|
||||
|
||||
if progress_callback: progress_callback(19)
|
||||
logger.info("19")
|
||||
if progress_callback: progress_callback(23)
|
||||
logger.info("23")
|
||||
|
||||
# Compute contrast results
|
||||
contrast_results = {}
|
||||
@@ -3099,15 +3594,20 @@ def process_participant(file_path, progress_callback=None):
|
||||
|
||||
cha["ID"] = file_path
|
||||
|
||||
fig_bytes = convert_fig_dict_to_png_bytes(fig_individual)
|
||||
if progress_callback: progress_callback(24)
|
||||
logger.info("24")
|
||||
|
||||
if progress_callback: progress_callback(20)
|
||||
logger.info("20")
|
||||
|
||||
fig_bytes = convert_fig_dict_to_png_bytes(fig_individual)
|
||||
|
||||
sanitize_paths_for_pickle(raw_haemo, epochs)
|
||||
|
||||
if progress_callback: progress_callback(25)
|
||||
logger.info("25")
|
||||
|
||||
return raw_haemo, epochs, fig_bytes, cha, contrast_results, df_ind, design_matrix, AGE, GENDER, GROUP, True
|
||||
|
||||
|
||||
def sanitize_paths_for_pickle(raw_haemo, epochs):
|
||||
# Fix raw_haemo._filenames
|
||||
if hasattr(raw_haemo, '_filenames'):
|
||||
|
||||
45
main.py
45
main.py
@@ -58,11 +58,30 @@ SECTIONS = [
|
||||
{
|
||||
"title": "Preprocessing",
|
||||
"params": [
|
||||
{"name": "SECONDS_TO_STRIP", "default": 0, "type": int, "help": "Seconds to remove from beginning of all loaded snirf files. Setting this to 0 will remove nothing from the files."},
|
||||
# {"name": "SECONDS_TO_STRIP", "default": 0, "type": int, "help": "Seconds to remove from beginning of all loaded snirf files. Setting this to 0 will remove nothing from the files."},
|
||||
{"name": "DOWNSAMPLE", "default": True, "type": bool, "help": "Should the snirf files be downsampled? If this is set to True, DOWNSAMPLE_FREQUENCY will be used as the target frequency to downsample to."},
|
||||
{"name": "DOWNSAMPLE_FREQUENCY", "default": 25, "type": int, "help": "Frequency (Hz) to downsample to. If this is set higher than the input data, new data will be interpolated. Only used if DOWNSAMPLE is set to True"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Trimming",
|
||||
"params": [
|
||||
{"name": "TRIM", "default": True, "type": bool, "help": "Trim the file start."},
|
||||
{"name": "SECONDS_TO_KEEP", "default": 5, "type": float, "help": "Seconds to keep at the beginning of all loaded snirf files before the first annotation/event occurs. Calculation is done seperatly on all loaded snirf files. Setting this to 0 will have the first annotation/event be at time point 0."},
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Verify Optode Placement",
|
||||
"params": [
|
||||
{"name": "OPTODE_PLACEMENT", "default": True, "type": bool, "help": "Generate an image for each participant outlining their optode placement."},
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Heart Rate",
|
||||
"params": [
|
||||
{"name": "HEART_RATE", "default": True, "type": bool, "help": "Attempt to calculate the participants heart rate."},
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Scalp Coupling Index",
|
||||
"params": [
|
||||
@@ -108,6 +127,15 @@ SECTIONS = [
|
||||
{"name": "TDDR", "default": True, "type": bool, "help": "Apply Temporal Derivitave Distribution Repair filtering - a method that removes baseline shift and spike artifacts from the data."},
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Wavelet filtering",
|
||||
"params": [
|
||||
{"name": "WAVELET", "default": True, "type": bool, "help": "Apply Wavelet filtering."},
|
||||
{"name": "IQR", "default": 1.5, "type": float, "help": "Inter-Quartile Range."},
|
||||
{"name": "WAVELET_TYPE", "default": "db4", "type": str, "help": "Wavelet type."},
|
||||
{"name": "WAVELET_LEVEL", "default": 3, "type": int, "help": "Wavelet level."},
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Haemoglobin Concentration",
|
||||
"params": [
|
||||
@@ -117,22 +145,23 @@ SECTIONS = [
|
||||
{
|
||||
"title": "Enhance Negative Correlation",
|
||||
"params": [
|
||||
#{"name": "ENHANCE_NEGATIVE_CORRELATION", "default": False, "type": bool, "help": "Calculate Peak Spectral Power."},
|
||||
{"name": "ENHANCE_NEGATIVE_CORRELATION", "default": False, "type": bool, "help": "Apply Enhance Negative Correlation."},
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Filtering",
|
||||
"params": [
|
||||
{"name": "FILTER", "default": True, "type": bool, "help": "Filter the data."},
|
||||
{"name": "L_FREQ", "default": 0.005, "type": float, "help": "Any frequencies lower than this value will be removed."},
|
||||
{"name": "H_FREQ", "default": 0.7, "type": float, "help": "Any frequencies higher than this value will be removed."},
|
||||
#{"name": "FILTER", "default": True, "type": bool, "help": "Calculate Peak Spectral Power."},
|
||||
|
||||
{"name": "H_FREQ", "default": 0.3, "type": float, "help": "Any frequencies higher than this value will be removed."},
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Short Channels",
|
||||
"title": "Short/Long Channels",
|
||||
"params": [
|
||||
{"name": "SHORT_CHANNEL", "default": True, "type": bool, "help": "This should be set to True if the data has a short channel present in the data."},
|
||||
{"name": "SHORT_CHANNEL_THRESH", "default": 0.015, "type": float, "help": "The maximum distance the short channel can be in metres."},
|
||||
{"name": "LONG_CHANNEL_THRESH", "default": 0.045, "type": float, "help": "The maximum distance the long channel can be in metres."},
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -151,7 +180,7 @@ SECTIONS = [
|
||||
"title": "Design Matrix",
|
||||
"params": [
|
||||
{"name": "REMOVE_EVENTS", "default": "None", "type": list, "help": "Remove events matching the names provided before generating the Design Matrix"},
|
||||
# {"name": "DRIFT_MODEL", "default": "cosine", "type": str, "help": "Drift model for GLM."},
|
||||
{"name": "DRIFT_MODEL", "default": "cosine", "type": str, "help": "Drift model for GLM."},
|
||||
# {"name": "DURATION_BETWEEN_ACTIVITIES", "default": 35, "type": int, "help": "Time between activities (s)."},
|
||||
# {"name": "SHORT_CHANNEL_REGRESSION", "default": True, "type": bool, "help": "Use short channel regression."},
|
||||
]
|
||||
@@ -1200,7 +1229,7 @@ class ProgressBubble(QWidget):
|
||||
self.progress_layout = QHBoxLayout()
|
||||
|
||||
self.rects = []
|
||||
for _ in range(20):
|
||||
for _ in range(25):
|
||||
rect = QFrame()
|
||||
rect.setFixedSize(10, 18)
|
||||
rect.setStyleSheet("background-color: white; border: 1px solid gray;")
|
||||
|
||||
Reference in New Issue
Block a user