diff --git a/changelog.md b/changelog.md index aff804e..b30b513 100644 --- a/changelog.md +++ b/changelog.md @@ -2,12 +2,14 @@ - Fixed a bug where having both a L_FREQ and H_FREQ would cause only the L_FREQ to be used - Changed the default H_FREQ from 0.7 to 0.3 -- Removed SECONDS_TO_STRIP from the preprocessing options -- Instead all files are trimmed up until 5 seconds before the first annotation/event in the file - Added a PSD graph, along with 2 heart rate images to the individual participant viewer -- The PSD graph is used to help calculate the heart rate, whereas the other 2 are just for show currently +- The PSD graph is used to help calculate the heart rate, whereas the other 2 are currently just for show - SCI is now done using a .6hz window around the calculated heart rate compared to a window around an average heart rate -- Fixed an issue with some epochs figures not showing +- Fixed an issue with some epochs figures not showing under the participant analysis +- Removed SECONDS_TO_STRIP from the preprocessing options +- Added new parameters to the right side of the screen +- These parameters include TRIM, SECONDS_TO_KEEP, OPTODE_PLACEMENT, HEART_RATE, WAVELET, IQR, WAVELET_TYPE, WAVELET_LEVEL, ENHANCE_NEGATIVE_CORRELATION, SHORT_CHANNEL_THRESH, LONG_CHANNEL_THRESH, and DRIFT_MODEL +- Changed number of rectangles in the progress bar to 25 to account for the new options # Version 1.1.6 diff --git a/flares.py b/flares.py index 5346994..efb12e6 100644 --- a/flares.py +++ b/flares.py @@ -119,6 +119,13 @@ GENDER: str DOWNSAMPLE: bool DOWNSAMPLE_FREQUENCY: int +TRIM: bool +SECONDS_TO_KEEP: float + +OPTODE_PLACEMENT: bool + +HEART_RATE: bool + SCI: bool SCI_TIME_WINDOW: int SCI_THRESHOLD: float @@ -133,27 +140,35 @@ PSP_THRESHOLD: float TDDR: bool -IQR = 1.5 +WAVELET: bool +IQR: float +WAVELET_TYPE: str +WAVELET_LEVEL: int + HEART_RATE = True # True if heart rate should be calculated. This helps the SCI, PSP, and SNR methods to be more accurate. SECONDS_TO_STRIP_HR =5 # Amount of seconds to temporarily strip from the data to calculate heart rate more effectively. Useful if participant removed cap while still recording. MAX_LOW_HR = 40 # Any heart rate values lower than this will be set to this value. MAX_HIGH_HR = 200 # Any heart rate values higher than this will be set to this value. SMOOTHING_WINDOW_HR = 100 # Heart rate will be calculated as a rolling average over this many amount of samples. HEART_RATE_WINDOW = 25 # Amount of BPM above and below the calculated average to use for a range of resting BPM. -SHORT_CHANNEL_THRESH = 0.018 ENHANCE_NEGATIVE_CORRELATION: bool +FILTER: bool L_FREQ: float H_FREQ: float SHORT_CHANNEL: bool +SHORT_CHANNEL_THRESH: float +LONG_CHANNEL_THRESH: float REMOVE_EVENTS: list TIME_WINDOW_START: int TIME_WINDOW_END: int +DRIFT_MODEL: str + VERBOSITY = True # FIXME: Shouldn't need each ordering - just order it before checking @@ -183,11 +198,17 @@ GROUP = "Default" REQUIRED_KEYS: dict[str, Any] = { - # "SECONDS_TO_STRIP": int, "DOWNSAMPLE": bool, "DOWNSAMPLE_FREQUENCY": int, + + "TRIM": bool, + "SECONDS_TO_KEEP": float, + "OPTODE_PLACEMENT": bool, + + "HEART_RATE": bool, + "SCI": bool, "SCI_TIME_WINDOW": int, "SCI_THRESHOLD": float, @@ -201,11 +222,23 @@ REQUIRED_KEYS: dict[str, Any] = { "PSP_THRESHOLD": float, "SHORT_CHANNEL": bool, + "SHORT_CHANNEL_THRESH": float, + "LONG_CHANNEL_THRESH": float, + + "REMOVE_EVENTS": list, "TIME_WINDOW_START": int, "TIME_WINDOW_END": int, "L_FREQ": float, "H_FREQ": float, + + "TDDR": bool, + "WAVELET": bool, + "IQR": float, + "WAVELET_TYPE": str, + "WAVELET_LEVEL": int, + "FILTER": bool, + "DRIFT_MODEL": str, # "REJECT_PAIRS": bool, # "FORCE_DROP_ANNOTATIONS": list, # "FILTER_LOW_PASS": float, @@ -1107,7 +1140,7 @@ def filter_the_data(raw_haemo): fig_raw_haemo_filter = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="Filtered HbO and HbR", show=False) - return fig_filter, fig_raw_haemo_filter + return raw_haemo, fig_filter, fig_raw_haemo_filter @@ -1284,7 +1317,7 @@ def make_design_matrix(raw_haemo, short_chans): hrf_model='fir', stim_dur=0.5, fir_delays=range(15), - drift_model='cosine', + drift_model=DRIFT_MODEL, high_pass=0.01, oversampling=1, min_onset=-125, @@ -1297,7 +1330,7 @@ def make_design_matrix(raw_haemo, short_chans): hrf_model='fir', stim_dur=0.5, fir_delays=range(15), - drift_model='cosine', + drift_model=DRIFT_MODEL, high_pass=0.01, oversampling=1, min_onset=-125, @@ -2975,7 +3008,7 @@ def calculate_and_apply_wavelet(data: BaseRaw) -> tuple[BaseRaw, Figure]: logger.info("Calculating the IQR, decomposing the signal, and thresholding the coefficients...") for ch in range(loaded_data.shape[0]): - denoised_data[ch, :] = wavelet_iqr_denoise(loaded_data[ch, :], wavelet='db4', level=3) + denoised_data[ch, :] = wavelet_iqr_denoise(loaded_data[ch, :], wavelet=WAVELET_TYPE, level=WAVELET_LEVEL) # Reconstruct the data with the annotations logger.info("Reconstructing the data with annotations...") @@ -3289,68 +3322,66 @@ def process_participant(file_path, progress_callback=None): logger.info("1") - if hasattr(raw, 'annotations') and len(raw.annotations) > 0: - # Get time of first event - first_event_time = raw.annotations.onset[0] - trim_time = max(0, first_event_time - 5.0) # Ensure we don't go negative - raw.crop(tmin=trim_time) - # Shift annotation onsets to match new t=0 - import mne + if TRIM: + if hasattr(raw, 'annotations') and len(raw.annotations) > 0: + # Get time of first event + first_event_time = raw.annotations.onset[0] + trim_time = max(0, first_event_time - SECONDS_TO_KEEP) # Ensure we don't go negative + raw.crop(tmin=trim_time) + # Shift annotation onsets to match new t=0 + import mne - ann = raw.annotations - ann_shifted = mne.Annotations( - onset=ann.onset - trim_time, # shift to start at zero - duration=ann.duration, - description=ann.description - ) - data = raw.get_data() - info = raw.info.copy() - raw = mne.io.RawArray(data, info) - raw.set_annotations(ann_shifted) + ann = raw.annotations + ann_shifted = mne.Annotations( + onset=ann.onset - trim_time, # shift to start at zero + duration=ann.duration, + description=ann.description + ) + data = raw.get_data() + info = raw.info.copy() + raw = mne.io.RawArray(data, info) + raw.set_annotations(ann_shifted) - logger.info(f"Trimmed raw data: start at {trim_time}s (5s before first event), t=0 at new start") - else: - logger.warning("No events found, skipping trim step.") + logger.info(f"Trimmed raw data: start at {trim_time}s (5s before first event), t=0 at new start") + else: + logger.warning("No events found, skipping trim step.") - fig_trimmed = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Trimmed Raw", show=False) - fig_individual["Trimmed Raw"] = fig_trimmed + fig_trimmed = raw.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Trimmed Raw", show=False) + fig_individual["Trimmed Raw"] = fig_trimmed if progress_callback: progress_callback(2) logger.info("2") # Step 1.5: Verify optode positions - fig_optodes = raw.plot_sensors(show_names=True, to_sphere=True, show=False) # type: ignore - fig_individual["Plot Sensors"] = fig_optodes - if progress_callback: progress_callback(2) - logger.info("2") - - # Step 2: Downsample - # raw = raw.resample(0.5) # Downsample to 0.5 Hz + if OPTODE_PLACEMENT: + fig_optodes = raw.plot_sensors(show_names=True, to_sphere=True, show=False) # type: ignore + fig_individual["Plot Sensors"] = fig_optodes + if progress_callback: progress_callback(3) + logger.info("3") # Step 2: Bad from SCI - if True: + if HEART_RATE: fig, hr1, hr2, low, high = hr_calc(raw) fig_individual["PSD"] = fig fig_individual['HeartRate_PSD'] = hr1 fig_individual['HeartRate_Time'] = hr2 - if progress_callback: progress_callback(10) - - if progress_callback: progress_callback(2) + if progress_callback: progress_callback(4) + logger.info("4") bad_sci = [] if SCI: bad_sci, fig_sci_1, fig_sci_2 = calculate_scalp_coupling(raw, low, high) fig_individual["SCI1"] = fig_sci_1 fig_individual["SCI2"] = fig_sci_2 - if progress_callback: progress_callback(3) - logger.info("3") + if progress_callback: progress_callback(5) + logger.info("5") # Step 2: Bad from SNR bad_snr = [] if SNR: bad_snr, fig_snr = calculate_signal_noise_ratio(raw) fig_individual["SNR1"] = fig_snr - if progress_callback: progress_callback(4) - logger.info("4") + if progress_callback: progress_callback(6) + logger.info("6") # Step 3: Bad from PSP bad_psp = [] @@ -3358,88 +3389,94 @@ def process_participant(file_path, progress_callback=None): bad_psp, fig_psp1, fig_psp2 = calculate_peak_power(raw) fig_individual["PSP1"] = fig_psp1 fig_individual["PSP2"] = fig_psp2 - if progress_callback: progress_callback(5) - logger.info("5") + if progress_callback: progress_callback(7) + logger.info("7") # Step 4: Mark the bad channels raw, fig_dropped, fig_raw_before, bad_channels = mark_bads(raw, bad_sci, bad_snr, bad_psp) if fig_dropped and fig_raw_before is not None: fig_individual["fig2"] = fig_dropped fig_individual["fig3"] = fig_raw_before - if progress_callback: progress_callback(6) - logger.info("6") + if progress_callback: progress_callback(8) + logger.info("8") # Step 5: Interpolate the bad channels if bad_channels: raw, fig_raw_after = interpolate_fNIRS_bads_weighted_average(raw, bad_channels) fig_individual["fig4"] = fig_raw_after - if progress_callback: progress_callback(7) - logger.info("7") + if progress_callback: progress_callback(9) + logger.info("9") # Step 6: Optical Density raw_od = optical_density(raw) fig_raw_od = raw_od.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="Optical Density", show=False) fig_individual["Optical Density"] = fig_raw_od - if progress_callback: progress_callback(8) - logger.info("8") + if progress_callback: progress_callback(10) + logger.info("10") # Step 7: TDDR if TDDR: raw_od = temporal_derivative_distribution_repair(raw_od) fig_raw_od_tddr = raw_od.plot(duration=raw.times[-1], n_channels=raw.info['nchan'], title="After TDDR (Motion Correction)", show=False) fig_individual["TDDR"] = fig_raw_od_tddr - if progress_callback: progress_callback(9) - logger.info("9") + if progress_callback: progress_callback(11) + logger.info("11") - raw_od, fig = calculate_and_apply_wavelet(raw_od) - fig_individual["Wavelet"] = fig - if progress_callback: progress_callback(9) - + if WAVELET: + raw_od, fig = calculate_and_apply_wavelet(raw_od) + fig_individual["Wavelet"] = fig + if progress_callback: progress_callback(12) + logger.info("12") + # Step 8: BLL raw_haemo = beer_lambert_law(raw_od, ppf=calculate_dpf(file_path)) fig_raw_haemo_bll = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="HbO and HbR Signals", show=False) fig_individual["BLL"] = fig_raw_haemo_bll - if progress_callback: progress_callback(10) - logger.info("10") + if progress_callback: progress_callback(13) + logger.info("13") # Step 9: ENC - # raw_haemo = enhance_negative_correlation(raw_haemo) - # fig_raw_haemo_enc = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="HbO and HbR Signals", show=False) - # fig_individual.append(fig_raw_haemo_enc) - + if ENHANCE_NEGATIVE_CORRELATION: + raw_haemo = enhance_negative_correlation(raw_haemo) + fig_raw_haemo_enc = raw_haemo.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="HbO and HbR Signals", show=False) + fig_individual["ENC"] = fig_raw_haemo_enc + if progress_callback: progress_callback(14) + logger.info("14") + # Step 10: Filter - fig_filter, fig_raw_haemo_filter = filter_the_data(raw_haemo) - fig_individual["filter1"] = fig_filter - fig_individual["filter2"] = fig_raw_haemo_filter - if progress_callback: progress_callback(11) - logger.info("11") + if FILTER: + raw_haemo, fig_filter, fig_raw_haemo_filter = filter_the_data(raw_haemo) + fig_individual["filter1"] = fig_filter + fig_individual["filter2"] = fig_raw_haemo_filter + if progress_callback: progress_callback(15) + logger.info("15") # Step 11: Get short / long channels if SHORT_CHANNEL: - short_chans = get_short_channels(raw_haemo, max_dist=0.02) + short_chans = get_short_channels(raw_haemo, max_dist=SHORT_CHANNEL_THRESH) fig_short_chans = short_chans.plot(duration=raw_haemo.times[-1], n_channels=raw_haemo.info['nchan'], title="Short Channels Only", show=False) fig_individual["short"] = fig_short_chans else: short_chans = None - raw_haemo = get_long_channels(raw_haemo) - if progress_callback: progress_callback(12) - logger.info("12") + raw_haemo = get_long_channels(raw_haemo, min_dist=SHORT_CHANNEL_THRESH, max_dist=LONG_CHANNEL_THRESH) + if progress_callback: progress_callback(16) + logger.info("16") # Step 12: Events from annotations events, event_dict = events_from_annotations(raw_haemo) fig_events = plot_events(events, event_id=event_dict, sfreq=raw_haemo.info["sfreq"], show=False) fig_individual["events"] = fig_events - if progress_callback: progress_callback(13) - logger.info("13") + if progress_callback: progress_callback(17) + logger.info("17") # Step 13: Epoch calculations epochs, fig_epochs = epochs_calculations(raw_haemo, events, event_dict) for name, fig in fig_epochs: # Unpack the tuple here fig_individual[f"epochs_{name}"] = fig # Store only the figure, not the name - if progress_callback: progress_callback(14) - logger.info("14") + if progress_callback: progress_callback(18) + logger.info("18") # Step 14: Design Matrix events_to_remove = REMOVE_EVENTS @@ -3457,8 +3494,8 @@ def process_participant(file_path, progress_callback=None): design_matrix, fig_design_matrix = make_design_matrix(raw_haemo, short_chans) fig_individual["Design Matrix"] = fig_design_matrix - if progress_callback: progress_callback(15) - logger.info("15") + if progress_callback: progress_callback(19) + logger.info("19") # Step 15: Run GLM glm_est = run_glm(raw_haemo, design_matrix) @@ -3473,22 +3510,22 @@ def process_participant(file_path, progress_callback=None): # A large p-value means the data do not provide strong evidence that the effect is different from zero. - if progress_callback: progress_callback(16) - logger.info("16") + if progress_callback: progress_callback(20) + logger.info("20") # Step 16: Plot GLM results fig_glm_result = plot_glm_results(file_path, raw_haemo, glm_est, design_matrix) for name, fig in fig_glm_result: fig_individual[f"GLM {name}"] = fig - if progress_callback: progress_callback(17) - logger.info("17") + if progress_callback: progress_callback(21) + logger.info("21") # Step 17: Plot channel significance fig_significance = individual_significance(raw_haemo, glm_est) for name, fig in fig_significance: fig_individual[f"Significance {name}"] = fig - if progress_callback: progress_callback(18) - logger.info("18") + if progress_callback: progress_callback(22) + logger.info("22") # Step 18: cha, con, roi cha = glm_est.to_dataframe() @@ -3543,8 +3580,8 @@ def process_participant(file_path, progress_callback=None): contrast_dict[condition] = contrast_vector - if progress_callback: progress_callback(19) - logger.info("19") + if progress_callback: progress_callback(23) + logger.info("23") # Compute contrast results contrast_results = {} @@ -3557,15 +3594,20 @@ def process_participant(file_path, progress_callback=None): cha["ID"] = file_path + if progress_callback: progress_callback(24) + logger.info("24") + + fig_bytes = convert_fig_dict_to_png_bytes(fig_individual) - if progress_callback: progress_callback(20) - logger.info("20") - sanitize_paths_for_pickle(raw_haemo, epochs) + if progress_callback: progress_callback(25) + logger.info("25") + return raw_haemo, epochs, fig_bytes, cha, contrast_results, df_ind, design_matrix, AGE, GENDER, GROUP, True + def sanitize_paths_for_pickle(raw_haemo, epochs): # Fix raw_haemo._filenames if hasattr(raw_haemo, '_filenames'): diff --git a/main.py b/main.py index 8b6e9ec..a5ef4df 100644 --- a/main.py +++ b/main.py @@ -63,6 +63,25 @@ SECTIONS = [ {"name": "DOWNSAMPLE_FREQUENCY", "default": 25, "type": int, "help": "Frequency (Hz) to downsample to. If this is set higher than the input data, new data will be interpolated. Only used if DOWNSAMPLE is set to True"}, ] }, + { + "title": "Trimming", + "params": [ + {"name": "TRIM", "default": True, "type": bool, "help": "Trim the file start."}, + {"name": "SECONDS_TO_KEEP", "default": 5, "type": float, "help": "Seconds to keep at the beginning of all loaded snirf files before the first annotation/event occurs. Calculation is done seperatly on all loaded snirf files. Setting this to 0 will have the first annotation/event be at time point 0."}, + ] + }, + { + "title": "Verify Optode Placement", + "params": [ + {"name": "OPTODE_PLACEMENT", "default": True, "type": bool, "help": "Generate an image for each participant outlining their optode placement."}, + ] + }, + { + "title": "Heart Rate", + "params": [ + {"name": "HEART_RATE", "default": True, "type": bool, "help": "Attempt to calculate the participants heart rate."}, + ] + }, { "title": "Scalp Coupling Index", "params": [ @@ -108,6 +127,15 @@ SECTIONS = [ {"name": "TDDR", "default": True, "type": bool, "help": "Apply Temporal Derivitave Distribution Repair filtering - a method that removes baseline shift and spike artifacts from the data."}, ] }, + { + "title": "Wavelet filtering", + "params": [ + {"name": "WAVELET", "default": True, "type": bool, "help": "Apply Wavelet filtering."}, + {"name": "IQR", "default": 1.5, "type": float, "help": "Inter-Quartile Range."}, + {"name": "WAVELET_TYPE", "default": "db4", "type": str, "help": "Wavelet type."}, + {"name": "WAVELET_LEVEL", "default": 3, "type": int, "help": "Wavelet level."}, + ] + }, { "title": "Haemoglobin Concentration", "params": [ @@ -117,22 +145,23 @@ SECTIONS = [ { "title": "Enhance Negative Correlation", "params": [ - #{"name": "ENHANCE_NEGATIVE_CORRELATION", "default": False, "type": bool, "help": "Calculate Peak Spectral Power."}, + {"name": "ENHANCE_NEGATIVE_CORRELATION", "default": False, "type": bool, "help": "Apply Enhance Negative Correlation."}, ] }, { "title": "Filtering", "params": [ + {"name": "FILTER", "default": True, "type": bool, "help": "Filter the data."}, {"name": "L_FREQ", "default": 0.005, "type": float, "help": "Any frequencies lower than this value will be removed."}, {"name": "H_FREQ", "default": 0.3, "type": float, "help": "Any frequencies higher than this value will be removed."}, - #{"name": "FILTER", "default": True, "type": bool, "help": "Calculate Peak Spectral Power."}, - ] }, { - "title": "Short Channels", + "title": "Short/Long Channels", "params": [ {"name": "SHORT_CHANNEL", "default": True, "type": bool, "help": "This should be set to True if the data has a short channel present in the data."}, + {"name": "SHORT_CHANNEL_THRESH", "default": 0.015, "type": float, "help": "The maximum distance the short channel can be in metres."}, + {"name": "LONG_CHANNEL_THRESH", "default": 0.045, "type": float, "help": "The maximum distance the long channel can be in metres."}, ] }, { @@ -151,7 +180,7 @@ SECTIONS = [ "title": "Design Matrix", "params": [ {"name": "REMOVE_EVENTS", "default": "None", "type": list, "help": "Remove events matching the names provided before generating the Design Matrix"}, - # {"name": "DRIFT_MODEL", "default": "cosine", "type": str, "help": "Drift model for GLM."}, + {"name": "DRIFT_MODEL", "default": "cosine", "type": str, "help": "Drift model for GLM."}, # {"name": "DURATION_BETWEEN_ACTIVITIES", "default": 35, "type": int, "help": "Time between activities (s)."}, # {"name": "SHORT_CHANNEL_REGRESSION", "default": True, "type": bool, "help": "Use short channel regression."}, ] @@ -1200,7 +1229,7 @@ class ProgressBubble(QWidget): self.progress_layout = QHBoxLayout() self.rects = [] - for _ in range(20): + for _ in range(25): rect = QFrame() rect.setFixedSize(10, 18) rect.setStyleSheet("background-color: white; border: 1px solid gray;")