diff --git a/.gitignore b/.gitignore index 36b13f1..0d9aef9 100644 --- a/.gitignore +++ b/.gitignore @@ -174,3 +174,5 @@ cython_debug/ # PyPI configuration file .pypirc +*.boris +*.json \ No newline at end of file diff --git a/main.py b/main.py index a2367cf..dc7fc0e 100644 --- a/main.py +++ b/main.py @@ -18,7 +18,7 @@ import platform import traceback from pathlib import Path from datetime import datetime -from multiprocessing import current_process, freeze_support +from multiprocessing import current_process, freeze_support, Process, Queue # External library imports import numpy as np @@ -30,14 +30,15 @@ from ultralytics import YOLO from updater import finish_update_if_needed, UpdateManager, LocalPendingUpdateCheckThread from predictor import GeneralPredictor +from pose_worker import run_pose_analysis from batch_processing import BatchProcessorDialog import PySide6 -from PySide6.QtWidgets import (QApplication, QMainWindow, QTabBar, QWidget, QVBoxLayout, QGraphicsView, QGraphicsScene, +from PySide6.QtWidgets import (QApplication, QLineEdit, QListWidget, QListWidgetItem, QMainWindow, QProgressDialog, QTabBar, QWidget, QVBoxLayout, QGraphicsView, QGraphicsScene, QHBoxLayout, QSplitter, QLabel, QPushButton, QComboBox, QInputDialog, - QFileDialog, QScrollArea, QMessageBox, QSlider, QTextEdit, QGroupBox, QGridLayout, QCheckBox, QTabWidget) -from PySide6.QtCore import Qt, QThread, Signal, QUrl, QRectF, QPointF, QRect, QSizeF -from PySide6.QtGui import QPainter, QColor, QFont, QPen, QBrush, QAction, QKeySequence, QIcon, QTextOption, QImage, QPixmap + QFileDialog, QScrollArea, QMessageBox, QSlider, QTextEdit, QGroupBox, QGridLayout, QCheckBox, QTabWidget, QProgressBar) +from PySide6.QtCore import QEvent, Qt, QThread, Signal, QUrl, QRectF, QPointF, QRect, QSizeF, QTimer +from PySide6.QtGui import QCursor, QGuiApplication, QPainter, QColor, QFont, QPen, QBrush, QAction, QKeySequence, QIcon, QTextOption, QImage, QPixmap from PySide6.QtMultimedia import QMediaPlayer, QAudioOutput from PySide6.QtMultimediaWidgets import QGraphicsVideoItem @@ -54,7 +55,7 @@ PLATFORM_NAME = platform.system().lower() def debug_print(): if VERBOSITY: frame = inspect.currentframe().f_back - qualname = frame.f_code.co_filename + qualname = frame.f_code.co_qualname print(qualname) @@ -109,6 +110,7 @@ from PySide6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QPushButton, from PySide6.QtGui import QPixmap, QImage from PySide6.QtCore import Qt + class OpenFileWindow(QWidget): def __init__(self, parent=None): super().__init__(parent, Qt.WindowType.Window) @@ -118,11 +120,30 @@ class OpenFileWindow(QWidget): # State self.video_path = None self.obs_file = None + self.pkl_path = None self.full_json_data = None self.current_video_fps = 30.0 self.current_video_offset = 0.0 + self.fps = 0 self.setup_ui() + self.center_on_screen() + + def center_on_screen(self): + """Centers the window on the current screen.""" + # Get the geometry of the screen where the mouse currently is + screen = QGuiApplication.screenAt(QCursor.pos()) + if not screen: + screen = QGuiApplication.primaryScreen() + + screen_geometry = screen.availableGeometry() + size = self.sizeHint() # Or self.geometry() if already sized + + x = (screen_geometry.width() - size.width()) // 2 + y = (screen_geometry.height() - size.height()) // 2 + + # Apply the coordinates (relative to the screen's top-left) + self.move(screen_geometry.left() + x, screen_geometry.top() + y) def setup_ui(self): self.setStyleSheet(""" @@ -167,58 +188,39 @@ class OpenFileWindow(QWidget): main_layout.addWidget(video_group) # --- Section 2: Analysis Modes --- - self.adv_group = QGroupBox("Behavioral Analysis Mode") - adv_layout = QVBoxLayout(self.adv_group) - self.combo_mode = QComboBox() - self.combo_mode.addItems(["BORIS Project (JSON)", "Trained ML Model (.pkl)", "Bypass / Manual"]) - adv_layout.addWidget(self.combo_mode) - - self.mode_stack = QStackedWidget() + self.boris_group = QGroupBox("Human Coding (BORIS)") + self.boris_group.setCheckable(True) # User can toggle this section off + self.boris_group.setChecked(False) + boris_layout = QGridLayout(self.boris_group) - # Mode 1: BORIS - self.page_boris = QWidget() - m1_grid = QGridLayout(self.page_boris) self.btn_boris_file = QPushButton("Load .boris File") - self.lbl_session = QLabel("Session Key:") self.combo_boris_keys = QComboBox() - self.lbl_slot = QLabel("Video Slot:") self.combo_video_slot = QComboBox() - # Initial Disabled States - for w in [self.lbl_session, self.combo_boris_keys, self.lbl_slot, self.combo_video_slot]: - w.setEnabled(False) + boris_layout.addWidget(QLabel("BORIS File:"), 0, 0) + boris_layout.addWidget(self.btn_boris_file, 0, 1) + boris_layout.addWidget(QLabel("Session:"), 1, 0) + boris_layout.addWidget(self.combo_boris_keys, 1, 1) + boris_layout.addWidget(QLabel("Slot:"), 2, 0) + boris_layout.addWidget(self.combo_video_slot, 2, 1) + main_layout.addWidget(self.boris_group) - m1_grid.addWidget(QLabel("BORIS File:"), 0, 0) - m1_grid.addWidget(self.btn_boris_file, 0, 1) - m1_grid.addWidget(self.lbl_session, 1, 0) - m1_grid.addWidget(self.combo_boris_keys, 1, 1) - m1_grid.addWidget(self.lbl_slot, 2, 0) - m1_grid.addWidget(self.combo_video_slot, 2, 1) + # --- Section 3: Trained ML Model --- + self.pkl_group = QGroupBox("Automated Prediction (.pkl)") + self.pkl_group.setCheckable(True) # User can toggle this section off + self.pkl_group.setChecked(False) + pkl_layout = QGridLayout(self.pkl_group) - # Mode 2: PKL Model - self.page_pkl = QWidget() - m2_grid = QGridLayout(self.page_pkl) self.btn_pkl_file = QPushButton("Load .pkl Model") self.lbl_pkl_path = QLabel("No model selected...") - m2_grid.addWidget(QLabel("Model File:"), 0, 0) - m2_grid.addWidget(self.btn_pkl_file, 0, 1) - m2_grid.addWidget(QLabel("Selected:"), 1, 0) - m2_grid.addWidget(self.lbl_pkl_path, 1, 1) + + pkl_layout.addWidget(QLabel("Model File:"), 0, 0) + pkl_layout.addWidget(self.btn_pkl_file, 0, 1) + pkl_layout.addWidget(QLabel("Path:"), 1, 0) + pkl_layout.addWidget(self.lbl_pkl_path, 1, 1) + main_layout.addWidget(self.pkl_group) - # Mode 3: Bypass - self.page_bypass = QWidget() - m3_layout = QVBoxLayout(self.page_bypass) - self.lbl_bypass_info = QLabel("Bypass Mode: No behavioral data will be loaded.\nManual annotation mode enabled.") - self.lbl_bypass_info.setStyleSheet("color: #888; font-style: italic;") - m3_layout.addWidget(self.lbl_bypass_info) - - self.mode_stack.addWidget(self.page_boris) - self.mode_stack.addWidget(self.page_pkl) - self.mode_stack.addWidget(self.page_bypass) - adv_layout.addWidget(self.mode_stack) - main_layout.addWidget(self.adv_group) - - # --- Section 3: Inference --- + # --- Section 4: Inference --- self.cfg_group = QGroupBox("Inference Settings") c_grid = QGridLayout(self.cfg_group) self.check_use_cache = QCheckBox("Auto-search pose cache (.npy)") @@ -226,7 +228,7 @@ class OpenFileWindow(QWidget): self.lbl_model_prompt = QLabel("Pose Model:") self.combo_inference_model = QComboBox() - self.combo_inference_model.addItems(["YOLO11n-Pose", "YOLO11m-Pose", "Mediapipe BlazePose"]) + self.combo_inference_model.addItems(["YOLO8n-Pose", "YOLO8m-Pose", "Mediapipe BlazePose"]) self.check_bypass_inference = QCheckBox("Bypass Pose Inference") self.lbl_inf_warning = QLabel("⚠ WARNING: Nothing fancy. Raw video playback only.") @@ -251,30 +253,42 @@ class OpenFileWindow(QWidget): # Connections self.btn_pick_video.clicked.connect(self.handle_video_selection) - self.combo_mode.currentIndexChanged.connect(self.mode_stack.setCurrentIndex) self.btn_boris_file.clicked.connect(self.handle_boris_load) + self.btn_pkl_file.clicked.connect(self.handle_pkl_selection) self.combo_boris_keys.currentIndexChanged.connect(self.handle_session_change) self.combo_video_slot.currentIndexChanged.connect(self.handle_slot_change) - self.btn_pkl_file.clicked.connect(self.handle_pkl_selection) self.check_bypass_inference.toggled.connect(self.handle_inference_toggle) self.btn_cancel.clicked.connect(self.close) + self.boris_group.toggled.connect(self.update_metadata_display) + def format_time(self, seconds): h, m, s = int(seconds // 3600), int((seconds % 3600) // 60), int(seconds % 60) return f"{h:02d}:{m:02d}:{s:02d}" + def handle_video_selection(self): path, _ = QFileDialog.getOpenFileName(self, "Open Video", "", "Video Files (*.mp4 *.avi *.mkv)") if path: self.video_path = path self.lbl_video_path.setText(os.path.basename(path)) + + # Open video to extract properties cap = cv2.VideoCapture(path) - fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 - time_str = self.format_time(cap.get(cv2.CAP_PROP_FRAME_COUNT) / fps) - self.lbl_video_metadata.setText(f"RES: {int(cap.get(3))}x{int(cap.get(4))} | FPS: {fps:.2f} | LEN: {time_str}") + self.fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 + self.total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + self.video_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.video_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) cap.release() + self.render_preview(path) - if self.full_json_data: self.attempt_auto_match() + + # If BORIS JSON is already loaded, try to find this video in the slots + if self.full_json_data: + self.attempt_auto_match() + + # Refresh the label with all 4 fields (Res, FPS, Len, and Offset) + self.update_metadata_display() def handle_boris_load(self): path, _ = QFileDialog.getOpenFileName(self, "Select JSON", "", "JSON Files (*.json *.boris)") @@ -285,8 +299,8 @@ class OpenFileWindow(QWidget): with open(path, 'r') as f: self.full_json_data = json.load(f) obs = self.full_json_data.get("observations", {}) + print(f"\n[DEBUG] BORIS File Loaded. Found {len(obs)} sessions.") self.combo_boris_keys.setEnabled(True) - self.lbl_session.setEnabled(True) self.combo_boris_keys.clear() self.combo_boris_keys.addItems(list(obs.keys())) if self.video_path: self.attempt_auto_match() @@ -301,44 +315,100 @@ class OpenFileWindow(QWidget): self.combo_video_slot.blockSignals(True) self.combo_video_slot.clear() - slots = list(file_map.keys()) - self.combo_video_slot.addItems(slots) - self.combo_video_slot.setEnabled(True) - self.lbl_slot.setEnabled(True) - self.combo_video_slot.blockSignals(False) - if self.video_path: - video_filename = os.path.basename(self.video_path) - for i, slot in enumerate(slots): - if any(video_filename in f for f in file_map[slot]): - self.combo_video_slot.setCurrentIndex(i) - break + print(f"[DEBUG] Filtering slots for session: {session_key}") + for slot, files in file_map.items(): + # Check if there is at least one non-empty string in the list + valid_files = [f for f in files if isinstance(f, str) and f.strip()] + if valid_files: + display_name = os.path.basename(valid_files[0].replace('\\', '/')) + print(f" > Valid slot found: {slot} ({display_name})") + self.combo_video_slot.addItem(f"Slot {slot}: {display_name}", slot) + + self.combo_video_slot.setEnabled(True) + self.combo_video_slot.blockSignals(False) + self.handle_slot_change() + + def attempt_auto_match(self): + """Debugged auto-match: Scans all slots in all sessions for the filename.""" + if not self.video_path or not self.full_json_data: + return + + target_name = os.path.basename(self.video_path) + print(f"\n[DEBUG] ATTEMPTING AUTO-MATCH FOR: {target_name}") + + obs = self.full_json_data.get("observations", {}) + + for s_idx, (session_key, content) in enumerate(obs.items()): + file_map = content.get("file", {}) + for slot, files in file_map.items(): + for f_path in files: + # Normalize path for comparison + clean_f_path = f_path.replace('\\', '/') + json_filename = os.path.basename(clean_f_path) + + if json_filename == target_name: + print(f"[DEBUG] !!! MATCH FOUND !!!") + print(f" Session: {session_key}") + print(f" Slot: {slot}") + + # Update UI + self.combo_boris_keys.setCurrentIndex(s_idx) + # We must allow handle_session_change to finish before setting slot + for i in range(self.combo_video_slot.count()): + if self.combo_video_slot.itemData(i) == slot: + self.combo_video_slot.setCurrentIndex(i) + break + return + + print(f"[DEBUG] No match found for {target_name} in the JSON file mapping.") def handle_slot_change(self): session_key = self.combo_boris_keys.currentText() - slot_key = self.combo_video_slot.currentText() - if not session_key or not slot_key: return + # Pull the slot ID (e.g., "1") we stored in handle_session_change + slot_id = self.combo_video_slot.currentData() + + if not session_key or slot_id is None: + return + session_data = self.full_json_data.get("observations", {}).get(session_key, {}) - val = session_data.get("media_info", {}).get("offset", {}).get(str(slot_key)) + + # Navigate: media_info -> offset -> {slot_id} + offsets = session_data.get("media_info", {}).get("offset", {}) + val = offsets.get(str(slot_id)) # Ensure it's a string key + if val is not None: self.current_video_offset = float(val) - txt = self.lbl_video_metadata.text().split(" | Offset:")[0] - self.lbl_video_metadata.setText(f"{txt} | Offset: {self.current_video_offset}s") + else: + self.current_video_offset = 0.0 + + self.update_metadata_display() + + def update_metadata_display(self): + # Only update if a video has been selected + if not self.video_path: + self.lbl_video_metadata.setText("Metadata: N/A") + return + + # Check if we are in BORIS mode and have a valid offset + if self.boris_group.isChecked(): + offset_str = f" | Offset: {self.current_video_offset}s" + else: + offset_str = "" + + # Assemble the final string + # Assuming self.fps and self.total_frames were set in handle_video_selection + time_str = self.format_time(self.total_frames / self.fps) + + # Get original metadata text but update the offset part + base_text = f"RES: {self.video_w}x{self.video_h} | FPS: {self.fps:.2f} | LEN: {time_str}" + self.lbl_video_metadata.setText(f"{base_text}{offset_str}") - def attempt_auto_match(self): - obs = self.full_json_data.get("observations", {}) - v_name = os.path.splitext(os.path.basename(self.video_path))[0] - v_parts = v_name.split('_') - v_fp = f"{v_parts[0]}_{v_parts[1]}_{v_parts[-1]}" if len(v_parts) >= 3 else v_name - for i, sk in enumerate(obs.keys()): - s_parts = sk.split('_') - if len(s_parts) == 3 and f"{s_parts[0]}_{s_parts[1]}_{s_parts[2]}".lower() == v_fp.lower(): - self.combo_boris_keys.setCurrentIndex(i) - return def handle_pkl_selection(self): path, _ = QFileDialog.getOpenFileName(self, "Select Model", "", "Pickle Files (*.pkl)") if path: + self.pkl_path = path self.lbl_pkl_path.setText(os.path.basename(path)) def handle_inference_toggle(self, checked): @@ -357,7 +427,216 @@ class OpenFileWindow(QWidget): cap.release() + def get_config(self): + """Returns a dictionary of all user-selected settings.""" + return { + "video_path": self.video_path, + "total_frames": getattr(self, 'total_frames', 0), + "fps": self.fps, + + # BORIS Data + "use_boris": self.boris_group.isChecked(), + "obs_file": self.obs_file if self.boris_group.isChecked() else None, + "session_key": self.combo_boris_keys.currentText(), + "slot": self.combo_video_slot.currentData(), + "offset": self.current_video_offset, + + # ML Model Data + "use_pkl": self.pkl_group.isChecked(), + "pkl_path": self.pkl_path if self.pkl_group.isChecked() else None, + + # Inference Settings + "use_pose": not self.check_bypass_inference.isChecked(), + "pose_model": self.combo_inference_model.currentText(), + "use_cache": self.check_use_cache.isChecked(), + } + + +import os +from PySide6.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QPushButton, + QLabel, QFileDialog, QFrame, QComboBox) +from PySide6.QtCore import Qt + +class TrainModelWindow(QDialog): + def __init__(self, parent=None): + super().__init__(parent) + self.setWindowTitle(f"Train Model - {APP_NAME.upper()}") + self.setFixedSize(500, 550) # Slightly taller to fit stats + self.selected_folder = None + self.valid_pairs = [] # Stores (json_path, csv_path) + + self.setup_ui() + + def setup_ui(self): + layout = QVBoxLayout(self) + layout.setSpacing(12) + + # --- Section 1: Folder Selection --- + self.path_display = QLabel("No folder selected...") + self.path_display.setStyleSheet("background: #1e1e1e; padding: 8px; border-radius: 3px;") + btn_browse = QPushButton("Select Training Folder") + btn_browse.clicked.connect(self.browse_folder) + + layout.addWidget(QLabel("Data Source:")) + layout.addWidget(self.path_display) + layout.addWidget(btn_browse) + + # --- Section 2: Behavior Selection (Multi-Select) --- + layout.addWidget(QLabel("Select Target Behavior(s):")) + self.behavior_list = QListWidget() + self.behavior_list.setMinimumHeight(150) + self.behavior_list.itemChanged.connect(self.handle_selection_change) + layout.addWidget(self.behavior_list) + + # --- Section 3: Group Name (Conditional) --- + self.group_name_container = QWidget() + group_layout = QVBoxLayout(self.group_name_container) + group_layout.setContentsMargins(0, 0, 0, 0) + + group_layout.addWidget(QLabel("Combined Variable Name:")) + self.edit_group_name = QLineEdit() + self.edit_group_name.setPlaceholderText("e.g., Total_Movement") + group_layout.addWidget(self.edit_group_name) + + self.group_name_container.hide() # Hidden by default + layout.addWidget(self.group_name_container) + + # --- Section 4: Folder Statistics --- + self.stats_display = QLabel("Valid Pairs Found: 0") + self.stats_display.setStyleSheet("color: #00ffaa; font-family: 'Consolas'; background: #111; padding: 10px;") + layout.addWidget(self.stats_display) + + # --- Section 5: ML Architecture --- + self.method_dropdown = QComboBox() + self.method_dropdown.addItems(["Random Forest", "1D-CNN", "LSTM", "XGBoost"]) + layout.addWidget(QLabel("ML Architecture:")) + layout.addWidget(self.method_dropdown) + + layout.addStretch() + + # --- Final Actions --- + button_box = QHBoxLayout() + self.btn_train = QPushButton("Start Training") + self.btn_train.setEnabled(False) + self.btn_train.setStyleSheet("background-color: #2e7d32; font-weight: bold; padding: 8px;") + self.btn_train.clicked.connect(self.accept) + + btn_cancel = QPushButton("Cancel") + btn_cancel.clicked.connect(self.reject) + + button_box.addWidget(btn_cancel) + button_box.addWidget(self.btn_train) + layout.addLayout(button_box) + + + def browse_folder(self): + folder = QFileDialog.getExistingDirectory(self, "Select Training Data Folder") + if folder: + self.selected_folder = folder + self.path_display.setText(folder) + self.scan_and_parse_folder(folder) + + def scan_and_parse_folder(self, folder): + """Scans for pairs and tracks per-behavior statistics.""" + self.valid_pairs = [] + + # Structure: { "Mouthing": {"count": 0, "frames": 0}, ... } + behavior_stats = {} + + total_global_events = 0 + total_global_frames = 0 + + files = os.listdir(folder) + json_files = [f for f in files if f.endswith("_metrics.json")] + + for j_file in json_files: + base_name = j_file.replace("_metrics.json", "") + csv_file = base_name + "_pose_raw.csv" + + json_path = os.path.join(folder, j_file) + csv_path = os.path.join(folder, csv_file) + + if os.path.exists(csv_path): + self.valid_pairs.append((json_path, csv_path)) + + try: + with open(json_path, 'r') as f: + data = json.load(f) + behaviors = data.get("behaviors", {}) + fps = data.get("metadata", {}).get("fps", 30.0) + + for b_name, instances in behaviors.items(): + if b_name not in behavior_stats: + behavior_stats[b_name] = {"count": 0, "frames": 0} + + count = len(instances) + frames = sum(inst.get("duration_frames", 0) for inst in instances) + + behavior_stats[b_name]["count"] += count + behavior_stats[b_name]["frames"] += frames + + total_global_events += count + total_global_frames += frames + except Exception as e: + print(f"Error parsing {j_file}: {e}") + + # --- Update Dataset Summary Label --- + pair_count = len(self.valid_pairs) + total_sec = total_global_frames / 30.0 # Standardized estimate + + stats_text = ( + f"Valid Pairs Found: {pair_count}\n" + f"Total Event Instances: {total_global_events}\n" + f"Total Behavior Time: {total_sec:.2f}s" + ) + self.stats_display.setText(stats_text) + + # --- Populate Dropdown with Detailed Labels --- + self.behavior_list.clear() + for b_name in sorted(behavior_stats.keys()): + stats = behavior_stats[b_name] + sec = stats["frames"] / 30.0 + label = f"{b_name} ({stats['count']} events, {sec:.1f}s)" + + item = QListWidgetItem(label) + item.setFlags(item.flags() | Qt.ItemIsUserCheckable) + item.setCheckState(Qt.Unchecked) + item.setData(Qt.UserRole, b_name) # Store clean name + self.behavior_list.addItem(item) + + + def handle_selection_change(self): + """Shows/Hides the name input based on how many boxes are checked.""" + selected_items = [self.behavior_list.item(i) for i in range(self.behavior_list.count()) + if self.behavior_list.item(i).checkState() == Qt.Checked] + + count = len(selected_items) + self.group_name_container.setVisible(count > 1) + self.btn_train.setEnabled(count > 0) + + + def get_selection(self): + """Returns the specific behaviors to combine and the final variable name.""" + selected_names = [self.behavior_list.item(i).data(Qt.UserRole) + for i in range(self.behavior_list.count()) + if self.behavior_list.item(i).checkState() == Qt.Checked] + + # If multiple are selected, use the text field name; otherwise use the single name + if len(selected_names) > 1: + final_name = self.edit_group_name.text().strip() or "combined_variable" + else: + final_name = selected_names[0] if selected_names else None + + return { + "folder": self.selected_folder, + "pairs": self.valid_pairs, + "selected_behaviors": selected_names, + "target_name": final_name, + "model_type": self.method_dropdown.currentText() + } + + class AboutWindow(QWidget): """ @@ -432,387 +711,387 @@ class UserGuideWindow(QWidget): -class PoseAnalyzerWorker(QThread): - progress = Signal(str) - finished_data = Signal(dict) +# class PoseAnalyzerWorker(QThread): +# progress = Signal(str) +# finished_data = Signal(dict) - def __init__(self, video_path, obs_info=None, predictor=None): - debug_print() - super().__init__() - self.video_path = video_path - self.obs_info = obs_info - self.predictor = predictor - self.pose_df = pd.DataFrame() +# def __init__(self, video_path, obs_info=None, predictor=None): +# debug_print() +# super().__init__() +# self.video_path = video_path +# self.obs_info = obs_info +# self.predictor = predictor +# self.pose_df = pd.DataFrame() - def get_best_infant_match(self, results, w, h, prev_track_id): - debug_print() - if not results[0].boxes or results[0].boxes.id is None: - return None, None, None, None - ids = results[0].boxes.id.int().cpu().tolist() - kpts = results[0].keypoints.xy.cpu().numpy() - confs = results[0].keypoints.conf.cpu().numpy() - best_idx, best_score = -1, -1 - for i, k in enumerate(kpts): - vis = np.sum(confs[i] > 0.5) - valid = k[confs[i] > 0.5] - dist = np.linalg.norm(np.mean(valid, axis=0) - [w/2, h/2]) if len(valid) > 0 else 1000 - score = (vis * 10) - (dist * 0.1) + (50 if ids[i] == prev_track_id else 0) - if score > best_score: - best_score, best_idx = score, i - if best_idx == -1: - return None, None, None, None - return ids[best_idx], kpts[best_idx], confs[best_idx], best_idx +# def get_best_infant_match(self, results, w, h, prev_track_id): +# debug_print() +# if not results[0].boxes or results[0].boxes.id is None: +# return None, None, None, None +# ids = results[0].boxes.id.int().cpu().tolist() +# kpts = results[0].keypoints.xy.cpu().numpy() +# confs = results[0].keypoints.conf.cpu().numpy() +# best_idx, best_score = -1, -1 +# for i, k in enumerate(kpts): +# vis = np.sum(confs[i] > 0.5) +# valid = k[confs[i] > 0.5] +# dist = np.linalg.norm(np.mean(valid, axis=0) - [w/2, h/2]) if len(valid) > 0 else 1000 +# score = (vis * 10) - (dist * 0.1) + (50 if ids[i] == prev_track_id else 0) +# if score > best_score: +# best_score, best_idx = score, i +# if best_idx == -1: +# return None, None, None, None +# return ids[best_idx], kpts[best_idx], confs[best_idx], best_idx - def _merge_json_observations(self, timeline_events, fps): - """Restores the grouping and block-pairing logic from the observation files.""" - debug_print() - if not self.obs_info: - return +# def _merge_json_observations(self, timeline_events, fps): +# """Restores the grouping and block-pairing logic from the observation files.""" +# debug_print() +# if not self.obs_info: +# return - self.progress.emit("Merging JSON Observations...") - json_path, subkey = self.obs_info +# self.progress.emit("Merging JSON Observations...") +# json_path, subkey = self.obs_info - # try: - # with open(json_path, 'r') as f: - # full_json = json.load(f) +# # try: +# # with open(json_path, 'r') as f: +# # full_json = json.load(f) - # # Extract events for the specific subkey (e.g., 'Participant_01') - # raw_obs_events = full_json["observations"][subkey]["events"] - # raw_obs_events.sort(key=lambda x: x[0]) # Sort by timestamp +# # # Extract events for the specific subkey (e.g., 'Participant_01') +# # raw_obs_events = full_json["observations"][subkey]["events"] +# # raw_obs_events.sort(key=lambda x: x[0]) # Sort by timestamp - # # Group frames by label - # obs_groups = {} - # for ev in raw_obs_events: - # time_sec, _, label, special = ev[0], ev[1], ev[2], ev[3] - # frame = int(time_sec * fps) - # if label not in obs_groups: - # obs_groups[label] = [] - # obs_groups[label].append(frame) +# # # Group frames by label +# # obs_groups = {} +# # for ev in raw_obs_events: +# # time_sec, _, label, special = ev[0], ev[1], ev[2], ev[3] +# # frame = int(time_sec * fps) +# # if label not in obs_groups: +# # obs_groups[label] = [] +# # obs_groups[label].append(frame) - # # Convert groups of frames into (Start, End) blocks - # for label, frames in obs_groups.items(): - # track_name = f"OBS: {label}" - # processed_blocks = [] +# # # Convert groups of frames into (Start, End) blocks +# # for label, frames in obs_groups.items(): +# # track_name = f"OBS: {label}" +# # processed_blocks = [] - # # Step by 2 to create start/end pairs - # for i in range(0, len(frames) - 1, 2): - # start_f = frames[i] - # end_f = frames[i+1] - # processed_blocks.append((start_f, end_f, "Moderate", "Manual")) +# # # Step by 2 to create start/end pairs +# # for i in range(0, len(frames) - 1, 2): +# # start_f = frames[i] +# # end_f = frames[i+1] +# # processed_blocks.append((start_f, end_f, "Moderate", "Manual")) - # # Register the track globally if it's new - # if track_name not in TRACK_NAMES: - # TRACK_NAMES.append(track_name) - # TRACK_COLORS.append(QColor("#AA00FF")) # Purple for Observations +# # # Register the track globally if it's new +# # if track_name not in TRACK_NAMES: +# # TRACK_NAMES.append(track_name) +# # TRACK_COLORS.append(QColor("#AA00FF")) # Purple for Observations - # timeline_events[track_name] = processed_blocks +# # timeline_events[track_name] = processed_blocks - # except Exception as e: - # print(f"Error parsing JSON Observations: {e}") +# # except Exception as e: +# # print(f"Error parsing JSON Observations: {e}") - try: - with open(json_path, 'r') as f: - full_json = json.load(f) +# try: +# with open(json_path, 'r') as f: +# full_json = json.load(f) - raw_obs_events = full_json["observations"][subkey]["events"] - raw_obs_events.sort(key=lambda x: x[0]) +# raw_obs_events = full_json["observations"][subkey]["events"] +# raw_obs_events.sort(key=lambda x: x[0]) - # NEW LOGIC: Use a dictionary to store frames for specific track names - # track_name -> [list of frames] - obs_groups = {} +# # NEW LOGIC: Use a dictionary to store frames for specific track names +# # track_name -> [list of frames] +# obs_groups = {} - for ev in raw_obs_events: - # ev structure: [time_sec, unknown, label, special] - time_sec, label, special = ev[0], ev[2], ev[3] - frame = int(time_sec * fps) +# for ev in raw_obs_events: +# # ev structure: [time_sec, unknown, label, special] +# time_sec, label, special = ev[0], ev[2], ev[3] +# frame = int(time_sec * fps) - # Determine which tracks this event belongs to - target_tracks = [] +# # Determine which tracks this event belongs to +# target_tracks = [] - if special == "Left": - target_tracks.append(f"OBS: {label} (Left)") - elif special == "Right": - target_tracks.append(f"OBS: {label} (Right)") - elif special == "Both": - target_tracks.append(f"OBS: {label} (Left)") - target_tracks.append(f"OBS: {label} (Right)") - else: - # No special or unrecognized value - target_tracks.append(f"OBS: {label}") +# if special == "Left": +# target_tracks.append(f"OBS: {label} (Left)") +# elif special == "Right": +# target_tracks.append(f"OBS: {label} (Right)") +# elif special == "Both": +# target_tracks.append(f"OBS: {label} (Left)") +# target_tracks.append(f"OBS: {label} (Right)") +# else: +# # No special or unrecognized value +# target_tracks.append(f"OBS: {label}") - # Add the frame to all applicable tracks - for t_name in target_tracks: - if t_name not in obs_groups: - obs_groups[t_name] = [] - obs_groups[t_name].append(frame) +# # Add the frame to all applicable tracks +# for t_name in target_tracks: +# if t_name not in obs_groups: +# obs_groups[t_name] = [] +# obs_groups[t_name].append(frame) - # Convert frame groups into (Start, End) blocks - for track_name, frames in obs_groups.items(): - processed_blocks = [] +# # Convert frame groups into (Start, End) blocks +# for track_name, frames in obs_groups.items(): +# processed_blocks = [] - # Step by 2 to create start/end pairs (ensures matching pairs per track) +# # Step by 2 to create start/end pairs (ensures matching pairs per track) - if "Sync" in track_name and len(frames) == 1: - start_f = frames[0] - end_f = start_f + 1 # Give it a visible width on the timeline - processed_blocks.append((start_f, end_f, "Moderate", "Manual")) +# if "Sync" in track_name and len(frames) == 1: +# start_f = frames[0] +# end_f = start_f + 1 # Give it a visible width on the timeline +# processed_blocks.append((start_f, end_f, "Moderate", "Manual")) - else: - for i in range(0, len(frames) - 1, 2): - start_f = frames[i] - end_f = frames[i+1] - processed_blocks.append((start_f, end_f, "Moderate", "Manual")) +# else: +# for i in range(0, len(frames) - 1, 2): +# start_f = frames[i] +# end_f = frames[i+1] +# processed_blocks.append((start_f, end_f, "Moderate", "Manual")) - # Register the track in global lists if not already there - if track_name not in TRACK_NAMES: - TRACK_NAMES.append(track_name) - # Using Purple for Observations - TRACK_COLORS.append(QColor("#AA00FF")) +# # Register the track in global lists if not already there +# if track_name not in TRACK_NAMES: +# TRACK_NAMES.append(track_name) +# # Using Purple for Observations +# TRACK_COLORS.append(QColor("#AA00FF")) - timeline_events[track_name] = processed_blocks +# timeline_events[track_name] = processed_blocks - except Exception as e: - print(f"Error parsing JSON Observations: {e}") +# except Exception as e: +# print(f"Error parsing JSON Observations: {e}") - def _run_existing_ml_models(self, z_kps, dirs, raw_kpts): - debug_print() - """ - Scans for trained models and generates timeline tracks for each. - """ - ai_events = {} +# def _run_existing_ml_models(self, z_kps, dirs, raw_kpts): +# debug_print() +# """ +# Scans for trained models and generates timeline tracks for each. +# """ +# ai_events = {} - # 1. Match the pattern from your GeneralPredictor: {Target}_rf.pkl - model_files = glob.glob("*_rf.pkl") - print(f"DEBUG: Found model files: {model_files}") +# # 1. Match the pattern from your GeneralPredictor: {Target}_rf.pkl +# model_files = glob.glob("*_rf.pkl") +# print(f"DEBUG: Found model files: {model_files}") - for model_path in model_files: - try: - # Extract Target (e.g., "Mouthing" from "Mouthing_rf.pkl") - base_name = model_path.split("_rf.pkl")[0] - target = base_name.replace("ml_", "", 1) - track_name = f"AI: {target}" +# for model_path in model_files: +# try: +# # Extract Target (e.g., "Mouthing" from "Mouthing_rf.pkl") +# base_name = model_path.split("_rf.pkl")[0] +# target = base_name.replace("ml_", "", 1) +# track_name = f"AI: {target}" - self.progress.emit(f"Loading AI Observations for {target}...") +# self.progress.emit(f"Loading AI Observations for {target}...") - # 2. Match the Scaler naming from calculate_and_train: - # {target}_random_forest_scaler.pkl - scaler_path = f"{base_name}_rf_scaler.pkl" +# # 2. Match the Scaler naming from calculate_and_train: +# # {target}_random_forest_scaler.pkl +# scaler_path = f"{base_name}_rf_scaler.pkl" - if not os.path.exists(scaler_path): - print(f"DEBUG: Skipping {target}, scaler not found at {scaler_path}") - continue +# if not os.path.exists(scaler_path): +# print(f"DEBUG: Skipping {target}, scaler not found at {scaler_path}") +# continue - # Load assets - model = joblib.load(model_path) - scaler = joblib.load(scaler_path) +# # Load assets +# model = joblib.load(model_path) +# scaler = joblib.load(scaler_path) - # 3. Feature extraction (On-the-fly) - all_features = [] - # We must set the predictor's target so format_features uses the correct ACTIVITY_MAP - self.predictor.current_target = target +# # 3. Feature extraction (On-the-fly) +# all_features = [] +# # We must set the predictor's target so format_features uses the correct ACTIVITY_MAP +# self.predictor.current_target = target - for f_idx in range(len(z_kps)): - feat = self.predictor.format_features(z_kps[f_idx], dirs[f_idx], raw_kpts[f_idx]) - all_features.append(feat) +# for f_idx in range(len(z_kps)): +# feat = self.predictor.format_features(z_kps[f_idx], dirs[f_idx], raw_kpts[f_idx]) +# all_features.append(feat) - # 4. Inference - X = np.array(all_features) - X_scaled = scaler.transform(X) - predictions = model.predict(X_scaled) +# # 4. Inference +# X = np.array(all_features) +# X_scaled = scaler.transform(X) +# predictions = model.predict(X_scaled) - # 5. Convert binary 0/1 to blocks - processed_blocks = [] - start_f = None +# # 5. Convert binary 0/1 to blocks +# processed_blocks = [] +# start_f = None - for f_idx, val in enumerate(predictions): - if val == 1 and start_f is None: - start_f = f_idx - elif val == 0 and start_f is not None: - # [start, end, severity, direction] - processed_blocks.append((start_f, f_idx - 1, "Large", "AI")) - start_f = None +# for f_idx, val in enumerate(predictions): +# if val == 1 and start_f is None: +# start_f = f_idx +# elif val == 0 and start_f is not None: +# # [start, end, severity, direction] +# processed_blocks.append((start_f, f_idx - 1, "Large", "AI")) +# start_f = None - if start_f is not None: - processed_blocks.append((start_f, len(predictions)-1, "Large", "AI")) +# if start_f is not None: +# processed_blocks.append((start_f, len(predictions)-1, "Large", "AI")) - # 6. Global Registration - if track_name not in TRACK_NAMES: - TRACK_NAMES.append(track_name) - # Ensure TRACK_COLORS has an entry for this new track - TRACK_COLORS.append(QColor("#00FF00")) +# # 6. Global Registration +# if track_name not in TRACK_NAMES: +# TRACK_NAMES.append(track_name) +# # Ensure TRACK_COLORS has an entry for this new track +# TRACK_COLORS.append(QColor("#00FF00")) - ai_events[track_name] = processed_blocks - print(f"DEBUG: Successfully generated {len(processed_blocks)} blocks for {track_name}") +# ai_events[track_name] = processed_blocks +# print(f"DEBUG: Successfully generated {len(processed_blocks)} blocks for {track_name}") - except Exception as e: - print(f"Inference Error for {model_path}: {e}") +# except Exception as e: +# print(f"Inference Error for {model_path}: {e}") - return ai_events +# return ai_events - def classify_delta(self, z): - # debug_print() - z_abs = abs(z) - if z_abs < 1: return "Rest" - elif z_abs < 2: return "Small" - elif z_abs < 3: return "Moderate" - else: return "Large" +# def classify_delta(self, z): +# # debug_print() +# z_abs = abs(z) +# if z_abs < 1: return "Rest" +# elif z_abs < 2: return "Small" +# elif z_abs < 3: return "Moderate" +# else: return "Large" - def _save_pose_cache(self, path, data): - """ - Saves the raw YOLO keypoints and confidence scores to a CSV. - Each row represents one frame, flattened from (17, 3) to (51,). - """ - try: - with open(path, 'w', newline='') as f: - writer = csv.writer(f) +# def _save_pose_cache(self, path, data): +# """ +# Saves the raw YOLO keypoints and confidence scores to a CSV. +# Each row represents one frame, flattened from (17, 3) to (51,). +# """ +# try: +# with open(path, 'w', newline='') as f: +# writer = csv.writer(f) - # Create the descriptive header - header = [] - for joint in JOINT_NAMES: - # Replace spaces with underscores for better compatibility with other tools - header.extend([f"{joint}_x", f"{joint}_y", f"{joint}_conf"]) +# # Create the descriptive header +# header = [] +# for joint in JOINT_NAMES: +# # Replace spaces with underscores for better compatibility with other tools +# header.extend([f"{joint}_x", f"{joint}_y", f"{joint}_conf"]) - writer.writerow(header) +# writer.writerow(header) - # Write the frame data - for frame_data in data: - # frame_data is (17, 3), flatten to (51,) - writer.writerow(frame_data.flatten()) +# # Write the frame data +# for frame_data in data: +# # frame_data is (17, 3), flatten to (51,) +# writer.writerow(frame_data.flatten()) - print(f"DEBUG: Pose cache saved with joint headers at {path}") - except Exception as e: - print(f"ERROR: Could not save pose cache: {e}") +# print(f"DEBUG: Pose cache saved with joint headers at {path}") +# except Exception as e: +# print(f"ERROR: Could not save pose cache: {e}") - def run(self): - debug_print() - # --- PHASE 1: VIDEO SETUP & POSE EXTRACTION --- - cap = cv2.VideoCapture(self.video_path) - fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 - width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) +# def run(self): +# debug_print() +# # --- PHASE 1: VIDEO SETUP & POSE EXTRACTION --- +# cap = cv2.VideoCapture(self.video_path) +# fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 +# width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) +# height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) +# total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - raw_kps_per_frame = [] - csv_storage_data = [] - valid_mask = [] - pose_cache = self.video_path.rsplit('.', 1)[0] + "_pose_raw.csv" +# raw_kps_per_frame = [] +# csv_storage_data = [] +# valid_mask = [] +# pose_cache = self.video_path.rsplit('.', 1)[0] + "_pose_raw.csv" - if os.path.exists(pose_cache): - self.progress.emit("Loading cached kinematic data...") - with open(pose_cache, 'r') as f: - reader = csv.reader(f) - next(reader) - for row in reader: - full_data = np.array([float(x) for x in row]).reshape(17, 3) - kp = full_data[:, :2] - raw_kps_per_frame.append(kp) - csv_storage_data.append(full_data) - valid_mask.append(np.any(kp)) - else: - self.progress.emit("Detecting poses with YOLO...") - model = YOLO("yolov8n-pose.pt") - prev_track_id = None - for i in range(total_frames): - ret, frame = cap.read() - if not ret: break - results = model.track(frame, persist=True, verbose=False) - track_id, kp, confs, _ = self.get_best_infant_match(results, width, height, prev_track_id) - if kp is not None: - prev_track_id = track_id - raw_kps_per_frame.append(kp) - csv_storage_data.append(np.column_stack((kp, confs))) - valid_mask.append(True) - else: - raw_kps_per_frame.append(np.zeros((17, 2))) - csv_storage_data.append(np.zeros((17, 3))) - valid_mask.append(False) - if i % 50 == 0: self.progress.emit(f"YOLO: {int((i/total_frames)*100)}%") - self._save_pose_cache(pose_cache, csv_storage_data) +# if os.path.exists(pose_cache): +# self.progress.emit("Loading cached kinematic data...") +# with open(pose_cache, 'r') as f: +# reader = csv.reader(f) +# next(reader) +# for row in reader: +# full_data = np.array([float(x) for x in row]).reshape(17, 3) +# kp = full_data[:, :2] +# raw_kps_per_frame.append(kp) +# csv_storage_data.append(full_data) +# valid_mask.append(np.any(kp)) +# else: +# self.progress.emit("Detecting poses with YOLO...") +# model = YOLO("yolov8n-pose.pt") +# prev_track_id = None +# for i in range(total_frames): +# ret, frame = cap.read() +# if not ret: break +# results = model.track(frame, persist=True, verbose=False) +# track_id, kp, confs, _ = self.get_best_infant_match(results, width, height, prev_track_id) +# if kp is not None: +# prev_track_id = track_id +# raw_kps_per_frame.append(kp) +# csv_storage_data.append(np.column_stack((kp, confs))) +# valid_mask.append(True) +# else: +# raw_kps_per_frame.append(np.zeros((17, 2))) +# csv_storage_data.append(np.zeros((17, 3))) +# valid_mask.append(False) +# if i % 50 == 0: self.progress.emit(f"YOLO: {int((i/total_frames)*100)}%") +# self._save_pose_cache(pose_cache, csv_storage_data) - cap.release() - actual_len = len(raw_kps_per_frame) +# cap.release() +# actual_len = len(raw_kps_per_frame) - flattened_rows = [] - for frame_array in csv_storage_data: - # frame_array is (17, 3) -> flatten to (51,) - flattened_rows.append(frame_array.flatten()) +# flattened_rows = [] +# for frame_array in csv_storage_data: +# # frame_array is (17, 3) -> flatten to (51,) +# flattened_rows.append(frame_array.flatten()) - columns = [] - for name in JOINT_NAMES: - columns.extend([f"{name}_x", f"{name}_y", f"{name}_conf"]) +# columns = [] +# for name in JOINT_NAMES: +# columns.extend([f"{name}_x", f"{name}_y", f"{name}_conf"]) - # Store this so the Inspector can access it instantly in memory - self.pose_df = pd.DataFrame(flattened_rows, columns=columns) +# # Store this so the Inspector can access it instantly in memory +# self.pose_df = pd.DataFrame(flattened_rows, columns=columns) - # --- PHASE 2: KINEMATICS & Z-SCORES --- - self.progress.emit("Calculating Kinematics...") - analysis_kpts = [] - for kp in raw_kps_per_frame: - pelvis = (kp[11] + kp[12]) / 2 - analysis_kpts.append(kp - pelvis) +# # --- PHASE 2: KINEMATICS & Z-SCORES --- +# self.progress.emit("Calculating Kinematics...") +# analysis_kpts = [] +# for kp in raw_kps_per_frame: +# pelvis = (kp[11] + kp[12]) / 2 +# analysis_kpts.append(kp - pelvis) - valid_data = [analysis_kpts[i] for i, v in enumerate(valid_mask) if v] - if valid_data: - stacked = np.stack(valid_data) - baseline_mean = np.mean(stacked, axis=0) - baseline_std = np.std(np.linalg.norm(stacked - baseline_mean, axis=2), axis=0) + 1e-6 - else: - baseline_mean, baseline_std = np.zeros((17, 2)), np.ones(17) +# valid_data = [analysis_kpts[i] for i, v in enumerate(valid_mask) if v] +# if valid_data: +# stacked = np.stack(valid_data) +# baseline_mean = np.mean(stacked, axis=0) +# baseline_std = np.std(np.linalg.norm(stacked - baseline_mean, axis=2), axis=0) + 1e-6 +# else: +# baseline_mean, baseline_std = np.zeros((17, 2)), np.ones(17) - np_raw_kps = np.array(raw_kps_per_frame) - np_z_kps = np.array([np.linalg.norm(kp - baseline_mean, axis=1) / baseline_std for kp in analysis_kpts]) +# np_raw_kps = np.array(raw_kps_per_frame) +# np_z_kps = np.array([np.linalg.norm(kp - baseline_mean, axis=1) / baseline_std for kp in analysis_kpts]) - # Calculate directions (Assume you have a method for this or use a dummy for now) - # Using placeholder empty strings to prevent errors in track generation - np_dirs = np.full((actual_len, 17), "", dtype=object) +# # Calculate directions (Assume you have a method for this or use a dummy for now) +# # Using placeholder empty strings to prevent errors in track generation +# np_dirs = np.full((actual_len, 17), "", dtype=object) - # --- PHASE 3: TIMELINE GENERATION --- - # Initialize dictionary with ALL global track names to prevent KeyErrors - timeline_events = {name: [] for name in TRACK_NAMES} +# # --- PHASE 3: TIMELINE GENERATION --- +# # Initialize dictionary with ALL global track names to prevent KeyErrors +# timeline_events = {name: [] for name in TRACK_NAMES} - # 1. Kinematic Events (The joint tracks) - for j_idx, joint_name in enumerate(JOINT_NAMES): - current_block = None - for f_idx in range(actual_len): - severity = self.classify_delta(np_z_kps[f_idx, j_idx]) - if severity != "Rest": - if current_block and current_block[2] == severity: - current_block[1] = f_idx - else: - current_block = [f_idx, f_idx, severity, ""] - timeline_events[joint_name].append(current_block) - else: - current_block = None +# # 1. Kinematic Events (The joint tracks) +# for j_idx, joint_name in enumerate(JOINT_NAMES): +# current_block = None +# for f_idx in range(actual_len): +# severity = self.classify_delta(np_z_kps[f_idx, j_idx]) +# if severity != "Rest": +# if current_block and current_block[2] == severity: +# current_block[1] = f_idx +# else: +# current_block = [f_idx, f_idx, severity, ""] +# timeline_events[joint_name].append(current_block) +# else: +# current_block = None - # 2. JSON Observations - self._merge_json_observations(timeline_events, fps) +# # 2. JSON Observations +# self._merge_json_observations(timeline_events, fps) - # 3. AI Inferred Events - ai_events = self._run_existing_ml_models(np_z_kps, np_dirs, np_raw_kps) - timeline_events.update(ai_events) +# # 3. AI Inferred Events +# ai_events = self._run_existing_ml_models(np_z_kps, np_dirs, np_raw_kps) +# timeline_events.update(ai_events) - # --- PHASE 4: EMIT --- - data = { - "video_path": self.video_path, - "fps": fps, - "total_frames": actual_len, - "width": width, "height": height, - "events": timeline_events, - "raw_kps": np_raw_kps, - "z_kps": np_z_kps, - "directions": np_dirs, - "baseline_kp_mean": baseline_mean - } - self.progress.emit("Analysis Complete!") - self.finished_data.emit(data) +# # --- PHASE 4: EMIT --- +# data = { +# "video_path": self.video_path, +# "fps": fps, +# "total_frames": actual_len, +# "width": width, "height": height, +# "events": timeline_events, +# "raw_kps": np_raw_kps, +# "z_kps": np_z_kps, +# "directions": np_dirs, +# "baseline_kp_mean": baseline_mean +# } +# self.progress.emit("Analysis Complete!") +# self.finished_data.emit(data) @@ -826,110 +1105,68 @@ class PoseAnalyzerWorker(QThread): # ========================================== # TIMELINE WIDGET # ========================================== +import numpy as np +from PySide6.QtWidgets import QWidget, QScrollArea +from PySide6.QtCore import Qt, Signal, QRect, QRectF +from PySide6.QtGui import QPainter, QPen, QColor, QFont, QBrush + + + class TimelineWidget(QWidget): seek_requested = Signal(int) visibility_changed = Signal(set) track_selected = Signal(str) - def __init__(self): - debug_print() - super().__init__() + + def __init__(self, parent=None): + super().__init__(parent) self.data = None + self.track_names = [] + self.track_colors = [] self.current_frame = 0 - self.zoom_factor = 1.0 # Pixels per frame - self.label_width = 160 # Fixed gutter for track names + self.zoom_factor = 1.0 + self.label_width = 160 self.track_height = 25 self.ruler_height = 20 - self.scrollbar_buffer = 2 # Extra space for the horizontal scrollbar + self.scrollbar_buffer = 2 self.hidden_tracks = set() - self.sync_offset = 0.0 + self.sync_offset = 0.0 self.sync_fps = 30.0 - # Calculate total required height - self.total_content_height = (NUM_TRACKS * self.track_height) + self.ruler_height + self.is_scrubbing = False + + self.total_content_height = (15 * self.track_height) + self.ruler_height self.setMinimumHeight(self.total_content_height + self.scrollbar_buffer) - def set_sync_params(self, offset_seconds, fps=None): - """ - Updates the temporal shift parameters and refreshes the UI. - """ - debug_print() - self.sync_offset = float(offset_seconds) + + def set_data(self, events_dict, total_frames, fps): + """Expects grouped events from your BORIS loader.""" + self.track_names = sorted(list(events_dict.keys())) + self.data = { + "events": events_dict, + "total_frames": total_frames, + "fps": fps + } + self.sync_fps = fps - # Only update FPS if a valid value is provided, - # otherwise keep the existing data/video FPS - if fps and fps > 0: - self.sync_fps = float(fps) - elif self.data and "fps" in self.data: - self.sync_fps = float(self.data["fps"]) - - print(f"DEBUG: Timeline Sync Set - Offset: {self.sync_offset}s, FPS: {self.sync_fps}") - - # Trigger paintEvent to redraw the blocks in their new shifted positions + # Generate colors dynamically since we don't know the tracks ahead of time + self.track_colors = [QColor.fromHsl((i * 360 // max(1, len(self.track_names))), 160, 140) + for i in range(len(self.track_names))] + self.update_geometry() self.update() - def set_zoom(self, factor): - debug_print() - if not self.data: return - - # Calculate MIN zoom: The zoom required to make the video fit the width exactly - # (Available Width - Sidebar) / Total Frames - available_w = self.parent().width() - self.label_width if self.parent() else 800 - min_zoom = available_w / self.data["total_frames"] - - # Clamp: Don't zoom out past the video end, don't zoom in to infinity - self.zoom_factor = max(min_zoom, min(factor, 50.0)) - self.update_geometry() - - def get_all_binary_labels(self, offset_seconds=0.0, fps=30.0): - """ - Extracts binary labels for ALL tracks in self.data["events"]. - Returns a dict: {'OBS: Mouthing': [0,1,0...], 'OBS: Kicking': [0,0,1...]} - """ - debug_print() - all_labels = {} - - if not self.data or "events" not in self.data: - return all_labels - - total_frames = self.data.get("total_frames", 0) - if total_frames == 0: - return all_labels - - frame_shift = int(offset_seconds * fps) - - for track_name in self.data["events"]: - sequence = np.zeros(total_frames, dtype=int) - - for event in self.data["events"][track_name]: - start_f = int(event[0]) - frame_shift - end_f = int(event[1]) - frame_shift - - # Clamp values - start_idx = max(0, min(start_f, total_frames - 1)) - end_idx = max(0, min(end_f, total_frames)) - - if start_idx < end_idx: - sequence[start_idx:end_idx] = 1 - - all_labels[track_name] = sequence.tolist() # Convert to list for easier storage - - return all_labels - - def update_geometry(self): - debug_print() - if self.data: # Width is sidebar + (frames * zoom) total_w = self.label_width + int(self.data["total_frames"] * self.zoom_factor) self.setFixedWidth(total_w) self.update() + def wheelEvent(self, event): - debug_print() + # debug_print() if event.modifiers() == Qt.ControlModifier: delta = event.angleDelta().y() @@ -940,9 +1177,9 @@ class TimelineWidget(QWidget): # Let the scroll area handle normal vertical scrolling super().wheelEvent(event) - # --- NEW: CTRL + Plus / Minus / Zero --- + def keyPressEvent(self, event): - debug_print() + #debug_print() if event.modifiers() == Qt.ControlModifier: if event.key() == Qt.Key_Plus or event.key() == Qt.Key_Equal: @@ -954,48 +1191,42 @@ class TimelineWidget(QWidget): else: super().keyPressEvent(event) - def set_data(self, data): - debug_print() - self.data = data - self.total_content_height = (len(TRACK_NAMES) * self.track_height) + self.ruler_height - self.setMinimumHeight(self.total_content_height + self.scrollbar_buffer) + def set_zoom(self, factor): + # Calculate MIN zoom: The zoom required to make the video fit the width exactly + # (Available Width - Sidebar) / Total Frames + available_w = self.parent().width() - self.label_width if self.parent() else 800 + min_zoom = available_w / self.data["total_frames"] + + # Clamp: Don't zoom out past the video end, don't zoom in to infinity + self.zoom_factor = max(min_zoom, min(factor, 50.0)) self.update_geometry() + def set_playhead(self, frame): - debug_print() old_x = self.label_width + (self.current_frame * self.zoom_factor) self.current_frame = frame new_x = self.label_width + (self.current_frame * self.zoom_factor) - self.ensure_playhead_visible() - self.update(int(old_x - 5), 0, 10, self.height()) + + # Repaint only the playhead areas for performance + self.update(int(old_x - 5), 0, 10, self.height()) self.update(int(new_x - 5), 0, 10, self.height()) + self.ensure_playhead_visible() + def ensure_playhead_visible(self): - debug_print() - - """Auto-scrolls the scroll area if the playhead leaves the viewport.""" - # Find the QScrollArea parent scroll_area = self.parent().parent() if not isinstance(scroll_area, QScrollArea): return - + scrollbar = scroll_area.horizontalScrollBar() view_width = scroll_area.viewport().width() - - # Playhead position in pixels px = self.label_width + int(self.current_frame * self.zoom_factor) - - # Current scroll position scroll_x = scrollbar.value() - - # If playhead is beyond the right edge of visible area - if px > (scroll_x + view_width): - # Shift scroll so playhead is at the left (plus sidebar) - scrollbar.setValue(px - self.label_width) - - # If playhead is behind the left edge (e.g. user seeked backwards) - elif px < (scroll_x + self.label_width): - scrollbar.setValue(px - self.label_width) + + if px > (scroll_x + view_width) or px < (scroll_x + self.label_width): + scrollbar.setValue(px - self.label_width - (view_width // 4)) + + def mousePressEvent(self, event): debug_print() @@ -1017,8 +1248,8 @@ class TimelineWidget(QWidget): if pos_x < scroll_x + self.label_width: relative_y = pos_y - self.ruler_height track_idx = int(relative_y // self.track_height) - if 0 <= track_idx < len(TRACK_NAMES): - name = TRACK_NAMES[track_idx] + if 0 <= track_idx < len(self.track_names): + name = self.track_names[track_idx] if name in self.hidden_tracks: self.hidden_tracks.remove(name) else: self.hidden_tracks.add(name) self.visibility_changed.emit(self.hidden_tracks) @@ -1032,8 +1263,8 @@ class TimelineWidget(QWidget): # Handle track selection if in the data area if pos_y >= self.ruler_height: track_idx = int((pos_y - self.ruler_height) // self.track_height) - if 0 <= track_idx < len(TRACK_NAMES): - self.track_selected.emit(TRACK_NAMES[track_idx]) + if 0 <= track_idx < len(self.track_names): + self.track_selected.emit(self.track_names[track_idx]) self.selected_track_idx = track_idx self.update() else: @@ -1043,164 +1274,115 @@ class TimelineWidget(QWidget): self.update() def mouseMoveEvent(self, event): - debug_print() - - # This only fires while moving if a button is held down by default if self.is_scrubbing: self.update_frame_from_mouse(event.position().x()) def mouseReleaseEvent(self, event): - debug_print() - if event.button() == Qt.LeftButton: self.is_scrubbing = False - def update_frame_from_mouse(self, x_pos): - """Helper to calculate frame and emit the seek signal.""" - debug_print() - relative_x = x_pos - self.label_width - frame = int(relative_x / self.zoom_factor) + def update_frame_from_mouse(self, x): + rel_x = x - self.label_width + frame = int(rel_x / self.zoom_factor) frame = max(0, min(frame, self.data["total_frames"] - 1)) - - # We emit seek_requested so the Video Player and Premiere class sync up self.seek_requested.emit(frame) - + + def paintEvent(self, event): - debug_print() if not self.data: return - - dirty_rect = event.rect() + painter = QPainter(self) + dirty_rect = event.rect() - # 1. Determine current scroll position to keep labels sticky + # 1. Coordinate Setup scroll_area = self.parent().parent() scroll_x = 0 if isinstance(scroll_area, QScrollArea): scroll_x = scroll_area.horizontalScrollBar().value() - w, h = self.width(), self.height() total_f = self.data["total_frames"] fps = self.data.get("fps", 30) offset_y = 20 - # 2. DRAW DATA AREA (Events and Playhead) - # --- 2. DRAW DATA AREA (Muted Patterns + Events + Playhead) --- - sync_off = getattr(self, "sync_offset", 0.0) - sync_fps = getattr(self, "sync_fps", fps) - frame_shift = int(sync_off * sync_fps) - for i, name in enumerate(TRACK_NAMES): + # --- 2. DRAW DATA AREA (Events and Playhead) --- + for i, name in enumerate(self.track_names): y = offset_y + (i * self.track_height) is_hidden = name in self.hidden_tracks - + if y + self.track_height < dirty_rect.top() or y > dirty_rect.bottom(): continue - # A. Draw Muted/Disabled Background Pattern - - if is_hidden: - # Calculate the visible rectangle for this track to the right of the sidebar - mute_rect = QRectF(self.label_width, y, w - self.label_width, self.track_height) - # Fill with a dark "disabled" base - painter.fillRect(mute_rect, QColor(40, 40, 40)) - # Add the Cross-Hatch Pattern - pattern_brush = QBrush(QColor(60, 60, 60, 100), Qt.DiagCrossPattern) - painter.fillRect(mute_rect, pattern_brush) + # Event Blocks + if name in self.data["events"]: + base_color = self.track_colors[i] + for event_item in self.data["events"][name]: + # Map the new data format: [start_f, end_f, type, value] + start_f, end_f = event_item[0], event_item[1] + + if "AI:" in name: + s_start, s_end = start_f, end_f + else: + s_start = start_f - 0 + s_end = end_f - 0 + + x_start = self.label_width + (s_start * self.zoom_factor) + x_end = self.label_width + (s_end * self.zoom_factor) - # B. Draw Event Blocks - base_color = TRACK_COLORS[i] - for start_f, end_f, severity, direction in self.data["events"][name]: - # x_start = self.label_width + (start_f * self.zoom_factor) - # x_end = self.label_width + (end_f * self.zoom_factor) + # Performance optimization: skip drawing if off-screen + if x_end < scroll_x or x_start > scroll_x + w: + continue - if "AI:" in name: - # AI predictions are already calculated in video-time, NO SHIFT - shifted_start, shifted_end = start_f, end_f - - else: - shifted_start = start_f - frame_shift - shifted_end = end_f - frame_shift - - x_start = self.label_width + (shifted_start * self.zoom_factor) - x_end = self.label_width + (shifted_end * self.zoom_factor) - - # Performance optimization: skip drawing if off-screen - if x_end < scroll_x or x_start > scroll_x + w: - continue - - if x_end < dirty_rect.left() or x_start > dirty_rect.right(): - continue - - # If hidden, make the event block very faint/transparent - if is_hidden: - color = QColor(120, 120, 120, 40) # Muted Grey - else: - alpha = 80 if severity == "Small" else 160 if severity == "Moderate" else 255 + # Draw block color = QColor(base_color) - color.setAlpha(alpha) + if is_hidden: + color = QColor(120, 120, 120, 40) + + painter.fillRect(QRectF(x_start, y + 2, max(1, x_end - x_start), self.track_height - 4), color) - painter.fillRect(QRectF(x_start, y + 2, max(1, x_end - x_start), self.track_height - 4), color) # Draw Playhead playhead_x = self.label_width + (self.current_frame * self.zoom_factor) painter.setPen(QPen(QColor(255, 0, 0), 2)) - painter.drawLine(playhead_x, 0, playhead_x, h) + painter.drawLine(int(playhead_x), 0, int(playhead_x), h) - # 3. DRAW STICKY SIDEBAR (Pinned to the left edge of the viewport) - # We draw this AFTER the data so it covers the blocks as they scroll past + # # --- 3. DRAW STICKY SIDEBAR (Pinned to the left edge) --- + # # Draw this AFTER the data so it masks anything scrolling under it sidebar_rect = QRect(scroll_x, 0, self.label_width, h) - painter.fillRect(sidebar_rect, QColor(30, 30, 30)) # Solid background + painter.fillRect(sidebar_rect, QColor(30, 30, 30)) # Ruler segment for the sidebar area painter.fillRect(scroll_x, 0, self.label_width, offset_y, QColor(45, 45, 45)) - for i, name in enumerate(TRACK_NAMES): + for i, name in enumerate(self.track_names): y = offset_y + (i * self.track_height) is_hidden = name in self.hidden_tracks - # Grid Line + + # Pinned Grid Line painter.setPen(QColor(60, 60, 60)) painter.drawLine(scroll_x, y, scroll_x + w, y) - # Sticky Label Text - if is_hidden: - # Very dark grey to show it's "OFF" - text_color = QColor(70, 70, 70) - else: - # Bright white/grey to show it's "ON" - text_color = QColor(220, 220, 220) - + # # Pinned Label Text (Anchored to scroll_x) + text_color = QColor(70, 70, 70) if is_hidden else QColor(220, 220, 220) painter.setPen(text_color) painter.setFont(QFont("Arial", 8, QFont.Bold)) painter.drawText(scroll_x + 10, y + 17, name) - # 4. DRAW TIME RULER TICKS (Right of the sticky sidebar) + # --- 4. DRAW TIME RULER TICKS --- target_spacing_px = 120 - - # Available units in frames: 1, 5, 15, 30 (1s), 150 (5s), 300 (10s), 1800 (1min) possible_units = [1, 5, 15, 30, 150, 300, 900, 1800] + tick_interval = next((u for u in possible_units if (u * self.zoom_factor) >= target_spacing_px), 1800) - # Find the smallest unit that results in at least target_spacing_px - tick_interval = possible_units[-1] - for unit in possible_units: - if (unit * self.zoom_factor) >= target_spacing_px: - tick_interval = unit - break - - # 2. DRAW BACKGROUNDS - painter.fillRect(0, 0, w, 20, QColor(45, 45, 45)) # Ruler Bar + # Draw Ruler Background + painter.fillRect(0, 0, w, 20, QColor(45, 45, 45)) - # 3. DRAW TICKS AND TIME LABELS painter.setPen(QColor(180, 180, 180)) painter.setFont(QFont("Segoe UI", 7)) - - # Sub-ticks (draw 5 small lines for every 1 major interval) sub_interval = max(1, tick_interval // 5) - # Start loop from 0 to total frames for f in range(0, total_f + 1, sub_interval): x = self.label_width + int(f * self.zoom_factor) - # Optimization: Don't draw if off-screen if x < scroll_x: continue if x > scroll_x + w: break @@ -1220,83 +1402,83 @@ class TimelineWidget(QWidget): time_str = f"{minutes:02d}m:{seconds:02d}s" else: time_str = f"{seconds}s" - painter.drawText(x + 4, 12, time_str) else: - # Minor Tick painter.drawLine(x, 16, x, 20) - def get_ai_extractions(self): - """ - Processes timeline data for AI tracks and specific OBS sync events. - """ - debug_print() - fps = self.data.get("fps", 30.0) - - extraction_data = { - "metadata": { - "fps": fps, - "total_frames": self.data.get("total_frames", 0), - "track_summaries": {} - }, - "obs": {}, # New top-level key for specific OBS events - "ai_tracks": {} - } + self.update_geometry() + + + + + +class TrainingWorker(QThread): + # Signals to communicate back to the UI + finished = Signal(str) # Sends the HTML report back + error = Signal(str) # Sends error messages + + def __init__(self, params): + super().__init__() + self.params = params + + def run(self): + try: + from predictor import GeneralPredictor + predictor = GeneralPredictor() + # This is the heavy calculation and training + report = predictor.calculate_and_train(self.params) + self.finished.emit(report) + except Exception as e: + self.error.emit(str(e)) + + + + +from PySide6.QtCore import QThread, Signal + +class MLInferenceWorker(QThread): + finished = Signal(dict) # Emits the timeline_events dictionary + error = Signal(str) + + def __init__(self, raw_kpts, ml_model, ml_scaler, active_features, behavior_name): + super().__init__() + self.raw_kpts = raw_kpts + self.ml_model = ml_model + self.ml_scaler = ml_scaler + self.active_features = active_features + self.behavior_name = f"AI: {behavior_name}" + + def run(self): + try: + # Import predictor logic inside the thread + from predictor import GeneralPredictor + engine = GeneralPredictor() + engine.active_feature_keys = self.active_features + + # 1. Feature Extraction (The slow part) + X_raw = [] + for frame in self.raw_kpts: + X_raw.append(engine.format_features(frame)) + X = np.array(X_raw) + + # 2. Scaling & Prediction + if self.ml_scaler: + X = self.ml_scaler.transform(X) + + preds = self.ml_model.predict(X) + + # 3. Convert to timeline blocks (using your existing converter logic) + # You can either move the converter into GeneralPredictor or call it here + events = engine.convert_to_events(preds, track_name=self.behavior_name) # Ensure engine has this method + + self.finished.emit(events) + except Exception as e: + self.error.emit(str(e)) - if not self.data or "events" not in self.data: - return extraction_data - # 1. Extract the OBS: Time Sync event specifically - sync_key = "OBS: Time Sync" - if sync_key in self.data["events"]: - sync_blocks = self.data["events"][sync_key] - # Convert blocks to a list of dicts for the JSON - extraction_data["obs"][sync_key] = [ - { - "start_frame": b[0], - "end_frame": b[1], - "start_time_sec": round(b[0] / fps, 3), - "end_time_sec": round(b[1] / fps, 3) - } for b in sync_blocks - ] - # 2. Process AI Tracks - for track_name, blocks in self.data["events"].items(): - if track_name.startswith("AI:"): - track_results = [] - track_total = 0 - track_long = 0 - - for block in blocks: - start_f, end_f = block[0], block[1] - severity = block[2] if len(block) > 2 else "Normal" - - start_sec = round(start_f / fps, 3) - end_sec = round(end_f / fps, 3) - duration = round(end_sec - start_sec, 3) - - track_total += 1 - if duration > 2.0: - track_long += 1 - - track_results.append({ - "start_frame": int(start_f), - "end_frame": int(end_f), - "start_time_sec": start_sec, - "end_time_sec": end_sec, - "duration_sec": duration, - "severity": severity - }) - - extraction_data["ai_tracks"][track_name] = track_results - extraction_data["metadata"]["track_summaries"][track_name] = { - "event_count": track_total, - "long_events_over_2s": track_long, - "total_duration_sec": round(sum(r["duration_sec"] for r in track_results), 3) - } - return extraction_data class SkeletonOverlay(QWidget): @@ -1388,7 +1570,9 @@ class SkeletonOverlay(QWidget): # --- 2. DRAW BASELINE (Only if not hidden) --- if "Baseline" not in self.hidden_tracks: idx_l_hip, idx_r_hip = self.KP_MAP["Lhip"], self.KP_MAP["Rhip"] - pelvis_live = (kp_live[idx_l_hip] + kp_live[idx_r_hip]) / 2 + p_left = kp_live[idx_l_hip][:2] + p_right = kp_live[idx_r_hip][:2] + pelvis_live = (p_left + p_right) / 2 kp_baseline = self.data['baseline_kp_mean'] + pelvis_live painter.setPen(QPen(QColor(200, 200, 200, 200), 2, Qt.DashLine)) @@ -1456,6 +1640,19 @@ class SkeletonOverlay(QWidget): painter.drawEllipse(pt, 4, 4) + + + + + + + + + + + + + class VideoView(QGraphicsView): resized = Signal() @@ -1501,18 +1698,60 @@ class PremiereWindow(QMainWindow): QMainWindow, QWidget#centralWidget { background-color: #1e1e1e; } QLabel, QStatusBar, QMenuBar { color: #ffffff; } QTabWidget::pane { border: 1px solid #333333; background: #1e1e1e; } - QTabBar::tab { background: #2b2b2b; color: #aaa; padding: 8px 15px; border: 1px solid #333; border-bottom: none; border-top-left-radius: 4px; border-top-right-radius: 4px; } - QTabBar::tab:selected { background: #3d3d3d; color: #fff; font-weight: bold; } - QTabBar::tab:hover { background: #444; } - """) + QTabBar::tab { + background: #2b2b2b; + color: #aaa; + padding: 8px 15px; + border: 1px solid #333; + border-bottom: none; + border-top-left-radius: 4px; + border-top-right-radius: 4px; + } + + /* 2. Content Tabs: Selected & Hover States */ + /* We add :!last to ensure these NEVER apply to the + tab */ + QTabBar::tab:selected:!last { + background: #3d3d3d; + color: #fff; + font-weight: bold; + } + + QTabBar::tab:hover:!last { + background: #444; + } + + /* 3. THE PLUS TAB: Constant State */ + /* We define all states (normal, selected, hover) to be identical */ + QTabBar::tab:last, + QTabBar::tab:last:hover, + QTabBar::tab:last:selected { + background: #1e1e1e; + color: #00aaff; + font-weight: bold; + margin-left: 2px; + border: 1px solid #333; + padding: 4px 12px; + }""") # --- Tab System --- self.tabs = QTabWidget() self.tabs.setTabsClosable(True) self.tabs.tabCloseRequested.connect(self.close_tab) + self.setCentralWidget(self.tabs) self.create_welcome_tab() + + self.tabs.addTab(QWidget(), "+") + + # 2. Disable the close button on the "+" tab specifically + # (Assuming index 0 was your first tab, the + is now at the last index) + plus_idx = self.tabs.count() - 1 + self.tabs.tabBar().setTabButton(plus_idx, QTabBar.ButtonPosition.RightSide, None) + + # 3. Connect to the click event + self.tabs.tabBar().installEventFilter(self) + self.create_menu_bar() # Update checks @@ -1523,9 +1762,46 @@ class PremiereWindow(QMainWindow): # Window instances self.load_window = None + self.train_window = None self.about = None self.help = None + + def eventFilter(self, obj, event): + # Check if the event is a mouse press on the TabBar + if obj == self.tabs.tabBar() and event.type() == QEvent.MouseButtonPress: + # Map the click position to which tab index was hit + index = self.tabs.tabBar().tabAt(event.pos()) + + # If they clicked the "+" tab + if self.tabs.tabText(index) == "+": + # Show the menu + self.show_new_tab_menu() + # Return True to CONSUME the event. + # This prevents QTabWidget from ever seeing the click and switching tabs. + return True + + return super().eventFilter(obj, event) + + + def show_new_tab_menu(self): + from PySide6.QtWidgets import QMenu + from PySide6.QtGui import QAction, QCursor + menu = QMenu(self) + + load_act = QAction("Load Video", self) + load_act.triggered.connect(self.open_load_video_dialog) + + train_act = QAction("Train Model", self) + train_act.triggered.connect(self.open_train_model_dialog) + + menu.addAction(load_act) + menu.addAction(train_act) + + # Show the menu right under the mouse cursor + menu.exec(QCursor.pos()) + + def create_welcome_tab(self): welcome_widget = QWidget() layout = QVBoxLayout(welcome_widget) @@ -1544,8 +1820,7 @@ class PremiereWindow(QMainWindow): layout.addStretch() self.tabs.addTab(welcome_widget, "Welcome") - # Disable the close button on the welcome tab - self.tabs.tabBar().setTabButton(0, QTabBar.ButtonPosition.RightSide, None) + def create_menu_bar(self): menu_bar = self.menuBar() @@ -1590,22 +1865,80 @@ class PremiereWindow(QMainWindow): self.load_window.btn_confirm.clicked.connect(self.handle_new_video_session) self.load_window.show() - def handle_new_video_session(self): - # Extract properties from the OpenFileWindow before closing it - video_path = self.load_window.video_path - obs_file = self.load_window.obs_file - offset = self.load_window.current_video_offset - # ... grab any other needed parameters (like selected ML model, etc.) + def open_train_model_dialog(self): + if self.train_window is None or not self.train_window.isVisible(): + self.train_window = TrainModelWindow(self) + # Connect the initialization button from OpenFileWindow to our tab creator + self.train_window.btn_train.clicked.connect(self.handle_start_training) + self.train_window.show() + + + def handle_start_training(self): + params = self.train_window.get_selection() + # 1. Use QProgressDialog instead of QMessageBox + # It's designed to stay on top and handle background tasks + self.loading_dialog = QProgressDialog("Processing data and training Random Forest...", None, 0, 0, self) + self.loading_dialog.setWindowTitle("Training Model") + self.loading_dialog.setWindowModality(Qt.WindowModal) + self.loading_dialog.setCancelButton(None) # Remove cancel button to prevent interruption + self.loading_dialog.setMinimumDuration(0) # Show immediately + self.loading_dialog.show() + + # 2. Setup the Worker Thread + self.training_thread = TrainingWorker(params) + self.training_thread.finished.connect(self.on_training_finished) + self.training_thread.error.connect(self.on_training_error) + + # Clean up the thread object when it's done to prevent memory leaks + self.training_thread.finished.connect(self.training_thread.deleteLater) + + # 3. Start thread + self.training_thread.start() + + # Close the selection window + self.train_window.close() + + def on_training_finished(self, report_html): + # Using reset() on QProgressDialog automatically closes it and cleans up + if self.loading_dialog: + self.loading_dialog.reset() + self.loading_dialog = None + + self.display_ml_results(report_html) + + def on_training_error(self, error_msg): + if self.loading_dialog: + self.loading_dialog.reset() + self.loading_dialog = None + + QMessageBox.critical(self, "Training Error", f"An error occurred: {error_msg}") + + def display_ml_results(self, report): + """Displays the RF performance report in a simple popup.""" + msg = QMessageBox(self) + msg.setWindowTitle("Training Results") + msg.setTextFormat(Qt.RichText) + msg.setText(report) + msg.exec() + + def handle_new_video_session(self): + config = self.load_window.get_config() + + # 2. Close the selection window self.load_window.close() - # Create a new, independent tab - new_tab = VideoAnalysisTab(video_path, obs_file, offset) + # 3. Create the new tab with the config dictionary + # We pass the config so the tab knows whether to run AI or just play video + new_tab = VideoAnalysisTab(config) - # Add to TabWidget and switch to it - tab_name = os.path.basename(video_path) - index = self.tabs.addTab(new_tab, tab_name) - self.tabs.setCurrentIndex(index) + # 4. Handle Tab Placement (Keep '+' at the end) + tab_name = os.path.basename(config['video_path']) + plus_idx = self.tabs.count() - 1 + new_idx = self.tabs.insertTab(plus_idx, new_tab, tab_name) + + # 5. Switch to it + self.tabs.setCurrentIndex(new_idx) def close_tab(self, index): # Prevent closing the Welcome tab if it's the only one left @@ -1645,26 +1978,72 @@ class PremiereWindow(QMainWindow): -from PySide6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QSplitter, - QScrollArea, QPushButton, QComboBox, QLabel, - QSlider, QTextEdit, QMessageBox) -from PySide6.QtCore import Qt, QSizeF, QUrl, Slot -from PySide6.QtMultimedia import QMediaPlayer, QAudioOutput -from PySide6.QtMultimediaWidgets import QGraphicsVideoItem - class VideoAnalysisTab(QWidget): - def __init__(self, video_path, obs_file=None, offset=0.0, parent=None): - super().__init__(parent) + def __init__(self, config): + super().__init__() # State - self.file_path = video_path - self.obs_file = obs_file - self.current_video_offset = offset - self.predictor = GeneralPredictor() - self.data = None # Will be populated by worker + self.config = config + + self.setStyleSheet(""" + QMainWindow, QWidget#centralWidget { + background-color: #1e1e1e; + } + QLabel, QStatusBar, QMenuBar { + color: #ffffff; + } + /* Target the Timeline specifically */ + TimelineWidget { + background-color: #1e1e1e; + border: 1px solid #333333; + } + /* Button styling with Grey borders */ + QDialog, QMessageBox, QFileDialog { + background-color: #2b2b2b; + } + QDialog QLabel, QMessageBox QLabel { + color: #ffffff; + } + QPushButton { + background-color: #2b2b2b; + color: #ffffff; + border: 1px solid #555555; /* Subtle Grey border */ + border-radius: 3px; + padding: 4px; + } + QPushButton:hover { + background-color: #3d3d3d; + border-color: #888888; /* Brightens border on hover */ + } + QPushButton:pressed { + background-color: #111111; + } + QPushButton:disabled { + border-color: #333333; + color: #444444; + } + /* Splitter/Divider styling */ + QSplitter::handle { + background-color: #333333; /* Dark grey dividers */ + } + QSplitter::handle:horizontal { + width: 2px; + } + QSplitter::handle:vertical { + height: 2px; + } + /* ScrollArea styling to keep it dark */ + QScrollArea, QScrollArea > QWidget > QWidget { + background-color: #1e1e1e; + border: none; + } + """) + self.setup_ui() - #self.reprocess_current_video() + + self.initialize_session() + def setup_ui(self): main_layout = QVBoxLayout(self) @@ -1685,7 +2064,7 @@ class VideoAnalysisTab(QWidget): self.scene.addItem(self.video_item) # Overlay initialization - self.skeleton_overlay = SkeletonOverlay(self.view.viewport()) + #self.skeleton_overlay = SkeletonOverlay(self.view.viewport()) self.player = QMediaPlayer() self.audio_output = QAudioOutput() @@ -1699,48 +2078,6 @@ class VideoAnalysisTab(QWidget): controls_container = QWidget() stacked_controls = QVBoxLayout(controls_container) stacked_controls.setSpacing(5) # Tight spacing between rows - - ml_row = QHBoxLayout() - ml_row.addStretch() - - ml_row.addWidget(QLabel("ML Model:")) - self.ml_dropdown = QComboBox() - self.ml_dropdown.addItems(["Random Forest", "LSTM", "XGBoost", "SVM", "1D-CNN"]) - ml_row.addWidget(self.ml_dropdown) - - ml_row.addWidget(QLabel("Target:")) - self.target_dropdown = QComboBox() - self.target_dropdown.addItems(["Mouthing", "Head Movement", "Kick (Left)", "Kick (Right)", "Reach (Left)", "Reach (Right)"]) - #self.target_dropdown.currentTextChanged.connect(self.update_predictor_target) - ml_row.addWidget(self.target_dropdown) - - self.btn_add_to_pool = QPushButton("Add to Pool") - #self.btn_add_to_pool.clicked.connect(self.add_current_to_ml_pool) - self.btn_add_to_pool.setFixedWidth(120) - ml_row.addWidget(self.btn_add_to_pool) - - self.btn_train_final = QPushButton("Train Global Model") - self.btn_train_final.setStyleSheet("background-color: #2e7d32; font-weight: bold;") - #self.btn_train_final.clicked.connect(self.run_final_training) - ml_row.addWidget(self.btn_train_final) - - self.lbl_pool_status = QLabel("Pool: 0 Participants") - self.lbl_pool_status.setStyleSheet("color: #00FF00; font-weight: bold; margin-left: 10px;") - self.lbl_pool_status.setFixedWidth(160) - ml_row.addWidget(self.lbl_pool_status) - - self.btn_clear_pool = QPushButton("Clear Pool") - self.btn_clear_pool.setFixedWidth(100) - self.btn_clear_pool.setStyleSheet("color: #ff5555; border: 1px solid #ff5555;") - #self.btn_clear_pool.clicked.connect(self.clear_ml_pool) - ml_row.addWidget(self.btn_clear_pool) - - self.btn_extract_ai = QPushButton("Extract AI Data") - #self.btn_extract_ai.clicked.connect(self.extract_ai_to_json) - ml_row.addWidget(self.btn_extract_ai) - - - ml_row.addStretch() # --- ROW 2: Playback & Transport --- playback_row = QHBoxLayout() @@ -1791,26 +2128,347 @@ class VideoAnalysisTab(QWidget): playback_row.addStretch() # --- Add Rows to Stack --- - stacked_controls.addLayout(ml_row) stacked_controls.addLayout(playback_row) video_layout.addWidget(controls_container) + + info_container = QWidget() + info_layout = QVBoxLayout(info_container) + + self.progress_container = QWidget() + progress_layout = QVBoxLayout(self.progress_container) + + self.lbl_analysis_status = QLabel("Pose Analysis: Idle") + self.analysis_bar = QProgressBar() + self.analysis_bar.setRange(0, 100) + self.analysis_bar.setValue(0) + self.analysis_bar.setStyleSheet(""" + QProgressBar { border: 1px solid #555; border-radius: 2px; text-align: center; height: 15px; } + QProgressBar::chunk { background-color: #00aaff; } + """) + + progress_layout.addWidget(self.lbl_analysis_status) + progress_layout.addWidget(self.analysis_bar) + + # Insert into info_layout (above the inspector scroll area) + info_layout.insertWidget(0, self.progress_container) + + # NEW: Wrap the info_label in a Scroll Area + self.inspector_scroll = QScrollArea() + self.inspector_scroll.setWidgetResizable(True) - # --- Inspector & Timeline --- self.info_label = QTextEdit() self.info_label.setReadOnly(True) + self.info_label.setStyleSheet("padding: 8px; font-family: 'Consolas', 'Segoe UI'; color: #ffffff;") + self.inspector_scroll.setWidget(self.info_label) + + # NEW: Export Button for Metrics + self.btn_export_metrics = QPushButton("Export Metrics to JSON") + self.btn_export_metrics.clicked.connect(self.export_behavior_metrics) + self.btn_export_metrics.setEnabled(False) # Enable only after load + + info_layout.addWidget(self.inspector_scroll) + info_layout.addWidget(self.btn_export_metrics) + + top_splitter.addWidget(video_container) + top_splitter.addWidget(info_container) + top_splitter.setSizes([800, 400]) self.timeline = TimelineWidget() self.timeline.seek_requested.connect(self.seek_video) - - top_splitter.addWidget(video_container) - top_splitter.addWidget(self.info_label) + + scroll_area = QScrollArea() + scroll_area.setWidgetResizable(True) + scroll_area.setWidget(self.timeline) self.main_splitter.addWidget(top_splitter) - self.main_splitter.addWidget(self.timeline) + self.main_splitter.addWidget(scroll_area) + self.main_splitter.setSizes([500, 400]) main_layout.addWidget(self.main_splitter) + + + self.skeleton_overlay = SkeletonOverlay(self.view) + self.skeleton_overlay.resize(self.view.size()) + self.skeleton_overlay.hide() + + # 2. FIX: Watch the view for resizes + self.view.installEventFilter(self) + + self.player.positionChanged.connect(self.update_timeline_playhead) self.setup_transport() # Start with empty workspace until worker finishes + # self.load_boris_to_timeline() + self.video_item.nativeSizeChanged.connect(self.update_video_geometry) + self.start_analysis() + + + + def start_analysis(self): + if not self.config.get("use_pose", True): + self.lbl_analysis_status.setText("Pose Analysis: Bypassed") + self.analysis_bar.setValue(100) + self.skeleton_overlay.hide() + return + + # 1. Setup Queues + self.prog_q = Queue() + self.res_q = Queue() + + # 2. Create the Process + self.analysis_proc = Process( + target=run_pose_analysis, + args=(self.config['video_path'], self.prog_q, self.res_q, self.config), + name="PoseWorkerProcess" + ) + + # 3. UI Updates + self.lbl_analysis_status.setText("Process Started...") + self.analysis_bar.setValue(0) + self.analysis_bar.show() + + # 4. Start + self.analysis_proc.start() + + # 5. Timer to check the Queue + self.poll_timer = QTimer() + self.poll_timer.timeout.connect(self.check_queues) + self.poll_timer.start(100) + + + def check_queues(self): + # Drain progress queue + while not self.prog_q.empty(): + val = self.prog_q.get() + self.analysis_bar.setValue(val) + self.lbl_analysis_status.setText(f"Extracting Poses: {val}%") + + # Check result queue + if not self.res_q.empty(): + data = self.res_q.get() + self.poll_timer.stop() + self.handle_finished_data(data) + + + def handle_finished_data(self, data): + """ + Runs on the main thread. Loads BORIS instantly, shows the skeleton, + and kicks off AI in the background. + """ + # 1. Unpack basic data + self.raw_kpts = data["raw_kpts"] + self.fps = data.get("fps", 30.0) + self.total_frames = data["total_frames"] + v_w, v_h = data["dims"] + + # 2. Setup master dictionary + self.processed_data = {} + + # 3. IMMEDIATE: Load BORIS (if available) + if self.config.get('use_boris'): + self.load_boris_to_timeline() + + # 4. IMMEDIATE: Show Skeleton Overlay + # We calculate the baseline mean here because it's fast (Numpy) + raw_kps_per_frame = [frame[:, :2] for frame in self.raw_kpts] + valid_mask = [np.any(kp) for kp in raw_kps_per_frame] + valid_data = [raw_kps_per_frame[i] for i, v in enumerate(valid_mask) if v] + + baseline_mean = np.mean(valid_data, axis=0) if valid_data else np.zeros((17, 2)) + + overlay_payload = { + "raw_kps": self.raw_kpts, + "width": v_w, + "height": v_h, + "events": self.processed_data, # Initially just BORIS + "baseline_kp_mean": baseline_mean + } + self.skeleton_overlay.set_data(overlay_payload) + self.skeleton_overlay.show() + + # 5. SYNC UI (Show the manual timeline immediately) + self.sync_timeline_to_ui() + + # 6. BACKGROUND: Start AI Inference Thread + if self.config.get('use_pkl') and hasattr(self, 'ml_model'): + self.lbl_analysis_status.setText("Running AI Inference...") + + # Start the worker thread + self.ml_worker = MLInferenceWorker( + self.raw_kpts, + self.ml_model, + self.ml_scaler, + self.active_features, + self.ml_metadata.get('target_behavior', 'Reach') + ) + self.ml_worker.finished.connect(self.on_ai_inference_complete) + self.ml_worker.error.connect(lambda e: print(f"AI ERROR: {e}")) + self.ml_worker.start() + else: + self.lbl_analysis_status.setText("Analysis Complete") + + + + def on_ai_inference_complete(self, ai_events): + """Runs when the thread finishes. Merges AI into the existing UI.""" + # 1. Merge AI tracks into the dictionary that already has BORIS + for track_name, blocks in ai_events.items(): + self.processed_data[track_name] = blocks + print(f"AI Thread complete: Injected {track_name}") + + # 2. Update the Timeline Widget visually + self.sync_timeline_to_ui() + + # 3. Update the Skeleton Overlay specifically + # This ensures the AI blocks show up in the video player's seek bar/overlay + if hasattr(self, 'skeleton_overlay'): + # Fetch the old data dict, update the 'events' key, and push it back + updated_payload = self.skeleton_overlay.data.copy() + updated_payload['events'] = self.processed_data + self.skeleton_overlay.set_data(updated_payload) + + self.lbl_analysis_status.setText("All Tracks (BORIS + AI) Loaded") + + + + def load_boris_to_timeline(self): + """ + Parses the JSON to identify unique behaviors and prepare the data. + """ + if not self.config.get('use_boris', False) or not self.config.get('obs_file'): + return + + try: + with open(self.config['obs_file'], 'r') as f: + data = json.load(f) + + behav_conf = data.get("behaviors_conf", {}) + type_lookup = {v['code']: v['type'] for v in behav_conf.values()} + + session_key = self.config.get('session_key') + session = data.get("observations", {}).get(session_key, {}) + events = session.get("events", []) + + unique_behaviors = sorted(list(set(e[2] for e in events))) + new_boris_data = {name: [] for name in unique_behaviors} + + fps = self.config.get('fps', 30.0) + # Use the offset from the config + offset_frames = int(self.config.get('offset', 0.0) * fps) + + state_tracker = {} + + for e in events: + timestamp = float(e[0]) + behavior_name = e[2] + event_type = type_lookup.get(behavior_name, "Point event") + current_frame = int(timestamp * fps) - offset_frames + + if event_type == "Point event": + new_boris_data[behavior_name].append([current_frame, current_frame + 1, "Normal", "N/A"]) + elif event_type == "State event": + if behavior_name in state_tracker: + start_f = state_tracker.pop(behavior_name) + new_boris_data[behavior_name].append([start_f, current_frame, "Normal", "N/A"]) + else: + state_tracker[behavior_name] = current_frame + + if not hasattr(self, 'processed_data') or self.processed_data is None: + self.processed_data = {} + + for behavior, blocks in new_boris_data.items(): + self.processed_data[behavior] = blocks + + print(f"[INFO] Merged {len(new_boris_data)} BORIS tracks into timeline.") + + except Exception as e: + print(f"Error loading BORIS data: {e}") + + + def sync_timeline_to_ui(self): + """Final push of all merged data (BORIS + AI) to the UI components.""" + if not hasattr(self, 'processed_data'): + return + + total_f = self.config.get('total_frames', self.total_frames) + fps = self.config.get('fps', self.fps) + + # 1. Update the Timeline Widget + self.timeline.set_data(self.processed_data, total_f, fps) + + # 2. Update Stats & Export button + self.update_inspector_stats(self.processed_data, fps) + self.btn_export_metrics.setEnabled(True) + + print(f"DEBUG: UI Synced with {len(self.processed_data)} total tracks.") + + + def update_inspector_stats(self, data, fps): + """Calculates and displays behavior summary on the right panel.""" + stats_text = "SESSION SUMMARY

" + stats_text += f"File: {self.config['session_key']}
" + stats_text += "-----------------------------------
" + + for behavior, instances in data.items(): + durations = [(end - start) / fps for start, end, _, _ in instances] + count = len(instances) + avg_dur = np.mean(durations) if durations else 0 + total_dur = sum(durations) + + stats_text += f"{behavior}:
" + stats_text += f" - Occurrences: {count}
" + stats_text += f" - Avg Duration: {avg_dur:.3f}s
" + stats_text += f" - Total Time: {total_dur:.2f}s

" + + self.info_label.setHtml(stats_text) + + + def export_behavior_metrics(self): + """Exports the processed behavior metrics to a new JSON file.""" + if not self.processed_data: + return + + export_payload = { + "metadata": { + "source_video": os.path.basename(self.config.get('video_path', 'unknown')), + "session": self.config.get('session_key', 'unknown'), + "export_timestamp": datetime.now().isoformat(), + "fps": self.config.get('fps', 30.0) + }, + "behaviors": {} + } + + for behavior, instances in self.processed_data.items(): + behavior_events = [] + + for start_f, end_f, _, _ in instances: + behavior_events.append({ + "start_frame": int(start_f), + "duration_frames": int(end_f - start_f) + }) + + export_payload["behaviors"][behavior] = behavior_events + + video_path = self.config.get('video_path') + if not video_path: + return + + base_path = video_path.rsplit('.', 1)[0] + export_path = f"{base_path}_metrics.json" + + # 3. SILENT SAVE + try: + with open(export_path, 'w') as f: + json.dump(export_payload, f, indent=4) + + # Log the success so the researcher knows where it went + msg = f"
[EXPORT COMPLETE]: {os.path.basename(export_path)}" + self.info_label.append(msg) + print(f"Metrics saved to: {export_path}") + + except Exception as e: + self.info_label.append(f"
[EXPORT FAILED]: {e}") + + def setup_transport(self): """Sets up player controls that don't depend on skeleton analysis.""" @@ -1829,29 +2487,170 @@ class VideoAnalysisTab(QWidget): self.btn_prev.clicked.connect(lambda: self.step_frame(-1)) self.btn_next.clicked.connect(lambda: self.step_frame(1)) - self.player.setSource(QUrl.fromLocalFile(self.file_path)) + self.player.setSource(QUrl.fromLocalFile(self.config['video_path'])) + self.player.pause() + # self.player.mediaStatusChanged.connect(self.initial_resize_hack) + + # def initial_resize_hack(self, status): + # # Once the media is loaded, refresh the layout + # if status >= QMediaPlayer.MediaStatus.LoadedMedia: + # self.update_video_geometry() + # # Seek to 0 just to be absolutely sure the buffer updates + # self.player.setPosition(0) + + + def initialize_session(self): + print(f"--- Initializing Session Components ---") - @Slot(dict) - def setup_workspace(self, analyzed_data): - """Only handles skeleton/analysis-specific data.""" - self.data = analyzed_data + # Component A: Pose Inference (Skeleton extraction) + if self.config.get('use_pose', False): + # This is already handled by self.start_analysis() in your __init__ + pass + + # Component B: BORIS Annotation Track + if self.config.get('use_boris', False): + print("[INFO] Loading BORIS annotation track...") + self.load_boris_to_timeline() - # Update timeline and overlay now that we have data - if hasattr(self, 'timeline'): - self.timeline.set_data(self.data) + # Component C: ML Prediction Track (.pkl) + if self.config.get('use_pkl', False): + print(f"[INFO] Loading ML Model: {os.path.basename(self.config['pkl_path'])}") + self.load_pretrained_classifier() + + def load_pretrained_classifier(self): + """Loads the .pkl model and automatically hunts for its scaler.""" + if not self.config.get('use_pkl') or not self.config.get('pkl_path'): + return + + model_path = self.config['pkl_path'] + + metadata_path = model_path.replace(".pkl", "_metadata.json") + if os.path.exists(metadata_path): + with open(metadata_path, 'r') as f: + self.ml_metadata = json.load(f) + self.active_features = self.ml_metadata.get("feature_keys", []) + print(f"[INFO] Feature map loaded: {len(self.active_features)} features.") + else: + raise Exception - if hasattr(self, 'skeleton_overlay'): - self.skeleton_overlay.set_data(self.data) + try: + # 1. Load the primary model + self.ml_model = joblib.load(model_path) + msg = f"[INFO] ML Model loaded: {os.path.basename(model_path)}" + print(msg) + self.update_status(f"{msg}") + + # 2. Auto-discover the Scaler + base_name = os.path.splitext(model_path)[0] + possible_scaler_paths = [ + f"{base_name}_scaler.pkl", + os.path.join(os.path.dirname(model_path), "scaler.pkl") + ] + + self.ml_scaler = None + for spath in possible_scaler_paths: + if os.path.exists(spath): + self.ml_scaler = joblib.load(spath) + s_msg = f"[INFO] Associated scaler auto-loaded: {os.path.basename(spath)}" + print(s_msg) + self.update_status(f"{s_msg}") + break + + if not self.ml_scaler: + print("[WARNING] No associated scaler found. Proceeding without scaling.") + + except Exception as e: + err = f"[ERROR] Failed to load ML Model or Scaler: {e}" + print(err) + self.update_status(f"{err}") - # Trigger a geometry refresh now that we have accurate video dims from data - self.update_video_geometry() - print(f"Analysis complete for {self.file_path}") + + def run_ml_inference(self, raw_kpts): + """ + Applies scaler, runs inference, and converts frame-by-frame + predictions into contiguous timeline blocks. + """ + if not hasattr(self, 'ml_model') or not self.active_features: + return {} + + # 1. Create the engine and tell it WHICH features to care about + engine = GeneralPredictor() + engine.active_feature_keys = self.active_features + + # 2. Extract 13 features for every frame + X_raw = [] + for frame_idx in range(len(raw_kpts)): + # format_features now returns only the 13 needed values + feat_vector = engine.format_features(raw_kpts[frame_idx]) + X_raw.append(feat_vector) + + X = np.array(X_raw) # Resulting shape: (Frames, 13) + + # 3. Predict + if self.ml_scaler: + X = self.ml_scaler.transform(X) + + preds = self.ml_model.predict(X) + unique, counts = np.unique(preds, return_counts=True) + print(f"DEBUG ML Results: {dict(zip(unique, counts))}") + return self._convert_predictions_to_tracks(preds) + + + + def _convert_predictions_to_tracks(self, predictions): + """Converts an array of class labels into start/stop timeline blocks.""" + events = {} + current_class = None + start_frame = 0 + + # Define labels that mean "nothing is happening" + background_labels = [0, "0", "Idle", "None", None, "Background"] + + for i, pred_class in enumerate(predictions): + if pred_class != current_class: + # Close the previous active block + if current_class not in background_labels: + track_name = f"🤖 AI: {current_class}" + if track_name not in events: + events[track_name] = [] + # Format: [start_frame, end_frame, label, notes] + events[track_name].append([start_frame, i, "Normal", "ML Prediction"]) + + # Start new block + current_class = pred_class + start_frame = i + + # Close the final block if the video ends while an action is active + if current_class not in background_labels: + track_name = f"🤖 AI: {current_class}" + if track_name not in events: + events[track_name] = [] + events[track_name].append([start_frame, len(predictions), "Normal", "ML Prediction"]) + + return events + + + def load_boris_annotations(self): + """Logic to parse the JSON for the specific session/slot.""" + try: + with open(self.config['obs_file'], 'r') as f: + data = json.load(f) + + session = data.get("observations", {}).get(self.config['session_key'], {}) + # Extract events for the specific slot + events = session.get("events", []) + # ... Filter events where slot matches config['slot'] ... + print(f"Loaded {len(events)} events from BORIS.") + except Exception as e: + print(f"Failed to load BORIS data: {e}") + def update_status(self, message): """Updates the inspector or a status bar with worker progress.""" self.info_label.append(message) + def toggle_playback(self): if self.player.playbackState() == QMediaPlayer.PlayingState: self.player.pause() @@ -1860,12 +2659,64 @@ class VideoAnalysisTab(QWidget): self.player.play() self.btn_play.setText("Pause") - def seek_video(self, ms): + + def update_timeline_playhead(self, position_ms): + #debug_print() + fps = self.config.get('fps', 30.0) + total_f = self.config.get('total_frames', 0) + + # Current frame calculation + current_f = int((position_ms / 1000.0) * fps) + + # --- PREVENT BLACK FRAME AT END --- + # If we are within 1 frame of the end, stop and lock to the last valid frame + if hasattr(self, 'skeleton_overlay') and self.skeleton_overlay.isVisible(): + self.skeleton_overlay.set_frame(current_f) + + if current_f >= total_f - 1: + if self.player.playbackState() == QMediaPlayer.PlayingState: + self.player.pause() + self.btn_play.setText("Play") + current_f = total_f - 1 + # Seek slightly back from total duration to keep the image visible + last_valid_ms = int(((total_f - 1) / fps) * 1000) + self.player.setPosition(last_valid_ms) + + # Sync UI + self.timeline.set_playhead(current_f) + self.update_counters(current_f) + + def seek_video(self, frame): + # Use the config or timeline data instead of self.data + fps = self.config.get('fps', 30.0) + total_f = self.config.get('total_frames', 0) + + target_frame = max(0, min(frame, total_f - 1)) + + # Convert frame to milliseconds for QMediaPlayer + ms = int((target_frame / fps) * 1000) self.player.setPosition(ms) + # Sync the UI + self.timeline.set_playhead(target_frame) + self.update_counters(target_frame) + + + def update_counters(self, current_f): + #debug_print() + + # Dedicated method to refresh the labels + fps = self.config.get('fps', 30.0) + total_f = self.config.get('total_frames', 0) + + cur_s, tot_s = int(current_f / fps), int(total_f / fps) + self.lbl_time_counter.setText(f"Time: {cur_s//60:02d}:{cur_s%60:02d} / {tot_s//60:02d}:{tot_s%60:02d}") + self.lbl_frame_counter.setText(f"Frame: {current_f} / {total_f-1}") + + def step_frame(self, delta): # Fallback to 30 FPS if worker data isn't ready - fps = self.data["fps"] if (hasattr(self, 'data') and self.data) else 30.0 + fps = self.config.get('fps', 30.0) current_ms = self.player.position() # One frame in ms = 1000 / fps @@ -1876,6 +2727,7 @@ class VideoAnalysisTab(QWidget): target_ms = max(0, min(target_ms, self.player.duration())) self.player.setPosition(target_ms) + def update_volume(self, value): # QAudioOutput expects a float between 0.0 and 1.0 volume = value / 100.0 @@ -1886,6 +2738,7 @@ class VideoAnalysisTab(QWidget): self.btn_mute.setChecked(False) self.toggle_mute() + def toggle_mute(self): is_muted = self.btn_mute.isChecked() self.audio_output.setMuted(is_muted) @@ -1893,6 +2746,7 @@ class VideoAnalysisTab(QWidget): # Optional: Dim the slider when muted self.sld_volume.setEnabled(not is_muted) + def update_video_geometry(self): if not hasattr(self, "video_item"): return @@ -1935,13 +2789,21 @@ class VideoAnalysisTab(QWidget): if hasattr(self, "skeleton_overlay"): self.skeleton_overlay.setGeometry(int(x_off), int(y_off), target_w, target_h) + def resizeEvent(self, event): # debug_print() super().resizeEvent(event) self.update_video_geometry() if hasattr(self, 'timeline'): - self.timeline.set_zoom(self.timeline.zoom_factor) + self.timeline.update_geometry() + + def eventFilter(self, source, event): + """Keeps the skeleton aligned with the video frame size.""" + if source == getattr(self, 'video_preview_label', None) and event.type() == QEvent.Resize: + self.skeleton_overlay.resize(event.size()) + return super().eventFilter(source, event) + def cleanup(self): if self.player: @@ -2065,6 +2927,132 @@ if __name__ == "__main__": # Only run GUI in the main process if current_process().name == 'MainProcess': app = QApplication(sys.argv) + + style = """ + + /* 1. General App Backgrounds */ + QMainWindow, QWidget#centralWidget, QDialog, QMessageBox, QFileDialog { + background-color: #1e1e1e; + color: #ffffff; + } + + QLabel, QStatusBar, QMenuBar { + color: #ffffff; + } + + /* 2. THE BIG FIX: Removing white backgrounds from inputs */ + QListWidget, QComboBox, QLineEdit, QSpinBox, QTextEdit { + background-color: #2b2b2b; + color: #ffffff; + border: 1px solid #555555; + border-radius: 3px; + padding: 2px; + } + + /* Fix for the dropdown list of a QComboBox */ + QComboBox QAbstractItemView { + background-color: #2b2b2b; + color: #ffffff; + selection-background-color: #00aaff; + selection-color: #ffffff; + outline: none; + border: 1px solid #555555; + } + + /* 3. Tab Navigation Styling */ + QTabWidget::pane { + border: 1px solid #333333; + background: #1e1e1e; + } + + QTabBar::tab { + background: #2b2b2b; + color: #aaa; + padding: 8px 15px; + border: 1px solid #333; + border-bottom: none; + border-top-left-radius: 4px; + border-top-right-radius: 4px; + } + + QTabBar::tab:selected:!last { + background: #3d3d3d; + color: #fff; + font-weight: bold; + } + + QTabBar::tab:hover:!last { + background: #444; + } + + /* THE PLUS TAB: Constant State */ + QTabBar::tab:last, QTabBar::tab:last:hover, QTabBar::tab:last:selected { + background: #1e1e1e; + color: #00aaff; + font-weight: bold; + margin-left: 2px; + border: 1px solid #333; + padding: 4px 12px; + } + + /* 4. Timeline & Custom Widgets */ + TimelineWidget { + background-color: #1e1e1e; + border: 1px solid #333333; + } + + /* 5. Buttons with Grey Borders */ + QPushButton { + background-color: #2b2b2b; + color: #ffffff; + border: 1px solid #555555; + border-radius: 3px; + padding: 4px; + } + QPushButton:hover { + background-color: #3d3d3d; + border-color: #888888; + } + QPushButton:pressed { + background-color: #111111; + } + QPushButton:disabled { + border-color: #333333; + color: #444444; + } + + /* 6. Layout Dividers */ + QSplitter::handle { + background-color: #333333; + } + QSplitter::handle:horizontal { width: 2px; } + QSplitter::handle:vertical { height: 2px; } + + /* 7. Scroll Areas */ + QScrollArea, QScrollArea > QWidget > QWidget { + background-color: #1e1e1e; + border: none; + } + QWidget { background-color: #1e1e1e; color: #ffffff; font-family: 'Segoe UI'; } + QGroupBox { + border: 1px solid #3d3d3d; border-radius: 8px; margin-top: 15px; + padding-top: 15px; font-weight: bold; color: #00aaff; text-transform: uppercase; + } + QLabel { color: #ffffff; font-weight: 500; } + QLabel:disabled { color: #444444; } + QLabel#Metadata { color: #00ffaa; font-family: 'Consolas'; font-size: 11px; } + QLabel#Preview { background-color: #000000; border: 2px solid #3d3d3d; } + QLabel#Warning { color: #ff5555; font-size: 11px; font-style: italic; font-weight: bold; } + + QPushButton { background-color: #3d3d3d; border: 1px solid #555; padding: 6px; border-radius: 4px; } + QPushButton:hover { background-color: #00aaff; color: #000; } + QPushButton:disabled { color: #444; background-color: #252525; } + + QComboBox { background-color: #2d2d2d; border: 1px solid #555; padding: 4px; border-radius: 4px; } + QComboBox:disabled { background-color: #222; color: #444; border: 1px solid #2a2a2a; } + """ + + app.setStyleSheet(style) finish_update_if_needed(PLATFORM_NAME, APP_NAME) window = PremiereWindow() diff --git a/pose_worker.py b/pose_worker.py new file mode 100644 index 0000000..6129ad0 --- /dev/null +++ b/pose_worker.py @@ -0,0 +1,154 @@ +import cv2 +import os +import csv +import numpy as np +from ultralytics import YOLO +from multiprocessing import current_process + +JOINT_NAMES = [ + "nose", "l_eye", "r_eye", "l_ear", "r_ear", "l_shld", "r_shld", + "l_elbw", "r_elbw", "l_wri", "r_wri", "l_hip", "r_hip", + "l_knee", "r_knee", "l_ankl", "r_ankl" +] + +def get_best_infant_match(results, w, h, prev_track_id): + """ + Identifies the most likely infant based on visibility, + centrality, and tracking ID consistency. + """ + if not results[0].boxes or results[0].boxes.id is None: + return None, None, None, None + + ids = results[0].boxes.id.int().cpu().tolist() + kpts = results[0].keypoints.xy.cpu().numpy() + confs = results[0].keypoints.conf.cpu().numpy() + + best_idx, best_score = -1, -1 + + for i, k in enumerate(kpts): + # visibility score + vis = np.sum(confs[i] > 0.5) + valid = k[confs[i] > 0.5] + + # distance from center score + dist = np.linalg.norm(np.mean(valid, axis=0) - [w/2, h/2]) if len(valid) > 0 else 1000 + + # calculate total score + score = (vis * 10) - (dist * 0.1) + (50 if ids[i] == prev_track_id else 0) + + if score > best_score: + best_score, best_idx = score, i + + if best_idx == -1: + return None, None, None, None + + return ids[best_idx], kpts[best_idx], confs[best_idx], best_idx + + + +def run_pose_analysis(video_path, progress_queue, result_queue, config): + """Worker task executed in a separate Process.""" + p_name = current_process().name + pose_cache = video_path.rsplit('.', 1)[0] + "_pose_raw.csv" + print(f"[{p_name}] Starting analysis on: {video_path}") + csv_storage_data = [] + inference_performed = False + + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + + + if config.get("use_cache") and os.path.exists(pose_cache): + print(f"[{p_name}] Cache checkmark active. Loading: {pose_cache}") + try: + with open(pose_cache, 'r') as f: + reader = csv.reader(f) + next(reader) # Skip header + for row in reader: + # Flattened (51,) back to (17, 3) + full_frame = np.array([float(x) for x in row]).reshape(17, 3) + csv_storage_data.append(full_frame) + + progress_queue.put(100) + result_queue.put({ + "raw_kpts": np.array(csv_storage_data), + "fps": fps, + "total_frames": len(csv_storage_data), + "dims": (width, height), + "status": "loaded_from_cache" + }) + return # Exit early, no inference or saving needed + except Exception as e: + print(f"[{p_name}] Cache read failed, falling back to inference: {e}") + csv_storage_data = [] + + + + inference_performed = True + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + model_map = { + "YOLO8n-Pose": "yolov8n-pose.pt", + "YOLO8m-Pose": "yolov8m-pose.pt", + "Mediapipe BlazePose": "mediapipe" + } + model_file = model_map.get(config.get("pose_model"), "yolov8n-pose.pt") + + print(f"[{p_name}] Running inference with model: {model_file}") + model = YOLO(model_file) + + new_csv_storage_data = [] + prev_track_id = None + + for i in range(total_frames): + ret, frame = cap.read() + if not ret: + break + + results = model.track(frame, persist=True, verbose=False) + + track_id, kp, confs, _ = get_best_infant_match(results, width, height, prev_track_id) + + if kp is not None: + prev_track_id = track_id + # Store as (17, 3) including confidence + new_csv_storage_data.append(np.column_stack((kp, confs))) + else: + new_csv_storage_data.append(np.zeros((17, 3))) + + if i % 10 == 0: + progress_queue.put(int((i / total_frames) * 100)) + + cap.release() + + if inference_performed: + print(f"[{p_name}] Saving new pose cache to {pose_cache}") + try: + with open(pose_cache, 'w', newline='') as f: + writer = csv.writer(f) + header = [] + for joint in JOINT_NAMES: + header.extend([f"{joint}_x", f"{joint}_y", f"{joint}_conf"]) + writer.writerow(header) + for frame_array in new_csv_storage_data: + writer.writerow(frame_array.flatten()) + except Exception as e: + print(f"[{p_name}] Error saving cache: {e}") + + # Return results through the queue + result_queue.put({ + "raw_kpts": np.array(new_csv_storage_data), + "fps": fps, + "total_frames": len(new_csv_storage_data), + "dims": (width, height), + "status": "inference_complete" + }) + + print(f"[{p_name}] Analysis complete.") \ No newline at end of file diff --git a/predictor.py b/predictor.py index bae5ced..216cd53 100644 --- a/predictor.py +++ b/predictor.py @@ -1,16 +1,9 @@ -""" -Filename: predictor.py -Description: BLAZES machine learning - -Author: Tyler de Zeeuw -License: GPL-3.0 -""" - -# Built-in imports import inspect +import csv +import os +import json from datetime import datetime -# External library imports import numpy as np import joblib import seaborn as sns @@ -21,249 +14,202 @@ from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report, f1_score, precision_score, recall_score, confusion_matrix from sklearn.preprocessing import StandardScaler -# To be used once multiple models are supported and functioning: -# import torch -# import torch.nn as nn -# import torch.optim as optim -# import xgboost as xgb -# from sklearn.svm import SVC -# import os - VERBOSITY = 1 -GEOMETRY_LIBRARY = { - # --- Distances (Point A, Point B) --- - "dist_l_wrist_nose": ("dist", [9, 0], True), - "dist_r_wrist_nose": ("dist", [10, 0], True), - "dist_l_ear_r_shld": ("dist", [3, 6], True), - "dist_r_ear_l_shld": ("dist", [4, 5], True), - - "dist_l_wrist_pelvis": ("dist", [9, [11, 12]], True), - "dist_r_wrist_pelvis": ("dist", [10, [11, 12]], True), - "dist_l_ankl_pelvis": ("dist", [15, [11, 12]], True), - "dist_r_ankl_pelvis": ("dist", [16, [11, 12]], True), - "dist_nose_pelvis": ("dist", [0, [11, 12]], True), - "dist_ankl_ankl": ("dist", [15, 16], True), +def load_analysis_config(path="analysis_config.json"): + with open(path, 'r') as f: + config = json.load(f) + return config['geometry_library'], config['activity_map'] - # NEW: Cross-Body and Pure Extension Distances - "dist_l_wri_r_shld": ("dist", [9, 6], True), # Reach across body - "dist_r_wri_l_shld": ("dist", [10, 5], True), # Reach across body - "dist_l_wri_l_shld": ("dist", [9, 5], True), # Pure arm extension - "dist_r_wri_r_shld": ("dist", [10, 6], True), # Pure arm extension - - # --- Angles (Point A, Center B, Point C) --- - "angle_l_elbow": ("angle", [5, 7, 9]), - "angle_r_elbow": ("angle", [6, 8, 10]), - "angle_l_shoulder": ("angle", [11, 5, 7]), - "angle_r_shoulder": ("angle", [12, 6, 8]), - "angle_l_knee": ("angle", [11, 13, 15]), - "angle_r_knee": ("angle", [12, 14, 16]), - "angle_l_hip": ("angle", [5, 11, 13]), - "angle_r_hip": ("angle", [6, 12, 14]), - - # --- Custom/Derived --- - "asym_wrist": ("z_diff", [9, 10]), - "asym_ankl": ("z_diff", [15, 16]), - "offset_head": ("head_offset", [0, 5, 6]), - "diff_ear_shld": ("subtraction", ["dist_l_ear_r_shld", "dist_r_ear_l_shld"]), - "abs_diff_ear_shld": ("abs_subtraction", ["dist_l_ear_r_shld", "dist_r_ear_l_shld"]), - - # NEW: Verticality and Contralateral Contrast - "height_l_ankl": ("y_diff", [15, 11]), # Foot height relative to hip - "height_r_ankl": ("y_diff", [16, 12]), # Foot height relative to hip - "diff_knee_angle": ("subtraction", ["angle_l_knee", "angle_r_knee"]), - "asym_wri_shld": ("subtraction", ["dist_l_wri_l_shld", "dist_r_wri_r_shld"]) -} - - -# The Target Activity Map -ACTIVITY_MAP = { - "Mouthing": [ - "dist_l_wrist_nose", "dist_r_wrist_nose", "angle_l_elbow", - "angle_r_elbow", "angle_l_shoulder", "angle_r_shoulder", - "asym_wrist", "offset_head" - ], - "Head Movement": [ - "dist_l_wrist_nose", "dist_r_wrist_nose", "angle_l_elbow", - "angle_r_elbow", "angle_l_shoulder", "angle_r_shoulder", - "asym_wrist", "offset_head", "dist_l_ear_r_shld", - "dist_r_ear_l_shld", "diff_ear_shld", "abs_diff_ear_shld" - ], - "Reach (Left)": [ - "dist_l_wrist_pelvis", "dist_l_wrist_nose", "dist_l_wri_l_shld", - "dist_l_wri_r_shld", "angle_l_elbow", "angle_l_shoulder", - "asym_wri_shld" - ], - "Reach (Right)": [ - "dist_r_wrist_pelvis", "dist_r_wrist_nose", "dist_r_wri_r_shld", - "dist_r_wri_l_shld", "angle_r_elbow", "angle_r_shoulder", - "asym_wri_shld" - ], - "Kick (Left)": [ - "dist_l_ankl_pelvis", "angle_l_knee", "angle_l_hip", - "height_l_ankl", "dist_ankl_ankl", "asym_ankl", - "diff_knee_angle", "dist_nose_pelvis" - ], - "Kick (Right)": [ - "dist_r_ankl_pelvis", "angle_r_knee", "angle_r_hip", - "height_r_ankl", "dist_ankl_ankl", "asym_ankl", - "diff_knee_angle", "dist_nose_pelvis" - ] -} +try: + GEOMETRY_LIBRARY, ACTIVITY_MAP = load_analysis_config() +except FileNotFoundError: + GEOMETRY_LIBRARY, ACTIVITY_MAP = {}, {} + print("Warning: analysis_config.json not found. ML functions will fail.") def debug_print(): if VERBOSITY: frame = inspect.currentframe().f_back qualname = frame.f_code.co_filename - print(qualname) + print(f"DEBUG_PRINT: {qualname}") class GeneralPredictor: def __init__(self): debug_print() self.base_paths = { - "Random Forest": "rf.pkl", - "XGBoost": "xgb.json", - "SVM": "svm.pkl", - "LSTM": "lstm.pth", - "1D-CNN": "cnn.pth" + "Random Forest": "rf.pkl" } - self.raw_participant_buffer = [] self.current_target = "" - self.scaler_cache = {} + self.active_feature_keys = [] - - def add_to_raw_buffer(self, raw_payload, y_labels): + def calculate_and_train(self, training_params): """ - Adds a participant's raw kinematic components to the pool. - raw_payload should contain: 'z_kps', 'directions', 'raw_kps' + Takes the dict from get_selection() in TrainModelWindow. + Loads CSV/JSON pairs, extracts combined features, and trains Random Forest. """ debug_print() - entry = { - "raw_data": raw_payload, - "labels": y_labels + + folder = training_params.get("folder") + pairs = training_params.get("pairs", []) + selected_behaviors = training_params.get("selected_behaviors", []) + self.current_target = training_params.get("target_name", "combined_model") + model_type = training_params.get("model_type", "Random Forest") + + if not pairs or not selected_behaviors: + return "Error: Missing data pairs or target behaviors." + + # 1. Determine the union of ALL needed geometric features across selected behaviors + needed_features = set() + for b_name in selected_behaviors: + req_feats = ACTIVITY_MAP.get(b_name, []) + needed_features.update(req_feats) + + self.active_feature_keys = sorted(list(needed_features)) + print(self.active_feature_keys) + + model_metadata = { + "target_behavior": self.current_target, + "feature_keys": self.active_feature_keys, + "model_type": model_type, + "timestamp": datetime.now().isoformat() } - self.raw_participant_buffer.append(entry) - return f"Added participant to pool. Total participants: {len(self.raw_participant_buffer)}" + + if not self.active_feature_keys: + return "Error: No geometric features mapped to the selected behavior(s) in analysis_config.json." - - def clear_buffer(self): - """Clears the raw pool.""" - debug_print() - self.raw_participant_buffer = [] - - - def calculate_and_train(self, model_type, target_name): - """ - The 'On-the-Fly' engine. Loops through the raw buffer, - calculates features for the SELECTED target, and trains. - """ - debug_print() - self.current_target = target_name all_X = [] all_y = [] - # 1. Process every participant in the pool - for participant in self.raw_participant_buffer: - raw = participant["raw_data"] - all_tracks = participant["labels"] + # 2. Process each Pair (JSON labels + CSV raw pose) + for json_path, csv_path in pairs: + # --- Load JSON Labels --- + try: + with open(json_path, 'r') as f: + label_data = json.load(f) + except Exception as e: + print(f"Error loading {json_path}: {e}") + continue - # Pull the specific track that was requested - track_key = f"OBS: {target_name}" - if track_key not in all_tracks: - print(f"Warning: Track {track_key} not found for a participant. Skipping.") + behaviors = label_data.get("behaviors", {}) + + # --- Load CSV Pose Data --- + try: + raw_kpts = [] + with open(csv_path, 'r') as f: + reader = csv.reader(f) + next(reader) # skip header + for row in reader: + raw_kpts.append(np.array([float(x) for x in row]).reshape(17, 3)) + raw_kpts = np.array(raw_kpts) + except Exception as e: + print(f"Error loading {csv_path}: {e}") continue - y = all_tracks[track_key] - - # Extract lists from the payload - z_scores = raw["z_kps"] - dirs = raw["directions"] - kpts = raw["raw_kps"] + total_frames = len(raw_kpts) + if total_frames == 0: + continue - # Calculate geometric features for every frame + # Create binary target array (0 = Rest, 1 = Active) + y_vector = np.zeros(total_frames, dtype=int) + + # If the frame falls inside ANY of the selected behaviors, mark it 1 + for b_name in selected_behaviors: + instances = behaviors.get(b_name, []) + for inst in instances: + start = inst.get("start_frame", 0) + duration = inst.get("duration_frames", 0) + end = min(start + duration, total_frames) + y_vector[start:end] = 1 + + # --- Calculate Features per Frame --- + # To match the new flow, we just need raw_kpts. + # (Z-scores were previously passed, but those were derived from raw anyway. + # If you require normalized z-scores for RF, you must recalculate them here + # using the same baseline logic from the main window. For now, we extract raw geom.) + participant_features = [] - for i in range(len(y)): - feat = self.format_features(z_scores[i], dirs[i], kpts[i]) + for i in range(total_frames): + kpts = raw_kpts[i] # Shape (17, 3) + feat = self.format_features(kpts) participant_features.append(feat) all_X.append(np.array(participant_features)) - all_y.append(y) + all_y.append(y_vector) + + # 3. Prepare for Training + if not all_X: + return "Error: No valid data extracted from files." - # 2. Prepare for Training X_combined = np.vstack(all_X) y_combined = np.concatenate(all_y) - # 3. Scale the data specifically for this target/model combo + # Check for class imbalance edge case (e.g. 0 instances of behavior found) + if len(np.unique(y_combined)) < 2: + return "Error: Training data only contains one class (usually 0/Rest). Model cannot train." + + metadata_path = self.get_path(model_type).replace(".pkl", "_metadata.json") + with open(metadata_path, 'w') as f: + json.dump(model_metadata, f, indent=4) + + print(f"[INFO] Metadata saved to: {metadata_path}") + + # 4. Scale Data scaler = StandardScaler() X_scaled = scaler.fit_transform(X_combined) scaler_path = self.get_path(model_type, is_scaler=True) joblib.dump(scaler, scaler_path) - # 4. Train/Test Split + # 5. Train/Test Split X_train, X_test, y_train, y_test = train_test_split( X_scaled, y_combined, test_size=0.2, stratify=y_combined, random_state=42 ) - # 5. Process with corresponding Model + # 6. Train Random Forest (Placeholders exist for others) if model_type == "Random Forest": model = RandomForestClassifier(max_depth=15, n_estimators=100, class_weight="balanced") model.fit(X_train, y_train) - # Save the model save_path = self.get_path(model_type) joblib.dump(model, save_path) y_pred = model.predict(X_test) - # Feature Importance for the UI - labels_names = self.get_feature_labels() + # Feature Importance importances = model.feature_importances_ - feature_data = sorted(zip(labels_names, importances), key=lambda x: x[1], reverse=True) - ui_extras = "Top Predictors:
" + "
".join([f"{n}: {v:.3f}" for n, v in feature_data]) + feature_data = sorted(zip(self.active_feature_keys, importances), key=lambda x: x[1], reverse=True) + + ui_extras = "Top Predictors:
" + "
".join([f"{n}: {v:.3f}" for n, v in feature_data[:10]]) file_extras = "Top Predictors:\n" + "\n".join([f"- {n}: {v:.3f}" for n, v in feature_data]) - return self._evaluate_and_report(model_type, y_test, y_pred, ui_extras=ui_extras, file_extras=file_extras, target_name=target_name) + return self._evaluate_and_report(model_type, y_test, y_pred, ui_extras=ui_extras, file_extras=file_extras) - # TODO: More than random forest + elif model_type == "1D-CNN": + return "1D-CNN training placeholder reached. Not yet implemented." + elif model_type == "LSTM": + return "LSTM training placeholder reached. Not yet implemented." + elif model_type == "XGBoost": + return "XGBoost training placeholder reached. Not yet implemented." else: - return "Model type not yet implemented in calculate_and_train." - + return f"Model type {model_type} not supported." def get_path(self, model_type, is_scaler=False): - """Returns the specific file path for the target/model or its scaler.""" debug_print() - suffix = self.base_paths[model_type] - + suffix = self.base_paths.get(model_type, "model.pkl") if is_scaler: suffix = suffix.split('.')[0] + "_scaler.pkl" - return f"ml_{self.current_target}_{suffix}" - - def get_feature_labels(self): - """Returns labels only for features active in the current target.""" - debug_print() - active_keys = ACTIVITY_MAP.get(self.current_target, []) - return active_keys - - - def format_features(self, z_scores, directions, kpts): - """The 'Universal Parser' for geometric features.""" - # debug_print() - # Internal Math Helpers - if self.current_target == "ALL_FEATURES": - active_list = list(GEOMETRY_LIBRARY.keys()) - else: - active_list = ACTIVITY_MAP.get(self.current_target, ACTIVITY_MAP["Mouthing"]) - + def format_features(self, kpts): + """ + Calculates only the geometric features required by self.active_feature_keys. + """ def resolve_pt(idx): if isinstance(idx, list): - # Calculate midpoint of all indices in the list - pts = [kpts[i] for i in idx] + pts = [kpts[i][:2] for i in idx] # Ensure X/Y only return np.mean(pts, axis=0) - return kpts[idx] + return kpts[idx][:2] def get_dist(p1, p2): return np.linalg.norm(p1 - p2) def get_angle(a, b, c): @@ -278,128 +224,122 @@ class GeneralPredictor: try: if kpts is None or len(kpts) < 13: raise ValueError() - # Reference scale (Shoulders) - scale = get_dist(kpts[5], kpts[6]) + 1e-6 + scale = get_dist(kpts[5][:2], kpts[6][:2]) + 1e-6 + + # First Pass: Direct Geometries (Only calculate what is needed or what is a dependency) + for name, config_data in GEOMETRY_LIBRARY.items(): + f_type = config_data[0] + indices = config_data[1] - # First Pass: Direct Geometries - for name, (f_type, indices, *meta) in GEOMETRY_LIBRARY.items(): if f_type == "dist": - # Use resolve_pt for both indices p1 = resolve_pt(indices[0]) p2 = resolve_pt(indices[1]) calculated_pool[name] = get_dist(p1, p2) / scale elif f_type == "angle": - # Use resolve_pt for all three indices p1 = resolve_pt(indices[0]) p2 = resolve_pt(indices[1]) p3 = resolve_pt(indices[2]) calculated_pool[name] = get_angle(p1, p2, p3) - elif f_type == "z_diff": - # Z-scores are usually single indices, but we handle lists just in case - z1 = np.mean([z_scores[i] for i in indices[0]]) if isinstance(indices[0], list) else z_scores[indices[0]] - z2 = np.mean([z_scores[i] for i in indices[1]]) if isinstance(indices[1], list) else z_scores[indices[1]] - calculated_pool[name] = abs(z1 - z2) - elif f_type == "head_offset": p_target = resolve_pt(indices[0]) - p_mid = resolve_pt([indices[1], indices[2]]) # Midpoint of shoulders + p_mid = resolve_pt([indices[1], indices[2]]) calculated_pool[name] = abs(p_target[0] - p_mid[0]) / scale + + elif f_type == "y_diff": # NEW from JSON + p1 = resolve_pt(indices[0]) + p2 = resolve_pt(indices[1]) + calculated_pool[name] = abs(p1[1] - p2[1]) / scale - # Second Pass: Composite Geometries (Subtractions/Symmetry) - # We do this after so 'dist_l_ear_r_shld' is already calculated - for name, (f_type, indices, *meta) in GEOMETRY_LIBRARY.items(): + # Second Pass: Subtractions (Requires first pass to be complete) + for name, config_data in GEOMETRY_LIBRARY.items(): + f_type = config_data[0] + indices = config_data[1] + if f_type == "subtraction": - calculated_pool[name] = calculated_pool[indices[0]] - calculated_pool[indices[1]] + val1 = calculated_pool.get(indices[0], 0) + val2 = calculated_pool.get(indices[1], 0) + calculated_pool[name] = val1 - val2 elif f_type == "abs_subtraction": - calculated_pool[name] = abs(calculated_pool[indices[0]] - calculated_pool[indices[1]]) + val1 = calculated_pool.get(indices[0], 0) + val2 = calculated_pool.get(indices[1], 0) + calculated_pool[name] = abs(val1 - val2) except Exception: - # If a frame fails, fill the pool with zeros to prevent crashes calculated_pool = {name: 0.0 for name in GEOMETRY_LIBRARY.keys()} - # Final Extraction based on current_target - - active_list = ACTIVITY_MAP.get(self.current_target, ACTIVITY_MAP["Mouthing"]) - feature_vector = [calculated_pool[feat] for feat in active_list] + # Final Extraction based on the set of needed features + feature_vector = [calculated_pool.get(feat, 0.0) for feat in self.active_feature_keys] return np.array(feature_vector, dtype=np.float32) - def _prepare_pool_data(self): - """Merges buffer and fits scaler.""" - debug_print() - if not self.X_buffer: - return None, None, None - - X_total = np.vstack(self.X_buffer) - y_total = np.concatenate(self.y_buffer) - - # We always fit a fresh scaler on the current pool - scaler_file = f"{self.current_target}_scaler.pkl" - scaler = StandardScaler() - X_scaled = scaler.fit_transform(X_total) - joblib.dump(scaler, scaler_file) - - return X_scaled, y_total, scaler - - - def _evaluate_and_report(self, model_name, y_test, y_pred, extra_text="", ui_extras="", file_extras="", target_name=""): - """Generates unified metrics, confusion matrix, and reports for ANY model""" + def _evaluate_and_report(self, model_name, y_test, y_pred, ui_extras="", file_extras=""): debug_print() prec = precision_score(y_test, y_pred, zero_division=0) rec = recall_score(y_test, y_pred, zero_division=0) f1 = f1_score(y_test, y_pred, zero_division=0) - target = getattr(self, 'current_target', 'Activity') - display_labels = ['Rest', target] - # Plot Confusion Matrix + display_labels = ['Rest', self.current_target] cm = confusion_matrix(y_test, y_pred) + plt.figure(figsize=(8, 6)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', - xticklabels=display_labels, - yticklabels=display_labels) - plt.title(f'{model_name} Detection: Predicted vs Actual') + xticklabels=display_labels, yticklabels=display_labels) + plt.title(f'{model_name} Detection: {self.current_target}') plt.ylabel('Actual State') plt.xlabel('Predicted State') timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - plt.savefig(f"ml_{target_name}_confusion_matrix_rf_{timestamp}.png") + plt.savefig(f"ml_{self.current_target}_cm_{timestamp}.png") plt.close() - # Classification Report String - report_str = classification_report(y_test, y_pred, - target_names=display_labels, - zero_division=0) + report_str = classification_report(y_test, y_pred, target_names=display_labels, zero_division=0) - # Build TXT File Content report_text = f"MODEL PERFORMANCE REPORT: {model_name}\nGenerated: {timestamp}\n" report_text += "="*40 + "\n" report_text += report_str + "\n" report_text += f"Precision: {prec:.4f}\nRecall: {rec:.4f}\nF1-Score: {f1:.4f}\n" - report_text += "="*40 + "\n" + extra_text report_text += "="*40 + "\n" + file_extras - with open(f"ml_{target_name}_performance_rf_{timestamp}.txt", "w") as f: + with open(f"ml_{self.current_target}_performance_{timestamp}.txt", "w") as f: f.write(report_text) - # Build UI String ui_report = f""" - {model_name} Performance:
+ {model_name} Model for '{self.current_target}'
Precision: {prec:.2f} | Recall: {rec:.2f} | F1: {f1:.2f}

{ui_extras} """ return ui_report - - def calculate_directions(self, analysis_kps): - debug_print() - all_dirs = np.zeros((len(analysis_kps), 17)) - - for f in range(1, len(analysis_kps)): - deltas = analysis_kps[f] - analysis_kps[f-1] # Shape (17, 2) + + + # Inside predictor.py -> GeneralPredictor class + def convert_to_events(self, predictions, track_name="🤖 AI: Predicted"): + """ + Converts a 1D array of class labels into a dictionary of timeline blocks. + predictions: np.array of 0s and 1s + track_name: The name for the resulting timeline row + """ + events = {track_name: []} + current_class = None + start_frame = 0 + + for i, pred in enumerate(predictions): + # We only care about the transition into or out of class 1 + if pred != current_class: + # If we were in an active block (1), close it + if current_class == 1: + events[track_name].append([start_frame, i, "Normal", "ML Prediction"]) + + # If we are starting a new active block (1), mark the start + if pred == 1: + start_frame = i + + current_class = pred + + # Close the final block if the video ends while the behavior is active + if current_class == 1: + events[track_name].append([start_frame, len(predictions), "Normal", "ML Prediction"]) - angles = np.arctan2(-deltas[:, 1], deltas[:, 0]) - all_dirs[f] = angles - - return all_dirs \ No newline at end of file + return events \ No newline at end of file