211 lines
7.7 KiB
Python
211 lines
7.7 KiB
Python
"""
|
|
Filename: batch_processing.py
|
|
Description: BLAZES batch processor
|
|
|
|
Author: Tyler de Zeeuw
|
|
License: GPL-3.0
|
|
"""
|
|
|
|
# Built-in imports
|
|
import os
|
|
import csv
|
|
from pathlib import Path
|
|
|
|
# External library imports
|
|
import cv2
|
|
import numpy as np
|
|
from ultralytics import YOLO
|
|
|
|
from PySide6.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QLineEdit,
|
|
QPushButton, QComboBox, QSpinBox, QLabel,
|
|
QFileDialog, QTextEdit)
|
|
from PySide6.QtCore import Qt, QObject, Signal, QRunnable, QThreadPool, Slot
|
|
|
|
|
|
JOINT_NAMES = [
|
|
"Nose", "Left Eye", "Right Eye", "Left Ear", "Right Ear",
|
|
"Left Shoulder", "Right Shoulder", "Left Elbow", "Right Elbow",
|
|
"Left Wrist", "Right Wrist", "Left Hip", "Right Hip",
|
|
"Left Knee", "Right Knee", "Left Ankle", "Right Ankle"
|
|
]
|
|
|
|
class WorkerSignals(QObject):
|
|
"""Signals to communicate back to the UI from the thread pool."""
|
|
progress = Signal(str)
|
|
finished = Signal()
|
|
|
|
class VideoWorker(QRunnable):
|
|
"""A worker task for processing a single video file."""
|
|
def __init__(self, video_path, model_name):
|
|
super().__init__()
|
|
self.video_path = video_path
|
|
self.model_name = model_name
|
|
self.signals = WorkerSignals()
|
|
|
|
@Slot()
|
|
def run(self):
|
|
filename = os.path.basename(self.video_path)
|
|
self.signals.progress.emit(f"Starting: {filename}")
|
|
|
|
try:
|
|
cap = cv2.VideoCapture(self.video_path)
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
csv_storage_data = []
|
|
pose_cache = str(Path(self.video_path).with_name(Path(self.video_path).stem + "_pose_raw.csv"))
|
|
|
|
if os.path.exists(pose_cache):
|
|
self.signals.progress.emit(f" - Cache found for {filename}. Skipping inference.")
|
|
else:
|
|
# Instantiate model INSIDE the thread for safety
|
|
model = YOLO(self.model_name)
|
|
prev_track_id = None
|
|
|
|
for i in range(total_frames):
|
|
ret, frame = cap.read()
|
|
if not ret: break
|
|
|
|
results = model.track(frame, persist=True, verbose=False)
|
|
track_id, kp, confs, _ = self.get_best_infant_match(results, width, height, prev_track_id)
|
|
|
|
if kp is not None:
|
|
prev_track_id = track_id
|
|
csv_storage_data.append(np.column_stack((kp, confs)))
|
|
else:
|
|
csv_storage_data.append(np.zeros((17, 3)))
|
|
|
|
if i % 100 == 0:
|
|
p = int((i / total_frames) * 100)
|
|
self.signals.progress.emit(f" - {filename}: {p}%")
|
|
|
|
self._save_pose_cache(pose_cache, csv_storage_data)
|
|
|
|
cap.release()
|
|
self.signals.progress.emit(f"COMPLETED: {filename}")
|
|
except Exception as e:
|
|
self.signals.progress.emit(f"ERROR on {filename}: {str(e)}")
|
|
finally:
|
|
self.signals.finished.emit()
|
|
|
|
def get_best_infant_match(self, results, w, h, prev_track_id):
|
|
if not results[0].boxes or results[0].boxes.id is None:
|
|
return None, None, None, None
|
|
|
|
ids = results[0].boxes.id.int().cpu().tolist()
|
|
kpts = results[0].keypoints.xy.cpu().numpy()
|
|
confs = results[0].keypoints.conf.cpu().numpy()
|
|
|
|
best_idx, best_score = -1, -1
|
|
for i, k in enumerate(kpts):
|
|
vis = np.sum(confs[i] > 0.5)
|
|
valid = k[confs[i] > 0.5]
|
|
dist = np.linalg.norm(np.mean(valid, axis=0) - [w/2, h/2]) if len(valid) > 0 else 1000
|
|
score = (vis * 10) - (dist * 0.1) + (50 if ids[i] == prev_track_id else 0)
|
|
|
|
if score > best_score:
|
|
best_score, best_idx = score, i
|
|
|
|
if best_idx == -1:
|
|
return None, None, None, None
|
|
return ids[best_idx], kpts[best_idx], confs[best_idx], best_idx
|
|
|
|
def _save_pose_cache(self, path, data):
|
|
with open(path, 'w', newline='') as f:
|
|
writer = csv.writer(f)
|
|
header = []
|
|
for joint in JOINT_NAMES:
|
|
header.extend([f"{joint}_x", f"{joint}_y", f"{joint}_conf"])
|
|
writer.writerow(header)
|
|
for frame_data in data:
|
|
writer.writerow(frame_data.flatten())
|
|
|
|
class BatchProcessorDialog(QDialog):
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent)
|
|
self.setWindowTitle("Concurrent Batch Video Processor")
|
|
self.setMinimumSize(600, 500)
|
|
self.threadpool = QThreadPool()
|
|
self.active_workers = 0
|
|
self.setup_ui()
|
|
|
|
def setup_ui(self):
|
|
layout = QVBoxLayout(self)
|
|
|
|
# Folder selection
|
|
f_lay = QHBoxLayout()
|
|
self.folder_edit = QLineEdit()
|
|
btn_browse = QPushButton("Browse")
|
|
btn_browse.clicked.connect(self.get_folder)
|
|
f_lay.addWidget(QLabel("Folder:"))
|
|
f_lay.addWidget(self.folder_edit)
|
|
f_lay.addWidget(btn_browse)
|
|
layout.addLayout(f_lay)
|
|
|
|
# Dropdown for YOLO Model
|
|
self.model_combo = QComboBox()
|
|
# self.model_combo.addItems([
|
|
# "yolov8n-pose.pt", "yolov8s-pose.pt",
|
|
# "yolov8m-pose.pt", "yolov8l-pose.pt", "yolov8x-pose.pt"
|
|
# ])
|
|
self.model_combo.addItems(["yolov8n-pose.pt"])
|
|
layout.addWidget(QLabel("Select YOLO Model:"))
|
|
layout.addWidget(self.model_combo)
|
|
|
|
# Int input for Concurrency
|
|
self.concurrency_spin = QSpinBox()
|
|
self.concurrency_spin.setRange(1, 16)
|
|
self.concurrency_spin.setValue(4)
|
|
layout.addWidget(QLabel("Concurrent Workers (Threads):"))
|
|
layout.addWidget(self.concurrency_spin)
|
|
|
|
# Output Log
|
|
self.log = QTextEdit()
|
|
self.log.setReadOnly(True)
|
|
self.log.setStyleSheet("background-color: #1e1e1e; color: #d4d4d4; font-family: Consolas;")
|
|
layout.addWidget(self.log)
|
|
|
|
# Action Buttons
|
|
self.btn_run = QPushButton("Run Concurrent Batch")
|
|
self.btn_run.setFixedHeight(40)
|
|
self.btn_run.clicked.connect(self.run_batch)
|
|
layout.addWidget(self.btn_run)
|
|
|
|
def get_folder(self):
|
|
path = QFileDialog.getExistingDirectory(self, "Select Folder")
|
|
if path: self.folder_edit.setText(path)
|
|
|
|
def log_msg(self, msg):
|
|
self.log.append(msg)
|
|
self.log.verticalScrollBar().setValue(self.log.verticalScrollBar().maximum())
|
|
|
|
def worker_finished(self):
|
|
self.active_workers -= 1
|
|
if self.active_workers <= 0:
|
|
self.btn_run.setEnabled(True)
|
|
self.log_msg("--- All concurrent tasks finished ---")
|
|
|
|
def run_batch(self):
|
|
folder = self.folder_edit.text()
|
|
if not os.path.isdir(folder):
|
|
self.log_msg("!!! Error: Invalid directory.")
|
|
return
|
|
|
|
video_extensions = ('.mp4', '.avi', '.mov', '.mkv')
|
|
files = [str(f) for f in Path(folder).iterdir() if f.suffix.lower() in video_extensions]
|
|
|
|
if not files:
|
|
self.log_msg("No videos found.")
|
|
return
|
|
|
|
self.btn_run.setEnabled(False)
|
|
self.active_workers = len(files)
|
|
self.threadpool.setMaxThreadCount(self.concurrency_spin.value())
|
|
|
|
self.log_msg(f"Queueing {len(files)} videos...")
|
|
for f_path in files:
|
|
worker = VideoWorker(f_path, self.model_combo.currentText())
|
|
worker.signals.progress.connect(self.log_msg)
|
|
worker.signals.finished.connect(self.worker_finished)
|
|
self.threadpool.start(worker) |