Files
blazes/batch_processing.py
2026-03-13 20:53:20 -07:00

211 lines
7.7 KiB
Python

"""
Filename: batch_processing.py
Description: BLAZES batch processor
Author: Tyler de Zeeuw
License: GPL-3.0
"""
# Built-in imports
import os
import csv
from pathlib import Path
# External library imports
import cv2
import numpy as np
from ultralytics import YOLO
from PySide6.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QLineEdit,
QPushButton, QComboBox, QSpinBox, QLabel,
QFileDialog, QTextEdit)
from PySide6.QtCore import Qt, QObject, Signal, QRunnable, QThreadPool, Slot
JOINT_NAMES = [
"Nose", "Left Eye", "Right Eye", "Left Ear", "Right Ear",
"Left Shoulder", "Right Shoulder", "Left Elbow", "Right Elbow",
"Left Wrist", "Right Wrist", "Left Hip", "Right Hip",
"Left Knee", "Right Knee", "Left Ankle", "Right Ankle"
]
class WorkerSignals(QObject):
"""Signals to communicate back to the UI from the thread pool."""
progress = Signal(str)
finished = Signal()
class VideoWorker(QRunnable):
"""A worker task for processing a single video file."""
def __init__(self, video_path, model_name):
super().__init__()
self.video_path = video_path
self.model_name = model_name
self.signals = WorkerSignals()
@Slot()
def run(self):
filename = os.path.basename(self.video_path)
self.signals.progress.emit(f"Starting: {filename}")
try:
cap = cv2.VideoCapture(self.video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
csv_storage_data = []
pose_cache = str(Path(self.video_path).with_name(Path(self.video_path).stem + "_pose_raw.csv"))
if os.path.exists(pose_cache):
self.signals.progress.emit(f" - Cache found for {filename}. Skipping inference.")
else:
# Instantiate model INSIDE the thread for safety
model = YOLO(self.model_name)
prev_track_id = None
for i in range(total_frames):
ret, frame = cap.read()
if not ret: break
results = model.track(frame, persist=True, verbose=False)
track_id, kp, confs, _ = self.get_best_infant_match(results, width, height, prev_track_id)
if kp is not None:
prev_track_id = track_id
csv_storage_data.append(np.column_stack((kp, confs)))
else:
csv_storage_data.append(np.zeros((17, 3)))
if i % 100 == 0:
p = int((i / total_frames) * 100)
self.signals.progress.emit(f" - {filename}: {p}%")
self._save_pose_cache(pose_cache, csv_storage_data)
cap.release()
self.signals.progress.emit(f"COMPLETED: {filename}")
except Exception as e:
self.signals.progress.emit(f"ERROR on {filename}: {str(e)}")
finally:
self.signals.finished.emit()
def get_best_infant_match(self, results, w, h, prev_track_id):
if not results[0].boxes or results[0].boxes.id is None:
return None, None, None, None
ids = results[0].boxes.id.int().cpu().tolist()
kpts = results[0].keypoints.xy.cpu().numpy()
confs = results[0].keypoints.conf.cpu().numpy()
best_idx, best_score = -1, -1
for i, k in enumerate(kpts):
vis = np.sum(confs[i] > 0.5)
valid = k[confs[i] > 0.5]
dist = np.linalg.norm(np.mean(valid, axis=0) - [w/2, h/2]) if len(valid) > 0 else 1000
score = (vis * 10) - (dist * 0.1) + (50 if ids[i] == prev_track_id else 0)
if score > best_score:
best_score, best_idx = score, i
if best_idx == -1:
return None, None, None, None
return ids[best_idx], kpts[best_idx], confs[best_idx], best_idx
def _save_pose_cache(self, path, data):
with open(path, 'w', newline='') as f:
writer = csv.writer(f)
header = []
for joint in JOINT_NAMES:
header.extend([f"{joint}_x", f"{joint}_y", f"{joint}_conf"])
writer.writerow(header)
for frame_data in data:
writer.writerow(frame_data.flatten())
class BatchProcessorDialog(QDialog):
def __init__(self, parent=None):
super().__init__(parent)
self.setWindowTitle("Concurrent Batch Video Processor")
self.setMinimumSize(600, 500)
self.threadpool = QThreadPool()
self.active_workers = 0
self.setup_ui()
def setup_ui(self):
layout = QVBoxLayout(self)
# Folder selection
f_lay = QHBoxLayout()
self.folder_edit = QLineEdit()
btn_browse = QPushButton("Browse")
btn_browse.clicked.connect(self.get_folder)
f_lay.addWidget(QLabel("Folder:"))
f_lay.addWidget(self.folder_edit)
f_lay.addWidget(btn_browse)
layout.addLayout(f_lay)
# Dropdown for YOLO Model
self.model_combo = QComboBox()
# self.model_combo.addItems([
# "yolov8n-pose.pt", "yolov8s-pose.pt",
# "yolov8m-pose.pt", "yolov8l-pose.pt", "yolov8x-pose.pt"
# ])
self.model_combo.addItems(["yolov8n-pose.pt"])
layout.addWidget(QLabel("Select YOLO Model:"))
layout.addWidget(self.model_combo)
# Int input for Concurrency
self.concurrency_spin = QSpinBox()
self.concurrency_spin.setRange(1, 16)
self.concurrency_spin.setValue(4)
layout.addWidget(QLabel("Concurrent Workers (Threads):"))
layout.addWidget(self.concurrency_spin)
# Output Log
self.log = QTextEdit()
self.log.setReadOnly(True)
self.log.setStyleSheet("background-color: #1e1e1e; color: #d4d4d4; font-family: Consolas;")
layout.addWidget(self.log)
# Action Buttons
self.btn_run = QPushButton("Run Concurrent Batch")
self.btn_run.setFixedHeight(40)
self.btn_run.clicked.connect(self.run_batch)
layout.addWidget(self.btn_run)
def get_folder(self):
path = QFileDialog.getExistingDirectory(self, "Select Folder")
if path: self.folder_edit.setText(path)
def log_msg(self, msg):
self.log.append(msg)
self.log.verticalScrollBar().setValue(self.log.verticalScrollBar().maximum())
def worker_finished(self):
self.active_workers -= 1
if self.active_workers <= 0:
self.btn_run.setEnabled(True)
self.log_msg("--- All concurrent tasks finished ---")
def run_batch(self):
folder = self.folder_edit.text()
if not os.path.isdir(folder):
self.log_msg("!!! Error: Invalid directory.")
return
video_extensions = ('.mp4', '.avi', '.mov', '.mkv')
files = [str(f) for f in Path(folder).iterdir() if f.suffix.lower() in video_extensions]
if not files:
self.log_msg("No videos found.")
return
self.btn_run.setEnabled(False)
self.active_workers = len(files)
self.threadpool.setMaxThreadCount(self.concurrency_spin.value())
self.log_msg(f"Queueing {len(files)} videos...")
for f_path in files:
worker = VideoWorker(f_path, self.model_combo.currentText())
worker.signals.progress.connect(self.log_msg)
worker.signals.finished.connect(self.worker_finished)
self.threadpool.start(worker)