2383 lines
92 KiB
Python
2383 lines
92 KiB
Python
"""
|
|
Filename: main.py
|
|
Description: BLAZES main executable
|
|
|
|
Author: Tyler de Zeeuw
|
|
License: GPL-3.0
|
|
"""
|
|
|
|
# Built-in imports
|
|
import os
|
|
import csv
|
|
import sys
|
|
import json
|
|
import glob
|
|
import shutil
|
|
import inspect
|
|
import platform
|
|
import traceback
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from multiprocessing import current_process, freeze_support
|
|
|
|
# External library imports
|
|
import numpy as np
|
|
import pandas as pd
|
|
import psutil
|
|
import joblib
|
|
import cv2
|
|
from ultralytics import YOLO
|
|
|
|
from updater import finish_update_if_needed, UpdateManager, LocalPendingUpdateCheckThread
|
|
from predictor import GeneralPredictor
|
|
from batch_processing import BatchProcessorDialog
|
|
|
|
import PySide6
|
|
from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QGraphicsView, QGraphicsScene,
|
|
QHBoxLayout, QSplitter, QLabel, QPushButton, QComboBox, QInputDialog,
|
|
QFileDialog, QScrollArea, QMessageBox, QSlider, QTextEdit)
|
|
from PySide6.QtCore import Qt, QThread, Signal, QUrl, QRectF, QPointF, QRect, QSizeF
|
|
from PySide6.QtGui import QPainter, QColor, QFont, QPen, QBrush, QAction, QKeySequence, QIcon, QTextOption
|
|
from PySide6.QtMultimedia import QMediaPlayer, QAudioOutput
|
|
from PySide6.QtMultimediaWidgets import QGraphicsVideoItem
|
|
|
|
|
|
VERBOSITY = 1
|
|
CURRENT_VERSION = "0.1.0"
|
|
APP_NAME = "blazes"
|
|
API_URL = f"https://git.research.dezeeuw.ca/api/v1/repos/tyler/{APP_NAME}/releases"
|
|
API_URL_SECONDARY = f"https://git.research2.dezeeuw.ca/api/v1/repos/tyler/{APP_NAME}/releases"
|
|
PLATFORM_NAME = platform.system().lower()
|
|
|
|
|
|
|
|
def debug_print():
|
|
if VERBOSITY:
|
|
frame = inspect.currentframe().f_back
|
|
qualname = frame.f_code.co_qualname
|
|
print(qualname)
|
|
|
|
|
|
# Ordered according to YOLO docs: https://docs.ultralytics.com/tasks/pose/
|
|
JOINT_NAMES = [
|
|
"Nose", "Left Eye", "Right Eye", "Left Ear", "Right Ear",
|
|
"Left Shoulder", "Right Shoulder", "Left Elbow", "Right Elbow",
|
|
"Left Wrist", "Right Wrist", "Left Hip", "Right Hip",
|
|
"Left Knee", "Right Knee", "Left Ankle", "Right Ankle"
|
|
]
|
|
|
|
|
|
# Needs to be pointed to the FFmpeg bin folder containing avcodec-*.dll, etc.
|
|
pyside_dir = Path(PySide6.__file__).parent
|
|
if sys.platform == "win32":
|
|
# Tell Python 3.13+ where to find the FFmpeg DLLs bundled with PySide
|
|
os.add_dll_directory(str(pyside_dir))
|
|
|
|
|
|
TRACK_NAMES = ["Baseline", "Live Skeleton"] + JOINT_NAMES
|
|
NUM_TRACKS = len(TRACK_NAMES)
|
|
|
|
# TODO: Improve colors?
|
|
# Generate distinct colors for the tracks
|
|
BASE_COLORS = [QColor(180, 180, 180), QColor(0, 0, 0)] # Grey for Baseline, Black for Live
|
|
REMAINING_COLORS = [QColor.fromHsv(int((i / (NUM_TRACKS-2)) * 359), 200, 255) for i in range(NUM_TRACKS-2)]
|
|
TRACK_COLORS = BASE_COLORS + REMAINING_COLORS
|
|
|
|
|
|
|
|
|
|
class AboutWindow(QWidget):
|
|
"""
|
|
Simple About window displaying basic application information.
|
|
|
|
Args:
|
|
parent (QWidget, optional): Parent widget of this window. Defaults to None.
|
|
"""
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent, Qt.WindowType.Window)
|
|
self.setWindowTitle(f"About {APP_NAME.upper()}")
|
|
self.resize(250, 100)
|
|
self.setStyleSheet("""
|
|
QVBoxLayout, QWidget {
|
|
background-color: #1e1e1e;
|
|
}
|
|
QLabel {
|
|
color: #ffffff;
|
|
}
|
|
""")
|
|
|
|
layout = QVBoxLayout()
|
|
label = QLabel(f"About {APP_NAME.upper()}", self)
|
|
label2 = QLabel("Behavioral Learning & Automated Zoned Events Suite", self)
|
|
label3 = QLabel(f"{APP_NAME.upper()} is licensed under the GPL-3.0 licence. For more information, visit https://www.gnu.org/licenses/gpl-3.0.en.html", self)
|
|
label4 = QLabel(f"Version v{CURRENT_VERSION}")
|
|
|
|
layout.addWidget(label)
|
|
layout.addWidget(label2)
|
|
layout.addWidget(label3)
|
|
layout.addWidget(label4)
|
|
|
|
self.setLayout(layout)
|
|
|
|
|
|
|
|
class UserGuideWindow(QWidget):
|
|
"""
|
|
Simple User Guide window displaying basic information on how to use the software.
|
|
|
|
Args:
|
|
parent (QWidget, optional): Parent widget of this window. Defaults to None.
|
|
"""
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent, Qt.WindowType.Window)
|
|
self.setWindowTitle(f"User Guide - {APP_NAME.upper()}")
|
|
self.resize(250, 100)
|
|
self.setStyleSheet("""
|
|
QVBoxLayout, QWidget {
|
|
background-color: #1e1e1e;
|
|
}
|
|
QLabel {
|
|
color: #ffffff;
|
|
}
|
|
""")
|
|
|
|
layout = QVBoxLayout()
|
|
label = QLabel("Hmmm...", self)
|
|
label2 = QLabel("Nothing to see here yet.", self)
|
|
|
|
label3 = QLabel(f"For more information, visit the Git wiki page <a href='https://git.research.dezeeuw.ca/tyler/{APP_NAME}/wiki'>here</a>.", self)
|
|
label3.setTextFormat(Qt.TextFormat.RichText)
|
|
label3.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction)
|
|
label3.setOpenExternalLinks(True)
|
|
layout.addWidget(label)
|
|
layout.addWidget(label2)
|
|
layout.addWidget(label3)
|
|
|
|
self.setLayout(layout)
|
|
|
|
|
|
|
|
class PoseAnalyzerWorker(QThread):
|
|
progress = Signal(str)
|
|
finished_data = Signal(dict)
|
|
|
|
def __init__(self, video_path, obs_info=None, predictor=None):
|
|
debug_print()
|
|
super().__init__()
|
|
self.video_path = video_path
|
|
self.obs_info = obs_info
|
|
self.predictor = predictor
|
|
self.pose_df = pd.DataFrame()
|
|
|
|
|
|
def get_best_infant_match(self, results, w, h, prev_track_id):
|
|
debug_print()
|
|
if not results[0].boxes or results[0].boxes.id is None:
|
|
return None, None, None, None
|
|
ids = results[0].boxes.id.int().cpu().tolist()
|
|
kpts = results[0].keypoints.xy.cpu().numpy()
|
|
confs = results[0].keypoints.conf.cpu().numpy()
|
|
best_idx, best_score = -1, -1
|
|
for i, k in enumerate(kpts):
|
|
vis = np.sum(confs[i] > 0.5)
|
|
valid = k[confs[i] > 0.5]
|
|
dist = np.linalg.norm(np.mean(valid, axis=0) - [w/2, h/2]) if len(valid) > 0 else 1000
|
|
score = (vis * 10) - (dist * 0.1) + (50 if ids[i] == prev_track_id else 0)
|
|
if score > best_score:
|
|
best_score, best_idx = score, i
|
|
if best_idx == -1:
|
|
return None, None, None, None
|
|
return ids[best_idx], kpts[best_idx], confs[best_idx], best_idx
|
|
|
|
|
|
def _merge_json_observations(self, timeline_events, fps):
|
|
"""Restores the grouping and block-pairing logic from the observation files."""
|
|
debug_print()
|
|
if not self.obs_info:
|
|
return
|
|
|
|
self.progress.emit("Merging JSON Observations...")
|
|
json_path, subkey = self.obs_info
|
|
|
|
# try:
|
|
# with open(json_path, 'r') as f:
|
|
# full_json = json.load(f)
|
|
|
|
# # Extract events for the specific subkey (e.g., 'Participant_01')
|
|
# raw_obs_events = full_json["observations"][subkey]["events"]
|
|
# raw_obs_events.sort(key=lambda x: x[0]) # Sort by timestamp
|
|
|
|
# # Group frames by label
|
|
# obs_groups = {}
|
|
# for ev in raw_obs_events:
|
|
# time_sec, _, label, special = ev[0], ev[1], ev[2], ev[3]
|
|
# frame = int(time_sec * fps)
|
|
# if label not in obs_groups:
|
|
# obs_groups[label] = []
|
|
# obs_groups[label].append(frame)
|
|
|
|
# # Convert groups of frames into (Start, End) blocks
|
|
# for label, frames in obs_groups.items():
|
|
# track_name = f"OBS: {label}"
|
|
# processed_blocks = []
|
|
|
|
# # Step by 2 to create start/end pairs
|
|
# for i in range(0, len(frames) - 1, 2):
|
|
# start_f = frames[i]
|
|
# end_f = frames[i+1]
|
|
# processed_blocks.append((start_f, end_f, "Moderate", "Manual"))
|
|
|
|
# # Register the track globally if it's new
|
|
# if track_name not in TRACK_NAMES:
|
|
# TRACK_NAMES.append(track_name)
|
|
# TRACK_COLORS.append(QColor("#AA00FF")) # Purple for Observations
|
|
|
|
# timeline_events[track_name] = processed_blocks
|
|
|
|
# except Exception as e:
|
|
# print(f"Error parsing JSON Observations: {e}")
|
|
|
|
|
|
try:
|
|
with open(json_path, 'r') as f:
|
|
full_json = json.load(f)
|
|
|
|
raw_obs_events = full_json["observations"][subkey]["events"]
|
|
raw_obs_events.sort(key=lambda x: x[0])
|
|
|
|
# NEW LOGIC: Use a dictionary to store frames for specific track names
|
|
# track_name -> [list of frames]
|
|
obs_groups = {}
|
|
|
|
for ev in raw_obs_events:
|
|
# ev structure: [time_sec, unknown, label, special]
|
|
time_sec, label, special = ev[0], ev[2], ev[3]
|
|
frame = int(time_sec * fps)
|
|
|
|
# Determine which tracks this event belongs to
|
|
target_tracks = []
|
|
|
|
if special == "Left":
|
|
target_tracks.append(f"OBS: {label} (Left)")
|
|
elif special == "Right":
|
|
target_tracks.append(f"OBS: {label} (Right)")
|
|
elif special == "Both":
|
|
target_tracks.append(f"OBS: {label} (Left)")
|
|
target_tracks.append(f"OBS: {label} (Right)")
|
|
else:
|
|
# No special or unrecognized value
|
|
target_tracks.append(f"OBS: {label}")
|
|
|
|
# Add the frame to all applicable tracks
|
|
for t_name in target_tracks:
|
|
if t_name not in obs_groups:
|
|
obs_groups[t_name] = []
|
|
obs_groups[t_name].append(frame)
|
|
|
|
# Convert frame groups into (Start, End) blocks
|
|
for track_name, frames in obs_groups.items():
|
|
processed_blocks = []
|
|
|
|
# Step by 2 to create start/end pairs (ensures matching pairs per track)
|
|
|
|
if "Sync" in track_name and len(frames) == 1:
|
|
start_f = frames[0]
|
|
end_f = start_f + 1 # Give it a visible width on the timeline
|
|
processed_blocks.append((start_f, end_f, "Moderate", "Manual"))
|
|
|
|
else:
|
|
for i in range(0, len(frames) - 1, 2):
|
|
start_f = frames[i]
|
|
end_f = frames[i+1]
|
|
processed_blocks.append((start_f, end_f, "Moderate", "Manual"))
|
|
|
|
# Register the track in global lists if not already there
|
|
if track_name not in TRACK_NAMES:
|
|
TRACK_NAMES.append(track_name)
|
|
# Using Purple for Observations
|
|
TRACK_COLORS.append(QColor("#AA00FF"))
|
|
|
|
timeline_events[track_name] = processed_blocks
|
|
|
|
except Exception as e:
|
|
print(f"Error parsing JSON Observations: {e}")
|
|
|
|
|
|
def _run_existing_ml_models(self, z_kps, dirs, raw_kpts):
|
|
debug_print()
|
|
"""
|
|
Scans for trained models and generates timeline tracks for each.
|
|
"""
|
|
ai_events = {}
|
|
|
|
# 1. Match the pattern from your GeneralPredictor: {Target}_rf.pkl
|
|
model_files = glob.glob("*_rf.pkl")
|
|
print(f"DEBUG: Found model files: {model_files}")
|
|
|
|
for model_path in model_files:
|
|
try:
|
|
# Extract Target (e.g., "Mouthing" from "Mouthing_rf.pkl")
|
|
base_name = model_path.split("_rf.pkl")[0]
|
|
target = base_name.replace("ml_", "", 1)
|
|
track_name = f"AI: {target}"
|
|
|
|
self.progress.emit(f"Loading AI Observations for {target}...")
|
|
|
|
|
|
# 2. Match the Scaler naming from calculate_and_train:
|
|
# {target}_random_forest_scaler.pkl
|
|
scaler_path = f"{base_name}_rf_scaler.pkl"
|
|
|
|
if not os.path.exists(scaler_path):
|
|
print(f"DEBUG: Skipping {target}, scaler not found at {scaler_path}")
|
|
continue
|
|
|
|
# Load assets
|
|
model = joblib.load(model_path)
|
|
scaler = joblib.load(scaler_path)
|
|
|
|
# 3. Feature extraction (On-the-fly)
|
|
all_features = []
|
|
# We must set the predictor's target so format_features uses the correct ACTIVITY_MAP
|
|
self.predictor.current_target = target
|
|
|
|
for f_idx in range(len(z_kps)):
|
|
feat = self.predictor.format_features(z_kps[f_idx], dirs[f_idx], raw_kpts[f_idx])
|
|
all_features.append(feat)
|
|
|
|
# 4. Inference
|
|
X = np.array(all_features)
|
|
X_scaled = scaler.transform(X)
|
|
predictions = model.predict(X_scaled)
|
|
|
|
# 5. Convert binary 0/1 to blocks
|
|
processed_blocks = []
|
|
start_f = None
|
|
|
|
for f_idx, val in enumerate(predictions):
|
|
if val == 1 and start_f is None:
|
|
start_f = f_idx
|
|
elif val == 0 and start_f is not None:
|
|
# [start, end, severity, direction]
|
|
processed_blocks.append((start_f, f_idx - 1, "Large", "AI"))
|
|
start_f = None
|
|
|
|
if start_f is not None:
|
|
processed_blocks.append((start_f, len(predictions)-1, "Large", "AI"))
|
|
|
|
# 6. Global Registration
|
|
if track_name not in TRACK_NAMES:
|
|
TRACK_NAMES.append(track_name)
|
|
# Ensure TRACK_COLORS has an entry for this new track
|
|
TRACK_COLORS.append(QColor("#00FF00"))
|
|
|
|
ai_events[track_name] = processed_blocks
|
|
print(f"DEBUG: Successfully generated {len(processed_blocks)} blocks for {track_name}")
|
|
|
|
except Exception as e:
|
|
print(f"Inference Error for {model_path}: {e}")
|
|
|
|
return ai_events
|
|
|
|
|
|
def classify_delta(self, z):
|
|
# debug_print()
|
|
z_abs = abs(z)
|
|
if z_abs < 1: return "Rest"
|
|
elif z_abs < 2: return "Small"
|
|
elif z_abs < 3: return "Moderate"
|
|
else: return "Large"
|
|
|
|
|
|
def _save_pose_cache(self, path, data):
|
|
"""
|
|
Saves the raw YOLO keypoints and confidence scores to a CSV.
|
|
Each row represents one frame, flattened from (17, 3) to (51,).
|
|
"""
|
|
try:
|
|
with open(path, 'w', newline='') as f:
|
|
writer = csv.writer(f)
|
|
|
|
# Create the descriptive header
|
|
header = []
|
|
for joint in JOINT_NAMES:
|
|
# Replace spaces with underscores for better compatibility with other tools
|
|
header.extend([f"{joint}_x", f"{joint}_y", f"{joint}_conf"])
|
|
|
|
writer.writerow(header)
|
|
|
|
# Write the frame data
|
|
for frame_data in data:
|
|
# frame_data is (17, 3), flatten to (51,)
|
|
writer.writerow(frame_data.flatten())
|
|
|
|
print(f"DEBUG: Pose cache saved with joint headers at {path}")
|
|
except Exception as e:
|
|
print(f"ERROR: Could not save pose cache: {e}")
|
|
|
|
|
|
def run(self):
|
|
debug_print()
|
|
# --- PHASE 1: VIDEO SETUP & POSE EXTRACTION ---
|
|
cap = cv2.VideoCapture(self.video_path)
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
raw_kps_per_frame = []
|
|
csv_storage_data = []
|
|
valid_mask = []
|
|
pose_cache = self.video_path.rsplit('.', 1)[0] + "_pose_raw.csv"
|
|
|
|
if os.path.exists(pose_cache):
|
|
self.progress.emit("Loading cached kinematic data...")
|
|
with open(pose_cache, 'r') as f:
|
|
reader = csv.reader(f)
|
|
next(reader)
|
|
for row in reader:
|
|
full_data = np.array([float(x) for x in row]).reshape(17, 3)
|
|
kp = full_data[:, :2]
|
|
raw_kps_per_frame.append(kp)
|
|
csv_storage_data.append(full_data)
|
|
valid_mask.append(np.any(kp))
|
|
else:
|
|
self.progress.emit("Detecting poses with YOLO...")
|
|
model = YOLO("yolov8n-pose.pt")
|
|
prev_track_id = None
|
|
for i in range(total_frames):
|
|
ret, frame = cap.read()
|
|
if not ret: break
|
|
results = model.track(frame, persist=True, verbose=False)
|
|
track_id, kp, confs, _ = self.get_best_infant_match(results, width, height, prev_track_id)
|
|
if kp is not None:
|
|
prev_track_id = track_id
|
|
raw_kps_per_frame.append(kp)
|
|
csv_storage_data.append(np.column_stack((kp, confs)))
|
|
valid_mask.append(True)
|
|
else:
|
|
raw_kps_per_frame.append(np.zeros((17, 2)))
|
|
csv_storage_data.append(np.zeros((17, 3)))
|
|
valid_mask.append(False)
|
|
if i % 50 == 0: self.progress.emit(f"YOLO: {int((i/total_frames)*100)}%")
|
|
self._save_pose_cache(pose_cache, csv_storage_data)
|
|
|
|
cap.release()
|
|
actual_len = len(raw_kps_per_frame)
|
|
|
|
flattened_rows = []
|
|
for frame_array in csv_storage_data:
|
|
# frame_array is (17, 3) -> flatten to (51,)
|
|
flattened_rows.append(frame_array.flatten())
|
|
|
|
columns = []
|
|
for name in JOINT_NAMES:
|
|
columns.extend([f"{name}_x", f"{name}_y", f"{name}_conf"])
|
|
|
|
# Store this so the Inspector can access it instantly in memory
|
|
self.pose_df = pd.DataFrame(flattened_rows, columns=columns)
|
|
|
|
# --- PHASE 2: KINEMATICS & Z-SCORES ---
|
|
self.progress.emit("Calculating Kinematics...")
|
|
analysis_kpts = []
|
|
for kp in raw_kps_per_frame:
|
|
pelvis = (kp[11] + kp[12]) / 2
|
|
analysis_kpts.append(kp - pelvis)
|
|
|
|
valid_data = [analysis_kpts[i] for i, v in enumerate(valid_mask) if v]
|
|
if valid_data:
|
|
stacked = np.stack(valid_data)
|
|
baseline_mean = np.mean(stacked, axis=0)
|
|
baseline_std = np.std(np.linalg.norm(stacked - baseline_mean, axis=2), axis=0) + 1e-6
|
|
else:
|
|
baseline_mean, baseline_std = np.zeros((17, 2)), np.ones(17)
|
|
|
|
np_raw_kps = np.array(raw_kps_per_frame)
|
|
np_z_kps = np.array([np.linalg.norm(kp - baseline_mean, axis=1) / baseline_std for kp in analysis_kpts])
|
|
|
|
# Calculate directions (Assume you have a method for this or use a dummy for now)
|
|
# Using placeholder empty strings to prevent errors in track generation
|
|
np_dirs = np.full((actual_len, 17), "", dtype=object)
|
|
|
|
# --- PHASE 3: TIMELINE GENERATION ---
|
|
# Initialize dictionary with ALL global track names to prevent KeyErrors
|
|
timeline_events = {name: [] for name in TRACK_NAMES}
|
|
|
|
# 1. Kinematic Events (The joint tracks)
|
|
for j_idx, joint_name in enumerate(JOINT_NAMES):
|
|
current_block = None
|
|
for f_idx in range(actual_len):
|
|
severity = self.classify_delta(np_z_kps[f_idx, j_idx])
|
|
if severity != "Rest":
|
|
if current_block and current_block[2] == severity:
|
|
current_block[1] = f_idx
|
|
else:
|
|
current_block = [f_idx, f_idx, severity, ""]
|
|
timeline_events[joint_name].append(current_block)
|
|
else:
|
|
current_block = None
|
|
|
|
# 2. JSON Observations
|
|
self._merge_json_observations(timeline_events, fps)
|
|
|
|
# 3. AI Inferred Events
|
|
ai_events = self._run_existing_ml_models(np_z_kps, np_dirs, np_raw_kps)
|
|
timeline_events.update(ai_events)
|
|
|
|
# --- PHASE 4: EMIT ---
|
|
data = {
|
|
"video_path": self.video_path,
|
|
"fps": fps,
|
|
"total_frames": actual_len,
|
|
"width": width, "height": height,
|
|
"events": timeline_events,
|
|
"raw_kps": np_raw_kps,
|
|
"z_kps": np_z_kps,
|
|
"directions": np_dirs,
|
|
"baseline_kp_mean": baseline_mean
|
|
}
|
|
self.progress.emit("Analysis Complete!")
|
|
self.finished_data.emit(data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ==========================================
|
|
# TIMELINE WIDGET
|
|
# ==========================================
|
|
class TimelineWidget(QWidget):
|
|
seek_requested = Signal(int)
|
|
visibility_changed = Signal(set)
|
|
track_selected = Signal(str)
|
|
|
|
def __init__(self):
|
|
debug_print()
|
|
super().__init__()
|
|
self.data = None
|
|
self.current_frame = 0
|
|
self.zoom_factor = 1.0 # Pixels per frame
|
|
self.label_width = 160 # Fixed gutter for track names
|
|
self.track_height = 25
|
|
self.ruler_height = 20
|
|
self.scrollbar_buffer = 2 # Extra space for the horizontal scrollbar
|
|
self.hidden_tracks = set()
|
|
self.sync_offset = 0.0
|
|
self.sync_fps = 30.0
|
|
# Calculate total required height
|
|
self.total_content_height = (NUM_TRACKS * self.track_height) + self.ruler_height
|
|
self.setMinimumHeight(self.total_content_height + self.scrollbar_buffer)
|
|
|
|
|
|
def set_sync_params(self, offset_seconds, fps=None):
|
|
"""
|
|
Updates the temporal shift parameters and refreshes the UI.
|
|
"""
|
|
debug_print()
|
|
self.sync_offset = float(offset_seconds)
|
|
|
|
# Only update FPS if a valid value is provided,
|
|
# otherwise keep the existing data/video FPS
|
|
if fps and fps > 0:
|
|
self.sync_fps = float(fps)
|
|
elif self.data and "fps" in self.data:
|
|
self.sync_fps = float(self.data["fps"])
|
|
|
|
print(f"DEBUG: Timeline Sync Set - Offset: {self.sync_offset}s, FPS: {self.sync_fps}")
|
|
|
|
# Trigger paintEvent to redraw the blocks in their new shifted positions
|
|
self.update()
|
|
|
|
|
|
def set_zoom(self, factor):
|
|
debug_print()
|
|
if not self.data: return
|
|
|
|
# Calculate MIN zoom: The zoom required to make the video fit the width exactly
|
|
# (Available Width - Sidebar) / Total Frames
|
|
available_w = self.parent().width() - self.label_width if self.parent() else 800
|
|
min_zoom = available_w / self.data["total_frames"]
|
|
|
|
# Clamp: Don't zoom out past the video end, don't zoom in to infinity
|
|
self.zoom_factor = max(min_zoom, min(factor, 50.0))
|
|
self.update_geometry()
|
|
|
|
|
|
def get_all_binary_labels(self, offset_seconds=0.0, fps=30.0):
|
|
"""
|
|
Extracts binary labels for ALL tracks in self.data["events"].
|
|
Returns a dict: {'OBS: Mouthing': [0,1,0...], 'OBS: Kicking': [0,0,1...]}
|
|
"""
|
|
debug_print()
|
|
all_labels = {}
|
|
|
|
if not self.data or "events" not in self.data:
|
|
return all_labels
|
|
|
|
total_frames = self.data.get("total_frames", 0)
|
|
if total_frames == 0:
|
|
return all_labels
|
|
|
|
frame_shift = int(offset_seconds * fps)
|
|
|
|
for track_name in self.data["events"]:
|
|
sequence = np.zeros(total_frames, dtype=int)
|
|
|
|
for event in self.data["events"][track_name]:
|
|
start_f = int(event[0]) - frame_shift
|
|
end_f = int(event[1]) - frame_shift
|
|
|
|
# Clamp values
|
|
start_idx = max(0, min(start_f, total_frames - 1))
|
|
end_idx = max(0, min(end_f, total_frames))
|
|
|
|
if start_idx < end_idx:
|
|
sequence[start_idx:end_idx] = 1
|
|
|
|
all_labels[track_name] = sequence.tolist() # Convert to list for easier storage
|
|
|
|
return all_labels
|
|
|
|
|
|
def update_geometry(self):
|
|
debug_print()
|
|
|
|
if self.data:
|
|
# Width is sidebar + (frames * zoom)
|
|
total_w = self.label_width + int(self.data["total_frames"] * self.zoom_factor)
|
|
self.setFixedWidth(total_w)
|
|
self.update()
|
|
|
|
def wheelEvent(self, event):
|
|
debug_print()
|
|
|
|
if event.modifiers() == Qt.ControlModifier:
|
|
delta = event.angleDelta().y()
|
|
# Zoom by 10% per notch
|
|
zoom_change = 1.1 if delta > 0 else 0.9
|
|
self.set_zoom(self.zoom_factor * zoom_change)
|
|
else:
|
|
# Let the scroll area handle normal vertical scrolling
|
|
super().wheelEvent(event)
|
|
|
|
# --- NEW: CTRL + Plus / Minus / Zero ---
|
|
def keyPressEvent(self, event):
|
|
debug_print()
|
|
|
|
if event.modifiers() == Qt.ControlModifier:
|
|
if event.key() == Qt.Key_Plus or event.key() == Qt.Key_Equal:
|
|
self.set_zoom(self.zoom_factor * 1.2)
|
|
elif event.key() == Qt.Key_Minus:
|
|
self.set_zoom(self.zoom_factor * 0.8)
|
|
elif event.key() == Qt.Key_0:
|
|
self.set_zoom(1.0) # Reset zoom
|
|
else:
|
|
super().keyPressEvent(event)
|
|
|
|
def set_data(self, data):
|
|
debug_print()
|
|
|
|
self.data = data
|
|
self.total_content_height = (len(TRACK_NAMES) * self.track_height) + self.ruler_height
|
|
self.setMinimumHeight(self.total_content_height + self.scrollbar_buffer)
|
|
self.update_geometry()
|
|
|
|
def set_playhead(self, frame):
|
|
debug_print()
|
|
old_x = self.label_width + (self.current_frame * self.zoom_factor)
|
|
self.current_frame = frame
|
|
new_x = self.label_width + (self.current_frame * self.zoom_factor)
|
|
self.ensure_playhead_visible()
|
|
self.update(int(old_x - 5), 0, 10, self.height())
|
|
self.update(int(new_x - 5), 0, 10, self.height())
|
|
|
|
def ensure_playhead_visible(self):
|
|
debug_print()
|
|
|
|
"""Auto-scrolls the scroll area if the playhead leaves the viewport."""
|
|
# Find the QScrollArea parent
|
|
scroll_area = self.parent().parent()
|
|
if not isinstance(scroll_area, QScrollArea): return
|
|
|
|
scrollbar = scroll_area.horizontalScrollBar()
|
|
view_width = scroll_area.viewport().width()
|
|
|
|
# Playhead position in pixels
|
|
px = self.label_width + int(self.current_frame * self.zoom_factor)
|
|
|
|
# Current scroll position
|
|
scroll_x = scrollbar.value()
|
|
|
|
# If playhead is beyond the right edge of visible area
|
|
if px > (scroll_x + view_width):
|
|
# Shift scroll so playhead is at the left (plus sidebar)
|
|
scrollbar.setValue(px - self.label_width)
|
|
|
|
# If playhead is behind the left edge (e.g. user seeked backwards)
|
|
elif px < (scroll_x + self.label_width):
|
|
scrollbar.setValue(px - self.label_width)
|
|
|
|
def mousePressEvent(self, event):
|
|
debug_print()
|
|
|
|
if not self.data or event.button() != Qt.LeftButton:
|
|
return
|
|
|
|
pos_x = event.position().x()
|
|
pos_y = event.position().y()
|
|
scroll_area = self.parent().parent()
|
|
scroll_x = scroll_area.horizontalScrollBar().value()
|
|
|
|
# 1. CALCULATE FRAME
|
|
relative_x = pos_x - self.label_width
|
|
frame = int(relative_x / self.zoom_factor)
|
|
frame = max(0, min(frame, self.data["total_frames"] - 1))
|
|
|
|
# 2. IF CLICKED SIDEBAR: Toggle Visibility (No Scrubbing)
|
|
if pos_x < scroll_x + self.label_width:
|
|
relative_y = pos_y - self.ruler_height
|
|
track_idx = int(relative_y // self.track_height)
|
|
if 0 <= track_idx < len(TRACK_NAMES):
|
|
name = TRACK_NAMES[track_idx]
|
|
if name in self.hidden_tracks: self.hidden_tracks.remove(name)
|
|
else: self.hidden_tracks.add(name)
|
|
self.visibility_changed.emit(self.hidden_tracks)
|
|
self.update()
|
|
return # Exit early; don't set is_scrubbing
|
|
|
|
# 3. IF CLICKED RULER OR DATA AREA: Start Scrubbing
|
|
self.is_scrubbing = True
|
|
self.seek_requested.emit(frame)
|
|
|
|
# Handle track selection if in the data area
|
|
if pos_y >= self.ruler_height:
|
|
track_idx = int((pos_y - self.ruler_height) // self.track_height)
|
|
if 0 <= track_idx < len(TRACK_NAMES):
|
|
self.track_selected.emit(TRACK_NAMES[track_idx])
|
|
self.selected_track_idx = track_idx
|
|
self.update()
|
|
else:
|
|
# Clicked ruler
|
|
self.selected_track_idx = -1
|
|
self.track_selected.emit("")
|
|
self.update()
|
|
|
|
def mouseMoveEvent(self, event):
|
|
debug_print()
|
|
|
|
# This only fires while moving if a button is held down by default
|
|
if self.is_scrubbing:
|
|
self.update_frame_from_mouse(event.position().x())
|
|
|
|
def mouseReleaseEvent(self, event):
|
|
debug_print()
|
|
|
|
if event.button() == Qt.LeftButton:
|
|
self.is_scrubbing = False
|
|
|
|
def update_frame_from_mouse(self, x_pos):
|
|
"""Helper to calculate frame and emit the seek signal."""
|
|
debug_print()
|
|
relative_x = x_pos - self.label_width
|
|
frame = int(relative_x / self.zoom_factor)
|
|
frame = max(0, min(frame, self.data["total_frames"] - 1))
|
|
|
|
# We emit seek_requested so the Video Player and Premiere class sync up
|
|
self.seek_requested.emit(frame)
|
|
|
|
|
|
def paintEvent(self, event):
|
|
debug_print()
|
|
if not self.data: return
|
|
|
|
dirty_rect = event.rect()
|
|
painter = QPainter(self)
|
|
|
|
# 1. Determine current scroll position to keep labels sticky
|
|
scroll_area = self.parent().parent()
|
|
scroll_x = 0
|
|
if isinstance(scroll_area, QScrollArea):
|
|
scroll_x = scroll_area.horizontalScrollBar().value()
|
|
|
|
w, h = self.width(), self.height()
|
|
total_f = self.data["total_frames"]
|
|
fps = self.data.get("fps", 30)
|
|
offset_y = 20
|
|
|
|
# 2. DRAW DATA AREA (Events and Playhead)
|
|
# --- 2. DRAW DATA AREA (Muted Patterns + Events + Playhead) ---
|
|
sync_off = getattr(self, "sync_offset", 0.0)
|
|
sync_fps = getattr(self, "sync_fps", fps)
|
|
frame_shift = int(sync_off * sync_fps)
|
|
|
|
for i, name in enumerate(TRACK_NAMES):
|
|
y = offset_y + (i * self.track_height)
|
|
is_hidden = name in self.hidden_tracks
|
|
|
|
if y + self.track_height < dirty_rect.top() or y > dirty_rect.bottom():
|
|
continue
|
|
|
|
# A. Draw Muted/Disabled Background Pattern
|
|
|
|
if is_hidden:
|
|
# Calculate the visible rectangle for this track to the right of the sidebar
|
|
mute_rect = QRectF(self.label_width, y, w - self.label_width, self.track_height)
|
|
# Fill with a dark "disabled" base
|
|
painter.fillRect(mute_rect, QColor(40, 40, 40))
|
|
# Add the Cross-Hatch Pattern
|
|
pattern_brush = QBrush(QColor(60, 60, 60, 100), Qt.DiagCrossPattern)
|
|
painter.fillRect(mute_rect, pattern_brush)
|
|
|
|
# B. Draw Event Blocks
|
|
base_color = TRACK_COLORS[i]
|
|
for start_f, end_f, severity, direction in self.data["events"][name]:
|
|
# x_start = self.label_width + (start_f * self.zoom_factor)
|
|
# x_end = self.label_width + (end_f * self.zoom_factor)
|
|
|
|
if "AI:" in name:
|
|
# AI predictions are already calculated in video-time, NO SHIFT
|
|
shifted_start, shifted_end = start_f, end_f
|
|
|
|
else:
|
|
shifted_start = start_f - frame_shift
|
|
shifted_end = end_f - frame_shift
|
|
|
|
x_start = self.label_width + (shifted_start * self.zoom_factor)
|
|
x_end = self.label_width + (shifted_end * self.zoom_factor)
|
|
|
|
# Performance optimization: skip drawing if off-screen
|
|
if x_end < scroll_x or x_start > scroll_x + w:
|
|
continue
|
|
|
|
if x_end < dirty_rect.left() or x_start > dirty_rect.right():
|
|
continue
|
|
|
|
# If hidden, make the event block very faint/transparent
|
|
if is_hidden:
|
|
color = QColor(120, 120, 120, 40) # Muted Grey
|
|
else:
|
|
alpha = 80 if severity == "Small" else 160 if severity == "Moderate" else 255
|
|
color = QColor(base_color)
|
|
color.setAlpha(alpha)
|
|
|
|
painter.fillRect(QRectF(x_start, y + 2, max(1, x_end - x_start), self.track_height - 4), color)
|
|
# Draw Playhead
|
|
playhead_x = self.label_width + (self.current_frame * self.zoom_factor)
|
|
painter.setPen(QPen(QColor(255, 0, 0), 2))
|
|
painter.drawLine(playhead_x, 0, playhead_x, h)
|
|
|
|
# 3. DRAW STICKY SIDEBAR (Pinned to the left edge of the viewport)
|
|
# We draw this AFTER the data so it covers the blocks as they scroll past
|
|
sidebar_rect = QRect(scroll_x, 0, self.label_width, h)
|
|
painter.fillRect(sidebar_rect, QColor(30, 30, 30)) # Solid background
|
|
|
|
# Ruler segment for the sidebar area
|
|
painter.fillRect(scroll_x, 0, self.label_width, offset_y, QColor(45, 45, 45))
|
|
|
|
for i, name in enumerate(TRACK_NAMES):
|
|
y = offset_y + (i * self.track_height)
|
|
is_hidden = name in self.hidden_tracks
|
|
# Grid Line
|
|
painter.setPen(QColor(60, 60, 60))
|
|
painter.drawLine(scroll_x, y, scroll_x + w, y)
|
|
|
|
# Sticky Label Text
|
|
if is_hidden:
|
|
# Very dark grey to show it's "OFF"
|
|
text_color = QColor(70, 70, 70)
|
|
else:
|
|
# Bright white/grey to show it's "ON"
|
|
text_color = QColor(220, 220, 220)
|
|
|
|
painter.setPen(text_color)
|
|
painter.setFont(QFont("Arial", 8, QFont.Bold))
|
|
painter.drawText(scroll_x + 10, y + 17, name)
|
|
|
|
# 4. DRAW TIME RULER TICKS (Right of the sticky sidebar)
|
|
target_spacing_px = 120
|
|
|
|
# Available units in frames: 1, 5, 15, 30 (1s), 150 (5s), 300 (10s), 1800 (1min)
|
|
possible_units = [1, 5, 15, 30, 150, 300, 900, 1800]
|
|
|
|
# Find the smallest unit that results in at least target_spacing_px
|
|
tick_interval = possible_units[-1]
|
|
for unit in possible_units:
|
|
if (unit * self.zoom_factor) >= target_spacing_px:
|
|
tick_interval = unit
|
|
break
|
|
|
|
# 2. DRAW BACKGROUNDS
|
|
painter.fillRect(0, 0, w, 20, QColor(45, 45, 45)) # Ruler Bar
|
|
|
|
# 3. DRAW TICKS AND TIME LABELS
|
|
painter.setPen(QColor(180, 180, 180))
|
|
painter.setFont(QFont("Segoe UI", 7))
|
|
|
|
# Sub-ticks (draw 5 small lines for every 1 major interval)
|
|
sub_interval = max(1, tick_interval // 5)
|
|
|
|
# Start loop from 0 to total frames
|
|
for f in range(0, total_f + 1, sub_interval):
|
|
x = self.label_width + int(f * self.zoom_factor)
|
|
|
|
# Optimization: Don't draw if off-screen
|
|
if x < scroll_x: continue
|
|
if x > scroll_x + w: break
|
|
|
|
if f % tick_interval == 0:
|
|
# Major Tick
|
|
painter.drawLine(x, 10, x, 20)
|
|
|
|
# Format Label: MM:SS or SS:FF
|
|
total_seconds = f / fps
|
|
minutes = int(total_seconds // 60)
|
|
seconds = int(total_seconds % 60)
|
|
frames = int(f % fps)
|
|
|
|
if tick_interval < fps:
|
|
time_str = f"{seconds:02d}:{frames:02d}f"
|
|
elif minutes > 0:
|
|
time_str = f"{minutes:02d}m:{seconds:02d}s"
|
|
else:
|
|
time_str = f"{seconds}s"
|
|
|
|
painter.drawText(x + 4, 12, time_str)
|
|
else:
|
|
# Minor Tick
|
|
painter.drawLine(x, 16, x, 20)
|
|
|
|
|
|
def get_ai_extractions(self):
|
|
"""
|
|
Processes timeline data for AI tracks and specific OBS sync events.
|
|
"""
|
|
debug_print()
|
|
fps = self.data.get("fps", 30.0)
|
|
|
|
extraction_data = {
|
|
"metadata": {
|
|
"fps": fps,
|
|
"total_frames": self.data.get("total_frames", 0),
|
|
"track_summaries": {}
|
|
},
|
|
"obs": {}, # New top-level key for specific OBS events
|
|
"ai_tracks": {}
|
|
}
|
|
|
|
if not self.data or "events" not in self.data:
|
|
return extraction_data
|
|
|
|
# 1. Extract the OBS: Time Sync event specifically
|
|
sync_key = "OBS: Time Sync"
|
|
if sync_key in self.data["events"]:
|
|
sync_blocks = self.data["events"][sync_key]
|
|
# Convert blocks to a list of dicts for the JSON
|
|
extraction_data["obs"][sync_key] = [
|
|
{
|
|
"start_frame": b[0],
|
|
"end_frame": b[1],
|
|
"start_time_sec": round(b[0] / fps, 3),
|
|
"end_time_sec": round(b[1] / fps, 3)
|
|
} for b in sync_blocks
|
|
]
|
|
|
|
# 2. Process AI Tracks
|
|
for track_name, blocks in self.data["events"].items():
|
|
if track_name.startswith("AI:"):
|
|
track_results = []
|
|
track_total = 0
|
|
track_long = 0
|
|
|
|
for block in blocks:
|
|
start_f, end_f = block[0], block[1]
|
|
severity = block[2] if len(block) > 2 else "Normal"
|
|
|
|
start_sec = round(start_f / fps, 3)
|
|
end_sec = round(end_f / fps, 3)
|
|
duration = round(end_sec - start_sec, 3)
|
|
|
|
track_total += 1
|
|
if duration > 2.0:
|
|
track_long += 1
|
|
|
|
track_results.append({
|
|
"start_frame": int(start_f),
|
|
"end_frame": int(end_f),
|
|
"start_time_sec": start_sec,
|
|
"end_time_sec": end_sec,
|
|
"duration_sec": duration,
|
|
"severity": severity
|
|
})
|
|
|
|
extraction_data["ai_tracks"][track_name] = track_results
|
|
extraction_data["metadata"]["track_summaries"][track_name] = {
|
|
"event_count": track_total,
|
|
"long_events_over_2s": track_long,
|
|
"total_duration_sec": round(sum(r["duration_sec"] for r in track_results), 3)
|
|
}
|
|
|
|
return extraction_data
|
|
|
|
|
|
class SkeletonOverlay(QWidget):
|
|
def __init__(self, parent=None):
|
|
debug_print()
|
|
super().__init__(parent)
|
|
self.setAttribute(Qt.WA_TransparentForMouseEvents) # Clicks pass through to video
|
|
self.data = None
|
|
self.current_frame = 0
|
|
self.hidden_tracks = set()
|
|
# Use your saved SKELETON_CONNECTIONS logic
|
|
self.connections = [
|
|
(5, 7), (7, 9), (6, 8), (8, 10), (5, 6), (5, 11),
|
|
(6, 12), (11, 12), (11, 13), (13, 15), (12, 14), (14, 16)
|
|
]
|
|
self.KP_MAP = {
|
|
"nose": 0, "LE": 1, "RE": 2, "Lear": 3, "Rear": 4,
|
|
"Lshoulder": 5, "Rshoulder": 6, "Lelbow": 7, "Relbow": 8,
|
|
"Lwrist": 9, "Rwrist": 10, "Lhip": 11, "Rhip": 12,
|
|
"Lknee": 13, "Rknee": 14, "Lankle": 15, "Rankle": 16
|
|
}
|
|
self.CONNECTIONS = [
|
|
("nose", "LE"), ("nose", "RE"), ("LE", "Lear"), ("RE", "Rear"),
|
|
("Lshoulder", "Rshoulder"), ("Lshoulder", "Lelbow"), ("Lelbow", "Lwrist"),
|
|
("Rshoulder", "Relbow"), ("Relbow", "Rwrist"), ("Lshoulder", "Lhip"),
|
|
("Rshoulder", "Rhip"), ("Lhip", "Rhip"), ("Lhip", "Lknee"),
|
|
("Lknee", "Lankle"), ("Rhip", "Rknee"), ("Rknee", "Rankle")
|
|
]
|
|
|
|
|
|
def set_frame(self, frame_idx):
|
|
debug_print()
|
|
self.current_frame = frame_idx
|
|
self.update()
|
|
|
|
|
|
def set_hidden_tracks(self, hidden_set):
|
|
debug_print()
|
|
self.hidden_tracks = hidden_set
|
|
self.update()
|
|
|
|
|
|
def set_data(self, data):
|
|
debug_print()
|
|
self.data = data
|
|
self.update()
|
|
|
|
|
|
def paintEvent(self, event):
|
|
debug_print()
|
|
if not self.data or 'raw_kps' not in self.data:
|
|
return
|
|
|
|
painter = QPainter(self)
|
|
painter.setRenderHint(QPainter.Antialiasing)
|
|
|
|
v_w, v_h = self.data['width'], self.data['height']
|
|
w, h = self.width(), self.height()
|
|
scale_x, scale_y = w / v_w, h / v_h
|
|
|
|
current_f = self.current_frame
|
|
kp_live = self.data['raw_kps'][current_f]
|
|
|
|
# --- 1. MODIFIED TRACK STATUS (Respects Visibility) ---
|
|
def get_track_status(track_name):
|
|
# If the user greyed out this track in the timeline, act as if it's inactive
|
|
if track_name in self.hidden_tracks:
|
|
return None
|
|
if track_name not in self.data['events']:
|
|
return None
|
|
for start, end, severity, direction in self.data['events'][track_name]:
|
|
if start <= current_f <= end:
|
|
idx = TRACK_NAMES.index(track_name)
|
|
color = QColor(TRACK_COLORS[idx])
|
|
alpha = 80 if severity == "Small" else 160 if severity == "Moderate" else 255
|
|
color.setAlpha(alpha)
|
|
return color
|
|
return None
|
|
|
|
ANGLE_SEGMENTS = {
|
|
"L_sh": [("Lhip", "Lshoulder"), ("Lshoulder", "Lelbow")],
|
|
"R_sh": [("Rhip", "Rshoulder"), ("Rshoulder", "Relbow")],
|
|
"L_el": [("Lshoulder", "Lelbow"), ("Lelbow", "Lwrist")],
|
|
"R_el": [("Rshoulder", "Relbow"), ("Relbow", "Rwrist")],
|
|
"L_leg": [("Lhip", "Lknee"), ("Lknee", "Lankle")],
|
|
"R_leg": [("Rhip", "Rknee"), ("Rknee", "Rankle")]
|
|
}
|
|
|
|
# --- 2. DRAW BASELINE (Only if not hidden) ---
|
|
if "Baseline" not in self.hidden_tracks:
|
|
idx_l_hip, idx_r_hip = self.KP_MAP["Lhip"], self.KP_MAP["Rhip"]
|
|
pelvis_live = (kp_live[idx_l_hip] + kp_live[idx_r_hip]) / 2
|
|
kp_baseline = self.data['baseline_kp_mean'] + pelvis_live
|
|
|
|
painter.setPen(QPen(QColor(200, 200, 200, 200), 2, Qt.DashLine))
|
|
for s_name, e_name in self.CONNECTIONS:
|
|
p1 = QPointF(kp_baseline[self.KP_MAP[s_name]][0] * scale_x, kp_baseline[self.KP_MAP[s_name]][1] * scale_y)
|
|
p2 = QPointF(kp_baseline[self.KP_MAP[e_name]][0] * scale_x, kp_baseline[self.KP_MAP[e_name]][1] * scale_y)
|
|
painter.drawLine(p1, p2)
|
|
|
|
# --- 3. DRAW LIVE SKELETON (Only if not hidden) ---
|
|
|
|
# CONNECTIONS
|
|
for s_name, e_name in self.CONNECTIONS:
|
|
active_color = None
|
|
for angle_track, segments in ANGLE_SEGMENTS.items():
|
|
if (s_name, e_name) in segments or (e_name, s_name) in segments:
|
|
active_color = get_track_status(angle_track)
|
|
if active_color: break
|
|
|
|
p1 = QPointF(kp_live[self.KP_MAP[s_name]][0] * scale_x, kp_live[self.KP_MAP[s_name]][1] * scale_y)
|
|
p2 = QPointF(kp_live[self.KP_MAP[e_name]][0] * scale_x, kp_live[self.KP_MAP[e_name]][1] * scale_y)
|
|
|
|
if active_color:
|
|
# Active events ALWAYS draw
|
|
painter.setPen(QPen(active_color, 8, Qt.SolidLine, Qt.RoundCap))
|
|
painter.drawLine(p1, p2)
|
|
elif "Live Skeleton" not in self.hidden_tracks:
|
|
# Black lines ONLY draw if Live Skeleton is ON
|
|
painter.setPen(QPen(Qt.black, 4, Qt.SolidLine, Qt.RoundCap))
|
|
painter.drawLine(p1, p2)
|
|
|
|
# DOTS
|
|
ANGLE_VERTEX_MAP = {
|
|
"L_sh": "Lshoulder", "R_sh": "Rshoulder",
|
|
"L_el": "Lelbow", "R_el": "Relbow",
|
|
"L_leg": "Lknee", "R_leg": "Rknee"
|
|
}
|
|
|
|
for kp_name, kp_idx in self.KP_MAP.items():
|
|
pt = QPointF(kp_live[kp_idx][0] * scale_x, kp_live[kp_idx][1] * scale_y)
|
|
|
|
# Check for Point Event (Skip if hidden via get_track_status)
|
|
point_color = get_track_status(kp_name)
|
|
|
|
if point_color:
|
|
painter.setBrush(point_color)
|
|
painter.setPen(QPen(Qt.white, 0.7))
|
|
painter.drawEllipse(pt, 5, 5)
|
|
continue
|
|
|
|
# Check for Angle Event
|
|
angle_color = None
|
|
for angle_track, vertex_name in ANGLE_VERTEX_MAP.items():
|
|
if kp_name == vertex_name:
|
|
angle_color = get_track_status(angle_track)
|
|
if angle_color: break
|
|
|
|
if angle_color:
|
|
painter.setBrush(angle_color)
|
|
painter.setPen(Qt.NoPen)
|
|
painter.drawEllipse(pt, 4, 4)
|
|
|
|
elif "Live Skeleton" not in self.hidden_tracks:
|
|
painter.setBrush(Qt.black)
|
|
painter.setPen(Qt.NoPen)
|
|
painter.drawEllipse(pt, 4, 4)
|
|
|
|
|
|
class VideoView(QGraphicsView):
|
|
resized = Signal()
|
|
|
|
def __init__(self, scene, parent=None):
|
|
debug_print()
|
|
super().__init__(scene, parent)
|
|
self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
|
self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
|
self.setFrameStyle(0)
|
|
self.setStyleSheet("background: black; border: none;")
|
|
self.setAlignment(Qt.AlignCenter)
|
|
|
|
def resizeEvent(self, event):
|
|
debug_print()
|
|
super().resizeEvent(event)
|
|
self.resized.emit()
|
|
|
|
|
|
# ==========================================
|
|
# MAIN PREMIERE WINDOW
|
|
# ==========================================
|
|
class PremiereWindow(QMainWindow):
|
|
def __init__(self):
|
|
debug_print()
|
|
super().__init__()
|
|
self.setWindowTitle("Pose Analysis Timeline")
|
|
self.resize(1200, 900)
|
|
|
|
self.about = None
|
|
self.help = None
|
|
|
|
self.platform_suffix = "-" + PLATFORM_NAME
|
|
|
|
self.updater = UpdateManager(
|
|
main_window=self,
|
|
api_url=API_URL,
|
|
api_url_sec=API_URL_SECONDARY,
|
|
current_version=CURRENT_VERSION,
|
|
platform_name=PLATFORM_NAME,
|
|
platform_suffix=self.platform_suffix,
|
|
app_name=APP_NAME
|
|
)
|
|
|
|
# self.setStyleSheet("background-color: #1e1e1e; color: #ffffff;")
|
|
self.setStyleSheet("""
|
|
QMainWindow, QWidget#centralWidget {
|
|
background-color: #1e1e1e;
|
|
}
|
|
QLabel, QStatusBar, QMenuBar {
|
|
color: #ffffff;
|
|
}
|
|
/* Target the Timeline specifically */
|
|
TimelineWidget {
|
|
background-color: #1e1e1e;
|
|
border: 1px solid #333333;
|
|
}
|
|
/* Button styling with Grey borders */
|
|
QDialog, QMessageBox, QFileDialog {
|
|
background-color: #2b2b2b;
|
|
}
|
|
QDialog QLabel, QMessageBox QLabel {
|
|
color: #ffffff;
|
|
}
|
|
QPushButton {
|
|
background-color: #2b2b2b;
|
|
color: #ffffff;
|
|
border: 1px solid #555555; /* Subtle Grey border */
|
|
border-radius: 3px;
|
|
padding: 4px;
|
|
}
|
|
QPushButton:hover {
|
|
background-color: #3d3d3d;
|
|
border-color: #888888; /* Brightens border on hover */
|
|
}
|
|
QPushButton:pressed {
|
|
background-color: #111111;
|
|
}
|
|
QPushButton:disabled {
|
|
border-color: #333333;
|
|
color: #444444;
|
|
}
|
|
/* Splitter/Divider styling */
|
|
QSplitter::handle {
|
|
background-color: #333333; /* Dark grey dividers */
|
|
}
|
|
QSplitter::handle:horizontal {
|
|
width: 2px;
|
|
}
|
|
QSplitter::handle:vertical {
|
|
height: 2px;
|
|
}
|
|
/* ScrollArea styling to keep it dark */
|
|
QScrollArea, QScrollArea > QWidget > QWidget {
|
|
background-color: #1e1e1e;
|
|
border: none;
|
|
}
|
|
""")
|
|
|
|
self.predictor = GeneralPredictor()
|
|
|
|
self.file_path = None
|
|
self.obs_file = None
|
|
self.selected_obs_subkey = None
|
|
self.current_video_offset = 0.0
|
|
|
|
# Core Layout
|
|
main_splitter = QSplitter(Qt.Vertical)
|
|
top_splitter = QSplitter(Qt.Horizontal)
|
|
|
|
# --- Top Left: Video Player ---
|
|
video_container = QWidget()
|
|
video_layout = QVBoxLayout(video_container)
|
|
video_layout.setContentsMargins(0, 0, 0, 0)
|
|
|
|
self.scene = QGraphicsScene()
|
|
# Use our new subclass instead of standard QGraphicsView
|
|
self.view = VideoView(self.scene)
|
|
self.view.resized.connect(self.update_video_geometry)
|
|
|
|
# Video item (NOT native)
|
|
self.video_item = QGraphicsVideoItem()
|
|
self.scene.addItem(self.video_item)
|
|
|
|
# Overlay widget (normal QWidget)
|
|
self.skeleton_overlay = SkeletonOverlay(self.view.viewport())
|
|
self.skeleton_overlay.setAttribute(Qt.WA_TransparentForMouseEvents)
|
|
self.skeleton_overlay.setAttribute(Qt.WA_TranslucentBackground)
|
|
self.skeleton_overlay.show()
|
|
|
|
# Media player
|
|
self.player = QMediaPlayer()
|
|
self.audio_output = QAudioOutput()
|
|
self.player.setAudioOutput(self.audio_output)
|
|
self.player.setVideoOutput(self.video_item)
|
|
|
|
video_layout.addWidget(self.view)
|
|
|
|
# --- Control Bar Container (Vertical Stack) ---
|
|
controls_container = QWidget()
|
|
stacked_controls = QVBoxLayout(controls_container)
|
|
stacked_controls.setSpacing(5) # Tight spacing between rows
|
|
|
|
# --- ROW 1: ML & Training Controls ---
|
|
ml_row = QHBoxLayout()
|
|
ml_row.addStretch()
|
|
|
|
ml_row.addWidget(QLabel("ML Model:"))
|
|
self.ml_dropdown = QComboBox()
|
|
self.ml_dropdown.addItems(["Random Forest", "LSTM", "XGBoost", "SVM", "1D-CNN"])
|
|
ml_row.addWidget(self.ml_dropdown)
|
|
|
|
ml_row.addWidget(QLabel("Target:"))
|
|
self.target_dropdown = QComboBox()
|
|
self.target_dropdown.addItems(["Mouthing", "Head Movement", "Kick (Left)", "Kick (Right)", "Reach (Left)", "Reach (Right)"])
|
|
self.target_dropdown.currentTextChanged.connect(self.update_predictor_target)
|
|
ml_row.addWidget(self.target_dropdown)
|
|
|
|
self.btn_add_to_pool = QPushButton("Add to Pool")
|
|
self.btn_add_to_pool.clicked.connect(self.add_current_to_ml_pool)
|
|
self.btn_add_to_pool.setFixedWidth(120)
|
|
ml_row.addWidget(self.btn_add_to_pool)
|
|
|
|
self.btn_train_final = QPushButton("Train Global Model")
|
|
self.btn_train_final.setStyleSheet("background-color: #2e7d32; font-weight: bold;")
|
|
self.btn_train_final.clicked.connect(self.run_final_training)
|
|
ml_row.addWidget(self.btn_train_final)
|
|
|
|
self.lbl_pool_status = QLabel("Pool: 0 Participants")
|
|
self.lbl_pool_status.setStyleSheet("color: #00FF00; font-weight: bold; margin-left: 10px;")
|
|
self.lbl_pool_status.setFixedWidth(160)
|
|
ml_row.addWidget(self.lbl_pool_status)
|
|
|
|
self.btn_clear_pool = QPushButton("Clear Pool")
|
|
self.btn_clear_pool.setFixedWidth(100)
|
|
self.btn_clear_pool.setStyleSheet("color: #ff5555; border: 1px solid #ff5555;")
|
|
self.btn_clear_pool.clicked.connect(self.clear_ml_pool)
|
|
ml_row.addWidget(self.btn_clear_pool)
|
|
|
|
self.btn_extract_ai = QPushButton("Extract AI Data")
|
|
self.btn_extract_ai.clicked.connect(self.extract_ai_to_json)
|
|
ml_row.addWidget(self.btn_extract_ai)
|
|
|
|
|
|
ml_row.addStretch()
|
|
|
|
# --- ROW 2: Playback & Transport ---
|
|
playback_row = QHBoxLayout()
|
|
playback_row.addStretch()
|
|
|
|
# Transport Buttons
|
|
self.btn_start = QPushButton("|<")
|
|
self.btn_prev = QPushButton("<")
|
|
self.btn_play = QPushButton("Play")
|
|
self.btn_next = QPushButton(">")
|
|
self.btn_end = QPushButton(">|")
|
|
|
|
self.transport_btns = [self.btn_start, self.btn_prev, self.btn_play, self.btn_next, self.btn_end]
|
|
for btn in self.transport_btns:
|
|
btn.setEnabled(False)
|
|
btn.setFixedWidth(50)
|
|
playback_row.addWidget(btn)
|
|
|
|
self.btn_mute = QPushButton("Vol")
|
|
self.btn_mute.setFixedWidth(40)
|
|
self.btn_mute.setCheckable(True)
|
|
self.btn_mute.clicked.connect(self.toggle_mute)
|
|
|
|
self.sld_volume = QSlider(Qt.Horizontal)
|
|
self.sld_volume.setRange(0, 100)
|
|
self.sld_volume.setValue(100) # Default volume
|
|
self.sld_volume.setFixedWidth(100)
|
|
self.sld_volume.valueChanged.connect(self.update_volume)
|
|
|
|
# Initialize volume
|
|
self.audio_output.setVolume(0.7)
|
|
|
|
playback_row.addWidget(self.btn_mute)
|
|
playback_row.addWidget(self.sld_volume)
|
|
|
|
# Counters
|
|
counter_style = "font-family: 'Consolas'; font-size: 10pt; margin-left: 5px; color: #00FF00;"
|
|
self.lbl_time_counter = QLabel("Time: 00:00 / 00:00")
|
|
self.lbl_frame_counter = QLabel("Frame: 0 / 0")
|
|
self.lbl_time_counter.setFixedWidth(180)
|
|
self.lbl_frame_counter.setFixedWidth(180)
|
|
self.lbl_time_counter.setStyleSheet(counter_style)
|
|
self.lbl_frame_counter.setStyleSheet(counter_style)
|
|
|
|
playback_row.addWidget(self.lbl_time_counter)
|
|
playback_row.addWidget(self.lbl_frame_counter)
|
|
|
|
playback_row.addStretch()
|
|
|
|
# --- Add Rows to Stack ---
|
|
stacked_controls.addLayout(ml_row)
|
|
stacked_controls.addLayout(playback_row)
|
|
|
|
# Add the whole stacked container to the main video layout
|
|
video_layout.addWidget(controls_container)
|
|
|
|
# --- Button Connections ---
|
|
self.btn_play.clicked.connect(self.toggle_playback)
|
|
# Use lambda to pass the target frame to your existing seek_video method
|
|
self.btn_start.clicked.connect(lambda: self.seek_video(0))
|
|
self.btn_end.clicked.connect(lambda: self.seek_video(self.data['total_frames'] - 1))
|
|
self.btn_prev.clicked.connect(lambda: self.step_frame(-1))
|
|
self.btn_next.clicked.connect(lambda: self.step_frame(1))
|
|
|
|
# --- Top Right: Media Info & Loader ---
|
|
info_container = QWidget()
|
|
info_layout = QVBoxLayout(info_container)
|
|
|
|
# NEW: Wrap the info_label in a Scroll Area
|
|
self.inspector_scroll = QScrollArea()
|
|
self.inspector_scroll.setWidgetResizable(True)
|
|
self.inspector_scroll.setStyleSheet("border: none; background-color: transparent;")
|
|
|
|
# Create the label as the scroll area's content
|
|
self.info_label = QTextEdit()
|
|
self.info_label.setText("No video loaded.\nClick 'File' > 'Load Video' to begin.")
|
|
self.info_label.setAlignment(Qt.AlignTop | Qt.AlignLeft)
|
|
self.info_label.setWordWrapMode(QTextOption.WordWrap)
|
|
self.info_label.setReadOnly(True)
|
|
|
|
# self.info_label.setWordWrap(True) # Ensure long text wraps instead o
|
|
# f stretching horizontally
|
|
self.info_label.setStyleSheet("padding: 5px; font-family: 'Segoe UI', Arial; color: #ffffff;")
|
|
|
|
self.inspector_scroll.setWidget(self.info_label)
|
|
|
|
# Add the scroll area to the layout instead of the naked label
|
|
info_layout.addWidget(self.inspector_scroll)
|
|
|
|
top_splitter.addWidget(video_container)
|
|
top_splitter.addWidget(info_container)
|
|
top_splitter.setSizes([800, 400])
|
|
|
|
# --- Bottom: Timeline in a Scroll Area ---
|
|
self.timeline = TimelineWidget()
|
|
self.timeline.seek_requested.connect(self.seek_video)
|
|
self.timeline.visibility_changed.connect(self.skeleton_overlay.set_hidden_tracks)
|
|
self.timeline.track_selected.connect(self.on_track_selected)
|
|
|
|
scroll_area = QScrollArea()
|
|
scroll_area.setWidgetResizable(True)
|
|
scroll_area.setWidget(self.timeline)
|
|
|
|
main_splitter.addWidget(top_splitter)
|
|
main_splitter.addWidget(scroll_area)
|
|
main_splitter.setSizes([500, 400])
|
|
|
|
self.setCentralWidget(main_splitter)
|
|
self.player.positionChanged.connect(self.update_timeline_playhead)
|
|
self.player.positionChanged.connect(self.update_inspector)
|
|
self.create_menu_bar()
|
|
|
|
self.local_check_thread = LocalPendingUpdateCheckThread(CURRENT_VERSION, self.platform_suffix, PLATFORM_NAME, APP_NAME)
|
|
self.local_check_thread.pending_update_found.connect(self.updater.on_pending_update_found)
|
|
self.local_check_thread.no_pending_update.connect(self.updater.on_no_pending_update)
|
|
self.local_check_thread.start()
|
|
|
|
|
|
|
|
def create_menu_bar(self):
|
|
'''Menu Bar at the top of the screen'''
|
|
|
|
menu_bar = self.menuBar()
|
|
self.statusbar = self.statusBar()
|
|
|
|
def make_action(name, shortcut=None, slot=None, checkable=False, checked=False, icon=None):
|
|
action = QAction(name, self)
|
|
|
|
if shortcut:
|
|
action.setShortcut(QKeySequence(shortcut))
|
|
if slot:
|
|
action.triggered.connect(slot)
|
|
if checkable:
|
|
action.setCheckable(True)
|
|
action.setChecked(checked)
|
|
if icon:
|
|
action.setIcon(QIcon(icon))
|
|
return action
|
|
|
|
# File menu and actions
|
|
file_menu = menu_bar.addMenu("File")
|
|
file_actions = [
|
|
("Load Video...", "Ctrl+O", self.load_video, resource_path("icons/file_open_24dp_1F1F1F.svg")),
|
|
# ("Open Folder...", "Ctrl+Alt+O", self.not_implemented, resource_path("icons/folder_24dp_1F1F1F.svg")),
|
|
# ("Open Folders...", "Ctrl+Shift+O", self.open_folder_dialog, resource_path("icons/folder_copy_24dp_1F1F1F.svg")),
|
|
# ("Load Project...", "Ctrl+L", self.not_implemented, resource_path("icons/article_24dp_1F1F1F.svg")),
|
|
# ("Save Project...", "Ctrl+S", self.not_implemented, resource_path("icons/save_24dp_1F1F1F.svg")),
|
|
# ("Save Project As...", "Ctrl+Shift+S", self.not_implemented, resource_path("icons/save_as_24dp_1F1F1F.svg")),
|
|
]
|
|
|
|
for i, (name, shortcut, slot, icon) in enumerate(file_actions):
|
|
file_menu.addAction(make_action(name, shortcut, slot, icon=icon))
|
|
if i == 1: # after the first 3 actions (0,1,2)
|
|
file_menu.addSeparator()
|
|
|
|
file_menu.addSeparator()
|
|
file_menu.addAction(make_action("Exit", "Ctrl+Q", QApplication.instance().quit, icon=resource_path("icons/exit_to_app_24dp_1F1F1F.svg")))
|
|
|
|
# Edit menu
|
|
edit_menu = menu_bar.addMenu("Edit")
|
|
edit_actions = [
|
|
("Cut", "Ctrl+X", self.cut_text, resource_path("icons/content_cut_24dp_1F1F1F.svg")),
|
|
("Copy", "Ctrl+C", self.copy_text, resource_path("icons/content_copy_24dp_1F1F1F.svg")),
|
|
("Paste", "Ctrl+V", self.paste_text, resource_path("icons/content_paste_24dp_1F1F1F.svg"))
|
|
]
|
|
for name, shortcut, slot, icon in edit_actions:
|
|
edit_menu.addAction(make_action(name, shortcut, slot, icon=icon))
|
|
|
|
# View menu
|
|
view_menu = menu_bar.addMenu("View")
|
|
toggle_statusbar_action = make_action("Toggle Status Bar", checkable=True, checked=True, slot=None)
|
|
view_menu.addAction(toggle_statusbar_action)
|
|
toggle_statusbar_action.toggled.connect(self.statusbar.setVisible)
|
|
|
|
# Options menu (Help & About)
|
|
options_menu = menu_bar.addMenu("Options")
|
|
|
|
options_actions = [
|
|
("User Guide", "F1", self.user_guide, resource_path("icons/help_24dp_1F1F1F.svg")),
|
|
("Check for Updates", "F5", self.updater.manual_check_for_updates, resource_path("icons/update_24dp_1F1F1F.svg")),
|
|
("Batch YOLO processing...", "F6", self.open_batch_tool, resource_path("icons/upgrade_24dp_1F1F1F.svg")),
|
|
("About", "F12", self.about_window, resource_path("icons/info_24dp_1F1F1F.svg"))
|
|
]
|
|
|
|
for i, (name, shortcut, slot, icon) in enumerate(options_actions):
|
|
options_menu.addAction(make_action(name, shortcut, slot, icon=icon))
|
|
if i == 1 or i == 3: # after the first 2 actions (0,1)
|
|
options_menu.addSeparator()
|
|
|
|
preferences_menu = menu_bar.addMenu("Preferences")
|
|
preferences_actions = [
|
|
("Not Implemented", "", self.not_implemented, resource_path("icons/info_24dp_1F1F1F.svg")),
|
|
]
|
|
for name, shortcut, slot, icon in preferences_actions:
|
|
preferences_menu.addAction(make_action(name, shortcut, slot, icon=icon, checkable=True, checked=False))
|
|
|
|
terminal_menu = menu_bar.addMenu("Terminal")
|
|
terminal_actions = [
|
|
("Not Implemented", "", self.not_implemented, resource_path("icons/terminal_24dp_1F1F1F.svg")),
|
|
]
|
|
for name, shortcut, slot, icon in terminal_actions:
|
|
terminal_menu.addAction(make_action(name, shortcut, slot, icon=icon))
|
|
|
|
self.statusbar.showMessage("Ready")
|
|
|
|
|
|
def not_implemented(self):
|
|
self.statusbar.showMessage("Not Implemented.") # Show status message
|
|
|
|
def copy_text(self):
|
|
self.info_label.copy() # Trigger copy
|
|
self.statusbar.showMessage("Copied to clipboard") # Show status message
|
|
|
|
def cut_text(self):
|
|
self.info_label.cut() # Trigger cut
|
|
self.statusbar.showMessage("Cut to clipboard") # Show status message
|
|
|
|
def about_window(self):
|
|
if self.about is None or not self.about.isVisible():
|
|
self.about = AboutWindow(self)
|
|
self.about.show()
|
|
|
|
def user_guide(self):
|
|
if self.help is None or not self.help.isVisible():
|
|
self.help = UserGuideWindow(self)
|
|
self.help.show()
|
|
|
|
def paste_text(self):
|
|
self.info_label.paste() # Trigger paste
|
|
self.statusbar.showMessage("Pasted from clipboard") # Show status message
|
|
|
|
def open_batch_tool(self):
|
|
dialog = BatchProcessorDialog(self) # Pass 'self' to keep it centered
|
|
dialog.exec()
|
|
|
|
def toggle_mute(self):
|
|
is_muted = self.btn_mute.isChecked()
|
|
self.audio_output.setMuted(is_muted)
|
|
self.btn_mute.setText("Mute" if is_muted else "Vol")
|
|
# Optional: Dim the slider when muted
|
|
self.sld_volume.setEnabled(not is_muted)
|
|
|
|
def update_volume(self, value):
|
|
# QAudioOutput expects a float between 0.0 and 1.0
|
|
volume = value / 100.0
|
|
self.audio_output.setVolume(volume)
|
|
|
|
# Auto-unmute if user moves the slider
|
|
if self.btn_mute.isChecked() and value > 0:
|
|
self.btn_mute.setChecked(False)
|
|
self.toggle_mute()
|
|
|
|
def clear_ml_pool(self):
|
|
"""Removes all participants from the training buffer."""
|
|
debug_print()
|
|
# Confirm with the user first to prevent accidental deletions
|
|
reply = QMessageBox.question(self, 'Clear Pool?',
|
|
f"This will remove all {len(self.predictor.raw_participant_buffer)} "
|
|
"participants from the training memory. Continue?",
|
|
QMessageBox.Yes | QMessageBox.No, QMessageBox.No)
|
|
|
|
if reply == QMessageBox.Yes:
|
|
# 1. Clear the actual list in the predictor
|
|
self.predictor.raw_participant_buffer = []
|
|
|
|
# 2. Update the UI label
|
|
self.lbl_pool_status.setText("Pool: 0 Participants")
|
|
|
|
# 3. Optional: Visual feedback
|
|
# self.statusBar().showMessage("ML Pool cleared.", 3000)
|
|
print("DEBUG: ML Pool manually cleared.")
|
|
|
|
|
|
def update_predictor_target(self):
|
|
debug_print()
|
|
# This physically changes the string from "Mouthing" to "Head Movement"
|
|
self.predictor.current_target = self.target_dropdown.currentText()
|
|
|
|
print(f"Predictor is now targeting: {self.predictor.current_target}")
|
|
|
|
|
|
def reprocess_current_video(self):
|
|
"""Restarts the analysis worker to pick up new models."""
|
|
debug_print()
|
|
|
|
# Start the worker (passing the predictor so it can run AI models)
|
|
self.worker = PoseAnalyzerWorker(
|
|
self.file_path,
|
|
obs_info=self.selected_obs_subkey,
|
|
predictor=self.predictor
|
|
)
|
|
|
|
self.worker.progress.connect(self.update_status)
|
|
self.worker.finished_data.connect(self.setup_workspace)
|
|
self.worker.start()
|
|
|
|
|
|
def update_video_geometry(self):
|
|
debug_print()
|
|
if not hasattr(self, "video_item") or not hasattr(self, "data"):
|
|
return
|
|
|
|
viewport_rect = self.view.viewport().rect()
|
|
v_w, v_h = viewport_rect.width(), viewport_rect.height()
|
|
if v_w <= 0 or v_h <= 0: return
|
|
|
|
video_w, video_h = self.data['width'], self.data['height']
|
|
aspect = video_w / video_h
|
|
|
|
if v_w / v_h > aspect:
|
|
target_h = v_h
|
|
target_w = int(v_h * aspect)
|
|
else:
|
|
target_w = v_w
|
|
target_h = int(v_w / aspect)
|
|
|
|
x_off = (v_w - target_w) / 2
|
|
y_off = (v_h - target_h) / 2
|
|
|
|
self.scene.setSceneRect(0, 0, v_w, v_h)
|
|
self.video_item.setPos(x_off, y_off)
|
|
self.video_item.setSize(QSizeF(target_w, target_h))
|
|
self.skeleton_overlay.setGeometry(int(x_off), int(y_off), target_w, target_h)
|
|
|
|
def resizeEvent(self, event):
|
|
debug_print()
|
|
|
|
super().resizeEvent(event)
|
|
self.update_video_geometry()
|
|
if hasattr(self, 'timeline'):
|
|
self.timeline.set_zoom(self.timeline.zoom_factor)
|
|
|
|
# def eventFilter(self, source, event):
|
|
# if source is self.video_widget and event.type() == QEvent.Resize:
|
|
# self.skeleton_overlay.resize(event.size())
|
|
# return super().eventFilter(source, event)
|
|
|
|
|
|
|
|
def add_current_to_ml_pool(self):
|
|
"""Adds raw kinematic data and current OBS labels to the buffer."""
|
|
debug_print()
|
|
if not hasattr(self, 'data') or 'raw_kps' not in self.data:
|
|
QMessageBox.warning(self, "No Data", "Load a video first.")
|
|
return
|
|
|
|
# 1. Grab everything the Worker produced
|
|
payload = {
|
|
"z_kps": self.data['z_kps'],
|
|
"directions": self.data['directions'],
|
|
"raw_kps": self.data['raw_kps']
|
|
}
|
|
|
|
all_labels = self.timeline.get_all_binary_labels(self.current_video_offset, self.data["fps"])
|
|
|
|
# 3. Hand off to predictor
|
|
msg = self.predictor.add_to_raw_buffer(payload, all_labels)
|
|
self.lbl_pool_status.setText(f"Pool: {len(self.predictor.raw_participant_buffer)} Participants")
|
|
print(f"DEBUG: Added to Predictor at {hex(id(self.predictor))}")
|
|
print(f"DEBUG: Buffer size is now: {len(self.predictor.raw_participant_buffer)}")
|
|
QMessageBox.information(self, "Success", msg)
|
|
|
|
|
|
def run_final_training(self):
|
|
"""
|
|
Triggers training
|
|
"""
|
|
debug_print()
|
|
# DEBUG: Check the buffer directly before the IF statement
|
|
actual_buffer = self.predictor.raw_participant_buffer
|
|
current_count = len(actual_buffer)
|
|
|
|
if current_count < 1:
|
|
# If this triggers, let's see WHY it's empty
|
|
QMessageBox.warning(self, "Empty Pool",
|
|
f"Buffer is empty (Size: {current_count}).\n"
|
|
f"Predictor ID: {hex(id(self.predictor))}")
|
|
return
|
|
|
|
model_type = self.ml_dropdown.currentText()
|
|
current_target = self.target_dropdown.currentText()
|
|
|
|
reply = QMessageBox.question(self, 'Confirm Training',
|
|
f"Train {model_type} for '{current_target}' using "
|
|
f"{current_count} participants?",
|
|
QMessageBox.Yes | QMessageBox.No, QMessageBox.No)
|
|
|
|
if reply == QMessageBox.Yes:
|
|
self.btn_train_final.setEnabled(False)
|
|
self.btn_train_final.setText(f"Training...")
|
|
|
|
try:
|
|
# Force the target update right before training
|
|
self.predictor.current_target = current_target
|
|
report_html = self.predictor.calculate_and_train(model_type, current_target)
|
|
|
|
self.reprocess_current_video()
|
|
|
|
self.info_label.setText(report_html)
|
|
msg = QMessageBox(self)
|
|
msg.setWindowTitle("Results")
|
|
msg.setTextFormat(Qt.RichText)
|
|
msg.setText(report_html)
|
|
msg.exec()
|
|
|
|
|
|
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
QMessageBox.critical(self, "Error", f"{str(e)}")
|
|
|
|
finally:
|
|
self.btn_train_final.setEnabled(True)
|
|
self.btn_train_final.setText("Train Global Model")
|
|
|
|
|
|
|
|
# def import_json_observations(self):
|
|
# debug_print()
|
|
# file_path, _ = QFileDialog.getOpenFileName(self, "Select JSON Observations", "", "JSON Files (*.json)")
|
|
# if not file_path: return
|
|
|
|
# with open(file_path, 'r') as f:
|
|
# full_data = json.load(f)
|
|
|
|
# # Get the subkeys under "observations"
|
|
# subkeys = list(full_data.get("observations", {}).keys())
|
|
|
|
# if not subkeys:
|
|
# print("No observations found in JSON.")
|
|
# return
|
|
|
|
# item, ok = QInputDialog.getItem(self, "Select Session", "Pick an observation set:", subkeys, 0, False)
|
|
|
|
# if ok and item:
|
|
# new_obs_data = self.load_external_observations(file_path, item)
|
|
# self.append_new_tracks(new_obs_data)
|
|
|
|
def append_new_tracks(self, new_obs_data):
|
|
debug_print()
|
|
# 1. Update global TRACK_NAMES and TRACK_COLORS
|
|
for name in new_obs_data.keys():
|
|
if name not in TRACK_NAMES:
|
|
TRACK_NAMES.append(name)
|
|
# Assign a distinct color (e.g., a dark purple/magenta for observations)
|
|
TRACK_COLORS.append("#AA00FF")
|
|
|
|
# 2. Merge into existing data dictionary
|
|
self.data["events"].update(new_obs_data)
|
|
|
|
# 3. Refresh Timeline
|
|
global NUM_TRACKS
|
|
NUM_TRACKS = len(TRACK_NAMES)
|
|
self.timeline.set_data(self.data)
|
|
self.timeline.update_geometry()
|
|
|
|
# def load_external_observations(self, file_path, subkey):
|
|
# debug_print()
|
|
# with open(file_path, 'r') as f:
|
|
# data = json.load(f)
|
|
|
|
# raw_events = data["observations"][subkey]["events"]
|
|
# # We only care about: [time_seconds (0), _, label (2), _, _, _]
|
|
|
|
# new_tracks = {}
|
|
|
|
# # Sort events by time just in case the JSON is unsorted
|
|
# raw_events.sort(key=lambda x: x[0])
|
|
|
|
# # Group timestamps by their label (e.g., "Kick", "Baseline")
|
|
# temp_storage = {}
|
|
# for event in raw_events:
|
|
# time_sec = event[0]
|
|
# label = event[2]
|
|
# frame = int(time_sec * self.data["fps"])
|
|
|
|
# if label not in temp_storage:
|
|
# temp_storage[label] = []
|
|
# temp_storage[label].append(frame)
|
|
|
|
# # Convert pairs of frames into (start, end) blocks
|
|
# for label, frames in temp_storage.items():
|
|
# processed_blocks = []
|
|
# # Step through frames in pairs (start, end)
|
|
# for i in range(0, len(frames) - 1, 2):
|
|
# start = frames[i]
|
|
# end = frames[i+1]
|
|
# # Format: (start, end, severity, direction)
|
|
# processed_blocks.append((start, end, "External", "Manual Obs"))
|
|
|
|
# new_tracks[f"OBS: {label}"] = processed_blocks
|
|
|
|
# return new_tracks
|
|
|
|
|
|
def load_video(self):
|
|
debug_print()
|
|
self.file_path, _ = QFileDialog.getOpenFileName(self, "Open Video", "", "Video Files (*.mp4 *.avi *.mkv)")
|
|
if not self.file_path: return
|
|
|
|
cap = cv2.VideoCapture(self.file_path)
|
|
if cap.isOpened():
|
|
self.current_video_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
|
#total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
# Optional: Initialize timeline with blank data so it can at least draw the ruler
|
|
#self.timeline.data = {"total_frames": total_frames, "fps": self.current_video_fps, "events": {}}
|
|
cap.release()
|
|
else:
|
|
self.current_video_fps = 30.0 # Fallback
|
|
|
|
# --- NEW: JSON Observation Prompt ---
|
|
self.obs_file, _ = QFileDialog.getOpenFileName(self, "Select JSON Observations (Optional)", "", "JSON Files (*.json)")
|
|
|
|
|
|
if self.obs_file:
|
|
try:
|
|
with open(self.obs_file, 'r') as f:
|
|
full_json = json.load(f)
|
|
|
|
observations = full_json.get("observations", {})
|
|
subkeys = list(observations.keys())
|
|
|
|
# --- AUTO-MATCHING LOGIC ---
|
|
# 1. Get the video filename without extension (e.g., 'T4_2T_WORD_F')
|
|
video_name = os.path.splitext(os.path.basename(self.file_path))[0]
|
|
v_parts = video_name.split('_')
|
|
|
|
# Build the 'fingerprint' from the video (Blocks 1, 2, and the Last one)
|
|
# This ignores the 'WORDHERE' block in the middle
|
|
if len(v_parts) >= 3:
|
|
video_fingerprint = f"{v_parts[0]}_{v_parts[1]}_{v_parts[-1]}"
|
|
else:
|
|
video_fingerprint = video_name # Fallback
|
|
|
|
match = None
|
|
for sk in subkeys:
|
|
s_parts = sk.split('_')
|
|
# Subkeys are shorter: Block 1, 2, and 3
|
|
if len(s_parts) == 3:
|
|
sk_fingerprint = f"{s_parts[0]}_{s_parts[1]}_{s_parts[2]}"
|
|
if sk_fingerprint.lower() == video_fingerprint.lower():
|
|
match = sk
|
|
break
|
|
|
|
# 2. Decision: Use match or prompt user
|
|
if match:
|
|
self.selected_obs_subkey = (self.obs_file, match)
|
|
self.statusBar().showMessage(f"Auto-matched JSON session: {match}", 5000)
|
|
elif subkeys:
|
|
# No match found, only then show the popup
|
|
item, ok = QInputDialog.getItem(self, "Select Session",
|
|
f"Could not auto-match '{video_name}'.\nPick manually:",
|
|
subkeys, 0, False)
|
|
if ok and item:
|
|
self.selected_obs_subkey = (self.obs_file, item)
|
|
|
|
# --- NEW: Offset & File Matching Logic ---
|
|
if self.selected_obs_subkey:
|
|
_, session_key = self.selected_obs_subkey
|
|
session_data = observations.get(session_key, {})
|
|
file_map = session_data.get("file", {})
|
|
|
|
video_filename = os.path.basename(self.file_path)
|
|
found_index = None
|
|
|
|
# 1. Attempt Auto-Match by filename
|
|
for idx_str, file_list in file_map.items():
|
|
# Check if our loaded video is in this list (e.g., "Videos\\T4_2T_WORD_F.mp4")
|
|
if any(video_filename in path for path in file_list):
|
|
found_index = idx_str
|
|
print(f"DEBUG: Auto-matched video to Camera Index {idx_str}")
|
|
break
|
|
|
|
# 2. If Auto-Match fails, prompt user for Camera Index
|
|
if not found_index:
|
|
available_indices = [k for k, v in file_map.items() if v] # Only indices with files
|
|
if available_indices:
|
|
item, ok = QInputDialog.getItem(self, "Identify Camera",
|
|
f"Could not find '{video_filename}' in JSON.\n"
|
|
"Which camera index is this video?",
|
|
available_indices, 0, False)
|
|
if ok:
|
|
found_index = item
|
|
|
|
# 3. Retrieve and Print the Offset
|
|
if found_index:
|
|
offsets = session_data.get("media_info", {}).get("offset", {})
|
|
search_key = str(found_index)
|
|
# Note: offsets dict might use integers or strings as keys
|
|
# We check both to be safe
|
|
actual_offset = offsets.get(search_key)
|
|
|
|
if actual_offset is not None:
|
|
print(f"MATCHED OFFSET: {actual_offset:.4f}")
|
|
# Store this if you need it for timeline syncing later
|
|
self.current_video_offset = float(actual_offset)
|
|
self.timeline.set_sync_params(
|
|
offset_seconds=self.current_video_offset,
|
|
fps=self.current_video_fps
|
|
)
|
|
|
|
print(f"Timeline synced with {actual_offset}s offset.")
|
|
else:
|
|
print(f"DEBUG: No offset found for index {found_index}")
|
|
|
|
except Exception as e:
|
|
QMessageBox.warning(self, "JSON Error", f"Could not parse JSON: {e}")
|
|
|
|
# --- Cache Logic ---
|
|
# cache_path = self.file_path.rsplit('.', 1)[0] + "_pose_cache.csv"
|
|
# use_cache = None
|
|
# if os.path.exists(cache_path):
|
|
# reply = QMessageBox.question(self, 'Cache Found',
|
|
# "Use existing pose cache?",
|
|
# QMessageBox.Yes | QMessageBox.No)
|
|
# use_cache = cache_path if reply == QMessageBox.Yes else None
|
|
|
|
|
|
# Pass the observation info to the worker
|
|
self.worker = PoseAnalyzerWorker(self.file_path, self.selected_obs_subkey, self.predictor)
|
|
self.worker.progress.connect(self.update_status)
|
|
self.worker.finished_data.connect(self.setup_workspace)
|
|
self.worker.start()
|
|
|
|
def update_status(self, msg):
|
|
debug_print()
|
|
|
|
self.info_label.setText(f"Status:\n{msg}")
|
|
|
|
def setup_workspace(self, data):
|
|
debug_print()
|
|
self.data = data
|
|
self.player.setSource(QUrl.fromLocalFile(data["video_path"]))
|
|
self.player.play()
|
|
self.player.pause()
|
|
self.timeline.set_data(data)
|
|
self.skeleton_overlay.set_data(data)
|
|
self.update_video_geometry()
|
|
for btn in self.transport_btns:
|
|
btn.setEnabled(True)
|
|
total_f = data['total_frames']
|
|
fps = data['fps']
|
|
tot_s = int(total_f / fps)
|
|
|
|
# Display 0 / Total
|
|
self.lbl_time_counter.setText(f"00:00 / {tot_s//60:02d}:{tot_s%60:02d}")
|
|
self.lbl_frame_counter.setText(f"0 / {total_f-1}")
|
|
|
|
# Sync widgets
|
|
self.timeline.set_data(data)
|
|
self.skeleton_overlay.set_data(data)
|
|
|
|
# Force a seek to frame 0 to initialize the video buffer
|
|
self.seek_video(0)
|
|
self.btn_load.setEnabled(True)
|
|
|
|
info_text = (
|
|
f"File: {os.path.basename(data['video_path'])}\n"
|
|
f"Resolution: {data['width']}x{data['height']}\n"
|
|
f"FPS: {data['fps']:.2f}\n"
|
|
f"Total Frames: {data['total_frames']}\n\n"
|
|
f"Timeline Legend (Opacity):\n"
|
|
f"255 Alpha = Large Deviation\n"
|
|
f"160 Alpha = Moderate Deviation\n"
|
|
f"80 Alpha = Small Deviation\n"
|
|
f"Empty = Rest (Baseline)"
|
|
)
|
|
self.info_label.setText(info_text)
|
|
|
|
|
|
def toggle_playback(self):
|
|
debug_print()
|
|
|
|
if not hasattr(self, 'data'): return
|
|
|
|
# If we are at the end, jump to the start first
|
|
fps = self.data["fps"]
|
|
current_frame = int((self.player.position() / 1000.0) * fps + 0.5)
|
|
if current_frame >= self.data["total_frames"] - 1:
|
|
self.seek_video(0)
|
|
|
|
if self.player.playbackState() == QMediaPlayer.PlayingState:
|
|
self.player.pause()
|
|
self.btn_play.setText("Play")
|
|
else:
|
|
self.player.play()
|
|
self.btn_play.setText("Pause")
|
|
|
|
def update_timeline_playhead(self, position_ms):
|
|
debug_print()
|
|
if hasattr(self, 'data') and self.data["fps"] > 0:
|
|
fps = self.data["fps"]
|
|
total_f = self.data["total_frames"]
|
|
|
|
# Current frame calculation
|
|
current_f = int((position_ms / 1000.0) * fps)
|
|
|
|
# --- PREVENT BLACK FRAME AT END ---
|
|
# If we are within 1 frame of the end, stop and lock to the last valid frame
|
|
if current_f >= total_f - 1:
|
|
if self.player.playbackState() == QMediaPlayer.PlayingState:
|
|
self.player.pause()
|
|
self.btn_play.setText("Play")
|
|
current_f = total_f - 1
|
|
# Seek slightly back from total duration to keep the image visible
|
|
last_valid_ms = int(((total_f - 1) / fps) * 1000)
|
|
self.player.setPosition(last_valid_ms)
|
|
|
|
# Sync UI
|
|
self.timeline.set_playhead(current_f)
|
|
self.skeleton_overlay.set_frame(current_f)
|
|
self.update_counters(current_f)
|
|
|
|
|
|
def on_track_selected(self, track_name):
|
|
debug_print()
|
|
|
|
self.selected_track = track_name
|
|
|
|
if not track_name:
|
|
self.info_label.setText("No track selected.\nClick a data track to inspect.")
|
|
self.info_label.setStyleSheet("color: #AAAAAA; font-family: 'Segoe UI'; font-size: 10pt;")
|
|
else:
|
|
self.info_label.setStyleSheet("color: #00FF00; font-family: 'Segoe UI'; font-size: 10pt;")
|
|
self.update_inspector() # Refresh immediately on click
|
|
|
|
|
|
def update_inspector(self):
|
|
debug_print()
|
|
if not hasattr(self, 'selected_track') or not self.selected_track or not self.data:
|
|
return
|
|
|
|
# 1. Temporal Logic
|
|
current_f = int((self.player.position() / 1000.0) * self.data["fps"])
|
|
current_f = max(0, min(current_f, self.data["total_frames"] - 1))
|
|
|
|
is_ai = "AI:" in self.selected_track
|
|
is_obs = "OBS:" in self.selected_track
|
|
|
|
# 2. Status/Raw Logic
|
|
if is_ai or is_obs:
|
|
# Check Activity for Behavior Tracks
|
|
events = self.data["events"].get(self.selected_track, [])
|
|
is_active = any(start <= current_f <= end for start, end, *rest in events)
|
|
active_color = "#ff5555" if is_active else "#888888"
|
|
|
|
status_line = f"<b>ACTIVE:</b> <span style='color:{active_color}; font-weight:bold;'>{'YES' if is_active else 'NO'}</span>"
|
|
raw_line = "" # Do not display raw for AI/OBS
|
|
else:
|
|
# Kinematics Logic (No Active status)
|
|
status_line = ""
|
|
raw_info = "N/A"
|
|
cache_path = self.file_path.rsplit('.', 1)[0] + "_pose_raw.csv"
|
|
print(cache_path)
|
|
|
|
if os.path.exists(cache_path):
|
|
try:
|
|
|
|
# Row 2 in CSV is Frame 0. pandas.read_csv uses Row 1 as header.
|
|
# So Frame 0 is df.iloc[0].
|
|
print(current_f)
|
|
print(len(self.worker.pose_df))
|
|
if current_f < len(self.worker.pose_df):
|
|
row = self.worker.pose_df.iloc[current_f]
|
|
print(self.selected_track)
|
|
col_x, col_y, col_c = f"{self.selected_track}_x", f"{self.selected_track}_y", f"{self.selected_track}_conf"
|
|
print(self.worker.pose_df.columns)
|
|
|
|
if col_x in self.worker.pose_df.columns and col_y in self.worker.pose_df.columns:
|
|
print("me")
|
|
rx, ry = row[col_x], row[col_y]
|
|
rc = row[col_c] if col_c in self.worker.pose_df.columns else 0.0
|
|
raw_info = f"X: {rx:.2f} | Y: {ry:.2f} | Conf: {rc:.2f}"
|
|
except Exception as e:
|
|
print(f"Inspector CSV Error: {e}")
|
|
raw_info = "Index Error"
|
|
|
|
raw_line = f"<b>RAW (CSV):</b> {raw_info}"
|
|
|
|
# 3. Construct Display
|
|
display_text = (
|
|
f"<b style='color: #55aaff;'>TRACK:</b> {self.selected_track}<br>"
|
|
f"<b>FRAME:</b> {current_f}<br>"
|
|
f"{status_line}"
|
|
f"{raw_line}"
|
|
)
|
|
|
|
# 4. Performance Report
|
|
if is_ai:
|
|
target_name = self.selected_track.replace("AI: ", "")
|
|
pattern = f"ml_{target_name}_performance_*.txt"
|
|
report_files = sorted(glob.glob(pattern))
|
|
|
|
report_content = "<i>No report found.</i>"
|
|
if report_files:
|
|
try:
|
|
with open(report_files[-1], 'r') as f:
|
|
report_content = f.read().replace('\n', '<br>')
|
|
except: pass
|
|
|
|
display_text += f"<hr><b style='color: #00FF00;'>AI Performance:</b><br><small>{report_content}</small>"
|
|
|
|
self.info_label.setText(display_text)
|
|
|
|
|
|
def step_frame(self, delta):
|
|
debug_print()
|
|
|
|
if not hasattr(self, 'data'): return
|
|
|
|
fps = self.data["fps"]
|
|
# Calculate current frame based on ms position
|
|
current_f = int((self.player.position() / 1000.0) * fps + 0.5)
|
|
target_f = current_f + delta
|
|
|
|
# Use your existing seek_video to handle bounds and UI updates
|
|
self.seek_video(target_f)
|
|
|
|
def seek_video(self, frame):
|
|
debug_print()
|
|
if hasattr(self, 'data') and self.data["fps"] > 0:
|
|
total_f = self.data["total_frames"]
|
|
target_frame = max(0, min(frame, total_f - 1))
|
|
|
|
# Calculate MS with a tiny offset (0.1) to ensure the player
|
|
# lands ON the frame, not slightly before it.
|
|
ms = int((target_frame / self.data["fps"]) * 1000)
|
|
self.player.setPosition(ms)
|
|
|
|
self.video_item.update()
|
|
|
|
# Update UI immediately for snappier feedback
|
|
self.timeline.set_playhead(target_frame)
|
|
self.update_counters(target_frame)
|
|
|
|
def update_counters(self, current_f):
|
|
debug_print()
|
|
|
|
# Dedicated method to refresh the labels
|
|
fps = self.data["fps"]
|
|
total_f = self.data["total_frames"]
|
|
|
|
cur_s, tot_s = int(current_f / fps), int(total_f / fps)
|
|
self.lbl_time_counter.setText(f"Time: {cur_s//60:02d}:{cur_s%60:02d} / {tot_s//60:02d}:{tot_s%60:02d}")
|
|
self.lbl_frame_counter.setText(f"Frame: {current_f} / {total_f-1}")
|
|
|
|
|
|
|
|
|
|
def extract_ai_to_json(self):
|
|
"""
|
|
Automatically saves AI extractions to the video directory
|
|
with the suffix '_events.json'.
|
|
"""
|
|
|
|
# 1. Check if a video is loaded to get the base path
|
|
video_path = getattr(self, "file_path", None)
|
|
if not video_path or not os.path.exists(video_path):
|
|
print("Error: No video loaded. Cannot determine save path.")
|
|
return
|
|
|
|
# 2. Construct the new filename
|
|
base_dir = os.path.dirname(video_path)
|
|
file_name = os.path.splitext(os.path.basename(video_path))[0]
|
|
save_path = os.path.join(base_dir, f"{file_name}_events.blaze")
|
|
|
|
# 3. Call the timeline method to get the data
|
|
try:
|
|
extraction_data = self.timeline.get_ai_extractions()
|
|
|
|
# Inject source video metadata
|
|
extraction_data["metadata"]["source_video"] = video_path
|
|
|
|
# 4. Save to disk
|
|
with open(save_path, 'w') as f:
|
|
json.dump(extraction_data, f, indent=4)
|
|
|
|
print(f"Extraction automatically saved to: {save_path}")
|
|
|
|
except Exception as e:
|
|
print(f"Error during automatic AI extraction: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
def resource_path(relative_path):
|
|
"""
|
|
Get absolute path to resource regardless of running directly or packaged using PyInstaller
|
|
"""
|
|
|
|
if hasattr(sys, '_MEIPASS'):
|
|
# PyInstaller bundle path
|
|
base_path = sys._MEIPASS
|
|
else:
|
|
base_path = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
return os.path.join(base_path, relative_path)
|
|
|
|
|
|
def kill_child_processes():
|
|
"""
|
|
Goodbye children
|
|
"""
|
|
|
|
try:
|
|
parent = psutil.Process(os.getpid())
|
|
children = parent.children(recursive=True)
|
|
for child in children:
|
|
try:
|
|
child.kill()
|
|
except psutil.NoSuchProcess:
|
|
pass
|
|
psutil.wait_procs(children, timeout=5)
|
|
except Exception as e:
|
|
print(f"Error killing child processes: {e}")
|
|
|
|
|
|
def exception_hook(exc_type, exc_value, exc_traceback):
|
|
"""
|
|
Method that will display a popup when the program hard crashes containg what went wrong
|
|
"""
|
|
|
|
error_msg = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
|
print(error_msg) # also print to console
|
|
|
|
kill_child_processes()
|
|
|
|
# Show error message box
|
|
# Make sure QApplication exists (or create a minimal one)
|
|
app = QApplication.instance()
|
|
if app is None:
|
|
app = QApplication(sys.argv)
|
|
|
|
show_critical_error(error_msg)
|
|
|
|
# Exit the app after user acknowledges
|
|
sys.exit(1)
|
|
|
|
def show_critical_error(error_msg):
|
|
msg_box = QMessageBox()
|
|
msg_box.setIcon(QMessageBox.Icon.Critical)
|
|
msg_box.setWindowTitle("Something went wrong!")
|
|
|
|
if PLATFORM_NAME == "darwin":
|
|
log_path = os.path.join(os.path.dirname(sys.executable), "../../../flares.log")
|
|
log_path2 = os.path.join(os.path.dirname(sys.executable), "../../../flares_error.log")
|
|
save_path = os.path.join(os.path.dirname(sys.executable), "../../../flares_autosave.flare")
|
|
|
|
else:
|
|
log_path = os.path.join(os.getcwd(), "flares.log")
|
|
log_path2 = os.path.join(os.getcwd(), "flares_error.log")
|
|
save_path = os.path.join(os.getcwd(), "flares_autosave.flare")
|
|
|
|
|
|
shutil.copy(log_path, log_path2)
|
|
log_path2 = Path(log_path2).absolute().as_posix()
|
|
autosave_path = Path(save_path).absolute().as_posix()
|
|
log_link = f"file:///{log_path2}"
|
|
autosave_link = f"file:///{autosave_path}"
|
|
|
|
message = (
|
|
f"{APP_NAME.upper()} has encountered an unrecoverable error and needs to close.<br><br>"
|
|
f"We are sorry for the inconvenience. An autosave was attempted to be saved to <a href='{autosave_link}'>{autosave_path}</a>, but it may not have been saved. "
|
|
"If the file was saved, it still may not be intact, openable, or contain the correct data. Use the autosave at your discretion.<br><br>"
|
|
f"This unrecoverable error was likely due to an error with {APP_NAME.upper()} and not your data.<br>"
|
|
f"Please raise an issue <a href='https://git.research.dezeeuw.ca/tyler/{APP_NAME}/issues'>here</a> and attach the error file located at <a href='{log_link}'>{log_path2}</a><br><br>"
|
|
f"<pre>{error_msg}</pre>"
|
|
)
|
|
|
|
msg_box.setTextFormat(Qt.TextFormat.RichText)
|
|
msg_box.setText(message)
|
|
msg_box.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction)
|
|
msg_box.setStandardButtons(QMessageBox.StandardButton.Ok)
|
|
|
|
msg_box.exec()
|
|
|
|
if __name__ == "__main__":
|
|
# Redirect exceptions to the popup window
|
|
sys.excepthook = exception_hook
|
|
|
|
# Set up application logging
|
|
if PLATFORM_NAME == "darwin":
|
|
log_path = os.path.join(os.path.dirname(sys.executable), f"../../../{APP_NAME}.log")
|
|
else:
|
|
log_path = os.path.join(os.getcwd(), f"{APP_NAME}.log")
|
|
|
|
try:
|
|
os.remove(log_path)
|
|
except:
|
|
pass
|
|
|
|
sys.stdout = open(log_path, "a", buffering=1)
|
|
sys.stderr = sys.stdout
|
|
print(f"\n=== App started at {datetime.now()} ===\n")
|
|
|
|
freeze_support() # Required for PyInstaller + multiprocessing
|
|
|
|
# Only run GUI in the main process
|
|
if current_process().name == 'MainProcess':
|
|
app = QApplication(sys.argv)
|
|
finish_update_if_needed(PLATFORM_NAME, APP_NAME)
|
|
window = PremiereWindow()
|
|
|
|
if PLATFORM_NAME == "darwin":
|
|
app.setWindowIcon(QIcon(resource_path("icons/main.icns")))
|
|
window.setWindowIcon(QIcon(resource_path("icons/main.icns")))
|
|
else:
|
|
app.setWindowIcon(QIcon(resource_path("icons/main.ico")))
|
|
window.setWindowIcon(QIcon(resource_path("icons/main.ico")))
|
|
window.show()
|
|
sys.exit(app.exec())
|
|
|
|
# Not 6000 lines yay! |