154 lines
5.1 KiB
Python
154 lines
5.1 KiB
Python
import cv2
|
|
import os
|
|
import csv
|
|
import numpy as np
|
|
from ultralytics import YOLO
|
|
from multiprocessing import current_process
|
|
|
|
JOINT_NAMES = [
|
|
"nose", "l_eye", "r_eye", "l_ear", "r_ear", "l_shld", "r_shld",
|
|
"l_elbw", "r_elbw", "l_wri", "r_wri", "l_hip", "r_hip",
|
|
"l_knee", "r_knee", "l_ankl", "r_ankl"
|
|
]
|
|
|
|
def get_best_infant_match(results, w, h, prev_track_id):
|
|
"""
|
|
Identifies the most likely infant based on visibility,
|
|
centrality, and tracking ID consistency.
|
|
"""
|
|
if not results[0].boxes or results[0].boxes.id is None:
|
|
return None, None, None, None
|
|
|
|
ids = results[0].boxes.id.int().cpu().tolist()
|
|
kpts = results[0].keypoints.xy.cpu().numpy()
|
|
confs = results[0].keypoints.conf.cpu().numpy()
|
|
|
|
best_idx, best_score = -1, -1
|
|
|
|
for i, k in enumerate(kpts):
|
|
# visibility score
|
|
vis = np.sum(confs[i] > 0.5)
|
|
valid = k[confs[i] > 0.5]
|
|
|
|
# distance from center score
|
|
dist = np.linalg.norm(np.mean(valid, axis=0) - [w/2, h/2]) if len(valid) > 0 else 1000
|
|
|
|
# calculate total score
|
|
score = (vis * 10) - (dist * 0.1) + (50 if ids[i] == prev_track_id else 0)
|
|
|
|
if score > best_score:
|
|
best_score, best_idx = score, i
|
|
|
|
if best_idx == -1:
|
|
return None, None, None, None
|
|
|
|
return ids[best_idx], kpts[best_idx], confs[best_idx], best_idx
|
|
|
|
|
|
|
|
def run_pose_analysis(video_path, progress_queue, result_queue, config):
|
|
"""Worker task executed in a separate Process."""
|
|
p_name = current_process().name
|
|
pose_cache = video_path.rsplit('.', 1)[0] + "_pose_raw.csv"
|
|
print(f"[{p_name}] Starting analysis on: {video_path}")
|
|
csv_storage_data = []
|
|
inference_performed = False
|
|
|
|
cap = cv2.VideoCapture(video_path)
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
cap.release()
|
|
|
|
|
|
if config.get("use_cache") and os.path.exists(pose_cache):
|
|
print(f"[{p_name}] Cache checkmark active. Loading: {pose_cache}")
|
|
try:
|
|
with open(pose_cache, 'r') as f:
|
|
reader = csv.reader(f)
|
|
next(reader) # Skip header
|
|
for row in reader:
|
|
# Flattened (51,) back to (17, 3)
|
|
full_frame = np.array([float(x) for x in row]).reshape(17, 3)
|
|
csv_storage_data.append(full_frame)
|
|
|
|
progress_queue.put(100)
|
|
result_queue.put({
|
|
"raw_kpts": np.array(csv_storage_data),
|
|
"fps": fps,
|
|
"total_frames": len(csv_storage_data),
|
|
"dims": (width, height),
|
|
"status": "loaded_from_cache"
|
|
})
|
|
return # Exit early, no inference or saving needed
|
|
except Exception as e:
|
|
print(f"[{p_name}] Cache read failed, falling back to inference: {e}")
|
|
csv_storage_data = []
|
|
|
|
|
|
|
|
inference_performed = True
|
|
cap = cv2.VideoCapture(video_path)
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
|
|
model_map = {
|
|
"YOLO8n-Pose": "yolov8n-pose.pt",
|
|
"YOLO8m-Pose": "yolov8m-pose.pt",
|
|
"Mediapipe BlazePose": "mediapipe"
|
|
}
|
|
model_file = model_map.get(config.get("pose_model"), "yolov8n-pose.pt")
|
|
|
|
print(f"[{p_name}] Running inference with model: {model_file}")
|
|
model = YOLO(model_file)
|
|
|
|
new_csv_storage_data = []
|
|
prev_track_id = None
|
|
|
|
for i in range(total_frames):
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
results = model.track(frame, persist=True, verbose=False)
|
|
|
|
track_id, kp, confs, _ = get_best_infant_match(results, width, height, prev_track_id)
|
|
|
|
if kp is not None:
|
|
prev_track_id = track_id
|
|
# Store as (17, 3) including confidence
|
|
new_csv_storage_data.append(np.column_stack((kp, confs)))
|
|
else:
|
|
new_csv_storage_data.append(np.zeros((17, 3)))
|
|
|
|
if i % 10 == 0:
|
|
progress_queue.put(int((i / total_frames) * 100))
|
|
|
|
cap.release()
|
|
|
|
if inference_performed:
|
|
print(f"[{p_name}] Saving new pose cache to {pose_cache}")
|
|
try:
|
|
with open(pose_cache, 'w', newline='') as f:
|
|
writer = csv.writer(f)
|
|
header = []
|
|
for joint in JOINT_NAMES:
|
|
header.extend([f"{joint}_x", f"{joint}_y", f"{joint}_conf"])
|
|
writer.writerow(header)
|
|
for frame_array in new_csv_storage_data:
|
|
writer.writerow(frame_array.flatten())
|
|
except Exception as e:
|
|
print(f"[{p_name}] Error saving cache: {e}")
|
|
|
|
# Return results through the queue
|
|
result_queue.put({
|
|
"raw_kpts": np.array(new_csv_storage_data),
|
|
"fps": fps,
|
|
"total_frames": len(new_csv_storage_data),
|
|
"dims": (width, height),
|
|
"status": "inference_complete"
|
|
})
|
|
|
|
print(f"[{p_name}] Analysis complete.") |