3223 lines
122 KiB
Python
3223 lines
122 KiB
Python
"""
|
||
Filename: main.py
|
||
Description: SPARKS main executable
|
||
|
||
Author: Tyler de Zeeuw
|
||
License: GPL-3.0
|
||
"""
|
||
|
||
# Built-in imports
|
||
import os
|
||
import re
|
||
import csv
|
||
import sys
|
||
import math
|
||
import json
|
||
import time
|
||
import shlex
|
||
import pickle
|
||
import shutil
|
||
import zipfile
|
||
import platform
|
||
import traceback
|
||
import subprocess
|
||
from itertools import product
|
||
from pathlib import Path, PurePosixPath
|
||
from datetime import datetime
|
||
from multiprocessing import Process, current_process, freeze_support, Manager
|
||
import cv2
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
# External library imports
|
||
import psutil
|
||
import requests
|
||
import torch
|
||
import torch.nn as nn
|
||
import mediapipe as mp
|
||
from torch.utils.data import TensorDataset, DataLoader
|
||
from sklearn.utils.class_weight import compute_class_weight
|
||
import matplotlib.pyplot as plt
|
||
|
||
from PySide6.QtWidgets import (
|
||
QApplication, QWidget, QMessageBox, QVBoxLayout, QHBoxLayout, QTextEdit, QScrollArea, QComboBox, QGridLayout,
|
||
QPushButton, QMainWindow, QFileDialog, QLabel, QLineEdit, QFrame, QSizePolicy, QGroupBox, QDialog, QListView, QMenu, QProgressBar, QCheckBox
|
||
)
|
||
from PySide6.QtCore import QThread, Signal, Qt, QTimer, QEvent, QSize, QPoint
|
||
from PySide6.QtGui import QAction, QKeySequence, QIcon, QIntValidator, QDoubleValidator, QPixmap, QStandardItemModel, QStandardItem, QImage
|
||
from PySide6.QtSvgWidgets import QSvgWidget # needed to show svgs when app is not frozen
|
||
|
||
|
||
CURRENT_VERSION = "1.0.0"
|
||
|
||
API_URL = "https://git.research.dezeeuw.ca/api/v1/repos/tyler/sparks/releases"
|
||
API_URL_SECONDARY = "https://git.research2.dezeeuw.ca/api/v1/repos/tyler/sparks/releases"
|
||
|
||
PLATFORM_NAME = platform.system().lower()
|
||
|
||
# Selectable parameters on the right side of the window
|
||
SECTIONS = [
|
||
{
|
||
"title": "Parameter 1",
|
||
"params": [
|
||
{"name": "Number 1", "default": True, "type": bool, "help": "N/A"},
|
||
{"name": "Number 2", "default": True, "type": bool, "help": "N/A"},
|
||
]
|
||
},
|
||
]
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
class TrainModelThread(QThread):
|
||
update = Signal(str) # new: emits messages for the GUI
|
||
finished = Signal(str)
|
||
|
||
def __init__(self, csv_paths, model_save_path, parent=None):
|
||
super().__init__(parent)
|
||
self.csv_paths = csv_paths
|
||
self.model_save_path = model_save_path
|
||
|
||
def run(self):
|
||
|
||
# Load CSVs
|
||
X_list, y_list, feature_cols = [], [], None
|
||
for path in self.csv_paths:
|
||
df = pd.read_csv(path)
|
||
if feature_cols is None:
|
||
feature_cols = [c for c in df.columns if c.startswith("lm")]
|
||
df[feature_cols] = df[feature_cols].ffill().bfill()
|
||
X_list.append(df[feature_cols].values)
|
||
y_list.append(df[["reach_active", "reach_before_contact"]].values)
|
||
|
||
X = np.concatenate(X_list, axis=0)
|
||
y = np.concatenate(y_list, axis=0)
|
||
|
||
# Windowing
|
||
FPS, WINDOW_SEC, STEP_SEC, THRESHOLD, EPOCHS = 60, 0.4, 0.2, 0.5, 150
|
||
window_length = int(FPS * WINDOW_SEC)
|
||
step_length = int(FPS * STEP_SEC)
|
||
X_windows, y_windows = [], []
|
||
for start in range(0, len(X)-window_length+1, step_length):
|
||
end = start+window_length
|
||
X_win = X[start:end]
|
||
y_win = y[start:end]
|
||
label_window = (y_win.sum(axis=0) / len(y_win) >= 0.3).astype(int) # shape: (2,)
|
||
X_windows.append(X[start:end])
|
||
y_windows.append(label_window)
|
||
|
||
X_windows = np.array(X_windows)
|
||
y_windows = np.array(y_windows)
|
||
|
||
class_weights = []
|
||
for i in range(2):
|
||
cw = compute_class_weight('balanced', classes=np.array([0,1]), y=y_windows[:, i])
|
||
class_weights.append(cw[1]/cw[0])
|
||
pos_weight = torch.tensor(class_weights, dtype=torch.float32)
|
||
X_tensor = torch.tensor(np.array(X_windows), dtype=torch.float32) # (N, window_length, features)
|
||
y_tensor = torch.tensor(np.array(y_windows), dtype=torch.float32) # (N, 2)
|
||
|
||
# LSTM
|
||
class WindowLSTM(nn.Module):
|
||
def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2):
|
||
super().__init__()
|
||
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True,
|
||
bidirectional=bidirectional)
|
||
self.fc = nn.Linear(hidden_size*(2 if bidirectional else 1), output_size)
|
||
def forward(self, x):
|
||
_, (h, _) = self.lstm(x)
|
||
if self.lstm.bidirectional:
|
||
h = torch.cat([h[-2], h[-1]], dim=1)
|
||
else:
|
||
h = h[-1]
|
||
return self.fc(h)
|
||
|
||
# Hyperparameter sweep
|
||
HIDDEN_SIZES = [64]
|
||
BATCH_SIZES = [32]
|
||
LRs = [0.0005]
|
||
OPTIMIZERS = ['adam']
|
||
BIDIRECTIONAL = [False]
|
||
|
||
best_f1 = 0
|
||
best_config = None
|
||
|
||
for hidden_size, batch_size, lr, opt_name, bi in product(HIDDEN_SIZES,BATCH_SIZES,LRs,OPTIMIZERS,BIDIRECTIONAL):
|
||
self.update.emit(f"Training model: hidden={hidden_size}, batch={batch_size}, lr={lr}, opt={opt_name}, bidir={bi}")
|
||
|
||
dataset = TensorDataset(X_tensor, y_tensor)
|
||
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||
model = WindowLSTM(input_size=X_windows.shape[2], hidden_size=hidden_size, bidirectional=bi)
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=lr) if opt_name=='adam' else torch.optim.SGD(model.parameters(), lr=lr)
|
||
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
||
|
||
# Training
|
||
for epoch in range(EPOCHS):
|
||
model.train()
|
||
epoch_loss = 0
|
||
for xb, yb in loader:
|
||
optimizer.zero_grad()
|
||
logits = model(xb)
|
||
loss = criterion(logits, yb)
|
||
loss.backward()
|
||
optimizer.step()
|
||
epoch_loss += loss.item()
|
||
|
||
# Evaluation
|
||
model.eval()
|
||
with torch.no_grad():
|
||
logits = model(X_tensor)
|
||
probs = torch.sigmoid(logits)
|
||
preds = (probs > THRESHOLD).int()
|
||
y_true = y_tensor.int()
|
||
|
||
tp = ((preds==1)&(y_true==1)).sum().item()
|
||
fp = ((preds==1)&(y_true==0)).sum().item()
|
||
fn = ((preds==0)&(y_true==1)).sum().item()
|
||
|
||
precision = tp/(tp+fp+1e-8)
|
||
recall = tp/(tp+fn+1e-8)
|
||
|
||
f1 = 2*precision*recall/(precision+recall+1e-8)
|
||
|
||
self.update.emit(f"Epoch {epoch+1}/{EPOCHS} Loss={epoch_loss:.4f} "
|
||
f"Precision={precision:.3f} Recall={recall:.3f} F1={f1:.3f}")
|
||
|
||
|
||
if f1 > best_f1:
|
||
best_f1 = f1
|
||
# When saving the model after training
|
||
best_config = (hidden_size, batch_size, lr, opt_name, bi)
|
||
torch.save(model.state_dict(), self.model_save_path)
|
||
|
||
self.update.emit(f"New best model saved to {os.path.basename(self.model_save_path)}!")
|
||
self.finished.emit(f"Training complete!\nBest config: {best_config}\nF1={best_f1:.3f}")
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
class TestModelThread(QThread):
|
||
update_frame = Signal(np.ndarray) # emit frames with overlay
|
||
finished = Signal()
|
||
|
||
def __init__(self, video_paths, model_path, fps=60, window_sec=0.4, step_sec=0.2, threshold=0.5):
|
||
super().__init__()
|
||
self.video_paths = video_paths
|
||
self.fps = fps
|
||
self.window_sec = window_sec
|
||
self.step_sec = step_sec
|
||
self.threshold = threshold
|
||
self.model_path = model_path
|
||
self.window_frames = int(fps*window_sec)
|
||
|
||
def run(self):
|
||
|
||
from time import time
|
||
|
||
class WindowLSTM(torch.nn.Module):
|
||
def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2):
|
||
super().__init__()
|
||
self.lstm = torch.nn.LSTM(
|
||
input_size, hidden_size, batch_first=True, bidirectional=bidirectional
|
||
)
|
||
self.fc = torch.nn.Linear(hidden_size * (2 if bidirectional else 1), output_size)
|
||
|
||
def forward(self, x):
|
||
_, (h, _) = self.lstm(x)
|
||
if self.lstm.bidirectional:
|
||
h = torch.cat([h[-2], h[-1]], dim=1)
|
||
else:
|
||
h = h[-1]
|
||
return self.fc(h)
|
||
|
||
# MATCH THE MODEL CONFIG YOU TRAINED WITH
|
||
model = WindowLSTM(input_size=63, hidden_size=64, bidirectional=False)
|
||
model.load_state_dict(torch.load(self.model_path, map_location="cpu"))
|
||
model.eval()
|
||
|
||
print("Loaded model.")
|
||
|
||
|
||
# ------------------------------ MediaPipe ------------------------------
|
||
mp_hands = mp.solutions.hands
|
||
mp_drawing = mp.solutions.drawing_utils
|
||
mp_drawing_styles = mp.solutions.drawing_styles
|
||
|
||
hands = mp_hands.Hands(
|
||
static_image_mode=False,
|
||
max_num_hands=2,
|
||
min_detection_confidence=0.5,
|
||
min_tracking_confidence=0.5,
|
||
model_complexity=0
|
||
)
|
||
|
||
# ------------------------------ Open videos ------------------------------
|
||
|
||
cap = cv2.VideoCapture(self.video_paths[0])
|
||
|
||
|
||
# buffer holds last 60 frames of landmarks
|
||
window_buffer = []
|
||
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
print(f"Original video dimensions: {original_width} x {original_height}")
|
||
print(f"Total frames in video: {total_frames}")
|
||
|
||
# Set display dimensions (maintaining aspect ratio)
|
||
display_width = 800
|
||
aspect_ratio = original_height / original_width
|
||
display_height = int(display_width * aspect_ratio)
|
||
print(f"Display dimensions: {display_width} x {display_height}")
|
||
|
||
window_name = f"Hand Tracking Test - Press 'q' to quit"
|
||
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
|
||
cv2.resizeWindow(window_name, display_width, display_height)
|
||
|
||
# Output window size
|
||
display_width = 1000
|
||
|
||
frame_index = 0 # <<< ADDED to match your other script
|
||
|
||
while True:
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
|
||
h, w = frame.shape[:2]
|
||
|
||
# detect hands
|
||
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
results = hands.process(rgb)
|
||
|
||
# find RIGHT hand
|
||
features = []
|
||
right_found = False
|
||
|
||
if results.multi_hand_landmarks and results.multi_handedness:
|
||
for i, handness in enumerate(results.multi_handedness):
|
||
label = handness.classification[0].label
|
||
if label == "Right":
|
||
lm = results.multi_hand_landmarks[i]
|
||
right_found = True
|
||
|
||
for p in lm.landmark:
|
||
features += [p.x, p.y, p.z]
|
||
|
||
# draw hand with SAME STYLES
|
||
mp_drawing.draw_landmarks(
|
||
frame,
|
||
lm,
|
||
mp_hands.HAND_CONNECTIONS,
|
||
mp_drawing_styles.get_default_hand_landmarks_style(),
|
||
mp_drawing_styles.get_default_hand_connections_style(),
|
||
)
|
||
break
|
||
|
||
if not right_found:
|
||
features = [math.nan] * 63
|
||
|
||
print(features)
|
||
# --------------------------
|
||
# Maintain window buffer
|
||
# --------------------------
|
||
window_buffer.append(features)
|
||
|
||
# forward fill
|
||
arr = np.array(window_buffer, dtype=np.float32)
|
||
for c in range(arr.shape[1]):
|
||
valid = ~np.isnan(arr[:, c])
|
||
if valid.any():
|
||
arr[:, c][~valid] = np.interp(
|
||
np.flatnonzero(~valid),
|
||
np.flatnonzero(valid),
|
||
arr[:, c][valid],
|
||
)
|
||
window_buffer = arr.tolist()
|
||
|
||
# only last 60 frames
|
||
if len(window_buffer) > self.window_frames:
|
||
window_buffer.pop(0)
|
||
|
||
prob = 0
|
||
pred_active = 0
|
||
pred_before_contact = 0
|
||
|
||
# --------------------------
|
||
# run inference when full window available
|
||
# --------------------------
|
||
if len(window_buffer) == self.window_frames:
|
||
x = torch.tensor(window_buffer, dtype=torch.float32).unsqueeze(0)
|
||
with torch.no_grad():
|
||
logits = model(x)
|
||
probs = torch.sigmoid(logits).squeeze(0).numpy()
|
||
|
||
pred_active = int(probs[0] > self.threshold)
|
||
pred_before_contact = int(probs[1] > self.threshold)
|
||
|
||
color_active = (0, 255, 0) if pred_active else (0, 0, 255)
|
||
color_before_contact = (255, 255, 0) if pred_before_contact else (0, 0, 128)
|
||
|
||
# Overlay rectangles
|
||
cv2.rectangle(frame, (0, 0), (w//2, 40), color_active, -1)
|
||
cv2.putText(frame, f"Reach Active: {'YES' if pred_active else 'NO'} ({probs[0]:.2f})",
|
||
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,255), 2)
|
||
|
||
cv2.rectangle(frame, (w//2, 0), (w, 40), color_before_contact, -1)
|
||
cv2.putText(frame, f"Reach Before Contact: {'YES' if pred_before_contact else 'NO'} ({probs[1]:.2f})",
|
||
(w//2 + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,255), 2)
|
||
|
||
else:
|
||
cv2.rectangle(frame, (0, 0), (w, 40), (0, 128, 255), -1)
|
||
cv2.putText(
|
||
frame,
|
||
f"Collecting window: {len(window_buffer)}/{self.window_frames}",
|
||
(10, 30),
|
||
cv2.FONT_HERSHEY_SIMPLEX,
|
||
1.0,
|
||
(255, 255, 255),
|
||
2,
|
||
)
|
||
|
||
# ------------------------------
|
||
# Extra overlays to MATCH your big script
|
||
# ------------------------------
|
||
info_text = [
|
||
f"Frame: {frame_index}",
|
||
f"Time: {frame_index / self.fps:.2f}s",
|
||
f"Reach Active: {'YES' if pred_active else 'NO'}",
|
||
f"Reach Before Contact: {'YES' if pred_before_contact else 'NO'}",
|
||
f"Right Hand: {'DETECTED' if right_found else 'NOT DETECTED'}",
|
||
"Press 'q' to quit",
|
||
]
|
||
|
||
for i, text in enumerate(info_text):
|
||
y = 60 + i * 25
|
||
cv2.putText(frame, text, (10, y),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,255,0), 2)
|
||
|
||
# If predicted active, add highlight frame
|
||
if pred_active or pred_before_contact:
|
||
cv2.rectangle(frame, (0,0), (w,h), (0,255,0), 6)
|
||
display_frame = frame.copy()
|
||
|
||
# resize for display
|
||
resized_frame = cv2.resize(display_frame, (display_width, display_height), interpolation=cv2.INTER_AREA)
|
||
|
||
cv2.imshow(window_name, resized_frame)
|
||
|
||
frame_index += 1
|
||
|
||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||
break
|
||
self.finished.emit()
|
||
|
||
|
||
|
||
|
||
|
||
class TestModelThread2(QThread):
|
||
finished = Signal(object)
|
||
|
||
def __init__(self, csv_path, model_path, fps=60, window_sec=0.4, threshold=0.5):
|
||
super().__init__()
|
||
self.csv_path = csv_path
|
||
self.model_path = model_path
|
||
self.fps = fps
|
||
self.window_sec = window_sec
|
||
self.window_frames = int(fps * window_sec)
|
||
self.threshold = threshold
|
||
|
||
|
||
def run(self):
|
||
try:
|
||
self._run_internal()
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
self.error = str(e)
|
||
self.finished.emit({"error": str(e)})
|
||
|
||
def _run_internal(self):
|
||
|
||
# --------------------------------------------------
|
||
# 1. Model Definition and Load (Copied from TestModelThread)
|
||
# --------------------------------------------------
|
||
class WindowLSTM(torch.nn.Module):
|
||
def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2):
|
||
super().__init__()
|
||
self.lstm = torch.nn.LSTM(
|
||
input_size, hidden_size, batch_first=True, bidirectional=bidirectional
|
||
)
|
||
self.fc = torch.nn.Linear(hidden_size * (2 if bidirectional else 1), output_size)
|
||
|
||
def forward(self, x):
|
||
_, (h, _) = self.lstm(x)
|
||
if self.lstm.bidirectional:
|
||
h = torch.cat([h[-2], h[-1]], dim=1)
|
||
else:
|
||
h = h[-1]
|
||
return self.fc(h)
|
||
|
||
# Assuming input_size=63, hidden_size=64, bidirectional=False, as in TestModelThread
|
||
# If your training params varied, you must update these defaults to match the saved model
|
||
model = WindowLSTM(input_size=63, hidden_size=64, bidirectional=False)
|
||
model.load_state_dict(torch.load(self.model_path, map_location="cpu"))
|
||
model.eval()
|
||
|
||
# --------------------------------------------------
|
||
# 2. Load CSV of landmarks (CRITICAL FIX APPLIED HERE)
|
||
# --------------------------------------------------
|
||
df = pd.read_csv(self.csv_path)
|
||
feature_cols = [c for c in df.columns if c.startswith("lm")]
|
||
X_with_nans = df[feature_cols].values
|
||
time = df["time_sec"].values if "time_sec" in df.columns else np.arange(len(X_with_nans)) / self.fps
|
||
|
||
# --------------------------------------------------
|
||
# 3. Sliding window inference (EXACTLY MATCHING TestModelThread)
|
||
# --------------------------------------------------
|
||
window_buffer = []
|
||
preds_active = []
|
||
preds_before_contact = []
|
||
|
||
for i in range(len(X_with_nans)):
|
||
|
||
# Pad output
|
||
if i < self.window_frames - 1:
|
||
preds_active.append(0)
|
||
preds_before_contact.append(0)
|
||
|
||
# Add features (may contain NaNs) to the buffer
|
||
features = X_with_nans[i]
|
||
window_buffer.append(features)
|
||
|
||
# Keep only last `window_frames`
|
||
if len(window_buffer) > self.window_frames:
|
||
window_buffer.pop(0)
|
||
|
||
# 🚨 CRITICAL: PERFORM INTERPOLATION ON THE BUFFER
|
||
# This block *must* run to smooth over lost frames, replicating the video logic.
|
||
arr = np.array(window_buffer, dtype=np.float32)
|
||
for c in range(arr.shape[1]):
|
||
valid = ~np.isnan(arr[:, c])
|
||
if valid.any():
|
||
arr[:, c][~valid] = np.interp(
|
||
np.flatnonzero(~valid),
|
||
np.flatnonzero(valid),
|
||
arr[:, c][valid],
|
||
)
|
||
|
||
# 🚨 CRITICAL: Update the buffer back to the interpolated LIST OF LISTS.
|
||
# This mirrors the exact data structure used by TestModelThread before calling torch.tensor().
|
||
window_buffer = arr.tolist()
|
||
|
||
|
||
# predict only when full window
|
||
if len(window_buffer) == self.window_frames:
|
||
|
||
# 🚨 CRITICAL: Create tensor from the LIST OF LISTS buffer
|
||
x = torch.tensor(window_buffer, dtype=torch.float32).unsqueeze(0)
|
||
with torch.no_grad():
|
||
logits = model(x)
|
||
probs = torch.sigmoid(logits).squeeze(0).numpy()
|
||
|
||
preds_active.append(probs[0])
|
||
preds_before_contact.append(probs[1])
|
||
|
||
|
||
self.finished.emit({
|
||
"time": time,
|
||
"active": preds_active,
|
||
"before": preds_before_contact,
|
||
"threshold": self.threshold,
|
||
})
|
||
|
||
|
||
|
||
class TerminalWindow(QWidget):
|
||
def __init__(self, parent=None):
|
||
super().__init__(parent, Qt.WindowType.Window)
|
||
self.setWindowTitle("Terminal - SPARKS")
|
||
|
||
self.output_area = QTextEdit()
|
||
self.output_area.setReadOnly(True)
|
||
|
||
self.input_line = QLineEdit()
|
||
self.input_line.returnPressed.connect(self.handle_command)
|
||
|
||
layout = QVBoxLayout()
|
||
layout.addWidget(self.output_area)
|
||
layout.addWidget(self.input_line)
|
||
self.setLayout(layout)
|
||
|
||
self.commands = {
|
||
"hello": self.cmd_hello,
|
||
"help": self.cmd_help
|
||
}
|
||
|
||
def handle_command(self):
|
||
command_text = self.input_line.text()
|
||
self.input_line.clear()
|
||
|
||
self.output_area.append(f"> {command_text}")
|
||
parts = command_text.strip().split()
|
||
if not parts:
|
||
return
|
||
|
||
command_name = parts[0]
|
||
args = parts[1:]
|
||
|
||
func = self.commands.get(command_name)
|
||
if func:
|
||
try:
|
||
result = func(*args)
|
||
if result:
|
||
self.output_area.append(str(result))
|
||
except Exception as e:
|
||
self.output_area.append(f"[Error] {e}")
|
||
else:
|
||
self.output_area.append(f"[Unknown command] '{command_name}'")
|
||
|
||
|
||
def cmd_hello(self, *args):
|
||
return "Hello from the terminal!"
|
||
|
||
def cmd_help(self, *args):
|
||
return f"Available commands: {', '.join(self.commands.keys())}"
|
||
|
||
|
||
class UpdateDownloadThread(QThread):
|
||
"""
|
||
Thread that downloads and extracts an update package and emits a signal on completion or error.
|
||
|
||
Args:
|
||
download_url (str): URL of the update zip file to download.
|
||
latest_version (str): Version string of the latest update.
|
||
"""
|
||
|
||
update_ready = Signal(str, str)
|
||
error_occurred = Signal(str)
|
||
|
||
def __init__(self, download_url, latest_version):
|
||
super().__init__()
|
||
self.download_url = download_url
|
||
self.latest_version = latest_version
|
||
|
||
def run(self):
|
||
try:
|
||
local_filename = os.path.basename(self.download_url)
|
||
|
||
if PLATFORM_NAME == 'darwin':
|
||
tmp_dir = '/tmp/sparkstempupdate'
|
||
os.makedirs(tmp_dir, exist_ok=True)
|
||
local_path = os.path.join(tmp_dir, local_filename)
|
||
else:
|
||
local_path = os.path.join(os.getcwd(), local_filename)
|
||
|
||
# Download the file
|
||
with requests.get(self.download_url, stream=True, timeout=15) as r:
|
||
r.raise_for_status()
|
||
with open(local_path, 'wb') as f:
|
||
for chunk in r.iter_content(chunk_size=8192):
|
||
if chunk:
|
||
f.write(chunk)
|
||
|
||
# Extract folder name (remove .zip)
|
||
if PLATFORM_NAME == 'darwin':
|
||
extract_folder = os.path.splitext(local_filename)[0]
|
||
extract_path = os.path.join(tmp_dir, extract_folder)
|
||
|
||
else:
|
||
extract_folder = os.path.splitext(local_filename)[0]
|
||
extract_path = os.path.join(os.getcwd(), extract_folder)
|
||
|
||
# Create the folder if not exists
|
||
os.makedirs(extract_path, exist_ok=True)
|
||
|
||
# Extract the zip file contents
|
||
if PLATFORM_NAME == 'darwin':
|
||
subprocess.run(['ditto', '-xk', local_path, extract_path], check=True)
|
||
else:
|
||
with zipfile.ZipFile(local_path, 'r') as zip_ref:
|
||
zip_ref.extractall(extract_path)
|
||
|
||
# Remove the zip once extracted and emit a signal
|
||
os.remove(local_path)
|
||
self.update_ready.emit(self.latest_version, extract_path)
|
||
|
||
except Exception as e:
|
||
# Emit a signal signifying failure
|
||
self.error_occurred.emit(str(e))
|
||
|
||
|
||
|
||
class UpdateCheckThread(QThread):
|
||
"""
|
||
Thread that checks for updates by querying the API and emits a signal based on the result.
|
||
|
||
Signals:
|
||
download_requested(str, str): Emitted with (download_url, latest_version) when an update is available.
|
||
no_update_available(): Emitted when no update is found or current version is up to date.
|
||
error_occurred(str): Emitted with an error message if the update check fails.
|
||
"""
|
||
|
||
download_requested = Signal(str, str)
|
||
no_update_available = Signal()
|
||
error_occurred = Signal(str)
|
||
|
||
def run(self):
|
||
if not getattr(sys, 'frozen', False):
|
||
self.error_occurred.emit("Application is not frozen (Development mode).")
|
||
return
|
||
try:
|
||
latest_version, download_url = self.get_latest_release_for_platform()
|
||
if not latest_version:
|
||
self.no_update_available.emit()
|
||
return
|
||
|
||
if not download_url:
|
||
self.error_occurred.emit(f"No download available for platform '{PLATFORM_NAME}'")
|
||
return
|
||
|
||
if self.version_compare(latest_version, CURRENT_VERSION) > 0:
|
||
self.download_requested.emit(download_url, latest_version)
|
||
else:
|
||
self.no_update_available.emit()
|
||
|
||
except Exception as e:
|
||
self.error_occurred.emit(f"Update check failed: {e}")
|
||
|
||
def version_compare(self, v1, v2):
|
||
def normalize(v): return [int(x) for x in v.split(".")]
|
||
return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2))
|
||
|
||
def get_latest_release_for_platform(self):
|
||
urls = [API_URL, API_URL_SECONDARY]
|
||
for url in urls:
|
||
try:
|
||
|
||
response = requests.get(API_URL, timeout=5)
|
||
response.raise_for_status()
|
||
releases = response.json()
|
||
|
||
if not releases:
|
||
return None, None
|
||
|
||
latest = releases[0]
|
||
tag = latest["tag_name"].lstrip("v")
|
||
|
||
for asset in latest.get("assets", []):
|
||
if PLATFORM_NAME in asset["name"].lower():
|
||
return tag, asset["browser_download_url"]
|
||
|
||
return tag, None
|
||
except (requests.RequestException, ValueError) as e:
|
||
continue
|
||
return None, None
|
||
|
||
|
||
class LocalPendingUpdateCheckThread(QThread):
|
||
"""
|
||
Thread that checks for locally pending updates by scanning the download directory and emits a signal accordingly.
|
||
|
||
Args:
|
||
current_version (str): Current application version.
|
||
platform_suffix (str): Platform-specific suffix to identify update folders.
|
||
"""
|
||
|
||
pending_update_found = Signal(str, str)
|
||
no_pending_update = Signal()
|
||
|
||
def __init__(self, current_version, platform_suffix):
|
||
super().__init__()
|
||
self.current_version = current_version
|
||
self.platform_suffix = platform_suffix
|
||
|
||
def version_compare(self, v1, v2):
|
||
def normalize(v): return [int(x) for x in v.split(".")]
|
||
return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2))
|
||
|
||
def run(self):
|
||
if PLATFORM_NAME == 'darwin':
|
||
cwd = '/tmp/sparkstempupdate'
|
||
else:
|
||
cwd = os.getcwd()
|
||
|
||
pattern = re.compile(r".*-(\d+\.\d+\.\d+)" + re.escape(self.platform_suffix) + r"$")
|
||
found = False
|
||
|
||
try:
|
||
for item in os.listdir(cwd):
|
||
folder_path = os.path.join(cwd, item)
|
||
if os.path.isdir(folder_path) and item.endswith(self.platform_suffix):
|
||
match = pattern.match(item)
|
||
if match:
|
||
folder_version = match.group(1)
|
||
if self.version_compare(folder_version, self.current_version) > 0:
|
||
self.pending_update_found.emit(folder_version, folder_path)
|
||
found = True
|
||
break
|
||
except:
|
||
pass
|
||
|
||
if not found:
|
||
self.no_pending_update.emit()
|
||
|
||
|
||
|
||
|
||
|
||
class ProgressDialog(QDialog):
|
||
def __init__(self, parent=None):
|
||
super().__init__(parent)
|
||
self.setWindowTitle("Processing Files...")
|
||
self.setWindowModality(Qt.WindowModality.ApplicationModal)
|
||
self.setMinimumWidth(300)
|
||
|
||
layout = QVBoxLayout(self)
|
||
|
||
self.message_label = QLabel("Loading and analyzing video files...")
|
||
self.progress_bar = QProgressBar()
|
||
self.progress_bar.setMaximum(0) # Indeterminate state initially
|
||
|
||
self.cancel_button = QPushButton("Cancel")
|
||
|
||
layout.addWidget(self.message_label)
|
||
layout.addWidget(self.progress_bar)
|
||
layout.addWidget(self.cancel_button)
|
||
|
||
def update_progress(self, current, total):
|
||
if self.progress_bar.maximum() == 0:
|
||
self.progress_bar.setMaximum(total)
|
||
|
||
self.progress_bar.setValue(current)
|
||
self.message_label.setText(f"Processing file {current} of {total}...")
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
class FileLoadWorker(QThread):
|
||
observations_loaded = Signal(dict, dict, str)
|
||
progress_update = Signal(int, int)
|
||
loading_finished = Signal()
|
||
loading_failed = Signal(str)
|
||
|
||
def __init__(self, file_path, observations_root, extract_frame_and_hands_func):
|
||
super().__init__()
|
||
self.file_path = file_path
|
||
self.observations_root = observations_root
|
||
self.extract_frame_and_hands = extract_frame_and_hands_func
|
||
self.is_running = True
|
||
self.previews_data = {} # To store the results of the heavy work
|
||
|
||
def resolve_video_path(self, relative_video_path):
|
||
"""Logic to handle flexible path resolution (basename or full join)."""
|
||
|
||
corrected_file_name = os.path.basename(relative_video_path)
|
||
video_path = os.path.join(self.observations_root, corrected_file_name)
|
||
|
||
if not os.path.exists(video_path):
|
||
# Fallback to full relative path join
|
||
video_path = os.path.join(self.observations_root, relative_video_path)
|
||
|
||
return video_path
|
||
|
||
def run(self):
|
||
try:
|
||
# 1. Load the BORIS JSON (Non-UI work)
|
||
with open(self.file_path, "r") as f:
|
||
boris = json.load(f)
|
||
|
||
observations = boris.get("observations", {})
|
||
if not observations:
|
||
self.loading_failed.emit("No observations found in BORIS JSON.")
|
||
return
|
||
|
||
# 2. Process the file data to gather initial previews (the heavy part)
|
||
self.process_previews(observations)
|
||
|
||
# 3. Emit the final results back
|
||
self.observations_loaded.emit(boris, observations, self.observations_root)
|
||
|
||
except Exception as e:
|
||
self.loading_failed.emit(f"Loading failed: {str(e)}")
|
||
finally:
|
||
self.loading_finished.emit() # Ensures dialog closes even on some non-critical path breaks
|
||
|
||
def process_previews(self, observations):
|
||
total_files = sum(len(obs_data.get("file", {}).items()) for obs_data in observations.values())
|
||
current_index = 0
|
||
|
||
for obs_id, obs_data in observations.items():
|
||
files_dict = obs_data.get("file", {})
|
||
media_info = obs_data.get("media_info", {})
|
||
fps_dict = media_info.get("fps", {})
|
||
# Simplified FPS calculation for the worker
|
||
fps_default = float(list(fps_dict.values())[0]) if fps_dict and list(fps_dict.values()) else 30.0
|
||
|
||
print(media_info)
|
||
for cam_id, paths in files_dict.items():
|
||
if not self.is_running:
|
||
return # Allow cancellation
|
||
if not paths:
|
||
print(f"Skipping entry for {obs_id}/{cam_id}: paths list is empty (IndexError prevented).")
|
||
current_index += 1
|
||
self.progress_update.emit(current_index, total_files)
|
||
continue
|
||
print(paths)
|
||
relative_video_path = paths[0]
|
||
video_path = self.resolve_video_path(relative_video_path)
|
||
print(video_path)
|
||
if os.path.exists(video_path):
|
||
# HEAVY OPERATION
|
||
frame_rgb, results = self.extract_frame_and_hands(video_path, 0)
|
||
|
||
initial_wrists = []
|
||
if results and results.multi_hand_landmarks:
|
||
for lm in results.multi_hand_landmarks:
|
||
wrist = lm.landmark[0]
|
||
initial_wrists.append((wrist.x, wrist.y))
|
||
|
||
# Store the results needed to BUILD the UI later
|
||
self.previews_data[(obs_id, cam_id)] = {
|
||
"frame_rgb": frame_rgb,
|
||
"results": results,
|
||
"video_path": video_path,
|
||
"fps": fps_default,
|
||
"initial_wrists": initial_wrists
|
||
}
|
||
|
||
current_index += 1
|
||
self.progress_update.emit(current_index, total_files)
|
||
|
||
def stop(self):
|
||
self.is_running = False
|
||
|
||
|
||
|
||
|
||
class AboutWindow(QWidget):
|
||
"""
|
||
Simple About window displaying basic application information.
|
||
|
||
Args:
|
||
parent (QWidget, optional): Parent widget of this window. Defaults to None.
|
||
"""
|
||
|
||
def __init__(self, parent=None):
|
||
super().__init__(parent, Qt.WindowType.Window)
|
||
self.setWindowTitle("About SPARKS")
|
||
self.resize(250, 100)
|
||
|
||
layout = QVBoxLayout()
|
||
label = QLabel("About SPARKS", self)
|
||
label2 = QLabel("Spacial Patterns Analysis, Research, & Knowledge Suite", self)
|
||
label3 = QLabel("SPARKS is licensed under the GPL-3.0 licence. For more information, visit https://www.gnu.org/licenses/gpl-3.0.en.html", self)
|
||
label4 = QLabel(f"Version v{CURRENT_VERSION}")
|
||
|
||
layout.addWidget(label)
|
||
layout.addWidget(label2)
|
||
layout.addWidget(label3)
|
||
layout.addWidget(label4)
|
||
|
||
self.setLayout(layout)
|
||
|
||
|
||
|
||
class UserGuideWindow(QWidget):
|
||
"""
|
||
Simple User Guide window displaying basic information on how to use the software.
|
||
|
||
Args:
|
||
parent (QWidget, optional): Parent widget of this window. Defaults to None.
|
||
"""
|
||
|
||
def __init__(self, parent=None):
|
||
super().__init__(parent, Qt.WindowType.Window)
|
||
self.setWindowTitle("User Guide - SPARKS")
|
||
self.resize(250, 100)
|
||
|
||
layout = QVBoxLayout()
|
||
label = QLabel("Not Applicable:", self)
|
||
label2 = QLabel("N/A", self)
|
||
|
||
layout.addWidget(label)
|
||
layout.addWidget(label2)
|
||
|
||
self.setLayout(layout)
|
||
|
||
|
||
|
||
class ParamSection(QWidget):
|
||
"""
|
||
A widget section that dynamically creates labeled input fields from parameter metadata.
|
||
|
||
Args:
|
||
section_data (dict): Dictionary containing section title and list of parameter info.
|
||
Expected format:
|
||
{
|
||
"title": str,
|
||
"params": [
|
||
{
|
||
"name": str,
|
||
"type": type,
|
||
"default": any,
|
||
"help": str (optional)
|
||
},
|
||
...
|
||
]
|
||
}
|
||
"""
|
||
|
||
def __init__(self, section_data):
|
||
super().__init__()
|
||
layout = QVBoxLayout()
|
||
self.setLayout(layout)
|
||
self.widgets = {}
|
||
|
||
self.selected_path = None
|
||
|
||
# Title label
|
||
title_label = QLabel(section_data["title"])
|
||
title_label.setStyleSheet("font-weight: bold; font-size: 14px; margin-top: 10px; margin-bottom: 5px;")
|
||
layout.addWidget(title_label)
|
||
|
||
# Horizontal line
|
||
line = QFrame()
|
||
line.setFrameShape(QFrame.Shape.HLine)
|
||
line.setFrameShadow(QFrame.Shadow.Sunken)
|
||
layout.addWidget(line)
|
||
|
||
for param in section_data["params"]:
|
||
h_layout = QHBoxLayout()
|
||
|
||
label = QLabel(param["name"])
|
||
label.setFixedWidth(180)
|
||
|
||
label.setToolTip(param.get("help", ""))
|
||
|
||
help_text = param.get("help", "")
|
||
|
||
help_btn = QPushButton("?")
|
||
help_btn.setFixedWidth(25)
|
||
help_btn.setToolTip(help_text)
|
||
help_btn.clicked.connect(lambda _, text=help_text: self.show_help_popup(text))
|
||
|
||
|
||
h_layout.addWidget(help_btn)
|
||
h_layout.setStretch(0, 1) # Set stretch factor for button (10%)
|
||
|
||
h_layout.addWidget(label)
|
||
h_layout.setStretch(1, 3) # Set the stretch factor for label (40%)
|
||
|
||
|
||
# Create input widget based on type
|
||
if param["type"] == bool:
|
||
widget = QComboBox()
|
||
widget.addItems(["True", "False"])
|
||
widget.setCurrentText(str(param["default"]))
|
||
elif param["type"] == int:
|
||
widget = QLineEdit()
|
||
widget.setValidator(QIntValidator())
|
||
widget.setText(str(param["default"]))
|
||
elif param["type"] == float:
|
||
widget = QLineEdit()
|
||
widget.setValidator(QDoubleValidator())
|
||
widget.setText(str(param["default"]))
|
||
elif param["type"] == list:
|
||
widget = self._create_multiselect_dropdown(None)
|
||
else:
|
||
widget = QLineEdit()
|
||
widget.setText(str(param["default"]))
|
||
|
||
widget.setToolTip(help_text)
|
||
|
||
h_layout.addWidget(widget)
|
||
h_layout.setStretch(2, 5) # Set stretch factor for input field (50%)
|
||
|
||
layout.addLayout(h_layout)
|
||
self.widgets[param["name"]] = {
|
||
"widget": widget,
|
||
"type": param["type"]
|
||
}
|
||
|
||
def _create_multiselect_dropdown(self, items):
|
||
combo = FullClickComboBox()
|
||
combo.setView(QListView())
|
||
model = QStandardItemModel()
|
||
combo.setModel(model)
|
||
combo.setEditable(True)
|
||
combo.lineEdit().setReadOnly(True)
|
||
combo.lineEdit().setPlaceholderText("Select...")
|
||
|
||
dummy_item = QStandardItem("<None Selected>")
|
||
dummy_item.setFlags(Qt.ItemIsEnabled)
|
||
model.appendRow(dummy_item)
|
||
|
||
toggle_item = QStandardItem("Toggle Select All")
|
||
toggle_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled)
|
||
toggle_item.setData(Qt.Unchecked, Qt.CheckStateRole)
|
||
model.appendRow(toggle_item)
|
||
|
||
if items is not None:
|
||
for item in items:
|
||
standard_item = QStandardItem(item)
|
||
standard_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled)
|
||
standard_item.setData(Qt.Unchecked, Qt.CheckStateRole)
|
||
model.appendRow(standard_item)
|
||
|
||
combo.setInsertPolicy(QComboBox.NoInsert)
|
||
|
||
|
||
def on_view_clicked(index):
|
||
item = model.itemFromIndex(index)
|
||
if item.isCheckable():
|
||
new_state = Qt.Checked if item.checkState() == Qt.Unchecked else Qt.Unchecked
|
||
item.setCheckState(new_state)
|
||
|
||
combo.view().pressed.connect(on_view_clicked)
|
||
|
||
self._updating_checkstates = False
|
||
|
||
def on_item_changed(item):
|
||
if self._updating_checkstates:
|
||
return
|
||
self._updating_checkstates = True
|
||
|
||
normal_items = [model.item(i) for i in range(2, model.rowCount())] # skip dummy and toggle
|
||
|
||
if item == toggle_item:
|
||
all_checked = all(i.checkState() == Qt.Checked for i in normal_items)
|
||
if all_checked:
|
||
for i in normal_items:
|
||
i.setCheckState(Qt.Unchecked)
|
||
toggle_item.setCheckState(Qt.Unchecked)
|
||
else:
|
||
for i in normal_items:
|
||
i.setCheckState(Qt.Checked)
|
||
toggle_item.setCheckState(Qt.Checked)
|
||
|
||
elif item == dummy_item:
|
||
pass
|
||
|
||
else:
|
||
# When normal items change, update toggle item
|
||
all_checked = all(i.checkState() == Qt.Checked for i in normal_items)
|
||
toggle_item.setCheckState(Qt.Checked if all_checked else Qt.Unchecked)
|
||
|
||
self._updating_checkstates = False
|
||
|
||
for param_name, info in self.widgets.items():
|
||
if info["widget"] == combo:
|
||
self.update_dropdown_label(param_name)
|
||
break
|
||
|
||
model.itemChanged.connect(on_item_changed)
|
||
|
||
combo.setInsertPolicy(QComboBox.NoInsert)
|
||
return combo
|
||
|
||
def show_help_popup(self, text):
|
||
msg = QMessageBox(self)
|
||
msg.setWindowTitle("Parameter Info - SPARKS")
|
||
msg.setText(text)
|
||
msg.exec()
|
||
|
||
|
||
def get_param_values(self):
|
||
values = {}
|
||
for name, info in self.widgets.items():
|
||
widget = info["widget"]
|
||
expected_type = info["type"]
|
||
|
||
if expected_type == bool:
|
||
values[name] = widget.currentText() == "True"
|
||
elif expected_type == list:
|
||
values[name] = [x.strip() for x in widget.lineEdit().text().split(",") if x.strip()]
|
||
else:
|
||
raw_text = widget.text()
|
||
try:
|
||
if expected_type == int:
|
||
values[name] = int(raw_text)
|
||
elif expected_type == float:
|
||
values[name] = float(raw_text)
|
||
elif expected_type == str:
|
||
values[name] = raw_text
|
||
else:
|
||
values[name] = raw_text # Fallback
|
||
except Exception as e:
|
||
raise ValueError(f"Invalid value for {name}: {raw_text}") from e
|
||
|
||
return values
|
||
|
||
|
||
class FullClickLineEdit(QLineEdit):
|
||
def mousePressEvent(self, event):
|
||
combo = self.parent()
|
||
if isinstance(combo, QComboBox):
|
||
combo.showPopup()
|
||
super().mousePressEvent(event)
|
||
|
||
|
||
class FullClickComboBox(QComboBox):
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self.setLineEdit(FullClickLineEdit(self))
|
||
self.lineEdit().setReadOnly(True)
|
||
|
||
|
||
class ParticipantProcessor(QThread):
|
||
progress_updated = Signal(int)
|
||
frame_ready = Signal(QImage)
|
||
finished_processing = Signal(str)
|
||
time_updated = Signal(str)
|
||
|
||
def __init__(self, obs_id, boris_json, selected_cam_id, selected_hand_idx,
|
||
hide_preview=False, output_csv=None, observations_root=None, output_dir=None, initial_wrists=None):
|
||
super().__init__()
|
||
self.obs_id = obs_id
|
||
self.boris_json = boris_json
|
||
self.selected_cam_id = selected_cam_id
|
||
self.selected_hand_idx = selected_hand_idx
|
||
self.hide_preview = hide_preview
|
||
self.output_csv = output_csv
|
||
self.observations_root = observations_root
|
||
self.output_dir = output_dir
|
||
self.initial_wrists = initial_wrists
|
||
|
||
# Mediapipe hands
|
||
self.mp_hands = mp.solutions.hands.Hands(
|
||
static_image_mode=False,
|
||
max_num_hands=2,
|
||
min_detection_confidence=0.5,
|
||
min_tracking_confidence=0.5,
|
||
model_complexity=0
|
||
)
|
||
self.is_running = True
|
||
|
||
def run(self):
|
||
observations = self.boris_json.get("observations", {})
|
||
obs_data = observations[self.obs_id]
|
||
|
||
# ---------------- Camera / Video ----------------
|
||
cameras_with_files = obs_data.get("file", {})
|
||
media_info = obs_data.get("media_info", {})
|
||
offset_dict = media_info.get("offset", {})
|
||
# paths = files_dict.get(self.cam_id, [])
|
||
|
||
video_file_list = cameras_with_files[self.selected_cam_id]
|
||
relative_video_path = video_file_list[0]
|
||
|
||
# 2. Resolve the FULL path using the stored root and the same logic as before
|
||
corrected_file_name = os.path.basename(relative_video_path)
|
||
|
||
# Try finding just the file name in the root folder
|
||
video_path = os.path.join(self.observations_root, corrected_file_name)
|
||
|
||
# if not os.path.exists(video_path):
|
||
# # Fallback to the full relative path join
|
||
# video_path = os.path.join(self.observations_root, relative_video_path)
|
||
|
||
if not os.path.exists(video_path):
|
||
print(f"FATAL ERROR: Video file not found at: {video_path}")
|
||
self.finished_processing.emit(f"Error: Video not found for {self.obs_id}")
|
||
return
|
||
|
||
camera_offset = float(offset_dict.get(self.selected_cam_id, 0.0))
|
||
|
||
fps_dict = media_info.get("fps", {})
|
||
fps = float(list(fps_dict.values())[0]) if fps_dict else 30.0
|
||
|
||
cap = cv2.VideoCapture(video_path)
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
total_seconds = total_frames / fps
|
||
total_time_str = self._format_seconds(total_seconds)
|
||
|
||
# ---------------- Trial Segments ----------------
|
||
raw_events = obs_data.get("events", [])
|
||
segments = []
|
||
reach_before_contact_segments = []
|
||
|
||
current_start = None
|
||
extra_duration = 1.0
|
||
|
||
for ev in raw_events:
|
||
if len(ev) < 3:
|
||
continue
|
||
time_sec = ev[0]
|
||
name = ev[2]
|
||
if name == "Trial Start":
|
||
current_start = time_sec
|
||
current_contact = None
|
||
elif name == "Contact" and current_start is not None:
|
||
current_contact = time_sec
|
||
reach_before_contact_segments.append((current_start, current_contact))
|
||
|
||
elif name == "End" and current_start is not None:
|
||
adjusted_start = current_start
|
||
adjusted_end = time_sec + extra_duration
|
||
segments.append((adjusted_start, adjusted_end))
|
||
if current_contact is None:
|
||
reach_before_contact_segments.append((current_start, adjusted_end))
|
||
current_start = None
|
||
|
||
# ---------------- CSV Header ----------------
|
||
header = ["frame_index", "time_sec", "adjusted_time_sec", "reach_active", "reach_before_contact", "camera_offset"]
|
||
for i in range(21):
|
||
header += [f"lm{i}_x", f"lm{i}_y", f"lm{i}_z"]
|
||
|
||
output_path = os.path.join(self.output_dir, self.output_csv)
|
||
print(output_path)
|
||
with open(output_path, "w", newline="") as csv_file:
|
||
csv_writer = csv.writer(csv_file)
|
||
csv_writer.writerow(header)
|
||
|
||
# ---------------- Frame Processing ----------------
|
||
frame_index = 0
|
||
while True:
|
||
if not self.is_running:
|
||
print(f"Processor for {self.obs_id}/{self.cam_id} was cancelled gracefully.")
|
||
# Clean up resources before returning
|
||
cap.release()
|
||
return # Thread exits gracefully
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
|
||
current_time = frame_index / fps
|
||
adjusted_time = current_time # you could add extra logic if needed
|
||
|
||
# Check if reach is active with offset
|
||
is_reach_active = any(
|
||
(start - camera_offset) <= current_time <= (end - camera_offset)
|
||
for start, end in segments
|
||
)
|
||
|
||
is_reach_before_contact = any(
|
||
(start - camera_offset) <= current_time <= (end - camera_offset)
|
||
for start, end in reach_before_contact_segments
|
||
)
|
||
|
||
row = [
|
||
frame_index, current_time, adjusted_time,
|
||
int(is_reach_active), int(is_reach_before_contact),
|
||
camera_offset
|
||
]
|
||
|
||
# ---------------- MediaPipe Hands ----------------
|
||
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
results = self.mp_hands.process(rgb_frame)
|
||
|
||
hand_found = False
|
||
if results.multi_hand_landmarks and results.multi_handedness:
|
||
# Map detected hands to initial_wrists
|
||
hand_mapping = {}
|
||
used_indices = set()
|
||
if self.initial_wrists:
|
||
for i, iw in enumerate(self.initial_wrists):
|
||
min_dist = float('inf')
|
||
best_idx = None
|
||
for j, lm in enumerate(results.multi_hand_landmarks):
|
||
if j in used_indices:
|
||
continue
|
||
wrist = lm.landmark[0]
|
||
dist = (wrist.x - iw[0])**2 + (wrist.y - iw[1])**2
|
||
if dist < min_dist:
|
||
min_dist = dist
|
||
best_idx = j
|
||
if best_idx is not None:
|
||
hand_mapping[i] = best_idx
|
||
used_indices.add(best_idx)
|
||
else:
|
||
hand_mapping = {i: i for i in range(len(results.multi_hand_landmarks))}
|
||
|
||
# Use selected_hand_idx to pick the right hand
|
||
mapped_idx = hand_mapping.get(self.selected_hand_idx)
|
||
if mapped_idx is not None:
|
||
hand_landmarks = results.multi_hand_landmarks[mapped_idx]
|
||
hand_found = True
|
||
for lm in hand_landmarks.landmark:
|
||
row += [lm.x, lm.y, lm.z]
|
||
|
||
if not hand_found:
|
||
row += [math.nan] * (21 * 3)
|
||
|
||
csv_writer.writerow(row)
|
||
|
||
# ---------------- Preview ----------------
|
||
if not self.hide_preview:
|
||
display_frame = frame.copy()
|
||
if results.multi_hand_landmarks and hand_found:
|
||
mp.solutions.drawing_utils.draw_landmarks(
|
||
display_frame,
|
||
hand_landmarks,
|
||
mp.solutions.hands.HAND_CONNECTIONS
|
||
)
|
||
# Convert to QImage
|
||
h, w, _ = display_frame.shape
|
||
qimg = QImage(display_frame.data, w, h, 3*w, QImage.Format_RGB888)
|
||
self.frame_ready.emit(qimg)
|
||
|
||
current_seconds = frame_index / fps
|
||
current_time_str = self._format_seconds(current_seconds)
|
||
# ---------------- Progress ----------------
|
||
progress = int((frame_index / total_frames) * 100)
|
||
self.progress_updated.emit(progress)
|
||
self.time_updated.emit(f"{current_time_str}/{total_time_str}")
|
||
frame_index += 1
|
||
|
||
cap.release()
|
||
self.mp_hands.close()
|
||
self.finished_processing.emit(self.obs_id)
|
||
|
||
def _format_seconds(self, seconds):
|
||
if seconds < 0:
|
||
return "00:00"
|
||
minutes = math.floor(seconds / 60)
|
||
secs = math.floor(seconds % 60)
|
||
return f"{minutes:02d}:{secs:02d}"
|
||
|
||
def stop(self):
|
||
"""Sets the flag to stop the thread's execution loop."""
|
||
self.is_running = False
|
||
|
||
|
||
class MainApplication(QMainWindow):
|
||
"""
|
||
Main application window that creates and sets up the UI.
|
||
"""
|
||
|
||
progress_update_signal = Signal(str, int)
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.setWindowTitle("SPARKS")
|
||
self.setGeometry(100, 100, 1280, 720)
|
||
|
||
self.about = None
|
||
self.help = None
|
||
self.optodes = None
|
||
self.events = None
|
||
self.terminal = None
|
||
self.bubble_widgets = {}
|
||
self.param_sections = []
|
||
self.folder_paths = []
|
||
self.section_widget = None
|
||
self.first_run = True
|
||
|
||
self.selection_widgets = {}
|
||
self.camera_state = {} # {(obs_id, cam_id): state info}
|
||
self.processing_widgets = {}
|
||
|
||
self.files_total = 0 # total number of files to process
|
||
self.files_done = set() # set of file paths done (success or fail)
|
||
self.files_failed = set() # set of failed file paths
|
||
self.files_results = {} # dict for successful results (if needed)
|
||
|
||
self.processing_threads = [] # List to hold all active ParticipantProcessor threads
|
||
self.processing_widgets = {}
|
||
|
||
self.init_ui()
|
||
self.create_menu_bar()
|
||
|
||
self.platform_suffix = "-" + PLATFORM_NAME
|
||
self.pending_update_version = None
|
||
self.pending_update_path = None
|
||
self.last_clicked_bubble = None
|
||
self.installEventFilter(self)
|
||
|
||
self.file_metadata = {}
|
||
self.current_file = None
|
||
|
||
self.worker_thread = None
|
||
self.progress_dialog = None
|
||
|
||
# Mediapipe hands
|
||
self.mp_hands = mp.solutions.hands.Hands(
|
||
static_image_mode=True,
|
||
max_num_hands=2,
|
||
min_detection_confidence=0.5
|
||
)
|
||
|
||
# Start local pending update check thread
|
||
self.local_check_thread = LocalPendingUpdateCheckThread(CURRENT_VERSION, self.platform_suffix)
|
||
self.local_check_thread.pending_update_found.connect(self.on_pending_update_found)
|
||
self.local_check_thread.no_pending_update.connect(self.on_no_pending_update)
|
||
self.local_check_thread.start()
|
||
|
||
|
||
def init_ui(self):
|
||
|
||
# Central widget and main horizontal layout
|
||
central = QWidget()
|
||
self.setCentralWidget(central)
|
||
|
||
main_layout = QHBoxLayout()
|
||
central.setLayout(main_layout)
|
||
|
||
# Left container with vertical layout: top left + bottom left
|
||
left_container = QWidget()
|
||
left_layout = QVBoxLayout()
|
||
left_container.setLayout(left_layout)
|
||
left_container.setMinimumWidth(300)
|
||
|
||
top_left_container = QGroupBox()
|
||
top_left_container.setTitle("File information")
|
||
top_left_container.setStyleSheet("QGroupBox { font-weight: bold; }") # Style if needed
|
||
top_left_container.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding)
|
||
|
||
top_left_layout = QHBoxLayout()
|
||
|
||
top_left_container.setLayout(top_left_layout)
|
||
|
||
# QTextEdit with fixed height, but only 80% width
|
||
self.top_left_widget = QTextEdit()
|
||
self.top_left_widget.setReadOnly(True)
|
||
self.top_left_widget.setPlaceholderText("Logging information will be available here.")
|
||
|
||
# Add QTextEdit to the layout with a stretch factor
|
||
top_left_layout.addWidget(self.top_left_widget, stretch=4) # 80%
|
||
|
||
# Create a vertical box layout for the right 20%
|
||
self.right_column_widget = QWidget()
|
||
right_column_layout = QVBoxLayout()
|
||
self.right_column_widget.setLayout(right_column_layout)
|
||
|
||
self.meta_fields = {
|
||
"HAND": QLineEdit(),
|
||
"GENDER": QLineEdit(),
|
||
"GROUP": QLineEdit(),
|
||
}
|
||
|
||
for key, field in self.meta_fields.items():
|
||
label = QLabel(key.capitalize())
|
||
field.setPlaceholderText(f"Enter {key}")
|
||
right_column_layout.addWidget(label)
|
||
right_column_layout.addWidget(field)
|
||
|
||
label_desc = QLabel('<a href="#">Why are these useful?</a>')
|
||
label_desc.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction)
|
||
label_desc.setOpenExternalLinks(False)
|
||
|
||
def show_info_popup():
|
||
QMessageBox.information(None, "Parameter Info - SPARKS",
|
||
"Age: Used to calculate the DPF factor.\nGender: Not currently used. "
|
||
"Will be able to sort into groups by gender in the near future.\nGroup: Allows contrast "
|
||
"images to be created comparing one group to another once the processing has completed.")
|
||
|
||
label_desc.linkActivated.connect(show_info_popup)
|
||
right_column_layout.addWidget(label_desc)
|
||
|
||
right_column_layout.addStretch() # Push fields to top
|
||
|
||
self.right_column_widget.hide()
|
||
|
||
# Add right column widget to the top-left layout (takes 20% width)
|
||
top_left_layout.addWidget(self.right_column_widget, stretch=1)
|
||
|
||
# Add top_left_container to the main left_layout
|
||
left_layout.addWidget(top_left_container, stretch=2)
|
||
|
||
# Bottom left: the bubbles inside the scroll area
|
||
self.bubble_container = QWidget()
|
||
self.bubble_layout = QGridLayout()
|
||
self.bubble_layout.setAlignment(Qt.AlignmentFlag.AlignTop)
|
||
self.bubble_container.setLayout(self.bubble_layout)
|
||
|
||
self.scroll_area = QScrollArea()
|
||
self.scroll_area.setWidgetResizable(True)
|
||
self.scroll_area.setWidget(self.bubble_container)
|
||
self.scroll_area.setMinimumHeight(300)
|
||
|
||
# Add top left and bottom left to left layout
|
||
left_layout.addWidget(self.scroll_area, stretch=8)
|
||
|
||
# Right widget (full height on right side)
|
||
self.right_container = QWidget()
|
||
right_container_layout = QVBoxLayout()
|
||
self.right_container.setLayout(right_container_layout)
|
||
|
||
# Content widget inside scroll area
|
||
self.right_content_widget = QWidget()
|
||
right_content_layout = QVBoxLayout()
|
||
self.right_content_widget.setLayout(right_content_layout)
|
||
|
||
# Option selector dropdown
|
||
self.option_selector = QComboBox()
|
||
self.option_selector.addItems(["N/A"])
|
||
right_content_layout.addWidget(self.option_selector)
|
||
|
||
# Container for the sections
|
||
self.rows_container = QWidget()
|
||
self.rows_layout = QVBoxLayout()
|
||
self.rows_layout.setSpacing(10)
|
||
self.rows_container.setLayout(self.rows_layout)
|
||
right_content_layout.addWidget(self.rows_container)
|
||
|
||
# Spacer at bottom inside scroll area content to push content up
|
||
right_content_layout.addStretch()
|
||
|
||
# Scroll area for the right side content
|
||
self.right_scroll_area = QScrollArea()
|
||
self.right_scroll_area.setWidgetResizable(True)
|
||
self.right_scroll_area.setWidget(self.right_content_widget)
|
||
|
||
# Buttons widget (fixed below the scroll area)
|
||
buttons_widget = QWidget()
|
||
buttons_layout = QHBoxLayout()
|
||
buttons_widget.setLayout(buttons_layout)
|
||
buttons_layout.addStretch()
|
||
|
||
self.button1 = QPushButton("Process")
|
||
self.button2 = QPushButton("Clear")
|
||
|
||
buttons_layout.addWidget(self.button1)
|
||
buttons_layout.addWidget(self.button2)
|
||
|
||
self.button1.setMinimumSize(100, 40)
|
||
self.button2.setMinimumSize(100, 40)
|
||
|
||
self.button1.setVisible(False)
|
||
|
||
self.button1.clicked.connect(self.on_run_task)
|
||
self.button2.clicked.connect(self.clear_all)
|
||
|
||
# Add scroll area and buttons widget to right container layout
|
||
right_container_layout.addWidget(self.right_scroll_area)
|
||
right_container_layout.addWidget(buttons_widget)
|
||
|
||
# Add left and right containers to main layout
|
||
main_layout.addWidget(left_container, stretch=70)
|
||
main_layout.addWidget(self.right_container, stretch=30)
|
||
|
||
# Set size policy to expand
|
||
self.right_container.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
||
self.right_scroll_area.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
||
|
||
# Store ParamSection widgets
|
||
self.option_selector.currentIndexChanged.connect(self.update_sections)
|
||
|
||
# Initial build
|
||
self.update_sections(0)
|
||
|
||
|
||
def create_menu_bar(self):
|
||
'''Menu Bar at the top of the screen'''
|
||
|
||
menu_bar = self.menuBar()
|
||
|
||
def make_action(name, shortcut=None, slot=None, checkable=False, checked=False, icon=None):
|
||
action = QAction(name, self)
|
||
|
||
if shortcut:
|
||
action.setShortcut(QKeySequence(shortcut))
|
||
if slot:
|
||
action.triggered.connect(slot)
|
||
if checkable:
|
||
action.setCheckable(True)
|
||
action.setChecked(checked)
|
||
if icon:
|
||
action.setIcon(QIcon(icon))
|
||
return action
|
||
|
||
# File menu and actions
|
||
file_menu = menu_bar.addMenu("File")
|
||
file_actions = [
|
||
("Open BORIS file...", "Ctrl+O", self.open_file_dialog, resource_path("icons/file_open_24dp_1F1F1F.svg")),
|
||
#("Open Folder...", "Ctrl+Alt+O", self.open_folder_dialog, resource_path("icons/folder_24dp_1F1F1F.svg")),
|
||
#("Open Folders...", "Ctrl+Shift+O", self.open_folder_dialog, resource_path("icons/folder_copy_24dp_1F1F1F.svg")),
|
||
#("Load Project...", "Ctrl+L", self.load_project, resource_path("icons/article_24dp_1F1F1F.svg")),
|
||
#("Save Project...", "Ctrl+S", self.save_project, resource_path("icons/save_24dp_1F1F1F.svg")),
|
||
#("Save Project As...", "Ctrl+Shift+S", self.save_project, resource_path("icons/save_as_24dp_1F1F1F.svg")),
|
||
]
|
||
|
||
for i, (name, shortcut, slot, icon) in enumerate(file_actions):
|
||
file_menu.addAction(make_action(name, shortcut, slot, icon=icon))
|
||
if i == 1: # after the first 3 actions (0,1,2)
|
||
file_menu.addSeparator()
|
||
|
||
file_menu.addSeparator()
|
||
file_menu.addAction(make_action("Exit", "Ctrl+Q", QApplication.instance().quit, icon=resource_path("icons/exit_to_app_24dp_1F1F1F.svg")))
|
||
|
||
# Edit menu
|
||
edit_menu = menu_bar.addMenu("Edit")
|
||
edit_actions = [
|
||
("Cut", "Ctrl+X", self.cut_text, resource_path("icons/content_cut_24dp_1F1F1F.svg")),
|
||
("Copy", "Ctrl+C", self.copy_text, resource_path("icons/content_copy_24dp_1F1F1F.svg")),
|
||
("Paste", "Ctrl+V", self.paste_text, resource_path("icons/content_paste_24dp_1F1F1F.svg"))
|
||
]
|
||
for name, shortcut, slot, icon in edit_actions:
|
||
edit_menu.addAction(make_action(name, shortcut, slot, icon=icon))
|
||
|
||
|
||
model_menu = menu_bar.addMenu("Model")
|
||
model_actions = [
|
||
("Create model from a CSV", "Ctrl+U", self.train_model_csv, resource_path("icons/content_cut_24dp_1F1F1F.svg")),
|
||
("Create model from a folder", "Ctrl+I", self.train_model_folder, resource_path("icons/content_copy_24dp_1F1F1F.svg")),
|
||
("Test model on a video", "Ctrl+O", self.test_model_video, resource_path("icons/content_paste_24dp_1F1F1F.svg")),
|
||
("Test model on a folder", "Ctrl+P", self.test_model_folder, resource_path("icons/content_paste_24dp_1F1F1F.svg")),
|
||
("Test model on a CSV", "Ctrl+P", self.test_model_csv, resource_path("icons/content_paste_24dp_1F1F1F.svg"))
|
||
]
|
||
for i, (name, shortcut, slot, icon) in enumerate(model_actions):
|
||
model_menu.addAction(make_action(name, shortcut, slot, icon=icon))
|
||
if i == 1:
|
||
model_menu.addSeparator()
|
||
|
||
# View menu
|
||
view_menu = menu_bar.addMenu("View")
|
||
toggle_statusbar_action = make_action("Toggle Status Bar", checkable=True, checked=True, slot=None)
|
||
view_menu.addAction(toggle_statusbar_action)
|
||
|
||
# Options menu (Help & About)
|
||
options_menu = menu_bar.addMenu("Options")
|
||
|
||
options_actions = [
|
||
("User Guide", "F1", self.user_guide, resource_path("icons/help_24dp_1F1F1F.svg")),
|
||
("Check for Updates", "F5", self.manual_check_for_updates, resource_path("icons/update_24dp_1F1F1F.svg")),
|
||
("About", "F12", self.about_window, resource_path("icons/info_24dp_1F1F1F.svg"))
|
||
]
|
||
|
||
for i, (name, shortcut, slot, icon) in enumerate(options_actions):
|
||
options_menu.addAction(make_action(name, shortcut, slot, icon=icon))
|
||
if i == 1 or i == 3: # after the first 2 actions (0,1)
|
||
options_menu.addSeparator()
|
||
|
||
terminal_menu = menu_bar.addMenu("Terminal")
|
||
terminal_actions = [
|
||
("New Terminal", "Ctrl+Alt+T", self.terminal_gui, resource_path("icons/terminal_24dp_1F1F1F.svg")),
|
||
]
|
||
for name, shortcut, slot, icon in terminal_actions:
|
||
terminal_menu.addAction(make_action(name, shortcut, slot, icon=icon))
|
||
|
||
self.statusbar = self.statusBar()
|
||
self.statusbar.showMessage("Ready")
|
||
|
||
|
||
def update_sections(self, index):
|
||
# Clear previous sections
|
||
for i in reversed(range(self.rows_layout.count())):
|
||
widget = self.rows_layout.itemAt(i).widget()
|
||
if widget is not None:
|
||
widget.deleteLater()
|
||
self.param_sections.clear()
|
||
|
||
# Add ParamSection widgets from SECTIONS
|
||
for section in SECTIONS:
|
||
self.section_widget = ParamSection(section)
|
||
self.rows_layout.addWidget(self.section_widget)
|
||
|
||
self.param_sections.append(self.section_widget)
|
||
|
||
|
||
def clear_bubble_widgets(self):
|
||
"""Clears all widgets from the self.bubble_layout."""
|
||
while self.bubble_layout.count():
|
||
# Get the next item in the layout
|
||
item = self.bubble_layout.takeAt(0)
|
||
|
||
# Check if the item holds a widget (most common case: QGroupBox)
|
||
widget = item.widget()
|
||
if widget:
|
||
widget.deleteLater()
|
||
|
||
# Check if the item is a spacer/stretch item
|
||
elif item.spacerItem():
|
||
self.bubble_layout.removeItem(item)
|
||
|
||
def clear_all(self):
|
||
|
||
self.cancel_task()
|
||
|
||
self.right_column_widget.hide()
|
||
|
||
# Clear the bubble layout
|
||
while self.bubble_layout.count():
|
||
item = self.bubble_layout.takeAt(0)
|
||
widget = item.widget()
|
||
if widget:
|
||
widget.deleteLater()
|
||
|
||
# Clear file data
|
||
self.bubble_widgets.clear()
|
||
self.statusBar().clearMessage()
|
||
|
||
self.raw_haemo_dict = None
|
||
self.epochs_dict = None
|
||
self.fig_bytes_dict = None
|
||
self.cha_dict = None
|
||
self.contrast_results_dict = None
|
||
self.df_ind_dict = None
|
||
self.design_matrix_dict = None
|
||
self.age_dict = None
|
||
self.gender_dict = None
|
||
self.group_dict = None
|
||
self.valid_dict = None
|
||
|
||
# Reset any visible UI elements
|
||
self.button1.setVisible(False)
|
||
self.top_left_widget.clear()
|
||
|
||
|
||
def get_output_directory(self):
|
||
"""Prompts user to select a parent folder and creates a unique, dated subfolder."""
|
||
|
||
# Open a folder selection dialog to get the parent folder
|
||
parent_dir = QFileDialog.getExistingDirectory(
|
||
self,
|
||
"Select Parent Folder for Results",
|
||
os.getcwd() # Start in the user's home directory
|
||
)
|
||
|
||
if not parent_dir:
|
||
return None # User cancelled the dialog
|
||
|
||
# Create the unique subfolder name: sparks_YYYYMMDD_HHMMSS
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
subfolder_name = f"sparks_{timestamp}"
|
||
output_path = os.path.join(parent_dir, subfolder_name)
|
||
|
||
# Create the directory if it doesn't exist
|
||
try:
|
||
os.makedirs(output_path, exist_ok=True)
|
||
print(f"Created output directory: {output_path}")
|
||
return output_path
|
||
except Exception as e:
|
||
QMessageBox.critical(self, "Error", f"Failed to create output directory: {e}")
|
||
return None
|
||
|
||
def copy_text(self):
|
||
self.top_left_widget.copy() # Trigger copy
|
||
self.statusbar.showMessage("Copied to clipboard") # Show status message
|
||
|
||
def cut_text(self):
|
||
self.top_left_widget.cut() # Trigger cut
|
||
self.statusbar.showMessage("Cut to clipboard") # Show status message
|
||
|
||
def paste_text(self):
|
||
self.top_left_widget.paste() # Trigger paste
|
||
self.statusbar.showMessage("Pasted from clipboard") # Show status message
|
||
|
||
|
||
def get_training_files(self):
|
||
"""Opens a dialog to select multiple CSV files for training."""
|
||
file_paths, _ = QFileDialog.getOpenFileNames(
|
||
self,
|
||
"Select Training CSV Files",
|
||
os.getcwd(), # Start in current directory
|
||
"CSV Files (*.csv)" # Filter for CSV files
|
||
)
|
||
return file_paths
|
||
|
||
def get_testing_files(self):
|
||
"""Opens dialogs to select the model file and the video file."""
|
||
|
||
# 1. Select Model File (.pth)
|
||
model_path, _ = QFileDialog.getOpenFileName(
|
||
self,
|
||
"Select Trained Model File (.pth)",
|
||
os.getcwd(),
|
||
"Model Files (*.pth)"
|
||
)
|
||
if not model_path: return None, None
|
||
|
||
# 2. Select Video File
|
||
video_path, _ = QFileDialog.getOpenFileName(
|
||
self,
|
||
"Select Video File for Testing",
|
||
os.getcwd(),
|
||
"Video Files (*.mp4 *.avi *.mov)"
|
||
)
|
||
|
||
return model_path, video_path
|
||
|
||
|
||
def get_testing_files2(self):
|
||
"""Opens dialogs to select the model file and the csv file."""
|
||
|
||
# 1. Select Model File (.pth)
|
||
model_path, _ = QFileDialog.getOpenFileName(
|
||
self,
|
||
"Select Trained Model File (.pth)",
|
||
os.getcwd(),
|
||
"Model Files (*.pth)"
|
||
)
|
||
if not model_path: return None, None
|
||
|
||
# 2. Select Video File
|
||
csv_path, _ = QFileDialog.getOpenFileName(
|
||
self,
|
||
"Select Video File for Testing",
|
||
os.getcwd(),
|
||
"Comma Seperated Files (*.csv)"
|
||
)
|
||
|
||
return model_path, csv_path
|
||
|
||
def train_model_csv(self):
|
||
"""Initiates the model training process, including selecting CSVs and save path."""
|
||
|
||
# 1. Get CSV paths
|
||
csv_paths = self.get_training_files() # Re-use the existing CSV file selection method
|
||
if not csv_paths:
|
||
self.statusBar().showMessage("Training cancelled. No CSV files selected.")
|
||
return
|
||
|
||
# 2. Get Model Save Path (New Pop-up)
|
||
default_filename = "best_reach_lstm.pth"
|
||
|
||
# The third argument can include a default filename
|
||
model_save_path, _ = QFileDialog.getSaveFileName(
|
||
self,
|
||
"Save Trained Model File",
|
||
os.path.join(os.getcwd(), default_filename), # Start in CWD with a default name
|
||
"PyTorch Model Files (*.pth);;All Files (*)"
|
||
)
|
||
|
||
if not model_save_path:
|
||
self.statusBar().showMessage("Training cancelled. Model save location not specified.")
|
||
return
|
||
|
||
# Ensure correct extension if the user didn't type one
|
||
if not model_save_path.lower().endswith('.pth'):
|
||
model_save_path += '.pth'
|
||
|
||
self.statusBar().showMessage(f"Starting model training on {len(csv_paths)} files. Model will save to: {os.path.basename(model_save_path)}")
|
||
|
||
# 3. Start the thread with the new path
|
||
self.train_thread = TrainModelThread(
|
||
csv_paths=csv_paths,
|
||
model_save_path=model_save_path # Pass the path here
|
||
)
|
||
|
||
# ... (connect signals, start thread, disable button, etc.) ...
|
||
self.train_thread.update.connect(self.log_training_message)
|
||
self.train_thread.finished.connect(self.on_train_finished)
|
||
self.train_thread.start()
|
||
|
||
|
||
def log_training_message(self, message):
|
||
"""Displays training status messages (e.g., in a QTextEdit)."""
|
||
# Assuming you have a QTextEdit named self.log_output (or similar)
|
||
# For simplicity, we'll use the status bar for now, but a dedicated log box is better.
|
||
self.statusBar().showMessage(f"Training: {message}")
|
||
# self.top_left_widget.append(f"[TRAIN] {message}") # Use this if you have a log box
|
||
|
||
def on_train_finished(self, final_message):
|
||
"""Handles completion of the training thread."""
|
||
self.statusBar().showMessage(final_message)
|
||
# Show final message in a pop-up or log box
|
||
QMessageBox.information(self, "Training Complete", final_message)
|
||
|
||
# Clean up and re-enable the button
|
||
self.train_thread = None
|
||
|
||
def train_model_folder(self):
|
||
return
|
||
|
||
def test_model_video(self):
|
||
model_path, video_path = self.get_testing_files()
|
||
|
||
if not model_path or not video_path:
|
||
self.statusBar().showMessage("Testing cancelled. Model or Video not selected.")
|
||
return
|
||
|
||
self.statusBar().showMessage(f"Starting model testing on {os.path.basename(video_path)}...")
|
||
|
||
# 1. Create the thread
|
||
# Note: TestModelThread expects a list of video paths, so we pass [video_path]
|
||
self.test_thread = TestModelThread(video_paths=[video_path], model_path=model_path)
|
||
|
||
# 2. Connect signals
|
||
# The update_frame signal is not used to update the main GUI here,
|
||
# as the thread is responsible for showing the cv2.imshow window.
|
||
self.test_thread.finished.connect(self.on_test_finished)
|
||
|
||
# 3. Start the thread
|
||
self.test_thread.start()
|
||
|
||
|
||
|
||
def test_model_csv(self):
|
||
model_path, csv_path = self.get_testing_files2()
|
||
|
||
if not model_path or not csv_path:
|
||
self.statusBar().showMessage("Testing cancelled. Model or Video not selected.")
|
||
return
|
||
|
||
self.statusBar().showMessage(f"Starting model testing on {os.path.basename(csv_path)}...")
|
||
|
||
# 1. Create the thread
|
||
# Note: TestModelThread expects a list of video paths, so we pass [video_path]
|
||
self.test_thread = TestModelThread2(csv_path=csv_path, model_path=model_path)
|
||
|
||
# 2. Connect signals
|
||
# The update_frame signal is not used to update the main GUI here,
|
||
# as the thread is responsible for showing the cv2.imshow window.
|
||
self.test_thread.finished.connect(self.on_csv_analysis_finished)
|
||
|
||
# 3. Start the thread
|
||
self.test_thread.start()
|
||
|
||
|
||
def on_csv_analysis_finished(self, result):
|
||
if "error" in result:
|
||
QMessageBox.critical(self, "Error", result["error"])
|
||
return
|
||
|
||
time = result["time"]
|
||
active = result["active"]
|
||
before = result["before"]
|
||
|
||
import matplotlib.pyplot as plt
|
||
|
||
fig, axes = plt.subplots(2, 1, figsize=(14, 6), sharex=True)
|
||
|
||
# --- Top plot: Reach Active ---
|
||
axes[0].plot(time, active, color="green", label="Reach Active")
|
||
axes[0].set_ylabel("Probability")
|
||
axes[0].set_title("Reach Active")
|
||
axes[0].set_ylim(0, 1)
|
||
axes[0].grid(True)
|
||
|
||
# --- Bottom plot: Reach Before Contact ---
|
||
axes[1].plot(time, before, color="orange", label="Reach Before Contact")
|
||
axes[1].set_ylabel("Probability")
|
||
axes[1].set_title("Reach Before Contact")
|
||
axes[1].set_xlabel("Time (s)")
|
||
axes[1].set_ylim(0, 1)
|
||
axes[1].grid(True)
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
self.statusBar().showMessage("Complete.")
|
||
|
||
|
||
|
||
def on_test_finished(self):
|
||
"""Handles cleanup after the testing video window is closed."""
|
||
self.statusBar().showMessage("Model testing complete.")
|
||
|
||
# Ensure the OpenCV window is destroyed
|
||
try:
|
||
cv2.destroyAllWindows()
|
||
except:
|
||
pass
|
||
|
||
# Clean up and re-enable the button
|
||
self.test_thread = None
|
||
|
||
def test_model_folder(self):
|
||
return
|
||
|
||
|
||
|
||
|
||
def about_window(self):
|
||
if self.about is None or not self.about.isVisible():
|
||
self.about = AboutWindow(self)
|
||
self.about.show()
|
||
|
||
def user_guide(self):
|
||
if self.help is None or not self.help.isVisible():
|
||
self.help = UserGuideWindow(self)
|
||
self.help.show()
|
||
|
||
def terminal_gui(self):
|
||
if self.terminal is None or not self.terminal.isVisible():
|
||
self.terminal = TerminalWindow(self)
|
||
self.terminal.show()
|
||
|
||
def update_optode_positions(self):
|
||
return
|
||
if self.optodes is None or not self.optodes.isVisible():
|
||
self.optodes = UpdateOptodesWindow(self)
|
||
self.optodes.show()
|
||
|
||
|
||
def update_event_markers(self):
|
||
return
|
||
if self.events is None or not self.events.isVisible():
|
||
self.events = UpdateEventsWindow(self)
|
||
self.events.show()
|
||
|
||
def open_file_dialog(self):
|
||
file_path, _ = QFileDialog.getOpenFileName(
|
||
self, "Open File", "", "BORIS Files (*.boris);;All Files (*)"
|
||
)
|
||
if file_path:
|
||
|
||
self.observations_root = self.resolve_path_to_observations(file_path)
|
||
if not self.observations_root:
|
||
self.statusBar().showMessage("File loading cancelled by user.")
|
||
return # User cancelled the folder selection
|
||
|
||
# 2. Initialize and Show Progress Dialog
|
||
self.progress_dialog = ProgressDialog(self)
|
||
# ⚠️ Connect the cancel button to the task stopper
|
||
self.progress_dialog.cancel_button.clicked.connect(self.cancel_task)
|
||
self.progress_dialog.show()
|
||
|
||
# 3. Initialize and Start Worker Thread
|
||
self.worker_thread = FileLoadWorker(
|
||
file_path,
|
||
self.observations_root,
|
||
self.extract_frame_and_hands # Assuming this is a method available to MainApplication
|
||
)
|
||
|
||
# 4. Connect Signals to Main Thread Slots
|
||
self.worker_thread.progress_update.connect(self.progress_dialog.update_progress)
|
||
self.worker_thread.observations_loaded.connect(self.on_files_loaded)
|
||
self.worker_thread.loading_finished.connect(self.on_loading_finished)
|
||
self.worker_thread.loading_failed.connect(self.on_loading_failed)
|
||
|
||
self.worker_thread.start()
|
||
|
||
|
||
|
||
def resolve_path_to_observations(self, boris_file_path):
|
||
"""
|
||
Attempts to find the 'Observations' directory based on the BORIS file location.
|
||
Prompts the user if it cannot be found.
|
||
"""
|
||
# 1. Get the directory containing the BORIS file
|
||
boris_dir = os.path.dirname(boris_file_path)
|
||
|
||
# 2. Check if the 'Observations' folder is adjacent to the BORIS file
|
||
initial_check_path = os.path.join(boris_dir, "Observations")
|
||
|
||
if os.path.isdir(initial_check_path):
|
||
# Success: The 'Observations' folder is found next to the BORIS file
|
||
return initial_check_path
|
||
|
||
# 3. If not found, prompt the user to locate the Observations folder
|
||
QMessageBox.warning(
|
||
self,
|
||
"Observations Folder Missing",
|
||
"The 'Observations' folder (containing the videos) was not found next to the BORIS file. Please select the correct root folder."
|
||
)
|
||
|
||
# Open a folder selection dialog
|
||
folder_path = QFileDialog.getExistingDirectory(
|
||
self,
|
||
"Select Observations Root Folder",
|
||
boris_dir # Start the dialog in the BORIS file's directory
|
||
)
|
||
|
||
if folder_path:
|
||
# Check if the user selected a folder named 'Observations'
|
||
if os.path.basename(folder_path) == "Observations":
|
||
return folder_path
|
||
|
||
# Check if the user selected the *parent* of 'Observations'
|
||
elif os.path.isdir(os.path.join(folder_path, "Observations")):
|
||
return os.path.join(folder_path, "Observations")
|
||
|
||
# If the selected folder is neither, return the selected path and hope it contains the relative paths
|
||
else:
|
||
return folder_path
|
||
|
||
# If the user cancels the dialog, return None
|
||
return None
|
||
|
||
|
||
|
||
def on_files_loaded(self, boris, observations, root_path):
|
||
# 1. Update MainApplication state variables
|
||
self.boris = boris
|
||
self.observations = observations
|
||
self.observations_root = root_path
|
||
|
||
# 2. Build the UI grid using the data gathered by the worker
|
||
self.build_preview_grid(self.worker_thread.previews_data)
|
||
|
||
self.statusBar().showMessage(f"{self.worker_thread.file_path} loaded.")
|
||
self.button1.setVisible(True)
|
||
|
||
def on_loading_finished(self):
|
||
if self.progress_dialog:
|
||
self.progress_dialog.accept() # Close the dialog if finished successfully
|
||
|
||
def on_loading_failed(self, error_message):
|
||
if self.progress_dialog:
|
||
self.progress_dialog.reject() # Close the dialog
|
||
QMessageBox.critical(self, "Error Loading Files", error_message)
|
||
self.statusBar().showMessage("File loading failed.")
|
||
|
||
def cancel_task(self):
|
||
if self.worker_thread and self.worker_thread.isRunning():
|
||
self.worker_thread.stop() # Tell the thread to stop processing
|
||
self.worker_thread.wait() # Wait for the thread to safely exit
|
||
if self.progress_dialog:
|
||
self.progress_dialog.reject()
|
||
self.statusBar().showMessage("File loading cancelled.")
|
||
|
||
# Add a dummy implementation for the processing function if it's not shown:
|
||
# def extract_frame_and_hands(self, video_path, frame_idx):
|
||
# # Placeholder implementation - replace with your actual function
|
||
# return None, None
|
||
|
||
def build_preview_grid(self, previews_data):
|
||
for obs_id, obs_data in self.observations.items():
|
||
|
||
group = QGroupBox(f"Participant / Observation: {obs_id}")
|
||
grouplayout = QVBoxLayout(group)
|
||
|
||
# Participant-level skip dropdown
|
||
participant_dropdown = QComboBox()
|
||
participant_dropdown.addItem("Skip this participant", 0)
|
||
participant_dropdown.addItem("Process this participant", 1)
|
||
participant_dropdown.setCurrentIndex(0) # default: skip
|
||
if "participant_selection" not in self.selection_widgets:
|
||
self.selection_widgets["participant_selection"] = {}
|
||
self.selection_widgets["participant_selection"][obs_id] = participant_dropdown
|
||
grouplayout.addWidget(QLabel("Participant Option:"))
|
||
grouplayout.addWidget(participant_dropdown)
|
||
|
||
files_dict = obs_data.get("file", {})
|
||
media_info = obs_data.get("media_info", {})
|
||
print(media_info)
|
||
fps_dict = media_info.get("fps", {})
|
||
fps_default = float(list(fps_dict.values())[0]) if fps_dict else 30.0
|
||
|
||
if obs_id not in self.selection_widgets:
|
||
self.selection_widgets[obs_id] = {}
|
||
|
||
for cam_id, paths in files_dict.items():
|
||
|
||
# Check if the worker successfully gathered the preview data for this file
|
||
state_key = (obs_id, cam_id)
|
||
if state_key not in previews_data:
|
||
# Skip files the worker couldn't process (e.g., path not found)
|
||
continue
|
||
|
||
# Retrieve the pre-calculated data
|
||
state_data = previews_data[state_key]
|
||
frame_rgb = state_data["frame_rgb"]
|
||
results = state_data["results"]
|
||
|
||
# Only draw the overlay if the worker found data
|
||
if frame_rgb is not None:
|
||
display_img = self.draw_hand_overlay(frame_rgb, results)
|
||
else:
|
||
# Use a placeholder image or skip if no data
|
||
continue
|
||
|
||
# Convert to pixmap
|
||
h, w, _ = display_img.shape
|
||
qimg = QImage(display_img.data, w, h, 3 * w, QImage.Format_RGB888)
|
||
pix = QPixmap.fromImage(qimg).scaled(350, 350, Qt.KeepAspectRatio)
|
||
|
||
# -------- UI Row --------
|
||
row = QWidget()
|
||
row_layout = QHBoxLayout(row)
|
||
|
||
preview_label = QLabel()
|
||
preview_label.setPixmap(pix)
|
||
row_layout.addWidget(preview_label)
|
||
|
||
# Dropdown
|
||
dropdown = QComboBox()
|
||
dropdown.addItem("Skip this camera", -1)
|
||
if results and results.multi_hand_landmarks:
|
||
for idx in range(len(results.multi_hand_landmarks)):
|
||
dropdown.addItem(f"Use Hand {idx}", idx)
|
||
row_layout.addWidget(dropdown)
|
||
|
||
# Store dropdown
|
||
self.selection_widgets[obs_id][cam_id] = dropdown
|
||
|
||
self.camera_state[(obs_id, cam_id)] = {
|
||
"path": state_data["video_path"],
|
||
"frame_idx": 0,
|
||
"fps": state_data["fps"],
|
||
"preview_label": preview_label,
|
||
"dropdown": dropdown,
|
||
"initial_wrists": state_data["initial_wrists"]
|
||
}
|
||
|
||
# Skip button
|
||
skip_btn = QPushButton("Skip 1s → Rescan")
|
||
skip_btn.clicked.connect(lambda _, o=obs_id, c=cam_id: self.skip_and_rescan(o, c))
|
||
row_layout.addWidget(skip_btn)
|
||
|
||
grouplayout.addWidget(row)
|
||
|
||
self.bubble_layout.addWidget(group)
|
||
|
||
#self.bubble_layout.addStretch()
|
||
|
||
|
||
|
||
|
||
# ============================================
|
||
# FRAME EXTRACTION + MEDIA PIPE DETECTION
|
||
# ============================================
|
||
def extract_frame_and_hands(self, video_path, frame_idx):
|
||
cap = cv2.VideoCapture(video_path)
|
||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
||
ret, frame = cap.read()
|
||
cap.release()
|
||
|
||
if not ret:
|
||
return None, None
|
||
|
||
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
results = self.mp_hands.process(rgb)
|
||
return rgb, results
|
||
|
||
|
||
# ============================================
|
||
# DRAW OVERLAYED HANDS + BIG LABELS
|
||
# ============================================
|
||
def draw_hand_overlay(self, img, results, initial_wrists=None):
|
||
draw_img = img.copy()
|
||
h, w, _ = draw_img.shape
|
||
|
||
COLORS = [
|
||
(0, 255, 0), # Hand 0 green
|
||
(255, 0, 255), # Hand 1 magenta
|
||
(0, 255, 255), # Hand 2 cyan
|
||
]
|
||
|
||
if results and results.multi_hand_landmarks:
|
||
# Map detected hands to initial wrists if provided
|
||
if initial_wrists:
|
||
hand_mapping = {} # initial_idx -> current_idx
|
||
used_indices = set()
|
||
for i, iw in enumerate(initial_wrists):
|
||
min_dist = float('inf')
|
||
best_idx = None
|
||
for j, lm in enumerate(results.multi_hand_landmarks):
|
||
if j in used_indices:
|
||
continue # skip already assigned
|
||
wrist = lm.landmark[0]
|
||
dist = (wrist.x - iw[0])**2 + (wrist.y - iw[1])**2
|
||
if dist < min_dist:
|
||
min_dist = dist
|
||
best_idx = j
|
||
if best_idx is not None:
|
||
hand_mapping[i] = best_idx
|
||
used_indices.add(best_idx)
|
||
else:
|
||
hand_mapping = {i: i for i in range(len(results.multi_hand_landmarks))}
|
||
|
||
for initial_idx, current_idx in hand_mapping.items():
|
||
lm_obj = results.multi_hand_landmarks[current_idx]
|
||
mp.solutions.drawing_utils.draw_landmarks(
|
||
draw_img,
|
||
lm_obj,
|
||
mp.solutions.hands.HAND_CONNECTIONS
|
||
)
|
||
wrist = lm_obj.landmark[0]
|
||
wx, wy = int(wrist.x * w), int(wrist.y * h)
|
||
color = COLORS[initial_idx % len(COLORS)]
|
||
|
||
# Big outline
|
||
cv2.putText(draw_img, str(initial_idx), (wx, wy - 40),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 3.2,
|
||
(0, 0, 0), 10, cv2.LINE_AA)
|
||
|
||
# Big colored label
|
||
cv2.putText(draw_img, str(initial_idx), (wx, wy - 40),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 3.2,
|
||
color, 6, cv2.LINE_AA)
|
||
|
||
return draw_img
|
||
|
||
|
||
# ============================================
|
||
# SKIP 1s AND RESCAN FRAME
|
||
# ============================================
|
||
def skip_and_rescan(self, obs_id, cam_id):
|
||
key = (obs_id, cam_id)
|
||
state = self.camera_state[key]
|
||
|
||
fps = state["fps"]
|
||
state["frame_idx"] += int(fps)
|
||
|
||
rgb, results = self.extract_frame_and_hands(
|
||
state["path"], state["frame_idx"]
|
||
)
|
||
if rgb is None:
|
||
return
|
||
|
||
display_img = self.draw_hand_overlay(rgb, results, state.get("initial_wrists"))
|
||
h, w, _ = display_img.shape
|
||
qimg = QImage(display_img.data, w, h, 3 * w, QImage.Format_RGB888)
|
||
pix = QPixmap.fromImage(qimg).scaled(350, 350, Qt.KeepAspectRatio)
|
||
|
||
# Update preview
|
||
state["preview_label"].setPixmap(pix)
|
||
|
||
# Update dropdown
|
||
dropdown = state["dropdown"]
|
||
dropdown.clear()
|
||
dropdown.addItem("Skip this camera", -1)
|
||
if results and results.multi_hand_landmarks:
|
||
for idx in range(len(results.multi_hand_landmarks)):
|
||
dropdown.addItem(f"Use Hand {idx}", idx)
|
||
|
||
|
||
|
||
# def save_project(self, onCrash=False):
|
||
|
||
# if not onCrash:
|
||
# filename, _ = QFileDialog.getSaveFileName(
|
||
# self, "Save Project", "", "SPARK Project (*.spark)"
|
||
# )
|
||
# if not filename:
|
||
# return
|
||
# else:
|
||
# if PLATFORM_NAME == "darwin":
|
||
# filename = os.path.join(os.path.dirname(sys.executable), "../../../sparks_autosave.spark")
|
||
# else:
|
||
# filename = os.path.join(os.getcwd(), "sparks_autosave.spark")
|
||
|
||
# try:
|
||
# # Ensure the filename has the proper extension
|
||
# if not filename.endswith(".spark"):
|
||
# filename += ".spark"
|
||
|
||
# project_path = Path(filename).resolve()
|
||
# project_dir = project_path.parent
|
||
|
||
# file_list = [
|
||
# str(PurePosixPath(Path(bubble.file_path).resolve().relative_to(project_dir)))
|
||
# for bubble in self.bubble_widgets.values()
|
||
# ]
|
||
|
||
# progress_states = {
|
||
# str(PurePosixPath(Path(bubble.file_path).resolve().relative_to(project_dir))): bubble.current_step
|
||
# for bubble in self.bubble_widgets.values()
|
||
# }
|
||
|
||
# project_data = {
|
||
# "file_list": file_list,
|
||
# "progress_states": progress_states,
|
||
# "raw_haemo_dict": self.raw_haemo_dict,
|
||
# "epochs_dict": self.epochs_dict,
|
||
# "fig_bytes_dict": self.fig_bytes_dict,
|
||
# "cha_dict": self.cha_dict,
|
||
# "contrast_results_dict": self.contrast_results_dict,
|
||
# "df_ind_dict": self.df_ind_dict,
|
||
# "design_matrix_dict": self.design_matrix_dict,
|
||
# "age_dict": self.age_dict,
|
||
# "gender_dict": self.gender_dict,
|
||
# "group_dict": self.group_dict,
|
||
# "valid_dict": self.valid_dict,
|
||
# }
|
||
|
||
# def sanitize(obj):
|
||
# if isinstance(obj, Path):
|
||
# return str(PurePosixPath(obj))
|
||
# elif isinstance(obj, dict):
|
||
# return {sanitize(k): sanitize(v) for k, v in obj.items()}
|
||
# elif isinstance(obj, list):
|
||
# return [sanitize(i) for i in obj]
|
||
# return obj
|
||
|
||
# project_data = sanitize(project_data)
|
||
|
||
# with open(filename, "wb") as f:
|
||
# pickle.dump(project_data, f)
|
||
|
||
# QMessageBox.information(self, "Success", f"Project saved to:\n{filename}")
|
||
|
||
# except Exception as e:
|
||
# if not onCrash:
|
||
# QMessageBox.critical(self, "Error", f"Failed to save project:\n{e}")
|
||
|
||
|
||
|
||
# def load_project(self):
|
||
# filename, _ = QFileDialog.getOpenFileName(
|
||
# self, "Load Project", "", "SPARK Project (*.spark)"
|
||
# )
|
||
# if not filename:
|
||
# return
|
||
|
||
# try:
|
||
# with open(filename, "rb") as f:
|
||
# data = pickle.load(f)
|
||
|
||
# self.raw_haemo_dict = data.get("raw_haemo_dict", {})
|
||
# self.epochs_dict = data.get("epochs_dict", {})
|
||
# self.fig_bytes_dict = data.get("fig_bytes_dict", {})
|
||
# self.cha_dict = data.get("cha_dict", {})
|
||
# self.contrast_results_dict = data.get("contrast_results_dict", {})
|
||
# self.df_ind_dict = data.get("df_ind_dict", {})
|
||
# self.design_matrix_dict = data.get("design_matrix_dict", {})
|
||
# self.age_dict = data.get("age_dict", {})
|
||
# self.gender_dict = data.get("gender_dict", {})
|
||
# self.group_dict = data.get("group_dict", {})
|
||
# self.valid_dict = data.get("valid_dict", {})
|
||
|
||
# project_dir = Path(filename).parent
|
||
|
||
# # Convert saved relative paths to absolute paths
|
||
# file_list = [str((project_dir / Path(rel_path)).resolve()) for rel_path in data["file_list"]]
|
||
|
||
# # Also resolve progress_states with updated paths
|
||
# raw_progress = data.get("progress_states", {})
|
||
# progress_states = {
|
||
# str((project_dir / Path(rel_path)).resolve()): step
|
||
# for rel_path, step in raw_progress.items()
|
||
# }
|
||
|
||
# self.show_files_as_bubbles_from_list(file_list, progress_states, filename)
|
||
|
||
# # Re-enable buttons
|
||
# # self.button1.setVisible(True)
|
||
# self.button3.setVisible(True)
|
||
|
||
# QMessageBox.information(self, "Loaded", f"Project loaded from:\n{filename}")
|
||
|
||
# except Exception as e:
|
||
# QMessageBox.critical(self, "Error", f"Failed to load project:\n{e}")
|
||
|
||
|
||
def placeholder(self):
|
||
QMessageBox.information(self, "Placeholder", "This feature is not implemented yet.")
|
||
|
||
def save_metadata(self, file_path):
|
||
if not file_path:
|
||
return
|
||
|
||
self.file_metadata[file_path] = {
|
||
key: field.text()
|
||
for key, field in self.meta_fields.items()
|
||
}
|
||
|
||
def get_all_metadata(self):
|
||
# First, make sure current file's edits are saved
|
||
|
||
for field in self.meta_fields.values():
|
||
field.clearFocus()
|
||
|
||
# Save current file's metadata
|
||
if self.current_file:
|
||
self.save_metadata(self.current_file)
|
||
|
||
return self.file_metadata
|
||
|
||
|
||
def cancel_task(self):
|
||
self.button1.clicked.disconnect(self.cancel_task)
|
||
self.button1.setText("Stopping...")
|
||
|
||
if hasattr(self, "result_process") and self.result_process.is_alive():
|
||
parent = psutil.Process(self.result_process.pid)
|
||
children = parent.children(recursive=True)
|
||
for child in children:
|
||
try:
|
||
child.kill()
|
||
except psutil.NoSuchProcess:
|
||
pass
|
||
self.result_process.terminate()
|
||
self.result_process.join()
|
||
|
||
if hasattr(self, "result_timer") and self.result_timer.isActive():
|
||
self.result_timer.stop()
|
||
|
||
# if hasattr(self, "result_process") and self.result_process.is_alive():
|
||
# self.result_process.terminate() # Forcefully terminate the process
|
||
# self.result_process.join() # Wait for it to properly close
|
||
|
||
# # Stop the QTimer if running
|
||
# if hasattr(self, "result_timer") and self.result_timer.isActive():
|
||
# self.result_timer.stop()
|
||
|
||
|
||
self.statusbar.showMessage("Processing cancelled.")
|
||
self.button1.clicked.connect(self.on_run_task)
|
||
self.button1.setText("Process")
|
||
|
||
|
||
'''MODULE FILE'''
|
||
def on_run_task(self):
|
||
|
||
self.button1.clicked.disconnect(self.on_run_task)
|
||
self.button1.setText("Cancel")
|
||
self.button1.clicked.connect(self.cancel_task)
|
||
|
||
# Clear previous processing widgets (if any)
|
||
# for obs_id, widgets in getattr(self, "processing_widgets", {}).items():
|
||
# for w in widgets.values():
|
||
# if isinstance(w, QWidget):
|
||
# w.deleteLater()
|
||
|
||
self.clear_bubble_widgets()
|
||
|
||
self.processing_widgets = {}
|
||
self.processing_threads = []
|
||
|
||
self.output_dir = self.get_output_directory()
|
||
if not self.output_dir:
|
||
# Revert button state if user cancels
|
||
self.button1.clicked.disconnect(self.cancel_task)
|
||
self.button1.setText("Process")
|
||
self.button1.clicked.connect(self.on_run_task)
|
||
self.statusBar().showMessage("Processing cancelled by user.")
|
||
return
|
||
|
||
selected_files_to_process = []
|
||
|
||
# 2️⃣ Iterate through all participant/camera combinations to find what to process
|
||
# We must iterate over all entries in self.selection_widgets (the structure built by build_preview_grid)
|
||
for obs_id, cam_dropdowns in self.selection_widgets.items():
|
||
if obs_id == "participant_selection":
|
||
continue # skip the participant-level dropdowns
|
||
for cam_id, dropdown in cam_dropdowns.items():
|
||
|
||
selected_hand_idx = dropdown.currentData()
|
||
|
||
# Check if the selection is NOT 'Skip' (-1)
|
||
if selected_hand_idx != -1:
|
||
state_data = self.camera_state.get((obs_id, cam_id))
|
||
if state_data:
|
||
selected_files_to_process.append({
|
||
"obs_id": obs_id,
|
||
"cam_id": cam_id,
|
||
"hand_idx": selected_hand_idx,
|
||
"path": state_data.get("path", "Unknown")
|
||
})
|
||
|
||
if not selected_files_to_process:
|
||
self.statusBar().showMessage("No files selected for processing (use dropdowns).")
|
||
# Revert button state
|
||
self.button1.clicked.disconnect(self.cancel_task)
|
||
self.button1.setText("Process")
|
||
self.button1.clicked.connect(self.on_run_task)
|
||
return
|
||
|
||
self.statusBar().showMessage(f"Starting to process {len(selected_files_to_process)} files...")
|
||
QApplication.processEvents() # Refresh GUI before starting threads
|
||
|
||
for file_info in selected_files_to_process:
|
||
obs_id = file_info["obs_id"]
|
||
cam_id = file_info["cam_id"]
|
||
selected_hand_idx = file_info["hand_idx"] # THIS is fixed per iteration
|
||
|
||
# ... (thread logic setup) ...
|
||
|
||
# --- PROGRESS ROW UI SETUP (50% Label | 40% Bar | 10% Time) ---
|
||
row = QWidget()
|
||
row_layout = QHBoxLayout(row)
|
||
row_layout.setContentsMargins(0, 0, 0, 0)
|
||
|
||
# LEFT SIDE (50%): Filename Label
|
||
filename = os.path.basename(file_info["path"])
|
||
filename_label = QLabel(f"P {obs_id}/{cam_id}: {filename}")
|
||
filename_label.setWordWrap(True)
|
||
row_layout.addWidget(filename_label, stretch=5) # 5 units = 50%
|
||
|
||
# MIDDLE (40%): Progress Bar
|
||
progress_bar = QProgressBar()
|
||
progress_bar.setRange(0, 100)
|
||
row_layout.addWidget(progress_bar, stretch=4) # 4 units = 40%
|
||
|
||
# RIGHT SIDE (10%): Time Indicator Label
|
||
time_label = QLabel("00:00/00:00") # Default placeholder
|
||
time_label.setAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter)
|
||
row_layout.addWidget(time_label, stretch=1) # 1 unit = 10%
|
||
|
||
# Add the combined row to the bubble layout area
|
||
row_index = self.bubble_layout.rowCount()
|
||
self.bubble_layout.addWidget(row, row_index, 0, 1, 1)
|
||
|
||
# Save widgets for thread updates
|
||
state_key = (obs_id, cam_id)
|
||
self.processing_widgets[state_key] = {
|
||
"label": filename_label,
|
||
"progress": progress_bar,
|
||
# 👇 NEW: Save the time label reference
|
||
"time_label": time_label
|
||
}
|
||
# --- THREAD START ---
|
||
output_csv = f"{obs_id}_{cam_id}_processed.csv"
|
||
|
||
# ⚠️ Ensure ParticipantProcessor is defined to accept self.observations_root
|
||
processor = ParticipantProcessor(
|
||
obs_id=obs_id,
|
||
boris_json=self.boris,
|
||
selected_cam_id=cam_id,
|
||
selected_hand_idx=selected_hand_idx,
|
||
observations_root=self.observations_root, # CRITICAL PATH FIX
|
||
hide_preview=False, # Assuming default for now
|
||
output_csv=output_csv,
|
||
output_dir=self.output_dir,
|
||
initial_wrists=self.camera_state[(obs_id, cam_id)].get("initial_wrists")
|
||
)
|
||
|
||
# Connect signals to the UI elements
|
||
processor.progress_updated.connect(progress_bar.setValue)
|
||
processor.finished_processing.connect(
|
||
lambda obs_id=obs_id, cam_id=cam_id: self.on_processing_finished(obs_id, cam_id)
|
||
)
|
||
processor.time_updated.connect(time_label.setText) # Connects to THIS time_label
|
||
processor.start()
|
||
self.processing_threads.append(processor)
|
||
|
||
# Set the vertical stretch to push all progress bars to the top
|
||
num_rows = self.bubble_layout.rowCount()
|
||
self.bubble_layout.setRowStretch(num_rows, 1)
|
||
|
||
|
||
def check_for_pipeline_results(self):
|
||
return
|
||
|
||
def on_processing_finished(self, obs_id, cam_id):
|
||
state_key = (obs_id, cam_id)
|
||
if state_key in self.processing_widgets:
|
||
widgets = self.processing_widgets[state_key]
|
||
widgets["progress"].setValue(100)
|
||
label = self.processing_widgets[state_key]["label"]
|
||
label.setText(f"P {obs_id}/{cam_id}: ✅ COMPLETE")
|
||
|
||
# Check if all threads are finished to revert the main button
|
||
if not any(t.isRunning() for t in self.processing_threads):
|
||
self.cancel_task() # Reverts the button and cleans up
|
||
self.statusBar().showMessage("All processing complete!")
|
||
|
||
|
||
def cancel_task(self):
|
||
"""Stops all active processing threads and reverts UI to 'Process' state."""
|
||
|
||
if not self.processing_threads:
|
||
return
|
||
|
||
self.statusBar().showMessage("Cancelling processing... Please wait.")
|
||
|
||
# 1. Stop all threads gracefully
|
||
for thread in self.processing_threads:
|
||
if thread.isRunning():
|
||
# Assuming ParticipantProcessor has a .stop() method
|
||
thread.stop()
|
||
thread.wait() # Wait for thread to exit its loop
|
||
|
||
# 2. Clean up threads and widgets
|
||
self.processing_threads.clear()
|
||
|
||
# 3. Revert button and status bar
|
||
self.button1.clicked.disconnect(self.cancel_task)
|
||
self.button1.setText("Process")
|
||
self.button1.clicked.connect(self.on_run_task)
|
||
|
||
# 4. Clear the progress bars display
|
||
#self.clear_bubble_widgets()
|
||
self.statusBar().showMessage("Processing cancelled. Ready to load or process.")
|
||
|
||
def show_error_popup(self, title, error_message, traceback_str=""):
|
||
msgbox = QMessageBox(self)
|
||
msgbox.setIcon(QMessageBox.Warning)
|
||
msgbox.setWindowTitle("Warning - SPARKS")
|
||
|
||
message = (
|
||
f"SPARKS has encountered an error processing the file {title}.<br><br>"
|
||
"This error was likely due to incorrect parameters on the right side of the screen and not an error with your data. "
|
||
"Processing of the remaining files continues in the background and this participant will be ignored in the analysis. "
|
||
"If you think the parameters on the right side are correct for your data, raise an issue <a href='https://git.research.dezeeuw.ca/tyler/sparks/issues'>here</a>.<br><br>"
|
||
f"Error message: {error_message}"
|
||
)
|
||
|
||
msgbox.setTextFormat(Qt.TextFormat.RichText)
|
||
msgbox.setText(message)
|
||
msgbox.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction)
|
||
|
||
# Add traceback to detailed text
|
||
if traceback_str:
|
||
msgbox.setDetailedText(traceback_str)
|
||
|
||
msgbox.setStandardButtons(QMessageBox.Ok)
|
||
msgbox.exec_()
|
||
|
||
|
||
def cleanup_after_process(self):
|
||
|
||
if hasattr(self, 'result_process'):
|
||
self.result_process.join(timeout=0)
|
||
if self.result_process.is_alive():
|
||
self.result_process.terminate()
|
||
self.result_process.join()
|
||
|
||
if hasattr(self, 'result_queue'):
|
||
if 'AutoProxy' in repr(self.result_queue):
|
||
pass
|
||
else:
|
||
self.result_queue.close()
|
||
self.result_queue.join_thread()
|
||
|
||
if hasattr(self, 'progress_queue'):
|
||
if 'AutoProxy' in repr(self.progress_queue):
|
||
pass
|
||
else:
|
||
self.progress_queue.close()
|
||
self.progress_queue.join_thread()
|
||
|
||
# Shutdown manager to kill its server process and clean up
|
||
if hasattr(self, 'manager'):
|
||
self.manager.shutdown()
|
||
|
||
|
||
|
||
|
||
'''UPDATER'''
|
||
def manual_check_for_updates(self):
|
||
self.local_check_thread = LocalPendingUpdateCheckThread(CURRENT_VERSION, self.platform_suffix)
|
||
self.local_check_thread.pending_update_found.connect(self.on_pending_update_found)
|
||
self.local_check_thread.no_pending_update.connect(self.on_no_pending_update)
|
||
self.local_check_thread.start()
|
||
|
||
def on_pending_update_found(self, version, folder_path):
|
||
self.statusBar().showMessage(f"Pending update found: version {version}")
|
||
self.pending_update_version = version
|
||
self.pending_update_path = folder_path
|
||
self.show_pending_update_popup()
|
||
|
||
def on_no_pending_update(self):
|
||
# No pending update found locally, start server check directly
|
||
self.statusBar().showMessage("No pending local update found. Checking server...")
|
||
self.start_update_check_thread()
|
||
|
||
def show_pending_update_popup(self):
|
||
msg_box = QMessageBox(self)
|
||
msg_box.setWindowTitle("Pending Update Found")
|
||
msg_box.setText(f"A previously downloaded update (version {self.pending_update_version}) is available at:\n{self.pending_update_path}\nWould you like to install it now?")
|
||
install_now_button = msg_box.addButton("Install Now", QMessageBox.ButtonRole.AcceptRole)
|
||
install_later_button = msg_box.addButton("Install Later", QMessageBox.ButtonRole.RejectRole)
|
||
msg_box.exec()
|
||
|
||
if msg_box.clickedButton() == install_now_button:
|
||
self.install_update(self.pending_update_path)
|
||
else:
|
||
self.statusBar().showMessage("Pending update available. Install later.")
|
||
# After user dismisses, still check the server for new updates
|
||
self.start_update_check_thread()
|
||
|
||
def start_update_check_thread(self):
|
||
self.check_thread = UpdateCheckThread()
|
||
self.check_thread.download_requested.connect(self.on_server_update_requested)
|
||
self.check_thread.no_update_available.connect(self.on_server_no_update)
|
||
self.check_thread.error_occurred.connect(self.on_error)
|
||
self.check_thread.start()
|
||
|
||
def on_server_no_update(self):
|
||
self.statusBar().showMessage("No new updates found on server.", 5000)
|
||
|
||
def on_server_update_requested(self, download_url, latest_version):
|
||
if self.pending_update_version:
|
||
cmp = self.version_compare(latest_version, self.pending_update_version)
|
||
if cmp > 0:
|
||
# Server version is newer than pending update
|
||
self.statusBar().showMessage(f"Newer version {latest_version} available on server. Removing old pending update...")
|
||
try:
|
||
shutil.rmtree(self.pending_update_path)
|
||
self.statusBar().showMessage(f"Deleted old update folder: {self.pending_update_path}")
|
||
except Exception as e:
|
||
self.statusBar().showMessage(f"Failed to delete old update folder: {e}")
|
||
|
||
# Clear pending update info so new download proceeds
|
||
self.pending_update_version = None
|
||
self.pending_update_path = None
|
||
|
||
# Download the new update
|
||
self.download_update(download_url, latest_version)
|
||
elif cmp == 0:
|
||
# Versions equal, no download needed
|
||
self.statusBar().showMessage(f"Pending update version {self.pending_update_version} is already latest. No download needed.")
|
||
else:
|
||
# Server version older than pending? Unlikely but just keep pending update
|
||
self.statusBar().showMessage(f"Pending update version {self.pending_update_version} is newer than server version. No action.")
|
||
else:
|
||
# No pending update, just download
|
||
self.download_update(download_url, latest_version)
|
||
|
||
def download_update(self, download_url, latest_version):
|
||
self.statusBar().showMessage("Downloading update...")
|
||
self.download_thread = UpdateDownloadThread(download_url, latest_version)
|
||
self.download_thread.update_ready.connect(self.on_update_ready)
|
||
self.download_thread.error_occurred.connect(self.on_error)
|
||
self.download_thread.start()
|
||
|
||
def on_update_ready(self, latest_version, extract_folder):
|
||
self.statusBar().showMessage("Update downloaded and extracted.")
|
||
|
||
msg_box = QMessageBox(self)
|
||
msg_box.setWindowTitle("Update Ready")
|
||
msg_box.setText(f"Version {latest_version} has been downloaded and extracted to:\n{extract_folder}\nWould you like to install it now?")
|
||
install_now_button = msg_box.addButton("Install Now", QMessageBox.ButtonRole.AcceptRole)
|
||
install_later_button = msg_box.addButton("Install Later", QMessageBox.ButtonRole.RejectRole)
|
||
|
||
msg_box.exec()
|
||
|
||
if msg_box.clickedButton() == install_now_button:
|
||
self.install_update(extract_folder)
|
||
else:
|
||
self.statusBar().showMessage("Update ready. Install later.")
|
||
|
||
|
||
def install_update(self, extract_folder):
|
||
# Path to updater executable
|
||
|
||
if PLATFORM_NAME == 'windows':
|
||
updater_path = os.path.join(os.getcwd(), "sparks_updater.exe")
|
||
elif PLATFORM_NAME == 'darwin':
|
||
if getattr(sys, 'frozen', False):
|
||
updater_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks_updater.app")
|
||
else:
|
||
updater_path = os.path.join(os.getcwd(), "../sparks_updater.app")
|
||
|
||
elif PLATFORM_NAME == 'linux':
|
||
updater_path = os.path.join(os.getcwd(), "sparks_updater")
|
||
else:
|
||
updater_path = os.getcwd()
|
||
|
||
if not os.path.exists(updater_path):
|
||
QMessageBox.critical(self, "Error", f"Updater not found at:\n{updater_path}. The absolute path was {os.path.abspath(updater_path)}")
|
||
return
|
||
|
||
# Launch updater with extracted folder path as argument
|
||
try:
|
||
# Pass current app's executable path for updater to relaunch
|
||
main_app_executable = os.path.abspath(sys.argv[0])
|
||
|
||
print(f'Launching updater with: "{updater_path}" "{extract_folder}" "{main_app_executable}"')
|
||
|
||
if PLATFORM_NAME == 'darwin':
|
||
subprocess.Popen(['open', updater_path, '--args', extract_folder, main_app_executable])
|
||
else:
|
||
subprocess.Popen([updater_path, f'{extract_folder}', f'{main_app_executable}'], cwd=os.path.dirname(updater_path))
|
||
|
||
# Close the current app so updater can replace files
|
||
sys.exit(0)
|
||
|
||
except Exception as e:
|
||
QMessageBox.critical(self, "Error", f"[Updater Launch Failed]\n{str(e)}\n{traceback.format_exc()}")
|
||
|
||
def on_error(self, message):
|
||
# print(f"Error: {message}")
|
||
self.statusBar().showMessage(f"Error occurred during update process. {message}")
|
||
|
||
def version_compare(self, v1, v2):
|
||
def normalize(v): return [int(x) for x in v.split(".")]
|
||
return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2))
|
||
|
||
|
||
|
||
def closeEvent(self, event):
|
||
# Gracefully shut down multiprocessing children
|
||
print("Window is closing. Cleaning up...")
|
||
|
||
if hasattr(self, 'manager'):
|
||
self.manager.shutdown()
|
||
|
||
for child in self.findChildren(QWidget):
|
||
if child is not self and child.isVisible():
|
||
child.close()
|
||
|
||
kill_child_processes()
|
||
|
||
event.accept()
|
||
|
||
|
||
def wait_for_process_to_exit(process_name, timeout=10):
|
||
"""
|
||
Waits for a process with the specified name to exit within a timeout period.
|
||
|
||
Args:
|
||
process_name (str): Name (or part of the name) of the process to wait for.
|
||
timeout (int, optional): Maximum time to wait in seconds. Defaults to 10.
|
||
|
||
Returns:
|
||
bool: True if the process exited before the timeout, False otherwise.
|
||
"""
|
||
|
||
print(f"Waiting for {process_name} to exit...")
|
||
deadline = time.time() + timeout
|
||
while time.time() < deadline:
|
||
still_running = False
|
||
for proc in psutil.process_iter(['name']):
|
||
try:
|
||
if proc.info['name'] and process_name.lower() in proc.info['name'].lower():
|
||
still_running = True
|
||
print(f"Still running: {proc.info['name']} (PID: {proc.pid})")
|
||
break
|
||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||
continue
|
||
if not still_running:
|
||
print(f"{process_name} has exited.")
|
||
return True
|
||
time.sleep(0.5)
|
||
print(f"{process_name} did not exit in time.")
|
||
return False
|
||
|
||
|
||
def finish_update_if_needed():
|
||
"""
|
||
Completes a pending application update if '--finish-update' is present in the command-line arguments.
|
||
"""
|
||
|
||
if "--finish-update" in sys.argv:
|
||
print("Finishing update...")
|
||
|
||
if PLATFORM_NAME == 'darwin':
|
||
app_dir = '/tmp/sparkstempupdate'
|
||
else:
|
||
app_dir = os.getcwd()
|
||
|
||
# 1. Find update folder
|
||
update_folder = None
|
||
for entry in os.listdir(app_dir):
|
||
entry_path = os.path.join(app_dir, entry)
|
||
if os.path.isdir(entry_path) and entry.startswith("sparks-") and entry.endswith("-" + PLATFORM_NAME):
|
||
update_folder = os.path.join(app_dir, entry)
|
||
break
|
||
|
||
if update_folder is None:
|
||
print("No update folder found. Skipping update steps.")
|
||
return
|
||
|
||
if PLATFORM_NAME == 'darwin':
|
||
update_folder = os.path.join(update_folder, "sparks-darwin")
|
||
|
||
# 2. Wait for sparks_updater to exit
|
||
print("Waiting for sparks_updater to exit...")
|
||
for proc in psutil.process_iter(['pid', 'name']):
|
||
if proc.info['name'] and "sparks_updater" in proc.info['name'].lower():
|
||
try:
|
||
proc.wait(timeout=5)
|
||
except psutil.TimeoutExpired:
|
||
print("Force killing lingering sparks_updater")
|
||
proc.kill()
|
||
|
||
# 3. Replace the updater
|
||
if PLATFORM_NAME == 'windows':
|
||
new_updater = os.path.join(update_folder, "sparks_updater.exe")
|
||
dest_updater = os.path.join(app_dir, "sparks_updater.exe")
|
||
|
||
elif PLATFORM_NAME == 'darwin':
|
||
new_updater = os.path.join(update_folder, "sparks_updater.app")
|
||
dest_updater = os.path.abspath(os.path.join(sys.executable, "../../../../sparks_updater.app"))
|
||
|
||
elif PLATFORM_NAME == 'linux':
|
||
new_updater = os.path.join(update_folder, "sparks_updater")
|
||
dest_updater = os.path.join(app_dir, "sparks_updater")
|
||
|
||
else:
|
||
print("No platform??")
|
||
new_updater = os.getcwd()
|
||
dest_updater = os.getcwd()
|
||
|
||
print(f"New updater is {new_updater}")
|
||
print(f"Dest updater is {dest_updater}")
|
||
|
||
print("Writable?", os.access(dest_updater, os.W_OK))
|
||
print("Executable path:", sys.executable)
|
||
print("Trying to copy:", new_updater, "->", dest_updater)
|
||
|
||
if os.path.exists(new_updater):
|
||
try:
|
||
if os.path.exists(dest_updater):
|
||
if PLATFORM_NAME == 'darwin':
|
||
try:
|
||
if os.path.isdir(dest_updater):
|
||
shutil.rmtree(dest_updater)
|
||
print(f"Deleted directory: {dest_updater}")
|
||
else:
|
||
os.remove(dest_updater)
|
||
print(f"Deleted file: {dest_updater}")
|
||
except Exception as e:
|
||
print(f"Error deleting {dest_updater}: {e}")
|
||
else:
|
||
os.remove(dest_updater)
|
||
|
||
if PLATFORM_NAME == 'darwin':
|
||
wait_for_process_to_exit("sparks_updater", timeout=10)
|
||
subprocess.check_call(["ditto", new_updater, dest_updater])
|
||
else:
|
||
shutil.copy2(new_updater, dest_updater)
|
||
|
||
if PLATFORM_NAME in ('linux', 'darwin'):
|
||
os.chmod(dest_updater, 0o755)
|
||
|
||
if PLATFORM_NAME == 'darwin':
|
||
remove_quarantine(dest_updater)
|
||
|
||
print("sparks_updater replaced.")
|
||
except Exception as e:
|
||
print(f"Failed to replace sparks_updater: {e}")
|
||
|
||
# 4. Delete the update folder
|
||
try:
|
||
if PLATFORM_NAME == 'darwin':
|
||
shutil.rmtree(app_dir)
|
||
else:
|
||
shutil.rmtree(update_folder)
|
||
except Exception as e:
|
||
print(f"Failed to delete update folder: {e}")
|
||
|
||
QMessageBox.information(None, "Update Complete", "The application has been successfully updated.")
|
||
sys.argv.remove("--finish-update")
|
||
|
||
|
||
def remove_quarantine(app_path):
|
||
"""
|
||
Removes the macOS quarantine attribute from the specified application path.
|
||
"""
|
||
|
||
script = f'''
|
||
do shell script "xattr -d -r com.apple.quarantine {shlex.quote(app_path)}" with administrator privileges with prompt "SPARKS needs privileges to finish the update. (2/2)"
|
||
'''
|
||
try:
|
||
subprocess.run(['osascript', '-e', script], check=True)
|
||
print("✅ Quarantine attribute removed.")
|
||
except subprocess.CalledProcessError as e:
|
||
print("❌ Failed to remove quarantine attribute.")
|
||
print(e)
|
||
|
||
|
||
def resource_path(relative_path):
|
||
"""
|
||
Get absolute path to resource regardless of running directly or packaged using PyInstaller
|
||
"""
|
||
|
||
if hasattr(sys, '_MEIPASS'):
|
||
# PyInstaller bundle path
|
||
base_path = sys._MEIPASS
|
||
else:
|
||
base_path = os.path.dirname(os.path.abspath(__file__))
|
||
|
||
return os.path.join(base_path, relative_path)
|
||
|
||
|
||
def kill_child_processes():
|
||
"""
|
||
Goodbye children
|
||
"""
|
||
|
||
try:
|
||
parent = psutil.Process(os.getpid())
|
||
children = parent.children(recursive=True)
|
||
for child in children:
|
||
try:
|
||
child.kill()
|
||
except psutil.NoSuchProcess:
|
||
pass
|
||
psutil.wait_procs(children, timeout=5)
|
||
except Exception as e:
|
||
print(f"Error killing child processes: {e}")
|
||
|
||
|
||
def exception_hook(exc_type, exc_value, exc_traceback):
|
||
"""
|
||
Method that will display a popup when the program hard crashes containg what went wrong
|
||
"""
|
||
|
||
error_msg = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
||
print(error_msg) # also print to console
|
||
|
||
kill_child_processes()
|
||
|
||
# Show error message box
|
||
# Make sure QApplication exists (or create a minimal one)
|
||
app = QApplication.instance()
|
||
if app is None:
|
||
app = QApplication(sys.argv)
|
||
|
||
show_critical_error(error_msg)
|
||
|
||
# Exit the app after user acknowledges
|
||
sys.exit(1)
|
||
|
||
def show_critical_error(error_msg):
|
||
msg_box = QMessageBox()
|
||
msg_box.setIcon(QMessageBox.Icon.Critical)
|
||
msg_box.setWindowTitle("Something went wrong!")
|
||
|
||
if PLATFORM_NAME == "darwin":
|
||
log_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks.log")
|
||
log_path2 = os.path.join(os.path.dirname(sys.executable), "../../../sparks_error.log")
|
||
save_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks_autosave.spark")
|
||
|
||
else:
|
||
log_path = os.path.join(os.getcwd(), "sparks.log")
|
||
log_path2 = os.path.join(os.getcwd(), "sparks_error.log")
|
||
save_path = os.path.join(os.getcwd(), "sparks_autosave.spark")
|
||
|
||
|
||
shutil.copy(log_path, log_path2)
|
||
log_path2 = Path(log_path2).absolute().as_posix()
|
||
autosave_path = Path(save_path).absolute().as_posix()
|
||
log_link = f"file:///{log_path2}"
|
||
autosave_link = f"file:///{autosave_path}"
|
||
|
||
window.save_project(True)
|
||
|
||
message = (
|
||
"SPARKS has encountered an unrecoverable error and needs to close.<br><br>"
|
||
f"We are sorry for the inconvenience. An autosave was attempted to be saved to <a href='{autosave_link}'>{autosave_path}</a>, but it may not have been saved. "
|
||
"If the file was saved, it still may not be intact, openable, or contain the correct data. Use the autosave at your discretion.<br><br>"
|
||
"This unrecoverable error was likely due to an error with SPARKS and not your data.<br>"
|
||
f"Please raise an issue <a href='https://git.research.dezeeuw.ca/tyler/sparks/issues'>here</a> and attach the error file located at <a href='{log_link}'>{log_path2}</a><br><br>"
|
||
f"<pre>{error_msg}</pre>"
|
||
)
|
||
|
||
msg_box.setTextFormat(Qt.TextFormat.RichText)
|
||
msg_box.setText(message)
|
||
msg_box.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction)
|
||
msg_box.setStandardButtons(QMessageBox.StandardButton.Ok)
|
||
|
||
msg_box.exec()
|
||
|
||
if __name__ == "__main__":
|
||
# Redirect exceptions to the popup window
|
||
sys.excepthook = exception_hook
|
||
|
||
# Set up application logging
|
||
if PLATFORM_NAME == "darwin":
|
||
log_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks.log")
|
||
else:
|
||
log_path = os.path.join(os.getcwd(), "sparks.log")
|
||
|
||
try:
|
||
os.remove(log_path)
|
||
except:
|
||
pass
|
||
|
||
sys.stdout = open(log_path, "a", buffering=1)
|
||
sys.stderr = sys.stdout
|
||
print(f"\n=== App started at {datetime.now()} ===\n")
|
||
|
||
freeze_support() # Required for PyInstaller + multiprocessing
|
||
|
||
# Only run GUI in the main process
|
||
if current_process().name == 'MainProcess':
|
||
app = QApplication(sys.argv)
|
||
finish_update_if_needed()
|
||
window = MainApplication()
|
||
|
||
if PLATFORM_NAME == "darwin":
|
||
app.setWindowIcon(QIcon(resource_path("icons/main.icns")))
|
||
window.setWindowIcon(QIcon(resource_path("icons/main.icns")))
|
||
else:
|
||
app.setWindowIcon(QIcon(resource_path("icons/main.ico")))
|
||
window.setWindowIcon(QIcon(resource_path("icons/main.ico")))
|
||
window.show()
|
||
sys.exit(app.exec())
|
||
|
||
# Not 4000 lines yay! |