Files
sparks/main.py
2026-01-05 10:36:19 -08:00

3223 lines
122 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Filename: main.py
Description: SPARKS main executable
Author: Tyler de Zeeuw
License: GPL-3.0
"""
# Built-in imports
import os
import re
import csv
import sys
import math
import json
import time
import shlex
import pickle
import shutil
import zipfile
import platform
import traceback
import subprocess
from itertools import product
from pathlib import Path, PurePosixPath
from datetime import datetime
from multiprocessing import Process, current_process, freeze_support, Manager
import cv2
import numpy as np
import pandas as pd
# External library imports
import psutil
import requests
import torch
import torch.nn as nn
import mediapipe as mp
from torch.utils.data import TensorDataset, DataLoader
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt
from PySide6.QtWidgets import (
QApplication, QWidget, QMessageBox, QVBoxLayout, QHBoxLayout, QTextEdit, QScrollArea, QComboBox, QGridLayout,
QPushButton, QMainWindow, QFileDialog, QLabel, QLineEdit, QFrame, QSizePolicy, QGroupBox, QDialog, QListView, QMenu, QProgressBar, QCheckBox
)
from PySide6.QtCore import QThread, Signal, Qt, QTimer, QEvent, QSize, QPoint
from PySide6.QtGui import QAction, QKeySequence, QIcon, QIntValidator, QDoubleValidator, QPixmap, QStandardItemModel, QStandardItem, QImage
from PySide6.QtSvgWidgets import QSvgWidget # needed to show svgs when app is not frozen
CURRENT_VERSION = "1.0.0"
API_URL = "https://git.research.dezeeuw.ca/api/v1/repos/tyler/sparks/releases"
API_URL_SECONDARY = "https://git.research2.dezeeuw.ca/api/v1/repos/tyler/sparks/releases"
PLATFORM_NAME = platform.system().lower()
# Selectable parameters on the right side of the window
SECTIONS = [
{
"title": "Parameter 1",
"params": [
{"name": "Number 1", "default": True, "type": bool, "help": "N/A"},
{"name": "Number 2", "default": True, "type": bool, "help": "N/A"},
]
},
]
class TrainModelThread(QThread):
update = Signal(str) # new: emits messages for the GUI
finished = Signal(str)
def __init__(self, csv_paths, model_save_path, parent=None):
super().__init__(parent)
self.csv_paths = csv_paths
self.model_save_path = model_save_path
def run(self):
# Load CSVs
X_list, y_list, feature_cols = [], [], None
for path in self.csv_paths:
df = pd.read_csv(path)
if feature_cols is None:
feature_cols = [c for c in df.columns if c.startswith("lm")]
df[feature_cols] = df[feature_cols].ffill().bfill()
X_list.append(df[feature_cols].values)
y_list.append(df[["reach_active", "reach_before_contact"]].values)
X = np.concatenate(X_list, axis=0)
y = np.concatenate(y_list, axis=0)
# Windowing
FPS, WINDOW_SEC, STEP_SEC, THRESHOLD, EPOCHS = 60, 0.4, 0.2, 0.5, 150
window_length = int(FPS * WINDOW_SEC)
step_length = int(FPS * STEP_SEC)
X_windows, y_windows = [], []
for start in range(0, len(X)-window_length+1, step_length):
end = start+window_length
X_win = X[start:end]
y_win = y[start:end]
label_window = (y_win.sum(axis=0) / len(y_win) >= 0.3).astype(int) # shape: (2,)
X_windows.append(X[start:end])
y_windows.append(label_window)
X_windows = np.array(X_windows)
y_windows = np.array(y_windows)
class_weights = []
for i in range(2):
cw = compute_class_weight('balanced', classes=np.array([0,1]), y=y_windows[:, i])
class_weights.append(cw[1]/cw[0])
pos_weight = torch.tensor(class_weights, dtype=torch.float32)
X_tensor = torch.tensor(np.array(X_windows), dtype=torch.float32) # (N, window_length, features)
y_tensor = torch.tensor(np.array(y_windows), dtype=torch.float32) # (N, 2)
# LSTM
class WindowLSTM(nn.Module):
def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True,
bidirectional=bidirectional)
self.fc = nn.Linear(hidden_size*(2 if bidirectional else 1), output_size)
def forward(self, x):
_, (h, _) = self.lstm(x)
if self.lstm.bidirectional:
h = torch.cat([h[-2], h[-1]], dim=1)
else:
h = h[-1]
return self.fc(h)
# Hyperparameter sweep
HIDDEN_SIZES = [64]
BATCH_SIZES = [32]
LRs = [0.0005]
OPTIMIZERS = ['adam']
BIDIRECTIONAL = [False]
best_f1 = 0
best_config = None
for hidden_size, batch_size, lr, opt_name, bi in product(HIDDEN_SIZES,BATCH_SIZES,LRs,OPTIMIZERS,BIDIRECTIONAL):
self.update.emit(f"Training model: hidden={hidden_size}, batch={batch_size}, lr={lr}, opt={opt_name}, bidir={bi}")
dataset = TensorDataset(X_tensor, y_tensor)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
model = WindowLSTM(input_size=X_windows.shape[2], hidden_size=hidden_size, bidirectional=bi)
optimizer = torch.optim.Adam(model.parameters(), lr=lr) if opt_name=='adam' else torch.optim.SGD(model.parameters(), lr=lr)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
# Training
for epoch in range(EPOCHS):
model.train()
epoch_loss = 0
for xb, yb in loader:
optimizer.zero_grad()
logits = model(xb)
loss = criterion(logits, yb)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# Evaluation
model.eval()
with torch.no_grad():
logits = model(X_tensor)
probs = torch.sigmoid(logits)
preds = (probs > THRESHOLD).int()
y_true = y_tensor.int()
tp = ((preds==1)&(y_true==1)).sum().item()
fp = ((preds==1)&(y_true==0)).sum().item()
fn = ((preds==0)&(y_true==1)).sum().item()
precision = tp/(tp+fp+1e-8)
recall = tp/(tp+fn+1e-8)
f1 = 2*precision*recall/(precision+recall+1e-8)
self.update.emit(f"Epoch {epoch+1}/{EPOCHS} Loss={epoch_loss:.4f} "
f"Precision={precision:.3f} Recall={recall:.3f} F1={f1:.3f}")
if f1 > best_f1:
best_f1 = f1
# When saving the model after training
best_config = (hidden_size, batch_size, lr, opt_name, bi)
torch.save(model.state_dict(), self.model_save_path)
self.update.emit(f"New best model saved to {os.path.basename(self.model_save_path)}!")
self.finished.emit(f"Training complete!\nBest config: {best_config}\nF1={best_f1:.3f}")
class TestModelThread(QThread):
update_frame = Signal(np.ndarray) # emit frames with overlay
finished = Signal()
def __init__(self, video_paths, model_path, fps=60, window_sec=0.4, step_sec=0.2, threshold=0.5):
super().__init__()
self.video_paths = video_paths
self.fps = fps
self.window_sec = window_sec
self.step_sec = step_sec
self.threshold = threshold
self.model_path = model_path
self.window_frames = int(fps*window_sec)
def run(self):
from time import time
class WindowLSTM(torch.nn.Module):
def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2):
super().__init__()
self.lstm = torch.nn.LSTM(
input_size, hidden_size, batch_first=True, bidirectional=bidirectional
)
self.fc = torch.nn.Linear(hidden_size * (2 if bidirectional else 1), output_size)
def forward(self, x):
_, (h, _) = self.lstm(x)
if self.lstm.bidirectional:
h = torch.cat([h[-2], h[-1]], dim=1)
else:
h = h[-1]
return self.fc(h)
# MATCH THE MODEL CONFIG YOU TRAINED WITH
model = WindowLSTM(input_size=63, hidden_size=64, bidirectional=False)
model.load_state_dict(torch.load(self.model_path, map_location="cpu"))
model.eval()
print("Loaded model.")
# ------------------------------ MediaPipe ------------------------------
mp_hands = mp.solutions.hands
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
hands = mp_hands.Hands(
static_image_mode=False,
max_num_hands=2,
min_detection_confidence=0.5,
min_tracking_confidence=0.5,
model_complexity=0
)
# ------------------------------ Open videos ------------------------------
cap = cv2.VideoCapture(self.video_paths[0])
# buffer holds last 60 frames of landmarks
window_buffer = []
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
print(f"Original video dimensions: {original_width} x {original_height}")
print(f"Total frames in video: {total_frames}")
# Set display dimensions (maintaining aspect ratio)
display_width = 800
aspect_ratio = original_height / original_width
display_height = int(display_width * aspect_ratio)
print(f"Display dimensions: {display_width} x {display_height}")
window_name = f"Hand Tracking Test - Press 'q' to quit"
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
cv2.resizeWindow(window_name, display_width, display_height)
# Output window size
display_width = 1000
frame_index = 0 # <<< ADDED to match your other script
while True:
ret, frame = cap.read()
if not ret:
break
h, w = frame.shape[:2]
# detect hands
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = hands.process(rgb)
# find RIGHT hand
features = []
right_found = False
if results.multi_hand_landmarks and results.multi_handedness:
for i, handness in enumerate(results.multi_handedness):
label = handness.classification[0].label
if label == "Right":
lm = results.multi_hand_landmarks[i]
right_found = True
for p in lm.landmark:
features += [p.x, p.y, p.z]
# draw hand with SAME STYLES
mp_drawing.draw_landmarks(
frame,
lm,
mp_hands.HAND_CONNECTIONS,
mp_drawing_styles.get_default_hand_landmarks_style(),
mp_drawing_styles.get_default_hand_connections_style(),
)
break
if not right_found:
features = [math.nan] * 63
print(features)
# --------------------------
# Maintain window buffer
# --------------------------
window_buffer.append(features)
# forward fill
arr = np.array(window_buffer, dtype=np.float32)
for c in range(arr.shape[1]):
valid = ~np.isnan(arr[:, c])
if valid.any():
arr[:, c][~valid] = np.interp(
np.flatnonzero(~valid),
np.flatnonzero(valid),
arr[:, c][valid],
)
window_buffer = arr.tolist()
# only last 60 frames
if len(window_buffer) > self.window_frames:
window_buffer.pop(0)
prob = 0
pred_active = 0
pred_before_contact = 0
# --------------------------
# run inference when full window available
# --------------------------
if len(window_buffer) == self.window_frames:
x = torch.tensor(window_buffer, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
logits = model(x)
probs = torch.sigmoid(logits).squeeze(0).numpy()
pred_active = int(probs[0] > self.threshold)
pred_before_contact = int(probs[1] > self.threshold)
color_active = (0, 255, 0) if pred_active else (0, 0, 255)
color_before_contact = (255, 255, 0) if pred_before_contact else (0, 0, 128)
# Overlay rectangles
cv2.rectangle(frame, (0, 0), (w//2, 40), color_active, -1)
cv2.putText(frame, f"Reach Active: {'YES' if pred_active else 'NO'} ({probs[0]:.2f})",
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,255), 2)
cv2.rectangle(frame, (w//2, 0), (w, 40), color_before_contact, -1)
cv2.putText(frame, f"Reach Before Contact: {'YES' if pred_before_contact else 'NO'} ({probs[1]:.2f})",
(w//2 + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,255), 2)
else:
cv2.rectangle(frame, (0, 0), (w, 40), (0, 128, 255), -1)
cv2.putText(
frame,
f"Collecting window: {len(window_buffer)}/{self.window_frames}",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
(255, 255, 255),
2,
)
# ------------------------------
# Extra overlays to MATCH your big script
# ------------------------------
info_text = [
f"Frame: {frame_index}",
f"Time: {frame_index / self.fps:.2f}s",
f"Reach Active: {'YES' if pred_active else 'NO'}",
f"Reach Before Contact: {'YES' if pred_before_contact else 'NO'}",
f"Right Hand: {'DETECTED' if right_found else 'NOT DETECTED'}",
"Press 'q' to quit",
]
for i, text in enumerate(info_text):
y = 60 + i * 25
cv2.putText(frame, text, (10, y),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,255,0), 2)
# If predicted active, add highlight frame
if pred_active or pred_before_contact:
cv2.rectangle(frame, (0,0), (w,h), (0,255,0), 6)
display_frame = frame.copy()
# resize for display
resized_frame = cv2.resize(display_frame, (display_width, display_height), interpolation=cv2.INTER_AREA)
cv2.imshow(window_name, resized_frame)
frame_index += 1
if cv2.waitKey(1) & 0xFF == ord("q"):
break
self.finished.emit()
class TestModelThread2(QThread):
finished = Signal(object)
def __init__(self, csv_path, model_path, fps=60, window_sec=0.4, threshold=0.5):
super().__init__()
self.csv_path = csv_path
self.model_path = model_path
self.fps = fps
self.window_sec = window_sec
self.window_frames = int(fps * window_sec)
self.threshold = threshold
def run(self):
try:
self._run_internal()
except Exception as e:
import traceback
traceback.print_exc()
self.error = str(e)
self.finished.emit({"error": str(e)})
def _run_internal(self):
# --------------------------------------------------
# 1. Model Definition and Load (Copied from TestModelThread)
# --------------------------------------------------
class WindowLSTM(torch.nn.Module):
def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2):
super().__init__()
self.lstm = torch.nn.LSTM(
input_size, hidden_size, batch_first=True, bidirectional=bidirectional
)
self.fc = torch.nn.Linear(hidden_size * (2 if bidirectional else 1), output_size)
def forward(self, x):
_, (h, _) = self.lstm(x)
if self.lstm.bidirectional:
h = torch.cat([h[-2], h[-1]], dim=1)
else:
h = h[-1]
return self.fc(h)
# Assuming input_size=63, hidden_size=64, bidirectional=False, as in TestModelThread
# If your training params varied, you must update these defaults to match the saved model
model = WindowLSTM(input_size=63, hidden_size=64, bidirectional=False)
model.load_state_dict(torch.load(self.model_path, map_location="cpu"))
model.eval()
# --------------------------------------------------
# 2. Load CSV of landmarks (CRITICAL FIX APPLIED HERE)
# --------------------------------------------------
df = pd.read_csv(self.csv_path)
feature_cols = [c for c in df.columns if c.startswith("lm")]
X_with_nans = df[feature_cols].values
time = df["time_sec"].values if "time_sec" in df.columns else np.arange(len(X_with_nans)) / self.fps
# --------------------------------------------------
# 3. Sliding window inference (EXACTLY MATCHING TestModelThread)
# --------------------------------------------------
window_buffer = []
preds_active = []
preds_before_contact = []
for i in range(len(X_with_nans)):
# Pad output
if i < self.window_frames - 1:
preds_active.append(0)
preds_before_contact.append(0)
# Add features (may contain NaNs) to the buffer
features = X_with_nans[i]
window_buffer.append(features)
# Keep only last `window_frames`
if len(window_buffer) > self.window_frames:
window_buffer.pop(0)
# 🚨 CRITICAL: PERFORM INTERPOLATION ON THE BUFFER
# This block *must* run to smooth over lost frames, replicating the video logic.
arr = np.array(window_buffer, dtype=np.float32)
for c in range(arr.shape[1]):
valid = ~np.isnan(arr[:, c])
if valid.any():
arr[:, c][~valid] = np.interp(
np.flatnonzero(~valid),
np.flatnonzero(valid),
arr[:, c][valid],
)
# 🚨 CRITICAL: Update the buffer back to the interpolated LIST OF LISTS.
# This mirrors the exact data structure used by TestModelThread before calling torch.tensor().
window_buffer = arr.tolist()
# predict only when full window
if len(window_buffer) == self.window_frames:
# 🚨 CRITICAL: Create tensor from the LIST OF LISTS buffer
x = torch.tensor(window_buffer, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
logits = model(x)
probs = torch.sigmoid(logits).squeeze(0).numpy()
preds_active.append(probs[0])
preds_before_contact.append(probs[1])
self.finished.emit({
"time": time,
"active": preds_active,
"before": preds_before_contact,
"threshold": self.threshold,
})
class TerminalWindow(QWidget):
def __init__(self, parent=None):
super().__init__(parent, Qt.WindowType.Window)
self.setWindowTitle("Terminal - SPARKS")
self.output_area = QTextEdit()
self.output_area.setReadOnly(True)
self.input_line = QLineEdit()
self.input_line.returnPressed.connect(self.handle_command)
layout = QVBoxLayout()
layout.addWidget(self.output_area)
layout.addWidget(self.input_line)
self.setLayout(layout)
self.commands = {
"hello": self.cmd_hello,
"help": self.cmd_help
}
def handle_command(self):
command_text = self.input_line.text()
self.input_line.clear()
self.output_area.append(f"> {command_text}")
parts = command_text.strip().split()
if not parts:
return
command_name = parts[0]
args = parts[1:]
func = self.commands.get(command_name)
if func:
try:
result = func(*args)
if result:
self.output_area.append(str(result))
except Exception as e:
self.output_area.append(f"[Error] {e}")
else:
self.output_area.append(f"[Unknown command] '{command_name}'")
def cmd_hello(self, *args):
return "Hello from the terminal!"
def cmd_help(self, *args):
return f"Available commands: {', '.join(self.commands.keys())}"
class UpdateDownloadThread(QThread):
"""
Thread that downloads and extracts an update package and emits a signal on completion or error.
Args:
download_url (str): URL of the update zip file to download.
latest_version (str): Version string of the latest update.
"""
update_ready = Signal(str, str)
error_occurred = Signal(str)
def __init__(self, download_url, latest_version):
super().__init__()
self.download_url = download_url
self.latest_version = latest_version
def run(self):
try:
local_filename = os.path.basename(self.download_url)
if PLATFORM_NAME == 'darwin':
tmp_dir = '/tmp/sparkstempupdate'
os.makedirs(tmp_dir, exist_ok=True)
local_path = os.path.join(tmp_dir, local_filename)
else:
local_path = os.path.join(os.getcwd(), local_filename)
# Download the file
with requests.get(self.download_url, stream=True, timeout=15) as r:
r.raise_for_status()
with open(local_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
# Extract folder name (remove .zip)
if PLATFORM_NAME == 'darwin':
extract_folder = os.path.splitext(local_filename)[0]
extract_path = os.path.join(tmp_dir, extract_folder)
else:
extract_folder = os.path.splitext(local_filename)[0]
extract_path = os.path.join(os.getcwd(), extract_folder)
# Create the folder if not exists
os.makedirs(extract_path, exist_ok=True)
# Extract the zip file contents
if PLATFORM_NAME == 'darwin':
subprocess.run(['ditto', '-xk', local_path, extract_path], check=True)
else:
with zipfile.ZipFile(local_path, 'r') as zip_ref:
zip_ref.extractall(extract_path)
# Remove the zip once extracted and emit a signal
os.remove(local_path)
self.update_ready.emit(self.latest_version, extract_path)
except Exception as e:
# Emit a signal signifying failure
self.error_occurred.emit(str(e))
class UpdateCheckThread(QThread):
"""
Thread that checks for updates by querying the API and emits a signal based on the result.
Signals:
download_requested(str, str): Emitted with (download_url, latest_version) when an update is available.
no_update_available(): Emitted when no update is found or current version is up to date.
error_occurred(str): Emitted with an error message if the update check fails.
"""
download_requested = Signal(str, str)
no_update_available = Signal()
error_occurred = Signal(str)
def run(self):
if not getattr(sys, 'frozen', False):
self.error_occurred.emit("Application is not frozen (Development mode).")
return
try:
latest_version, download_url = self.get_latest_release_for_platform()
if not latest_version:
self.no_update_available.emit()
return
if not download_url:
self.error_occurred.emit(f"No download available for platform '{PLATFORM_NAME}'")
return
if self.version_compare(latest_version, CURRENT_VERSION) > 0:
self.download_requested.emit(download_url, latest_version)
else:
self.no_update_available.emit()
except Exception as e:
self.error_occurred.emit(f"Update check failed: {e}")
def version_compare(self, v1, v2):
def normalize(v): return [int(x) for x in v.split(".")]
return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2))
def get_latest_release_for_platform(self):
urls = [API_URL, API_URL_SECONDARY]
for url in urls:
try:
response = requests.get(API_URL, timeout=5)
response.raise_for_status()
releases = response.json()
if not releases:
return None, None
latest = releases[0]
tag = latest["tag_name"].lstrip("v")
for asset in latest.get("assets", []):
if PLATFORM_NAME in asset["name"].lower():
return tag, asset["browser_download_url"]
return tag, None
except (requests.RequestException, ValueError) as e:
continue
return None, None
class LocalPendingUpdateCheckThread(QThread):
"""
Thread that checks for locally pending updates by scanning the download directory and emits a signal accordingly.
Args:
current_version (str): Current application version.
platform_suffix (str): Platform-specific suffix to identify update folders.
"""
pending_update_found = Signal(str, str)
no_pending_update = Signal()
def __init__(self, current_version, platform_suffix):
super().__init__()
self.current_version = current_version
self.platform_suffix = platform_suffix
def version_compare(self, v1, v2):
def normalize(v): return [int(x) for x in v.split(".")]
return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2))
def run(self):
if PLATFORM_NAME == 'darwin':
cwd = '/tmp/sparkstempupdate'
else:
cwd = os.getcwd()
pattern = re.compile(r".*-(\d+\.\d+\.\d+)" + re.escape(self.platform_suffix) + r"$")
found = False
try:
for item in os.listdir(cwd):
folder_path = os.path.join(cwd, item)
if os.path.isdir(folder_path) and item.endswith(self.platform_suffix):
match = pattern.match(item)
if match:
folder_version = match.group(1)
if self.version_compare(folder_version, self.current_version) > 0:
self.pending_update_found.emit(folder_version, folder_path)
found = True
break
except:
pass
if not found:
self.no_pending_update.emit()
class ProgressDialog(QDialog):
def __init__(self, parent=None):
super().__init__(parent)
self.setWindowTitle("Processing Files...")
self.setWindowModality(Qt.WindowModality.ApplicationModal)
self.setMinimumWidth(300)
layout = QVBoxLayout(self)
self.message_label = QLabel("Loading and analyzing video files...")
self.progress_bar = QProgressBar()
self.progress_bar.setMaximum(0) # Indeterminate state initially
self.cancel_button = QPushButton("Cancel")
layout.addWidget(self.message_label)
layout.addWidget(self.progress_bar)
layout.addWidget(self.cancel_button)
def update_progress(self, current, total):
if self.progress_bar.maximum() == 0:
self.progress_bar.setMaximum(total)
self.progress_bar.setValue(current)
self.message_label.setText(f"Processing file {current} of {total}...")
class FileLoadWorker(QThread):
observations_loaded = Signal(dict, dict, str)
progress_update = Signal(int, int)
loading_finished = Signal()
loading_failed = Signal(str)
def __init__(self, file_path, observations_root, extract_frame_and_hands_func):
super().__init__()
self.file_path = file_path
self.observations_root = observations_root
self.extract_frame_and_hands = extract_frame_and_hands_func
self.is_running = True
self.previews_data = {} # To store the results of the heavy work
def resolve_video_path(self, relative_video_path):
"""Logic to handle flexible path resolution (basename or full join)."""
corrected_file_name = os.path.basename(relative_video_path)
video_path = os.path.join(self.observations_root, corrected_file_name)
if not os.path.exists(video_path):
# Fallback to full relative path join
video_path = os.path.join(self.observations_root, relative_video_path)
return video_path
def run(self):
try:
# 1. Load the BORIS JSON (Non-UI work)
with open(self.file_path, "r") as f:
boris = json.load(f)
observations = boris.get("observations", {})
if not observations:
self.loading_failed.emit("No observations found in BORIS JSON.")
return
# 2. Process the file data to gather initial previews (the heavy part)
self.process_previews(observations)
# 3. Emit the final results back
self.observations_loaded.emit(boris, observations, self.observations_root)
except Exception as e:
self.loading_failed.emit(f"Loading failed: {str(e)}")
finally:
self.loading_finished.emit() # Ensures dialog closes even on some non-critical path breaks
def process_previews(self, observations):
total_files = sum(len(obs_data.get("file", {}).items()) for obs_data in observations.values())
current_index = 0
for obs_id, obs_data in observations.items():
files_dict = obs_data.get("file", {})
media_info = obs_data.get("media_info", {})
fps_dict = media_info.get("fps", {})
# Simplified FPS calculation for the worker
fps_default = float(list(fps_dict.values())[0]) if fps_dict and list(fps_dict.values()) else 30.0
print(media_info)
for cam_id, paths in files_dict.items():
if not self.is_running:
return # Allow cancellation
if not paths:
print(f"Skipping entry for {obs_id}/{cam_id}: paths list is empty (IndexError prevented).")
current_index += 1
self.progress_update.emit(current_index, total_files)
continue
print(paths)
relative_video_path = paths[0]
video_path = self.resolve_video_path(relative_video_path)
print(video_path)
if os.path.exists(video_path):
# HEAVY OPERATION
frame_rgb, results = self.extract_frame_and_hands(video_path, 0)
initial_wrists = []
if results and results.multi_hand_landmarks:
for lm in results.multi_hand_landmarks:
wrist = lm.landmark[0]
initial_wrists.append((wrist.x, wrist.y))
# Store the results needed to BUILD the UI later
self.previews_data[(obs_id, cam_id)] = {
"frame_rgb": frame_rgb,
"results": results,
"video_path": video_path,
"fps": fps_default,
"initial_wrists": initial_wrists
}
current_index += 1
self.progress_update.emit(current_index, total_files)
def stop(self):
self.is_running = False
class AboutWindow(QWidget):
"""
Simple About window displaying basic application information.
Args:
parent (QWidget, optional): Parent widget of this window. Defaults to None.
"""
def __init__(self, parent=None):
super().__init__(parent, Qt.WindowType.Window)
self.setWindowTitle("About SPARKS")
self.resize(250, 100)
layout = QVBoxLayout()
label = QLabel("About SPARKS", self)
label2 = QLabel("Spacial Patterns Analysis, Research, & Knowledge Suite", self)
label3 = QLabel("SPARKS is licensed under the GPL-3.0 licence. For more information, visit https://www.gnu.org/licenses/gpl-3.0.en.html", self)
label4 = QLabel(f"Version v{CURRENT_VERSION}")
layout.addWidget(label)
layout.addWidget(label2)
layout.addWidget(label3)
layout.addWidget(label4)
self.setLayout(layout)
class UserGuideWindow(QWidget):
"""
Simple User Guide window displaying basic information on how to use the software.
Args:
parent (QWidget, optional): Parent widget of this window. Defaults to None.
"""
def __init__(self, parent=None):
super().__init__(parent, Qt.WindowType.Window)
self.setWindowTitle("User Guide - SPARKS")
self.resize(250, 100)
layout = QVBoxLayout()
label = QLabel("Not Applicable:", self)
label2 = QLabel("N/A", self)
layout.addWidget(label)
layout.addWidget(label2)
self.setLayout(layout)
class ParamSection(QWidget):
"""
A widget section that dynamically creates labeled input fields from parameter metadata.
Args:
section_data (dict): Dictionary containing section title and list of parameter info.
Expected format:
{
"title": str,
"params": [
{
"name": str,
"type": type,
"default": any,
"help": str (optional)
},
...
]
}
"""
def __init__(self, section_data):
super().__init__()
layout = QVBoxLayout()
self.setLayout(layout)
self.widgets = {}
self.selected_path = None
# Title label
title_label = QLabel(section_data["title"])
title_label.setStyleSheet("font-weight: bold; font-size: 14px; margin-top: 10px; margin-bottom: 5px;")
layout.addWidget(title_label)
# Horizontal line
line = QFrame()
line.setFrameShape(QFrame.Shape.HLine)
line.setFrameShadow(QFrame.Shadow.Sunken)
layout.addWidget(line)
for param in section_data["params"]:
h_layout = QHBoxLayout()
label = QLabel(param["name"])
label.setFixedWidth(180)
label.setToolTip(param.get("help", ""))
help_text = param.get("help", "")
help_btn = QPushButton("?")
help_btn.setFixedWidth(25)
help_btn.setToolTip(help_text)
help_btn.clicked.connect(lambda _, text=help_text: self.show_help_popup(text))
h_layout.addWidget(help_btn)
h_layout.setStretch(0, 1) # Set stretch factor for button (10%)
h_layout.addWidget(label)
h_layout.setStretch(1, 3) # Set the stretch factor for label (40%)
# Create input widget based on type
if param["type"] == bool:
widget = QComboBox()
widget.addItems(["True", "False"])
widget.setCurrentText(str(param["default"]))
elif param["type"] == int:
widget = QLineEdit()
widget.setValidator(QIntValidator())
widget.setText(str(param["default"]))
elif param["type"] == float:
widget = QLineEdit()
widget.setValidator(QDoubleValidator())
widget.setText(str(param["default"]))
elif param["type"] == list:
widget = self._create_multiselect_dropdown(None)
else:
widget = QLineEdit()
widget.setText(str(param["default"]))
widget.setToolTip(help_text)
h_layout.addWidget(widget)
h_layout.setStretch(2, 5) # Set stretch factor for input field (50%)
layout.addLayout(h_layout)
self.widgets[param["name"]] = {
"widget": widget,
"type": param["type"]
}
def _create_multiselect_dropdown(self, items):
combo = FullClickComboBox()
combo.setView(QListView())
model = QStandardItemModel()
combo.setModel(model)
combo.setEditable(True)
combo.lineEdit().setReadOnly(True)
combo.lineEdit().setPlaceholderText("Select...")
dummy_item = QStandardItem("<None Selected>")
dummy_item.setFlags(Qt.ItemIsEnabled)
model.appendRow(dummy_item)
toggle_item = QStandardItem("Toggle Select All")
toggle_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled)
toggle_item.setData(Qt.Unchecked, Qt.CheckStateRole)
model.appendRow(toggle_item)
if items is not None:
for item in items:
standard_item = QStandardItem(item)
standard_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled)
standard_item.setData(Qt.Unchecked, Qt.CheckStateRole)
model.appendRow(standard_item)
combo.setInsertPolicy(QComboBox.NoInsert)
def on_view_clicked(index):
item = model.itemFromIndex(index)
if item.isCheckable():
new_state = Qt.Checked if item.checkState() == Qt.Unchecked else Qt.Unchecked
item.setCheckState(new_state)
combo.view().pressed.connect(on_view_clicked)
self._updating_checkstates = False
def on_item_changed(item):
if self._updating_checkstates:
return
self._updating_checkstates = True
normal_items = [model.item(i) for i in range(2, model.rowCount())] # skip dummy and toggle
if item == toggle_item:
all_checked = all(i.checkState() == Qt.Checked for i in normal_items)
if all_checked:
for i in normal_items:
i.setCheckState(Qt.Unchecked)
toggle_item.setCheckState(Qt.Unchecked)
else:
for i in normal_items:
i.setCheckState(Qt.Checked)
toggle_item.setCheckState(Qt.Checked)
elif item == dummy_item:
pass
else:
# When normal items change, update toggle item
all_checked = all(i.checkState() == Qt.Checked for i in normal_items)
toggle_item.setCheckState(Qt.Checked if all_checked else Qt.Unchecked)
self._updating_checkstates = False
for param_name, info in self.widgets.items():
if info["widget"] == combo:
self.update_dropdown_label(param_name)
break
model.itemChanged.connect(on_item_changed)
combo.setInsertPolicy(QComboBox.NoInsert)
return combo
def show_help_popup(self, text):
msg = QMessageBox(self)
msg.setWindowTitle("Parameter Info - SPARKS")
msg.setText(text)
msg.exec()
def get_param_values(self):
values = {}
for name, info in self.widgets.items():
widget = info["widget"]
expected_type = info["type"]
if expected_type == bool:
values[name] = widget.currentText() == "True"
elif expected_type == list:
values[name] = [x.strip() for x in widget.lineEdit().text().split(",") if x.strip()]
else:
raw_text = widget.text()
try:
if expected_type == int:
values[name] = int(raw_text)
elif expected_type == float:
values[name] = float(raw_text)
elif expected_type == str:
values[name] = raw_text
else:
values[name] = raw_text # Fallback
except Exception as e:
raise ValueError(f"Invalid value for {name}: {raw_text}") from e
return values
class FullClickLineEdit(QLineEdit):
def mousePressEvent(self, event):
combo = self.parent()
if isinstance(combo, QComboBox):
combo.showPopup()
super().mousePressEvent(event)
class FullClickComboBox(QComboBox):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setLineEdit(FullClickLineEdit(self))
self.lineEdit().setReadOnly(True)
class ParticipantProcessor(QThread):
progress_updated = Signal(int)
frame_ready = Signal(QImage)
finished_processing = Signal(str)
time_updated = Signal(str)
def __init__(self, obs_id, boris_json, selected_cam_id, selected_hand_idx,
hide_preview=False, output_csv=None, observations_root=None, output_dir=None, initial_wrists=None):
super().__init__()
self.obs_id = obs_id
self.boris_json = boris_json
self.selected_cam_id = selected_cam_id
self.selected_hand_idx = selected_hand_idx
self.hide_preview = hide_preview
self.output_csv = output_csv
self.observations_root = observations_root
self.output_dir = output_dir
self.initial_wrists = initial_wrists
# Mediapipe hands
self.mp_hands = mp.solutions.hands.Hands(
static_image_mode=False,
max_num_hands=2,
min_detection_confidence=0.5,
min_tracking_confidence=0.5,
model_complexity=0
)
self.is_running = True
def run(self):
observations = self.boris_json.get("observations", {})
obs_data = observations[self.obs_id]
# ---------------- Camera / Video ----------------
cameras_with_files = obs_data.get("file", {})
media_info = obs_data.get("media_info", {})
offset_dict = media_info.get("offset", {})
# paths = files_dict.get(self.cam_id, [])
video_file_list = cameras_with_files[self.selected_cam_id]
relative_video_path = video_file_list[0]
# 2. Resolve the FULL path using the stored root and the same logic as before
corrected_file_name = os.path.basename(relative_video_path)
# Try finding just the file name in the root folder
video_path = os.path.join(self.observations_root, corrected_file_name)
# if not os.path.exists(video_path):
# # Fallback to the full relative path join
# video_path = os.path.join(self.observations_root, relative_video_path)
if not os.path.exists(video_path):
print(f"FATAL ERROR: Video file not found at: {video_path}")
self.finished_processing.emit(f"Error: Video not found for {self.obs_id}")
return
camera_offset = float(offset_dict.get(self.selected_cam_id, 0.0))
fps_dict = media_info.get("fps", {})
fps = float(list(fps_dict.values())[0]) if fps_dict else 30.0
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_seconds = total_frames / fps
total_time_str = self._format_seconds(total_seconds)
# ---------------- Trial Segments ----------------
raw_events = obs_data.get("events", [])
segments = []
reach_before_contact_segments = []
current_start = None
extra_duration = 1.0
for ev in raw_events:
if len(ev) < 3:
continue
time_sec = ev[0]
name = ev[2]
if name == "Trial Start":
current_start = time_sec
current_contact = None
elif name == "Contact" and current_start is not None:
current_contact = time_sec
reach_before_contact_segments.append((current_start, current_contact))
elif name == "End" and current_start is not None:
adjusted_start = current_start
adjusted_end = time_sec + extra_duration
segments.append((adjusted_start, adjusted_end))
if current_contact is None:
reach_before_contact_segments.append((current_start, adjusted_end))
current_start = None
# ---------------- CSV Header ----------------
header = ["frame_index", "time_sec", "adjusted_time_sec", "reach_active", "reach_before_contact", "camera_offset"]
for i in range(21):
header += [f"lm{i}_x", f"lm{i}_y", f"lm{i}_z"]
output_path = os.path.join(self.output_dir, self.output_csv)
print(output_path)
with open(output_path, "w", newline="") as csv_file:
csv_writer = csv.writer(csv_file)
csv_writer.writerow(header)
# ---------------- Frame Processing ----------------
frame_index = 0
while True:
if not self.is_running:
print(f"Processor for {self.obs_id}/{self.cam_id} was cancelled gracefully.")
# Clean up resources before returning
cap.release()
return # Thread exits gracefully
ret, frame = cap.read()
if not ret:
break
current_time = frame_index / fps
adjusted_time = current_time # you could add extra logic if needed
# Check if reach is active with offset
is_reach_active = any(
(start - camera_offset) <= current_time <= (end - camera_offset)
for start, end in segments
)
is_reach_before_contact = any(
(start - camera_offset) <= current_time <= (end - camera_offset)
for start, end in reach_before_contact_segments
)
row = [
frame_index, current_time, adjusted_time,
int(is_reach_active), int(is_reach_before_contact),
camera_offset
]
# ---------------- MediaPipe Hands ----------------
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = self.mp_hands.process(rgb_frame)
hand_found = False
if results.multi_hand_landmarks and results.multi_handedness:
# Map detected hands to initial_wrists
hand_mapping = {}
used_indices = set()
if self.initial_wrists:
for i, iw in enumerate(self.initial_wrists):
min_dist = float('inf')
best_idx = None
for j, lm in enumerate(results.multi_hand_landmarks):
if j in used_indices:
continue
wrist = lm.landmark[0]
dist = (wrist.x - iw[0])**2 + (wrist.y - iw[1])**2
if dist < min_dist:
min_dist = dist
best_idx = j
if best_idx is not None:
hand_mapping[i] = best_idx
used_indices.add(best_idx)
else:
hand_mapping = {i: i for i in range(len(results.multi_hand_landmarks))}
# Use selected_hand_idx to pick the right hand
mapped_idx = hand_mapping.get(self.selected_hand_idx)
if mapped_idx is not None:
hand_landmarks = results.multi_hand_landmarks[mapped_idx]
hand_found = True
for lm in hand_landmarks.landmark:
row += [lm.x, lm.y, lm.z]
if not hand_found:
row += [math.nan] * (21 * 3)
csv_writer.writerow(row)
# ---------------- Preview ----------------
if not self.hide_preview:
display_frame = frame.copy()
if results.multi_hand_landmarks and hand_found:
mp.solutions.drawing_utils.draw_landmarks(
display_frame,
hand_landmarks,
mp.solutions.hands.HAND_CONNECTIONS
)
# Convert to QImage
h, w, _ = display_frame.shape
qimg = QImage(display_frame.data, w, h, 3*w, QImage.Format_RGB888)
self.frame_ready.emit(qimg)
current_seconds = frame_index / fps
current_time_str = self._format_seconds(current_seconds)
# ---------------- Progress ----------------
progress = int((frame_index / total_frames) * 100)
self.progress_updated.emit(progress)
self.time_updated.emit(f"{current_time_str}/{total_time_str}")
frame_index += 1
cap.release()
self.mp_hands.close()
self.finished_processing.emit(self.obs_id)
def _format_seconds(self, seconds):
if seconds < 0:
return "00:00"
minutes = math.floor(seconds / 60)
secs = math.floor(seconds % 60)
return f"{minutes:02d}:{secs:02d}"
def stop(self):
"""Sets the flag to stop the thread's execution loop."""
self.is_running = False
class MainApplication(QMainWindow):
"""
Main application window that creates and sets up the UI.
"""
progress_update_signal = Signal(str, int)
def __init__(self):
super().__init__()
self.setWindowTitle("SPARKS")
self.setGeometry(100, 100, 1280, 720)
self.about = None
self.help = None
self.optodes = None
self.events = None
self.terminal = None
self.bubble_widgets = {}
self.param_sections = []
self.folder_paths = []
self.section_widget = None
self.first_run = True
self.selection_widgets = {}
self.camera_state = {} # {(obs_id, cam_id): state info}
self.processing_widgets = {}
self.files_total = 0 # total number of files to process
self.files_done = set() # set of file paths done (success or fail)
self.files_failed = set() # set of failed file paths
self.files_results = {} # dict for successful results (if needed)
self.processing_threads = [] # List to hold all active ParticipantProcessor threads
self.processing_widgets = {}
self.init_ui()
self.create_menu_bar()
self.platform_suffix = "-" + PLATFORM_NAME
self.pending_update_version = None
self.pending_update_path = None
self.last_clicked_bubble = None
self.installEventFilter(self)
self.file_metadata = {}
self.current_file = None
self.worker_thread = None
self.progress_dialog = None
# Mediapipe hands
self.mp_hands = mp.solutions.hands.Hands(
static_image_mode=True,
max_num_hands=2,
min_detection_confidence=0.5
)
# Start local pending update check thread
self.local_check_thread = LocalPendingUpdateCheckThread(CURRENT_VERSION, self.platform_suffix)
self.local_check_thread.pending_update_found.connect(self.on_pending_update_found)
self.local_check_thread.no_pending_update.connect(self.on_no_pending_update)
self.local_check_thread.start()
def init_ui(self):
# Central widget and main horizontal layout
central = QWidget()
self.setCentralWidget(central)
main_layout = QHBoxLayout()
central.setLayout(main_layout)
# Left container with vertical layout: top left + bottom left
left_container = QWidget()
left_layout = QVBoxLayout()
left_container.setLayout(left_layout)
left_container.setMinimumWidth(300)
top_left_container = QGroupBox()
top_left_container.setTitle("File information")
top_left_container.setStyleSheet("QGroupBox { font-weight: bold; }") # Style if needed
top_left_container.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding)
top_left_layout = QHBoxLayout()
top_left_container.setLayout(top_left_layout)
# QTextEdit with fixed height, but only 80% width
self.top_left_widget = QTextEdit()
self.top_left_widget.setReadOnly(True)
self.top_left_widget.setPlaceholderText("Logging information will be available here.")
# Add QTextEdit to the layout with a stretch factor
top_left_layout.addWidget(self.top_left_widget, stretch=4) # 80%
# Create a vertical box layout for the right 20%
self.right_column_widget = QWidget()
right_column_layout = QVBoxLayout()
self.right_column_widget.setLayout(right_column_layout)
self.meta_fields = {
"HAND": QLineEdit(),
"GENDER": QLineEdit(),
"GROUP": QLineEdit(),
}
for key, field in self.meta_fields.items():
label = QLabel(key.capitalize())
field.setPlaceholderText(f"Enter {key}")
right_column_layout.addWidget(label)
right_column_layout.addWidget(field)
label_desc = QLabel('<a href="#">Why are these useful?</a>')
label_desc.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction)
label_desc.setOpenExternalLinks(False)
def show_info_popup():
QMessageBox.information(None, "Parameter Info - SPARKS",
"Age: Used to calculate the DPF factor.\nGender: Not currently used. "
"Will be able to sort into groups by gender in the near future.\nGroup: Allows contrast "
"images to be created comparing one group to another once the processing has completed.")
label_desc.linkActivated.connect(show_info_popup)
right_column_layout.addWidget(label_desc)
right_column_layout.addStretch() # Push fields to top
self.right_column_widget.hide()
# Add right column widget to the top-left layout (takes 20% width)
top_left_layout.addWidget(self.right_column_widget, stretch=1)
# Add top_left_container to the main left_layout
left_layout.addWidget(top_left_container, stretch=2)
# Bottom left: the bubbles inside the scroll area
self.bubble_container = QWidget()
self.bubble_layout = QGridLayout()
self.bubble_layout.setAlignment(Qt.AlignmentFlag.AlignTop)
self.bubble_container.setLayout(self.bubble_layout)
self.scroll_area = QScrollArea()
self.scroll_area.setWidgetResizable(True)
self.scroll_area.setWidget(self.bubble_container)
self.scroll_area.setMinimumHeight(300)
# Add top left and bottom left to left layout
left_layout.addWidget(self.scroll_area, stretch=8)
# Right widget (full height on right side)
self.right_container = QWidget()
right_container_layout = QVBoxLayout()
self.right_container.setLayout(right_container_layout)
# Content widget inside scroll area
self.right_content_widget = QWidget()
right_content_layout = QVBoxLayout()
self.right_content_widget.setLayout(right_content_layout)
# Option selector dropdown
self.option_selector = QComboBox()
self.option_selector.addItems(["N/A"])
right_content_layout.addWidget(self.option_selector)
# Container for the sections
self.rows_container = QWidget()
self.rows_layout = QVBoxLayout()
self.rows_layout.setSpacing(10)
self.rows_container.setLayout(self.rows_layout)
right_content_layout.addWidget(self.rows_container)
# Spacer at bottom inside scroll area content to push content up
right_content_layout.addStretch()
# Scroll area for the right side content
self.right_scroll_area = QScrollArea()
self.right_scroll_area.setWidgetResizable(True)
self.right_scroll_area.setWidget(self.right_content_widget)
# Buttons widget (fixed below the scroll area)
buttons_widget = QWidget()
buttons_layout = QHBoxLayout()
buttons_widget.setLayout(buttons_layout)
buttons_layout.addStretch()
self.button1 = QPushButton("Process")
self.button2 = QPushButton("Clear")
buttons_layout.addWidget(self.button1)
buttons_layout.addWidget(self.button2)
self.button1.setMinimumSize(100, 40)
self.button2.setMinimumSize(100, 40)
self.button1.setVisible(False)
self.button1.clicked.connect(self.on_run_task)
self.button2.clicked.connect(self.clear_all)
# Add scroll area and buttons widget to right container layout
right_container_layout.addWidget(self.right_scroll_area)
right_container_layout.addWidget(buttons_widget)
# Add left and right containers to main layout
main_layout.addWidget(left_container, stretch=70)
main_layout.addWidget(self.right_container, stretch=30)
# Set size policy to expand
self.right_container.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
self.right_scroll_area.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
# Store ParamSection widgets
self.option_selector.currentIndexChanged.connect(self.update_sections)
# Initial build
self.update_sections(0)
def create_menu_bar(self):
'''Menu Bar at the top of the screen'''
menu_bar = self.menuBar()
def make_action(name, shortcut=None, slot=None, checkable=False, checked=False, icon=None):
action = QAction(name, self)
if shortcut:
action.setShortcut(QKeySequence(shortcut))
if slot:
action.triggered.connect(slot)
if checkable:
action.setCheckable(True)
action.setChecked(checked)
if icon:
action.setIcon(QIcon(icon))
return action
# File menu and actions
file_menu = menu_bar.addMenu("File")
file_actions = [
("Open BORIS file...", "Ctrl+O", self.open_file_dialog, resource_path("icons/file_open_24dp_1F1F1F.svg")),
#("Open Folder...", "Ctrl+Alt+O", self.open_folder_dialog, resource_path("icons/folder_24dp_1F1F1F.svg")),
#("Open Folders...", "Ctrl+Shift+O", self.open_folder_dialog, resource_path("icons/folder_copy_24dp_1F1F1F.svg")),
#("Load Project...", "Ctrl+L", self.load_project, resource_path("icons/article_24dp_1F1F1F.svg")),
#("Save Project...", "Ctrl+S", self.save_project, resource_path("icons/save_24dp_1F1F1F.svg")),
#("Save Project As...", "Ctrl+Shift+S", self.save_project, resource_path("icons/save_as_24dp_1F1F1F.svg")),
]
for i, (name, shortcut, slot, icon) in enumerate(file_actions):
file_menu.addAction(make_action(name, shortcut, slot, icon=icon))
if i == 1: # after the first 3 actions (0,1,2)
file_menu.addSeparator()
file_menu.addSeparator()
file_menu.addAction(make_action("Exit", "Ctrl+Q", QApplication.instance().quit, icon=resource_path("icons/exit_to_app_24dp_1F1F1F.svg")))
# Edit menu
edit_menu = menu_bar.addMenu("Edit")
edit_actions = [
("Cut", "Ctrl+X", self.cut_text, resource_path("icons/content_cut_24dp_1F1F1F.svg")),
("Copy", "Ctrl+C", self.copy_text, resource_path("icons/content_copy_24dp_1F1F1F.svg")),
("Paste", "Ctrl+V", self.paste_text, resource_path("icons/content_paste_24dp_1F1F1F.svg"))
]
for name, shortcut, slot, icon in edit_actions:
edit_menu.addAction(make_action(name, shortcut, slot, icon=icon))
model_menu = menu_bar.addMenu("Model")
model_actions = [
("Create model from a CSV", "Ctrl+U", self.train_model_csv, resource_path("icons/content_cut_24dp_1F1F1F.svg")),
("Create model from a folder", "Ctrl+I", self.train_model_folder, resource_path("icons/content_copy_24dp_1F1F1F.svg")),
("Test model on a video", "Ctrl+O", self.test_model_video, resource_path("icons/content_paste_24dp_1F1F1F.svg")),
("Test model on a folder", "Ctrl+P", self.test_model_folder, resource_path("icons/content_paste_24dp_1F1F1F.svg")),
("Test model on a CSV", "Ctrl+P", self.test_model_csv, resource_path("icons/content_paste_24dp_1F1F1F.svg"))
]
for i, (name, shortcut, slot, icon) in enumerate(model_actions):
model_menu.addAction(make_action(name, shortcut, slot, icon=icon))
if i == 1:
model_menu.addSeparator()
# View menu
view_menu = menu_bar.addMenu("View")
toggle_statusbar_action = make_action("Toggle Status Bar", checkable=True, checked=True, slot=None)
view_menu.addAction(toggle_statusbar_action)
# Options menu (Help & About)
options_menu = menu_bar.addMenu("Options")
options_actions = [
("User Guide", "F1", self.user_guide, resource_path("icons/help_24dp_1F1F1F.svg")),
("Check for Updates", "F5", self.manual_check_for_updates, resource_path("icons/update_24dp_1F1F1F.svg")),
("About", "F12", self.about_window, resource_path("icons/info_24dp_1F1F1F.svg"))
]
for i, (name, shortcut, slot, icon) in enumerate(options_actions):
options_menu.addAction(make_action(name, shortcut, slot, icon=icon))
if i == 1 or i == 3: # after the first 2 actions (0,1)
options_menu.addSeparator()
terminal_menu = menu_bar.addMenu("Terminal")
terminal_actions = [
("New Terminal", "Ctrl+Alt+T", self.terminal_gui, resource_path("icons/terminal_24dp_1F1F1F.svg")),
]
for name, shortcut, slot, icon in terminal_actions:
terminal_menu.addAction(make_action(name, shortcut, slot, icon=icon))
self.statusbar = self.statusBar()
self.statusbar.showMessage("Ready")
def update_sections(self, index):
# Clear previous sections
for i in reversed(range(self.rows_layout.count())):
widget = self.rows_layout.itemAt(i).widget()
if widget is not None:
widget.deleteLater()
self.param_sections.clear()
# Add ParamSection widgets from SECTIONS
for section in SECTIONS:
self.section_widget = ParamSection(section)
self.rows_layout.addWidget(self.section_widget)
self.param_sections.append(self.section_widget)
def clear_bubble_widgets(self):
"""Clears all widgets from the self.bubble_layout."""
while self.bubble_layout.count():
# Get the next item in the layout
item = self.bubble_layout.takeAt(0)
# Check if the item holds a widget (most common case: QGroupBox)
widget = item.widget()
if widget:
widget.deleteLater()
# Check if the item is a spacer/stretch item
elif item.spacerItem():
self.bubble_layout.removeItem(item)
def clear_all(self):
self.cancel_task()
self.right_column_widget.hide()
# Clear the bubble layout
while self.bubble_layout.count():
item = self.bubble_layout.takeAt(0)
widget = item.widget()
if widget:
widget.deleteLater()
# Clear file data
self.bubble_widgets.clear()
self.statusBar().clearMessage()
self.raw_haemo_dict = None
self.epochs_dict = None
self.fig_bytes_dict = None
self.cha_dict = None
self.contrast_results_dict = None
self.df_ind_dict = None
self.design_matrix_dict = None
self.age_dict = None
self.gender_dict = None
self.group_dict = None
self.valid_dict = None
# Reset any visible UI elements
self.button1.setVisible(False)
self.top_left_widget.clear()
def get_output_directory(self):
"""Prompts user to select a parent folder and creates a unique, dated subfolder."""
# Open a folder selection dialog to get the parent folder
parent_dir = QFileDialog.getExistingDirectory(
self,
"Select Parent Folder for Results",
os.getcwd() # Start in the user's home directory
)
if not parent_dir:
return None # User cancelled the dialog
# Create the unique subfolder name: sparks_YYYYMMDD_HHMMSS
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
subfolder_name = f"sparks_{timestamp}"
output_path = os.path.join(parent_dir, subfolder_name)
# Create the directory if it doesn't exist
try:
os.makedirs(output_path, exist_ok=True)
print(f"Created output directory: {output_path}")
return output_path
except Exception as e:
QMessageBox.critical(self, "Error", f"Failed to create output directory: {e}")
return None
def copy_text(self):
self.top_left_widget.copy() # Trigger copy
self.statusbar.showMessage("Copied to clipboard") # Show status message
def cut_text(self):
self.top_left_widget.cut() # Trigger cut
self.statusbar.showMessage("Cut to clipboard") # Show status message
def paste_text(self):
self.top_left_widget.paste() # Trigger paste
self.statusbar.showMessage("Pasted from clipboard") # Show status message
def get_training_files(self):
"""Opens a dialog to select multiple CSV files for training."""
file_paths, _ = QFileDialog.getOpenFileNames(
self,
"Select Training CSV Files",
os.getcwd(), # Start in current directory
"CSV Files (*.csv)" # Filter for CSV files
)
return file_paths
def get_testing_files(self):
"""Opens dialogs to select the model file and the video file."""
# 1. Select Model File (.pth)
model_path, _ = QFileDialog.getOpenFileName(
self,
"Select Trained Model File (.pth)",
os.getcwd(),
"Model Files (*.pth)"
)
if not model_path: return None, None
# 2. Select Video File
video_path, _ = QFileDialog.getOpenFileName(
self,
"Select Video File for Testing",
os.getcwd(),
"Video Files (*.mp4 *.avi *.mov)"
)
return model_path, video_path
def get_testing_files2(self):
"""Opens dialogs to select the model file and the csv file."""
# 1. Select Model File (.pth)
model_path, _ = QFileDialog.getOpenFileName(
self,
"Select Trained Model File (.pth)",
os.getcwd(),
"Model Files (*.pth)"
)
if not model_path: return None, None
# 2. Select Video File
csv_path, _ = QFileDialog.getOpenFileName(
self,
"Select Video File for Testing",
os.getcwd(),
"Comma Seperated Files (*.csv)"
)
return model_path, csv_path
def train_model_csv(self):
"""Initiates the model training process, including selecting CSVs and save path."""
# 1. Get CSV paths
csv_paths = self.get_training_files() # Re-use the existing CSV file selection method
if not csv_paths:
self.statusBar().showMessage("Training cancelled. No CSV files selected.")
return
# 2. Get Model Save Path (New Pop-up)
default_filename = "best_reach_lstm.pth"
# The third argument can include a default filename
model_save_path, _ = QFileDialog.getSaveFileName(
self,
"Save Trained Model File",
os.path.join(os.getcwd(), default_filename), # Start in CWD with a default name
"PyTorch Model Files (*.pth);;All Files (*)"
)
if not model_save_path:
self.statusBar().showMessage("Training cancelled. Model save location not specified.")
return
# Ensure correct extension if the user didn't type one
if not model_save_path.lower().endswith('.pth'):
model_save_path += '.pth'
self.statusBar().showMessage(f"Starting model training on {len(csv_paths)} files. Model will save to: {os.path.basename(model_save_path)}")
# 3. Start the thread with the new path
self.train_thread = TrainModelThread(
csv_paths=csv_paths,
model_save_path=model_save_path # Pass the path here
)
# ... (connect signals, start thread, disable button, etc.) ...
self.train_thread.update.connect(self.log_training_message)
self.train_thread.finished.connect(self.on_train_finished)
self.train_thread.start()
def log_training_message(self, message):
"""Displays training status messages (e.g., in a QTextEdit)."""
# Assuming you have a QTextEdit named self.log_output (or similar)
# For simplicity, we'll use the status bar for now, but a dedicated log box is better.
self.statusBar().showMessage(f"Training: {message}")
# self.top_left_widget.append(f"[TRAIN] {message}") # Use this if you have a log box
def on_train_finished(self, final_message):
"""Handles completion of the training thread."""
self.statusBar().showMessage(final_message)
# Show final message in a pop-up or log box
QMessageBox.information(self, "Training Complete", final_message)
# Clean up and re-enable the button
self.train_thread = None
def train_model_folder(self):
return
def test_model_video(self):
model_path, video_path = self.get_testing_files()
if not model_path or not video_path:
self.statusBar().showMessage("Testing cancelled. Model or Video not selected.")
return
self.statusBar().showMessage(f"Starting model testing on {os.path.basename(video_path)}...")
# 1. Create the thread
# Note: TestModelThread expects a list of video paths, so we pass [video_path]
self.test_thread = TestModelThread(video_paths=[video_path], model_path=model_path)
# 2. Connect signals
# The update_frame signal is not used to update the main GUI here,
# as the thread is responsible for showing the cv2.imshow window.
self.test_thread.finished.connect(self.on_test_finished)
# 3. Start the thread
self.test_thread.start()
def test_model_csv(self):
model_path, csv_path = self.get_testing_files2()
if not model_path or not csv_path:
self.statusBar().showMessage("Testing cancelled. Model or Video not selected.")
return
self.statusBar().showMessage(f"Starting model testing on {os.path.basename(csv_path)}...")
# 1. Create the thread
# Note: TestModelThread expects a list of video paths, so we pass [video_path]
self.test_thread = TestModelThread2(csv_path=csv_path, model_path=model_path)
# 2. Connect signals
# The update_frame signal is not used to update the main GUI here,
# as the thread is responsible for showing the cv2.imshow window.
self.test_thread.finished.connect(self.on_csv_analysis_finished)
# 3. Start the thread
self.test_thread.start()
def on_csv_analysis_finished(self, result):
if "error" in result:
QMessageBox.critical(self, "Error", result["error"])
return
time = result["time"]
active = result["active"]
before = result["before"]
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 1, figsize=(14, 6), sharex=True)
# --- Top plot: Reach Active ---
axes[0].plot(time, active, color="green", label="Reach Active")
axes[0].set_ylabel("Probability")
axes[0].set_title("Reach Active")
axes[0].set_ylim(0, 1)
axes[0].grid(True)
# --- Bottom plot: Reach Before Contact ---
axes[1].plot(time, before, color="orange", label="Reach Before Contact")
axes[1].set_ylabel("Probability")
axes[1].set_title("Reach Before Contact")
axes[1].set_xlabel("Time (s)")
axes[1].set_ylim(0, 1)
axes[1].grid(True)
plt.tight_layout()
plt.show()
self.statusBar().showMessage("Complete.")
def on_test_finished(self):
"""Handles cleanup after the testing video window is closed."""
self.statusBar().showMessage("Model testing complete.")
# Ensure the OpenCV window is destroyed
try:
cv2.destroyAllWindows()
except:
pass
# Clean up and re-enable the button
self.test_thread = None
def test_model_folder(self):
return
def about_window(self):
if self.about is None or not self.about.isVisible():
self.about = AboutWindow(self)
self.about.show()
def user_guide(self):
if self.help is None or not self.help.isVisible():
self.help = UserGuideWindow(self)
self.help.show()
def terminal_gui(self):
if self.terminal is None or not self.terminal.isVisible():
self.terminal = TerminalWindow(self)
self.terminal.show()
def update_optode_positions(self):
return
if self.optodes is None or not self.optodes.isVisible():
self.optodes = UpdateOptodesWindow(self)
self.optodes.show()
def update_event_markers(self):
return
if self.events is None or not self.events.isVisible():
self.events = UpdateEventsWindow(self)
self.events.show()
def open_file_dialog(self):
file_path, _ = QFileDialog.getOpenFileName(
self, "Open File", "", "BORIS Files (*.boris);;All Files (*)"
)
if file_path:
self.observations_root = self.resolve_path_to_observations(file_path)
if not self.observations_root:
self.statusBar().showMessage("File loading cancelled by user.")
return # User cancelled the folder selection
# 2. Initialize and Show Progress Dialog
self.progress_dialog = ProgressDialog(self)
# ⚠️ Connect the cancel button to the task stopper
self.progress_dialog.cancel_button.clicked.connect(self.cancel_task)
self.progress_dialog.show()
# 3. Initialize and Start Worker Thread
self.worker_thread = FileLoadWorker(
file_path,
self.observations_root,
self.extract_frame_and_hands # Assuming this is a method available to MainApplication
)
# 4. Connect Signals to Main Thread Slots
self.worker_thread.progress_update.connect(self.progress_dialog.update_progress)
self.worker_thread.observations_loaded.connect(self.on_files_loaded)
self.worker_thread.loading_finished.connect(self.on_loading_finished)
self.worker_thread.loading_failed.connect(self.on_loading_failed)
self.worker_thread.start()
def resolve_path_to_observations(self, boris_file_path):
"""
Attempts to find the 'Observations' directory based on the BORIS file location.
Prompts the user if it cannot be found.
"""
# 1. Get the directory containing the BORIS file
boris_dir = os.path.dirname(boris_file_path)
# 2. Check if the 'Observations' folder is adjacent to the BORIS file
initial_check_path = os.path.join(boris_dir, "Observations")
if os.path.isdir(initial_check_path):
# Success: The 'Observations' folder is found next to the BORIS file
return initial_check_path
# 3. If not found, prompt the user to locate the Observations folder
QMessageBox.warning(
self,
"Observations Folder Missing",
"The 'Observations' folder (containing the videos) was not found next to the BORIS file. Please select the correct root folder."
)
# Open a folder selection dialog
folder_path = QFileDialog.getExistingDirectory(
self,
"Select Observations Root Folder",
boris_dir # Start the dialog in the BORIS file's directory
)
if folder_path:
# Check if the user selected a folder named 'Observations'
if os.path.basename(folder_path) == "Observations":
return folder_path
# Check if the user selected the *parent* of 'Observations'
elif os.path.isdir(os.path.join(folder_path, "Observations")):
return os.path.join(folder_path, "Observations")
# If the selected folder is neither, return the selected path and hope it contains the relative paths
else:
return folder_path
# If the user cancels the dialog, return None
return None
def on_files_loaded(self, boris, observations, root_path):
# 1. Update MainApplication state variables
self.boris = boris
self.observations = observations
self.observations_root = root_path
# 2. Build the UI grid using the data gathered by the worker
self.build_preview_grid(self.worker_thread.previews_data)
self.statusBar().showMessage(f"{self.worker_thread.file_path} loaded.")
self.button1.setVisible(True)
def on_loading_finished(self):
if self.progress_dialog:
self.progress_dialog.accept() # Close the dialog if finished successfully
def on_loading_failed(self, error_message):
if self.progress_dialog:
self.progress_dialog.reject() # Close the dialog
QMessageBox.critical(self, "Error Loading Files", error_message)
self.statusBar().showMessage("File loading failed.")
def cancel_task(self):
if self.worker_thread and self.worker_thread.isRunning():
self.worker_thread.stop() # Tell the thread to stop processing
self.worker_thread.wait() # Wait for the thread to safely exit
if self.progress_dialog:
self.progress_dialog.reject()
self.statusBar().showMessage("File loading cancelled.")
# Add a dummy implementation for the processing function if it's not shown:
# def extract_frame_and_hands(self, video_path, frame_idx):
# # Placeholder implementation - replace with your actual function
# return None, None
def build_preview_grid(self, previews_data):
for obs_id, obs_data in self.observations.items():
group = QGroupBox(f"Participant / Observation: {obs_id}")
grouplayout = QVBoxLayout(group)
# Participant-level skip dropdown
participant_dropdown = QComboBox()
participant_dropdown.addItem("Skip this participant", 0)
participant_dropdown.addItem("Process this participant", 1)
participant_dropdown.setCurrentIndex(0) # default: skip
if "participant_selection" not in self.selection_widgets:
self.selection_widgets["participant_selection"] = {}
self.selection_widgets["participant_selection"][obs_id] = participant_dropdown
grouplayout.addWidget(QLabel("Participant Option:"))
grouplayout.addWidget(participant_dropdown)
files_dict = obs_data.get("file", {})
media_info = obs_data.get("media_info", {})
print(media_info)
fps_dict = media_info.get("fps", {})
fps_default = float(list(fps_dict.values())[0]) if fps_dict else 30.0
if obs_id not in self.selection_widgets:
self.selection_widgets[obs_id] = {}
for cam_id, paths in files_dict.items():
# Check if the worker successfully gathered the preview data for this file
state_key = (obs_id, cam_id)
if state_key not in previews_data:
# Skip files the worker couldn't process (e.g., path not found)
continue
# Retrieve the pre-calculated data
state_data = previews_data[state_key]
frame_rgb = state_data["frame_rgb"]
results = state_data["results"]
# Only draw the overlay if the worker found data
if frame_rgb is not None:
display_img = self.draw_hand_overlay(frame_rgb, results)
else:
# Use a placeholder image or skip if no data
continue
# Convert to pixmap
h, w, _ = display_img.shape
qimg = QImage(display_img.data, w, h, 3 * w, QImage.Format_RGB888)
pix = QPixmap.fromImage(qimg).scaled(350, 350, Qt.KeepAspectRatio)
# -------- UI Row --------
row = QWidget()
row_layout = QHBoxLayout(row)
preview_label = QLabel()
preview_label.setPixmap(pix)
row_layout.addWidget(preview_label)
# Dropdown
dropdown = QComboBox()
dropdown.addItem("Skip this camera", -1)
if results and results.multi_hand_landmarks:
for idx in range(len(results.multi_hand_landmarks)):
dropdown.addItem(f"Use Hand {idx}", idx)
row_layout.addWidget(dropdown)
# Store dropdown
self.selection_widgets[obs_id][cam_id] = dropdown
self.camera_state[(obs_id, cam_id)] = {
"path": state_data["video_path"],
"frame_idx": 0,
"fps": state_data["fps"],
"preview_label": preview_label,
"dropdown": dropdown,
"initial_wrists": state_data["initial_wrists"]
}
# Skip button
skip_btn = QPushButton("Skip 1s → Rescan")
skip_btn.clicked.connect(lambda _, o=obs_id, c=cam_id: self.skip_and_rescan(o, c))
row_layout.addWidget(skip_btn)
grouplayout.addWidget(row)
self.bubble_layout.addWidget(group)
#self.bubble_layout.addStretch()
# ============================================
# FRAME EXTRACTION + MEDIA PIPE DETECTION
# ============================================
def extract_frame_and_hands(self, video_path, frame_idx):
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
cap.release()
if not ret:
return None, None
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = self.mp_hands.process(rgb)
return rgb, results
# ============================================
# DRAW OVERLAYED HANDS + BIG LABELS
# ============================================
def draw_hand_overlay(self, img, results, initial_wrists=None):
draw_img = img.copy()
h, w, _ = draw_img.shape
COLORS = [
(0, 255, 0), # Hand 0 green
(255, 0, 255), # Hand 1 magenta
(0, 255, 255), # Hand 2 cyan
]
if results and results.multi_hand_landmarks:
# Map detected hands to initial wrists if provided
if initial_wrists:
hand_mapping = {} # initial_idx -> current_idx
used_indices = set()
for i, iw in enumerate(initial_wrists):
min_dist = float('inf')
best_idx = None
for j, lm in enumerate(results.multi_hand_landmarks):
if j in used_indices:
continue # skip already assigned
wrist = lm.landmark[0]
dist = (wrist.x - iw[0])**2 + (wrist.y - iw[1])**2
if dist < min_dist:
min_dist = dist
best_idx = j
if best_idx is not None:
hand_mapping[i] = best_idx
used_indices.add(best_idx)
else:
hand_mapping = {i: i for i in range(len(results.multi_hand_landmarks))}
for initial_idx, current_idx in hand_mapping.items():
lm_obj = results.multi_hand_landmarks[current_idx]
mp.solutions.drawing_utils.draw_landmarks(
draw_img,
lm_obj,
mp.solutions.hands.HAND_CONNECTIONS
)
wrist = lm_obj.landmark[0]
wx, wy = int(wrist.x * w), int(wrist.y * h)
color = COLORS[initial_idx % len(COLORS)]
# Big outline
cv2.putText(draw_img, str(initial_idx), (wx, wy - 40),
cv2.FONT_HERSHEY_SIMPLEX, 3.2,
(0, 0, 0), 10, cv2.LINE_AA)
# Big colored label
cv2.putText(draw_img, str(initial_idx), (wx, wy - 40),
cv2.FONT_HERSHEY_SIMPLEX, 3.2,
color, 6, cv2.LINE_AA)
return draw_img
# ============================================
# SKIP 1s AND RESCAN FRAME
# ============================================
def skip_and_rescan(self, obs_id, cam_id):
key = (obs_id, cam_id)
state = self.camera_state[key]
fps = state["fps"]
state["frame_idx"] += int(fps)
rgb, results = self.extract_frame_and_hands(
state["path"], state["frame_idx"]
)
if rgb is None:
return
display_img = self.draw_hand_overlay(rgb, results, state.get("initial_wrists"))
h, w, _ = display_img.shape
qimg = QImage(display_img.data, w, h, 3 * w, QImage.Format_RGB888)
pix = QPixmap.fromImage(qimg).scaled(350, 350, Qt.KeepAspectRatio)
# Update preview
state["preview_label"].setPixmap(pix)
# Update dropdown
dropdown = state["dropdown"]
dropdown.clear()
dropdown.addItem("Skip this camera", -1)
if results and results.multi_hand_landmarks:
for idx in range(len(results.multi_hand_landmarks)):
dropdown.addItem(f"Use Hand {idx}", idx)
# def save_project(self, onCrash=False):
# if not onCrash:
# filename, _ = QFileDialog.getSaveFileName(
# self, "Save Project", "", "SPARK Project (*.spark)"
# )
# if not filename:
# return
# else:
# if PLATFORM_NAME == "darwin":
# filename = os.path.join(os.path.dirname(sys.executable), "../../../sparks_autosave.spark")
# else:
# filename = os.path.join(os.getcwd(), "sparks_autosave.spark")
# try:
# # Ensure the filename has the proper extension
# if not filename.endswith(".spark"):
# filename += ".spark"
# project_path = Path(filename).resolve()
# project_dir = project_path.parent
# file_list = [
# str(PurePosixPath(Path(bubble.file_path).resolve().relative_to(project_dir)))
# for bubble in self.bubble_widgets.values()
# ]
# progress_states = {
# str(PurePosixPath(Path(bubble.file_path).resolve().relative_to(project_dir))): bubble.current_step
# for bubble in self.bubble_widgets.values()
# }
# project_data = {
# "file_list": file_list,
# "progress_states": progress_states,
# "raw_haemo_dict": self.raw_haemo_dict,
# "epochs_dict": self.epochs_dict,
# "fig_bytes_dict": self.fig_bytes_dict,
# "cha_dict": self.cha_dict,
# "contrast_results_dict": self.contrast_results_dict,
# "df_ind_dict": self.df_ind_dict,
# "design_matrix_dict": self.design_matrix_dict,
# "age_dict": self.age_dict,
# "gender_dict": self.gender_dict,
# "group_dict": self.group_dict,
# "valid_dict": self.valid_dict,
# }
# def sanitize(obj):
# if isinstance(obj, Path):
# return str(PurePosixPath(obj))
# elif isinstance(obj, dict):
# return {sanitize(k): sanitize(v) for k, v in obj.items()}
# elif isinstance(obj, list):
# return [sanitize(i) for i in obj]
# return obj
# project_data = sanitize(project_data)
# with open(filename, "wb") as f:
# pickle.dump(project_data, f)
# QMessageBox.information(self, "Success", f"Project saved to:\n{filename}")
# except Exception as e:
# if not onCrash:
# QMessageBox.critical(self, "Error", f"Failed to save project:\n{e}")
# def load_project(self):
# filename, _ = QFileDialog.getOpenFileName(
# self, "Load Project", "", "SPARK Project (*.spark)"
# )
# if not filename:
# return
# try:
# with open(filename, "rb") as f:
# data = pickle.load(f)
# self.raw_haemo_dict = data.get("raw_haemo_dict", {})
# self.epochs_dict = data.get("epochs_dict", {})
# self.fig_bytes_dict = data.get("fig_bytes_dict", {})
# self.cha_dict = data.get("cha_dict", {})
# self.contrast_results_dict = data.get("contrast_results_dict", {})
# self.df_ind_dict = data.get("df_ind_dict", {})
# self.design_matrix_dict = data.get("design_matrix_dict", {})
# self.age_dict = data.get("age_dict", {})
# self.gender_dict = data.get("gender_dict", {})
# self.group_dict = data.get("group_dict", {})
# self.valid_dict = data.get("valid_dict", {})
# project_dir = Path(filename).parent
# # Convert saved relative paths to absolute paths
# file_list = [str((project_dir / Path(rel_path)).resolve()) for rel_path in data["file_list"]]
# # Also resolve progress_states with updated paths
# raw_progress = data.get("progress_states", {})
# progress_states = {
# str((project_dir / Path(rel_path)).resolve()): step
# for rel_path, step in raw_progress.items()
# }
# self.show_files_as_bubbles_from_list(file_list, progress_states, filename)
# # Re-enable buttons
# # self.button1.setVisible(True)
# self.button3.setVisible(True)
# QMessageBox.information(self, "Loaded", f"Project loaded from:\n{filename}")
# except Exception as e:
# QMessageBox.critical(self, "Error", f"Failed to load project:\n{e}")
def placeholder(self):
QMessageBox.information(self, "Placeholder", "This feature is not implemented yet.")
def save_metadata(self, file_path):
if not file_path:
return
self.file_metadata[file_path] = {
key: field.text()
for key, field in self.meta_fields.items()
}
def get_all_metadata(self):
# First, make sure current file's edits are saved
for field in self.meta_fields.values():
field.clearFocus()
# Save current file's metadata
if self.current_file:
self.save_metadata(self.current_file)
return self.file_metadata
def cancel_task(self):
self.button1.clicked.disconnect(self.cancel_task)
self.button1.setText("Stopping...")
if hasattr(self, "result_process") and self.result_process.is_alive():
parent = psutil.Process(self.result_process.pid)
children = parent.children(recursive=True)
for child in children:
try:
child.kill()
except psutil.NoSuchProcess:
pass
self.result_process.terminate()
self.result_process.join()
if hasattr(self, "result_timer") and self.result_timer.isActive():
self.result_timer.stop()
# if hasattr(self, "result_process") and self.result_process.is_alive():
# self.result_process.terminate() # Forcefully terminate the process
# self.result_process.join() # Wait for it to properly close
# # Stop the QTimer if running
# if hasattr(self, "result_timer") and self.result_timer.isActive():
# self.result_timer.stop()
self.statusbar.showMessage("Processing cancelled.")
self.button1.clicked.connect(self.on_run_task)
self.button1.setText("Process")
'''MODULE FILE'''
def on_run_task(self):
self.button1.clicked.disconnect(self.on_run_task)
self.button1.setText("Cancel")
self.button1.clicked.connect(self.cancel_task)
# Clear previous processing widgets (if any)
# for obs_id, widgets in getattr(self, "processing_widgets", {}).items():
# for w in widgets.values():
# if isinstance(w, QWidget):
# w.deleteLater()
self.clear_bubble_widgets()
self.processing_widgets = {}
self.processing_threads = []
self.output_dir = self.get_output_directory()
if not self.output_dir:
# Revert button state if user cancels
self.button1.clicked.disconnect(self.cancel_task)
self.button1.setText("Process")
self.button1.clicked.connect(self.on_run_task)
self.statusBar().showMessage("Processing cancelled by user.")
return
selected_files_to_process = []
# 2⃣ Iterate through all participant/camera combinations to find what to process
# We must iterate over all entries in self.selection_widgets (the structure built by build_preview_grid)
for obs_id, cam_dropdowns in self.selection_widgets.items():
if obs_id == "participant_selection":
continue # skip the participant-level dropdowns
for cam_id, dropdown in cam_dropdowns.items():
selected_hand_idx = dropdown.currentData()
# Check if the selection is NOT 'Skip' (-1)
if selected_hand_idx != -1:
state_data = self.camera_state.get((obs_id, cam_id))
if state_data:
selected_files_to_process.append({
"obs_id": obs_id,
"cam_id": cam_id,
"hand_idx": selected_hand_idx,
"path": state_data.get("path", "Unknown")
})
if not selected_files_to_process:
self.statusBar().showMessage("No files selected for processing (use dropdowns).")
# Revert button state
self.button1.clicked.disconnect(self.cancel_task)
self.button1.setText("Process")
self.button1.clicked.connect(self.on_run_task)
return
self.statusBar().showMessage(f"Starting to process {len(selected_files_to_process)} files...")
QApplication.processEvents() # Refresh GUI before starting threads
for file_info in selected_files_to_process:
obs_id = file_info["obs_id"]
cam_id = file_info["cam_id"]
selected_hand_idx = file_info["hand_idx"] # THIS is fixed per iteration
# ... (thread logic setup) ...
# --- PROGRESS ROW UI SETUP (50% Label | 40% Bar | 10% Time) ---
row = QWidget()
row_layout = QHBoxLayout(row)
row_layout.setContentsMargins(0, 0, 0, 0)
# LEFT SIDE (50%): Filename Label
filename = os.path.basename(file_info["path"])
filename_label = QLabel(f"P {obs_id}/{cam_id}: {filename}")
filename_label.setWordWrap(True)
row_layout.addWidget(filename_label, stretch=5) # 5 units = 50%
# MIDDLE (40%): Progress Bar
progress_bar = QProgressBar()
progress_bar.setRange(0, 100)
row_layout.addWidget(progress_bar, stretch=4) # 4 units = 40%
# RIGHT SIDE (10%): Time Indicator Label
time_label = QLabel("00:00/00:00") # Default placeholder
time_label.setAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter)
row_layout.addWidget(time_label, stretch=1) # 1 unit = 10%
# Add the combined row to the bubble layout area
row_index = self.bubble_layout.rowCount()
self.bubble_layout.addWidget(row, row_index, 0, 1, 1)
# Save widgets for thread updates
state_key = (obs_id, cam_id)
self.processing_widgets[state_key] = {
"label": filename_label,
"progress": progress_bar,
# 👇 NEW: Save the time label reference
"time_label": time_label
}
# --- THREAD START ---
output_csv = f"{obs_id}_{cam_id}_processed.csv"
# ⚠️ Ensure ParticipantProcessor is defined to accept self.observations_root
processor = ParticipantProcessor(
obs_id=obs_id,
boris_json=self.boris,
selected_cam_id=cam_id,
selected_hand_idx=selected_hand_idx,
observations_root=self.observations_root, # CRITICAL PATH FIX
hide_preview=False, # Assuming default for now
output_csv=output_csv,
output_dir=self.output_dir,
initial_wrists=self.camera_state[(obs_id, cam_id)].get("initial_wrists")
)
# Connect signals to the UI elements
processor.progress_updated.connect(progress_bar.setValue)
processor.finished_processing.connect(
lambda obs_id=obs_id, cam_id=cam_id: self.on_processing_finished(obs_id, cam_id)
)
processor.time_updated.connect(time_label.setText) # Connects to THIS time_label
processor.start()
self.processing_threads.append(processor)
# Set the vertical stretch to push all progress bars to the top
num_rows = self.bubble_layout.rowCount()
self.bubble_layout.setRowStretch(num_rows, 1)
def check_for_pipeline_results(self):
return
def on_processing_finished(self, obs_id, cam_id):
state_key = (obs_id, cam_id)
if state_key in self.processing_widgets:
widgets = self.processing_widgets[state_key]
widgets["progress"].setValue(100)
label = self.processing_widgets[state_key]["label"]
label.setText(f"P {obs_id}/{cam_id}: ✅ COMPLETE")
# Check if all threads are finished to revert the main button
if not any(t.isRunning() for t in self.processing_threads):
self.cancel_task() # Reverts the button and cleans up
self.statusBar().showMessage("All processing complete!")
def cancel_task(self):
"""Stops all active processing threads and reverts UI to 'Process' state."""
if not self.processing_threads:
return
self.statusBar().showMessage("Cancelling processing... Please wait.")
# 1. Stop all threads gracefully
for thread in self.processing_threads:
if thread.isRunning():
# Assuming ParticipantProcessor has a .stop() method
thread.stop()
thread.wait() # Wait for thread to exit its loop
# 2. Clean up threads and widgets
self.processing_threads.clear()
# 3. Revert button and status bar
self.button1.clicked.disconnect(self.cancel_task)
self.button1.setText("Process")
self.button1.clicked.connect(self.on_run_task)
# 4. Clear the progress bars display
#self.clear_bubble_widgets()
self.statusBar().showMessage("Processing cancelled. Ready to load or process.")
def show_error_popup(self, title, error_message, traceback_str=""):
msgbox = QMessageBox(self)
msgbox.setIcon(QMessageBox.Warning)
msgbox.setWindowTitle("Warning - SPARKS")
message = (
f"SPARKS has encountered an error processing the file {title}.<br><br>"
"This error was likely due to incorrect parameters on the right side of the screen and not an error with your data. "
"Processing of the remaining files continues in the background and this participant will be ignored in the analysis. "
"If you think the parameters on the right side are correct for your data, raise an issue <a href='https://git.research.dezeeuw.ca/tyler/sparks/issues'>here</a>.<br><br>"
f"Error message: {error_message}"
)
msgbox.setTextFormat(Qt.TextFormat.RichText)
msgbox.setText(message)
msgbox.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction)
# Add traceback to detailed text
if traceback_str:
msgbox.setDetailedText(traceback_str)
msgbox.setStandardButtons(QMessageBox.Ok)
msgbox.exec_()
def cleanup_after_process(self):
if hasattr(self, 'result_process'):
self.result_process.join(timeout=0)
if self.result_process.is_alive():
self.result_process.terminate()
self.result_process.join()
if hasattr(self, 'result_queue'):
if 'AutoProxy' in repr(self.result_queue):
pass
else:
self.result_queue.close()
self.result_queue.join_thread()
if hasattr(self, 'progress_queue'):
if 'AutoProxy' in repr(self.progress_queue):
pass
else:
self.progress_queue.close()
self.progress_queue.join_thread()
# Shutdown manager to kill its server process and clean up
if hasattr(self, 'manager'):
self.manager.shutdown()
'''UPDATER'''
def manual_check_for_updates(self):
self.local_check_thread = LocalPendingUpdateCheckThread(CURRENT_VERSION, self.platform_suffix)
self.local_check_thread.pending_update_found.connect(self.on_pending_update_found)
self.local_check_thread.no_pending_update.connect(self.on_no_pending_update)
self.local_check_thread.start()
def on_pending_update_found(self, version, folder_path):
self.statusBar().showMessage(f"Pending update found: version {version}")
self.pending_update_version = version
self.pending_update_path = folder_path
self.show_pending_update_popup()
def on_no_pending_update(self):
# No pending update found locally, start server check directly
self.statusBar().showMessage("No pending local update found. Checking server...")
self.start_update_check_thread()
def show_pending_update_popup(self):
msg_box = QMessageBox(self)
msg_box.setWindowTitle("Pending Update Found")
msg_box.setText(f"A previously downloaded update (version {self.pending_update_version}) is available at:\n{self.pending_update_path}\nWould you like to install it now?")
install_now_button = msg_box.addButton("Install Now", QMessageBox.ButtonRole.AcceptRole)
install_later_button = msg_box.addButton("Install Later", QMessageBox.ButtonRole.RejectRole)
msg_box.exec()
if msg_box.clickedButton() == install_now_button:
self.install_update(self.pending_update_path)
else:
self.statusBar().showMessage("Pending update available. Install later.")
# After user dismisses, still check the server for new updates
self.start_update_check_thread()
def start_update_check_thread(self):
self.check_thread = UpdateCheckThread()
self.check_thread.download_requested.connect(self.on_server_update_requested)
self.check_thread.no_update_available.connect(self.on_server_no_update)
self.check_thread.error_occurred.connect(self.on_error)
self.check_thread.start()
def on_server_no_update(self):
self.statusBar().showMessage("No new updates found on server.", 5000)
def on_server_update_requested(self, download_url, latest_version):
if self.pending_update_version:
cmp = self.version_compare(latest_version, self.pending_update_version)
if cmp > 0:
# Server version is newer than pending update
self.statusBar().showMessage(f"Newer version {latest_version} available on server. Removing old pending update...")
try:
shutil.rmtree(self.pending_update_path)
self.statusBar().showMessage(f"Deleted old update folder: {self.pending_update_path}")
except Exception as e:
self.statusBar().showMessage(f"Failed to delete old update folder: {e}")
# Clear pending update info so new download proceeds
self.pending_update_version = None
self.pending_update_path = None
# Download the new update
self.download_update(download_url, latest_version)
elif cmp == 0:
# Versions equal, no download needed
self.statusBar().showMessage(f"Pending update version {self.pending_update_version} is already latest. No download needed.")
else:
# Server version older than pending? Unlikely but just keep pending update
self.statusBar().showMessage(f"Pending update version {self.pending_update_version} is newer than server version. No action.")
else:
# No pending update, just download
self.download_update(download_url, latest_version)
def download_update(self, download_url, latest_version):
self.statusBar().showMessage("Downloading update...")
self.download_thread = UpdateDownloadThread(download_url, latest_version)
self.download_thread.update_ready.connect(self.on_update_ready)
self.download_thread.error_occurred.connect(self.on_error)
self.download_thread.start()
def on_update_ready(self, latest_version, extract_folder):
self.statusBar().showMessage("Update downloaded and extracted.")
msg_box = QMessageBox(self)
msg_box.setWindowTitle("Update Ready")
msg_box.setText(f"Version {latest_version} has been downloaded and extracted to:\n{extract_folder}\nWould you like to install it now?")
install_now_button = msg_box.addButton("Install Now", QMessageBox.ButtonRole.AcceptRole)
install_later_button = msg_box.addButton("Install Later", QMessageBox.ButtonRole.RejectRole)
msg_box.exec()
if msg_box.clickedButton() == install_now_button:
self.install_update(extract_folder)
else:
self.statusBar().showMessage("Update ready. Install later.")
def install_update(self, extract_folder):
# Path to updater executable
if PLATFORM_NAME == 'windows':
updater_path = os.path.join(os.getcwd(), "sparks_updater.exe")
elif PLATFORM_NAME == 'darwin':
if getattr(sys, 'frozen', False):
updater_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks_updater.app")
else:
updater_path = os.path.join(os.getcwd(), "../sparks_updater.app")
elif PLATFORM_NAME == 'linux':
updater_path = os.path.join(os.getcwd(), "sparks_updater")
else:
updater_path = os.getcwd()
if not os.path.exists(updater_path):
QMessageBox.critical(self, "Error", f"Updater not found at:\n{updater_path}. The absolute path was {os.path.abspath(updater_path)}")
return
# Launch updater with extracted folder path as argument
try:
# Pass current app's executable path for updater to relaunch
main_app_executable = os.path.abspath(sys.argv[0])
print(f'Launching updater with: "{updater_path}" "{extract_folder}" "{main_app_executable}"')
if PLATFORM_NAME == 'darwin':
subprocess.Popen(['open', updater_path, '--args', extract_folder, main_app_executable])
else:
subprocess.Popen([updater_path, f'{extract_folder}', f'{main_app_executable}'], cwd=os.path.dirname(updater_path))
# Close the current app so updater can replace files
sys.exit(0)
except Exception as e:
QMessageBox.critical(self, "Error", f"[Updater Launch Failed]\n{str(e)}\n{traceback.format_exc()}")
def on_error(self, message):
# print(f"Error: {message}")
self.statusBar().showMessage(f"Error occurred during update process. {message}")
def version_compare(self, v1, v2):
def normalize(v): return [int(x) for x in v.split(".")]
return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2))
def closeEvent(self, event):
# Gracefully shut down multiprocessing children
print("Window is closing. Cleaning up...")
if hasattr(self, 'manager'):
self.manager.shutdown()
for child in self.findChildren(QWidget):
if child is not self and child.isVisible():
child.close()
kill_child_processes()
event.accept()
def wait_for_process_to_exit(process_name, timeout=10):
"""
Waits for a process with the specified name to exit within a timeout period.
Args:
process_name (str): Name (or part of the name) of the process to wait for.
timeout (int, optional): Maximum time to wait in seconds. Defaults to 10.
Returns:
bool: True if the process exited before the timeout, False otherwise.
"""
print(f"Waiting for {process_name} to exit...")
deadline = time.time() + timeout
while time.time() < deadline:
still_running = False
for proc in psutil.process_iter(['name']):
try:
if proc.info['name'] and process_name.lower() in proc.info['name'].lower():
still_running = True
print(f"Still running: {proc.info['name']} (PID: {proc.pid})")
break
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue
if not still_running:
print(f"{process_name} has exited.")
return True
time.sleep(0.5)
print(f"{process_name} did not exit in time.")
return False
def finish_update_if_needed():
"""
Completes a pending application update if '--finish-update' is present in the command-line arguments.
"""
if "--finish-update" in sys.argv:
print("Finishing update...")
if PLATFORM_NAME == 'darwin':
app_dir = '/tmp/sparkstempupdate'
else:
app_dir = os.getcwd()
# 1. Find update folder
update_folder = None
for entry in os.listdir(app_dir):
entry_path = os.path.join(app_dir, entry)
if os.path.isdir(entry_path) and entry.startswith("sparks-") and entry.endswith("-" + PLATFORM_NAME):
update_folder = os.path.join(app_dir, entry)
break
if update_folder is None:
print("No update folder found. Skipping update steps.")
return
if PLATFORM_NAME == 'darwin':
update_folder = os.path.join(update_folder, "sparks-darwin")
# 2. Wait for sparks_updater to exit
print("Waiting for sparks_updater to exit...")
for proc in psutil.process_iter(['pid', 'name']):
if proc.info['name'] and "sparks_updater" in proc.info['name'].lower():
try:
proc.wait(timeout=5)
except psutil.TimeoutExpired:
print("Force killing lingering sparks_updater")
proc.kill()
# 3. Replace the updater
if PLATFORM_NAME == 'windows':
new_updater = os.path.join(update_folder, "sparks_updater.exe")
dest_updater = os.path.join(app_dir, "sparks_updater.exe")
elif PLATFORM_NAME == 'darwin':
new_updater = os.path.join(update_folder, "sparks_updater.app")
dest_updater = os.path.abspath(os.path.join(sys.executable, "../../../../sparks_updater.app"))
elif PLATFORM_NAME == 'linux':
new_updater = os.path.join(update_folder, "sparks_updater")
dest_updater = os.path.join(app_dir, "sparks_updater")
else:
print("No platform??")
new_updater = os.getcwd()
dest_updater = os.getcwd()
print(f"New updater is {new_updater}")
print(f"Dest updater is {dest_updater}")
print("Writable?", os.access(dest_updater, os.W_OK))
print("Executable path:", sys.executable)
print("Trying to copy:", new_updater, "->", dest_updater)
if os.path.exists(new_updater):
try:
if os.path.exists(dest_updater):
if PLATFORM_NAME == 'darwin':
try:
if os.path.isdir(dest_updater):
shutil.rmtree(dest_updater)
print(f"Deleted directory: {dest_updater}")
else:
os.remove(dest_updater)
print(f"Deleted file: {dest_updater}")
except Exception as e:
print(f"Error deleting {dest_updater}: {e}")
else:
os.remove(dest_updater)
if PLATFORM_NAME == 'darwin':
wait_for_process_to_exit("sparks_updater", timeout=10)
subprocess.check_call(["ditto", new_updater, dest_updater])
else:
shutil.copy2(new_updater, dest_updater)
if PLATFORM_NAME in ('linux', 'darwin'):
os.chmod(dest_updater, 0o755)
if PLATFORM_NAME == 'darwin':
remove_quarantine(dest_updater)
print("sparks_updater replaced.")
except Exception as e:
print(f"Failed to replace sparks_updater: {e}")
# 4. Delete the update folder
try:
if PLATFORM_NAME == 'darwin':
shutil.rmtree(app_dir)
else:
shutil.rmtree(update_folder)
except Exception as e:
print(f"Failed to delete update folder: {e}")
QMessageBox.information(None, "Update Complete", "The application has been successfully updated.")
sys.argv.remove("--finish-update")
def remove_quarantine(app_path):
"""
Removes the macOS quarantine attribute from the specified application path.
"""
script = f'''
do shell script "xattr -d -r com.apple.quarantine {shlex.quote(app_path)}" with administrator privileges with prompt "SPARKS needs privileges to finish the update. (2/2)"
'''
try:
subprocess.run(['osascript', '-e', script], check=True)
print("✅ Quarantine attribute removed.")
except subprocess.CalledProcessError as e:
print("❌ Failed to remove quarantine attribute.")
print(e)
def resource_path(relative_path):
"""
Get absolute path to resource regardless of running directly or packaged using PyInstaller
"""
if hasattr(sys, '_MEIPASS'):
# PyInstaller bundle path
base_path = sys._MEIPASS
else:
base_path = os.path.dirname(os.path.abspath(__file__))
return os.path.join(base_path, relative_path)
def kill_child_processes():
"""
Goodbye children
"""
try:
parent = psutil.Process(os.getpid())
children = parent.children(recursive=True)
for child in children:
try:
child.kill()
except psutil.NoSuchProcess:
pass
psutil.wait_procs(children, timeout=5)
except Exception as e:
print(f"Error killing child processes: {e}")
def exception_hook(exc_type, exc_value, exc_traceback):
"""
Method that will display a popup when the program hard crashes containg what went wrong
"""
error_msg = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
print(error_msg) # also print to console
kill_child_processes()
# Show error message box
# Make sure QApplication exists (or create a minimal one)
app = QApplication.instance()
if app is None:
app = QApplication(sys.argv)
show_critical_error(error_msg)
# Exit the app after user acknowledges
sys.exit(1)
def show_critical_error(error_msg):
msg_box = QMessageBox()
msg_box.setIcon(QMessageBox.Icon.Critical)
msg_box.setWindowTitle("Something went wrong!")
if PLATFORM_NAME == "darwin":
log_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks.log")
log_path2 = os.path.join(os.path.dirname(sys.executable), "../../../sparks_error.log")
save_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks_autosave.spark")
else:
log_path = os.path.join(os.getcwd(), "sparks.log")
log_path2 = os.path.join(os.getcwd(), "sparks_error.log")
save_path = os.path.join(os.getcwd(), "sparks_autosave.spark")
shutil.copy(log_path, log_path2)
log_path2 = Path(log_path2).absolute().as_posix()
autosave_path = Path(save_path).absolute().as_posix()
log_link = f"file:///{log_path2}"
autosave_link = f"file:///{autosave_path}"
window.save_project(True)
message = (
"SPARKS has encountered an unrecoverable error and needs to close.<br><br>"
f"We are sorry for the inconvenience. An autosave was attempted to be saved to <a href='{autosave_link}'>{autosave_path}</a>, but it may not have been saved. "
"If the file was saved, it still may not be intact, openable, or contain the correct data. Use the autosave at your discretion.<br><br>"
"This unrecoverable error was likely due to an error with SPARKS and not your data.<br>"
f"Please raise an issue <a href='https://git.research.dezeeuw.ca/tyler/sparks/issues'>here</a> and attach the error file located at <a href='{log_link}'>{log_path2}</a><br><br>"
f"<pre>{error_msg}</pre>"
)
msg_box.setTextFormat(Qt.TextFormat.RichText)
msg_box.setText(message)
msg_box.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction)
msg_box.setStandardButtons(QMessageBox.StandardButton.Ok)
msg_box.exec()
if __name__ == "__main__":
# Redirect exceptions to the popup window
sys.excepthook = exception_hook
# Set up application logging
if PLATFORM_NAME == "darwin":
log_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks.log")
else:
log_path = os.path.join(os.getcwd(), "sparks.log")
try:
os.remove(log_path)
except:
pass
sys.stdout = open(log_path, "a", buffering=1)
sys.stderr = sys.stdout
print(f"\n=== App started at {datetime.now()} ===\n")
freeze_support() # Required for PyInstaller + multiprocessing
# Only run GUI in the main process
if current_process().name == 'MainProcess':
app = QApplication(sys.argv)
finish_update_if_needed()
window = MainApplication()
if PLATFORM_NAME == "darwin":
app.setWindowIcon(QIcon(resource_path("icons/main.icns")))
window.setWindowIcon(QIcon(resource_path("icons/main.icns")))
else:
app.setWindowIcon(QIcon(resource_path("icons/main.ico")))
window.setWindowIcon(QIcon(resource_path("icons/main.ico")))
window.show()
sys.exit(app.exec())
# Not 4000 lines yay!