diff --git a/.gitignore b/.gitignore
index 36b13f1..0d9aef9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -174,3 +174,5 @@ cython_debug/
# PyPI configuration file
.pypirc
+*.boris
+*.json
\ No newline at end of file
diff --git a/main.py b/main.py
index a2367cf..dc7fc0e 100644
--- a/main.py
+++ b/main.py
@@ -18,7 +18,7 @@ import platform
import traceback
from pathlib import Path
from datetime import datetime
-from multiprocessing import current_process, freeze_support
+from multiprocessing import current_process, freeze_support, Process, Queue
# External library imports
import numpy as np
@@ -30,14 +30,15 @@ from ultralytics import YOLO
from updater import finish_update_if_needed, UpdateManager, LocalPendingUpdateCheckThread
from predictor import GeneralPredictor
+from pose_worker import run_pose_analysis
from batch_processing import BatchProcessorDialog
import PySide6
-from PySide6.QtWidgets import (QApplication, QMainWindow, QTabBar, QWidget, QVBoxLayout, QGraphicsView, QGraphicsScene,
+from PySide6.QtWidgets import (QApplication, QLineEdit, QListWidget, QListWidgetItem, QMainWindow, QProgressDialog, QTabBar, QWidget, QVBoxLayout, QGraphicsView, QGraphicsScene,
QHBoxLayout, QSplitter, QLabel, QPushButton, QComboBox, QInputDialog,
- QFileDialog, QScrollArea, QMessageBox, QSlider, QTextEdit, QGroupBox, QGridLayout, QCheckBox, QTabWidget)
-from PySide6.QtCore import Qt, QThread, Signal, QUrl, QRectF, QPointF, QRect, QSizeF
-from PySide6.QtGui import QPainter, QColor, QFont, QPen, QBrush, QAction, QKeySequence, QIcon, QTextOption, QImage, QPixmap
+ QFileDialog, QScrollArea, QMessageBox, QSlider, QTextEdit, QGroupBox, QGridLayout, QCheckBox, QTabWidget, QProgressBar)
+from PySide6.QtCore import QEvent, Qt, QThread, Signal, QUrl, QRectF, QPointF, QRect, QSizeF, QTimer
+from PySide6.QtGui import QCursor, QGuiApplication, QPainter, QColor, QFont, QPen, QBrush, QAction, QKeySequence, QIcon, QTextOption, QImage, QPixmap
from PySide6.QtMultimedia import QMediaPlayer, QAudioOutput
from PySide6.QtMultimediaWidgets import QGraphicsVideoItem
@@ -54,7 +55,7 @@ PLATFORM_NAME = platform.system().lower()
def debug_print():
if VERBOSITY:
frame = inspect.currentframe().f_back
- qualname = frame.f_code.co_filename
+ qualname = frame.f_code.co_qualname
print(qualname)
@@ -109,6 +110,7 @@ from PySide6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QPushButton,
from PySide6.QtGui import QPixmap, QImage
from PySide6.QtCore import Qt
+
class OpenFileWindow(QWidget):
def __init__(self, parent=None):
super().__init__(parent, Qt.WindowType.Window)
@@ -118,11 +120,30 @@ class OpenFileWindow(QWidget):
# State
self.video_path = None
self.obs_file = None
+ self.pkl_path = None
self.full_json_data = None
self.current_video_fps = 30.0
self.current_video_offset = 0.0
+ self.fps = 0
self.setup_ui()
+ self.center_on_screen()
+
+ def center_on_screen(self):
+ """Centers the window on the current screen."""
+ # Get the geometry of the screen where the mouse currently is
+ screen = QGuiApplication.screenAt(QCursor.pos())
+ if not screen:
+ screen = QGuiApplication.primaryScreen()
+
+ screen_geometry = screen.availableGeometry()
+ size = self.sizeHint() # Or self.geometry() if already sized
+
+ x = (screen_geometry.width() - size.width()) // 2
+ y = (screen_geometry.height() - size.height()) // 2
+
+ # Apply the coordinates (relative to the screen's top-left)
+ self.move(screen_geometry.left() + x, screen_geometry.top() + y)
def setup_ui(self):
self.setStyleSheet("""
@@ -167,58 +188,39 @@ class OpenFileWindow(QWidget):
main_layout.addWidget(video_group)
# --- Section 2: Analysis Modes ---
- self.adv_group = QGroupBox("Behavioral Analysis Mode")
- adv_layout = QVBoxLayout(self.adv_group)
- self.combo_mode = QComboBox()
- self.combo_mode.addItems(["BORIS Project (JSON)", "Trained ML Model (.pkl)", "Bypass / Manual"])
- adv_layout.addWidget(self.combo_mode)
-
- self.mode_stack = QStackedWidget()
+ self.boris_group = QGroupBox("Human Coding (BORIS)")
+ self.boris_group.setCheckable(True) # User can toggle this section off
+ self.boris_group.setChecked(False)
+ boris_layout = QGridLayout(self.boris_group)
- # Mode 1: BORIS
- self.page_boris = QWidget()
- m1_grid = QGridLayout(self.page_boris)
self.btn_boris_file = QPushButton("Load .boris File")
- self.lbl_session = QLabel("Session Key:")
self.combo_boris_keys = QComboBox()
- self.lbl_slot = QLabel("Video Slot:")
self.combo_video_slot = QComboBox()
- # Initial Disabled States
- for w in [self.lbl_session, self.combo_boris_keys, self.lbl_slot, self.combo_video_slot]:
- w.setEnabled(False)
+ boris_layout.addWidget(QLabel("BORIS File:"), 0, 0)
+ boris_layout.addWidget(self.btn_boris_file, 0, 1)
+ boris_layout.addWidget(QLabel("Session:"), 1, 0)
+ boris_layout.addWidget(self.combo_boris_keys, 1, 1)
+ boris_layout.addWidget(QLabel("Slot:"), 2, 0)
+ boris_layout.addWidget(self.combo_video_slot, 2, 1)
+ main_layout.addWidget(self.boris_group)
- m1_grid.addWidget(QLabel("BORIS File:"), 0, 0)
- m1_grid.addWidget(self.btn_boris_file, 0, 1)
- m1_grid.addWidget(self.lbl_session, 1, 0)
- m1_grid.addWidget(self.combo_boris_keys, 1, 1)
- m1_grid.addWidget(self.lbl_slot, 2, 0)
- m1_grid.addWidget(self.combo_video_slot, 2, 1)
+ # --- Section 3: Trained ML Model ---
+ self.pkl_group = QGroupBox("Automated Prediction (.pkl)")
+ self.pkl_group.setCheckable(True) # User can toggle this section off
+ self.pkl_group.setChecked(False)
+ pkl_layout = QGridLayout(self.pkl_group)
- # Mode 2: PKL Model
- self.page_pkl = QWidget()
- m2_grid = QGridLayout(self.page_pkl)
self.btn_pkl_file = QPushButton("Load .pkl Model")
self.lbl_pkl_path = QLabel("No model selected...")
- m2_grid.addWidget(QLabel("Model File:"), 0, 0)
- m2_grid.addWidget(self.btn_pkl_file, 0, 1)
- m2_grid.addWidget(QLabel("Selected:"), 1, 0)
- m2_grid.addWidget(self.lbl_pkl_path, 1, 1)
+
+ pkl_layout.addWidget(QLabel("Model File:"), 0, 0)
+ pkl_layout.addWidget(self.btn_pkl_file, 0, 1)
+ pkl_layout.addWidget(QLabel("Path:"), 1, 0)
+ pkl_layout.addWidget(self.lbl_pkl_path, 1, 1)
+ main_layout.addWidget(self.pkl_group)
- # Mode 3: Bypass
- self.page_bypass = QWidget()
- m3_layout = QVBoxLayout(self.page_bypass)
- self.lbl_bypass_info = QLabel("Bypass Mode: No behavioral data will be loaded.\nManual annotation mode enabled.")
- self.lbl_bypass_info.setStyleSheet("color: #888; font-style: italic;")
- m3_layout.addWidget(self.lbl_bypass_info)
-
- self.mode_stack.addWidget(self.page_boris)
- self.mode_stack.addWidget(self.page_pkl)
- self.mode_stack.addWidget(self.page_bypass)
- adv_layout.addWidget(self.mode_stack)
- main_layout.addWidget(self.adv_group)
-
- # --- Section 3: Inference ---
+ # --- Section 4: Inference ---
self.cfg_group = QGroupBox("Inference Settings")
c_grid = QGridLayout(self.cfg_group)
self.check_use_cache = QCheckBox("Auto-search pose cache (.npy)")
@@ -226,7 +228,7 @@ class OpenFileWindow(QWidget):
self.lbl_model_prompt = QLabel("Pose Model:")
self.combo_inference_model = QComboBox()
- self.combo_inference_model.addItems(["YOLO11n-Pose", "YOLO11m-Pose", "Mediapipe BlazePose"])
+ self.combo_inference_model.addItems(["YOLO8n-Pose", "YOLO8m-Pose", "Mediapipe BlazePose"])
self.check_bypass_inference = QCheckBox("Bypass Pose Inference")
self.lbl_inf_warning = QLabel("⚠ WARNING: Nothing fancy. Raw video playback only.")
@@ -251,30 +253,42 @@ class OpenFileWindow(QWidget):
# Connections
self.btn_pick_video.clicked.connect(self.handle_video_selection)
- self.combo_mode.currentIndexChanged.connect(self.mode_stack.setCurrentIndex)
self.btn_boris_file.clicked.connect(self.handle_boris_load)
+ self.btn_pkl_file.clicked.connect(self.handle_pkl_selection)
self.combo_boris_keys.currentIndexChanged.connect(self.handle_session_change)
self.combo_video_slot.currentIndexChanged.connect(self.handle_slot_change)
- self.btn_pkl_file.clicked.connect(self.handle_pkl_selection)
self.check_bypass_inference.toggled.connect(self.handle_inference_toggle)
self.btn_cancel.clicked.connect(self.close)
+ self.boris_group.toggled.connect(self.update_metadata_display)
+
def format_time(self, seconds):
h, m, s = int(seconds // 3600), int((seconds % 3600) // 60), int(seconds % 60)
return f"{h:02d}:{m:02d}:{s:02d}"
+
def handle_video_selection(self):
path, _ = QFileDialog.getOpenFileName(self, "Open Video", "", "Video Files (*.mp4 *.avi *.mkv)")
if path:
self.video_path = path
self.lbl_video_path.setText(os.path.basename(path))
+
+ # Open video to extract properties
cap = cv2.VideoCapture(path)
- fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
- time_str = self.format_time(cap.get(cv2.CAP_PROP_FRAME_COUNT) / fps)
- self.lbl_video_metadata.setText(f"RES: {int(cap.get(3))}x{int(cap.get(4))} | FPS: {fps:.2f} | LEN: {time_str}")
+ self.fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
+ self.total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ self.video_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ self.video_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
+
self.render_preview(path)
- if self.full_json_data: self.attempt_auto_match()
+
+ # If BORIS JSON is already loaded, try to find this video in the slots
+ if self.full_json_data:
+ self.attempt_auto_match()
+
+ # Refresh the label with all 4 fields (Res, FPS, Len, and Offset)
+ self.update_metadata_display()
def handle_boris_load(self):
path, _ = QFileDialog.getOpenFileName(self, "Select JSON", "", "JSON Files (*.json *.boris)")
@@ -285,8 +299,8 @@ class OpenFileWindow(QWidget):
with open(path, 'r') as f:
self.full_json_data = json.load(f)
obs = self.full_json_data.get("observations", {})
+ print(f"\n[DEBUG] BORIS File Loaded. Found {len(obs)} sessions.")
self.combo_boris_keys.setEnabled(True)
- self.lbl_session.setEnabled(True)
self.combo_boris_keys.clear()
self.combo_boris_keys.addItems(list(obs.keys()))
if self.video_path: self.attempt_auto_match()
@@ -301,44 +315,100 @@ class OpenFileWindow(QWidget):
self.combo_video_slot.blockSignals(True)
self.combo_video_slot.clear()
- slots = list(file_map.keys())
- self.combo_video_slot.addItems(slots)
- self.combo_video_slot.setEnabled(True)
- self.lbl_slot.setEnabled(True)
- self.combo_video_slot.blockSignals(False)
- if self.video_path:
- video_filename = os.path.basename(self.video_path)
- for i, slot in enumerate(slots):
- if any(video_filename in f for f in file_map[slot]):
- self.combo_video_slot.setCurrentIndex(i)
- break
+ print(f"[DEBUG] Filtering slots for session: {session_key}")
+ for slot, files in file_map.items():
+ # Check if there is at least one non-empty string in the list
+ valid_files = [f for f in files if isinstance(f, str) and f.strip()]
+ if valid_files:
+ display_name = os.path.basename(valid_files[0].replace('\\', '/'))
+ print(f" > Valid slot found: {slot} ({display_name})")
+ self.combo_video_slot.addItem(f"Slot {slot}: {display_name}", slot)
+
+ self.combo_video_slot.setEnabled(True)
+ self.combo_video_slot.blockSignals(False)
+ self.handle_slot_change()
+
+ def attempt_auto_match(self):
+ """Debugged auto-match: Scans all slots in all sessions for the filename."""
+ if not self.video_path or not self.full_json_data:
+ return
+
+ target_name = os.path.basename(self.video_path)
+ print(f"\n[DEBUG] ATTEMPTING AUTO-MATCH FOR: {target_name}")
+
+ obs = self.full_json_data.get("observations", {})
+
+ for s_idx, (session_key, content) in enumerate(obs.items()):
+ file_map = content.get("file", {})
+ for slot, files in file_map.items():
+ for f_path in files:
+ # Normalize path for comparison
+ clean_f_path = f_path.replace('\\', '/')
+ json_filename = os.path.basename(clean_f_path)
+
+ if json_filename == target_name:
+ print(f"[DEBUG] !!! MATCH FOUND !!!")
+ print(f" Session: {session_key}")
+ print(f" Slot: {slot}")
+
+ # Update UI
+ self.combo_boris_keys.setCurrentIndex(s_idx)
+ # We must allow handle_session_change to finish before setting slot
+ for i in range(self.combo_video_slot.count()):
+ if self.combo_video_slot.itemData(i) == slot:
+ self.combo_video_slot.setCurrentIndex(i)
+ break
+ return
+
+ print(f"[DEBUG] No match found for {target_name} in the JSON file mapping.")
def handle_slot_change(self):
session_key = self.combo_boris_keys.currentText()
- slot_key = self.combo_video_slot.currentText()
- if not session_key or not slot_key: return
+ # Pull the slot ID (e.g., "1") we stored in handle_session_change
+ slot_id = self.combo_video_slot.currentData()
+
+ if not session_key or slot_id is None:
+ return
+
session_data = self.full_json_data.get("observations", {}).get(session_key, {})
- val = session_data.get("media_info", {}).get("offset", {}).get(str(slot_key))
+
+ # Navigate: media_info -> offset -> {slot_id}
+ offsets = session_data.get("media_info", {}).get("offset", {})
+ val = offsets.get(str(slot_id)) # Ensure it's a string key
+
if val is not None:
self.current_video_offset = float(val)
- txt = self.lbl_video_metadata.text().split(" | Offset:")[0]
- self.lbl_video_metadata.setText(f"{txt} | Offset: {self.current_video_offset}s")
+ else:
+ self.current_video_offset = 0.0
+
+ self.update_metadata_display()
+
+ def update_metadata_display(self):
+ # Only update if a video has been selected
+ if not self.video_path:
+ self.lbl_video_metadata.setText("Metadata: N/A")
+ return
+
+ # Check if we are in BORIS mode and have a valid offset
+ if self.boris_group.isChecked():
+ offset_str = f" | Offset: {self.current_video_offset}s"
+ else:
+ offset_str = ""
+
+ # Assemble the final string
+ # Assuming self.fps and self.total_frames were set in handle_video_selection
+ time_str = self.format_time(self.total_frames / self.fps)
+
+ # Get original metadata text but update the offset part
+ base_text = f"RES: {self.video_w}x{self.video_h} | FPS: {self.fps:.2f} | LEN: {time_str}"
+ self.lbl_video_metadata.setText(f"{base_text}{offset_str}")
- def attempt_auto_match(self):
- obs = self.full_json_data.get("observations", {})
- v_name = os.path.splitext(os.path.basename(self.video_path))[0]
- v_parts = v_name.split('_')
- v_fp = f"{v_parts[0]}_{v_parts[1]}_{v_parts[-1]}" if len(v_parts) >= 3 else v_name
- for i, sk in enumerate(obs.keys()):
- s_parts = sk.split('_')
- if len(s_parts) == 3 and f"{s_parts[0]}_{s_parts[1]}_{s_parts[2]}".lower() == v_fp.lower():
- self.combo_boris_keys.setCurrentIndex(i)
- return
def handle_pkl_selection(self):
path, _ = QFileDialog.getOpenFileName(self, "Select Model", "", "Pickle Files (*.pkl)")
if path:
+ self.pkl_path = path
self.lbl_pkl_path.setText(os.path.basename(path))
def handle_inference_toggle(self, checked):
@@ -357,7 +427,216 @@ class OpenFileWindow(QWidget):
cap.release()
+ def get_config(self):
+ """Returns a dictionary of all user-selected settings."""
+ return {
+ "video_path": self.video_path,
+ "total_frames": getattr(self, 'total_frames', 0),
+ "fps": self.fps,
+
+ # BORIS Data
+ "use_boris": self.boris_group.isChecked(),
+ "obs_file": self.obs_file if self.boris_group.isChecked() else None,
+ "session_key": self.combo_boris_keys.currentText(),
+ "slot": self.combo_video_slot.currentData(),
+ "offset": self.current_video_offset,
+
+ # ML Model Data
+ "use_pkl": self.pkl_group.isChecked(),
+ "pkl_path": self.pkl_path if self.pkl_group.isChecked() else None,
+
+ # Inference Settings
+ "use_pose": not self.check_bypass_inference.isChecked(),
+ "pose_model": self.combo_inference_model.currentText(),
+ "use_cache": self.check_use_cache.isChecked(),
+ }
+
+
+import os
+from PySide6.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QPushButton,
+ QLabel, QFileDialog, QFrame, QComboBox)
+from PySide6.QtCore import Qt
+
+class TrainModelWindow(QDialog):
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self.setWindowTitle(f"Train Model - {APP_NAME.upper()}")
+ self.setFixedSize(500, 550) # Slightly taller to fit stats
+ self.selected_folder = None
+ self.valid_pairs = [] # Stores (json_path, csv_path)
+
+ self.setup_ui()
+
+ def setup_ui(self):
+ layout = QVBoxLayout(self)
+ layout.setSpacing(12)
+
+ # --- Section 1: Folder Selection ---
+ self.path_display = QLabel("No folder selected...")
+ self.path_display.setStyleSheet("background: #1e1e1e; padding: 8px; border-radius: 3px;")
+ btn_browse = QPushButton("Select Training Folder")
+ btn_browse.clicked.connect(self.browse_folder)
+
+ layout.addWidget(QLabel("Data Source:"))
+ layout.addWidget(self.path_display)
+ layout.addWidget(btn_browse)
+
+ # --- Section 2: Behavior Selection (Multi-Select) ---
+ layout.addWidget(QLabel("Select Target Behavior(s):"))
+ self.behavior_list = QListWidget()
+ self.behavior_list.setMinimumHeight(150)
+ self.behavior_list.itemChanged.connect(self.handle_selection_change)
+ layout.addWidget(self.behavior_list)
+
+ # --- Section 3: Group Name (Conditional) ---
+ self.group_name_container = QWidget()
+ group_layout = QVBoxLayout(self.group_name_container)
+ group_layout.setContentsMargins(0, 0, 0, 0)
+
+ group_layout.addWidget(QLabel("Combined Variable Name:"))
+ self.edit_group_name = QLineEdit()
+ self.edit_group_name.setPlaceholderText("e.g., Total_Movement")
+ group_layout.addWidget(self.edit_group_name)
+
+ self.group_name_container.hide() # Hidden by default
+ layout.addWidget(self.group_name_container)
+
+ # --- Section 4: Folder Statistics ---
+ self.stats_display = QLabel("Valid Pairs Found: 0")
+ self.stats_display.setStyleSheet("color: #00ffaa; font-family: 'Consolas'; background: #111; padding: 10px;")
+ layout.addWidget(self.stats_display)
+
+ # --- Section 5: ML Architecture ---
+ self.method_dropdown = QComboBox()
+ self.method_dropdown.addItems(["Random Forest", "1D-CNN", "LSTM", "XGBoost"])
+ layout.addWidget(QLabel("ML Architecture:"))
+ layout.addWidget(self.method_dropdown)
+
+ layout.addStretch()
+
+ # --- Final Actions ---
+ button_box = QHBoxLayout()
+ self.btn_train = QPushButton("Start Training")
+ self.btn_train.setEnabled(False)
+ self.btn_train.setStyleSheet("background-color: #2e7d32; font-weight: bold; padding: 8px;")
+ self.btn_train.clicked.connect(self.accept)
+
+ btn_cancel = QPushButton("Cancel")
+ btn_cancel.clicked.connect(self.reject)
+
+ button_box.addWidget(btn_cancel)
+ button_box.addWidget(self.btn_train)
+ layout.addLayout(button_box)
+
+
+ def browse_folder(self):
+ folder = QFileDialog.getExistingDirectory(self, "Select Training Data Folder")
+ if folder:
+ self.selected_folder = folder
+ self.path_display.setText(folder)
+ self.scan_and_parse_folder(folder)
+
+ def scan_and_parse_folder(self, folder):
+ """Scans for pairs and tracks per-behavior statistics."""
+ self.valid_pairs = []
+
+ # Structure: { "Mouthing": {"count": 0, "frames": 0}, ... }
+ behavior_stats = {}
+
+ total_global_events = 0
+ total_global_frames = 0
+
+ files = os.listdir(folder)
+ json_files = [f for f in files if f.endswith("_metrics.json")]
+
+ for j_file in json_files:
+ base_name = j_file.replace("_metrics.json", "")
+ csv_file = base_name + "_pose_raw.csv"
+
+ json_path = os.path.join(folder, j_file)
+ csv_path = os.path.join(folder, csv_file)
+
+ if os.path.exists(csv_path):
+ self.valid_pairs.append((json_path, csv_path))
+
+ try:
+ with open(json_path, 'r') as f:
+ data = json.load(f)
+ behaviors = data.get("behaviors", {})
+ fps = data.get("metadata", {}).get("fps", 30.0)
+
+ for b_name, instances in behaviors.items():
+ if b_name not in behavior_stats:
+ behavior_stats[b_name] = {"count": 0, "frames": 0}
+
+ count = len(instances)
+ frames = sum(inst.get("duration_frames", 0) for inst in instances)
+
+ behavior_stats[b_name]["count"] += count
+ behavior_stats[b_name]["frames"] += frames
+
+ total_global_events += count
+ total_global_frames += frames
+ except Exception as e:
+ print(f"Error parsing {j_file}: {e}")
+
+ # --- Update Dataset Summary Label ---
+ pair_count = len(self.valid_pairs)
+ total_sec = total_global_frames / 30.0 # Standardized estimate
+
+ stats_text = (
+ f"Valid Pairs Found: {pair_count}\n"
+ f"Total Event Instances: {total_global_events}\n"
+ f"Total Behavior Time: {total_sec:.2f}s"
+ )
+ self.stats_display.setText(stats_text)
+
+ # --- Populate Dropdown with Detailed Labels ---
+ self.behavior_list.clear()
+ for b_name in sorted(behavior_stats.keys()):
+ stats = behavior_stats[b_name]
+ sec = stats["frames"] / 30.0
+ label = f"{b_name} ({stats['count']} events, {sec:.1f}s)"
+
+ item = QListWidgetItem(label)
+ item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
+ item.setCheckState(Qt.Unchecked)
+ item.setData(Qt.UserRole, b_name) # Store clean name
+ self.behavior_list.addItem(item)
+
+
+ def handle_selection_change(self):
+ """Shows/Hides the name input based on how many boxes are checked."""
+ selected_items = [self.behavior_list.item(i) for i in range(self.behavior_list.count())
+ if self.behavior_list.item(i).checkState() == Qt.Checked]
+
+ count = len(selected_items)
+ self.group_name_container.setVisible(count > 1)
+ self.btn_train.setEnabled(count > 0)
+
+
+ def get_selection(self):
+ """Returns the specific behaviors to combine and the final variable name."""
+ selected_names = [self.behavior_list.item(i).data(Qt.UserRole)
+ for i in range(self.behavior_list.count())
+ if self.behavior_list.item(i).checkState() == Qt.Checked]
+
+ # If multiple are selected, use the text field name; otherwise use the single name
+ if len(selected_names) > 1:
+ final_name = self.edit_group_name.text().strip() or "combined_variable"
+ else:
+ final_name = selected_names[0] if selected_names else None
+
+ return {
+ "folder": self.selected_folder,
+ "pairs": self.valid_pairs,
+ "selected_behaviors": selected_names,
+ "target_name": final_name,
+ "model_type": self.method_dropdown.currentText()
+ }
+
+
class AboutWindow(QWidget):
"""
@@ -432,387 +711,387 @@ class UserGuideWindow(QWidget):
-class PoseAnalyzerWorker(QThread):
- progress = Signal(str)
- finished_data = Signal(dict)
+# class PoseAnalyzerWorker(QThread):
+# progress = Signal(str)
+# finished_data = Signal(dict)
- def __init__(self, video_path, obs_info=None, predictor=None):
- debug_print()
- super().__init__()
- self.video_path = video_path
- self.obs_info = obs_info
- self.predictor = predictor
- self.pose_df = pd.DataFrame()
+# def __init__(self, video_path, obs_info=None, predictor=None):
+# debug_print()
+# super().__init__()
+# self.video_path = video_path
+# self.obs_info = obs_info
+# self.predictor = predictor
+# self.pose_df = pd.DataFrame()
- def get_best_infant_match(self, results, w, h, prev_track_id):
- debug_print()
- if not results[0].boxes or results[0].boxes.id is None:
- return None, None, None, None
- ids = results[0].boxes.id.int().cpu().tolist()
- kpts = results[0].keypoints.xy.cpu().numpy()
- confs = results[0].keypoints.conf.cpu().numpy()
- best_idx, best_score = -1, -1
- for i, k in enumerate(kpts):
- vis = np.sum(confs[i] > 0.5)
- valid = k[confs[i] > 0.5]
- dist = np.linalg.norm(np.mean(valid, axis=0) - [w/2, h/2]) if len(valid) > 0 else 1000
- score = (vis * 10) - (dist * 0.1) + (50 if ids[i] == prev_track_id else 0)
- if score > best_score:
- best_score, best_idx = score, i
- if best_idx == -1:
- return None, None, None, None
- return ids[best_idx], kpts[best_idx], confs[best_idx], best_idx
+# def get_best_infant_match(self, results, w, h, prev_track_id):
+# debug_print()
+# if not results[0].boxes or results[0].boxes.id is None:
+# return None, None, None, None
+# ids = results[0].boxes.id.int().cpu().tolist()
+# kpts = results[0].keypoints.xy.cpu().numpy()
+# confs = results[0].keypoints.conf.cpu().numpy()
+# best_idx, best_score = -1, -1
+# for i, k in enumerate(kpts):
+# vis = np.sum(confs[i] > 0.5)
+# valid = k[confs[i] > 0.5]
+# dist = np.linalg.norm(np.mean(valid, axis=0) - [w/2, h/2]) if len(valid) > 0 else 1000
+# score = (vis * 10) - (dist * 0.1) + (50 if ids[i] == prev_track_id else 0)
+# if score > best_score:
+# best_score, best_idx = score, i
+# if best_idx == -1:
+# return None, None, None, None
+# return ids[best_idx], kpts[best_idx], confs[best_idx], best_idx
- def _merge_json_observations(self, timeline_events, fps):
- """Restores the grouping and block-pairing logic from the observation files."""
- debug_print()
- if not self.obs_info:
- return
+# def _merge_json_observations(self, timeline_events, fps):
+# """Restores the grouping and block-pairing logic from the observation files."""
+# debug_print()
+# if not self.obs_info:
+# return
- self.progress.emit("Merging JSON Observations...")
- json_path, subkey = self.obs_info
+# self.progress.emit("Merging JSON Observations...")
+# json_path, subkey = self.obs_info
- # try:
- # with open(json_path, 'r') as f:
- # full_json = json.load(f)
+# # try:
+# # with open(json_path, 'r') as f:
+# # full_json = json.load(f)
- # # Extract events for the specific subkey (e.g., 'Participant_01')
- # raw_obs_events = full_json["observations"][subkey]["events"]
- # raw_obs_events.sort(key=lambda x: x[0]) # Sort by timestamp
+# # # Extract events for the specific subkey (e.g., 'Participant_01')
+# # raw_obs_events = full_json["observations"][subkey]["events"]
+# # raw_obs_events.sort(key=lambda x: x[0]) # Sort by timestamp
- # # Group frames by label
- # obs_groups = {}
- # for ev in raw_obs_events:
- # time_sec, _, label, special = ev[0], ev[1], ev[2], ev[3]
- # frame = int(time_sec * fps)
- # if label not in obs_groups:
- # obs_groups[label] = []
- # obs_groups[label].append(frame)
+# # # Group frames by label
+# # obs_groups = {}
+# # for ev in raw_obs_events:
+# # time_sec, _, label, special = ev[0], ev[1], ev[2], ev[3]
+# # frame = int(time_sec * fps)
+# # if label not in obs_groups:
+# # obs_groups[label] = []
+# # obs_groups[label].append(frame)
- # # Convert groups of frames into (Start, End) blocks
- # for label, frames in obs_groups.items():
- # track_name = f"OBS: {label}"
- # processed_blocks = []
+# # # Convert groups of frames into (Start, End) blocks
+# # for label, frames in obs_groups.items():
+# # track_name = f"OBS: {label}"
+# # processed_blocks = []
- # # Step by 2 to create start/end pairs
- # for i in range(0, len(frames) - 1, 2):
- # start_f = frames[i]
- # end_f = frames[i+1]
- # processed_blocks.append((start_f, end_f, "Moderate", "Manual"))
+# # # Step by 2 to create start/end pairs
+# # for i in range(0, len(frames) - 1, 2):
+# # start_f = frames[i]
+# # end_f = frames[i+1]
+# # processed_blocks.append((start_f, end_f, "Moderate", "Manual"))
- # # Register the track globally if it's new
- # if track_name not in TRACK_NAMES:
- # TRACK_NAMES.append(track_name)
- # TRACK_COLORS.append(QColor("#AA00FF")) # Purple for Observations
+# # # Register the track globally if it's new
+# # if track_name not in TRACK_NAMES:
+# # TRACK_NAMES.append(track_name)
+# # TRACK_COLORS.append(QColor("#AA00FF")) # Purple for Observations
- # timeline_events[track_name] = processed_blocks
+# # timeline_events[track_name] = processed_blocks
- # except Exception as e:
- # print(f"Error parsing JSON Observations: {e}")
+# # except Exception as e:
+# # print(f"Error parsing JSON Observations: {e}")
- try:
- with open(json_path, 'r') as f:
- full_json = json.load(f)
+# try:
+# with open(json_path, 'r') as f:
+# full_json = json.load(f)
- raw_obs_events = full_json["observations"][subkey]["events"]
- raw_obs_events.sort(key=lambda x: x[0])
+# raw_obs_events = full_json["observations"][subkey]["events"]
+# raw_obs_events.sort(key=lambda x: x[0])
- # NEW LOGIC: Use a dictionary to store frames for specific track names
- # track_name -> [list of frames]
- obs_groups = {}
+# # NEW LOGIC: Use a dictionary to store frames for specific track names
+# # track_name -> [list of frames]
+# obs_groups = {}
- for ev in raw_obs_events:
- # ev structure: [time_sec, unknown, label, special]
- time_sec, label, special = ev[0], ev[2], ev[3]
- frame = int(time_sec * fps)
+# for ev in raw_obs_events:
+# # ev structure: [time_sec, unknown, label, special]
+# time_sec, label, special = ev[0], ev[2], ev[3]
+# frame = int(time_sec * fps)
- # Determine which tracks this event belongs to
- target_tracks = []
+# # Determine which tracks this event belongs to
+# target_tracks = []
- if special == "Left":
- target_tracks.append(f"OBS: {label} (Left)")
- elif special == "Right":
- target_tracks.append(f"OBS: {label} (Right)")
- elif special == "Both":
- target_tracks.append(f"OBS: {label} (Left)")
- target_tracks.append(f"OBS: {label} (Right)")
- else:
- # No special or unrecognized value
- target_tracks.append(f"OBS: {label}")
+# if special == "Left":
+# target_tracks.append(f"OBS: {label} (Left)")
+# elif special == "Right":
+# target_tracks.append(f"OBS: {label} (Right)")
+# elif special == "Both":
+# target_tracks.append(f"OBS: {label} (Left)")
+# target_tracks.append(f"OBS: {label} (Right)")
+# else:
+# # No special or unrecognized value
+# target_tracks.append(f"OBS: {label}")
- # Add the frame to all applicable tracks
- for t_name in target_tracks:
- if t_name not in obs_groups:
- obs_groups[t_name] = []
- obs_groups[t_name].append(frame)
+# # Add the frame to all applicable tracks
+# for t_name in target_tracks:
+# if t_name not in obs_groups:
+# obs_groups[t_name] = []
+# obs_groups[t_name].append(frame)
- # Convert frame groups into (Start, End) blocks
- for track_name, frames in obs_groups.items():
- processed_blocks = []
+# # Convert frame groups into (Start, End) blocks
+# for track_name, frames in obs_groups.items():
+# processed_blocks = []
- # Step by 2 to create start/end pairs (ensures matching pairs per track)
+# # Step by 2 to create start/end pairs (ensures matching pairs per track)
- if "Sync" in track_name and len(frames) == 1:
- start_f = frames[0]
- end_f = start_f + 1 # Give it a visible width on the timeline
- processed_blocks.append((start_f, end_f, "Moderate", "Manual"))
+# if "Sync" in track_name and len(frames) == 1:
+# start_f = frames[0]
+# end_f = start_f + 1 # Give it a visible width on the timeline
+# processed_blocks.append((start_f, end_f, "Moderate", "Manual"))
- else:
- for i in range(0, len(frames) - 1, 2):
- start_f = frames[i]
- end_f = frames[i+1]
- processed_blocks.append((start_f, end_f, "Moderate", "Manual"))
+# else:
+# for i in range(0, len(frames) - 1, 2):
+# start_f = frames[i]
+# end_f = frames[i+1]
+# processed_blocks.append((start_f, end_f, "Moderate", "Manual"))
- # Register the track in global lists if not already there
- if track_name not in TRACK_NAMES:
- TRACK_NAMES.append(track_name)
- # Using Purple for Observations
- TRACK_COLORS.append(QColor("#AA00FF"))
+# # Register the track in global lists if not already there
+# if track_name not in TRACK_NAMES:
+# TRACK_NAMES.append(track_name)
+# # Using Purple for Observations
+# TRACK_COLORS.append(QColor("#AA00FF"))
- timeline_events[track_name] = processed_blocks
+# timeline_events[track_name] = processed_blocks
- except Exception as e:
- print(f"Error parsing JSON Observations: {e}")
+# except Exception as e:
+# print(f"Error parsing JSON Observations: {e}")
- def _run_existing_ml_models(self, z_kps, dirs, raw_kpts):
- debug_print()
- """
- Scans for trained models and generates timeline tracks for each.
- """
- ai_events = {}
+# def _run_existing_ml_models(self, z_kps, dirs, raw_kpts):
+# debug_print()
+# """
+# Scans for trained models and generates timeline tracks for each.
+# """
+# ai_events = {}
- # 1. Match the pattern from your GeneralPredictor: {Target}_rf.pkl
- model_files = glob.glob("*_rf.pkl")
- print(f"DEBUG: Found model files: {model_files}")
+# # 1. Match the pattern from your GeneralPredictor: {Target}_rf.pkl
+# model_files = glob.glob("*_rf.pkl")
+# print(f"DEBUG: Found model files: {model_files}")
- for model_path in model_files:
- try:
- # Extract Target (e.g., "Mouthing" from "Mouthing_rf.pkl")
- base_name = model_path.split("_rf.pkl")[0]
- target = base_name.replace("ml_", "", 1)
- track_name = f"AI: {target}"
+# for model_path in model_files:
+# try:
+# # Extract Target (e.g., "Mouthing" from "Mouthing_rf.pkl")
+# base_name = model_path.split("_rf.pkl")[0]
+# target = base_name.replace("ml_", "", 1)
+# track_name = f"AI: {target}"
- self.progress.emit(f"Loading AI Observations for {target}...")
+# self.progress.emit(f"Loading AI Observations for {target}...")
- # 2. Match the Scaler naming from calculate_and_train:
- # {target}_random_forest_scaler.pkl
- scaler_path = f"{base_name}_rf_scaler.pkl"
+# # 2. Match the Scaler naming from calculate_and_train:
+# # {target}_random_forest_scaler.pkl
+# scaler_path = f"{base_name}_rf_scaler.pkl"
- if not os.path.exists(scaler_path):
- print(f"DEBUG: Skipping {target}, scaler not found at {scaler_path}")
- continue
+# if not os.path.exists(scaler_path):
+# print(f"DEBUG: Skipping {target}, scaler not found at {scaler_path}")
+# continue
- # Load assets
- model = joblib.load(model_path)
- scaler = joblib.load(scaler_path)
+# # Load assets
+# model = joblib.load(model_path)
+# scaler = joblib.load(scaler_path)
- # 3. Feature extraction (On-the-fly)
- all_features = []
- # We must set the predictor's target so format_features uses the correct ACTIVITY_MAP
- self.predictor.current_target = target
+# # 3. Feature extraction (On-the-fly)
+# all_features = []
+# # We must set the predictor's target so format_features uses the correct ACTIVITY_MAP
+# self.predictor.current_target = target
- for f_idx in range(len(z_kps)):
- feat = self.predictor.format_features(z_kps[f_idx], dirs[f_idx], raw_kpts[f_idx])
- all_features.append(feat)
+# for f_idx in range(len(z_kps)):
+# feat = self.predictor.format_features(z_kps[f_idx], dirs[f_idx], raw_kpts[f_idx])
+# all_features.append(feat)
- # 4. Inference
- X = np.array(all_features)
- X_scaled = scaler.transform(X)
- predictions = model.predict(X_scaled)
+# # 4. Inference
+# X = np.array(all_features)
+# X_scaled = scaler.transform(X)
+# predictions = model.predict(X_scaled)
- # 5. Convert binary 0/1 to blocks
- processed_blocks = []
- start_f = None
+# # 5. Convert binary 0/1 to blocks
+# processed_blocks = []
+# start_f = None
- for f_idx, val in enumerate(predictions):
- if val == 1 and start_f is None:
- start_f = f_idx
- elif val == 0 and start_f is not None:
- # [start, end, severity, direction]
- processed_blocks.append((start_f, f_idx - 1, "Large", "AI"))
- start_f = None
+# for f_idx, val in enumerate(predictions):
+# if val == 1 and start_f is None:
+# start_f = f_idx
+# elif val == 0 and start_f is not None:
+# # [start, end, severity, direction]
+# processed_blocks.append((start_f, f_idx - 1, "Large", "AI"))
+# start_f = None
- if start_f is not None:
- processed_blocks.append((start_f, len(predictions)-1, "Large", "AI"))
+# if start_f is not None:
+# processed_blocks.append((start_f, len(predictions)-1, "Large", "AI"))
- # 6. Global Registration
- if track_name not in TRACK_NAMES:
- TRACK_NAMES.append(track_name)
- # Ensure TRACK_COLORS has an entry for this new track
- TRACK_COLORS.append(QColor("#00FF00"))
+# # 6. Global Registration
+# if track_name not in TRACK_NAMES:
+# TRACK_NAMES.append(track_name)
+# # Ensure TRACK_COLORS has an entry for this new track
+# TRACK_COLORS.append(QColor("#00FF00"))
- ai_events[track_name] = processed_blocks
- print(f"DEBUG: Successfully generated {len(processed_blocks)} blocks for {track_name}")
+# ai_events[track_name] = processed_blocks
+# print(f"DEBUG: Successfully generated {len(processed_blocks)} blocks for {track_name}")
- except Exception as e:
- print(f"Inference Error for {model_path}: {e}")
+# except Exception as e:
+# print(f"Inference Error for {model_path}: {e}")
- return ai_events
+# return ai_events
- def classify_delta(self, z):
- # debug_print()
- z_abs = abs(z)
- if z_abs < 1: return "Rest"
- elif z_abs < 2: return "Small"
- elif z_abs < 3: return "Moderate"
- else: return "Large"
+# def classify_delta(self, z):
+# # debug_print()
+# z_abs = abs(z)
+# if z_abs < 1: return "Rest"
+# elif z_abs < 2: return "Small"
+# elif z_abs < 3: return "Moderate"
+# else: return "Large"
- def _save_pose_cache(self, path, data):
- """
- Saves the raw YOLO keypoints and confidence scores to a CSV.
- Each row represents one frame, flattened from (17, 3) to (51,).
- """
- try:
- with open(path, 'w', newline='') as f:
- writer = csv.writer(f)
+# def _save_pose_cache(self, path, data):
+# """
+# Saves the raw YOLO keypoints and confidence scores to a CSV.
+# Each row represents one frame, flattened from (17, 3) to (51,).
+# """
+# try:
+# with open(path, 'w', newline='') as f:
+# writer = csv.writer(f)
- # Create the descriptive header
- header = []
- for joint in JOINT_NAMES:
- # Replace spaces with underscores for better compatibility with other tools
- header.extend([f"{joint}_x", f"{joint}_y", f"{joint}_conf"])
+# # Create the descriptive header
+# header = []
+# for joint in JOINT_NAMES:
+# # Replace spaces with underscores for better compatibility with other tools
+# header.extend([f"{joint}_x", f"{joint}_y", f"{joint}_conf"])
- writer.writerow(header)
+# writer.writerow(header)
- # Write the frame data
- for frame_data in data:
- # frame_data is (17, 3), flatten to (51,)
- writer.writerow(frame_data.flatten())
+# # Write the frame data
+# for frame_data in data:
+# # frame_data is (17, 3), flatten to (51,)
+# writer.writerow(frame_data.flatten())
- print(f"DEBUG: Pose cache saved with joint headers at {path}")
- except Exception as e:
- print(f"ERROR: Could not save pose cache: {e}")
+# print(f"DEBUG: Pose cache saved with joint headers at {path}")
+# except Exception as e:
+# print(f"ERROR: Could not save pose cache: {e}")
- def run(self):
- debug_print()
- # --- PHASE 1: VIDEO SETUP & POSE EXTRACTION ---
- cap = cv2.VideoCapture(self.video_path)
- fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+# def run(self):
+# debug_print()
+# # --- PHASE 1: VIDEO SETUP & POSE EXTRACTION ---
+# cap = cv2.VideoCapture(self.video_path)
+# fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
+# width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+# height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+# total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
- raw_kps_per_frame = []
- csv_storage_data = []
- valid_mask = []
- pose_cache = self.video_path.rsplit('.', 1)[0] + "_pose_raw.csv"
+# raw_kps_per_frame = []
+# csv_storage_data = []
+# valid_mask = []
+# pose_cache = self.video_path.rsplit('.', 1)[0] + "_pose_raw.csv"
- if os.path.exists(pose_cache):
- self.progress.emit("Loading cached kinematic data...")
- with open(pose_cache, 'r') as f:
- reader = csv.reader(f)
- next(reader)
- for row in reader:
- full_data = np.array([float(x) for x in row]).reshape(17, 3)
- kp = full_data[:, :2]
- raw_kps_per_frame.append(kp)
- csv_storage_data.append(full_data)
- valid_mask.append(np.any(kp))
- else:
- self.progress.emit("Detecting poses with YOLO...")
- model = YOLO("yolov8n-pose.pt")
- prev_track_id = None
- for i in range(total_frames):
- ret, frame = cap.read()
- if not ret: break
- results = model.track(frame, persist=True, verbose=False)
- track_id, kp, confs, _ = self.get_best_infant_match(results, width, height, prev_track_id)
- if kp is not None:
- prev_track_id = track_id
- raw_kps_per_frame.append(kp)
- csv_storage_data.append(np.column_stack((kp, confs)))
- valid_mask.append(True)
- else:
- raw_kps_per_frame.append(np.zeros((17, 2)))
- csv_storage_data.append(np.zeros((17, 3)))
- valid_mask.append(False)
- if i % 50 == 0: self.progress.emit(f"YOLO: {int((i/total_frames)*100)}%")
- self._save_pose_cache(pose_cache, csv_storage_data)
+# if os.path.exists(pose_cache):
+# self.progress.emit("Loading cached kinematic data...")
+# with open(pose_cache, 'r') as f:
+# reader = csv.reader(f)
+# next(reader)
+# for row in reader:
+# full_data = np.array([float(x) for x in row]).reshape(17, 3)
+# kp = full_data[:, :2]
+# raw_kps_per_frame.append(kp)
+# csv_storage_data.append(full_data)
+# valid_mask.append(np.any(kp))
+# else:
+# self.progress.emit("Detecting poses with YOLO...")
+# model = YOLO("yolov8n-pose.pt")
+# prev_track_id = None
+# for i in range(total_frames):
+# ret, frame = cap.read()
+# if not ret: break
+# results = model.track(frame, persist=True, verbose=False)
+# track_id, kp, confs, _ = self.get_best_infant_match(results, width, height, prev_track_id)
+# if kp is not None:
+# prev_track_id = track_id
+# raw_kps_per_frame.append(kp)
+# csv_storage_data.append(np.column_stack((kp, confs)))
+# valid_mask.append(True)
+# else:
+# raw_kps_per_frame.append(np.zeros((17, 2)))
+# csv_storage_data.append(np.zeros((17, 3)))
+# valid_mask.append(False)
+# if i % 50 == 0: self.progress.emit(f"YOLO: {int((i/total_frames)*100)}%")
+# self._save_pose_cache(pose_cache, csv_storage_data)
- cap.release()
- actual_len = len(raw_kps_per_frame)
+# cap.release()
+# actual_len = len(raw_kps_per_frame)
- flattened_rows = []
- for frame_array in csv_storage_data:
- # frame_array is (17, 3) -> flatten to (51,)
- flattened_rows.append(frame_array.flatten())
+# flattened_rows = []
+# for frame_array in csv_storage_data:
+# # frame_array is (17, 3) -> flatten to (51,)
+# flattened_rows.append(frame_array.flatten())
- columns = []
- for name in JOINT_NAMES:
- columns.extend([f"{name}_x", f"{name}_y", f"{name}_conf"])
+# columns = []
+# for name in JOINT_NAMES:
+# columns.extend([f"{name}_x", f"{name}_y", f"{name}_conf"])
- # Store this so the Inspector can access it instantly in memory
- self.pose_df = pd.DataFrame(flattened_rows, columns=columns)
+# # Store this so the Inspector can access it instantly in memory
+# self.pose_df = pd.DataFrame(flattened_rows, columns=columns)
- # --- PHASE 2: KINEMATICS & Z-SCORES ---
- self.progress.emit("Calculating Kinematics...")
- analysis_kpts = []
- for kp in raw_kps_per_frame:
- pelvis = (kp[11] + kp[12]) / 2
- analysis_kpts.append(kp - pelvis)
+# # --- PHASE 2: KINEMATICS & Z-SCORES ---
+# self.progress.emit("Calculating Kinematics...")
+# analysis_kpts = []
+# for kp in raw_kps_per_frame:
+# pelvis = (kp[11] + kp[12]) / 2
+# analysis_kpts.append(kp - pelvis)
- valid_data = [analysis_kpts[i] for i, v in enumerate(valid_mask) if v]
- if valid_data:
- stacked = np.stack(valid_data)
- baseline_mean = np.mean(stacked, axis=0)
- baseline_std = np.std(np.linalg.norm(stacked - baseline_mean, axis=2), axis=0) + 1e-6
- else:
- baseline_mean, baseline_std = np.zeros((17, 2)), np.ones(17)
+# valid_data = [analysis_kpts[i] for i, v in enumerate(valid_mask) if v]
+# if valid_data:
+# stacked = np.stack(valid_data)
+# baseline_mean = np.mean(stacked, axis=0)
+# baseline_std = np.std(np.linalg.norm(stacked - baseline_mean, axis=2), axis=0) + 1e-6
+# else:
+# baseline_mean, baseline_std = np.zeros((17, 2)), np.ones(17)
- np_raw_kps = np.array(raw_kps_per_frame)
- np_z_kps = np.array([np.linalg.norm(kp - baseline_mean, axis=1) / baseline_std for kp in analysis_kpts])
+# np_raw_kps = np.array(raw_kps_per_frame)
+# np_z_kps = np.array([np.linalg.norm(kp - baseline_mean, axis=1) / baseline_std for kp in analysis_kpts])
- # Calculate directions (Assume you have a method for this or use a dummy for now)
- # Using placeholder empty strings to prevent errors in track generation
- np_dirs = np.full((actual_len, 17), "", dtype=object)
+# # Calculate directions (Assume you have a method for this or use a dummy for now)
+# # Using placeholder empty strings to prevent errors in track generation
+# np_dirs = np.full((actual_len, 17), "", dtype=object)
- # --- PHASE 3: TIMELINE GENERATION ---
- # Initialize dictionary with ALL global track names to prevent KeyErrors
- timeline_events = {name: [] for name in TRACK_NAMES}
+# # --- PHASE 3: TIMELINE GENERATION ---
+# # Initialize dictionary with ALL global track names to prevent KeyErrors
+# timeline_events = {name: [] for name in TRACK_NAMES}
- # 1. Kinematic Events (The joint tracks)
- for j_idx, joint_name in enumerate(JOINT_NAMES):
- current_block = None
- for f_idx in range(actual_len):
- severity = self.classify_delta(np_z_kps[f_idx, j_idx])
- if severity != "Rest":
- if current_block and current_block[2] == severity:
- current_block[1] = f_idx
- else:
- current_block = [f_idx, f_idx, severity, ""]
- timeline_events[joint_name].append(current_block)
- else:
- current_block = None
+# # 1. Kinematic Events (The joint tracks)
+# for j_idx, joint_name in enumerate(JOINT_NAMES):
+# current_block = None
+# for f_idx in range(actual_len):
+# severity = self.classify_delta(np_z_kps[f_idx, j_idx])
+# if severity != "Rest":
+# if current_block and current_block[2] == severity:
+# current_block[1] = f_idx
+# else:
+# current_block = [f_idx, f_idx, severity, ""]
+# timeline_events[joint_name].append(current_block)
+# else:
+# current_block = None
- # 2. JSON Observations
- self._merge_json_observations(timeline_events, fps)
+# # 2. JSON Observations
+# self._merge_json_observations(timeline_events, fps)
- # 3. AI Inferred Events
- ai_events = self._run_existing_ml_models(np_z_kps, np_dirs, np_raw_kps)
- timeline_events.update(ai_events)
+# # 3. AI Inferred Events
+# ai_events = self._run_existing_ml_models(np_z_kps, np_dirs, np_raw_kps)
+# timeline_events.update(ai_events)
- # --- PHASE 4: EMIT ---
- data = {
- "video_path": self.video_path,
- "fps": fps,
- "total_frames": actual_len,
- "width": width, "height": height,
- "events": timeline_events,
- "raw_kps": np_raw_kps,
- "z_kps": np_z_kps,
- "directions": np_dirs,
- "baseline_kp_mean": baseline_mean
- }
- self.progress.emit("Analysis Complete!")
- self.finished_data.emit(data)
+# # --- PHASE 4: EMIT ---
+# data = {
+# "video_path": self.video_path,
+# "fps": fps,
+# "total_frames": actual_len,
+# "width": width, "height": height,
+# "events": timeline_events,
+# "raw_kps": np_raw_kps,
+# "z_kps": np_z_kps,
+# "directions": np_dirs,
+# "baseline_kp_mean": baseline_mean
+# }
+# self.progress.emit("Analysis Complete!")
+# self.finished_data.emit(data)
@@ -826,110 +1105,68 @@ class PoseAnalyzerWorker(QThread):
# ==========================================
# TIMELINE WIDGET
# ==========================================
+import numpy as np
+from PySide6.QtWidgets import QWidget, QScrollArea
+from PySide6.QtCore import Qt, Signal, QRect, QRectF
+from PySide6.QtGui import QPainter, QPen, QColor, QFont, QBrush
+
+
+
class TimelineWidget(QWidget):
seek_requested = Signal(int)
visibility_changed = Signal(set)
track_selected = Signal(str)
- def __init__(self):
- debug_print()
- super().__init__()
+
+ def __init__(self, parent=None):
+ super().__init__(parent)
self.data = None
+ self.track_names = []
+ self.track_colors = []
self.current_frame = 0
- self.zoom_factor = 1.0 # Pixels per frame
- self.label_width = 160 # Fixed gutter for track names
+ self.zoom_factor = 1.0
+ self.label_width = 160
self.track_height = 25
self.ruler_height = 20
- self.scrollbar_buffer = 2 # Extra space for the horizontal scrollbar
+ self.scrollbar_buffer = 2
self.hidden_tracks = set()
- self.sync_offset = 0.0
+ self.sync_offset = 0.0
self.sync_fps = 30.0
- # Calculate total required height
- self.total_content_height = (NUM_TRACKS * self.track_height) + self.ruler_height
+ self.is_scrubbing = False
+
+ self.total_content_height = (15 * self.track_height) + self.ruler_height
self.setMinimumHeight(self.total_content_height + self.scrollbar_buffer)
- def set_sync_params(self, offset_seconds, fps=None):
- """
- Updates the temporal shift parameters and refreshes the UI.
- """
- debug_print()
- self.sync_offset = float(offset_seconds)
+
+ def set_data(self, events_dict, total_frames, fps):
+ """Expects grouped events from your BORIS loader."""
+ self.track_names = sorted(list(events_dict.keys()))
+ self.data = {
+ "events": events_dict,
+ "total_frames": total_frames,
+ "fps": fps
+ }
+ self.sync_fps = fps
- # Only update FPS if a valid value is provided,
- # otherwise keep the existing data/video FPS
- if fps and fps > 0:
- self.sync_fps = float(fps)
- elif self.data and "fps" in self.data:
- self.sync_fps = float(self.data["fps"])
-
- print(f"DEBUG: Timeline Sync Set - Offset: {self.sync_offset}s, FPS: {self.sync_fps}")
-
- # Trigger paintEvent to redraw the blocks in their new shifted positions
+ # Generate colors dynamically since we don't know the tracks ahead of time
+ self.track_colors = [QColor.fromHsl((i * 360 // max(1, len(self.track_names))), 160, 140)
+ for i in range(len(self.track_names))]
+ self.update_geometry()
self.update()
- def set_zoom(self, factor):
- debug_print()
- if not self.data: return
-
- # Calculate MIN zoom: The zoom required to make the video fit the width exactly
- # (Available Width - Sidebar) / Total Frames
- available_w = self.parent().width() - self.label_width if self.parent() else 800
- min_zoom = available_w / self.data["total_frames"]
-
- # Clamp: Don't zoom out past the video end, don't zoom in to infinity
- self.zoom_factor = max(min_zoom, min(factor, 50.0))
- self.update_geometry()
-
- def get_all_binary_labels(self, offset_seconds=0.0, fps=30.0):
- """
- Extracts binary labels for ALL tracks in self.data["events"].
- Returns a dict: {'OBS: Mouthing': [0,1,0...], 'OBS: Kicking': [0,0,1...]}
- """
- debug_print()
- all_labels = {}
-
- if not self.data or "events" not in self.data:
- return all_labels
-
- total_frames = self.data.get("total_frames", 0)
- if total_frames == 0:
- return all_labels
-
- frame_shift = int(offset_seconds * fps)
-
- for track_name in self.data["events"]:
- sequence = np.zeros(total_frames, dtype=int)
-
- for event in self.data["events"][track_name]:
- start_f = int(event[0]) - frame_shift
- end_f = int(event[1]) - frame_shift
-
- # Clamp values
- start_idx = max(0, min(start_f, total_frames - 1))
- end_idx = max(0, min(end_f, total_frames))
-
- if start_idx < end_idx:
- sequence[start_idx:end_idx] = 1
-
- all_labels[track_name] = sequence.tolist() # Convert to list for easier storage
-
- return all_labels
-
-
def update_geometry(self):
- debug_print()
-
if self.data:
# Width is sidebar + (frames * zoom)
total_w = self.label_width + int(self.data["total_frames"] * self.zoom_factor)
self.setFixedWidth(total_w)
self.update()
+
def wheelEvent(self, event):
- debug_print()
+ # debug_print()
if event.modifiers() == Qt.ControlModifier:
delta = event.angleDelta().y()
@@ -940,9 +1177,9 @@ class TimelineWidget(QWidget):
# Let the scroll area handle normal vertical scrolling
super().wheelEvent(event)
- # --- NEW: CTRL + Plus / Minus / Zero ---
+
def keyPressEvent(self, event):
- debug_print()
+ #debug_print()
if event.modifiers() == Qt.ControlModifier:
if event.key() == Qt.Key_Plus or event.key() == Qt.Key_Equal:
@@ -954,48 +1191,42 @@ class TimelineWidget(QWidget):
else:
super().keyPressEvent(event)
- def set_data(self, data):
- debug_print()
- self.data = data
- self.total_content_height = (len(TRACK_NAMES) * self.track_height) + self.ruler_height
- self.setMinimumHeight(self.total_content_height + self.scrollbar_buffer)
+ def set_zoom(self, factor):
+ # Calculate MIN zoom: The zoom required to make the video fit the width exactly
+ # (Available Width - Sidebar) / Total Frames
+ available_w = self.parent().width() - self.label_width if self.parent() else 800
+ min_zoom = available_w / self.data["total_frames"]
+
+ # Clamp: Don't zoom out past the video end, don't zoom in to infinity
+ self.zoom_factor = max(min_zoom, min(factor, 50.0))
self.update_geometry()
+
def set_playhead(self, frame):
- debug_print()
old_x = self.label_width + (self.current_frame * self.zoom_factor)
self.current_frame = frame
new_x = self.label_width + (self.current_frame * self.zoom_factor)
- self.ensure_playhead_visible()
- self.update(int(old_x - 5), 0, 10, self.height())
+
+ # Repaint only the playhead areas for performance
+ self.update(int(old_x - 5), 0, 10, self.height())
self.update(int(new_x - 5), 0, 10, self.height())
+ self.ensure_playhead_visible()
+
def ensure_playhead_visible(self):
- debug_print()
-
- """Auto-scrolls the scroll area if the playhead leaves the viewport."""
- # Find the QScrollArea parent
scroll_area = self.parent().parent()
if not isinstance(scroll_area, QScrollArea): return
-
+
scrollbar = scroll_area.horizontalScrollBar()
view_width = scroll_area.viewport().width()
-
- # Playhead position in pixels
px = self.label_width + int(self.current_frame * self.zoom_factor)
-
- # Current scroll position
scroll_x = scrollbar.value()
-
- # If playhead is beyond the right edge of visible area
- if px > (scroll_x + view_width):
- # Shift scroll so playhead is at the left (plus sidebar)
- scrollbar.setValue(px - self.label_width)
-
- # If playhead is behind the left edge (e.g. user seeked backwards)
- elif px < (scroll_x + self.label_width):
- scrollbar.setValue(px - self.label_width)
+
+ if px > (scroll_x + view_width) or px < (scroll_x + self.label_width):
+ scrollbar.setValue(px - self.label_width - (view_width // 4))
+
+
def mousePressEvent(self, event):
debug_print()
@@ -1017,8 +1248,8 @@ class TimelineWidget(QWidget):
if pos_x < scroll_x + self.label_width:
relative_y = pos_y - self.ruler_height
track_idx = int(relative_y // self.track_height)
- if 0 <= track_idx < len(TRACK_NAMES):
- name = TRACK_NAMES[track_idx]
+ if 0 <= track_idx < len(self.track_names):
+ name = self.track_names[track_idx]
if name in self.hidden_tracks: self.hidden_tracks.remove(name)
else: self.hidden_tracks.add(name)
self.visibility_changed.emit(self.hidden_tracks)
@@ -1032,8 +1263,8 @@ class TimelineWidget(QWidget):
# Handle track selection if in the data area
if pos_y >= self.ruler_height:
track_idx = int((pos_y - self.ruler_height) // self.track_height)
- if 0 <= track_idx < len(TRACK_NAMES):
- self.track_selected.emit(TRACK_NAMES[track_idx])
+ if 0 <= track_idx < len(self.track_names):
+ self.track_selected.emit(self.track_names[track_idx])
self.selected_track_idx = track_idx
self.update()
else:
@@ -1043,164 +1274,115 @@ class TimelineWidget(QWidget):
self.update()
def mouseMoveEvent(self, event):
- debug_print()
-
- # This only fires while moving if a button is held down by default
if self.is_scrubbing:
self.update_frame_from_mouse(event.position().x())
def mouseReleaseEvent(self, event):
- debug_print()
-
if event.button() == Qt.LeftButton:
self.is_scrubbing = False
- def update_frame_from_mouse(self, x_pos):
- """Helper to calculate frame and emit the seek signal."""
- debug_print()
- relative_x = x_pos - self.label_width
- frame = int(relative_x / self.zoom_factor)
+ def update_frame_from_mouse(self, x):
+ rel_x = x - self.label_width
+ frame = int(rel_x / self.zoom_factor)
frame = max(0, min(frame, self.data["total_frames"] - 1))
-
- # We emit seek_requested so the Video Player and Premiere class sync up
self.seek_requested.emit(frame)
-
+
+
def paintEvent(self, event):
- debug_print()
if not self.data: return
-
- dirty_rect = event.rect()
+
painter = QPainter(self)
+ dirty_rect = event.rect()
- # 1. Determine current scroll position to keep labels sticky
+ # 1. Coordinate Setup
scroll_area = self.parent().parent()
scroll_x = 0
if isinstance(scroll_area, QScrollArea):
scroll_x = scroll_area.horizontalScrollBar().value()
-
w, h = self.width(), self.height()
total_f = self.data["total_frames"]
fps = self.data.get("fps", 30)
offset_y = 20
- # 2. DRAW DATA AREA (Events and Playhead)
- # --- 2. DRAW DATA AREA (Muted Patterns + Events + Playhead) ---
- sync_off = getattr(self, "sync_offset", 0.0)
- sync_fps = getattr(self, "sync_fps", fps)
- frame_shift = int(sync_off * sync_fps)
- for i, name in enumerate(TRACK_NAMES):
+ # --- 2. DRAW DATA AREA (Events and Playhead) ---
+ for i, name in enumerate(self.track_names):
y = offset_y + (i * self.track_height)
is_hidden = name in self.hidden_tracks
-
+
if y + self.track_height < dirty_rect.top() or y > dirty_rect.bottom():
continue
- # A. Draw Muted/Disabled Background Pattern
-
- if is_hidden:
- # Calculate the visible rectangle for this track to the right of the sidebar
- mute_rect = QRectF(self.label_width, y, w - self.label_width, self.track_height)
- # Fill with a dark "disabled" base
- painter.fillRect(mute_rect, QColor(40, 40, 40))
- # Add the Cross-Hatch Pattern
- pattern_brush = QBrush(QColor(60, 60, 60, 100), Qt.DiagCrossPattern)
- painter.fillRect(mute_rect, pattern_brush)
+ # Event Blocks
+ if name in self.data["events"]:
+ base_color = self.track_colors[i]
+ for event_item in self.data["events"][name]:
+ # Map the new data format: [start_f, end_f, type, value]
+ start_f, end_f = event_item[0], event_item[1]
+
+ if "AI:" in name:
+ s_start, s_end = start_f, end_f
+ else:
+ s_start = start_f - 0
+ s_end = end_f - 0
+
+ x_start = self.label_width + (s_start * self.zoom_factor)
+ x_end = self.label_width + (s_end * self.zoom_factor)
- # B. Draw Event Blocks
- base_color = TRACK_COLORS[i]
- for start_f, end_f, severity, direction in self.data["events"][name]:
- # x_start = self.label_width + (start_f * self.zoom_factor)
- # x_end = self.label_width + (end_f * self.zoom_factor)
+ # Performance optimization: skip drawing if off-screen
+ if x_end < scroll_x or x_start > scroll_x + w:
+ continue
- if "AI:" in name:
- # AI predictions are already calculated in video-time, NO SHIFT
- shifted_start, shifted_end = start_f, end_f
-
- else:
- shifted_start = start_f - frame_shift
- shifted_end = end_f - frame_shift
-
- x_start = self.label_width + (shifted_start * self.zoom_factor)
- x_end = self.label_width + (shifted_end * self.zoom_factor)
-
- # Performance optimization: skip drawing if off-screen
- if x_end < scroll_x or x_start > scroll_x + w:
- continue
-
- if x_end < dirty_rect.left() or x_start > dirty_rect.right():
- continue
-
- # If hidden, make the event block very faint/transparent
- if is_hidden:
- color = QColor(120, 120, 120, 40) # Muted Grey
- else:
- alpha = 80 if severity == "Small" else 160 if severity == "Moderate" else 255
+ # Draw block
color = QColor(base_color)
- color.setAlpha(alpha)
+ if is_hidden:
+ color = QColor(120, 120, 120, 40)
+
+ painter.fillRect(QRectF(x_start, y + 2, max(1, x_end - x_start), self.track_height - 4), color)
- painter.fillRect(QRectF(x_start, y + 2, max(1, x_end - x_start), self.track_height - 4), color)
# Draw Playhead
playhead_x = self.label_width + (self.current_frame * self.zoom_factor)
painter.setPen(QPen(QColor(255, 0, 0), 2))
- painter.drawLine(playhead_x, 0, playhead_x, h)
+ painter.drawLine(int(playhead_x), 0, int(playhead_x), h)
- # 3. DRAW STICKY SIDEBAR (Pinned to the left edge of the viewport)
- # We draw this AFTER the data so it covers the blocks as they scroll past
+ # # --- 3. DRAW STICKY SIDEBAR (Pinned to the left edge) ---
+ # # Draw this AFTER the data so it masks anything scrolling under it
sidebar_rect = QRect(scroll_x, 0, self.label_width, h)
- painter.fillRect(sidebar_rect, QColor(30, 30, 30)) # Solid background
+ painter.fillRect(sidebar_rect, QColor(30, 30, 30))
# Ruler segment for the sidebar area
painter.fillRect(scroll_x, 0, self.label_width, offset_y, QColor(45, 45, 45))
- for i, name in enumerate(TRACK_NAMES):
+ for i, name in enumerate(self.track_names):
y = offset_y + (i * self.track_height)
is_hidden = name in self.hidden_tracks
- # Grid Line
+
+ # Pinned Grid Line
painter.setPen(QColor(60, 60, 60))
painter.drawLine(scroll_x, y, scroll_x + w, y)
- # Sticky Label Text
- if is_hidden:
- # Very dark grey to show it's "OFF"
- text_color = QColor(70, 70, 70)
- else:
- # Bright white/grey to show it's "ON"
- text_color = QColor(220, 220, 220)
-
+ # # Pinned Label Text (Anchored to scroll_x)
+ text_color = QColor(70, 70, 70) if is_hidden else QColor(220, 220, 220)
painter.setPen(text_color)
painter.setFont(QFont("Arial", 8, QFont.Bold))
painter.drawText(scroll_x + 10, y + 17, name)
- # 4. DRAW TIME RULER TICKS (Right of the sticky sidebar)
+ # --- 4. DRAW TIME RULER TICKS ---
target_spacing_px = 120
-
- # Available units in frames: 1, 5, 15, 30 (1s), 150 (5s), 300 (10s), 1800 (1min)
possible_units = [1, 5, 15, 30, 150, 300, 900, 1800]
+ tick_interval = next((u for u in possible_units if (u * self.zoom_factor) >= target_spacing_px), 1800)
- # Find the smallest unit that results in at least target_spacing_px
- tick_interval = possible_units[-1]
- for unit in possible_units:
- if (unit * self.zoom_factor) >= target_spacing_px:
- tick_interval = unit
- break
-
- # 2. DRAW BACKGROUNDS
- painter.fillRect(0, 0, w, 20, QColor(45, 45, 45)) # Ruler Bar
+ # Draw Ruler Background
+ painter.fillRect(0, 0, w, 20, QColor(45, 45, 45))
- # 3. DRAW TICKS AND TIME LABELS
painter.setPen(QColor(180, 180, 180))
painter.setFont(QFont("Segoe UI", 7))
-
- # Sub-ticks (draw 5 small lines for every 1 major interval)
sub_interval = max(1, tick_interval // 5)
- # Start loop from 0 to total frames
for f in range(0, total_f + 1, sub_interval):
x = self.label_width + int(f * self.zoom_factor)
- # Optimization: Don't draw if off-screen
if x < scroll_x: continue
if x > scroll_x + w: break
@@ -1220,83 +1402,83 @@ class TimelineWidget(QWidget):
time_str = f"{minutes:02d}m:{seconds:02d}s"
else:
time_str = f"{seconds}s"
-
painter.drawText(x + 4, 12, time_str)
else:
- # Minor Tick
painter.drawLine(x, 16, x, 20)
- def get_ai_extractions(self):
- """
- Processes timeline data for AI tracks and specific OBS sync events.
- """
- debug_print()
- fps = self.data.get("fps", 30.0)
-
- extraction_data = {
- "metadata": {
- "fps": fps,
- "total_frames": self.data.get("total_frames", 0),
- "track_summaries": {}
- },
- "obs": {}, # New top-level key for specific OBS events
- "ai_tracks": {}
- }
+ self.update_geometry()
+
+
+
+
+
+class TrainingWorker(QThread):
+ # Signals to communicate back to the UI
+ finished = Signal(str) # Sends the HTML report back
+ error = Signal(str) # Sends error messages
+
+ def __init__(self, params):
+ super().__init__()
+ self.params = params
+
+ def run(self):
+ try:
+ from predictor import GeneralPredictor
+ predictor = GeneralPredictor()
+ # This is the heavy calculation and training
+ report = predictor.calculate_and_train(self.params)
+ self.finished.emit(report)
+ except Exception as e:
+ self.error.emit(str(e))
+
+
+
+
+from PySide6.QtCore import QThread, Signal
+
+class MLInferenceWorker(QThread):
+ finished = Signal(dict) # Emits the timeline_events dictionary
+ error = Signal(str)
+
+ def __init__(self, raw_kpts, ml_model, ml_scaler, active_features, behavior_name):
+ super().__init__()
+ self.raw_kpts = raw_kpts
+ self.ml_model = ml_model
+ self.ml_scaler = ml_scaler
+ self.active_features = active_features
+ self.behavior_name = f"AI: {behavior_name}"
+
+ def run(self):
+ try:
+ # Import predictor logic inside the thread
+ from predictor import GeneralPredictor
+ engine = GeneralPredictor()
+ engine.active_feature_keys = self.active_features
+
+ # 1. Feature Extraction (The slow part)
+ X_raw = []
+ for frame in self.raw_kpts:
+ X_raw.append(engine.format_features(frame))
+ X = np.array(X_raw)
+
+ # 2. Scaling & Prediction
+ if self.ml_scaler:
+ X = self.ml_scaler.transform(X)
+
+ preds = self.ml_model.predict(X)
+
+ # 3. Convert to timeline blocks (using your existing converter logic)
+ # You can either move the converter into GeneralPredictor or call it here
+ events = engine.convert_to_events(preds, track_name=self.behavior_name) # Ensure engine has this method
+
+ self.finished.emit(events)
+ except Exception as e:
+ self.error.emit(str(e))
- if not self.data or "events" not in self.data:
- return extraction_data
- # 1. Extract the OBS: Time Sync event specifically
- sync_key = "OBS: Time Sync"
- if sync_key in self.data["events"]:
- sync_blocks = self.data["events"][sync_key]
- # Convert blocks to a list of dicts for the JSON
- extraction_data["obs"][sync_key] = [
- {
- "start_frame": b[0],
- "end_frame": b[1],
- "start_time_sec": round(b[0] / fps, 3),
- "end_time_sec": round(b[1] / fps, 3)
- } for b in sync_blocks
- ]
- # 2. Process AI Tracks
- for track_name, blocks in self.data["events"].items():
- if track_name.startswith("AI:"):
- track_results = []
- track_total = 0
- track_long = 0
-
- for block in blocks:
- start_f, end_f = block[0], block[1]
- severity = block[2] if len(block) > 2 else "Normal"
-
- start_sec = round(start_f / fps, 3)
- end_sec = round(end_f / fps, 3)
- duration = round(end_sec - start_sec, 3)
-
- track_total += 1
- if duration > 2.0:
- track_long += 1
-
- track_results.append({
- "start_frame": int(start_f),
- "end_frame": int(end_f),
- "start_time_sec": start_sec,
- "end_time_sec": end_sec,
- "duration_sec": duration,
- "severity": severity
- })
-
- extraction_data["ai_tracks"][track_name] = track_results
- extraction_data["metadata"]["track_summaries"][track_name] = {
- "event_count": track_total,
- "long_events_over_2s": track_long,
- "total_duration_sec": round(sum(r["duration_sec"] for r in track_results), 3)
- }
- return extraction_data
class SkeletonOverlay(QWidget):
@@ -1388,7 +1570,9 @@ class SkeletonOverlay(QWidget):
# --- 2. DRAW BASELINE (Only if not hidden) ---
if "Baseline" not in self.hidden_tracks:
idx_l_hip, idx_r_hip = self.KP_MAP["Lhip"], self.KP_MAP["Rhip"]
- pelvis_live = (kp_live[idx_l_hip] + kp_live[idx_r_hip]) / 2
+ p_left = kp_live[idx_l_hip][:2]
+ p_right = kp_live[idx_r_hip][:2]
+ pelvis_live = (p_left + p_right) / 2
kp_baseline = self.data['baseline_kp_mean'] + pelvis_live
painter.setPen(QPen(QColor(200, 200, 200, 200), 2, Qt.DashLine))
@@ -1456,6 +1640,19 @@ class SkeletonOverlay(QWidget):
painter.drawEllipse(pt, 4, 4)
+
+
+
+
+
+
+
+
+
+
+
+
+
class VideoView(QGraphicsView):
resized = Signal()
@@ -1501,18 +1698,60 @@ class PremiereWindow(QMainWindow):
QMainWindow, QWidget#centralWidget { background-color: #1e1e1e; }
QLabel, QStatusBar, QMenuBar { color: #ffffff; }
QTabWidget::pane { border: 1px solid #333333; background: #1e1e1e; }
- QTabBar::tab { background: #2b2b2b; color: #aaa; padding: 8px 15px; border: 1px solid #333; border-bottom: none; border-top-left-radius: 4px; border-top-right-radius: 4px; }
- QTabBar::tab:selected { background: #3d3d3d; color: #fff; font-weight: bold; }
- QTabBar::tab:hover { background: #444; }
- """)
+ QTabBar::tab {
+ background: #2b2b2b;
+ color: #aaa;
+ padding: 8px 15px;
+ border: 1px solid #333;
+ border-bottom: none;
+ border-top-left-radius: 4px;
+ border-top-right-radius: 4px;
+ }
+
+ /* 2. Content Tabs: Selected & Hover States */
+ /* We add :!last to ensure these NEVER apply to the + tab */
+ QTabBar::tab:selected:!last {
+ background: #3d3d3d;
+ color: #fff;
+ font-weight: bold;
+ }
+
+ QTabBar::tab:hover:!last {
+ background: #444;
+ }
+
+ /* 3. THE PLUS TAB: Constant State */
+ /* We define all states (normal, selected, hover) to be identical */
+ QTabBar::tab:last,
+ QTabBar::tab:last:hover,
+ QTabBar::tab:last:selected {
+ background: #1e1e1e;
+ color: #00aaff;
+ font-weight: bold;
+ margin-left: 2px;
+ border: 1px solid #333;
+ padding: 4px 12px;
+ }""")
# --- Tab System ---
self.tabs = QTabWidget()
self.tabs.setTabsClosable(True)
self.tabs.tabCloseRequested.connect(self.close_tab)
+
self.setCentralWidget(self.tabs)
self.create_welcome_tab()
+
+ self.tabs.addTab(QWidget(), "+")
+
+ # 2. Disable the close button on the "+" tab specifically
+ # (Assuming index 0 was your first tab, the + is now at the last index)
+ plus_idx = self.tabs.count() - 1
+ self.tabs.tabBar().setTabButton(plus_idx, QTabBar.ButtonPosition.RightSide, None)
+
+ # 3. Connect to the click event
+ self.tabs.tabBar().installEventFilter(self)
+
self.create_menu_bar()
# Update checks
@@ -1523,9 +1762,46 @@ class PremiereWindow(QMainWindow):
# Window instances
self.load_window = None
+ self.train_window = None
self.about = None
self.help = None
+
+ def eventFilter(self, obj, event):
+ # Check if the event is a mouse press on the TabBar
+ if obj == self.tabs.tabBar() and event.type() == QEvent.MouseButtonPress:
+ # Map the click position to which tab index was hit
+ index = self.tabs.tabBar().tabAt(event.pos())
+
+ # If they clicked the "+" tab
+ if self.tabs.tabText(index) == "+":
+ # Show the menu
+ self.show_new_tab_menu()
+ # Return True to CONSUME the event.
+ # This prevents QTabWidget from ever seeing the click and switching tabs.
+ return True
+
+ return super().eventFilter(obj, event)
+
+
+ def show_new_tab_menu(self):
+ from PySide6.QtWidgets import QMenu
+ from PySide6.QtGui import QAction, QCursor
+ menu = QMenu(self)
+
+ load_act = QAction("Load Video", self)
+ load_act.triggered.connect(self.open_load_video_dialog)
+
+ train_act = QAction("Train Model", self)
+ train_act.triggered.connect(self.open_train_model_dialog)
+
+ menu.addAction(load_act)
+ menu.addAction(train_act)
+
+ # Show the menu right under the mouse cursor
+ menu.exec(QCursor.pos())
+
+
def create_welcome_tab(self):
welcome_widget = QWidget()
layout = QVBoxLayout(welcome_widget)
@@ -1544,8 +1820,7 @@ class PremiereWindow(QMainWindow):
layout.addStretch()
self.tabs.addTab(welcome_widget, "Welcome")
- # Disable the close button on the welcome tab
- self.tabs.tabBar().setTabButton(0, QTabBar.ButtonPosition.RightSide, None)
+
def create_menu_bar(self):
menu_bar = self.menuBar()
@@ -1590,22 +1865,80 @@ class PremiereWindow(QMainWindow):
self.load_window.btn_confirm.clicked.connect(self.handle_new_video_session)
self.load_window.show()
- def handle_new_video_session(self):
- # Extract properties from the OpenFileWindow before closing it
- video_path = self.load_window.video_path
- obs_file = self.load_window.obs_file
- offset = self.load_window.current_video_offset
- # ... grab any other needed parameters (like selected ML model, etc.)
+ def open_train_model_dialog(self):
+ if self.train_window is None or not self.train_window.isVisible():
+ self.train_window = TrainModelWindow(self)
+ # Connect the initialization button from OpenFileWindow to our tab creator
+ self.train_window.btn_train.clicked.connect(self.handle_start_training)
+ self.train_window.show()
+
+
+ def handle_start_training(self):
+ params = self.train_window.get_selection()
+ # 1. Use QProgressDialog instead of QMessageBox
+ # It's designed to stay on top and handle background tasks
+ self.loading_dialog = QProgressDialog("Processing data and training Random Forest...", None, 0, 0, self)
+ self.loading_dialog.setWindowTitle("Training Model")
+ self.loading_dialog.setWindowModality(Qt.WindowModal)
+ self.loading_dialog.setCancelButton(None) # Remove cancel button to prevent interruption
+ self.loading_dialog.setMinimumDuration(0) # Show immediately
+ self.loading_dialog.show()
+
+ # 2. Setup the Worker Thread
+ self.training_thread = TrainingWorker(params)
+ self.training_thread.finished.connect(self.on_training_finished)
+ self.training_thread.error.connect(self.on_training_error)
+
+ # Clean up the thread object when it's done to prevent memory leaks
+ self.training_thread.finished.connect(self.training_thread.deleteLater)
+
+ # 3. Start thread
+ self.training_thread.start()
+
+ # Close the selection window
+ self.train_window.close()
+
+ def on_training_finished(self, report_html):
+ # Using reset() on QProgressDialog automatically closes it and cleans up
+ if self.loading_dialog:
+ self.loading_dialog.reset()
+ self.loading_dialog = None
+
+ self.display_ml_results(report_html)
+
+ def on_training_error(self, error_msg):
+ if self.loading_dialog:
+ self.loading_dialog.reset()
+ self.loading_dialog = None
+
+ QMessageBox.critical(self, "Training Error", f"An error occurred: {error_msg}")
+
+ def display_ml_results(self, report):
+ """Displays the RF performance report in a simple popup."""
+ msg = QMessageBox(self)
+ msg.setWindowTitle("Training Results")
+ msg.setTextFormat(Qt.RichText)
+ msg.setText(report)
+ msg.exec()
+
+ def handle_new_video_session(self):
+ config = self.load_window.get_config()
+
+ # 2. Close the selection window
self.load_window.close()
- # Create a new, independent tab
- new_tab = VideoAnalysisTab(video_path, obs_file, offset)
+ # 3. Create the new tab with the config dictionary
+ # We pass the config so the tab knows whether to run AI or just play video
+ new_tab = VideoAnalysisTab(config)
- # Add to TabWidget and switch to it
- tab_name = os.path.basename(video_path)
- index = self.tabs.addTab(new_tab, tab_name)
- self.tabs.setCurrentIndex(index)
+ # 4. Handle Tab Placement (Keep '+' at the end)
+ tab_name = os.path.basename(config['video_path'])
+ plus_idx = self.tabs.count() - 1
+ new_idx = self.tabs.insertTab(plus_idx, new_tab, tab_name)
+
+ # 5. Switch to it
+ self.tabs.setCurrentIndex(new_idx)
def close_tab(self, index):
# Prevent closing the Welcome tab if it's the only one left
@@ -1645,26 +1978,72 @@ class PremiereWindow(QMainWindow):
-from PySide6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QSplitter,
- QScrollArea, QPushButton, QComboBox, QLabel,
- QSlider, QTextEdit, QMessageBox)
-from PySide6.QtCore import Qt, QSizeF, QUrl, Slot
-from PySide6.QtMultimedia import QMediaPlayer, QAudioOutput
-from PySide6.QtMultimediaWidgets import QGraphicsVideoItem
-
class VideoAnalysisTab(QWidget):
- def __init__(self, video_path, obs_file=None, offset=0.0, parent=None):
- super().__init__(parent)
+ def __init__(self, config):
+ super().__init__()
# State
- self.file_path = video_path
- self.obs_file = obs_file
- self.current_video_offset = offset
- self.predictor = GeneralPredictor()
- self.data = None # Will be populated by worker
+ self.config = config
+
+ self.setStyleSheet("""
+ QMainWindow, QWidget#centralWidget {
+ background-color: #1e1e1e;
+ }
+ QLabel, QStatusBar, QMenuBar {
+ color: #ffffff;
+ }
+ /* Target the Timeline specifically */
+ TimelineWidget {
+ background-color: #1e1e1e;
+ border: 1px solid #333333;
+ }
+ /* Button styling with Grey borders */
+ QDialog, QMessageBox, QFileDialog {
+ background-color: #2b2b2b;
+ }
+ QDialog QLabel, QMessageBox QLabel {
+ color: #ffffff;
+ }
+ QPushButton {
+ background-color: #2b2b2b;
+ color: #ffffff;
+ border: 1px solid #555555; /* Subtle Grey border */
+ border-radius: 3px;
+ padding: 4px;
+ }
+ QPushButton:hover {
+ background-color: #3d3d3d;
+ border-color: #888888; /* Brightens border on hover */
+ }
+ QPushButton:pressed {
+ background-color: #111111;
+ }
+ QPushButton:disabled {
+ border-color: #333333;
+ color: #444444;
+ }
+ /* Splitter/Divider styling */
+ QSplitter::handle {
+ background-color: #333333; /* Dark grey dividers */
+ }
+ QSplitter::handle:horizontal {
+ width: 2px;
+ }
+ QSplitter::handle:vertical {
+ height: 2px;
+ }
+ /* ScrollArea styling to keep it dark */
+ QScrollArea, QScrollArea > QWidget > QWidget {
+ background-color: #1e1e1e;
+ border: none;
+ }
+ """)
+
self.setup_ui()
- #self.reprocess_current_video()
+
+ self.initialize_session()
+
def setup_ui(self):
main_layout = QVBoxLayout(self)
@@ -1685,7 +2064,7 @@ class VideoAnalysisTab(QWidget):
self.scene.addItem(self.video_item)
# Overlay initialization
- self.skeleton_overlay = SkeletonOverlay(self.view.viewport())
+ #self.skeleton_overlay = SkeletonOverlay(self.view.viewport())
self.player = QMediaPlayer()
self.audio_output = QAudioOutput()
@@ -1699,48 +2078,6 @@ class VideoAnalysisTab(QWidget):
controls_container = QWidget()
stacked_controls = QVBoxLayout(controls_container)
stacked_controls.setSpacing(5) # Tight spacing between rows
-
- ml_row = QHBoxLayout()
- ml_row.addStretch()
-
- ml_row.addWidget(QLabel("ML Model:"))
- self.ml_dropdown = QComboBox()
- self.ml_dropdown.addItems(["Random Forest", "LSTM", "XGBoost", "SVM", "1D-CNN"])
- ml_row.addWidget(self.ml_dropdown)
-
- ml_row.addWidget(QLabel("Target:"))
- self.target_dropdown = QComboBox()
- self.target_dropdown.addItems(["Mouthing", "Head Movement", "Kick (Left)", "Kick (Right)", "Reach (Left)", "Reach (Right)"])
- #self.target_dropdown.currentTextChanged.connect(self.update_predictor_target)
- ml_row.addWidget(self.target_dropdown)
-
- self.btn_add_to_pool = QPushButton("Add to Pool")
- #self.btn_add_to_pool.clicked.connect(self.add_current_to_ml_pool)
- self.btn_add_to_pool.setFixedWidth(120)
- ml_row.addWidget(self.btn_add_to_pool)
-
- self.btn_train_final = QPushButton("Train Global Model")
- self.btn_train_final.setStyleSheet("background-color: #2e7d32; font-weight: bold;")
- #self.btn_train_final.clicked.connect(self.run_final_training)
- ml_row.addWidget(self.btn_train_final)
-
- self.lbl_pool_status = QLabel("Pool: 0 Participants")
- self.lbl_pool_status.setStyleSheet("color: #00FF00; font-weight: bold; margin-left: 10px;")
- self.lbl_pool_status.setFixedWidth(160)
- ml_row.addWidget(self.lbl_pool_status)
-
- self.btn_clear_pool = QPushButton("Clear Pool")
- self.btn_clear_pool.setFixedWidth(100)
- self.btn_clear_pool.setStyleSheet("color: #ff5555; border: 1px solid #ff5555;")
- #self.btn_clear_pool.clicked.connect(self.clear_ml_pool)
- ml_row.addWidget(self.btn_clear_pool)
-
- self.btn_extract_ai = QPushButton("Extract AI Data")
- #self.btn_extract_ai.clicked.connect(self.extract_ai_to_json)
- ml_row.addWidget(self.btn_extract_ai)
-
-
- ml_row.addStretch()
# --- ROW 2: Playback & Transport ---
playback_row = QHBoxLayout()
@@ -1791,26 +2128,347 @@ class VideoAnalysisTab(QWidget):
playback_row.addStretch()
# --- Add Rows to Stack ---
- stacked_controls.addLayout(ml_row)
stacked_controls.addLayout(playback_row)
video_layout.addWidget(controls_container)
+
+ info_container = QWidget()
+ info_layout = QVBoxLayout(info_container)
+
+ self.progress_container = QWidget()
+ progress_layout = QVBoxLayout(self.progress_container)
+
+ self.lbl_analysis_status = QLabel("Pose Analysis: Idle")
+ self.analysis_bar = QProgressBar()
+ self.analysis_bar.setRange(0, 100)
+ self.analysis_bar.setValue(0)
+ self.analysis_bar.setStyleSheet("""
+ QProgressBar { border: 1px solid #555; border-radius: 2px; text-align: center; height: 15px; }
+ QProgressBar::chunk { background-color: #00aaff; }
+ """)
+
+ progress_layout.addWidget(self.lbl_analysis_status)
+ progress_layout.addWidget(self.analysis_bar)
+
+ # Insert into info_layout (above the inspector scroll area)
+ info_layout.insertWidget(0, self.progress_container)
+
+ # NEW: Wrap the info_label in a Scroll Area
+ self.inspector_scroll = QScrollArea()
+ self.inspector_scroll.setWidgetResizable(True)
- # --- Inspector & Timeline ---
self.info_label = QTextEdit()
self.info_label.setReadOnly(True)
+ self.info_label.setStyleSheet("padding: 8px; font-family: 'Consolas', 'Segoe UI'; color: #ffffff;")
+ self.inspector_scroll.setWidget(self.info_label)
+
+ # NEW: Export Button for Metrics
+ self.btn_export_metrics = QPushButton("Export Metrics to JSON")
+ self.btn_export_metrics.clicked.connect(self.export_behavior_metrics)
+ self.btn_export_metrics.setEnabled(False) # Enable only after load
+
+ info_layout.addWidget(self.inspector_scroll)
+ info_layout.addWidget(self.btn_export_metrics)
+
+ top_splitter.addWidget(video_container)
+ top_splitter.addWidget(info_container)
+ top_splitter.setSizes([800, 400])
self.timeline = TimelineWidget()
self.timeline.seek_requested.connect(self.seek_video)
-
- top_splitter.addWidget(video_container)
- top_splitter.addWidget(self.info_label)
+
+ scroll_area = QScrollArea()
+ scroll_area.setWidgetResizable(True)
+ scroll_area.setWidget(self.timeline)
self.main_splitter.addWidget(top_splitter)
- self.main_splitter.addWidget(self.timeline)
+ self.main_splitter.addWidget(scroll_area)
+ self.main_splitter.setSizes([500, 400])
main_layout.addWidget(self.main_splitter)
+
+
+ self.skeleton_overlay = SkeletonOverlay(self.view)
+ self.skeleton_overlay.resize(self.view.size())
+ self.skeleton_overlay.hide()
+
+ # 2. FIX: Watch the view for resizes
+ self.view.installEventFilter(self)
+
+ self.player.positionChanged.connect(self.update_timeline_playhead)
self.setup_transport() # Start with empty workspace until worker finishes
+ # self.load_boris_to_timeline()
+ self.video_item.nativeSizeChanged.connect(self.update_video_geometry)
+ self.start_analysis()
+
+
+
+ def start_analysis(self):
+ if not self.config.get("use_pose", True):
+ self.lbl_analysis_status.setText("Pose Analysis: Bypassed")
+ self.analysis_bar.setValue(100)
+ self.skeleton_overlay.hide()
+ return
+
+ # 1. Setup Queues
+ self.prog_q = Queue()
+ self.res_q = Queue()
+
+ # 2. Create the Process
+ self.analysis_proc = Process(
+ target=run_pose_analysis,
+ args=(self.config['video_path'], self.prog_q, self.res_q, self.config),
+ name="PoseWorkerProcess"
+ )
+
+ # 3. UI Updates
+ self.lbl_analysis_status.setText("Process Started...")
+ self.analysis_bar.setValue(0)
+ self.analysis_bar.show()
+
+ # 4. Start
+ self.analysis_proc.start()
+
+ # 5. Timer to check the Queue
+ self.poll_timer = QTimer()
+ self.poll_timer.timeout.connect(self.check_queues)
+ self.poll_timer.start(100)
+
+
+ def check_queues(self):
+ # Drain progress queue
+ while not self.prog_q.empty():
+ val = self.prog_q.get()
+ self.analysis_bar.setValue(val)
+ self.lbl_analysis_status.setText(f"Extracting Poses: {val}%")
+
+ # Check result queue
+ if not self.res_q.empty():
+ data = self.res_q.get()
+ self.poll_timer.stop()
+ self.handle_finished_data(data)
+
+
+ def handle_finished_data(self, data):
+ """
+ Runs on the main thread. Loads BORIS instantly, shows the skeleton,
+ and kicks off AI in the background.
+ """
+ # 1. Unpack basic data
+ self.raw_kpts = data["raw_kpts"]
+ self.fps = data.get("fps", 30.0)
+ self.total_frames = data["total_frames"]
+ v_w, v_h = data["dims"]
+
+ # 2. Setup master dictionary
+ self.processed_data = {}
+
+ # 3. IMMEDIATE: Load BORIS (if available)
+ if self.config.get('use_boris'):
+ self.load_boris_to_timeline()
+
+ # 4. IMMEDIATE: Show Skeleton Overlay
+ # We calculate the baseline mean here because it's fast (Numpy)
+ raw_kps_per_frame = [frame[:, :2] for frame in self.raw_kpts]
+ valid_mask = [np.any(kp) for kp in raw_kps_per_frame]
+ valid_data = [raw_kps_per_frame[i] for i, v in enumerate(valid_mask) if v]
+
+ baseline_mean = np.mean(valid_data, axis=0) if valid_data else np.zeros((17, 2))
+
+ overlay_payload = {
+ "raw_kps": self.raw_kpts,
+ "width": v_w,
+ "height": v_h,
+ "events": self.processed_data, # Initially just BORIS
+ "baseline_kp_mean": baseline_mean
+ }
+ self.skeleton_overlay.set_data(overlay_payload)
+ self.skeleton_overlay.show()
+
+ # 5. SYNC UI (Show the manual timeline immediately)
+ self.sync_timeline_to_ui()
+
+ # 6. BACKGROUND: Start AI Inference Thread
+ if self.config.get('use_pkl') and hasattr(self, 'ml_model'):
+ self.lbl_analysis_status.setText("Running AI Inference...")
+
+ # Start the worker thread
+ self.ml_worker = MLInferenceWorker(
+ self.raw_kpts,
+ self.ml_model,
+ self.ml_scaler,
+ self.active_features,
+ self.ml_metadata.get('target_behavior', 'Reach')
+ )
+ self.ml_worker.finished.connect(self.on_ai_inference_complete)
+ self.ml_worker.error.connect(lambda e: print(f"AI ERROR: {e}"))
+ self.ml_worker.start()
+ else:
+ self.lbl_analysis_status.setText("Analysis Complete")
+
+
+
+ def on_ai_inference_complete(self, ai_events):
+ """Runs when the thread finishes. Merges AI into the existing UI."""
+ # 1. Merge AI tracks into the dictionary that already has BORIS
+ for track_name, blocks in ai_events.items():
+ self.processed_data[track_name] = blocks
+ print(f"AI Thread complete: Injected {track_name}")
+
+ # 2. Update the Timeline Widget visually
+ self.sync_timeline_to_ui()
+
+ # 3. Update the Skeleton Overlay specifically
+ # This ensures the AI blocks show up in the video player's seek bar/overlay
+ if hasattr(self, 'skeleton_overlay'):
+ # Fetch the old data dict, update the 'events' key, and push it back
+ updated_payload = self.skeleton_overlay.data.copy()
+ updated_payload['events'] = self.processed_data
+ self.skeleton_overlay.set_data(updated_payload)
+
+ self.lbl_analysis_status.setText("All Tracks (BORIS + AI) Loaded")
+
+
+
+ def load_boris_to_timeline(self):
+ """
+ Parses the JSON to identify unique behaviors and prepare the data.
+ """
+ if not self.config.get('use_boris', False) or not self.config.get('obs_file'):
+ return
+
+ try:
+ with open(self.config['obs_file'], 'r') as f:
+ data = json.load(f)
+
+ behav_conf = data.get("behaviors_conf", {})
+ type_lookup = {v['code']: v['type'] for v in behav_conf.values()}
+
+ session_key = self.config.get('session_key')
+ session = data.get("observations", {}).get(session_key, {})
+ events = session.get("events", [])
+
+ unique_behaviors = sorted(list(set(e[2] for e in events)))
+ new_boris_data = {name: [] for name in unique_behaviors}
+
+ fps = self.config.get('fps', 30.0)
+ # Use the offset from the config
+ offset_frames = int(self.config.get('offset', 0.0) * fps)
+
+ state_tracker = {}
+
+ for e in events:
+ timestamp = float(e[0])
+ behavior_name = e[2]
+ event_type = type_lookup.get(behavior_name, "Point event")
+ current_frame = int(timestamp * fps) - offset_frames
+
+ if event_type == "Point event":
+ new_boris_data[behavior_name].append([current_frame, current_frame + 1, "Normal", "N/A"])
+ elif event_type == "State event":
+ if behavior_name in state_tracker:
+ start_f = state_tracker.pop(behavior_name)
+ new_boris_data[behavior_name].append([start_f, current_frame, "Normal", "N/A"])
+ else:
+ state_tracker[behavior_name] = current_frame
+
+ if not hasattr(self, 'processed_data') or self.processed_data is None:
+ self.processed_data = {}
+
+ for behavior, blocks in new_boris_data.items():
+ self.processed_data[behavior] = blocks
+
+ print(f"[INFO] Merged {len(new_boris_data)} BORIS tracks into timeline.")
+
+ except Exception as e:
+ print(f"Error loading BORIS data: {e}")
+
+
+ def sync_timeline_to_ui(self):
+ """Final push of all merged data (BORIS + AI) to the UI components."""
+ if not hasattr(self, 'processed_data'):
+ return
+
+ total_f = self.config.get('total_frames', self.total_frames)
+ fps = self.config.get('fps', self.fps)
+
+ # 1. Update the Timeline Widget
+ self.timeline.set_data(self.processed_data, total_f, fps)
+
+ # 2. Update Stats & Export button
+ self.update_inspector_stats(self.processed_data, fps)
+ self.btn_export_metrics.setEnabled(True)
+
+ print(f"DEBUG: UI Synced with {len(self.processed_data)} total tracks.")
+
+
+ def update_inspector_stats(self, data, fps):
+ """Calculates and displays behavior summary on the right panel."""
+ stats_text = "SESSION SUMMARY
"
+ stats_text += f"File: {self.config['session_key']}
"
+ stats_text += "-----------------------------------
"
+
+ for behavior, instances in data.items():
+ durations = [(end - start) / fps for start, end, _, _ in instances]
+ count = len(instances)
+ avg_dur = np.mean(durations) if durations else 0
+ total_dur = sum(durations)
+
+ stats_text += f"{behavior}:
"
+ stats_text += f" - Occurrences: {count}
"
+ stats_text += f" - Avg Duration: {avg_dur:.3f}s
"
+ stats_text += f" - Total Time: {total_dur:.2f}s
"
+
+ self.info_label.setHtml(stats_text)
+
+
+ def export_behavior_metrics(self):
+ """Exports the processed behavior metrics to a new JSON file."""
+ if not self.processed_data:
+ return
+
+ export_payload = {
+ "metadata": {
+ "source_video": os.path.basename(self.config.get('video_path', 'unknown')),
+ "session": self.config.get('session_key', 'unknown'),
+ "export_timestamp": datetime.now().isoformat(),
+ "fps": self.config.get('fps', 30.0)
+ },
+ "behaviors": {}
+ }
+
+ for behavior, instances in self.processed_data.items():
+ behavior_events = []
+
+ for start_f, end_f, _, _ in instances:
+ behavior_events.append({
+ "start_frame": int(start_f),
+ "duration_frames": int(end_f - start_f)
+ })
+
+ export_payload["behaviors"][behavior] = behavior_events
+
+ video_path = self.config.get('video_path')
+ if not video_path:
+ return
+
+ base_path = video_path.rsplit('.', 1)[0]
+ export_path = f"{base_path}_metrics.json"
+
+ # 3. SILENT SAVE
+ try:
+ with open(export_path, 'w') as f:
+ json.dump(export_payload, f, indent=4)
+
+ # Log the success so the researcher knows where it went
+ msg = f"
[EXPORT COMPLETE]: {os.path.basename(export_path)}"
+ self.info_label.append(msg)
+ print(f"Metrics saved to: {export_path}")
+
+ except Exception as e:
+ self.info_label.append(f"
[EXPORT FAILED]: {e}")
+
+
def setup_transport(self):
"""Sets up player controls that don't depend on skeleton analysis."""
@@ -1829,29 +2487,170 @@ class VideoAnalysisTab(QWidget):
self.btn_prev.clicked.connect(lambda: self.step_frame(-1))
self.btn_next.clicked.connect(lambda: self.step_frame(1))
- self.player.setSource(QUrl.fromLocalFile(self.file_path))
+ self.player.setSource(QUrl.fromLocalFile(self.config['video_path']))
+ self.player.pause()
+ # self.player.mediaStatusChanged.connect(self.initial_resize_hack)
+
+ # def initial_resize_hack(self, status):
+ # # Once the media is loaded, refresh the layout
+ # if status >= QMediaPlayer.MediaStatus.LoadedMedia:
+ # self.update_video_geometry()
+ # # Seek to 0 just to be absolutely sure the buffer updates
+ # self.player.setPosition(0)
+
+
+ def initialize_session(self):
+ print(f"--- Initializing Session Components ---")
- @Slot(dict)
- def setup_workspace(self, analyzed_data):
- """Only handles skeleton/analysis-specific data."""
- self.data = analyzed_data
+ # Component A: Pose Inference (Skeleton extraction)
+ if self.config.get('use_pose', False):
+ # This is already handled by self.start_analysis() in your __init__
+ pass
+
+ # Component B: BORIS Annotation Track
+ if self.config.get('use_boris', False):
+ print("[INFO] Loading BORIS annotation track...")
+ self.load_boris_to_timeline()
- # Update timeline and overlay now that we have data
- if hasattr(self, 'timeline'):
- self.timeline.set_data(self.data)
+ # Component C: ML Prediction Track (.pkl)
+ if self.config.get('use_pkl', False):
+ print(f"[INFO] Loading ML Model: {os.path.basename(self.config['pkl_path'])}")
+ self.load_pretrained_classifier()
+
+ def load_pretrained_classifier(self):
+ """Loads the .pkl model and automatically hunts for its scaler."""
+ if not self.config.get('use_pkl') or not self.config.get('pkl_path'):
+ return
+
+ model_path = self.config['pkl_path']
+
+ metadata_path = model_path.replace(".pkl", "_metadata.json")
+ if os.path.exists(metadata_path):
+ with open(metadata_path, 'r') as f:
+ self.ml_metadata = json.load(f)
+ self.active_features = self.ml_metadata.get("feature_keys", [])
+ print(f"[INFO] Feature map loaded: {len(self.active_features)} features.")
+ else:
+ raise Exception
- if hasattr(self, 'skeleton_overlay'):
- self.skeleton_overlay.set_data(self.data)
+ try:
+ # 1. Load the primary model
+ self.ml_model = joblib.load(model_path)
+ msg = f"[INFO] ML Model loaded: {os.path.basename(model_path)}"
+ print(msg)
+ self.update_status(f"{msg}")
+
+ # 2. Auto-discover the Scaler
+ base_name = os.path.splitext(model_path)[0]
+ possible_scaler_paths = [
+ f"{base_name}_scaler.pkl",
+ os.path.join(os.path.dirname(model_path), "scaler.pkl")
+ ]
+
+ self.ml_scaler = None
+ for spath in possible_scaler_paths:
+ if os.path.exists(spath):
+ self.ml_scaler = joblib.load(spath)
+ s_msg = f"[INFO] Associated scaler auto-loaded: {os.path.basename(spath)}"
+ print(s_msg)
+ self.update_status(f"{s_msg}")
+ break
+
+ if not self.ml_scaler:
+ print("[WARNING] No associated scaler found. Proceeding without scaling.")
+
+ except Exception as e:
+ err = f"[ERROR] Failed to load ML Model or Scaler: {e}"
+ print(err)
+ self.update_status(f"{err}")
- # Trigger a geometry refresh now that we have accurate video dims from data
- self.update_video_geometry()
- print(f"Analysis complete for {self.file_path}")
+
+ def run_ml_inference(self, raw_kpts):
+ """
+ Applies scaler, runs inference, and converts frame-by-frame
+ predictions into contiguous timeline blocks.
+ """
+ if not hasattr(self, 'ml_model') or not self.active_features:
+ return {}
+
+ # 1. Create the engine and tell it WHICH features to care about
+ engine = GeneralPredictor()
+ engine.active_feature_keys = self.active_features
+
+ # 2. Extract 13 features for every frame
+ X_raw = []
+ for frame_idx in range(len(raw_kpts)):
+ # format_features now returns only the 13 needed values
+ feat_vector = engine.format_features(raw_kpts[frame_idx])
+ X_raw.append(feat_vector)
+
+ X = np.array(X_raw) # Resulting shape: (Frames, 13)
+
+ # 3. Predict
+ if self.ml_scaler:
+ X = self.ml_scaler.transform(X)
+
+ preds = self.ml_model.predict(X)
+ unique, counts = np.unique(preds, return_counts=True)
+ print(f"DEBUG ML Results: {dict(zip(unique, counts))}")
+ return self._convert_predictions_to_tracks(preds)
+
+
+
+ def _convert_predictions_to_tracks(self, predictions):
+ """Converts an array of class labels into start/stop timeline blocks."""
+ events = {}
+ current_class = None
+ start_frame = 0
+
+ # Define labels that mean "nothing is happening"
+ background_labels = [0, "0", "Idle", "None", None, "Background"]
+
+ for i, pred_class in enumerate(predictions):
+ if pred_class != current_class:
+ # Close the previous active block
+ if current_class not in background_labels:
+ track_name = f"🤖 AI: {current_class}"
+ if track_name not in events:
+ events[track_name] = []
+ # Format: [start_frame, end_frame, label, notes]
+ events[track_name].append([start_frame, i, "Normal", "ML Prediction"])
+
+ # Start new block
+ current_class = pred_class
+ start_frame = i
+
+ # Close the final block if the video ends while an action is active
+ if current_class not in background_labels:
+ track_name = f"🤖 AI: {current_class}"
+ if track_name not in events:
+ events[track_name] = []
+ events[track_name].append([start_frame, len(predictions), "Normal", "ML Prediction"])
+
+ return events
+
+
+ def load_boris_annotations(self):
+ """Logic to parse the JSON for the specific session/slot."""
+ try:
+ with open(self.config['obs_file'], 'r') as f:
+ data = json.load(f)
+
+ session = data.get("observations", {}).get(self.config['session_key'], {})
+ # Extract events for the specific slot
+ events = session.get("events", [])
+ # ... Filter events where slot matches config['slot'] ...
+ print(f"Loaded {len(events)} events from BORIS.")
+ except Exception as e:
+ print(f"Failed to load BORIS data: {e}")
+
def update_status(self, message):
"""Updates the inspector or a status bar with worker progress."""
self.info_label.append(message)
+
def toggle_playback(self):
if self.player.playbackState() == QMediaPlayer.PlayingState:
self.player.pause()
@@ -1860,12 +2659,64 @@ class VideoAnalysisTab(QWidget):
self.player.play()
self.btn_play.setText("Pause")
- def seek_video(self, ms):
+
+ def update_timeline_playhead(self, position_ms):
+ #debug_print()
+ fps = self.config.get('fps', 30.0)
+ total_f = self.config.get('total_frames', 0)
+
+ # Current frame calculation
+ current_f = int((position_ms / 1000.0) * fps)
+
+ # --- PREVENT BLACK FRAME AT END ---
+ # If we are within 1 frame of the end, stop and lock to the last valid frame
+ if hasattr(self, 'skeleton_overlay') and self.skeleton_overlay.isVisible():
+ self.skeleton_overlay.set_frame(current_f)
+
+ if current_f >= total_f - 1:
+ if self.player.playbackState() == QMediaPlayer.PlayingState:
+ self.player.pause()
+ self.btn_play.setText("Play")
+ current_f = total_f - 1
+ # Seek slightly back from total duration to keep the image visible
+ last_valid_ms = int(((total_f - 1) / fps) * 1000)
+ self.player.setPosition(last_valid_ms)
+
+ # Sync UI
+ self.timeline.set_playhead(current_f)
+ self.update_counters(current_f)
+
+ def seek_video(self, frame):
+ # Use the config or timeline data instead of self.data
+ fps = self.config.get('fps', 30.0)
+ total_f = self.config.get('total_frames', 0)
+
+ target_frame = max(0, min(frame, total_f - 1))
+
+ # Convert frame to milliseconds for QMediaPlayer
+ ms = int((target_frame / fps) * 1000)
self.player.setPosition(ms)
+ # Sync the UI
+ self.timeline.set_playhead(target_frame)
+ self.update_counters(target_frame)
+
+
+ def update_counters(self, current_f):
+ #debug_print()
+
+ # Dedicated method to refresh the labels
+ fps = self.config.get('fps', 30.0)
+ total_f = self.config.get('total_frames', 0)
+
+ cur_s, tot_s = int(current_f / fps), int(total_f / fps)
+ self.lbl_time_counter.setText(f"Time: {cur_s//60:02d}:{cur_s%60:02d} / {tot_s//60:02d}:{tot_s%60:02d}")
+ self.lbl_frame_counter.setText(f"Frame: {current_f} / {total_f-1}")
+
+
def step_frame(self, delta):
# Fallback to 30 FPS if worker data isn't ready
- fps = self.data["fps"] if (hasattr(self, 'data') and self.data) else 30.0
+ fps = self.config.get('fps', 30.0)
current_ms = self.player.position()
# One frame in ms = 1000 / fps
@@ -1876,6 +2727,7 @@ class VideoAnalysisTab(QWidget):
target_ms = max(0, min(target_ms, self.player.duration()))
self.player.setPosition(target_ms)
+
def update_volume(self, value):
# QAudioOutput expects a float between 0.0 and 1.0
volume = value / 100.0
@@ -1886,6 +2738,7 @@ class VideoAnalysisTab(QWidget):
self.btn_mute.setChecked(False)
self.toggle_mute()
+
def toggle_mute(self):
is_muted = self.btn_mute.isChecked()
self.audio_output.setMuted(is_muted)
@@ -1893,6 +2746,7 @@ class VideoAnalysisTab(QWidget):
# Optional: Dim the slider when muted
self.sld_volume.setEnabled(not is_muted)
+
def update_video_geometry(self):
if not hasattr(self, "video_item"):
return
@@ -1935,13 +2789,21 @@ class VideoAnalysisTab(QWidget):
if hasattr(self, "skeleton_overlay"):
self.skeleton_overlay.setGeometry(int(x_off), int(y_off), target_w, target_h)
+
def resizeEvent(self, event):
# debug_print()
super().resizeEvent(event)
self.update_video_geometry()
if hasattr(self, 'timeline'):
- self.timeline.set_zoom(self.timeline.zoom_factor)
+ self.timeline.update_geometry()
+
+ def eventFilter(self, source, event):
+ """Keeps the skeleton aligned with the video frame size."""
+ if source == getattr(self, 'video_preview_label', None) and event.type() == QEvent.Resize:
+ self.skeleton_overlay.resize(event.size())
+ return super().eventFilter(source, event)
+
def cleanup(self):
if self.player:
@@ -2065,6 +2927,132 @@ if __name__ == "__main__":
# Only run GUI in the main process
if current_process().name == 'MainProcess':
app = QApplication(sys.argv)
+
+ style = """
+
+ /* 1. General App Backgrounds */
+ QMainWindow, QWidget#centralWidget, QDialog, QMessageBox, QFileDialog {
+ background-color: #1e1e1e;
+ color: #ffffff;
+ }
+
+ QLabel, QStatusBar, QMenuBar {
+ color: #ffffff;
+ }
+
+ /* 2. THE BIG FIX: Removing white backgrounds from inputs */
+ QListWidget, QComboBox, QLineEdit, QSpinBox, QTextEdit {
+ background-color: #2b2b2b;
+ color: #ffffff;
+ border: 1px solid #555555;
+ border-radius: 3px;
+ padding: 2px;
+ }
+
+ /* Fix for the dropdown list of a QComboBox */
+ QComboBox QAbstractItemView {
+ background-color: #2b2b2b;
+ color: #ffffff;
+ selection-background-color: #00aaff;
+ selection-color: #ffffff;
+ outline: none;
+ border: 1px solid #555555;
+ }
+
+ /* 3. Tab Navigation Styling */
+ QTabWidget::pane {
+ border: 1px solid #333333;
+ background: #1e1e1e;
+ }
+
+ QTabBar::tab {
+ background: #2b2b2b;
+ color: #aaa;
+ padding: 8px 15px;
+ border: 1px solid #333;
+ border-bottom: none;
+ border-top-left-radius: 4px;
+ border-top-right-radius: 4px;
+ }
+
+ QTabBar::tab:selected:!last {
+ background: #3d3d3d;
+ color: #fff;
+ font-weight: bold;
+ }
+
+ QTabBar::tab:hover:!last {
+ background: #444;
+ }
+
+ /* THE PLUS TAB: Constant State */
+ QTabBar::tab:last, QTabBar::tab:last:hover, QTabBar::tab:last:selected {
+ background: #1e1e1e;
+ color: #00aaff;
+ font-weight: bold;
+ margin-left: 2px;
+ border: 1px solid #333;
+ padding: 4px 12px;
+ }
+
+ /* 4. Timeline & Custom Widgets */
+ TimelineWidget {
+ background-color: #1e1e1e;
+ border: 1px solid #333333;
+ }
+
+ /* 5. Buttons with Grey Borders */
+ QPushButton {
+ background-color: #2b2b2b;
+ color: #ffffff;
+ border: 1px solid #555555;
+ border-radius: 3px;
+ padding: 4px;
+ }
+ QPushButton:hover {
+ background-color: #3d3d3d;
+ border-color: #888888;
+ }
+ QPushButton:pressed {
+ background-color: #111111;
+ }
+ QPushButton:disabled {
+ border-color: #333333;
+ color: #444444;
+ }
+
+ /* 6. Layout Dividers */
+ QSplitter::handle {
+ background-color: #333333;
+ }
+ QSplitter::handle:horizontal { width: 2px; }
+ QSplitter::handle:vertical { height: 2px; }
+
+ /* 7. Scroll Areas */
+ QScrollArea, QScrollArea > QWidget > QWidget {
+ background-color: #1e1e1e;
+ border: none;
+ }
+ QWidget { background-color: #1e1e1e; color: #ffffff; font-family: 'Segoe UI'; }
+ QGroupBox {
+ border: 1px solid #3d3d3d; border-radius: 8px; margin-top: 15px;
+ padding-top: 15px; font-weight: bold; color: #00aaff; text-transform: uppercase;
+ }
+ QLabel { color: #ffffff; font-weight: 500; }
+ QLabel:disabled { color: #444444; }
+ QLabel#Metadata { color: #00ffaa; font-family: 'Consolas'; font-size: 11px; }
+ QLabel#Preview { background-color: #000000; border: 2px solid #3d3d3d; }
+ QLabel#Warning { color: #ff5555; font-size: 11px; font-style: italic; font-weight: bold; }
+
+ QPushButton { background-color: #3d3d3d; border: 1px solid #555; padding: 6px; border-radius: 4px; }
+ QPushButton:hover { background-color: #00aaff; color: #000; }
+ QPushButton:disabled { color: #444; background-color: #252525; }
+
+ QComboBox { background-color: #2d2d2d; border: 1px solid #555; padding: 4px; border-radius: 4px; }
+ QComboBox:disabled { background-color: #222; color: #444; border: 1px solid #2a2a2a; }
+ """
+
+ app.setStyleSheet(style)
finish_update_if_needed(PLATFORM_NAME, APP_NAME)
window = PremiereWindow()
diff --git a/pose_worker.py b/pose_worker.py
new file mode 100644
index 0000000..6129ad0
--- /dev/null
+++ b/pose_worker.py
@@ -0,0 +1,154 @@
+import cv2
+import os
+import csv
+import numpy as np
+from ultralytics import YOLO
+from multiprocessing import current_process
+
+JOINT_NAMES = [
+ "nose", "l_eye", "r_eye", "l_ear", "r_ear", "l_shld", "r_shld",
+ "l_elbw", "r_elbw", "l_wri", "r_wri", "l_hip", "r_hip",
+ "l_knee", "r_knee", "l_ankl", "r_ankl"
+]
+
+def get_best_infant_match(results, w, h, prev_track_id):
+ """
+ Identifies the most likely infant based on visibility,
+ centrality, and tracking ID consistency.
+ """
+ if not results[0].boxes or results[0].boxes.id is None:
+ return None, None, None, None
+
+ ids = results[0].boxes.id.int().cpu().tolist()
+ kpts = results[0].keypoints.xy.cpu().numpy()
+ confs = results[0].keypoints.conf.cpu().numpy()
+
+ best_idx, best_score = -1, -1
+
+ for i, k in enumerate(kpts):
+ # visibility score
+ vis = np.sum(confs[i] > 0.5)
+ valid = k[confs[i] > 0.5]
+
+ # distance from center score
+ dist = np.linalg.norm(np.mean(valid, axis=0) - [w/2, h/2]) if len(valid) > 0 else 1000
+
+ # calculate total score
+ score = (vis * 10) - (dist * 0.1) + (50 if ids[i] == prev_track_id else 0)
+
+ if score > best_score:
+ best_score, best_idx = score, i
+
+ if best_idx == -1:
+ return None, None, None, None
+
+ return ids[best_idx], kpts[best_idx], confs[best_idx], best_idx
+
+
+
+def run_pose_analysis(video_path, progress_queue, result_queue, config):
+ """Worker task executed in a separate Process."""
+ p_name = current_process().name
+ pose_cache = video_path.rsplit('.', 1)[0] + "_pose_raw.csv"
+ print(f"[{p_name}] Starting analysis on: {video_path}")
+ csv_storage_data = []
+ inference_performed = False
+
+ cap = cv2.VideoCapture(video_path)
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ cap.release()
+
+
+ if config.get("use_cache") and os.path.exists(pose_cache):
+ print(f"[{p_name}] Cache checkmark active. Loading: {pose_cache}")
+ try:
+ with open(pose_cache, 'r') as f:
+ reader = csv.reader(f)
+ next(reader) # Skip header
+ for row in reader:
+ # Flattened (51,) back to (17, 3)
+ full_frame = np.array([float(x) for x in row]).reshape(17, 3)
+ csv_storage_data.append(full_frame)
+
+ progress_queue.put(100)
+ result_queue.put({
+ "raw_kpts": np.array(csv_storage_data),
+ "fps": fps,
+ "total_frames": len(csv_storage_data),
+ "dims": (width, height),
+ "status": "loaded_from_cache"
+ })
+ return # Exit early, no inference or saving needed
+ except Exception as e:
+ print(f"[{p_name}] Cache read failed, falling back to inference: {e}")
+ csv_storage_data = []
+
+
+
+ inference_performed = True
+ cap = cv2.VideoCapture(video_path)
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+
+ model_map = {
+ "YOLO8n-Pose": "yolov8n-pose.pt",
+ "YOLO8m-Pose": "yolov8m-pose.pt",
+ "Mediapipe BlazePose": "mediapipe"
+ }
+ model_file = model_map.get(config.get("pose_model"), "yolov8n-pose.pt")
+
+ print(f"[{p_name}] Running inference with model: {model_file}")
+ model = YOLO(model_file)
+
+ new_csv_storage_data = []
+ prev_track_id = None
+
+ for i in range(total_frames):
+ ret, frame = cap.read()
+ if not ret:
+ break
+
+ results = model.track(frame, persist=True, verbose=False)
+
+ track_id, kp, confs, _ = get_best_infant_match(results, width, height, prev_track_id)
+
+ if kp is not None:
+ prev_track_id = track_id
+ # Store as (17, 3) including confidence
+ new_csv_storage_data.append(np.column_stack((kp, confs)))
+ else:
+ new_csv_storage_data.append(np.zeros((17, 3)))
+
+ if i % 10 == 0:
+ progress_queue.put(int((i / total_frames) * 100))
+
+ cap.release()
+
+ if inference_performed:
+ print(f"[{p_name}] Saving new pose cache to {pose_cache}")
+ try:
+ with open(pose_cache, 'w', newline='') as f:
+ writer = csv.writer(f)
+ header = []
+ for joint in JOINT_NAMES:
+ header.extend([f"{joint}_x", f"{joint}_y", f"{joint}_conf"])
+ writer.writerow(header)
+ for frame_array in new_csv_storage_data:
+ writer.writerow(frame_array.flatten())
+ except Exception as e:
+ print(f"[{p_name}] Error saving cache: {e}")
+
+ # Return results through the queue
+ result_queue.put({
+ "raw_kpts": np.array(new_csv_storage_data),
+ "fps": fps,
+ "total_frames": len(new_csv_storage_data),
+ "dims": (width, height),
+ "status": "inference_complete"
+ })
+
+ print(f"[{p_name}] Analysis complete.")
\ No newline at end of file
diff --git a/predictor.py b/predictor.py
index bae5ced..216cd53 100644
--- a/predictor.py
+++ b/predictor.py
@@ -1,16 +1,9 @@
-"""
-Filename: predictor.py
-Description: BLAZES machine learning
-
-Author: Tyler de Zeeuw
-License: GPL-3.0
-"""
-
-# Built-in imports
import inspect
+import csv
+import os
+import json
from datetime import datetime
-# External library imports
import numpy as np
import joblib
import seaborn as sns
@@ -21,249 +14,202 @@ from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score, precision_score, recall_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
-# To be used once multiple models are supported and functioning:
-# import torch
-# import torch.nn as nn
-# import torch.optim as optim
-# import xgboost as xgb
-# from sklearn.svm import SVC
-# import os
-
VERBOSITY = 1
-GEOMETRY_LIBRARY = {
- # --- Distances (Point A, Point B) ---
- "dist_l_wrist_nose": ("dist", [9, 0], True),
- "dist_r_wrist_nose": ("dist", [10, 0], True),
- "dist_l_ear_r_shld": ("dist", [3, 6], True),
- "dist_r_ear_l_shld": ("dist", [4, 5], True),
-
- "dist_l_wrist_pelvis": ("dist", [9, [11, 12]], True),
- "dist_r_wrist_pelvis": ("dist", [10, [11, 12]], True),
- "dist_l_ankl_pelvis": ("dist", [15, [11, 12]], True),
- "dist_r_ankl_pelvis": ("dist", [16, [11, 12]], True),
- "dist_nose_pelvis": ("dist", [0, [11, 12]], True),
- "dist_ankl_ankl": ("dist", [15, 16], True),
+def load_analysis_config(path="analysis_config.json"):
+ with open(path, 'r') as f:
+ config = json.load(f)
+ return config['geometry_library'], config['activity_map']
- # NEW: Cross-Body and Pure Extension Distances
- "dist_l_wri_r_shld": ("dist", [9, 6], True), # Reach across body
- "dist_r_wri_l_shld": ("dist", [10, 5], True), # Reach across body
- "dist_l_wri_l_shld": ("dist", [9, 5], True), # Pure arm extension
- "dist_r_wri_r_shld": ("dist", [10, 6], True), # Pure arm extension
-
- # --- Angles (Point A, Center B, Point C) ---
- "angle_l_elbow": ("angle", [5, 7, 9]),
- "angle_r_elbow": ("angle", [6, 8, 10]),
- "angle_l_shoulder": ("angle", [11, 5, 7]),
- "angle_r_shoulder": ("angle", [12, 6, 8]),
- "angle_l_knee": ("angle", [11, 13, 15]),
- "angle_r_knee": ("angle", [12, 14, 16]),
- "angle_l_hip": ("angle", [5, 11, 13]),
- "angle_r_hip": ("angle", [6, 12, 14]),
-
- # --- Custom/Derived ---
- "asym_wrist": ("z_diff", [9, 10]),
- "asym_ankl": ("z_diff", [15, 16]),
- "offset_head": ("head_offset", [0, 5, 6]),
- "diff_ear_shld": ("subtraction", ["dist_l_ear_r_shld", "dist_r_ear_l_shld"]),
- "abs_diff_ear_shld": ("abs_subtraction", ["dist_l_ear_r_shld", "dist_r_ear_l_shld"]),
-
- # NEW: Verticality and Contralateral Contrast
- "height_l_ankl": ("y_diff", [15, 11]), # Foot height relative to hip
- "height_r_ankl": ("y_diff", [16, 12]), # Foot height relative to hip
- "diff_knee_angle": ("subtraction", ["angle_l_knee", "angle_r_knee"]),
- "asym_wri_shld": ("subtraction", ["dist_l_wri_l_shld", "dist_r_wri_r_shld"])
-}
-
-
-# The Target Activity Map
-ACTIVITY_MAP = {
- "Mouthing": [
- "dist_l_wrist_nose", "dist_r_wrist_nose", "angle_l_elbow",
- "angle_r_elbow", "angle_l_shoulder", "angle_r_shoulder",
- "asym_wrist", "offset_head"
- ],
- "Head Movement": [
- "dist_l_wrist_nose", "dist_r_wrist_nose", "angle_l_elbow",
- "angle_r_elbow", "angle_l_shoulder", "angle_r_shoulder",
- "asym_wrist", "offset_head", "dist_l_ear_r_shld",
- "dist_r_ear_l_shld", "diff_ear_shld", "abs_diff_ear_shld"
- ],
- "Reach (Left)": [
- "dist_l_wrist_pelvis", "dist_l_wrist_nose", "dist_l_wri_l_shld",
- "dist_l_wri_r_shld", "angle_l_elbow", "angle_l_shoulder",
- "asym_wri_shld"
- ],
- "Reach (Right)": [
- "dist_r_wrist_pelvis", "dist_r_wrist_nose", "dist_r_wri_r_shld",
- "dist_r_wri_l_shld", "angle_r_elbow", "angle_r_shoulder",
- "asym_wri_shld"
- ],
- "Kick (Left)": [
- "dist_l_ankl_pelvis", "angle_l_knee", "angle_l_hip",
- "height_l_ankl", "dist_ankl_ankl", "asym_ankl",
- "diff_knee_angle", "dist_nose_pelvis"
- ],
- "Kick (Right)": [
- "dist_r_ankl_pelvis", "angle_r_knee", "angle_r_hip",
- "height_r_ankl", "dist_ankl_ankl", "asym_ankl",
- "diff_knee_angle", "dist_nose_pelvis"
- ]
-}
+try:
+ GEOMETRY_LIBRARY, ACTIVITY_MAP = load_analysis_config()
+except FileNotFoundError:
+ GEOMETRY_LIBRARY, ACTIVITY_MAP = {}, {}
+ print("Warning: analysis_config.json not found. ML functions will fail.")
def debug_print():
if VERBOSITY:
frame = inspect.currentframe().f_back
qualname = frame.f_code.co_filename
- print(qualname)
+ print(f"DEBUG_PRINT: {qualname}")
class GeneralPredictor:
def __init__(self):
debug_print()
self.base_paths = {
- "Random Forest": "rf.pkl",
- "XGBoost": "xgb.json",
- "SVM": "svm.pkl",
- "LSTM": "lstm.pth",
- "1D-CNN": "cnn.pth"
+ "Random Forest": "rf.pkl"
}
- self.raw_participant_buffer = []
self.current_target = ""
- self.scaler_cache = {}
+ self.active_feature_keys = []
-
- def add_to_raw_buffer(self, raw_payload, y_labels):
+ def calculate_and_train(self, training_params):
"""
- Adds a participant's raw kinematic components to the pool.
- raw_payload should contain: 'z_kps', 'directions', 'raw_kps'
+ Takes the dict from get_selection() in TrainModelWindow.
+ Loads CSV/JSON pairs, extracts combined features, and trains Random Forest.
"""
debug_print()
- entry = {
- "raw_data": raw_payload,
- "labels": y_labels
+
+ folder = training_params.get("folder")
+ pairs = training_params.get("pairs", [])
+ selected_behaviors = training_params.get("selected_behaviors", [])
+ self.current_target = training_params.get("target_name", "combined_model")
+ model_type = training_params.get("model_type", "Random Forest")
+
+ if not pairs or not selected_behaviors:
+ return "Error: Missing data pairs or target behaviors."
+
+ # 1. Determine the union of ALL needed geometric features across selected behaviors
+ needed_features = set()
+ for b_name in selected_behaviors:
+ req_feats = ACTIVITY_MAP.get(b_name, [])
+ needed_features.update(req_feats)
+
+ self.active_feature_keys = sorted(list(needed_features))
+ print(self.active_feature_keys)
+
+ model_metadata = {
+ "target_behavior": self.current_target,
+ "feature_keys": self.active_feature_keys,
+ "model_type": model_type,
+ "timestamp": datetime.now().isoformat()
}
- self.raw_participant_buffer.append(entry)
- return f"Added participant to pool. Total participants: {len(self.raw_participant_buffer)}"
+
+ if not self.active_feature_keys:
+ return "Error: No geometric features mapped to the selected behavior(s) in analysis_config.json."
-
- def clear_buffer(self):
- """Clears the raw pool."""
- debug_print()
- self.raw_participant_buffer = []
-
-
- def calculate_and_train(self, model_type, target_name):
- """
- The 'On-the-Fly' engine. Loops through the raw buffer,
- calculates features for the SELECTED target, and trains.
- """
- debug_print()
- self.current_target = target_name
all_X = []
all_y = []
- # 1. Process every participant in the pool
- for participant in self.raw_participant_buffer:
- raw = participant["raw_data"]
- all_tracks = participant["labels"]
+ # 2. Process each Pair (JSON labels + CSV raw pose)
+ for json_path, csv_path in pairs:
+ # --- Load JSON Labels ---
+ try:
+ with open(json_path, 'r') as f:
+ label_data = json.load(f)
+ except Exception as e:
+ print(f"Error loading {json_path}: {e}")
+ continue
- # Pull the specific track that was requested
- track_key = f"OBS: {target_name}"
- if track_key not in all_tracks:
- print(f"Warning: Track {track_key} not found for a participant. Skipping.")
+ behaviors = label_data.get("behaviors", {})
+
+ # --- Load CSV Pose Data ---
+ try:
+ raw_kpts = []
+ with open(csv_path, 'r') as f:
+ reader = csv.reader(f)
+ next(reader) # skip header
+ for row in reader:
+ raw_kpts.append(np.array([float(x) for x in row]).reshape(17, 3))
+ raw_kpts = np.array(raw_kpts)
+ except Exception as e:
+ print(f"Error loading {csv_path}: {e}")
continue
- y = all_tracks[track_key]
-
- # Extract lists from the payload
- z_scores = raw["z_kps"]
- dirs = raw["directions"]
- kpts = raw["raw_kps"]
+ total_frames = len(raw_kpts)
+ if total_frames == 0:
+ continue
- # Calculate geometric features for every frame
+ # Create binary target array (0 = Rest, 1 = Active)
+ y_vector = np.zeros(total_frames, dtype=int)
+
+ # If the frame falls inside ANY of the selected behaviors, mark it 1
+ for b_name in selected_behaviors:
+ instances = behaviors.get(b_name, [])
+ for inst in instances:
+ start = inst.get("start_frame", 0)
+ duration = inst.get("duration_frames", 0)
+ end = min(start + duration, total_frames)
+ y_vector[start:end] = 1
+
+ # --- Calculate Features per Frame ---
+ # To match the new flow, we just need raw_kpts.
+ # (Z-scores were previously passed, but those were derived from raw anyway.
+ # If you require normalized z-scores for RF, you must recalculate them here
+ # using the same baseline logic from the main window. For now, we extract raw geom.)
+
participant_features = []
- for i in range(len(y)):
- feat = self.format_features(z_scores[i], dirs[i], kpts[i])
+ for i in range(total_frames):
+ kpts = raw_kpts[i] # Shape (17, 3)
+ feat = self.format_features(kpts)
participant_features.append(feat)
all_X.append(np.array(participant_features))
- all_y.append(y)
+ all_y.append(y_vector)
+
+ # 3. Prepare for Training
+ if not all_X:
+ return "Error: No valid data extracted from files."
- # 2. Prepare for Training
X_combined = np.vstack(all_X)
y_combined = np.concatenate(all_y)
- # 3. Scale the data specifically for this target/model combo
+ # Check for class imbalance edge case (e.g. 0 instances of behavior found)
+ if len(np.unique(y_combined)) < 2:
+ return "Error: Training data only contains one class (usually 0/Rest). Model cannot train."
+
+ metadata_path = self.get_path(model_type).replace(".pkl", "_metadata.json")
+ with open(metadata_path, 'w') as f:
+ json.dump(model_metadata, f, indent=4)
+
+ print(f"[INFO] Metadata saved to: {metadata_path}")
+
+ # 4. Scale Data
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_combined)
scaler_path = self.get_path(model_type, is_scaler=True)
joblib.dump(scaler, scaler_path)
- # 4. Train/Test Split
+ # 5. Train/Test Split
X_train, X_test, y_train, y_test = train_test_split(
X_scaled, y_combined, test_size=0.2, stratify=y_combined, random_state=42
)
- # 5. Process with corresponding Model
+ # 6. Train Random Forest (Placeholders exist for others)
if model_type == "Random Forest":
model = RandomForestClassifier(max_depth=15, n_estimators=100, class_weight="balanced")
model.fit(X_train, y_train)
- # Save the model
save_path = self.get_path(model_type)
joblib.dump(model, save_path)
y_pred = model.predict(X_test)
- # Feature Importance for the UI
- labels_names = self.get_feature_labels()
+ # Feature Importance
importances = model.feature_importances_
- feature_data = sorted(zip(labels_names, importances), key=lambda x: x[1], reverse=True)
- ui_extras = "Top Predictors:
" + "
".join([f"{n}: {v:.3f}" for n, v in feature_data])
+ feature_data = sorted(zip(self.active_feature_keys, importances), key=lambda x: x[1], reverse=True)
+
+ ui_extras = "Top Predictors:
" + "
".join([f"{n}: {v:.3f}" for n, v in feature_data[:10]])
file_extras = "Top Predictors:\n" + "\n".join([f"- {n}: {v:.3f}" for n, v in feature_data])
- return self._evaluate_and_report(model_type, y_test, y_pred, ui_extras=ui_extras, file_extras=file_extras, target_name=target_name)
+ return self._evaluate_and_report(model_type, y_test, y_pred, ui_extras=ui_extras, file_extras=file_extras)
- # TODO: More than random forest
+ elif model_type == "1D-CNN":
+ return "1D-CNN training placeholder reached. Not yet implemented."
+ elif model_type == "LSTM":
+ return "LSTM training placeholder reached. Not yet implemented."
+ elif model_type == "XGBoost":
+ return "XGBoost training placeholder reached. Not yet implemented."
else:
- return "Model type not yet implemented in calculate_and_train."
-
+ return f"Model type {model_type} not supported."
def get_path(self, model_type, is_scaler=False):
- """Returns the specific file path for the target/model or its scaler."""
debug_print()
- suffix = self.base_paths[model_type]
-
+ suffix = self.base_paths.get(model_type, "model.pkl")
if is_scaler:
suffix = suffix.split('.')[0] + "_scaler.pkl"
-
return f"ml_{self.current_target}_{suffix}"
-
- def get_feature_labels(self):
- """Returns labels only for features active in the current target."""
- debug_print()
- active_keys = ACTIVITY_MAP.get(self.current_target, [])
- return active_keys
-
-
- def format_features(self, z_scores, directions, kpts):
- """The 'Universal Parser' for geometric features."""
- # debug_print()
- # Internal Math Helpers
- if self.current_target == "ALL_FEATURES":
- active_list = list(GEOMETRY_LIBRARY.keys())
- else:
- active_list = ACTIVITY_MAP.get(self.current_target, ACTIVITY_MAP["Mouthing"])
-
+ def format_features(self, kpts):
+ """
+ Calculates only the geometric features required by self.active_feature_keys.
+ """
def resolve_pt(idx):
if isinstance(idx, list):
- # Calculate midpoint of all indices in the list
- pts = [kpts[i] for i in idx]
+ pts = [kpts[i][:2] for i in idx] # Ensure X/Y only
return np.mean(pts, axis=0)
- return kpts[idx]
+ return kpts[idx][:2]
def get_dist(p1, p2): return np.linalg.norm(p1 - p2)
def get_angle(a, b, c):
@@ -278,128 +224,122 @@ class GeneralPredictor:
try:
if kpts is None or len(kpts) < 13: raise ValueError()
- # Reference scale (Shoulders)
- scale = get_dist(kpts[5], kpts[6]) + 1e-6
+ scale = get_dist(kpts[5][:2], kpts[6][:2]) + 1e-6
+
+ # First Pass: Direct Geometries (Only calculate what is needed or what is a dependency)
+ for name, config_data in GEOMETRY_LIBRARY.items():
+ f_type = config_data[0]
+ indices = config_data[1]
- # First Pass: Direct Geometries
- for name, (f_type, indices, *meta) in GEOMETRY_LIBRARY.items():
if f_type == "dist":
- # Use resolve_pt for both indices
p1 = resolve_pt(indices[0])
p2 = resolve_pt(indices[1])
calculated_pool[name] = get_dist(p1, p2) / scale
elif f_type == "angle":
- # Use resolve_pt for all three indices
p1 = resolve_pt(indices[0])
p2 = resolve_pt(indices[1])
p3 = resolve_pt(indices[2])
calculated_pool[name] = get_angle(p1, p2, p3)
- elif f_type == "z_diff":
- # Z-scores are usually single indices, but we handle lists just in case
- z1 = np.mean([z_scores[i] for i in indices[0]]) if isinstance(indices[0], list) else z_scores[indices[0]]
- z2 = np.mean([z_scores[i] for i in indices[1]]) if isinstance(indices[1], list) else z_scores[indices[1]]
- calculated_pool[name] = abs(z1 - z2)
-
elif f_type == "head_offset":
p_target = resolve_pt(indices[0])
- p_mid = resolve_pt([indices[1], indices[2]]) # Midpoint of shoulders
+ p_mid = resolve_pt([indices[1], indices[2]])
calculated_pool[name] = abs(p_target[0] - p_mid[0]) / scale
+
+ elif f_type == "y_diff": # NEW from JSON
+ p1 = resolve_pt(indices[0])
+ p2 = resolve_pt(indices[1])
+ calculated_pool[name] = abs(p1[1] - p2[1]) / scale
- # Second Pass: Composite Geometries (Subtractions/Symmetry)
- # We do this after so 'dist_l_ear_r_shld' is already calculated
- for name, (f_type, indices, *meta) in GEOMETRY_LIBRARY.items():
+ # Second Pass: Subtractions (Requires first pass to be complete)
+ for name, config_data in GEOMETRY_LIBRARY.items():
+ f_type = config_data[0]
+ indices = config_data[1]
+
if f_type == "subtraction":
- calculated_pool[name] = calculated_pool[indices[0]] - calculated_pool[indices[1]]
+ val1 = calculated_pool.get(indices[0], 0)
+ val2 = calculated_pool.get(indices[1], 0)
+ calculated_pool[name] = val1 - val2
elif f_type == "abs_subtraction":
- calculated_pool[name] = abs(calculated_pool[indices[0]] - calculated_pool[indices[1]])
+ val1 = calculated_pool.get(indices[0], 0)
+ val2 = calculated_pool.get(indices[1], 0)
+ calculated_pool[name] = abs(val1 - val2)
except Exception:
- # If a frame fails, fill the pool with zeros to prevent crashes
calculated_pool = {name: 0.0 for name in GEOMETRY_LIBRARY.keys()}
- # Final Extraction based on current_target
-
- active_list = ACTIVITY_MAP.get(self.current_target, ACTIVITY_MAP["Mouthing"])
- feature_vector = [calculated_pool[feat] for feat in active_list]
+ # Final Extraction based on the set of needed features
+ feature_vector = [calculated_pool.get(feat, 0.0) for feat in self.active_feature_keys]
return np.array(feature_vector, dtype=np.float32)
- def _prepare_pool_data(self):
- """Merges buffer and fits scaler."""
- debug_print()
- if not self.X_buffer:
- return None, None, None
-
- X_total = np.vstack(self.X_buffer)
- y_total = np.concatenate(self.y_buffer)
-
- # We always fit a fresh scaler on the current pool
- scaler_file = f"{self.current_target}_scaler.pkl"
- scaler = StandardScaler()
- X_scaled = scaler.fit_transform(X_total)
- joblib.dump(scaler, scaler_file)
-
- return X_scaled, y_total, scaler
-
-
- def _evaluate_and_report(self, model_name, y_test, y_pred, extra_text="", ui_extras="", file_extras="", target_name=""):
- """Generates unified metrics, confusion matrix, and reports for ANY model"""
+ def _evaluate_and_report(self, model_name, y_test, y_pred, ui_extras="", file_extras=""):
debug_print()
prec = precision_score(y_test, y_pred, zero_division=0)
rec = recall_score(y_test, y_pred, zero_division=0)
f1 = f1_score(y_test, y_pred, zero_division=0)
- target = getattr(self, 'current_target', 'Activity')
- display_labels = ['Rest', target]
- # Plot Confusion Matrix
+ display_labels = ['Rest', self.current_target]
cm = confusion_matrix(y_test, y_pred)
+
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
- xticklabels=display_labels,
- yticklabels=display_labels)
- plt.title(f'{model_name} Detection: Predicted vs Actual')
+ xticklabels=display_labels, yticklabels=display_labels)
+ plt.title(f'{model_name} Detection: {self.current_target}')
plt.ylabel('Actual State')
plt.xlabel('Predicted State')
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
- plt.savefig(f"ml_{target_name}_confusion_matrix_rf_{timestamp}.png")
+ plt.savefig(f"ml_{self.current_target}_cm_{timestamp}.png")
plt.close()
- # Classification Report String
- report_str = classification_report(y_test, y_pred,
- target_names=display_labels,
- zero_division=0)
+ report_str = classification_report(y_test, y_pred, target_names=display_labels, zero_division=0)
- # Build TXT File Content
report_text = f"MODEL PERFORMANCE REPORT: {model_name}\nGenerated: {timestamp}\n"
report_text += "="*40 + "\n"
report_text += report_str + "\n"
report_text += f"Precision: {prec:.4f}\nRecall: {rec:.4f}\nF1-Score: {f1:.4f}\n"
- report_text += "="*40 + "\n" + extra_text
report_text += "="*40 + "\n" + file_extras
- with open(f"ml_{target_name}_performance_rf_{timestamp}.txt", "w") as f:
+ with open(f"ml_{self.current_target}_performance_{timestamp}.txt", "w") as f:
f.write(report_text)
- # Build UI String
ui_report = f"""
- {model_name} Performance:
+ {model_name} Model for '{self.current_target}'
Precision: {prec:.2f} | Recall: {rec:.2f} | F1: {f1:.2f}