diff --git a/LICENSE b/LICENSE index 2b2c8b6..13d9d2f 100644 --- a/LICENSE +++ b/LICENSE @@ -200,33 +200,4 @@ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY C 17. Interpretation of Sections 15 and 16. If the disclaimer of warranty and limitation of liability provided above cannot be given local legal effect according to their terms, reviewing courts shall apply local law that most closely approximates an absolute waiver of all civil liability in connection with the Program, unless a warranty or assumption of liability accompanies a copy of the Program in return for a fee. -END OF TERMS AND CONDITIONS - -How to Apply These Terms to Your New Programs - -If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it free software which everyone can redistribute and change under these terms. - -To do so, attach the following notices to the program. It is safest to attach them to the start of each source file to most effectively state the exclusion of warranty; and each file should have at least the “copyright” line and a pointer to where the full notice is found. - - sparks - Copyright (C) 2026 tyler - - This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. - - This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. - - You should have received a copy of the GNU General Public License along with this program. If not, see . - -Also add information on how to contact you by electronic and paper mail. - -If the program does terminal interaction, make it output a short notice like this when it starts in an interactive mode: - - sparks Copyright (C) 2026 tyler - This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. - This is free software, and you are welcome to redistribute it under certain conditions; type `show c' for details. - -The hypothetical commands `show w' and `show c' should show the appropriate parts of the General Public License. Of course, your program's commands might be different; for a GUI interface, you would use an “about box”. - -You should also get your employer (if you work as a programmer) or school, if any, to sign a “copyright disclaimer” for the program, if necessary. For more information on this, and how to apply and follow the GNU GPL, see . - -The GNU General Public License does not permit incorporating your program into proprietary programs. If your program is a subroutine library, you may consider it more useful to permit linking proprietary applications with the library. If this is what you want to do, use the GNU Lesser General Public License instead of this License. But first, please read . +END OF TERMS AND CONDITIONS \ No newline at end of file diff --git a/README.md b/README.md index be50970..1ebb5c9 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,44 @@ -# sparks +SPARKS (Spacial Patterns Analysis, Research, & Knowledge Suite) +================================================================= -Spatial Patterns, Analysis, Research, & Knowledge Suite \ No newline at end of file +SPARKS is a standalone application to extract meaningful data out of video files. + +SPARKS is free and open-source software that runs on Windows, MacOS, and Linux. Please read the information regarding each operating system below. + +Visit the official [SPARKS web site](https://research.dezeeuw.ca/sparks). + +[![Python web site](https://img.shields.io/badge/Made%20with-Python-1f425f.svg)](https://www.python.org) + +# For MacOS Users + +Due to the cost of an Apple Developer account, the application is not certified by Apple. Once the application is extracted and attempted to be launched for the first time you will get a popup stating: + +"Apple could not verify sparks.app is free of malware that may harm your Mac or compromise your privacy.", with the options of "Done" or "Move to Trash". + +The solution around this is to use finder and navigate to the sparks-darwin folder. Once the folder has been located, right click the folder and click the option "New Terminal at Folder". Once the terminal opens, run the following command (you can copy + paste): + +```xattr -dr com.apple.quarantine sparks.app & pid1=$!; xattr -dr com.apple.quarantine sparks_updater.app & pid2=$!; wait $pid1 $pid2; exit``` + +Once the command has been executed and the text "[Process completed]" appears, you may close the terminal window and attempt to open the application again. If you choose to unrestrict the app through Settings > Privacy & Security, the app may not be able to update correctly in the future. + +This only applies for the first time you attempt to run SPARKS. Subsequent times, including after updates, will function correctly as-is. + +# For Windows Users + +Due to the cost of a code signing certificate, the application is not digitally signed. Once the application is extracted and attempted to be launched for the first time you will get a popup stating: + +"Windows protected your PC - Microsoft Defender SmartScreen prevented an unrecognized app from starting. Running this app might put your PC at risk.", with the options of" More info" or "Don't run". + +The solution around this is to click "More info" and then select "Run anyway". + +This only applies for the first time you attempt to run SPARKS. Subsequent times, including after updates, will function correctly as-is. + +# For Linux Users + +There are no conditions for Linux users at this time. + +# Licence + +SPARKS is distributed under the GPL-3.0 license. + +Copyright (C) 2025-2026 Tyler de Zeeuw \ No newline at end of file diff --git a/changelog.md b/changelog.md new file mode 100644 index 0000000..68bddf6 --- /dev/null +++ b/changelog.md @@ -0,0 +1,3 @@ +# Version 0.1.0 + +- Initial preview release. \ No newline at end of file diff --git a/icons/article_24dp_1F1F1F.svg b/icons/article_24dp_1F1F1F.svg new file mode 100644 index 0000000..66793ed --- /dev/null +++ b/icons/article_24dp_1F1F1F.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/icons/content_copy_24dp_1F1F1F.svg b/icons/content_copy_24dp_1F1F1F.svg new file mode 100644 index 0000000..aeabcb9 --- /dev/null +++ b/icons/content_copy_24dp_1F1F1F.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/icons/content_cut_24dp_1F1F1F.svg b/icons/content_cut_24dp_1F1F1F.svg new file mode 100644 index 0000000..1a03cdb --- /dev/null +++ b/icons/content_cut_24dp_1F1F1F.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/icons/content_paste_24dp_1F1F1F.svg b/icons/content_paste_24dp_1F1F1F.svg new file mode 100644 index 0000000..6eea988 --- /dev/null +++ b/icons/content_paste_24dp_1F1F1F.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/icons/exit_to_app_24dp_1F1F1F.svg b/icons/exit_to_app_24dp_1F1F1F.svg new file mode 100644 index 0000000..7da9840 --- /dev/null +++ b/icons/exit_to_app_24dp_1F1F1F.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/icons/file_open_24dp_1F1F1F.svg b/icons/file_open_24dp_1F1F1F.svg new file mode 100644 index 0000000..29343b7 --- /dev/null +++ b/icons/file_open_24dp_1F1F1F.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/icons/folder_24dp_1F1F1F.svg b/icons/folder_24dp_1F1F1F.svg new file mode 100644 index 0000000..c4edc42 --- /dev/null +++ b/icons/folder_24dp_1F1F1F.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/icons/folder_copy_24dp_1F1F1F.svg b/icons/folder_copy_24dp_1F1F1F.svg new file mode 100644 index 0000000..d903aa9 --- /dev/null +++ b/icons/folder_copy_24dp_1F1F1F.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/icons/folder_eye_24dp_1F1F1F.svg b/icons/folder_eye_24dp_1F1F1F.svg new file mode 100644 index 0000000..aa5ccc6 --- /dev/null +++ b/icons/folder_eye_24dp_1F1F1F.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/icons/help_24dp_1F1F1F.svg b/icons/help_24dp_1F1F1F.svg new file mode 100644 index 0000000..ea47319 --- /dev/null +++ b/icons/help_24dp_1F1F1F.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/icons/info_24dp_1F1F1F.svg b/icons/info_24dp_1F1F1F.svg new file mode 100644 index 0000000..f749f5c --- /dev/null +++ b/icons/info_24dp_1F1F1F.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/icons/main.icns b/icons/main.icns new file mode 100644 index 0000000..724bfd8 Binary files /dev/null and b/icons/main.icns differ diff --git a/icons/main.ico b/icons/main.ico new file mode 100644 index 0000000..c5ad4f9 Binary files /dev/null and b/icons/main.ico differ diff --git a/icons/remove_24dp_1F1F1F.svg b/icons/remove_24dp_1F1F1F.svg new file mode 100644 index 0000000..ed87e30 --- /dev/null +++ b/icons/remove_24dp_1F1F1F.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/icons/save_24dp_1F1F1F.svg b/icons/save_24dp_1F1F1F.svg new file mode 100644 index 0000000..a8b8172 --- /dev/null +++ b/icons/save_24dp_1F1F1F.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/icons/save_as_24dp_1F1F1F.svg b/icons/save_as_24dp_1F1F1F.svg new file mode 100644 index 0000000..671d587 --- /dev/null +++ b/icons/save_as_24dp_1F1F1F.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/icons/terminal_24dp_1F1F1F.svg b/icons/terminal_24dp_1F1F1F.svg new file mode 100644 index 0000000..0a8e6a6 --- /dev/null +++ b/icons/terminal_24dp_1F1F1F.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/icons/update_24dp_1F1F1F.svg b/icons/update_24dp_1F1F1F.svg new file mode 100644 index 0000000..c62cfde --- /dev/null +++ b/icons/update_24dp_1F1F1F.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/icons/updater.icns b/icons/updater.icns new file mode 100644 index 0000000..da2ab20 Binary files /dev/null and b/icons/updater.icns differ diff --git a/icons/updater.ico b/icons/updater.ico new file mode 100644 index 0000000..9ee2b55 Binary files /dev/null and b/icons/updater.ico differ diff --git a/icons/upgrade_24dp_1F1F1F.svg b/icons/upgrade_24dp_1F1F1F.svg new file mode 100644 index 0000000..3640fe1 --- /dev/null +++ b/icons/upgrade_24dp_1F1F1F.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..e717d46 --- /dev/null +++ b/main.py @@ -0,0 +1,3223 @@ +""" +Filename: main.py +Description: SPARKS main executable + +Author: Tyler de Zeeuw +License: GPL-3.0 +""" + +# Built-in imports +import os +import re +import csv +import sys +import math +import json +import time +import shlex +import pickle +import shutil +import zipfile +import platform +import traceback +import subprocess +from itertools import product +from pathlib import Path, PurePosixPath +from datetime import datetime +from multiprocessing import Process, current_process, freeze_support, Manager +import cv2 + +import numpy as np +import pandas as pd + +# External library imports +import psutil +import requests +import torch +import torch.nn as nn +import mediapipe as mp +from torch.utils.data import TensorDataset, DataLoader +from sklearn.utils.class_weight import compute_class_weight +import matplotlib.pyplot as plt + +from PySide6.QtWidgets import ( + QApplication, QWidget, QMessageBox, QVBoxLayout, QHBoxLayout, QTextEdit, QScrollArea, QComboBox, QGridLayout, + QPushButton, QMainWindow, QFileDialog, QLabel, QLineEdit, QFrame, QSizePolicy, QGroupBox, QDialog, QListView, QMenu, QProgressBar, QCheckBox +) +from PySide6.QtCore import QThread, Signal, Qt, QTimer, QEvent, QSize, QPoint +from PySide6.QtGui import QAction, QKeySequence, QIcon, QIntValidator, QDoubleValidator, QPixmap, QStandardItemModel, QStandardItem, QImage +from PySide6.QtSvgWidgets import QSvgWidget # needed to show svgs when app is not frozen + + +CURRENT_VERSION = "1.0.0" + +API_URL = "https://git.research.dezeeuw.ca/api/v1/repos/tyler/sparks/releases" +API_URL_SECONDARY = "https://git.research2.dezeeuw.ca/api/v1/repos/tyler/sparks/releases" + +PLATFORM_NAME = platform.system().lower() + +# Selectable parameters on the right side of the window +SECTIONS = [ + { + "title": "Parameter 1", + "params": [ + {"name": "Number 1", "default": True, "type": bool, "help": "N/A"}, + {"name": "Number 2", "default": True, "type": bool, "help": "N/A"}, + ] + }, +] + + + + + + + +class TrainModelThread(QThread): + update = Signal(str) # new: emits messages for the GUI + finished = Signal(str) + + def __init__(self, csv_paths, model_save_path, parent=None): + super().__init__(parent) + self.csv_paths = csv_paths + self.model_save_path = model_save_path + + def run(self): + + # Load CSVs + X_list, y_list, feature_cols = [], [], None + for path in self.csv_paths: + df = pd.read_csv(path) + if feature_cols is None: + feature_cols = [c for c in df.columns if c.startswith("lm")] + df[feature_cols] = df[feature_cols].ffill().bfill() + X_list.append(df[feature_cols].values) + y_list.append(df[["reach_active", "reach_before_contact"]].values) + + X = np.concatenate(X_list, axis=0) + y = np.concatenate(y_list, axis=0) + + # Windowing + FPS, WINDOW_SEC, STEP_SEC, THRESHOLD, EPOCHS = 60, 0.4, 0.2, 0.5, 150 + window_length = int(FPS * WINDOW_SEC) + step_length = int(FPS * STEP_SEC) + X_windows, y_windows = [], [] + for start in range(0, len(X)-window_length+1, step_length): + end = start+window_length + X_win = X[start:end] + y_win = y[start:end] + label_window = (y_win.sum(axis=0) / len(y_win) >= 0.3).astype(int) # shape: (2,) + X_windows.append(X[start:end]) + y_windows.append(label_window) + + X_windows = np.array(X_windows) + y_windows = np.array(y_windows) + + class_weights = [] + for i in range(2): + cw = compute_class_weight('balanced', classes=np.array([0,1]), y=y_windows[:, i]) + class_weights.append(cw[1]/cw[0]) + pos_weight = torch.tensor(class_weights, dtype=torch.float32) + X_tensor = torch.tensor(np.array(X_windows), dtype=torch.float32) # (N, window_length, features) + y_tensor = torch.tensor(np.array(y_windows), dtype=torch.float32) # (N, 2) + + # LSTM + class WindowLSTM(nn.Module): + def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2): + super().__init__() + self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True, + bidirectional=bidirectional) + self.fc = nn.Linear(hidden_size*(2 if bidirectional else 1), output_size) + def forward(self, x): + _, (h, _) = self.lstm(x) + if self.lstm.bidirectional: + h = torch.cat([h[-2], h[-1]], dim=1) + else: + h = h[-1] + return self.fc(h) + + # Hyperparameter sweep + HIDDEN_SIZES = [64] + BATCH_SIZES = [32] + LRs = [0.0005] + OPTIMIZERS = ['adam'] + BIDIRECTIONAL = [False] + + best_f1 = 0 + best_config = None + + for hidden_size, batch_size, lr, opt_name, bi in product(HIDDEN_SIZES,BATCH_SIZES,LRs,OPTIMIZERS,BIDIRECTIONAL): + self.update.emit(f"Training model: hidden={hidden_size}, batch={batch_size}, lr={lr}, opt={opt_name}, bidir={bi}") + + dataset = TensorDataset(X_tensor, y_tensor) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + model = WindowLSTM(input_size=X_windows.shape[2], hidden_size=hidden_size, bidirectional=bi) + optimizer = torch.optim.Adam(model.parameters(), lr=lr) if opt_name=='adam' else torch.optim.SGD(model.parameters(), lr=lr) + criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) + + # Training + for epoch in range(EPOCHS): + model.train() + epoch_loss = 0 + for xb, yb in loader: + optimizer.zero_grad() + logits = model(xb) + loss = criterion(logits, yb) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + + # Evaluation + model.eval() + with torch.no_grad(): + logits = model(X_tensor) + probs = torch.sigmoid(logits) + preds = (probs > THRESHOLD).int() + y_true = y_tensor.int() + + tp = ((preds==1)&(y_true==1)).sum().item() + fp = ((preds==1)&(y_true==0)).sum().item() + fn = ((preds==0)&(y_true==1)).sum().item() + + precision = tp/(tp+fp+1e-8) + recall = tp/(tp+fn+1e-8) + + f1 = 2*precision*recall/(precision+recall+1e-8) + + self.update.emit(f"Epoch {epoch+1}/{EPOCHS} Loss={epoch_loss:.4f} " + f"Precision={precision:.3f} Recall={recall:.3f} F1={f1:.3f}") + + + if f1 > best_f1: + best_f1 = f1 + # When saving the model after training + best_config = (hidden_size, batch_size, lr, opt_name, bi) + torch.save(model.state_dict(), self.model_save_path) + + self.update.emit(f"New best model saved to {os.path.basename(self.model_save_path)}!") + self.finished.emit(f"Training complete!\nBest config: {best_config}\nF1={best_f1:.3f}") + + + + + + + +class TestModelThread(QThread): + update_frame = Signal(np.ndarray) # emit frames with overlay + finished = Signal() + + def __init__(self, video_paths, model_path, fps=60, window_sec=0.4, step_sec=0.2, threshold=0.5): + super().__init__() + self.video_paths = video_paths + self.fps = fps + self.window_sec = window_sec + self.step_sec = step_sec + self.threshold = threshold + self.model_path = model_path + self.window_frames = int(fps*window_sec) + + def run(self): + + from time import time + + class WindowLSTM(torch.nn.Module): + def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2): + super().__init__() + self.lstm = torch.nn.LSTM( + input_size, hidden_size, batch_first=True, bidirectional=bidirectional + ) + self.fc = torch.nn.Linear(hidden_size * (2 if bidirectional else 1), output_size) + + def forward(self, x): + _, (h, _) = self.lstm(x) + if self.lstm.bidirectional: + h = torch.cat([h[-2], h[-1]], dim=1) + else: + h = h[-1] + return self.fc(h) + + # MATCH THE MODEL CONFIG YOU TRAINED WITH + model = WindowLSTM(input_size=63, hidden_size=64, bidirectional=False) + model.load_state_dict(torch.load(self.model_path, map_location="cpu")) + model.eval() + + print("Loaded model.") + + + # ------------------------------ MediaPipe ------------------------------ + mp_hands = mp.solutions.hands + mp_drawing = mp.solutions.drawing_utils + mp_drawing_styles = mp.solutions.drawing_styles + + hands = mp_hands.Hands( + static_image_mode=False, + max_num_hands=2, + min_detection_confidence=0.5, + min_tracking_confidence=0.5, + model_complexity=0 + ) + + # ------------------------------ Open videos ------------------------------ + + cap = cv2.VideoCapture(self.video_paths[0]) + + + # buffer holds last 60 frames of landmarks + window_buffer = [] + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + print(f"Original video dimensions: {original_width} x {original_height}") + print(f"Total frames in video: {total_frames}") + + # Set display dimensions (maintaining aspect ratio) + display_width = 800 + aspect_ratio = original_height / original_width + display_height = int(display_width * aspect_ratio) + print(f"Display dimensions: {display_width} x {display_height}") + + window_name = f"Hand Tracking Test - Press 'q' to quit" + cv2.namedWindow(window_name, cv2.WINDOW_NORMAL) + cv2.resizeWindow(window_name, display_width, display_height) + + # Output window size + display_width = 1000 + + frame_index = 0 # <<< ADDED to match your other script + + while True: + ret, frame = cap.read() + if not ret: + break + + h, w = frame.shape[:2] + + # detect hands + rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + results = hands.process(rgb) + + # find RIGHT hand + features = [] + right_found = False + + if results.multi_hand_landmarks and results.multi_handedness: + for i, handness in enumerate(results.multi_handedness): + label = handness.classification[0].label + if label == "Right": + lm = results.multi_hand_landmarks[i] + right_found = True + + for p in lm.landmark: + features += [p.x, p.y, p.z] + + # draw hand with SAME STYLES + mp_drawing.draw_landmarks( + frame, + lm, + mp_hands.HAND_CONNECTIONS, + mp_drawing_styles.get_default_hand_landmarks_style(), + mp_drawing_styles.get_default_hand_connections_style(), + ) + break + + if not right_found: + features = [math.nan] * 63 + + print(features) + # -------------------------- + # Maintain window buffer + # -------------------------- + window_buffer.append(features) + + # forward fill + arr = np.array(window_buffer, dtype=np.float32) + for c in range(arr.shape[1]): + valid = ~np.isnan(arr[:, c]) + if valid.any(): + arr[:, c][~valid] = np.interp( + np.flatnonzero(~valid), + np.flatnonzero(valid), + arr[:, c][valid], + ) + window_buffer = arr.tolist() + + # only last 60 frames + if len(window_buffer) > self.window_frames: + window_buffer.pop(0) + + prob = 0 + pred_active = 0 + pred_before_contact = 0 + + # -------------------------- + # run inference when full window available + # -------------------------- + if len(window_buffer) == self.window_frames: + x = torch.tensor(window_buffer, dtype=torch.float32).unsqueeze(0) + with torch.no_grad(): + logits = model(x) + probs = torch.sigmoid(logits).squeeze(0).numpy() + + pred_active = int(probs[0] > self.threshold) + pred_before_contact = int(probs[1] > self.threshold) + + color_active = (0, 255, 0) if pred_active else (0, 0, 255) + color_before_contact = (255, 255, 0) if pred_before_contact else (0, 0, 128) + + # Overlay rectangles + cv2.rectangle(frame, (0, 0), (w//2, 40), color_active, -1) + cv2.putText(frame, f"Reach Active: {'YES' if pred_active else 'NO'} ({probs[0]:.2f})", + (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,255), 2) + + cv2.rectangle(frame, (w//2, 0), (w, 40), color_before_contact, -1) + cv2.putText(frame, f"Reach Before Contact: {'YES' if pred_before_contact else 'NO'} ({probs[1]:.2f})", + (w//2 + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,255), 2) + + else: + cv2.rectangle(frame, (0, 0), (w, 40), (0, 128, 255), -1) + cv2.putText( + frame, + f"Collecting window: {len(window_buffer)}/{self.window_frames}", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 1.0, + (255, 255, 255), + 2, + ) + + # ------------------------------ + # Extra overlays to MATCH your big script + # ------------------------------ + info_text = [ + f"Frame: {frame_index}", + f"Time: {frame_index / self.fps:.2f}s", + f"Reach Active: {'YES' if pred_active else 'NO'}", + f"Reach Before Contact: {'YES' if pred_before_contact else 'NO'}", + f"Right Hand: {'DETECTED' if right_found else 'NOT DETECTED'}", + "Press 'q' to quit", + ] + + for i, text in enumerate(info_text): + y = 60 + i * 25 + cv2.putText(frame, text, (10, y), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,255,0), 2) + + # If predicted active, add highlight frame + if pred_active or pred_before_contact: + cv2.rectangle(frame, (0,0), (w,h), (0,255,0), 6) + display_frame = frame.copy() + + # resize for display + resized_frame = cv2.resize(display_frame, (display_width, display_height), interpolation=cv2.INTER_AREA) + + cv2.imshow(window_name, resized_frame) + + frame_index += 1 + + if cv2.waitKey(1) & 0xFF == ord("q"): + break + self.finished.emit() + + + + + +class TestModelThread2(QThread): + finished = Signal(object) + + def __init__(self, csv_path, model_path, fps=60, window_sec=0.4, threshold=0.5): + super().__init__() + self.csv_path = csv_path + self.model_path = model_path + self.fps = fps + self.window_sec = window_sec + self.window_frames = int(fps * window_sec) + self.threshold = threshold + + + def run(self): + try: + self._run_internal() + except Exception as e: + import traceback + traceback.print_exc() + self.error = str(e) + self.finished.emit({"error": str(e)}) + + def _run_internal(self): + + # -------------------------------------------------- + # 1. Model Definition and Load (Copied from TestModelThread) + # -------------------------------------------------- + class WindowLSTM(torch.nn.Module): + def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2): + super().__init__() + self.lstm = torch.nn.LSTM( + input_size, hidden_size, batch_first=True, bidirectional=bidirectional + ) + self.fc = torch.nn.Linear(hidden_size * (2 if bidirectional else 1), output_size) + + def forward(self, x): + _, (h, _) = self.lstm(x) + if self.lstm.bidirectional: + h = torch.cat([h[-2], h[-1]], dim=1) + else: + h = h[-1] + return self.fc(h) + + # Assuming input_size=63, hidden_size=64, bidirectional=False, as in TestModelThread + # If your training params varied, you must update these defaults to match the saved model + model = WindowLSTM(input_size=63, hidden_size=64, bidirectional=False) + model.load_state_dict(torch.load(self.model_path, map_location="cpu")) + model.eval() + + # -------------------------------------------------- + # 2. Load CSV of landmarks (CRITICAL FIX APPLIED HERE) + # -------------------------------------------------- + df = pd.read_csv(self.csv_path) + feature_cols = [c for c in df.columns if c.startswith("lm")] + X_with_nans = df[feature_cols].values + time = df["time_sec"].values if "time_sec" in df.columns else np.arange(len(X_with_nans)) / self.fps + + # -------------------------------------------------- + # 3. Sliding window inference (EXACTLY MATCHING TestModelThread) + # -------------------------------------------------- + window_buffer = [] + preds_active = [] + preds_before_contact = [] + + for i in range(len(X_with_nans)): + + # Pad output + if i < self.window_frames - 1: + preds_active.append(0) + preds_before_contact.append(0) + + # Add features (may contain NaNs) to the buffer + features = X_with_nans[i] + window_buffer.append(features) + + # Keep only last `window_frames` + if len(window_buffer) > self.window_frames: + window_buffer.pop(0) + + # 🚨 CRITICAL: PERFORM INTERPOLATION ON THE BUFFER + # This block *must* run to smooth over lost frames, replicating the video logic. + arr = np.array(window_buffer, dtype=np.float32) + for c in range(arr.shape[1]): + valid = ~np.isnan(arr[:, c]) + if valid.any(): + arr[:, c][~valid] = np.interp( + np.flatnonzero(~valid), + np.flatnonzero(valid), + arr[:, c][valid], + ) + + # 🚨 CRITICAL: Update the buffer back to the interpolated LIST OF LISTS. + # This mirrors the exact data structure used by TestModelThread before calling torch.tensor(). + window_buffer = arr.tolist() + + + # predict only when full window + if len(window_buffer) == self.window_frames: + + # 🚨 CRITICAL: Create tensor from the LIST OF LISTS buffer + x = torch.tensor(window_buffer, dtype=torch.float32).unsqueeze(0) + with torch.no_grad(): + logits = model(x) + probs = torch.sigmoid(logits).squeeze(0).numpy() + + preds_active.append(probs[0]) + preds_before_contact.append(probs[1]) + + + self.finished.emit({ + "time": time, + "active": preds_active, + "before": preds_before_contact, + "threshold": self.threshold, + }) + + + +class TerminalWindow(QWidget): + def __init__(self, parent=None): + super().__init__(parent, Qt.WindowType.Window) + self.setWindowTitle("Terminal - SPARKS") + + self.output_area = QTextEdit() + self.output_area.setReadOnly(True) + + self.input_line = QLineEdit() + self.input_line.returnPressed.connect(self.handle_command) + + layout = QVBoxLayout() + layout.addWidget(self.output_area) + layout.addWidget(self.input_line) + self.setLayout(layout) + + self.commands = { + "hello": self.cmd_hello, + "help": self.cmd_help + } + + def handle_command(self): + command_text = self.input_line.text() + self.input_line.clear() + + self.output_area.append(f"> {command_text}") + parts = command_text.strip().split() + if not parts: + return + + command_name = parts[0] + args = parts[1:] + + func = self.commands.get(command_name) + if func: + try: + result = func(*args) + if result: + self.output_area.append(str(result)) + except Exception as e: + self.output_area.append(f"[Error] {e}") + else: + self.output_area.append(f"[Unknown command] '{command_name}'") + + + def cmd_hello(self, *args): + return "Hello from the terminal!" + + def cmd_help(self, *args): + return f"Available commands: {', '.join(self.commands.keys())}" + + +class UpdateDownloadThread(QThread): + """ + Thread that downloads and extracts an update package and emits a signal on completion or error. + + Args: + download_url (str): URL of the update zip file to download. + latest_version (str): Version string of the latest update. + """ + + update_ready = Signal(str, str) + error_occurred = Signal(str) + + def __init__(self, download_url, latest_version): + super().__init__() + self.download_url = download_url + self.latest_version = latest_version + + def run(self): + try: + local_filename = os.path.basename(self.download_url) + + if PLATFORM_NAME == 'darwin': + tmp_dir = '/tmp/sparkstempupdate' + os.makedirs(tmp_dir, exist_ok=True) + local_path = os.path.join(tmp_dir, local_filename) + else: + local_path = os.path.join(os.getcwd(), local_filename) + + # Download the file + with requests.get(self.download_url, stream=True, timeout=15) as r: + r.raise_for_status() + with open(local_path, 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + + # Extract folder name (remove .zip) + if PLATFORM_NAME == 'darwin': + extract_folder = os.path.splitext(local_filename)[0] + extract_path = os.path.join(tmp_dir, extract_folder) + + else: + extract_folder = os.path.splitext(local_filename)[0] + extract_path = os.path.join(os.getcwd(), extract_folder) + + # Create the folder if not exists + os.makedirs(extract_path, exist_ok=True) + + # Extract the zip file contents + if PLATFORM_NAME == 'darwin': + subprocess.run(['ditto', '-xk', local_path, extract_path], check=True) + else: + with zipfile.ZipFile(local_path, 'r') as zip_ref: + zip_ref.extractall(extract_path) + + # Remove the zip once extracted and emit a signal + os.remove(local_path) + self.update_ready.emit(self.latest_version, extract_path) + + except Exception as e: + # Emit a signal signifying failure + self.error_occurred.emit(str(e)) + + + +class UpdateCheckThread(QThread): + """ + Thread that checks for updates by querying the API and emits a signal based on the result. + + Signals: + download_requested(str, str): Emitted with (download_url, latest_version) when an update is available. + no_update_available(): Emitted when no update is found or current version is up to date. + error_occurred(str): Emitted with an error message if the update check fails. + """ + + download_requested = Signal(str, str) + no_update_available = Signal() + error_occurred = Signal(str) + + def run(self): + if not getattr(sys, 'frozen', False): + self.error_occurred.emit("Application is not frozen (Development mode).") + return + try: + latest_version, download_url = self.get_latest_release_for_platform() + if not latest_version: + self.no_update_available.emit() + return + + if not download_url: + self.error_occurred.emit(f"No download available for platform '{PLATFORM_NAME}'") + return + + if self.version_compare(latest_version, CURRENT_VERSION) > 0: + self.download_requested.emit(download_url, latest_version) + else: + self.no_update_available.emit() + + except Exception as e: + self.error_occurred.emit(f"Update check failed: {e}") + + def version_compare(self, v1, v2): + def normalize(v): return [int(x) for x in v.split(".")] + return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2)) + + def get_latest_release_for_platform(self): + urls = [API_URL, API_URL_SECONDARY] + for url in urls: + try: + + response = requests.get(API_URL, timeout=5) + response.raise_for_status() + releases = response.json() + + if not releases: + return None, None + + latest = releases[0] + tag = latest["tag_name"].lstrip("v") + + for asset in latest.get("assets", []): + if PLATFORM_NAME in asset["name"].lower(): + return tag, asset["browser_download_url"] + + return tag, None + except (requests.RequestException, ValueError) as e: + continue + return None, None + + +class LocalPendingUpdateCheckThread(QThread): + """ + Thread that checks for locally pending updates by scanning the download directory and emits a signal accordingly. + + Args: + current_version (str): Current application version. + platform_suffix (str): Platform-specific suffix to identify update folders. + """ + + pending_update_found = Signal(str, str) + no_pending_update = Signal() + + def __init__(self, current_version, platform_suffix): + super().__init__() + self.current_version = current_version + self.platform_suffix = platform_suffix + + def version_compare(self, v1, v2): + def normalize(v): return [int(x) for x in v.split(".")] + return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2)) + + def run(self): + if PLATFORM_NAME == 'darwin': + cwd = '/tmp/sparkstempupdate' + else: + cwd = os.getcwd() + + pattern = re.compile(r".*-(\d+\.\d+\.\d+)" + re.escape(self.platform_suffix) + r"$") + found = False + + try: + for item in os.listdir(cwd): + folder_path = os.path.join(cwd, item) + if os.path.isdir(folder_path) and item.endswith(self.platform_suffix): + match = pattern.match(item) + if match: + folder_version = match.group(1) + if self.version_compare(folder_version, self.current_version) > 0: + self.pending_update_found.emit(folder_version, folder_path) + found = True + break + except: + pass + + if not found: + self.no_pending_update.emit() + + + + + +class ProgressDialog(QDialog): + def __init__(self, parent=None): + super().__init__(parent) + self.setWindowTitle("Processing Files...") + self.setWindowModality(Qt.WindowModality.ApplicationModal) + self.setMinimumWidth(300) + + layout = QVBoxLayout(self) + + self.message_label = QLabel("Loading and analyzing video files...") + self.progress_bar = QProgressBar() + self.progress_bar.setMaximum(0) # Indeterminate state initially + + self.cancel_button = QPushButton("Cancel") + + layout.addWidget(self.message_label) + layout.addWidget(self.progress_bar) + layout.addWidget(self.cancel_button) + + def update_progress(self, current, total): + if self.progress_bar.maximum() == 0: + self.progress_bar.setMaximum(total) + + self.progress_bar.setValue(current) + self.message_label.setText(f"Processing file {current} of {total}...") + + + + + + + +class FileLoadWorker(QThread): + observations_loaded = Signal(dict, dict, str) + progress_update = Signal(int, int) + loading_finished = Signal() + loading_failed = Signal(str) + + def __init__(self, file_path, observations_root, extract_frame_and_hands_func): + super().__init__() + self.file_path = file_path + self.observations_root = observations_root + self.extract_frame_and_hands = extract_frame_and_hands_func + self.is_running = True + self.previews_data = {} # To store the results of the heavy work + + def resolve_video_path(self, relative_video_path): + """Logic to handle flexible path resolution (basename or full join).""" + + corrected_file_name = os.path.basename(relative_video_path) + video_path = os.path.join(self.observations_root, corrected_file_name) + + if not os.path.exists(video_path): + # Fallback to full relative path join + video_path = os.path.join(self.observations_root, relative_video_path) + + return video_path + + def run(self): + try: + # 1. Load the BORIS JSON (Non-UI work) + with open(self.file_path, "r") as f: + boris = json.load(f) + + observations = boris.get("observations", {}) + if not observations: + self.loading_failed.emit("No observations found in BORIS JSON.") + return + + # 2. Process the file data to gather initial previews (the heavy part) + self.process_previews(observations) + + # 3. Emit the final results back + self.observations_loaded.emit(boris, observations, self.observations_root) + + except Exception as e: + self.loading_failed.emit(f"Loading failed: {str(e)}") + finally: + self.loading_finished.emit() # Ensures dialog closes even on some non-critical path breaks + + def process_previews(self, observations): + total_files = sum(len(obs_data.get("file", {}).items()) for obs_data in observations.values()) + current_index = 0 + + for obs_id, obs_data in observations.items(): + files_dict = obs_data.get("file", {}) + media_info = obs_data.get("media_info", {}) + fps_dict = media_info.get("fps", {}) + # Simplified FPS calculation for the worker + fps_default = float(list(fps_dict.values())[0]) if fps_dict and list(fps_dict.values()) else 30.0 + + print(media_info) + for cam_id, paths in files_dict.items(): + if not self.is_running: + return # Allow cancellation + if not paths: + print(f"Skipping entry for {obs_id}/{cam_id}: paths list is empty (IndexError prevented).") + current_index += 1 + self.progress_update.emit(current_index, total_files) + continue + print(paths) + relative_video_path = paths[0] + video_path = self.resolve_video_path(relative_video_path) + print(video_path) + if os.path.exists(video_path): + # HEAVY OPERATION + frame_rgb, results = self.extract_frame_and_hands(video_path, 0) + + initial_wrists = [] + if results and results.multi_hand_landmarks: + for lm in results.multi_hand_landmarks: + wrist = lm.landmark[0] + initial_wrists.append((wrist.x, wrist.y)) + + # Store the results needed to BUILD the UI later + self.previews_data[(obs_id, cam_id)] = { + "frame_rgb": frame_rgb, + "results": results, + "video_path": video_path, + "fps": fps_default, + "initial_wrists": initial_wrists + } + + current_index += 1 + self.progress_update.emit(current_index, total_files) + + def stop(self): + self.is_running = False + + + + +class AboutWindow(QWidget): + """ + Simple About window displaying basic application information. + + Args: + parent (QWidget, optional): Parent widget of this window. Defaults to None. + """ + + def __init__(self, parent=None): + super().__init__(parent, Qt.WindowType.Window) + self.setWindowTitle("About SPARKS") + self.resize(250, 100) + + layout = QVBoxLayout() + label = QLabel("About SPARKS", self) + label2 = QLabel("Spacial Patterns Analysis, Research, & Knowledge Suite", self) + label3 = QLabel("SPARKS is licensed under the GPL-3.0 licence. For more information, visit https://www.gnu.org/licenses/gpl-3.0.en.html", self) + label4 = QLabel(f"Version v{CURRENT_VERSION}") + + layout.addWidget(label) + layout.addWidget(label2) + layout.addWidget(label3) + layout.addWidget(label4) + + self.setLayout(layout) + + + +class UserGuideWindow(QWidget): + """ + Simple User Guide window displaying basic information on how to use the software. + + Args: + parent (QWidget, optional): Parent widget of this window. Defaults to None. + """ + + def __init__(self, parent=None): + super().__init__(parent, Qt.WindowType.Window) + self.setWindowTitle("User Guide - SPARKS") + self.resize(250, 100) + + layout = QVBoxLayout() + label = QLabel("Not Applicable:", self) + label2 = QLabel("N/A", self) + + layout.addWidget(label) + layout.addWidget(label2) + + self.setLayout(layout) + + + +class ParamSection(QWidget): + """ + A widget section that dynamically creates labeled input fields from parameter metadata. + + Args: + section_data (dict): Dictionary containing section title and list of parameter info. + Expected format: + { + "title": str, + "params": [ + { + "name": str, + "type": type, + "default": any, + "help": str (optional) + }, + ... + ] + } + """ + + def __init__(self, section_data): + super().__init__() + layout = QVBoxLayout() + self.setLayout(layout) + self.widgets = {} + + self.selected_path = None + + # Title label + title_label = QLabel(section_data["title"]) + title_label.setStyleSheet("font-weight: bold; font-size: 14px; margin-top: 10px; margin-bottom: 5px;") + layout.addWidget(title_label) + + # Horizontal line + line = QFrame() + line.setFrameShape(QFrame.Shape.HLine) + line.setFrameShadow(QFrame.Shadow.Sunken) + layout.addWidget(line) + + for param in section_data["params"]: + h_layout = QHBoxLayout() + + label = QLabel(param["name"]) + label.setFixedWidth(180) + + label.setToolTip(param.get("help", "")) + + help_text = param.get("help", "") + + help_btn = QPushButton("?") + help_btn.setFixedWidth(25) + help_btn.setToolTip(help_text) + help_btn.clicked.connect(lambda _, text=help_text: self.show_help_popup(text)) + + + h_layout.addWidget(help_btn) + h_layout.setStretch(0, 1) # Set stretch factor for button (10%) + + h_layout.addWidget(label) + h_layout.setStretch(1, 3) # Set the stretch factor for label (40%) + + + # Create input widget based on type + if param["type"] == bool: + widget = QComboBox() + widget.addItems(["True", "False"]) + widget.setCurrentText(str(param["default"])) + elif param["type"] == int: + widget = QLineEdit() + widget.setValidator(QIntValidator()) + widget.setText(str(param["default"])) + elif param["type"] == float: + widget = QLineEdit() + widget.setValidator(QDoubleValidator()) + widget.setText(str(param["default"])) + elif param["type"] == list: + widget = self._create_multiselect_dropdown(None) + else: + widget = QLineEdit() + widget.setText(str(param["default"])) + + widget.setToolTip(help_text) + + h_layout.addWidget(widget) + h_layout.setStretch(2, 5) # Set stretch factor for input field (50%) + + layout.addLayout(h_layout) + self.widgets[param["name"]] = { + "widget": widget, + "type": param["type"] + } + + def _create_multiselect_dropdown(self, items): + combo = FullClickComboBox() + combo.setView(QListView()) + model = QStandardItemModel() + combo.setModel(model) + combo.setEditable(True) + combo.lineEdit().setReadOnly(True) + combo.lineEdit().setPlaceholderText("Select...") + + dummy_item = QStandardItem("") + dummy_item.setFlags(Qt.ItemIsEnabled) + model.appendRow(dummy_item) + + toggle_item = QStandardItem("Toggle Select All") + toggle_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) + toggle_item.setData(Qt.Unchecked, Qt.CheckStateRole) + model.appendRow(toggle_item) + + if items is not None: + for item in items: + standard_item = QStandardItem(item) + standard_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) + standard_item.setData(Qt.Unchecked, Qt.CheckStateRole) + model.appendRow(standard_item) + + combo.setInsertPolicy(QComboBox.NoInsert) + + + def on_view_clicked(index): + item = model.itemFromIndex(index) + if item.isCheckable(): + new_state = Qt.Checked if item.checkState() == Qt.Unchecked else Qt.Unchecked + item.setCheckState(new_state) + + combo.view().pressed.connect(on_view_clicked) + + self._updating_checkstates = False + + def on_item_changed(item): + if self._updating_checkstates: + return + self._updating_checkstates = True + + normal_items = [model.item(i) for i in range(2, model.rowCount())] # skip dummy and toggle + + if item == toggle_item: + all_checked = all(i.checkState() == Qt.Checked for i in normal_items) + if all_checked: + for i in normal_items: + i.setCheckState(Qt.Unchecked) + toggle_item.setCheckState(Qt.Unchecked) + else: + for i in normal_items: + i.setCheckState(Qt.Checked) + toggle_item.setCheckState(Qt.Checked) + + elif item == dummy_item: + pass + + else: + # When normal items change, update toggle item + all_checked = all(i.checkState() == Qt.Checked for i in normal_items) + toggle_item.setCheckState(Qt.Checked if all_checked else Qt.Unchecked) + + self._updating_checkstates = False + + for param_name, info in self.widgets.items(): + if info["widget"] == combo: + self.update_dropdown_label(param_name) + break + + model.itemChanged.connect(on_item_changed) + + combo.setInsertPolicy(QComboBox.NoInsert) + return combo + + def show_help_popup(self, text): + msg = QMessageBox(self) + msg.setWindowTitle("Parameter Info - SPARKS") + msg.setText(text) + msg.exec() + + + def get_param_values(self): + values = {} + for name, info in self.widgets.items(): + widget = info["widget"] + expected_type = info["type"] + + if expected_type == bool: + values[name] = widget.currentText() == "True" + elif expected_type == list: + values[name] = [x.strip() for x in widget.lineEdit().text().split(",") if x.strip()] + else: + raw_text = widget.text() + try: + if expected_type == int: + values[name] = int(raw_text) + elif expected_type == float: + values[name] = float(raw_text) + elif expected_type == str: + values[name] = raw_text + else: + values[name] = raw_text # Fallback + except Exception as e: + raise ValueError(f"Invalid value for {name}: {raw_text}") from e + + return values + + +class FullClickLineEdit(QLineEdit): + def mousePressEvent(self, event): + combo = self.parent() + if isinstance(combo, QComboBox): + combo.showPopup() + super().mousePressEvent(event) + + +class FullClickComboBox(QComboBox): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setLineEdit(FullClickLineEdit(self)) + self.lineEdit().setReadOnly(True) + + +class ParticipantProcessor(QThread): + progress_updated = Signal(int) + frame_ready = Signal(QImage) + finished_processing = Signal(str) + time_updated = Signal(str) + + def __init__(self, obs_id, boris_json, selected_cam_id, selected_hand_idx, + hide_preview=False, output_csv=None, observations_root=None, output_dir=None, initial_wrists=None): + super().__init__() + self.obs_id = obs_id + self.boris_json = boris_json + self.selected_cam_id = selected_cam_id + self.selected_hand_idx = selected_hand_idx + self.hide_preview = hide_preview + self.output_csv = output_csv + self.observations_root = observations_root + self.output_dir = output_dir + self.initial_wrists = initial_wrists + + # Mediapipe hands + self.mp_hands = mp.solutions.hands.Hands( + static_image_mode=False, + max_num_hands=2, + min_detection_confidence=0.5, + min_tracking_confidence=0.5, + model_complexity=0 + ) + self.is_running = True + + def run(self): + observations = self.boris_json.get("observations", {}) + obs_data = observations[self.obs_id] + + # ---------------- Camera / Video ---------------- + cameras_with_files = obs_data.get("file", {}) + media_info = obs_data.get("media_info", {}) + offset_dict = media_info.get("offset", {}) + # paths = files_dict.get(self.cam_id, []) + + video_file_list = cameras_with_files[self.selected_cam_id] + relative_video_path = video_file_list[0] + + # 2. Resolve the FULL path using the stored root and the same logic as before + corrected_file_name = os.path.basename(relative_video_path) + + # Try finding just the file name in the root folder + video_path = os.path.join(self.observations_root, corrected_file_name) + + # if not os.path.exists(video_path): + # # Fallback to the full relative path join + # video_path = os.path.join(self.observations_root, relative_video_path) + + if not os.path.exists(video_path): + print(f"FATAL ERROR: Video file not found at: {video_path}") + self.finished_processing.emit(f"Error: Video not found for {self.obs_id}") + return + + camera_offset = float(offset_dict.get(self.selected_cam_id, 0.0)) + + fps_dict = media_info.get("fps", {}) + fps = float(list(fps_dict.values())[0]) if fps_dict else 30.0 + + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + total_seconds = total_frames / fps + total_time_str = self._format_seconds(total_seconds) + + # ---------------- Trial Segments ---------------- + raw_events = obs_data.get("events", []) + segments = [] + reach_before_contact_segments = [] + + current_start = None + extra_duration = 1.0 + + for ev in raw_events: + if len(ev) < 3: + continue + time_sec = ev[0] + name = ev[2] + if name == "Trial Start": + current_start = time_sec + current_contact = None + elif name == "Contact" and current_start is not None: + current_contact = time_sec + reach_before_contact_segments.append((current_start, current_contact)) + + elif name == "End" and current_start is not None: + adjusted_start = current_start + adjusted_end = time_sec + extra_duration + segments.append((adjusted_start, adjusted_end)) + if current_contact is None: + reach_before_contact_segments.append((current_start, adjusted_end)) + current_start = None + + # ---------------- CSV Header ---------------- + header = ["frame_index", "time_sec", "adjusted_time_sec", "reach_active", "reach_before_contact", "camera_offset"] + for i in range(21): + header += [f"lm{i}_x", f"lm{i}_y", f"lm{i}_z"] + + output_path = os.path.join(self.output_dir, self.output_csv) + print(output_path) + with open(output_path, "w", newline="") as csv_file: + csv_writer = csv.writer(csv_file) + csv_writer.writerow(header) + + # ---------------- Frame Processing ---------------- + frame_index = 0 + while True: + if not self.is_running: + print(f"Processor for {self.obs_id}/{self.cam_id} was cancelled gracefully.") + # Clean up resources before returning + cap.release() + return # Thread exits gracefully + ret, frame = cap.read() + if not ret: + break + + current_time = frame_index / fps + adjusted_time = current_time # you could add extra logic if needed + + # Check if reach is active with offset + is_reach_active = any( + (start - camera_offset) <= current_time <= (end - camera_offset) + for start, end in segments + ) + + is_reach_before_contact = any( + (start - camera_offset) <= current_time <= (end - camera_offset) + for start, end in reach_before_contact_segments + ) + + row = [ + frame_index, current_time, adjusted_time, + int(is_reach_active), int(is_reach_before_contact), + camera_offset + ] + + # ---------------- MediaPipe Hands ---------------- + rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + results = self.mp_hands.process(rgb_frame) + + hand_found = False + if results.multi_hand_landmarks and results.multi_handedness: + # Map detected hands to initial_wrists + hand_mapping = {} + used_indices = set() + if self.initial_wrists: + for i, iw in enumerate(self.initial_wrists): + min_dist = float('inf') + best_idx = None + for j, lm in enumerate(results.multi_hand_landmarks): + if j in used_indices: + continue + wrist = lm.landmark[0] + dist = (wrist.x - iw[0])**2 + (wrist.y - iw[1])**2 + if dist < min_dist: + min_dist = dist + best_idx = j + if best_idx is not None: + hand_mapping[i] = best_idx + used_indices.add(best_idx) + else: + hand_mapping = {i: i for i in range(len(results.multi_hand_landmarks))} + + # Use selected_hand_idx to pick the right hand + mapped_idx = hand_mapping.get(self.selected_hand_idx) + if mapped_idx is not None: + hand_landmarks = results.multi_hand_landmarks[mapped_idx] + hand_found = True + for lm in hand_landmarks.landmark: + row += [lm.x, lm.y, lm.z] + + if not hand_found: + row += [math.nan] * (21 * 3) + + csv_writer.writerow(row) + + # ---------------- Preview ---------------- + if not self.hide_preview: + display_frame = frame.copy() + if results.multi_hand_landmarks and hand_found: + mp.solutions.drawing_utils.draw_landmarks( + display_frame, + hand_landmarks, + mp.solutions.hands.HAND_CONNECTIONS + ) + # Convert to QImage + h, w, _ = display_frame.shape + qimg = QImage(display_frame.data, w, h, 3*w, QImage.Format_RGB888) + self.frame_ready.emit(qimg) + + current_seconds = frame_index / fps + current_time_str = self._format_seconds(current_seconds) + # ---------------- Progress ---------------- + progress = int((frame_index / total_frames) * 100) + self.progress_updated.emit(progress) + self.time_updated.emit(f"{current_time_str}/{total_time_str}") + frame_index += 1 + + cap.release() + self.mp_hands.close() + self.finished_processing.emit(self.obs_id) + + def _format_seconds(self, seconds): + if seconds < 0: + return "00:00" + minutes = math.floor(seconds / 60) + secs = math.floor(seconds % 60) + return f"{minutes:02d}:{secs:02d}" + + def stop(self): + """Sets the flag to stop the thread's execution loop.""" + self.is_running = False + + +class MainApplication(QMainWindow): + """ + Main application window that creates and sets up the UI. + """ + + progress_update_signal = Signal(str, int) + + def __init__(self): + super().__init__() + self.setWindowTitle("SPARKS") + self.setGeometry(100, 100, 1280, 720) + + self.about = None + self.help = None + self.optodes = None + self.events = None + self.terminal = None + self.bubble_widgets = {} + self.param_sections = [] + self.folder_paths = [] + self.section_widget = None + self.first_run = True + + self.selection_widgets = {} + self.camera_state = {} # {(obs_id, cam_id): state info} + self.processing_widgets = {} + + self.files_total = 0 # total number of files to process + self.files_done = set() # set of file paths done (success or fail) + self.files_failed = set() # set of failed file paths + self.files_results = {} # dict for successful results (if needed) + + self.processing_threads = [] # List to hold all active ParticipantProcessor threads + self.processing_widgets = {} + + self.init_ui() + self.create_menu_bar() + + self.platform_suffix = "-" + PLATFORM_NAME + self.pending_update_version = None + self.pending_update_path = None + self.last_clicked_bubble = None + self.installEventFilter(self) + + self.file_metadata = {} + self.current_file = None + + self.worker_thread = None + self.progress_dialog = None + + # Mediapipe hands + self.mp_hands = mp.solutions.hands.Hands( + static_image_mode=True, + max_num_hands=2, + min_detection_confidence=0.5 + ) + + # Start local pending update check thread + self.local_check_thread = LocalPendingUpdateCheckThread(CURRENT_VERSION, self.platform_suffix) + self.local_check_thread.pending_update_found.connect(self.on_pending_update_found) + self.local_check_thread.no_pending_update.connect(self.on_no_pending_update) + self.local_check_thread.start() + + + def init_ui(self): + + # Central widget and main horizontal layout + central = QWidget() + self.setCentralWidget(central) + + main_layout = QHBoxLayout() + central.setLayout(main_layout) + + # Left container with vertical layout: top left + bottom left + left_container = QWidget() + left_layout = QVBoxLayout() + left_container.setLayout(left_layout) + left_container.setMinimumWidth(300) + + top_left_container = QGroupBox() + top_left_container.setTitle("File information") + top_left_container.setStyleSheet("QGroupBox { font-weight: bold; }") # Style if needed + top_left_container.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding) + + top_left_layout = QHBoxLayout() + + top_left_container.setLayout(top_left_layout) + + # QTextEdit with fixed height, but only 80% width + self.top_left_widget = QTextEdit() + self.top_left_widget.setReadOnly(True) + self.top_left_widget.setPlaceholderText("Logging information will be available here.") + + # Add QTextEdit to the layout with a stretch factor + top_left_layout.addWidget(self.top_left_widget, stretch=4) # 80% + + # Create a vertical box layout for the right 20% + self.right_column_widget = QWidget() + right_column_layout = QVBoxLayout() + self.right_column_widget.setLayout(right_column_layout) + + self.meta_fields = { + "HAND": QLineEdit(), + "GENDER": QLineEdit(), + "GROUP": QLineEdit(), + } + + for key, field in self.meta_fields.items(): + label = QLabel(key.capitalize()) + field.setPlaceholderText(f"Enter {key}") + right_column_layout.addWidget(label) + right_column_layout.addWidget(field) + + label_desc = QLabel('Why are these useful?') + label_desc.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction) + label_desc.setOpenExternalLinks(False) + + def show_info_popup(): + QMessageBox.information(None, "Parameter Info - SPARKS", + "Age: Used to calculate the DPF factor.\nGender: Not currently used. " + "Will be able to sort into groups by gender in the near future.\nGroup: Allows contrast " + "images to be created comparing one group to another once the processing has completed.") + + label_desc.linkActivated.connect(show_info_popup) + right_column_layout.addWidget(label_desc) + + right_column_layout.addStretch() # Push fields to top + + self.right_column_widget.hide() + + # Add right column widget to the top-left layout (takes 20% width) + top_left_layout.addWidget(self.right_column_widget, stretch=1) + + # Add top_left_container to the main left_layout + left_layout.addWidget(top_left_container, stretch=2) + + # Bottom left: the bubbles inside the scroll area + self.bubble_container = QWidget() + self.bubble_layout = QGridLayout() + self.bubble_layout.setAlignment(Qt.AlignmentFlag.AlignTop) + self.bubble_container.setLayout(self.bubble_layout) + + self.scroll_area = QScrollArea() + self.scroll_area.setWidgetResizable(True) + self.scroll_area.setWidget(self.bubble_container) + self.scroll_area.setMinimumHeight(300) + + # Add top left and bottom left to left layout + left_layout.addWidget(self.scroll_area, stretch=8) + + # Right widget (full height on right side) + self.right_container = QWidget() + right_container_layout = QVBoxLayout() + self.right_container.setLayout(right_container_layout) + + # Content widget inside scroll area + self.right_content_widget = QWidget() + right_content_layout = QVBoxLayout() + self.right_content_widget.setLayout(right_content_layout) + + # Option selector dropdown + self.option_selector = QComboBox() + self.option_selector.addItems(["N/A"]) + right_content_layout.addWidget(self.option_selector) + + # Container for the sections + self.rows_container = QWidget() + self.rows_layout = QVBoxLayout() + self.rows_layout.setSpacing(10) + self.rows_container.setLayout(self.rows_layout) + right_content_layout.addWidget(self.rows_container) + + # Spacer at bottom inside scroll area content to push content up + right_content_layout.addStretch() + + # Scroll area for the right side content + self.right_scroll_area = QScrollArea() + self.right_scroll_area.setWidgetResizable(True) + self.right_scroll_area.setWidget(self.right_content_widget) + + # Buttons widget (fixed below the scroll area) + buttons_widget = QWidget() + buttons_layout = QHBoxLayout() + buttons_widget.setLayout(buttons_layout) + buttons_layout.addStretch() + + self.button1 = QPushButton("Process") + self.button2 = QPushButton("Clear") + + buttons_layout.addWidget(self.button1) + buttons_layout.addWidget(self.button2) + + self.button1.setMinimumSize(100, 40) + self.button2.setMinimumSize(100, 40) + + self.button1.setVisible(False) + + self.button1.clicked.connect(self.on_run_task) + self.button2.clicked.connect(self.clear_all) + + # Add scroll area and buttons widget to right container layout + right_container_layout.addWidget(self.right_scroll_area) + right_container_layout.addWidget(buttons_widget) + + # Add left and right containers to main layout + main_layout.addWidget(left_container, stretch=70) + main_layout.addWidget(self.right_container, stretch=30) + + # Set size policy to expand + self.right_container.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + self.right_scroll_area.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + + # Store ParamSection widgets + self.option_selector.currentIndexChanged.connect(self.update_sections) + + # Initial build + self.update_sections(0) + + + def create_menu_bar(self): + '''Menu Bar at the top of the screen''' + + menu_bar = self.menuBar() + + def make_action(name, shortcut=None, slot=None, checkable=False, checked=False, icon=None): + action = QAction(name, self) + + if shortcut: + action.setShortcut(QKeySequence(shortcut)) + if slot: + action.triggered.connect(slot) + if checkable: + action.setCheckable(True) + action.setChecked(checked) + if icon: + action.setIcon(QIcon(icon)) + return action + + # File menu and actions + file_menu = menu_bar.addMenu("File") + file_actions = [ + ("Open BORIS file...", "Ctrl+O", self.open_file_dialog, resource_path("icons/file_open_24dp_1F1F1F.svg")), + #("Open Folder...", "Ctrl+Alt+O", self.open_folder_dialog, resource_path("icons/folder_24dp_1F1F1F.svg")), + #("Open Folders...", "Ctrl+Shift+O", self.open_folder_dialog, resource_path("icons/folder_copy_24dp_1F1F1F.svg")), + #("Load Project...", "Ctrl+L", self.load_project, resource_path("icons/article_24dp_1F1F1F.svg")), + #("Save Project...", "Ctrl+S", self.save_project, resource_path("icons/save_24dp_1F1F1F.svg")), + #("Save Project As...", "Ctrl+Shift+S", self.save_project, resource_path("icons/save_as_24dp_1F1F1F.svg")), + ] + + for i, (name, shortcut, slot, icon) in enumerate(file_actions): + file_menu.addAction(make_action(name, shortcut, slot, icon=icon)) + if i == 1: # after the first 3 actions (0,1,2) + file_menu.addSeparator() + + file_menu.addSeparator() + file_menu.addAction(make_action("Exit", "Ctrl+Q", QApplication.instance().quit, icon=resource_path("icons/exit_to_app_24dp_1F1F1F.svg"))) + + # Edit menu + edit_menu = menu_bar.addMenu("Edit") + edit_actions = [ + ("Cut", "Ctrl+X", self.cut_text, resource_path("icons/content_cut_24dp_1F1F1F.svg")), + ("Copy", "Ctrl+C", self.copy_text, resource_path("icons/content_copy_24dp_1F1F1F.svg")), + ("Paste", "Ctrl+V", self.paste_text, resource_path("icons/content_paste_24dp_1F1F1F.svg")) + ] + for name, shortcut, slot, icon in edit_actions: + edit_menu.addAction(make_action(name, shortcut, slot, icon=icon)) + + + model_menu = menu_bar.addMenu("Model") + model_actions = [ + ("Create model from a CSV", "Ctrl+U", self.train_model_csv, resource_path("icons/content_cut_24dp_1F1F1F.svg")), + ("Create model from a folder", "Ctrl+I", self.train_model_folder, resource_path("icons/content_copy_24dp_1F1F1F.svg")), + ("Test model on a video", "Ctrl+O", self.test_model_video, resource_path("icons/content_paste_24dp_1F1F1F.svg")), + ("Test model on a folder", "Ctrl+P", self.test_model_folder, resource_path("icons/content_paste_24dp_1F1F1F.svg")), + ("Test model on a CSV", "Ctrl+P", self.test_model_csv, resource_path("icons/content_paste_24dp_1F1F1F.svg")) + ] + for i, (name, shortcut, slot, icon) in enumerate(model_actions): + model_menu.addAction(make_action(name, shortcut, slot, icon=icon)) + if i == 1: + model_menu.addSeparator() + + # View menu + view_menu = menu_bar.addMenu("View") + toggle_statusbar_action = make_action("Toggle Status Bar", checkable=True, checked=True, slot=None) + view_menu.addAction(toggle_statusbar_action) + + # Options menu (Help & About) + options_menu = menu_bar.addMenu("Options") + + options_actions = [ + ("User Guide", "F1", self.user_guide, resource_path("icons/help_24dp_1F1F1F.svg")), + ("Check for Updates", "F5", self.manual_check_for_updates, resource_path("icons/update_24dp_1F1F1F.svg")), + ("About", "F12", self.about_window, resource_path("icons/info_24dp_1F1F1F.svg")) + ] + + for i, (name, shortcut, slot, icon) in enumerate(options_actions): + options_menu.addAction(make_action(name, shortcut, slot, icon=icon)) + if i == 1 or i == 3: # after the first 2 actions (0,1) + options_menu.addSeparator() + + terminal_menu = menu_bar.addMenu("Terminal") + terminal_actions = [ + ("New Terminal", "Ctrl+Alt+T", self.terminal_gui, resource_path("icons/terminal_24dp_1F1F1F.svg")), + ] + for name, shortcut, slot, icon in terminal_actions: + terminal_menu.addAction(make_action(name, shortcut, slot, icon=icon)) + + self.statusbar = self.statusBar() + self.statusbar.showMessage("Ready") + + + def update_sections(self, index): + # Clear previous sections + for i in reversed(range(self.rows_layout.count())): + widget = self.rows_layout.itemAt(i).widget() + if widget is not None: + widget.deleteLater() + self.param_sections.clear() + + # Add ParamSection widgets from SECTIONS + for section in SECTIONS: + self.section_widget = ParamSection(section) + self.rows_layout.addWidget(self.section_widget) + + self.param_sections.append(self.section_widget) + + + def clear_bubble_widgets(self): + """Clears all widgets from the self.bubble_layout.""" + while self.bubble_layout.count(): + # Get the next item in the layout + item = self.bubble_layout.takeAt(0) + + # Check if the item holds a widget (most common case: QGroupBox) + widget = item.widget() + if widget: + widget.deleteLater() + + # Check if the item is a spacer/stretch item + elif item.spacerItem(): + self.bubble_layout.removeItem(item) + + def clear_all(self): + + self.cancel_task() + + self.right_column_widget.hide() + + # Clear the bubble layout + while self.bubble_layout.count(): + item = self.bubble_layout.takeAt(0) + widget = item.widget() + if widget: + widget.deleteLater() + + # Clear file data + self.bubble_widgets.clear() + self.statusBar().clearMessage() + + self.raw_haemo_dict = None + self.epochs_dict = None + self.fig_bytes_dict = None + self.cha_dict = None + self.contrast_results_dict = None + self.df_ind_dict = None + self.design_matrix_dict = None + self.age_dict = None + self.gender_dict = None + self.group_dict = None + self.valid_dict = None + + # Reset any visible UI elements + self.button1.setVisible(False) + self.top_left_widget.clear() + + + def get_output_directory(self): + """Prompts user to select a parent folder and creates a unique, dated subfolder.""" + + # Open a folder selection dialog to get the parent folder + parent_dir = QFileDialog.getExistingDirectory( + self, + "Select Parent Folder for Results", + os.getcwd() # Start in the user's home directory + ) + + if not parent_dir: + return None # User cancelled the dialog + + # Create the unique subfolder name: sparks_YYYYMMDD_HHMMSS + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + subfolder_name = f"sparks_{timestamp}" + output_path = os.path.join(parent_dir, subfolder_name) + + # Create the directory if it doesn't exist + try: + os.makedirs(output_path, exist_ok=True) + print(f"Created output directory: {output_path}") + return output_path + except Exception as e: + QMessageBox.critical(self, "Error", f"Failed to create output directory: {e}") + return None + + def copy_text(self): + self.top_left_widget.copy() # Trigger copy + self.statusbar.showMessage("Copied to clipboard") # Show status message + + def cut_text(self): + self.top_left_widget.cut() # Trigger cut + self.statusbar.showMessage("Cut to clipboard") # Show status message + + def paste_text(self): + self.top_left_widget.paste() # Trigger paste + self.statusbar.showMessage("Pasted from clipboard") # Show status message + + + def get_training_files(self): + """Opens a dialog to select multiple CSV files for training.""" + file_paths, _ = QFileDialog.getOpenFileNames( + self, + "Select Training CSV Files", + os.getcwd(), # Start in current directory + "CSV Files (*.csv)" # Filter for CSV files + ) + return file_paths + + def get_testing_files(self): + """Opens dialogs to select the model file and the video file.""" + + # 1. Select Model File (.pth) + model_path, _ = QFileDialog.getOpenFileName( + self, + "Select Trained Model File (.pth)", + os.getcwd(), + "Model Files (*.pth)" + ) + if not model_path: return None, None + + # 2. Select Video File + video_path, _ = QFileDialog.getOpenFileName( + self, + "Select Video File for Testing", + os.getcwd(), + "Video Files (*.mp4 *.avi *.mov)" + ) + + return model_path, video_path + + + def get_testing_files2(self): + """Opens dialogs to select the model file and the csv file.""" + + # 1. Select Model File (.pth) + model_path, _ = QFileDialog.getOpenFileName( + self, + "Select Trained Model File (.pth)", + os.getcwd(), + "Model Files (*.pth)" + ) + if not model_path: return None, None + + # 2. Select Video File + csv_path, _ = QFileDialog.getOpenFileName( + self, + "Select Video File for Testing", + os.getcwd(), + "Comma Seperated Files (*.csv)" + ) + + return model_path, csv_path + + def train_model_csv(self): + """Initiates the model training process, including selecting CSVs and save path.""" + + # 1. Get CSV paths + csv_paths = self.get_training_files() # Re-use the existing CSV file selection method + if not csv_paths: + self.statusBar().showMessage("Training cancelled. No CSV files selected.") + return + + # 2. Get Model Save Path (New Pop-up) + default_filename = "best_reach_lstm.pth" + + # The third argument can include a default filename + model_save_path, _ = QFileDialog.getSaveFileName( + self, + "Save Trained Model File", + os.path.join(os.getcwd(), default_filename), # Start in CWD with a default name + "PyTorch Model Files (*.pth);;All Files (*)" + ) + + if not model_save_path: + self.statusBar().showMessage("Training cancelled. Model save location not specified.") + return + + # Ensure correct extension if the user didn't type one + if not model_save_path.lower().endswith('.pth'): + model_save_path += '.pth' + + self.statusBar().showMessage(f"Starting model training on {len(csv_paths)} files. Model will save to: {os.path.basename(model_save_path)}") + + # 3. Start the thread with the new path + self.train_thread = TrainModelThread( + csv_paths=csv_paths, + model_save_path=model_save_path # Pass the path here + ) + + # ... (connect signals, start thread, disable button, etc.) ... + self.train_thread.update.connect(self.log_training_message) + self.train_thread.finished.connect(self.on_train_finished) + self.train_thread.start() + + + def log_training_message(self, message): + """Displays training status messages (e.g., in a QTextEdit).""" + # Assuming you have a QTextEdit named self.log_output (or similar) + # For simplicity, we'll use the status bar for now, but a dedicated log box is better. + self.statusBar().showMessage(f"Training: {message}") + # self.top_left_widget.append(f"[TRAIN] {message}") # Use this if you have a log box + + def on_train_finished(self, final_message): + """Handles completion of the training thread.""" + self.statusBar().showMessage(final_message) + # Show final message in a pop-up or log box + QMessageBox.information(self, "Training Complete", final_message) + + # Clean up and re-enable the button + self.train_thread = None + + def train_model_folder(self): + return + + def test_model_video(self): + model_path, video_path = self.get_testing_files() + + if not model_path or not video_path: + self.statusBar().showMessage("Testing cancelled. Model or Video not selected.") + return + + self.statusBar().showMessage(f"Starting model testing on {os.path.basename(video_path)}...") + + # 1. Create the thread + # Note: TestModelThread expects a list of video paths, so we pass [video_path] + self.test_thread = TestModelThread(video_paths=[video_path], model_path=model_path) + + # 2. Connect signals + # The update_frame signal is not used to update the main GUI here, + # as the thread is responsible for showing the cv2.imshow window. + self.test_thread.finished.connect(self.on_test_finished) + + # 3. Start the thread + self.test_thread.start() + + + + def test_model_csv(self): + model_path, csv_path = self.get_testing_files2() + + if not model_path or not csv_path: + self.statusBar().showMessage("Testing cancelled. Model or Video not selected.") + return + + self.statusBar().showMessage(f"Starting model testing on {os.path.basename(csv_path)}...") + + # 1. Create the thread + # Note: TestModelThread expects a list of video paths, so we pass [video_path] + self.test_thread = TestModelThread2(csv_path=csv_path, model_path=model_path) + + # 2. Connect signals + # The update_frame signal is not used to update the main GUI here, + # as the thread is responsible for showing the cv2.imshow window. + self.test_thread.finished.connect(self.on_csv_analysis_finished) + + # 3. Start the thread + self.test_thread.start() + + + def on_csv_analysis_finished(self, result): + if "error" in result: + QMessageBox.critical(self, "Error", result["error"]) + return + + time = result["time"] + active = result["active"] + before = result["before"] + + import matplotlib.pyplot as plt + + fig, axes = plt.subplots(2, 1, figsize=(14, 6), sharex=True) + + # --- Top plot: Reach Active --- + axes[0].plot(time, active, color="green", label="Reach Active") + axes[0].set_ylabel("Probability") + axes[0].set_title("Reach Active") + axes[0].set_ylim(0, 1) + axes[0].grid(True) + + # --- Bottom plot: Reach Before Contact --- + axes[1].plot(time, before, color="orange", label="Reach Before Contact") + axes[1].set_ylabel("Probability") + axes[1].set_title("Reach Before Contact") + axes[1].set_xlabel("Time (s)") + axes[1].set_ylim(0, 1) + axes[1].grid(True) + + plt.tight_layout() + plt.show() + + self.statusBar().showMessage("Complete.") + + + + def on_test_finished(self): + """Handles cleanup after the testing video window is closed.""" + self.statusBar().showMessage("Model testing complete.") + + # Ensure the OpenCV window is destroyed + try: + cv2.destroyAllWindows() + except: + pass + + # Clean up and re-enable the button + self.test_thread = None + + def test_model_folder(self): + return + + + + + def about_window(self): + if self.about is None or not self.about.isVisible(): + self.about = AboutWindow(self) + self.about.show() + + def user_guide(self): + if self.help is None or not self.help.isVisible(): + self.help = UserGuideWindow(self) + self.help.show() + + def terminal_gui(self): + if self.terminal is None or not self.terminal.isVisible(): + self.terminal = TerminalWindow(self) + self.terminal.show() + + def update_optode_positions(self): + return + if self.optodes is None or not self.optodes.isVisible(): + self.optodes = UpdateOptodesWindow(self) + self.optodes.show() + + + def update_event_markers(self): + return + if self.events is None or not self.events.isVisible(): + self.events = UpdateEventsWindow(self) + self.events.show() + + def open_file_dialog(self): + file_path, _ = QFileDialog.getOpenFileName( + self, "Open File", "", "BORIS Files (*.boris);;All Files (*)" + ) + if file_path: + + self.observations_root = self.resolve_path_to_observations(file_path) + if not self.observations_root: + self.statusBar().showMessage("File loading cancelled by user.") + return # User cancelled the folder selection + + # 2. Initialize and Show Progress Dialog + self.progress_dialog = ProgressDialog(self) + # ⚠️ Connect the cancel button to the task stopper + self.progress_dialog.cancel_button.clicked.connect(self.cancel_task) + self.progress_dialog.show() + + # 3. Initialize and Start Worker Thread + self.worker_thread = FileLoadWorker( + file_path, + self.observations_root, + self.extract_frame_and_hands # Assuming this is a method available to MainApplication + ) + + # 4. Connect Signals to Main Thread Slots + self.worker_thread.progress_update.connect(self.progress_dialog.update_progress) + self.worker_thread.observations_loaded.connect(self.on_files_loaded) + self.worker_thread.loading_finished.connect(self.on_loading_finished) + self.worker_thread.loading_failed.connect(self.on_loading_failed) + + self.worker_thread.start() + + + + def resolve_path_to_observations(self, boris_file_path): + """ + Attempts to find the 'Observations' directory based on the BORIS file location. + Prompts the user if it cannot be found. + """ + # 1. Get the directory containing the BORIS file + boris_dir = os.path.dirname(boris_file_path) + + # 2. Check if the 'Observations' folder is adjacent to the BORIS file + initial_check_path = os.path.join(boris_dir, "Observations") + + if os.path.isdir(initial_check_path): + # Success: The 'Observations' folder is found next to the BORIS file + return initial_check_path + + # 3. If not found, prompt the user to locate the Observations folder + QMessageBox.warning( + self, + "Observations Folder Missing", + "The 'Observations' folder (containing the videos) was not found next to the BORIS file. Please select the correct root folder." + ) + + # Open a folder selection dialog + folder_path = QFileDialog.getExistingDirectory( + self, + "Select Observations Root Folder", + boris_dir # Start the dialog in the BORIS file's directory + ) + + if folder_path: + # Check if the user selected a folder named 'Observations' + if os.path.basename(folder_path) == "Observations": + return folder_path + + # Check if the user selected the *parent* of 'Observations' + elif os.path.isdir(os.path.join(folder_path, "Observations")): + return os.path.join(folder_path, "Observations") + + # If the selected folder is neither, return the selected path and hope it contains the relative paths + else: + return folder_path + + # If the user cancels the dialog, return None + return None + + + + def on_files_loaded(self, boris, observations, root_path): + # 1. Update MainApplication state variables + self.boris = boris + self.observations = observations + self.observations_root = root_path + + # 2. Build the UI grid using the data gathered by the worker + self.build_preview_grid(self.worker_thread.previews_data) + + self.statusBar().showMessage(f"{self.worker_thread.file_path} loaded.") + self.button1.setVisible(True) + + def on_loading_finished(self): + if self.progress_dialog: + self.progress_dialog.accept() # Close the dialog if finished successfully + + def on_loading_failed(self, error_message): + if self.progress_dialog: + self.progress_dialog.reject() # Close the dialog + QMessageBox.critical(self, "Error Loading Files", error_message) + self.statusBar().showMessage("File loading failed.") + + def cancel_task(self): + if self.worker_thread and self.worker_thread.isRunning(): + self.worker_thread.stop() # Tell the thread to stop processing + self.worker_thread.wait() # Wait for the thread to safely exit + if self.progress_dialog: + self.progress_dialog.reject() + self.statusBar().showMessage("File loading cancelled.") + + # Add a dummy implementation for the processing function if it's not shown: + # def extract_frame_and_hands(self, video_path, frame_idx): + # # Placeholder implementation - replace with your actual function + # return None, None + + def build_preview_grid(self, previews_data): + for obs_id, obs_data in self.observations.items(): + + group = QGroupBox(f"Participant / Observation: {obs_id}") + grouplayout = QVBoxLayout(group) + + # Participant-level skip dropdown + participant_dropdown = QComboBox() + participant_dropdown.addItem("Skip this participant", 0) + participant_dropdown.addItem("Process this participant", 1) + participant_dropdown.setCurrentIndex(0) # default: skip + if "participant_selection" not in self.selection_widgets: + self.selection_widgets["participant_selection"] = {} + self.selection_widgets["participant_selection"][obs_id] = participant_dropdown + grouplayout.addWidget(QLabel("Participant Option:")) + grouplayout.addWidget(participant_dropdown) + + files_dict = obs_data.get("file", {}) + media_info = obs_data.get("media_info", {}) + print(media_info) + fps_dict = media_info.get("fps", {}) + fps_default = float(list(fps_dict.values())[0]) if fps_dict else 30.0 + + if obs_id not in self.selection_widgets: + self.selection_widgets[obs_id] = {} + + for cam_id, paths in files_dict.items(): + + # Check if the worker successfully gathered the preview data for this file + state_key = (obs_id, cam_id) + if state_key not in previews_data: + # Skip files the worker couldn't process (e.g., path not found) + continue + + # Retrieve the pre-calculated data + state_data = previews_data[state_key] + frame_rgb = state_data["frame_rgb"] + results = state_data["results"] + + # Only draw the overlay if the worker found data + if frame_rgb is not None: + display_img = self.draw_hand_overlay(frame_rgb, results) + else: + # Use a placeholder image or skip if no data + continue + + # Convert to pixmap + h, w, _ = display_img.shape + qimg = QImage(display_img.data, w, h, 3 * w, QImage.Format_RGB888) + pix = QPixmap.fromImage(qimg).scaled(350, 350, Qt.KeepAspectRatio) + + # -------- UI Row -------- + row = QWidget() + row_layout = QHBoxLayout(row) + + preview_label = QLabel() + preview_label.setPixmap(pix) + row_layout.addWidget(preview_label) + + # Dropdown + dropdown = QComboBox() + dropdown.addItem("Skip this camera", -1) + if results and results.multi_hand_landmarks: + for idx in range(len(results.multi_hand_landmarks)): + dropdown.addItem(f"Use Hand {idx}", idx) + row_layout.addWidget(dropdown) + + # Store dropdown + self.selection_widgets[obs_id][cam_id] = dropdown + + self.camera_state[(obs_id, cam_id)] = { + "path": state_data["video_path"], + "frame_idx": 0, + "fps": state_data["fps"], + "preview_label": preview_label, + "dropdown": dropdown, + "initial_wrists": state_data["initial_wrists"] + } + + # Skip button + skip_btn = QPushButton("Skip 1s → Rescan") + skip_btn.clicked.connect(lambda _, o=obs_id, c=cam_id: self.skip_and_rescan(o, c)) + row_layout.addWidget(skip_btn) + + grouplayout.addWidget(row) + + self.bubble_layout.addWidget(group) + + #self.bubble_layout.addStretch() + + + + + # ============================================ + # FRAME EXTRACTION + MEDIA PIPE DETECTION + # ============================================ + def extract_frame_and_hands(self, video_path, frame_idx): + cap = cv2.VideoCapture(video_path) + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) + ret, frame = cap.read() + cap.release() + + if not ret: + return None, None + + rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + results = self.mp_hands.process(rgb) + return rgb, results + + + # ============================================ + # DRAW OVERLAYED HANDS + BIG LABELS + # ============================================ + def draw_hand_overlay(self, img, results, initial_wrists=None): + draw_img = img.copy() + h, w, _ = draw_img.shape + + COLORS = [ + (0, 255, 0), # Hand 0 green + (255, 0, 255), # Hand 1 magenta + (0, 255, 255), # Hand 2 cyan + ] + + if results and results.multi_hand_landmarks: + # Map detected hands to initial wrists if provided + if initial_wrists: + hand_mapping = {} # initial_idx -> current_idx + used_indices = set() + for i, iw in enumerate(initial_wrists): + min_dist = float('inf') + best_idx = None + for j, lm in enumerate(results.multi_hand_landmarks): + if j in used_indices: + continue # skip already assigned + wrist = lm.landmark[0] + dist = (wrist.x - iw[0])**2 + (wrist.y - iw[1])**2 + if dist < min_dist: + min_dist = dist + best_idx = j + if best_idx is not None: + hand_mapping[i] = best_idx + used_indices.add(best_idx) + else: + hand_mapping = {i: i for i in range(len(results.multi_hand_landmarks))} + + for initial_idx, current_idx in hand_mapping.items(): + lm_obj = results.multi_hand_landmarks[current_idx] + mp.solutions.drawing_utils.draw_landmarks( + draw_img, + lm_obj, + mp.solutions.hands.HAND_CONNECTIONS + ) + wrist = lm_obj.landmark[0] + wx, wy = int(wrist.x * w), int(wrist.y * h) + color = COLORS[initial_idx % len(COLORS)] + + # Big outline + cv2.putText(draw_img, str(initial_idx), (wx, wy - 40), + cv2.FONT_HERSHEY_SIMPLEX, 3.2, + (0, 0, 0), 10, cv2.LINE_AA) + + # Big colored label + cv2.putText(draw_img, str(initial_idx), (wx, wy - 40), + cv2.FONT_HERSHEY_SIMPLEX, 3.2, + color, 6, cv2.LINE_AA) + + return draw_img + + + # ============================================ + # SKIP 1s AND RESCAN FRAME + # ============================================ + def skip_and_rescan(self, obs_id, cam_id): + key = (obs_id, cam_id) + state = self.camera_state[key] + + fps = state["fps"] + state["frame_idx"] += int(fps) + + rgb, results = self.extract_frame_and_hands( + state["path"], state["frame_idx"] + ) + if rgb is None: + return + + display_img = self.draw_hand_overlay(rgb, results, state.get("initial_wrists")) + h, w, _ = display_img.shape + qimg = QImage(display_img.data, w, h, 3 * w, QImage.Format_RGB888) + pix = QPixmap.fromImage(qimg).scaled(350, 350, Qt.KeepAspectRatio) + + # Update preview + state["preview_label"].setPixmap(pix) + + # Update dropdown + dropdown = state["dropdown"] + dropdown.clear() + dropdown.addItem("Skip this camera", -1) + if results and results.multi_hand_landmarks: + for idx in range(len(results.multi_hand_landmarks)): + dropdown.addItem(f"Use Hand {idx}", idx) + + + + # def save_project(self, onCrash=False): + + # if not onCrash: + # filename, _ = QFileDialog.getSaveFileName( + # self, "Save Project", "", "SPARK Project (*.spark)" + # ) + # if not filename: + # return + # else: + # if PLATFORM_NAME == "darwin": + # filename = os.path.join(os.path.dirname(sys.executable), "../../../sparks_autosave.spark") + # else: + # filename = os.path.join(os.getcwd(), "sparks_autosave.spark") + + # try: + # # Ensure the filename has the proper extension + # if not filename.endswith(".spark"): + # filename += ".spark" + + # project_path = Path(filename).resolve() + # project_dir = project_path.parent + + # file_list = [ + # str(PurePosixPath(Path(bubble.file_path).resolve().relative_to(project_dir))) + # for bubble in self.bubble_widgets.values() + # ] + + # progress_states = { + # str(PurePosixPath(Path(bubble.file_path).resolve().relative_to(project_dir))): bubble.current_step + # for bubble in self.bubble_widgets.values() + # } + + # project_data = { + # "file_list": file_list, + # "progress_states": progress_states, + # "raw_haemo_dict": self.raw_haemo_dict, + # "epochs_dict": self.epochs_dict, + # "fig_bytes_dict": self.fig_bytes_dict, + # "cha_dict": self.cha_dict, + # "contrast_results_dict": self.contrast_results_dict, + # "df_ind_dict": self.df_ind_dict, + # "design_matrix_dict": self.design_matrix_dict, + # "age_dict": self.age_dict, + # "gender_dict": self.gender_dict, + # "group_dict": self.group_dict, + # "valid_dict": self.valid_dict, + # } + + # def sanitize(obj): + # if isinstance(obj, Path): + # return str(PurePosixPath(obj)) + # elif isinstance(obj, dict): + # return {sanitize(k): sanitize(v) for k, v in obj.items()} + # elif isinstance(obj, list): + # return [sanitize(i) for i in obj] + # return obj + + # project_data = sanitize(project_data) + + # with open(filename, "wb") as f: + # pickle.dump(project_data, f) + + # QMessageBox.information(self, "Success", f"Project saved to:\n{filename}") + + # except Exception as e: + # if not onCrash: + # QMessageBox.critical(self, "Error", f"Failed to save project:\n{e}") + + + + # def load_project(self): + # filename, _ = QFileDialog.getOpenFileName( + # self, "Load Project", "", "SPARK Project (*.spark)" + # ) + # if not filename: + # return + + # try: + # with open(filename, "rb") as f: + # data = pickle.load(f) + + # self.raw_haemo_dict = data.get("raw_haemo_dict", {}) + # self.epochs_dict = data.get("epochs_dict", {}) + # self.fig_bytes_dict = data.get("fig_bytes_dict", {}) + # self.cha_dict = data.get("cha_dict", {}) + # self.contrast_results_dict = data.get("contrast_results_dict", {}) + # self.df_ind_dict = data.get("df_ind_dict", {}) + # self.design_matrix_dict = data.get("design_matrix_dict", {}) + # self.age_dict = data.get("age_dict", {}) + # self.gender_dict = data.get("gender_dict", {}) + # self.group_dict = data.get("group_dict", {}) + # self.valid_dict = data.get("valid_dict", {}) + + # project_dir = Path(filename).parent + + # # Convert saved relative paths to absolute paths + # file_list = [str((project_dir / Path(rel_path)).resolve()) for rel_path in data["file_list"]] + + # # Also resolve progress_states with updated paths + # raw_progress = data.get("progress_states", {}) + # progress_states = { + # str((project_dir / Path(rel_path)).resolve()): step + # for rel_path, step in raw_progress.items() + # } + + # self.show_files_as_bubbles_from_list(file_list, progress_states, filename) + + # # Re-enable buttons + # # self.button1.setVisible(True) + # self.button3.setVisible(True) + + # QMessageBox.information(self, "Loaded", f"Project loaded from:\n{filename}") + + # except Exception as e: + # QMessageBox.critical(self, "Error", f"Failed to load project:\n{e}") + + + def placeholder(self): + QMessageBox.information(self, "Placeholder", "This feature is not implemented yet.") + + def save_metadata(self, file_path): + if not file_path: + return + + self.file_metadata[file_path] = { + key: field.text() + for key, field in self.meta_fields.items() + } + + def get_all_metadata(self): + # First, make sure current file's edits are saved + + for field in self.meta_fields.values(): + field.clearFocus() + + # Save current file's metadata + if self.current_file: + self.save_metadata(self.current_file) + + return self.file_metadata + + + def cancel_task(self): + self.button1.clicked.disconnect(self.cancel_task) + self.button1.setText("Stopping...") + + if hasattr(self, "result_process") and self.result_process.is_alive(): + parent = psutil.Process(self.result_process.pid) + children = parent.children(recursive=True) + for child in children: + try: + child.kill() + except psutil.NoSuchProcess: + pass + self.result_process.terminate() + self.result_process.join() + + if hasattr(self, "result_timer") and self.result_timer.isActive(): + self.result_timer.stop() + + # if hasattr(self, "result_process") and self.result_process.is_alive(): + # self.result_process.terminate() # Forcefully terminate the process + # self.result_process.join() # Wait for it to properly close + + # # Stop the QTimer if running + # if hasattr(self, "result_timer") and self.result_timer.isActive(): + # self.result_timer.stop() + + + self.statusbar.showMessage("Processing cancelled.") + self.button1.clicked.connect(self.on_run_task) + self.button1.setText("Process") + + + '''MODULE FILE''' + def on_run_task(self): + + self.button1.clicked.disconnect(self.on_run_task) + self.button1.setText("Cancel") + self.button1.clicked.connect(self.cancel_task) + + # Clear previous processing widgets (if any) + # for obs_id, widgets in getattr(self, "processing_widgets", {}).items(): + # for w in widgets.values(): + # if isinstance(w, QWidget): + # w.deleteLater() + + self.clear_bubble_widgets() + + self.processing_widgets = {} + self.processing_threads = [] + + self.output_dir = self.get_output_directory() + if not self.output_dir: + # Revert button state if user cancels + self.button1.clicked.disconnect(self.cancel_task) + self.button1.setText("Process") + self.button1.clicked.connect(self.on_run_task) + self.statusBar().showMessage("Processing cancelled by user.") + return + + selected_files_to_process = [] + + # 2️⃣ Iterate through all participant/camera combinations to find what to process + # We must iterate over all entries in self.selection_widgets (the structure built by build_preview_grid) + for obs_id, cam_dropdowns in self.selection_widgets.items(): + if obs_id == "participant_selection": + continue # skip the participant-level dropdowns + for cam_id, dropdown in cam_dropdowns.items(): + + selected_hand_idx = dropdown.currentData() + + # Check if the selection is NOT 'Skip' (-1) + if selected_hand_idx != -1: + state_data = self.camera_state.get((obs_id, cam_id)) + if state_data: + selected_files_to_process.append({ + "obs_id": obs_id, + "cam_id": cam_id, + "hand_idx": selected_hand_idx, + "path": state_data.get("path", "Unknown") + }) + + if not selected_files_to_process: + self.statusBar().showMessage("No files selected for processing (use dropdowns).") + # Revert button state + self.button1.clicked.disconnect(self.cancel_task) + self.button1.setText("Process") + self.button1.clicked.connect(self.on_run_task) + return + + self.statusBar().showMessage(f"Starting to process {len(selected_files_to_process)} files...") + QApplication.processEvents() # Refresh GUI before starting threads + + for file_info in selected_files_to_process: + obs_id = file_info["obs_id"] + cam_id = file_info["cam_id"] + selected_hand_idx = file_info["hand_idx"] # THIS is fixed per iteration + + # ... (thread logic setup) ... + + # --- PROGRESS ROW UI SETUP (50% Label | 40% Bar | 10% Time) --- + row = QWidget() + row_layout = QHBoxLayout(row) + row_layout.setContentsMargins(0, 0, 0, 0) + + # LEFT SIDE (50%): Filename Label + filename = os.path.basename(file_info["path"]) + filename_label = QLabel(f"P {obs_id}/{cam_id}: {filename}") + filename_label.setWordWrap(True) + row_layout.addWidget(filename_label, stretch=5) # 5 units = 50% + + # MIDDLE (40%): Progress Bar + progress_bar = QProgressBar() + progress_bar.setRange(0, 100) + row_layout.addWidget(progress_bar, stretch=4) # 4 units = 40% + + # RIGHT SIDE (10%): Time Indicator Label + time_label = QLabel("00:00/00:00") # Default placeholder + time_label.setAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter) + row_layout.addWidget(time_label, stretch=1) # 1 unit = 10% + + # Add the combined row to the bubble layout area + row_index = self.bubble_layout.rowCount() + self.bubble_layout.addWidget(row, row_index, 0, 1, 1) + + # Save widgets for thread updates + state_key = (obs_id, cam_id) + self.processing_widgets[state_key] = { + "label": filename_label, + "progress": progress_bar, + # 👇 NEW: Save the time label reference + "time_label": time_label + } + # --- THREAD START --- + output_csv = f"{obs_id}_{cam_id}_processed.csv" + + # ⚠️ Ensure ParticipantProcessor is defined to accept self.observations_root + processor = ParticipantProcessor( + obs_id=obs_id, + boris_json=self.boris, + selected_cam_id=cam_id, + selected_hand_idx=selected_hand_idx, + observations_root=self.observations_root, # CRITICAL PATH FIX + hide_preview=False, # Assuming default for now + output_csv=output_csv, + output_dir=self.output_dir, + initial_wrists=self.camera_state[(obs_id, cam_id)].get("initial_wrists") + ) + + # Connect signals to the UI elements + processor.progress_updated.connect(progress_bar.setValue) + processor.finished_processing.connect( + lambda obs_id=obs_id, cam_id=cam_id: self.on_processing_finished(obs_id, cam_id) + ) + processor.time_updated.connect(time_label.setText) # Connects to THIS time_label + processor.start() + self.processing_threads.append(processor) + + # Set the vertical stretch to push all progress bars to the top + num_rows = self.bubble_layout.rowCount() + self.bubble_layout.setRowStretch(num_rows, 1) + + + def check_for_pipeline_results(self): + return + + def on_processing_finished(self, obs_id, cam_id): + state_key = (obs_id, cam_id) + if state_key in self.processing_widgets: + widgets = self.processing_widgets[state_key] + widgets["progress"].setValue(100) + label = self.processing_widgets[state_key]["label"] + label.setText(f"P {obs_id}/{cam_id}: ✅ COMPLETE") + + # Check if all threads are finished to revert the main button + if not any(t.isRunning() for t in self.processing_threads): + self.cancel_task() # Reverts the button and cleans up + self.statusBar().showMessage("All processing complete!") + + + def cancel_task(self): + """Stops all active processing threads and reverts UI to 'Process' state.""" + + if not self.processing_threads: + return + + self.statusBar().showMessage("Cancelling processing... Please wait.") + + # 1. Stop all threads gracefully + for thread in self.processing_threads: + if thread.isRunning(): + # Assuming ParticipantProcessor has a .stop() method + thread.stop() + thread.wait() # Wait for thread to exit its loop + + # 2. Clean up threads and widgets + self.processing_threads.clear() + + # 3. Revert button and status bar + self.button1.clicked.disconnect(self.cancel_task) + self.button1.setText("Process") + self.button1.clicked.connect(self.on_run_task) + + # 4. Clear the progress bars display + #self.clear_bubble_widgets() + self.statusBar().showMessage("Processing cancelled. Ready to load or process.") + + def show_error_popup(self, title, error_message, traceback_str=""): + msgbox = QMessageBox(self) + msgbox.setIcon(QMessageBox.Warning) + msgbox.setWindowTitle("Warning - SPARKS") + + message = ( + f"SPARKS has encountered an error processing the file {title}.

" + "This error was likely due to incorrect parameters on the right side of the screen and not an error with your data. " + "Processing of the remaining files continues in the background and this participant will be ignored in the analysis. " + "If you think the parameters on the right side are correct for your data, raise an issue here.

" + f"Error message: {error_message}" + ) + + msgbox.setTextFormat(Qt.TextFormat.RichText) + msgbox.setText(message) + msgbox.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction) + + # Add traceback to detailed text + if traceback_str: + msgbox.setDetailedText(traceback_str) + + msgbox.setStandardButtons(QMessageBox.Ok) + msgbox.exec_() + + + def cleanup_after_process(self): + + if hasattr(self, 'result_process'): + self.result_process.join(timeout=0) + if self.result_process.is_alive(): + self.result_process.terminate() + self.result_process.join() + + if hasattr(self, 'result_queue'): + if 'AutoProxy' in repr(self.result_queue): + pass + else: + self.result_queue.close() + self.result_queue.join_thread() + + if hasattr(self, 'progress_queue'): + if 'AutoProxy' in repr(self.progress_queue): + pass + else: + self.progress_queue.close() + self.progress_queue.join_thread() + + # Shutdown manager to kill its server process and clean up + if hasattr(self, 'manager'): + self.manager.shutdown() + + + + + '''UPDATER''' + def manual_check_for_updates(self): + self.local_check_thread = LocalPendingUpdateCheckThread(CURRENT_VERSION, self.platform_suffix) + self.local_check_thread.pending_update_found.connect(self.on_pending_update_found) + self.local_check_thread.no_pending_update.connect(self.on_no_pending_update) + self.local_check_thread.start() + + def on_pending_update_found(self, version, folder_path): + self.statusBar().showMessage(f"Pending update found: version {version}") + self.pending_update_version = version + self.pending_update_path = folder_path + self.show_pending_update_popup() + + def on_no_pending_update(self): + # No pending update found locally, start server check directly + self.statusBar().showMessage("No pending local update found. Checking server...") + self.start_update_check_thread() + + def show_pending_update_popup(self): + msg_box = QMessageBox(self) + msg_box.setWindowTitle("Pending Update Found") + msg_box.setText(f"A previously downloaded update (version {self.pending_update_version}) is available at:\n{self.pending_update_path}\nWould you like to install it now?") + install_now_button = msg_box.addButton("Install Now", QMessageBox.ButtonRole.AcceptRole) + install_later_button = msg_box.addButton("Install Later", QMessageBox.ButtonRole.RejectRole) + msg_box.exec() + + if msg_box.clickedButton() == install_now_button: + self.install_update(self.pending_update_path) + else: + self.statusBar().showMessage("Pending update available. Install later.") + # After user dismisses, still check the server for new updates + self.start_update_check_thread() + + def start_update_check_thread(self): + self.check_thread = UpdateCheckThread() + self.check_thread.download_requested.connect(self.on_server_update_requested) + self.check_thread.no_update_available.connect(self.on_server_no_update) + self.check_thread.error_occurred.connect(self.on_error) + self.check_thread.start() + + def on_server_no_update(self): + self.statusBar().showMessage("No new updates found on server.", 5000) + + def on_server_update_requested(self, download_url, latest_version): + if self.pending_update_version: + cmp = self.version_compare(latest_version, self.pending_update_version) + if cmp > 0: + # Server version is newer than pending update + self.statusBar().showMessage(f"Newer version {latest_version} available on server. Removing old pending update...") + try: + shutil.rmtree(self.pending_update_path) + self.statusBar().showMessage(f"Deleted old update folder: {self.pending_update_path}") + except Exception as e: + self.statusBar().showMessage(f"Failed to delete old update folder: {e}") + + # Clear pending update info so new download proceeds + self.pending_update_version = None + self.pending_update_path = None + + # Download the new update + self.download_update(download_url, latest_version) + elif cmp == 0: + # Versions equal, no download needed + self.statusBar().showMessage(f"Pending update version {self.pending_update_version} is already latest. No download needed.") + else: + # Server version older than pending? Unlikely but just keep pending update + self.statusBar().showMessage(f"Pending update version {self.pending_update_version} is newer than server version. No action.") + else: + # No pending update, just download + self.download_update(download_url, latest_version) + + def download_update(self, download_url, latest_version): + self.statusBar().showMessage("Downloading update...") + self.download_thread = UpdateDownloadThread(download_url, latest_version) + self.download_thread.update_ready.connect(self.on_update_ready) + self.download_thread.error_occurred.connect(self.on_error) + self.download_thread.start() + + def on_update_ready(self, latest_version, extract_folder): + self.statusBar().showMessage("Update downloaded and extracted.") + + msg_box = QMessageBox(self) + msg_box.setWindowTitle("Update Ready") + msg_box.setText(f"Version {latest_version} has been downloaded and extracted to:\n{extract_folder}\nWould you like to install it now?") + install_now_button = msg_box.addButton("Install Now", QMessageBox.ButtonRole.AcceptRole) + install_later_button = msg_box.addButton("Install Later", QMessageBox.ButtonRole.RejectRole) + + msg_box.exec() + + if msg_box.clickedButton() == install_now_button: + self.install_update(extract_folder) + else: + self.statusBar().showMessage("Update ready. Install later.") + + + def install_update(self, extract_folder): + # Path to updater executable + + if PLATFORM_NAME == 'windows': + updater_path = os.path.join(os.getcwd(), "sparks_updater.exe") + elif PLATFORM_NAME == 'darwin': + if getattr(sys, 'frozen', False): + updater_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks_updater.app") + else: + updater_path = os.path.join(os.getcwd(), "../sparks_updater.app") + + elif PLATFORM_NAME == 'linux': + updater_path = os.path.join(os.getcwd(), "sparks_updater") + else: + updater_path = os.getcwd() + + if not os.path.exists(updater_path): + QMessageBox.critical(self, "Error", f"Updater not found at:\n{updater_path}. The absolute path was {os.path.abspath(updater_path)}") + return + + # Launch updater with extracted folder path as argument + try: + # Pass current app's executable path for updater to relaunch + main_app_executable = os.path.abspath(sys.argv[0]) + + print(f'Launching updater with: "{updater_path}" "{extract_folder}" "{main_app_executable}"') + + if PLATFORM_NAME == 'darwin': + subprocess.Popen(['open', updater_path, '--args', extract_folder, main_app_executable]) + else: + subprocess.Popen([updater_path, f'{extract_folder}', f'{main_app_executable}'], cwd=os.path.dirname(updater_path)) + + # Close the current app so updater can replace files + sys.exit(0) + + except Exception as e: + QMessageBox.critical(self, "Error", f"[Updater Launch Failed]\n{str(e)}\n{traceback.format_exc()}") + + def on_error(self, message): + # print(f"Error: {message}") + self.statusBar().showMessage(f"Error occurred during update process. {message}") + + def version_compare(self, v1, v2): + def normalize(v): return [int(x) for x in v.split(".")] + return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2)) + + + + def closeEvent(self, event): + # Gracefully shut down multiprocessing children + print("Window is closing. Cleaning up...") + + if hasattr(self, 'manager'): + self.manager.shutdown() + + for child in self.findChildren(QWidget): + if child is not self and child.isVisible(): + child.close() + + kill_child_processes() + + event.accept() + + +def wait_for_process_to_exit(process_name, timeout=10): + """ + Waits for a process with the specified name to exit within a timeout period. + + Args: + process_name (str): Name (or part of the name) of the process to wait for. + timeout (int, optional): Maximum time to wait in seconds. Defaults to 10. + + Returns: + bool: True if the process exited before the timeout, False otherwise. + """ + + print(f"Waiting for {process_name} to exit...") + deadline = time.time() + timeout + while time.time() < deadline: + still_running = False + for proc in psutil.process_iter(['name']): + try: + if proc.info['name'] and process_name.lower() in proc.info['name'].lower(): + still_running = True + print(f"Still running: {proc.info['name']} (PID: {proc.pid})") + break + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + if not still_running: + print(f"{process_name} has exited.") + return True + time.sleep(0.5) + print(f"{process_name} did not exit in time.") + return False + + +def finish_update_if_needed(): + """ + Completes a pending application update if '--finish-update' is present in the command-line arguments. + """ + + if "--finish-update" in sys.argv: + print("Finishing update...") + + if PLATFORM_NAME == 'darwin': + app_dir = '/tmp/sparkstempupdate' + else: + app_dir = os.getcwd() + + # 1. Find update folder + update_folder = None + for entry in os.listdir(app_dir): + entry_path = os.path.join(app_dir, entry) + if os.path.isdir(entry_path) and entry.startswith("sparks-") and entry.endswith("-" + PLATFORM_NAME): + update_folder = os.path.join(app_dir, entry) + break + + if update_folder is None: + print("No update folder found. Skipping update steps.") + return + + if PLATFORM_NAME == 'darwin': + update_folder = os.path.join(update_folder, "sparks-darwin") + + # 2. Wait for sparks_updater to exit + print("Waiting for sparks_updater to exit...") + for proc in psutil.process_iter(['pid', 'name']): + if proc.info['name'] and "sparks_updater" in proc.info['name'].lower(): + try: + proc.wait(timeout=5) + except psutil.TimeoutExpired: + print("Force killing lingering sparks_updater") + proc.kill() + + # 3. Replace the updater + if PLATFORM_NAME == 'windows': + new_updater = os.path.join(update_folder, "sparks_updater.exe") + dest_updater = os.path.join(app_dir, "sparks_updater.exe") + + elif PLATFORM_NAME == 'darwin': + new_updater = os.path.join(update_folder, "sparks_updater.app") + dest_updater = os.path.abspath(os.path.join(sys.executable, "../../../../sparks_updater.app")) + + elif PLATFORM_NAME == 'linux': + new_updater = os.path.join(update_folder, "sparks_updater") + dest_updater = os.path.join(app_dir, "sparks_updater") + + else: + print("No platform??") + new_updater = os.getcwd() + dest_updater = os.getcwd() + + print(f"New updater is {new_updater}") + print(f"Dest updater is {dest_updater}") + + print("Writable?", os.access(dest_updater, os.W_OK)) + print("Executable path:", sys.executable) + print("Trying to copy:", new_updater, "->", dest_updater) + + if os.path.exists(new_updater): + try: + if os.path.exists(dest_updater): + if PLATFORM_NAME == 'darwin': + try: + if os.path.isdir(dest_updater): + shutil.rmtree(dest_updater) + print(f"Deleted directory: {dest_updater}") + else: + os.remove(dest_updater) + print(f"Deleted file: {dest_updater}") + except Exception as e: + print(f"Error deleting {dest_updater}: {e}") + else: + os.remove(dest_updater) + + if PLATFORM_NAME == 'darwin': + wait_for_process_to_exit("sparks_updater", timeout=10) + subprocess.check_call(["ditto", new_updater, dest_updater]) + else: + shutil.copy2(new_updater, dest_updater) + + if PLATFORM_NAME in ('linux', 'darwin'): + os.chmod(dest_updater, 0o755) + + if PLATFORM_NAME == 'darwin': + remove_quarantine(dest_updater) + + print("sparks_updater replaced.") + except Exception as e: + print(f"Failed to replace sparks_updater: {e}") + + # 4. Delete the update folder + try: + if PLATFORM_NAME == 'darwin': + shutil.rmtree(app_dir) + else: + shutil.rmtree(update_folder) + except Exception as e: + print(f"Failed to delete update folder: {e}") + + QMessageBox.information(None, "Update Complete", "The application has been successfully updated.") + sys.argv.remove("--finish-update") + + +def remove_quarantine(app_path): + """ + Removes the macOS quarantine attribute from the specified application path. + """ + + script = f''' + do shell script "xattr -d -r com.apple.quarantine {shlex.quote(app_path)}" with administrator privileges with prompt "SPARKS needs privileges to finish the update. (2/2)" + ''' + try: + subprocess.run(['osascript', '-e', script], check=True) + print("✅ Quarantine attribute removed.") + except subprocess.CalledProcessError as e: + print("❌ Failed to remove quarantine attribute.") + print(e) + + +def resource_path(relative_path): + """ + Get absolute path to resource regardless of running directly or packaged using PyInstaller + """ + + if hasattr(sys, '_MEIPASS'): + # PyInstaller bundle path + base_path = sys._MEIPASS + else: + base_path = os.path.dirname(os.path.abspath(__file__)) + + return os.path.join(base_path, relative_path) + + +def kill_child_processes(): + """ + Goodbye children + """ + + try: + parent = psutil.Process(os.getpid()) + children = parent.children(recursive=True) + for child in children: + try: + child.kill() + except psutil.NoSuchProcess: + pass + psutil.wait_procs(children, timeout=5) + except Exception as e: + print(f"Error killing child processes: {e}") + + +def exception_hook(exc_type, exc_value, exc_traceback): + """ + Method that will display a popup when the program hard crashes containg what went wrong + """ + + error_msg = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) + print(error_msg) # also print to console + + kill_child_processes() + + # Show error message box + # Make sure QApplication exists (or create a minimal one) + app = QApplication.instance() + if app is None: + app = QApplication(sys.argv) + + show_critical_error(error_msg) + + # Exit the app after user acknowledges + sys.exit(1) + +def show_critical_error(error_msg): + msg_box = QMessageBox() + msg_box.setIcon(QMessageBox.Icon.Critical) + msg_box.setWindowTitle("Something went wrong!") + + if PLATFORM_NAME == "darwin": + log_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks.log") + log_path2 = os.path.join(os.path.dirname(sys.executable), "../../../sparks_error.log") + save_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks_autosave.spark") + + else: + log_path = os.path.join(os.getcwd(), "sparks.log") + log_path2 = os.path.join(os.getcwd(), "sparks_error.log") + save_path = os.path.join(os.getcwd(), "sparks_autosave.spark") + + + shutil.copy(log_path, log_path2) + log_path2 = Path(log_path2).absolute().as_posix() + autosave_path = Path(save_path).absolute().as_posix() + log_link = f"file:///{log_path2}" + autosave_link = f"file:///{autosave_path}" + + window.save_project(True) + + message = ( + "SPARKS has encountered an unrecoverable error and needs to close.

" + f"We are sorry for the inconvenience. An autosave was attempted to be saved to {autosave_path}, but it may not have been saved. " + "If the file was saved, it still may not be intact, openable, or contain the correct data. Use the autosave at your discretion.

" + "This unrecoverable error was likely due to an error with SPARKS and not your data.
" + f"Please raise an issue here and attach the error file located at {log_path2}

" + f"
{error_msg}
" + ) + + msg_box.setTextFormat(Qt.TextFormat.RichText) + msg_box.setText(message) + msg_box.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction) + msg_box.setStandardButtons(QMessageBox.StandardButton.Ok) + + msg_box.exec() + +if __name__ == "__main__": + # Redirect exceptions to the popup window + sys.excepthook = exception_hook + + # Set up application logging + if PLATFORM_NAME == "darwin": + log_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks.log") + else: + log_path = os.path.join(os.getcwd(), "sparks.log") + + try: + os.remove(log_path) + except: + pass + + sys.stdout = open(log_path, "a", buffering=1) + sys.stderr = sys.stdout + print(f"\n=== App started at {datetime.now()} ===\n") + + freeze_support() # Required for PyInstaller + multiprocessing + + # Only run GUI in the main process + if current_process().name == 'MainProcess': + app = QApplication(sys.argv) + finish_update_if_needed() + window = MainApplication() + + if PLATFORM_NAME == "darwin": + app.setWindowIcon(QIcon(resource_path("icons/main.icns"))) + window.setWindowIcon(QIcon(resource_path("icons/main.icns"))) + else: + app.setWindowIcon(QIcon(resource_path("icons/main.ico"))) + window.setWindowIcon(QIcon(resource_path("icons/main.ico"))) + window.show() + sys.exit(app.exec()) + +# Not 4000 lines yay! \ No newline at end of file diff --git a/sparks_updater.py b/sparks_updater.py new file mode 100644 index 0000000..d88f83b --- /dev/null +++ b/sparks_updater.py @@ -0,0 +1,255 @@ +""" +Filename: sparks_updater.py +Description: SPARKS updater executable + +Author: Tyler de Zeeuw +License: GPL-3.0 +""" + +# Built-in imports +import os +import sys +import time +import shlex +import psutil +import shutil +import platform +import subprocess +from datetime import datetime + +PLATFORM_NAME = platform.system().lower() + +if PLATFORM_NAME == 'darwin': + LOG_FILE = os.path.join(os.path.dirname(sys.executable), "../../../sparks_updater.log") +else: + LOG_FILE = os.path.join(os.getcwd(), "sparks_updater.log") + + +def log(msg): + with open(LOG_FILE, "a", encoding="utf-8") as f: + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + f.write(f"{timestamp} - {msg}\n") + + +def kill_all_processes_by_executable(exe_path): + terminated_any = False + exe_path = os.path.realpath(exe_path) + + if PLATFORM_NAME == 'windows': + for proc in psutil.process_iter(['pid', 'exe']): + try: + proc_exe = proc.info.get('exe') + if proc_exe and os.path.samefile(os.path.realpath(proc_exe), exe_path): + log(f"Terminating process: PID {proc.pid}") + _terminate_process(proc) + terminated_any = True + except Exception as e: + log(f"Error terminating process (Windows): {e}") + elif PLATFORM_NAME == 'linux': + for proc in psutil.process_iter(['pid', 'cmdline']): + try: + cmdline = proc.info.get('cmdline', []) + if cmdline: + proc_cmd = os.path.realpath(cmdline[0]) + if os.path.samefile(proc_cmd, exe_path): + log(f"Terminating process: PID {proc.pid}") + _terminate_process(proc) + terminated_any = True + except Exception as e: + log(f"Error terminating process (Linux): {e}") + + if not terminated_any: + log(f"No running processes found for {exe_path}") + return terminated_any + + +def _terminate_process(proc): + try: + proc.terminate() + proc.wait(timeout=10) + log(f"Process {proc.pid} terminated gracefully.") + except psutil.TimeoutExpired: + log(f"Process {proc.pid} did not terminate in time. Killing forcefully.") + proc.kill() + proc.wait(timeout=5) + log(f"Process {proc.pid} killed.") + + +def wait_for_unlock(path, timeout=100): + start_time = time.time() + while time.time() - start_time < timeout: + try: + if os.path.isdir(path): + shutil.rmtree(path) + else: + os.remove(path) + log(f"Deleted (after wait): {path}") + return + except Exception as e: + log(f"Still locked: {path} - {e}") + time.sleep(1) + log(f"Failed to delete after wait: {path}") + + +def delete_path(path): + if os.path.exists(path): + try: + if os.path.isdir(path): + shutil.rmtree(path) + log(f"Deleted directory: {path}") + else: + os.remove(path) + log(f"Deleted file: {path}") + except Exception as e: + log(f"Error deleting {path}: {e}") + + +def copy_update_files(src_folder, dest_folder, updater_name): + for item in os.listdir(src_folder): + if item.lower() == updater_name.lower(): + log(f"Skipping updater executable: {item}") + continue + s = os.path.join(src_folder, item) + d = os.path.join(dest_folder, item) + delete_path(d) + try: + if os.path.isdir(s): + shutil.copytree(s, d) + log(f"Copied folder: {s} -> {d}") + else: + shutil.copy2(s, d) + log(f"Copied file: {s} -> {d}") + except Exception as e: + log(f"Error copying {s} -> {d}: {e}") + + +def copy_update_files_darwin(src_folder, dest_folder, updater_name): + + updater_name = updater_name + ".app" + + for item in os.listdir(src_folder): + if item.lower() == updater_name.lower(): + log(f"Skipping updater executable: {item}") + continue + s = os.path.join(src_folder, item) + d = os.path.join(dest_folder, item) + delete_path(d) + try: + if os.path.isdir(s): + subprocess.check_call(["ditto", s, d]) + log(f"Copied folder with ditto: {s} -> {d}") + else: + shutil.copy2(s, d) + log(f"Copied file: {s} -> {d}") + except Exception as e: + log(f"Error copying {s} -> {d}: {e}") + + +def remove_quarantine(app_path): + script = f''' + do shell script "xattr -d -r com.apple.quarantine {shlex.quote(app_path)}" with administrator privileges with prompt "SPARKS needs privileges to finish the update. (1/2)" + ''' + try: + subprocess.run(['osascript', '-e', script], check=True) + print("✅ Quarantine attribute removed.") + except subprocess.CalledProcessError as e: + print("❌ Failed to remove quarantine attribute.") + print(e) + + +def main(): + try: + log(f"[Updater] sys.argv: {sys.argv}") + + if len(sys.argv) != 3: + log("Invalid arguments. Usage: sparks_updater ") + sys.exit(1) + + update_folder = sys.argv[1] + main_exe = sys.argv[2] + + # Interesting naming convention + parent_dir = os.path.dirname(os.path.abspath(main_exe)) + pparent_dir = os.path.dirname(parent_dir) + ppparent_dir = os.path.dirname(pparent_dir) + pppparent_dir = os.path.dirname(ppparent_dir) + + updater_name = os.path.basename(sys.argv[0]) + + log("Updater started.") + log(f"Update folder: {update_folder}") + log(f"Main EXE: {main_exe}") + log(f"Updater EXE: {updater_name}") + if PLATFORM_NAME == 'darwin': + log(f"Main App Folder: {ppparent_dir}") + + # Kill all instances of main app + kill_all_processes_by_executable(main_exe) + + # Wait until main_exe process is fully gone (polling) + for _ in range(20): # wait max 10 seconds + running = False + for proc in psutil.process_iter(['exe', 'cmdline']): + try: + if PLATFORM_NAME == 'windows': + proc_exe = proc.info.get('exe') + if proc_exe and os.path.samefile(os.path.realpath(proc_exe), os.path.realpath(main_exe)): + running = True + break + elif PLATFORM_NAME == 'linux': + cmdline = proc.info.get('cmdline', []) + if cmdline: + proc_cmd = os.path.realpath(cmdline[0]) + if os.path.samefile(proc_cmd, os.path.realpath(main_exe)): + running = True + break + except Exception as e: + log(f"Polling error: {e}") + if not running: + break + time.sleep(0.5) + else: + log("Warning: main executable still running after wait timeout.") + + # Delete old version files + if PLATFORM_NAME == 'darwin': + log(f'Attempting to delete {ppparent_dir}') + delete_path(ppparent_dir) + update_folder = os.path.join(sys.argv[1], "sparks-darwin") + copy_update_files_darwin(update_folder, pppparent_dir, updater_name) + + else: + delete_path(main_exe) + wait_for_unlock(os.path.join(parent_dir, "_internal")) + + # Copy new files excluding the updater itself + copy_update_files(update_folder, parent_dir, updater_name) + + except Exception as e: + log(f"Something went wrong: {e}") + + # Relaunch main app + try: + if PLATFORM_NAME == 'linux': + os.chmod(main_exe, 0o755) + log("Added executable bit") + + if PLATFORM_NAME == 'darwin': + os.chmod(ppparent_dir, 0o755) + log("Added executable bit") + remove_quarantine(ppparent_dir) + log(f"Removed the quarantine flag on {ppparent_dir}") + subprocess.Popen(['open', ppparent_dir, "--args", "--finish-update"]) + else: + subprocess.Popen([main_exe, "--finish-update"], cwd=parent_dir) + + log("Relaunched main app.") + except Exception as e: + log(f"Failed to relaunch main app: {e}") + + log("Updater completed. Exiting.") + sys.exit(0) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/version_main.txt b/version_main.txt new file mode 100644 index 0000000..73c6ead --- /dev/null +++ b/version_main.txt @@ -0,0 +1,29 @@ +VSVersionInfo( + ffi=FixedFileInfo( + filevers=(1, 0, 0, 0), + prodvers=(1, 0, 0, 0), + mask=0x3f, + flags=0x0, + OS=0x4, + fileType=0x1, + subtype=0x0, + date=(0, 0) + ), + kids=[ + StringFileInfo( + [ + StringTable( + '040904B0', + [StringStruct('CompanyName', 'Tyler de Zeeuw'), + StringStruct('FileDescription', 'SPARKS main application'), + StringStruct('FileVersion', '1.0.0.0'), + StringStruct('InternalName', 'sparks.exe'), + StringStruct('LegalCopyright', '© 2025 Tyler de Zeeuw'), + StringStruct('OriginalFilename', 'sparks.exe'), + StringStruct('ProductName', 'SPARKS'), + StringStruct('ProductVersion', '1.0.0.0')]) + ] + ), + VarFileInfo([VarStruct('Translation', [1033, 1200])]) + ] +) diff --git a/version_updater.txt b/version_updater.txt new file mode 100644 index 0000000..1b75fe4 --- /dev/null +++ b/version_updater.txt @@ -0,0 +1,29 @@ +VSVersionInfo( + ffi=FixedFileInfo( + filevers=(1, 0, 0, 0), + prodvers=(1, 0, 0, 0), + mask=0x3f, + flags=0x0, + OS=0x4, + fileType=0x1, + subtype=0x0, + date=(0, 0) + ), + kids=[ + StringFileInfo( + [ + StringTable( + '040904B0', + [StringStruct('CompanyName', 'Tyler de Zeeuw'), + StringStruct('FileDescription', 'SPARKS updater application'), + StringStruct('FileVersion', '1.0.0.0'), + StringStruct('InternalName', 'main.exe'), + StringStruct('LegalCopyright', '© 2025 Tyler de Zeeuw'), + StringStruct('OriginalFilename', 'sparks_updater.exe'), + StringStruct('ProductName', 'SPARKS Updater'), + StringStruct('ProductVersion', '1.0.0.0')]) + ] + ), + VarFileInfo([VarStruct('Translation', [1033, 1200])]) + ] +)