diff --git a/README.md b/README.md index 9650194..ae53a61 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,44 @@ -# blazes +BLAZES (Behavioral Learning & Automated Zoned Events Suite) +================================================================= -Behavioral Learning & Automated Zoned Events Suite \ No newline at end of file +BLAZES is a standalone application to predict behavioural events from video files. + +BLAZES is free and open-source software that runs on Windows, MacOS, and Linux. Please read the information regarding each operating system below. + +Visit the official [BLAZES web site](https://research.dezeeuw.ca/blazes). + +[![Python web site](https://img.shields.io/badge/Made%20with-Python-1f425f.svg)](https://www.python.org) + +# For MacOS Users + +Due to the cost of an Apple Developer account, the application is not certified by Apple. Once the application is extracted and attempted to be launched for the first time you will get a popup stating: + +"Apple could not verify blazes.app is free of malware that may harm your Mac or compromise your privacy.", with the options of "Done" or "Move to Trash". + +The solution around this is to use finder and navigate to the blazes-darwin folder. Once the folder has been located, right click the folder and click the option "New Terminal at Folder". Once the terminal opens, run the following command (you can copy + paste): + +```xattr -dr com.apple.quarantine blazes.app & pid1=$!; xattr -dr com.apple.quarantine blazes_updater.app & pid2=$!; wait $pid1 $pid2; exit``` + +Once the command has been executed and the text "[Process completed]" appears, you may close the terminal window and attempt to open the application again. If you choose to unrestrict the app through Settings > Privacy & Security, the app may not be able to update correctly in the future. + +This only applies for the first time you attempt to run BLAZES. Subsequent times, including after updates, will function correctly as-is. + +# For Windows Users + +Due to the cost of a code signing certificate, the application is not digitally signed. Once the application is extracted and attempted to be launched for the first time you will get a popup stating: + +"Windows protected your PC - Microsoft Defender SmartScreen prevented an unrecognized app from starting. Running this app might put your PC at risk.", with the options of "More info" or "Don't run". + +The solution around this is to click "More info" and then select "Run anyway". + +This only applies for the first time you attempt to run BLAZES. Subsequent times, including after updates, will function correctly as-is. + +# For Linux Users + +There are no conditions for Linux users at this time. + +# Licence + +BLAZES is distributed under the GPL-3.0 license. + +Copyright (C) 2025-2026 Tyler de Zeeuw \ No newline at end of file diff --git a/batch_processing.py b/batch_processing.py new file mode 100644 index 0000000..fa1040b --- /dev/null +++ b/batch_processing.py @@ -0,0 +1,211 @@ +""" +Filename: batch_processing.py +Description: BLAZES batch processor + +Author: Tyler de Zeeuw +License: GPL-3.0 +""" + +# Built-in imports +import os +import csv +from pathlib import Path + +# External library imports +import cv2 +import numpy as np +from ultralytics import YOLO + +from PySide6.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QLineEdit, + QPushButton, QComboBox, QSpinBox, QLabel, + QFileDialog, QTextEdit) +from PySide6.QtCore import Qt, QObject, Signal, QRunnable, QThreadPool, Slot + + +JOINT_NAMES = [ + "Nose", "Left Eye", "Right Eye", "Left Ear", "Right Ear", + "Left Shoulder", "Right Shoulder", "Left Elbow", "Right Elbow", + "Left Wrist", "Right Wrist", "Left Hip", "Right Hip", + "Left Knee", "Right Knee", "Left Ankle", "Right Ankle" +] + +class WorkerSignals(QObject): + """Signals to communicate back to the UI from the thread pool.""" + progress = Signal(str) + finished = Signal() + +class VideoWorker(QRunnable): + """A worker task for processing a single video file.""" + def __init__(self, video_path, model_name): + super().__init__() + self.video_path = video_path + self.model_name = model_name + self.signals = WorkerSignals() + + @Slot() + def run(self): + filename = os.path.basename(self.video_path) + self.signals.progress.emit(f"Starting: {filename}") + + try: + cap = cv2.VideoCapture(self.video_path) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + csv_storage_data = [] + pose_cache = str(Path(self.video_path).with_name(Path(self.video_path).stem + "_pose_raw.csv")) + + if os.path.exists(pose_cache): + self.signals.progress.emit(f" - Cache found for {filename}. Skipping inference.") + else: + # Instantiate model INSIDE the thread for safety + model = YOLO(self.model_name) + prev_track_id = None + + for i in range(total_frames): + ret, frame = cap.read() + if not ret: break + + results = model.track(frame, persist=True, verbose=False) + track_id, kp, confs, _ = self.get_best_infant_match(results, width, height, prev_track_id) + + if kp is not None: + prev_track_id = track_id + csv_storage_data.append(np.column_stack((kp, confs))) + else: + csv_storage_data.append(np.zeros((17, 3))) + + if i % 100 == 0: + p = int((i / total_frames) * 100) + self.signals.progress.emit(f" - {filename}: {p}%") + + self._save_pose_cache(pose_cache, csv_storage_data) + + cap.release() + self.signals.progress.emit(f"COMPLETED: {filename}") + except Exception as e: + self.signals.progress.emit(f"ERROR on {filename}: {str(e)}") + finally: + self.signals.finished.emit() + + def get_best_infant_match(self, results, w, h, prev_track_id): + if not results[0].boxes or results[0].boxes.id is None: + return None, None, None, None + + ids = results[0].boxes.id.int().cpu().tolist() + kpts = results[0].keypoints.xy.cpu().numpy() + confs = results[0].keypoints.conf.cpu().numpy() + + best_idx, best_score = -1, -1 + for i, k in enumerate(kpts): + vis = np.sum(confs[i] > 0.5) + valid = k[confs[i] > 0.5] + dist = np.linalg.norm(np.mean(valid, axis=0) - [w/2, h/2]) if len(valid) > 0 else 1000 + score = (vis * 10) - (dist * 0.1) + (50 if ids[i] == prev_track_id else 0) + + if score > best_score: + best_score, best_idx = score, i + + if best_idx == -1: + return None, None, None, None + return ids[best_idx], kpts[best_idx], confs[best_idx], best_idx + + def _save_pose_cache(self, path, data): + with open(path, 'w', newline='') as f: + writer = csv.writer(f) + header = [] + for joint in JOINT_NAMES: + header.extend([f"{joint}_x", f"{joint}_y", f"{joint}_conf"]) + writer.writerow(header) + for frame_data in data: + writer.writerow(frame_data.flatten()) + +class BatchProcessorDialog(QDialog): + def __init__(self, parent=None): + super().__init__(parent) + self.setWindowTitle("Concurrent Batch Video Processor") + self.setMinimumSize(600, 500) + self.threadpool = QThreadPool() + self.active_workers = 0 + self.setup_ui() + + def setup_ui(self): + layout = QVBoxLayout(self) + + # Folder selection + f_lay = QHBoxLayout() + self.folder_edit = QLineEdit() + btn_browse = QPushButton("Browse") + btn_browse.clicked.connect(self.get_folder) + f_lay.addWidget(QLabel("Folder:")) + f_lay.addWidget(self.folder_edit) + f_lay.addWidget(btn_browse) + layout.addLayout(f_lay) + + # Dropdown for YOLO Model + self.model_combo = QComboBox() + # self.model_combo.addItems([ + # "yolov8n-pose.pt", "yolov8s-pose.pt", + # "yolov8m-pose.pt", "yolov8l-pose.pt", "yolov8x-pose.pt" + # ]) + self.model_combo.addItems(["yolov8n-pose.pt"]) + layout.addWidget(QLabel("Select YOLO Model:")) + layout.addWidget(self.model_combo) + + # Int input for Concurrency + self.concurrency_spin = QSpinBox() + self.concurrency_spin.setRange(1, 16) + self.concurrency_spin.setValue(4) + layout.addWidget(QLabel("Concurrent Workers (Threads):")) + layout.addWidget(self.concurrency_spin) + + # Output Log + self.log = QTextEdit() + self.log.setReadOnly(True) + self.log.setStyleSheet("background-color: #1e1e1e; color: #d4d4d4; font-family: Consolas;") + layout.addWidget(self.log) + + # Action Buttons + self.btn_run = QPushButton("Run Concurrent Batch") + self.btn_run.setFixedHeight(40) + self.btn_run.clicked.connect(self.run_batch) + layout.addWidget(self.btn_run) + + def get_folder(self): + path = QFileDialog.getExistingDirectory(self, "Select Folder") + if path: self.folder_edit.setText(path) + + def log_msg(self, msg): + self.log.append(msg) + self.log.verticalScrollBar().setValue(self.log.verticalScrollBar().maximum()) + + def worker_finished(self): + self.active_workers -= 1 + if self.active_workers <= 0: + self.btn_run.setEnabled(True) + self.log_msg("--- All concurrent tasks finished ---") + + def run_batch(self): + folder = self.folder_edit.text() + if not os.path.isdir(folder): + self.log_msg("!!! Error: Invalid directory.") + return + + video_extensions = ('.mp4', '.avi', '.mov', '.mkv') + files = [str(f) for f in Path(folder).iterdir() if f.suffix.lower() in video_extensions] + + if not files: + self.log_msg("No videos found.") + return + + self.btn_run.setEnabled(False) + self.active_workers = len(files) + self.threadpool.setMaxThreadCount(self.concurrency_spin.value()) + + self.log_msg(f"Queueing {len(files)} videos...") + for f_path in files: + worker = VideoWorker(f_path, self.model_combo.currentText()) + worker.signals.progress.connect(self.log_msg) + worker.signals.finished.connect(self.worker_finished) + self.threadpool.start(worker) \ No newline at end of file diff --git a/blazes_updater.py b/blazes_updater.py new file mode 100644 index 0000000..ede519c --- /dev/null +++ b/blazes_updater.py @@ -0,0 +1,257 @@ +""" +Filename: blazes_updater.py +Description: BLAZES updater executable + +Author: Tyler de Zeeuw +License: GPL-3.0 +""" + +# Built-in imports +import os +import sys +import time +import shlex +import psutil +import shutil +import platform +import subprocess +from datetime import datetime + + +PLATFORM_NAME = platform.system().lower() +APP_NAME = "blazes" + +if PLATFORM_NAME == 'darwin': + LOG_FILE = os.path.join(os.path.dirname(sys.executable), f"../../../{APP_NAME}_updater.log") +else: + LOG_FILE = os.path.join(os.getcwd(), f"{APP_NAME}_updater.log") + + +def log(msg): + with open(LOG_FILE, "a", encoding="utf-8") as f: + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + f.write(f"{timestamp} - {msg}\n") + + +def kill_all_processes_by_executable(exe_path): + terminated_any = False + exe_path = os.path.realpath(exe_path) + + if PLATFORM_NAME == 'windows': + for proc in psutil.process_iter(['pid', 'exe']): + try: + proc_exe = proc.info.get('exe') + if proc_exe and os.path.samefile(os.path.realpath(proc_exe), exe_path): + log(f"Terminating process: PID {proc.pid}") + _terminate_process(proc) + terminated_any = True + except Exception as e: + log(f"Error terminating process (Windows): {e}") + elif PLATFORM_NAME == 'linux': + for proc in psutil.process_iter(['pid', 'cmdline']): + try: + cmdline = proc.info.get('cmdline', []) + if cmdline: + proc_cmd = os.path.realpath(cmdline[0]) + if os.path.samefile(proc_cmd, exe_path): + log(f"Terminating process: PID {proc.pid}") + _terminate_process(proc) + terminated_any = True + except Exception as e: + log(f"Error terminating process (Linux): {e}") + + if not terminated_any: + log(f"No running processes found for {exe_path}") + return terminated_any + + +def _terminate_process(proc): + try: + proc.terminate() + proc.wait(timeout=10) + log(f"Process {proc.pid} terminated gracefully.") + except psutil.TimeoutExpired: + log(f"Process {proc.pid} did not terminate in time. Killing forcefully.") + proc.kill() + proc.wait(timeout=5) + log(f"Process {proc.pid} killed.") + + +def wait_for_unlock(path, timeout=100): + start_time = time.time() + while time.time() - start_time < timeout: + try: + if os.path.isdir(path): + shutil.rmtree(path) + else: + os.remove(path) + log(f"Deleted (after wait): {path}") + return + except Exception as e: + log(f"Still locked: {path} - {e}") + time.sleep(1) + log(f"Failed to delete after wait: {path}") + + +def delete_path(path): + if os.path.exists(path): + try: + if os.path.isdir(path): + shutil.rmtree(path) + log(f"Deleted directory: {path}") + else: + os.remove(path) + log(f"Deleted file: {path}") + except Exception as e: + log(f"Error deleting {path}: {e}") + + +def copy_update_files(src_folder, dest_folder, updater_name): + for item in os.listdir(src_folder): + if item.lower() == updater_name.lower(): + log(f"Skipping updater executable: {item}") + continue + s = os.path.join(src_folder, item) + d = os.path.join(dest_folder, item) + delete_path(d) + try: + if os.path.isdir(s): + shutil.copytree(s, d) + log(f"Copied folder: {s} -> {d}") + else: + shutil.copy2(s, d) + log(f"Copied file: {s} -> {d}") + except Exception as e: + log(f"Error copying {s} -> {d}: {e}") + + +def copy_update_files_darwin(src_folder, dest_folder, updater_name): + + updater_name = updater_name + ".app" + + for item in os.listdir(src_folder): + if item.lower() == updater_name.lower(): + log(f"Skipping updater executable: {item}") + continue + s = os.path.join(src_folder, item) + d = os.path.join(dest_folder, item) + delete_path(d) + try: + if os.path.isdir(s): + subprocess.check_call(["ditto", s, d]) + log(f"Copied folder with ditto: {s} -> {d}") + else: + shutil.copy2(s, d) + log(f"Copied file: {s} -> {d}") + except Exception as e: + log(f"Error copying {s} -> {d}: {e}") + + +def remove_quarantine(app_path): + script = f''' + do shell script "xattr -d -r com.apple.quarantine {shlex.quote(app_path)}" with administrator privileges with prompt "{APP_NAME} needs privileges to finish the update. (1/2)" + ''' + try: + subprocess.run(['osascript', '-e', script], check=True) + print("✅ Quarantine attribute removed.") + except subprocess.CalledProcessError as e: + print("❌ Failed to remove quarantine attribute.") + print(e) + + +def main(): + try: + log(f"[Updater] sys.argv: {sys.argv}") + + if len(sys.argv) != 3: + log(f"Invalid arguments. Usage: {APP_NAME}_updater ") + sys.exit(1) + + update_folder = sys.argv[1] + main_exe = sys.argv[2] + + # Interesting naming convention + parent_dir = os.path.dirname(os.path.abspath(main_exe)) + pparent_dir = os.path.dirname(parent_dir) + ppparent_dir = os.path.dirname(pparent_dir) + pppparent_dir = os.path.dirname(ppparent_dir) + + updater_name = os.path.basename(sys.argv[0]) + + log("Updater started.") + log(f"Update folder: {update_folder}") + log(f"Main EXE: {main_exe}") + log(f"Updater EXE: {updater_name}") + if PLATFORM_NAME == 'darwin': + log(f"Main App Folder: {ppparent_dir}") + + # Kill all instances of main app + kill_all_processes_by_executable(main_exe) + + # Wait until main_exe process is fully gone (polling) + for _ in range(20): # wait max 10 seconds + running = False + for proc in psutil.process_iter(['exe', 'cmdline']): + try: + if PLATFORM_NAME == 'windows': + proc_exe = proc.info.get('exe') + if proc_exe and os.path.samefile(os.path.realpath(proc_exe), os.path.realpath(main_exe)): + running = True + break + elif PLATFORM_NAME == 'linux': + cmdline = proc.info.get('cmdline', []) + if cmdline: + proc_cmd = os.path.realpath(cmdline[0]) + if os.path.samefile(proc_cmd, os.path.realpath(main_exe)): + running = True + break + except Exception as e: + log(f"Polling error: {e}") + if not running: + break + time.sleep(0.5) + else: + log("Warning: main executable still running after wait timeout.") + + # Delete old version files + if PLATFORM_NAME == 'darwin': + log(f'Attempting to delete {ppparent_dir}') + delete_path(ppparent_dir) + update_folder = os.path.join(sys.argv[1], f"{APP_NAME}-darwin") + copy_update_files_darwin(update_folder, pppparent_dir, updater_name) + + else: + delete_path(main_exe) + wait_for_unlock(os.path.join(parent_dir, "_internal")) + + # Copy new files excluding the updater itself + copy_update_files(update_folder, parent_dir, updater_name) + + except Exception as e: + log(f"Something went wrong: {e}") + + # Relaunch main app + try: + if PLATFORM_NAME == 'linux': + os.chmod(main_exe, 0o755) + log("Added executable bit") + + if PLATFORM_NAME == 'darwin': + os.chmod(ppparent_dir, 0o755) + log("Added executable bit") + remove_quarantine(ppparent_dir) + log(f"Removed the quarantine flag on {ppparent_dir}") + subprocess.Popen(['open', ppparent_dir, "--args", "--finish-update"]) + else: + subprocess.Popen([main_exe, "--finish-update"], cwd=parent_dir) + + log("Relaunched main app.") + except Exception as e: + log(f"Failed to relaunch main app: {e}") + + log("Updater completed. Exiting.") + sys.exit(0) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..c6a61ca --- /dev/null +++ b/main.py @@ -0,0 +1,2384 @@ +""" +Filename: main.py +Description: BLAZES main executable + +Author: Tyler de Zeeuw +License: GPL-3.0 +""" + +# Built-in imports +import os +import csv +import sys +import json +import glob +import shutil +import inspect +import platform +import traceback +from pathlib import Path +from datetime import datetime +from multiprocessing import current_process, freeze_support + +# External library imports +import numpy as np +import pandas as pd +import psutil +import joblib +import cv2 +from ultralytics import YOLO + +from updater import finish_update_if_needed, UpdateManager, LocalPendingUpdateCheckThread +from predictor import GeneralPredictor +from batch_processing import BatchProcessorDialog + +import PySide6 +from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QGraphicsView, QGraphicsScene, + QHBoxLayout, QSplitter, QLabel, QPushButton, QComboBox, QInputDialog, + QFileDialog, QScrollArea, QMessageBox, QSlider, QTextEdit) +from PySide6.QtCore import Qt, QThread, Signal, QUrl, QRectF, QPointF, QRect, QSizeF +from PySide6.QtGui import QPainter, QColor, QFont, QPen, QBrush, QAction, QKeySequence, QIcon, QTextOption +from PySide6.QtMultimedia import QMediaPlayer, QAudioOutput +from PySide6.QtMultimediaWidgets import QGraphicsVideoItem + + +VERBOSITY = 1 +CURRENT_VERSION = "0.1.0" +APP_NAME = "blazes" +API_URL = f"https://git.research.dezeeuw.ca/api/v1/repos/tyler/{APP_NAME}/releases" +API_URL_SECONDARY = f"https://git.research2.dezeeuw.ca/api/v1/repos/tyler/{APP_NAME}/releases" +PLATFORM_NAME = platform.system().lower() + + + +def debug_print(): + if VERBOSITY: + frame = inspect.currentframe().f_back + qualname = frame.f_code.co_qualname + print(qualname) + + +# Ordered according to YOLO docs: https://docs.ultralytics.com/tasks/pose/ +JOINT_NAMES = [ + "Nose", "Left Eye", "Right Eye", "Left Ear", "Right Ear", + "Left Shoulder", "Right Shoulder", "Left Elbow", "Right Elbow", + "Left Wrist", "Right Wrist", "Left Hip", "Right Hip", + "Left Knee", "Right Knee", "Left Ankle", "Right Ankle" + ] + + +# Needs to be pointed to the FFmpeg bin folder containing avcodec-*.dll, etc. +pyside_dir = Path(PySide6.__file__).parent +if sys.platform == "win32": + # Tell Python 3.13+ where to find the FFmpeg DLLs bundled with PySide + os.add_dll_directory(str(pyside_dir)) + + +TRACK_NAMES = ["Baseline", "Live Skeleton"] + JOINT_NAMES +NUM_TRACKS = len(TRACK_NAMES) + +# TODO: Improve colors? +# Generate distinct colors for the tracks +BASE_COLORS = [QColor(180, 180, 180), QColor(0, 0, 0)] # Grey for Baseline, Black for Live +REMAINING_COLORS = [QColor.fromHsv(int((i / (NUM_TRACKS-2)) * 359), 200, 255) for i in range(NUM_TRACKS-2)] +TRACK_COLORS = BASE_COLORS + REMAINING_COLORS + + + + +class AboutWindow(QWidget): + """ + Simple About window displaying basic application information. + + Args: + parent (QWidget, optional): Parent widget of this window. Defaults to None. + """ + + def __init__(self, parent=None): + super().__init__(parent, Qt.WindowType.Window) + self.setWindowTitle(f"About {APP_NAME.upper()}") + self.resize(250, 100) + self.setStyleSheet(""" + QVBoxLayout, QWidget { + background-color: #1e1e1e; + } + QLabel { + color: #ffffff; + } + """) + + layout = QVBoxLayout() + label = QLabel(f"About {APP_NAME.upper()}", self) + label2 = QLabel("Behavioral Learning & Automated Zoned Events Suite", self) + label3 = QLabel(f"{APP_NAME.upper()} is licensed under the GPL-3.0 licence. For more information, visit https://www.gnu.org/licenses/gpl-3.0.en.html", self) + label4 = QLabel(f"Version v{CURRENT_VERSION}") + + layout.addWidget(label) + layout.addWidget(label2) + layout.addWidget(label3) + layout.addWidget(label4) + + self.setLayout(layout) + + + +class UserGuideWindow(QWidget): + """ + Simple User Guide window displaying basic information on how to use the software. + + Args: + parent (QWidget, optional): Parent widget of this window. Defaults to None. + """ + + def __init__(self, parent=None): + super().__init__(parent, Qt.WindowType.Window) + self.setWindowTitle(f"User Guide - {APP_NAME.upper()}") + self.resize(250, 100) + self.setStyleSheet(""" + QVBoxLayout, QWidget { + background-color: #1e1e1e; + } + QLabel { + color: #ffffff; + } + """) + + layout = QVBoxLayout() + label = QLabel("Hmmm...", self) + label2 = QLabel("Nothing to see here yet.", self) + + label3 = QLabel(f"For more information, visit the Git wiki page here.", self) + label3.setTextFormat(Qt.TextFormat.RichText) + label3.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction) + label3.setOpenExternalLinks(True) + layout.addWidget(label) + layout.addWidget(label2) + layout.addWidget(label3) + + self.setLayout(layout) + + + +class PoseAnalyzerWorker(QThread): + progress = Signal(str) + finished_data = Signal(dict) + + def __init__(self, video_path, obs_info=None, predictor=None): + debug_print() + super().__init__() + self.video_path = video_path + self.obs_info = obs_info + self.predictor = predictor + self.pose_df = pd.DataFrame() + + + def get_best_infant_match(self, results, w, h, prev_track_id): + debug_print() + if not results[0].boxes or results[0].boxes.id is None: + return None, None, None, None + ids = results[0].boxes.id.int().cpu().tolist() + kpts = results[0].keypoints.xy.cpu().numpy() + confs = results[0].keypoints.conf.cpu().numpy() + best_idx, best_score = -1, -1 + for i, k in enumerate(kpts): + vis = np.sum(confs[i] > 0.5) + valid = k[confs[i] > 0.5] + dist = np.linalg.norm(np.mean(valid, axis=0) - [w/2, h/2]) if len(valid) > 0 else 1000 + score = (vis * 10) - (dist * 0.1) + (50 if ids[i] == prev_track_id else 0) + if score > best_score: + best_score, best_idx = score, i + if best_idx == -1: + return None, None, None, None + return ids[best_idx], kpts[best_idx], confs[best_idx], best_idx + + + def _merge_json_observations(self, timeline_events, fps): + """Restores the grouping and block-pairing logic from the observation files.""" + debug_print() + if not self.obs_info: + return + + self.progress.emit("Merging JSON Observations...") + json_path, subkey = self.obs_info + + # try: + # with open(json_path, 'r') as f: + # full_json = json.load(f) + + # # Extract events for the specific subkey (e.g., 'Participant_01') + # raw_obs_events = full_json["observations"][subkey]["events"] + # raw_obs_events.sort(key=lambda x: x[0]) # Sort by timestamp + + # # Group frames by label + # obs_groups = {} + # for ev in raw_obs_events: + # time_sec, _, label, special = ev[0], ev[1], ev[2], ev[3] + # frame = int(time_sec * fps) + # if label not in obs_groups: + # obs_groups[label] = [] + # obs_groups[label].append(frame) + + # # Convert groups of frames into (Start, End) blocks + # for label, frames in obs_groups.items(): + # track_name = f"OBS: {label}" + # processed_blocks = [] + + # # Step by 2 to create start/end pairs + # for i in range(0, len(frames) - 1, 2): + # start_f = frames[i] + # end_f = frames[i+1] + # processed_blocks.append((start_f, end_f, "Moderate", "Manual")) + + # # Register the track globally if it's new + # if track_name not in TRACK_NAMES: + # TRACK_NAMES.append(track_name) + # TRACK_COLORS.append(QColor("#AA00FF")) # Purple for Observations + + # timeline_events[track_name] = processed_blocks + + # except Exception as e: + # print(f"Error parsing JSON Observations: {e}") + + + try: + with open(json_path, 'r') as f: + full_json = json.load(f) + + raw_obs_events = full_json["observations"][subkey]["events"] + raw_obs_events.sort(key=lambda x: x[0]) + + # NEW LOGIC: Use a dictionary to store frames for specific track names + # track_name -> [list of frames] + obs_groups = {} + + for ev in raw_obs_events: + # ev structure: [time_sec, unknown, label, special] + time_sec, label, special = ev[0], ev[2], ev[3] + frame = int(time_sec * fps) + + # Determine which tracks this event belongs to + target_tracks = [] + + if special == "Left": + target_tracks.append(f"OBS: {label} (Left)") + elif special == "Right": + target_tracks.append(f"OBS: {label} (Right)") + elif special == "Both": + target_tracks.append(f"OBS: {label} (Left)") + target_tracks.append(f"OBS: {label} (Right)") + else: + # No special or unrecognized value + target_tracks.append(f"OBS: {label}") + + # Add the frame to all applicable tracks + for t_name in target_tracks: + if t_name not in obs_groups: + obs_groups[t_name] = [] + obs_groups[t_name].append(frame) + + # Convert frame groups into (Start, End) blocks + for track_name, frames in obs_groups.items(): + processed_blocks = [] + + # Step by 2 to create start/end pairs (ensures matching pairs per track) + + if "Sync" in track_name and len(frames) == 1: + start_f = frames[0] + end_f = start_f + 1 # Give it a visible width on the timeline + processed_blocks.append((start_f, end_f, "Moderate", "Manual")) + + else: + for i in range(0, len(frames) - 1, 2): + start_f = frames[i] + end_f = frames[i+1] + processed_blocks.append((start_f, end_f, "Moderate", "Manual")) + + # Register the track in global lists if not already there + if track_name not in TRACK_NAMES: + TRACK_NAMES.append(track_name) + # Using Purple for Observations + TRACK_COLORS.append(QColor("#AA00FF")) + + timeline_events[track_name] = processed_blocks + + except Exception as e: + print(f"Error parsing JSON Observations: {e}") + + + def _run_existing_ml_models(self, z_kps, dirs, raw_kpts): + debug_print() + """ + Scans for trained models and generates timeline tracks for each. + """ + ai_events = {} + + # 1. Match the pattern from your GeneralPredictor: {Target}_rf.pkl + model_files = glob.glob("*_rf.pkl") + print(f"DEBUG: Found model files: {model_files}") + + for model_path in model_files: + try: + # Extract Target (e.g., "Mouthing" from "Mouthing_rf.pkl") + base_name = model_path.split("_rf.pkl")[0] + target = base_name.replace("ml_", "", 1) + track_name = f"AI: {target}" + + self.progress.emit(f"Loading AI Observations for {target}...") + + + # 2. Match the Scaler naming from calculate_and_train: + # {target}_random_forest_scaler.pkl + scaler_path = f"{base_name}_rf_scaler.pkl" + + if not os.path.exists(scaler_path): + print(f"DEBUG: Skipping {target}, scaler not found at {scaler_path}") + continue + + # Load assets + model = joblib.load(model_path) + scaler = joblib.load(scaler_path) + + # 3. Feature extraction (On-the-fly) + all_features = [] + # We must set the predictor's target so format_features uses the correct ACTIVITY_MAP + self.predictor.current_target = target + + for f_idx in range(len(z_kps)): + feat = self.predictor.format_features(z_kps[f_idx], dirs[f_idx], raw_kpts[f_idx]) + all_features.append(feat) + + # 4. Inference + X = np.array(all_features) + X_scaled = scaler.transform(X) + predictions = model.predict(X_scaled) + + # 5. Convert binary 0/1 to blocks + processed_blocks = [] + start_f = None + + for f_idx, val in enumerate(predictions): + if val == 1 and start_f is None: + start_f = f_idx + elif val == 0 and start_f is not None: + # [start, end, severity, direction] + processed_blocks.append((start_f, f_idx - 1, "Large", "AI")) + start_f = None + + if start_f is not None: + processed_blocks.append((start_f, len(predictions)-1, "Large", "AI")) + + # 6. Global Registration + if track_name not in TRACK_NAMES: + TRACK_NAMES.append(track_name) + # Ensure TRACK_COLORS has an entry for this new track + TRACK_COLORS.append(QColor("#00FF00")) + + ai_events[track_name] = processed_blocks + print(f"DEBUG: Successfully generated {len(processed_blocks)} blocks for {track_name}") + + except Exception as e: + print(f"Inference Error for {model_path}: {e}") + + return ai_events + + + def classify_delta(self, z): + # debug_print() + z_abs = abs(z) + if z_abs < 1: return "Rest" + elif z_abs < 2: return "Small" + elif z_abs < 3: return "Moderate" + else: return "Large" + + + def _save_pose_cache(self, path, data): + """ + Saves the raw YOLO keypoints and confidence scores to a CSV. + Each row represents one frame, flattened from (17, 3) to (51,). + """ + try: + with open(path, 'w', newline='') as f: + writer = csv.writer(f) + + # Create the descriptive header + header = [] + for joint in JOINT_NAMES: + # Replace spaces with underscores for better compatibility with other tools + header.extend([f"{joint}_x", f"{joint}_y", f"{joint}_conf"]) + + writer.writerow(header) + + # Write the frame data + for frame_data in data: + # frame_data is (17, 3), flatten to (51,) + writer.writerow(frame_data.flatten()) + + print(f"DEBUG: Pose cache saved with joint headers at {path}") + except Exception as e: + print(f"ERROR: Could not save pose cache: {e}") + + + def run(self): + debug_print() + # --- PHASE 1: VIDEO SETUP & POSE EXTRACTION --- + cap = cv2.VideoCapture(self.video_path) + fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + raw_kps_per_frame = [] + csv_storage_data = [] + valid_mask = [] + pose_cache = self.video_path.rsplit('.', 1)[0] + "_pose_raw.csv" + + if os.path.exists(pose_cache): + self.progress.emit("Loading cached kinematic data...") + with open(pose_cache, 'r') as f: + reader = csv.reader(f) + next(reader) + for row in reader: + full_data = np.array([float(x) for x in row]).reshape(17, 3) + kp = full_data[:, :2] + raw_kps_per_frame.append(kp) + csv_storage_data.append(full_data) + valid_mask.append(np.any(kp)) + else: + self.progress.emit("Detecting poses with YOLO...") + model = YOLO("yolov8n-pose.pt") + prev_track_id = None + for i in range(total_frames): + ret, frame = cap.read() + if not ret: break + results = model.track(frame, persist=True, verbose=False) + track_id, kp, confs, _ = self.get_best_infant_match(results, width, height, prev_track_id) + if kp is not None: + prev_track_id = track_id + raw_kps_per_frame.append(kp) + csv_storage_data.append(np.column_stack((kp, confs))) + valid_mask.append(True) + else: + raw_kps_per_frame.append(np.zeros((17, 2))) + csv_storage_data.append(np.zeros((17, 3))) + valid_mask.append(False) + if i % 50 == 0: self.progress.emit(f"YOLO: {int((i/total_frames)*100)}%") + self._save_pose_cache(pose_cache, csv_storage_data) + + cap.release() + actual_len = len(raw_kps_per_frame) + + flattened_rows = [] + for frame_array in csv_storage_data: + # frame_array is (17, 3) -> flatten to (51,) + flattened_rows.append(frame_array.flatten()) + + columns = [] + for name in JOINT_NAMES: + columns.extend([f"{name}_x", f"{name}_y", f"{name}_conf"]) + + # Store this so the Inspector can access it instantly in memory + self.pose_df = pd.DataFrame(flattened_rows, columns=columns) + + # --- PHASE 2: KINEMATICS & Z-SCORES --- + self.progress.emit("Calculating Kinematics...") + analysis_kpts = [] + for kp in raw_kps_per_frame: + pelvis = (kp[11] + kp[12]) / 2 + analysis_kpts.append(kp - pelvis) + + valid_data = [analysis_kpts[i] for i, v in enumerate(valid_mask) if v] + if valid_data: + stacked = np.stack(valid_data) + baseline_mean = np.mean(stacked, axis=0) + baseline_std = np.std(np.linalg.norm(stacked - baseline_mean, axis=2), axis=0) + 1e-6 + else: + baseline_mean, baseline_std = np.zeros((17, 2)), np.ones(17) + + np_raw_kps = np.array(raw_kps_per_frame) + np_z_kps = np.array([np.linalg.norm(kp - baseline_mean, axis=1) / baseline_std for kp in analysis_kpts]) + + # Calculate directions (Assume you have a method for this or use a dummy for now) + # Using placeholder empty strings to prevent errors in track generation + np_dirs = np.full((actual_len, 17), "", dtype=object) + + # --- PHASE 3: TIMELINE GENERATION --- + # Initialize dictionary with ALL global track names to prevent KeyErrors + timeline_events = {name: [] for name in TRACK_NAMES} + + # 1. Kinematic Events (The joint tracks) + for j_idx, joint_name in enumerate(JOINT_NAMES): + current_block = None + for f_idx in range(actual_len): + severity = self.classify_delta(np_z_kps[f_idx, j_idx]) + if severity != "Rest": + if current_block and current_block[2] == severity: + current_block[1] = f_idx + else: + current_block = [f_idx, f_idx, severity, ""] + timeline_events[joint_name].append(current_block) + else: + current_block = None + + # 2. JSON Observations + self._merge_json_observations(timeline_events, fps) + + # 3. AI Inferred Events + ai_events = self._run_existing_ml_models(np_z_kps, np_dirs, np_raw_kps) + timeline_events.update(ai_events) + + # --- PHASE 4: EMIT --- + data = { + "video_path": self.video_path, + "fps": fps, + "total_frames": actual_len, + "width": width, "height": height, + "events": timeline_events, + "raw_kps": np_raw_kps, + "z_kps": np_z_kps, + "directions": np_dirs, + "baseline_kp_mean": baseline_mean + } + self.progress.emit("Analysis Complete!") + self.finished_data.emit(data) + + + + + + + + + + +# ========================================== +# TIMELINE WIDGET +# ========================================== +class TimelineWidget(QWidget): + seek_requested = Signal(int) + visibility_changed = Signal(set) + track_selected = Signal(str) + + def __init__(self): + debug_print() + super().__init__() + self.data = None + self.current_frame = 0 + self.zoom_factor = 1.0 # Pixels per frame + self.label_width = 160 # Fixed gutter for track names + self.track_height = 25 + self.ruler_height = 20 + self.scrollbar_buffer = 2 # Extra space for the horizontal scrollbar + self.hidden_tracks = set() + self.sync_offset = 0.0 + self.sync_fps = 30.0 + # Calculate total required height + self.total_content_height = (NUM_TRACKS * self.track_height) + self.ruler_height + self.setMinimumHeight(self.total_content_height + self.scrollbar_buffer) + + + def set_sync_params(self, offset_seconds, fps=None): + """ + Updates the temporal shift parameters and refreshes the UI. + """ + debug_print() + self.sync_offset = float(offset_seconds) + + # Only update FPS if a valid value is provided, + # otherwise keep the existing data/video FPS + if fps and fps > 0: + self.sync_fps = float(fps) + elif self.data and "fps" in self.data: + self.sync_fps = float(self.data["fps"]) + + print(f"DEBUG: Timeline Sync Set - Offset: {self.sync_offset}s, FPS: {self.sync_fps}") + + # Trigger paintEvent to redraw the blocks in their new shifted positions + self.update() + + + def set_zoom(self, factor): + debug_print() + if not self.data: return + + # Calculate MIN zoom: The zoom required to make the video fit the width exactly + # (Available Width - Sidebar) / Total Frames + available_w = self.parent().width() - self.label_width if self.parent() else 800 + min_zoom = available_w / self.data["total_frames"] + + # Clamp: Don't zoom out past the video end, don't zoom in to infinity + self.zoom_factor = max(min_zoom, min(factor, 50.0)) + self.update_geometry() + + + def get_all_binary_labels(self, offset_seconds=0.0, fps=30.0): + """ + Extracts binary labels for ALL tracks in self.data["events"]. + Returns a dict: {'OBS: Mouthing': [0,1,0...], 'OBS: Kicking': [0,0,1...]} + """ + debug_print() + all_labels = {} + + if not self.data or "events" not in self.data: + return all_labels + + total_frames = self.data.get("total_frames", 0) + if total_frames == 0: + return all_labels + + frame_shift = int(offset_seconds * fps) + + for track_name in self.data["events"]: + sequence = np.zeros(total_frames, dtype=int) + + for event in self.data["events"][track_name]: + start_f = int(event[0]) - frame_shift + end_f = int(event[1]) - frame_shift + + # Clamp values + start_idx = max(0, min(start_f, total_frames - 1)) + end_idx = max(0, min(end_f, total_frames)) + + if start_idx < end_idx: + sequence[start_idx:end_idx] = 1 + + all_labels[track_name] = sequence.tolist() # Convert to list for easier storage + + return all_labels + + + def update_geometry(self): + debug_print() + + if self.data: + # Width is sidebar + (frames * zoom) + total_w = self.label_width + int(self.data["total_frames"] * self.zoom_factor) + self.setFixedWidth(total_w) + self.update() + + def wheelEvent(self, event): + debug_print() + + if event.modifiers() == Qt.ControlModifier: + delta = event.angleDelta().y() + # Zoom by 10% per notch + zoom_change = 1.1 if delta > 0 else 0.9 + self.set_zoom(self.zoom_factor * zoom_change) + else: + # Let the scroll area handle normal vertical scrolling + super().wheelEvent(event) + + # --- NEW: CTRL + Plus / Minus / Zero --- + def keyPressEvent(self, event): + debug_print() + + if event.modifiers() == Qt.ControlModifier: + if event.key() == Qt.Key_Plus or event.key() == Qt.Key_Equal: + self.set_zoom(self.zoom_factor * 1.2) + elif event.key() == Qt.Key_Minus: + self.set_zoom(self.zoom_factor * 0.8) + elif event.key() == Qt.Key_0: + self.set_zoom(1.0) # Reset zoom + else: + super().keyPressEvent(event) + + def set_data(self, data): + debug_print() + + self.data = data + self.total_content_height = (len(TRACK_NAMES) * self.track_height) + self.ruler_height + self.setMinimumHeight(self.total_content_height + self.scrollbar_buffer) + self.update_geometry() + + def set_playhead(self, frame): + debug_print() + old_x = self.label_width + (self.current_frame * self.zoom_factor) + self.current_frame = frame + new_x = self.label_width + (self.current_frame * self.zoom_factor) + self.ensure_playhead_visible() + self.update(int(old_x - 5), 0, 10, self.height()) + self.update(int(new_x - 5), 0, 10, self.height()) + + def ensure_playhead_visible(self): + debug_print() + + """Auto-scrolls the scroll area if the playhead leaves the viewport.""" + # Find the QScrollArea parent + scroll_area = self.parent().parent() + if not isinstance(scroll_area, QScrollArea): return + + scrollbar = scroll_area.horizontalScrollBar() + view_width = scroll_area.viewport().width() + + # Playhead position in pixels + px = self.label_width + int(self.current_frame * self.zoom_factor) + + # Current scroll position + scroll_x = scrollbar.value() + + # If playhead is beyond the right edge of visible area + if px > (scroll_x + view_width): + # Shift scroll so playhead is at the left (plus sidebar) + scrollbar.setValue(px - self.label_width) + + # If playhead is behind the left edge (e.g. user seeked backwards) + elif px < (scroll_x + self.label_width): + scrollbar.setValue(px - self.label_width) + + def mousePressEvent(self, event): + debug_print() + + if not self.data or event.button() != Qt.LeftButton: + return + + pos_x = event.position().x() + pos_y = event.position().y() + scroll_area = self.parent().parent() + scroll_x = scroll_area.horizontalScrollBar().value() + + # 1. CALCULATE FRAME + relative_x = pos_x - self.label_width + frame = int(relative_x / self.zoom_factor) + frame = max(0, min(frame, self.data["total_frames"] - 1)) + + # 2. IF CLICKED SIDEBAR: Toggle Visibility (No Scrubbing) + if pos_x < scroll_x + self.label_width: + relative_y = pos_y - self.ruler_height + track_idx = int(relative_y // self.track_height) + if 0 <= track_idx < len(TRACK_NAMES): + name = TRACK_NAMES[track_idx] + if name in self.hidden_tracks: self.hidden_tracks.remove(name) + else: self.hidden_tracks.add(name) + self.visibility_changed.emit(self.hidden_tracks) + self.update() + return # Exit early; don't set is_scrubbing + + # 3. IF CLICKED RULER OR DATA AREA: Start Scrubbing + self.is_scrubbing = True + self.seek_requested.emit(frame) + + # Handle track selection if in the data area + if pos_y >= self.ruler_height: + track_idx = int((pos_y - self.ruler_height) // self.track_height) + if 0 <= track_idx < len(TRACK_NAMES): + self.track_selected.emit(TRACK_NAMES[track_idx]) + self.selected_track_idx = track_idx + self.update() + else: + # Clicked ruler + self.selected_track_idx = -1 + self.track_selected.emit("") + self.update() + + def mouseMoveEvent(self, event): + debug_print() + + # This only fires while moving if a button is held down by default + if self.is_scrubbing: + self.update_frame_from_mouse(event.position().x()) + + def mouseReleaseEvent(self, event): + debug_print() + + if event.button() == Qt.LeftButton: + self.is_scrubbing = False + + def update_frame_from_mouse(self, x_pos): + """Helper to calculate frame and emit the seek signal.""" + debug_print() + relative_x = x_pos - self.label_width + frame = int(relative_x / self.zoom_factor) + frame = max(0, min(frame, self.data["total_frames"] - 1)) + + # We emit seek_requested so the Video Player and Premiere class sync up + self.seek_requested.emit(frame) + + + def paintEvent(self, event): + debug_print() + if not self.data: return + + dirty_rect = event.rect() + painter = QPainter(self) + + # 1. Determine current scroll position to keep labels sticky + scroll_area = self.parent().parent() + scroll_x = 0 + if isinstance(scroll_area, QScrollArea): + scroll_x = scroll_area.horizontalScrollBar().value() + + w, h = self.width(), self.height() + total_f = self.data["total_frames"] + fps = self.data.get("fps", 30) + offset_y = 20 + + # 2. DRAW DATA AREA (Events and Playhead) + # --- 2. DRAW DATA AREA (Muted Patterns + Events + Playhead) --- + sync_off = getattr(self, "sync_offset", 0.0) + sync_fps = getattr(self, "sync_fps", fps) + frame_shift = int(sync_off * sync_fps) + + for i, name in enumerate(TRACK_NAMES): + y = offset_y + (i * self.track_height) + is_hidden = name in self.hidden_tracks + + if y + self.track_height < dirty_rect.top() or y > dirty_rect.bottom(): + continue + + # A. Draw Muted/Disabled Background Pattern + + if is_hidden: + # Calculate the visible rectangle for this track to the right of the sidebar + mute_rect = QRectF(self.label_width, y, w - self.label_width, self.track_height) + # Fill with a dark "disabled" base + painter.fillRect(mute_rect, QColor(40, 40, 40)) + # Add the Cross-Hatch Pattern + pattern_brush = QBrush(QColor(60, 60, 60, 100), Qt.DiagCrossPattern) + painter.fillRect(mute_rect, pattern_brush) + + # B. Draw Event Blocks + base_color = TRACK_COLORS[i] + for start_f, end_f, severity, direction in self.data["events"][name]: + # x_start = self.label_width + (start_f * self.zoom_factor) + # x_end = self.label_width + (end_f * self.zoom_factor) + + if "AI:" in name: + # AI predictions are already calculated in video-time, NO SHIFT + shifted_start, shifted_end = start_f, end_f + + else: + shifted_start = start_f - frame_shift + shifted_end = end_f - frame_shift + + x_start = self.label_width + (shifted_start * self.zoom_factor) + x_end = self.label_width + (shifted_end * self.zoom_factor) + + # Performance optimization: skip drawing if off-screen + if x_end < scroll_x or x_start > scroll_x + w: + continue + + if x_end < dirty_rect.left() or x_start > dirty_rect.right(): + continue + + # If hidden, make the event block very faint/transparent + if is_hidden: + color = QColor(120, 120, 120, 40) # Muted Grey + else: + alpha = 80 if severity == "Small" else 160 if severity == "Moderate" else 255 + color = QColor(base_color) + color.setAlpha(alpha) + + painter.fillRect(QRectF(x_start, y + 2, max(1, x_end - x_start), self.track_height - 4), color) + # Draw Playhead + playhead_x = self.label_width + (self.current_frame * self.zoom_factor) + painter.setPen(QPen(QColor(255, 0, 0), 2)) + painter.drawLine(playhead_x, 0, playhead_x, h) + + # 3. DRAW STICKY SIDEBAR (Pinned to the left edge of the viewport) + # We draw this AFTER the data so it covers the blocks as they scroll past + sidebar_rect = QRect(scroll_x, 0, self.label_width, h) + painter.fillRect(sidebar_rect, QColor(30, 30, 30)) # Solid background + + # Ruler segment for the sidebar area + painter.fillRect(scroll_x, 0, self.label_width, offset_y, QColor(45, 45, 45)) + + for i, name in enumerate(TRACK_NAMES): + y = offset_y + (i * self.track_height) + is_hidden = name in self.hidden_tracks + # Grid Line + painter.setPen(QColor(60, 60, 60)) + painter.drawLine(scroll_x, y, scroll_x + w, y) + + # Sticky Label Text + if is_hidden: + # Very dark grey to show it's "OFF" + text_color = QColor(70, 70, 70) + else: + # Bright white/grey to show it's "ON" + text_color = QColor(220, 220, 220) + + painter.setPen(text_color) + painter.setFont(QFont("Arial", 8, QFont.Bold)) + painter.drawText(scroll_x + 10, y + 17, name) + + # 4. DRAW TIME RULER TICKS (Right of the sticky sidebar) + target_spacing_px = 120 + + # Available units in frames: 1, 5, 15, 30 (1s), 150 (5s), 300 (10s), 1800 (1min) + possible_units = [1, 5, 15, 30, 150, 300, 900, 1800] + + # Find the smallest unit that results in at least target_spacing_px + tick_interval = possible_units[-1] + for unit in possible_units: + if (unit * self.zoom_factor) >= target_spacing_px: + tick_interval = unit + break + + # 2. DRAW BACKGROUNDS + painter.fillRect(0, 0, w, 20, QColor(45, 45, 45)) # Ruler Bar + + # 3. DRAW TICKS AND TIME LABELS + painter.setPen(QColor(180, 180, 180)) + painter.setFont(QFont("Segoe UI", 7)) + + # Sub-ticks (draw 5 small lines for every 1 major interval) + sub_interval = max(1, tick_interval // 5) + + # Start loop from 0 to total frames + for f in range(0, total_f + 1, sub_interval): + x = self.label_width + int(f * self.zoom_factor) + + # Optimization: Don't draw if off-screen + if x < scroll_x: continue + if x > scroll_x + w: break + + if f % tick_interval == 0: + # Major Tick + painter.drawLine(x, 10, x, 20) + + # Format Label: MM:SS or SS:FF + total_seconds = f / fps + minutes = int(total_seconds // 60) + seconds = int(total_seconds % 60) + frames = int(f % fps) + + if tick_interval < fps: + time_str = f"{seconds:02d}:{frames:02d}f" + elif minutes > 0: + time_str = f"{minutes:02d}m:{seconds:02d}s" + else: + time_str = f"{seconds}s" + + painter.drawText(x + 4, 12, time_str) + else: + # Minor Tick + painter.drawLine(x, 16, x, 20) + + + def get_ai_extractions(self): + """ + Processes timeline data for AI tracks and specific OBS sync events. + """ + debug_print() + fps = self.data.get("fps", 30.0) + + extraction_data = { + "metadata": { + "fps": fps, + "total_frames": self.data.get("total_frames", 0), + "track_summaries": {} + }, + "obs": {}, # New top-level key for specific OBS events + "ai_tracks": {} + } + + if not self.data or "events" not in self.data: + return extraction_data + + # 1. Extract the OBS: Time Sync event specifically + sync_key = "OBS: Time Sync" + if sync_key in self.data["events"]: + sync_blocks = self.data["events"][sync_key] + # Convert blocks to a list of dicts for the JSON + extraction_data["obs"][sync_key] = [ + { + "start_frame": b[0], + "end_frame": b[1], + "start_time_sec": round(b[0] / fps, 3), + "end_time_sec": round(b[1] / fps, 3) + } for b in sync_blocks + ] + + # 2. Process AI Tracks + for track_name, blocks in self.data["events"].items(): + if track_name.startswith("AI:"): + track_results = [] + track_total = 0 + track_long = 0 + + for block in blocks: + start_f, end_f = block[0], block[1] + severity = block[2] if len(block) > 2 else "Normal" + + start_sec = round(start_f / fps, 3) + end_sec = round(end_f / fps, 3) + duration = round(end_sec - start_sec, 3) + + track_total += 1 + if duration > 2.0: + track_long += 1 + + track_results.append({ + "start_frame": int(start_f), + "end_frame": int(end_f), + "start_time_sec": start_sec, + "end_time_sec": end_sec, + "duration_sec": duration, + "severity": severity + }) + + extraction_data["ai_tracks"][track_name] = track_results + extraction_data["metadata"]["track_summaries"][track_name] = { + "event_count": track_total, + "long_events_over_2s": track_long, + "total_duration_sec": round(sum(r["duration_sec"] for r in track_results), 3) + } + + return extraction_data + + +class SkeletonOverlay(QWidget): + def __init__(self, parent=None): + debug_print() + super().__init__(parent) + self.setAttribute(Qt.WA_TransparentForMouseEvents) # Clicks pass through to video + self.data = None + self.current_frame = 0 + self.hidden_tracks = set() + # Use your saved SKELETON_CONNECTIONS logic + self.connections = [ + (5, 7), (7, 9), (6, 8), (8, 10), (5, 6), (5, 11), + (6, 12), (11, 12), (11, 13), (13, 15), (12, 14), (14, 16) + ] + self.KP_MAP = { + "nose": 0, "LE": 1, "RE": 2, "Lear": 3, "Rear": 4, + "Lshoulder": 5, "Rshoulder": 6, "Lelbow": 7, "Relbow": 8, + "Lwrist": 9, "Rwrist": 10, "Lhip": 11, "Rhip": 12, + "Lknee": 13, "Rknee": 14, "Lankle": 15, "Rankle": 16 + } + self.CONNECTIONS = [ + ("nose", "LE"), ("nose", "RE"), ("LE", "Lear"), ("RE", "Rear"), + ("Lshoulder", "Rshoulder"), ("Lshoulder", "Lelbow"), ("Lelbow", "Lwrist"), + ("Rshoulder", "Relbow"), ("Relbow", "Rwrist"), ("Lshoulder", "Lhip"), + ("Rshoulder", "Rhip"), ("Lhip", "Rhip"), ("Lhip", "Lknee"), + ("Lknee", "Lankle"), ("Rhip", "Rknee"), ("Rknee", "Rankle") + ] + + + def set_frame(self, frame_idx): + debug_print() + self.current_frame = frame_idx + self.update() + + + def set_hidden_tracks(self, hidden_set): + debug_print() + self.hidden_tracks = hidden_set + self.update() + + + def set_data(self, data): + debug_print() + self.data = data + self.update() + + + def paintEvent(self, event): + debug_print() + if not self.data or 'raw_kps' not in self.data: + return + + painter = QPainter(self) + painter.setRenderHint(QPainter.Antialiasing) + + v_w, v_h = self.data['width'], self.data['height'] + w, h = self.width(), self.height() + scale_x, scale_y = w / v_w, h / v_h + + current_f = self.current_frame + kp_live = self.data['raw_kps'][current_f] + + # --- 1. MODIFIED TRACK STATUS (Respects Visibility) --- + def get_track_status(track_name): + # If the user greyed out this track in the timeline, act as if it's inactive + if track_name in self.hidden_tracks: + return None + if track_name not in self.data['events']: + return None + for start, end, severity, direction in self.data['events'][track_name]: + if start <= current_f <= end: + idx = TRACK_NAMES.index(track_name) + color = QColor(TRACK_COLORS[idx]) + alpha = 80 if severity == "Small" else 160 if severity == "Moderate" else 255 + color.setAlpha(alpha) + return color + return None + + ANGLE_SEGMENTS = { + "L_sh": [("Lhip", "Lshoulder"), ("Lshoulder", "Lelbow")], + "R_sh": [("Rhip", "Rshoulder"), ("Rshoulder", "Relbow")], + "L_el": [("Lshoulder", "Lelbow"), ("Lelbow", "Lwrist")], + "R_el": [("Rshoulder", "Relbow"), ("Relbow", "Rwrist")], + "L_leg": [("Lhip", "Lknee"), ("Lknee", "Lankle")], + "R_leg": [("Rhip", "Rknee"), ("Rknee", "Rankle")] + } + + # --- 2. DRAW BASELINE (Only if not hidden) --- + if "Baseline" not in self.hidden_tracks: + idx_l_hip, idx_r_hip = self.KP_MAP["Lhip"], self.KP_MAP["Rhip"] + pelvis_live = (kp_live[idx_l_hip] + kp_live[idx_r_hip]) / 2 + kp_baseline = self.data['baseline_kp_mean'] + pelvis_live + + painter.setPen(QPen(QColor(200, 200, 200, 200), 2, Qt.DashLine)) + for s_name, e_name in self.CONNECTIONS: + p1 = QPointF(kp_baseline[self.KP_MAP[s_name]][0] * scale_x, kp_baseline[self.KP_MAP[s_name]][1] * scale_y) + p2 = QPointF(kp_baseline[self.KP_MAP[e_name]][0] * scale_x, kp_baseline[self.KP_MAP[e_name]][1] * scale_y) + painter.drawLine(p1, p2) + + # --- 3. DRAW LIVE SKELETON (Only if not hidden) --- + + # CONNECTIONS + for s_name, e_name in self.CONNECTIONS: + active_color = None + for angle_track, segments in ANGLE_SEGMENTS.items(): + if (s_name, e_name) in segments or (e_name, s_name) in segments: + active_color = get_track_status(angle_track) + if active_color: break + + p1 = QPointF(kp_live[self.KP_MAP[s_name]][0] * scale_x, kp_live[self.KP_MAP[s_name]][1] * scale_y) + p2 = QPointF(kp_live[self.KP_MAP[e_name]][0] * scale_x, kp_live[self.KP_MAP[e_name]][1] * scale_y) + + if active_color: + # Active events ALWAYS draw + painter.setPen(QPen(active_color, 8, Qt.SolidLine, Qt.RoundCap)) + painter.drawLine(p1, p2) + elif "Live Skeleton" not in self.hidden_tracks: + # Black lines ONLY draw if Live Skeleton is ON + painter.setPen(QPen(Qt.black, 4, Qt.SolidLine, Qt.RoundCap)) + painter.drawLine(p1, p2) + + # DOTS + ANGLE_VERTEX_MAP = { + "L_sh": "Lshoulder", "R_sh": "Rshoulder", + "L_el": "Lelbow", "R_el": "Relbow", + "L_leg": "Lknee", "R_leg": "Rknee" + } + + for kp_name, kp_idx in self.KP_MAP.items(): + pt = QPointF(kp_live[kp_idx][0] * scale_x, kp_live[kp_idx][1] * scale_y) + + # Check for Point Event (Skip if hidden via get_track_status) + point_color = get_track_status(kp_name) + + if point_color: + painter.setBrush(point_color) + painter.setPen(QPen(Qt.white, 0.7)) + painter.drawEllipse(pt, 5, 5) + continue + + # Check for Angle Event + angle_color = None + for angle_track, vertex_name in ANGLE_VERTEX_MAP.items(): + if kp_name == vertex_name: + angle_color = get_track_status(angle_track) + if angle_color: break + + if angle_color: + painter.setBrush(angle_color) + painter.setPen(Qt.NoPen) + painter.drawEllipse(pt, 4, 4) + + elif "Live Skeleton" not in self.hidden_tracks: + painter.setBrush(Qt.black) + painter.setPen(Qt.NoPen) + painter.drawEllipse(pt, 4, 4) + + +class VideoView(QGraphicsView): + resized = Signal() + + def __init__(self, scene, parent=None): + debug_print() + super().__init__(scene, parent) + self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + self.setFrameStyle(0) + self.setStyleSheet("background: black; border: none;") + self.setAlignment(Qt.AlignCenter) + + def resizeEvent(self, event): + debug_print() + super().resizeEvent(event) + self.resized.emit() + + +# ========================================== +# MAIN PREMIERE WINDOW +# ========================================== +class PremiereWindow(QMainWindow): + def __init__(self): + debug_print() + super().__init__() + self.setWindowTitle("Pose Analysis Timeline") + self.resize(1200, 900) + + self.about = None + self.help = None + + self.platform_suffix = "-" + PLATFORM_NAME + + self.updater = UpdateManager( + main_window=self, + api_url=API_URL, + api_url_sec=API_URL_SECONDARY, + current_version=CURRENT_VERSION, + platform_name=PLATFORM_NAME, + platform_suffix=self.platform_suffix, + app_name=APP_NAME + ) + + # self.setStyleSheet("background-color: #1e1e1e; color: #ffffff;") + self.setStyleSheet(""" + QMainWindow, QWidget#centralWidget { + background-color: #1e1e1e; + } + QLabel, QStatusBar, QMenuBar { + color: #ffffff; + } + /* Target the Timeline specifically */ + TimelineWidget { + background-color: #1e1e1e; + border: 1px solid #333333; + } + /* Button styling with Grey borders */ + QDialog, QMessageBox, QFileDialog { + background-color: #2b2b2b; + } + QDialog QLabel, QMessageBox QLabel { + color: #ffffff; + } + QPushButton { + background-color: #2b2b2b; + color: #ffffff; + border: 1px solid #555555; /* Subtle Grey border */ + border-radius: 3px; + padding: 4px; + } + QPushButton:hover { + background-color: #3d3d3d; + border-color: #888888; /* Brightens border on hover */ + } + QPushButton:pressed { + background-color: #111111; + } + QPushButton:disabled { + border-color: #333333; + color: #444444; + } + /* Splitter/Divider styling */ + QSplitter::handle { + background-color: #333333; /* Dark grey dividers */ + } + QSplitter::handle:horizontal { + width: 2px; + } + QSplitter::handle:vertical { + height: 2px; + } + /* ScrollArea styling to keep it dark */ + QScrollArea, QScrollArea > QWidget > QWidget { + background-color: #1e1e1e; + border: none; + } + """) + + self.predictor = GeneralPredictor() + + self.file_path = None + self.obs_file = None + self.selected_obs_subkey = None + self.current_video_offset = 0.0 + + # Core Layout + main_splitter = QSplitter(Qt.Vertical) + top_splitter = QSplitter(Qt.Horizontal) + + # --- Top Left: Video Player --- + video_container = QWidget() + video_layout = QVBoxLayout(video_container) + video_layout.setContentsMargins(0, 0, 0, 0) + + self.scene = QGraphicsScene() + # Use our new subclass instead of standard QGraphicsView + self.view = VideoView(self.scene) + self.view.resized.connect(self.update_video_geometry) + + # Video item (NOT native) + self.video_item = QGraphicsVideoItem() + self.scene.addItem(self.video_item) + + # Overlay widget (normal QWidget) + self.skeleton_overlay = SkeletonOverlay(self.view.viewport()) + self.skeleton_overlay.setAttribute(Qt.WA_TransparentForMouseEvents) + self.skeleton_overlay.setAttribute(Qt.WA_TranslucentBackground) + self.skeleton_overlay.show() + + # Media player + self.player = QMediaPlayer() + self.audio_output = QAudioOutput() + self.player.setAudioOutput(self.audio_output) + self.player.setVideoOutput(self.video_item) + + video_layout.addWidget(self.view) + + # --- Control Bar Container (Vertical Stack) --- + controls_container = QWidget() + stacked_controls = QVBoxLayout(controls_container) + stacked_controls.setSpacing(5) # Tight spacing between rows + + # --- ROW 1: ML & Training Controls --- + ml_row = QHBoxLayout() + ml_row.addStretch() + + ml_row.addWidget(QLabel("ML Model:")) + self.ml_dropdown = QComboBox() + self.ml_dropdown.addItems(["Random Forest", "LSTM", "XGBoost", "SVM", "1D-CNN"]) + ml_row.addWidget(self.ml_dropdown) + + ml_row.addWidget(QLabel("Target:")) + self.target_dropdown = QComboBox() + self.target_dropdown.addItems(["Mouthing", "Head Movement", "Kick (Left)", "Kick (Right)", "Reach (Left)", "Reach (Right)"]) + self.target_dropdown.currentTextChanged.connect(self.update_predictor_target) + ml_row.addWidget(self.target_dropdown) + + self.btn_add_to_pool = QPushButton("Add to Pool") + self.btn_add_to_pool.clicked.connect(self.add_current_to_ml_pool) + self.btn_add_to_pool.setFixedWidth(120) + ml_row.addWidget(self.btn_add_to_pool) + + self.btn_train_final = QPushButton("Train Global Model") + self.btn_train_final.setStyleSheet("background-color: #2e7d32; font-weight: bold;") + self.btn_train_final.clicked.connect(self.run_final_training) + ml_row.addWidget(self.btn_train_final) + + self.lbl_pool_status = QLabel("Pool: 0 Participants") + self.lbl_pool_status.setStyleSheet("color: #00FF00; font-weight: bold; margin-left: 10px;") + self.lbl_pool_status.setFixedWidth(160) + ml_row.addWidget(self.lbl_pool_status) + + self.btn_clear_pool = QPushButton("Clear Pool") + self.btn_clear_pool.setFixedWidth(100) + self.btn_clear_pool.setStyleSheet("color: #ff5555; border: 1px solid #ff5555;") + self.btn_clear_pool.clicked.connect(self.clear_ml_pool) + ml_row.addWidget(self.btn_clear_pool) + + self.btn_extract_ai = QPushButton("Extract AI Data") + self.btn_extract_ai.clicked.connect(self.extract_ai_to_json) + ml_row.addWidget(self.btn_extract_ai) + + + ml_row.addStretch() + + # --- ROW 2: Playback & Transport --- + playback_row = QHBoxLayout() + playback_row.addStretch() + + # Transport Buttons + self.btn_start = QPushButton("|<") + self.btn_prev = QPushButton("<") + self.btn_play = QPushButton("Play") + self.btn_next = QPushButton(">") + self.btn_end = QPushButton(">|") + + self.transport_btns = [self.btn_start, self.btn_prev, self.btn_play, self.btn_next, self.btn_end] + for btn in self.transport_btns: + btn.setEnabled(False) + btn.setFixedWidth(50) + playback_row.addWidget(btn) + + self.btn_mute = QPushButton("Vol") + self.btn_mute.setFixedWidth(40) + self.btn_mute.setCheckable(True) + self.btn_mute.clicked.connect(self.toggle_mute) + + self.sld_volume = QSlider(Qt.Horizontal) + self.sld_volume.setRange(0, 100) + self.sld_volume.setValue(100) # Default volume + self.sld_volume.setFixedWidth(100) + self.sld_volume.valueChanged.connect(self.update_volume) + + # Initialize volume + self.audio_output.setVolume(0.7) + + playback_row.addWidget(self.btn_mute) + playback_row.addWidget(self.sld_volume) + + # Counters + counter_style = "font-family: 'Consolas'; font-size: 10pt; margin-left: 5px; color: #00FF00;" + self.lbl_time_counter = QLabel("Time: 00:00 / 00:00") + self.lbl_frame_counter = QLabel("Frame: 0 / 0") + self.lbl_time_counter.setFixedWidth(180) + self.lbl_frame_counter.setFixedWidth(180) + self.lbl_time_counter.setStyleSheet(counter_style) + self.lbl_frame_counter.setStyleSheet(counter_style) + + playback_row.addWidget(self.lbl_time_counter) + playback_row.addWidget(self.lbl_frame_counter) + + playback_row.addStretch() + + # --- Add Rows to Stack --- + stacked_controls.addLayout(ml_row) + stacked_controls.addLayout(playback_row) + + # Add the whole stacked container to the main video layout + video_layout.addWidget(controls_container) + + # --- Button Connections --- + self.btn_play.clicked.connect(self.toggle_playback) + # Use lambda to pass the target frame to your existing seek_video method + self.btn_start.clicked.connect(lambda: self.seek_video(0)) + self.btn_end.clicked.connect(lambda: self.seek_video(self.data['total_frames'] - 1)) + self.btn_prev.clicked.connect(lambda: self.step_frame(-1)) + self.btn_next.clicked.connect(lambda: self.step_frame(1)) + + # --- Top Right: Media Info & Loader --- + info_container = QWidget() + info_layout = QVBoxLayout(info_container) + + # NEW: Wrap the info_label in a Scroll Area + self.inspector_scroll = QScrollArea() + self.inspector_scroll.setWidgetResizable(True) + self.inspector_scroll.setStyleSheet("border: none; background-color: transparent;") + + # Create the label as the scroll area's content + self.info_label = QTextEdit() + self.info_label.setText("No video loaded.\nClick 'File' > 'Load Video' to begin.") + self.info_label.setAlignment(Qt.AlignTop | Qt.AlignLeft) + self.info_label.setWordWrapMode(QTextOption.WordWrap) + self.info_label.setReadOnly(True) + + # self.info_label.setWordWrap(True) # Ensure long text wraps instead o + # f stretching horizontally + self.info_label.setStyleSheet("padding: 5px; font-family: 'Segoe UI', Arial; color: #ffffff;") + + self.inspector_scroll.setWidget(self.info_label) + + # Add the scroll area to the layout instead of the naked label + info_layout.addWidget(self.inspector_scroll) + + top_splitter.addWidget(video_container) + top_splitter.addWidget(info_container) + top_splitter.setSizes([800, 400]) + + # --- Bottom: Timeline in a Scroll Area --- + self.timeline = TimelineWidget() + self.timeline.seek_requested.connect(self.seek_video) + self.timeline.visibility_changed.connect(self.skeleton_overlay.set_hidden_tracks) + self.timeline.track_selected.connect(self.on_track_selected) + + scroll_area = QScrollArea() + scroll_area.setWidgetResizable(True) + scroll_area.setWidget(self.timeline) + + main_splitter.addWidget(top_splitter) + main_splitter.addWidget(scroll_area) + main_splitter.setSizes([500, 400]) + + self.setCentralWidget(main_splitter) + self.player.positionChanged.connect(self.update_timeline_playhead) + self.player.positionChanged.connect(self.update_inspector) + self.create_menu_bar() + + self.local_check_thread = LocalPendingUpdateCheckThread(CURRENT_VERSION, self.platform_suffix, PLATFORM_NAME, APP_NAME) + self.local_check_thread.pending_update_found.connect(self.updater.on_pending_update_found) + self.local_check_thread.no_pending_update.connect(self.updater.on_no_pending_update) + self.local_check_thread.start() + + + + def create_menu_bar(self): + '''Menu Bar at the top of the screen''' + + menu_bar = self.menuBar() + self.statusbar = self.statusBar() + + def make_action(name, shortcut=None, slot=None, checkable=False, checked=False, icon=None): + action = QAction(name, self) + + if shortcut: + action.setShortcut(QKeySequence(shortcut)) + if slot: + action.triggered.connect(slot) + if checkable: + action.setCheckable(True) + action.setChecked(checked) + if icon: + action.setIcon(QIcon(icon)) + return action + + # File menu and actions + file_menu = menu_bar.addMenu("File") + file_actions = [ + ("Load Video...", "Ctrl+O", self.load_video, resource_path("icons/file_open_24dp_1F1F1F.svg")), + # ("Open Folder...", "Ctrl+Alt+O", self.not_implemented, resource_path("icons/folder_24dp_1F1F1F.svg")), + # ("Open Folders...", "Ctrl+Shift+O", self.open_folder_dialog, resource_path("icons/folder_copy_24dp_1F1F1F.svg")), + # ("Load Project...", "Ctrl+L", self.not_implemented, resource_path("icons/article_24dp_1F1F1F.svg")), + # ("Save Project...", "Ctrl+S", self.not_implemented, resource_path("icons/save_24dp_1F1F1F.svg")), + # ("Save Project As...", "Ctrl+Shift+S", self.not_implemented, resource_path("icons/save_as_24dp_1F1F1F.svg")), + ] + + for i, (name, shortcut, slot, icon) in enumerate(file_actions): + file_menu.addAction(make_action(name, shortcut, slot, icon=icon)) + if i == 1: # after the first 3 actions (0,1,2) + file_menu.addSeparator() + + file_menu.addSeparator() + file_menu.addAction(make_action("Exit", "Ctrl+Q", QApplication.instance().quit, icon=resource_path("icons/exit_to_app_24dp_1F1F1F.svg"))) + + # Edit menu + edit_menu = menu_bar.addMenu("Edit") + edit_actions = [ + ("Cut", "Ctrl+X", self.cut_text, resource_path("icons/content_cut_24dp_1F1F1F.svg")), + ("Copy", "Ctrl+C", self.copy_text, resource_path("icons/content_copy_24dp_1F1F1F.svg")), + ("Paste", "Ctrl+V", self.paste_text, resource_path("icons/content_paste_24dp_1F1F1F.svg")) + ] + for name, shortcut, slot, icon in edit_actions: + edit_menu.addAction(make_action(name, shortcut, slot, icon=icon)) + + # View menu + view_menu = menu_bar.addMenu("View") + toggle_statusbar_action = make_action("Toggle Status Bar", checkable=True, checked=True, slot=None) + view_menu.addAction(toggle_statusbar_action) + toggle_statusbar_action.toggled.connect(self.statusbar.setVisible) + + # Options menu (Help & About) + options_menu = menu_bar.addMenu("Options") + + options_actions = [ + ("User Guide", "F1", self.user_guide, resource_path("icons/help_24dp_1F1F1F.svg")), + ("Check for Updates", "F5", self.updater.manual_check_for_updates, resource_path("icons/update_24dp_1F1F1F.svg")), + ("Batch YOLO processing...", "F6", self.open_batch_tool, resource_path("icons/upgrade_24dp_1F1F1F.svg")), + ("About", "F12", self.about_window, resource_path("icons/info_24dp_1F1F1F.svg")) + ] + + for i, (name, shortcut, slot, icon) in enumerate(options_actions): + options_menu.addAction(make_action(name, shortcut, slot, icon=icon)) + if i == 1 or i == 3: # after the first 2 actions (0,1) + options_menu.addSeparator() + + preferences_menu = menu_bar.addMenu("Preferences") + preferences_actions = [ + ("Not Implemented", "", self.not_implemented, resource_path("icons/info_24dp_1F1F1F.svg")), + ] + for name, shortcut, slot, icon in preferences_actions: + preferences_menu.addAction(make_action(name, shortcut, slot, icon=icon, checkable=True, checked=False)) + + terminal_menu = menu_bar.addMenu("Terminal") + terminal_actions = [ + ("Not Implemented", "", self.not_implemented, resource_path("icons/terminal_24dp_1F1F1F.svg")), + ] + for name, shortcut, slot, icon in terminal_actions: + terminal_menu.addAction(make_action(name, shortcut, slot, icon=icon)) + + self.statusbar.showMessage("Ready") + + + def not_implemented(self): + self.statusbar.showMessage("Not Implemented.") # Show status message + + def copy_text(self): + self.info_label.copy() # Trigger copy + self.statusbar.showMessage("Copied to clipboard") # Show status message + + def cut_text(self): + self.info_label.cut() # Trigger cut + self.statusbar.showMessage("Cut to clipboard") # Show status message + + def about_window(self): + if self.about is None or not self.about.isVisible(): + self.about = AboutWindow(self) + self.about.show() + + def user_guide(self): + if self.help is None or not self.help.isVisible(): + self.help = UserGuideWindow(self) + self.help.show() + + def paste_text(self): + self.info_label.paste() # Trigger paste + self.statusbar.showMessage("Pasted from clipboard") # Show status message + + def open_batch_tool(self): + dialog = BatchProcessorDialog(self) # Pass 'self' to keep it centered + dialog.exec() + + def toggle_mute(self): + is_muted = self.btn_mute.isChecked() + self.audio_output.setMuted(is_muted) + self.btn_mute.setText("Mute" if is_muted else "Vol") + # Optional: Dim the slider when muted + self.sld_volume.setEnabled(not is_muted) + + def update_volume(self, value): + # QAudioOutput expects a float between 0.0 and 1.0 + volume = value / 100.0 + self.audio_output.setVolume(volume) + + # Auto-unmute if user moves the slider + if self.btn_mute.isChecked() and value > 0: + self.btn_mute.setChecked(False) + self.toggle_mute() + + def clear_ml_pool(self): + """Removes all participants from the training buffer.""" + debug_print() + # Confirm with the user first to prevent accidental deletions + reply = QMessageBox.question(self, 'Clear Pool?', + f"This will remove all {len(self.predictor.raw_participant_buffer)} " + "participants from the training memory. Continue?", + QMessageBox.Yes | QMessageBox.No, QMessageBox.No) + + if reply == QMessageBox.Yes: + # 1. Clear the actual list in the predictor + self.predictor.raw_participant_buffer = [] + + # 2. Update the UI label + self.lbl_pool_status.setText("Pool: 0 Participants") + + # 3. Optional: Visual feedback + # self.statusBar().showMessage("ML Pool cleared.", 3000) + print("DEBUG: ML Pool manually cleared.") + + + def update_predictor_target(self): + debug_print() + # This physically changes the string from "Mouthing" to "Head Movement" + self.predictor.current_target = self.target_dropdown.currentText() + + print(f"Predictor is now targeting: {self.predictor.current_target}") + + + def reprocess_current_video(self): + """Restarts the analysis worker to pick up new models.""" + debug_print() + + # Start the worker (passing the predictor so it can run AI models) + self.worker = PoseAnalyzerWorker( + self.file_path, + obs_info=self.selected_obs_subkey, + predictor=self.predictor + ) + + self.worker.progress.connect(self.update_status) + self.worker.finished_data.connect(self.setup_workspace) + self.worker.start() + + + def update_video_geometry(self): + debug_print() + if not hasattr(self, "video_item") or not hasattr(self, "data"): + return + + viewport_rect = self.view.viewport().rect() + v_w, v_h = viewport_rect.width(), viewport_rect.height() + if v_w <= 0 or v_h <= 0: return + + video_w, video_h = self.data['width'], self.data['height'] + aspect = video_w / video_h + + if v_w / v_h > aspect: + target_h = v_h + target_w = int(v_h * aspect) + else: + target_w = v_w + target_h = int(v_w / aspect) + + x_off = (v_w - target_w) / 2 + y_off = (v_h - target_h) / 2 + + self.scene.setSceneRect(0, 0, v_w, v_h) + self.video_item.setPos(x_off, y_off) + self.video_item.setSize(QSizeF(target_w, target_h)) + self.skeleton_overlay.setGeometry(int(x_off), int(y_off), target_w, target_h) + + def resizeEvent(self, event): + debug_print() + + super().resizeEvent(event) + self.update_video_geometry() + if hasattr(self, 'timeline'): + self.timeline.set_zoom(self.timeline.zoom_factor) + + # def eventFilter(self, source, event): + # if source is self.video_widget and event.type() == QEvent.Resize: + # self.skeleton_overlay.resize(event.size()) + # return super().eventFilter(source, event) + + + + def add_current_to_ml_pool(self): + """Adds raw kinematic data and current OBS labels to the buffer.""" + debug_print() + if not hasattr(self, 'data') or 'raw_kps' not in self.data: + QMessageBox.warning(self, "No Data", "Load a video first.") + return + + # 1. Grab everything the Worker produced + payload = { + "z_kps": self.data['z_kps'], + "directions": self.data['directions'], + "raw_kps": self.data['raw_kps'] + } + + all_labels = self.timeline.get_all_binary_labels(self.current_video_offset, self.data["fps"]) + + # 3. Hand off to predictor + msg = self.predictor.add_to_raw_buffer(payload, all_labels) + self.lbl_pool_status.setText(f"Pool: {len(self.predictor.raw_participant_buffer)} Participants") + print(f"DEBUG: Added to Predictor at {hex(id(self.predictor))}") + print(f"DEBUG: Buffer size is now: {len(self.predictor.raw_participant_buffer)}") + QMessageBox.information(self, "Success", msg) + + + def run_final_training(self): + """ + Triggers training + """ + debug_print() + # DEBUG: Check the buffer directly before the IF statement + actual_buffer = self.predictor.raw_participant_buffer + current_count = len(actual_buffer) + + if current_count < 1: + # If this triggers, let's see WHY it's empty + QMessageBox.warning(self, "Empty Pool", + f"Buffer is empty (Size: {current_count}).\n" + f"Predictor ID: {hex(id(self.predictor))}") + return + + model_type = self.ml_dropdown.currentText() + current_target = self.target_dropdown.currentText() + + reply = QMessageBox.question(self, 'Confirm Training', + f"Train {model_type} for '{current_target}' using " + f"{current_count} participants?", + QMessageBox.Yes | QMessageBox.No, QMessageBox.No) + + if reply == QMessageBox.Yes: + self.btn_train_final.setEnabled(False) + self.btn_train_final.setText(f"Training...") + + try: + # Force the target update right before training + self.predictor.current_target = current_target + report_html = self.predictor.calculate_and_train(model_type, current_target) + + self.reprocess_current_video() + + self.info_label.setText(report_html) + msg = QMessageBox(self) + msg.setWindowTitle("Results") + msg.setTextFormat(Qt.RichText) + msg.setText(report_html) + msg.exec() + + + + except Exception as e: + traceback.print_exc() + QMessageBox.critical(self, "Error", f"{str(e)}") + + finally: + self.btn_train_final.setEnabled(True) + self.btn_train_final.setText("Train Global Model") + + + + # def import_json_observations(self): + # debug_print() + # file_path, _ = QFileDialog.getOpenFileName(self, "Select JSON Observations", "", "JSON Files (*.json)") + # if not file_path: return + + # with open(file_path, 'r') as f: + # full_data = json.load(f) + + # # Get the subkeys under "observations" + # subkeys = list(full_data.get("observations", {}).keys()) + + # if not subkeys: + # print("No observations found in JSON.") + # return + + # item, ok = QInputDialog.getItem(self, "Select Session", "Pick an observation set:", subkeys, 0, False) + + # if ok and item: + # new_obs_data = self.load_external_observations(file_path, item) + # self.append_new_tracks(new_obs_data) + + def append_new_tracks(self, new_obs_data): + debug_print() + # 1. Update global TRACK_NAMES and TRACK_COLORS + for name in new_obs_data.keys(): + if name not in TRACK_NAMES: + TRACK_NAMES.append(name) + # Assign a distinct color (e.g., a dark purple/magenta for observations) + TRACK_COLORS.append("#AA00FF") + + # 2. Merge into existing data dictionary + self.data["events"].update(new_obs_data) + + # 3. Refresh Timeline + global NUM_TRACKS + NUM_TRACKS = len(TRACK_NAMES) + self.timeline.set_data(self.data) + self.timeline.update_geometry() + + # def load_external_observations(self, file_path, subkey): + # debug_print() + # with open(file_path, 'r') as f: + # data = json.load(f) + + # raw_events = data["observations"][subkey]["events"] + # # We only care about: [time_seconds (0), _, label (2), _, _, _] + + # new_tracks = {} + + # # Sort events by time just in case the JSON is unsorted + # raw_events.sort(key=lambda x: x[0]) + + # # Group timestamps by their label (e.g., "Kick", "Baseline") + # temp_storage = {} + # for event in raw_events: + # time_sec = event[0] + # label = event[2] + # frame = int(time_sec * self.data["fps"]) + + # if label not in temp_storage: + # temp_storage[label] = [] + # temp_storage[label].append(frame) + + # # Convert pairs of frames into (start, end) blocks + # for label, frames in temp_storage.items(): + # processed_blocks = [] + # # Step through frames in pairs (start, end) + # for i in range(0, len(frames) - 1, 2): + # start = frames[i] + # end = frames[i+1] + # # Format: (start, end, severity, direction) + # processed_blocks.append((start, end, "External", "Manual Obs")) + + # new_tracks[f"OBS: {label}"] = processed_blocks + + # return new_tracks + + + def load_video(self): + debug_print() + self.file_path, _ = QFileDialog.getOpenFileName(self, "Open Video", "", "Video Files (*.mp4 *.avi *.mkv)") + if not self.file_path: return + + cap = cv2.VideoCapture(self.file_path) + if cap.isOpened(): + self.current_video_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 + #total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + # Optional: Initialize timeline with blank data so it can at least draw the ruler + #self.timeline.data = {"total_frames": total_frames, "fps": self.current_video_fps, "events": {}} + cap.release() + else: + self.current_video_fps = 30.0 # Fallback + + # --- NEW: JSON Observation Prompt --- + self.obs_file, _ = QFileDialog.getOpenFileName(self, "Select JSON Observations (Optional)", "", "JSON Files (*.json)") + + + if self.obs_file: + try: + with open(self.obs_file, 'r') as f: + full_json = json.load(f) + + observations = full_json.get("observations", {}) + subkeys = list(observations.keys()) + + # --- AUTO-MATCHING LOGIC --- + # 1. Get the video filename without extension (e.g., 'T4_2T_WORD_F') + video_name = os.path.splitext(os.path.basename(self.file_path))[0] + v_parts = video_name.split('_') + + # Build the 'fingerprint' from the video (Blocks 1, 2, and the Last one) + # This ignores the 'WORDHERE' block in the middle + if len(v_parts) >= 3: + video_fingerprint = f"{v_parts[0]}_{v_parts[1]}_{v_parts[-1]}" + else: + video_fingerprint = video_name # Fallback + + match = None + for sk in subkeys: + s_parts = sk.split('_') + # Subkeys are shorter: Block 1, 2, and 3 + if len(s_parts) == 3: + sk_fingerprint = f"{s_parts[0]}_{s_parts[1]}_{s_parts[2]}" + if sk_fingerprint.lower() == video_fingerprint.lower(): + match = sk + break + + # 2. Decision: Use match or prompt user + if match: + self.selected_obs_subkey = (self.obs_file, match) + self.statusBar().showMessage(f"Auto-matched JSON session: {match}", 5000) + elif subkeys: + # No match found, only then show the popup + item, ok = QInputDialog.getItem(self, "Select Session", + f"Could not auto-match '{video_name}'.\nPick manually:", + subkeys, 0, False) + if ok and item: + self.selected_obs_subkey = (self.obs_file, item) + + # --- NEW: Offset & File Matching Logic --- + if self.selected_obs_subkey: + _, session_key = self.selected_obs_subkey + session_data = observations.get(session_key, {}) + file_map = session_data.get("file", {}) + + video_filename = os.path.basename(self.file_path) + found_index = None + + # 1. Attempt Auto-Match by filename + for idx_str, file_list in file_map.items(): + # Check if our loaded video is in this list (e.g., "Videos\\T4_2T_WORD_F.mp4") + if any(video_filename in path for path in file_list): + found_index = idx_str + print(f"DEBUG: Auto-matched video to Camera Index {idx_str}") + break + + # 2. If Auto-Match fails, prompt user for Camera Index + if not found_index: + available_indices = [k for k, v in file_map.items() if v] # Only indices with files + if available_indices: + item, ok = QInputDialog.getItem(self, "Identify Camera", + f"Could not find '{video_filename}' in JSON.\n" + "Which camera index is this video?", + available_indices, 0, False) + if ok: + found_index = item + + # 3. Retrieve and Print the Offset + if found_index: + offsets = session_data.get("media_info", {}).get("offset", {}) + search_key = str(found_index) + # Note: offsets dict might use integers or strings as keys + # We check both to be safe + actual_offset = offsets.get(search_key) + + if actual_offset is not None: + print(f"MATCHED OFFSET: {actual_offset:.4f}") + # Store this if you need it for timeline syncing later + self.current_video_offset = float(actual_offset) + self.timeline.set_sync_params( + offset_seconds=self.current_video_offset, + fps=self.current_video_fps + ) + + print(f"✅ Timeline synced with {actual_offset}s offset.") + else: + print(f"DEBUG: No offset found for index {found_index}") + + except Exception as e: + QMessageBox.warning(self, "JSON Error", f"Could not parse JSON: {e}") + + # --- Cache Logic --- + # cache_path = self.file_path.rsplit('.', 1)[0] + "_pose_cache.csv" + # use_cache = None + # if os.path.exists(cache_path): + # reply = QMessageBox.question(self, 'Cache Found', + # "Use existing pose cache?", + # QMessageBox.Yes | QMessageBox.No) + # use_cache = cache_path if reply == QMessageBox.Yes else None + + self.btn_load.setEnabled(False) + + # Pass the observation info to the worker + self.worker = PoseAnalyzerWorker(self.file_path, self.selected_obs_subkey, self.predictor) + self.worker.progress.connect(self.update_status) + self.worker.finished_data.connect(self.setup_workspace) + self.worker.start() + + def update_status(self, msg): + debug_print() + + self.info_label.setText(f"Status:\n{msg}") + + def setup_workspace(self, data): + debug_print() + self.data = data + self.player.setSource(QUrl.fromLocalFile(data["video_path"])) + self.player.play() + self.player.pause() + self.timeline.set_data(data) + self.skeleton_overlay.set_data(data) + self.update_video_geometry() + for btn in self.transport_btns: + btn.setEnabled(True) + total_f = data['total_frames'] + fps = data['fps'] + tot_s = int(total_f / fps) + + # Display 0 / Total + self.lbl_time_counter.setText(f"00:00 / {tot_s//60:02d}:{tot_s%60:02d}") + self.lbl_frame_counter.setText(f"0 / {total_f-1}") + + # Sync widgets + self.timeline.set_data(data) + self.skeleton_overlay.set_data(data) + + # Force a seek to frame 0 to initialize the video buffer + self.seek_video(0) + self.btn_load.setEnabled(True) + + info_text = ( + f"File: {os.path.basename(data['video_path'])}\n" + f"Resolution: {data['width']}x{data['height']}\n" + f"FPS: {data['fps']:.2f}\n" + f"Total Frames: {data['total_frames']}\n\n" + f"Timeline Legend (Opacity):\n" + f"255 Alpha = Large Deviation\n" + f"160 Alpha = Moderate Deviation\n" + f"80 Alpha = Small Deviation\n" + f"Empty = Rest (Baseline)" + ) + self.info_label.setText(info_text) + + + def toggle_playback(self): + debug_print() + + if not hasattr(self, 'data'): return + + # If we are at the end, jump to the start first + fps = self.data["fps"] + current_frame = int((self.player.position() / 1000.0) * fps + 0.5) + if current_frame >= self.data["total_frames"] - 1: + self.seek_video(0) + + if self.player.playbackState() == QMediaPlayer.PlayingState: + self.player.pause() + self.btn_play.setText("Play") + else: + self.player.play() + self.btn_play.setText("Pause") + + def update_timeline_playhead(self, position_ms): + debug_print() + if hasattr(self, 'data') and self.data["fps"] > 0: + fps = self.data["fps"] + total_f = self.data["total_frames"] + + # Current frame calculation + current_f = int((position_ms / 1000.0) * fps) + + # --- PREVENT BLACK FRAME AT END --- + # If we are within 1 frame of the end, stop and lock to the last valid frame + if current_f >= total_f - 1: + if self.player.playbackState() == QMediaPlayer.PlayingState: + self.player.pause() + self.btn_play.setText("Play") + current_f = total_f - 1 + # Seek slightly back from total duration to keep the image visible + last_valid_ms = int(((total_f - 1) / fps) * 1000) + self.player.setPosition(last_valid_ms) + + # Sync UI + self.timeline.set_playhead(current_f) + self.skeleton_overlay.set_frame(current_f) + self.update_counters(current_f) + + + def on_track_selected(self, track_name): + debug_print() + + self.selected_track = track_name + + if not track_name: + self.info_label.setText("No track selected.\nClick a data track to inspect.") + self.info_label.setStyleSheet("color: #AAAAAA; font-family: 'Segoe UI'; font-size: 10pt;") + else: + self.info_label.setStyleSheet("color: #00FF00; font-family: 'Segoe UI'; font-size: 10pt;") + self.update_inspector() # Refresh immediately on click + + + def update_inspector(self): + debug_print() + if not hasattr(self, 'selected_track') or not self.selected_track or not self.data: + return + + # 1. Temporal Logic + current_f = int((self.player.position() / 1000.0) * self.data["fps"]) + current_f = max(0, min(current_f, self.data["total_frames"] - 1)) + + is_ai = "AI:" in self.selected_track + is_obs = "OBS:" in self.selected_track + + # 2. Status/Raw Logic + if is_ai or is_obs: + # Check Activity for Behavior Tracks + events = self.data["events"].get(self.selected_track, []) + is_active = any(start <= current_f <= end for start, end, *rest in events) + active_color = "#ff5555" if is_active else "#888888" + + status_line = f"ACTIVE: {'YES' if is_active else 'NO'}" + raw_line = "" # Do not display raw for AI/OBS + else: + # Kinematics Logic (No Active status) + status_line = "" + raw_info = "N/A" + cache_path = self.file_path.rsplit('.', 1)[0] + "_pose_raw.csv" + print(cache_path) + + if os.path.exists(cache_path): + try: + + # Row 2 in CSV is Frame 0. pandas.read_csv uses Row 1 as header. + # So Frame 0 is df.iloc[0]. + print(current_f) + print(len(self.worker.pose_df)) + if current_f < len(self.worker.pose_df): + row = self.worker.pose_df.iloc[current_f] + print(self.selected_track) + col_x, col_y, col_c = f"{self.selected_track}_x", f"{self.selected_track}_y", f"{self.selected_track}_conf" + print(self.worker.pose_df.columns) + + if col_x in self.worker.pose_df.columns and col_y in self.worker.pose_df.columns: + print("me") + rx, ry = row[col_x], row[col_y] + rc = row[col_c] if col_c in self.worker.pose_df.columns else 0.0 + raw_info = f"X: {rx:.2f} | Y: {ry:.2f} | Conf: {rc:.2f}" + except Exception as e: + print(f"Inspector CSV Error: {e}") + raw_info = "Index Error" + + raw_line = f"RAW (CSV): {raw_info}" + + # 3. Construct Display + display_text = ( + f"TRACK: {self.selected_track}
" + f"FRAME: {current_f}
" + f"{status_line}" + f"{raw_line}" + ) + + # 4. Performance Report + if is_ai: + target_name = self.selected_track.replace("AI: ", "") + pattern = f"ml_{target_name}_performance_*.txt" + report_files = sorted(glob.glob(pattern)) + + report_content = "No report found." + if report_files: + try: + with open(report_files[-1], 'r') as f: + report_content = f.read().replace('\n', '
') + except: pass + + display_text += f"
AI Performance:
{report_content}" + + self.info_label.setText(display_text) + + + def step_frame(self, delta): + debug_print() + + if not hasattr(self, 'data'): return + + fps = self.data["fps"] + # Calculate current frame based on ms position + current_f = int((self.player.position() / 1000.0) * fps + 0.5) + target_f = current_f + delta + + # Use your existing seek_video to handle bounds and UI updates + self.seek_video(target_f) + + def seek_video(self, frame): + debug_print() + if hasattr(self, 'data') and self.data["fps"] > 0: + total_f = self.data["total_frames"] + target_frame = max(0, min(frame, total_f - 1)) + + # Calculate MS with a tiny offset (0.1) to ensure the player + # lands ON the frame, not slightly before it. + ms = int((target_frame / self.data["fps"]) * 1000) + self.player.setPosition(ms) + + self.video_item.update() + + # Update UI immediately for snappier feedback + self.timeline.set_playhead(target_frame) + self.update_counters(target_frame) + + def update_counters(self, current_f): + debug_print() + + # Dedicated method to refresh the labels + fps = self.data["fps"] + total_f = self.data["total_frames"] + + cur_s, tot_s = int(current_f / fps), int(total_f / fps) + self.lbl_time_counter.setText(f"Time: {cur_s//60:02d}:{cur_s%60:02d} / {tot_s//60:02d}:{tot_s%60:02d}") + self.lbl_frame_counter.setText(f"Frame: {current_f} / {total_f-1}") + + + + + def extract_ai_to_json(self): + """ + Automatically saves AI extractions to the video directory + with the suffix '_events.json'. + """ + + # 1. Check if a video is loaded to get the base path + video_path = getattr(self, "file_path", None) + if not video_path or not os.path.exists(video_path): + print("Error: No video loaded. Cannot determine save path.") + return + + # 2. Construct the new filename + base_dir = os.path.dirname(video_path) + file_name = os.path.splitext(os.path.basename(video_path))[0] + save_path = os.path.join(base_dir, f"{file_name}_events.blaze") + + # 3. Call the timeline method to get the data + try: + extraction_data = self.timeline.get_ai_extractions() + + # Inject source video metadata + extraction_data["metadata"]["source_video"] = video_path + + # 4. Save to disk + with open(save_path, 'w') as f: + json.dump(extraction_data, f, indent=4) + + print(f"Extraction automatically saved to: {save_path}") + + except Exception as e: + print(f"Error during automatic AI extraction: {e}") + + + + + +def resource_path(relative_path): + """ + Get absolute path to resource regardless of running directly or packaged using PyInstaller + """ + + if hasattr(sys, '_MEIPASS'): + # PyInstaller bundle path + base_path = sys._MEIPASS + else: + base_path = os.path.dirname(os.path.abspath(__file__)) + + return os.path.join(base_path, relative_path) + + +def kill_child_processes(): + """ + Goodbye children + """ + + try: + parent = psutil.Process(os.getpid()) + children = parent.children(recursive=True) + for child in children: + try: + child.kill() + except psutil.NoSuchProcess: + pass + psutil.wait_procs(children, timeout=5) + except Exception as e: + print(f"Error killing child processes: {e}") + + +def exception_hook(exc_type, exc_value, exc_traceback): + """ + Method that will display a popup when the program hard crashes containg what went wrong + """ + + error_msg = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) + print(error_msg) # also print to console + + kill_child_processes() + + # Show error message box + # Make sure QApplication exists (or create a minimal one) + app = QApplication.instance() + if app is None: + app = QApplication(sys.argv) + + show_critical_error(error_msg) + + # Exit the app after user acknowledges + sys.exit(1) + +def show_critical_error(error_msg): + msg_box = QMessageBox() + msg_box.setIcon(QMessageBox.Icon.Critical) + msg_box.setWindowTitle("Something went wrong!") + + if PLATFORM_NAME == "darwin": + log_path = os.path.join(os.path.dirname(sys.executable), "../../../flares.log") + log_path2 = os.path.join(os.path.dirname(sys.executable), "../../../flares_error.log") + save_path = os.path.join(os.path.dirname(sys.executable), "../../../flares_autosave.flare") + + else: + log_path = os.path.join(os.getcwd(), "flares.log") + log_path2 = os.path.join(os.getcwd(), "flares_error.log") + save_path = os.path.join(os.getcwd(), "flares_autosave.flare") + + + shutil.copy(log_path, log_path2) + log_path2 = Path(log_path2).absolute().as_posix() + autosave_path = Path(save_path).absolute().as_posix() + log_link = f"file:///{log_path2}" + autosave_link = f"file:///{autosave_path}" + + message = ( + f"{APP_NAME.upper()} has encountered an unrecoverable error and needs to close.

" + f"We are sorry for the inconvenience. An autosave was attempted to be saved to {autosave_path}, but it may not have been saved. " + "If the file was saved, it still may not be intact, openable, or contain the correct data. Use the autosave at your discretion.

" + f"This unrecoverable error was likely due to an error with {APP_NAME.upper()} and not your data.
" + f"Please raise an issue here and attach the error file located at {log_path2}

" + f"
{error_msg}
" + ) + + msg_box.setTextFormat(Qt.TextFormat.RichText) + msg_box.setText(message) + msg_box.setTextInteractionFlags(Qt.TextInteractionFlag.TextBrowserInteraction) + msg_box.setStandardButtons(QMessageBox.StandardButton.Ok) + + msg_box.exec() + +if __name__ == "__main__": + # Redirect exceptions to the popup window + sys.excepthook = exception_hook + + # Set up application logging + if PLATFORM_NAME == "darwin": + log_path = os.path.join(os.path.dirname(sys.executable), f"../../../{APP_NAME}.log") + else: + log_path = os.path.join(os.getcwd(), f"{APP_NAME}.log") + + try: + os.remove(log_path) + except: + pass + + sys.stdout = open(log_path, "a", buffering=1) + sys.stderr = sys.stdout + print(f"\n=== App started at {datetime.now()} ===\n") + + freeze_support() # Required for PyInstaller + multiprocessing + + # Only run GUI in the main process + if current_process().name == 'MainProcess': + app = QApplication(sys.argv) + finish_update_if_needed(PLATFORM_NAME, APP_NAME) + window = PremiereWindow() + + if PLATFORM_NAME == "darwin": + app.setWindowIcon(QIcon(resource_path("icons/main.icns"))) + window.setWindowIcon(QIcon(resource_path("icons/main.icns"))) + else: + app.setWindowIcon(QIcon(resource_path("icons/main.ico"))) + window.setWindowIcon(QIcon(resource_path("icons/main.ico"))) + window.show() + sys.exit(app.exec()) + +# Not 6000 lines yay! \ No newline at end of file diff --git a/predictor.py b/predictor.py new file mode 100644 index 0000000..01e0f25 --- /dev/null +++ b/predictor.py @@ -0,0 +1,405 @@ +""" +Filename: predictor.py +Description: BLAZES machine learning + +Author: Tyler de Zeeuw +License: GPL-3.0 +""" + +# Built-in imports +import inspect +from datetime import datetime + +# External library imports +import numpy as np +import joblib +import seaborn as sns +import matplotlib.pyplot as plt + +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import train_test_split +from sklearn.metrics import classification_report, f1_score, precision_score, recall_score, confusion_matrix +from sklearn.preprocessing import StandardScaler + +# To be used once multiple models are supported and functioning: +# import torch +# import torch.nn as nn +# import torch.optim as optim +# import xgboost as xgb +# from sklearn.svm import SVC +# import os + +VERBOSITY = 1 + +GEOMETRY_LIBRARY = { + # --- Distances (Point A, Point B) --- + "dist_l_wrist_nose": ("dist", [9, 0], True), + "dist_r_wrist_nose": ("dist", [10, 0], True), + "dist_l_ear_r_shld": ("dist", [3, 6], True), + "dist_r_ear_l_shld": ("dist", [4, 5], True), + + "dist_l_wrist_pelvis": ("dist", [9, [11, 12]], True), + "dist_r_wrist_pelvis": ("dist", [10, [11, 12]], True), + "dist_l_ankl_pelvis": ("dist", [15, [11, 12]], True), + "dist_r_ankl_pelvis": ("dist", [16, [11, 12]], True), + "dist_nose_pelvis": ("dist", [0, [11, 12]], True), + "dist_ankl_ankl": ("dist", [15, 16], True), + + # NEW: Cross-Body and Pure Extension Distances + "dist_l_wri_r_shld": ("dist", [9, 6], True), # Reach across body + "dist_r_wri_l_shld": ("dist", [10, 5], True), # Reach across body + "dist_l_wri_l_shld": ("dist", [9, 5], True), # Pure arm extension + "dist_r_wri_r_shld": ("dist", [10, 6], True), # Pure arm extension + + # --- Angles (Point A, Center B, Point C) --- + "angle_l_elbow": ("angle", [5, 7, 9]), + "angle_r_elbow": ("angle", [6, 8, 10]), + "angle_l_shoulder": ("angle", [11, 5, 7]), + "angle_r_shoulder": ("angle", [12, 6, 8]), + "angle_l_knee": ("angle", [11, 13, 15]), + "angle_r_knee": ("angle", [12, 14, 16]), + "angle_l_hip": ("angle", [5, 11, 13]), + "angle_r_hip": ("angle", [6, 12, 14]), + + # --- Custom/Derived --- + "asym_wrist": ("z_diff", [9, 10]), + "asym_ankl": ("z_diff", [15, 16]), + "offset_head": ("head_offset", [0, 5, 6]), + "diff_ear_shld": ("subtraction", ["dist_l_ear_r_shld", "dist_r_ear_l_shld"]), + "abs_diff_ear_shld": ("abs_subtraction", ["dist_l_ear_r_shld", "dist_r_ear_l_shld"]), + + # NEW: Verticality and Contralateral Contrast + "height_l_ankl": ("y_diff", [15, 11]), # Foot height relative to hip + "height_r_ankl": ("y_diff", [16, 12]), # Foot height relative to hip + "diff_knee_angle": ("subtraction", ["angle_l_knee", "angle_r_knee"]), + "asym_wri_shld": ("subtraction", ["dist_l_wri_l_shld", "dist_r_wri_r_shld"]) +} + + +# The Target Activity Map +ACTIVITY_MAP = { + "Mouthing": [ + "dist_l_wrist_nose", "dist_r_wrist_nose", "angle_l_elbow", + "angle_r_elbow", "angle_l_shoulder", "angle_r_shoulder", + "asym_wrist", "offset_head" + ], + "Head Movement": [ + "dist_l_wrist_nose", "dist_r_wrist_nose", "angle_l_elbow", + "angle_r_elbow", "angle_l_shoulder", "angle_r_shoulder", + "asym_wrist", "offset_head", "dist_l_ear_r_shld", + "dist_r_ear_l_shld", "diff_ear_shld", "abs_diff_ear_shld" + ], + "Reach (Left)": [ + "dist_l_wrist_pelvis", "dist_l_wrist_nose", "dist_l_wri_l_shld", + "dist_l_wri_r_shld", "angle_l_elbow", "angle_l_shoulder", + "asym_wri_shld" + ], + "Reach (Right)": [ + "dist_r_wrist_pelvis", "dist_r_wrist_nose", "dist_r_wri_r_shld", + "dist_r_wri_l_shld", "angle_r_elbow", "angle_r_shoulder", + "asym_wri_shld" + ], + "Kick (Left)": [ + "dist_l_ankl_pelvis", "angle_l_knee", "angle_l_hip", + "height_l_ankl", "dist_ankl_ankl", "asym_ankl", + "diff_knee_angle", "dist_nose_pelvis" + ], + "Kick (Right)": [ + "dist_r_ankl_pelvis", "angle_r_knee", "angle_r_hip", + "height_r_ankl", "dist_ankl_ankl", "asym_ankl", + "diff_knee_angle", "dist_nose_pelvis" + ] +} + + +def debug_print(): + if VERBOSITY: + frame = inspect.currentframe().f_back + qualname = frame.f_code.co_qualname + print(qualname) + + +class GeneralPredictor: + def __init__(self): + debug_print() + self.base_paths = { + "Random Forest": "rf.pkl", + "XGBoost": "xgb.json", + "SVM": "svm.pkl", + "LSTM": "lstm.pth", + "1D-CNN": "cnn.pth" + } + self.raw_participant_buffer = [] + self.current_target = "" + self.scaler_cache = {} + + + def add_to_raw_buffer(self, raw_payload, y_labels): + """ + Adds a participant's raw kinematic components to the pool. + raw_payload should contain: 'z_kps', 'directions', 'raw_kps' + """ + debug_print() + entry = { + "raw_data": raw_payload, + "labels": y_labels + } + self.raw_participant_buffer.append(entry) + return f"Added participant to pool. Total participants: {len(self.raw_participant_buffer)}" + + + def clear_buffer(self): + """Clears the raw pool.""" + debug_print() + self.raw_participant_buffer = [] + + + def calculate_and_train(self, model_type, target_name): + """ + The 'On-the-Fly' engine. Loops through the raw buffer, + calculates features for the SELECTED target, and trains. + """ + debug_print() + self.current_target = target_name + all_X = [] + all_y = [] + + # 1. Process every participant in the pool + for participant in self.raw_participant_buffer: + raw = participant["raw_data"] + all_tracks = participant["labels"] + + # Pull the specific track that was requested + track_key = f"OBS: {target_name}" + if track_key not in all_tracks: + print(f"Warning: Track {track_key} not found for a participant. Skipping.") + continue + + y = all_tracks[track_key] + + # Extract lists from the payload + z_scores = raw["z_kps"] + dirs = raw["directions"] + kpts = raw["raw_kps"] + + # Calculate geometric features for every frame + participant_features = [] + for i in range(len(y)): + feat = self.format_features(z_scores[i], dirs[i], kpts[i]) + participant_features.append(feat) + + all_X.append(np.array(participant_features)) + all_y.append(y) + + # 2. Prepare for Training + X_combined = np.vstack(all_X) + y_combined = np.concatenate(all_y) + + # 3. Scale the data specifically for this target/model combo + scaler = StandardScaler() + X_scaled = scaler.fit_transform(X_combined) + scaler_path = self.get_path(model_type, is_scaler=True) + joblib.dump(scaler, scaler_path) + + # 4. Train/Test Split + X_train, X_test, y_train, y_test = train_test_split( + X_scaled, y_combined, test_size=0.2, stratify=y_combined, random_state=42 + ) + + # 5. Process with corresponding Model + if model_type == "Random Forest": + model = RandomForestClassifier(max_depth=15, n_estimators=100, class_weight="balanced") + model.fit(X_train, y_train) + + # Save the model + save_path = self.get_path(model_type) + joblib.dump(model, save_path) + + y_pred = model.predict(X_test) + + # Feature Importance for the UI + labels_names = self.get_feature_labels() + importances = model.feature_importances_ + feature_data = sorted(zip(labels_names, importances), key=lambda x: x[1], reverse=True) + ui_extras = "Top Predictors:
" + "
".join([f"{n}: {v:.3f}" for n, v in feature_data]) + file_extras = "Top Predictors:\n" + "\n".join([f"- {n}: {v:.3f}" for n, v in feature_data]) + + return self._evaluate_and_report(model_type, y_test, y_pred, ui_extras=ui_extras, file_extras=file_extras, target_name=target_name) + + # TODO: More than random forest + else: + return "Model type not yet implemented in calculate_and_train." + + + def get_path(self, model_type, is_scaler=False): + """Returns the specific file path for the target/model or its scaler.""" + debug_print() + suffix = self.base_paths[model_type] + + if is_scaler: + suffix = suffix.split('.')[0] + "_scaler.pkl" + + return f"ml_{self.current_target}_{suffix}" + + + def get_feature_labels(self): + """Returns labels only for features active in the current target.""" + debug_print() + active_keys = ACTIVITY_MAP.get(self.current_target, []) + return active_keys + + + def format_features(self, z_scores, directions, kpts): + """The 'Universal Parser' for geometric features.""" + # debug_print() + # Internal Math Helpers + if self.current_target == "ALL_FEATURES": + active_list = list(GEOMETRY_LIBRARY.keys()) + else: + active_list = ACTIVITY_MAP.get(self.current_target, ACTIVITY_MAP["Mouthing"]) + + def resolve_pt(idx): + if isinstance(idx, list): + # Calculate midpoint of all indices in the list + pts = [kpts[i] for i in idx] + return np.mean(pts, axis=0) + return kpts[idx] + + def get_dist(p1, p2): return np.linalg.norm(p1 - p2) + def get_angle(a, b, c): + try: + ba, bc = a - b, c - b + denom = (np.linalg.norm(ba) * np.linalg.norm(bc) + 1e-6) + cos = np.dot(ba, bc) / denom + return np.degrees(np.arccos(np.clip(cos, -1.0, 1.0))) / 180.0 + except: return 0.0 + + calculated_pool = {} + + try: + if kpts is None or len(kpts) < 13: raise ValueError() + # Reference scale (Shoulders) + scale = get_dist(kpts[5], kpts[6]) + 1e-6 + + # First Pass: Direct Geometries + for name, (f_type, indices, *meta) in GEOMETRY_LIBRARY.items(): + if f_type == "dist": + # Use resolve_pt for both indices + p1 = resolve_pt(indices[0]) + p2 = resolve_pt(indices[1]) + calculated_pool[name] = get_dist(p1, p2) / scale + + elif f_type == "angle": + # Use resolve_pt for all three indices + p1 = resolve_pt(indices[0]) + p2 = resolve_pt(indices[1]) + p3 = resolve_pt(indices[2]) + calculated_pool[name] = get_angle(p1, p2, p3) + + elif f_type == "z_diff": + # Z-scores are usually single indices, but we handle lists just in case + z1 = np.mean([z_scores[i] for i in indices[0]]) if isinstance(indices[0], list) else z_scores[indices[0]] + z2 = np.mean([z_scores[i] for i in indices[1]]) if isinstance(indices[1], list) else z_scores[indices[1]] + calculated_pool[name] = abs(z1 - z2) + + elif f_type == "head_offset": + p_target = resolve_pt(indices[0]) + p_mid = resolve_pt([indices[1], indices[2]]) # Midpoint of shoulders + calculated_pool[name] = abs(p_target[0] - p_mid[0]) / scale + + # Second Pass: Composite Geometries (Subtractions/Symmetry) + # We do this after so 'dist_l_ear_r_shld' is already calculated + for name, (f_type, indices, *meta) in GEOMETRY_LIBRARY.items(): + if f_type == "subtraction": + calculated_pool[name] = calculated_pool[indices[0]] - calculated_pool[indices[1]] + elif f_type == "abs_subtraction": + calculated_pool[name] = abs(calculated_pool[indices[0]] - calculated_pool[indices[1]]) + + except Exception: + # If a frame fails, fill the pool with zeros to prevent crashes + calculated_pool = {name: 0.0 for name in GEOMETRY_LIBRARY.keys()} + + # Final Extraction based on current_target + + active_list = ACTIVITY_MAP.get(self.current_target, ACTIVITY_MAP["Mouthing"]) + feature_vector = [calculated_pool[feat] for feat in active_list] + + return np.array(feature_vector, dtype=np.float32) + + def _prepare_pool_data(self): + """Merges buffer and fits scaler.""" + debug_print() + if not self.X_buffer: + return None, None, None + + X_total = np.vstack(self.X_buffer) + y_total = np.concatenate(self.y_buffer) + + # We always fit a fresh scaler on the current pool + scaler_file = f"{self.current_target}_scaler.pkl" + scaler = StandardScaler() + X_scaled = scaler.fit_transform(X_total) + joblib.dump(scaler, scaler_file) + + return X_scaled, y_total, scaler + + + def _evaluate_and_report(self, model_name, y_test, y_pred, extra_text="", ui_extras="", file_extras="", target_name=""): + """Generates unified metrics, confusion matrix, and reports for ANY model""" + debug_print() + prec = precision_score(y_test, y_pred, zero_division=0) + rec = recall_score(y_test, y_pred, zero_division=0) + f1 = f1_score(y_test, y_pred, zero_division=0) + + target = getattr(self, 'current_target', 'Activity') + display_labels = ['Rest', target] + # Plot Confusion Matrix + cm = confusion_matrix(y_test, y_pred) + plt.figure(figsize=(8, 6)) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=display_labels, + yticklabels=display_labels) + plt.title(f'{model_name} Detection: Predicted vs Actual') + plt.ylabel('Actual State') + plt.xlabel('Predicted State') + + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + plt.savefig(f"ml_{target_name}_confusion_matrix_rf_{timestamp}.png") + plt.close() + + # Classification Report String + report_str = classification_report(y_test, y_pred, + target_names=display_labels, + zero_division=0) + + # Build TXT File Content + report_text = f"MODEL PERFORMANCE REPORT: {model_name}\nGenerated: {timestamp}\n" + report_text += "="*40 + "\n" + report_text += report_str + "\n" + report_text += f"Precision: {prec:.4f}\nRecall: {rec:.4f}\nF1-Score: {f1:.4f}\n" + report_text += "="*40 + "\n" + extra_text + report_text += "="*40 + "\n" + file_extras + + with open(f"ml_{target_name}_performance_rf_{timestamp}.txt", "w") as f: + f.write(report_text) + + # Build UI String + ui_report = f""" + {model_name} Performance:
+ Precision: {prec:.2f} | Recall: {rec:.2f} | F1: {f1:.2f}
+
+ {ui_extras} + """ + return ui_report + + def calculate_directions(self, analysis_kps): + debug_print() + all_dirs = np.zeros((len(analysis_kps), 17)) + + for f in range(1, len(analysis_kps)): + deltas = analysis_kps[f] - analysis_kps[f-1] # Shape (17, 2) + + angles = np.arctan2(-deltas[:, 1], deltas[:, 0]) + all_dirs[f] = angles + + return all_dirs \ No newline at end of file diff --git a/updater.py b/updater.py new file mode 100644 index 0000000..be1272c --- /dev/null +++ b/updater.py @@ -0,0 +1,539 @@ +""" +Filename: updater.py +Description: Generic updater file + +Author: Tyler de Zeeuw +License: GPL-3.0 +""" + +# Built-in imports +import os +import re +import sys +import time +import shlex +import shutil +import zipfile +import traceback +import subprocess + +# External library imports +import psutil +import requests + +from PySide6.QtWidgets import QMessageBox +from PySide6.QtCore import QThread, Signal, QObject + + +class UpdateDownloadThread(QThread): + """ + Thread that downloads and extracts an update package and emits a signal on completion or error. + + Args: + download_url (str): URL of the update zip file to download. + latest_version (str): Version string of the latest update. + """ + + update_ready = Signal(str, str) + error_occurred = Signal(str) + + def __init__(self, download_url, latest_version, platform_name, app_name): + super().__init__() + self.download_url = download_url + self.latest_version = latest_version + self.platform_name = platform_name + self.app_name = app_name + + def run(self): + try: + local_filename = os.path.basename(self.download_url) + + if self.platform_name == 'darwin': + tmp_dir = f'/tmp/{self.app_name}tempupdate' + os.makedirs(tmp_dir, exist_ok=True) + local_path = os.path.join(tmp_dir, local_filename) + else: + local_path = os.path.join(os.getcwd(), local_filename) + + # Download the file + with requests.get(self.download_url, stream=True, timeout=15) as r: + r.raise_for_status() + with open(local_path, 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + + # Extract folder name (remove .zip) + if self.platform_name == 'darwin': + extract_folder = os.path.splitext(local_filename)[0] + extract_path = os.path.join(tmp_dir, extract_folder) + + else: + extract_folder = os.path.splitext(local_filename)[0] + extract_path = os.path.join(os.getcwd(), extract_folder) + + # Create the folder if not exists + os.makedirs(extract_path, exist_ok=True) + + # Extract the zip file contents + if self.platform_name == 'darwin': + subprocess.run(['ditto', '-xk', local_path, extract_path], check=True) + else: + with zipfile.ZipFile(local_path, 'r') as zip_ref: + zip_ref.extractall(extract_path) + + # Remove the zip once extracted and emit a signal + os.remove(local_path) + self.update_ready.emit(self.latest_version, extract_path) + + except Exception as e: + # Emit a signal signifying failure + self.error_occurred.emit(str(e)) + + + +class UpdateCheckThread(QThread): + """ + Thread that checks for updates by querying the API and emits a signal based on the result. + + Signals: + download_requested(str, str): Emitted with (download_url, latest_version) when an update is available. + no_update_available(): Emitted when no update is found or current version is up to date. + error_occurred(str): Emitted with an error message if the update check fails. + """ + + download_requested = Signal(str, str) + no_update_available = Signal() + error_occurred = Signal(str) + + def __init__(self, api_url, api_url_sec, current_version, platform_name, app_name): + super().__init__() + self.api_url = api_url + self.api_url_sec = api_url_sec + self.current_version = current_version + self.platform_name = platform_name + self.app_name = app_name + + def run(self): + # if not getattr(sys, 'frozen', False): + # self.error_occurred.emit("Application is not frozen (Development mode).") + # return + try: + latest_version, download_url = self.get_latest_release_for_platform() + if not latest_version: + self.no_update_available.emit() + return + + if not download_url: + self.error_occurred.emit(f"No download available for platform '{self.platform_name}'") + return + + if self.version_compare(latest_version, self.current_version) > 0: + self.download_requested.emit(download_url, latest_version) + else: + self.no_update_available.emit() + + except Exception as e: + self.error_occurred.emit(f"Update check failed: {e}") + + def version_compare(self, v1, v2): + def normalize(v): return [int(x) for x in v.split(".")] + return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2)) + + def get_latest_release_for_platform(self): + urls = [self.api_url, self.api_url_sec] + for url in urls: + try: + + response = requests.get(url, timeout=5) + response.raise_for_status() + releases = response.json() + + if not releases: + continue + + latest = next((r for r in releases if not r.get("prerelease") and not r.get("draft")), None) + + if not latest: + continue + + tag = latest["tag_name"].lstrip("v") + + for asset in latest.get("assets", []): + if self.platform_name in asset["name"].lower(): + return tag, asset["browser_download_url"] + + return tag, None + except (requests.RequestException, ValueError) as e: + continue + return None, None + + +class LocalPendingUpdateCheckThread(QThread): + """ + Thread that checks for locally pending updates by scanning the download directory and emits a signal accordingly. + + Args: + current_version (str): Current application version. + platform_suffix (str): Platform-specific suffix to identify update folders. + """ + + pending_update_found = Signal(str, str) + no_pending_update = Signal() + + def __init__(self, current_version, platform_suffix, platform_name, app_name): + super().__init__() + self.current_version = current_version + self.platform_suffix = platform_suffix + self.platform_name = platform_name + self.app_name = app_name + + def version_compare(self, v1, v2): + def normalize(v): return [int(x) for x in v.split(".")] + return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2)) + + def run(self): + if self.platform_name == 'darwin': + cwd = f'/tmp/{self.app_name}tempupdate' + else: + cwd = os.getcwd() + + pattern = re.compile(r".*-(\d+\.\d+\.\d+)" + re.escape(self.platform_suffix) + r"$") + found = False + + try: + for item in os.listdir(cwd): + folder_path = os.path.join(cwd, item) + if os.path.isdir(folder_path) and item.endswith(self.platform_suffix): + match = pattern.match(item) + if match: + folder_version = match.group(1) + if self.version_compare(folder_version, self.current_version) > 0: + self.pending_update_found.emit(folder_version, folder_path) + found = True + break + except: + pass + + if not found: + self.no_pending_update.emit() + + + + + +class UpdateManager(QObject): + """ + Orchestrates the update process. + Main apps should instantiate this and call check_for_updates(). + """ + + def __init__(self, main_window, api_url, api_url_sec, current_version, platform_name, platform_suffix, app_name): + super().__init__() + self.parent = main_window + self.api_url = api_url + self.api_url_sec = api_url_sec + self.current_version = current_version + self.platform_name = platform_name + self.platform_suffix = platform_suffix + self.app_name = app_name + + self.pending_update_version = None + self.pending_update_path = None + + + def manual_check_for_updates(self): + self.local_check_thread = LocalPendingUpdateCheckThread(self.current_version, self.platform_suffix, self.platform_name, self.app_name) + self.local_check_thread.pending_update_found.connect(self.on_pending_update_found) + self.local_check_thread.no_pending_update.connect(self.on_no_pending_update) + self.local_check_thread.start() + + def on_pending_update_found(self, version, folder_path): + self.parent.statusBar().showMessage(f"Pending update found: version {version}") + self.pending_update_version = version + self.pending_update_path = folder_path + self.show_pending_update_popup() + + def on_no_pending_update(self): + # No pending update found locally, start server check directly + self.parent.statusBar().showMessage("No pending local update found. Checking server...") + self.start_update_check_thread() + + def show_pending_update_popup(self): + msg_box = QMessageBox(self.parent) + msg_box.setWindowTitle("Pending Update Found") + msg_box.setText(f"A previously downloaded update for {self.app_name.upper()} (version {self.pending_update_version}) is available at:\n{self.pending_update_path}\nWould you like to install it now?") + install_now_button = msg_box.addButton("Install Now", QMessageBox.ButtonRole.AcceptRole) + install_later_button = msg_box.addButton("Install Later", QMessageBox.ButtonRole.RejectRole) + msg_box.exec() + + if msg_box.clickedButton() == install_now_button: + self.install_update(self.pending_update_path) + else: + self.parent.statusBar().showMessage("Pending update available. Install later.") + # After user dismisses, still check the server for new updates + self.start_update_check_thread() + + def start_update_check_thread(self): + self.check_thread = UpdateCheckThread(self.api_url, self.api_url_sec, self.current_version, self.platform_name, self.app_name) + self.check_thread.download_requested.connect(self.on_server_update_requested) + self.check_thread.no_update_available.connect(self.on_server_no_update) + self.check_thread.error_occurred.connect(self.on_error) + self.check_thread.start() + + def on_server_no_update(self): + self.parent.statusBar().showMessage("No new updates found on server.", 5000) + + def on_server_update_requested(self, download_url, latest_version): + if self.pending_update_version: + cmp = self.version_compare(latest_version, self.pending_update_version) + if cmp > 0: + # Server version is newer than pending update + self.parent.statusBar().showMessage(f"Newer version {latest_version} available on server. Removing old pending update...") + try: + shutil.rmtree(self.pending_update_path) + self.parent.statusBar().showMessage(f"Deleted old update folder: {self.pending_update_path}") + except Exception as e: + self.parent.statusBar().showMessage(f"Failed to delete old update folder: {e}") + + # Clear pending update info so new download proceeds + self.pending_update_version = None + self.pending_update_path = None + + # Download the new update + self.download_update(download_url, latest_version) + elif cmp == 0: + # Versions equal, no download needed + self.parent.statusBar().showMessage(f"Pending update version {self.pending_update_version} is already latest. No download needed.") + else: + # Server version older than pending? Unlikely but just keep pending update + self.parent.statusBar().showMessage(f"Pending update version {self.pending_update_version} is newer than server version. No action.") + else: + # No pending update, just download + self.download_update(download_url, latest_version) + + def download_update(self, download_url, latest_version): + self.parent.statusBar().showMessage("Downloading update...") + self.download_thread = UpdateDownloadThread(download_url, latest_version, self.platform_name, self.app_name) + self.download_thread.update_ready.connect(self.on_update_ready) + self.download_thread.error_occurred.connect(self.on_error) + self.download_thread.start() + + def on_update_ready(self, latest_version, extract_folder): + self.parent.statusBar().showMessage("Update downloaded and extracted.") + + msg_box = QMessageBox(self.parent) + msg_box.setWindowTitle("Update Ready") + msg_box.setText(f"Version {latest_version} has been downloaded and extracted to:\n{extract_folder}\nWould you like to install it now?") + install_now_button = msg_box.addButton("Install Now", QMessageBox.ButtonRole.AcceptRole) + install_later_button = msg_box.addButton("Install Later", QMessageBox.ButtonRole.RejectRole) + + msg_box.exec() + + if msg_box.clickedButton() == install_now_button: + self.install_update(extract_folder) + else: + self.parent.statusBar().showMessage("Update ready. Install later.") + + + def install_update(self, extract_folder): + # Path to updater executable + + if self.platform_name == 'windows': + updater_path = os.path.join(os.getcwd(), f"{self.app_name}_updater.exe") + elif self.platform_name == 'darwin': + if getattr(sys, 'frozen', False): + updater_path = os.path.join(os.path.dirname(sys.executable), f"../../../{self.app_name}_updater.app") + else: + updater_path = os.path.join(os.getcwd(), f"../{self.app_name}_updater.app") + + elif self.platform_name == 'linux': + updater_path = os.path.join(os.getcwd(), f"{self.app_name}_updater") + else: + updater_path = os.getcwd() + + if not os.path.exists(updater_path): + QMessageBox.critical(self.parent, "Error", f"Updater not found at:\n{updater_path}. The absolute path was {os.path.abspath(updater_path)}") + return + + # Launch updater with extracted folder path as argument + try: + # Pass current app's executable path for updater to relaunch + main_app_executable = os.path.abspath(sys.argv[0]) + + print(f'Launching updater with: "{updater_path}" "{extract_folder}" "{main_app_executable}"') + + if self.platform_name == 'darwin': + subprocess.Popen(['open', updater_path, '--args', extract_folder, main_app_executable]) + else: + subprocess.Popen([updater_path, f'{extract_folder}', f'{main_app_executable}'], cwd=os.path.dirname(updater_path)) + + # Close the current app so updater can replace files + sys.exit(0) + + except Exception as e: + QMessageBox.critical(self.parent, "Error", f"[Updater Launch Failed]\n{str(e)}\n{traceback.format_exc()}") + + def on_error(self, message): + # print(f"Error: {message}") + self.parent.statusBar().showMessage(f"Error occurred during update process. {message}") + + def version_compare(self, v1, v2): + def normalize(v): return [int(x) for x in v.split(".")] + return (normalize(v1) > normalize(v2)) - (normalize(v1) < normalize(v2)) + + +def wait_for_process_to_exit(process_name, timeout=10): + """ + Waits for a process with the specified name to exit within a timeout period. + + Args: + process_name (str): Name (or part of the name) of the process to wait for. + timeout (int, optional): Maximum time to wait in seconds. Defaults to 10. + + Returns: + bool: True if the process exited before the timeout, False otherwise. + """ + + print(f"Waiting for {process_name} to exit...") + deadline = time.time() + timeout + while time.time() < deadline: + still_running = False + for proc in psutil.process_iter(['name']): + try: + if proc.info['name'] and process_name.lower() in proc.info['name'].lower(): + still_running = True + print(f"Still running: {proc.info['name']} (PID: {proc.pid})") + break + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + if not still_running: + print(f"{process_name} has exited.") + return True + time.sleep(0.5) + print(f"{process_name} did not exit in time.") + return False + + +def finish_update_if_needed(platform_name, app_name): + """ + Completes a pending application update if '--finish-update' is present in the command-line arguments. + """ + + if "--finish-update" in sys.argv: + print("Finishing update...") + + if platform_name == 'darwin': + app_dir = f'/tmp/{app_name}tempupdate' + else: + app_dir = os.getcwd() + + # 1. Find update folder + update_folder = None + for entry in os.listdir(app_dir): + entry_path = os.path.join(app_dir, entry) + if os.path.isdir(entry_path) and entry.startswith(f"{app_name}-") and entry.endswith("-" + platform_name): + update_folder = os.path.join(app_dir, entry) + break + + if update_folder is None: + print("No update folder found. Skipping update steps.") + return + + if platform_name == 'darwin': + update_folder = os.path.join(update_folder, f"{app_name}-darwin") + + # 2. Wait for updater to exit + print(f"Waiting for {app_name}_updater to exit...") + for proc in psutil.process_iter(['pid', 'name']): + if proc.info['name'] and f"{app_name}_updater" in proc.info['name'].lower(): + try: + proc.wait(timeout=5) + except psutil.TimeoutExpired: + print(f"Force killing lingering {app_name}_updater") + proc.kill() + + # 3. Replace the updater + if platform_name == 'windows': + new_updater = os.path.join(update_folder, f"{app_name}_updater.exe") + dest_updater = os.path.join(app_dir, f"{app_name}_updater.exe") + + elif platform_name == 'darwin': + new_updater = os.path.join(update_folder, f"{app_name}_updater.app") + dest_updater = os.path.abspath(os.path.join(sys.executable, f"../../../../{app_name}_updater.app")) + + elif platform_name == 'linux': + new_updater = os.path.join(update_folder, f"{app_name}_updater") + dest_updater = os.path.join(app_dir, f"{app_name}_updater") + + else: + print("Unknown Platform") + new_updater = os.getcwd() + dest_updater = os.getcwd() + + print(f"New updater is {new_updater}") + print(f"Dest updater is {dest_updater}") + + print("Writable?", os.access(dest_updater, os.W_OK)) + print("Executable path:", sys.executable) + print("Trying to copy:", new_updater, "->", dest_updater) + + if os.path.exists(new_updater): + try: + if os.path.exists(dest_updater): + if platform_name == 'darwin': + try: + if os.path.isdir(dest_updater): + shutil.rmtree(dest_updater) + print(f"Deleted directory: {dest_updater}") + else: + os.remove(dest_updater) + print(f"Deleted file: {dest_updater}") + except Exception as e: + print(f"Error deleting {dest_updater}: {e}") + else: + os.remove(dest_updater) + + if platform_name == 'darwin': + wait_for_process_to_exit(f"{app_name}_updater", timeout=10) + subprocess.check_call(["ditto", new_updater, dest_updater]) + else: + shutil.copy2(new_updater, dest_updater) + + if platform_name in ('linux', 'darwin'): + os.chmod(dest_updater, 0o755) + + if platform_name == 'darwin': + remove_quarantine(dest_updater, app_name) + + print(f"{app_name}_updater replaced.") + except Exception as e: + print(f"Failed to replace {app_name}_updater: {e}") + + # 4. Delete the update folder + try: + if platform_name == 'darwin': + shutil.rmtree(app_dir) + else: + shutil.rmtree(update_folder) + except Exception as e: + print(f"Failed to delete update folder: {e}") + + QMessageBox.information(None, "Update Complete", "The application has been successfully updated.") + sys.argv.remove("--finish-update") + + +def remove_quarantine(app_path, app_name): + """ + Removes the macOS quarantine attribute from the specified application path. + """ + + script = f''' + do shell script "xattr -d -r com.apple.quarantine {shlex.quote(app_path)}" with administrator privileges with prompt "{app_name.upper()} needs privileges to finish the update. (2/2)" + ''' + try: + subprocess.run(['osascript', '-e', script], check=True) + print("✅ Quarantine attribute removed.") + except subprocess.CalledProcessError as e: + print("❌ Failed to remove quarantine attribute.") + print(e) \ No newline at end of file diff --git a/version_main.txt b/version_main.txt new file mode 100644 index 0000000..aee1c74 --- /dev/null +++ b/version_main.txt @@ -0,0 +1,29 @@ +VSVersionInfo( + ffi=FixedFileInfo( + filevers=(1, 0, 0, 0), + prodvers=(1, 0, 0, 0), + mask=0x3f, + flags=0x0, + OS=0x4, + fileType=0x1, + subtype=0x0, + date=(0, 0) + ), + kids=[ + StringFileInfo( + [ + StringTable( + '040904B0', + [StringStruct('CompanyName', 'Tyler de Zeeuw'), + StringStruct('FileDescription', 'BLAZES main application'), + StringStruct('FileVersion', '1.0.0.0'), + StringStruct('InternalName', 'blazes.exe'), + StringStruct('LegalCopyright', '© 2025-2026 Tyler de Zeeuw'), + StringStruct('OriginalFilename', 'blazes.exe'), + StringStruct('ProductName', 'BLAZES'), + StringStruct('ProductVersion', '1.0.0.0')]) + ] + ), + VarFileInfo([VarStruct('Translation', [1033, 1200])]) + ] +) diff --git a/version_updater.txt b/version_updater.txt new file mode 100644 index 0000000..cf06f0f --- /dev/null +++ b/version_updater.txt @@ -0,0 +1,29 @@ +VSVersionInfo( + ffi=FixedFileInfo( + filevers=(1, 0, 0, 0), + prodvers=(1, 0, 0, 0), + mask=0x3f, + flags=0x0, + OS=0x4, + fileType=0x1, + subtype=0x0, + date=(0, 0) + ), + kids=[ + StringFileInfo( + [ + StringTable( + '040904B0', + [StringStruct('CompanyName', 'Tyler de Zeeuw'), + StringStruct('FileDescription', 'BLAZES updater application'), + StringStruct('FileVersion', '1.0.0.0'), + StringStruct('InternalName', 'main.exe'), + StringStruct('LegalCopyright', '© 2025-2026 Tyler de Zeeuw'), + StringStruct('OriginalFilename', 'blazes_updater.exe'), + StringStruct('ProductName', 'BLAZES Updater'), + StringStruct('ProductVersion', '1.0.0.0')]) + ] + ), + VarFileInfo([VarStruct('Translation', [1033, 1200])]) + ] +) diff --git a/yolov8n-pose.pt b/yolov8n-pose.pt new file mode 100644 index 0000000..f41b11f Binary files /dev/null and b/yolov8n-pose.pt differ