diff --git a/.gitignore b/.gitignore index 36b13f1..e3c9c22 100644 --- a/.gitignore +++ b/.gitignore @@ -174,3 +174,4 @@ cython_debug/ # PyPI configuration file .pypirc +sparks_*/ \ No newline at end of file diff --git a/hand_landmarker.task b/hand_landmarker.task new file mode 100644 index 0000000..0d53faf Binary files /dev/null and b/hand_landmarker.task differ diff --git a/main.py b/main.py index e717d46..102a1c3 100644 --- a/main.py +++ b/main.py @@ -36,13 +36,16 @@ import requests import torch import torch.nn as nn import mediapipe as mp +from mediapipe.tasks import python +from mediapipe.tasks.python import vision from torch.utils.data import TensorDataset, DataLoader from sklearn.utils.class_weight import compute_class_weight import matplotlib.pyplot as plt from PySide6.QtWidgets import ( QApplication, QWidget, QMessageBox, QVBoxLayout, QHBoxLayout, QTextEdit, QScrollArea, QComboBox, QGridLayout, - QPushButton, QMainWindow, QFileDialog, QLabel, QLineEdit, QFrame, QSizePolicy, QGroupBox, QDialog, QListView, QMenu, QProgressBar, QCheckBox + QPushButton, QMainWindow, QFileDialog, QLabel, QLineEdit, QFrame, QSizePolicy, QGroupBox, QDialog, QListView, + QMenu, QProgressBar, QCheckBox, QSlider, QTabWidget, QTreeWidget, QTreeWidgetItem, QHeaderView ) from PySide6.QtCore import QThread, Signal, Qt, QTimer, QEvent, QSize, QPoint from PySide6.QtGui import QAction, QKeySequence, QIcon, QIntValidator, QDoubleValidator, QPixmap, QStandardItemModel, QStandardItem, QImage @@ -907,6 +910,60 @@ class FileLoadWorker(QThread): + + +class IndividualFileLoadWorker(QThread): + observations_loaded = Signal() + loading_finished = Signal() + loading_failed = Signal(str) + + def __init__(self, file_path, extract_frame_and_hands_func): + super().__init__() + self.file_path = file_path + self.extract_frame_and_hands = extract_frame_and_hands_func + self.is_running = True + self.previews_data = {} # To store the results of the heavy work + + def run(self): + try: + # 2. Process the file data to gather initial previews (the heavy part) + self.process_previews() + + self.observations_loaded.emit() + + except Exception as e: + self.loading_failed.emit(f"Loading failed: {str(e)}") + finally: + self.loading_finished.emit() # Ensures dialog closes even on some non-critical path breaks + + def process_previews(self): + + if os.path.exists(self.file_path): + # HEAVY OPERATION + frame_rgb, results = self.extract_frame_and_hands(self.file_path, 0) + + initial_wrists = [] + if results and results.hand_landmarks: + for hand_landmarks in results.hand_landmarks: + # Change: hand_landmarks is a list, access index [0] directly + wrist = hand_landmarks[0] + initial_wrists.append((wrist.x, wrist.y)) + + # Store the results needed to BUILD the UI later + self.previews_data[(1, 1)] = { + "frame_rgb": frame_rgb, + "results": results, + "video_path": self.file_path, + "fps": 60, + "initial_wrists": initial_wrists + } + + + def stop(self): + self.is_running = False + + + class AboutWindow(QWidget): """ Simple About window displaying basic application information. @@ -1396,6 +1453,546 @@ class ParticipantProcessor(QThread): self.is_running = False + +class ParticipantProcessor2(QThread): + progress_updated = Signal(int) + time_updated = Signal(str) + finished_processing = Signal(str, str) # obs_id, cam_id + + def __init__(self, obs_id, selected_cam_id, selected_hand_idx, + video_path, output_csv, output_dir, initial_wrists, **kwargs): + super().__init__() + self.obs_id = obs_id + self.cam_id = selected_cam_id + self.selected_hand_idx = selected_hand_idx + self.video_path = video_path + self.output_dir = output_dir + self.output_csv = output_csv + self.is_running = True + + # Convert initial_wrists (list) to initial_centroids (dict) for the tracker + self.current_centroids = {} + if initial_wrists: + for i, pos in enumerate(initial_wrists): + self.current_centroids[i] = pos + + def get_centroid(self, lm_list): + avg_x = sum(lm.x for lm in lm_list) / len(lm_list) + avg_y = sum(lm.y for lm in lm_list) / len(lm_list) + return (avg_x, avg_y) + + def update_tracking(self, results, last_known): + if not results or not results.hand_landmarks: + return last_known + + detected_hands = results.hand_landmarks + new_centroids = {} + used_indices = set() + + # Priority 1: Match existing IDs + for hand_id, last_pos in last_known.items(): + min_dist = float('inf') + best_idx = None + for j, lm_list in enumerate(detected_hands): + if j in used_indices: continue + curr_c = self.get_centroid(lm_list) + dist = (curr_c[0] - last_pos[0])**2 + (curr_c[1] - last_pos[1])**2 + if dist < 0.1 and dist < min_dist: + min_dist = dist + best_idx = j + + if best_idx is not None: + new_centroids[hand_id] = self.get_centroid(detected_hands[best_idx]) + used_indices.add(best_idx) + # Keep track of which detection index corresponds to our ID + if hand_id == self.selected_hand_idx: + self.current_detection_idx = best_idx + + # Carry over ghosts for lost hands + for hand_id, pos in last_known.items(): + if hand_id not in new_centroids: + new_centroids[hand_id] = pos + + return new_centroids + + def run(self): + # 1. Initialize Mediapipe INSIDE the thread + base_options = python.BaseOptions(model_asset_path='hand_landmarker.task') + options = vision.HandLandmarkerOptions( + base_options=base_options, + running_mode=vision.RunningMode.IMAGE, + num_hands=2, + min_hand_detection_confidence=0.3 + ) + detector = vision.HandLandmarker.create_from_options(options) + + cap = cv2.VideoCapture(self.video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + + save_path = os.path.join(self.output_dir, self.output_csv) + + header = ['frame'] + if self.selected_hand_idx == 99: + # Dual hand columns: h0_x0, h0_y0 ... h1_x20, h1_y20 + for h_prefix in ['h0', 'h1']: + for i in range(21): + header.extend([f'{h_prefix}_x{i}', f'{h_prefix}_y{i}']) + else: + # Single hand columns: x0, y0 ... x20, y20 + for i in range(21): + header.extend([f'x{i}', f'y{i}']) + + with open(save_path, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(header) + + + last_known_centroids = {} + + while cap.isOpened() and self.is_running: + ret, frame = cap.read() + if not ret: break + + # 1. Process with Tasks API + mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + results = detector.detect(mp_image) + frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) + + # current_mapping stores {hand_id: landmark_list} for this frame + current_mapping = {} + + if results and results.hand_landmarks: + hand_landmarks_list = results.hand_landmarks + temp_mapping = {} # {hand_id: index_in_landmarks_list} + used_indices = set() + + # --- STEP 1: TRACKING (Matching to Last Frame) --- + if last_known_centroids: + for hand_id, last_pos in last_known_centroids.items(): + min_dist = float('inf') + best_idx = None + + for j, lm_list in enumerate(hand_landmarks_list): + if j in used_indices: continue + + c = self.get_centroid(lm_list) # Assumes get_centroid is in this class + dist = (c[0] - last_pos[0])**2 + (c[1] - last_pos[1])**2 + + if dist < min_dist: + min_dist = dist + best_idx = j + + if best_idx is not None and min_dist < 0.1: + temp_mapping[hand_id] = best_idx + used_indices.add(best_idx) + + # --- STEP 2: DISCOVERY (New Hands) --- + for j in range(len(hand_landmarks_list)): + if j not in used_indices: + new_id = 0 + while new_id in temp_mapping or new_id in last_known_centroids: + new_id += 1 + temp_mapping[new_id] = j + used_indices.add(j) + + # --- STEP 3: UPDATE MEMORY & PREPARE ROW --- + for hand_id, idx in temp_mapping.items(): + landmarks = hand_landmarks_list[idx] + current_mapping[hand_id] = landmarks + last_known_centroids[hand_id] = self.get_centroid(landmarks) + + # --- STEP 4: WRITE TO CSV --- + row = [frame_idx] + + if self.selected_hand_idx == 99: + # DUAL MODE: We expect Hand 0 and Hand 1 + for hand_id in [0, 1]: + if hand_id in current_mapping: + for lm in current_mapping[hand_id]: + row.extend([lm.x, lm.y]) + else: + row.extend([0.0] * 42) # Zero-pad if this specific ID is missing + else: + # SINGLE MODE: Use the specific ID (0 or 1) selected in the UI + if self.selected_hand_idx in current_mapping: + for lm in current_mapping[self.selected_hand_idx]: + row.extend([lm.x, lm.y]) + else: + row.extend([0.0] * 42) + + writer.writerow(row) + + # Update UI + if frame_idx % 10 == 0: + progress = int((frame_idx / total_frames) * 100) + self.progress_updated.emit(progress) + + secs = int(frame_idx / fps) + total_secs = int(total_frames / fps) + self.time_updated.emit(f"{secs//60:02}:{secs%60:02}/{total_secs//60:02}:{total_secs%60:02}") + + cap.release() + print("Released") + detector.close() + print("Closed") + self.finished_processing.emit(str(self.obs_id), str(self.cam_id)) + + def cancel(self): + self.is_running = False + + +def load_hand_csv(filepath): + df = pd.read_csv(filepath) + # We create a dictionary: {frame_number: [list of 21 (x,y) tuples]} + data = {} + for _, row in df.iterrows(): + f_idx = int(row['frame']) + landmarks = [] + for i in range(21): + landmarks.append((row[f'x{i}'], row[f'y{i}'])) + data[f_idx] = landmarks + return data + + +class HandValidationWindow(QWidget): + def __init__(self): + super().__init__() + self.setWindowTitle("SPARKS - Dual Hand Data Validator") + self.resize(1000, 850) + + self.layout = QVBoxLayout(self) + + # --- Video Display --- + self.video_label = QLabel("Load a Video and CSV to begin") + self.video_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.video_label.setStyleSheet("background-color: black; border: 2px solid #333;") + self.video_label.setMinimumSize(1, 1) + self.layout.addWidget(self.video_label, stretch=1) + + # --- Playback Slider --- + self.slider = QSlider(Qt.Orientation.Horizontal) + self.slider.sliderMoved.connect(self.set_position) + self.layout.addWidget(self.slider) + + # --- Control Row --- + ctrl_layout = QHBoxLayout() + + self.btn_load = QPushButton("Load Pair") + self.btn_load.clicked.connect(self.load_files) + ctrl_layout.addWidget(self.btn_load) + + self.btn_play = QPushButton("Pause") # Toggles Play/Pause + self.btn_play.clicked.connect(self.toggle_play) + ctrl_layout.addWidget(self.btn_play) + + self.speed_combo = QComboBox() + self.speed_combo.addItems(["0.25x", "0.5x", "1.0x", "2.0x"]) + self.speed_combo.setCurrentText("1.0x") + self.speed_combo.currentIndexChanged.connect(self.update_speed) + + ctrl_layout.addWidget(QLabel("Speed:")) + ctrl_layout.addWidget(self.speed_combo) + # Toggles for Hands + self.chk_h0 = QCheckBox("Show Hand 0 (Cyan)") + self.chk_h0.setChecked(True) + self.chk_h1 = QCheckBox("Show Hand 1 (Magenta)") + self.chk_h1.setChecked(True) + ctrl_layout.addWidget(self.chk_h0) + ctrl_layout.addWidget(self.chk_h1) + + self.lbl_frame = QLabel("Frame: 0") + ctrl_layout.addWidget(self.lbl_frame) + + self.layout.addLayout(ctrl_layout) + + self.inspector = HandDataInspector() + self.inspector.show() + + # Logic state + self.timer = QTimer() + self.timer.timeout.connect(self.update_frame) + self.cap = None + self.hand_data = {} # Will now store {frame: {'h0': [...], 'h1': [...]}} + self.paused = False + + from collections import deque + self.com_history = { + 'h0': deque(maxlen=30), + 'h1': deque(maxlen=30) + } + + def load_files(self): + video_path, _ = QFileDialog.getOpenFileName(self, "Select Video", "", "Videos (*.mp4 *.avi)") + csv_path, _ = QFileDialog.getOpenFileName(self, "Select CSV", "", "CSV Files (*.csv)") + + if video_path and csv_path: + self.cap = cv2.VideoCapture(video_path) + self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) + self.slider.setRange(0, self.total_frames - 1) + + # Load the dual-hand data + self.hand_data = self.parse_dual_hand_csv(csv_path) + self.timer.start(16) + + def update_speed(self): + speed_map = {"0.25x": 64, "0.5x": 32, "1.0x": 16, "2.0x": 8} + ms = speed_map.get(self.speed_combo.currentText(), 16) + if self.timer.isActive(): + self.timer.stop() + self.timer.start(ms) + else: + # If paused, just store the intended speed for when they hit Play + self.current_interval = ms + + def parse_dual_hand_csv(self, path): + df = pd.read_csv(path) + data = {} + + # Check if this is a Dual Hand CSV or Single Hand CSV + is_dual = 'h0_x0' in df.columns + + for _, row in df.iterrows(): + f = int(row['frame']) + if is_dual: + h0 = [(row[f'h0_x{i}'], row[f'h0_y{i}']) for i in range(21)] + h1 = [(row[f'h1_x{i}'], row[f'h1_y{i}']) for i in range(21)] + data[f] = {'h0': h0, 'h1': h1} + else: + # Fallback for old single-hand files + h0 = [(row[f'x{i}'], row[f'y{i}']) for i in range(21)] + data[f] = {'h0': h0, 'h1': []} + + # Disable Hand 1 checkbox if it's a single hand file + self.chk_h1.setEnabled(is_dual) + return data + + def toggle_play(self): + if self.paused: + self.timer.start(16) + self.btn_play.setText("Pause") + else: + self.timer.stop() + self.btn_play.setText("Play") + self.paused = not self.paused + + def set_position(self, position): + if self.cap: + self.cap.set(cv2.CAP_PROP_POS_FRAMES, position) + # Update visual immediately if paused + if self.paused: self.update_frame(manual=True) + + def update_frame(self, manual=False): + if not manual: + ret, frame = self.cap.read() + else: + # Re-read the current frame for scrubbing + curr = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES)) + self.cap.set(cv2.CAP_PROP_POS_FRAMES, curr) + ret, frame = self.cap.read() + + if not ret: return + + f_idx = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES)) + self.slider.setValue(f_idx) + self.lbl_frame.setText(f"Frame: {f_idx}") + + h, w, _ = frame.shape + prev_idx = f_idx - 1 + prev_data = self.hand_data.get(prev_idx, {}) + fps = self.cap.get(cv2.CAP_PROP_FPS) + if f_idx in self.hand_data: + for h_key in ['h0', 'h1']: + curr = self.hand_data[f_idx].get(h_key, []) + prev = self.hand_data.get(prev_idx, {}).get(h_key, []) + + # This handles the zero-padding (if hand is lost, don't calculate) + if curr and not (curr[0][0] == 0 and curr[0][1] == 0): + self.inspector.update_hand_data(h_key, curr, prev, fps) + if self.chk_h0.isChecked(): + curr_h0 = self.hand_data[f_idx]['h0'] + prev_h0 = prev_data.get('h0', []) + self.draw_skeleton(frame, curr_h0, w, h, (255, 255, 0), 'h0', prev_h0, fps) + if self.chk_h1.isChecked(): + curr_h1 = self.hand_data[f_idx]['h1'] + prev_h1 = prev_data.get('h1', []) + self.draw_skeleton(frame, curr_h1, w, h, (255, 0, 255), 'h1', prev_h1, fps) + + rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + qimg = QImage(rgb.data, w, h, 3 * w, QImage.Format_RGB888) + display_size = self.video_label.contentsRect().size() + + self.video_label.setPixmap( + QPixmap.fromImage(qimg).scaled( + display_size, + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation # Added for better quality + ) + ) + + def draw_skeleton(self, img, landmarks, w, h, color, h_key, prev_landmarks=None, fps=30): + # --- 1. Draw Skeleton Connections --- + conns = [(0,1), (1,2), (2,3), (3,4), (0,5), (5,6), (6,7), (7,8), + (9,10), (10,11), (11,12), (13,14), (14,15), (15,16), + (0,17), (17,18), (18,19), (19,20), (5,9), (9,13), (13,17)] + + for s, e in conns: + p1 = (int(landmarks[s][0]*w), int(landmarks[s][1]*h)) + p2 = (int(landmarks[e][0]*w), int(landmarks[e][1]*h)) + cv2.line(img, p1, p2, color, 2) + + # Calculate current Center of Mass (COM) in pixels + avg_x = sum(p[0] for p in landmarks) / 21 + avg_y = sum(p[1] for p in landmarks) / 21 + com_px = (int(avg_x * w), int(avg_y * h)) + + # --- 2. Instant Vector (Thinner, Frame-to-Frame) --- + if prev_landmarks and len(prev_landmarks) == 21: + p_avg_x = sum(p[0] for p in prev_landmarks) / 21 + p_avg_y = sum(p[1] for p in prev_landmarks) / 21 + + # Instant displacement scaled for visibility + v_scale_inst = 35 + vx_inst = (avg_x - p_avg_x) * v_scale_inst + vy_inst = (avg_y - p_avg_y) * v_scale_inst + + inst_end_px = (int((avg_x + vx_inst) * w), int((avg_y + vy_inst) * h)) + + # Draw Instant Arrow (Colored to match the hand) + cv2.arrowedLine(img, com_px, inst_end_px, color, 2, tipLength=0.2) + + + self.com_history[h_key].append((avg_x, avg_y)) + + if len(self.com_history[h_key]) > 1: + # Calculate average per-frame displacement over the history + # (Current Position - Oldest Position) / Number of Frames + oldest = self.com_history[h_key][0] + count = len(self.com_history[h_key]) + + avg_dx = (avg_x - oldest[0]) / count + avg_dy = (avg_y - oldest[1]) / count + + # Scale for the Big Arrow (Adjust 60-100 to change length) + v_scale_smooth = 35 + + # Calculate end point + smooth_end_x = avg_x + (avg_dx * v_scale_smooth) + smooth_end_y = avg_y + (avg_dy * v_scale_smooth) + smooth_end_px = (int(smooth_end_x * w), int(smooth_end_y * h)) + + # Draw the Big Arrow (White with Black Outline for high visibility) + cv2.arrowedLine(img, com_px, smooth_end_px, (0, 0, 0), 11, tipLength=0.3) + cv2.arrowedLine(img, com_px, smooth_end_px, (255, 255, 255), 5, tipLength=0.3) + + +class HandDataInspector(QWidget): + def __init__(self): + super().__init__() + self.setWindowTitle("Hand Kinematics Inspector") + self.resize(700, 800) + layout = QVBoxLayout(self) + self.tabs = QTabWidget() + layout.addWidget(self.tabs) + + self.hand_trees = {} + self.joint_nodes = {'h0': {}, 'h1': {}} + self.group_nodes = {'h0': {}, 'h1': {}} # Track finger groups + self.hand_nodes = {'h0': None, 'h1': None} + + # Define finger groups once + self.finger_map = { + "Palm / Wrist": [0, 1, 5, 9, 13, 17], + "Thumb": [2, 3, 4], + "Index": [6, 7, 8], + "Middle": [10, 11, 12], + "Ring": [14, 15, 16], + "Pinky": [18, 19, 20] + } + + self.setup_hand_tab("h0", "Hand 0 (Cyan)") + self.setup_hand_tab("h1", "Hand 1 (Magenta)") + + def setup_hand_tab(self, key, label): + tree = QTreeWidget() + tree.setColumnCount(4) + tree.setHeaderLabels(["Feature", "Position (X,Y)", "Vector (Vx, Vy)", "Speed/Angle"]) + tree.header().setSectionResizeMode(QHeaderView.ResizeMode.Stretch) + self.tabs.addTab(tree, label) + self.hand_trees[key] = tree + + # Create Root + hand_root = QTreeWidgetItem(tree, ["Whole Hand", "", "", ""]) + self.hand_nodes[key] = hand_root + + # Create Groups (Fingers) + for group_name, indices in self.finger_map.items(): + parent = QTreeWidgetItem(hand_root, [group_name, "", "", ""]) + self.group_nodes[key][group_name] = parent # Store for updating + for idx in indices: + child = QTreeWidgetItem(parent, [f"Joint {idx}", "", "", ""]) + self.joint_nodes[key][idx] = child + + # START EXPANDED + tree.expandAll() + + def update_hand_data(self, hand_key, current_pts, last_pts, fps): + if not current_pts or len(current_pts) < 21: return + + # 1. Update Every Individual Joint + joint_velocities = {} + for idx, pt in enumerate(current_pts): + vx, vy, speed, angle = 0, 0, 0, 0 + if last_pts and len(last_pts) == 21: + vx = (pt[0] - last_pts[idx][0]) * fps + vy = (pt[1] - last_pts[idx][1]) * fps + speed = (vx**2 + vy**2)**0.5 + angle = math.degrees(math.atan2(vy, vx)) + + joint_velocities[idx] = (vx, vy, speed, angle) + node = self.joint_nodes[hand_key][idx] + node.setText(1, f"{pt[0]:.3f}, {pt[1]:.3f}") + node.setText(2, f"{vx:+.2f}, {vy:+.2f}") + node.setText(3, f"{speed:.2f} @ {angle:.0f}°") + + # 2. Update Finger Groups (Averaging their specific joints) + for group_name, indices in self.finger_map.items(): + g_pts = [current_pts[i] for i in indices] + g_vels = [joint_velocities[i] for i in indices] + + # Calculate Mean Position for this Finger + m_x = sum(p[0] for p in g_pts) / len(indices) + m_y = sum(p[1] for p in g_pts) / len(indices) + + # Calculate Mean Velocity for this Finger + m_vx = sum(v[0] for v in g_vels) / len(indices) + m_vy = sum(v[1] for v in g_vels) / len(indices) + m_speed = (m_vx**2 + m_vy**2)**0.5 + m_angle = math.degrees(math.atan2(m_vy, m_vx)) + + group_node = self.group_nodes[hand_key][group_name] + group_node.setText(1, f"{m_x:.3f}, {m_y:.3f}") # Numerical Mean Pos + group_node.setText(2, f"{m_vx:+.2f}, {m_vy:+.2f}") + group_node.setText(3, f"{m_speed:.2f} @ {m_angle:.0f}°") + + # 3. Update Whole Hand (Mean of ALL 21 points) + # Position Mean + whole_x = sum(p[0] for p in current_pts) / 21 + whole_y = sum(p[1] for p in current_pts) / 21 + + # Velocity Mean + whole_vx = sum(v[0] for v in joint_velocities.values()) / 21 + whole_vy = sum(v[1] for v in joint_velocities.values()) / 21 + whole_speed = (whole_vx**2 + whole_vy**2)**0.5 + whole_angle = math.degrees(math.atan2(whole_vy, whole_vx)) + + root = self.hand_nodes[hand_key] + root.setText(1, f"{whole_x:.3f}, {whole_y:.3f}") # Numerical Mean Pos + root.setText(2, f"{whole_vx:+.2f}, {whole_vy:+.2f}") + root.setText(3, f"{whole_speed:.2f} @ {whole_angle:.0f}°") + + + class MainApplication(QMainWindow): """ Main application window that creates and sets up the UI. @@ -1447,12 +2044,26 @@ class MainApplication(QMainWindow): self.progress_dialog = None # Mediapipe hands - self.mp_hands = mp.solutions.hands.Hands( - static_image_mode=True, - max_num_hands=2, - min_detection_confidence=0.5 + # --- Mediapipe Hands (Updated for Tasks API) --- + # Ensure 'hand_landmarker.task' is in your project directory + base_options = python.BaseOptions(model_asset_path='hand_landmarker.task') + + # Mapping old parameters to new HandLandmarkerOptions + # static_image_mode=True -> vision.RunningMode.IMAGE + # max_num_hands=2 -> num_hands=2 + # min_detection_confidence -> min_hand_detection_confidence + self.mp_hands_options = vision.HandLandmarkerOptions( + base_options=base_options, + running_mode=vision.RunningMode.IMAGE, + num_hands=2, + min_hand_detection_confidence=0.5, + min_hand_presence_confidence=0.5, + min_tracking_confidence=0.5 ) + # Initialize the detector + self.mp_hands = vision.HandLandmarker.create_from_options(self.mp_hands_options) + # Start local pending update check thread self.local_check_thread = LocalPendingUpdateCheckThread(CURRENT_VERSION, self.platform_suffix) self.local_check_thread.pending_update_found.connect(self.on_pending_update_found) @@ -1638,7 +2249,7 @@ class MainApplication(QMainWindow): file_menu = menu_bar.addMenu("File") file_actions = [ ("Open BORIS file...", "Ctrl+O", self.open_file_dialog, resource_path("icons/file_open_24dp_1F1F1F.svg")), - #("Open Folder...", "Ctrl+Alt+O", self.open_folder_dialog, resource_path("icons/folder_24dp_1F1F1F.svg")), + ("Open Video file...", "Ctrl+Alt+O", self.open_video_file_dialog, resource_path("icons/file_open_24dp_1F1F1F.svg")), #("Open Folders...", "Ctrl+Shift+O", self.open_folder_dialog, resource_path("icons/folder_copy_24dp_1F1F1F.svg")), #("Load Project...", "Ctrl+L", self.load_project, resource_path("icons/article_24dp_1F1F1F.svg")), #("Save Project...", "Ctrl+S", self.save_project, resource_path("icons/save_24dp_1F1F1F.svg")), @@ -1670,11 +2281,13 @@ class MainApplication(QMainWindow): ("Create model from a folder", "Ctrl+I", self.train_model_folder, resource_path("icons/content_copy_24dp_1F1F1F.svg")), ("Test model on a video", "Ctrl+O", self.test_model_video, resource_path("icons/content_paste_24dp_1F1F1F.svg")), ("Test model on a folder", "Ctrl+P", self.test_model_folder, resource_path("icons/content_paste_24dp_1F1F1F.svg")), - ("Test model on a CSV", "Ctrl+P", self.test_model_csv, resource_path("icons/content_paste_24dp_1F1F1F.svg")) + ("Test model on a CSV", "Ctrl+P", self.test_model_csv, resource_path("icons/content_paste_24dp_1F1F1F.svg")), + ("Test CSV on a Video", "Ctrl+P", self.open_validator, resource_path("icons/content_paste_24dp_1F1F1F.svg")) + ] for i, (name, shortcut, slot, icon) in enumerate(model_actions): model_menu.addAction(make_action(name, shortcut, slot, icon=icon)) - if i == 1: + if i == 1 or i == 4: model_menu.addSeparator() # View menu @@ -1972,6 +2585,10 @@ class MainApplication(QMainWindow): # 3. Start the thread self.test_thread.start() + + def open_validator(self): + self.validator_window = HandValidationWindow() + self.validator_window.show() def on_csv_analysis_finished(self, result): if "error" in result: @@ -2089,6 +2706,26 @@ class MainApplication(QMainWindow): + def open_video_file_dialog(self): + file_path, _ = QFileDialog.getOpenFileName( + self, "Open File", "", "Video Files (*.mp4 *.avi *.mov *.mkv *.wmv);;All Files (*)" + ) + if file_path: + # 3. Initialize and Start Worker Thread + self.worker_thread = IndividualFileLoadWorker( + file_path, + self.extract_frame_and_hands # Assuming this is a method available to MainApplication + ) + + # 4. Connect Signals to Main Thread Slots + self.worker_thread.observations_loaded.connect(self.on_individual_files_loaded) + + self.worker_thread.loading_finished.connect(self.on_loading_finished) + self.worker_thread.loading_failed.connect(self.on_loading_failed) + + self.worker_thread.start() + + def resolve_path_to_observations(self, boris_file_path): """ Attempts to find the 'Observations' directory based on the BORIS file location. @@ -2148,6 +2785,16 @@ class MainApplication(QMainWindow): self.statusBar().showMessage(f"{self.worker_thread.file_path} loaded.") self.button1.setVisible(True) + + def on_individual_files_loaded(self): + # 1. Update MainApplication state variables + + # 2. Build the UI grid using the data gathered by the worker + self.build_individual_preview_grid(self.worker_thread.previews_data) + + self.statusBar().showMessage(f"{self.worker_thread.file_path} loaded.") + self.button1.setVisible(True) + def on_loading_finished(self): if self.progress_dialog: self.progress_dialog.accept() # Close the dialog if finished successfully @@ -2233,9 +2880,17 @@ class MainApplication(QMainWindow): # Dropdown dropdown = QComboBox() dropdown.addItem("Skip this camera", -1) - if results and results.multi_hand_landmarks: - for idx in range(len(results.multi_hand_landmarks)): + if results and results.hand_landmarks: + num_hands = len(results.hand_landmarks) + + # Individual Hand Options + for idx in range(num_hands): dropdown.addItem(f"Use Hand {idx}", idx) + + # NEW: Add "Both" option if more than 1 hand is detected + if num_hands > 1: + dropdown.addItem("Use Both Hands", 99) # Use 99 as a special flag + row_layout.addWidget(dropdown) # Store dropdown @@ -2263,6 +2918,90 @@ class MainApplication(QMainWindow): + def build_individual_preview_grid(self, previews_data): + group = QGroupBox(f"Participant / Observation: {1}") + grouplayout = QVBoxLayout(group) + + # Participant-level skip dropdown + participant_dropdown = QComboBox() + participant_dropdown.addItem("Skip this participant", 0) + participant_dropdown.addItem("Process this participant", 1) + participant_dropdown.setCurrentIndex(0) # default: skip + if "participant_selection" not in self.selection_widgets: + self.selection_widgets["participant_selection"] = {} + self.selection_widgets["participant_selection"][1] = participant_dropdown + grouplayout.addWidget(QLabel("Participant Option:")) + grouplayout.addWidget(participant_dropdown) + + if 1 not in self.selection_widgets: + self.selection_widgets[1] = {} + + + + # Check if the worker successfully gathered the preview data for this file + state_key = (1, 1) + + # Retrieve the pre-calculated data + state_data = previews_data[state_key] + frame_rgb = state_data["frame_rgb"] + results = state_data["results"] + + # Only draw the overlay if the worker found data + if frame_rgb is not None: + display_img, updated_centroids = self.draw_hand_overlay(frame_rgb, results) + + h, w, _ = display_img.shape + qimg = QImage(display_img.data, w, h, 3 * w, QImage.Format_RGB888) + pix = QPixmap.fromImage(qimg).scaled(350, 350, Qt.KeepAspectRatio) + + # -------- UI Row -------- + row = QWidget() + row_layout = QHBoxLayout(row) + + preview_label = QLabel() + preview_label.setPixmap(pix) + row_layout.addWidget(preview_label) + + # Dropdown + dropdown = QComboBox() + dropdown.addItem("Skip this camera", -1) + if results and results.hand_landmarks: + num_hands = len(results.hand_landmarks) + + # Individual Hand Options + for idx in range(num_hands): + dropdown.addItem(f"Use Hand {idx}", idx) + + # NEW: Add "Both" option if more than 1 hand is detected + if num_hands > 1: + dropdown.addItem("Use Both Hands", 99) # Use 99 as a special flag + + + row_layout.addWidget(dropdown) + + # Store dropdown + self.selection_widgets[1][1] = dropdown + + self.camera_state[(1, 1)] = { + "path": state_data["video_path"], + "frame_idx": 0, + "fps": state_data["fps"], + "preview_label": preview_label, + "dropdown": dropdown, + "initial_wrists": state_data["initial_wrists"] + } + + # Skip button + skip_btn = QPushButton("Skip 1s → Rescan") + skip_btn.clicked.connect(lambda _, o=1, c=1: self.skip_and_rescan(o, c)) + row_layout.addWidget(skip_btn) + + grouplayout.addWidget(row) + + self.bubble_layout.addWidget(group) + + #self.bubble_layout.addStretch() + # ============================================ # FRAME EXTRACTION + MEDIA PIPE DETECTION @@ -2277,67 +3016,111 @@ class MainApplication(QMainWindow): return None, None rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - results = self.mp_hands.process(rgb) + # 1. Convert BGR to RGB (OpenCV uses BGR by default) + rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # 2. Wrap the numpy array in a MediaPipe Image object + mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_frame) + + # 3. Use .detect() instead of .process() + # Note: Since your __init__ uses RunningMode.IMAGE, use .detect() + results = self.mp_hands.detect(mp_image) return rgb, results - + # ============================================ # DRAW OVERLAYED HANDS + BIG LABELS # ============================================ - def draw_hand_overlay(self, img, results, initial_wrists=None): + def draw_hand_overlay(self, img, results, last_known_centroids=None): draw_img = img.copy() h, w, _ = draw_img.shape + COLORS = [(0, 255, 0), (255, 0, 255), (0, 255, 255)] + + current_centroids = {} # We will return this - COLORS = [ - (0, 255, 0), # Hand 0 green - (255, 0, 255), # Hand 1 magenta - (0, 255, 255), # Hand 2 cyan - ] + if results and results.hand_landmarks: + hand_landmarks_list = results.hand_landmarks + hand_mapping = {} + used_indices = set() - if results and results.multi_hand_landmarks: - # Map detected hands to initial wrists if provided - if initial_wrists: - hand_mapping = {} # initial_idx -> current_idx - used_indices = set() - for i, iw in enumerate(initial_wrists): + # 1. TRACKING: Match detected hands to the LAST known centroids + if last_known_centroids: + # last_known_centroids should be a dict: {id: (x, y)} + for hand_id, last_pos in last_known_centroids.items(): min_dist = float('inf') best_idx = None - for j, lm in enumerate(results.multi_hand_landmarks): - if j in used_indices: - continue # skip already assigned - wrist = lm.landmark[0] - dist = (wrist.x - iw[0])**2 + (wrist.y - iw[1])**2 + + for j, lm_list in enumerate(hand_landmarks_list): + if j in used_indices: continue + + current_c = self.get_centroid(lm_list) + dist = (current_c[0] - last_pos[0])**2 + (current_c[1] - last_pos[1])**2 + if dist < min_dist: min_dist = dist best_idx = j - if best_idx is not None: - hand_mapping[i] = best_idx + + if best_idx is not None and min_dist < 0.1: # Threshold to prevent "teleporting" + hand_mapping[hand_id] = best_idx used_indices.add(best_idx) - else: - hand_mapping = {i: i for i in range(len(results.multi_hand_landmarks))} - for initial_idx, current_idx in hand_mapping.items(): - lm_obj = results.multi_hand_landmarks[current_idx] - mp.solutions.drawing_utils.draw_landmarks( - draw_img, - lm_obj, - mp.solutions.hands.HAND_CONNECTIONS - ) - wrist = lm_obj.landmark[0] + # 2. DISCOVERY: If a hand wasn't matched (or no last_known), assign it a new ID + for j in range(len(hand_landmarks_list)): + if j not in used_indices: + new_id = 0 + while new_id in hand_mapping or (last_known_centroids and new_id in last_known_centroids): + new_id += 1 + hand_mapping[new_id] = j + used_indices.add(j) + + + # Define connections manually (since mp.solutions.hands.HAND_CONNECTIONS is gone) + HAND_CONNECTIONS = [ + (0, 1), (1, 2), (2, 3), (3, 4), # Thumb + (0, 5), (5, 6), (6, 7), (7, 8), # Index + (9, 10), (10, 11), (11, 12), # Middle + (13, 14), (14, 15), (15, 16), # Ring + (0, 17), (17, 18), (18, 19), (19, 20), # Pinky + (5, 9), (9, 13), (13, 17) # Palm + ] + + for hand_id, current_idx in hand_mapping.items(): + lm_list = hand_landmarks_list[current_idx] + color = COLORS[hand_id % len(COLORS)] + + current_centroids[hand_id] = self.get_centroid(lm_list) + + # Change 3: Manual Drawing (since solutions.drawing_utils is removed) + # Draw Connections + for connection in HAND_CONNECTIONS: + start_lm = lm_list[connection[0]] + end_lm = lm_list[connection[1]] + cv2.line(draw_img, + (int(start_lm.x * w), int(start_lm.y * h)), + (int(end_lm.x * w), int(end_lm.y * h)), + color, 2) + + # Draw Landmarks + for lm in lm_list: + cv2.circle(draw_img, (int(lm.x * w), int(lm.y * h)), 5, (255, 255, 255), -1) + + # Draw Label + wrist = lm_list[0] wx, wy = int(wrist.x * w), int(wrist.y * h) - color = COLORS[initial_idx % len(COLORS)] - + # Big outline - cv2.putText(draw_img, str(initial_idx), (wx, wy - 40), + cv2.putText(draw_img, str(hand_id), (wx, wy - 40), cv2.FONT_HERSHEY_SIMPLEX, 3.2, (0, 0, 0), 10, cv2.LINE_AA) - # Big colored label - cv2.putText(draw_img, str(initial_idx), (wx, wy - 40), + cv2.putText(draw_img, str(hand_id), (wx, wy - 40), cv2.FONT_HERSHEY_SIMPLEX, 3.2, color, 6, cv2.LINE_AA) - return draw_img + combined_centroids = last_known_centroids.copy() if last_known_centroids else {} + combined_centroids.update(current_centroids) + + return draw_img, combined_centroids # ============================================ @@ -2347,31 +3130,125 @@ class MainApplication(QMainWindow): key = (obs_id, cam_id) state = self.camera_state[key] - fps = state["fps"] - state["frame_idx"] += int(fps) + video_path = state["path"] + start_frame = state["frame_idx"] + fps = int(state.get("fps", 60)) + end_frame = start_frame + fps - rgb, results = self.extract_frame_and_hands( - state["path"], state["frame_idx"] - ) - if rgb is None: + # 1. Initialize Video Capture + cap = cv2.VideoCapture(video_path) + cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + + # Get the current tracking 'anchors' + # We ensure it's a dict for our 'sticky' logic + last_centroids = state.get("initial_centroids", {}) + if isinstance(last_centroids, list): + last_centroids = {i: pos for i, pos in enumerate(last_centroids)} + + # 2. BRIDGE THE GAP: Process intermediate frames at high speed + # We skip drawing to save processing time + for f_idx in range(start_frame, end_frame): + ret, frame = cap.read() + if not ret: + break + + # Convert and detect + rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_frame) + results = self.mp_hands.detect(mp_image) + + # Update the centroids based on this frame's movement + # This keeps the IDs 'locked' to the moving hands + last_centroids = self.update_tracking_only(results, last_centroids) + + # 3. FINAL FRAME: Read the destination frame and display it + # Note: cap.read() has already moved to end_frame after the loop + ret, final_frame = cap.read() + cap.release() + + if not ret: return - display_img = self.draw_hand_overlay(rgb, results, state.get("initial_wrists")) + rgb_final = cv2.cvtColor(final_frame, cv2.COLOR_BGR2RGB) + mp_final = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_final) + final_results = self.mp_hands.detect(mp_final) + + # Draw the final overlay using our 'bridged' centroids + display_img, final_centroids = self.draw_hand_overlay(rgb_final, final_results, last_centroids) + + # Update State for UI and future tracking + state["frame_idx"] = end_frame + 1 + state["initial_centroids"] = final_centroids + + # Update UI Pixmap h, w, _ = display_img.shape qimg = QImage(display_img.data, w, h, 3 * w, QImage.Format_RGB888) pix = QPixmap.fromImage(qimg).scaled(350, 350, Qt.KeepAspectRatio) - - # Update preview state["preview_label"].setPixmap(pix) - # Update dropdown + # Update Dropdown dropdown = state["dropdown"] dropdown.clear() dropdown.addItem("Skip this camera", -1) - if results and results.multi_hand_landmarks: - for idx in range(len(results.multi_hand_landmarks)): + if final_results and final_results.hand_landmarks: + for idx in range(len(final_results.hand_landmarks)): dropdown.addItem(f"Use Hand {idx}", idx) + # NEW: Add "Both" option if more than 1 hand is detected + if len(final_results.hand_landmarks) > 1: + dropdown.addItem("Use Both Hands", 99) # Use 99 as a special flag + def get_centroid(self, lm_list): + """Calculates the center of mass for all 21 hand landmarks.""" + avg_x = sum(lm.x for lm in lm_list) / len(lm_list) + avg_y = sum(lm.y for lm in lm_list) / len(lm_list) + return (avg_x, avg_y) + + def update_tracking_only(self, results, last_known): + """Updates hand positions without drawing. Perfect for high-speed bridging.""" + if not results or not results.hand_landmarks: + # If tracking is lost this frame, return last known positions + return last_known + + detected_hands = results.hand_landmarks + new_centroids = {} + used_indices = set() + + # Match current detections to last known IDs + for hand_id, last_pos in last_known.items(): + min_dist = float('inf') + best_idx = None + + for j, lm_list in enumerate(detected_hands): + if j in used_indices: + continue + + curr_c = self.get_centroid(lm_list) + # Distance check (squared) + dist = (curr_c[0] - last_pos[0])**2 + (curr_c[1] - last_pos[1])**2 + + # Since frames are 1/60th of a sec apart, hands shouldn't move more than 5% + if dist < 0.05 and dist < min_dist: + min_dist = dist + best_idx = j + + if best_idx is not None: + new_centroids[hand_id] = self.get_centroid(detected_hands[best_idx]) + used_indices.add(best_idx) + + # Carry over 'ghost' positions for hands that were not found this frame + for hand_id, pos in last_known.items(): + if hand_id not in new_centroids: + new_centroids[hand_id] = pos + + # If new hands appear that weren't tracked before, add them + for j in range(len(detected_hands)): + if j not in used_indices: + new_id = 0 + while new_id in new_centroids: + new_id += 1 + new_centroids[new_id] = self.get_centroid(detected_hands[j]) + + return new_centroids # def save_project(self, onCrash=False): @@ -2641,7 +3518,7 @@ class MainApplication(QMainWindow): self.bubble_layout.addWidget(row, row_index, 0, 1, 1) # Save widgets for thread updates - state_key = (obs_id, cam_id) + state_key = (str(obs_id), str(cam_id)) self.processing_widgets[state_key] = { "label": filename_label, "progress": progress_bar, @@ -2650,14 +3527,13 @@ class MainApplication(QMainWindow): } # --- THREAD START --- output_csv = f"{obs_id}_{cam_id}_processed.csv" - + full_video_path = file_info["path"] # ⚠️ Ensure ParticipantProcessor is defined to accept self.observations_root - processor = ParticipantProcessor( + processor = ParticipantProcessor2( obs_id=obs_id, - boris_json=self.boris, selected_cam_id=cam_id, selected_hand_idx=selected_hand_idx, - observations_root=self.observations_root, # CRITICAL PATH FIX + video_path=full_video_path, # CRITICAL PATH FIX hide_preview=False, # Assuming default for now output_csv=output_csv, output_dir=self.output_dir, @@ -2666,13 +3542,13 @@ class MainApplication(QMainWindow): # Connect signals to the UI elements processor.progress_updated.connect(progress_bar.setValue) - processor.finished_processing.connect( - lambda obs_id=obs_id, cam_id=cam_id: self.on_processing_finished(obs_id, cam_id) - ) + # Replace your current connection with this: + processor.finished_processing.connect(self.on_processing_finished) processor.time_updated.connect(time_label.setText) # Connects to THIS time_label - processor.start() self.processing_threads.append(processor) + processor.start() + # Set the vertical stretch to push all progress bars to the top num_rows = self.bubble_layout.rowCount() self.bubble_layout.setRowStretch(num_rows, 1)