4088 lines
154 KiB
Python
4088 lines
154 KiB
Python
"""
|
|
Filename: main.py
|
|
Description: BLAZES main executable
|
|
|
|
Author: Tyler de Zeeuw
|
|
License: GPL-3.0
|
|
"""
|
|
|
|
# Built-in imports
|
|
import os
|
|
import csv
|
|
import sys
|
|
import json
|
|
import glob
|
|
import shutil
|
|
import inspect
|
|
import platform
|
|
import traceback
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from multiprocessing import current_process, freeze_support, Process, Queue
|
|
|
|
# External library imports
|
|
import numpy as np
|
|
import pandas as pd
|
|
import psutil
|
|
import joblib
|
|
import cv2
|
|
from ultralytics import YOLO
|
|
|
|
from updater import finish_update_if_needed, UpdateManager, LocalPendingUpdateCheckThread
|
|
from predictor import GeneralPredictor
|
|
from pose_worker import run_pose_analysis
|
|
from batch_processing import BatchProcessorDialog
|
|
|
|
import PySide6
|
|
from PySide6.QtWidgets import (QApplication, QDoubleSpinBox, QFormLayout, QGraphicsItem, QGraphicsProxyWidget, QGraphicsTextItem, QLineEdit, QListWidget, QListWidgetItem, QMainWindow, QProgressDialog, QSizePolicy, QStyleOptionGraphicsItem, QTabBar, QWidget, QVBoxLayout, QGraphicsView, QGraphicsScene,
|
|
QHBoxLayout, QSplitter, QLabel, QPushButton, QComboBox, QInputDialog, QGraphicsRectItem,
|
|
QFileDialog, QScrollArea, QMessageBox, QSlider, QTextEdit, QGroupBox, QGridLayout, QCheckBox, QTabWidget, QProgressBar)
|
|
from PySide6.QtCore import QEvent, Qt, QThread, Signal, QUrl, QRectF, QPointF, QRect, QSizeF, QTimer
|
|
from PySide6.QtGui import QCursor, QDoubleValidator, QGuiApplication, QPainter, QColor, QFont, QPen, QBrush, QAction, QKeySequence, QIcon, QTextOption, QImage, QPixmap, QTransform
|
|
from PySide6.QtMultimedia import QMediaPlayer, QAudioOutput
|
|
from PySide6.QtMultimediaWidgets import QGraphicsVideoItem
|
|
|
|
|
|
VERBOSITY = 1
|
|
CURRENT_VERSION = "0.1.0"
|
|
APP_NAME = "blazes"
|
|
API_URL = f"https://git.research.dezeeuw.ca/api/v1/repos/tyler/{APP_NAME}/releases"
|
|
API_URL_SECONDARY = f"https://git.research2.dezeeuw.ca/api/v1/repos/tyler/{APP_NAME}/releases"
|
|
PLATFORM_NAME = platform.system().lower()
|
|
|
|
|
|
|
|
def debug_print():
|
|
if VERBOSITY:
|
|
frame = inspect.currentframe().f_back
|
|
qualname = frame.f_code.co_qualname
|
|
print(qualname)
|
|
|
|
|
|
# Ordered according to YOLO docs: https://docs.ultralytics.com/tasks/pose/
|
|
JOINT_NAMES = [
|
|
"Nose", "Left Eye", "Right Eye", "Left Ear", "Right Ear",
|
|
"Left Shoulder", "Right Shoulder", "Left Elbow", "Right Elbow",
|
|
"Left Wrist", "Right Wrist", "Left Hip", "Right Hip",
|
|
"Left Knee", "Right Knee", "Left Ankle", "Right Ankle"
|
|
]
|
|
|
|
|
|
# Needs to be pointed to the FFmpeg bin folder containing avcodec-*.dll, etc.
|
|
pyside_dir = Path(PySide6.__file__).parent
|
|
if sys.platform == "win32":
|
|
# Tell Python 3.13+ where to find the FFmpeg DLLs bundled with PySide
|
|
os.add_dll_directory(str(pyside_dir))
|
|
|
|
|
|
TRACK_NAMES = ["Baseline", "Live Skeleton"] + JOINT_NAMES
|
|
NUM_TRACKS = len(TRACK_NAMES)
|
|
|
|
# TODO: Improve colors?
|
|
# Generate distinct colors for the tracks
|
|
BASE_COLORS = [QColor(180, 180, 180), QColor(0, 0, 0)] # Grey for Baseline, Black for Live
|
|
REMAINING_COLORS = [QColor.fromHsv(int((i / (NUM_TRACKS-2)) * 359), 200, 255) for i in range(NUM_TRACKS-2)]
|
|
TRACK_COLORS = BASE_COLORS + REMAINING_COLORS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
import json
|
|
import cv2
|
|
from PySide6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QPushButton,
|
|
QFileDialog, QCheckBox, QComboBox, QLabel,
|
|
QGridLayout, QGroupBox, QStackedWidget, QInputDialog, QMessageBox)
|
|
from PySide6.QtGui import QPixmap, QImage
|
|
from PySide6.QtCore import Qt
|
|
|
|
|
|
|
|
|
|
import os
|
|
import json
|
|
import cv2
|
|
from PySide6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QPushButton,
|
|
QFileDialog, QCheckBox, QComboBox, QLabel,
|
|
QGridLayout, QGroupBox, QStackedWidget, QInputDialog, QMessageBox)
|
|
from PySide6.QtGui import QPixmap, QImage
|
|
from PySide6.QtCore import Qt
|
|
|
|
|
|
class OpenFileWindow(QWidget):
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent, Qt.WindowType.Window)
|
|
self.setWindowTitle(f"Load Video - {APP_NAME.upper()}")
|
|
self.setMinimumWidth(650)
|
|
|
|
# State
|
|
self.video_path = None
|
|
self.obs_file = None
|
|
self.pkl_path = None
|
|
self.full_json_data = None
|
|
self.current_video_fps = 30.0
|
|
self.current_video_offset = 0.0
|
|
self.fps = 0
|
|
|
|
self.setup_ui()
|
|
self.center_on_screen()
|
|
|
|
def center_on_screen(self):
|
|
"""Centers the window on the current screen."""
|
|
# Get the geometry of the screen where the mouse currently is
|
|
screen = QGuiApplication.screenAt(QCursor.pos())
|
|
if not screen:
|
|
screen = QGuiApplication.primaryScreen()
|
|
|
|
screen_geometry = screen.availableGeometry()
|
|
size = self.sizeHint() # Or self.geometry() if already sized
|
|
|
|
x = (screen_geometry.width() - size.width()) // 2
|
|
y = (screen_geometry.height() - size.height()) // 2
|
|
|
|
# Apply the coordinates (relative to the screen's top-left)
|
|
self.move(screen_geometry.left() + x, screen_geometry.top() + y)
|
|
|
|
def setup_ui(self):
|
|
self.setStyleSheet("""
|
|
QWidget { background-color: #1e1e1e; color: #ffffff; font-family: 'Segoe UI'; }
|
|
QGroupBox {
|
|
border: 1px solid #3d3d3d; border-radius: 8px; margin-top: 15px;
|
|
padding-top: 15px; font-weight: bold; color: #00aaff; text-transform: uppercase;
|
|
}
|
|
QLabel { color: #ffffff; font-weight: 500; }
|
|
QLabel:disabled { color: #444444; }
|
|
QLabel#Metadata { color: #00ffaa; font-family: 'Consolas'; font-size: 11px; }
|
|
QLabel#Preview { background-color: #000000; border: 2px solid #3d3d3d; }
|
|
QLabel#Warning { color: #ff5555; font-size: 11px; font-style: italic; font-weight: bold; }
|
|
|
|
QPushButton { background-color: #3d3d3d; border: 1px solid #555; padding: 6px; border-radius: 4px; }
|
|
QPushButton:hover { background-color: #00aaff; color: #000; }
|
|
QPushButton:disabled { color: #444; background-color: #252525; }
|
|
|
|
QComboBox { background-color: #2d2d2d; border: 1px solid #555; padding: 4px; border-radius: 4px; }
|
|
QComboBox:disabled { background-color: #222; color: #444; border: 1px solid #2a2a2a; }
|
|
""")
|
|
|
|
main_layout = QVBoxLayout(self)
|
|
|
|
# --- Section 1: Video ---
|
|
video_group = QGroupBox("Primary Video Source")
|
|
v_grid = QGridLayout(video_group)
|
|
self.btn_pick_video = QPushButton("Select Video")
|
|
self.lbl_video_path = QLabel("No video selected...")
|
|
self.lbl_video_metadata = QLabel("Metadata: N/A")
|
|
self.lbl_video_metadata.setObjectName("Metadata")
|
|
self.video_preview = QLabel("NO PREVIEW")
|
|
self.video_preview.setFixedSize(160, 90)
|
|
self.video_preview.setObjectName("Preview")
|
|
|
|
v_grid.addWidget(QLabel("Target Video:"), 0, 0)
|
|
v_grid.addWidget(self.btn_pick_video, 0, 1)
|
|
v_grid.addWidget(self.video_preview, 0, 2, 3, 1)
|
|
v_grid.addWidget(QLabel("Path:"), 1, 0)
|
|
v_grid.addWidget(self.lbl_video_path, 1, 1)
|
|
v_grid.addWidget(self.lbl_video_metadata, 2, 0, 1, 2)
|
|
main_layout.addWidget(video_group)
|
|
|
|
# --- Section 2: Analysis Modes ---
|
|
self.boris_group = QGroupBox("Human Coding (BORIS)")
|
|
self.boris_group.setCheckable(True) # User can toggle this section off
|
|
self.boris_group.setChecked(False)
|
|
boris_layout = QGridLayout(self.boris_group)
|
|
|
|
self.btn_boris_file = QPushButton("Load .boris File")
|
|
self.combo_boris_keys = QComboBox()
|
|
self.combo_video_slot = QComboBox()
|
|
|
|
boris_layout.addWidget(QLabel("BORIS File:"), 0, 0)
|
|
boris_layout.addWidget(self.btn_boris_file, 0, 1)
|
|
boris_layout.addWidget(QLabel("Session:"), 1, 0)
|
|
boris_layout.addWidget(self.combo_boris_keys, 1, 1)
|
|
boris_layout.addWidget(QLabel("Slot:"), 2, 0)
|
|
boris_layout.addWidget(self.combo_video_slot, 2, 1)
|
|
main_layout.addWidget(self.boris_group)
|
|
|
|
# --- Section 3: Trained ML Model ---
|
|
self.pkl_group = QGroupBox("Automated Prediction (.pkl)")
|
|
self.pkl_group.setCheckable(True) # User can toggle this section off
|
|
self.pkl_group.setChecked(False)
|
|
pkl_layout = QGridLayout(self.pkl_group)
|
|
|
|
self.btn_pkl_file = QPushButton("Load .pkl Model")
|
|
self.lbl_pkl_path = QLabel("No model selected...")
|
|
|
|
pkl_layout.addWidget(QLabel("Model File:"), 0, 0)
|
|
pkl_layout.addWidget(self.btn_pkl_file, 0, 1)
|
|
pkl_layout.addWidget(QLabel("Path:"), 1, 0)
|
|
pkl_layout.addWidget(self.lbl_pkl_path, 1, 1)
|
|
main_layout.addWidget(self.pkl_group)
|
|
|
|
# section 3.5
|
|
# --- Velocities and Deviations Section ---
|
|
self.calc_group = QGroupBox("Calculated Events")
|
|
self.calc_group.setCheckable(True)
|
|
self.calc_group.setChecked(False)
|
|
calc_layout = QGridLayout(self.calc_group)
|
|
|
|
# --- Velocity Row ---
|
|
self.cb_velocity = QCheckBox("Enable Velocities")
|
|
self.cb_velocity.setChecked(True)
|
|
self.spin_vel_threshold = QDoubleSpinBox()
|
|
self.spin_vel_threshold.setRange(0.0, 999.99)
|
|
self.spin_vel_threshold.setValue(15)
|
|
self.spin_vel_threshold.setSuffix(" px/s") # Optional: add units for clarity
|
|
|
|
# --- Deviation Row ---
|
|
self.cb_deviation = QCheckBox("Enable Deviations")
|
|
self.cb_deviation.setChecked(True)
|
|
self.spin_dev_threshold = QDoubleSpinBox()
|
|
self.spin_dev_threshold.setRange(0.0, 999.99)
|
|
self.spin_dev_threshold.setValue(80)
|
|
self.spin_dev_threshold.setSuffix(" px") # Optional: add units for clarity
|
|
|
|
# Add to Grid: (widget, row, column)
|
|
calc_layout.addWidget(self.cb_velocity, 0, 0)
|
|
calc_layout.addWidget(QLabel("Vel. Threshold:"), 0, 1)
|
|
calc_layout.addWidget(self.spin_vel_threshold, 0, 2)
|
|
|
|
calc_layout.addWidget(self.cb_deviation, 1, 0)
|
|
calc_layout.addWidget(QLabel("Dev. Threshold:"), 1, 1)
|
|
calc_layout.addWidget(self.spin_dev_threshold, 1, 2)
|
|
|
|
main_layout.addWidget(self.calc_group)
|
|
|
|
# --- Section 4: Inference ---
|
|
self.cfg_group = QGroupBox("Inference Settings")
|
|
c_grid = QGridLayout(self.cfg_group)
|
|
self.check_use_cache = QCheckBox("Auto-search pose cache (.npy)")
|
|
self.check_use_cache.setChecked(True)
|
|
|
|
self.lbl_model_prompt = QLabel("Pose Model:")
|
|
self.combo_inference_model = QComboBox()
|
|
self.combo_inference_model.addItems(["YOLO8n-Pose", "YOLO8m-Pose", "Mediapipe BlazePose"])
|
|
|
|
self.check_bypass_inference = QCheckBox("Bypass Pose Inference")
|
|
self.lbl_inf_warning = QLabel("⚠ WARNING: Nothing fancy. Raw video playback only.")
|
|
self.lbl_inf_warning.setObjectName("Warning")
|
|
self.lbl_inf_warning.setVisible(False)
|
|
|
|
c_grid.addWidget(self.check_use_cache, 0, 0, 1, 2)
|
|
c_grid.addWidget(self.lbl_model_prompt, 1, 0)
|
|
c_grid.addWidget(self.combo_inference_model, 1, 1)
|
|
c_grid.addWidget(self.check_bypass_inference, 2, 0)
|
|
c_grid.addWidget(self.lbl_inf_warning, 2, 1)
|
|
main_layout.addWidget(self.cfg_group)
|
|
|
|
# --- Bottom Buttons ---
|
|
btn_layout = QHBoxLayout()
|
|
self.btn_cancel = QPushButton("Cancel")
|
|
self.btn_confirm = QPushButton("Initialize BLAZE Engine")
|
|
self.btn_confirm.setStyleSheet("background-color: #00aaff; color: #1e1e1e; font-weight: bold;")
|
|
btn_layout.addWidget(self.btn_cancel)
|
|
btn_layout.addWidget(self.btn_confirm)
|
|
main_layout.addLayout(btn_layout)
|
|
|
|
# Connections
|
|
self.btn_pick_video.clicked.connect(self.handle_video_selection)
|
|
self.btn_boris_file.clicked.connect(self.handle_boris_load)
|
|
self.btn_pkl_file.clicked.connect(self.handle_pkl_selection)
|
|
self.combo_boris_keys.currentIndexChanged.connect(self.handle_session_change)
|
|
self.combo_video_slot.currentIndexChanged.connect(self.handle_slot_change)
|
|
self.check_bypass_inference.toggled.connect(self.handle_inference_toggle)
|
|
self.btn_cancel.clicked.connect(self.close)
|
|
self.boris_group.toggled.connect(self.update_metadata_display)
|
|
|
|
|
|
def format_time(self, seconds):
|
|
h, m, s = int(seconds // 3600), int((seconds % 3600) // 60), int(seconds % 60)
|
|
return f"{h:02d}:{m:02d}:{s:02d}"
|
|
|
|
|
|
def handle_video_selection(self):
|
|
path, _ = QFileDialog.getOpenFileName(self, "Open Video", "", "Video Files (*.mp4 *.avi *.mkv)")
|
|
if path:
|
|
self.video_path = path
|
|
self.lbl_video_path.setText(os.path.basename(path))
|
|
|
|
# Open video to extract properties
|
|
cap = cv2.VideoCapture(path)
|
|
self.fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
|
self.total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
self.video_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
self.video_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
cap.release()
|
|
|
|
self.render_preview(path)
|
|
|
|
# If BORIS JSON is already loaded, try to find this video in the slots
|
|
if self.full_json_data:
|
|
self.attempt_auto_match()
|
|
|
|
# Refresh the label with all 4 fields (Res, FPS, Len, and Offset)
|
|
self.update_metadata_display()
|
|
|
|
def handle_boris_load(self):
|
|
path, _ = QFileDialog.getOpenFileName(self, "Select JSON", "", "JSON Files (*.json *.boris)")
|
|
if not path: return
|
|
self.obs_file = path
|
|
self.btn_boris_file.setText(os.path.basename(path))
|
|
try:
|
|
with open(path, 'r') as f:
|
|
self.full_json_data = json.load(f)
|
|
obs = self.full_json_data.get("observations", {})
|
|
print(f"\n[DEBUG] BORIS File Loaded. Found {len(obs)} sessions.")
|
|
self.combo_boris_keys.setEnabled(True)
|
|
self.combo_boris_keys.clear()
|
|
self.combo_boris_keys.addItems(list(obs.keys()))
|
|
if self.video_path: self.attempt_auto_match()
|
|
except Exception as e:
|
|
QMessageBox.warning(self, "Parse Error", str(e))
|
|
|
|
def handle_session_change(self):
|
|
session_key = self.combo_boris_keys.currentText()
|
|
if not self.full_json_data or not session_key: return
|
|
session_data = self.full_json_data.get("observations", {}).get(session_key, {})
|
|
file_map = session_data.get("file", {})
|
|
|
|
self.combo_video_slot.blockSignals(True)
|
|
self.combo_video_slot.clear()
|
|
|
|
print(f"[DEBUG] Filtering slots for session: {session_key}")
|
|
for slot, files in file_map.items():
|
|
# Check if there is at least one non-empty string in the list
|
|
valid_files = [f for f in files if isinstance(f, str) and f.strip()]
|
|
if valid_files:
|
|
display_name = os.path.basename(valid_files[0].replace('\\', '/'))
|
|
print(f" > Valid slot found: {slot} ({display_name})")
|
|
self.combo_video_slot.addItem(f"Slot {slot}: {display_name}", slot)
|
|
|
|
self.combo_video_slot.setEnabled(True)
|
|
self.combo_video_slot.blockSignals(False)
|
|
self.handle_slot_change()
|
|
|
|
def attempt_auto_match(self):
|
|
"""Debugged auto-match: Scans all slots in all sessions for the filename."""
|
|
if not self.video_path or not self.full_json_data:
|
|
return
|
|
|
|
target_name = os.path.basename(self.video_path)
|
|
print(f"\n[DEBUG] ATTEMPTING AUTO-MATCH FOR: {target_name}")
|
|
|
|
obs = self.full_json_data.get("observations", {})
|
|
|
|
for s_idx, (session_key, content) in enumerate(obs.items()):
|
|
file_map = content.get("file", {})
|
|
for slot, files in file_map.items():
|
|
for f_path in files:
|
|
# Normalize path for comparison
|
|
clean_f_path = f_path.replace('\\', '/')
|
|
json_filename = os.path.basename(clean_f_path)
|
|
|
|
if json_filename == target_name:
|
|
print(f"[DEBUG] !!! MATCH FOUND !!!")
|
|
print(f" Session: {session_key}")
|
|
print(f" Slot: {slot}")
|
|
|
|
# Update UI
|
|
self.combo_boris_keys.setCurrentIndex(s_idx)
|
|
# We must allow handle_session_change to finish before setting slot
|
|
for i in range(self.combo_video_slot.count()):
|
|
if self.combo_video_slot.itemData(i) == slot:
|
|
self.combo_video_slot.setCurrentIndex(i)
|
|
break
|
|
return
|
|
|
|
print(f"[DEBUG] No match found for {target_name} in the JSON file mapping.")
|
|
|
|
def handle_slot_change(self):
|
|
session_key = self.combo_boris_keys.currentText()
|
|
# Pull the slot ID (e.g., "1") we stored in handle_session_change
|
|
slot_id = self.combo_video_slot.currentData()
|
|
|
|
if not session_key or slot_id is None:
|
|
return
|
|
|
|
session_data = self.full_json_data.get("observations", {}).get(session_key, {})
|
|
|
|
# Navigate: media_info -> offset -> {slot_id}
|
|
offsets = session_data.get("media_info", {}).get("offset", {})
|
|
val = offsets.get(str(slot_id)) # Ensure it's a string key
|
|
|
|
if val is not None:
|
|
self.current_video_offset = float(val)
|
|
else:
|
|
self.current_video_offset = 0.0
|
|
|
|
self.update_metadata_display()
|
|
|
|
def update_metadata_display(self):
|
|
# Only update if a video has been selected
|
|
if not self.video_path:
|
|
self.lbl_video_metadata.setText("Metadata: N/A")
|
|
return
|
|
|
|
# Check if we are in BORIS mode and have a valid offset
|
|
if self.boris_group.isChecked():
|
|
offset_str = f" | Offset: {self.current_video_offset}s"
|
|
else:
|
|
offset_str = ""
|
|
|
|
# Assemble the final string
|
|
# Assuming self.fps and self.total_frames were set in handle_video_selection
|
|
time_str = self.format_time(self.total_frames / self.fps)
|
|
|
|
# Get original metadata text but update the offset part
|
|
base_text = f"RES: {self.video_w}x{self.video_h} | FPS: {self.fps:.2f} | LEN: {time_str}"
|
|
self.lbl_video_metadata.setText(f"{base_text}{offset_str}")
|
|
|
|
|
|
def handle_pkl_selection(self):
|
|
path, _ = QFileDialog.getOpenFileName(self, "Select Model", "", "Pickle Files (*.pkl)")
|
|
if path:
|
|
self.pkl_path = path
|
|
self.lbl_pkl_path.setText(os.path.basename(path))
|
|
|
|
def handle_inference_toggle(self, checked):
|
|
# Target the model label and checkbox explicitly for greying out
|
|
for w in [self.check_use_cache, self.combo_inference_model, self.lbl_model_prompt]:
|
|
w.setEnabled(not checked)
|
|
self.lbl_inf_warning.setVisible(checked)
|
|
|
|
def render_preview(self, path):
|
|
cap = cv2.VideoCapture(path)
|
|
ret, frame = cap.read()
|
|
if ret:
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
pixmap = QPixmap.fromImage(QImage(frame.data, frame.shape[1], frame.shape[0], frame.shape[1]*3, QImage.Format_RGB888))
|
|
self.video_preview.setPixmap(pixmap.scaled(self.video_preview.size(), Qt.KeepAspectRatio))
|
|
cap.release()
|
|
|
|
|
|
def get_config(self):
|
|
"""Returns a dictionary of all user-selected settings."""
|
|
|
|
return {
|
|
"video_path": self.video_path,
|
|
"total_frames": getattr(self, 'total_frames', 0),
|
|
"fps": self.fps,
|
|
|
|
# BORIS Data
|
|
"use_boris": self.boris_group.isChecked(),
|
|
"obs_file": self.obs_file if self.boris_group.isChecked() else None,
|
|
"session_key": self.combo_boris_keys.currentText(),
|
|
"slot": self.combo_video_slot.currentData(),
|
|
"offset": self.current_video_offset,
|
|
|
|
# ML Model Data
|
|
"use_pkl": self.pkl_group.isChecked(),
|
|
"pkl_path": self.pkl_path if self.pkl_group.isChecked() else None,
|
|
|
|
"use_calculations": self.calc_group.isChecked(),
|
|
"velocity_enabled": self.cb_velocity.isChecked(),
|
|
"velocity_threshold": self.spin_vel_threshold.value(),
|
|
"deviation_enabled": self.cb_deviation.isChecked(),
|
|
"deviation_threshold": self.spin_dev_threshold.value(),
|
|
|
|
# Inference Settings
|
|
"use_pose": not self.check_bypass_inference.isChecked(),
|
|
"pose_model": self.combo_inference_model.currentText(),
|
|
"use_cache": self.check_use_cache.isChecked(),
|
|
}
|
|
|
|
|
|
import os
|
|
from PySide6.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QPushButton,
|
|
QLabel, QFileDialog, QFrame, QComboBox)
|
|
from PySide6.QtCore import Qt
|
|
|
|
|
|
class TrainModelWindow(QDialog):
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent)
|
|
self.setWindowTitle(f"Train Model - {APP_NAME.upper()}")
|
|
self.setFixedSize(500, 550) # Slightly taller to fit stats
|
|
self.selected_folder = None
|
|
self.valid_pairs = [] # Stores (json_path, csv_path)
|
|
|
|
self.setup_ui()
|
|
|
|
def setup_ui(self):
|
|
layout = QVBoxLayout(self)
|
|
layout.setSpacing(12)
|
|
|
|
# --- Section 1: Folder Selection ---
|
|
self.path_display = QLabel("No folder selected...")
|
|
self.path_display.setStyleSheet("background: #1e1e1e; padding: 8px; border-radius: 3px;")
|
|
btn_browse = QPushButton("Select Training Folder")
|
|
btn_browse.clicked.connect(self.browse_folder)
|
|
|
|
layout.addWidget(QLabel("Data Source:"))
|
|
layout.addWidget(self.path_display)
|
|
layout.addWidget(btn_browse)
|
|
|
|
# --- Section 2: Behavior Selection (Multi-Select) ---
|
|
layout.addWidget(QLabel("Select Target Behavior(s):"))
|
|
self.behavior_list = QListWidget()
|
|
self.behavior_list.setMinimumHeight(150)
|
|
self.behavior_list.itemChanged.connect(self.handle_selection_change)
|
|
layout.addWidget(self.behavior_list)
|
|
|
|
# --- Section 3: Group Name (Conditional) ---
|
|
self.group_name_container = QWidget()
|
|
group_layout = QVBoxLayout(self.group_name_container)
|
|
group_layout.setContentsMargins(0, 0, 0, 0)
|
|
|
|
group_layout.addWidget(QLabel("Combined Variable Name:"))
|
|
self.edit_group_name = QLineEdit()
|
|
self.edit_group_name.setPlaceholderText("e.g., Total_Movement")
|
|
group_layout.addWidget(self.edit_group_name)
|
|
|
|
self.group_name_container.hide() # Hidden by default
|
|
layout.addWidget(self.group_name_container)
|
|
|
|
# --- Section 4: Folder Statistics ---
|
|
self.stats_display = QLabel("Valid Pairs Found: 0")
|
|
self.stats_display.setStyleSheet("color: #00ffaa; font-family: 'Consolas'; background: #111; padding: 10px;")
|
|
layout.addWidget(self.stats_display)
|
|
|
|
# --- Section 5: ML Architecture ---
|
|
self.method_dropdown = QComboBox()
|
|
self.method_dropdown.addItems(["Random Forest", "1D-CNN", "LSTM", "XGBoost"])
|
|
layout.addWidget(QLabel("ML Architecture:"))
|
|
layout.addWidget(self.method_dropdown)
|
|
|
|
layout.addStretch()
|
|
|
|
# --- Final Actions ---
|
|
button_box = QHBoxLayout()
|
|
self.btn_train = QPushButton("Start Training")
|
|
self.btn_train.setEnabled(False)
|
|
self.btn_train.setStyleSheet("background-color: #2e7d32; font-weight: bold; padding: 8px;")
|
|
self.btn_train.clicked.connect(self.accept)
|
|
|
|
btn_cancel = QPushButton("Cancel")
|
|
btn_cancel.clicked.connect(self.reject)
|
|
|
|
button_box.addWidget(btn_cancel)
|
|
button_box.addWidget(self.btn_train)
|
|
layout.addLayout(button_box)
|
|
|
|
|
|
def browse_folder(self):
|
|
folder = QFileDialog.getExistingDirectory(self, "Select Training Data Folder")
|
|
if folder:
|
|
self.selected_folder = folder
|
|
self.path_display.setText(folder)
|
|
self.scan_and_parse_folder(folder)
|
|
|
|
def scan_and_parse_folder(self, folder):
|
|
"""Scans for pairs and tracks per-behavior statistics."""
|
|
self.valid_pairs = []
|
|
|
|
# Structure: { "Mouthing": {"count": 0, "frames": 0}, ... }
|
|
behavior_stats = {}
|
|
|
|
total_global_events = 0
|
|
total_global_frames = 0
|
|
|
|
files = os.listdir(folder)
|
|
json_files = [f for f in files if f.endswith("_metrics.json")]
|
|
|
|
for j_file in json_files:
|
|
base_name = j_file.replace("_metrics.json", "")
|
|
csv_file = base_name + "_pose_raw.csv"
|
|
|
|
json_path = os.path.join(folder, j_file)
|
|
csv_path = os.path.join(folder, csv_file)
|
|
|
|
if os.path.exists(csv_path):
|
|
self.valid_pairs.append((json_path, csv_path))
|
|
|
|
try:
|
|
with open(json_path, 'r') as f:
|
|
data = json.load(f)
|
|
behaviors = data.get("behaviors", {})
|
|
fps = data.get("metadata", {}).get("fps", 30.0)
|
|
|
|
for b_name, instances in behaviors.items():
|
|
if b_name not in behavior_stats:
|
|
behavior_stats[b_name] = {"count": 0, "frames": 0}
|
|
|
|
count = len(instances)
|
|
frames = sum(inst.get("duration_frames", 0) for inst in instances)
|
|
|
|
behavior_stats[b_name]["count"] += count
|
|
behavior_stats[b_name]["frames"] += frames
|
|
|
|
total_global_events += count
|
|
total_global_frames += frames
|
|
except Exception as e:
|
|
print(f"Error parsing {j_file}: {e}")
|
|
|
|
# --- Update Dataset Summary Label ---
|
|
pair_count = len(self.valid_pairs)
|
|
total_sec = total_global_frames / 30.0 # Standardized estimate
|
|
|
|
stats_text = (
|
|
f"Valid Pairs Found: {pair_count}\n"
|
|
f"Total Event Instances: {total_global_events}\n"
|
|
f"Total Behavior Time: {total_sec:.2f}s"
|
|
)
|
|
self.stats_display.setText(stats_text)
|
|
|
|
# --- Populate Dropdown with Detailed Labels ---
|
|
self.behavior_list.clear()
|
|
for b_name in sorted(behavior_stats.keys()):
|
|
stats = behavior_stats[b_name]
|
|
sec = stats["frames"] / 30.0
|
|
label = f"{b_name} ({stats['count']} events, {sec:.1f}s)"
|
|
|
|
item = QListWidgetItem(label)
|
|
item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
|
|
item.setCheckState(Qt.Unchecked)
|
|
item.setData(Qt.UserRole, b_name) # Store clean name
|
|
self.behavior_list.addItem(item)
|
|
|
|
|
|
def handle_selection_change(self):
|
|
"""Shows/Hides the name input based on how many boxes are checked."""
|
|
selected_items = [self.behavior_list.item(i) for i in range(self.behavior_list.count())
|
|
if self.behavior_list.item(i).checkState() == Qt.Checked]
|
|
|
|
count = len(selected_items)
|
|
self.group_name_container.setVisible(count > 1)
|
|
self.btn_train.setEnabled(count > 0)
|
|
|
|
|
|
def get_selection(self):
|
|
"""Returns the specific behaviors to combine and the final variable name."""
|
|
selected_names = [self.behavior_list.item(i).data(Qt.UserRole)
|
|
for i in range(self.behavior_list.count())
|
|
if self.behavior_list.item(i).checkState() == Qt.Checked]
|
|
|
|
# If multiple are selected, use the text field name; otherwise use the single name
|
|
if len(selected_names) > 1:
|
|
final_name = self.edit_group_name.text().strip() or "combined_variable"
|
|
else:
|
|
final_name = selected_names[0] if selected_names else None
|
|
|
|
return {
|
|
"folder": self.selected_folder,
|
|
"pairs": self.valid_pairs,
|
|
"selected_behaviors": selected_names,
|
|
"target_name": final_name,
|
|
"model_type": self.method_dropdown.currentText()
|
|
}
|
|
|
|
|
|
|
|
import json
|
|
from PySide6.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QLabel,
|
|
QPushButton, QListWidget, QListWidgetItem,
|
|
QFileDialog, QCheckBox, QMessageBox)
|
|
from PySide6.QtCore import Qt
|
|
|
|
|
|
class ExportTimelineJsonWindow(QDialog):
|
|
def __init__(self, timeline_data, fps=30.0, parent=None):
|
|
super().__init__(parent)
|
|
self.setWindowTitle("Export Timeline Data")
|
|
self.setFixedSize(500, 550)
|
|
|
|
self.timeline_data = timeline_data
|
|
self.fps = fps
|
|
self.output_path = None
|
|
|
|
self.setup_ui()
|
|
|
|
def setup_ui(self):
|
|
layout = QVBoxLayout(self)
|
|
layout.setSpacing(12)
|
|
|
|
# --- Section 1: Output Location ---
|
|
self.path_display = QLabel("No output file selected...")
|
|
self.path_display.setStyleSheet("background: #1e1e1e; padding: 8px; border-radius: 3px;")
|
|
btn_browse = QPushButton("Select Output Location")
|
|
btn_browse.clicked.connect(self.browse_file)
|
|
|
|
layout.addWidget(QLabel("Export Destination:"))
|
|
layout.addWidget(self.path_display)
|
|
layout.addWidget(btn_browse)
|
|
|
|
# --- Section 2: Track Selection ---
|
|
layout.addWidget(QLabel("Select Tracks to Include:"))
|
|
self.track_list = QListWidget()
|
|
self.populate_track_list()
|
|
layout.addWidget(self.track_list)
|
|
|
|
# --- Section 3: 'Fancy' Calculations Filter ---
|
|
self.cb_fancy = QCheckBox("Apply Fancy Filtering")
|
|
self.cb_fancy.setToolTip("Drops any Dev_ or Vel_ track events that overlap with an active BORIS event.")
|
|
layout.addWidget(self.cb_fancy)
|
|
|
|
layout.addStretch()
|
|
|
|
# --- Final Actions ---
|
|
button_box = QHBoxLayout()
|
|
self.btn_export = QPushButton("Export JSON")
|
|
self.btn_export.setEnabled(False)
|
|
self.btn_export.setStyleSheet("background-color: #2e7d32; font-weight: bold; padding: 8px;")
|
|
self.btn_export.clicked.connect(self.perform_export)
|
|
|
|
btn_cancel = QPushButton("Cancel")
|
|
btn_cancel.clicked.connect(self.reject)
|
|
|
|
button_box.addWidget(btn_cancel)
|
|
button_box.addWidget(self.btn_export)
|
|
layout.addLayout(button_box)
|
|
|
|
def populate_track_list(self):
|
|
"""Populates the list widget with all available tracks, defaulting to checked."""
|
|
for track_name in sorted(self.timeline_data.keys()):
|
|
item = QListWidgetItem(track_name)
|
|
item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
|
|
item.setCheckState(Qt.Checked)
|
|
self.track_list.addItem(item)
|
|
|
|
def browse_file(self):
|
|
file_path, _ = QFileDialog.getSaveFileName(
|
|
self, "Save Timeline JSON", "", "JSON Files (*.json)"
|
|
)
|
|
if file_path:
|
|
# Ensure extension
|
|
if not file_path.endswith('.json'):
|
|
file_path += '.json'
|
|
|
|
self.output_path = file_path
|
|
self.path_display.setText(file_path)
|
|
self.btn_export.setEnabled(True)
|
|
|
|
def perform_export(self):
|
|
if not self.output_path:
|
|
return
|
|
|
|
# 1. Get explicitly selected tracks
|
|
selected_tracks = []
|
|
for i in range(self.track_list.count()):
|
|
item = self.track_list.item(i)
|
|
if item.checkState() == Qt.Checked:
|
|
selected_tracks.append(item.text())
|
|
|
|
# 2. Gather BORIS intervals for the "Fancy" overlap check
|
|
do_fancy = self.cb_fancy.isChecked()
|
|
boris_intervals = []
|
|
|
|
if do_fancy:
|
|
for track_name in selected_tracks:
|
|
# Assuming BORIS events don't start with Dev_ or Vel_
|
|
if not track_name.startswith(("Dev_", "Vel_")):
|
|
for ev in self.timeline_data.get(track_name, []):
|
|
boris_intervals.append((ev[0], ev[1]))
|
|
|
|
# 3. Process events into a flat list
|
|
flat_events = []
|
|
|
|
for track_name in selected_tracks:
|
|
is_calc_track = track_name.startswith(("Dev_", "Vel_"))
|
|
events = self.timeline_data.get(track_name, [])
|
|
|
|
for ev in events:
|
|
start_f = ev[0]
|
|
end_f = ev[1]
|
|
|
|
# 'Fancy' Logic: Skip this event if it's a calc track and overlaps with BORIS
|
|
if do_fancy and is_calc_track:
|
|
overlap_found = False
|
|
for (b_start, b_end) in boris_intervals:
|
|
# Standard math for checking if two intervals overlap
|
|
if max(start_f, b_start) <= min(end_f, b_end):
|
|
overlap_found = True
|
|
break
|
|
|
|
if overlap_found:
|
|
continue # Drop completely
|
|
|
|
# Append valid events
|
|
flat_events.append({
|
|
"track_name": track_name,
|
|
"start_frame": int(start_f),
|
|
"start_sec": round(start_f / self.fps, 3),
|
|
"end_frame": int(end_f),
|
|
"end_sec": round(end_f / self.fps, 3)
|
|
})
|
|
|
|
# 4. Order events chronologically by start frame
|
|
flat_events.sort(key=lambda x: x["start_frame"])
|
|
|
|
# 5. Build final JSON structure
|
|
all_possible_tracks = list(self.timeline_data.keys())
|
|
|
|
export_payload = {
|
|
"metadata": {
|
|
"fps": self.fps,
|
|
"total_events_exported": len(flat_events),
|
|
"fancy_filtering_applied": do_fancy,
|
|
"all_possible_tracks": all_possible_tracks
|
|
},
|
|
"events": flat_events
|
|
}
|
|
|
|
# 6. Save to disk
|
|
try:
|
|
with open(self.output_path, 'w') as f:
|
|
json.dump(export_payload, f, indent=4)
|
|
self.accept() # Close the dialog successfully
|
|
except Exception as e:
|
|
QMessageBox.critical(self, "Export Error", f"Failed to write JSON:\n{str(e)}")
|
|
|
|
|
|
|
|
|
|
class AboutWindow(QWidget):
|
|
"""
|
|
Simple About window displaying basic application information.
|
|
|
|
Args:
|
|
parent (QWidget, optional): Parent widget of this window. Defaults to None.
|
|
"""
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent, Qt.WindowType.Window)
|
|
self.setWindowTitle(f"About {APP_NAME.upper()}")
|
|
self.resize(250, 100)
|
|
self.setStyleSheet("""
|
|
QVBoxLayout, QWidget {
|
|
background-color: #1e1e1e;
|
|
}
|
|
QLabel {
|
|
color: #ffffff;
|
|
}
|
|
""")
|
|
|
|
layout = QVBoxLayout()
|
|
label = QLabel(f"About {APP_NAME.upper()}", self)
|
|
label2 = QLabel("Behavioral Learning & Automated Zoned Events Suite", self)
|
|
label3 = QLabel(f"{APP_NAME.upper()} is licensed under the GPL-3.0 licence. For more information, visit https://www.gnu.org/licenses/gpl-3.0.en.html", self)
|
|
label4 = QLabel(f"Version v{CURRENT_VERSION}")
|
|
|
|
layout.addWidget(label)
|
|
layout.addWidget(label2)
|
|
layout.addWidget(label3)
|
|
layout.addWidget(label4)
|
|
|
|
self.setLayout(layout)
|
|
|
|
|
|
|
|
class UserGuideWindow(QWidget):
|
|
"""
|
|
Simple User Guide window displaying basic information on how to use the software.
|
|
|
|
Args:
|
|
parent (QWidget, optional): Parent widget of this window. Defaults to None.
|
|
"""
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent, Qt.WindowType.Window)
|
|
self.setWindowTitle(f"User Guide - {APP_NAME.upper()}")
|
|
self.resize(250, 100)
|
|
self.setStyleSheet("""
|
|
QVBoxLayout, QWidget {
|
|
background-color: #1e1e1e;
|
|
}
|
|
QLabel {
|
|
color: #ffffff;
|
|
}
|
|
""")
|
|
|
|
layout = QVBoxLayout()
|
|
label = QLabel("Hmmm...", self)
|
|
label2 = QLabel("Nothing to see here yet.", self)
|
|
|
|
label3 = QLabel(f"For more information, visit the Git wiki page <a href='https://git.research.dezeeuw.ca/tyler/{APP_NAME}/wiki'>here</a>.", self)
|
|
label3.setTextFormat(Qt.TextFormat.RichText)
|
|
label3.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction)
|
|
label3.setOpenExternalLinks(True)
|
|
layout.addWidget(label)
|
|
layout.addWidget(label2)
|
|
layout.addWidget(label3)
|
|
|
|
self.setLayout(layout)
|
|
|
|
|
|
|
|
# class PoseAnalyzerWorker(QThread):
|
|
# progress = Signal(str)
|
|
# finished_data = Signal(dict)
|
|
|
|
# def __init__(self, video_path, obs_info=None, predictor=None):
|
|
# debug_print()
|
|
# super().__init__()
|
|
# self.video_path = video_path
|
|
# self.obs_info = obs_info
|
|
# self.predictor = predictor
|
|
# self.pose_df = pd.DataFrame()
|
|
|
|
|
|
# def get_best_infant_match(self, results, w, h, prev_track_id):
|
|
# debug_print()
|
|
# if not results[0].boxes or results[0].boxes.id is None:
|
|
# return None, None, None, None
|
|
# ids = results[0].boxes.id.int().cpu().tolist()
|
|
# kpts = results[0].keypoints.xy.cpu().numpy()
|
|
# confs = results[0].keypoints.conf.cpu().numpy()
|
|
# best_idx, best_score = -1, -1
|
|
# for i, k in enumerate(kpts):
|
|
# vis = np.sum(confs[i] > 0.5)
|
|
# valid = k[confs[i] > 0.5]
|
|
# dist = np.linalg.norm(np.mean(valid, axis=0) - [w/2, h/2]) if len(valid) > 0 else 1000
|
|
# score = (vis * 10) - (dist * 0.1) + (50 if ids[i] == prev_track_id else 0)
|
|
# if score > best_score:
|
|
# best_score, best_idx = score, i
|
|
# if best_idx == -1:
|
|
# return None, None, None, None
|
|
# return ids[best_idx], kpts[best_idx], confs[best_idx], best_idx
|
|
|
|
|
|
# def _merge_json_observations(self, timeline_events, fps):
|
|
# """Restores the grouping and block-pairing logic from the observation files."""
|
|
# debug_print()
|
|
# if not self.obs_info:
|
|
# return
|
|
|
|
# self.progress.emit("Merging JSON Observations...")
|
|
# json_path, subkey = self.obs_info
|
|
|
|
# # try:
|
|
# # with open(json_path, 'r') as f:
|
|
# # full_json = json.load(f)
|
|
|
|
# # # Extract events for the specific subkey (e.g., 'Participant_01')
|
|
# # raw_obs_events = full_json["observations"][subkey]["events"]
|
|
# # raw_obs_events.sort(key=lambda x: x[0]) # Sort by timestamp
|
|
|
|
# # # Group frames by label
|
|
# # obs_groups = {}
|
|
# # for ev in raw_obs_events:
|
|
# # time_sec, _, label, special = ev[0], ev[1], ev[2], ev[3]
|
|
# # frame = int(time_sec * fps)
|
|
# # if label not in obs_groups:
|
|
# # obs_groups[label] = []
|
|
# # obs_groups[label].append(frame)
|
|
|
|
# # # Convert groups of frames into (Start, End) blocks
|
|
# # for label, frames in obs_groups.items():
|
|
# # track_name = f"OBS: {label}"
|
|
# # processed_blocks = []
|
|
|
|
# # # Step by 2 to create start/end pairs
|
|
# # for i in range(0, len(frames) - 1, 2):
|
|
# # start_f = frames[i]
|
|
# # end_f = frames[i+1]
|
|
# # processed_blocks.append((start_f, end_f, "Moderate", "Manual"))
|
|
|
|
# # # Register the track globally if it's new
|
|
# # if track_name not in TRACK_NAMES:
|
|
# # TRACK_NAMES.append(track_name)
|
|
# # TRACK_COLORS.append(QColor("#AA00FF")) # Purple for Observations
|
|
|
|
# # timeline_events[track_name] = processed_blocks
|
|
|
|
# # except Exception as e:
|
|
# # print(f"Error parsing JSON Observations: {e}")
|
|
|
|
|
|
# try:
|
|
# with open(json_path, 'r') as f:
|
|
# full_json = json.load(f)
|
|
|
|
# raw_obs_events = full_json["observations"][subkey]["events"]
|
|
# raw_obs_events.sort(key=lambda x: x[0])
|
|
|
|
# # NEW LOGIC: Use a dictionary to store frames for specific track names
|
|
# # track_name -> [list of frames]
|
|
# obs_groups = {}
|
|
|
|
# for ev in raw_obs_events:
|
|
# # ev structure: [time_sec, unknown, label, special]
|
|
# time_sec, label, special = ev[0], ev[2], ev[3]
|
|
# frame = int(time_sec * fps)
|
|
|
|
# # Determine which tracks this event belongs to
|
|
# target_tracks = []
|
|
|
|
# if special == "Left":
|
|
# target_tracks.append(f"OBS: {label} (Left)")
|
|
# elif special == "Right":
|
|
# target_tracks.append(f"OBS: {label} (Right)")
|
|
# elif special == "Both":
|
|
# target_tracks.append(f"OBS: {label} (Left)")
|
|
# target_tracks.append(f"OBS: {label} (Right)")
|
|
# else:
|
|
# # No special or unrecognized value
|
|
# target_tracks.append(f"OBS: {label}")
|
|
|
|
# # Add the frame to all applicable tracks
|
|
# for t_name in target_tracks:
|
|
# if t_name not in obs_groups:
|
|
# obs_groups[t_name] = []
|
|
# obs_groups[t_name].append(frame)
|
|
|
|
# # Convert frame groups into (Start, End) blocks
|
|
# for track_name, frames in obs_groups.items():
|
|
# processed_blocks = []
|
|
|
|
# # Step by 2 to create start/end pairs (ensures matching pairs per track)
|
|
|
|
# if "Sync" in track_name and len(frames) == 1:
|
|
# start_f = frames[0]
|
|
# end_f = start_f + 1 # Give it a visible width on the timeline
|
|
# processed_blocks.append((start_f, end_f, "Moderate", "Manual"))
|
|
|
|
# else:
|
|
# for i in range(0, len(frames) - 1, 2):
|
|
# start_f = frames[i]
|
|
# end_f = frames[i+1]
|
|
# processed_blocks.append((start_f, end_f, "Moderate", "Manual"))
|
|
|
|
# # Register the track in global lists if not already there
|
|
# if track_name not in TRACK_NAMES:
|
|
# TRACK_NAMES.append(track_name)
|
|
# # Using Purple for Observations
|
|
# TRACK_COLORS.append(QColor("#AA00FF"))
|
|
|
|
# timeline_events[track_name] = processed_blocks
|
|
|
|
# except Exception as e:
|
|
# print(f"Error parsing JSON Observations: {e}")
|
|
|
|
|
|
# def _run_existing_ml_models(self, z_kps, dirs, raw_kpts):
|
|
# debug_print()
|
|
# """
|
|
# Scans for trained models and generates timeline tracks for each.
|
|
# """
|
|
# ai_events = {}
|
|
|
|
# # 1. Match the pattern from your GeneralPredictor: {Target}_rf.pkl
|
|
# model_files = glob.glob("*_rf.pkl")
|
|
# print(f"DEBUG: Found model files: {model_files}")
|
|
|
|
# for model_path in model_files:
|
|
# try:
|
|
# # Extract Target (e.g., "Mouthing" from "Mouthing_rf.pkl")
|
|
# base_name = model_path.split("_rf.pkl")[0]
|
|
# target = base_name.replace("ml_", "", 1)
|
|
# track_name = f"AI: {target}"
|
|
|
|
# self.progress.emit(f"Loading AI Observations for {target}...")
|
|
|
|
|
|
# # 2. Match the Scaler naming from calculate_and_train:
|
|
# # {target}_random_forest_scaler.pkl
|
|
# scaler_path = f"{base_name}_rf_scaler.pkl"
|
|
|
|
# if not os.path.exists(scaler_path):
|
|
# print(f"DEBUG: Skipping {target}, scaler not found at {scaler_path}")
|
|
# continue
|
|
|
|
# # Load assets
|
|
# model = joblib.load(model_path)
|
|
# scaler = joblib.load(scaler_path)
|
|
|
|
# # 3. Feature extraction (On-the-fly)
|
|
# all_features = []
|
|
# # We must set the predictor's target so format_features uses the correct ACTIVITY_MAP
|
|
# self.predictor.current_target = target
|
|
|
|
# for f_idx in range(len(z_kps)):
|
|
# feat = self.predictor.format_features(z_kps[f_idx], dirs[f_idx], raw_kpts[f_idx])
|
|
# all_features.append(feat)
|
|
|
|
# # 4. Inference
|
|
# X = np.array(all_features)
|
|
# X_scaled = scaler.transform(X)
|
|
# predictions = model.predict(X_scaled)
|
|
|
|
# # 5. Convert binary 0/1 to blocks
|
|
# processed_blocks = []
|
|
# start_f = None
|
|
|
|
# for f_idx, val in enumerate(predictions):
|
|
# if val == 1 and start_f is None:
|
|
# start_f = f_idx
|
|
# elif val == 0 and start_f is not None:
|
|
# # [start, end, severity, direction]
|
|
# processed_blocks.append((start_f, f_idx - 1, "Large", "AI"))
|
|
# start_f = None
|
|
|
|
# if start_f is not None:
|
|
# processed_blocks.append((start_f, len(predictions)-1, "Large", "AI"))
|
|
|
|
# # 6. Global Registration
|
|
# if track_name not in TRACK_NAMES:
|
|
# TRACK_NAMES.append(track_name)
|
|
# # Ensure TRACK_COLORS has an entry for this new track
|
|
# TRACK_COLORS.append(QColor("#00FF00"))
|
|
|
|
# ai_events[track_name] = processed_blocks
|
|
# print(f"DEBUG: Successfully generated {len(processed_blocks)} blocks for {track_name}")
|
|
|
|
# except Exception as e:
|
|
# print(f"Inference Error for {model_path}: {e}")
|
|
|
|
# return ai_events
|
|
|
|
|
|
# def classify_delta(self, z):
|
|
# # debug_print()
|
|
# z_abs = abs(z)
|
|
# if z_abs < 1: return "Rest"
|
|
# elif z_abs < 2: return "Small"
|
|
# elif z_abs < 3: return "Moderate"
|
|
# else: return "Large"
|
|
|
|
|
|
# def _save_pose_cache(self, path, data):
|
|
# """
|
|
# Saves the raw YOLO keypoints and confidence scores to a CSV.
|
|
# Each row represents one frame, flattened from (17, 3) to (51,).
|
|
# """
|
|
# try:
|
|
# with open(path, 'w', newline='') as f:
|
|
# writer = csv.writer(f)
|
|
|
|
# # Create the descriptive header
|
|
# header = []
|
|
# for joint in JOINT_NAMES:
|
|
# # Replace spaces with underscores for better compatibility with other tools
|
|
# header.extend([f"{joint}_x", f"{joint}_y", f"{joint}_conf"])
|
|
|
|
# writer.writerow(header)
|
|
|
|
# # Write the frame data
|
|
# for frame_data in data:
|
|
# # frame_data is (17, 3), flatten to (51,)
|
|
# writer.writerow(frame_data.flatten())
|
|
|
|
# print(f"DEBUG: Pose cache saved with joint headers at {path}")
|
|
# except Exception as e:
|
|
# print(f"ERROR: Could not save pose cache: {e}")
|
|
|
|
|
|
# def run(self):
|
|
# debug_print()
|
|
# # --- PHASE 1: VIDEO SETUP & POSE EXTRACTION ---
|
|
# cap = cv2.VideoCapture(self.video_path)
|
|
# fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
|
# width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
# height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
# total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
# raw_kps_per_frame = []
|
|
# csv_storage_data = []
|
|
# valid_mask = []
|
|
# pose_cache = self.video_path.rsplit('.', 1)[0] + "_pose_raw.csv"
|
|
|
|
# if os.path.exists(pose_cache):
|
|
# self.progress.emit("Loading cached kinematic data...")
|
|
# with open(pose_cache, 'r') as f:
|
|
# reader = csv.reader(f)
|
|
# next(reader)
|
|
# for row in reader:
|
|
# full_data = np.array([float(x) for x in row]).reshape(17, 3)
|
|
# kp = full_data[:, :2]
|
|
# raw_kps_per_frame.append(kp)
|
|
# csv_storage_data.append(full_data)
|
|
# valid_mask.append(np.any(kp))
|
|
# else:
|
|
# self.progress.emit("Detecting poses with YOLO...")
|
|
# model = YOLO("yolov8n-pose.pt")
|
|
# prev_track_id = None
|
|
# for i in range(total_frames):
|
|
# ret, frame = cap.read()
|
|
# if not ret: break
|
|
# results = model.track(frame, persist=True, verbose=False)
|
|
# track_id, kp, confs, _ = self.get_best_infant_match(results, width, height, prev_track_id)
|
|
# if kp is not None:
|
|
# prev_track_id = track_id
|
|
# raw_kps_per_frame.append(kp)
|
|
# csv_storage_data.append(np.column_stack((kp, confs)))
|
|
# valid_mask.append(True)
|
|
# else:
|
|
# raw_kps_per_frame.append(np.zeros((17, 2)))
|
|
# csv_storage_data.append(np.zeros((17, 3)))
|
|
# valid_mask.append(False)
|
|
# if i % 50 == 0: self.progress.emit(f"YOLO: {int((i/total_frames)*100)}%")
|
|
# self._save_pose_cache(pose_cache, csv_storage_data)
|
|
|
|
# cap.release()
|
|
# actual_len = len(raw_kps_per_frame)
|
|
|
|
# flattened_rows = []
|
|
# for frame_array in csv_storage_data:
|
|
# # frame_array is (17, 3) -> flatten to (51,)
|
|
# flattened_rows.append(frame_array.flatten())
|
|
|
|
# columns = []
|
|
# for name in JOINT_NAMES:
|
|
# columns.extend([f"{name}_x", f"{name}_y", f"{name}_conf"])
|
|
|
|
# # Store this so the Inspector can access it instantly in memory
|
|
# self.pose_df = pd.DataFrame(flattened_rows, columns=columns)
|
|
|
|
# # --- PHASE 2: KINEMATICS & Z-SCORES ---
|
|
# self.progress.emit("Calculating Kinematics...")
|
|
# analysis_kpts = []
|
|
# for kp in raw_kps_per_frame:
|
|
# pelvis = (kp[11] + kp[12]) / 2
|
|
# analysis_kpts.append(kp - pelvis)
|
|
|
|
# valid_data = [analysis_kpts[i] for i, v in enumerate(valid_mask) if v]
|
|
# if valid_data:
|
|
# stacked = np.stack(valid_data)
|
|
# baseline_mean = np.mean(stacked, axis=0)
|
|
# baseline_std = np.std(np.linalg.norm(stacked - baseline_mean, axis=2), axis=0) + 1e-6
|
|
# else:
|
|
# baseline_mean, baseline_std = np.zeros((17, 2)), np.ones(17)
|
|
|
|
# np_raw_kps = np.array(raw_kps_per_frame)
|
|
# np_z_kps = np.array([np.linalg.norm(kp - baseline_mean, axis=1) / baseline_std for kp in analysis_kpts])
|
|
|
|
# # Calculate directions (Assume you have a method for this or use a dummy for now)
|
|
# # Using placeholder empty strings to prevent errors in track generation
|
|
# np_dirs = np.full((actual_len, 17), "", dtype=object)
|
|
|
|
# # --- PHASE 3: TIMELINE GENERATION ---
|
|
# # Initialize dictionary with ALL global track names to prevent KeyErrors
|
|
# timeline_events = {name: [] for name in TRACK_NAMES}
|
|
|
|
# # 1. Kinematic Events (The joint tracks)
|
|
# for j_idx, joint_name in enumerate(JOINT_NAMES):
|
|
# current_block = None
|
|
# for f_idx in range(actual_len):
|
|
# severity = self.classify_delta(np_z_kps[f_idx, j_idx])
|
|
# if severity != "Rest":
|
|
# if current_block and current_block[2] == severity:
|
|
# current_block[1] = f_idx
|
|
# else:
|
|
# current_block = [f_idx, f_idx, severity, ""]
|
|
# timeline_events[joint_name].append(current_block)
|
|
# else:
|
|
# current_block = None
|
|
|
|
# # 2. JSON Observations
|
|
# self._merge_json_observations(timeline_events, fps)
|
|
|
|
# # 3. AI Inferred Events
|
|
# ai_events = self._run_existing_ml_models(np_z_kps, np_dirs, np_raw_kps)
|
|
# timeline_events.update(ai_events)
|
|
|
|
# # --- PHASE 4: EMIT ---
|
|
# data = {
|
|
# "video_path": self.video_path,
|
|
# "fps": fps,
|
|
# "total_frames": actual_len,
|
|
# "width": width, "height": height,
|
|
# "events": timeline_events,
|
|
# "raw_kps": np_raw_kps,
|
|
# "z_kps": np_z_kps,
|
|
# "directions": np_dirs,
|
|
# "baseline_kp_mean": baseline_mean
|
|
# }
|
|
# self.progress.emit("Analysis Complete!")
|
|
# self.finished_data.emit(data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ==========================================
|
|
# TIMELINE WIDGET
|
|
# ==========================================
|
|
import numpy as np
|
|
from PySide6.QtWidgets import QWidget, QScrollArea
|
|
from PySide6.QtCore import Qt, Signal, QRect, QRectF
|
|
from PySide6.QtGui import QPainter, QPen, QColor, QFont, QBrush
|
|
|
|
|
|
|
|
class TimelineWidget(QWidget):
|
|
seek_requested = Signal(int)
|
|
visibility_changed = Signal(set)
|
|
track_selected = Signal(str)
|
|
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent)
|
|
self.data = None
|
|
self.track_names = []
|
|
self.track_colors = []
|
|
self.current_frame = 0
|
|
self.zoom_factor = 1.0
|
|
self.label_width = 160
|
|
self.track_height = 25
|
|
self.ruler_height = 20
|
|
self.scrollbar_buffer = 2
|
|
self.hidden_tracks = set()
|
|
self.sync_offset = 0.0
|
|
self.sync_fps = 30.0
|
|
self.is_scrubbing = False
|
|
|
|
# self.total_content_height = (15 * self.track_height) + self.ruler_height
|
|
# self.setMinimumHeight(self.total_content_height + self.scrollbar_buffer)
|
|
|
|
|
|
|
|
def set_data(self, events_dict, total_frames, fps):
|
|
"""Expects grouped events from your BORIS loader."""
|
|
self.track_names = sorted(list(events_dict.keys()))
|
|
self.data = {
|
|
"events": events_dict,
|
|
"total_frames": total_frames,
|
|
"fps": fps
|
|
}
|
|
self.sync_fps = fps
|
|
|
|
self.total_content_height = (len(self.track_names) * self.track_height) + self.ruler_height
|
|
self.setMinimumHeight(self.total_content_height + self.scrollbar_buffer)
|
|
|
|
# Generate colors dynamically since we don't know the tracks ahead of time
|
|
self.track_colors = [QColor.fromHsl((i * 360 // max(1, len(self.track_names))), 160, 140)
|
|
for i in range(len(self.track_names))]
|
|
self.update_geometry()
|
|
self.update()
|
|
|
|
|
|
|
|
def update_geometry(self):
|
|
if self.data:
|
|
# Width is sidebar + (frames * zoom)
|
|
total_w = self.label_width + int(self.data["total_frames"] * self.zoom_factor)
|
|
self.setFixedWidth(total_w)
|
|
self.update()
|
|
|
|
|
|
def wheelEvent(self, event):
|
|
# debug_print()
|
|
|
|
if event.modifiers() == Qt.ControlModifier:
|
|
delta = event.angleDelta().y()
|
|
# Zoom by 10% per notch
|
|
zoom_change = 1.1 if delta > 0 else 0.9
|
|
self.set_zoom(self.zoom_factor * zoom_change)
|
|
else:
|
|
# Let the scroll area handle normal vertical scrolling
|
|
super().wheelEvent(event)
|
|
|
|
|
|
def keyPressEvent(self, event):
|
|
#debug_print()
|
|
|
|
if event.modifiers() == Qt.ControlModifier:
|
|
if event.key() == Qt.Key_Plus or event.key() == Qt.Key_Equal:
|
|
self.set_zoom(self.zoom_factor * 1.2)
|
|
elif event.key() == Qt.Key_Minus:
|
|
self.set_zoom(self.zoom_factor * 0.8)
|
|
elif event.key() == Qt.Key_0:
|
|
self.set_zoom(1.0) # Reset zoom
|
|
else:
|
|
super().keyPressEvent(event)
|
|
|
|
|
|
def set_zoom(self, factor):
|
|
# Calculate MIN zoom: The zoom required to make the video fit the width exactly
|
|
# (Available Width - Sidebar) / Total Frames
|
|
available_w = self.parent().width() - self.label_width if self.parent() else 800
|
|
min_zoom = available_w / self.data["total_frames"]
|
|
|
|
# Clamp: Don't zoom out past the video end, don't zoom in to infinity
|
|
self.zoom_factor = max(min_zoom, min(factor, 50.0))
|
|
self.update_geometry()
|
|
|
|
|
|
def set_playhead(self, frame):
|
|
old_x = self.label_width + (self.current_frame * self.zoom_factor)
|
|
self.current_frame = frame
|
|
new_x = self.label_width + (self.current_frame * self.zoom_factor)
|
|
|
|
# Repaint only the playhead areas for performance
|
|
self.update(int(old_x - 5), 0, 10, self.height())
|
|
self.update(int(new_x - 5), 0, 10, self.height())
|
|
self.ensure_playhead_visible()
|
|
|
|
|
|
def ensure_playhead_visible(self):
|
|
scroll_area = self.parent().parent()
|
|
if not isinstance(scroll_area, QScrollArea): return
|
|
|
|
scrollbar = scroll_area.horizontalScrollBar()
|
|
view_width = scroll_area.viewport().width()
|
|
px = self.label_width + int(self.current_frame * self.zoom_factor)
|
|
scroll_x = scrollbar.value()
|
|
|
|
if px > (scroll_x + view_width) or px < (scroll_x + self.label_width):
|
|
scrollbar.setValue(px - self.label_width - (view_width // 4))
|
|
|
|
|
|
|
|
def mousePressEvent(self, event):
|
|
debug_print()
|
|
|
|
if not self.data or event.button() != Qt.LeftButton:
|
|
return
|
|
|
|
pos_x = event.position().x()
|
|
pos_y = event.position().y()
|
|
scroll_area = self.parent().parent()
|
|
scroll_x = scroll_area.horizontalScrollBar().value()
|
|
|
|
# 1. CALCULATE FRAME
|
|
relative_x = pos_x - self.label_width
|
|
frame = int(relative_x / self.zoom_factor)
|
|
frame = max(0, min(frame, self.data["total_frames"] - 1))
|
|
|
|
# 2. IF CLICKED SIDEBAR: Toggle Visibility (No Scrubbing)
|
|
if pos_x < scroll_x + self.label_width:
|
|
relative_y = pos_y - self.ruler_height
|
|
track_idx = int(relative_y // self.track_height)
|
|
if 0 <= track_idx < len(self.track_names):
|
|
name = self.track_names[track_idx]
|
|
if name in self.hidden_tracks: self.hidden_tracks.remove(name)
|
|
else: self.hidden_tracks.add(name)
|
|
self.visibility_changed.emit(self.hidden_tracks)
|
|
self.update()
|
|
return # Exit early; don't set is_scrubbing
|
|
|
|
# 3. IF CLICKED RULER OR DATA AREA: Start Scrubbing
|
|
self.is_scrubbing = True
|
|
self.seek_requested.emit(frame)
|
|
|
|
# Handle track selection if in the data area
|
|
if pos_y >= self.ruler_height:
|
|
track_idx = int((pos_y - self.ruler_height) // self.track_height)
|
|
if 0 <= track_idx < len(self.track_names):
|
|
self.track_selected.emit(self.track_names[track_idx])
|
|
self.selected_track_idx = track_idx
|
|
self.update()
|
|
else:
|
|
# Clicked ruler
|
|
self.selected_track_idx = -1
|
|
self.track_selected.emit("")
|
|
self.update()
|
|
|
|
def mouseMoveEvent(self, event):
|
|
if self.is_scrubbing:
|
|
self.update_frame_from_mouse(event.position().x())
|
|
|
|
def mouseReleaseEvent(self, event):
|
|
if event.button() == Qt.LeftButton:
|
|
self.is_scrubbing = False
|
|
|
|
def update_frame_from_mouse(self, x):
|
|
rel_x = x - self.label_width
|
|
frame = int(rel_x / self.zoom_factor)
|
|
frame = max(0, min(frame, self.data["total_frames"] - 1))
|
|
self.seek_requested.emit(frame)
|
|
|
|
|
|
|
|
def paintEvent(self, event):
|
|
if not self.data: return
|
|
|
|
painter = QPainter(self)
|
|
dirty_rect = event.rect()
|
|
|
|
# 1. Coordinate Setup
|
|
scroll_area = self.parent().parent()
|
|
scroll_x = 0
|
|
if isinstance(scroll_area, QScrollArea):
|
|
scroll_x = scroll_area.horizontalScrollBar().value()
|
|
w, h = self.width(), self.height()
|
|
total_f = self.data["total_frames"]
|
|
fps = self.data.get("fps", 30)
|
|
offset_y = 20
|
|
|
|
|
|
# --- 2. DRAW DATA AREA (Events and Playhead) ---
|
|
for i, name in enumerate(self.track_names):
|
|
y = offset_y + (i * self.track_height)
|
|
is_hidden = name in self.hidden_tracks
|
|
|
|
if y + self.track_height < dirty_rect.top() or y > dirty_rect.bottom():
|
|
continue
|
|
|
|
# Event Blocks
|
|
if name in self.data["events"]:
|
|
base_color = self.track_colors[i]
|
|
for event_item in self.data["events"][name]:
|
|
# Map the new data format: [start_f, end_f, type, value]
|
|
start_f, end_f = event_item[0], event_item[1]
|
|
|
|
if "AI:" in name:
|
|
s_start, s_end = start_f, end_f
|
|
else:
|
|
s_start = start_f - 0
|
|
s_end = end_f - 0
|
|
|
|
x_start = self.label_width + (s_start * self.zoom_factor)
|
|
x_end = self.label_width + (s_end * self.zoom_factor)
|
|
|
|
# Performance optimization: skip drawing if off-screen
|
|
if x_end < scroll_x or x_start > scroll_x + w:
|
|
continue
|
|
|
|
# Draw block
|
|
color = QColor(base_color)
|
|
if is_hidden:
|
|
color = QColor(120, 120, 120, 40)
|
|
|
|
painter.fillRect(QRectF(x_start, y + 2, max(1, x_end - x_start), self.track_height - 4), color)
|
|
|
|
# Draw Playhead
|
|
playhead_x = self.label_width + (self.current_frame * self.zoom_factor)
|
|
painter.setPen(QPen(QColor(255, 0, 0), 2))
|
|
painter.drawLine(int(playhead_x), 0, int(playhead_x), h)
|
|
|
|
# # --- 3. DRAW STICKY SIDEBAR (Pinned to the left edge) ---
|
|
# # Draw this AFTER the data so it masks anything scrolling under it
|
|
sidebar_rect = QRect(scroll_x, 0, self.label_width, h)
|
|
painter.fillRect(sidebar_rect, QColor(30, 30, 30))
|
|
|
|
# Ruler segment for the sidebar area
|
|
painter.fillRect(scroll_x, 0, self.label_width, offset_y, QColor(45, 45, 45))
|
|
|
|
for i, name in enumerate(self.track_names):
|
|
y = offset_y + (i * self.track_height)
|
|
is_hidden = name in self.hidden_tracks
|
|
|
|
# Pinned Grid Line
|
|
painter.setPen(QColor(60, 60, 60))
|
|
painter.drawLine(scroll_x, y, scroll_x + w, y)
|
|
|
|
# # Pinned Label Text (Anchored to scroll_x)
|
|
text_color = QColor(70, 70, 70) if is_hidden else QColor(220, 220, 220)
|
|
painter.setPen(text_color)
|
|
painter.setFont(QFont("Arial", 8, QFont.Bold))
|
|
painter.drawText(scroll_x + 10, y + 17, name)
|
|
|
|
# --- 4. DRAW TIME RULER TICKS ---
|
|
target_spacing_px = 120
|
|
possible_units = [1, 5, 15, 30, 150, 300, 900, 1800]
|
|
tick_interval = next((u for u in possible_units if (u * self.zoom_factor) >= target_spacing_px), 1800)
|
|
|
|
# Draw Ruler Background
|
|
painter.fillRect(0, 0, w, 20, QColor(45, 45, 45))
|
|
|
|
painter.setPen(QColor(180, 180, 180))
|
|
painter.setFont(QFont("Segoe UI", 7))
|
|
sub_interval = max(1, tick_interval // 5)
|
|
|
|
for f in range(0, total_f + 1, sub_interval):
|
|
x = self.label_width + int(f * self.zoom_factor)
|
|
|
|
if x < scroll_x: continue
|
|
if x > scroll_x + w: break
|
|
|
|
if f % tick_interval == 0:
|
|
# Major Tick
|
|
painter.drawLine(x, 10, x, 20)
|
|
|
|
# Format Label: MM:SS or SS:FF
|
|
total_seconds = f / fps
|
|
minutes = int(total_seconds // 60)
|
|
seconds = int(total_seconds % 60)
|
|
frames = int(f % fps)
|
|
|
|
if tick_interval < fps:
|
|
time_str = f"{seconds:02d}:{frames:02d}f"
|
|
elif minutes > 0:
|
|
time_str = f"{minutes:02d}m:{seconds:02d}s"
|
|
else:
|
|
time_str = f"{seconds}s"
|
|
painter.drawText(x + 4, 12, time_str)
|
|
else:
|
|
painter.drawLine(x, 16, x, 20)
|
|
|
|
|
|
self.update_geometry()
|
|
|
|
|
|
|
|
|
|
|
|
class TrainingWorker(QThread):
|
|
# Signals to communicate back to the UI
|
|
finished = Signal(str) # Sends the HTML report back
|
|
error = Signal(str) # Sends error messages
|
|
|
|
def __init__(self, params):
|
|
super().__init__()
|
|
self.params = params
|
|
|
|
def run(self):
|
|
try:
|
|
from predictor import GeneralPredictor
|
|
predictor = GeneralPredictor()
|
|
# This is the heavy calculation and training
|
|
report = predictor.calculate_and_train(self.params)
|
|
self.finished.emit(report)
|
|
except Exception as e:
|
|
self.error.emit(str(e))
|
|
|
|
|
|
|
|
|
|
from PySide6.QtCore import QThread, Signal
|
|
|
|
class MLInferenceWorker(QThread):
|
|
finished = Signal(dict) # Emits the timeline_events dictionary
|
|
error = Signal(str)
|
|
|
|
def __init__(self, raw_kpts, ml_model, ml_scaler, active_features, behavior_name):
|
|
super().__init__()
|
|
self.raw_kpts = raw_kpts
|
|
self.ml_model = ml_model
|
|
self.ml_scaler = ml_scaler
|
|
self.active_features = active_features
|
|
self.behavior_name = f"AI: {behavior_name}"
|
|
|
|
def run(self):
|
|
try:
|
|
# Import predictor logic inside the thread
|
|
from predictor import GeneralPredictor
|
|
engine = GeneralPredictor()
|
|
engine.active_feature_keys = self.active_features
|
|
|
|
# 1. Feature Extraction (The slow part)
|
|
X_raw = []
|
|
for frame in self.raw_kpts:
|
|
X_raw.append(engine.format_features(frame))
|
|
X = np.array(X_raw)
|
|
|
|
# 2. Scaling & Prediction
|
|
if self.ml_scaler:
|
|
X = self.ml_scaler.transform(X)
|
|
|
|
preds = self.ml_model.predict(X)
|
|
|
|
# 3. Convert to timeline blocks (using your existing converter logic)
|
|
# You can either move the converter into GeneralPredictor or call it here
|
|
events = engine.convert_to_events(preds, track_name=self.behavior_name) # Ensure engine has this method
|
|
|
|
self.finished.emit(events)
|
|
except Exception as e:
|
|
self.error.emit(str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SkeletonOverlay(QWidget):
|
|
def __init__(self, parent=None):
|
|
debug_print()
|
|
super().__init__(parent)
|
|
self.setAttribute(Qt.WA_TransparentForMouseEvents) # Clicks pass through to video
|
|
self.data = None
|
|
self.current_frame = 0
|
|
self.hidden_tracks = set()
|
|
# Use your saved SKELETON_CONNECTIONS logic
|
|
self.connections = [
|
|
(5, 7), (7, 9), (6, 8), (8, 10), (5, 6), (5, 11),
|
|
(6, 12), (11, 12), (11, 13), (13, 15), (12, 14), (14, 16)
|
|
]
|
|
self.KP_MAP = {
|
|
"nose": 0, "LE": 1, "RE": 2, "Lear": 3, "Rear": 4,
|
|
"Lshoulder": 5, "Rshoulder": 6, "Lelbow": 7, "Relbow": 8,
|
|
"Lwrist": 9, "Rwrist": 10, "Lhip": 11, "Rhip": 12,
|
|
"Lknee": 13, "Rknee": 14, "Lankle": 15, "Rankle": 16
|
|
}
|
|
self.CONNECTIONS = [
|
|
("nose", "LE"), ("nose", "RE"), ("LE", "Lear"), ("RE", "Rear"),
|
|
("Lshoulder", "Rshoulder"), ("Lshoulder", "Lelbow"), ("Lelbow", "Lwrist"),
|
|
("Rshoulder", "Relbow"), ("Relbow", "Rwrist"), ("Lshoulder", "Lhip"),
|
|
("Rshoulder", "Rhip"), ("Lhip", "Rhip"), ("Lhip", "Lknee"),
|
|
("Lknee", "Lankle"), ("Rhip", "Rknee"), ("Rknee", "Rankle")
|
|
]
|
|
|
|
|
|
def set_frame(self, frame_idx):
|
|
self.current_frame = frame_idx
|
|
self.update()
|
|
|
|
|
|
def set_hidden_tracks(self, hidden_set):
|
|
debug_print()
|
|
self.hidden_tracks = hidden_set
|
|
self.update()
|
|
|
|
|
|
def set_data(self, data):
|
|
debug_print()
|
|
self.data = data
|
|
self.update()
|
|
|
|
|
|
def paintEvent(self, event):
|
|
if not self.data or 'raw_kps' not in self.data:
|
|
return
|
|
|
|
painter = QPainter(self)
|
|
painter.setRenderHint(QPainter.Antialiasing)
|
|
|
|
v_w, v_h = self.data['width'], self.data['height']
|
|
w, h = self.width(), self.height()
|
|
scale_x, scale_y = w / v_w, h / v_h
|
|
|
|
current_f = self.current_frame
|
|
kp_live = self.data['raw_kps'][current_f]
|
|
|
|
# --- 1. MODIFIED TRACK STATUS (Respects Visibility) ---
|
|
def get_track_status(track_name):
|
|
# If the user greyed out this track in the timeline, act as if it's inactive
|
|
if track_name in self.hidden_tracks:
|
|
return None
|
|
if track_name not in self.data['events']:
|
|
return None
|
|
for start, end, severity, direction in self.data['events'][track_name]:
|
|
if start <= current_f <= end:
|
|
idx = TRACK_NAMES.index(track_name)
|
|
color = QColor(TRACK_COLORS[idx])
|
|
alpha = 80 if severity == "Small" else 160 if severity == "Moderate" else 255
|
|
color.setAlpha(alpha)
|
|
return color
|
|
return None
|
|
|
|
ANGLE_SEGMENTS = {
|
|
"L_sh": [("Lhip", "Lshoulder"), ("Lshoulder", "Lelbow")],
|
|
"R_sh": [("Rhip", "Rshoulder"), ("Rshoulder", "Relbow")],
|
|
"L_el": [("Lshoulder", "Lelbow"), ("Lelbow", "Lwrist")],
|
|
"R_el": [("Rshoulder", "Relbow"), ("Relbow", "Rwrist")],
|
|
"L_leg": [("Lhip", "Lknee"), ("Lknee", "Lankle")],
|
|
"R_leg": [("Rhip", "Rknee"), ("Rknee", "Rankle")]
|
|
}
|
|
|
|
# --- 2. DRAW BASELINE (Only if not hidden) ---
|
|
if "Baseline" not in self.hidden_tracks:
|
|
idx_l_hip, idx_r_hip = self.KP_MAP["Lhip"], self.KP_MAP["Rhip"]
|
|
p_left = kp_live[idx_l_hip][:2]
|
|
p_right = kp_live[idx_r_hip][:2]
|
|
pelvis_live = (p_left + p_right) / 2
|
|
base_raw = self.data['baseline_kp_mean']
|
|
|
|
# CRITICAL: Center the baseline template around its own pelvis first
|
|
# This prevents the "Double Dipping" jump
|
|
b_l_hip, b_r_hip = base_raw[idx_l_hip], base_raw[idx_r_hip]
|
|
pelvis_base = (b_l_hip + b_r_hip) / 2
|
|
|
|
# New calculation: (Template - its center) + live anchor
|
|
kp_baseline = (base_raw - pelvis_base) + pelvis_live
|
|
|
|
painter.setPen(QPen(QColor(200, 200, 200, 200), 2, Qt.DashLine))
|
|
for s_name, e_name in self.CONNECTIONS:
|
|
p1 = QPointF(kp_baseline[self.KP_MAP[s_name]][0] * scale_x, kp_baseline[self.KP_MAP[s_name]][1] * scale_y)
|
|
p2 = QPointF(kp_baseline[self.KP_MAP[e_name]][0] * scale_x, kp_baseline[self.KP_MAP[e_name]][1] * scale_y)
|
|
painter.drawLine(p1, p2)
|
|
|
|
# --- 3. DRAW LIVE SKELETON (Only if not hidden) ---
|
|
|
|
# CONNECTIONS
|
|
for s_name, e_name in self.CONNECTIONS:
|
|
active_color = None
|
|
for angle_track, segments in ANGLE_SEGMENTS.items():
|
|
if (s_name, e_name) in segments or (e_name, s_name) in segments:
|
|
active_color = get_track_status(angle_track)
|
|
if active_color: break
|
|
|
|
p1 = QPointF(kp_live[self.KP_MAP[s_name]][0] * scale_x, kp_live[self.KP_MAP[s_name]][1] * scale_y)
|
|
p2 = QPointF(kp_live[self.KP_MAP[e_name]][0] * scale_x, kp_live[self.KP_MAP[e_name]][1] * scale_y)
|
|
|
|
if active_color:
|
|
# Active events ALWAYS draw
|
|
painter.setPen(QPen(active_color, 8, Qt.SolidLine, Qt.RoundCap))
|
|
painter.drawLine(p1, p2)
|
|
elif "Live Skeleton" not in self.hidden_tracks:
|
|
# Black lines ONLY draw if Live Skeleton is ON
|
|
painter.setPen(QPen(Qt.black, 4, Qt.SolidLine, Qt.RoundCap))
|
|
painter.drawLine(p1, p2)
|
|
|
|
# DOTS
|
|
ANGLE_VERTEX_MAP = {
|
|
"L_sh": "Lshoulder", "R_sh": "Rshoulder",
|
|
"L_el": "Lelbow", "R_el": "Relbow",
|
|
"L_leg": "Lknee", "R_leg": "Rknee"
|
|
}
|
|
|
|
for kp_name, kp_idx in self.KP_MAP.items():
|
|
pt = QPointF(kp_live[kp_idx][0] * scale_x, kp_live[kp_idx][1] * scale_y)
|
|
|
|
# Check for Point Event (Skip if hidden via get_track_status)
|
|
point_color = get_track_status(kp_name)
|
|
|
|
if point_color:
|
|
painter.setBrush(point_color)
|
|
painter.setPen(QPen(Qt.white, 0.7))
|
|
painter.drawEllipse(pt, 5, 5)
|
|
continue
|
|
|
|
# Check for Angle Event
|
|
angle_color = None
|
|
for angle_track, vertex_name in ANGLE_VERTEX_MAP.items():
|
|
if kp_name == vertex_name:
|
|
angle_color = get_track_status(angle_track)
|
|
if angle_color: break
|
|
|
|
if angle_color:
|
|
painter.setBrush(angle_color)
|
|
painter.setPen(Qt.NoPen)
|
|
painter.drawEllipse(pt, 4, 4)
|
|
|
|
elif "Live Skeleton" not in self.hidden_tracks:
|
|
painter.setBrush(Qt.black)
|
|
painter.setPen(Qt.NoPen)
|
|
painter.drawEllipse(pt, 4, 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VideoView(QGraphicsView):
|
|
resized = Signal()
|
|
|
|
def __init__(self, scene, parent=None):
|
|
debug_print()
|
|
super().__init__(scene, parent)
|
|
self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
|
self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
|
self.setFrameStyle(0)
|
|
self.setStyleSheet("background: black; border: none;")
|
|
self.setAlignment(Qt.AlignCenter)
|
|
|
|
def resizeEvent(self, event):
|
|
debug_print()
|
|
super().resizeEvent(event)
|
|
self.resized.emit()
|
|
|
|
|
|
# ==========================================
|
|
# MAIN PREMIERE WINDOW
|
|
# ==========================================
|
|
|
|
class PremiereWindow(QMainWindow):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.setWindowTitle(f"Pose Analysis Timeline - {APP_NAME}")
|
|
self.resize(1200, 900)
|
|
|
|
self.platform_suffix = "-" + PLATFORM_NAME
|
|
|
|
# Application-wide Updaters
|
|
self.updater = UpdateManager(
|
|
main_window=self,
|
|
api_url=API_URL,
|
|
api_url_sec=API_URL_SECONDARY,
|
|
current_version=CURRENT_VERSION,
|
|
platform_name=PLATFORM_NAME,
|
|
platform_suffix=self.platform_suffix,
|
|
app_name=APP_NAME
|
|
)
|
|
|
|
self.setStyleSheet("""
|
|
QMainWindow, QWidget#centralWidget { background-color: #1e1e1e; }
|
|
QLabel, QStatusBar, QMenuBar { color: #ffffff; }
|
|
QTabWidget::pane { border: 1px solid #333333; background: #1e1e1e; }
|
|
QTabBar::tab {
|
|
background: #2b2b2b;
|
|
color: #aaa;
|
|
padding: 8px 15px;
|
|
border: 1px solid #333;
|
|
border-bottom: none;
|
|
border-top-left-radius: 4px;
|
|
border-top-right-radius: 4px;
|
|
}
|
|
|
|
/* 2. Content Tabs: Selected & Hover States */
|
|
/* We add :!last to ensure these NEVER apply to the + tab */
|
|
QTabBar::tab:selected:!last {
|
|
background: #3d3d3d;
|
|
color: #fff;
|
|
font-weight: bold;
|
|
}
|
|
|
|
QTabBar::tab:hover:!last {
|
|
background: #444;
|
|
}
|
|
|
|
/* 3. THE PLUS TAB: Constant State */
|
|
/* We define all states (normal, selected, hover) to be identical */
|
|
QTabBar::tab:last,
|
|
QTabBar::tab:last:hover,
|
|
QTabBar::tab:last:selected {
|
|
background: #1e1e1e;
|
|
color: #00aaff;
|
|
font-weight: bold;
|
|
margin-left: 2px;
|
|
border: 1px solid #333;
|
|
padding: 4px 12px;
|
|
}""")
|
|
|
|
# --- Tab System ---
|
|
self.tabs = QTabWidget()
|
|
self.tabs.setTabsClosable(True)
|
|
self.tabs.tabCloseRequested.connect(self.close_tab)
|
|
|
|
self.setCentralWidget(self.tabs)
|
|
|
|
self.create_welcome_tab()
|
|
|
|
self.tabs.addTab(QWidget(), "+")
|
|
|
|
# 2. Disable the close button on the "+" tab specifically
|
|
# (Assuming index 0 was your first tab, the + is now at the last index)
|
|
plus_idx = self.tabs.count() - 1
|
|
self.tabs.tabBar().setTabButton(plus_idx, QTabBar.ButtonPosition.RightSide, None)
|
|
|
|
# 3. Connect to the click event
|
|
self.tabs.tabBar().installEventFilter(self)
|
|
|
|
self.create_menu_bar()
|
|
|
|
# Update checks
|
|
self.local_check_thread = LocalPendingUpdateCheckThread(CURRENT_VERSION, self.platform_suffix, PLATFORM_NAME, APP_NAME)
|
|
self.local_check_thread.pending_update_found.connect(self.updater.on_pending_update_found)
|
|
self.local_check_thread.no_pending_update.connect(self.updater.on_no_pending_update)
|
|
self.local_check_thread.start()
|
|
|
|
# Window instances
|
|
self.load_window = None
|
|
self.train_window = None
|
|
self.export_window = None
|
|
self.about = None
|
|
self.help = None
|
|
|
|
|
|
def eventFilter(self, obj, event):
|
|
# Check if the event is a mouse press on the TabBar
|
|
if obj == self.tabs.tabBar() and event.type() == QEvent.MouseButtonPress:
|
|
# Map the click position to which tab index was hit
|
|
index = self.tabs.tabBar().tabAt(event.pos())
|
|
|
|
# If they clicked the "+" tab
|
|
if self.tabs.tabText(index) == "+":
|
|
# Show the menu
|
|
self.show_new_tab_menu()
|
|
# Return True to CONSUME the event.
|
|
# This prevents QTabWidget from ever seeing the click and switching tabs.
|
|
return True
|
|
|
|
return super().eventFilter(obj, event)
|
|
|
|
|
|
def show_new_tab_menu(self):
|
|
from PySide6.QtWidgets import QMenu
|
|
from PySide6.QtGui import QAction, QCursor
|
|
menu = QMenu(self)
|
|
|
|
load_act = QAction("Load Video", self)
|
|
load_act.triggered.connect(self.open_load_video_dialog)
|
|
|
|
train_act = QAction("Train Model", self)
|
|
train_act.triggered.connect(self.open_train_model_dialog)
|
|
|
|
config_act = QAction("Configuration (Ind.)", self)
|
|
config_act.triggered.connect(lambda: self.open_model_configuration_tab("individual"))
|
|
|
|
config_act2 = QAction("Configuration (Group)", self)
|
|
config_act2.triggered.connect(lambda: self.open_model_configuration_tab("group"))
|
|
|
|
menu.addAction(load_act)
|
|
menu.addAction(train_act)
|
|
menu.addAction(config_act)
|
|
menu.addAction(config_act2)
|
|
|
|
# Show the menu right under the mouse cursor
|
|
menu.exec(QCursor.pos())
|
|
|
|
|
|
def create_welcome_tab(self):
|
|
welcome_widget = QWidget()
|
|
layout = QVBoxLayout(welcome_widget)
|
|
|
|
title = QLabel(f"Welcome to {APP_NAME}")
|
|
title.setStyleSheet("font-size: 24px; font-weight: bold; color: #00aaff;")
|
|
title.setAlignment(Qt.AlignCenter)
|
|
|
|
subtitle = QLabel("Click 'File' > 'Load Video...' to begin a new analysis session.")
|
|
subtitle.setStyleSheet("font-size: 14px; color: #aaaaaa;")
|
|
subtitle.setAlignment(Qt.AlignCenter)
|
|
|
|
layout.addStretch()
|
|
layout.addWidget(title)
|
|
layout.addWidget(subtitle)
|
|
layout.addStretch()
|
|
|
|
self.tabs.addTab(welcome_widget, "Welcome")
|
|
|
|
|
|
def create_menu_bar(self):
|
|
menu_bar = self.menuBar()
|
|
self.statusbar = self.statusBar()
|
|
|
|
def make_action(name, shortcut=None, slot=None, checkable=False, checked=False, icon=None):
|
|
action = QAction(name, self)
|
|
if shortcut: action.setShortcut(QKeySequence(shortcut))
|
|
if slot: action.triggered.connect(slot)
|
|
if checkable:
|
|
action.setCheckable(True)
|
|
action.setChecked(checked)
|
|
if icon: action.setIcon(QIcon(icon))
|
|
return action
|
|
|
|
# File Menu
|
|
file_menu = menu_bar.addMenu("File")
|
|
file_menu.addAction(make_action("Load Video...", "Ctrl+O", self.open_load_video_dialog))
|
|
file_menu.addSeparator()
|
|
file_menu.addAction(make_action("Exit", "Ctrl+Q", QApplication.instance().quit))
|
|
|
|
# Edit Menu (Routes to current tab)
|
|
edit_menu = menu_bar.addMenu("Edit")
|
|
edit_menu.addAction(make_action("Cut", "Ctrl+X", self.route_cut))
|
|
edit_menu.addAction(make_action("Copy", "Ctrl+C", self.route_copy))
|
|
edit_menu.addAction(make_action("Paste", "Ctrl+V", self.route_paste))
|
|
|
|
# View Menu
|
|
view_menu = menu_bar.addMenu("View")
|
|
toggle_sb = make_action("Toggle Status Bar", checkable=True, checked=True)
|
|
toggle_sb.toggled.connect(self.statusbar.setVisible)
|
|
view_menu.addAction(toggle_sb)
|
|
|
|
self.statusbar.showMessage("Ready")
|
|
|
|
# --- Tab & Loading Logic ---
|
|
|
|
def open_load_video_dialog(self):
|
|
if self.load_window is None or not self.load_window.isVisible():
|
|
self.load_window = OpenFileWindow(self)
|
|
# Connect the initialization button from OpenFileWindow to our tab creator
|
|
self.load_window.btn_confirm.clicked.connect(self.handle_new_video_session)
|
|
self.load_window.show()
|
|
|
|
def open_train_model_dialog(self):
|
|
if self.train_window is None or not self.train_window.isVisible():
|
|
self.train_window = TrainModelWindow(self)
|
|
# Connect the initialization button from OpenFileWindow to our tab creator
|
|
self.train_window.btn_train.clicked.connect(self.handle_start_training)
|
|
self.train_window.show()
|
|
|
|
def open_model_configuration_tab(self, stringy):
|
|
self.handle_model_configuration_session(stringy)
|
|
|
|
def open_export_data_dialog(self):
|
|
if self.export_window is None or not self.export_window.isVisible():
|
|
self.export_window = ExportTimelineJsonWindow(self)
|
|
# Connect the initialization button from OpenFileWindow to our tab creator
|
|
# self.export_window.btn_train.clicked.connect(self.handle_start_training)
|
|
self.export_window.show()
|
|
|
|
|
|
def handle_start_training(self):
|
|
params = self.train_window.get_selection()
|
|
|
|
# 1. Use QProgressDialog instead of QMessageBox
|
|
# It's designed to stay on top and handle background tasks
|
|
self.loading_dialog = QProgressDialog("Processing data and training Random Forest...", None, 0, 0, self)
|
|
self.loading_dialog.setWindowTitle("Training Model")
|
|
self.loading_dialog.setWindowModality(Qt.WindowModal)
|
|
self.loading_dialog.setCancelButton(None) # Remove cancel button to prevent interruption
|
|
self.loading_dialog.setMinimumDuration(0) # Show immediately
|
|
self.loading_dialog.show()
|
|
|
|
# 2. Setup the Worker Thread
|
|
self.training_thread = TrainingWorker(params)
|
|
self.training_thread.finished.connect(self.on_training_finished)
|
|
self.training_thread.error.connect(self.on_training_error)
|
|
|
|
# Clean up the thread object when it's done to prevent memory leaks
|
|
self.training_thread.finished.connect(self.training_thread.deleteLater)
|
|
|
|
# 3. Start thread
|
|
self.training_thread.start()
|
|
|
|
# Close the selection window
|
|
self.train_window.close()
|
|
|
|
def on_training_finished(self, report_html):
|
|
# Using reset() on QProgressDialog automatically closes it and cleans up
|
|
if self.loading_dialog:
|
|
self.loading_dialog.reset()
|
|
self.loading_dialog = None
|
|
|
|
self.display_ml_results(report_html)
|
|
|
|
def on_training_error(self, error_msg):
|
|
if self.loading_dialog:
|
|
self.loading_dialog.reset()
|
|
self.loading_dialog = None
|
|
|
|
QMessageBox.critical(self, "Training Error", f"An error occurred: {error_msg}")
|
|
|
|
def display_ml_results(self, report):
|
|
"""Displays the RF performance report in a simple popup."""
|
|
msg = QMessageBox(self)
|
|
msg.setWindowTitle("Training Results")
|
|
msg.setTextFormat(Qt.RichText)
|
|
msg.setText(report)
|
|
msg.exec()
|
|
|
|
def handle_new_video_session(self):
|
|
config = self.load_window.get_config()
|
|
|
|
# 2. Close the selection window
|
|
self.load_window.close()
|
|
|
|
# 3. Create the new tab with the config dictionary
|
|
# We pass the config so the tab knows whether to run AI or just play video
|
|
new_tab = VideoAnalysisTab(config)
|
|
|
|
# 4. Handle Tab Placement (Keep '+' at the end)
|
|
tab_name = os.path.basename(config['video_path'])
|
|
plus_idx = self.tabs.count() - 1
|
|
new_idx = self.tabs.insertTab(plus_idx, new_tab, tab_name)
|
|
|
|
# 5. Switch to it
|
|
self.tabs.setCurrentIndex(new_idx)
|
|
|
|
|
|
def handle_model_configuration_session(self, stringy):
|
|
new_tab = ModelParameterConfigurationTab(stringy)
|
|
|
|
# 4. Handle Tab Placement (Keep '+' at the end)
|
|
tab_name = f"Configuration Editor - {stringy}"
|
|
plus_idx = self.tabs.count() - 1
|
|
new_idx = self.tabs.insertTab(plus_idx, new_tab, tab_name)
|
|
|
|
# 5. Switch to it
|
|
self.tabs.setCurrentIndex(new_idx)
|
|
|
|
|
|
def close_tab(self, index):
|
|
# Prevent closing the Welcome tab if it's the only one left
|
|
if index == 0 and self.tabs.count() == 1:
|
|
return
|
|
|
|
widget = self.tabs.widget(index)
|
|
if widget:
|
|
# If the widget has cleanup routines (like stopping video), call them here
|
|
if hasattr(widget, 'cleanup'):
|
|
widget.cleanup()
|
|
widget.deleteLater()
|
|
self.tabs.removeTab(index)
|
|
|
|
# --- Routing Menu Actions to Current Tab ---
|
|
|
|
def get_current_tab(self):
|
|
return self.tabs.currentWidget()
|
|
|
|
def route_copy(self):
|
|
tab = self.get_current_tab()
|
|
if hasattr(tab, 'info_label'):
|
|
tab.info_label.copy()
|
|
self.statusbar.showMessage("Copied to clipboard")
|
|
|
|
def route_cut(self):
|
|
tab = self.get_current_tab()
|
|
if hasattr(tab, 'info_label'):
|
|
tab.info_label.cut()
|
|
self.statusbar.showMessage("Cut to clipboard")
|
|
|
|
def route_paste(self):
|
|
tab = self.get_current_tab()
|
|
if hasattr(tab, 'info_label'):
|
|
tab.info_label.paste()
|
|
self.statusbar.showMessage("Pasted from clipboard")
|
|
|
|
|
|
|
|
class VideoAnalysisTab(QWidget):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
# State
|
|
self.config = config
|
|
|
|
self.setStyleSheet("""
|
|
QMainWindow, QWidget#centralWidget {
|
|
background-color: #1e1e1e;
|
|
}
|
|
QLabel, QStatusBar, QMenuBar {
|
|
color: #ffffff;
|
|
}
|
|
/* Target the Timeline specifically */
|
|
TimelineWidget {
|
|
background-color: #1e1e1e;
|
|
border: 1px solid #333333;
|
|
}
|
|
/* Button styling with Grey borders */
|
|
QDialog, QMessageBox, QFileDialog {
|
|
background-color: #2b2b2b;
|
|
}
|
|
QDialog QLabel, QMessageBox QLabel {
|
|
color: #ffffff;
|
|
}
|
|
QPushButton {
|
|
background-color: #2b2b2b;
|
|
color: #ffffff;
|
|
border: 1px solid #555555; /* Subtle Grey border */
|
|
border-radius: 3px;
|
|
padding: 4px;
|
|
}
|
|
QPushButton:hover {
|
|
background-color: #3d3d3d;
|
|
border-color: #888888; /* Brightens border on hover */
|
|
}
|
|
QPushButton:pressed {
|
|
background-color: #111111;
|
|
}
|
|
QPushButton:disabled {
|
|
border-color: #333333;
|
|
color: #444444;
|
|
}
|
|
/* Splitter/Divider styling */
|
|
QSplitter::handle {
|
|
background-color: #333333; /* Dark grey dividers */
|
|
}
|
|
QSplitter::handle:horizontal {
|
|
width: 2px;
|
|
}
|
|
QSplitter::handle:vertical {
|
|
height: 2px;
|
|
}
|
|
/* ScrollArea styling to keep it dark */
|
|
QScrollArea, QScrollArea > QWidget > QWidget {
|
|
background-color: #1e1e1e;
|
|
border: none;
|
|
}
|
|
""")
|
|
|
|
|
|
self.setup_ui()
|
|
|
|
self.initialize_session()
|
|
|
|
|
|
def setup_ui(self):
|
|
main_layout = QVBoxLayout(self)
|
|
main_layout.setContentsMargins(0, 0, 0, 0)
|
|
|
|
self.main_splitter = QSplitter(Qt.Vertical)
|
|
top_splitter = QSplitter(Qt.Horizontal)
|
|
|
|
# --- Video Area ---
|
|
video_container = QWidget()
|
|
video_layout = QVBoxLayout(video_container)
|
|
|
|
self.scene = QGraphicsScene()
|
|
self.view = VideoView(self.scene)
|
|
self.view.resized.connect(self.update_video_geometry)
|
|
|
|
self.video_item = QGraphicsVideoItem()
|
|
self.scene.addItem(self.video_item)
|
|
|
|
# Overlay initialization
|
|
#self.skeleton_overlay = SkeletonOverlay(self.view.viewport())
|
|
|
|
self.player = QMediaPlayer()
|
|
self.audio_output = QAudioOutput()
|
|
self.player.setAudioOutput(self.audio_output)
|
|
self.player.setVideoOutput(self.video_item)
|
|
|
|
video_layout.addWidget(self.view)
|
|
|
|
# --- Controls Area ---
|
|
|
|
controls_container = QWidget()
|
|
stacked_controls = QVBoxLayout(controls_container)
|
|
stacked_controls.setSpacing(5) # Tight spacing between rows
|
|
|
|
# --- ROW 2: Playback & Transport ---
|
|
playback_row = QHBoxLayout()
|
|
playback_row.addStretch()
|
|
|
|
# Transport Buttons
|
|
self.btn_start = QPushButton("|<")
|
|
self.btn_prev = QPushButton("<")
|
|
self.btn_play = QPushButton("Play")
|
|
self.btn_next = QPushButton(">")
|
|
self.btn_end = QPushButton(">|")
|
|
|
|
self.transport_btns = [self.btn_start, self.btn_prev, self.btn_play, self.btn_next, self.btn_end]
|
|
for btn in self.transport_btns:
|
|
btn.setEnabled(False)
|
|
btn.setFixedWidth(50)
|
|
playback_row.addWidget(btn)
|
|
|
|
self.btn_mute = QPushButton("Vol")
|
|
self.btn_mute.setFixedWidth(40)
|
|
self.btn_mute.setCheckable(True)
|
|
self.btn_mute.clicked.connect(self.toggle_mute)
|
|
|
|
self.sld_volume = QSlider(Qt.Horizontal)
|
|
self.sld_volume.setRange(0, 100)
|
|
self.sld_volume.setValue(100) # Default volume
|
|
self.sld_volume.setFixedWidth(100)
|
|
self.sld_volume.valueChanged.connect(self.update_volume)
|
|
|
|
# Initialize volume
|
|
self.audio_output.setVolume(0.7)
|
|
|
|
playback_row.addWidget(self.btn_mute)
|
|
playback_row.addWidget(self.sld_volume)
|
|
|
|
# Counters
|
|
counter_style = "font-family: 'Consolas'; font-size: 10pt; margin-left: 5px; color: #00FF00;"
|
|
self.lbl_time_counter = QLabel("Time: 00:00 / 00:00")
|
|
self.lbl_frame_counter = QLabel("Frame: 0 / 0")
|
|
self.lbl_time_counter.setFixedWidth(180)
|
|
self.lbl_frame_counter.setFixedWidth(180)
|
|
self.lbl_time_counter.setStyleSheet(counter_style)
|
|
self.lbl_frame_counter.setStyleSheet(counter_style)
|
|
|
|
playback_row.addWidget(self.lbl_time_counter)
|
|
playback_row.addWidget(self.lbl_frame_counter)
|
|
|
|
playback_row.addStretch()
|
|
|
|
# --- Add Rows to Stack ---
|
|
stacked_controls.addLayout(playback_row)
|
|
|
|
video_layout.addWidget(controls_container)
|
|
|
|
info_container = QWidget()
|
|
info_layout = QVBoxLayout(info_container)
|
|
|
|
self.progress_container = QWidget()
|
|
progress_layout = QVBoxLayout(self.progress_container)
|
|
|
|
self.lbl_analysis_status = QLabel("Pose Analysis: Idle")
|
|
self.analysis_bar = QProgressBar()
|
|
self.analysis_bar.setRange(0, 100)
|
|
self.analysis_bar.setValue(0)
|
|
self.analysis_bar.setStyleSheet("""
|
|
QProgressBar { border: 1px solid #555; border-radius: 2px; text-align: center; height: 15px; }
|
|
QProgressBar::chunk { background-color: #00aaff; }
|
|
""")
|
|
|
|
progress_layout.addWidget(self.lbl_analysis_status)
|
|
progress_layout.addWidget(self.analysis_bar)
|
|
|
|
# Insert into info_layout (above the inspector scroll area)
|
|
info_layout.insertWidget(0, self.progress_container)
|
|
|
|
# NEW: Wrap the info_label in a Scroll Area
|
|
self.inspector_scroll = QScrollArea()
|
|
self.inspector_scroll.setWidgetResizable(True)
|
|
|
|
self.info_label = QTextEdit()
|
|
self.info_label.setReadOnly(True)
|
|
self.info_label.setStyleSheet("padding: 8px; font-family: 'Consolas', 'Segoe UI'; color: #ffffff;")
|
|
self.inspector_scroll.setWidget(self.info_label)
|
|
|
|
# NEW: Export Button for Metrics
|
|
self.btn_export_metrics = QPushButton("Export Data for Machine Learning...")
|
|
self.btn_export_metrics.clicked.connect(self.export_behavior_metrics)
|
|
self.btn_export_metrics.setEnabled(False) # Enable only after load
|
|
|
|
self.btn_export_flares = QPushButton("Export Timeline Events for FLARES...")
|
|
self.btn_export_flares.clicked.connect(self.export_timeline_flares)
|
|
self.btn_export_flares.setEnabled(False) # Enable only after load
|
|
|
|
info_layout.addWidget(self.inspector_scroll)
|
|
info_layout.addWidget(self.btn_export_metrics)
|
|
info_layout.addWidget(self.btn_export_flares)
|
|
|
|
top_splitter.addWidget(video_container)
|
|
top_splitter.addWidget(info_container)
|
|
top_splitter.setSizes([800, 400])
|
|
|
|
self.timeline = TimelineWidget()
|
|
self.timeline.seek_requested.connect(self.seek_video)
|
|
|
|
scroll_area = QScrollArea()
|
|
scroll_area.setWidgetResizable(True)
|
|
scroll_area.setWidget(self.timeline)
|
|
|
|
self.main_splitter.addWidget(top_splitter)
|
|
self.main_splitter.addWidget(scroll_area)
|
|
self.main_splitter.setSizes([500, 400])
|
|
main_layout.addWidget(self.main_splitter)
|
|
|
|
|
|
self.skeleton_overlay = SkeletonOverlay(self.view)
|
|
self.skeleton_overlay.resize(self.view.size())
|
|
self.skeleton_overlay.hide()
|
|
|
|
# 2. FIX: Watch the view for resizes
|
|
self.view.installEventFilter(self)
|
|
|
|
self.player.positionChanged.connect(self.update_timeline_playhead)
|
|
|
|
self.setup_transport() # Start with empty workspace until worker finishes
|
|
# self.load_boris_to_timeline()
|
|
self.video_item.nativeSizeChanged.connect(self.update_video_geometry)
|
|
self.start_analysis()
|
|
|
|
|
|
|
|
def start_analysis(self):
|
|
if not self.config.get("use_pose", True):
|
|
self.lbl_analysis_status.setText("Pose Analysis: Bypassed")
|
|
self.analysis_bar.setValue(100)
|
|
self.skeleton_overlay.hide()
|
|
return
|
|
|
|
# 1. Setup Queues
|
|
self.prog_q = Queue()
|
|
self.res_q = Queue()
|
|
|
|
# 2. Create the Process
|
|
self.analysis_proc = Process(
|
|
target=run_pose_analysis,
|
|
args=(self.config['video_path'], self.prog_q, self.res_q, self.config),
|
|
name="PoseWorkerProcess"
|
|
)
|
|
|
|
# 3. UI Updates
|
|
self.lbl_analysis_status.setText("Process Started...")
|
|
self.analysis_bar.setValue(0)
|
|
self.analysis_bar.show()
|
|
|
|
# 4. Start
|
|
self.analysis_proc.start()
|
|
|
|
# 5. Timer to check the Queue
|
|
self.poll_timer = QTimer()
|
|
self.poll_timer.timeout.connect(self.check_queues)
|
|
self.poll_timer.start(100)
|
|
|
|
|
|
def check_queues(self):
|
|
# Drain progress queue
|
|
while not self.prog_q.empty():
|
|
val = self.prog_q.get()
|
|
self.analysis_bar.setValue(val)
|
|
self.lbl_analysis_status.setText(f"Extracting Poses: {val}%")
|
|
|
|
# Check result queue
|
|
if not self.res_q.empty():
|
|
data = self.res_q.get()
|
|
self.poll_timer.stop()
|
|
self.handle_finished_data(data)
|
|
|
|
|
|
def handle_finished_data(self, data):
|
|
"""
|
|
Runs on the main thread. Loads BORIS instantly, shows the skeleton,
|
|
and kicks off AI in the background.
|
|
"""
|
|
# 1. Unpack basic data
|
|
self.raw_kpts = data["raw_kpts"]
|
|
self.fps = data.get("fps", 30.0)
|
|
self.total_frames = data["total_frames"]
|
|
v_w, v_h = data["dims"]
|
|
|
|
# 2. Setup master dictionary
|
|
self.processed_data = {}
|
|
|
|
# 3. IMMEDIATE: Load BORIS (if available)
|
|
if self.config.get('use_boris'):
|
|
self.load_boris_to_timeline()
|
|
|
|
# 4. IMMEDIATE: Show Skeleton Overlay
|
|
# We calculate the baseline mean here because it's fast (Numpy)
|
|
raw_kps_per_frame = [frame[:, :2] for frame in self.raw_kpts]
|
|
valid_mask = [np.any(kp) for kp in raw_kps_per_frame]
|
|
valid_data = [raw_kps_per_frame[i] for i, v in enumerate(valid_mask) if v]
|
|
|
|
baseline_mean = np.mean(valid_data, axis=0) if valid_data else np.zeros((17, 2))
|
|
|
|
if self.config.get('use_calculations'):
|
|
dists, velocities = self.generate_automated_tracks(baseline_mean)
|
|
|
|
if self.config.get('velocity_enabled'):
|
|
for joint_name, idx in self.skeleton_overlay.KP_MAP.items():
|
|
vel_events = self.create_event_blocks(velocities[:, idx], threshold=float(self.config.get('velocity_threshold', 15)))
|
|
if vel_events:
|
|
self.processed_data[f"Vel_{joint_name}"] = vel_events
|
|
|
|
if self.config.get('deviation_enabled'):
|
|
# 2. Convert to timeline events (using a threshold of e.g. 50 pixels)
|
|
for joint_name, idx in self.skeleton_overlay.KP_MAP.items():
|
|
dev_events = self.create_event_blocks(dists[:, idx], threshold=float(self.config.get('deviation_threshold', 50)))
|
|
if dev_events:
|
|
self.processed_data[f"Dev_{joint_name}"] = dev_events
|
|
|
|
|
|
overlay_payload = {
|
|
"raw_kps": self.raw_kpts,
|
|
"width": v_w,
|
|
"height": v_h,
|
|
"events": self.processed_data, # Initially just BORIS
|
|
"baseline_kp_mean": baseline_mean
|
|
}
|
|
self.skeleton_overlay.set_data(overlay_payload)
|
|
self.skeleton_overlay.show()
|
|
|
|
# 5. SYNC UI (Show the manual timeline immediately)
|
|
self.sync_timeline_to_ui()
|
|
|
|
# 6. BACKGROUND: Start AI Inference Thread
|
|
if self.config.get('use_pkl') and hasattr(self, 'ml_model'):
|
|
self.lbl_analysis_status.setText("Running AI Inference...")
|
|
|
|
# Start the worker thread
|
|
self.ml_worker = MLInferenceWorker(
|
|
self.raw_kpts,
|
|
self.ml_model,
|
|
self.ml_scaler,
|
|
self.active_features,
|
|
self.ml_metadata.get('target_behavior', 'Reach')
|
|
)
|
|
self.ml_worker.finished.connect(self.on_ai_inference_complete)
|
|
self.ml_worker.error.connect(lambda e: print(f"AI ERROR: {e}"))
|
|
self.ml_worker.start()
|
|
else:
|
|
self.lbl_analysis_status.setText("Analysis Complete")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_automated_tracks(self, baseline_mean):
|
|
# Convert list of frames to a numpy stack: (frames, joints, xy)
|
|
raw_stack = np.stack([f[:, :2] for f in self.raw_kpts])
|
|
|
|
# 1. Get Live Pelvis (Anchor)
|
|
l_hip_idx, r_hip_idx = 11, 12 # Adjust based on your KP_MAP
|
|
pelvis_live = (raw_stack[:, l_hip_idx] + raw_stack[:, r_hip_idx]) / 2
|
|
|
|
# 2. Center the Baseline Template
|
|
pelvis_base = (baseline_mean[l_hip_idx] + baseline_mean[r_hip_idx]) / 2
|
|
base_template = baseline_mean - pelvis_base
|
|
|
|
# 3. Broadcast template across all frames
|
|
# (frames, 1, 2) + (17, 2) -> (frames, 17, 2)
|
|
target_pos = pelvis_live[:, np.newaxis, :] + base_template
|
|
|
|
# 4. Euclidean Distance
|
|
dists = np.linalg.norm(raw_stack - target_pos, axis=2)
|
|
|
|
# 5. Velocity (Difference between frames)
|
|
velocities = np.zeros_like(dists)
|
|
velocities[1:] = np.linalg.norm(np.diff(raw_stack, axis=0), axis=2)
|
|
|
|
return dists, velocities
|
|
|
|
|
|
|
|
def create_event_blocks(self, data_array, threshold):
|
|
events = []
|
|
active = False
|
|
start_f = 0
|
|
|
|
mask = data_array > threshold
|
|
for f, is_high in enumerate(mask):
|
|
if is_high and not active:
|
|
active = True
|
|
start_f = f
|
|
elif not is_high and active:
|
|
active = False
|
|
if (f - start_f) > 5: # Filter out noise shorter than 5 frames
|
|
events.append([start_f, f, "Moderate", "N/A"])
|
|
return events
|
|
|
|
|
|
|
|
def on_ai_inference_complete(self, ai_events):
|
|
"""Runs when the thread finishes. Merges AI into the existing UI."""
|
|
# 1. Merge AI tracks into the dictionary that already has BORIS
|
|
for track_name, blocks in ai_events.items():
|
|
self.processed_data[track_name] = blocks
|
|
print(f"AI Thread complete: Injected {track_name}")
|
|
|
|
# 2. Update the Timeline Widget visually
|
|
self.sync_timeline_to_ui()
|
|
|
|
# 3. Update the Skeleton Overlay specifically
|
|
# This ensures the AI blocks show up in the video player's seek bar/overlay
|
|
if hasattr(self, 'skeleton_overlay'):
|
|
# Fetch the old data dict, update the 'events' key, and push it back
|
|
updated_payload = self.skeleton_overlay.data.copy()
|
|
updated_payload['events'] = self.processed_data
|
|
self.skeleton_overlay.set_data(updated_payload)
|
|
|
|
self.lbl_analysis_status.setText("All Tracks (BORIS + AI) Loaded")
|
|
|
|
|
|
|
|
def load_boris_to_timeline(self):
|
|
"""
|
|
Parses the JSON to identify unique behaviors and prepare the data.
|
|
"""
|
|
if not self.config.get('use_boris', False) or not self.config.get('obs_file'):
|
|
return
|
|
|
|
try:
|
|
with open(self.config['obs_file'], 'r') as f:
|
|
data = json.load(f)
|
|
|
|
behav_conf = data.get("behaviors_conf", {})
|
|
type_lookup = {v['code']: v['type'] for v in behav_conf.values()}
|
|
|
|
session_key = self.config.get('session_key')
|
|
session = data.get("observations", {}).get(session_key, {})
|
|
events = session.get("events", [])
|
|
|
|
unique_behaviors = sorted(list(set(e[2] for e in events)))
|
|
new_boris_data = {name: [] for name in unique_behaviors}
|
|
|
|
fps = self.config.get('fps', 30.0)
|
|
# Use the offset from the config
|
|
offset_frames = int(self.config.get('offset', 0.0) * fps)
|
|
|
|
state_tracker = {}
|
|
|
|
for e in events:
|
|
timestamp = float(e[0])
|
|
behavior_name = e[2]
|
|
event_type = type_lookup.get(behavior_name, "Point event")
|
|
current_frame = int(timestamp * fps) - offset_frames
|
|
|
|
if event_type == "Point event":
|
|
new_boris_data[behavior_name].append([current_frame, current_frame + 1, "Normal", "N/A"])
|
|
elif event_type == "State event":
|
|
if behavior_name in state_tracker:
|
|
start_f = state_tracker.pop(behavior_name)
|
|
new_boris_data[behavior_name].append([start_f, current_frame, "Normal", "N/A"])
|
|
else:
|
|
state_tracker[behavior_name] = current_frame
|
|
|
|
if not hasattr(self, 'processed_data') or self.processed_data is None:
|
|
self.processed_data = {}
|
|
|
|
for behavior, blocks in new_boris_data.items():
|
|
self.processed_data[behavior] = blocks
|
|
|
|
print(f"[INFO] Merged {len(new_boris_data)} BORIS tracks into timeline.")
|
|
|
|
except Exception as e:
|
|
print(f"Error loading BORIS data: {e}")
|
|
|
|
|
|
def sync_timeline_to_ui(self):
|
|
"""Final push of all merged data (BORIS + AI) to the UI components."""
|
|
if not hasattr(self, 'processed_data'):
|
|
return
|
|
|
|
total_f = self.config.get('total_frames', self.total_frames)
|
|
fps = self.config.get('fps', self.fps)
|
|
|
|
# 1. Update the Timeline Widget
|
|
self.timeline.set_data(self.processed_data, total_f, fps)
|
|
|
|
# 2. Update Stats & Export button
|
|
self.update_inspector_stats(self.processed_data, fps)
|
|
self.btn_export_metrics.setEnabled(True)
|
|
self.btn_export_flares.setEnabled(True)
|
|
|
|
print(f"DEBUG: UI Synced with {len(self.processed_data)} total tracks.")
|
|
|
|
|
|
def update_inspector_stats(self, data, fps):
|
|
"""Calculates and displays behavior summary on the right panel."""
|
|
stats_text = "<b><font color='#00FF00'>SESSION SUMMARY</font></b><br><br>"
|
|
stats_text += f"File: {self.config['session_key']}<br>"
|
|
stats_text += "-----------------------------------<br>"
|
|
|
|
for behavior, instances in data.items():
|
|
durations = [(end - start) / fps for start, end, _, _ in instances]
|
|
count = len(instances)
|
|
avg_dur = np.mean(durations) if durations else 0
|
|
total_dur = sum(durations)
|
|
|
|
stats_text += f"<b>{behavior}</b>:<br>"
|
|
stats_text += f" - Occurrences: {count}<br>"
|
|
stats_text += f" - Avg Duration: {avg_dur:.3f}s<br>"
|
|
stats_text += f" - Total Time: {total_dur:.2f}s<br><br>"
|
|
|
|
self.info_label.setHtml(stats_text)
|
|
|
|
|
|
def export_timeline_flares(self):
|
|
export_dialog = ExportTimelineJsonWindow(
|
|
timeline_data=self.processed_data,
|
|
fps=self.fps,
|
|
parent=self
|
|
)
|
|
export_dialog.exec()
|
|
|
|
|
|
def export_behavior_metrics(self):
|
|
"""Exports the processed behavior metrics to a new JSON file."""
|
|
if not self.processed_data:
|
|
return
|
|
|
|
export_payload = {
|
|
"metadata": {
|
|
"source_video": os.path.basename(self.config.get('video_path', 'unknown')),
|
|
"session": self.config.get('session_key', 'unknown'),
|
|
"pose_model": self.config.get('pose_model', 'unknown'),
|
|
"export_timestamp": datetime.now().isoformat(),
|
|
"fps": self.config.get('fps', 30.0)
|
|
},
|
|
"behaviors": {}
|
|
}
|
|
|
|
for behavior, instances in self.processed_data.items():
|
|
behavior_events = []
|
|
|
|
for start_f, end_f, _, _ in instances:
|
|
behavior_events.append({
|
|
"start_frame": int(start_f),
|
|
"duration_frames": int(end_f - start_f)
|
|
})
|
|
|
|
export_payload["behaviors"][behavior] = behavior_events
|
|
|
|
video_path = self.config.get('video_path')
|
|
if not video_path:
|
|
return
|
|
|
|
base_path = video_path.rsplit('.', 1)[0]
|
|
export_path = f"{base_path}_metrics.json"
|
|
|
|
# 3. SILENT SAVE
|
|
try:
|
|
with open(export_path, 'w') as f:
|
|
json.dump(export_payload, f, indent=4)
|
|
|
|
# Log the success so the researcher knows where it went
|
|
msg = f"<br><b><font color='#00ffaa'>[EXPORT COMPLETE]</font></b>: {os.path.basename(export_path)}"
|
|
self.info_label.append(msg)
|
|
print(f"Metrics saved to: {export_path}")
|
|
|
|
except Exception as e:
|
|
self.info_label.append(f"<br><font color='#ff4444'>[EXPORT FAILED]</font>: {e}")
|
|
|
|
|
|
|
|
def setup_transport(self):
|
|
"""Sets up player controls that don't depend on skeleton analysis."""
|
|
# Enable buttons immediately
|
|
for btn in [self.btn_play, self.btn_prev, self.btn_next, self.btn_start, self.btn_end]:
|
|
btn.setEnabled(True)
|
|
|
|
# Connections (Use disconnect first to avoid double-firing if re-called)
|
|
try: self.btn_play.clicked.disconnect()
|
|
except: pass
|
|
|
|
self.btn_play.clicked.connect(self.toggle_playback)
|
|
self.btn_start.clicked.connect(lambda: self.player.setPosition(0))
|
|
# Note: 'End' and 'Step' need FPS/Duration, handled in the methods themselves
|
|
self.btn_end.clicked.connect(lambda: self.player.setPosition(self.player.duration()))
|
|
self.btn_prev.clicked.connect(lambda: self.step_frame(-1))
|
|
self.btn_next.clicked.connect(lambda: self.step_frame(1))
|
|
|
|
self.player.setSource(QUrl.fromLocalFile(self.config['video_path']))
|
|
self.player.pause()
|
|
# self.player.mediaStatusChanged.connect(self.initial_resize_hack)
|
|
|
|
|
|
# def initial_resize_hack(self, status):
|
|
# # Once the media is loaded, refresh the layout
|
|
# if status >= QMediaPlayer.MediaStatus.LoadedMedia:
|
|
# self.update_video_geometry()
|
|
# # Seek to 0 just to be absolutely sure the buffer updates
|
|
# self.player.setPosition(0)
|
|
|
|
|
|
def initialize_session(self):
|
|
print(f"--- Initializing Session Components ---")
|
|
|
|
# Component A: Pose Inference (Skeleton extraction)
|
|
if self.config.get('use_pose', False):
|
|
# This is already handled by self.start_analysis() in your __init__
|
|
pass
|
|
|
|
# Component B: BORIS Annotation Track
|
|
if self.config.get('use_boris', False):
|
|
print("[INFO] Loading BORIS annotation track...")
|
|
self.load_boris_to_timeline()
|
|
|
|
# Component C: ML Prediction Track (.pkl)
|
|
if self.config.get('use_pkl', False):
|
|
print(f"[INFO] Loading ML Model: {os.path.basename(self.config['pkl_path'])}")
|
|
self.load_pretrained_classifier()
|
|
|
|
def load_pretrained_classifier(self):
|
|
"""Loads the .pkl model and automatically hunts for its scaler."""
|
|
if not self.config.get('use_pkl') or not self.config.get('pkl_path'):
|
|
return
|
|
|
|
model_path = self.config['pkl_path']
|
|
|
|
metadata_path = model_path.replace(".pkl", "_metadata.json")
|
|
if os.path.exists(metadata_path):
|
|
with open(metadata_path, 'r') as f:
|
|
self.ml_metadata = json.load(f)
|
|
self.active_features = self.ml_metadata.get("feature_keys", [])
|
|
print(f"[INFO] Feature map loaded: {len(self.active_features)} features.")
|
|
else:
|
|
raise Exception
|
|
|
|
try:
|
|
# 1. Load the primary model
|
|
self.ml_model = joblib.load(model_path)
|
|
msg = f"[INFO] ML Model loaded: {os.path.basename(model_path)}"
|
|
print(msg)
|
|
self.update_status(f"<font color='#00aaff'>{msg}</font>")
|
|
|
|
# 2. Auto-discover the Scaler
|
|
base_name = os.path.splitext(model_path)[0]
|
|
possible_scaler_paths = [
|
|
f"{base_name}_scaler.pkl",
|
|
os.path.join(os.path.dirname(model_path), "scaler.pkl")
|
|
]
|
|
|
|
self.ml_scaler = None
|
|
for spath in possible_scaler_paths:
|
|
if os.path.exists(spath):
|
|
self.ml_scaler = joblib.load(spath)
|
|
s_msg = f"[INFO] Associated scaler auto-loaded: {os.path.basename(spath)}"
|
|
print(s_msg)
|
|
self.update_status(f"<font color='#00aaff'>{s_msg}</font>")
|
|
break
|
|
|
|
if not self.ml_scaler:
|
|
print("[WARNING] No associated scaler found. Proceeding without scaling.")
|
|
|
|
except Exception as e:
|
|
err = f"[ERROR] Failed to load ML Model or Scaler: {e}"
|
|
print(err)
|
|
self.update_status(f"<font color='#ff4444'>{err}</font>")
|
|
|
|
|
|
def run_ml_inference(self, raw_kpts):
|
|
"""
|
|
Applies scaler, runs inference, and converts frame-by-frame
|
|
predictions into contiguous timeline blocks.
|
|
"""
|
|
if not hasattr(self, 'ml_model') or not self.active_features:
|
|
return {}
|
|
|
|
# 1. Create the engine and tell it WHICH features to care about
|
|
engine = GeneralPredictor()
|
|
engine.active_feature_keys = self.active_features
|
|
|
|
# 2. Extract 13 features for every frame
|
|
X_raw = []
|
|
for frame_idx in range(len(raw_kpts)):
|
|
# format_features now returns only the 13 needed values
|
|
feat_vector = engine.format_features(raw_kpts[frame_idx])
|
|
X_raw.append(feat_vector)
|
|
|
|
X = np.array(X_raw) # Resulting shape: (Frames, 13)
|
|
|
|
# 3. Predict
|
|
if self.ml_scaler:
|
|
X = self.ml_scaler.transform(X)
|
|
|
|
preds = self.ml_model.predict(X)
|
|
unique, counts = np.unique(preds, return_counts=True)
|
|
print(f"DEBUG ML Results: {dict(zip(unique, counts))}")
|
|
return self._convert_predictions_to_tracks(preds)
|
|
|
|
|
|
|
|
def _convert_predictions_to_tracks(self, predictions):
|
|
"""Converts an array of class labels into start/stop timeline blocks."""
|
|
events = {}
|
|
current_class = None
|
|
start_frame = 0
|
|
|
|
# Define labels that mean "nothing is happening"
|
|
background_labels = [0, "0", "Idle", "None", None, "Background"]
|
|
|
|
for i, pred_class in enumerate(predictions):
|
|
if pred_class != current_class:
|
|
# Close the previous active block
|
|
if current_class not in background_labels:
|
|
track_name = f"🤖 AI: {current_class}"
|
|
if track_name not in events:
|
|
events[track_name] = []
|
|
# Format: [start_frame, end_frame, label, notes]
|
|
events[track_name].append([start_frame, i, "Normal", "ML Prediction"])
|
|
|
|
# Start new block
|
|
current_class = pred_class
|
|
start_frame = i
|
|
|
|
# Close the final block if the video ends while an action is active
|
|
if current_class not in background_labels:
|
|
track_name = f"🤖 AI: {current_class}"
|
|
if track_name not in events:
|
|
events[track_name] = []
|
|
events[track_name].append([start_frame, len(predictions), "Normal", "ML Prediction"])
|
|
|
|
return events
|
|
|
|
|
|
def load_boris_annotations(self):
|
|
"""Logic to parse the JSON for the specific session/slot."""
|
|
try:
|
|
with open(self.config['obs_file'], 'r') as f:
|
|
data = json.load(f)
|
|
|
|
session = data.get("observations", {}).get(self.config['session_key'], {})
|
|
# Extract events for the specific slot
|
|
events = session.get("events", [])
|
|
# ... Filter events where slot matches config['slot'] ...
|
|
print(f"Loaded {len(events)} events from BORIS.")
|
|
except Exception as e:
|
|
print(f"Failed to load BORIS data: {e}")
|
|
|
|
|
|
def update_status(self, message):
|
|
"""Updates the inspector or a status bar with worker progress."""
|
|
self.info_label.append(message)
|
|
|
|
|
|
def toggle_playback(self):
|
|
if self.player.playbackState() == QMediaPlayer.PlayingState:
|
|
self.player.pause()
|
|
self.btn_play.setText("Play")
|
|
else:
|
|
self.player.play()
|
|
self.btn_play.setText("Pause")
|
|
|
|
|
|
def update_timeline_playhead(self, position_ms):
|
|
#debug_print()
|
|
fps = self.config.get('fps', 30.0)
|
|
total_f = self.config.get('total_frames', 0)
|
|
|
|
# Current frame calculation
|
|
current_f = int((position_ms / 1000.0) * fps)
|
|
|
|
# --- PREVENT BLACK FRAME AT END ---
|
|
# If we are within 1 frame of the end, stop and lock to the last valid frame
|
|
if hasattr(self, 'skeleton_overlay') and self.skeleton_overlay.isVisible():
|
|
self.skeleton_overlay.set_frame(current_f)
|
|
|
|
if current_f >= total_f - 1:
|
|
if self.player.playbackState() == QMediaPlayer.PlayingState:
|
|
self.player.pause()
|
|
self.btn_play.setText("Play")
|
|
current_f = total_f - 1
|
|
# Seek slightly back from total duration to keep the image visible
|
|
last_valid_ms = int(((total_f - 1) / fps) * 1000)
|
|
self.player.setPosition(last_valid_ms)
|
|
|
|
# Sync UI
|
|
self.timeline.set_playhead(current_f)
|
|
self.update_counters(current_f)
|
|
|
|
def seek_video(self, frame):
|
|
# Use the config or timeline data instead of self.data
|
|
fps = self.config.get('fps', 30.0)
|
|
total_f = self.config.get('total_frames', 0)
|
|
|
|
target_frame = max(0, min(frame, total_f - 1))
|
|
|
|
# Convert frame to milliseconds for QMediaPlayer
|
|
ms = int((target_frame / fps) * 1000)
|
|
self.player.setPosition(ms)
|
|
|
|
# Sync the UI
|
|
self.timeline.set_playhead(target_frame)
|
|
self.update_counters(target_frame)
|
|
|
|
|
|
def update_counters(self, current_f):
|
|
#debug_print()
|
|
|
|
# Dedicated method to refresh the labels
|
|
fps = self.config.get('fps', 30.0)
|
|
total_f = self.config.get('total_frames', 0)
|
|
|
|
cur_s, tot_s = int(current_f / fps), int(total_f / fps)
|
|
self.lbl_time_counter.setText(f"Time: {cur_s//60:02d}:{cur_s%60:02d} / {tot_s//60:02d}:{tot_s%60:02d}")
|
|
self.lbl_frame_counter.setText(f"Frame: {current_f} / {total_f-1}")
|
|
|
|
|
|
def step_frame(self, delta):
|
|
# Fallback to 30 FPS if worker data isn't ready
|
|
fps = self.config.get('fps', 30.0)
|
|
|
|
current_ms = self.player.position()
|
|
# One frame in ms = 1000 / fps
|
|
frame_ms = 1000.0 / fps
|
|
target_ms = int(current_ms + (delta * frame_ms))
|
|
|
|
# Ensure we don't seek past duration
|
|
target_ms = max(0, min(target_ms, self.player.duration()))
|
|
self.player.setPosition(target_ms)
|
|
|
|
|
|
def update_volume(self, value):
|
|
# QAudioOutput expects a float between 0.0 and 1.0
|
|
volume = value / 100.0
|
|
self.audio_output.setVolume(volume)
|
|
|
|
# Auto-unmute if user moves the slider
|
|
if self.btn_mute.isChecked() and value > 0:
|
|
self.btn_mute.setChecked(False)
|
|
self.toggle_mute()
|
|
|
|
|
|
def toggle_mute(self):
|
|
is_muted = self.btn_mute.isChecked()
|
|
self.audio_output.setMuted(is_muted)
|
|
self.btn_mute.setText("Mute" if is_muted else "Vol")
|
|
# Optional: Dim the slider when muted
|
|
self.sld_volume.setEnabled(not is_muted)
|
|
|
|
|
|
def update_video_geometry(self):
|
|
if not hasattr(self, "video_item"):
|
|
return
|
|
|
|
# 1. Get viewport dimensions
|
|
viewport_rect = self.view.viewport().rect()
|
|
v_w, v_h = viewport_rect.width(), viewport_rect.height()
|
|
if v_w <= 0 or v_h <= 0:
|
|
return
|
|
|
|
# 2. Get Video Dimensions (Fall back to native size if worker data is missing)
|
|
if hasattr(self, "data") and self.data:
|
|
video_w, video_h = self.data['width'], self.data['height']
|
|
else:
|
|
native_size = self.video_item.nativeSize()
|
|
video_w, video_h = native_size.width(), native_size.height()
|
|
|
|
# If the video hasn't loaded metadata yet, it will be -1 or 0
|
|
if video_w <= 0 or video_h <= 0:
|
|
return
|
|
|
|
# 3. Calculate Aspect Ratio Scaling
|
|
aspect = video_w / video_h
|
|
if v_w / v_h > aspect:
|
|
target_h = v_h
|
|
target_w = int(v_h * aspect)
|
|
else:
|
|
target_w = v_w
|
|
target_h = int(v_w / aspect)
|
|
|
|
x_off = (v_w - target_w) / 2
|
|
y_off = (v_h - target_h) / 2
|
|
|
|
# 4. Apply transformations
|
|
self.scene.setSceneRect(0, 0, v_w, v_h)
|
|
self.video_item.setPos(x_off, y_off)
|
|
self.video_item.setSize(QSizeF(target_w, target_h))
|
|
|
|
# Only update overlay if it exists and we have data
|
|
if hasattr(self, "skeleton_overlay"):
|
|
self.skeleton_overlay.setGeometry(int(x_off), int(y_off), target_w, target_h)
|
|
|
|
|
|
def resizeEvent(self, event):
|
|
# debug_print()
|
|
|
|
super().resizeEvent(event)
|
|
self.update_video_geometry()
|
|
if hasattr(self, 'timeline'):
|
|
self.timeline.update_geometry()
|
|
|
|
def eventFilter(self, source, event):
|
|
"""Keeps the skeleton aligned with the video frame size."""
|
|
if source == getattr(self, 'video_preview_label', None) and event.type() == QEvent.Resize:
|
|
self.skeleton_overlay.resize(event.size())
|
|
return super().eventFilter(source, event)
|
|
|
|
|
|
def cleanup(self):
|
|
if self.player:
|
|
self.player.stop()
|
|
if hasattr(self, 'worker') and self.worker.isRunning():
|
|
self.worker.terminate()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
------------------------------ It all breaks here ------------------------------
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from PySide6.QtWidgets import (QWidget, QVBoxLayout, QSplitter, QListWidget,
|
|
QGraphicsView, QGraphicsScene, QListWidgetItem)
|
|
from PySide6.QtCore import Qt, QMimeData
|
|
from PySide6.QtGui import QDrag
|
|
|
|
from PySide6.QtGui import QPainterPath, QColor, QPen, QBrush
|
|
from PySide6.QtWidgets import QGraphicsPathItem, QGraphicsSimpleTextItem
|
|
|
|
from PySide6.QtWidgets import QGraphicsItem, QGraphicsSimpleTextItem, QInputDialog
|
|
|
|
|
|
|
|
class PuzzleBlock(QGraphicsPathItem):
|
|
def __init__(self, b_type, label, parent_item=None, fields=None):
|
|
super().__init__(parent_item)
|
|
self.b_type = b_type # "begin", "middle", "end"
|
|
self.label_text = label
|
|
self.width = 160
|
|
self.height = 50
|
|
self.tab_size = 15
|
|
|
|
self.parent_block = None
|
|
self.child_block = None
|
|
|
|
self.setFlags(
|
|
QGraphicsItem.GraphicsItemFlag.ItemIsMovable |
|
|
QGraphicsItem.GraphicsItemFlag.ItemIsSelectable |
|
|
QGraphicsItem.GraphicsItemFlag.ItemSendsGeometryChanges
|
|
)
|
|
|
|
self.setData(0, "p_block")
|
|
|
|
self.inputs = {}
|
|
self.original_inputs = {}
|
|
|
|
print(f"[DEBUG] - {fields}")
|
|
|
|
if fields:
|
|
self.container = QWidget()
|
|
self.container.setObjectName("blockContainer")
|
|
self.container.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground)
|
|
if self.b_type == "logic":
|
|
self.container.setFixedSize(130, 20)
|
|
self.container.setStyleSheet("background: transparent; border: none;")
|
|
else:
|
|
self.container.setFixedWidth(self.width)
|
|
self.container.setStyleSheet("background: transparent; color: white;")
|
|
|
|
# Use FormLayout for Side-by-Side placement
|
|
layout = QFormLayout(self.container)
|
|
layout.setContentsMargins(10, 5, 10, 5)
|
|
layout.setSpacing(5)
|
|
layout.setLabelAlignment(Qt.AlignmentFlag.AlignLeft)
|
|
layout.setFieldGrowthPolicy(QFormLayout.FieldGrowthPolicy.AllNonFixedFieldsGrow)
|
|
|
|
for f in fields:
|
|
if f['type'] in ["number", "string"]:
|
|
# Create the label
|
|
field_label = QLabel(f.get('label', 'Field:'))
|
|
field_label.setStyleSheet("font-size: 9px; font-weight: bold;")
|
|
|
|
# Create the input
|
|
edit = QLineEdit()
|
|
edit.setFixedHeight(18) # Keep it slim
|
|
edit.setStyleSheet("background: rgba(255, 255, 255, 0.2); border: 1px solid #aaa; color: white;")
|
|
|
|
if f['type'] == "number":
|
|
edit.setValidator(QDoubleValidator())
|
|
|
|
# Add to the form: Label is left, Edit is right
|
|
layout.addRow(field_label, edit)
|
|
self.inputs[f.get('label')] = edit
|
|
|
|
elif f['type'] == "slot":
|
|
slot_label = QLabel("")
|
|
slot_label.setStyleSheet("""
|
|
color: #FFA500;
|
|
font-size: 10px;
|
|
border: 1px dashed #FFA500;
|
|
padding: 2px;
|
|
border-radius: 2px;
|
|
min-height: 20px; /* CRITICAL: Gives the mouse something to hit */
|
|
background: rgba(255, 165, 0, 0.05); /* Slight tint helps visual hit-testing */
|
|
""")
|
|
# Force it to expand so it's not a tiny dot in the corner
|
|
slot_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
|
|
|
|
# Add as a full-width row
|
|
layout.addRow(slot_label)
|
|
# Make sure it's in the inputs dict so perform_snap can find the key
|
|
key = f.get('label', 'slot')
|
|
self.inputs[key] = slot_label
|
|
self.original_inputs[key] = slot_label
|
|
|
|
self.proxy = QGraphicsProxyWidget(self)
|
|
self.proxy.setZValue(10)
|
|
self.proxy.setWidget(self.container)
|
|
# Center the 160px container inside your (potentially) 180px block
|
|
# Or set self.width = 160 to match exactly
|
|
self.width = 160
|
|
self.proxy.setPos(0, 35)
|
|
|
|
# Calculate height based on the form's hint
|
|
self.height = 50 + layout.sizeHint().height()
|
|
|
|
# Color Coding
|
|
self.update_path()
|
|
|
|
colors = {"begin": "#2e7d32", "middle": "#1565c0", "end": "#c62828", "logic": "#ef6c00"}
|
|
self.setBrush(QBrush(QColor(colors.get(b_type, "#555"))))
|
|
self.setPen(QPen(QColor("#ffffff"), 1))
|
|
|
|
self.label_item = QGraphicsSimpleTextItem(self.label_text, self)
|
|
self.label_item.setBrush(QColor("white"))
|
|
|
|
if self.b_type == "logic":
|
|
self.label_item.setPos(2, 1)
|
|
else:
|
|
self.label_item.setPos(25, 15)
|
|
|
|
|
|
def update_path(self):
|
|
"""This changes how they visually appear. Likely no logic broken here."""
|
|
self.prepareGeometryChange()
|
|
|
|
if self.b_type == "logic":
|
|
display_width = 130 # Slimmer profile
|
|
display_height = 20 # Shorter default height
|
|
else:
|
|
display_width = self.width
|
|
display_height = self.height
|
|
|
|
path = QPainterPath()
|
|
path.moveTo(0, 0)
|
|
|
|
# 1. TOP EDGE
|
|
# Middle and End blocks now have a "Nub" (Male) on top
|
|
if self.b_type in ["middle", "end"]:
|
|
path.lineTo(30, 0)
|
|
# Nub (Male) - Sweep is negative to curve OUTWARD
|
|
path.arcTo(30, -self.tab_size/2, self.tab_size, self.tab_size, 180, -180)
|
|
path.lineTo(display_width, 0)
|
|
|
|
# 2. RIGHT EDGE
|
|
path.lineTo(display_width, display_height)
|
|
|
|
# 3. BOTTOM EDGE
|
|
# Begin and Middle blocks now have a "Hole" (Female) on bottom
|
|
if self.b_type in ["begin", "middle"]:
|
|
path.lineTo(30 + self.tab_size, display_height)
|
|
# Socket (Female) - Sweep is positive to curve INWARD
|
|
path.arcTo(30, display_height - self.tab_size/2, self.tab_size, self.tab_size, 0, 180)
|
|
path.lineTo(0, display_height)
|
|
|
|
# 4. LEFT EDGE
|
|
path.lineTo(0, 0)
|
|
|
|
path.closeSubpath()
|
|
self.setPath(path)
|
|
|
|
|
|
def mouseDoubleClickEvent(self, event):
|
|
'''Simply to edit the name of a start box.'''
|
|
if self.b_type == "begin":
|
|
new_name, ok = QInputDialog.getText(None, "Rename Step", "Enter name:", text=self.label_text)
|
|
if ok and new_name:
|
|
self.label_text = new_name
|
|
self.label_item.setText(new_name)
|
|
super().mouseDoubleClickEvent(event)
|
|
|
|
|
|
def mousePressEvent(self, event):
|
|
# Only handle visual layering here. Do NOT detach links.
|
|
self.setZValue(100)
|
|
super().mousePressEvent(event)
|
|
|
|
|
|
def mouseReleaseEvent(self, event):
|
|
# Drop it back down to standard Z-level
|
|
self.setZValue(0)
|
|
super().mouseReleaseEvent(event)
|
|
|
|
|
|
def itemChange(self, change, value):
|
|
if change == QGraphicsItem.GraphicsItemChange.ItemPositionHasChanged:
|
|
if self.scene() and self.scene().mouseGrabberItem() is self:
|
|
|
|
# 1. Sever Parent Link (Both ways)
|
|
if self.parent_block:
|
|
print(f"DEBUG: Tearing away from parent: {self.parent_block.label_text}")
|
|
self.parent_block.child_block = None # Tell parent I'm gone
|
|
self.parent_block = None # Forget my parent
|
|
|
|
# 2. Sever Child Link (Both ways)
|
|
if self.child_block and not self.child_block.isSelected():
|
|
print(f"DEBUG: Tearing away from stationary child: {self.child_block.label_text}")
|
|
self.child_block.parent_block = None # Tell child I'm gone
|
|
self.child_block = None # Forget my child
|
|
|
|
return super().itemChange(change, value)
|
|
|
|
|
|
|
|
class BlockLibrary(QListWidget):
|
|
def __init__(self, parent_tab, stringy):
|
|
super().__init__()
|
|
self.setDragEnabled(True)
|
|
|
|
self.mode = stringy
|
|
|
|
self.parent_tab = parent_tab # Reference to main tab to access canvas
|
|
|
|
self.load_blocks_from_json("blocks.json")
|
|
|
|
|
|
def load_blocks_from_json(self, file_path):
|
|
if not os.path.exists(file_path):
|
|
print(f"Error: {file_path} not found.")
|
|
return
|
|
|
|
try:
|
|
with open(file_path, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
# Pull the list assigned to the current mode (e.g., "individual")
|
|
blocks = data.get(self.mode, [])
|
|
|
|
for block in blocks:
|
|
# NEW: Get fields if they exist, otherwise empty list
|
|
fields = block.get('fields', [])
|
|
self.add_item(block['name'], block['type'], fields)
|
|
|
|
except Exception as e:
|
|
print(f"Failed to load block library: {e}")
|
|
|
|
|
|
def add_item(self, label, block_type, fields):
|
|
item = QListWidgetItem(label)
|
|
item.setData(Qt.UserRole, block_type)
|
|
# NEW: Store fields list in the item
|
|
item.setData(Qt.UserRole + 1, fields)
|
|
self.addItem(item)
|
|
|
|
|
|
def startDrag(self, supportedActions):
|
|
item = self.currentItem()
|
|
if not item: return
|
|
|
|
mime_data = QMimeData()
|
|
|
|
# NEW: We must include the fields in the drag data!
|
|
# Format: type | label | fields_as_json_string
|
|
b_type = item.data(Qt.UserRole)
|
|
label = item.text()
|
|
fields = item.data(Qt.UserRole + 1)
|
|
fields_str = json.dumps(fields) # Convert list to string for transfer
|
|
|
|
mime_data.setText(f"{b_type}|{label}|{fields_str}")
|
|
|
|
drag = QDrag(self)
|
|
drag.setMimeData(mime_data)
|
|
result = drag.exec(supportedActions)
|
|
|
|
print(f"DEBUG: Drag finished with result: {result}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from PySide6.QtWidgets import QGraphicsLineItem, QGraphicsRectItem, QGraphicsSimpleTextItem
|
|
from PySide6.QtGui import QPen, QColor, QBrush
|
|
|
|
|
|
class TopologyCanvas(QGraphicsView):
|
|
def __init__(self):
|
|
self.scene = QGraphicsScene()
|
|
super().__init__(self.scene)
|
|
self.setAcceptDrops(True)
|
|
self.scene.setSceneRect(0, 0, 2000, 2000)
|
|
|
|
|
|
|
|
def mousePressEvent(self, event):
|
|
item = self.itemAt(event.position().toPoint())
|
|
# Toggle drag mode based on whether an item was clicked
|
|
mode = QGraphicsView.DragMode.NoDrag if item else QGraphicsView.DragMode.RubberBandDrag
|
|
self.setDragMode(mode)
|
|
super().mousePressEvent(event)
|
|
|
|
#needed?
|
|
self.chain = []
|
|
self._processing_release = False
|
|
|
|
|
|
def mouseReleaseEvent(self, event):
|
|
# Let the native C++ event finish first
|
|
super().mouseReleaseEvent(event)
|
|
|
|
if self._processing_release:
|
|
return
|
|
|
|
# Capture the data we need
|
|
global_pos = event.globalPosition().toPoint()
|
|
scene_pos = self.mapToScene(event.position().toPoint())
|
|
|
|
# Defer the logic to the next event loop cycle (stops the DLL crash)
|
|
QTimer.singleShot(0, lambda: self.safe_process_release(global_pos, scene_pos))
|
|
|
|
def safe_process_release(self, global_pos, scene_pos):
|
|
self._processing_release = True
|
|
try:
|
|
# 1. Cleaner Trash Check (Assume you have a ref to your sidebar)
|
|
# Replacing the 'while widget' loop with a geometry check is much safer
|
|
if hasattr(self, 'library_ref') and self.library_ref.rect().contains(self.library_ref.mapFromGlobal(global_pos)):
|
|
print("Action: Trashing items")
|
|
for item in self.scene.selectedItems():
|
|
self.scene.removeItem(item)
|
|
return
|
|
|
|
# 2. Logic for Snapping
|
|
items_at_pos = self.scene.items(scene_pos)
|
|
driver = next((i for i in items_at_pos if isinstance(i, PuzzleBlock) and i.isSelected()), None)
|
|
target = next((i for i in items_at_pos if isinstance(i, PuzzleBlock) and not i.isSelected()), None)
|
|
|
|
if driver and target:
|
|
self.perform_snap(driver, scene_pos)
|
|
|
|
finally:
|
|
self._processing_release = False
|
|
|
|
|
|
def dragEnterEvent(self, event):
|
|
if event.mimeData().hasText():
|
|
event.acceptProposedAction()
|
|
|
|
def dragMoveEvent(self, event):
|
|
if event.mimeData().hasText():
|
|
event.acceptProposedAction()
|
|
|
|
|
|
def dropEvent(self, event):
|
|
raw_data = event.mimeData().text()
|
|
parts = raw_data.split("|")
|
|
if len(parts) < 2: return # Safety check
|
|
|
|
block_type, label = parts[0], parts[1]
|
|
fields = []
|
|
|
|
# Safe JSON parsing
|
|
if len(parts) > 2:
|
|
try:
|
|
fields = json.loads(parts[2])
|
|
except Exception as e:
|
|
print(f"DEBUG: Failed to parse: {e}")
|
|
|
|
# Spawn selection
|
|
if block_type == "value":
|
|
# new_block = ValueBlock(value="0")
|
|
print("Spawned ValueBlock (Placeholder)")
|
|
return # Exit early if not implemented to avoid setPos error
|
|
else:
|
|
new_block = PuzzleBlock(block_type, label, fields=fields)
|
|
|
|
self.scene.addItem(new_block)
|
|
|
|
# Try snapping, otherwise place at drop location
|
|
if not self.perform_snap(new_block, self.mapToScene(event.position().toPoint())):
|
|
new_block.setPos(self.mapToScene(event.position().toPoint()))
|
|
|
|
event.acceptProposedAction()
|
|
|
|
|
|
def get_stack_extremity(self, block, direction):
|
|
curr = block
|
|
visited = {curr}
|
|
for _ in range(50):
|
|
next_node = curr.parent_block if direction == "up" else curr.child_block
|
|
if next_node and next_node not in visited:
|
|
curr = next_node
|
|
visited.add(curr)
|
|
else:
|
|
break
|
|
return curr
|
|
|
|
|
|
def keyPressEvent(self, event):
|
|
if event.key() == Qt.Key.Key_F7:
|
|
self.diagnostic_dump()
|
|
super().keyPressEvent(event)
|
|
|
|
|
|
def diagnostic_dump(self):
|
|
print(f"\n{'-'*20} LOGICAL STATE DUMP {'-'*20}")
|
|
blocks = [item for item in self.scene.items() if hasattr(item, 'data') and item.data(0) == "p_block"]
|
|
|
|
if not blocks:
|
|
print("No blocks found in scene.")
|
|
return
|
|
|
|
for i, block in enumerate(blocks):
|
|
parent_name = block.parent_block.label_text if block.parent_block else "None"
|
|
child_name = block.child_block.label_text if block.child_block else "None"
|
|
|
|
# Check for 'Half-Broken' links (Asymmetry)
|
|
warning = ""
|
|
if block.parent_block and block.parent_block.child_block is not block:
|
|
warning = " [!] ASYMMETRY: Parent doesn't recognize this child"
|
|
if block.child_block and block.child_block.parent_block is not block:
|
|
warning = " [!] ASYMMETRY: Child doesn't recognize this parent"
|
|
|
|
print(f"[{i}] BLOCK: {block.label_text}")
|
|
print(f" - Parent: {parent_name}")
|
|
print(f" - Child: {child_name}{warning}")
|
|
print(f" - Selected: {block.isSelected()}")
|
|
print(f"{'-'*60}\n")
|
|
|
|
|
|
|
|
def perform_snap(self, dragged_item, mouse_pos):
|
|
print(f"\n{'='*60}\n[DEBUG] STARTING SNAP: {dragged_item.label_text}")
|
|
print(f"[DEBUG] Mouse Release Pos: {mouse_pos}")
|
|
|
|
potential_targets = self.scene.items(mouse_pos)
|
|
target_block = None
|
|
|
|
for item in potential_targets:
|
|
# Ignore the item we are dragging
|
|
if item is dragged_item:
|
|
continue
|
|
|
|
# PROXY CHECK: If it's a proxy, we want the block it belongs to
|
|
# Instead of walking up, check if the item itself IS the block
|
|
if item.data(0) == "p_block":
|
|
target_block = item
|
|
break
|
|
|
|
# If it's a child of a block (like a label), get the parent safely
|
|
p = item.parentItem()
|
|
if p and p.data(0) == "p_block":
|
|
target_block = p
|
|
break
|
|
|
|
if not target_block:
|
|
return False
|
|
|
|
# 1. Find the ultimate roots of both stacks
|
|
dragged_root = self.get_stack_extremity(dragged_item, "up")
|
|
dragged_family = set()
|
|
curr = dragged_root
|
|
while curr:
|
|
dragged_family.add(curr)
|
|
curr = curr.child_block
|
|
if curr == dragged_root: break
|
|
|
|
# 2. Find a target that is NOT in that family
|
|
potential_targets = self.scene.items(mouse_pos)
|
|
target_block = None
|
|
|
|
for item in potential_targets:
|
|
# Safely get the block object
|
|
candidate = None
|
|
if item.data(0) == "p_block":
|
|
candidate = item
|
|
else:
|
|
p = item.parentItem()
|
|
if p and p.data(0) == "p_block":
|
|
candidate = p
|
|
|
|
# CRITICAL: Target cannot be part of the stack we are holding
|
|
if candidate and candidate not in dragged_family:
|
|
target_block = candidate
|
|
break
|
|
|
|
if not target_block:
|
|
return False
|
|
|
|
# 3. Double-check logical roots (The "Atomic" check)
|
|
target_root = self.get_stack_extremity(target_block, "up")
|
|
if dragged_root == target_root:
|
|
print(f"DEBUG: Still shares root {dragged_root.label_text}")
|
|
return False
|
|
|
|
|
|
|
|
print(f"[DEBUG] Processing stack-to-stack connection: {dragged_item.label_text} -> {target_block.label_text}")
|
|
|
|
head = self.get_stack_extremity(target_block, "up")
|
|
tail = self.get_stack_extremity(target_block, "down")
|
|
print(f"[DEBUG] Stack Extremities -> Head: {head.label_text} | Tail: {tail.label_text}")
|
|
|
|
if dragged_item.b_type != "end" and head.parent_block is None:
|
|
if head.b_type != "begin":
|
|
print(f"[DEBUG] Snap Trigger: {dragged_item.label_text} ABOVE {head.label_text}")
|
|
return self.connect_blocks(dragged_item, head, mode="upward")
|
|
else:
|
|
print("[DEBUG] SNAP REJECTED: Cannot snap above a 'Begin' block.")
|
|
|
|
if dragged_item.b_type != "begin" and tail.child_block is None:
|
|
if tail.b_type != "end":
|
|
print(f"[DEBUG] Snap Trigger: {dragged_item.label_text} BELOW {tail.label_text}")
|
|
return self.connect_blocks(tail, dragged_item, mode="downward")
|
|
else:
|
|
print("[DEBUG] SNAP REJECTED: Cannot snap below a 'Finish' block.")
|
|
|
|
print(f"[DEBUG] SNAP FAILED: {dragged_item.label_text} found no valid connection point on {target_block.label_text}")
|
|
return False
|
|
|
|
|
|
def connect_blocks(self, parent, child, mode="downward"):
|
|
"""Links blocks and forces physical alignment based on a stable anchor."""
|
|
print(f"DEBUG: Entering connect_blocks | Parent: {parent.label_text} | Child: {child.label_text} | Mode: {mode}")
|
|
|
|
if not parent or not child:
|
|
print("DEBUG: Connection Aborted -> Missing reference to parent or child.")
|
|
return False
|
|
|
|
if mode == "upward":
|
|
anchor_pos = child.scenePos()
|
|
if anchor_pos.isNull():
|
|
print(f"DEBUG: UPWARD ERROR -> {child.label_text} scenePos is Null. Aborting.")
|
|
return False
|
|
|
|
new_parent_pos = anchor_pos - QPointF(0, parent.height)
|
|
parent.setPos(new_parent_pos)
|
|
print(f"DEBUG: Physical Move -> {parent.label_text} set to {new_parent_pos} (Above {child.label_text})")
|
|
|
|
else:
|
|
anchor_pos = parent.scenePos()
|
|
if anchor_pos.isNull():
|
|
print(f"DEBUG: DOWNWARD ERROR -> {parent.label_text} scenePos is Null. Aborting.")
|
|
return False
|
|
|
|
new_child_pos = anchor_pos + QPointF(0, parent.height)
|
|
child.setPos(new_child_pos)
|
|
print(f"DEBUG: Physical Move -> {child.label_text} set to {new_child_pos} (Below {parent.label_text})")
|
|
|
|
parent.child_block = child
|
|
child.parent_block = parent
|
|
print(f"DEBUG: Logical Link Established -> {parent.label_text}.child = {child.label_text}")
|
|
|
|
self.ripple_move(child)
|
|
print(f"SUCCESS: {parent.label_text} and {child.label_text} are now connected.")
|
|
return True
|
|
|
|
|
|
|
|
def ripple_move(self, start_block):
|
|
curr = start_block
|
|
visited = set() # Track to prevent infinite loops
|
|
print(f"DEBUG: Starting Ripple Move from {start_block.label_text}")
|
|
|
|
while curr and curr.child_block:
|
|
if curr in visited:
|
|
print("CRITICAL: Circular link detected in ripple! Breaking.")
|
|
break
|
|
visited.add(curr)
|
|
|
|
next_block = curr.child_block
|
|
expected_pos = curr.scenePos() + QPointF(0, curr.height)
|
|
|
|
if next_block.scenePos() != expected_pos:
|
|
print(f"DEBUG: Ripple Sync -> Moving {next_block.label_text} to {expected_pos}")
|
|
|
|
next_block.setPos(expected_pos)
|
|
|
|
curr = next_block
|
|
print("DEBUG: Ripple Move Complete.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelParameterConfigurationTab(QWidget):
|
|
def __init__(self, stringy):
|
|
super().__init__()
|
|
|
|
self.mode = stringy
|
|
|
|
self.setStyleSheet("""
|
|
QMainWindow, QWidget#centralWidget {
|
|
background-color: #1e1e1e;
|
|
}
|
|
QLabel, QStatusBar, QMenuBar {
|
|
color: #ffffff;
|
|
}
|
|
/* Target the Timeline specifically */
|
|
TimelineWidget {
|
|
background-color: #1e1e1e;
|
|
border: 1px solid #333333;
|
|
}
|
|
/* Button styling with Grey borders */
|
|
QDialog, QMessageBox, QFileDialog {
|
|
background-color: #2b2b2b;
|
|
}
|
|
QDialog QLabel, QMessageBox QLabel {
|
|
color: #ffffff;
|
|
}
|
|
QPushButton {
|
|
background-color: #2b2b2b;
|
|
color: #ffffff;
|
|
border: 1px solid #555555; /* Subtle Grey border */
|
|
border-radius: 3px;
|
|
padding: 4px;
|
|
}
|
|
QPushButton:hover {
|
|
background-color: #3d3d3d;
|
|
border-color: #888888; /* Brightens border on hover */
|
|
}
|
|
QPushButton:pressed {
|
|
background-color: #111111;
|
|
}
|
|
QPushButton:disabled {
|
|
border-color: #333333;
|
|
color: #444444;
|
|
}
|
|
/* Splitter/Divider styling */
|
|
QSplitter::handle {
|
|
background-color: #333333; /* Dark grey dividers */
|
|
}
|
|
QSplitter::handle:horizontal {
|
|
width: 2px;
|
|
}
|
|
QSplitter::handle:vertical {
|
|
height: 2px;
|
|
}
|
|
/* ScrollArea styling to keep it dark */
|
|
QScrollArea, QScrollArea > QWidget > QWidget {
|
|
background-color: #1e1e1e;
|
|
border: none;
|
|
}
|
|
""")
|
|
|
|
|
|
self.setup_ui()
|
|
|
|
|
|
|
|
def setup_ui(self):
|
|
main_layout = QVBoxLayout(self)
|
|
main_layout.setContentsMargins(0, 0, 0, 0)
|
|
|
|
# Main horizontal split: [ Canvas | Library/Inspector ]
|
|
self.horizontal_splitter = QSplitter(Qt.Horizontal)
|
|
|
|
# --- LEFT: Topology Canvas ---
|
|
self.canvas = TopologyCanvas()
|
|
|
|
# --- RIGHT: Sidebar (Library + Inspector) ---
|
|
sidebar_container = QWidget()
|
|
sidebar_layout = QVBoxLayout(sidebar_container)
|
|
|
|
self.block_library = BlockLibrary(parent_tab=self, stringy=self.mode)
|
|
|
|
# Reuse your existing Inspector logic
|
|
self.inspector_scroll = QScrollArea()
|
|
self.inspector_scroll.setWidgetResizable(True)
|
|
self.info_label = QTextEdit()
|
|
self.info_label.setReadOnly(True)
|
|
self.info_label.setStyleSheet("background-color: #2b2b2b; color: #ffffff;")
|
|
self.inspector_scroll.setWidget(self.info_label)
|
|
|
|
sidebar_layout.addWidget(QLabel("Block Library"))
|
|
sidebar_layout.addWidget(self.block_library, 2) # Library gets more space
|
|
sidebar_layout.addWidget(QLabel("Properties"))
|
|
sidebar_layout.addWidget(self.inspector_scroll, 1)
|
|
|
|
# Add to main splitter
|
|
self.horizontal_splitter.addWidget(self.canvas)
|
|
self.horizontal_splitter.addWidget(sidebar_container)
|
|
self.horizontal_splitter.setSizes([800, 300])
|
|
|
|
main_layout.addWidget(self.horizontal_splitter)
|
|
|
|
|
|
|
|
|
|
|
|
def resource_path(relative_path):
|
|
"""
|
|
Get absolute path to resource regardless of running directly or packaged using PyInstaller
|
|
"""
|
|
|
|
if hasattr(sys, '_MEIPASS'):
|
|
# PyInstaller bundle path
|
|
base_path = sys._MEIPASS
|
|
else:
|
|
base_path = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
return os.path.join(base_path, relative_path)
|
|
|
|
|
|
def kill_child_processes():
|
|
"""
|
|
Goodbye children
|
|
"""
|
|
|
|
try:
|
|
parent = psutil.Process(os.getpid())
|
|
children = parent.children(recursive=True)
|
|
for child in children:
|
|
try:
|
|
child.kill()
|
|
except psutil.NoSuchProcess:
|
|
pass
|
|
psutil.wait_procs(children, timeout=5)
|
|
except Exception as e:
|
|
print(f"Error killing child processes: {e}")
|
|
|
|
|
|
def exception_hook(exc_type, exc_value, exc_traceback):
|
|
"""
|
|
Method that will display a popup when the program hard crashes containg what went wrong
|
|
"""
|
|
|
|
error_msg = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
|
print(error_msg) # also print to console
|
|
|
|
kill_child_processes()
|
|
|
|
# Show error message box
|
|
# Make sure QApplication exists (or create a minimal one)
|
|
app = QApplication.instance()
|
|
if app is None:
|
|
app = QApplication(sys.argv)
|
|
|
|
show_critical_error(error_msg)
|
|
|
|
# Exit the app after user acknowledges
|
|
sys.exit(1)
|
|
|
|
def show_critical_error(error_msg):
|
|
msg_box = QMessageBox()
|
|
msg_box.setIcon(QMessageBox.Icon.Critical)
|
|
msg_box.setWindowTitle("Something went wrong!")
|
|
|
|
if PLATFORM_NAME == "darwin":
|
|
log_path = os.path.join(os.path.dirname(sys.executable), "../../../flares.log")
|
|
log_path2 = os.path.join(os.path.dirname(sys.executable), "../../../flares_error.log")
|
|
save_path = os.path.join(os.path.dirname(sys.executable), "../../../flares_autosave.flare")
|
|
|
|
else:
|
|
log_path = os.path.join(os.getcwd(), "flares.log")
|
|
log_path2 = os.path.join(os.getcwd(), "flares_error.log")
|
|
save_path = os.path.join(os.getcwd(), "flares_autosave.flare")
|
|
|
|
|
|
shutil.copy(log_path, log_path2)
|
|
log_path2 = Path(log_path2).absolute().as_posix()
|
|
autosave_path = Path(save_path).absolute().as_posix()
|
|
log_link = f"file:///{log_path2}"
|
|
autosave_link = f"file:///{autosave_path}"
|
|
|
|
message = (
|
|
f"{APP_NAME.upper()} has encountered an unrecoverable error and needs to close.<br><br>"
|
|
f"We are sorry for the inconvenience. An autosave was attempted to be saved to <a href='{autosave_link}'>{autosave_path}</a>, but it may not have been saved. "
|
|
"If the file was saved, it still may not be intact, openable, or contain the correct data. Use the autosave at your discretion.<br><br>"
|
|
f"This unrecoverable error was likely due to an error with {APP_NAME.upper()} and not your data.<br>"
|
|
f"Please raise an issue <a href='https://git.research.dezeeuw.ca/tyler/{APP_NAME}/issues'>here</a> and attach the error file located at <a href='{log_link}'>{log_path2}</a><br><br>"
|
|
f"<pre>{error_msg}</pre>"
|
|
)
|
|
|
|
msg_box.setTextFormat(Qt.TextFormat.RichText)
|
|
msg_box.setText(message)
|
|
msg_box.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction)
|
|
msg_box.setStandardButtons(QMessageBox.StandardButton.Ok)
|
|
|
|
msg_box.exec()
|
|
|
|
if __name__ == "__main__":
|
|
# Redirect exceptions to the popup window
|
|
sys.excepthook = exception_hook
|
|
|
|
# Set up application logging
|
|
if PLATFORM_NAME == "darwin":
|
|
log_path = os.path.join(os.path.dirname(sys.executable), f"../../../{APP_NAME}.log")
|
|
else:
|
|
log_path = os.path.join(os.getcwd(), f"{APP_NAME}.log")
|
|
|
|
try:
|
|
os.remove(log_path)
|
|
except:
|
|
pass
|
|
|
|
sys.stdout = open(log_path, "a", buffering=1)
|
|
sys.stderr = sys.stdout
|
|
print(f"\n=== App started at {datetime.now()} ===\n")
|
|
|
|
freeze_support() # Required for PyInstaller + multiprocessing
|
|
|
|
# Only run GUI in the main process
|
|
if current_process().name == 'MainProcess':
|
|
app = QApplication(sys.argv)
|
|
|
|
style = """
|
|
|
|
/* 1. General App Backgrounds */
|
|
QMainWindow, QWidget#centralWidget, QDialog, QMessageBox, QFileDialog {
|
|
background-color: #1e1e1e;
|
|
color: #ffffff;
|
|
}
|
|
|
|
QLabel, QStatusBar, QMenuBar {
|
|
color: #ffffff;
|
|
}
|
|
|
|
/* 2. THE BIG FIX: Removing white backgrounds from inputs */
|
|
QListWidget, QComboBox, QLineEdit, QSpinBox, QTextEdit {
|
|
background-color: #2b2b2b;
|
|
color: #ffffff;
|
|
border: 1px solid #555555;
|
|
border-radius: 3px;
|
|
padding: 2px;
|
|
}
|
|
|
|
/* Fix for the dropdown list of a QComboBox */
|
|
QComboBox QAbstractItemView {
|
|
background-color: #2b2b2b;
|
|
color: #ffffff;
|
|
selection-background-color: #00aaff;
|
|
selection-color: #ffffff;
|
|
outline: none;
|
|
border: 1px solid #555555;
|
|
}
|
|
|
|
/* 3. Tab Navigation Styling */
|
|
QTabWidget::pane {
|
|
border: 1px solid #333333;
|
|
background: #1e1e1e;
|
|
}
|
|
|
|
QTabBar::tab {
|
|
background: #2b2b2b;
|
|
color: #aaa;
|
|
padding: 8px 15px;
|
|
border: 1px solid #333;
|
|
border-bottom: none;
|
|
border-top-left-radius: 4px;
|
|
border-top-right-radius: 4px;
|
|
}
|
|
|
|
QTabBar::tab:selected:!last {
|
|
background: #3d3d3d;
|
|
color: #fff;
|
|
font-weight: bold;
|
|
}
|
|
|
|
QTabBar::tab:hover:!last {
|
|
background: #444;
|
|
}
|
|
|
|
/* THE PLUS TAB: Constant State */
|
|
QTabBar::tab:last, QTabBar::tab:last:hover, QTabBar::tab:last:selected {
|
|
background: #1e1e1e;
|
|
color: #00aaff;
|
|
font-weight: bold;
|
|
margin-left: 2px;
|
|
border: 1px solid #333;
|
|
padding: 4px 12px;
|
|
}
|
|
|
|
/* 4. Timeline & Custom Widgets */
|
|
TimelineWidget {
|
|
background-color: #1e1e1e;
|
|
border: 1px solid #333333;
|
|
}
|
|
|
|
/* 5. Buttons with Grey Borders */
|
|
QPushButton {
|
|
background-color: #2b2b2b;
|
|
color: #ffffff;
|
|
border: 1px solid #555555;
|
|
border-radius: 3px;
|
|
padding: 4px;
|
|
}
|
|
QPushButton:hover {
|
|
background-color: #3d3d3d;
|
|
border-color: #888888;
|
|
}
|
|
QPushButton:pressed {
|
|
background-color: #111111;
|
|
}
|
|
QPushButton:disabled {
|
|
border-color: #333333;
|
|
color: #444444;
|
|
}
|
|
|
|
/* 6. Layout Dividers */
|
|
QSplitter::handle {
|
|
background-color: #333333;
|
|
}
|
|
QSplitter::handle:horizontal { width: 2px; }
|
|
QSplitter::handle:vertical { height: 2px; }
|
|
|
|
/* 7. Scroll Areas */
|
|
QScrollArea, QScrollArea > QWidget > QWidget {
|
|
background-color: #1e1e1e;
|
|
border: none;
|
|
}
|
|
QWidget { background-color: #1e1e1e; color: #ffffff; font-family: 'Segoe UI'; }
|
|
QGroupBox {
|
|
border: 1px solid #3d3d3d; border-radius: 8px; margin-top: 15px;
|
|
padding-top: 15px; font-weight: bold; color: #00aaff; text-transform: uppercase;
|
|
}
|
|
QLabel { color: #ffffff; font-weight: 500; }
|
|
QLabel:disabled { color: #444444; }
|
|
QLabel#Metadata { color: #00ffaa; font-family: 'Consolas'; font-size: 11px; }
|
|
QLabel#Preview { background-color: #000000; border: 2px solid #3d3d3d; }
|
|
QLabel#Warning { color: #ff5555; font-size: 11px; font-style: italic; font-weight: bold; }
|
|
|
|
QPushButton { background-color: #3d3d3d; border: 1px solid #555; padding: 6px; border-radius: 4px; }
|
|
QPushButton:hover { background-color: #00aaff; color: #000; }
|
|
QPushButton:disabled { color: #444; background-color: #252525; }
|
|
|
|
QComboBox { background-color: #2d2d2d; border: 1px solid #555; padding: 4px; border-radius: 4px; }
|
|
QComboBox:disabled { background-color: #222; color: #444; border: 1px solid #2a2a2a; }
|
|
"""
|
|
|
|
app.setStyleSheet(style)
|
|
finish_update_if_needed(PLATFORM_NAME, APP_NAME)
|
|
window = PremiereWindow()
|
|
|
|
if PLATFORM_NAME == "darwin":
|
|
app.setWindowIcon(QIcon(resource_path("icons/main.icns")))
|
|
window.setWindowIcon(QIcon(resource_path("icons/main.icns")))
|
|
else:
|
|
app.setWindowIcon(QIcon(resource_path("icons/main.ico")))
|
|
window.setWindowIcon(QIcon(resource_path("icons/main.ico")))
|
|
window.show()
|
|
sys.exit(app.exec())
|
|
|
|
# Not 6000 lines yay! |