""" Filename: batch_processing.py Description: BLAZES batch processor Author: Tyler de Zeeuw License: GPL-3.0 """ # Built-in imports import os import csv from pathlib import Path # External library imports import cv2 import numpy as np from ultralytics import YOLO from PySide6.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QLineEdit, QPushButton, QComboBox, QSpinBox, QLabel, QFileDialog, QTextEdit) from PySide6.QtCore import Qt, QObject, Signal, QRunnable, QThreadPool, Slot JOINT_NAMES = [ "Nose", "Left Eye", "Right Eye", "Left Ear", "Right Ear", "Left Shoulder", "Right Shoulder", "Left Elbow", "Right Elbow", "Left Wrist", "Right Wrist", "Left Hip", "Right Hip", "Left Knee", "Right Knee", "Left Ankle", "Right Ankle" ] class WorkerSignals(QObject): """Signals to communicate back to the UI from the thread pool.""" progress = Signal(str) finished = Signal() class VideoWorker(QRunnable): """A worker task for processing a single video file.""" def __init__(self, video_path, model_name): super().__init__() self.video_path = video_path self.model_name = model_name self.signals = WorkerSignals() @Slot() def run(self): filename = os.path.basename(self.video_path) self.signals.progress.emit(f"Starting: {filename}") try: cap = cv2.VideoCapture(self.video_path) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) csv_storage_data = [] pose_cache = str(Path(self.video_path).with_name(Path(self.video_path).stem + "_pose_raw.csv")) if os.path.exists(pose_cache): self.signals.progress.emit(f" - Cache found for {filename}. Skipping inference.") else: # Instantiate model INSIDE the thread for safety model = YOLO(self.model_name) prev_track_id = None for i in range(total_frames): ret, frame = cap.read() if not ret: break results = model.track(frame, persist=True, verbose=False) track_id, kp, confs, _ = self.get_best_infant_match(results, width, height, prev_track_id) if kp is not None: prev_track_id = track_id csv_storage_data.append(np.column_stack((kp, confs))) else: csv_storage_data.append(np.zeros((17, 3))) if i % 100 == 0: p = int((i / total_frames) * 100) self.signals.progress.emit(f" - {filename}: {p}%") self._save_pose_cache(pose_cache, csv_storage_data) cap.release() self.signals.progress.emit(f"COMPLETED: {filename}") except Exception as e: self.signals.progress.emit(f"ERROR on {filename}: {str(e)}") finally: self.signals.finished.emit() def get_best_infant_match(self, results, w, h, prev_track_id): if not results[0].boxes or results[0].boxes.id is None: return None, None, None, None ids = results[0].boxes.id.int().cpu().tolist() kpts = results[0].keypoints.xy.cpu().numpy() confs = results[0].keypoints.conf.cpu().numpy() best_idx, best_score = -1, -1 for i, k in enumerate(kpts): vis = np.sum(confs[i] > 0.5) valid = k[confs[i] > 0.5] dist = np.linalg.norm(np.mean(valid, axis=0) - [w/2, h/2]) if len(valid) > 0 else 1000 score = (vis * 10) - (dist * 0.1) + (50 if ids[i] == prev_track_id else 0) if score > best_score: best_score, best_idx = score, i if best_idx == -1: return None, None, None, None return ids[best_idx], kpts[best_idx], confs[best_idx], best_idx def _save_pose_cache(self, path, data): with open(path, 'w', newline='') as f: writer = csv.writer(f) header = [] for joint in JOINT_NAMES: header.extend([f"{joint}_x", f"{joint}_y", f"{joint}_conf"]) writer.writerow(header) for frame_data in data: writer.writerow(frame_data.flatten()) class BatchProcessorDialog(QDialog): def __init__(self, parent=None): super().__init__(parent) self.setWindowTitle("Concurrent Batch Video Processor") self.setMinimumSize(600, 500) self.threadpool = QThreadPool() self.active_workers = 0 self.setup_ui() def setup_ui(self): layout = QVBoxLayout(self) # Folder selection f_lay = QHBoxLayout() self.folder_edit = QLineEdit() btn_browse = QPushButton("Browse") btn_browse.clicked.connect(self.get_folder) f_lay.addWidget(QLabel("Folder:")) f_lay.addWidget(self.folder_edit) f_lay.addWidget(btn_browse) layout.addLayout(f_lay) # Dropdown for YOLO Model self.model_combo = QComboBox() # self.model_combo.addItems([ # "yolov8n-pose.pt", "yolov8s-pose.pt", # "yolov8m-pose.pt", "yolov8l-pose.pt", "yolov8x-pose.pt" # ]) self.model_combo.addItems(["yolov8n-pose.pt"]) layout.addWidget(QLabel("Select YOLO Model:")) layout.addWidget(self.model_combo) # Int input for Concurrency self.concurrency_spin = QSpinBox() self.concurrency_spin.setRange(1, 16) self.concurrency_spin.setValue(4) layout.addWidget(QLabel("Concurrent Workers (Threads):")) layout.addWidget(self.concurrency_spin) # Output Log self.log = QTextEdit() self.log.setReadOnly(True) self.log.setStyleSheet("background-color: #1e1e1e; color: #d4d4d4; font-family: Consolas;") layout.addWidget(self.log) # Action Buttons self.btn_run = QPushButton("Run Concurrent Batch") self.btn_run.setFixedHeight(40) self.btn_run.clicked.connect(self.run_batch) layout.addWidget(self.btn_run) def get_folder(self): path = QFileDialog.getExistingDirectory(self, "Select Folder") if path: self.folder_edit.setText(path) def log_msg(self, msg): self.log.append(msg) self.log.verticalScrollBar().setValue(self.log.verticalScrollBar().maximum()) def worker_finished(self): self.active_workers -= 1 if self.active_workers <= 0: self.btn_run.setEnabled(True) self.log_msg("--- All concurrent tasks finished ---") def run_batch(self): folder = self.folder_edit.text() if not os.path.isdir(folder): self.log_msg("!!! Error: Invalid directory.") return video_extensions = ('.mp4', '.avi', '.mov', '.mkv') files = [str(f) for f in Path(folder).iterdir() if f.suffix.lower() in video_extensions] if not files: self.log_msg("No videos found.") return self.btn_run.setEnabled(False) self.active_workers = len(files) self.threadpool.setMaxThreadCount(self.concurrency_spin.value()) self.log_msg(f"Queueing {len(files)} videos...") for f_path in files: worker = VideoWorker(f_path, self.model_combo.currentText()) worker.signals.progress.connect(self.log_msg) worker.signals.finished.connect(self.worker_finished) self.threadpool.start(worker)