4099 lines
157 KiB
Python
4099 lines
157 KiB
Python
"""
|
||
Filename: main.py
|
||
Description: SPARKS main executable
|
||
|
||
Author: Tyler de Zeeuw
|
||
License: GPL-3.0
|
||
"""
|
||
|
||
# Built-in imports
|
||
import os
|
||
import re
|
||
import csv
|
||
import sys
|
||
import math
|
||
import json
|
||
import time
|
||
import shlex
|
||
import pickle
|
||
import shutil
|
||
import zipfile
|
||
import platform
|
||
import traceback
|
||
import subprocess
|
||
from itertools import product
|
||
from pathlib import Path, PurePosixPath
|
||
from datetime import datetime
|
||
from multiprocessing import Process, current_process, freeze_support, Manager
|
||
import cv2
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
# External library imports
|
||
import psutil
|
||
import requests
|
||
import torch
|
||
import torch.nn as nn
|
||
import mediapipe as mp
|
||
from mediapipe.tasks import python
|
||
from mediapipe.tasks.python import vision
|
||
from torch.utils.data import TensorDataset, DataLoader
|
||
from sklearn.utils.class_weight import compute_class_weight
|
||
import matplotlib.pyplot as plt
|
||
|
||
from PySide6.QtWidgets import (
|
||
QApplication, QWidget, QMessageBox, QVBoxLayout, QHBoxLayout, QTextEdit, QScrollArea, QComboBox, QGridLayout,
|
||
QPushButton, QMainWindow, QFileDialog, QLabel, QLineEdit, QFrame, QSizePolicy, QGroupBox, QDialog, QListView,
|
||
QMenu, QProgressBar, QCheckBox, QSlider, QTabWidget, QTreeWidget, QTreeWidgetItem, QHeaderView
|
||
)
|
||
from PySide6.QtCore import QThread, Signal, Qt, QTimer, QEvent, QSize, QPoint
|
||
from PySide6.QtGui import QAction, QKeySequence, QIcon, QIntValidator, QDoubleValidator, QPixmap, QStandardItemModel, QStandardItem, QImage
|
||
from PySide6.QtSvgWidgets import QSvgWidget # needed to show svgs when app is not frozen
|
||
|
||
|
||
CURRENT_VERSION = "1.0.0"
|
||
|
||
API_URL = "https://git.research.dezeeuw.ca/api/v1/repos/tyler/sparks/releases"
|
||
API_URL_SECONDARY = "https://git.research2.dezeeuw.ca/api/v1/repos/tyler/sparks/releases"
|
||
|
||
PLATFORM_NAME = platform.system().lower()
|
||
|
||
# Selectable parameters on the right side of the window
|
||
SECTIONS = [
|
||
{
|
||
"title": "Parameter 1",
|
||
"params": [
|
||
{"name": "Number 1", "default": True, "type": bool, "help": "N/A"},
|
||
{"name": "Number 2", "default": True, "type": bool, "help": "N/A"},
|
||
]
|
||
},
|
||
]
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
class TrainModelThread(QThread):
|
||
update = Signal(str) # new: emits messages for the GUI
|
||
finished = Signal(str)
|
||
|
||
def __init__(self, csv_paths, model_save_path, parent=None):
|
||
super().__init__(parent)
|
||
self.csv_paths = csv_paths
|
||
self.model_save_path = model_save_path
|
||
|
||
def run(self):
|
||
|
||
# Load CSVs
|
||
X_list, y_list, feature_cols = [], [], None
|
||
for path in self.csv_paths:
|
||
df = pd.read_csv(path)
|
||
if feature_cols is None:
|
||
feature_cols = [c for c in df.columns if c.startswith("lm")]
|
||
df[feature_cols] = df[feature_cols].ffill().bfill()
|
||
X_list.append(df[feature_cols].values)
|
||
y_list.append(df[["reach_active", "reach_before_contact"]].values)
|
||
|
||
X = np.concatenate(X_list, axis=0)
|
||
y = np.concatenate(y_list, axis=0)
|
||
|
||
# Windowing
|
||
FPS, WINDOW_SEC, STEP_SEC, THRESHOLD, EPOCHS = 60, 0.4, 0.2, 0.5, 150
|
||
window_length = int(FPS * WINDOW_SEC)
|
||
step_length = int(FPS * STEP_SEC)
|
||
X_windows, y_windows = [], []
|
||
for start in range(0, len(X)-window_length+1, step_length):
|
||
end = start+window_length
|
||
X_win = X[start:end]
|
||
y_win = y[start:end]
|
||
label_window = (y_win.sum(axis=0) / len(y_win) >= 0.3).astype(int) # shape: (2,)
|
||
X_windows.append(X[start:end])
|
||
y_windows.append(label_window)
|
||
|
||
X_windows = np.array(X_windows)
|
||
y_windows = np.array(y_windows)
|
||
|
||
class_weights = []
|
||
for i in range(2):
|
||
cw = compute_class_weight('balanced', classes=np.array([0,1]), y=y_windows[:, i])
|
||
class_weights.append(cw[1]/cw[0])
|
||
pos_weight = torch.tensor(class_weights, dtype=torch.float32)
|
||
X_tensor = torch.tensor(np.array(X_windows), dtype=torch.float32) # (N, window_length, features)
|
||
y_tensor = torch.tensor(np.array(y_windows), dtype=torch.float32) # (N, 2)
|
||
|
||
# LSTM
|
||
class WindowLSTM(nn.Module):
|
||
def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2):
|
||
super().__init__()
|
||
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True,
|
||
bidirectional=bidirectional)
|
||
self.fc = nn.Linear(hidden_size*(2 if bidirectional else 1), output_size)
|
||
def forward(self, x):
|
||
_, (h, _) = self.lstm(x)
|
||
if self.lstm.bidirectional:
|
||
h = torch.cat([h[-2], h[-1]], dim=1)
|
||
else:
|
||
h = h[-1]
|
||
return self.fc(h)
|
||
|
||
# Hyperparameter sweep
|
||
HIDDEN_SIZES = [64]
|
||
BATCH_SIZES = [32]
|
||
LRs = [0.0005]
|
||
OPTIMIZERS = ['adam']
|
||
BIDIRECTIONAL = [False]
|
||
|
||
best_f1 = 0
|
||
best_config = None
|
||
|
||
for hidden_size, batch_size, lr, opt_name, bi in product(HIDDEN_SIZES,BATCH_SIZES,LRs,OPTIMIZERS,BIDIRECTIONAL):
|
||
self.update.emit(f"Training model: hidden={hidden_size}, batch={batch_size}, lr={lr}, opt={opt_name}, bidir={bi}")
|
||
|
||
dataset = TensorDataset(X_tensor, y_tensor)
|
||
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||
model = WindowLSTM(input_size=X_windows.shape[2], hidden_size=hidden_size, bidirectional=bi)
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=lr) if opt_name=='adam' else torch.optim.SGD(model.parameters(), lr=lr)
|
||
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
||
|
||
# Training
|
||
for epoch in range(EPOCHS):
|
||
model.train()
|
||
epoch_loss = 0
|
||
for xb, yb in loader:
|
||
optimizer.zero_grad()
|
||
logits = model(xb)
|
||
loss = criterion(logits, yb)
|
||
loss.backward()
|
||
optimizer.step()
|
||
epoch_loss += loss.item()
|
||
|
||
# Evaluation
|
||
model.eval()
|
||
with torch.no_grad():
|
||
logits = model(X_tensor)
|
||
probs = torch.sigmoid(logits)
|
||
preds = (probs > THRESHOLD).int()
|
||
y_true = y_tensor.int()
|
||
|
||
tp = ((preds==1)&(y_true==1)).sum().item()
|
||
fp = ((preds==1)&(y_true==0)).sum().item()
|
||
fn = ((preds==0)&(y_true==1)).sum().item()
|
||
|
||
precision = tp/(tp+fp+1e-8)
|
||
recall = tp/(tp+fn+1e-8)
|
||
|
||
f1 = 2*precision*recall/(precision+recall+1e-8)
|
||
|
||
self.update.emit(f"Epoch {epoch+1}/{EPOCHS} Loss={epoch_loss:.4f} "
|
||
f"Precision={precision:.3f} Recall={recall:.3f} F1={f1:.3f}")
|
||
|
||
|
||
if f1 > best_f1:
|
||
best_f1 = f1
|
||
# When saving the model after training
|
||
best_config = (hidden_size, batch_size, lr, opt_name, bi)
|
||
torch.save(model.state_dict(), self.model_save_path)
|
||
|
||
self.update.emit(f"New best model saved to {os.path.basename(self.model_save_path)}!")
|
||
self.finished.emit(f"Training complete!\nBest config: {best_config}\nF1={best_f1:.3f}")
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
class TestModelThread(QThread):
|
||
update_frame = Signal(np.ndarray) # emit frames with overlay
|
||
finished = Signal()
|
||
|
||
def __init__(self, video_paths, model_path, fps=60, window_sec=0.4, step_sec=0.2, threshold=0.5):
|
||
super().__init__()
|
||
self.video_paths = video_paths
|
||
self.fps = fps
|
||
self.window_sec = window_sec
|
||
self.step_sec = step_sec
|
||
self.threshold = threshold
|
||
self.model_path = model_path
|
||
self.window_frames = int(fps*window_sec)
|
||
|
||
def run(self):
|
||
|
||
from time import time
|
||
|
||
class WindowLSTM(torch.nn.Module):
|
||
def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2):
|
||
super().__init__()
|
||
self.lstm = torch.nn.LSTM(
|
||
input_size, hidden_size, batch_first=True, bidirectional=bidirectional
|
||
)
|
||
self.fc = torch.nn.Linear(hidden_size * (2 if bidirectional else 1), output_size)
|
||
|
||
def forward(self, x):
|
||
_, (h, _) = self.lstm(x)
|
||
if self.lstm.bidirectional:
|
||
h = torch.cat([h[-2], h[-1]], dim=1)
|
||
else:
|
||
h = h[-1]
|
||
return self.fc(h)
|
||
|
||
# MATCH THE MODEL CONFIG YOU TRAINED WITH
|
||
model = WindowLSTM(input_size=63, hidden_size=64, bidirectional=False)
|
||
model.load_state_dict(torch.load(self.model_path, map_location="cpu"))
|
||
model.eval()
|
||
|
||
print("Loaded model.")
|
||
|
||
|
||
# ------------------------------ MediaPipe ------------------------------
|
||
mp_hands = mp.solutions.hands
|
||
mp_drawing = mp.solutions.drawing_utils
|
||
mp_drawing_styles = mp.solutions.drawing_styles
|
||
|
||
hands = mp_hands.Hands(
|
||
static_image_mode=False,
|
||
max_num_hands=2,
|
||
min_detection_confidence=0.5,
|
||
min_tracking_confidence=0.5,
|
||
model_complexity=0
|
||
)
|
||
|
||
# ------------------------------ Open videos ------------------------------
|
||
|
||
cap = cv2.VideoCapture(self.video_paths[0])
|
||
|
||
|
||
# buffer holds last 60 frames of landmarks
|
||
window_buffer = []
|
||
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
print(f"Original video dimensions: {original_width} x {original_height}")
|
||
print(f"Total frames in video: {total_frames}")
|
||
|
||
# Set display dimensions (maintaining aspect ratio)
|
||
display_width = 800
|
||
aspect_ratio = original_height / original_width
|
||
display_height = int(display_width * aspect_ratio)
|
||
print(f"Display dimensions: {display_width} x {display_height}")
|
||
|
||
window_name = f"Hand Tracking Test - Press 'q' to quit"
|
||
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
|
||
cv2.resizeWindow(window_name, display_width, display_height)
|
||
|
||
# Output window size
|
||
display_width = 1000
|
||
|
||
frame_index = 0 # <<< ADDED to match your other script
|
||
|
||
while True:
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
|
||
h, w = frame.shape[:2]
|
||
|
||
# detect hands
|
||
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
results = hands.process(rgb)
|
||
|
||
# find RIGHT hand
|
||
features = []
|
||
right_found = False
|
||
|
||
if results.multi_hand_landmarks and results.multi_handedness:
|
||
for i, handness in enumerate(results.multi_handedness):
|
||
label = handness.classification[0].label
|
||
if label == "Right":
|
||
lm = results.multi_hand_landmarks[i]
|
||
right_found = True
|
||
|
||
for p in lm.landmark:
|
||
features += [p.x, p.y, p.z]
|
||
|
||
# draw hand with SAME STYLES
|
||
mp_drawing.draw_landmarks(
|
||
frame,
|
||
lm,
|
||
mp_hands.HAND_CONNECTIONS,
|
||
mp_drawing_styles.get_default_hand_landmarks_style(),
|
||
mp_drawing_styles.get_default_hand_connections_style(),
|
||
)
|
||
break
|
||
|
||
if not right_found:
|
||
features = [math.nan] * 63
|
||
|
||
print(features)
|
||
# --------------------------
|
||
# Maintain window buffer
|
||
# --------------------------
|
||
window_buffer.append(features)
|
||
|
||
# forward fill
|
||
arr = np.array(window_buffer, dtype=np.float32)
|
||
for c in range(arr.shape[1]):
|
||
valid = ~np.isnan(arr[:, c])
|
||
if valid.any():
|
||
arr[:, c][~valid] = np.interp(
|
||
np.flatnonzero(~valid),
|
||
np.flatnonzero(valid),
|
||
arr[:, c][valid],
|
||
)
|
||
window_buffer = arr.tolist()
|
||
|
||
# only last 60 frames
|
||
if len(window_buffer) > self.window_frames:
|
||
window_buffer.pop(0)
|
||
|
||
prob = 0
|
||
pred_active = 0
|
||
pred_before_contact = 0
|
||
|
||
# --------------------------
|
||
# run inference when full window available
|
||
# --------------------------
|
||
if len(window_buffer) == self.window_frames:
|
||
x = torch.tensor(window_buffer, dtype=torch.float32).unsqueeze(0)
|
||
with torch.no_grad():
|
||
logits = model(x)
|
||
probs = torch.sigmoid(logits).squeeze(0).numpy()
|
||
|
||
pred_active = int(probs[0] > self.threshold)
|
||
pred_before_contact = int(probs[1] > self.threshold)
|
||
|
||
color_active = (0, 255, 0) if pred_active else (0, 0, 255)
|
||
color_before_contact = (255, 255, 0) if pred_before_contact else (0, 0, 128)
|
||
|
||
# Overlay rectangles
|
||
cv2.rectangle(frame, (0, 0), (w//2, 40), color_active, -1)
|
||
cv2.putText(frame, f"Reach Active: {'YES' if pred_active else 'NO'} ({probs[0]:.2f})",
|
||
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,255), 2)
|
||
|
||
cv2.rectangle(frame, (w//2, 0), (w, 40), color_before_contact, -1)
|
||
cv2.putText(frame, f"Reach Before Contact: {'YES' if pred_before_contact else 'NO'} ({probs[1]:.2f})",
|
||
(w//2 + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,255), 2)
|
||
|
||
else:
|
||
cv2.rectangle(frame, (0, 0), (w, 40), (0, 128, 255), -1)
|
||
cv2.putText(
|
||
frame,
|
||
f"Collecting window: {len(window_buffer)}/{self.window_frames}",
|
||
(10, 30),
|
||
cv2.FONT_HERSHEY_SIMPLEX,
|
||
1.0,
|
||
(255, 255, 255),
|
||
2,
|
||
)
|
||
|
||
# ------------------------------
|
||
# Extra overlays to MATCH your big script
|
||
# ------------------------------
|
||
info_text = [
|
||
f"Frame: {frame_index}",
|
||
f"Time: {frame_index / self.fps:.2f}s",
|
||
f"Reach Active: {'YES' if pred_active else 'NO'}",
|
||
f"Reach Before Contact: {'YES' if pred_before_contact else 'NO'}",
|
||
f"Right Hand: {'DETECTED' if right_found else 'NOT DETECTED'}",
|
||
"Press 'q' to quit",
|
||
]
|
||
|
||
for i, text in enumerate(info_text):
|
||
y = 60 + i * 25
|
||
cv2.putText(frame, text, (10, y),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,255,0), 2)
|
||
|
||
# If predicted active, add highlight frame
|
||
if pred_active or pred_before_contact:
|
||
cv2.rectangle(frame, (0,0), (w,h), (0,255,0), 6)
|
||
display_frame = frame.copy()
|
||
|
||
# resize for display
|
||
resized_frame = cv2.resize(display_frame, (display_width, display_height), interpolation=cv2.INTER_AREA)
|
||
|
||
cv2.imshow(window_name, resized_frame)
|
||
|
||
frame_index += 1
|
||
|
||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||
break
|
||
self.finished.emit()
|
||
|
||
|
||
|
||
|
||
|
||
class TestModelThread2(QThread):
|
||
finished = Signal(object)
|
||
|
||
def __init__(self, csv_path, model_path, fps=60, window_sec=0.4, threshold=0.5):
|
||
super().__init__()
|
||
self.csv_path = csv_path
|
||
self.model_path = model_path
|
||
self.fps = fps
|
||
self.window_sec = window_sec
|
||
self.window_frames = int(fps * window_sec)
|
||
self.threshold = threshold
|
||
|
||
|
||
def run(self):
|
||
try:
|
||
self._run_internal()
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
self.error = str(e)
|
||
self.finished.emit({"error": str(e)})
|
||
|
||
def _run_internal(self):
|
||
|
||
# --------------------------------------------------
|
||
# 1. Model Definition and Load (Copied from TestModelThread)
|
||
# --------------------------------------------------
|
||
class WindowLSTM(torch.nn.Module):
|
||
def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2):
|
||
super().__init__()
|
||
self.lstm = torch.nn.LSTM(
|
||
input_size, hidden_size, batch_first=True, bidirectional=bidirectional
|
||
)
|
||
self.fc = torch.nn.Linear(hidden_size * (2 if bidirectional else 1), output_size)
|
||
|
||
def forward(self, x):
|
||
_, (h, _) = self.lstm(x)
|
||
if self.lstm.bidirectional:
|
||
h = torch.cat([h[-2], h[-1]], dim=1)
|
||
else:
|
||
h = h[-1]
|
||
return self.fc(h)
|
||
|
||
# Assuming input_size=63, hidden_size=64, bidirectional=False, as in TestModelThread
|
||
# If your training params varied, you must update these defaults to match the saved model
|
||
model = WindowLSTM(input_size=63, hidden_size=64, bidirectional=False)
|
||
model.load_state_dict(torch.load(self.model_path, map_location="cpu"))
|
||
model.eval()
|
||
|
||
# --------------------------------------------------
|
||
# 2. Load CSV of landmarks (CRITICAL FIX APPLIED HERE)
|
||
# --------------------------------------------------
|
||
df = pd.read_csv(self.csv_path)
|
||
feature_cols = [c for c in df.columns if c.startswith("lm")]
|
||
X_with_nans = df[feature_cols].values
|
||
time = df["time_sec"].values if "time_sec" in df.columns else np.arange(len(X_with_nans)) / self.fps
|
||
|
||
# --------------------------------------------------
|
||
# 3. Sliding window inference (EXACTLY MATCHING TestModelThread)
|
||
# --------------------------------------------------
|
||
window_buffer = []
|
||
preds_active = []
|
||
preds_before_contact = []
|
||
|
||
for i in range(len(X_with_nans)):
|
||
|
||
# Pad output
|
||
if i < self.window_frames - 1:
|
||
preds_active.append(0)
|
||
preds_before_contact.append(0)
|
||
|
||
# Add features (may contain NaNs) to the buffer
|
||
features = X_with_nans[i]
|
||
window_buffer.append(features)
|
||
|
||
# Keep only last `window_frames`
|
||
if len(window_buffer) > self.window_frames:
|
||
window_buffer.pop(0)
|
||
|
||
# 🚨 CRITICAL: PERFORM INTERPOLATION ON THE BUFFER
|
||
# This block *must* run to smooth over lost frames, replicating the video logic.
|
||
arr = np.array(window_buffer, dtype=np.float32)
|
||
for c in range(arr.shape[1]):
|
||
valid = ~np.isnan(arr[:, c])
|
||
if valid.any():
|
||
arr[:, c][~valid] = np.interp(
|
||
np.flatnonzero(~valid),
|
||
np.flatnonzero(valid),
|
||
arr[:, c][valid],
|
||
)
|
||
|
||
# 🚨 CRITICAL: Update the buffer back to the interpolated LIST OF LISTS.
|
||
# This mirrors the exact data structure used by TestModelThread before calling torch.tensor().
|
||
window_buffer = arr.tolist()
|
||
|
||
|
||
# predict only when full window
|
||
if len(window_buffer) == self.window_frames:
|
||
|
||
# 🚨 CRITICAL: Create tensor from the LIST OF LISTS buffer
|
||
x = torch.tensor(window_buffer, dtype=torch.float32).unsqueeze(0)
|
||
with torch.no_grad():
|
||
logits = model(x)
|
||
probs = torch.sigmoid(logits).squeeze(0).numpy()
|
||
|
||
preds_active.append(probs[0])
|
||
preds_before_contact.append(probs[1])
|
||
|
||
|
||
self.finished.emit({
|
||
"time": time,
|
||
"active": preds_active,
|
||
"before": preds_before_contact,
|
||
"threshold": self.threshold,
|
||
})
|
||
|
||
|
||
|
||
class TerminalWindow(QWidget):
|
||
def __init__(self, parent=None):
|
||
super().__init__(parent, Qt.WindowType.Window)
|
||
self.setWindowTitle("Terminal - SPARKS")
|
||
|
||
self.output_area = QTextEdit()
|
||
self.output_area.setReadOnly(True)
|
||
|
||
self.input_line = QLineEdit()
|
||
self.input_line.returnPressed.connect(self.handle_command)
|
||
|
||
layout = QVBoxLayout()
|
||
layout.addWidget(self.output_area)
|
||
layout.addWidget(self.input_line)
|
||
self.setLayout(layout)
|
||
|
||
self.commands = {
|
||
"hello": self.cmd_hello,
|
||
"help": self.cmd_help
|
||
}
|
||
|
||
def handle_command(self):
|
||
command_text = self.input_line.text()
|
||
self.input_line.clear()
|
||
|
||
self.output_area.append(f"> {command_text}")
|
||
parts = command_text.strip().split()
|
||
if not parts:
|
||
return
|
||
|
||
command_name = parts[0]
|
||
args = parts[1:]
|
||
|
||
func = self.commands.get(command_name)
|
||
if func:
|
||
try:
|
||
result = func(*args)
|
||
if result:
|
||
self.output_area.append(str(result))
|
||
except Exception as e:
|
||
self.output_area.append(f"[Error] {e}")
|
||
else:
|
||
self.output_area.append(f"[Unknown command] '{command_name}'")
|
||
|
||
|
||
def cmd_hello(self, *args):
|
||
return "Hello from the terminal!"
|
||
|
||
def cmd_help(self, *args):
|
||
return f"Available commands: {', '.join(self.commands.keys())}"
|
||
|
||
|
||
class UpdateDownloadThread(QThread):
|
||
"""
|
||
Thread that downloads and extracts an update package and emits a signal on completion or error.
|
||
|
||
Args:
|
||
download_url (str): URL of the update zip file to download.
|
||
latest_version (str): Version string of the latest update.
|
||
"""
|
||
|
||
update_ready = Signal(str, str)
|
||
error_occurred = Signal(str)
|
||
|
||
def __init__(self, download_url, latest_version):
|
||
super().__init__()
|
||
self.download_url = download_url
|
||
self.latest_version = latest_version
|
||
|
||
def run(self):
|
||
try:
|
||
local_filename = os.path.basename(self.download_url)
|
||
|
||
if PLATFORM_NAME == 'darwin':
|
||
tmp_dir = '/tmp/sparkstempupdate'
|
||
os.makedirs(tmp_dir, exist_ok=True)
|
||
local_path = os.path.join(tmp_dir, local_filename)
|
||
else:
|
||
local_path = os.path.join(os.getcwd(), local_filename)
|
||
|
||
# Download the file
|
||
with requests.get(self.download_url, stream=True, timeout=15) as r:
|
||
r.raise_for_status()
|
||
with open(local_path, 'wb') as f:
|
||
for chunk in r.iter_content(chunk_size=8192):
|
||
if chunk:
|
||
f.write(chunk)
|
||
|
||
# Extract folder name (remove .zip)
|
||
if PLATFORM_NAME == 'darwin':
|
||
extract_folder = os.path.splitext(local_filename)[0]
|
||
extract_path = os.path.join(tmp_dir, extract_folder)
|
||
|
||
else:
|
||
extract_folder = os.path.splitext(local_filename)[0]
|
||
extract_path = os.path.join(os.getcwd(), extract_folder)
|
||
|
||
# Create the folder if not exists
|
||
os.makedirs(extract_path, exist_ok=True)
|
||
|
||
# Extract the zip file contents
|
||
if PLATFORM_NAME == 'darwin':
|
||
subprocess.run(['ditto', '-xk', local_path, extract_path], check=True)
|
||
else:
|
||
with zipfile.ZipFile(local_path, 'r') as zip_ref:
|
||
zip_ref.extractall(extract_path)
|
||
|
||
# Remove the zip once extracted and emit a signal
|
||
os.remove(local_path)
|
||
self.update_ready.emit(self.latest_version, extract_path)
|
||
|
||
except Exception as e:
|
||
# Emit a signal signifying failure
|
||
self.error_occurred.emit(str(e))
|
||
|
||
|
||
|
||
class UpdateCheckThread(QThread):
|
||
"""
|
||
Thread that checks for updates by querying the API and emits a signal based on the result.
|
||
|
||
Signals:
|
||
download_requested(str, str): Emitted with (download_url, latest_version) when an update is available.
|
||
no_update_available(): Emitted when no update is found or current version is up to date.
|
||
error_occurred(str): Emitted with an error message if the update check fails.
|
||
"""
|
||
|
||
download_requested = Signal(str, str)
|
||
no_update_available = Signal()
|
||
error_occurred = Signal(str)
|
||
|
||
def run(self):
|
||
if not getattr(sys, 'frozen', False):
|
||
self.error_occurred.emit("Application is not frozen (Development mode).")
|
||
return
|
||
try:
|
||
latest_version, download_url = self.get_latest_release_for_platform()
|
||
if not latest_version:
|
||
self.no_update_available.emit()
|
||
return
|
||
|
||
if not download_url:
|
||
self.error_occurred.emit(f"No download available for platform '{PLATFORM_NAME}'")
|
||
return
|
||
|
||
if self.version_compare(latest_version, CURRENT_VERSION) > 0:
|
||
self.download_requested.emit(download_url, latest_version)
|
||
else:
|
||
self.no_update_available.emit()
|
||
|
||
except Exception as e:
|
||
self.error_occurred.emit(f"Update check failed: {e}")
|
||
|
||
def version_compare(self, v1, v2):
|
||
def normalize(v): return [int(x) for x in v.split(".")]
|
||
return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2))
|
||
|
||
def get_latest_release_for_platform(self):
|
||
urls = [API_URL, API_URL_SECONDARY]
|
||
for url in urls:
|
||
try:
|
||
|
||
response = requests.get(API_URL, timeout=5)
|
||
response.raise_for_status()
|
||
releases = response.json()
|
||
|
||
if not releases:
|
||
return None, None
|
||
|
||
latest = releases[0]
|
||
tag = latest["tag_name"].lstrip("v")
|
||
|
||
for asset in latest.get("assets", []):
|
||
if PLATFORM_NAME in asset["name"].lower():
|
||
return tag, asset["browser_download_url"]
|
||
|
||
return tag, None
|
||
except (requests.RequestException, ValueError) as e:
|
||
continue
|
||
return None, None
|
||
|
||
|
||
class LocalPendingUpdateCheckThread(QThread):
|
||
"""
|
||
Thread that checks for locally pending updates by scanning the download directory and emits a signal accordingly.
|
||
|
||
Args:
|
||
current_version (str): Current application version.
|
||
platform_suffix (str): Platform-specific suffix to identify update folders.
|
||
"""
|
||
|
||
pending_update_found = Signal(str, str)
|
||
no_pending_update = Signal()
|
||
|
||
def __init__(self, current_version, platform_suffix):
|
||
super().__init__()
|
||
self.current_version = current_version
|
||
self.platform_suffix = platform_suffix
|
||
|
||
def version_compare(self, v1, v2):
|
||
def normalize(v): return [int(x) for x in v.split(".")]
|
||
return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2))
|
||
|
||
def run(self):
|
||
if PLATFORM_NAME == 'darwin':
|
||
cwd = '/tmp/sparkstempupdate'
|
||
else:
|
||
cwd = os.getcwd()
|
||
|
||
pattern = re.compile(r".*-(\d+\.\d+\.\d+)" + re.escape(self.platform_suffix) + r"$")
|
||
found = False
|
||
|
||
try:
|
||
for item in os.listdir(cwd):
|
||
folder_path = os.path.join(cwd, item)
|
||
if os.path.isdir(folder_path) and item.endswith(self.platform_suffix):
|
||
match = pattern.match(item)
|
||
if match:
|
||
folder_version = match.group(1)
|
||
if self.version_compare(folder_version, self.current_version) > 0:
|
||
self.pending_update_found.emit(folder_version, folder_path)
|
||
found = True
|
||
break
|
||
except:
|
||
pass
|
||
|
||
if not found:
|
||
self.no_pending_update.emit()
|
||
|
||
|
||
|
||
|
||
|
||
class ProgressDialog(QDialog):
|
||
def __init__(self, parent=None):
|
||
super().__init__(parent)
|
||
self.setWindowTitle("Processing Files...")
|
||
self.setWindowModality(Qt.WindowModality.ApplicationModal)
|
||
self.setMinimumWidth(300)
|
||
|
||
layout = QVBoxLayout(self)
|
||
|
||
self.message_label = QLabel("Loading and analyzing video files...")
|
||
self.progress_bar = QProgressBar()
|
||
self.progress_bar.setMaximum(0) # Indeterminate state initially
|
||
|
||
self.cancel_button = QPushButton("Cancel")
|
||
|
||
layout.addWidget(self.message_label)
|
||
layout.addWidget(self.progress_bar)
|
||
layout.addWidget(self.cancel_button)
|
||
|
||
def update_progress(self, current, total):
|
||
if self.progress_bar.maximum() == 0:
|
||
self.progress_bar.setMaximum(total)
|
||
|
||
self.progress_bar.setValue(current)
|
||
self.message_label.setText(f"Processing file {current} of {total}...")
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
class FileLoadWorker(QThread):
|
||
observations_loaded = Signal(dict, dict, str)
|
||
progress_update = Signal(int, int)
|
||
loading_finished = Signal()
|
||
loading_failed = Signal(str)
|
||
|
||
def __init__(self, file_path, observations_root, extract_frame_and_hands_func):
|
||
super().__init__()
|
||
self.file_path = file_path
|
||
self.observations_root = observations_root
|
||
self.extract_frame_and_hands = extract_frame_and_hands_func
|
||
self.is_running = True
|
||
self.previews_data = {} # To store the results of the heavy work
|
||
|
||
def resolve_video_path(self, relative_video_path):
|
||
"""Logic to handle flexible path resolution (basename or full join)."""
|
||
|
||
corrected_file_name = os.path.basename(relative_video_path)
|
||
video_path = os.path.join(self.observations_root, corrected_file_name)
|
||
|
||
if not os.path.exists(video_path):
|
||
# Fallback to full relative path join
|
||
video_path = os.path.join(self.observations_root, relative_video_path)
|
||
|
||
return video_path
|
||
|
||
def run(self):
|
||
try:
|
||
# 1. Load the BORIS JSON (Non-UI work)
|
||
with open(self.file_path, "r") as f:
|
||
boris = json.load(f)
|
||
|
||
observations = boris.get("observations", {})
|
||
if not observations:
|
||
self.loading_failed.emit("No observations found in BORIS JSON.")
|
||
return
|
||
|
||
# 2. Process the file data to gather initial previews (the heavy part)
|
||
self.process_previews(observations)
|
||
|
||
# 3. Emit the final results back
|
||
self.observations_loaded.emit(boris, observations, self.observations_root)
|
||
|
||
except Exception as e:
|
||
self.loading_failed.emit(f"Loading failed: {str(e)}")
|
||
finally:
|
||
self.loading_finished.emit() # Ensures dialog closes even on some non-critical path breaks
|
||
|
||
def process_previews(self, observations):
|
||
total_files = sum(len(obs_data.get("file", {}).items()) for obs_data in observations.values())
|
||
current_index = 0
|
||
|
||
for obs_id, obs_data in observations.items():
|
||
files_dict = obs_data.get("file", {})
|
||
media_info = obs_data.get("media_info", {})
|
||
fps_dict = media_info.get("fps", {})
|
||
# Simplified FPS calculation for the worker
|
||
fps_default = float(list(fps_dict.values())[0]) if fps_dict and list(fps_dict.values()) else 30.0
|
||
|
||
print(media_info)
|
||
for cam_id, paths in files_dict.items():
|
||
if not self.is_running:
|
||
return # Allow cancellation
|
||
if not paths:
|
||
print(f"Skipping entry for {obs_id}/{cam_id}: paths list is empty (IndexError prevented).")
|
||
current_index += 1
|
||
self.progress_update.emit(current_index, total_files)
|
||
continue
|
||
print(paths)
|
||
relative_video_path = paths[0]
|
||
video_path = self.resolve_video_path(relative_video_path)
|
||
print(video_path)
|
||
if os.path.exists(video_path):
|
||
# HEAVY OPERATION
|
||
frame_rgb, results = self.extract_frame_and_hands(video_path, 0)
|
||
|
||
initial_wrists = []
|
||
if results and results.multi_hand_landmarks:
|
||
for lm in results.multi_hand_landmarks:
|
||
wrist = lm.landmark[0]
|
||
initial_wrists.append((wrist.x, wrist.y))
|
||
|
||
# Store the results needed to BUILD the UI later
|
||
self.previews_data[(obs_id, cam_id)] = {
|
||
"frame_rgb": frame_rgb,
|
||
"results": results,
|
||
"video_path": video_path,
|
||
"fps": fps_default,
|
||
"initial_wrists": initial_wrists
|
||
}
|
||
|
||
current_index += 1
|
||
self.progress_update.emit(current_index, total_files)
|
||
|
||
def stop(self):
|
||
self.is_running = False
|
||
|
||
|
||
|
||
|
||
|
||
|
||
class IndividualFileLoadWorker(QThread):
|
||
observations_loaded = Signal()
|
||
loading_finished = Signal()
|
||
loading_failed = Signal(str)
|
||
|
||
def __init__(self, file_path, extract_frame_and_hands_func):
|
||
super().__init__()
|
||
self.file_path = file_path
|
||
self.extract_frame_and_hands = extract_frame_and_hands_func
|
||
self.is_running = True
|
||
self.previews_data = {} # To store the results of the heavy work
|
||
|
||
def run(self):
|
||
try:
|
||
# 2. Process the file data to gather initial previews (the heavy part)
|
||
self.process_previews()
|
||
|
||
self.observations_loaded.emit()
|
||
|
||
except Exception as e:
|
||
self.loading_failed.emit(f"Loading failed: {str(e)}")
|
||
finally:
|
||
self.loading_finished.emit() # Ensures dialog closes even on some non-critical path breaks
|
||
|
||
def process_previews(self):
|
||
|
||
if os.path.exists(self.file_path):
|
||
# HEAVY OPERATION
|
||
frame_rgb, results = self.extract_frame_and_hands(self.file_path, 0)
|
||
|
||
initial_wrists = []
|
||
if results and results.hand_landmarks:
|
||
for hand_landmarks in results.hand_landmarks:
|
||
# Change: hand_landmarks is a list, access index [0] directly
|
||
wrist = hand_landmarks[0]
|
||
initial_wrists.append((wrist.x, wrist.y))
|
||
|
||
# Store the results needed to BUILD the UI later
|
||
self.previews_data[(1, 1)] = {
|
||
"frame_rgb": frame_rgb,
|
||
"results": results,
|
||
"video_path": self.file_path,
|
||
"fps": 60,
|
||
"initial_wrists": initial_wrists
|
||
}
|
||
|
||
|
||
def stop(self):
|
||
self.is_running = False
|
||
|
||
|
||
|
||
class AboutWindow(QWidget):
|
||
"""
|
||
Simple About window displaying basic application information.
|
||
|
||
Args:
|
||
parent (QWidget, optional): Parent widget of this window. Defaults to None.
|
||
"""
|
||
|
||
def __init__(self, parent=None):
|
||
super().__init__(parent, Qt.WindowType.Window)
|
||
self.setWindowTitle("About SPARKS")
|
||
self.resize(250, 100)
|
||
|
||
layout = QVBoxLayout()
|
||
label = QLabel("About SPARKS", self)
|
||
label2 = QLabel("Spacial Patterns Analysis, Research, & Knowledge Suite", self)
|
||
label3 = QLabel("SPARKS is licensed under the GPL-3.0 licence. For more information, visit https://www.gnu.org/licenses/gpl-3.0.en.html", self)
|
||
label4 = QLabel(f"Version v{CURRENT_VERSION}")
|
||
|
||
layout.addWidget(label)
|
||
layout.addWidget(label2)
|
||
layout.addWidget(label3)
|
||
layout.addWidget(label4)
|
||
|
||
self.setLayout(layout)
|
||
|
||
|
||
|
||
class UserGuideWindow(QWidget):
|
||
"""
|
||
Simple User Guide window displaying basic information on how to use the software.
|
||
|
||
Args:
|
||
parent (QWidget, optional): Parent widget of this window. Defaults to None.
|
||
"""
|
||
|
||
def __init__(self, parent=None):
|
||
super().__init__(parent, Qt.WindowType.Window)
|
||
self.setWindowTitle("User Guide - SPARKS")
|
||
self.resize(250, 100)
|
||
|
||
layout = QVBoxLayout()
|
||
label = QLabel("Not Applicable:", self)
|
||
label2 = QLabel("N/A", self)
|
||
|
||
layout.addWidget(label)
|
||
layout.addWidget(label2)
|
||
|
||
self.setLayout(layout)
|
||
|
||
|
||
|
||
class ParamSection(QWidget):
|
||
"""
|
||
A widget section that dynamically creates labeled input fields from parameter metadata.
|
||
|
||
Args:
|
||
section_data (dict): Dictionary containing section title and list of parameter info.
|
||
Expected format:
|
||
{
|
||
"title": str,
|
||
"params": [
|
||
{
|
||
"name": str,
|
||
"type": type,
|
||
"default": any,
|
||
"help": str (optional)
|
||
},
|
||
...
|
||
]
|
||
}
|
||
"""
|
||
|
||
def __init__(self, section_data):
|
||
super().__init__()
|
||
layout = QVBoxLayout()
|
||
self.setLayout(layout)
|
||
self.widgets = {}
|
||
|
||
self.selected_path = None
|
||
|
||
# Title label
|
||
title_label = QLabel(section_data["title"])
|
||
title_label.setStyleSheet("font-weight: bold; font-size: 14px; margin-top: 10px; margin-bottom: 5px;")
|
||
layout.addWidget(title_label)
|
||
|
||
# Horizontal line
|
||
line = QFrame()
|
||
line.setFrameShape(QFrame.Shape.HLine)
|
||
line.setFrameShadow(QFrame.Shadow.Sunken)
|
||
layout.addWidget(line)
|
||
|
||
for param in section_data["params"]:
|
||
h_layout = QHBoxLayout()
|
||
|
||
label = QLabel(param["name"])
|
||
label.setFixedWidth(180)
|
||
|
||
label.setToolTip(param.get("help", ""))
|
||
|
||
help_text = param.get("help", "")
|
||
|
||
help_btn = QPushButton("?")
|
||
help_btn.setFixedWidth(25)
|
||
help_btn.setToolTip(help_text)
|
||
help_btn.clicked.connect(lambda _, text=help_text: self.show_help_popup(text))
|
||
|
||
|
||
h_layout.addWidget(help_btn)
|
||
h_layout.setStretch(0, 1) # Set stretch factor for button (10%)
|
||
|
||
h_layout.addWidget(label)
|
||
h_layout.setStretch(1, 3) # Set the stretch factor for label (40%)
|
||
|
||
|
||
# Create input widget based on type
|
||
if param["type"] == bool:
|
||
widget = QComboBox()
|
||
widget.addItems(["True", "False"])
|
||
widget.setCurrentText(str(param["default"]))
|
||
elif param["type"] == int:
|
||
widget = QLineEdit()
|
||
widget.setValidator(QIntValidator())
|
||
widget.setText(str(param["default"]))
|
||
elif param["type"] == float:
|
||
widget = QLineEdit()
|
||
widget.setValidator(QDoubleValidator())
|
||
widget.setText(str(param["default"]))
|
||
elif param["type"] == list:
|
||
widget = self._create_multiselect_dropdown(None)
|
||
else:
|
||
widget = QLineEdit()
|
||
widget.setText(str(param["default"]))
|
||
|
||
widget.setToolTip(help_text)
|
||
|
||
h_layout.addWidget(widget)
|
||
h_layout.setStretch(2, 5) # Set stretch factor for input field (50%)
|
||
|
||
layout.addLayout(h_layout)
|
||
self.widgets[param["name"]] = {
|
||
"widget": widget,
|
||
"type": param["type"]
|
||
}
|
||
|
||
def _create_multiselect_dropdown(self, items):
|
||
combo = FullClickComboBox()
|
||
combo.setView(QListView())
|
||
model = QStandardItemModel()
|
||
combo.setModel(model)
|
||
combo.setEditable(True)
|
||
combo.lineEdit().setReadOnly(True)
|
||
combo.lineEdit().setPlaceholderText("Select...")
|
||
|
||
dummy_item = QStandardItem("<None Selected>")
|
||
dummy_item.setFlags(Qt.ItemIsEnabled)
|
||
model.appendRow(dummy_item)
|
||
|
||
toggle_item = QStandardItem("Toggle Select All")
|
||
toggle_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled)
|
||
toggle_item.setData(Qt.Unchecked, Qt.CheckStateRole)
|
||
model.appendRow(toggle_item)
|
||
|
||
if items is not None:
|
||
for item in items:
|
||
standard_item = QStandardItem(item)
|
||
standard_item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled)
|
||
standard_item.setData(Qt.Unchecked, Qt.CheckStateRole)
|
||
model.appendRow(standard_item)
|
||
|
||
combo.setInsertPolicy(QComboBox.NoInsert)
|
||
|
||
|
||
def on_view_clicked(index):
|
||
item = model.itemFromIndex(index)
|
||
if item.isCheckable():
|
||
new_state = Qt.Checked if item.checkState() == Qt.Unchecked else Qt.Unchecked
|
||
item.setCheckState(new_state)
|
||
|
||
combo.view().pressed.connect(on_view_clicked)
|
||
|
||
self._updating_checkstates = False
|
||
|
||
def on_item_changed(item):
|
||
if self._updating_checkstates:
|
||
return
|
||
self._updating_checkstates = True
|
||
|
||
normal_items = [model.item(i) for i in range(2, model.rowCount())] # skip dummy and toggle
|
||
|
||
if item == toggle_item:
|
||
all_checked = all(i.checkState() == Qt.Checked for i in normal_items)
|
||
if all_checked:
|
||
for i in normal_items:
|
||
i.setCheckState(Qt.Unchecked)
|
||
toggle_item.setCheckState(Qt.Unchecked)
|
||
else:
|
||
for i in normal_items:
|
||
i.setCheckState(Qt.Checked)
|
||
toggle_item.setCheckState(Qt.Checked)
|
||
|
||
elif item == dummy_item:
|
||
pass
|
||
|
||
else:
|
||
# When normal items change, update toggle item
|
||
all_checked = all(i.checkState() == Qt.Checked for i in normal_items)
|
||
toggle_item.setCheckState(Qt.Checked if all_checked else Qt.Unchecked)
|
||
|
||
self._updating_checkstates = False
|
||
|
||
for param_name, info in self.widgets.items():
|
||
if info["widget"] == combo:
|
||
self.update_dropdown_label(param_name)
|
||
break
|
||
|
||
model.itemChanged.connect(on_item_changed)
|
||
|
||
combo.setInsertPolicy(QComboBox.NoInsert)
|
||
return combo
|
||
|
||
def show_help_popup(self, text):
|
||
msg = QMessageBox(self)
|
||
msg.setWindowTitle("Parameter Info - SPARKS")
|
||
msg.setText(text)
|
||
msg.exec()
|
||
|
||
|
||
def get_param_values(self):
|
||
values = {}
|
||
for name, info in self.widgets.items():
|
||
widget = info["widget"]
|
||
expected_type = info["type"]
|
||
|
||
if expected_type == bool:
|
||
values[name] = widget.currentText() == "True"
|
||
elif expected_type == list:
|
||
values[name] = [x.strip() for x in widget.lineEdit().text().split(",") if x.strip()]
|
||
else:
|
||
raw_text = widget.text()
|
||
try:
|
||
if expected_type == int:
|
||
values[name] = int(raw_text)
|
||
elif expected_type == float:
|
||
values[name] = float(raw_text)
|
||
elif expected_type == str:
|
||
values[name] = raw_text
|
||
else:
|
||
values[name] = raw_text # Fallback
|
||
except Exception as e:
|
||
raise ValueError(f"Invalid value for {name}: {raw_text}") from e
|
||
|
||
return values
|
||
|
||
|
||
class FullClickLineEdit(QLineEdit):
|
||
def mousePressEvent(self, event):
|
||
combo = self.parent()
|
||
if isinstance(combo, QComboBox):
|
||
combo.showPopup()
|
||
super().mousePressEvent(event)
|
||
|
||
|
||
class FullClickComboBox(QComboBox):
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self.setLineEdit(FullClickLineEdit(self))
|
||
self.lineEdit().setReadOnly(True)
|
||
|
||
|
||
class ParticipantProcessor(QThread):
|
||
progress_updated = Signal(int)
|
||
frame_ready = Signal(QImage)
|
||
finished_processing = Signal(str)
|
||
time_updated = Signal(str)
|
||
|
||
def __init__(self, obs_id, boris_json, selected_cam_id, selected_hand_idx,
|
||
hide_preview=False, output_csv=None, observations_root=None, output_dir=None, initial_wrists=None):
|
||
super().__init__()
|
||
self.obs_id = obs_id
|
||
self.boris_json = boris_json
|
||
self.selected_cam_id = selected_cam_id
|
||
self.selected_hand_idx = selected_hand_idx
|
||
self.hide_preview = hide_preview
|
||
self.output_csv = output_csv
|
||
self.observations_root = observations_root
|
||
self.output_dir = output_dir
|
||
self.initial_wrists = initial_wrists
|
||
|
||
# Mediapipe hands
|
||
self.mp_hands = mp.solutions.hands.Hands(
|
||
static_image_mode=False,
|
||
max_num_hands=2,
|
||
min_detection_confidence=0.5,
|
||
min_tracking_confidence=0.5,
|
||
model_complexity=0
|
||
)
|
||
self.is_running = True
|
||
|
||
def run(self):
|
||
observations = self.boris_json.get("observations", {})
|
||
obs_data = observations[self.obs_id]
|
||
|
||
# ---------------- Camera / Video ----------------
|
||
cameras_with_files = obs_data.get("file", {})
|
||
media_info = obs_data.get("media_info", {})
|
||
offset_dict = media_info.get("offset", {})
|
||
# paths = files_dict.get(self.cam_id, [])
|
||
|
||
video_file_list = cameras_with_files[self.selected_cam_id]
|
||
relative_video_path = video_file_list[0]
|
||
|
||
# 2. Resolve the FULL path using the stored root and the same logic as before
|
||
corrected_file_name = os.path.basename(relative_video_path)
|
||
|
||
# Try finding just the file name in the root folder
|
||
video_path = os.path.join(self.observations_root, corrected_file_name)
|
||
|
||
# if not os.path.exists(video_path):
|
||
# # Fallback to the full relative path join
|
||
# video_path = os.path.join(self.observations_root, relative_video_path)
|
||
|
||
if not os.path.exists(video_path):
|
||
print(f"FATAL ERROR: Video file not found at: {video_path}")
|
||
self.finished_processing.emit(f"Error: Video not found for {self.obs_id}")
|
||
return
|
||
|
||
camera_offset = float(offset_dict.get(self.selected_cam_id, 0.0))
|
||
|
||
fps_dict = media_info.get("fps", {})
|
||
fps = float(list(fps_dict.values())[0]) if fps_dict else 30.0
|
||
|
||
cap = cv2.VideoCapture(video_path)
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
total_seconds = total_frames / fps
|
||
total_time_str = self._format_seconds(total_seconds)
|
||
|
||
# ---------------- Trial Segments ----------------
|
||
raw_events = obs_data.get("events", [])
|
||
segments = []
|
||
reach_before_contact_segments = []
|
||
|
||
current_start = None
|
||
extra_duration = 1.0
|
||
|
||
for ev in raw_events:
|
||
if len(ev) < 3:
|
||
continue
|
||
time_sec = ev[0]
|
||
name = ev[2]
|
||
if name == "Trial Start":
|
||
current_start = time_sec
|
||
current_contact = None
|
||
elif name == "Contact" and current_start is not None:
|
||
current_contact = time_sec
|
||
reach_before_contact_segments.append((current_start, current_contact))
|
||
|
||
elif name == "End" and current_start is not None:
|
||
adjusted_start = current_start
|
||
adjusted_end = time_sec + extra_duration
|
||
segments.append((adjusted_start, adjusted_end))
|
||
if current_contact is None:
|
||
reach_before_contact_segments.append((current_start, adjusted_end))
|
||
current_start = None
|
||
|
||
# ---------------- CSV Header ----------------
|
||
header = ["frame_index", "time_sec", "adjusted_time_sec", "reach_active", "reach_before_contact", "camera_offset"]
|
||
for i in range(21):
|
||
header += [f"lm{i}_x", f"lm{i}_y", f"lm{i}_z"]
|
||
|
||
output_path = os.path.join(self.output_dir, self.output_csv)
|
||
print(output_path)
|
||
with open(output_path, "w", newline="") as csv_file:
|
||
csv_writer = csv.writer(csv_file)
|
||
csv_writer.writerow(header)
|
||
|
||
# ---------------- Frame Processing ----------------
|
||
frame_index = 0
|
||
while True:
|
||
if not self.is_running:
|
||
print(f"Processor for {self.obs_id}/{self.cam_id} was cancelled gracefully.")
|
||
# Clean up resources before returning
|
||
cap.release()
|
||
return # Thread exits gracefully
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
|
||
current_time = frame_index / fps
|
||
adjusted_time = current_time # you could add extra logic if needed
|
||
|
||
# Check if reach is active with offset
|
||
is_reach_active = any(
|
||
(start - camera_offset) <= current_time <= (end - camera_offset)
|
||
for start, end in segments
|
||
)
|
||
|
||
is_reach_before_contact = any(
|
||
(start - camera_offset) <= current_time <= (end - camera_offset)
|
||
for start, end in reach_before_contact_segments
|
||
)
|
||
|
||
row = [
|
||
frame_index, current_time, adjusted_time,
|
||
int(is_reach_active), int(is_reach_before_contact),
|
||
camera_offset
|
||
]
|
||
|
||
# ---------------- MediaPipe Hands ----------------
|
||
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
results = self.mp_hands.process(rgb_frame)
|
||
|
||
hand_found = False
|
||
if results.multi_hand_landmarks and results.multi_handedness:
|
||
# Map detected hands to initial_wrists
|
||
hand_mapping = {}
|
||
used_indices = set()
|
||
if self.initial_wrists:
|
||
for i, iw in enumerate(self.initial_wrists):
|
||
min_dist = float('inf')
|
||
best_idx = None
|
||
for j, lm in enumerate(results.multi_hand_landmarks):
|
||
if j in used_indices:
|
||
continue
|
||
wrist = lm.landmark[0]
|
||
dist = (wrist.x - iw[0])**2 + (wrist.y - iw[1])**2
|
||
if dist < min_dist:
|
||
min_dist = dist
|
||
best_idx = j
|
||
if best_idx is not None:
|
||
hand_mapping[i] = best_idx
|
||
used_indices.add(best_idx)
|
||
else:
|
||
hand_mapping = {i: i for i in range(len(results.multi_hand_landmarks))}
|
||
|
||
# Use selected_hand_idx to pick the right hand
|
||
mapped_idx = hand_mapping.get(self.selected_hand_idx)
|
||
if mapped_idx is not None:
|
||
hand_landmarks = results.multi_hand_landmarks[mapped_idx]
|
||
hand_found = True
|
||
for lm in hand_landmarks.landmark:
|
||
row += [lm.x, lm.y, lm.z]
|
||
|
||
if not hand_found:
|
||
row += [math.nan] * (21 * 3)
|
||
|
||
csv_writer.writerow(row)
|
||
|
||
# ---------------- Preview ----------------
|
||
if not self.hide_preview:
|
||
display_frame = frame.copy()
|
||
if results.multi_hand_landmarks and hand_found:
|
||
mp.solutions.drawing_utils.draw_landmarks(
|
||
display_frame,
|
||
hand_landmarks,
|
||
mp.solutions.hands.HAND_CONNECTIONS
|
||
)
|
||
# Convert to QImage
|
||
h, w, _ = display_frame.shape
|
||
qimg = QImage(display_frame.data, w, h, 3*w, QImage.Format_RGB888)
|
||
self.frame_ready.emit(qimg)
|
||
|
||
current_seconds = frame_index / fps
|
||
current_time_str = self._format_seconds(current_seconds)
|
||
# ---------------- Progress ----------------
|
||
progress = int((frame_index / total_frames) * 100)
|
||
self.progress_updated.emit(progress)
|
||
self.time_updated.emit(f"{current_time_str}/{total_time_str}")
|
||
frame_index += 1
|
||
|
||
cap.release()
|
||
self.mp_hands.close()
|
||
self.finished_processing.emit(self.obs_id)
|
||
|
||
def _format_seconds(self, seconds):
|
||
if seconds < 0:
|
||
return "00:00"
|
||
minutes = math.floor(seconds / 60)
|
||
secs = math.floor(seconds % 60)
|
||
return f"{minutes:02d}:{secs:02d}"
|
||
|
||
def stop(self):
|
||
"""Sets the flag to stop the thread's execution loop."""
|
||
self.is_running = False
|
||
|
||
|
||
|
||
class ParticipantProcessor2(QThread):
|
||
progress_updated = Signal(int)
|
||
time_updated = Signal(str)
|
||
finished_processing = Signal(str, str) # obs_id, cam_id
|
||
|
||
def __init__(self, obs_id, selected_cam_id, selected_hand_idx,
|
||
video_path, output_csv, output_dir, initial_wrists, **kwargs):
|
||
super().__init__()
|
||
self.obs_id = obs_id
|
||
self.cam_id = selected_cam_id
|
||
self.selected_hand_idx = selected_hand_idx
|
||
self.video_path = video_path
|
||
self.output_dir = output_dir
|
||
self.output_csv = output_csv
|
||
self.is_running = True
|
||
|
||
# Convert initial_wrists (list) to initial_centroids (dict) for the tracker
|
||
self.current_centroids = {}
|
||
if initial_wrists:
|
||
for i, pos in enumerate(initial_wrists):
|
||
self.current_centroids[i] = pos
|
||
|
||
def get_centroid(self, lm_list):
|
||
avg_x = sum(lm.x for lm in lm_list) / len(lm_list)
|
||
avg_y = sum(lm.y for lm in lm_list) / len(lm_list)
|
||
return (avg_x, avg_y)
|
||
|
||
def update_tracking(self, results, last_known):
|
||
if not results or not results.hand_landmarks:
|
||
return last_known
|
||
|
||
detected_hands = results.hand_landmarks
|
||
new_centroids = {}
|
||
used_indices = set()
|
||
|
||
# Priority 1: Match existing IDs
|
||
for hand_id, last_pos in last_known.items():
|
||
min_dist = float('inf')
|
||
best_idx = None
|
||
for j, lm_list in enumerate(detected_hands):
|
||
if j in used_indices: continue
|
||
curr_c = self.get_centroid(lm_list)
|
||
dist = (curr_c[0] - last_pos[0])**2 + (curr_c[1] - last_pos[1])**2
|
||
if dist < 0.1 and dist < min_dist:
|
||
min_dist = dist
|
||
best_idx = j
|
||
|
||
if best_idx is not None:
|
||
new_centroids[hand_id] = self.get_centroid(detected_hands[best_idx])
|
||
used_indices.add(best_idx)
|
||
# Keep track of which detection index corresponds to our ID
|
||
if hand_id == self.selected_hand_idx:
|
||
self.current_detection_idx = best_idx
|
||
|
||
# Carry over ghosts for lost hands
|
||
for hand_id, pos in last_known.items():
|
||
if hand_id not in new_centroids:
|
||
new_centroids[hand_id] = pos
|
||
|
||
return new_centroids
|
||
|
||
def run(self):
|
||
# 1. Initialize Mediapipe INSIDE the thread
|
||
base_options = python.BaseOptions(model_asset_path='hand_landmarker.task')
|
||
options = vision.HandLandmarkerOptions(
|
||
base_options=base_options,
|
||
running_mode=vision.RunningMode.IMAGE,
|
||
num_hands=2,
|
||
min_hand_detection_confidence=0.3
|
||
)
|
||
detector = vision.HandLandmarker.create_from_options(options)
|
||
|
||
cap = cv2.VideoCapture(self.video_path)
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
|
||
save_path = os.path.join(self.output_dir, self.output_csv)
|
||
|
||
header = ['frame']
|
||
if self.selected_hand_idx == 99:
|
||
# Dual hand columns: h0_x0, h0_y0 ... h1_x20, h1_y20
|
||
for h_prefix in ['h0', 'h1']:
|
||
for i in range(21):
|
||
header.extend([f'{h_prefix}_x{i}', f'{h_prefix}_y{i}'])
|
||
else:
|
||
# Single hand columns: x0, y0 ... x20, y20
|
||
for i in range(21):
|
||
header.extend([f'x{i}', f'y{i}'])
|
||
|
||
with open(save_path, 'w', newline='') as f:
|
||
writer = csv.writer(f)
|
||
writer.writerow(header)
|
||
|
||
|
||
last_known_centroids = {}
|
||
|
||
while cap.isOpened() and self.is_running:
|
||
ret, frame = cap.read()
|
||
if not ret: break
|
||
|
||
# 1. Process with Tasks API
|
||
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||
results = detector.detect(mp_image)
|
||
frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
|
||
|
||
# current_mapping stores {hand_id: landmark_list} for this frame
|
||
current_mapping = {}
|
||
|
||
if results and results.hand_landmarks:
|
||
hand_landmarks_list = results.hand_landmarks
|
||
temp_mapping = {} # {hand_id: index_in_landmarks_list}
|
||
used_indices = set()
|
||
|
||
# --- STEP 1: TRACKING (Matching to Last Frame) ---
|
||
if last_known_centroids:
|
||
for hand_id, last_pos in last_known_centroids.items():
|
||
min_dist = float('inf')
|
||
best_idx = None
|
||
|
||
for j, lm_list in enumerate(hand_landmarks_list):
|
||
if j in used_indices: continue
|
||
|
||
c = self.get_centroid(lm_list) # Assumes get_centroid is in this class
|
||
dist = (c[0] - last_pos[0])**2 + (c[1] - last_pos[1])**2
|
||
|
||
if dist < min_dist:
|
||
min_dist = dist
|
||
best_idx = j
|
||
|
||
if best_idx is not None and min_dist < 0.1:
|
||
temp_mapping[hand_id] = best_idx
|
||
used_indices.add(best_idx)
|
||
|
||
# --- STEP 2: DISCOVERY (New Hands) ---
|
||
for j in range(len(hand_landmarks_list)):
|
||
if j not in used_indices:
|
||
new_id = 0
|
||
while new_id in temp_mapping or new_id in last_known_centroids:
|
||
new_id += 1
|
||
temp_mapping[new_id] = j
|
||
used_indices.add(j)
|
||
|
||
# --- STEP 3: UPDATE MEMORY & PREPARE ROW ---
|
||
for hand_id, idx in temp_mapping.items():
|
||
landmarks = hand_landmarks_list[idx]
|
||
current_mapping[hand_id] = landmarks
|
||
last_known_centroids[hand_id] = self.get_centroid(landmarks)
|
||
|
||
# --- STEP 4: WRITE TO CSV ---
|
||
row = [frame_idx]
|
||
|
||
if self.selected_hand_idx == 99:
|
||
# DUAL MODE: We expect Hand 0 and Hand 1
|
||
for hand_id in [0, 1]:
|
||
if hand_id in current_mapping:
|
||
for lm in current_mapping[hand_id]:
|
||
row.extend([lm.x, lm.y])
|
||
else:
|
||
row.extend([0.0] * 42) # Zero-pad if this specific ID is missing
|
||
else:
|
||
# SINGLE MODE: Use the specific ID (0 or 1) selected in the UI
|
||
if self.selected_hand_idx in current_mapping:
|
||
for lm in current_mapping[self.selected_hand_idx]:
|
||
row.extend([lm.x, lm.y])
|
||
else:
|
||
row.extend([0.0] * 42)
|
||
|
||
writer.writerow(row)
|
||
|
||
# Update UI
|
||
if frame_idx % 10 == 0:
|
||
progress = int((frame_idx / total_frames) * 100)
|
||
self.progress_updated.emit(progress)
|
||
|
||
secs = int(frame_idx / fps)
|
||
total_secs = int(total_frames / fps)
|
||
self.time_updated.emit(f"{secs//60:02}:{secs%60:02}/{total_secs//60:02}:{total_secs%60:02}")
|
||
|
||
cap.release()
|
||
print("Released")
|
||
detector.close()
|
||
print("Closed")
|
||
self.finished_processing.emit(str(self.obs_id), str(self.cam_id))
|
||
|
||
def cancel(self):
|
||
self.is_running = False
|
||
|
||
|
||
def load_hand_csv(filepath):
|
||
df = pd.read_csv(filepath)
|
||
# We create a dictionary: {frame_number: [list of 21 (x,y) tuples]}
|
||
data = {}
|
||
for _, row in df.iterrows():
|
||
f_idx = int(row['frame'])
|
||
landmarks = []
|
||
for i in range(21):
|
||
landmarks.append((row[f'x{i}'], row[f'y{i}']))
|
||
data[f_idx] = landmarks
|
||
return data
|
||
|
||
|
||
class HandValidationWindow(QWidget):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.setWindowTitle("SPARKS - Dual Hand Data Validator")
|
||
self.resize(1000, 850)
|
||
|
||
self.layout = QVBoxLayout(self)
|
||
|
||
# --- Video Display ---
|
||
self.video_label = QLabel("Load a Video and CSV to begin")
|
||
self.video_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
||
self.video_label.setStyleSheet("background-color: black; border: 2px solid #333;")
|
||
self.video_label.setMinimumSize(1, 1)
|
||
self.layout.addWidget(self.video_label, stretch=1)
|
||
|
||
# --- Playback Slider ---
|
||
self.slider = QSlider(Qt.Orientation.Horizontal)
|
||
self.slider.sliderMoved.connect(self.set_position)
|
||
self.layout.addWidget(self.slider)
|
||
|
||
# --- Control Row ---
|
||
ctrl_layout = QHBoxLayout()
|
||
|
||
self.btn_load = QPushButton("Load Pair")
|
||
self.btn_load.clicked.connect(self.load_files)
|
||
ctrl_layout.addWidget(self.btn_load)
|
||
|
||
self.btn_play = QPushButton("Pause") # Toggles Play/Pause
|
||
self.btn_play.clicked.connect(self.toggle_play)
|
||
ctrl_layout.addWidget(self.btn_play)
|
||
|
||
self.speed_combo = QComboBox()
|
||
self.speed_combo.addItems(["0.25x", "0.5x", "1.0x", "2.0x"])
|
||
self.speed_combo.setCurrentText("1.0x")
|
||
self.speed_combo.currentIndexChanged.connect(self.update_speed)
|
||
|
||
ctrl_layout.addWidget(QLabel("Speed:"))
|
||
ctrl_layout.addWidget(self.speed_combo)
|
||
# Toggles for Hands
|
||
self.chk_h0 = QCheckBox("Show Hand 0 (Cyan)")
|
||
self.chk_h0.setChecked(True)
|
||
self.chk_h1 = QCheckBox("Show Hand 1 (Magenta)")
|
||
self.chk_h1.setChecked(True)
|
||
ctrl_layout.addWidget(self.chk_h0)
|
||
ctrl_layout.addWidget(self.chk_h1)
|
||
|
||
self.lbl_frame = QLabel("Frame: 0")
|
||
ctrl_layout.addWidget(self.lbl_frame)
|
||
|
||
self.layout.addLayout(ctrl_layout)
|
||
|
||
self.inspector = HandDataInspector()
|
||
self.inspector.show()
|
||
|
||
# Logic state
|
||
self.timer = QTimer()
|
||
self.timer.timeout.connect(self.update_frame)
|
||
self.cap = None
|
||
self.hand_data = {} # Will now store {frame: {'h0': [...], 'h1': [...]}}
|
||
self.paused = False
|
||
|
||
from collections import deque
|
||
self.com_history = {
|
||
'h0': deque(maxlen=30),
|
||
'h1': deque(maxlen=30)
|
||
}
|
||
|
||
def load_files(self):
|
||
video_path, _ = QFileDialog.getOpenFileName(self, "Select Video", "", "Videos (*.mp4 *.avi)")
|
||
csv_path, _ = QFileDialog.getOpenFileName(self, "Select CSV", "", "CSV Files (*.csv)")
|
||
|
||
if video_path and csv_path:
|
||
self.cap = cv2.VideoCapture(video_path)
|
||
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
self.slider.setRange(0, self.total_frames - 1)
|
||
|
||
# Load the dual-hand data
|
||
self.hand_data = self.parse_dual_hand_csv(csv_path)
|
||
self.timer.start(16)
|
||
|
||
def update_speed(self):
|
||
speed_map = {"0.25x": 64, "0.5x": 32, "1.0x": 16, "2.0x": 8}
|
||
ms = speed_map.get(self.speed_combo.currentText(), 16)
|
||
if self.timer.isActive():
|
||
self.timer.stop()
|
||
self.timer.start(ms)
|
||
else:
|
||
# If paused, just store the intended speed for when they hit Play
|
||
self.current_interval = ms
|
||
|
||
def parse_dual_hand_csv(self, path):
|
||
df = pd.read_csv(path)
|
||
data = {}
|
||
|
||
# Check if this is a Dual Hand CSV or Single Hand CSV
|
||
is_dual = 'h0_x0' in df.columns
|
||
|
||
for _, row in df.iterrows():
|
||
f = int(row['frame'])
|
||
if is_dual:
|
||
h0 = [(row[f'h0_x{i}'], row[f'h0_y{i}']) for i in range(21)]
|
||
h1 = [(row[f'h1_x{i}'], row[f'h1_y{i}']) for i in range(21)]
|
||
data[f] = {'h0': h0, 'h1': h1}
|
||
else:
|
||
# Fallback for old single-hand files
|
||
h0 = [(row[f'x{i}'], row[f'y{i}']) for i in range(21)]
|
||
data[f] = {'h0': h0, 'h1': []}
|
||
|
||
# Disable Hand 1 checkbox if it's a single hand file
|
||
self.chk_h1.setEnabled(is_dual)
|
||
return data
|
||
|
||
def toggle_play(self):
|
||
if self.paused:
|
||
self.timer.start(16)
|
||
self.btn_play.setText("Pause")
|
||
else:
|
||
self.timer.stop()
|
||
self.btn_play.setText("Play")
|
||
self.paused = not self.paused
|
||
|
||
def set_position(self, position):
|
||
if self.cap:
|
||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, position)
|
||
# Update visual immediately if paused
|
||
if self.paused: self.update_frame(manual=True)
|
||
|
||
def update_frame(self, manual=False):
|
||
if not manual:
|
||
ret, frame = self.cap.read()
|
||
else:
|
||
# Re-read the current frame for scrubbing
|
||
curr = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES))
|
||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, curr)
|
||
ret, frame = self.cap.read()
|
||
|
||
if not ret: return
|
||
|
||
f_idx = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES))
|
||
self.slider.setValue(f_idx)
|
||
self.lbl_frame.setText(f"Frame: {f_idx}")
|
||
|
||
h, w, _ = frame.shape
|
||
prev_idx = f_idx - 1
|
||
prev_data = self.hand_data.get(prev_idx, {})
|
||
fps = self.cap.get(cv2.CAP_PROP_FPS)
|
||
if f_idx in self.hand_data:
|
||
for h_key in ['h0', 'h1']:
|
||
curr = self.hand_data[f_idx].get(h_key, [])
|
||
prev = self.hand_data.get(prev_idx, {}).get(h_key, [])
|
||
|
||
# This handles the zero-padding (if hand is lost, don't calculate)
|
||
if curr and not (curr[0][0] == 0 and curr[0][1] == 0):
|
||
self.inspector.update_hand_data(h_key, curr, prev, fps)
|
||
if self.chk_h0.isChecked():
|
||
curr_h0 = self.hand_data[f_idx]['h0']
|
||
prev_h0 = prev_data.get('h0', [])
|
||
self.draw_skeleton(frame, curr_h0, w, h, (255, 255, 0), 'h0', prev_h0, fps)
|
||
if self.chk_h1.isChecked():
|
||
curr_h1 = self.hand_data[f_idx]['h1']
|
||
prev_h1 = prev_data.get('h1', [])
|
||
self.draw_skeleton(frame, curr_h1, w, h, (255, 0, 255), 'h1', prev_h1, fps)
|
||
|
||
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
qimg = QImage(rgb.data, w, h, 3 * w, QImage.Format_RGB888)
|
||
display_size = self.video_label.contentsRect().size()
|
||
|
||
self.video_label.setPixmap(
|
||
QPixmap.fromImage(qimg).scaled(
|
||
display_size,
|
||
Qt.AspectRatioMode.KeepAspectRatio,
|
||
Qt.TransformationMode.SmoothTransformation # Added for better quality
|
||
)
|
||
)
|
||
|
||
def draw_skeleton(self, img, landmarks, w, h, color, h_key, prev_landmarks=None, fps=30):
|
||
# --- 1. Draw Skeleton Connections ---
|
||
conns = [(0,1), (1,2), (2,3), (3,4), (0,5), (5,6), (6,7), (7,8),
|
||
(9,10), (10,11), (11,12), (13,14), (14,15), (15,16),
|
||
(0,17), (17,18), (18,19), (19,20), (5,9), (9,13), (13,17)]
|
||
|
||
for s, e in conns:
|
||
p1 = (int(landmarks[s][0]*w), int(landmarks[s][1]*h))
|
||
p2 = (int(landmarks[e][0]*w), int(landmarks[e][1]*h))
|
||
cv2.line(img, p1, p2, color, 2)
|
||
|
||
# Calculate current Center of Mass (COM) in pixels
|
||
avg_x = sum(p[0] for p in landmarks) / 21
|
||
avg_y = sum(p[1] for p in landmarks) / 21
|
||
com_px = (int(avg_x * w), int(avg_y * h))
|
||
|
||
# --- 2. Instant Vector (Thinner, Frame-to-Frame) ---
|
||
if prev_landmarks and len(prev_landmarks) == 21:
|
||
p_avg_x = sum(p[0] for p in prev_landmarks) / 21
|
||
p_avg_y = sum(p[1] for p in prev_landmarks) / 21
|
||
|
||
# Instant displacement scaled for visibility
|
||
v_scale_inst = 35
|
||
vx_inst = (avg_x - p_avg_x) * v_scale_inst
|
||
vy_inst = (avg_y - p_avg_y) * v_scale_inst
|
||
|
||
inst_end_px = (int((avg_x + vx_inst) * w), int((avg_y + vy_inst) * h))
|
||
|
||
# Draw Instant Arrow (Colored to match the hand)
|
||
cv2.arrowedLine(img, com_px, inst_end_px, color, 2, tipLength=0.2)
|
||
|
||
|
||
self.com_history[h_key].append((avg_x, avg_y))
|
||
|
||
if len(self.com_history[h_key]) > 1:
|
||
# Calculate average per-frame displacement over the history
|
||
# (Current Position - Oldest Position) / Number of Frames
|
||
oldest = self.com_history[h_key][0]
|
||
count = len(self.com_history[h_key])
|
||
|
||
avg_dx = (avg_x - oldest[0]) / count
|
||
avg_dy = (avg_y - oldest[1]) / count
|
||
|
||
# Scale for the Big Arrow (Adjust 60-100 to change length)
|
||
v_scale_smooth = 35
|
||
|
||
# Calculate end point
|
||
smooth_end_x = avg_x + (avg_dx * v_scale_smooth)
|
||
smooth_end_y = avg_y + (avg_dy * v_scale_smooth)
|
||
smooth_end_px = (int(smooth_end_x * w), int(smooth_end_y * h))
|
||
|
||
# Draw the Big Arrow (White with Black Outline for high visibility)
|
||
cv2.arrowedLine(img, com_px, smooth_end_px, (0, 0, 0), 11, tipLength=0.3)
|
||
cv2.arrowedLine(img, com_px, smooth_end_px, (255, 255, 255), 5, tipLength=0.3)
|
||
|
||
|
||
class HandDataInspector(QWidget):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.setWindowTitle("Hand Kinematics Inspector")
|
||
self.resize(700, 800)
|
||
layout = QVBoxLayout(self)
|
||
self.tabs = QTabWidget()
|
||
layout.addWidget(self.tabs)
|
||
|
||
self.hand_trees = {}
|
||
self.joint_nodes = {'h0': {}, 'h1': {}}
|
||
self.group_nodes = {'h0': {}, 'h1': {}} # Track finger groups
|
||
self.hand_nodes = {'h0': None, 'h1': None}
|
||
|
||
# Define finger groups once
|
||
self.finger_map = {
|
||
"Palm / Wrist": [0, 1, 5, 9, 13, 17],
|
||
"Thumb": [2, 3, 4],
|
||
"Index": [6, 7, 8],
|
||
"Middle": [10, 11, 12],
|
||
"Ring": [14, 15, 16],
|
||
"Pinky": [18, 19, 20]
|
||
}
|
||
|
||
self.setup_hand_tab("h0", "Hand 0 (Cyan)")
|
||
self.setup_hand_tab("h1", "Hand 1 (Magenta)")
|
||
|
||
def setup_hand_tab(self, key, label):
|
||
tree = QTreeWidget()
|
||
tree.setColumnCount(4)
|
||
tree.setHeaderLabels(["Feature", "Position (X,Y)", "Vector (Vx, Vy)", "Speed/Angle"])
|
||
tree.header().setSectionResizeMode(QHeaderView.ResizeMode.Stretch)
|
||
self.tabs.addTab(tree, label)
|
||
self.hand_trees[key] = tree
|
||
|
||
# Create Root
|
||
hand_root = QTreeWidgetItem(tree, ["Whole Hand", "", "", ""])
|
||
self.hand_nodes[key] = hand_root
|
||
|
||
# Create Groups (Fingers)
|
||
for group_name, indices in self.finger_map.items():
|
||
parent = QTreeWidgetItem(hand_root, [group_name, "", "", ""])
|
||
self.group_nodes[key][group_name] = parent # Store for updating
|
||
for idx in indices:
|
||
child = QTreeWidgetItem(parent, [f"Joint {idx}", "", "", ""])
|
||
self.joint_nodes[key][idx] = child
|
||
|
||
# START EXPANDED
|
||
tree.expandAll()
|
||
|
||
def update_hand_data(self, hand_key, current_pts, last_pts, fps):
|
||
if not current_pts or len(current_pts) < 21: return
|
||
|
||
# 1. Update Every Individual Joint
|
||
joint_velocities = {}
|
||
for idx, pt in enumerate(current_pts):
|
||
vx, vy, speed, angle = 0, 0, 0, 0
|
||
if last_pts and len(last_pts) == 21:
|
||
vx = (pt[0] - last_pts[idx][0]) * fps
|
||
vy = (pt[1] - last_pts[idx][1]) * fps
|
||
speed = (vx**2 + vy**2)**0.5
|
||
angle = math.degrees(math.atan2(vy, vx))
|
||
|
||
joint_velocities[idx] = (vx, vy, speed, angle)
|
||
node = self.joint_nodes[hand_key][idx]
|
||
node.setText(1, f"{pt[0]:.3f}, {pt[1]:.3f}")
|
||
node.setText(2, f"{vx:+.2f}, {vy:+.2f}")
|
||
node.setText(3, f"{speed:.2f} @ {angle:.0f}°")
|
||
|
||
# 2. Update Finger Groups (Averaging their specific joints)
|
||
for group_name, indices in self.finger_map.items():
|
||
g_pts = [current_pts[i] for i in indices]
|
||
g_vels = [joint_velocities[i] for i in indices]
|
||
|
||
# Calculate Mean Position for this Finger
|
||
m_x = sum(p[0] for p in g_pts) / len(indices)
|
||
m_y = sum(p[1] for p in g_pts) / len(indices)
|
||
|
||
# Calculate Mean Velocity for this Finger
|
||
m_vx = sum(v[0] for v in g_vels) / len(indices)
|
||
m_vy = sum(v[1] for v in g_vels) / len(indices)
|
||
m_speed = (m_vx**2 + m_vy**2)**0.5
|
||
m_angle = math.degrees(math.atan2(m_vy, m_vx))
|
||
|
||
group_node = self.group_nodes[hand_key][group_name]
|
||
group_node.setText(1, f"{m_x:.3f}, {m_y:.3f}") # Numerical Mean Pos
|
||
group_node.setText(2, f"{m_vx:+.2f}, {m_vy:+.2f}")
|
||
group_node.setText(3, f"{m_speed:.2f} @ {m_angle:.0f}°")
|
||
|
||
# 3. Update Whole Hand (Mean of ALL 21 points)
|
||
# Position Mean
|
||
whole_x = sum(p[0] for p in current_pts) / 21
|
||
whole_y = sum(p[1] for p in current_pts) / 21
|
||
|
||
# Velocity Mean
|
||
whole_vx = sum(v[0] for v in joint_velocities.values()) / 21
|
||
whole_vy = sum(v[1] for v in joint_velocities.values()) / 21
|
||
whole_speed = (whole_vx**2 + whole_vy**2)**0.5
|
||
whole_angle = math.degrees(math.atan2(whole_vy, whole_vx))
|
||
|
||
root = self.hand_nodes[hand_key]
|
||
root.setText(1, f"{whole_x:.3f}, {whole_y:.3f}") # Numerical Mean Pos
|
||
root.setText(2, f"{whole_vx:+.2f}, {whole_vy:+.2f}")
|
||
root.setText(3, f"{whole_speed:.2f} @ {whole_angle:.0f}°")
|
||
|
||
|
||
|
||
class MainApplication(QMainWindow):
|
||
"""
|
||
Main application window that creates and sets up the UI.
|
||
"""
|
||
|
||
progress_update_signal = Signal(str, int)
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.setWindowTitle("SPARKS")
|
||
self.setGeometry(100, 100, 1280, 720)
|
||
|
||
self.about = None
|
||
self.help = None
|
||
self.optodes = None
|
||
self.events = None
|
||
self.terminal = None
|
||
self.bubble_widgets = {}
|
||
self.param_sections = []
|
||
self.folder_paths = []
|
||
self.section_widget = None
|
||
self.first_run = True
|
||
|
||
self.selection_widgets = {}
|
||
self.camera_state = {} # {(obs_id, cam_id): state info}
|
||
self.processing_widgets = {}
|
||
|
||
self.files_total = 0 # total number of files to process
|
||
self.files_done = set() # set of file paths done (success or fail)
|
||
self.files_failed = set() # set of failed file paths
|
||
self.files_results = {} # dict for successful results (if needed)
|
||
|
||
self.processing_threads = [] # List to hold all active ParticipantProcessor threads
|
||
self.processing_widgets = {}
|
||
|
||
self.init_ui()
|
||
self.create_menu_bar()
|
||
|
||
self.platform_suffix = "-" + PLATFORM_NAME
|
||
self.pending_update_version = None
|
||
self.pending_update_path = None
|
||
self.last_clicked_bubble = None
|
||
self.installEventFilter(self)
|
||
|
||
self.file_metadata = {}
|
||
self.current_file = None
|
||
|
||
self.worker_thread = None
|
||
self.progress_dialog = None
|
||
|
||
# Mediapipe hands
|
||
# --- Mediapipe Hands (Updated for Tasks API) ---
|
||
# Ensure 'hand_landmarker.task' is in your project directory
|
||
base_options = python.BaseOptions(model_asset_path='hand_landmarker.task')
|
||
|
||
# Mapping old parameters to new HandLandmarkerOptions
|
||
# static_image_mode=True -> vision.RunningMode.IMAGE
|
||
# max_num_hands=2 -> num_hands=2
|
||
# min_detection_confidence -> min_hand_detection_confidence
|
||
self.mp_hands_options = vision.HandLandmarkerOptions(
|
||
base_options=base_options,
|
||
running_mode=vision.RunningMode.IMAGE,
|
||
num_hands=2,
|
||
min_hand_detection_confidence=0.5,
|
||
min_hand_presence_confidence=0.5,
|
||
min_tracking_confidence=0.5
|
||
)
|
||
|
||
# Initialize the detector
|
||
self.mp_hands = vision.HandLandmarker.create_from_options(self.mp_hands_options)
|
||
|
||
# Start local pending update check thread
|
||
self.local_check_thread = LocalPendingUpdateCheckThread(CURRENT_VERSION, self.platform_suffix)
|
||
self.local_check_thread.pending_update_found.connect(self.on_pending_update_found)
|
||
self.local_check_thread.no_pending_update.connect(self.on_no_pending_update)
|
||
self.local_check_thread.start()
|
||
|
||
|
||
def init_ui(self):
|
||
|
||
# Central widget and main horizontal layout
|
||
central = QWidget()
|
||
self.setCentralWidget(central)
|
||
|
||
main_layout = QHBoxLayout()
|
||
central.setLayout(main_layout)
|
||
|
||
# Left container with vertical layout: top left + bottom left
|
||
left_container = QWidget()
|
||
left_layout = QVBoxLayout()
|
||
left_container.setLayout(left_layout)
|
||
left_container.setMinimumWidth(300)
|
||
|
||
top_left_container = QGroupBox()
|
||
top_left_container.setTitle("File information")
|
||
top_left_container.setStyleSheet("QGroupBox { font-weight: bold; }") # Style if needed
|
||
top_left_container.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding)
|
||
|
||
top_left_layout = QHBoxLayout()
|
||
|
||
top_left_container.setLayout(top_left_layout)
|
||
|
||
# QTextEdit with fixed height, but only 80% width
|
||
self.top_left_widget = QTextEdit()
|
||
self.top_left_widget.setReadOnly(True)
|
||
self.top_left_widget.setPlaceholderText("Logging information will be available here.")
|
||
|
||
# Add QTextEdit to the layout with a stretch factor
|
||
top_left_layout.addWidget(self.top_left_widget, stretch=4) # 80%
|
||
|
||
# Create a vertical box layout for the right 20%
|
||
self.right_column_widget = QWidget()
|
||
right_column_layout = QVBoxLayout()
|
||
self.right_column_widget.setLayout(right_column_layout)
|
||
|
||
self.meta_fields = {
|
||
"HAND": QLineEdit(),
|
||
"GENDER": QLineEdit(),
|
||
"GROUP": QLineEdit(),
|
||
}
|
||
|
||
for key, field in self.meta_fields.items():
|
||
label = QLabel(key.capitalize())
|
||
field.setPlaceholderText(f"Enter {key}")
|
||
right_column_layout.addWidget(label)
|
||
right_column_layout.addWidget(field)
|
||
|
||
label_desc = QLabel('<a href="#">Why are these useful?</a>')
|
||
label_desc.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction)
|
||
label_desc.setOpenExternalLinks(False)
|
||
|
||
def show_info_popup():
|
||
QMessageBox.information(None, "Parameter Info - SPARKS",
|
||
"Age: Used to calculate the DPF factor.\nGender: Not currently used. "
|
||
"Will be able to sort into groups by gender in the near future.\nGroup: Allows contrast "
|
||
"images to be created comparing one group to another once the processing has completed.")
|
||
|
||
label_desc.linkActivated.connect(show_info_popup)
|
||
right_column_layout.addWidget(label_desc)
|
||
|
||
right_column_layout.addStretch() # Push fields to top
|
||
|
||
self.right_column_widget.hide()
|
||
|
||
# Add right column widget to the top-left layout (takes 20% width)
|
||
top_left_layout.addWidget(self.right_column_widget, stretch=1)
|
||
|
||
# Add top_left_container to the main left_layout
|
||
left_layout.addWidget(top_left_container, stretch=2)
|
||
|
||
# Bottom left: the bubbles inside the scroll area
|
||
self.bubble_container = QWidget()
|
||
self.bubble_layout = QGridLayout()
|
||
self.bubble_layout.setAlignment(Qt.AlignmentFlag.AlignTop)
|
||
self.bubble_container.setLayout(self.bubble_layout)
|
||
|
||
self.scroll_area = QScrollArea()
|
||
self.scroll_area.setWidgetResizable(True)
|
||
self.scroll_area.setWidget(self.bubble_container)
|
||
self.scroll_area.setMinimumHeight(300)
|
||
|
||
# Add top left and bottom left to left layout
|
||
left_layout.addWidget(self.scroll_area, stretch=8)
|
||
|
||
# Right widget (full height on right side)
|
||
self.right_container = QWidget()
|
||
right_container_layout = QVBoxLayout()
|
||
self.right_container.setLayout(right_container_layout)
|
||
|
||
# Content widget inside scroll area
|
||
self.right_content_widget = QWidget()
|
||
right_content_layout = QVBoxLayout()
|
||
self.right_content_widget.setLayout(right_content_layout)
|
||
|
||
# Option selector dropdown
|
||
self.option_selector = QComboBox()
|
||
self.option_selector.addItems(["N/A"])
|
||
right_content_layout.addWidget(self.option_selector)
|
||
|
||
# Container for the sections
|
||
self.rows_container = QWidget()
|
||
self.rows_layout = QVBoxLayout()
|
||
self.rows_layout.setSpacing(10)
|
||
self.rows_container.setLayout(self.rows_layout)
|
||
right_content_layout.addWidget(self.rows_container)
|
||
|
||
# Spacer at bottom inside scroll area content to push content up
|
||
right_content_layout.addStretch()
|
||
|
||
# Scroll area for the right side content
|
||
self.right_scroll_area = QScrollArea()
|
||
self.right_scroll_area.setWidgetResizable(True)
|
||
self.right_scroll_area.setWidget(self.right_content_widget)
|
||
|
||
# Buttons widget (fixed below the scroll area)
|
||
buttons_widget = QWidget()
|
||
buttons_layout = QHBoxLayout()
|
||
buttons_widget.setLayout(buttons_layout)
|
||
buttons_layout.addStretch()
|
||
|
||
self.button1 = QPushButton("Process")
|
||
self.button2 = QPushButton("Clear")
|
||
|
||
buttons_layout.addWidget(self.button1)
|
||
buttons_layout.addWidget(self.button2)
|
||
|
||
self.button1.setMinimumSize(100, 40)
|
||
self.button2.setMinimumSize(100, 40)
|
||
|
||
self.button1.setVisible(False)
|
||
|
||
self.button1.clicked.connect(self.on_run_task)
|
||
self.button2.clicked.connect(self.clear_all)
|
||
|
||
# Add scroll area and buttons widget to right container layout
|
||
right_container_layout.addWidget(self.right_scroll_area)
|
||
right_container_layout.addWidget(buttons_widget)
|
||
|
||
# Add left and right containers to main layout
|
||
main_layout.addWidget(left_container, stretch=70)
|
||
main_layout.addWidget(self.right_container, stretch=30)
|
||
|
||
# Set size policy to expand
|
||
self.right_container.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
||
self.right_scroll_area.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
||
|
||
# Store ParamSection widgets
|
||
self.option_selector.currentIndexChanged.connect(self.update_sections)
|
||
|
||
# Initial build
|
||
self.update_sections(0)
|
||
|
||
|
||
def create_menu_bar(self):
|
||
'''Menu Bar at the top of the screen'''
|
||
|
||
menu_bar = self.menuBar()
|
||
|
||
def make_action(name, shortcut=None, slot=None, checkable=False, checked=False, icon=None):
|
||
action = QAction(name, self)
|
||
|
||
if shortcut:
|
||
action.setShortcut(QKeySequence(shortcut))
|
||
if slot:
|
||
action.triggered.connect(slot)
|
||
if checkable:
|
||
action.setCheckable(True)
|
||
action.setChecked(checked)
|
||
if icon:
|
||
action.setIcon(QIcon(icon))
|
||
return action
|
||
|
||
# File menu and actions
|
||
file_menu = menu_bar.addMenu("File")
|
||
file_actions = [
|
||
("Open BORIS file...", "Ctrl+O", self.open_file_dialog, resource_path("icons/file_open_24dp_1F1F1F.svg")),
|
||
("Open Video file...", "Ctrl+Alt+O", self.open_video_file_dialog, resource_path("icons/file_open_24dp_1F1F1F.svg")),
|
||
#("Open Folders...", "Ctrl+Shift+O", self.open_folder_dialog, resource_path("icons/folder_copy_24dp_1F1F1F.svg")),
|
||
#("Load Project...", "Ctrl+L", self.load_project, resource_path("icons/article_24dp_1F1F1F.svg")),
|
||
#("Save Project...", "Ctrl+S", self.save_project, resource_path("icons/save_24dp_1F1F1F.svg")),
|
||
#("Save Project As...", "Ctrl+Shift+S", self.save_project, resource_path("icons/save_as_24dp_1F1F1F.svg")),
|
||
]
|
||
|
||
for i, (name, shortcut, slot, icon) in enumerate(file_actions):
|
||
file_menu.addAction(make_action(name, shortcut, slot, icon=icon))
|
||
if i == 1: # after the first 3 actions (0,1,2)
|
||
file_menu.addSeparator()
|
||
|
||
file_menu.addSeparator()
|
||
file_menu.addAction(make_action("Exit", "Ctrl+Q", QApplication.instance().quit, icon=resource_path("icons/exit_to_app_24dp_1F1F1F.svg")))
|
||
|
||
# Edit menu
|
||
edit_menu = menu_bar.addMenu("Edit")
|
||
edit_actions = [
|
||
("Cut", "Ctrl+X", self.cut_text, resource_path("icons/content_cut_24dp_1F1F1F.svg")),
|
||
("Copy", "Ctrl+C", self.copy_text, resource_path("icons/content_copy_24dp_1F1F1F.svg")),
|
||
("Paste", "Ctrl+V", self.paste_text, resource_path("icons/content_paste_24dp_1F1F1F.svg"))
|
||
]
|
||
for name, shortcut, slot, icon in edit_actions:
|
||
edit_menu.addAction(make_action(name, shortcut, slot, icon=icon))
|
||
|
||
|
||
model_menu = menu_bar.addMenu("Model")
|
||
model_actions = [
|
||
("Create model from a CSV", "Ctrl+U", self.train_model_csv, resource_path("icons/content_cut_24dp_1F1F1F.svg")),
|
||
("Create model from a folder", "Ctrl+I", self.train_model_folder, resource_path("icons/content_copy_24dp_1F1F1F.svg")),
|
||
("Test model on a video", "Ctrl+O", self.test_model_video, resource_path("icons/content_paste_24dp_1F1F1F.svg")),
|
||
("Test model on a folder", "Ctrl+P", self.test_model_folder, resource_path("icons/content_paste_24dp_1F1F1F.svg")),
|
||
("Test model on a CSV", "Ctrl+P", self.test_model_csv, resource_path("icons/content_paste_24dp_1F1F1F.svg")),
|
||
("Test CSV on a Video", "Ctrl+P", self.open_validator, resource_path("icons/content_paste_24dp_1F1F1F.svg"))
|
||
|
||
]
|
||
for i, (name, shortcut, slot, icon) in enumerate(model_actions):
|
||
model_menu.addAction(make_action(name, shortcut, slot, icon=icon))
|
||
if i == 1 or i == 4:
|
||
model_menu.addSeparator()
|
||
|
||
# View menu
|
||
view_menu = menu_bar.addMenu("View")
|
||
toggle_statusbar_action = make_action("Toggle Status Bar", checkable=True, checked=True, slot=None)
|
||
view_menu.addAction(toggle_statusbar_action)
|
||
|
||
# Options menu (Help & About)
|
||
options_menu = menu_bar.addMenu("Options")
|
||
|
||
options_actions = [
|
||
("User Guide", "F1", self.user_guide, resource_path("icons/help_24dp_1F1F1F.svg")),
|
||
("Check for Updates", "F5", self.manual_check_for_updates, resource_path("icons/update_24dp_1F1F1F.svg")),
|
||
("About", "F12", self.about_window, resource_path("icons/info_24dp_1F1F1F.svg"))
|
||
]
|
||
|
||
for i, (name, shortcut, slot, icon) in enumerate(options_actions):
|
||
options_menu.addAction(make_action(name, shortcut, slot, icon=icon))
|
||
if i == 1 or i == 3: # after the first 2 actions (0,1)
|
||
options_menu.addSeparator()
|
||
|
||
terminal_menu = menu_bar.addMenu("Terminal")
|
||
terminal_actions = [
|
||
("New Terminal", "Ctrl+Alt+T", self.terminal_gui, resource_path("icons/terminal_24dp_1F1F1F.svg")),
|
||
]
|
||
for name, shortcut, slot, icon in terminal_actions:
|
||
terminal_menu.addAction(make_action(name, shortcut, slot, icon=icon))
|
||
|
||
self.statusbar = self.statusBar()
|
||
self.statusbar.showMessage("Ready")
|
||
|
||
|
||
def update_sections(self, index):
|
||
# Clear previous sections
|
||
for i in reversed(range(self.rows_layout.count())):
|
||
widget = self.rows_layout.itemAt(i).widget()
|
||
if widget is not None:
|
||
widget.deleteLater()
|
||
self.param_sections.clear()
|
||
|
||
# Add ParamSection widgets from SECTIONS
|
||
for section in SECTIONS:
|
||
self.section_widget = ParamSection(section)
|
||
self.rows_layout.addWidget(self.section_widget)
|
||
|
||
self.param_sections.append(self.section_widget)
|
||
|
||
|
||
def clear_bubble_widgets(self):
|
||
"""Clears all widgets from the self.bubble_layout."""
|
||
while self.bubble_layout.count():
|
||
# Get the next item in the layout
|
||
item = self.bubble_layout.takeAt(0)
|
||
|
||
# Check if the item holds a widget (most common case: QGroupBox)
|
||
widget = item.widget()
|
||
if widget:
|
||
widget.deleteLater()
|
||
|
||
# Check if the item is a spacer/stretch item
|
||
elif item.spacerItem():
|
||
self.bubble_layout.removeItem(item)
|
||
|
||
def clear_all(self):
|
||
|
||
self.cancel_task()
|
||
|
||
self.right_column_widget.hide()
|
||
|
||
# Clear the bubble layout
|
||
while self.bubble_layout.count():
|
||
item = self.bubble_layout.takeAt(0)
|
||
widget = item.widget()
|
||
if widget:
|
||
widget.deleteLater()
|
||
|
||
# Clear file data
|
||
self.bubble_widgets.clear()
|
||
self.statusBar().clearMessage()
|
||
|
||
self.raw_haemo_dict = None
|
||
self.epochs_dict = None
|
||
self.fig_bytes_dict = None
|
||
self.cha_dict = None
|
||
self.contrast_results_dict = None
|
||
self.df_ind_dict = None
|
||
self.design_matrix_dict = None
|
||
self.age_dict = None
|
||
self.gender_dict = None
|
||
self.group_dict = None
|
||
self.valid_dict = None
|
||
|
||
# Reset any visible UI elements
|
||
self.button1.setVisible(False)
|
||
self.top_left_widget.clear()
|
||
|
||
|
||
def get_output_directory(self):
|
||
"""Prompts user to select a parent folder and creates a unique, dated subfolder."""
|
||
|
||
# Open a folder selection dialog to get the parent folder
|
||
parent_dir = QFileDialog.getExistingDirectory(
|
||
self,
|
||
"Select Parent Folder for Results",
|
||
os.getcwd() # Start in the user's home directory
|
||
)
|
||
|
||
if not parent_dir:
|
||
return None # User cancelled the dialog
|
||
|
||
# Create the unique subfolder name: sparks_YYYYMMDD_HHMMSS
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
subfolder_name = f"sparks_{timestamp}"
|
||
output_path = os.path.join(parent_dir, subfolder_name)
|
||
|
||
# Create the directory if it doesn't exist
|
||
try:
|
||
os.makedirs(output_path, exist_ok=True)
|
||
print(f"Created output directory: {output_path}")
|
||
return output_path
|
||
except Exception as e:
|
||
QMessageBox.critical(self, "Error", f"Failed to create output directory: {e}")
|
||
return None
|
||
|
||
def copy_text(self):
|
||
self.top_left_widget.copy() # Trigger copy
|
||
self.statusbar.showMessage("Copied to clipboard") # Show status message
|
||
|
||
def cut_text(self):
|
||
self.top_left_widget.cut() # Trigger cut
|
||
self.statusbar.showMessage("Cut to clipboard") # Show status message
|
||
|
||
def paste_text(self):
|
||
self.top_left_widget.paste() # Trigger paste
|
||
self.statusbar.showMessage("Pasted from clipboard") # Show status message
|
||
|
||
|
||
def get_training_files(self):
|
||
"""Opens a dialog to select multiple CSV files for training."""
|
||
file_paths, _ = QFileDialog.getOpenFileNames(
|
||
self,
|
||
"Select Training CSV Files",
|
||
os.getcwd(), # Start in current directory
|
||
"CSV Files (*.csv)" # Filter for CSV files
|
||
)
|
||
return file_paths
|
||
|
||
def get_testing_files(self):
|
||
"""Opens dialogs to select the model file and the video file."""
|
||
|
||
# 1. Select Model File (.pth)
|
||
model_path, _ = QFileDialog.getOpenFileName(
|
||
self,
|
||
"Select Trained Model File (.pth)",
|
||
os.getcwd(),
|
||
"Model Files (*.pth)"
|
||
)
|
||
if not model_path: return None, None
|
||
|
||
# 2. Select Video File
|
||
video_path, _ = QFileDialog.getOpenFileName(
|
||
self,
|
||
"Select Video File for Testing",
|
||
os.getcwd(),
|
||
"Video Files (*.mp4 *.avi *.mov)"
|
||
)
|
||
|
||
return model_path, video_path
|
||
|
||
|
||
def get_testing_files2(self):
|
||
"""Opens dialogs to select the model file and the csv file."""
|
||
|
||
# 1. Select Model File (.pth)
|
||
model_path, _ = QFileDialog.getOpenFileName(
|
||
self,
|
||
"Select Trained Model File (.pth)",
|
||
os.getcwd(),
|
||
"Model Files (*.pth)"
|
||
)
|
||
if not model_path: return None, None
|
||
|
||
# 2. Select Video File
|
||
csv_path, _ = QFileDialog.getOpenFileName(
|
||
self,
|
||
"Select Video File for Testing",
|
||
os.getcwd(),
|
||
"Comma Seperated Files (*.csv)"
|
||
)
|
||
|
||
return model_path, csv_path
|
||
|
||
def train_model_csv(self):
|
||
"""Initiates the model training process, including selecting CSVs and save path."""
|
||
|
||
# 1. Get CSV paths
|
||
csv_paths = self.get_training_files() # Re-use the existing CSV file selection method
|
||
if not csv_paths:
|
||
self.statusBar().showMessage("Training cancelled. No CSV files selected.")
|
||
return
|
||
|
||
# 2. Get Model Save Path (New Pop-up)
|
||
default_filename = "best_reach_lstm.pth"
|
||
|
||
# The third argument can include a default filename
|
||
model_save_path, _ = QFileDialog.getSaveFileName(
|
||
self,
|
||
"Save Trained Model File",
|
||
os.path.join(os.getcwd(), default_filename), # Start in CWD with a default name
|
||
"PyTorch Model Files (*.pth);;All Files (*)"
|
||
)
|
||
|
||
if not model_save_path:
|
||
self.statusBar().showMessage("Training cancelled. Model save location not specified.")
|
||
return
|
||
|
||
# Ensure correct extension if the user didn't type one
|
||
if not model_save_path.lower().endswith('.pth'):
|
||
model_save_path += '.pth'
|
||
|
||
self.statusBar().showMessage(f"Starting model training on {len(csv_paths)} files. Model will save to: {os.path.basename(model_save_path)}")
|
||
|
||
# 3. Start the thread with the new path
|
||
self.train_thread = TrainModelThread(
|
||
csv_paths=csv_paths,
|
||
model_save_path=model_save_path # Pass the path here
|
||
)
|
||
|
||
# ... (connect signals, start thread, disable button, etc.) ...
|
||
self.train_thread.update.connect(self.log_training_message)
|
||
self.train_thread.finished.connect(self.on_train_finished)
|
||
self.train_thread.start()
|
||
|
||
|
||
def log_training_message(self, message):
|
||
"""Displays training status messages (e.g., in a QTextEdit)."""
|
||
# Assuming you have a QTextEdit named self.log_output (or similar)
|
||
# For simplicity, we'll use the status bar for now, but a dedicated log box is better.
|
||
self.statusBar().showMessage(f"Training: {message}")
|
||
# self.top_left_widget.append(f"[TRAIN] {message}") # Use this if you have a log box
|
||
|
||
def on_train_finished(self, final_message):
|
||
"""Handles completion of the training thread."""
|
||
self.statusBar().showMessage(final_message)
|
||
# Show final message in a pop-up or log box
|
||
QMessageBox.information(self, "Training Complete", final_message)
|
||
|
||
# Clean up and re-enable the button
|
||
self.train_thread = None
|
||
|
||
def train_model_folder(self):
|
||
return
|
||
|
||
def test_model_video(self):
|
||
model_path, video_path = self.get_testing_files()
|
||
|
||
if not model_path or not video_path:
|
||
self.statusBar().showMessage("Testing cancelled. Model or Video not selected.")
|
||
return
|
||
|
||
self.statusBar().showMessage(f"Starting model testing on {os.path.basename(video_path)}...")
|
||
|
||
# 1. Create the thread
|
||
# Note: TestModelThread expects a list of video paths, so we pass [video_path]
|
||
self.test_thread = TestModelThread(video_paths=[video_path], model_path=model_path)
|
||
|
||
# 2. Connect signals
|
||
# The update_frame signal is not used to update the main GUI here,
|
||
# as the thread is responsible for showing the cv2.imshow window.
|
||
self.test_thread.finished.connect(self.on_test_finished)
|
||
|
||
# 3. Start the thread
|
||
self.test_thread.start()
|
||
|
||
|
||
|
||
def test_model_csv(self):
|
||
model_path, csv_path = self.get_testing_files2()
|
||
|
||
if not model_path or not csv_path:
|
||
self.statusBar().showMessage("Testing cancelled. Model or Video not selected.")
|
||
return
|
||
|
||
self.statusBar().showMessage(f"Starting model testing on {os.path.basename(csv_path)}...")
|
||
|
||
# 1. Create the thread
|
||
# Note: TestModelThread expects a list of video paths, so we pass [video_path]
|
||
self.test_thread = TestModelThread2(csv_path=csv_path, model_path=model_path)
|
||
|
||
# 2. Connect signals
|
||
# The update_frame signal is not used to update the main GUI here,
|
||
# as the thread is responsible for showing the cv2.imshow window.
|
||
self.test_thread.finished.connect(self.on_csv_analysis_finished)
|
||
|
||
# 3. Start the thread
|
||
self.test_thread.start()
|
||
|
||
|
||
def open_validator(self):
|
||
self.validator_window = HandValidationWindow()
|
||
self.validator_window.show()
|
||
|
||
def on_csv_analysis_finished(self, result):
|
||
if "error" in result:
|
||
QMessageBox.critical(self, "Error", result["error"])
|
||
return
|
||
|
||
time = result["time"]
|
||
active = result["active"]
|
||
before = result["before"]
|
||
|
||
import matplotlib.pyplot as plt
|
||
|
||
fig, axes = plt.subplots(2, 1, figsize=(14, 6), sharex=True)
|
||
|
||
# --- Top plot: Reach Active ---
|
||
axes[0].plot(time, active, color="green", label="Reach Active")
|
||
axes[0].set_ylabel("Probability")
|
||
axes[0].set_title("Reach Active")
|
||
axes[0].set_ylim(0, 1)
|
||
axes[0].grid(True)
|
||
|
||
# --- Bottom plot: Reach Before Contact ---
|
||
axes[1].plot(time, before, color="orange", label="Reach Before Contact")
|
||
axes[1].set_ylabel("Probability")
|
||
axes[1].set_title("Reach Before Contact")
|
||
axes[1].set_xlabel("Time (s)")
|
||
axes[1].set_ylim(0, 1)
|
||
axes[1].grid(True)
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
self.statusBar().showMessage("Complete.")
|
||
|
||
|
||
|
||
def on_test_finished(self):
|
||
"""Handles cleanup after the testing video window is closed."""
|
||
self.statusBar().showMessage("Model testing complete.")
|
||
|
||
# Ensure the OpenCV window is destroyed
|
||
try:
|
||
cv2.destroyAllWindows()
|
||
except:
|
||
pass
|
||
|
||
# Clean up and re-enable the button
|
||
self.test_thread = None
|
||
|
||
def test_model_folder(self):
|
||
return
|
||
|
||
|
||
|
||
|
||
def about_window(self):
|
||
if self.about is None or not self.about.isVisible():
|
||
self.about = AboutWindow(self)
|
||
self.about.show()
|
||
|
||
def user_guide(self):
|
||
if self.help is None or not self.help.isVisible():
|
||
self.help = UserGuideWindow(self)
|
||
self.help.show()
|
||
|
||
def terminal_gui(self):
|
||
if self.terminal is None or not self.terminal.isVisible():
|
||
self.terminal = TerminalWindow(self)
|
||
self.terminal.show()
|
||
|
||
def update_optode_positions(self):
|
||
return
|
||
if self.optodes is None or not self.optodes.isVisible():
|
||
self.optodes = UpdateOptodesWindow(self)
|
||
self.optodes.show()
|
||
|
||
|
||
def update_event_markers(self):
|
||
return
|
||
if self.events is None or not self.events.isVisible():
|
||
self.events = UpdateEventsWindow(self)
|
||
self.events.show()
|
||
|
||
def open_file_dialog(self):
|
||
file_path, _ = QFileDialog.getOpenFileName(
|
||
self, "Open File", "", "BORIS Files (*.boris);;All Files (*)"
|
||
)
|
||
if file_path:
|
||
|
||
self.observations_root = self.resolve_path_to_observations(file_path)
|
||
if not self.observations_root:
|
||
self.statusBar().showMessage("File loading cancelled by user.")
|
||
return # User cancelled the folder selection
|
||
|
||
# 2. Initialize and Show Progress Dialog
|
||
self.progress_dialog = ProgressDialog(self)
|
||
# ⚠️ Connect the cancel button to the task stopper
|
||
self.progress_dialog.cancel_button.clicked.connect(self.cancel_task)
|
||
self.progress_dialog.show()
|
||
|
||
# 3. Initialize and Start Worker Thread
|
||
self.worker_thread = FileLoadWorker(
|
||
file_path,
|
||
self.observations_root,
|
||
self.extract_frame_and_hands # Assuming this is a method available to MainApplication
|
||
)
|
||
|
||
# 4. Connect Signals to Main Thread Slots
|
||
self.worker_thread.progress_update.connect(self.progress_dialog.update_progress)
|
||
self.worker_thread.observations_loaded.connect(self.on_files_loaded)
|
||
self.worker_thread.loading_finished.connect(self.on_loading_finished)
|
||
self.worker_thread.loading_failed.connect(self.on_loading_failed)
|
||
|
||
self.worker_thread.start()
|
||
|
||
|
||
|
||
def open_video_file_dialog(self):
|
||
file_path, _ = QFileDialog.getOpenFileName(
|
||
self, "Open File", "", "Video Files (*.mp4 *.avi *.mov *.mkv *.wmv);;All Files (*)"
|
||
)
|
||
if file_path:
|
||
# 3. Initialize and Start Worker Thread
|
||
self.worker_thread = IndividualFileLoadWorker(
|
||
file_path,
|
||
self.extract_frame_and_hands # Assuming this is a method available to MainApplication
|
||
)
|
||
|
||
# 4. Connect Signals to Main Thread Slots
|
||
self.worker_thread.observations_loaded.connect(self.on_individual_files_loaded)
|
||
|
||
self.worker_thread.loading_finished.connect(self.on_loading_finished)
|
||
self.worker_thread.loading_failed.connect(self.on_loading_failed)
|
||
|
||
self.worker_thread.start()
|
||
|
||
|
||
def resolve_path_to_observations(self, boris_file_path):
|
||
"""
|
||
Attempts to find the 'Observations' directory based on the BORIS file location.
|
||
Prompts the user if it cannot be found.
|
||
"""
|
||
# 1. Get the directory containing the BORIS file
|
||
boris_dir = os.path.dirname(boris_file_path)
|
||
|
||
# 2. Check if the 'Observations' folder is adjacent to the BORIS file
|
||
initial_check_path = os.path.join(boris_dir, "Observations")
|
||
|
||
if os.path.isdir(initial_check_path):
|
||
# Success: The 'Observations' folder is found next to the BORIS file
|
||
return initial_check_path
|
||
|
||
# 3. If not found, prompt the user to locate the Observations folder
|
||
QMessageBox.warning(
|
||
self,
|
||
"Observations Folder Missing",
|
||
"The 'Observations' folder (containing the videos) was not found next to the BORIS file. Please select the correct root folder."
|
||
)
|
||
|
||
# Open a folder selection dialog
|
||
folder_path = QFileDialog.getExistingDirectory(
|
||
self,
|
||
"Select Observations Root Folder",
|
||
boris_dir # Start the dialog in the BORIS file's directory
|
||
)
|
||
|
||
if folder_path:
|
||
# Check if the user selected a folder named 'Observations'
|
||
if os.path.basename(folder_path) == "Observations":
|
||
return folder_path
|
||
|
||
# Check if the user selected the *parent* of 'Observations'
|
||
elif os.path.isdir(os.path.join(folder_path, "Observations")):
|
||
return os.path.join(folder_path, "Observations")
|
||
|
||
# If the selected folder is neither, return the selected path and hope it contains the relative paths
|
||
else:
|
||
return folder_path
|
||
|
||
# If the user cancels the dialog, return None
|
||
return None
|
||
|
||
|
||
|
||
def on_files_loaded(self, boris, observations, root_path):
|
||
# 1. Update MainApplication state variables
|
||
self.boris = boris
|
||
self.observations = observations
|
||
self.observations_root = root_path
|
||
|
||
# 2. Build the UI grid using the data gathered by the worker
|
||
self.build_preview_grid(self.worker_thread.previews_data)
|
||
|
||
self.statusBar().showMessage(f"{self.worker_thread.file_path} loaded.")
|
||
self.button1.setVisible(True)
|
||
|
||
|
||
def on_individual_files_loaded(self):
|
||
# 1. Update MainApplication state variables
|
||
|
||
# 2. Build the UI grid using the data gathered by the worker
|
||
self.build_individual_preview_grid(self.worker_thread.previews_data)
|
||
|
||
self.statusBar().showMessage(f"{self.worker_thread.file_path} loaded.")
|
||
self.button1.setVisible(True)
|
||
|
||
def on_loading_finished(self):
|
||
if self.progress_dialog:
|
||
self.progress_dialog.accept() # Close the dialog if finished successfully
|
||
|
||
def on_loading_failed(self, error_message):
|
||
if self.progress_dialog:
|
||
self.progress_dialog.reject() # Close the dialog
|
||
QMessageBox.critical(self, "Error Loading Files", error_message)
|
||
self.statusBar().showMessage("File loading failed.")
|
||
|
||
def cancel_task(self):
|
||
if self.worker_thread and self.worker_thread.isRunning():
|
||
self.worker_thread.stop() # Tell the thread to stop processing
|
||
self.worker_thread.wait() # Wait for the thread to safely exit
|
||
if self.progress_dialog:
|
||
self.progress_dialog.reject()
|
||
self.statusBar().showMessage("File loading cancelled.")
|
||
|
||
# Add a dummy implementation for the processing function if it's not shown:
|
||
# def extract_frame_and_hands(self, video_path, frame_idx):
|
||
# # Placeholder implementation - replace with your actual function
|
||
# return None, None
|
||
|
||
def build_preview_grid(self, previews_data):
|
||
for obs_id, obs_data in self.observations.items():
|
||
|
||
group = QGroupBox(f"Participant / Observation: {obs_id}")
|
||
grouplayout = QVBoxLayout(group)
|
||
|
||
# Participant-level skip dropdown
|
||
participant_dropdown = QComboBox()
|
||
participant_dropdown.addItem("Skip this participant", 0)
|
||
participant_dropdown.addItem("Process this participant", 1)
|
||
participant_dropdown.setCurrentIndex(0) # default: skip
|
||
if "participant_selection" not in self.selection_widgets:
|
||
self.selection_widgets["participant_selection"] = {}
|
||
self.selection_widgets["participant_selection"][obs_id] = participant_dropdown
|
||
grouplayout.addWidget(QLabel("Participant Option:"))
|
||
grouplayout.addWidget(participant_dropdown)
|
||
|
||
files_dict = obs_data.get("file", {})
|
||
media_info = obs_data.get("media_info", {})
|
||
print(media_info)
|
||
fps_dict = media_info.get("fps", {})
|
||
fps_default = float(list(fps_dict.values())[0]) if fps_dict else 30.0
|
||
|
||
if obs_id not in self.selection_widgets:
|
||
self.selection_widgets[obs_id] = {}
|
||
|
||
for cam_id, paths in files_dict.items():
|
||
|
||
# Check if the worker successfully gathered the preview data for this file
|
||
state_key = (obs_id, cam_id)
|
||
if state_key not in previews_data:
|
||
# Skip files the worker couldn't process (e.g., path not found)
|
||
continue
|
||
|
||
# Retrieve the pre-calculated data
|
||
state_data = previews_data[state_key]
|
||
frame_rgb = state_data["frame_rgb"]
|
||
results = state_data["results"]
|
||
|
||
# Only draw the overlay if the worker found data
|
||
if frame_rgb is not None:
|
||
display_img = self.draw_hand_overlay(frame_rgb, results)
|
||
else:
|
||
# Use a placeholder image or skip if no data
|
||
continue
|
||
|
||
# Convert to pixmap
|
||
h, w, _ = display_img.shape
|
||
qimg = QImage(display_img.data, w, h, 3 * w, QImage.Format_RGB888)
|
||
pix = QPixmap.fromImage(qimg).scaled(350, 350, Qt.KeepAspectRatio)
|
||
|
||
# -------- UI Row --------
|
||
row = QWidget()
|
||
row_layout = QHBoxLayout(row)
|
||
|
||
preview_label = QLabel()
|
||
preview_label.setPixmap(pix)
|
||
row_layout.addWidget(preview_label)
|
||
|
||
# Dropdown
|
||
dropdown = QComboBox()
|
||
dropdown.addItem("Skip this camera", -1)
|
||
if results and results.hand_landmarks:
|
||
num_hands = len(results.hand_landmarks)
|
||
|
||
# Individual Hand Options
|
||
for idx in range(num_hands):
|
||
dropdown.addItem(f"Use Hand {idx}", idx)
|
||
|
||
# NEW: Add "Both" option if more than 1 hand is detected
|
||
if num_hands > 1:
|
||
dropdown.addItem("Use Both Hands", 99) # Use 99 as a special flag
|
||
|
||
row_layout.addWidget(dropdown)
|
||
|
||
# Store dropdown
|
||
self.selection_widgets[obs_id][cam_id] = dropdown
|
||
|
||
self.camera_state[(obs_id, cam_id)] = {
|
||
"path": state_data["video_path"],
|
||
"frame_idx": 0,
|
||
"fps": state_data["fps"],
|
||
"preview_label": preview_label,
|
||
"dropdown": dropdown,
|
||
"initial_wrists": state_data["initial_wrists"]
|
||
}
|
||
|
||
# Skip button
|
||
skip_btn = QPushButton("Skip 1s → Rescan")
|
||
skip_btn.clicked.connect(lambda _, o=obs_id, c=cam_id: self.skip_and_rescan(o, c))
|
||
row_layout.addWidget(skip_btn)
|
||
|
||
grouplayout.addWidget(row)
|
||
|
||
self.bubble_layout.addWidget(group)
|
||
|
||
#self.bubble_layout.addStretch()
|
||
|
||
|
||
|
||
def build_individual_preview_grid(self, previews_data):
|
||
group = QGroupBox(f"Participant / Observation: {1}")
|
||
grouplayout = QVBoxLayout(group)
|
||
|
||
# Participant-level skip dropdown
|
||
participant_dropdown = QComboBox()
|
||
participant_dropdown.addItem("Skip this participant", 0)
|
||
participant_dropdown.addItem("Process this participant", 1)
|
||
participant_dropdown.setCurrentIndex(0) # default: skip
|
||
if "participant_selection" not in self.selection_widgets:
|
||
self.selection_widgets["participant_selection"] = {}
|
||
self.selection_widgets["participant_selection"][1] = participant_dropdown
|
||
grouplayout.addWidget(QLabel("Participant Option:"))
|
||
grouplayout.addWidget(participant_dropdown)
|
||
|
||
if 1 not in self.selection_widgets:
|
||
self.selection_widgets[1] = {}
|
||
|
||
|
||
|
||
# Check if the worker successfully gathered the preview data for this file
|
||
state_key = (1, 1)
|
||
|
||
# Retrieve the pre-calculated data
|
||
state_data = previews_data[state_key]
|
||
frame_rgb = state_data["frame_rgb"]
|
||
results = state_data["results"]
|
||
|
||
# Only draw the overlay if the worker found data
|
||
if frame_rgb is not None:
|
||
display_img, updated_centroids = self.draw_hand_overlay(frame_rgb, results)
|
||
|
||
h, w, _ = display_img.shape
|
||
qimg = QImage(display_img.data, w, h, 3 * w, QImage.Format_RGB888)
|
||
pix = QPixmap.fromImage(qimg).scaled(350, 350, Qt.KeepAspectRatio)
|
||
|
||
# -------- UI Row --------
|
||
row = QWidget()
|
||
row_layout = QHBoxLayout(row)
|
||
|
||
preview_label = QLabel()
|
||
preview_label.setPixmap(pix)
|
||
row_layout.addWidget(preview_label)
|
||
|
||
# Dropdown
|
||
dropdown = QComboBox()
|
||
dropdown.addItem("Skip this camera", -1)
|
||
if results and results.hand_landmarks:
|
||
num_hands = len(results.hand_landmarks)
|
||
|
||
# Individual Hand Options
|
||
for idx in range(num_hands):
|
||
dropdown.addItem(f"Use Hand {idx}", idx)
|
||
|
||
# NEW: Add "Both" option if more than 1 hand is detected
|
||
if num_hands > 1:
|
||
dropdown.addItem("Use Both Hands", 99) # Use 99 as a special flag
|
||
|
||
|
||
row_layout.addWidget(dropdown)
|
||
|
||
# Store dropdown
|
||
self.selection_widgets[1][1] = dropdown
|
||
|
||
self.camera_state[(1, 1)] = {
|
||
"path": state_data["video_path"],
|
||
"frame_idx": 0,
|
||
"fps": state_data["fps"],
|
||
"preview_label": preview_label,
|
||
"dropdown": dropdown,
|
||
"initial_wrists": state_data["initial_wrists"]
|
||
}
|
||
|
||
# Skip button
|
||
skip_btn = QPushButton("Skip 1s → Rescan")
|
||
skip_btn.clicked.connect(lambda _, o=1, c=1: self.skip_and_rescan(o, c))
|
||
row_layout.addWidget(skip_btn)
|
||
|
||
grouplayout.addWidget(row)
|
||
|
||
self.bubble_layout.addWidget(group)
|
||
|
||
#self.bubble_layout.addStretch()
|
||
|
||
|
||
# ============================================
|
||
# FRAME EXTRACTION + MEDIA PIPE DETECTION
|
||
# ============================================
|
||
def extract_frame_and_hands(self, video_path, frame_idx):
|
||
cap = cv2.VideoCapture(video_path)
|
||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
||
ret, frame = cap.read()
|
||
cap.release()
|
||
|
||
if not ret:
|
||
return None, None
|
||
|
||
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
# 1. Convert BGR to RGB (OpenCV uses BGR by default)
|
||
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
|
||
# 2. Wrap the numpy array in a MediaPipe Image object
|
||
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_frame)
|
||
|
||
# 3. Use .detect() instead of .process()
|
||
# Note: Since your __init__ uses RunningMode.IMAGE, use .detect()
|
||
results = self.mp_hands.detect(mp_image)
|
||
return rgb, results
|
||
|
||
|
||
# ============================================
|
||
# DRAW OVERLAYED HANDS + BIG LABELS
|
||
# ============================================
|
||
def draw_hand_overlay(self, img, results, last_known_centroids=None):
|
||
draw_img = img.copy()
|
||
h, w, _ = draw_img.shape
|
||
COLORS = [(0, 255, 0), (255, 0, 255), (0, 255, 255)]
|
||
|
||
current_centroids = {} # We will return this
|
||
|
||
if results and results.hand_landmarks:
|
||
hand_landmarks_list = results.hand_landmarks
|
||
hand_mapping = {}
|
||
used_indices = set()
|
||
|
||
# 1. TRACKING: Match detected hands to the LAST known centroids
|
||
if last_known_centroids:
|
||
# last_known_centroids should be a dict: {id: (x, y)}
|
||
for hand_id, last_pos in last_known_centroids.items():
|
||
min_dist = float('inf')
|
||
best_idx = None
|
||
|
||
for j, lm_list in enumerate(hand_landmarks_list):
|
||
if j in used_indices: continue
|
||
|
||
current_c = self.get_centroid(lm_list)
|
||
dist = (current_c[0] - last_pos[0])**2 + (current_c[1] - last_pos[1])**2
|
||
|
||
if dist < min_dist:
|
||
min_dist = dist
|
||
best_idx = j
|
||
|
||
if best_idx is not None and min_dist < 0.1: # Threshold to prevent "teleporting"
|
||
hand_mapping[hand_id] = best_idx
|
||
used_indices.add(best_idx)
|
||
|
||
# 2. DISCOVERY: If a hand wasn't matched (or no last_known), assign it a new ID
|
||
for j in range(len(hand_landmarks_list)):
|
||
if j not in used_indices:
|
||
new_id = 0
|
||
while new_id in hand_mapping or (last_known_centroids and new_id in last_known_centroids):
|
||
new_id += 1
|
||
hand_mapping[new_id] = j
|
||
used_indices.add(j)
|
||
|
||
|
||
# Define connections manually (since mp.solutions.hands.HAND_CONNECTIONS is gone)
|
||
HAND_CONNECTIONS = [
|
||
(0, 1), (1, 2), (2, 3), (3, 4), # Thumb
|
||
(0, 5), (5, 6), (6, 7), (7, 8), # Index
|
||
(9, 10), (10, 11), (11, 12), # Middle
|
||
(13, 14), (14, 15), (15, 16), # Ring
|
||
(0, 17), (17, 18), (18, 19), (19, 20), # Pinky
|
||
(5, 9), (9, 13), (13, 17) # Palm
|
||
]
|
||
|
||
for hand_id, current_idx in hand_mapping.items():
|
||
lm_list = hand_landmarks_list[current_idx]
|
||
color = COLORS[hand_id % len(COLORS)]
|
||
|
||
current_centroids[hand_id] = self.get_centroid(lm_list)
|
||
|
||
# Change 3: Manual Drawing (since solutions.drawing_utils is removed)
|
||
# Draw Connections
|
||
for connection in HAND_CONNECTIONS:
|
||
start_lm = lm_list[connection[0]]
|
||
end_lm = lm_list[connection[1]]
|
||
cv2.line(draw_img,
|
||
(int(start_lm.x * w), int(start_lm.y * h)),
|
||
(int(end_lm.x * w), int(end_lm.y * h)),
|
||
color, 2)
|
||
|
||
# Draw Landmarks
|
||
for lm in lm_list:
|
||
cv2.circle(draw_img, (int(lm.x * w), int(lm.y * h)), 5, (255, 255, 255), -1)
|
||
|
||
# Draw Label
|
||
wrist = lm_list[0]
|
||
wx, wy = int(wrist.x * w), int(wrist.y * h)
|
||
|
||
# Big outline
|
||
cv2.putText(draw_img, str(hand_id), (wx, wy - 40),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 3.2,
|
||
(0, 0, 0), 10, cv2.LINE_AA)
|
||
# Big colored label
|
||
cv2.putText(draw_img, str(hand_id), (wx, wy - 40),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 3.2,
|
||
color, 6, cv2.LINE_AA)
|
||
|
||
combined_centroids = last_known_centroids.copy() if last_known_centroids else {}
|
||
combined_centroids.update(current_centroids)
|
||
|
||
return draw_img, combined_centroids
|
||
|
||
|
||
# ============================================
|
||
# SKIP 1s AND RESCAN FRAME
|
||
# ============================================
|
||
def skip_and_rescan(self, obs_id, cam_id):
|
||
key = (obs_id, cam_id)
|
||
state = self.camera_state[key]
|
||
|
||
video_path = state["path"]
|
||
start_frame = state["frame_idx"]
|
||
fps = int(state.get("fps", 60))
|
||
end_frame = start_frame + fps
|
||
|
||
# 1. Initialize Video Capture
|
||
cap = cv2.VideoCapture(video_path)
|
||
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
||
|
||
# Get the current tracking 'anchors'
|
||
# We ensure it's a dict for our 'sticky' logic
|
||
last_centroids = state.get("initial_centroids", {})
|
||
if isinstance(last_centroids, list):
|
||
last_centroids = {i: pos for i, pos in enumerate(last_centroids)}
|
||
|
||
# 2. BRIDGE THE GAP: Process intermediate frames at high speed
|
||
# We skip drawing to save processing time
|
||
for f_idx in range(start_frame, end_frame):
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
|
||
# Convert and detect
|
||
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_frame)
|
||
results = self.mp_hands.detect(mp_image)
|
||
|
||
# Update the centroids based on this frame's movement
|
||
# This keeps the IDs 'locked' to the moving hands
|
||
last_centroids = self.update_tracking_only(results, last_centroids)
|
||
|
||
# 3. FINAL FRAME: Read the destination frame and display it
|
||
# Note: cap.read() has already moved to end_frame after the loop
|
||
ret, final_frame = cap.read()
|
||
cap.release()
|
||
|
||
if not ret:
|
||
return
|
||
|
||
rgb_final = cv2.cvtColor(final_frame, cv2.COLOR_BGR2RGB)
|
||
mp_final = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_final)
|
||
final_results = self.mp_hands.detect(mp_final)
|
||
|
||
# Draw the final overlay using our 'bridged' centroids
|
||
display_img, final_centroids = self.draw_hand_overlay(rgb_final, final_results, last_centroids)
|
||
|
||
# Update State for UI and future tracking
|
||
state["frame_idx"] = end_frame + 1
|
||
state["initial_centroids"] = final_centroids
|
||
|
||
# Update UI Pixmap
|
||
h, w, _ = display_img.shape
|
||
qimg = QImage(display_img.data, w, h, 3 * w, QImage.Format_RGB888)
|
||
pix = QPixmap.fromImage(qimg).scaled(350, 350, Qt.KeepAspectRatio)
|
||
state["preview_label"].setPixmap(pix)
|
||
|
||
# Update Dropdown
|
||
dropdown = state["dropdown"]
|
||
dropdown.clear()
|
||
dropdown.addItem("Skip this camera", -1)
|
||
if final_results and final_results.hand_landmarks:
|
||
for idx in range(len(final_results.hand_landmarks)):
|
||
dropdown.addItem(f"Use Hand {idx}", idx)
|
||
# NEW: Add "Both" option if more than 1 hand is detected
|
||
if len(final_results.hand_landmarks) > 1:
|
||
dropdown.addItem("Use Both Hands", 99) # Use 99 as a special flag
|
||
|
||
def get_centroid(self, lm_list):
|
||
"""Calculates the center of mass for all 21 hand landmarks."""
|
||
avg_x = sum(lm.x for lm in lm_list) / len(lm_list)
|
||
avg_y = sum(lm.y for lm in lm_list) / len(lm_list)
|
||
return (avg_x, avg_y)
|
||
|
||
def update_tracking_only(self, results, last_known):
|
||
"""Updates hand positions without drawing. Perfect for high-speed bridging."""
|
||
if not results or not results.hand_landmarks:
|
||
# If tracking is lost this frame, return last known positions
|
||
return last_known
|
||
|
||
detected_hands = results.hand_landmarks
|
||
new_centroids = {}
|
||
used_indices = set()
|
||
|
||
# Match current detections to last known IDs
|
||
for hand_id, last_pos in last_known.items():
|
||
min_dist = float('inf')
|
||
best_idx = None
|
||
|
||
for j, lm_list in enumerate(detected_hands):
|
||
if j in used_indices:
|
||
continue
|
||
|
||
curr_c = self.get_centroid(lm_list)
|
||
# Distance check (squared)
|
||
dist = (curr_c[0] - last_pos[0])**2 + (curr_c[1] - last_pos[1])**2
|
||
|
||
# Since frames are 1/60th of a sec apart, hands shouldn't move more than 5%
|
||
if dist < 0.05 and dist < min_dist:
|
||
min_dist = dist
|
||
best_idx = j
|
||
|
||
if best_idx is not None:
|
||
new_centroids[hand_id] = self.get_centroid(detected_hands[best_idx])
|
||
used_indices.add(best_idx)
|
||
|
||
# Carry over 'ghost' positions for hands that were not found this frame
|
||
for hand_id, pos in last_known.items():
|
||
if hand_id not in new_centroids:
|
||
new_centroids[hand_id] = pos
|
||
|
||
# If new hands appear that weren't tracked before, add them
|
||
for j in range(len(detected_hands)):
|
||
if j not in used_indices:
|
||
new_id = 0
|
||
while new_id in new_centroids:
|
||
new_id += 1
|
||
new_centroids[new_id] = self.get_centroid(detected_hands[j])
|
||
|
||
return new_centroids
|
||
|
||
|
||
# def save_project(self, onCrash=False):
|
||
|
||
# if not onCrash:
|
||
# filename, _ = QFileDialog.getSaveFileName(
|
||
# self, "Save Project", "", "SPARK Project (*.spark)"
|
||
# )
|
||
# if not filename:
|
||
# return
|
||
# else:
|
||
# if PLATFORM_NAME == "darwin":
|
||
# filename = os.path.join(os.path.dirname(sys.executable), "../../../sparks_autosave.spark")
|
||
# else:
|
||
# filename = os.path.join(os.getcwd(), "sparks_autosave.spark")
|
||
|
||
# try:
|
||
# # Ensure the filename has the proper extension
|
||
# if not filename.endswith(".spark"):
|
||
# filename += ".spark"
|
||
|
||
# project_path = Path(filename).resolve()
|
||
# project_dir = project_path.parent
|
||
|
||
# file_list = [
|
||
# str(PurePosixPath(Path(bubble.file_path).resolve().relative_to(project_dir)))
|
||
# for bubble in self.bubble_widgets.values()
|
||
# ]
|
||
|
||
# progress_states = {
|
||
# str(PurePosixPath(Path(bubble.file_path).resolve().relative_to(project_dir))): bubble.current_step
|
||
# for bubble in self.bubble_widgets.values()
|
||
# }
|
||
|
||
# project_data = {
|
||
# "file_list": file_list,
|
||
# "progress_states": progress_states,
|
||
# "raw_haemo_dict": self.raw_haemo_dict,
|
||
# "epochs_dict": self.epochs_dict,
|
||
# "fig_bytes_dict": self.fig_bytes_dict,
|
||
# "cha_dict": self.cha_dict,
|
||
# "contrast_results_dict": self.contrast_results_dict,
|
||
# "df_ind_dict": self.df_ind_dict,
|
||
# "design_matrix_dict": self.design_matrix_dict,
|
||
# "age_dict": self.age_dict,
|
||
# "gender_dict": self.gender_dict,
|
||
# "group_dict": self.group_dict,
|
||
# "valid_dict": self.valid_dict,
|
||
# }
|
||
|
||
# def sanitize(obj):
|
||
# if isinstance(obj, Path):
|
||
# return str(PurePosixPath(obj))
|
||
# elif isinstance(obj, dict):
|
||
# return {sanitize(k): sanitize(v) for k, v in obj.items()}
|
||
# elif isinstance(obj, list):
|
||
# return [sanitize(i) for i in obj]
|
||
# return obj
|
||
|
||
# project_data = sanitize(project_data)
|
||
|
||
# with open(filename, "wb") as f:
|
||
# pickle.dump(project_data, f)
|
||
|
||
# QMessageBox.information(self, "Success", f"Project saved to:\n{filename}")
|
||
|
||
# except Exception as e:
|
||
# if not onCrash:
|
||
# QMessageBox.critical(self, "Error", f"Failed to save project:\n{e}")
|
||
|
||
|
||
|
||
# def load_project(self):
|
||
# filename, _ = QFileDialog.getOpenFileName(
|
||
# self, "Load Project", "", "SPARK Project (*.spark)"
|
||
# )
|
||
# if not filename:
|
||
# return
|
||
|
||
# try:
|
||
# with open(filename, "rb") as f:
|
||
# data = pickle.load(f)
|
||
|
||
# self.raw_haemo_dict = data.get("raw_haemo_dict", {})
|
||
# self.epochs_dict = data.get("epochs_dict", {})
|
||
# self.fig_bytes_dict = data.get("fig_bytes_dict", {})
|
||
# self.cha_dict = data.get("cha_dict", {})
|
||
# self.contrast_results_dict = data.get("contrast_results_dict", {})
|
||
# self.df_ind_dict = data.get("df_ind_dict", {})
|
||
# self.design_matrix_dict = data.get("design_matrix_dict", {})
|
||
# self.age_dict = data.get("age_dict", {})
|
||
# self.gender_dict = data.get("gender_dict", {})
|
||
# self.group_dict = data.get("group_dict", {})
|
||
# self.valid_dict = data.get("valid_dict", {})
|
||
|
||
# project_dir = Path(filename).parent
|
||
|
||
# # Convert saved relative paths to absolute paths
|
||
# file_list = [str((project_dir / Path(rel_path)).resolve()) for rel_path in data["file_list"]]
|
||
|
||
# # Also resolve progress_states with updated paths
|
||
# raw_progress = data.get("progress_states", {})
|
||
# progress_states = {
|
||
# str((project_dir / Path(rel_path)).resolve()): step
|
||
# for rel_path, step in raw_progress.items()
|
||
# }
|
||
|
||
# self.show_files_as_bubbles_from_list(file_list, progress_states, filename)
|
||
|
||
# # Re-enable buttons
|
||
# # self.button1.setVisible(True)
|
||
# self.button3.setVisible(True)
|
||
|
||
# QMessageBox.information(self, "Loaded", f"Project loaded from:\n{filename}")
|
||
|
||
# except Exception as e:
|
||
# QMessageBox.critical(self, "Error", f"Failed to load project:\n{e}")
|
||
|
||
|
||
def placeholder(self):
|
||
QMessageBox.information(self, "Placeholder", "This feature is not implemented yet.")
|
||
|
||
def save_metadata(self, file_path):
|
||
if not file_path:
|
||
return
|
||
|
||
self.file_metadata[file_path] = {
|
||
key: field.text()
|
||
for key, field in self.meta_fields.items()
|
||
}
|
||
|
||
def get_all_metadata(self):
|
||
# First, make sure current file's edits are saved
|
||
|
||
for field in self.meta_fields.values():
|
||
field.clearFocus()
|
||
|
||
# Save current file's metadata
|
||
if self.current_file:
|
||
self.save_metadata(self.current_file)
|
||
|
||
return self.file_metadata
|
||
|
||
|
||
def cancel_task(self):
|
||
self.button1.clicked.disconnect(self.cancel_task)
|
||
self.button1.setText("Stopping...")
|
||
|
||
if hasattr(self, "result_process") and self.result_process.is_alive():
|
||
parent = psutil.Process(self.result_process.pid)
|
||
children = parent.children(recursive=True)
|
||
for child in children:
|
||
try:
|
||
child.kill()
|
||
except psutil.NoSuchProcess:
|
||
pass
|
||
self.result_process.terminate()
|
||
self.result_process.join()
|
||
|
||
if hasattr(self, "result_timer") and self.result_timer.isActive():
|
||
self.result_timer.stop()
|
||
|
||
# if hasattr(self, "result_process") and self.result_process.is_alive():
|
||
# self.result_process.terminate() # Forcefully terminate the process
|
||
# self.result_process.join() # Wait for it to properly close
|
||
|
||
# # Stop the QTimer if running
|
||
# if hasattr(self, "result_timer") and self.result_timer.isActive():
|
||
# self.result_timer.stop()
|
||
|
||
|
||
self.statusbar.showMessage("Processing cancelled.")
|
||
self.button1.clicked.connect(self.on_run_task)
|
||
self.button1.setText("Process")
|
||
|
||
|
||
'''MODULE FILE'''
|
||
def on_run_task(self):
|
||
|
||
self.button1.clicked.disconnect(self.on_run_task)
|
||
self.button1.setText("Cancel")
|
||
self.button1.clicked.connect(self.cancel_task)
|
||
|
||
# Clear previous processing widgets (if any)
|
||
# for obs_id, widgets in getattr(self, "processing_widgets", {}).items():
|
||
# for w in widgets.values():
|
||
# if isinstance(w, QWidget):
|
||
# w.deleteLater()
|
||
|
||
self.clear_bubble_widgets()
|
||
|
||
self.processing_widgets = {}
|
||
self.processing_threads = []
|
||
|
||
self.output_dir = self.get_output_directory()
|
||
if not self.output_dir:
|
||
# Revert button state if user cancels
|
||
self.button1.clicked.disconnect(self.cancel_task)
|
||
self.button1.setText("Process")
|
||
self.button1.clicked.connect(self.on_run_task)
|
||
self.statusBar().showMessage("Processing cancelled by user.")
|
||
return
|
||
|
||
selected_files_to_process = []
|
||
|
||
# 2️⃣ Iterate through all participant/camera combinations to find what to process
|
||
# We must iterate over all entries in self.selection_widgets (the structure built by build_preview_grid)
|
||
for obs_id, cam_dropdowns in self.selection_widgets.items():
|
||
if obs_id == "participant_selection":
|
||
continue # skip the participant-level dropdowns
|
||
for cam_id, dropdown in cam_dropdowns.items():
|
||
|
||
selected_hand_idx = dropdown.currentData()
|
||
|
||
# Check if the selection is NOT 'Skip' (-1)
|
||
if selected_hand_idx != -1:
|
||
state_data = self.camera_state.get((obs_id, cam_id))
|
||
if state_data:
|
||
selected_files_to_process.append({
|
||
"obs_id": obs_id,
|
||
"cam_id": cam_id,
|
||
"hand_idx": selected_hand_idx,
|
||
"path": state_data.get("path", "Unknown")
|
||
})
|
||
|
||
if not selected_files_to_process:
|
||
self.statusBar().showMessage("No files selected for processing (use dropdowns).")
|
||
# Revert button state
|
||
self.button1.clicked.disconnect(self.cancel_task)
|
||
self.button1.setText("Process")
|
||
self.button1.clicked.connect(self.on_run_task)
|
||
return
|
||
|
||
self.statusBar().showMessage(f"Starting to process {len(selected_files_to_process)} files...")
|
||
QApplication.processEvents() # Refresh GUI before starting threads
|
||
|
||
for file_info in selected_files_to_process:
|
||
obs_id = file_info["obs_id"]
|
||
cam_id = file_info["cam_id"]
|
||
selected_hand_idx = file_info["hand_idx"] # THIS is fixed per iteration
|
||
|
||
# ... (thread logic setup) ...
|
||
|
||
# --- PROGRESS ROW UI SETUP (50% Label | 40% Bar | 10% Time) ---
|
||
row = QWidget()
|
||
row_layout = QHBoxLayout(row)
|
||
row_layout.setContentsMargins(0, 0, 0, 0)
|
||
|
||
# LEFT SIDE (50%): Filename Label
|
||
filename = os.path.basename(file_info["path"])
|
||
filename_label = QLabel(f"P {obs_id}/{cam_id}: {filename}")
|
||
filename_label.setWordWrap(True)
|
||
row_layout.addWidget(filename_label, stretch=5) # 5 units = 50%
|
||
|
||
# MIDDLE (40%): Progress Bar
|
||
progress_bar = QProgressBar()
|
||
progress_bar.setRange(0, 100)
|
||
row_layout.addWidget(progress_bar, stretch=4) # 4 units = 40%
|
||
|
||
# RIGHT SIDE (10%): Time Indicator Label
|
||
time_label = QLabel("00:00/00:00") # Default placeholder
|
||
time_label.setAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter)
|
||
row_layout.addWidget(time_label, stretch=1) # 1 unit = 10%
|
||
|
||
# Add the combined row to the bubble layout area
|
||
row_index = self.bubble_layout.rowCount()
|
||
self.bubble_layout.addWidget(row, row_index, 0, 1, 1)
|
||
|
||
# Save widgets for thread updates
|
||
state_key = (str(obs_id), str(cam_id))
|
||
self.processing_widgets[state_key] = {
|
||
"label": filename_label,
|
||
"progress": progress_bar,
|
||
# 👇 NEW: Save the time label reference
|
||
"time_label": time_label
|
||
}
|
||
# --- THREAD START ---
|
||
output_csv = f"{obs_id}_{cam_id}_processed.csv"
|
||
full_video_path = file_info["path"]
|
||
# ⚠️ Ensure ParticipantProcessor is defined to accept self.observations_root
|
||
processor = ParticipantProcessor2(
|
||
obs_id=obs_id,
|
||
selected_cam_id=cam_id,
|
||
selected_hand_idx=selected_hand_idx,
|
||
video_path=full_video_path, # CRITICAL PATH FIX
|
||
hide_preview=False, # Assuming default for now
|
||
output_csv=output_csv,
|
||
output_dir=self.output_dir,
|
||
initial_wrists=self.camera_state[(obs_id, cam_id)].get("initial_wrists")
|
||
)
|
||
|
||
# Connect signals to the UI elements
|
||
processor.progress_updated.connect(progress_bar.setValue)
|
||
# Replace your current connection with this:
|
||
processor.finished_processing.connect(self.on_processing_finished)
|
||
processor.time_updated.connect(time_label.setText) # Connects to THIS time_label
|
||
self.processing_threads.append(processor)
|
||
|
||
processor.start()
|
||
|
||
# Set the vertical stretch to push all progress bars to the top
|
||
num_rows = self.bubble_layout.rowCount()
|
||
self.bubble_layout.setRowStretch(num_rows, 1)
|
||
|
||
|
||
def check_for_pipeline_results(self):
|
||
return
|
||
|
||
def on_processing_finished(self, obs_id, cam_id):
|
||
state_key = (obs_id, cam_id)
|
||
if state_key in self.processing_widgets:
|
||
widgets = self.processing_widgets[state_key]
|
||
widgets["progress"].setValue(100)
|
||
label = self.processing_widgets[state_key]["label"]
|
||
label.setText(f"P {obs_id}/{cam_id}: ✅ COMPLETE")
|
||
|
||
# Check if all threads are finished to revert the main button
|
||
if not any(t.isRunning() for t in self.processing_threads):
|
||
self.cancel_task() # Reverts the button and cleans up
|
||
self.statusBar().showMessage("All processing complete!")
|
||
|
||
|
||
def cancel_task(self):
|
||
"""Stops all active processing threads and reverts UI to 'Process' state."""
|
||
|
||
if not self.processing_threads:
|
||
return
|
||
|
||
self.statusBar().showMessage("Cancelling processing... Please wait.")
|
||
|
||
# 1. Stop all threads gracefully
|
||
for thread in self.processing_threads:
|
||
if thread.isRunning():
|
||
# Assuming ParticipantProcessor has a .stop() method
|
||
thread.stop()
|
||
thread.wait() # Wait for thread to exit its loop
|
||
|
||
# 2. Clean up threads and widgets
|
||
self.processing_threads.clear()
|
||
|
||
# 3. Revert button and status bar
|
||
self.button1.clicked.disconnect(self.cancel_task)
|
||
self.button1.setText("Process")
|
||
self.button1.clicked.connect(self.on_run_task)
|
||
|
||
# 4. Clear the progress bars display
|
||
#self.clear_bubble_widgets()
|
||
self.statusBar().showMessage("Processing cancelled. Ready to load or process.")
|
||
|
||
def show_error_popup(self, title, error_message, traceback_str=""):
|
||
msgbox = QMessageBox(self)
|
||
msgbox.setIcon(QMessageBox.Warning)
|
||
msgbox.setWindowTitle("Warning - SPARKS")
|
||
|
||
message = (
|
||
f"SPARKS has encountered an error processing the file {title}.<br><br>"
|
||
"This error was likely due to incorrect parameters on the right side of the screen and not an error with your data. "
|
||
"Processing of the remaining files continues in the background and this participant will be ignored in the analysis. "
|
||
"If you think the parameters on the right side are correct for your data, raise an issue <a href='https://git.research.dezeeuw.ca/tyler/sparks/issues'>here</a>.<br><br>"
|
||
f"Error message: {error_message}"
|
||
)
|
||
|
||
msgbox.setTextFormat(Qt.TextFormat.RichText)
|
||
msgbox.setText(message)
|
||
msgbox.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction)
|
||
|
||
# Add traceback to detailed text
|
||
if traceback_str:
|
||
msgbox.setDetailedText(traceback_str)
|
||
|
||
msgbox.setStandardButtons(QMessageBox.Ok)
|
||
msgbox.exec_()
|
||
|
||
|
||
def cleanup_after_process(self):
|
||
|
||
if hasattr(self, 'result_process'):
|
||
self.result_process.join(timeout=0)
|
||
if self.result_process.is_alive():
|
||
self.result_process.terminate()
|
||
self.result_process.join()
|
||
|
||
if hasattr(self, 'result_queue'):
|
||
if 'AutoProxy' in repr(self.result_queue):
|
||
pass
|
||
else:
|
||
self.result_queue.close()
|
||
self.result_queue.join_thread()
|
||
|
||
if hasattr(self, 'progress_queue'):
|
||
if 'AutoProxy' in repr(self.progress_queue):
|
||
pass
|
||
else:
|
||
self.progress_queue.close()
|
||
self.progress_queue.join_thread()
|
||
|
||
# Shutdown manager to kill its server process and clean up
|
||
if hasattr(self, 'manager'):
|
||
self.manager.shutdown()
|
||
|
||
|
||
|
||
|
||
'''UPDATER'''
|
||
def manual_check_for_updates(self):
|
||
self.local_check_thread = LocalPendingUpdateCheckThread(CURRENT_VERSION, self.platform_suffix)
|
||
self.local_check_thread.pending_update_found.connect(self.on_pending_update_found)
|
||
self.local_check_thread.no_pending_update.connect(self.on_no_pending_update)
|
||
self.local_check_thread.start()
|
||
|
||
def on_pending_update_found(self, version, folder_path):
|
||
self.statusBar().showMessage(f"Pending update found: version {version}")
|
||
self.pending_update_version = version
|
||
self.pending_update_path = folder_path
|
||
self.show_pending_update_popup()
|
||
|
||
def on_no_pending_update(self):
|
||
# No pending update found locally, start server check directly
|
||
self.statusBar().showMessage("No pending local update found. Checking server...")
|
||
self.start_update_check_thread()
|
||
|
||
def show_pending_update_popup(self):
|
||
msg_box = QMessageBox(self)
|
||
msg_box.setWindowTitle("Pending Update Found")
|
||
msg_box.setText(f"A previously downloaded update (version {self.pending_update_version}) is available at:\n{self.pending_update_path}\nWould you like to install it now?")
|
||
install_now_button = msg_box.addButton("Install Now", QMessageBox.ButtonRole.AcceptRole)
|
||
install_later_button = msg_box.addButton("Install Later", QMessageBox.ButtonRole.RejectRole)
|
||
msg_box.exec()
|
||
|
||
if msg_box.clickedButton() == install_now_button:
|
||
self.install_update(self.pending_update_path)
|
||
else:
|
||
self.statusBar().showMessage("Pending update available. Install later.")
|
||
# After user dismisses, still check the server for new updates
|
||
self.start_update_check_thread()
|
||
|
||
def start_update_check_thread(self):
|
||
self.check_thread = UpdateCheckThread()
|
||
self.check_thread.download_requested.connect(self.on_server_update_requested)
|
||
self.check_thread.no_update_available.connect(self.on_server_no_update)
|
||
self.check_thread.error_occurred.connect(self.on_error)
|
||
self.check_thread.start()
|
||
|
||
def on_server_no_update(self):
|
||
self.statusBar().showMessage("No new updates found on server.", 5000)
|
||
|
||
def on_server_update_requested(self, download_url, latest_version):
|
||
if self.pending_update_version:
|
||
cmp = self.version_compare(latest_version, self.pending_update_version)
|
||
if cmp > 0:
|
||
# Server version is newer than pending update
|
||
self.statusBar().showMessage(f"Newer version {latest_version} available on server. Removing old pending update...")
|
||
try:
|
||
shutil.rmtree(self.pending_update_path)
|
||
self.statusBar().showMessage(f"Deleted old update folder: {self.pending_update_path}")
|
||
except Exception as e:
|
||
self.statusBar().showMessage(f"Failed to delete old update folder: {e}")
|
||
|
||
# Clear pending update info so new download proceeds
|
||
self.pending_update_version = None
|
||
self.pending_update_path = None
|
||
|
||
# Download the new update
|
||
self.download_update(download_url, latest_version)
|
||
elif cmp == 0:
|
||
# Versions equal, no download needed
|
||
self.statusBar().showMessage(f"Pending update version {self.pending_update_version} is already latest. No download needed.")
|
||
else:
|
||
# Server version older than pending? Unlikely but just keep pending update
|
||
self.statusBar().showMessage(f"Pending update version {self.pending_update_version} is newer than server version. No action.")
|
||
else:
|
||
# No pending update, just download
|
||
self.download_update(download_url, latest_version)
|
||
|
||
def download_update(self, download_url, latest_version):
|
||
self.statusBar().showMessage("Downloading update...")
|
||
self.download_thread = UpdateDownloadThread(download_url, latest_version)
|
||
self.download_thread.update_ready.connect(self.on_update_ready)
|
||
self.download_thread.error_occurred.connect(self.on_error)
|
||
self.download_thread.start()
|
||
|
||
def on_update_ready(self, latest_version, extract_folder):
|
||
self.statusBar().showMessage("Update downloaded and extracted.")
|
||
|
||
msg_box = QMessageBox(self)
|
||
msg_box.setWindowTitle("Update Ready")
|
||
msg_box.setText(f"Version {latest_version} has been downloaded and extracted to:\n{extract_folder}\nWould you like to install it now?")
|
||
install_now_button = msg_box.addButton("Install Now", QMessageBox.ButtonRole.AcceptRole)
|
||
install_later_button = msg_box.addButton("Install Later", QMessageBox.ButtonRole.RejectRole)
|
||
|
||
msg_box.exec()
|
||
|
||
if msg_box.clickedButton() == install_now_button:
|
||
self.install_update(extract_folder)
|
||
else:
|
||
self.statusBar().showMessage("Update ready. Install later.")
|
||
|
||
|
||
def install_update(self, extract_folder):
|
||
# Path to updater executable
|
||
|
||
if PLATFORM_NAME == 'windows':
|
||
updater_path = os.path.join(os.getcwd(), "sparks_updater.exe")
|
||
elif PLATFORM_NAME == 'darwin':
|
||
if getattr(sys, 'frozen', False):
|
||
updater_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks_updater.app")
|
||
else:
|
||
updater_path = os.path.join(os.getcwd(), "../sparks_updater.app")
|
||
|
||
elif PLATFORM_NAME == 'linux':
|
||
updater_path = os.path.join(os.getcwd(), "sparks_updater")
|
||
else:
|
||
updater_path = os.getcwd()
|
||
|
||
if not os.path.exists(updater_path):
|
||
QMessageBox.critical(self, "Error", f"Updater not found at:\n{updater_path}. The absolute path was {os.path.abspath(updater_path)}")
|
||
return
|
||
|
||
# Launch updater with extracted folder path as argument
|
||
try:
|
||
# Pass current app's executable path for updater to relaunch
|
||
main_app_executable = os.path.abspath(sys.argv[0])
|
||
|
||
print(f'Launching updater with: "{updater_path}" "{extract_folder}" "{main_app_executable}"')
|
||
|
||
if PLATFORM_NAME == 'darwin':
|
||
subprocess.Popen(['open', updater_path, '--args', extract_folder, main_app_executable])
|
||
else:
|
||
subprocess.Popen([updater_path, f'{extract_folder}', f'{main_app_executable}'], cwd=os.path.dirname(updater_path))
|
||
|
||
# Close the current app so updater can replace files
|
||
sys.exit(0)
|
||
|
||
except Exception as e:
|
||
QMessageBox.critical(self, "Error", f"[Updater Launch Failed]\n{str(e)}\n{traceback.format_exc()}")
|
||
|
||
def on_error(self, message):
|
||
# print(f"Error: {message}")
|
||
self.statusBar().showMessage(f"Error occurred during update process. {message}")
|
||
|
||
def version_compare(self, v1, v2):
|
||
def normalize(v): return [int(x) for x in v.split(".")]
|
||
return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2))
|
||
|
||
|
||
|
||
def closeEvent(self, event):
|
||
# Gracefully shut down multiprocessing children
|
||
print("Window is closing. Cleaning up...")
|
||
|
||
if hasattr(self, 'manager'):
|
||
self.manager.shutdown()
|
||
|
||
for child in self.findChildren(QWidget):
|
||
if child is not self and child.isVisible():
|
||
child.close()
|
||
|
||
kill_child_processes()
|
||
|
||
event.accept()
|
||
|
||
|
||
def wait_for_process_to_exit(process_name, timeout=10):
|
||
"""
|
||
Waits for a process with the specified name to exit within a timeout period.
|
||
|
||
Args:
|
||
process_name (str): Name (or part of the name) of the process to wait for.
|
||
timeout (int, optional): Maximum time to wait in seconds. Defaults to 10.
|
||
|
||
Returns:
|
||
bool: True if the process exited before the timeout, False otherwise.
|
||
"""
|
||
|
||
print(f"Waiting for {process_name} to exit...")
|
||
deadline = time.time() + timeout
|
||
while time.time() < deadline:
|
||
still_running = False
|
||
for proc in psutil.process_iter(['name']):
|
||
try:
|
||
if proc.info['name'] and process_name.lower() in proc.info['name'].lower():
|
||
still_running = True
|
||
print(f"Still running: {proc.info['name']} (PID: {proc.pid})")
|
||
break
|
||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||
continue
|
||
if not still_running:
|
||
print(f"{process_name} has exited.")
|
||
return True
|
||
time.sleep(0.5)
|
||
print(f"{process_name} did not exit in time.")
|
||
return False
|
||
|
||
|
||
def finish_update_if_needed():
|
||
"""
|
||
Completes a pending application update if '--finish-update' is present in the command-line arguments.
|
||
"""
|
||
|
||
if "--finish-update" in sys.argv:
|
||
print("Finishing update...")
|
||
|
||
if PLATFORM_NAME == 'darwin':
|
||
app_dir = '/tmp/sparkstempupdate'
|
||
else:
|
||
app_dir = os.getcwd()
|
||
|
||
# 1. Find update folder
|
||
update_folder = None
|
||
for entry in os.listdir(app_dir):
|
||
entry_path = os.path.join(app_dir, entry)
|
||
if os.path.isdir(entry_path) and entry.startswith("sparks-") and entry.endswith("-" + PLATFORM_NAME):
|
||
update_folder = os.path.join(app_dir, entry)
|
||
break
|
||
|
||
if update_folder is None:
|
||
print("No update folder found. Skipping update steps.")
|
||
return
|
||
|
||
if PLATFORM_NAME == 'darwin':
|
||
update_folder = os.path.join(update_folder, "sparks-darwin")
|
||
|
||
# 2. Wait for sparks_updater to exit
|
||
print("Waiting for sparks_updater to exit...")
|
||
for proc in psutil.process_iter(['pid', 'name']):
|
||
if proc.info['name'] and "sparks_updater" in proc.info['name'].lower():
|
||
try:
|
||
proc.wait(timeout=5)
|
||
except psutil.TimeoutExpired:
|
||
print("Force killing lingering sparks_updater")
|
||
proc.kill()
|
||
|
||
# 3. Replace the updater
|
||
if PLATFORM_NAME == 'windows':
|
||
new_updater = os.path.join(update_folder, "sparks_updater.exe")
|
||
dest_updater = os.path.join(app_dir, "sparks_updater.exe")
|
||
|
||
elif PLATFORM_NAME == 'darwin':
|
||
new_updater = os.path.join(update_folder, "sparks_updater.app")
|
||
dest_updater = os.path.abspath(os.path.join(sys.executable, "../../../../sparks_updater.app"))
|
||
|
||
elif PLATFORM_NAME == 'linux':
|
||
new_updater = os.path.join(update_folder, "sparks_updater")
|
||
dest_updater = os.path.join(app_dir, "sparks_updater")
|
||
|
||
else:
|
||
print("No platform??")
|
||
new_updater = os.getcwd()
|
||
dest_updater = os.getcwd()
|
||
|
||
print(f"New updater is {new_updater}")
|
||
print(f"Dest updater is {dest_updater}")
|
||
|
||
print("Writable?", os.access(dest_updater, os.W_OK))
|
||
print("Executable path:", sys.executable)
|
||
print("Trying to copy:", new_updater, "->", dest_updater)
|
||
|
||
if os.path.exists(new_updater):
|
||
try:
|
||
if os.path.exists(dest_updater):
|
||
if PLATFORM_NAME == 'darwin':
|
||
try:
|
||
if os.path.isdir(dest_updater):
|
||
shutil.rmtree(dest_updater)
|
||
print(f"Deleted directory: {dest_updater}")
|
||
else:
|
||
os.remove(dest_updater)
|
||
print(f"Deleted file: {dest_updater}")
|
||
except Exception as e:
|
||
print(f"Error deleting {dest_updater}: {e}")
|
||
else:
|
||
os.remove(dest_updater)
|
||
|
||
if PLATFORM_NAME == 'darwin':
|
||
wait_for_process_to_exit("sparks_updater", timeout=10)
|
||
subprocess.check_call(["ditto", new_updater, dest_updater])
|
||
else:
|
||
shutil.copy2(new_updater, dest_updater)
|
||
|
||
if PLATFORM_NAME in ('linux', 'darwin'):
|
||
os.chmod(dest_updater, 0o755)
|
||
|
||
if PLATFORM_NAME == 'darwin':
|
||
remove_quarantine(dest_updater)
|
||
|
||
print("sparks_updater replaced.")
|
||
except Exception as e:
|
||
print(f"Failed to replace sparks_updater: {e}")
|
||
|
||
# 4. Delete the update folder
|
||
try:
|
||
if PLATFORM_NAME == 'darwin':
|
||
shutil.rmtree(app_dir)
|
||
else:
|
||
shutil.rmtree(update_folder)
|
||
except Exception as e:
|
||
print(f"Failed to delete update folder: {e}")
|
||
|
||
QMessageBox.information(None, "Update Complete", "The application has been successfully updated.")
|
||
sys.argv.remove("--finish-update")
|
||
|
||
|
||
def remove_quarantine(app_path):
|
||
"""
|
||
Removes the macOS quarantine attribute from the specified application path.
|
||
"""
|
||
|
||
script = f'''
|
||
do shell script "xattr -d -r com.apple.quarantine {shlex.quote(app_path)}" with administrator privileges with prompt "SPARKS needs privileges to finish the update. (2/2)"
|
||
'''
|
||
try:
|
||
subprocess.run(['osascript', '-e', script], check=True)
|
||
print("✅ Quarantine attribute removed.")
|
||
except subprocess.CalledProcessError as e:
|
||
print("❌ Failed to remove quarantine attribute.")
|
||
print(e)
|
||
|
||
|
||
def resource_path(relative_path):
|
||
"""
|
||
Get absolute path to resource regardless of running directly or packaged using PyInstaller
|
||
"""
|
||
|
||
if hasattr(sys, '_MEIPASS'):
|
||
# PyInstaller bundle path
|
||
base_path = sys._MEIPASS
|
||
else:
|
||
base_path = os.path.dirname(os.path.abspath(__file__))
|
||
|
||
return os.path.join(base_path, relative_path)
|
||
|
||
|
||
def kill_child_processes():
|
||
"""
|
||
Goodbye children
|
||
"""
|
||
|
||
try:
|
||
parent = psutil.Process(os.getpid())
|
||
children = parent.children(recursive=True)
|
||
for child in children:
|
||
try:
|
||
child.kill()
|
||
except psutil.NoSuchProcess:
|
||
pass
|
||
psutil.wait_procs(children, timeout=5)
|
||
except Exception as e:
|
||
print(f"Error killing child processes: {e}")
|
||
|
||
|
||
def exception_hook(exc_type, exc_value, exc_traceback):
|
||
"""
|
||
Method that will display a popup when the program hard crashes containg what went wrong
|
||
"""
|
||
|
||
error_msg = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
||
print(error_msg) # also print to console
|
||
|
||
kill_child_processes()
|
||
|
||
# Show error message box
|
||
# Make sure QApplication exists (or create a minimal one)
|
||
app = QApplication.instance()
|
||
if app is None:
|
||
app = QApplication(sys.argv)
|
||
|
||
show_critical_error(error_msg)
|
||
|
||
# Exit the app after user acknowledges
|
||
sys.exit(1)
|
||
|
||
def show_critical_error(error_msg):
|
||
msg_box = QMessageBox()
|
||
msg_box.setIcon(QMessageBox.Icon.Critical)
|
||
msg_box.setWindowTitle("Something went wrong!")
|
||
|
||
if PLATFORM_NAME == "darwin":
|
||
log_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks.log")
|
||
log_path2 = os.path.join(os.path.dirname(sys.executable), "../../../sparks_error.log")
|
||
save_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks_autosave.spark")
|
||
|
||
else:
|
||
log_path = os.path.join(os.getcwd(), "sparks.log")
|
||
log_path2 = os.path.join(os.getcwd(), "sparks_error.log")
|
||
save_path = os.path.join(os.getcwd(), "sparks_autosave.spark")
|
||
|
||
|
||
shutil.copy(log_path, log_path2)
|
||
log_path2 = Path(log_path2).absolute().as_posix()
|
||
autosave_path = Path(save_path).absolute().as_posix()
|
||
log_link = f"file:///{log_path2}"
|
||
autosave_link = f"file:///{autosave_path}"
|
||
|
||
window.save_project(True)
|
||
|
||
message = (
|
||
"SPARKS has encountered an unrecoverable error and needs to close.<br><br>"
|
||
f"We are sorry for the inconvenience. An autosave was attempted to be saved to <a href='{autosave_link}'>{autosave_path}</a>, but it may not have been saved. "
|
||
"If the file was saved, it still may not be intact, openable, or contain the correct data. Use the autosave at your discretion.<br><br>"
|
||
"This unrecoverable error was likely due to an error with SPARKS and not your data.<br>"
|
||
f"Please raise an issue <a href='https://git.research.dezeeuw.ca/tyler/sparks/issues'>here</a> and attach the error file located at <a href='{log_link}'>{log_path2}</a><br><br>"
|
||
f"<pre>{error_msg}</pre>"
|
||
)
|
||
|
||
msg_box.setTextFormat(Qt.TextFormat.RichText)
|
||
msg_box.setText(message)
|
||
msg_box.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction)
|
||
msg_box.setStandardButtons(QMessageBox.StandardButton.Ok)
|
||
|
||
msg_box.exec()
|
||
|
||
if __name__ == "__main__":
|
||
# Redirect exceptions to the popup window
|
||
sys.excepthook = exception_hook
|
||
|
||
# Set up application logging
|
||
if PLATFORM_NAME == "darwin":
|
||
log_path = os.path.join(os.path.dirname(sys.executable), "../../../sparks.log")
|
||
else:
|
||
log_path = os.path.join(os.getcwd(), "sparks.log")
|
||
|
||
try:
|
||
os.remove(log_path)
|
||
except:
|
||
pass
|
||
|
||
sys.stdout = open(log_path, "a", buffering=1)
|
||
sys.stderr = sys.stdout
|
||
print(f"\n=== App started at {datetime.now()} ===\n")
|
||
|
||
freeze_support() # Required for PyInstaller + multiprocessing
|
||
|
||
# Only run GUI in the main process
|
||
if current_process().name == 'MainProcess':
|
||
app = QApplication(sys.argv)
|
||
finish_update_if_needed()
|
||
window = MainApplication()
|
||
|
||
if PLATFORM_NAME == "darwin":
|
||
app.setWindowIcon(QIcon(resource_path("icons/main.icns")))
|
||
window.setWindowIcon(QIcon(resource_path("icons/main.icns")))
|
||
else:
|
||
app.setWindowIcon(QIcon(resource_path("icons/main.ico")))
|
||
window.setWindowIcon(QIcon(resource_path("icons/main.ico")))
|
||
window.show()
|
||
sys.exit(app.exec())
|
||
|
||
# Not 4000 lines yay! |