""" Filename: main.py Description: SPARKS main executable Author: Tyler de Zeeuw License: GPL-3.0 """ # Built-in imports import os import re import csv import sys import math import json import time import shlex import pickle import shutil import zipfile import platform import traceback import subprocess from itertools import product from pathlib import Path, PurePosixPath from datetime import datetime from multiprocessing import Process, current_process, freeze_support, Manager import cv2 import numpy as np import pandas as pd # External library imports import psutil import requests import torch import torch.nn as nn import mediapipe as mp from torch.utils.data import TensorDataset, DataLoader from sklearn.utils.class_weight import compute_class_weight import matplotlib.pyplot as plt from PySide6.QtWidgets import ( QApplication, QWidget, QMessageBox, QVBoxLayout, QHBoxLayout, QTextEdit, QScrollArea, QComboBox, QGridLayout, QPushButton, QMainWindow, QFileDialog, QLabel, QLineEdit, QFrame, QSizePolicy, QGroupBox, QDialog, QListView, QMenu, QProgressBar, QCheckBox ) from PySide6.QtCore import QThread, Signal, Qt, QTimer, QEvent, QSize, QPoint from PySide6.QtGui import QAction, QKeySequence, QIcon, QIntValidator, QDoubleValidator, QPixmap, QStandardItemModel, QStandardItem, QImage from PySide6.QtSvgWidgets import QSvgWidget # needed to show svgs when app is not frozen CURRENT_VERSION = "1.0.0" API_URL = "https://git.research.dezeeuw.ca/api/v1/repos/tyler/sparks/releases" API_URL_SECONDARY = "https://git.research2.dezeeuw.ca/api/v1/repos/tyler/sparks/releases" PLATFORM_NAME = platform.system().lower() # Selectable parameters on the right side of the window SECTIONS = [ { "title": "Parameter 1", "params": [ {"name": "Number 1", "default": True, "type": bool, "help": "N/A"}, {"name": "Number 2", "default": True, "type": bool, "help": "N/A"}, ] }, ] class TrainModelThread(QThread): update = Signal(str) # new: emits messages for the GUI finished = Signal(str) def __init__(self, csv_paths, model_save_path, parent=None): super().__init__(parent) self.csv_paths = csv_paths self.model_save_path = model_save_path def run(self): # Load CSVs X_list, y_list, feature_cols = [], [], None for path in self.csv_paths: df = pd.read_csv(path) if feature_cols is None: feature_cols = [c for c in df.columns if c.startswith("lm")] df[feature_cols] = df[feature_cols].ffill().bfill() X_list.append(df[feature_cols].values) y_list.append(df[["reach_active", "reach_before_contact"]].values) X = np.concatenate(X_list, axis=0) y = np.concatenate(y_list, axis=0) # Windowing FPS, WINDOW_SEC, STEP_SEC, THRESHOLD, EPOCHS = 60, 0.4, 0.2, 0.5, 150 window_length = int(FPS * WINDOW_SEC) step_length = int(FPS * STEP_SEC) X_windows, y_windows = [], [] for start in range(0, len(X)-window_length+1, step_length): end = start+window_length X_win = X[start:end] y_win = y[start:end] label_window = (y_win.sum(axis=0) / len(y_win) >= 0.3).astype(int) # shape: (2,) X_windows.append(X[start:end]) y_windows.append(label_window) X_windows = np.array(X_windows) y_windows = np.array(y_windows) class_weights = [] for i in range(2): cw = compute_class_weight('balanced', classes=np.array([0,1]), y=y_windows[:, i]) class_weights.append(cw[1]/cw[0]) pos_weight = torch.tensor(class_weights, dtype=torch.float32) X_tensor = torch.tensor(np.array(X_windows), dtype=torch.float32) # (N, window_length, features) y_tensor = torch.tensor(np.array(y_windows), dtype=torch.float32) # (N, 2) # LSTM class WindowLSTM(nn.Module): def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2): super().__init__() self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True, bidirectional=bidirectional) self.fc = nn.Linear(hidden_size*(2 if bidirectional else 1), output_size) def forward(self, x): _, (h, _) = self.lstm(x) if self.lstm.bidirectional: h = torch.cat([h[-2], h[-1]], dim=1) else: h = h[-1] return self.fc(h) # Hyperparameter sweep HIDDEN_SIZES = [64] BATCH_SIZES = [32] LRs = [0.0005] OPTIMIZERS = ['adam'] BIDIRECTIONAL = [False] best_f1 = 0 best_config = None for hidden_size, batch_size, lr, opt_name, bi in product(HIDDEN_SIZES,BATCH_SIZES,LRs,OPTIMIZERS,BIDIRECTIONAL): self.update.emit(f"Training model: hidden={hidden_size}, batch={batch_size}, lr={lr}, opt={opt_name}, bidir={bi}") dataset = TensorDataset(X_tensor, y_tensor) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) model = WindowLSTM(input_size=X_windows.shape[2], hidden_size=hidden_size, bidirectional=bi) optimizer = torch.optim.Adam(model.parameters(), lr=lr) if opt_name=='adam' else torch.optim.SGD(model.parameters(), lr=lr) criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) # Training for epoch in range(EPOCHS): model.train() epoch_loss = 0 for xb, yb in loader: optimizer.zero_grad() logits = model(xb) loss = criterion(logits, yb) loss.backward() optimizer.step() epoch_loss += loss.item() # Evaluation model.eval() with torch.no_grad(): logits = model(X_tensor) probs = torch.sigmoid(logits) preds = (probs > THRESHOLD).int() y_true = y_tensor.int() tp = ((preds==1)&(y_true==1)).sum().item() fp = ((preds==1)&(y_true==0)).sum().item() fn = ((preds==0)&(y_true==1)).sum().item() precision = tp/(tp+fp+1e-8) recall = tp/(tp+fn+1e-8) f1 = 2*precision*recall/(precision+recall+1e-8) self.update.emit(f"Epoch {epoch+1}/{EPOCHS} Loss={epoch_loss:.4f} " f"Precision={precision:.3f} Recall={recall:.3f} F1={f1:.3f}") if f1 > best_f1: best_f1 = f1 # When saving the model after training best_config = (hidden_size, batch_size, lr, opt_name, bi) torch.save(model.state_dict(), self.model_save_path) self.update.emit(f"New best model saved to {os.path.basename(self.model_save_path)}!") self.finished.emit(f"Training complete!\nBest config: {best_config}\nF1={best_f1:.3f}") class TestModelThread(QThread): update_frame = Signal(np.ndarray) # emit frames with overlay finished = Signal() def __init__(self, video_paths, model_path, fps=60, window_sec=0.4, step_sec=0.2, threshold=0.5): super().__init__() self.video_paths = video_paths self.fps = fps self.window_sec = window_sec self.step_sec = step_sec self.threshold = threshold self.model_path = model_path self.window_frames = int(fps*window_sec) def run(self): from time import time class WindowLSTM(torch.nn.Module): def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2): super().__init__() self.lstm = torch.nn.LSTM( input_size, hidden_size, batch_first=True, bidirectional=bidirectional ) self.fc = torch.nn.Linear(hidden_size * (2 if bidirectional else 1), output_size) def forward(self, x): _, (h, _) = self.lstm(x) if self.lstm.bidirectional: h = torch.cat([h[-2], h[-1]], dim=1) else: h = h[-1] return self.fc(h) # MATCH THE MODEL CONFIG YOU TRAINED WITH model = WindowLSTM(input_size=63, hidden_size=64, bidirectional=False) model.load_state_dict(torch.load(self.model_path, map_location="cpu")) model.eval() print("Loaded model.") # ------------------------------ MediaPipe ------------------------------ mp_hands = mp.solutions.hands mp_drawing = mp.solutions.drawing_utils mp_drawing_styles = mp.solutions.drawing_styles hands = mp_hands.Hands( static_image_mode=False, max_num_hands=2, min_detection_confidence=0.5, min_tracking_confidence=0.5, model_complexity=0 ) # ------------------------------ Open videos ------------------------------ cap = cv2.VideoCapture(self.video_paths[0]) # buffer holds last 60 frames of landmarks window_buffer = [] total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) print(f"Original video dimensions: {original_width} x {original_height}") print(f"Total frames in video: {total_frames}") # Set display dimensions (maintaining aspect ratio) display_width = 800 aspect_ratio = original_height / original_width display_height = int(display_width * aspect_ratio) print(f"Display dimensions: {display_width} x {display_height}") window_name = f"Hand Tracking Test - Press 'q' to quit" cv2.namedWindow(window_name, cv2.WINDOW_NORMAL) cv2.resizeWindow(window_name, display_width, display_height) # Output window size display_width = 1000 frame_index = 0 # <<< ADDED to match your other script while True: ret, frame = cap.read() if not ret: break h, w = frame.shape[:2] # detect hands rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) results = hands.process(rgb) # find RIGHT hand features = [] right_found = False if results.multi_hand_landmarks and results.multi_handedness: for i, handness in enumerate(results.multi_handedness): label = handness.classification[0].label if label == "Right": lm = results.multi_hand_landmarks[i] right_found = True for p in lm.landmark: features += [p.x, p.y, p.z] # draw hand with SAME STYLES mp_drawing.draw_landmarks( frame, lm, mp_hands.HAND_CONNECTIONS, mp_drawing_styles.get_default_hand_landmarks_style(), mp_drawing_styles.get_default_hand_connections_style(), ) break if not right_found: features = [math.nan] * 63 print(features) # -------------------------- # Maintain window buffer # -------------------------- window_buffer.append(features) # forward fill arr = np.array(window_buffer, dtype=np.float32) for c in range(arr.shape[1]): valid = ~np.isnan(arr[:, c]) if valid.any(): arr[:, c][~valid] = np.interp( np.flatnonzero(~valid), np.flatnonzero(valid), arr[:, c][valid], ) window_buffer = arr.tolist() # only last 60 frames if len(window_buffer) > self.window_frames: window_buffer.pop(0) prob = 0 pred_active = 0 pred_before_contact = 0 # -------------------------- # run inference when full window available # -------------------------- if len(window_buffer) == self.window_frames: x = torch.tensor(window_buffer, dtype=torch.float32).unsqueeze(0) with torch.no_grad(): logits = model(x) probs = torch.sigmoid(logits).squeeze(0).numpy() pred_active = int(probs[0] > self.threshold) pred_before_contact = int(probs[1] > self.threshold) color_active = (0, 255, 0) if pred_active else (0, 0, 255) color_before_contact = (255, 255, 0) if pred_before_contact else (0, 0, 128) # Overlay rectangles cv2.rectangle(frame, (0, 0), (w//2, 40), color_active, -1) cv2.putText(frame, f"Reach Active: {'YES' if pred_active else 'NO'} ({probs[0]:.2f})", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,255), 2) cv2.rectangle(frame, (w//2, 0), (w, 40), color_before_contact, -1) cv2.putText(frame, f"Reach Before Contact: {'YES' if pred_before_contact else 'NO'} ({probs[1]:.2f})", (w//2 + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,255), 2) else: cv2.rectangle(frame, (0, 0), (w, 40), (0, 128, 255), -1) cv2.putText( frame, f"Collecting window: {len(window_buffer)}/{self.window_frames}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2, ) # ------------------------------ # Extra overlays to MATCH your big script # ------------------------------ info_text = [ f"Frame: {frame_index}", f"Time: {frame_index / self.fps:.2f}s", f"Reach Active: {'YES' if pred_active else 'NO'}", f"Reach Before Contact: {'YES' if pred_before_contact else 'NO'}", f"Right Hand: {'DETECTED' if right_found else 'NOT DETECTED'}", "Press 'q' to quit", ] for i, text in enumerate(info_text): y = 60 + i * 25 cv2.putText(frame, text, (10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,255,0), 2) # If predicted active, add highlight frame if pred_active or pred_before_contact: cv2.rectangle(frame, (0,0), (w,h), (0,255,0), 6) display_frame = frame.copy() # resize for display resized_frame = cv2.resize(display_frame, (display_width, display_height), interpolation=cv2.INTER_AREA) cv2.imshow(window_name, resized_frame) frame_index += 1 if cv2.waitKey(1) & 0xFF == ord("q"): break self.finished.emit() class TestModelThread2(QThread): finished = Signal(object) def __init__(self, csv_path, model_path, fps=60, window_sec=0.4, threshold=0.5): super().__init__() self.csv_path = csv_path self.model_path = model_path self.fps = fps self.window_sec = window_sec self.window_frames = int(fps * window_sec) self.threshold = threshold def run(self): try: self._run_internal() except Exception as e: import traceback traceback.print_exc() self.error = str(e) self.finished.emit({"error": str(e)}) def _run_internal(self): # -------------------------------------------------- # 1. Model Definition and Load (Copied from TestModelThread) # -------------------------------------------------- class WindowLSTM(torch.nn.Module): def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2): super().__init__() self.lstm = torch.nn.LSTM( input_size, hidden_size, batch_first=True, bidirectional=bidirectional ) self.fc = torch.nn.Linear(hidden_size * (2 if bidirectional else 1), output_size) def forward(self, x): _, (h, _) = self.lstm(x) if self.lstm.bidirectional: h = torch.cat([h[-2], h[-1]], dim=1) else: h = h[-1] return self.fc(h) # Assuming input_size=63, hidden_size=64, bidirectional=False, as in TestModelThread # If your training params varied, you must update these defaults to match the saved model model = WindowLSTM(input_size=63, hidden_size=64, bidirectional=False) model.load_state_dict(torch.load(self.model_path, map_location="cpu")) model.eval() # -------------------------------------------------- # 2. Load CSV of landmarks (CRITICAL FIX APPLIED HERE) # -------------------------------------------------- df = pd.read_csv(self.csv_path) feature_cols = [c for c in df.columns if c.startswith("lm")] X_with_nans = df[feature_cols].values time = df["time_sec"].values if "time_sec" in df.columns else np.arange(len(X_with_nans)) / self.fps # -------------------------------------------------- # 3. Sliding window inference (EXACTLY MATCHING TestModelThread) # -------------------------------------------------- window_buffer = [] preds_active = [] preds_before_contact = [] for i in range(len(X_with_nans)): # Pad output if i < self.window_frames - 1: preds_active.append(0) preds_before_contact.append(0) # Add features (may contain NaNs) to the buffer features = X_with_nans[i] window_buffer.append(features) # Keep only last `window_frames` if len(window_buffer) > self.window_frames: window_buffer.pop(0) # 🚨 CRITICAL: PERFORM INTERPOLATION ON THE BUFFER # This block *must* run to smooth over lost frames, replicating the video logic. arr = np.array(window_buffer, dtype=np.float32) for c in range(arr.shape[1]): valid = ~np.isnan(arr[:, c]) if valid.any(): arr[:, c][~valid] = np.interp( np.flatnonzero(~valid), np.flatnonzero(valid), arr[:, c][valid], ) # 🚨 CRITICAL: Update the buffer back to the interpolated LIST OF LISTS. # This mirrors the exact data structure used by TestModelThread before calling torch.tensor(). window_buffer = arr.tolist() # predict only when full window if len(window_buffer) == self.window_frames: # 🚨 CRITICAL: Create tensor from the LIST OF LISTS buffer x = torch.tensor(window_buffer, dtype=torch.float32).unsqueeze(0) with torch.no_grad(): logits = model(x) probs = torch.sigmoid(logits).squeeze(0).numpy() preds_active.append(probs[0]) preds_before_contact.append(probs[1]) self.finished.emit({ "time": time, "active": preds_active, "before": preds_before_contact, "threshold": self.threshold, }) class TerminalWindow(QWidget): def __init__(self, parent=None): super().__init__(parent, Qt.WindowType.Window) self.setWindowTitle("Terminal - SPARKS") self.output_area = QTextEdit() self.output_area.setReadOnly(True) self.input_line = QLineEdit() self.input_line.returnPressed.connect(self.handle_command) layout = QVBoxLayout() layout.addWidget(self.output_area) layout.addWidget(self.input_line) self.setLayout(layout) self.commands = { "hello": self.cmd_hello, "help": self.cmd_help } def handle_command(self): command_text = self.input_line.text() self.input_line.clear() self.output_area.append(f"> {command_text}") parts = command_text.strip().split() if not parts: return command_name = parts[0] args = parts[1:] func = self.commands.get(command_name) if func: try: result = func(*args) if result: self.output_area.append(str(result)) except Exception as e: self.output_area.append(f"[Error] {e}") else: self.output_area.append(f"[Unknown command] '{command_name}'") def cmd_hello(self, *args): return "Hello from the terminal!" def cmd_help(self, *args): return f"Available commands: {', '.join(self.commands.keys())}" class UpdateDownloadThread(QThread): """ Thread that downloads and extracts an update package and emits a signal on completion or error. Args: download_url (str): URL of the update zip file to download. latest_version (str): Version string of the latest update. """ update_ready = Signal(str, str) error_occurred = Signal(str) def __init__(self, download_url, latest_version): super().__init__() self.download_url = download_url self.latest_version = latest_version def run(self): try: local_filename = os.path.basename(self.download_url) if PLATFORM_NAME == 'darwin': tmp_dir = '/tmp/sparkstempupdate' os.makedirs(tmp_dir, exist_ok=True) local_path = os.path.join(tmp_dir, local_filename) else: local_path = os.path.join(os.getcwd(), local_filename) # Download the file with requests.get(self.download_url, stream=True, timeout=15) as r: r.raise_for_status() with open(local_path, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): if chunk: f.write(chunk) # Extract folder name (remove .zip) if PLATFORM_NAME == 'darwin': extract_folder = os.path.splitext(local_filename)[0] extract_path = os.path.join(tmp_dir, extract_folder) else: extract_folder = os.path.splitext(local_filename)[0] extract_path = os.path.join(os.getcwd(), extract_folder) # Create the folder if not exists os.makedirs(extract_path, exist_ok=True) # Extract the zip file contents if PLATFORM_NAME == 'darwin': subprocess.run(['ditto', '-xk', local_path, extract_path], check=True) else: with zipfile.ZipFile(local_path, 'r') as zip_ref: zip_ref.extractall(extract_path) # Remove the zip once extracted and emit a signal os.remove(local_path) self.update_ready.emit(self.latest_version, extract_path) except Exception as e: # Emit a signal signifying failure self.error_occurred.emit(str(e)) class UpdateCheckThread(QThread): """ Thread that checks for updates by querying the API and emits a signal based on the result. Signals: download_requested(str, str): Emitted with (download_url, latest_version) when an update is available. no_update_available(): Emitted when no update is found or current version is up to date. error_occurred(str): Emitted with an error message if the update check fails. """ download_requested = Signal(str, str) no_update_available = Signal() error_occurred = Signal(str) def run(self): if not getattr(sys, 'frozen', False): self.error_occurred.emit("Application is not frozen (Development mode).") return try: latest_version, download_url = self.get_latest_release_for_platform() if not latest_version: self.no_update_available.emit() return if not download_url: self.error_occurred.emit(f"No download available for platform '{PLATFORM_NAME}'") return if self.version_compare(latest_version, CURRENT_VERSION) > 0: self.download_requested.emit(download_url, latest_version) else: self.no_update_available.emit() except Exception as e: self.error_occurred.emit(f"Update check failed: {e}") def version_compare(self, v1, v2): def normalize(v): return [int(x) for x in v.split(".")] return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2)) def get_latest_release_for_platform(self): urls = [API_URL, API_URL_SECONDARY] for url in urls: try: response = requests.get(API_URL, timeout=5) response.raise_for_status() releases = response.json() if not releases: return None, None latest = releases[0] tag = latest["tag_name"].lstrip("v") for asset in latest.get("assets", []): if PLATFORM_NAME in asset["name"].lower(): return tag, asset["browser_download_url"] return tag, None except (requests.RequestException, ValueError) as e: continue return None, None class LocalPendingUpdateCheckThread(QThread): """ Thread that checks for locally pending updates by scanning the download directory and emits a signal accordingly. Args: current_version (str): Current application version. platform_suffix (str): Platform-specific suffix to identify update folders. """ pending_update_found = Signal(str, str) no_pending_update = Signal() def __init__(self, current_version, platform_suffix): super().__init__() self.current_version = current_version self.platform_suffix = platform_suffix def version_compare(self, v1, v2): def normalize(v): return [int(x) for x in v.split(".")] return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2)) def run(self): if PLATFORM_NAME == 'darwin': cwd = '/tmp/sparkstempupdate' else: cwd = os.getcwd() pattern = re.compile(r".*-(\d+\.\d+\.\d+)" + re.escape(self.platform_suffix) + r"$") found = False try: for item in os.listdir(cwd): folder_path = os.path.join(cwd, item) if os.path.isdir(folder_path) and item.endswith(self.platform_suffix): match = pattern.match(item) if match: folder_version = match.group(1) if self.version_compare(folder_version, self.current_version) > 0: self.pending_update_found.emit(folder_version, folder_path) found = True break except: pass if not found: self.no_pending_update.emit() class ProgressDialog(QDialog): def __init__(self, parent=None): super().__init__(parent) self.setWindowTitle("Processing Files...") self.setWindowModality(Qt.WindowModality.ApplicationModal) self.setMinimumWidth(300) layout = QVBoxLayout(self) self.message_label = QLabel("Loading and analyzing video files...") self.progress_bar = QProgressBar() self.progress_bar.setMaximum(0) # Indeterminate state initially self.cancel_button = QPushButton("Cancel") layout.addWidget(self.message_label) layout.addWidget(self.progress_bar) layout.addWidget(self.cancel_button) def update_progress(self, current, total): if self.progress_bar.maximum() == 0: self.progress_bar.setMaximum(total) self.progress_bar.setValue(current) self.message_label.setText(f"Processing file {current} of {total}...") class FileLoadWorker(QThread): observations_loaded = Signal(dict, dict, str) progress_update = Signal(int, int) loading_finished = Signal() loading_failed = Signal(str) def __init__(self, file_path, observations_root, extract_frame_and_hands_func): super().__init__() self.file_path = file_path self.observations_root = observations_root self.extract_frame_and_hands = extract_frame_and_hands_func self.is_running = True self.previews_data = {} # To store the results of the heavy work def resolve_video_path(self, relative_video_path): """Logic to handle flexible path resolution (basename or full join).""" corrected_file_name = os.path.basename(relative_video_path) video_path = os.path.join(self.observations_root, corrected_file_name) if not os.path.exists(video_path): # Fallback to full relative path join video_path = os.path.join(self.observations_root, relative_video_path) return video_path def run(self): try: # 1. Load the BORIS JSON (Non-UI work) with open(self.file_path, "r") as f: boris = json.load(f) observations = boris.get("observations", {}) if not observations: self.loading_failed.emit("No observations found in BORIS JSON.") return # 2. Process the file data to gather initial previews (the heavy part) self.process_previews(observations) # 3. Emit the final results back self.observations_loaded.emit(boris, observations, self.observations_root) except Exception as e: self.loading_failed.emit(f"Loading failed: {str(e)}") finally: self.loading_finished.emit() # Ensures dialog closes even on some non-critical path breaks def process_previews(self, observations): total_files = sum(len(obs_data.get("file", {}).items()) for obs_data in observations.values()) current_index = 0 for obs_id, obs_data in observations.items(): files_dict = obs_data.get("file", {}) media_info = obs_data.get("media_info", {}) fps_dict = media_info.get("fps", {}) # Simplified FPS calculation for the worker fps_default = float(list(fps_dict.values())[0]) if fps_dict and list(fps_dict.values()) else 30.0 print(media_info) for cam_id, paths in files_dict.items(): if not self.is_running: return # Allow cancellation if not paths: print(f"Skipping entry for {obs_id}/{cam_id}: paths list is empty (IndexError prevented).") current_index += 1 self.progress_update.emit(current_index, total_files) continue print(paths) relative_video_path = paths[0] video_path = self.resolve_video_path(relative_video_path) print(video_path) if os.path.exists(video_path): # HEAVY OPERATION frame_rgb, results = self.extract_frame_and_hands(video_path, 0) initial_wrists = [] if results and results.multi_hand_landmarks: for lm in results.multi_hand_landmarks: wrist = lm.landmark[0] initial_wrists.append((wrist.x, wrist.y)) # Store the results needed to BUILD the UI later self.previews_data[(obs_id, cam_id)] = { "frame_rgb": frame_rgb, "results": results, "video_path": video_path, "fps": fps_default, "initial_wrists": initial_wrists } current_index += 1 self.progress_update.emit(current_index, total_files) def stop(self): self.is_running = False class AboutWindow(QWidget): """ Simple About window displaying basic application information. Args: parent (QWidget, optional): Parent widget of this window. Defaults to None. """ def __init__(self, parent=None): super().__init__(parent, Qt.WindowType.Window) self.setWindowTitle("About SPARKS") self.resize(250, 100) layout = QVBoxLayout() label = QLabel("About SPARKS", self) label2 = QLabel("Spacial Patterns Analysis, Research, & Knowledge Suite", self) label3 = QLabel("SPARKS is licensed under the GPL-3.0 licence. For more information, visit https://www.gnu.org/licenses/gpl-3.0.en.html", self) label4 = QLabel(f"Version v{CURRENT_VERSION}") layout.addWidget(label) layout.addWidget(label2) layout.addWidget(label3) layout.addWidget(label4) self.setLayout(layout) class UserGuideWindow(QWidget): """ Simple User Guide window displaying basic information on how to use the software. Args: parent (QWidget, optional): Parent widget of this window. Defaults to None. """ def __init__(self, parent=None): super().__init__(parent, Qt.WindowType.Window) self.setWindowTitle("User Guide - SPARKS") self.resize(250, 100) layout = QVBoxLayout() label = QLabel("Not Applicable:", self) label2 = QLabel("N/A", self) layout.addWidget(label) layout.addWidget(label2) self.setLayout(layout) class ParamSection(QWidget): """ A widget section that dynamically creates labeled input fields from parameter metadata. Args: section_data (dict): Dictionary containing section title and list of parameter info. Expected format: { "title": str, "params": [ { "name": str, "type": type, "default": any, "help": str (optional) }, ... ] } """ def __init__(self, section_data): super().__init__() layout = QVBoxLayout() self.setLayout(layout) self.widgets = {} self.selected_path = None # Title label title_label = QLabel(section_data["title"]) title_label.setStyleSheet("font-weight: bold; font-size: 14px; margin-top: 10px; margin-bottom: 5px;") layout.addWidget(title_label) # Horizontal line line = QFrame() line.setFrameShape(QFrame.Shape.HLine) line.setFrameShadow(QFrame.Shadow.Sunken) layout.addWidget(line) for param in section_data["params"]: h_layout = QHBoxLayout() label = QLabel(param["name"]) label.setFixedWidth(180) label.setToolTip(param.get("help", "")) help_text = param.get("help", "") help_btn = QPushButton("?") help_btn.setFixedWidth(25) help_btn.setToolTip(help_text) help_btn.clicked.connect(lambda _, text=help_text: self.show_help_popup(text)) h_layout.addWidget(help_btn) h_layout.setStretch(0, 1) # Set stretch factor for button (10%) h_layout.addWidget(label) h_layout.setStretch(1, 3) # Set the stretch factor for label (40%) # Create input widget based on type if param["type"] == bool: widget = QComboBox() widget.addItems(["True", "False"]) widget.setCurrentText(str(param["default"])) elif param["type"] == int: widget = QLineEdit() widget.setValidator(QIntValidator()) widget.setText(str(param["default"])) elif param["type"] == float: widget = QLineEdit() widget.setValidator(QDoubleValidator()) widget.setText(str(param["default"])) elif param["type"] == list: widget = self._create_multiselect_dropdown(None) else: widget = QLineEdit() widget.setText(str(param["default"])) widget.setToolTip(help_text) h_layout.addWidget(widget) h_layout.setStretch(2, 5) # Set stretch factor for input field (50%) layout.addLayout(h_layout) self.widgets[param["name"]] = { "widget": widget, "type": param["type"] } def _create_multiselect_dropdown(self, items): combo = FullClickComboBox() combo.setView(QListView()) model = QStandardItemModel() combo.setModel(model) combo.setEditable(True) combo.lineEdit().setReadOnly(True) combo.lineEdit().setPlaceholderText("Select...") dummy_item = QStandardItem("") dummy_item.setFlags(Qt.ItemIsEnabled) model.appendRow(dummy_item) toggle_item = QStandardItem("Toggle Select All") toggle_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) toggle_item.setData(Qt.Unchecked, Qt.CheckStateRole) model.appendRow(toggle_item) if items is not None: for item in items: standard_item = QStandardItem(item) standard_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) standard_item.setData(Qt.Unchecked, Qt.CheckStateRole) model.appendRow(standard_item) combo.setInsertPolicy(QComboBox.NoInsert) def on_view_clicked(index): item = model.itemFromIndex(index) if item.isCheckable(): new_state = Qt.Checked if item.checkState() == Qt.Unchecked else Qt.Unchecked item.setCheckState(new_state) combo.view().pressed.connect(on_view_clicked) self._updating_checkstates = False def on_item_changed(item): if self._updating_checkstates: return self._updating_checkstates = True normal_items = [model.item(i) for i in range(2, model.rowCount())] # skip dummy and toggle if item == toggle_item: all_checked = all(i.checkState() == Qt.Checked for i in normal_items) if all_checked: for i in normal_items: i.setCheckState(Qt.Unchecked) toggle_item.setCheckState(Qt.Unchecked) else: for i in normal_items: i.setCheckState(Qt.Checked) toggle_item.setCheckState(Qt.Checked) elif item == dummy_item: pass else: # When normal items change, update toggle item all_checked = all(i.checkState() == Qt.Checked for i in normal_items) toggle_item.setCheckState(Qt.Checked if all_checked else Qt.Unchecked) self._updating_checkstates = False for param_name, info in self.widgets.items(): if info["widget"] == combo: self.update_dropdown_label(param_name) break model.itemChanged.connect(on_item_changed) combo.setInsertPolicy(QComboBox.NoInsert) return combo def show_help_popup(self, text): msg = QMessageBox(self) msg.setWindowTitle("Parameter Info - SPARKS") msg.setText(text) msg.exec() def get_param_values(self): values = {} for name, info in self.widgets.items(): widget = info["widget"] expected_type = info["type"] if expected_type == bool: values[name] = widget.currentText() == "True" elif expected_type == list: values[name] = [x.strip() for x in widget.lineEdit().text().split(",") if x.strip()] else: raw_text = widget.text() try: if expected_type == int: values[name] = int(raw_text) elif expected_type == float: values[name] = float(raw_text) elif expected_type == str: values[name] = raw_text else: values[name] = raw_text # Fallback except Exception as e: raise ValueError(f"Invalid value for {name}: {raw_text}") from e return values class FullClickLineEdit(QLineEdit): def mousePressEvent(self, event): combo = self.parent() if isinstance(combo, QComboBox): combo.showPopup() super().mousePressEvent(event) class FullClickComboBox(QComboBox): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.setLineEdit(FullClickLineEdit(self)) self.lineEdit().setReadOnly(True) class ParticipantProcessor(QThread): progress_updated = Signal(int) frame_ready = Signal(QImage) finished_processing = Signal(str) time_updated = Signal(str) def __init__(self, obs_id, boris_json, selected_cam_id, selected_hand_idx, hide_preview=False, output_csv=None, observations_root=None, output_dir=None, initial_wrists=None): super().__init__() self.obs_id = obs_id self.boris_json = boris_json self.selected_cam_id = selected_cam_id self.selected_hand_idx = selected_hand_idx self.hide_preview = hide_preview self.output_csv = output_csv self.observations_root = observations_root self.output_dir = output_dir self.initial_wrists = initial_wrists # Mediapipe hands self.mp_hands = mp.solutions.hands.Hands( static_image_mode=False, max_num_hands=2, min_detection_confidence=0.5, min_tracking_confidence=0.5, model_complexity=0 ) self.is_running = True def run(self): observations = self.boris_json.get("observations", {}) obs_data = observations[self.obs_id] # ---------------- Camera / Video ---------------- cameras_with_files = obs_data.get("file", {}) media_info = obs_data.get("media_info", {}) offset_dict = media_info.get("offset", {}) # paths = files_dict.get(self.cam_id, []) video_file_list = cameras_with_files[self.selected_cam_id] relative_video_path = video_file_list[0] # 2. Resolve the FULL path using the stored root and the same logic as before corrected_file_name = os.path.basename(relative_video_path) # Try finding just the file name in the root folder video_path = os.path.join(self.observations_root, corrected_file_name) # if not os.path.exists(video_path): # # Fallback to the full relative path join # video_path = os.path.join(self.observations_root, relative_video_path) if not os.path.exists(video_path): print(f"FATAL ERROR: Video file not found at: {video_path}") self.finished_processing.emit(f"Error: Video not found for {self.obs_id}") return camera_offset = float(offset_dict.get(self.selected_cam_id, 0.0)) fps_dict = media_info.get("fps", {}) fps = float(list(fps_dict.values())[0]) if fps_dict else 30.0 cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_seconds = total_frames / fps total_time_str = self._format_seconds(total_seconds) # ---------------- Trial Segments ---------------- raw_events = obs_data.get("events", []) segments = [] reach_before_contact_segments = [] current_start = None extra_duration = 1.0 for ev in raw_events: if len(ev) < 3: continue time_sec = ev[0] name = ev[2] if name == "Trial Start": current_start = time_sec current_contact = None elif name == "Contact" and current_start is not None: current_contact = time_sec reach_before_contact_segments.append((current_start, current_contact)) elif name == "End" and current_start is not None: adjusted_start = current_start adjusted_end = time_sec + extra_duration segments.append((adjusted_start, adjusted_end)) if current_contact is None: reach_before_contact_segments.append((current_start, adjusted_end)) current_start = None # ---------------- CSV Header ---------------- header = ["frame_index", "time_sec", "adjusted_time_sec", "reach_active", "reach_before_contact", "camera_offset"] for i in range(21): header += [f"lm{i}_x", f"lm{i}_y", f"lm{i}_z"] output_path = os.path.join(self.output_dir, self.output_csv) print(output_path) with open(output_path, "w", newline="") as csv_file: csv_writer = csv.writer(csv_file) csv_writer.writerow(header) # ---------------- Frame Processing ---------------- frame_index = 0 while True: if not self.is_running: print(f"Processor for {self.obs_id}/{self.cam_id} was cancelled gracefully.") # Clean up resources before returning cap.release() return # Thread exits gracefully ret, frame = cap.read() if not ret: break current_time = frame_index / fps adjusted_time = current_time # you could add extra logic if needed # Check if reach is active with offset is_reach_active = any( (start - camera_offset) <= current_time <= (end - camera_offset) for start, end in segments ) is_reach_before_contact = any( (start - camera_offset) <= current_time <= (end - camera_offset) for start, end in reach_before_contact_segments ) row = [ frame_index, current_time, adjusted_time, int(is_reach_active), int(is_reach_before_contact), camera_offset ] # ---------------- MediaPipe Hands ---------------- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) results = self.mp_hands.process(rgb_frame) hand_found = False if results.multi_hand_landmarks and results.multi_handedness: # Map detected hands to initial_wrists hand_mapping = {} used_indices = set() if self.initial_wrists: for i, iw in enumerate(self.initial_wrists): min_dist = float('inf') best_idx = None for j, lm in enumerate(results.multi_hand_landmarks): if j in used_indices: continue wrist = lm.landmark[0] dist = (wrist.x - iw[0])**2 + (wrist.y - iw[1])**2 if dist < min_dist: min_dist = dist best_idx = j if best_idx is not None: hand_mapping[i] = best_idx used_indices.add(best_idx) else: hand_mapping = {i: i for i in range(len(results.multi_hand_landmarks))} # Use selected_hand_idx to pick the right hand mapped_idx = hand_mapping.get(self.selected_hand_idx) if mapped_idx is not None: hand_landmarks = results.multi_hand_landmarks[mapped_idx] hand_found = True for lm in hand_landmarks.landmark: row += [lm.x, lm.y, lm.z] if not hand_found: row += [math.nan] * (21 * 3) csv_writer.writerow(row) # ---------------- Preview ---------------- if not self.hide_preview: display_frame = frame.copy() if results.multi_hand_landmarks and hand_found: mp.solutions.drawing_utils.draw_landmarks( display_frame, hand_landmarks, mp.solutions.hands.HAND_CONNECTIONS ) # Convert to QImage h, w, _ = display_frame.shape qimg = QImage(display_frame.data, w, h, 3*w, QImage.Format_RGB888) self.frame_ready.emit(qimg) current_seconds = frame_index / fps current_time_str = self._format_seconds(current_seconds) # ---------------- Progress ---------------- progress = int((frame_index / total_frames) * 100) self.progress_updated.emit(progress) self.time_updated.emit(f"{current_time_str}/{total_time_str}") frame_index += 1 cap.release() self.mp_hands.close() self.finished_processing.emit(self.obs_id) def _format_seconds(self, seconds): if seconds < 0: return "00:00" minutes = math.floor(seconds / 60) secs = math.floor(seconds % 60) return f"{minutes:02d}:{secs:02d}" def stop(self): """Sets the flag to stop the thread's execution loop.""" self.is_running = False class MainApplication(QMainWindow): """ Main application window that creates and sets up the UI. """ progress_update_signal = Signal(str, int) def __init__(self): super().__init__() self.setWindowTitle("SPARKS") self.setGeometry(100, 100, 1280, 720) self.about = None self.help = None self.optodes = None self.events = None self.terminal = None self.bubble_widgets = {} self.param_sections = [] self.folder_paths = [] self.section_widget = None self.first_run = True self.selection_widgets = {} self.camera_state = {} # {(obs_id, cam_id): state info} self.processing_widgets = {} self.files_total = 0 # total number of files to process self.files_done = set() # set of file paths done (success or fail) self.files_failed = set() # set of failed file paths self.files_results = {} # dict for successful results (if needed) self.processing_threads = [] # List to hold all active ParticipantProcessor threads self.processing_widgets = {} self.init_ui() self.create_menu_bar() self.platform_suffix = "-" + PLATFORM_NAME self.pending_update_version = None self.pending_update_path = None self.last_clicked_bubble = None self.installEventFilter(self) self.file_metadata = {} self.current_file = None self.worker_thread = None self.progress_dialog = None # Mediapipe hands self.mp_hands = mp.solutions.hands.Hands( static_image_mode=True, max_num_hands=2, min_detection_confidence=0.5 ) # Start local pending update check thread self.local_check_thread = LocalPendingUpdateCheckThread(CURRENT_VERSION, self.platform_suffix) self.local_check_thread.pending_update_found.connect(self.on_pending_update_found) self.local_check_thread.no_pending_update.connect(self.on_no_pending_update) self.local_check_thread.start() def init_ui(self): # Central widget and main horizontal layout central = QWidget() self.setCentralWidget(central) main_layout = QHBoxLayout() central.setLayout(main_layout) # Left container with vertical layout: top left + bottom left left_container = QWidget() left_layout = QVBoxLayout() left_container.setLayout(left_layout) left_container.setMinimumWidth(300) top_left_container = QGroupBox() top_left_container.setTitle("File information") top_left_container.setStyleSheet("QGroupBox { font-weight: bold; }") # Style if needed top_left_container.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding) top_left_layout = QHBoxLayout() top_left_container.setLayout(top_left_layout) # QTextEdit with fixed height, but only 80% width self.top_left_widget = QTextEdit() self.top_left_widget.setReadOnly(True) self.top_left_widget.setPlaceholderText("Logging information will be available here.") # Add QTextEdit to the layout with a stretch factor top_left_layout.addWidget(self.top_left_widget, stretch=4) # 80% # Create a vertical box layout for the right 20% self.right_column_widget = QWidget() right_column_layout = QVBoxLayout() self.right_column_widget.setLayout(right_column_layout) self.meta_fields = { "HAND": QLineEdit(), "GENDER": QLineEdit(), "GROUP": QLineEdit(), } for key, field in self.meta_fields.items(): label = QLabel(key.capitalize()) field.setPlaceholderText(f"Enter {key}") right_column_layout.addWidget(label) right_column_layout.addWidget(field) label_desc = QLabel('Why are these useful?') label_desc.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction) label_desc.setOpenExternalLinks(False) def show_info_popup(): QMessageBox.information(None, "Parameter Info - SPARKS", "Age: Used to calculate the DPF factor.\nGender: Not currently used. " "Will be able to sort into groups by gender in the near future.\nGroup: Allows contrast " "images to be created comparing one group to another once the processing has completed.") label_desc.linkActivated.connect(show_info_popup) right_column_layout.addWidget(label_desc) right_column_layout.addStretch() # Push fields to top self.right_column_widget.hide() # Add right column widget to the top-left layout (takes 20% width) top_left_layout.addWidget(self.right_column_widget, stretch=1) # Add top_left_container to the main left_layout left_layout.addWidget(top_left_container, stretch=2) # Bottom left: the bubbles inside the scroll area self.bubble_container = QWidget() self.bubble_layout = QGridLayout() self.bubble_layout.setAlignment(Qt.AlignmentFlag.AlignTop) self.bubble_container.setLayout(self.bubble_layout) self.scroll_area = QScrollArea() self.scroll_area.setWidgetResizable(True) self.scroll_area.setWidget(self.bubble_container) self.scroll_area.setMinimumHeight(300) # Add top left and bottom left to left layout left_layout.addWidget(self.scroll_area, stretch=8) # Right widget (full height on right side) self.right_container = QWidget() right_container_layout = QVBoxLayout() self.right_container.setLayout(right_container_layout) # Content widget inside scroll area self.right_content_widget = QWidget() right_content_layout = QVBoxLayout() self.right_content_widget.setLayout(right_content_layout) # Option selector dropdown self.option_selector = QComboBox() self.option_selector.addItems(["N/A"]) right_content_layout.addWidget(self.option_selector) # Container for the sections self.rows_container = QWidget() self.rows_layout = QVBoxLayout() self.rows_layout.setSpacing(10) self.rows_container.setLayout(self.rows_layout) right_content_layout.addWidget(self.rows_container) # Spacer at bottom inside scroll area content to push content up right_content_layout.addStretch() # Scroll area for the right side content self.right_scroll_area = QScrollArea() self.right_scroll_area.setWidgetResizable(True) self.right_scroll_area.setWidget(self.right_content_widget) # Buttons widget (fixed below the scroll area) buttons_widget = QWidget() buttons_layout = QHBoxLayout() buttons_widget.setLayout(buttons_layout) buttons_layout.addStretch() self.button1 = QPushButton("Process") self.button2 = QPushButton("Clear") buttons_layout.addWidget(self.button1) buttons_layout.addWidget(self.button2) self.button1.setMinimumSize(100, 40) self.button2.setMinimumSize(100, 40) self.button1.setVisible(False) self.button1.clicked.connect(self.on_run_task) self.button2.clicked.connect(self.clear_all) # Add scroll area and buttons widget to right container layout right_container_layout.addWidget(self.right_scroll_area) right_container_layout.addWidget(buttons_widget) # Add left and right containers to main layout main_layout.addWidget(left_container, stretch=70) main_layout.addWidget(self.right_container, stretch=30) # Set size policy to expand self.right_container.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) self.right_scroll_area.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) # Store ParamSection widgets self.option_selector.currentIndexChanged.connect(self.update_sections) # Initial build self.update_sections(0) def create_menu_bar(self): '''Menu Bar at the top of the screen''' menu_bar = self.menuBar() def make_action(name, shortcut=None, slot=None, checkable=False, checked=False, icon=None): action = QAction(name, self) if shortcut: action.setShortcut(QKeySequence(shortcut)) if slot: action.triggered.connect(slot) if checkable: action.setCheckable(True) action.setChecked(checked) if icon: action.setIcon(QIcon(icon)) return action # File menu and actions file_menu = menu_bar.addMenu("File") file_actions = [ ("Open BORIS file...", "Ctrl+O", self.open_file_dialog, resource_path("icons/file_open_24dp_1F1F1F.svg")), #("Open Folder...", "Ctrl+Alt+O", self.open_folder_dialog, resource_path("icons/folder_24dp_1F1F1F.svg")), #("Open Folders...", "Ctrl+Shift+O", self.open_folder_dialog, resource_path("icons/folder_copy_24dp_1F1F1F.svg")), #("Load Project...", "Ctrl+L", self.load_project, resource_path("icons/article_24dp_1F1F1F.svg")), #("Save Project...", "Ctrl+S", self.save_project, resource_path("icons/save_24dp_1F1F1F.svg")), #("Save Project As...", "Ctrl+Shift+S", self.save_project, resource_path("icons/save_as_24dp_1F1F1F.svg")), ] for i, (name, shortcut, slot, icon) in enumerate(file_actions): file_menu.addAction(make_action(name, shortcut, slot, icon=icon)) if i == 1: # after the first 3 actions (0,1,2) file_menu.addSeparator() file_menu.addSeparator() file_menu.addAction(make_action("Exit", "Ctrl+Q", QApplication.instance().quit, icon=resource_path("icons/exit_to_app_24dp_1F1F1F.svg"))) # Edit menu edit_menu = menu_bar.addMenu("Edit") edit_actions = [ ("Cut", "Ctrl+X", self.cut_text, resource_path("icons/content_cut_24dp_1F1F1F.svg")), ("Copy", "Ctrl+C", self.copy_text, resource_path("icons/content_copy_24dp_1F1F1F.svg")), ("Paste", "Ctrl+V", self.paste_text, resource_path("icons/content_paste_24dp_1F1F1F.svg")) ] for name, shortcut, slot, icon in edit_actions: edit_menu.addAction(make_action(name, shortcut, slot, icon=icon)) model_menu = menu_bar.addMenu("Model") model_actions = [ ("Create model from a CSV", "Ctrl+U", self.train_model_csv, resource_path("icons/content_cut_24dp_1F1F1F.svg")), ("Create model from a folder", "Ctrl+I", self.train_model_folder, resource_path("icons/content_copy_24dp_1F1F1F.svg")), ("Test model on a video", "Ctrl+O", self.test_model_video, resource_path("icons/content_paste_24dp_1F1F1F.svg")), ("Test model on a folder", "Ctrl+P", self.test_model_folder, resource_path("icons/content_paste_24dp_1F1F1F.svg")), ("Test model on a CSV", "Ctrl+P", self.test_model_csv, resource_path("icons/content_paste_24dp_1F1F1F.svg")) ] for i, (name, shortcut, slot, icon) in enumerate(model_actions): model_menu.addAction(make_action(name, shortcut, slot, icon=icon)) if i == 1: model_menu.addSeparator() # View menu view_menu = menu_bar.addMenu("View") toggle_statusbar_action = make_action("Toggle Status Bar", checkable=True, checked=True, slot=None) view_menu.addAction(toggle_statusbar_action) # Options menu (Help & About) options_menu = menu_bar.addMenu("Options") options_actions = [ ("User Guide", "F1", self.user_guide, resource_path("icons/help_24dp_1F1F1F.svg")), ("Check for Updates", "F5", self.manual_check_for_updates, resource_path("icons/update_24dp_1F1F1F.svg")), ("About", "F12", self.about_window, resource_path("icons/info_24dp_1F1F1F.svg")) ] for i, (name, shortcut, slot, icon) in enumerate(options_actions): options_menu.addAction(make_action(name, shortcut, slot, icon=icon)) if i == 1 or i == 3: # after the first 2 actions (0,1) options_menu.addSeparator() terminal_menu = menu_bar.addMenu("Terminal") terminal_actions = [ ("New Terminal", "Ctrl+Alt+T", self.terminal_gui, resource_path("icons/terminal_24dp_1F1F1F.svg")), ] for name, shortcut, slot, icon in terminal_actions: terminal_menu.addAction(make_action(name, shortcut, slot, icon=icon)) self.statusbar = self.statusBar() self.statusbar.showMessage("Ready") def update_sections(self, index): # Clear previous sections for i in reversed(range(self.rows_layout.count())): widget = self.rows_layout.itemAt(i).widget() if widget is not None: widget.deleteLater() self.param_sections.clear() # Add ParamSection widgets from SECTIONS for section in SECTIONS: self.section_widget = ParamSection(section) self.rows_layout.addWidget(self.section_widget) self.param_sections.append(self.section_widget) def clear_bubble_widgets(self): """Clears all widgets from the self.bubble_layout.""" while self.bubble_layout.count(): # Get the next item in the layout item = self.bubble_layout.takeAt(0) # Check if the item holds a widget (most common case: QGroupBox) widget = item.widget() if widget: widget.deleteLater() # Check if the item is a spacer/stretch item elif item.spacerItem(): self.bubble_layout.removeItem(item) def clear_all(self): self.cancel_task() self.right_column_widget.hide() # Clear the bubble layout while self.bubble_layout.count(): item = self.bubble_layout.takeAt(0) widget = item.widget() if widget: widget.deleteLater() # Clear file data self.bubble_widgets.clear() self.statusBar().clearMessage() self.raw_haemo_dict = None self.epochs_dict = None self.fig_bytes_dict = None self.cha_dict = None self.contrast_results_dict = None self.df_ind_dict = None self.design_matrix_dict = None self.age_dict = None self.gender_dict = None self.group_dict = None self.valid_dict = None # Reset any visible UI elements self.button1.setVisible(False) self.top_left_widget.clear() def get_output_directory(self): """Prompts user to select a parent folder and creates a unique, dated subfolder.""" # Open a folder selection dialog to get the parent folder parent_dir = QFileDialog.getExistingDirectory( self, "Select Parent Folder for Results", os.getcwd() # Start in the user's home directory ) if not parent_dir: return None # User cancelled the dialog # Create the unique subfolder name: sparks_YYYYMMDD_HHMMSS timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") subfolder_name = f"sparks_{timestamp}" output_path = os.path.join(parent_dir, subfolder_name) # Create the directory if it doesn't exist try: os.makedirs(output_path, exist_ok=True) print(f"Created output directory: {output_path}") return output_path except Exception as e: QMessageBox.critical(self, "Error", f"Failed to create output directory: {e}") return None def copy_text(self): self.top_left_widget.copy() # Trigger copy self.statusbar.showMessage("Copied to clipboard") # Show status message def cut_text(self): self.top_left_widget.cut() # Trigger cut self.statusbar.showMessage("Cut to clipboard") # Show status message def paste_text(self): self.top_left_widget.paste() # Trigger paste self.statusbar.showMessage("Pasted from clipboard") # Show status message def get_training_files(self): """Opens a dialog to select multiple CSV files for training.""" file_paths, _ = QFileDialog.getOpenFileNames( self, "Select Training CSV Files", os.getcwd(), # Start in current directory "CSV Files (*.csv)" # Filter for CSV files ) return file_paths def get_testing_files(self): """Opens dialogs to select the model file and the video file.""" # 1. Select Model File (.pth) model_path, _ = QFileDialog.getOpenFileName( self, "Select Trained Model File (.pth)", os.getcwd(), "Model Files (*.pth)" ) if not model_path: return None, None # 2. Select Video File video_path, _ = QFileDialog.getOpenFileName( self, "Select Video File for Testing", os.getcwd(), "Video Files (*.mp4 *.avi *.mov)" ) return model_path, video_path def get_testing_files2(self): """Opens dialogs to select the model file and the csv file.""" # 1. Select Model File (.pth) model_path, _ = QFileDialog.getOpenFileName( self, "Select Trained Model File (.pth)", os.getcwd(), "Model Files (*.pth)" ) if not model_path: return None, None # 2. Select Video File csv_path, _ = QFileDialog.getOpenFileName( self, "Select Video File for Testing", os.getcwd(), "Comma Seperated Files (*.csv)" ) return model_path, csv_path def train_model_csv(self): """Initiates the model training process, including selecting CSVs and save path.""" # 1. Get CSV paths csv_paths = self.get_training_files() # Re-use the existing CSV file selection method if not csv_paths: self.statusBar().showMessage("Training cancelled. No CSV files selected.") return # 2. Get Model Save Path (New Pop-up) default_filename = "best_reach_lstm.pth" # The third argument can include a default filename model_save_path, _ = QFileDialog.getSaveFileName( self, "Save Trained Model File", os.path.join(os.getcwd(), default_filename), # Start in CWD with a default name "PyTorch Model Files (*.pth);;All Files (*)" ) if not model_save_path: self.statusBar().showMessage("Training cancelled. Model save location not specified.") return # Ensure correct extension if the user didn't type one if not model_save_path.lower().endswith('.pth'): model_save_path += '.pth' self.statusBar().showMessage(f"Starting model training on {len(csv_paths)} files. Model will save to: {os.path.basename(model_save_path)}") # 3. Start the thread with the new path self.train_thread = TrainModelThread( csv_paths=csv_paths, model_save_path=model_save_path # Pass the path here ) # ... (connect signals, start thread, disable button, etc.) ... self.train_thread.update.connect(self.log_training_message) self.train_thread.finished.connect(self.on_train_finished) self.train_thread.start() def log_training_message(self, message): """Displays training status messages (e.g., in a QTextEdit).""" # Assuming you have a QTextEdit named self.log_output (or similar) # For simplicity, we'll use the status bar for now, but a dedicated log box is better. self.statusBar().showMessage(f"Training: {message}") # self.top_left_widget.append(f"[TRAIN] {message}") # Use this if you have a log box def on_train_finished(self, final_message): """Handles completion of the training thread.""" self.statusBar().showMessage(final_message) # Show final message in a pop-up or log box QMessageBox.information(self, "Training Complete", final_message) # Clean up and re-enable the button self.train_thread = None def train_model_folder(self): return def test_model_video(self): model_path, video_path = self.get_testing_files() if not model_path or not video_path: self.statusBar().showMessage("Testing cancelled. Model or Video not selected.") return self.statusBar().showMessage(f"Starting model testing on {os.path.basename(video_path)}...") # 1. Create the thread # Note: TestModelThread expects a list of video paths, so we pass [video_path] self.test_thread = TestModelThread(video_paths=[video_path], model_path=model_path) # 2. Connect signals # The update_frame signal is not used to update the main GUI here, # as the thread is responsible for showing the cv2.imshow window. self.test_thread.finished.connect(self.on_test_finished) # 3. Start the thread self.test_thread.start() def test_model_csv(self): model_path, csv_path = self.get_testing_files2() if not model_path or not csv_path: self.statusBar().showMessage("Testing cancelled. Model or Video not selected.") return self.statusBar().showMessage(f"Starting model testing on {os.path.basename(csv_path)}...") # 1. Create the thread # Note: TestModelThread expects a list of video paths, so we pass [video_path] self.test_thread = TestModelThread2(csv_path=csv_path, model_path=model_path) # 2. Connect signals # The update_frame signal is not used to update the main GUI here, # as the thread is responsible for showing the cv2.imshow window. self.test_thread.finished.connect(self.on_csv_analysis_finished) # 3. Start the thread self.test_thread.start() def on_csv_analysis_finished(self, result): if "error" in result: QMessageBox.critical(self, "Error", result["error"]) return time = result["time"] active = result["active"] before = result["before"] import matplotlib.pyplot as plt fig, axes = plt.subplots(2, 1, figsize=(14, 6), sharex=True) # --- Top plot: Reach Active --- axes[0].plot(time, active, color="green", label="Reach Active") axes[0].set_ylabel("Probability") axes[0].set_title("Reach Active") axes[0].set_ylim(0, 1) axes[0].grid(True) # --- Bottom plot: Reach Before Contact --- axes[1].plot(time, before, color="orange", label="Reach Before Contact") axes[1].set_ylabel("Probability") axes[1].set_title("Reach Before Contact") axes[1].set_xlabel("Time (s)") axes[1].set_ylim(0, 1) axes[1].grid(True) plt.tight_layout() plt.show() self.statusBar().showMessage("Complete.") def on_test_finished(self): """Handles cleanup after the testing video window is closed.""" self.statusBar().showMessage("Model testing complete.") # Ensure the OpenCV window is destroyed try: cv2.destroyAllWindows() except: pass # Clean up and re-enable the button self.test_thread = None def test_model_folder(self): return def about_window(self): if self.about is None or not self.about.isVisible(): self.about = AboutWindow(self) self.about.show() def user_guide(self): if self.help is None or not self.help.isVisible(): self.help = UserGuideWindow(self) self.help.show() def terminal_gui(self): if self.terminal is None or not self.terminal.isVisible(): self.terminal = TerminalWindow(self) self.terminal.show() def update_optode_positions(self): return if self.optodes is None or not self.optodes.isVisible(): self.optodes = UpdateOptodesWindow(self) self.optodes.show() def update_event_markers(self): return if self.events is None or not self.events.isVisible(): self.events = UpdateEventsWindow(self) self.events.show() def open_file_dialog(self): file_path, _ = QFileDialog.getOpenFileName( self, "Open File", "", "BORIS Files (*.boris);;All Files (*)" ) if file_path: self.observations_root = self.resolve_path_to_observations(file_path) if not self.observations_root: self.statusBar().showMessage("File loading cancelled by user.") return # User cancelled the folder selection # 2. Initialize and Show Progress Dialog self.progress_dialog = ProgressDialog(self) # ⚠️ Connect the cancel button to the task stopper self.progress_dialog.cancel_button.clicked.connect(self.cancel_task) self.progress_dialog.show() # 3. Initialize and Start Worker Thread self.worker_thread = FileLoadWorker( file_path, self.observations_root, self.extract_frame_and_hands # Assuming this is a method available to MainApplication ) # 4. Connect Signals to Main Thread Slots self.worker_thread.progress_update.connect(self.progress_dialog.update_progress) self.worker_thread.observations_loaded.connect(self.on_files_loaded) self.worker_thread.loading_finished.connect(self.on_loading_finished) self.worker_thread.loading_failed.connect(self.on_loading_failed) self.worker_thread.start() def resolve_path_to_observations(self, boris_file_path): """ Attempts to find the 'Observations' directory based on the BORIS file location. Prompts the user if it cannot be found. """ # 1. Get the directory containing the BORIS file boris_dir = os.path.dirname(boris_file_path) # 2. Check if the 'Observations' folder is adjacent to the BORIS file initial_check_path = os.path.join(boris_dir, "Observations") if os.path.isdir(initial_check_path): # Success: The 'Observations' folder is found next to the BORIS file return initial_check_path # 3. If not found, prompt the user to locate the Observations folder QMessageBox.warning( self, "Observations Folder Missing", "The 'Observations' folder (containing the videos) was not found next to the BORIS file. Please select the correct root folder." ) # Open a folder selection dialog folder_path = QFileDialog.getExistingDirectory( self, "Select Observations Root Folder", boris_dir # Start the dialog in the BORIS file's directory ) if folder_path: # Check if the user selected a folder named 'Observations' if os.path.basename(folder_path) == "Observations": return folder_path # Check if the user selected the *parent* of 'Observations' elif os.path.isdir(os.path.join(folder_path, "Observations")): return os.path.join(folder_path, "Observations") # If the selected folder is neither, return the selected path and hope it contains the relative paths else: return folder_path # If the user cancels the dialog, return None return None def on_files_loaded(self, boris, observations, root_path): # 1. Update MainApplication state variables self.boris = boris self.observations = observations self.observations_root = root_path # 2. Build the UI grid using the data gathered by the worker self.build_preview_grid(self.worker_thread.previews_data) self.statusBar().showMessage(f"{self.worker_thread.file_path} loaded.") self.button1.setVisible(True) def on_loading_finished(self): if self.progress_dialog: self.progress_dialog.accept() # Close the dialog if finished successfully def on_loading_failed(self, error_message): if self.progress_dialog: self.progress_dialog.reject() # Close the dialog QMessageBox.critical(self, "Error Loading Files", error_message) self.statusBar().showMessage("File loading failed.") def cancel_task(self): if self.worker_thread and self.worker_thread.isRunning(): self.worker_thread.stop() # Tell the thread to stop processing self.worker_thread.wait() # Wait for the thread to safely exit if self.progress_dialog: self.progress_dialog.reject() self.statusBar().showMessage("File loading cancelled.") # Add a dummy implementation for the processing function if it's not shown: # def extract_frame_and_hands(self, video_path, frame_idx): # # Placeholder implementation - replace with your actual function # return None, None def build_preview_grid(self, previews_data): for obs_id, obs_data in self.observations.items(): group = QGroupBox(f"Participant / Observation: {obs_id}") grouplayout = QVBoxLayout(group) # Participant-level skip dropdown participant_dropdown = QComboBox() participant_dropdown.addItem("Skip this participant", 0) participant_dropdown.addItem("Process this participant", 1) participant_dropdown.setCurrentIndex(0) # default: skip if "participant_selection" not in self.selection_widgets: self.selection_widgets["participant_selection"] = {} self.selection_widgets["participant_selection"][obs_id] = participant_dropdown grouplayout.addWidget(QLabel("Participant Option:")) grouplayout.addWidget(participant_dropdown) files_dict = obs_data.get("file", {}) media_info = obs_data.get("media_info", {}) print(media_info) fps_dict = media_info.get("fps", {}) fps_default = float(list(fps_dict.values())[0]) if fps_dict else 30.0 if obs_id not in self.selection_widgets: self.selection_widgets[obs_id] = {} for cam_id, paths in files_dict.items(): # Check if the worker successfully gathered the preview data for this file state_key = (obs_id, cam_id) if state_key not in previews_data: # Skip files the worker couldn't process (e.g., path not found) continue # Retrieve the pre-calculated data state_data = previews_data[state_key] frame_rgb = state_data["frame_rgb"] results = state_data["results"] # Only draw the overlay if the worker found data if frame_rgb is not None: display_img = self.draw_hand_overlay(frame_rgb, results) else: # Use a placeholder image or skip if no data continue # Convert to pixmap h, w, _ = display_img.shape qimg = QImage(display_img.data, w, h, 3 * w, QImage.Format_RGB888) pix = QPixmap.fromImage(qimg).scaled(350, 350, Qt.KeepAspectRatio) # -------- UI Row -------- row = QWidget() row_layout = QHBoxLayout(row) preview_label = QLabel() preview_label.setPixmap(pix) row_layout.addWidget(preview_label) # Dropdown dropdown = QComboBox() dropdown.addItem("Skip this camera", -1) if results and results.multi_hand_landmarks: for idx in range(len(results.multi_hand_landmarks)): dropdown.addItem(f"Use Hand {idx}", idx) row_layout.addWidget(dropdown) # Store dropdown self.selection_widgets[obs_id][cam_id] = dropdown self.camera_state[(obs_id, cam_id)] = { "path": state_data["video_path"], "frame_idx": 0, "fps": state_data["fps"], "preview_label": preview_label, "dropdown": dropdown, "initial_wrists": state_data["initial_wrists"] } # Skip button skip_btn = QPushButton("Skip 1s → Rescan") skip_btn.clicked.connect(lambda _, o=obs_id, c=cam_id: self.skip_and_rescan(o, c)) row_layout.addWidget(skip_btn) grouplayout.addWidget(row) self.bubble_layout.addWidget(group) #self.bubble_layout.addStretch() # ============================================ # FRAME EXTRACTION + MEDIA PIPE DETECTION # ============================================ def extract_frame_and_hands(self, video_path, frame_idx): cap = cv2.VideoCapture(video_path) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) ret, frame = cap.read() cap.release() if not ret: return None, None rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) results = self.mp_hands.process(rgb) return rgb, results # ============================================ # DRAW OVERLAYED HANDS + BIG LABELS # ============================================ def draw_hand_overlay(self, img, results, initial_wrists=None): draw_img = img.copy() h, w, _ = draw_img.shape COLORS = [ (0, 255, 0), # Hand 0 green (255, 0, 255), # Hand 1 magenta (0, 255, 255), # Hand 2 cyan ] if results and results.multi_hand_landmarks: # Map detected hands to initial wrists if provided if initial_wrists: hand_mapping = {} # initial_idx -> current_idx used_indices = set() for i, iw in enumerate(initial_wrists): min_dist = float('inf') best_idx = None for j, lm in enumerate(results.multi_hand_landmarks): if j in used_indices: continue # skip already assigned wrist = lm.landmark[0] dist = (wrist.x - iw[0])**2 + (wrist.y - iw[1])**2 if dist < min_dist: min_dist = dist best_idx = j if best_idx is not None: hand_mapping[i] = best_idx used_indices.add(best_idx) else: hand_mapping = {i: i for i in range(len(results.multi_hand_landmarks))} for initial_idx, current_idx in hand_mapping.items(): lm_obj = results.multi_hand_landmarks[current_idx] mp.solutions.drawing_utils.draw_landmarks( draw_img, lm_obj, mp.solutions.hands.HAND_CONNECTIONS ) wrist = lm_obj.landmark[0] wx, wy = int(wrist.x * w), int(wrist.y * h) color = COLORS[initial_idx % len(COLORS)] # Big outline cv2.putText(draw_img, str(initial_idx), (wx, wy - 40), cv2.FONT_HERSHEY_SIMPLEX, 3.2, (0, 0, 0), 10, cv2.LINE_AA) # Big colored label cv2.putText(draw_img, str(initial_idx), (wx, wy - 40), cv2.FONT_HERSHEY_SIMPLEX, 3.2, color, 6, cv2.LINE_AA) return draw_img # ============================================ # SKIP 1s AND RESCAN FRAME # ============================================ def skip_and_rescan(self, obs_id, cam_id): key = (obs_id, cam_id) state = self.camera_state[key] fps = state["fps"] state["frame_idx"] += int(fps) rgb, results = self.extract_frame_and_hands( state["path"], state["frame_idx"] ) if rgb is None: return display_img = self.draw_hand_overlay(rgb, results, state.get("initial_wrists")) h, w, _ = display_img.shape qimg = QImage(display_img.data, w, h, 3 * w, QImage.Format_RGB888) pix = QPixmap.fromImage(qimg).scaled(350, 350, Qt.KeepAspectRatio) # Update preview state["preview_label"].setPixmap(pix) # Update dropdown dropdown = state["dropdown"] dropdown.clear() dropdown.addItem("Skip this camera", -1) if results and results.multi_hand_landmarks: for idx in range(len(results.multi_hand_landmarks)): dropdown.addItem(f"Use Hand {idx}", idx) # def save_project(self, onCrash=False): # if not onCrash: # filename, _ = QFileDialog.getSaveFileName( # self, "Save Project", "", "SPARK Project (*.spark)" # ) # if not filename: # return # else: # if PLATFORM_NAME == "darwin": # filename = os.path.join(os.path.dirname(sys.executable), "../../../sparks_autosave.spark") # else: # filename = os.path.join(os.getcwd(), "sparks_autosave.spark") # try: # # Ensure the filename has the proper extension # if not filename.endswith(".spark"): # filename += ".spark" # project_path = Path(filename).resolve() # project_dir = project_path.parent # file_list = [ # str(PurePosixPath(Path(bubble.file_path).resolve().relative_to(project_dir))) # for bubble in self.bubble_widgets.values() # ] # progress_states = { # str(PurePosixPath(Path(bubble.file_path).resolve().relative_to(project_dir))): bubble.current_step # for bubble in self.bubble_widgets.values() # } # project_data = { # "file_list": file_list, # "progress_states": progress_states, # "raw_haemo_dict": self.raw_haemo_dict, # "epochs_dict": self.epochs_dict, # "fig_bytes_dict": self.fig_bytes_dict, # "cha_dict": self.cha_dict, # "contrast_results_dict": self.contrast_results_dict, # "df_ind_dict": self.df_ind_dict, # "design_matrix_dict": self.design_matrix_dict, # "age_dict": self.age_dict, # "gender_dict": self.gender_dict, # "group_dict": self.group_dict, # "valid_dict": self.valid_dict, # } # def sanitize(obj): # if isinstance(obj, Path): # return str(PurePosixPath(obj)) # elif isinstance(obj, dict): # return {sanitize(k): sanitize(v) for k, v in obj.items()} # elif isinstance(obj, list): # return [sanitize(i) for i in obj] # return obj # project_data = sanitize(project_data) # with open(filename, "wb") as f: # pickle.dump(project_data, f) # QMessageBox.information(self, "Success", f"Project saved to:\n{filename}") # except Exception as e: # if not onCrash: # QMessageBox.critical(self, "Error", f"Failed to save project:\n{e}") # def load_project(self): # filename, _ = QFileDialog.getOpenFileName( # self, "Load Project", "", "SPARK Project (*.spark)" # ) # if not filename: # return # try: # with open(filename, "rb") as f: # data = pickle.load(f) # self.raw_haemo_dict = data.get("raw_haemo_dict", {}) # self.epochs_dict = data.get("epochs_dict", {}) # self.fig_bytes_dict = data.get("fig_bytes_dict", {}) # self.cha_dict = data.get("cha_dict", {}) # self.contrast_results_dict = data.get("contrast_results_dict", {}) # self.df_ind_dict = data.get("df_ind_dict", {}) # self.design_matrix_dict = data.get("design_matrix_dict", {}) # self.age_dict = data.get("age_dict", {}) # self.gender_dict = data.get("gender_dict", {}) # self.group_dict = data.get("group_dict", {}) # self.valid_dict = data.get("valid_dict", {}) # project_dir = Path(filename).parent # # Convert saved relative paths to absolute paths # file_list = [str((project_dir / Path(rel_path)).resolve()) for rel_path in data["file_list"]] # # Also resolve progress_states with updated paths # raw_progress = data.get("progress_states", {}) # progress_states = { # str((project_dir / Path(rel_path)).resolve()): step # for rel_path, step in raw_progress.items() # } # self.show_files_as_bubbles_from_list(file_list, progress_states, filename) # # Re-enable buttons # # self.button1.setVisible(True) # self.button3.setVisible(True) # QMessageBox.information(self, "Loaded", f"Project loaded from:\n{filename}") # except Exception as e: # QMessageBox.critical(self, "Error", f"Failed to load project:\n{e}") def placeholder(self): QMessageBox.information(self, "Placeholder", "This feature is not implemented yet.") def save_metadata(self, file_path): if not file_path: return self.file_metadata[file_path] = { key: field.text() for key, field in self.meta_fields.items() } def get_all_metadata(self): # First, make sure current file's edits are saved for field in self.meta_fields.values(): field.clearFocus() # Save current file's metadata if self.current_file: self.save_metadata(self.current_file) return self.file_metadata def cancel_task(self): self.button1.clicked.disconnect(self.cancel_task) self.button1.setText("Stopping...") if hasattr(self, "result_process") and self.result_process.is_alive(): parent = psutil.Process(self.result_process.pid) children = parent.children(recursive=True) for child in children: try: child.kill() except psutil.NoSuchProcess: pass self.result_process.terminate() self.result_process.join() if hasattr(self, "result_timer") and self.result_timer.isActive(): self.result_timer.stop() # if hasattr(self, "result_process") and self.result_process.is_alive(): # self.result_process.terminate() # Forcefully terminate the process # self.result_process.join() # Wait for it to properly close # # Stop the QTimer if running # if hasattr(self, "result_timer") and self.result_timer.isActive(): # self.result_timer.stop() self.statusbar.showMessage("Processing cancelled.") self.button1.clicked.connect(self.on_run_task) self.button1.setText("Process") '''MODULE FILE''' def on_run_task(self): self.button1.clicked.disconnect(self.on_run_task) self.button1.setText("Cancel") self.button1.clicked.connect(self.cancel_task) # Clear previous processing widgets (if any) # for obs_id, widgets in getattr(self, "processing_widgets", {}).items(): # for w in widgets.values(): # if isinstance(w, QWidget): # w.deleteLater() self.clear_bubble_widgets() self.processing_widgets = {} self.processing_threads = [] self.output_dir = self.get_output_directory() if not self.output_dir: # Revert button state if user cancels self.button1.clicked.disconnect(self.cancel_task) self.button1.setText("Process") self.button1.clicked.connect(self.on_run_task) self.statusBar().showMessage("Processing cancelled by user.") return selected_files_to_process = [] # 2️⃣ Iterate through all participant/camera combinations to find what to process # We must iterate over all entries in self.selection_widgets (the structure built by build_preview_grid) for obs_id, cam_dropdowns in self.selection_widgets.items(): if obs_id == "participant_selection": continue # skip the participant-level dropdowns for cam_id, dropdown in cam_dropdowns.items(): selected_hand_idx = dropdown.currentData() # Check if the selection is NOT 'Skip' (-1) if selected_hand_idx != -1: state_data = self.camera_state.get((obs_id, cam_id)) if state_data: selected_files_to_process.append({ "obs_id": obs_id, "cam_id": cam_id, "hand_idx": selected_hand_idx, "path": state_data.get("path", "Unknown") }) if not selected_files_to_process: self.statusBar().showMessage("No files selected for processing (use dropdowns).") # Revert button state self.button1.clicked.disconnect(self.cancel_task) self.button1.setText("Process") self.button1.clicked.connect(self.on_run_task) return self.statusBar().showMessage(f"Starting to process {len(selected_files_to_process)} files...") QApplication.processEvents() # Refresh GUI before starting threads for file_info in selected_files_to_process: obs_id = file_info["obs_id"] cam_id = file_info["cam_id"] selected_hand_idx = file_info["hand_idx"] # THIS is fixed per iteration # ... (thread logic setup) ... # --- PROGRESS ROW UI SETUP (50% Label | 40% Bar | 10% Time) --- row = QWidget() row_layout = QHBoxLayout(row) row_layout.setContentsMargins(0, 0, 0, 0) # LEFT SIDE (50%): Filename Label filename = os.path.basename(file_info["path"]) filename_label = QLabel(f"P {obs_id}/{cam_id}: {filename}") filename_label.setWordWrap(True) row_layout.addWidget(filename_label, stretch=5) # 5 units = 50% # MIDDLE (40%): Progress Bar progress_bar = QProgressBar() progress_bar.setRange(0, 100) row_layout.addWidget(progress_bar, stretch=4) # 4 units = 40% # RIGHT SIDE (10%): Time Indicator Label time_label = QLabel("00:00/00:00") # Default placeholder time_label.setAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter) row_layout.addWidget(time_label, stretch=1) # 1 unit = 10% # Add the combined row to the bubble layout area row_index = self.bubble_layout.rowCount() self.bubble_layout.addWidget(row, row_index, 0, 1, 1) # Save widgets for thread updates state_key = (obs_id, cam_id) self.processing_widgets[state_key] = { "label": filename_label, "progress": progress_bar, # 👇 NEW: Save the time label reference "time_label": time_label } # --- THREAD START --- output_csv = f"{obs_id}_{cam_id}_processed.csv" # ⚠️ Ensure ParticipantProcessor is defined to accept self.observations_root processor = ParticipantProcessor( obs_id=obs_id, boris_json=self.boris, selected_cam_id=cam_id, selected_hand_idx=selected_hand_idx, observations_root=self.observations_root, # CRITICAL PATH FIX hide_preview=False, # Assuming default for now output_csv=output_csv, output_dir=self.output_dir, initial_wrists=self.camera_state[(obs_id, cam_id)].get("initial_wrists") ) # Connect signals to the UI elements processor.progress_updated.connect(progress_bar.setValue) processor.finished_processing.connect( lambda obs_id=obs_id, cam_id=cam_id: self.on_processing_finished(obs_id, cam_id) ) processor.time_updated.connect(time_label.setText) # Connects to THIS time_label processor.start() self.processing_threads.append(processor) # Set the vertical stretch to push all progress bars to the top num_rows = self.bubble_layout.rowCount() self.bubble_layout.setRowStretch(num_rows, 1) def check_for_pipeline_results(self): return def on_processing_finished(self, obs_id, cam_id): state_key = (obs_id, cam_id) if state_key in self.processing_widgets: widgets = self.processing_widgets[state_key] widgets["progress"].setValue(100) label = self.processing_widgets[state_key]["label"] label.setText(f"P {obs_id}/{cam_id}: ✅ COMPLETE") # Check if all threads are finished to revert the main button if not any(t.isRunning() for t in self.processing_threads): self.cancel_task() # Reverts the button and cleans up self.statusBar().showMessage("All processing complete!") def cancel_task(self): """Stops all active processing threads and reverts UI to 'Process' state.""" if not self.processing_threads: return self.statusBar().showMessage("Cancelling processing... Please wait.") # 1. Stop all threads gracefully for thread in self.processing_threads: if thread.isRunning(): # Assuming ParticipantProcessor has a .stop() method thread.stop() thread.wait() # Wait for thread to exit its loop # 2. Clean up threads and widgets self.processing_threads.clear() # 3. Revert button and status bar self.button1.clicked.disconnect(self.cancel_task) self.button1.setText("Process") self.button1.clicked.connect(self.on_run_task) # 4. Clear the progress bars display #self.clear_bubble_widgets() self.statusBar().showMessage("Processing cancelled. Ready to load or process.") def show_error_popup(self, title, error_message, traceback_str=""): msgbox = QMessageBox(self) msgbox.setIcon(QMessageBox.Warning) msgbox.setWindowTitle("Warning - SPARKS") message = ( f"SPARKS has encountered an error processing the file {title}.

" "This error was likely due to incorrect parameters on the right side of the screen and not an error with your data. " "Processing of the remaining files continues in the background and this participant will be ignored in the analysis. " "If you think the parameters on the right side are correct for your data, raise an issue here.

" f"Error message: {error_message}" ) msgbox.setTextFormat(Qt.TextFormat.RichText) msgbox.setText(message) msgbox.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction) # Add traceback to detailed text if traceback_str: msgbox.setDetailedText(traceback_str) msgbox.setStandardButtons(QMessageBox.Ok) msgbox.exec_() def cleanup_after_process(self): if hasattr(self, 'result_process'): self.result_process.join(timeout=0) if self.result_process.is_alive(): self.result_process.terminate() self.result_process.join() if hasattr(self, 'result_queue'): if 'AutoProxy' in repr(self.result_queue): pass else: self.result_queue.close() self.result_queue.join_thread() if hasattr(self, 'progress_queue'): if 'AutoProxy' in repr(self.progress_queue): pass else: self.progress_queue.close() self.progress_queue.join_thread() # Shutdown manager to kill its server process and clean up if hasattr(self, 'manager'): self.manager.shutdown() '''UPDATER''' def manual_check_for_updates(self): self.local_check_thread = LocalPendingUpdateCheckThread(CURRENT_VERSION, self.platform_suffix) self.local_check_thread.pending_update_found.connect(self.on_pending_update_found) self.local_check_thread.no_pending_update.connect(self.on_no_pending_update) self.local_check_thread.start() def on_pending_update_found(self, version, folder_path): self.statusBar().showMessage(f"Pending update found: version {version}") self.pending_update_version = version self.pending_update_path = folder_path self.show_pending_update_popup() def on_no_pending_update(self): # No pending update found locally, start server check directly self.statusBar().showMessage("No pending local update found. Checking server...") self.start_update_check_thread() def show_pending_update_popup(self): msg_box = QMessageBox(self) msg_box.setWindowTitle("Pending Update Found") msg_box.setText(f"A previously downloaded update (version {self.pending_update_version}) is available at:\n{self.pending_update_path}\nWould you like to install it now?") install_now_button = msg_box.addButton("Install Now", QMessageBox.ButtonRole.AcceptRole) install_later_button = msg_box.addButton("Install Later", QMessageBox.ButtonRole.RejectRole) msg_box.exec() if msg_box.clickedButton() == install_now_button: self.install_update(self.pending_update_path) else: self.statusBar().showMessage("Pending update available. Install later.") # After user dismisses, still check the server for new updates self.start_update_check_thread() def start_update_check_thread(self): self.check_thread = UpdateCheckThread() self.check_thread.download_requested.connect(self.on_server_update_requested) self.check_thread.no_update_available.connect(self.on_server_no_update) self.check_thread.error_occurred.connect(self.on_error) self.check_thread.start() def on_server_no_update(self): self.statusBar().showMessage("No new updates found on server.", 5000) def on_server_update_requested(self, download_url, latest_version): if self.pending_update_version: cmp = self.version_compare(latest_version, self.pending_update_version) if cmp > 0: # Server version is newer than pending update self.statusBar().showMessage(f"Newer version {latest_version} available on server. Removing old pending update...") try: shutil.rmtree(self.pending_update_path) self.statusBar().showMessage(f"Deleted old update folder: {self.pending_update_path}") except Exception as e: self.statusBar().showMessage(f"Failed to delete old update folder: {e}") # Clear pending update info so new download proceeds self.pending_update_version = None self.pending_update_path = None # Download the new update self.download_update(download_url, latest_version) elif cmp == 0: # Versions equal, no download needed self.statusBar().showMessage(f"Pending update version {self.pending_update_version} is already latest. No download needed.") else: # Server version older than pending? Unlikely but just keep pending update self.statusBar().showMessage(f"Pending update version {self.pending_update_version} is newer than server version. No action.") else: # No pending update, just download self.download_update(download_url, latest_version) def download_update(self, download_url, latest_version): self.statusBar().showMessage("Downloading update...") self.download_thread = UpdateDownloadThread(download_url, latest_version) self.download_thread.update_ready.connect(self.on_update_ready) self.download_thread.error_occurred.connect(self.on_error) self.download_thread.start() def on_update_ready(self, latest_version, extract_folder): self.statusBar().showMessage("Update downloaded and extracted.") msg_box = QMessageBox(self) msg_box.setWindowTitle("Update Ready") msg_box.setText(f"Version {latest_version} has been downloaded and extracted to:\n{extract_folder}\nWould you like to install it now?") install_now_button = msg_box.addButton("Install Now", QMessageBox.ButtonRole.AcceptRole) install_later_button = msg_box.addButton("Install Later", QMessageBox.ButtonRole.RejectRole) msg_box.exec() if msg_box.clickedButton() == install_now_button: self.install_update(extract_folder) else: self.statusBar().showMessage("Update ready. Install later.") def install_update(self, extract_folder): # Path to updater executable if PLATFORM_NAME == 'windows': updater_path = os.path.join(os.getcwd(), "sparks_updater.exe") elif PLATFORM_NAME == 'darwin': if getattr(sys, 'frozen', False): updater_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks_updater.app") else: updater_path = os.path.join(os.getcwd(), "../sparks_updater.app") elif PLATFORM_NAME == 'linux': updater_path = os.path.join(os.getcwd(), "sparks_updater") else: updater_path = os.getcwd() if not os.path.exists(updater_path): QMessageBox.critical(self, "Error", f"Updater not found at:\n{updater_path}. The absolute path was {os.path.abspath(updater_path)}") return # Launch updater with extracted folder path as argument try: # Pass current app's executable path for updater to relaunch main_app_executable = os.path.abspath(sys.argv[0]) print(f'Launching updater with: "{updater_path}" "{extract_folder}" "{main_app_executable}"') if PLATFORM_NAME == 'darwin': subprocess.Popen(['open', updater_path, '--args', extract_folder, main_app_executable]) else: subprocess.Popen([updater_path, f'{extract_folder}', f'{main_app_executable}'], cwd=os.path.dirname(updater_path)) # Close the current app so updater can replace files sys.exit(0) except Exception as e: QMessageBox.critical(self, "Error", f"[Updater Launch Failed]\n{str(e)}\n{traceback.format_exc()}") def on_error(self, message): # print(f"Error: {message}") self.statusBar().showMessage(f"Error occurred during update process. {message}") def version_compare(self, v1, v2): def normalize(v): return [int(x) for x in v.split(".")] return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2)) def closeEvent(self, event): # Gracefully shut down multiprocessing children print("Window is closing. Cleaning up...") if hasattr(self, 'manager'): self.manager.shutdown() for child in self.findChildren(QWidget): if child is not self and child.isVisible(): child.close() kill_child_processes() event.accept() def wait_for_process_to_exit(process_name, timeout=10): """ Waits for a process with the specified name to exit within a timeout period. Args: process_name (str): Name (or part of the name) of the process to wait for. timeout (int, optional): Maximum time to wait in seconds. Defaults to 10. Returns: bool: True if the process exited before the timeout, False otherwise. """ print(f"Waiting for {process_name} to exit...") deadline = time.time() + timeout while time.time() < deadline: still_running = False for proc in psutil.process_iter(['name']): try: if proc.info['name'] and process_name.lower() in proc.info['name'].lower(): still_running = True print(f"Still running: {proc.info['name']} (PID: {proc.pid})") break except (psutil.NoSuchProcess, psutil.AccessDenied): continue if not still_running: print(f"{process_name} has exited.") return True time.sleep(0.5) print(f"{process_name} did not exit in time.") return False def finish_update_if_needed(): """ Completes a pending application update if '--finish-update' is present in the command-line arguments. """ if "--finish-update" in sys.argv: print("Finishing update...") if PLATFORM_NAME == 'darwin': app_dir = '/tmp/sparkstempupdate' else: app_dir = os.getcwd() # 1. Find update folder update_folder = None for entry in os.listdir(app_dir): entry_path = os.path.join(app_dir, entry) if os.path.isdir(entry_path) and entry.startswith("sparks-") and entry.endswith("-" + PLATFORM_NAME): update_folder = os.path.join(app_dir, entry) break if update_folder is None: print("No update folder found. Skipping update steps.") return if PLATFORM_NAME == 'darwin': update_folder = os.path.join(update_folder, "sparks-darwin") # 2. Wait for sparks_updater to exit print("Waiting for sparks_updater to exit...") for proc in psutil.process_iter(['pid', 'name']): if proc.info['name'] and "sparks_updater" in proc.info['name'].lower(): try: proc.wait(timeout=5) except psutil.TimeoutExpired: print("Force killing lingering sparks_updater") proc.kill() # 3. Replace the updater if PLATFORM_NAME == 'windows': new_updater = os.path.join(update_folder, "sparks_updater.exe") dest_updater = os.path.join(app_dir, "sparks_updater.exe") elif PLATFORM_NAME == 'darwin': new_updater = os.path.join(update_folder, "sparks_updater.app") dest_updater = os.path.abspath(os.path.join(sys.executable, "../../../../sparks_updater.app")) elif PLATFORM_NAME == 'linux': new_updater = os.path.join(update_folder, "sparks_updater") dest_updater = os.path.join(app_dir, "sparks_updater") else: print("No platform??") new_updater = os.getcwd() dest_updater = os.getcwd() print(f"New updater is {new_updater}") print(f"Dest updater is {dest_updater}") print("Writable?", os.access(dest_updater, os.W_OK)) print("Executable path:", sys.executable) print("Trying to copy:", new_updater, "->", dest_updater) if os.path.exists(new_updater): try: if os.path.exists(dest_updater): if PLATFORM_NAME == 'darwin': try: if os.path.isdir(dest_updater): shutil.rmtree(dest_updater) print(f"Deleted directory: {dest_updater}") else: os.remove(dest_updater) print(f"Deleted file: {dest_updater}") except Exception as e: print(f"Error deleting {dest_updater}: {e}") else: os.remove(dest_updater) if PLATFORM_NAME == 'darwin': wait_for_process_to_exit("sparks_updater", timeout=10) subprocess.check_call(["ditto", new_updater, dest_updater]) else: shutil.copy2(new_updater, dest_updater) if PLATFORM_NAME in ('linux', 'darwin'): os.chmod(dest_updater, 0o755) if PLATFORM_NAME == 'darwin': remove_quarantine(dest_updater) print("sparks_updater replaced.") except Exception as e: print(f"Failed to replace sparks_updater: {e}") # 4. Delete the update folder try: if PLATFORM_NAME == 'darwin': shutil.rmtree(app_dir) else: shutil.rmtree(update_folder) except Exception as e: print(f"Failed to delete update folder: {e}") QMessageBox.information(None, "Update Complete", "The application has been successfully updated.") sys.argv.remove("--finish-update") def remove_quarantine(app_path): """ Removes the macOS quarantine attribute from the specified application path. """ script = f''' do shell script "xattr -d -r com.apple.quarantine {shlex.quote(app_path)}" with administrator privileges with prompt "SPARKS needs privileges to finish the update. (2/2)" ''' try: subprocess.run(['osascript', '-e', script], check=True) print("✅ Quarantine attribute removed.") except subprocess.CalledProcessError as e: print("❌ Failed to remove quarantine attribute.") print(e) def resource_path(relative_path): """ Get absolute path to resource regardless of running directly or packaged using PyInstaller """ if hasattr(sys, '_MEIPASS'): # PyInstaller bundle path base_path = sys._MEIPASS else: base_path = os.path.dirname(os.path.abspath(__file__)) return os.path.join(base_path, relative_path) def kill_child_processes(): """ Goodbye children """ try: parent = psutil.Process(os.getpid()) children = parent.children(recursive=True) for child in children: try: child.kill() except psutil.NoSuchProcess: pass psutil.wait_procs(children, timeout=5) except Exception as e: print(f"Error killing child processes: {e}") def exception_hook(exc_type, exc_value, exc_traceback): """ Method that will display a popup when the program hard crashes containg what went wrong """ error_msg = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) print(error_msg) # also print to console kill_child_processes() # Show error message box # Make sure QApplication exists (or create a minimal one) app = QApplication.instance() if app is None: app = QApplication(sys.argv) show_critical_error(error_msg) # Exit the app after user acknowledges sys.exit(1) def show_critical_error(error_msg): msg_box = QMessageBox() msg_box.setIcon(QMessageBox.Icon.Critical) msg_box.setWindowTitle("Something went wrong!") if PLATFORM_NAME == "darwin": log_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks.log") log_path2 = os.path.join(os.path.dirname(sys.executable), "../../../sparks_error.log") save_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks_autosave.spark") else: log_path = os.path.join(os.getcwd(), "sparks.log") log_path2 = os.path.join(os.getcwd(), "sparks_error.log") save_path = os.path.join(os.getcwd(), "sparks_autosave.spark") shutil.copy(log_path, log_path2) log_path2 = Path(log_path2).absolute().as_posix() autosave_path = Path(save_path).absolute().as_posix() log_link = f"file:///{log_path2}" autosave_link = f"file:///{autosave_path}" window.save_project(True) message = ( "SPARKS has encountered an unrecoverable error and needs to close.

" f"We are sorry for the inconvenience. An autosave was attempted to be saved to {autosave_path}, but it may not have been saved. " "If the file was saved, it still may not be intact, openable, or contain the correct data. Use the autosave at your discretion.

" "This unrecoverable error was likely due to an error with SPARKS and not your data.
" f"Please raise an issue here and attach the error file located at {log_path2}

" f"
{error_msg}
" ) msg_box.setTextFormat(Qt.TextFormat.RichText) msg_box.setText(message) msg_box.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction) msg_box.setStandardButtons(QMessageBox.StandardButton.Ok) msg_box.exec() if __name__ == "__main__": # Redirect exceptions to the popup window sys.excepthook = exception_hook # Set up application logging if PLATFORM_NAME == "darwin": log_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks.log") else: log_path = os.path.join(os.getcwd(), "sparks.log") try: os.remove(log_path) except: pass sys.stdout = open(log_path, "a", buffering=1) sys.stderr = sys.stdout print(f"\n=== App started at {datetime.now()} ===\n") freeze_support() # Required for PyInstaller + multiprocessing # Only run GUI in the main process if current_process().name == 'MainProcess': app = QApplication(sys.argv) finish_update_if_needed() window = MainApplication() if PLATFORM_NAME == "darwin": app.setWindowIcon(QIcon(resource_path("icons/main.icns"))) window.setWindowIcon(QIcon(resource_path("icons/main.icns"))) else: app.setWindowIcon(QIcon(resource_path("icons/main.ico"))) window.setWindowIcon(QIcon(resource_path("icons/main.ico"))) window.show() sys.exit(app.exec()) # Not 4000 lines yay!