2080 lines
80 KiB
Python
2080 lines
80 KiB
Python
"""
|
|
Filename: main.py
|
|
Description: BLAZES main executable
|
|
|
|
Author: Tyler de Zeeuw
|
|
License: GPL-3.0
|
|
"""
|
|
|
|
# Built-in imports
|
|
import os
|
|
import csv
|
|
import sys
|
|
import json
|
|
import glob
|
|
import shutil
|
|
import inspect
|
|
import platform
|
|
import traceback
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from multiprocessing import current_process, freeze_support
|
|
|
|
# External library imports
|
|
import numpy as np
|
|
import pandas as pd
|
|
import psutil
|
|
import joblib
|
|
import cv2
|
|
from ultralytics import YOLO
|
|
|
|
from updater import finish_update_if_needed, UpdateManager, LocalPendingUpdateCheckThread
|
|
from predictor import GeneralPredictor
|
|
from batch_processing import BatchProcessorDialog
|
|
|
|
import PySide6
|
|
from PySide6.QtWidgets import (QApplication, QMainWindow, QTabBar, QWidget, QVBoxLayout, QGraphicsView, QGraphicsScene,
|
|
QHBoxLayout, QSplitter, QLabel, QPushButton, QComboBox, QInputDialog,
|
|
QFileDialog, QScrollArea, QMessageBox, QSlider, QTextEdit, QGroupBox, QGridLayout, QCheckBox, QTabWidget)
|
|
from PySide6.QtCore import Qt, QThread, Signal, QUrl, QRectF, QPointF, QRect, QSizeF
|
|
from PySide6.QtGui import QPainter, QColor, QFont, QPen, QBrush, QAction, QKeySequence, QIcon, QTextOption, QImage, QPixmap
|
|
from PySide6.QtMultimedia import QMediaPlayer, QAudioOutput
|
|
from PySide6.QtMultimediaWidgets import QGraphicsVideoItem
|
|
|
|
|
|
VERBOSITY = 1
|
|
CURRENT_VERSION = "0.1.0"
|
|
APP_NAME = "blazes"
|
|
API_URL = f"https://git.research.dezeeuw.ca/api/v1/repos/tyler/{APP_NAME}/releases"
|
|
API_URL_SECONDARY = f"https://git.research2.dezeeuw.ca/api/v1/repos/tyler/{APP_NAME}/releases"
|
|
PLATFORM_NAME = platform.system().lower()
|
|
|
|
|
|
|
|
def debug_print():
|
|
if VERBOSITY:
|
|
frame = inspect.currentframe().f_back
|
|
qualname = frame.f_code.co_filename
|
|
print(qualname)
|
|
|
|
|
|
# Ordered according to YOLO docs: https://docs.ultralytics.com/tasks/pose/
|
|
JOINT_NAMES = [
|
|
"Nose", "Left Eye", "Right Eye", "Left Ear", "Right Ear",
|
|
"Left Shoulder", "Right Shoulder", "Left Elbow", "Right Elbow",
|
|
"Left Wrist", "Right Wrist", "Left Hip", "Right Hip",
|
|
"Left Knee", "Right Knee", "Left Ankle", "Right Ankle"
|
|
]
|
|
|
|
|
|
# Needs to be pointed to the FFmpeg bin folder containing avcodec-*.dll, etc.
|
|
pyside_dir = Path(PySide6.__file__).parent
|
|
if sys.platform == "win32":
|
|
# Tell Python 3.13+ where to find the FFmpeg DLLs bundled with PySide
|
|
os.add_dll_directory(str(pyside_dir))
|
|
|
|
|
|
TRACK_NAMES = ["Baseline", "Live Skeleton"] + JOINT_NAMES
|
|
NUM_TRACKS = len(TRACK_NAMES)
|
|
|
|
# TODO: Improve colors?
|
|
# Generate distinct colors for the tracks
|
|
BASE_COLORS = [QColor(180, 180, 180), QColor(0, 0, 0)] # Grey for Baseline, Black for Live
|
|
REMAINING_COLORS = [QColor.fromHsv(int((i / (NUM_TRACKS-2)) * 359), 200, 255) for i in range(NUM_TRACKS-2)]
|
|
TRACK_COLORS = BASE_COLORS + REMAINING_COLORS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
import json
|
|
import cv2
|
|
from PySide6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QPushButton,
|
|
QFileDialog, QCheckBox, QComboBox, QLabel,
|
|
QGridLayout, QGroupBox, QStackedWidget, QInputDialog, QMessageBox)
|
|
from PySide6.QtGui import QPixmap, QImage
|
|
from PySide6.QtCore import Qt
|
|
|
|
|
|
|
|
|
|
import os
|
|
import json
|
|
import cv2
|
|
from PySide6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QPushButton,
|
|
QFileDialog, QCheckBox, QComboBox, QLabel,
|
|
QGridLayout, QGroupBox, QStackedWidget, QInputDialog, QMessageBox)
|
|
from PySide6.QtGui import QPixmap, QImage
|
|
from PySide6.QtCore import Qt
|
|
|
|
class OpenFileWindow(QWidget):
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent, Qt.WindowType.Window)
|
|
self.setWindowTitle(f"Load Video - {APP_NAME.upper()}")
|
|
self.setMinimumWidth(650)
|
|
|
|
# State
|
|
self.video_path = None
|
|
self.obs_file = None
|
|
self.full_json_data = None
|
|
self.current_video_fps = 30.0
|
|
self.current_video_offset = 0.0
|
|
|
|
self.setup_ui()
|
|
|
|
def setup_ui(self):
|
|
self.setStyleSheet("""
|
|
QWidget { background-color: #1e1e1e; color: #ffffff; font-family: 'Segoe UI'; }
|
|
QGroupBox {
|
|
border: 1px solid #3d3d3d; border-radius: 8px; margin-top: 15px;
|
|
padding-top: 15px; font-weight: bold; color: #00aaff; text-transform: uppercase;
|
|
}
|
|
QLabel { color: #ffffff; font-weight: 500; }
|
|
QLabel:disabled { color: #444444; }
|
|
QLabel#Metadata { color: #00ffaa; font-family: 'Consolas'; font-size: 11px; }
|
|
QLabel#Preview { background-color: #000000; border: 2px solid #3d3d3d; }
|
|
QLabel#Warning { color: #ff5555; font-size: 11px; font-style: italic; font-weight: bold; }
|
|
|
|
QPushButton { background-color: #3d3d3d; border: 1px solid #555; padding: 6px; border-radius: 4px; }
|
|
QPushButton:hover { background-color: #00aaff; color: #000; }
|
|
QPushButton:disabled { color: #444; background-color: #252525; }
|
|
|
|
QComboBox { background-color: #2d2d2d; border: 1px solid #555; padding: 4px; border-radius: 4px; }
|
|
QComboBox:disabled { background-color: #222; color: #444; border: 1px solid #2a2a2a; }
|
|
""")
|
|
|
|
main_layout = QVBoxLayout(self)
|
|
|
|
# --- Section 1: Video ---
|
|
video_group = QGroupBox("Primary Video Source")
|
|
v_grid = QGridLayout(video_group)
|
|
self.btn_pick_video = QPushButton("Select Video")
|
|
self.lbl_video_path = QLabel("No video selected...")
|
|
self.lbl_video_metadata = QLabel("Metadata: N/A")
|
|
self.lbl_video_metadata.setObjectName("Metadata")
|
|
self.video_preview = QLabel("NO PREVIEW")
|
|
self.video_preview.setFixedSize(160, 90)
|
|
self.video_preview.setObjectName("Preview")
|
|
|
|
v_grid.addWidget(QLabel("Target Video:"), 0, 0)
|
|
v_grid.addWidget(self.btn_pick_video, 0, 1)
|
|
v_grid.addWidget(self.video_preview, 0, 2, 3, 1)
|
|
v_grid.addWidget(QLabel("Path:"), 1, 0)
|
|
v_grid.addWidget(self.lbl_video_path, 1, 1)
|
|
v_grid.addWidget(self.lbl_video_metadata, 2, 0, 1, 2)
|
|
main_layout.addWidget(video_group)
|
|
|
|
# --- Section 2: Analysis Modes ---
|
|
self.adv_group = QGroupBox("Behavioral Analysis Mode")
|
|
adv_layout = QVBoxLayout(self.adv_group)
|
|
self.combo_mode = QComboBox()
|
|
self.combo_mode.addItems(["BORIS Project (JSON)", "Trained ML Model (.pkl)", "Bypass / Manual"])
|
|
adv_layout.addWidget(self.combo_mode)
|
|
|
|
self.mode_stack = QStackedWidget()
|
|
|
|
# Mode 1: BORIS
|
|
self.page_boris = QWidget()
|
|
m1_grid = QGridLayout(self.page_boris)
|
|
self.btn_boris_file = QPushButton("Load .boris File")
|
|
self.lbl_session = QLabel("Session Key:")
|
|
self.combo_boris_keys = QComboBox()
|
|
self.lbl_slot = QLabel("Video Slot:")
|
|
self.combo_video_slot = QComboBox()
|
|
|
|
# Initial Disabled States
|
|
for w in [self.lbl_session, self.combo_boris_keys, self.lbl_slot, self.combo_video_slot]:
|
|
w.setEnabled(False)
|
|
|
|
m1_grid.addWidget(QLabel("BORIS File:"), 0, 0)
|
|
m1_grid.addWidget(self.btn_boris_file, 0, 1)
|
|
m1_grid.addWidget(self.lbl_session, 1, 0)
|
|
m1_grid.addWidget(self.combo_boris_keys, 1, 1)
|
|
m1_grid.addWidget(self.lbl_slot, 2, 0)
|
|
m1_grid.addWidget(self.combo_video_slot, 2, 1)
|
|
|
|
# Mode 2: PKL Model
|
|
self.page_pkl = QWidget()
|
|
m2_grid = QGridLayout(self.page_pkl)
|
|
self.btn_pkl_file = QPushButton("Load .pkl Model")
|
|
self.lbl_pkl_path = QLabel("No model selected...")
|
|
m2_grid.addWidget(QLabel("Model File:"), 0, 0)
|
|
m2_grid.addWidget(self.btn_pkl_file, 0, 1)
|
|
m2_grid.addWidget(QLabel("Selected:"), 1, 0)
|
|
m2_grid.addWidget(self.lbl_pkl_path, 1, 1)
|
|
|
|
# Mode 3: Bypass
|
|
self.page_bypass = QWidget()
|
|
m3_layout = QVBoxLayout(self.page_bypass)
|
|
self.lbl_bypass_info = QLabel("Bypass Mode: No behavioral data will be loaded.\nManual annotation mode enabled.")
|
|
self.lbl_bypass_info.setStyleSheet("color: #888; font-style: italic;")
|
|
m3_layout.addWidget(self.lbl_bypass_info)
|
|
|
|
self.mode_stack.addWidget(self.page_boris)
|
|
self.mode_stack.addWidget(self.page_pkl)
|
|
self.mode_stack.addWidget(self.page_bypass)
|
|
adv_layout.addWidget(self.mode_stack)
|
|
main_layout.addWidget(self.adv_group)
|
|
|
|
# --- Section 3: Inference ---
|
|
self.cfg_group = QGroupBox("Inference Settings")
|
|
c_grid = QGridLayout(self.cfg_group)
|
|
self.check_use_cache = QCheckBox("Auto-search pose cache (.npy)")
|
|
self.check_use_cache.setChecked(True)
|
|
|
|
self.lbl_model_prompt = QLabel("Pose Model:")
|
|
self.combo_inference_model = QComboBox()
|
|
self.combo_inference_model.addItems(["YOLO11n-Pose", "YOLO11m-Pose", "Mediapipe BlazePose"])
|
|
|
|
self.check_bypass_inference = QCheckBox("Bypass Pose Inference")
|
|
self.lbl_inf_warning = QLabel("⚠ WARNING: Nothing fancy. Raw video playback only.")
|
|
self.lbl_inf_warning.setObjectName("Warning")
|
|
self.lbl_inf_warning.setVisible(False)
|
|
|
|
c_grid.addWidget(self.check_use_cache, 0, 0, 1, 2)
|
|
c_grid.addWidget(self.lbl_model_prompt, 1, 0)
|
|
c_grid.addWidget(self.combo_inference_model, 1, 1)
|
|
c_grid.addWidget(self.check_bypass_inference, 2, 0)
|
|
c_grid.addWidget(self.lbl_inf_warning, 2, 1)
|
|
main_layout.addWidget(self.cfg_group)
|
|
|
|
# --- Bottom Buttons ---
|
|
btn_layout = QHBoxLayout()
|
|
self.btn_cancel = QPushButton("Cancel")
|
|
self.btn_confirm = QPushButton("Initialize BLAZE Engine")
|
|
self.btn_confirm.setStyleSheet("background-color: #00aaff; color: #1e1e1e; font-weight: bold;")
|
|
btn_layout.addWidget(self.btn_cancel)
|
|
btn_layout.addWidget(self.btn_confirm)
|
|
main_layout.addLayout(btn_layout)
|
|
|
|
# Connections
|
|
self.btn_pick_video.clicked.connect(self.handle_video_selection)
|
|
self.combo_mode.currentIndexChanged.connect(self.mode_stack.setCurrentIndex)
|
|
self.btn_boris_file.clicked.connect(self.handle_boris_load)
|
|
self.combo_boris_keys.currentIndexChanged.connect(self.handle_session_change)
|
|
self.combo_video_slot.currentIndexChanged.connect(self.handle_slot_change)
|
|
self.btn_pkl_file.clicked.connect(self.handle_pkl_selection)
|
|
self.check_bypass_inference.toggled.connect(self.handle_inference_toggle)
|
|
self.btn_cancel.clicked.connect(self.close)
|
|
|
|
def format_time(self, seconds):
|
|
h, m, s = int(seconds // 3600), int((seconds % 3600) // 60), int(seconds % 60)
|
|
return f"{h:02d}:{m:02d}:{s:02d}"
|
|
|
|
def handle_video_selection(self):
|
|
path, _ = QFileDialog.getOpenFileName(self, "Open Video", "", "Video Files (*.mp4 *.avi *.mkv)")
|
|
if path:
|
|
self.video_path = path
|
|
self.lbl_video_path.setText(os.path.basename(path))
|
|
cap = cv2.VideoCapture(path)
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
|
time_str = self.format_time(cap.get(cv2.CAP_PROP_FRAME_COUNT) / fps)
|
|
self.lbl_video_metadata.setText(f"RES: {int(cap.get(3))}x{int(cap.get(4))} | FPS: {fps:.2f} | LEN: {time_str}")
|
|
cap.release()
|
|
self.render_preview(path)
|
|
if self.full_json_data: self.attempt_auto_match()
|
|
|
|
def handle_boris_load(self):
|
|
path, _ = QFileDialog.getOpenFileName(self, "Select JSON", "", "JSON Files (*.json *.boris)")
|
|
if not path: return
|
|
self.obs_file = path
|
|
self.btn_boris_file.setText(os.path.basename(path))
|
|
try:
|
|
with open(path, 'r') as f:
|
|
self.full_json_data = json.load(f)
|
|
obs = self.full_json_data.get("observations", {})
|
|
self.combo_boris_keys.setEnabled(True)
|
|
self.lbl_session.setEnabled(True)
|
|
self.combo_boris_keys.clear()
|
|
self.combo_boris_keys.addItems(list(obs.keys()))
|
|
if self.video_path: self.attempt_auto_match()
|
|
except Exception as e:
|
|
QMessageBox.warning(self, "Parse Error", str(e))
|
|
|
|
def handle_session_change(self):
|
|
session_key = self.combo_boris_keys.currentText()
|
|
if not self.full_json_data or not session_key: return
|
|
session_data = self.full_json_data.get("observations", {}).get(session_key, {})
|
|
file_map = session_data.get("file", {})
|
|
|
|
self.combo_video_slot.blockSignals(True)
|
|
self.combo_video_slot.clear()
|
|
slots = list(file_map.keys())
|
|
self.combo_video_slot.addItems(slots)
|
|
self.combo_video_slot.setEnabled(True)
|
|
self.lbl_slot.setEnabled(True)
|
|
self.combo_video_slot.blockSignals(False)
|
|
|
|
if self.video_path:
|
|
video_filename = os.path.basename(self.video_path)
|
|
for i, slot in enumerate(slots):
|
|
if any(video_filename in f for f in file_map[slot]):
|
|
self.combo_video_slot.setCurrentIndex(i)
|
|
break
|
|
|
|
def handle_slot_change(self):
|
|
session_key = self.combo_boris_keys.currentText()
|
|
slot_key = self.combo_video_slot.currentText()
|
|
if not session_key or not slot_key: return
|
|
session_data = self.full_json_data.get("observations", {}).get(session_key, {})
|
|
val = session_data.get("media_info", {}).get("offset", {}).get(str(slot_key))
|
|
if val is not None:
|
|
self.current_video_offset = float(val)
|
|
txt = self.lbl_video_metadata.text().split(" | Offset:")[0]
|
|
self.lbl_video_metadata.setText(f"{txt} | Offset: {self.current_video_offset}s")
|
|
|
|
def attempt_auto_match(self):
|
|
obs = self.full_json_data.get("observations", {})
|
|
v_name = os.path.splitext(os.path.basename(self.video_path))[0]
|
|
v_parts = v_name.split('_')
|
|
v_fp = f"{v_parts[0]}_{v_parts[1]}_{v_parts[-1]}" if len(v_parts) >= 3 else v_name
|
|
for i, sk in enumerate(obs.keys()):
|
|
s_parts = sk.split('_')
|
|
if len(s_parts) == 3 and f"{s_parts[0]}_{s_parts[1]}_{s_parts[2]}".lower() == v_fp.lower():
|
|
self.combo_boris_keys.setCurrentIndex(i)
|
|
return
|
|
|
|
def handle_pkl_selection(self):
|
|
path, _ = QFileDialog.getOpenFileName(self, "Select Model", "", "Pickle Files (*.pkl)")
|
|
if path:
|
|
self.lbl_pkl_path.setText(os.path.basename(path))
|
|
|
|
def handle_inference_toggle(self, checked):
|
|
# Target the model label and checkbox explicitly for greying out
|
|
for w in [self.check_use_cache, self.combo_inference_model, self.lbl_model_prompt]:
|
|
w.setEnabled(not checked)
|
|
self.lbl_inf_warning.setVisible(checked)
|
|
|
|
def render_preview(self, path):
|
|
cap = cv2.VideoCapture(path)
|
|
ret, frame = cap.read()
|
|
if ret:
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
pixmap = QPixmap.fromImage(QImage(frame.data, frame.shape[1], frame.shape[0], frame.shape[1]*3, QImage.Format_RGB888))
|
|
self.video_preview.setPixmap(pixmap.scaled(self.video_preview.size(), Qt.KeepAspectRatio))
|
|
cap.release()
|
|
|
|
|
|
|
|
|
|
class AboutWindow(QWidget):
|
|
"""
|
|
Simple About window displaying basic application information.
|
|
|
|
Args:
|
|
parent (QWidget, optional): Parent widget of this window. Defaults to None.
|
|
"""
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent, Qt.WindowType.Window)
|
|
self.setWindowTitle(f"About {APP_NAME.upper()}")
|
|
self.resize(250, 100)
|
|
self.setStyleSheet("""
|
|
QVBoxLayout, QWidget {
|
|
background-color: #1e1e1e;
|
|
}
|
|
QLabel {
|
|
color: #ffffff;
|
|
}
|
|
""")
|
|
|
|
layout = QVBoxLayout()
|
|
label = QLabel(f"About {APP_NAME.upper()}", self)
|
|
label2 = QLabel("Behavioral Learning & Automated Zoned Events Suite", self)
|
|
label3 = QLabel(f"{APP_NAME.upper()} is licensed under the GPL-3.0 licence. For more information, visit https://www.gnu.org/licenses/gpl-3.0.en.html", self)
|
|
label4 = QLabel(f"Version v{CURRENT_VERSION}")
|
|
|
|
layout.addWidget(label)
|
|
layout.addWidget(label2)
|
|
layout.addWidget(label3)
|
|
layout.addWidget(label4)
|
|
|
|
self.setLayout(layout)
|
|
|
|
|
|
|
|
class UserGuideWindow(QWidget):
|
|
"""
|
|
Simple User Guide window displaying basic information on how to use the software.
|
|
|
|
Args:
|
|
parent (QWidget, optional): Parent widget of this window. Defaults to None.
|
|
"""
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent, Qt.WindowType.Window)
|
|
self.setWindowTitle(f"User Guide - {APP_NAME.upper()}")
|
|
self.resize(250, 100)
|
|
self.setStyleSheet("""
|
|
QVBoxLayout, QWidget {
|
|
background-color: #1e1e1e;
|
|
}
|
|
QLabel {
|
|
color: #ffffff;
|
|
}
|
|
""")
|
|
|
|
layout = QVBoxLayout()
|
|
label = QLabel("Hmmm...", self)
|
|
label2 = QLabel("Nothing to see here yet.", self)
|
|
|
|
label3 = QLabel(f"For more information, visit the Git wiki page <a href='https://git.research.dezeeuw.ca/tyler/{APP_NAME}/wiki'>here</a>.", self)
|
|
label3.setTextFormat(Qt.TextFormat.RichText)
|
|
label3.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction)
|
|
label3.setOpenExternalLinks(True)
|
|
layout.addWidget(label)
|
|
layout.addWidget(label2)
|
|
layout.addWidget(label3)
|
|
|
|
self.setLayout(layout)
|
|
|
|
|
|
|
|
class PoseAnalyzerWorker(QThread):
|
|
progress = Signal(str)
|
|
finished_data = Signal(dict)
|
|
|
|
def __init__(self, video_path, obs_info=None, predictor=None):
|
|
debug_print()
|
|
super().__init__()
|
|
self.video_path = video_path
|
|
self.obs_info = obs_info
|
|
self.predictor = predictor
|
|
self.pose_df = pd.DataFrame()
|
|
|
|
|
|
def get_best_infant_match(self, results, w, h, prev_track_id):
|
|
debug_print()
|
|
if not results[0].boxes or results[0].boxes.id is None:
|
|
return None, None, None, None
|
|
ids = results[0].boxes.id.int().cpu().tolist()
|
|
kpts = results[0].keypoints.xy.cpu().numpy()
|
|
confs = results[0].keypoints.conf.cpu().numpy()
|
|
best_idx, best_score = -1, -1
|
|
for i, k in enumerate(kpts):
|
|
vis = np.sum(confs[i] > 0.5)
|
|
valid = k[confs[i] > 0.5]
|
|
dist = np.linalg.norm(np.mean(valid, axis=0) - [w/2, h/2]) if len(valid) > 0 else 1000
|
|
score = (vis * 10) - (dist * 0.1) + (50 if ids[i] == prev_track_id else 0)
|
|
if score > best_score:
|
|
best_score, best_idx = score, i
|
|
if best_idx == -1:
|
|
return None, None, None, None
|
|
return ids[best_idx], kpts[best_idx], confs[best_idx], best_idx
|
|
|
|
|
|
def _merge_json_observations(self, timeline_events, fps):
|
|
"""Restores the grouping and block-pairing logic from the observation files."""
|
|
debug_print()
|
|
if not self.obs_info:
|
|
return
|
|
|
|
self.progress.emit("Merging JSON Observations...")
|
|
json_path, subkey = self.obs_info
|
|
|
|
# try:
|
|
# with open(json_path, 'r') as f:
|
|
# full_json = json.load(f)
|
|
|
|
# # Extract events for the specific subkey (e.g., 'Participant_01')
|
|
# raw_obs_events = full_json["observations"][subkey]["events"]
|
|
# raw_obs_events.sort(key=lambda x: x[0]) # Sort by timestamp
|
|
|
|
# # Group frames by label
|
|
# obs_groups = {}
|
|
# for ev in raw_obs_events:
|
|
# time_sec, _, label, special = ev[0], ev[1], ev[2], ev[3]
|
|
# frame = int(time_sec * fps)
|
|
# if label not in obs_groups:
|
|
# obs_groups[label] = []
|
|
# obs_groups[label].append(frame)
|
|
|
|
# # Convert groups of frames into (Start, End) blocks
|
|
# for label, frames in obs_groups.items():
|
|
# track_name = f"OBS: {label}"
|
|
# processed_blocks = []
|
|
|
|
# # Step by 2 to create start/end pairs
|
|
# for i in range(0, len(frames) - 1, 2):
|
|
# start_f = frames[i]
|
|
# end_f = frames[i+1]
|
|
# processed_blocks.append((start_f, end_f, "Moderate", "Manual"))
|
|
|
|
# # Register the track globally if it's new
|
|
# if track_name not in TRACK_NAMES:
|
|
# TRACK_NAMES.append(track_name)
|
|
# TRACK_COLORS.append(QColor("#AA00FF")) # Purple for Observations
|
|
|
|
# timeline_events[track_name] = processed_blocks
|
|
|
|
# except Exception as e:
|
|
# print(f"Error parsing JSON Observations: {e}")
|
|
|
|
|
|
try:
|
|
with open(json_path, 'r') as f:
|
|
full_json = json.load(f)
|
|
|
|
raw_obs_events = full_json["observations"][subkey]["events"]
|
|
raw_obs_events.sort(key=lambda x: x[0])
|
|
|
|
# NEW LOGIC: Use a dictionary to store frames for specific track names
|
|
# track_name -> [list of frames]
|
|
obs_groups = {}
|
|
|
|
for ev in raw_obs_events:
|
|
# ev structure: [time_sec, unknown, label, special]
|
|
time_sec, label, special = ev[0], ev[2], ev[3]
|
|
frame = int(time_sec * fps)
|
|
|
|
# Determine which tracks this event belongs to
|
|
target_tracks = []
|
|
|
|
if special == "Left":
|
|
target_tracks.append(f"OBS: {label} (Left)")
|
|
elif special == "Right":
|
|
target_tracks.append(f"OBS: {label} (Right)")
|
|
elif special == "Both":
|
|
target_tracks.append(f"OBS: {label} (Left)")
|
|
target_tracks.append(f"OBS: {label} (Right)")
|
|
else:
|
|
# No special or unrecognized value
|
|
target_tracks.append(f"OBS: {label}")
|
|
|
|
# Add the frame to all applicable tracks
|
|
for t_name in target_tracks:
|
|
if t_name not in obs_groups:
|
|
obs_groups[t_name] = []
|
|
obs_groups[t_name].append(frame)
|
|
|
|
# Convert frame groups into (Start, End) blocks
|
|
for track_name, frames in obs_groups.items():
|
|
processed_blocks = []
|
|
|
|
# Step by 2 to create start/end pairs (ensures matching pairs per track)
|
|
|
|
if "Sync" in track_name and len(frames) == 1:
|
|
start_f = frames[0]
|
|
end_f = start_f + 1 # Give it a visible width on the timeline
|
|
processed_blocks.append((start_f, end_f, "Moderate", "Manual"))
|
|
|
|
else:
|
|
for i in range(0, len(frames) - 1, 2):
|
|
start_f = frames[i]
|
|
end_f = frames[i+1]
|
|
processed_blocks.append((start_f, end_f, "Moderate", "Manual"))
|
|
|
|
# Register the track in global lists if not already there
|
|
if track_name not in TRACK_NAMES:
|
|
TRACK_NAMES.append(track_name)
|
|
# Using Purple for Observations
|
|
TRACK_COLORS.append(QColor("#AA00FF"))
|
|
|
|
timeline_events[track_name] = processed_blocks
|
|
|
|
except Exception as e:
|
|
print(f"Error parsing JSON Observations: {e}")
|
|
|
|
|
|
def _run_existing_ml_models(self, z_kps, dirs, raw_kpts):
|
|
debug_print()
|
|
"""
|
|
Scans for trained models and generates timeline tracks for each.
|
|
"""
|
|
ai_events = {}
|
|
|
|
# 1. Match the pattern from your GeneralPredictor: {Target}_rf.pkl
|
|
model_files = glob.glob("*_rf.pkl")
|
|
print(f"DEBUG: Found model files: {model_files}")
|
|
|
|
for model_path in model_files:
|
|
try:
|
|
# Extract Target (e.g., "Mouthing" from "Mouthing_rf.pkl")
|
|
base_name = model_path.split("_rf.pkl")[0]
|
|
target = base_name.replace("ml_", "", 1)
|
|
track_name = f"AI: {target}"
|
|
|
|
self.progress.emit(f"Loading AI Observations for {target}...")
|
|
|
|
|
|
# 2. Match the Scaler naming from calculate_and_train:
|
|
# {target}_random_forest_scaler.pkl
|
|
scaler_path = f"{base_name}_rf_scaler.pkl"
|
|
|
|
if not os.path.exists(scaler_path):
|
|
print(f"DEBUG: Skipping {target}, scaler not found at {scaler_path}")
|
|
continue
|
|
|
|
# Load assets
|
|
model = joblib.load(model_path)
|
|
scaler = joblib.load(scaler_path)
|
|
|
|
# 3. Feature extraction (On-the-fly)
|
|
all_features = []
|
|
# We must set the predictor's target so format_features uses the correct ACTIVITY_MAP
|
|
self.predictor.current_target = target
|
|
|
|
for f_idx in range(len(z_kps)):
|
|
feat = self.predictor.format_features(z_kps[f_idx], dirs[f_idx], raw_kpts[f_idx])
|
|
all_features.append(feat)
|
|
|
|
# 4. Inference
|
|
X = np.array(all_features)
|
|
X_scaled = scaler.transform(X)
|
|
predictions = model.predict(X_scaled)
|
|
|
|
# 5. Convert binary 0/1 to blocks
|
|
processed_blocks = []
|
|
start_f = None
|
|
|
|
for f_idx, val in enumerate(predictions):
|
|
if val == 1 and start_f is None:
|
|
start_f = f_idx
|
|
elif val == 0 and start_f is not None:
|
|
# [start, end, severity, direction]
|
|
processed_blocks.append((start_f, f_idx - 1, "Large", "AI"))
|
|
start_f = None
|
|
|
|
if start_f is not None:
|
|
processed_blocks.append((start_f, len(predictions)-1, "Large", "AI"))
|
|
|
|
# 6. Global Registration
|
|
if track_name not in TRACK_NAMES:
|
|
TRACK_NAMES.append(track_name)
|
|
# Ensure TRACK_COLORS has an entry for this new track
|
|
TRACK_COLORS.append(QColor("#00FF00"))
|
|
|
|
ai_events[track_name] = processed_blocks
|
|
print(f"DEBUG: Successfully generated {len(processed_blocks)} blocks for {track_name}")
|
|
|
|
except Exception as e:
|
|
print(f"Inference Error for {model_path}: {e}")
|
|
|
|
return ai_events
|
|
|
|
|
|
def classify_delta(self, z):
|
|
# debug_print()
|
|
z_abs = abs(z)
|
|
if z_abs < 1: return "Rest"
|
|
elif z_abs < 2: return "Small"
|
|
elif z_abs < 3: return "Moderate"
|
|
else: return "Large"
|
|
|
|
|
|
def _save_pose_cache(self, path, data):
|
|
"""
|
|
Saves the raw YOLO keypoints and confidence scores to a CSV.
|
|
Each row represents one frame, flattened from (17, 3) to (51,).
|
|
"""
|
|
try:
|
|
with open(path, 'w', newline='') as f:
|
|
writer = csv.writer(f)
|
|
|
|
# Create the descriptive header
|
|
header = []
|
|
for joint in JOINT_NAMES:
|
|
# Replace spaces with underscores for better compatibility with other tools
|
|
header.extend([f"{joint}_x", f"{joint}_y", f"{joint}_conf"])
|
|
|
|
writer.writerow(header)
|
|
|
|
# Write the frame data
|
|
for frame_data in data:
|
|
# frame_data is (17, 3), flatten to (51,)
|
|
writer.writerow(frame_data.flatten())
|
|
|
|
print(f"DEBUG: Pose cache saved with joint headers at {path}")
|
|
except Exception as e:
|
|
print(f"ERROR: Could not save pose cache: {e}")
|
|
|
|
|
|
def run(self):
|
|
debug_print()
|
|
# --- PHASE 1: VIDEO SETUP & POSE EXTRACTION ---
|
|
cap = cv2.VideoCapture(self.video_path)
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
raw_kps_per_frame = []
|
|
csv_storage_data = []
|
|
valid_mask = []
|
|
pose_cache = self.video_path.rsplit('.', 1)[0] + "_pose_raw.csv"
|
|
|
|
if os.path.exists(pose_cache):
|
|
self.progress.emit("Loading cached kinematic data...")
|
|
with open(pose_cache, 'r') as f:
|
|
reader = csv.reader(f)
|
|
next(reader)
|
|
for row in reader:
|
|
full_data = np.array([float(x) for x in row]).reshape(17, 3)
|
|
kp = full_data[:, :2]
|
|
raw_kps_per_frame.append(kp)
|
|
csv_storage_data.append(full_data)
|
|
valid_mask.append(np.any(kp))
|
|
else:
|
|
self.progress.emit("Detecting poses with YOLO...")
|
|
model = YOLO("yolov8n-pose.pt")
|
|
prev_track_id = None
|
|
for i in range(total_frames):
|
|
ret, frame = cap.read()
|
|
if not ret: break
|
|
results = model.track(frame, persist=True, verbose=False)
|
|
track_id, kp, confs, _ = self.get_best_infant_match(results, width, height, prev_track_id)
|
|
if kp is not None:
|
|
prev_track_id = track_id
|
|
raw_kps_per_frame.append(kp)
|
|
csv_storage_data.append(np.column_stack((kp, confs)))
|
|
valid_mask.append(True)
|
|
else:
|
|
raw_kps_per_frame.append(np.zeros((17, 2)))
|
|
csv_storage_data.append(np.zeros((17, 3)))
|
|
valid_mask.append(False)
|
|
if i % 50 == 0: self.progress.emit(f"YOLO: {int((i/total_frames)*100)}%")
|
|
self._save_pose_cache(pose_cache, csv_storage_data)
|
|
|
|
cap.release()
|
|
actual_len = len(raw_kps_per_frame)
|
|
|
|
flattened_rows = []
|
|
for frame_array in csv_storage_data:
|
|
# frame_array is (17, 3) -> flatten to (51,)
|
|
flattened_rows.append(frame_array.flatten())
|
|
|
|
columns = []
|
|
for name in JOINT_NAMES:
|
|
columns.extend([f"{name}_x", f"{name}_y", f"{name}_conf"])
|
|
|
|
# Store this so the Inspector can access it instantly in memory
|
|
self.pose_df = pd.DataFrame(flattened_rows, columns=columns)
|
|
|
|
# --- PHASE 2: KINEMATICS & Z-SCORES ---
|
|
self.progress.emit("Calculating Kinematics...")
|
|
analysis_kpts = []
|
|
for kp in raw_kps_per_frame:
|
|
pelvis = (kp[11] + kp[12]) / 2
|
|
analysis_kpts.append(kp - pelvis)
|
|
|
|
valid_data = [analysis_kpts[i] for i, v in enumerate(valid_mask) if v]
|
|
if valid_data:
|
|
stacked = np.stack(valid_data)
|
|
baseline_mean = np.mean(stacked, axis=0)
|
|
baseline_std = np.std(np.linalg.norm(stacked - baseline_mean, axis=2), axis=0) + 1e-6
|
|
else:
|
|
baseline_mean, baseline_std = np.zeros((17, 2)), np.ones(17)
|
|
|
|
np_raw_kps = np.array(raw_kps_per_frame)
|
|
np_z_kps = np.array([np.linalg.norm(kp - baseline_mean, axis=1) / baseline_std for kp in analysis_kpts])
|
|
|
|
# Calculate directions (Assume you have a method for this or use a dummy for now)
|
|
# Using placeholder empty strings to prevent errors in track generation
|
|
np_dirs = np.full((actual_len, 17), "", dtype=object)
|
|
|
|
# --- PHASE 3: TIMELINE GENERATION ---
|
|
# Initialize dictionary with ALL global track names to prevent KeyErrors
|
|
timeline_events = {name: [] for name in TRACK_NAMES}
|
|
|
|
# 1. Kinematic Events (The joint tracks)
|
|
for j_idx, joint_name in enumerate(JOINT_NAMES):
|
|
current_block = None
|
|
for f_idx in range(actual_len):
|
|
severity = self.classify_delta(np_z_kps[f_idx, j_idx])
|
|
if severity != "Rest":
|
|
if current_block and current_block[2] == severity:
|
|
current_block[1] = f_idx
|
|
else:
|
|
current_block = [f_idx, f_idx, severity, ""]
|
|
timeline_events[joint_name].append(current_block)
|
|
else:
|
|
current_block = None
|
|
|
|
# 2. JSON Observations
|
|
self._merge_json_observations(timeline_events, fps)
|
|
|
|
# 3. AI Inferred Events
|
|
ai_events = self._run_existing_ml_models(np_z_kps, np_dirs, np_raw_kps)
|
|
timeline_events.update(ai_events)
|
|
|
|
# --- PHASE 4: EMIT ---
|
|
data = {
|
|
"video_path": self.video_path,
|
|
"fps": fps,
|
|
"total_frames": actual_len,
|
|
"width": width, "height": height,
|
|
"events": timeline_events,
|
|
"raw_kps": np_raw_kps,
|
|
"z_kps": np_z_kps,
|
|
"directions": np_dirs,
|
|
"baseline_kp_mean": baseline_mean
|
|
}
|
|
self.progress.emit("Analysis Complete!")
|
|
self.finished_data.emit(data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ==========================================
|
|
# TIMELINE WIDGET
|
|
# ==========================================
|
|
class TimelineWidget(QWidget):
|
|
seek_requested = Signal(int)
|
|
visibility_changed = Signal(set)
|
|
track_selected = Signal(str)
|
|
|
|
def __init__(self):
|
|
debug_print()
|
|
super().__init__()
|
|
self.data = None
|
|
self.current_frame = 0
|
|
self.zoom_factor = 1.0 # Pixels per frame
|
|
self.label_width = 160 # Fixed gutter for track names
|
|
self.track_height = 25
|
|
self.ruler_height = 20
|
|
self.scrollbar_buffer = 2 # Extra space for the horizontal scrollbar
|
|
self.hidden_tracks = set()
|
|
self.sync_offset = 0.0
|
|
self.sync_fps = 30.0
|
|
# Calculate total required height
|
|
self.total_content_height = (NUM_TRACKS * self.track_height) + self.ruler_height
|
|
self.setMinimumHeight(self.total_content_height + self.scrollbar_buffer)
|
|
|
|
|
|
def set_sync_params(self, offset_seconds, fps=None):
|
|
"""
|
|
Updates the temporal shift parameters and refreshes the UI.
|
|
"""
|
|
debug_print()
|
|
self.sync_offset = float(offset_seconds)
|
|
|
|
# Only update FPS if a valid value is provided,
|
|
# otherwise keep the existing data/video FPS
|
|
if fps and fps > 0:
|
|
self.sync_fps = float(fps)
|
|
elif self.data and "fps" in self.data:
|
|
self.sync_fps = float(self.data["fps"])
|
|
|
|
print(f"DEBUG: Timeline Sync Set - Offset: {self.sync_offset}s, FPS: {self.sync_fps}")
|
|
|
|
# Trigger paintEvent to redraw the blocks in their new shifted positions
|
|
self.update()
|
|
|
|
|
|
def set_zoom(self, factor):
|
|
debug_print()
|
|
if not self.data: return
|
|
|
|
# Calculate MIN zoom: The zoom required to make the video fit the width exactly
|
|
# (Available Width - Sidebar) / Total Frames
|
|
available_w = self.parent().width() - self.label_width if self.parent() else 800
|
|
min_zoom = available_w / self.data["total_frames"]
|
|
|
|
# Clamp: Don't zoom out past the video end, don't zoom in to infinity
|
|
self.zoom_factor = max(min_zoom, min(factor, 50.0))
|
|
self.update_geometry()
|
|
|
|
|
|
def get_all_binary_labels(self, offset_seconds=0.0, fps=30.0):
|
|
"""
|
|
Extracts binary labels for ALL tracks in self.data["events"].
|
|
Returns a dict: {'OBS: Mouthing': [0,1,0...], 'OBS: Kicking': [0,0,1...]}
|
|
"""
|
|
debug_print()
|
|
all_labels = {}
|
|
|
|
if not self.data or "events" not in self.data:
|
|
return all_labels
|
|
|
|
total_frames = self.data.get("total_frames", 0)
|
|
if total_frames == 0:
|
|
return all_labels
|
|
|
|
frame_shift = int(offset_seconds * fps)
|
|
|
|
for track_name in self.data["events"]:
|
|
sequence = np.zeros(total_frames, dtype=int)
|
|
|
|
for event in self.data["events"][track_name]:
|
|
start_f = int(event[0]) - frame_shift
|
|
end_f = int(event[1]) - frame_shift
|
|
|
|
# Clamp values
|
|
start_idx = max(0, min(start_f, total_frames - 1))
|
|
end_idx = max(0, min(end_f, total_frames))
|
|
|
|
if start_idx < end_idx:
|
|
sequence[start_idx:end_idx] = 1
|
|
|
|
all_labels[track_name] = sequence.tolist() # Convert to list for easier storage
|
|
|
|
return all_labels
|
|
|
|
|
|
def update_geometry(self):
|
|
debug_print()
|
|
|
|
if self.data:
|
|
# Width is sidebar + (frames * zoom)
|
|
total_w = self.label_width + int(self.data["total_frames"] * self.zoom_factor)
|
|
self.setFixedWidth(total_w)
|
|
self.update()
|
|
|
|
def wheelEvent(self, event):
|
|
debug_print()
|
|
|
|
if event.modifiers() == Qt.ControlModifier:
|
|
delta = event.angleDelta().y()
|
|
# Zoom by 10% per notch
|
|
zoom_change = 1.1 if delta > 0 else 0.9
|
|
self.set_zoom(self.zoom_factor * zoom_change)
|
|
else:
|
|
# Let the scroll area handle normal vertical scrolling
|
|
super().wheelEvent(event)
|
|
|
|
# --- NEW: CTRL + Plus / Minus / Zero ---
|
|
def keyPressEvent(self, event):
|
|
debug_print()
|
|
|
|
if event.modifiers() == Qt.ControlModifier:
|
|
if event.key() == Qt.Key_Plus or event.key() == Qt.Key_Equal:
|
|
self.set_zoom(self.zoom_factor * 1.2)
|
|
elif event.key() == Qt.Key_Minus:
|
|
self.set_zoom(self.zoom_factor * 0.8)
|
|
elif event.key() == Qt.Key_0:
|
|
self.set_zoom(1.0) # Reset zoom
|
|
else:
|
|
super().keyPressEvent(event)
|
|
|
|
def set_data(self, data):
|
|
debug_print()
|
|
|
|
self.data = data
|
|
self.total_content_height = (len(TRACK_NAMES) * self.track_height) + self.ruler_height
|
|
self.setMinimumHeight(self.total_content_height + self.scrollbar_buffer)
|
|
self.update_geometry()
|
|
|
|
def set_playhead(self, frame):
|
|
debug_print()
|
|
old_x = self.label_width + (self.current_frame * self.zoom_factor)
|
|
self.current_frame = frame
|
|
new_x = self.label_width + (self.current_frame * self.zoom_factor)
|
|
self.ensure_playhead_visible()
|
|
self.update(int(old_x - 5), 0, 10, self.height())
|
|
self.update(int(new_x - 5), 0, 10, self.height())
|
|
|
|
def ensure_playhead_visible(self):
|
|
debug_print()
|
|
|
|
"""Auto-scrolls the scroll area if the playhead leaves the viewport."""
|
|
# Find the QScrollArea parent
|
|
scroll_area = self.parent().parent()
|
|
if not isinstance(scroll_area, QScrollArea): return
|
|
|
|
scrollbar = scroll_area.horizontalScrollBar()
|
|
view_width = scroll_area.viewport().width()
|
|
|
|
# Playhead position in pixels
|
|
px = self.label_width + int(self.current_frame * self.zoom_factor)
|
|
|
|
# Current scroll position
|
|
scroll_x = scrollbar.value()
|
|
|
|
# If playhead is beyond the right edge of visible area
|
|
if px > (scroll_x + view_width):
|
|
# Shift scroll so playhead is at the left (plus sidebar)
|
|
scrollbar.setValue(px - self.label_width)
|
|
|
|
# If playhead is behind the left edge (e.g. user seeked backwards)
|
|
elif px < (scroll_x + self.label_width):
|
|
scrollbar.setValue(px - self.label_width)
|
|
|
|
def mousePressEvent(self, event):
|
|
debug_print()
|
|
|
|
if not self.data or event.button() != Qt.LeftButton:
|
|
return
|
|
|
|
pos_x = event.position().x()
|
|
pos_y = event.position().y()
|
|
scroll_area = self.parent().parent()
|
|
scroll_x = scroll_area.horizontalScrollBar().value()
|
|
|
|
# 1. CALCULATE FRAME
|
|
relative_x = pos_x - self.label_width
|
|
frame = int(relative_x / self.zoom_factor)
|
|
frame = max(0, min(frame, self.data["total_frames"] - 1))
|
|
|
|
# 2. IF CLICKED SIDEBAR: Toggle Visibility (No Scrubbing)
|
|
if pos_x < scroll_x + self.label_width:
|
|
relative_y = pos_y - self.ruler_height
|
|
track_idx = int(relative_y // self.track_height)
|
|
if 0 <= track_idx < len(TRACK_NAMES):
|
|
name = TRACK_NAMES[track_idx]
|
|
if name in self.hidden_tracks: self.hidden_tracks.remove(name)
|
|
else: self.hidden_tracks.add(name)
|
|
self.visibility_changed.emit(self.hidden_tracks)
|
|
self.update()
|
|
return # Exit early; don't set is_scrubbing
|
|
|
|
# 3. IF CLICKED RULER OR DATA AREA: Start Scrubbing
|
|
self.is_scrubbing = True
|
|
self.seek_requested.emit(frame)
|
|
|
|
# Handle track selection if in the data area
|
|
if pos_y >= self.ruler_height:
|
|
track_idx = int((pos_y - self.ruler_height) // self.track_height)
|
|
if 0 <= track_idx < len(TRACK_NAMES):
|
|
self.track_selected.emit(TRACK_NAMES[track_idx])
|
|
self.selected_track_idx = track_idx
|
|
self.update()
|
|
else:
|
|
# Clicked ruler
|
|
self.selected_track_idx = -1
|
|
self.track_selected.emit("")
|
|
self.update()
|
|
|
|
def mouseMoveEvent(self, event):
|
|
debug_print()
|
|
|
|
# This only fires while moving if a button is held down by default
|
|
if self.is_scrubbing:
|
|
self.update_frame_from_mouse(event.position().x())
|
|
|
|
def mouseReleaseEvent(self, event):
|
|
debug_print()
|
|
|
|
if event.button() == Qt.LeftButton:
|
|
self.is_scrubbing = False
|
|
|
|
def update_frame_from_mouse(self, x_pos):
|
|
"""Helper to calculate frame and emit the seek signal."""
|
|
debug_print()
|
|
relative_x = x_pos - self.label_width
|
|
frame = int(relative_x / self.zoom_factor)
|
|
frame = max(0, min(frame, self.data["total_frames"] - 1))
|
|
|
|
# We emit seek_requested so the Video Player and Premiere class sync up
|
|
self.seek_requested.emit(frame)
|
|
|
|
|
|
def paintEvent(self, event):
|
|
debug_print()
|
|
if not self.data: return
|
|
|
|
dirty_rect = event.rect()
|
|
painter = QPainter(self)
|
|
|
|
# 1. Determine current scroll position to keep labels sticky
|
|
scroll_area = self.parent().parent()
|
|
scroll_x = 0
|
|
if isinstance(scroll_area, QScrollArea):
|
|
scroll_x = scroll_area.horizontalScrollBar().value()
|
|
|
|
w, h = self.width(), self.height()
|
|
total_f = self.data["total_frames"]
|
|
fps = self.data.get("fps", 30)
|
|
offset_y = 20
|
|
|
|
# 2. DRAW DATA AREA (Events and Playhead)
|
|
# --- 2. DRAW DATA AREA (Muted Patterns + Events + Playhead) ---
|
|
sync_off = getattr(self, "sync_offset", 0.0)
|
|
sync_fps = getattr(self, "sync_fps", fps)
|
|
frame_shift = int(sync_off * sync_fps)
|
|
|
|
for i, name in enumerate(TRACK_NAMES):
|
|
y = offset_y + (i * self.track_height)
|
|
is_hidden = name in self.hidden_tracks
|
|
|
|
if y + self.track_height < dirty_rect.top() or y > dirty_rect.bottom():
|
|
continue
|
|
|
|
# A. Draw Muted/Disabled Background Pattern
|
|
|
|
if is_hidden:
|
|
# Calculate the visible rectangle for this track to the right of the sidebar
|
|
mute_rect = QRectF(self.label_width, y, w - self.label_width, self.track_height)
|
|
# Fill with a dark "disabled" base
|
|
painter.fillRect(mute_rect, QColor(40, 40, 40))
|
|
# Add the Cross-Hatch Pattern
|
|
pattern_brush = QBrush(QColor(60, 60, 60, 100), Qt.DiagCrossPattern)
|
|
painter.fillRect(mute_rect, pattern_brush)
|
|
|
|
# B. Draw Event Blocks
|
|
base_color = TRACK_COLORS[i]
|
|
for start_f, end_f, severity, direction in self.data["events"][name]:
|
|
# x_start = self.label_width + (start_f * self.zoom_factor)
|
|
# x_end = self.label_width + (end_f * self.zoom_factor)
|
|
|
|
if "AI:" in name:
|
|
# AI predictions are already calculated in video-time, NO SHIFT
|
|
shifted_start, shifted_end = start_f, end_f
|
|
|
|
else:
|
|
shifted_start = start_f - frame_shift
|
|
shifted_end = end_f - frame_shift
|
|
|
|
x_start = self.label_width + (shifted_start * self.zoom_factor)
|
|
x_end = self.label_width + (shifted_end * self.zoom_factor)
|
|
|
|
# Performance optimization: skip drawing if off-screen
|
|
if x_end < scroll_x or x_start > scroll_x + w:
|
|
continue
|
|
|
|
if x_end < dirty_rect.left() or x_start > dirty_rect.right():
|
|
continue
|
|
|
|
# If hidden, make the event block very faint/transparent
|
|
if is_hidden:
|
|
color = QColor(120, 120, 120, 40) # Muted Grey
|
|
else:
|
|
alpha = 80 if severity == "Small" else 160 if severity == "Moderate" else 255
|
|
color = QColor(base_color)
|
|
color.setAlpha(alpha)
|
|
|
|
painter.fillRect(QRectF(x_start, y + 2, max(1, x_end - x_start), self.track_height - 4), color)
|
|
# Draw Playhead
|
|
playhead_x = self.label_width + (self.current_frame * self.zoom_factor)
|
|
painter.setPen(QPen(QColor(255, 0, 0), 2))
|
|
painter.drawLine(playhead_x, 0, playhead_x, h)
|
|
|
|
# 3. DRAW STICKY SIDEBAR (Pinned to the left edge of the viewport)
|
|
# We draw this AFTER the data so it covers the blocks as they scroll past
|
|
sidebar_rect = QRect(scroll_x, 0, self.label_width, h)
|
|
painter.fillRect(sidebar_rect, QColor(30, 30, 30)) # Solid background
|
|
|
|
# Ruler segment for the sidebar area
|
|
painter.fillRect(scroll_x, 0, self.label_width, offset_y, QColor(45, 45, 45))
|
|
|
|
for i, name in enumerate(TRACK_NAMES):
|
|
y = offset_y + (i * self.track_height)
|
|
is_hidden = name in self.hidden_tracks
|
|
# Grid Line
|
|
painter.setPen(QColor(60, 60, 60))
|
|
painter.drawLine(scroll_x, y, scroll_x + w, y)
|
|
|
|
# Sticky Label Text
|
|
if is_hidden:
|
|
# Very dark grey to show it's "OFF"
|
|
text_color = QColor(70, 70, 70)
|
|
else:
|
|
# Bright white/grey to show it's "ON"
|
|
text_color = QColor(220, 220, 220)
|
|
|
|
painter.setPen(text_color)
|
|
painter.setFont(QFont("Arial", 8, QFont.Bold))
|
|
painter.drawText(scroll_x + 10, y + 17, name)
|
|
|
|
# 4. DRAW TIME RULER TICKS (Right of the sticky sidebar)
|
|
target_spacing_px = 120
|
|
|
|
# Available units in frames: 1, 5, 15, 30 (1s), 150 (5s), 300 (10s), 1800 (1min)
|
|
possible_units = [1, 5, 15, 30, 150, 300, 900, 1800]
|
|
|
|
# Find the smallest unit that results in at least target_spacing_px
|
|
tick_interval = possible_units[-1]
|
|
for unit in possible_units:
|
|
if (unit * self.zoom_factor) >= target_spacing_px:
|
|
tick_interval = unit
|
|
break
|
|
|
|
# 2. DRAW BACKGROUNDS
|
|
painter.fillRect(0, 0, w, 20, QColor(45, 45, 45)) # Ruler Bar
|
|
|
|
# 3. DRAW TICKS AND TIME LABELS
|
|
painter.setPen(QColor(180, 180, 180))
|
|
painter.setFont(QFont("Segoe UI", 7))
|
|
|
|
# Sub-ticks (draw 5 small lines for every 1 major interval)
|
|
sub_interval = max(1, tick_interval // 5)
|
|
|
|
# Start loop from 0 to total frames
|
|
for f in range(0, total_f + 1, sub_interval):
|
|
x = self.label_width + int(f * self.zoom_factor)
|
|
|
|
# Optimization: Don't draw if off-screen
|
|
if x < scroll_x: continue
|
|
if x > scroll_x + w: break
|
|
|
|
if f % tick_interval == 0:
|
|
# Major Tick
|
|
painter.drawLine(x, 10, x, 20)
|
|
|
|
# Format Label: MM:SS or SS:FF
|
|
total_seconds = f / fps
|
|
minutes = int(total_seconds // 60)
|
|
seconds = int(total_seconds % 60)
|
|
frames = int(f % fps)
|
|
|
|
if tick_interval < fps:
|
|
time_str = f"{seconds:02d}:{frames:02d}f"
|
|
elif minutes > 0:
|
|
time_str = f"{minutes:02d}m:{seconds:02d}s"
|
|
else:
|
|
time_str = f"{seconds}s"
|
|
|
|
painter.drawText(x + 4, 12, time_str)
|
|
else:
|
|
# Minor Tick
|
|
painter.drawLine(x, 16, x, 20)
|
|
|
|
|
|
def get_ai_extractions(self):
|
|
"""
|
|
Processes timeline data for AI tracks and specific OBS sync events.
|
|
"""
|
|
debug_print()
|
|
fps = self.data.get("fps", 30.0)
|
|
|
|
extraction_data = {
|
|
"metadata": {
|
|
"fps": fps,
|
|
"total_frames": self.data.get("total_frames", 0),
|
|
"track_summaries": {}
|
|
},
|
|
"obs": {}, # New top-level key for specific OBS events
|
|
"ai_tracks": {}
|
|
}
|
|
|
|
if not self.data or "events" not in self.data:
|
|
return extraction_data
|
|
|
|
# 1. Extract the OBS: Time Sync event specifically
|
|
sync_key = "OBS: Time Sync"
|
|
if sync_key in self.data["events"]:
|
|
sync_blocks = self.data["events"][sync_key]
|
|
# Convert blocks to a list of dicts for the JSON
|
|
extraction_data["obs"][sync_key] = [
|
|
{
|
|
"start_frame": b[0],
|
|
"end_frame": b[1],
|
|
"start_time_sec": round(b[0] / fps, 3),
|
|
"end_time_sec": round(b[1] / fps, 3)
|
|
} for b in sync_blocks
|
|
]
|
|
|
|
# 2. Process AI Tracks
|
|
for track_name, blocks in self.data["events"].items():
|
|
if track_name.startswith("AI:"):
|
|
track_results = []
|
|
track_total = 0
|
|
track_long = 0
|
|
|
|
for block in blocks:
|
|
start_f, end_f = block[0], block[1]
|
|
severity = block[2] if len(block) > 2 else "Normal"
|
|
|
|
start_sec = round(start_f / fps, 3)
|
|
end_sec = round(end_f / fps, 3)
|
|
duration = round(end_sec - start_sec, 3)
|
|
|
|
track_total += 1
|
|
if duration > 2.0:
|
|
track_long += 1
|
|
|
|
track_results.append({
|
|
"start_frame": int(start_f),
|
|
"end_frame": int(end_f),
|
|
"start_time_sec": start_sec,
|
|
"end_time_sec": end_sec,
|
|
"duration_sec": duration,
|
|
"severity": severity
|
|
})
|
|
|
|
extraction_data["ai_tracks"][track_name] = track_results
|
|
extraction_data["metadata"]["track_summaries"][track_name] = {
|
|
"event_count": track_total,
|
|
"long_events_over_2s": track_long,
|
|
"total_duration_sec": round(sum(r["duration_sec"] for r in track_results), 3)
|
|
}
|
|
|
|
return extraction_data
|
|
|
|
|
|
class SkeletonOverlay(QWidget):
|
|
def __init__(self, parent=None):
|
|
debug_print()
|
|
super().__init__(parent)
|
|
self.setAttribute(Qt.WA_TransparentForMouseEvents) # Clicks pass through to video
|
|
self.data = None
|
|
self.current_frame = 0
|
|
self.hidden_tracks = set()
|
|
# Use your saved SKELETON_CONNECTIONS logic
|
|
self.connections = [
|
|
(5, 7), (7, 9), (6, 8), (8, 10), (5, 6), (5, 11),
|
|
(6, 12), (11, 12), (11, 13), (13, 15), (12, 14), (14, 16)
|
|
]
|
|
self.KP_MAP = {
|
|
"nose": 0, "LE": 1, "RE": 2, "Lear": 3, "Rear": 4,
|
|
"Lshoulder": 5, "Rshoulder": 6, "Lelbow": 7, "Relbow": 8,
|
|
"Lwrist": 9, "Rwrist": 10, "Lhip": 11, "Rhip": 12,
|
|
"Lknee": 13, "Rknee": 14, "Lankle": 15, "Rankle": 16
|
|
}
|
|
self.CONNECTIONS = [
|
|
("nose", "LE"), ("nose", "RE"), ("LE", "Lear"), ("RE", "Rear"),
|
|
("Lshoulder", "Rshoulder"), ("Lshoulder", "Lelbow"), ("Lelbow", "Lwrist"),
|
|
("Rshoulder", "Relbow"), ("Relbow", "Rwrist"), ("Lshoulder", "Lhip"),
|
|
("Rshoulder", "Rhip"), ("Lhip", "Rhip"), ("Lhip", "Lknee"),
|
|
("Lknee", "Lankle"), ("Rhip", "Rknee"), ("Rknee", "Rankle")
|
|
]
|
|
|
|
|
|
def set_frame(self, frame_idx):
|
|
debug_print()
|
|
self.current_frame = frame_idx
|
|
self.update()
|
|
|
|
|
|
def set_hidden_tracks(self, hidden_set):
|
|
debug_print()
|
|
self.hidden_tracks = hidden_set
|
|
self.update()
|
|
|
|
|
|
def set_data(self, data):
|
|
debug_print()
|
|
self.data = data
|
|
self.update()
|
|
|
|
|
|
def paintEvent(self, event):
|
|
debug_print()
|
|
if not self.data or 'raw_kps' not in self.data:
|
|
return
|
|
|
|
painter = QPainter(self)
|
|
painter.setRenderHint(QPainter.Antialiasing)
|
|
|
|
v_w, v_h = self.data['width'], self.data['height']
|
|
w, h = self.width(), self.height()
|
|
scale_x, scale_y = w / v_w, h / v_h
|
|
|
|
current_f = self.current_frame
|
|
kp_live = self.data['raw_kps'][current_f]
|
|
|
|
# --- 1. MODIFIED TRACK STATUS (Respects Visibility) ---
|
|
def get_track_status(track_name):
|
|
# If the user greyed out this track in the timeline, act as if it's inactive
|
|
if track_name in self.hidden_tracks:
|
|
return None
|
|
if track_name not in self.data['events']:
|
|
return None
|
|
for start, end, severity, direction in self.data['events'][track_name]:
|
|
if start <= current_f <= end:
|
|
idx = TRACK_NAMES.index(track_name)
|
|
color = QColor(TRACK_COLORS[idx])
|
|
alpha = 80 if severity == "Small" else 160 if severity == "Moderate" else 255
|
|
color.setAlpha(alpha)
|
|
return color
|
|
return None
|
|
|
|
ANGLE_SEGMENTS = {
|
|
"L_sh": [("Lhip", "Lshoulder"), ("Lshoulder", "Lelbow")],
|
|
"R_sh": [("Rhip", "Rshoulder"), ("Rshoulder", "Relbow")],
|
|
"L_el": [("Lshoulder", "Lelbow"), ("Lelbow", "Lwrist")],
|
|
"R_el": [("Rshoulder", "Relbow"), ("Relbow", "Rwrist")],
|
|
"L_leg": [("Lhip", "Lknee"), ("Lknee", "Lankle")],
|
|
"R_leg": [("Rhip", "Rknee"), ("Rknee", "Rankle")]
|
|
}
|
|
|
|
# --- 2. DRAW BASELINE (Only if not hidden) ---
|
|
if "Baseline" not in self.hidden_tracks:
|
|
idx_l_hip, idx_r_hip = self.KP_MAP["Lhip"], self.KP_MAP["Rhip"]
|
|
pelvis_live = (kp_live[idx_l_hip] + kp_live[idx_r_hip]) / 2
|
|
kp_baseline = self.data['baseline_kp_mean'] + pelvis_live
|
|
|
|
painter.setPen(QPen(QColor(200, 200, 200, 200), 2, Qt.DashLine))
|
|
for s_name, e_name in self.CONNECTIONS:
|
|
p1 = QPointF(kp_baseline[self.KP_MAP[s_name]][0] * scale_x, kp_baseline[self.KP_MAP[s_name]][1] * scale_y)
|
|
p2 = QPointF(kp_baseline[self.KP_MAP[e_name]][0] * scale_x, kp_baseline[self.KP_MAP[e_name]][1] * scale_y)
|
|
painter.drawLine(p1, p2)
|
|
|
|
# --- 3. DRAW LIVE SKELETON (Only if not hidden) ---
|
|
|
|
# CONNECTIONS
|
|
for s_name, e_name in self.CONNECTIONS:
|
|
active_color = None
|
|
for angle_track, segments in ANGLE_SEGMENTS.items():
|
|
if (s_name, e_name) in segments or (e_name, s_name) in segments:
|
|
active_color = get_track_status(angle_track)
|
|
if active_color: break
|
|
|
|
p1 = QPointF(kp_live[self.KP_MAP[s_name]][0] * scale_x, kp_live[self.KP_MAP[s_name]][1] * scale_y)
|
|
p2 = QPointF(kp_live[self.KP_MAP[e_name]][0] * scale_x, kp_live[self.KP_MAP[e_name]][1] * scale_y)
|
|
|
|
if active_color:
|
|
# Active events ALWAYS draw
|
|
painter.setPen(QPen(active_color, 8, Qt.SolidLine, Qt.RoundCap))
|
|
painter.drawLine(p1, p2)
|
|
elif "Live Skeleton" not in self.hidden_tracks:
|
|
# Black lines ONLY draw if Live Skeleton is ON
|
|
painter.setPen(QPen(Qt.black, 4, Qt.SolidLine, Qt.RoundCap))
|
|
painter.drawLine(p1, p2)
|
|
|
|
# DOTS
|
|
ANGLE_VERTEX_MAP = {
|
|
"L_sh": "Lshoulder", "R_sh": "Rshoulder",
|
|
"L_el": "Lelbow", "R_el": "Relbow",
|
|
"L_leg": "Lknee", "R_leg": "Rknee"
|
|
}
|
|
|
|
for kp_name, kp_idx in self.KP_MAP.items():
|
|
pt = QPointF(kp_live[kp_idx][0] * scale_x, kp_live[kp_idx][1] * scale_y)
|
|
|
|
# Check for Point Event (Skip if hidden via get_track_status)
|
|
point_color = get_track_status(kp_name)
|
|
|
|
if point_color:
|
|
painter.setBrush(point_color)
|
|
painter.setPen(QPen(Qt.white, 0.7))
|
|
painter.drawEllipse(pt, 5, 5)
|
|
continue
|
|
|
|
# Check for Angle Event
|
|
angle_color = None
|
|
for angle_track, vertex_name in ANGLE_VERTEX_MAP.items():
|
|
if kp_name == vertex_name:
|
|
angle_color = get_track_status(angle_track)
|
|
if angle_color: break
|
|
|
|
if angle_color:
|
|
painter.setBrush(angle_color)
|
|
painter.setPen(Qt.NoPen)
|
|
painter.drawEllipse(pt, 4, 4)
|
|
|
|
elif "Live Skeleton" not in self.hidden_tracks:
|
|
painter.setBrush(Qt.black)
|
|
painter.setPen(Qt.NoPen)
|
|
painter.drawEllipse(pt, 4, 4)
|
|
|
|
|
|
class VideoView(QGraphicsView):
|
|
resized = Signal()
|
|
|
|
def __init__(self, scene, parent=None):
|
|
debug_print()
|
|
super().__init__(scene, parent)
|
|
self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
|
self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
|
self.setFrameStyle(0)
|
|
self.setStyleSheet("background: black; border: none;")
|
|
self.setAlignment(Qt.AlignCenter)
|
|
|
|
def resizeEvent(self, event):
|
|
debug_print()
|
|
super().resizeEvent(event)
|
|
self.resized.emit()
|
|
|
|
|
|
# ==========================================
|
|
# MAIN PREMIERE WINDOW
|
|
# ==========================================
|
|
|
|
class PremiereWindow(QMainWindow):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.setWindowTitle(f"Pose Analysis Timeline - {APP_NAME}")
|
|
self.resize(1200, 900)
|
|
|
|
self.platform_suffix = "-" + PLATFORM_NAME
|
|
|
|
# Application-wide Updaters
|
|
self.updater = UpdateManager(
|
|
main_window=self,
|
|
api_url=API_URL,
|
|
api_url_sec=API_URL_SECONDARY,
|
|
current_version=CURRENT_VERSION,
|
|
platform_name=PLATFORM_NAME,
|
|
platform_suffix=self.platform_suffix,
|
|
app_name=APP_NAME
|
|
)
|
|
|
|
self.setStyleSheet("""
|
|
QMainWindow, QWidget#centralWidget { background-color: #1e1e1e; }
|
|
QLabel, QStatusBar, QMenuBar { color: #ffffff; }
|
|
QTabWidget::pane { border: 1px solid #333333; background: #1e1e1e; }
|
|
QTabBar::tab { background: #2b2b2b; color: #aaa; padding: 8px 15px; border: 1px solid #333; border-bottom: none; border-top-left-radius: 4px; border-top-right-radius: 4px; }
|
|
QTabBar::tab:selected { background: #3d3d3d; color: #fff; font-weight: bold; }
|
|
QTabBar::tab:hover { background: #444; }
|
|
""")
|
|
|
|
# --- Tab System ---
|
|
self.tabs = QTabWidget()
|
|
self.tabs.setTabsClosable(True)
|
|
self.tabs.tabCloseRequested.connect(self.close_tab)
|
|
self.setCentralWidget(self.tabs)
|
|
|
|
self.create_welcome_tab()
|
|
self.create_menu_bar()
|
|
|
|
# Update checks
|
|
self.local_check_thread = LocalPendingUpdateCheckThread(CURRENT_VERSION, self.platform_suffix, PLATFORM_NAME, APP_NAME)
|
|
self.local_check_thread.pending_update_found.connect(self.updater.on_pending_update_found)
|
|
self.local_check_thread.no_pending_update.connect(self.updater.on_no_pending_update)
|
|
self.local_check_thread.start()
|
|
|
|
# Window instances
|
|
self.load_window = None
|
|
self.about = None
|
|
self.help = None
|
|
|
|
def create_welcome_tab(self):
|
|
welcome_widget = QWidget()
|
|
layout = QVBoxLayout(welcome_widget)
|
|
|
|
title = QLabel(f"Welcome to {APP_NAME}")
|
|
title.setStyleSheet("font-size: 24px; font-weight: bold; color: #00aaff;")
|
|
title.setAlignment(Qt.AlignCenter)
|
|
|
|
subtitle = QLabel("Click 'File' > 'Load Video...' to begin a new analysis session.")
|
|
subtitle.setStyleSheet("font-size: 14px; color: #aaaaaa;")
|
|
subtitle.setAlignment(Qt.AlignCenter)
|
|
|
|
layout.addStretch()
|
|
layout.addWidget(title)
|
|
layout.addWidget(subtitle)
|
|
layout.addStretch()
|
|
|
|
self.tabs.addTab(welcome_widget, "Welcome")
|
|
# Disable the close button on the welcome tab
|
|
self.tabs.tabBar().setTabButton(0, QTabBar.ButtonPosition.RightSide, None)
|
|
|
|
def create_menu_bar(self):
|
|
menu_bar = self.menuBar()
|
|
self.statusbar = self.statusBar()
|
|
|
|
def make_action(name, shortcut=None, slot=None, checkable=False, checked=False, icon=None):
|
|
action = QAction(name, self)
|
|
if shortcut: action.setShortcut(QKeySequence(shortcut))
|
|
if slot: action.triggered.connect(slot)
|
|
if checkable:
|
|
action.setCheckable(True)
|
|
action.setChecked(checked)
|
|
if icon: action.setIcon(QIcon(icon))
|
|
return action
|
|
|
|
# File Menu
|
|
file_menu = menu_bar.addMenu("File")
|
|
file_menu.addAction(make_action("Load Video...", "Ctrl+O", self.open_load_video_dialog))
|
|
file_menu.addSeparator()
|
|
file_menu.addAction(make_action("Exit", "Ctrl+Q", QApplication.instance().quit))
|
|
|
|
# Edit Menu (Routes to current tab)
|
|
edit_menu = menu_bar.addMenu("Edit")
|
|
edit_menu.addAction(make_action("Cut", "Ctrl+X", self.route_cut))
|
|
edit_menu.addAction(make_action("Copy", "Ctrl+C", self.route_copy))
|
|
edit_menu.addAction(make_action("Paste", "Ctrl+V", self.route_paste))
|
|
|
|
# View Menu
|
|
view_menu = menu_bar.addMenu("View")
|
|
toggle_sb = make_action("Toggle Status Bar", checkable=True, checked=True)
|
|
toggle_sb.toggled.connect(self.statusbar.setVisible)
|
|
view_menu.addAction(toggle_sb)
|
|
|
|
self.statusbar.showMessage("Ready")
|
|
|
|
# --- Tab & Loading Logic ---
|
|
|
|
def open_load_video_dialog(self):
|
|
if self.load_window is None or not self.load_window.isVisible():
|
|
self.load_window = OpenFileWindow(self)
|
|
# Connect the initialization button from OpenFileWindow to our tab creator
|
|
self.load_window.btn_confirm.clicked.connect(self.handle_new_video_session)
|
|
self.load_window.show()
|
|
|
|
def handle_new_video_session(self):
|
|
# Extract properties from the OpenFileWindow before closing it
|
|
video_path = self.load_window.video_path
|
|
obs_file = self.load_window.obs_file
|
|
offset = self.load_window.current_video_offset
|
|
# ... grab any other needed parameters (like selected ML model, etc.)
|
|
|
|
self.load_window.close()
|
|
|
|
# Create a new, independent tab
|
|
new_tab = VideoAnalysisTab(video_path, obs_file, offset)
|
|
|
|
# Add to TabWidget and switch to it
|
|
tab_name = os.path.basename(video_path)
|
|
index = self.tabs.addTab(new_tab, tab_name)
|
|
self.tabs.setCurrentIndex(index)
|
|
|
|
def close_tab(self, index):
|
|
# Prevent closing the Welcome tab if it's the only one left
|
|
if index == 0 and self.tabs.count() == 1:
|
|
return
|
|
|
|
widget = self.tabs.widget(index)
|
|
if widget:
|
|
# If the widget has cleanup routines (like stopping video), call them here
|
|
if hasattr(widget, 'cleanup'):
|
|
widget.cleanup()
|
|
widget.deleteLater()
|
|
self.tabs.removeTab(index)
|
|
|
|
# --- Routing Menu Actions to Current Tab ---
|
|
|
|
def get_current_tab(self):
|
|
return self.tabs.currentWidget()
|
|
|
|
def route_copy(self):
|
|
tab = self.get_current_tab()
|
|
if hasattr(tab, 'info_label'):
|
|
tab.info_label.copy()
|
|
self.statusbar.showMessage("Copied to clipboard")
|
|
|
|
def route_cut(self):
|
|
tab = self.get_current_tab()
|
|
if hasattr(tab, 'info_label'):
|
|
tab.info_label.cut()
|
|
self.statusbar.showMessage("Cut to clipboard")
|
|
|
|
def route_paste(self):
|
|
tab = self.get_current_tab()
|
|
if hasattr(tab, 'info_label'):
|
|
tab.info_label.paste()
|
|
self.statusbar.showMessage("Pasted from clipboard")
|
|
|
|
|
|
|
|
from PySide6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QSplitter,
|
|
QScrollArea, QPushButton, QComboBox, QLabel,
|
|
QSlider, QTextEdit, QMessageBox)
|
|
from PySide6.QtCore import Qt, QSizeF, QUrl, Slot
|
|
from PySide6.QtMultimedia import QMediaPlayer, QAudioOutput
|
|
from PySide6.QtMultimediaWidgets import QGraphicsVideoItem
|
|
|
|
class VideoAnalysisTab(QWidget):
|
|
def __init__(self, video_path, obs_file=None, offset=0.0, parent=None):
|
|
super().__init__(parent)
|
|
|
|
# State
|
|
self.file_path = video_path
|
|
self.obs_file = obs_file
|
|
self.current_video_offset = offset
|
|
self.predictor = GeneralPredictor()
|
|
self.data = None # Will be populated by worker
|
|
|
|
self.setup_ui()
|
|
#self.reprocess_current_video()
|
|
|
|
def setup_ui(self):
|
|
main_layout = QVBoxLayout(self)
|
|
main_layout.setContentsMargins(0, 0, 0, 0)
|
|
|
|
self.main_splitter = QSplitter(Qt.Vertical)
|
|
top_splitter = QSplitter(Qt.Horizontal)
|
|
|
|
# --- Video Area ---
|
|
video_container = QWidget()
|
|
video_layout = QVBoxLayout(video_container)
|
|
|
|
self.scene = QGraphicsScene()
|
|
self.view = VideoView(self.scene)
|
|
self.view.resized.connect(self.update_video_geometry)
|
|
|
|
self.video_item = QGraphicsVideoItem()
|
|
self.scene.addItem(self.video_item)
|
|
|
|
# Overlay initialization
|
|
self.skeleton_overlay = SkeletonOverlay(self.view.viewport())
|
|
|
|
self.player = QMediaPlayer()
|
|
self.audio_output = QAudioOutput()
|
|
self.player.setAudioOutput(self.audio_output)
|
|
self.player.setVideoOutput(self.video_item)
|
|
|
|
video_layout.addWidget(self.view)
|
|
|
|
# --- Controls Area ---
|
|
|
|
controls_container = QWidget()
|
|
stacked_controls = QVBoxLayout(controls_container)
|
|
stacked_controls.setSpacing(5) # Tight spacing between rows
|
|
|
|
ml_row = QHBoxLayout()
|
|
ml_row.addStretch()
|
|
|
|
ml_row.addWidget(QLabel("ML Model:"))
|
|
self.ml_dropdown = QComboBox()
|
|
self.ml_dropdown.addItems(["Random Forest", "LSTM", "XGBoost", "SVM", "1D-CNN"])
|
|
ml_row.addWidget(self.ml_dropdown)
|
|
|
|
ml_row.addWidget(QLabel("Target:"))
|
|
self.target_dropdown = QComboBox()
|
|
self.target_dropdown.addItems(["Mouthing", "Head Movement", "Kick (Left)", "Kick (Right)", "Reach (Left)", "Reach (Right)"])
|
|
#self.target_dropdown.currentTextChanged.connect(self.update_predictor_target)
|
|
ml_row.addWidget(self.target_dropdown)
|
|
|
|
self.btn_add_to_pool = QPushButton("Add to Pool")
|
|
#self.btn_add_to_pool.clicked.connect(self.add_current_to_ml_pool)
|
|
self.btn_add_to_pool.setFixedWidth(120)
|
|
ml_row.addWidget(self.btn_add_to_pool)
|
|
|
|
self.btn_train_final = QPushButton("Train Global Model")
|
|
self.btn_train_final.setStyleSheet("background-color: #2e7d32; font-weight: bold;")
|
|
#self.btn_train_final.clicked.connect(self.run_final_training)
|
|
ml_row.addWidget(self.btn_train_final)
|
|
|
|
self.lbl_pool_status = QLabel("Pool: 0 Participants")
|
|
self.lbl_pool_status.setStyleSheet("color: #00FF00; font-weight: bold; margin-left: 10px;")
|
|
self.lbl_pool_status.setFixedWidth(160)
|
|
ml_row.addWidget(self.lbl_pool_status)
|
|
|
|
self.btn_clear_pool = QPushButton("Clear Pool")
|
|
self.btn_clear_pool.setFixedWidth(100)
|
|
self.btn_clear_pool.setStyleSheet("color: #ff5555; border: 1px solid #ff5555;")
|
|
#self.btn_clear_pool.clicked.connect(self.clear_ml_pool)
|
|
ml_row.addWidget(self.btn_clear_pool)
|
|
|
|
self.btn_extract_ai = QPushButton("Extract AI Data")
|
|
#self.btn_extract_ai.clicked.connect(self.extract_ai_to_json)
|
|
ml_row.addWidget(self.btn_extract_ai)
|
|
|
|
|
|
ml_row.addStretch()
|
|
|
|
# --- ROW 2: Playback & Transport ---
|
|
playback_row = QHBoxLayout()
|
|
playback_row.addStretch()
|
|
|
|
# Transport Buttons
|
|
self.btn_start = QPushButton("|<")
|
|
self.btn_prev = QPushButton("<")
|
|
self.btn_play = QPushButton("Play")
|
|
self.btn_next = QPushButton(">")
|
|
self.btn_end = QPushButton(">|")
|
|
|
|
self.transport_btns = [self.btn_start, self.btn_prev, self.btn_play, self.btn_next, self.btn_end]
|
|
for btn in self.transport_btns:
|
|
btn.setEnabled(False)
|
|
btn.setFixedWidth(50)
|
|
playback_row.addWidget(btn)
|
|
|
|
self.btn_mute = QPushButton("Vol")
|
|
self.btn_mute.setFixedWidth(40)
|
|
self.btn_mute.setCheckable(True)
|
|
self.btn_mute.clicked.connect(self.toggle_mute)
|
|
|
|
self.sld_volume = QSlider(Qt.Horizontal)
|
|
self.sld_volume.setRange(0, 100)
|
|
self.sld_volume.setValue(100) # Default volume
|
|
self.sld_volume.setFixedWidth(100)
|
|
self.sld_volume.valueChanged.connect(self.update_volume)
|
|
|
|
# Initialize volume
|
|
self.audio_output.setVolume(0.7)
|
|
|
|
playback_row.addWidget(self.btn_mute)
|
|
playback_row.addWidget(self.sld_volume)
|
|
|
|
# Counters
|
|
counter_style = "font-family: 'Consolas'; font-size: 10pt; margin-left: 5px; color: #00FF00;"
|
|
self.lbl_time_counter = QLabel("Time: 00:00 / 00:00")
|
|
self.lbl_frame_counter = QLabel("Frame: 0 / 0")
|
|
self.lbl_time_counter.setFixedWidth(180)
|
|
self.lbl_frame_counter.setFixedWidth(180)
|
|
self.lbl_time_counter.setStyleSheet(counter_style)
|
|
self.lbl_frame_counter.setStyleSheet(counter_style)
|
|
|
|
playback_row.addWidget(self.lbl_time_counter)
|
|
playback_row.addWidget(self.lbl_frame_counter)
|
|
|
|
playback_row.addStretch()
|
|
|
|
# --- Add Rows to Stack ---
|
|
stacked_controls.addLayout(ml_row)
|
|
stacked_controls.addLayout(playback_row)
|
|
|
|
video_layout.addWidget(controls_container)
|
|
|
|
# --- Inspector & Timeline ---
|
|
self.info_label = QTextEdit()
|
|
self.info_label.setReadOnly(True)
|
|
|
|
self.timeline = TimelineWidget()
|
|
self.timeline.seek_requested.connect(self.seek_video)
|
|
|
|
top_splitter.addWidget(video_container)
|
|
top_splitter.addWidget(self.info_label)
|
|
|
|
self.main_splitter.addWidget(top_splitter)
|
|
self.main_splitter.addWidget(self.timeline)
|
|
main_layout.addWidget(self.main_splitter)
|
|
|
|
self.setup_transport() # Start with empty workspace until worker finishes
|
|
|
|
def setup_transport(self):
|
|
"""Sets up player controls that don't depend on skeleton analysis."""
|
|
# Enable buttons immediately
|
|
for btn in [self.btn_play, self.btn_prev, self.btn_next, self.btn_start, self.btn_end]:
|
|
btn.setEnabled(True)
|
|
|
|
# Connections (Use disconnect first to avoid double-firing if re-called)
|
|
try: self.btn_play.clicked.disconnect()
|
|
except: pass
|
|
|
|
self.btn_play.clicked.connect(self.toggle_playback)
|
|
self.btn_start.clicked.connect(lambda: self.player.setPosition(0))
|
|
# Note: 'End' and 'Step' need FPS/Duration, handled in the methods themselves
|
|
self.btn_end.clicked.connect(lambda: self.player.setPosition(self.player.duration()))
|
|
self.btn_prev.clicked.connect(lambda: self.step_frame(-1))
|
|
self.btn_next.clicked.connect(lambda: self.step_frame(1))
|
|
|
|
self.player.setSource(QUrl.fromLocalFile(self.file_path))
|
|
|
|
|
|
@Slot(dict)
|
|
def setup_workspace(self, analyzed_data):
|
|
"""Only handles skeleton/analysis-specific data."""
|
|
self.data = analyzed_data
|
|
|
|
# Update timeline and overlay now that we have data
|
|
if hasattr(self, 'timeline'):
|
|
self.timeline.set_data(self.data)
|
|
|
|
if hasattr(self, 'skeleton_overlay'):
|
|
self.skeleton_overlay.set_data(self.data)
|
|
|
|
# Trigger a geometry refresh now that we have accurate video dims from data
|
|
self.update_video_geometry()
|
|
print(f"Analysis complete for {self.file_path}")
|
|
|
|
def update_status(self, message):
|
|
"""Updates the inspector or a status bar with worker progress."""
|
|
self.info_label.append(message)
|
|
|
|
def toggle_playback(self):
|
|
if self.player.playbackState() == QMediaPlayer.PlayingState:
|
|
self.player.pause()
|
|
self.btn_play.setText("Play")
|
|
else:
|
|
self.player.play()
|
|
self.btn_play.setText("Pause")
|
|
|
|
def seek_video(self, ms):
|
|
self.player.setPosition(ms)
|
|
|
|
def step_frame(self, delta):
|
|
# Fallback to 30 FPS if worker data isn't ready
|
|
fps = self.data["fps"] if (hasattr(self, 'data') and self.data) else 30.0
|
|
|
|
current_ms = self.player.position()
|
|
# One frame in ms = 1000 / fps
|
|
frame_ms = 1000.0 / fps
|
|
target_ms = int(current_ms + (delta * frame_ms))
|
|
|
|
# Ensure we don't seek past duration
|
|
target_ms = max(0, min(target_ms, self.player.duration()))
|
|
self.player.setPosition(target_ms)
|
|
|
|
def update_volume(self, value):
|
|
# QAudioOutput expects a float between 0.0 and 1.0
|
|
volume = value / 100.0
|
|
self.audio_output.setVolume(volume)
|
|
|
|
# Auto-unmute if user moves the slider
|
|
if self.btn_mute.isChecked() and value > 0:
|
|
self.btn_mute.setChecked(False)
|
|
self.toggle_mute()
|
|
|
|
def toggle_mute(self):
|
|
is_muted = self.btn_mute.isChecked()
|
|
self.audio_output.setMuted(is_muted)
|
|
self.btn_mute.setText("Mute" if is_muted else "Vol")
|
|
# Optional: Dim the slider when muted
|
|
self.sld_volume.setEnabled(not is_muted)
|
|
|
|
def update_video_geometry(self):
|
|
if not hasattr(self, "video_item"):
|
|
return
|
|
|
|
# 1. Get viewport dimensions
|
|
viewport_rect = self.view.viewport().rect()
|
|
v_w, v_h = viewport_rect.width(), viewport_rect.height()
|
|
if v_w <= 0 or v_h <= 0:
|
|
return
|
|
|
|
# 2. Get Video Dimensions (Fall back to native size if worker data is missing)
|
|
if hasattr(self, "data") and self.data:
|
|
video_w, video_h = self.data['width'], self.data['height']
|
|
else:
|
|
native_size = self.video_item.nativeSize()
|
|
video_w, video_h = native_size.width(), native_size.height()
|
|
|
|
# If the video hasn't loaded metadata yet, it will be -1 or 0
|
|
if video_w <= 0 or video_h <= 0:
|
|
return
|
|
|
|
# 3. Calculate Aspect Ratio Scaling
|
|
aspect = video_w / video_h
|
|
if v_w / v_h > aspect:
|
|
target_h = v_h
|
|
target_w = int(v_h * aspect)
|
|
else:
|
|
target_w = v_w
|
|
target_h = int(v_w / aspect)
|
|
|
|
x_off = (v_w - target_w) / 2
|
|
y_off = (v_h - target_h) / 2
|
|
|
|
# 4. Apply transformations
|
|
self.scene.setSceneRect(0, 0, v_w, v_h)
|
|
self.video_item.setPos(x_off, y_off)
|
|
self.video_item.setSize(QSizeF(target_w, target_h))
|
|
|
|
# Only update overlay if it exists and we have data
|
|
if hasattr(self, "skeleton_overlay"):
|
|
self.skeleton_overlay.setGeometry(int(x_off), int(y_off), target_w, target_h)
|
|
|
|
def resizeEvent(self, event):
|
|
# debug_print()
|
|
|
|
super().resizeEvent(event)
|
|
self.update_video_geometry()
|
|
if hasattr(self, 'timeline'):
|
|
self.timeline.set_zoom(self.timeline.zoom_factor)
|
|
|
|
def cleanup(self):
|
|
if self.player:
|
|
self.player.stop()
|
|
if hasattr(self, 'worker') and self.worker.isRunning():
|
|
self.worker.terminate()
|
|
|
|
|
|
def resource_path(relative_path):
|
|
"""
|
|
Get absolute path to resource regardless of running directly or packaged using PyInstaller
|
|
"""
|
|
|
|
if hasattr(sys, '_MEIPASS'):
|
|
# PyInstaller bundle path
|
|
base_path = sys._MEIPASS
|
|
else:
|
|
base_path = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
return os.path.join(base_path, relative_path)
|
|
|
|
|
|
def kill_child_processes():
|
|
"""
|
|
Goodbye children
|
|
"""
|
|
|
|
try:
|
|
parent = psutil.Process(os.getpid())
|
|
children = parent.children(recursive=True)
|
|
for child in children:
|
|
try:
|
|
child.kill()
|
|
except psutil.NoSuchProcess:
|
|
pass
|
|
psutil.wait_procs(children, timeout=5)
|
|
except Exception as e:
|
|
print(f"Error killing child processes: {e}")
|
|
|
|
|
|
def exception_hook(exc_type, exc_value, exc_traceback):
|
|
"""
|
|
Method that will display a popup when the program hard crashes containg what went wrong
|
|
"""
|
|
|
|
error_msg = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
|
print(error_msg) # also print to console
|
|
|
|
kill_child_processes()
|
|
|
|
# Show error message box
|
|
# Make sure QApplication exists (or create a minimal one)
|
|
app = QApplication.instance()
|
|
if app is None:
|
|
app = QApplication(sys.argv)
|
|
|
|
show_critical_error(error_msg)
|
|
|
|
# Exit the app after user acknowledges
|
|
sys.exit(1)
|
|
|
|
def show_critical_error(error_msg):
|
|
msg_box = QMessageBox()
|
|
msg_box.setIcon(QMessageBox.Icon.Critical)
|
|
msg_box.setWindowTitle("Something went wrong!")
|
|
|
|
if PLATFORM_NAME == "darwin":
|
|
log_path = os.path.join(os.path.dirname(sys.executable), "../../../flares.log")
|
|
log_path2 = os.path.join(os.path.dirname(sys.executable), "../../../flares_error.log")
|
|
save_path = os.path.join(os.path.dirname(sys.executable), "../../../flares_autosave.flare")
|
|
|
|
else:
|
|
log_path = os.path.join(os.getcwd(), "flares.log")
|
|
log_path2 = os.path.join(os.getcwd(), "flares_error.log")
|
|
save_path = os.path.join(os.getcwd(), "flares_autosave.flare")
|
|
|
|
|
|
shutil.copy(log_path, log_path2)
|
|
log_path2 = Path(log_path2).absolute().as_posix()
|
|
autosave_path = Path(save_path).absolute().as_posix()
|
|
log_link = f"file:///{log_path2}"
|
|
autosave_link = f"file:///{autosave_path}"
|
|
|
|
message = (
|
|
f"{APP_NAME.upper()} has encountered an unrecoverable error and needs to close.<br><br>"
|
|
f"We are sorry for the inconvenience. An autosave was attempted to be saved to <a href='{autosave_link}'>{autosave_path}</a>, but it may not have been saved. "
|
|
"If the file was saved, it still may not be intact, openable, or contain the correct data. Use the autosave at your discretion.<br><br>"
|
|
f"This unrecoverable error was likely due to an error with {APP_NAME.upper()} and not your data.<br>"
|
|
f"Please raise an issue <a href='https://git.research.dezeeuw.ca/tyler/{APP_NAME}/issues'>here</a> and attach the error file located at <a href='{log_link}'>{log_path2}</a><br><br>"
|
|
f"<pre>{error_msg}</pre>"
|
|
)
|
|
|
|
msg_box.setTextFormat(Qt.TextFormat.RichText)
|
|
msg_box.setText(message)
|
|
msg_box.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction)
|
|
msg_box.setStandardButtons(QMessageBox.StandardButton.Ok)
|
|
|
|
msg_box.exec()
|
|
|
|
if __name__ == "__main__":
|
|
# Redirect exceptions to the popup window
|
|
sys.excepthook = exception_hook
|
|
|
|
# Set up application logging
|
|
if PLATFORM_NAME == "darwin":
|
|
log_path = os.path.join(os.path.dirname(sys.executable), f"../../../{APP_NAME}.log")
|
|
else:
|
|
log_path = os.path.join(os.getcwd(), f"{APP_NAME}.log")
|
|
|
|
try:
|
|
os.remove(log_path)
|
|
except:
|
|
pass
|
|
|
|
sys.stdout = open(log_path, "a", buffering=1)
|
|
sys.stderr = sys.stdout
|
|
print(f"\n=== App started at {datetime.now()} ===\n")
|
|
|
|
freeze_support() # Required for PyInstaller + multiprocessing
|
|
|
|
# Only run GUI in the main process
|
|
if current_process().name == 'MainProcess':
|
|
app = QApplication(sys.argv)
|
|
finish_update_if_needed(PLATFORM_NAME, APP_NAME)
|
|
window = PremiereWindow()
|
|
|
|
if PLATFORM_NAME == "darwin":
|
|
app.setWindowIcon(QIcon(resource_path("icons/main.icns")))
|
|
window.setWindowIcon(QIcon(resource_path("icons/main.icns")))
|
|
else:
|
|
app.setWindowIcon(QIcon(resource_path("icons/main.ico")))
|
|
window.setWindowIcon(QIcon(resource_path("icons/main.ico")))
|
|
window.show()
|
|
sys.exit(app.exec())
|
|
|
|
# Not 6000 lines yay! |