improvements

This commit is contained in:
2026-03-13 21:41:47 -07:00
parent 95f68e0c12
commit 60735b0d38
2 changed files with 239 additions and 15 deletions

1
.gitignore vendored
View File

@@ -175,3 +175,4 @@ cython_debug/
.pypirc .pypirc
sparks_*/ sparks_*/
*.pth

247
main.py
View File

@@ -45,12 +45,15 @@ import matplotlib.pyplot as plt
from PySide6.QtWidgets import ( from PySide6.QtWidgets import (
QApplication, QWidget, QMessageBox, QVBoxLayout, QHBoxLayout, QTextEdit, QScrollArea, QComboBox, QGridLayout, QApplication, QWidget, QMessageBox, QVBoxLayout, QHBoxLayout, QTextEdit, QScrollArea, QComboBox, QGridLayout,
QPushButton, QMainWindow, QFileDialog, QLabel, QLineEdit, QFrame, QSizePolicy, QGroupBox, QDialog, QListView, QPushButton, QMainWindow, QFileDialog, QLabel, QLineEdit, QFrame, QSizePolicy, QGroupBox, QDialog, QListView,
QMenu, QProgressBar, QCheckBox, QSlider, QTabWidget, QTreeWidget, QTreeWidgetItem, QHeaderView QMenu, QProgressBar, QCheckBox, QSlider, QTabWidget, QTreeWidget, QTreeWidgetItem, QHeaderView, QInputDialog
) )
from PySide6.QtCore import QThread, Signal, Qt, QTimer, QEvent, QSize, QPoint from PySide6.QtCore import QThread, Signal, Qt, QTimer, QEvent, QSize, QPoint
from PySide6.QtGui import QAction, QKeySequence, QIcon, QIntValidator, QDoubleValidator, QPixmap, QStandardItemModel, QStandardItem, QImage from PySide6.QtGui import QAction, QKeySequence, QIcon, QIntValidator, QDoubleValidator, QPixmap, QStandardItemModel, QStandardItem, QImage
from PySide6.QtSvgWidgets import QSvgWidget # needed to show svgs when app is not frozen from PySide6.QtSvgWidgets import QSvgWidget # needed to show svgs when app is not frozen
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
import pyqtgraph as pg
CURRENT_VERSION = "1.0.0" CURRENT_VERSION = "1.0.0"
@@ -89,13 +92,21 @@ class TrainModelThread(QThread):
# Load CSVs # Load CSVs
X_list, y_list, feature_cols = [], [], None X_list, y_list, feature_cols = [], [], None
# for path in self.csv_paths:
# df = pd.read_csv(path)
# if feature_cols is None:
# feature_cols = [c for c in df.columns if c.startswith("lm")]
# df[feature_cols] = df[feature_cols].ffill().bfill()
# X_list.append(df[feature_cols].values)
# y_list.append(df[["reach_active", "reach_before_contact"]].values)
for path in self.csv_paths: for path in self.csv_paths:
df = pd.read_csv(path) df = pd.read_csv(path)
if feature_cols is None: if feature_cols is None:
feature_cols = [c for c in df.columns if c.startswith("lm")] feature_cols = [c for c in df.columns if c.startswith("h0")]
df[feature_cols] = df[feature_cols].ffill().bfill() df[feature_cols] = df[feature_cols].ffill().bfill()
X_list.append(df[feature_cols].values) X_list.append(df[feature_cols].values)
y_list.append(df[["reach_active", "reach_before_contact"]].values) y_list.append(df[["current_event"]].values)
X = np.concatenate(X_list, axis=0) X = np.concatenate(X_list, axis=0)
y = np.concatenate(y_list, axis=0) y = np.concatenate(y_list, axis=0)
@@ -117,7 +128,7 @@ class TrainModelThread(QThread):
y_windows = np.array(y_windows) y_windows = np.array(y_windows)
class_weights = [] class_weights = []
for i in range(2): for i in range(1):
cw = compute_class_weight('balanced', classes=np.array([0,1]), y=y_windows[:, i]) cw = compute_class_weight('balanced', classes=np.array([0,1]), y=y_windows[:, i])
class_weights.append(cw[1]/cw[0]) class_weights.append(cw[1]/cw[0])
pos_weight = torch.tensor(class_weights, dtype=torch.float32) pos_weight = torch.tensor(class_weights, dtype=torch.float32)
@@ -126,7 +137,7 @@ class TrainModelThread(QThread):
# LSTM # LSTM
class WindowLSTM(nn.Module): class WindowLSTM(nn.Module):
def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2): def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=1):
super().__init__() super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True, self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True,
bidirectional=bidirectional) bidirectional=bidirectional)
@@ -1460,7 +1471,7 @@ class ParticipantProcessor2(QThread):
finished_processing = Signal(str, str) # obs_id, cam_id finished_processing = Signal(str, str) # obs_id, cam_id
def __init__(self, obs_id, selected_cam_id, selected_hand_idx, def __init__(self, obs_id, selected_cam_id, selected_hand_idx,
video_path, output_csv, output_dir, initial_wrists, **kwargs): video_path, output_csv, output_dir, initial_wrists, json_path=None, boris_obs_key=None, **kwargs):
super().__init__() super().__init__()
self.obs_id = obs_id self.obs_id = obs_id
self.cam_id = selected_cam_id self.cam_id = selected_cam_id
@@ -1468,6 +1479,8 @@ class ParticipantProcessor2(QThread):
self.video_path = video_path self.video_path = video_path
self.output_dir = output_dir self.output_dir = output_dir
self.output_csv = output_csv self.output_csv = output_csv
self.json_path = json_path
self.boris_obs_key = boris_obs_key
self.is_running = True self.is_running = True
# Convert initial_wrists (list) to initial_centroids (dict) for the tracker # Convert initial_wrists (list) to initial_centroids (dict) for the tracker
@@ -1532,7 +1545,29 @@ class ParticipantProcessor2(QThread):
save_path = os.path.join(self.output_dir, self.output_csv) save_path = os.path.join(self.output_dir, self.output_csv)
header = ['frame'] event_list = []
camera_offset = 0.0
if self.json_path and self.boris_obs_key:
try:
with open(self.json_path, 'r') as f:
full_data = json.load(f)
# Use the specific key selected from the dropdown
obs_data = full_data.get("observations", {}).get(self.boris_obs_key, {})
# Get events: [time, ?, label, description, ...]
event_list = obs_data.get("events", [])
# Sort events by time to ensure lookup works
event_list.sort(key=lambda x: x[0])
# Get camera offset
media_info = obs_data.get("media_info", {})
offsets = media_info.get("offset", {})
camera_offset = float(offsets.get(self.cam_id, 0.0))
except Exception as e:
print(f"Error loading JSON events: {e}")
header = ['frame', 'time_sec', 'current_event']
if self.selected_hand_idx == 99: if self.selected_hand_idx == 99:
# Dual hand columns: h0_x0, h0_y0 ... h1_x20, h1_y20 # Dual hand columns: h0_x0, h0_y0 ... h1_x20, h1_y20
for h_prefix in ['h0', 'h1']: for h_prefix in ['h0', 'h1']:
@@ -1543,6 +1578,11 @@ class ParticipantProcessor2(QThread):
for i in range(21): for i in range(21):
header.extend([f'x{i}', f'y{i}']) header.extend([f'x{i}', f'y{i}'])
current_ev_idx = 0
reach_active = 0 # This will be written as 1 or 0
in_reach_phase = False
with open(save_path, 'w', newline='') as f: with open(save_path, 'w', newline='') as f:
writer = csv.writer(f) writer = csv.writer(f)
writer.writerow(header) writer.writerow(header)
@@ -1559,6 +1599,30 @@ class ParticipantProcessor2(QThread):
results = detector.detect(mp_image) results = detector.detect(mp_image)
frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
current_time = frame_idx / fps
current_event_label = "None"
adjusted_time = current_time + camera_offset
while (current_ev_idx < len(event_list) and
event_list[current_ev_idx][0] <= adjusted_time):
event_label = str(event_list[current_ev_idx][2]).strip()
# 1. Start on the FIRST "Reach"
if event_label == "Reach" and not in_reach_phase:
in_reach_phase = True
reach_active = 1
# 2. Stop on "End"
elif event_label == "End":
in_reach_phase = False
reach_active = 0
# Note: If event_label is the second "Reach", we do nothing.
# It hits this 'else' (implicitly) and we just move the index forward.
current_ev_idx += 1
# current_mapping stores {hand_id: landmark_list} for this frame # current_mapping stores {hand_id: landmark_list} for this frame
current_mapping = {} current_mapping = {}
@@ -1603,7 +1667,7 @@ class ParticipantProcessor2(QThread):
last_known_centroids[hand_id] = self.get_centroid(landmarks) last_known_centroids[hand_id] = self.get_centroid(landmarks)
# --- STEP 4: WRITE TO CSV --- # --- STEP 4: WRITE TO CSV ---
row = [frame_idx] row = [frame_idx, round(current_time, 4), reach_active]
if self.selected_hand_idx == 99: if self.selected_hand_idx == 99:
# DUAL MODE: We expect Hand 0 and Hand 1 # DUAL MODE: We expect Hand 0 and Hand 1
@@ -1708,6 +1772,7 @@ class HandValidationWindow(QWidget):
self.inspector = HandDataInspector() self.inspector = HandDataInspector()
self.inspector.show() self.inspector.show()
self.frame_counter = 0
# Logic state # Logic state
self.timer = QTimer() self.timer = QTimer()
@@ -1721,11 +1786,18 @@ class HandValidationWindow(QWidget):
'h0': deque(maxlen=30), 'h0': deque(maxlen=30),
'h1': deque(maxlen=30) 'h1': deque(maxlen=30)
} }
self.csv_graph_window = None
def load_files(self): def load_files(self):
video_path, _ = QFileDialog.getOpenFileName(self, "Select Video", "", "Videos (*.mp4 *.avi)") video_path, _ = QFileDialog.getOpenFileName(self, "Select Video", "", "Videos (*.mp4 *.avi)")
csv_path, _ = QFileDialog.getOpenFileName(self, "Select CSV", "", "CSV Files (*.csv)") csv_path, _ = QFileDialog.getOpenFileName(self, "Select CSV", "", "CSV Files (*.csv)")
# --- fNIRS CSV (52 channels, synced plot) ---
csv_52_path, _ = QFileDialog.getOpenFileName(self, "Select 52-Channel CSV", "", "CSV Files (*.csv)")
# --- JSON (alignment info / time_shift) ---
json_path, _ = QFileDialog.getOpenFileName(self, "Select JSON", "", "JSON Files (*.json)")
if video_path and csv_path: if video_path and csv_path:
self.cap = cv2.VideoCapture(video_path) self.cap = cv2.VideoCapture(video_path)
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
@@ -1733,8 +1805,30 @@ class HandValidationWindow(QWidget):
# Load the dual-hand data # Load the dual-hand data
self.hand_data = self.parse_dual_hand_csv(csv_path) self.hand_data = self.parse_dual_hand_csv(csv_path)
self.csv_52 = pd.read_csv(csv_52_path)
self.csv_52_time_col = self.csv_52.columns[0]
with open(json_path, "r") as f:
self.json_data = json.load(f)
# Extract values dynamically but hardcode the '2' key
boris_anchor = self.json_data.get("boris_anchor", {}).get("time", 0.0) # Accessing video 2 delay specifically
v2_delay = self.json_data.get("videos", {}).get("2", {}).get("delay", 0.0)
print(f"Syncing with Video 2: Shift {boris_anchor}, Delay {v2_delay}")
# Initialize the window with these values
self.csv_graph_window = CSVGraphWindow(
csv_52_path,
boris_anchor=boris_anchor,
video_delay=v2_delay
)
self.csv_graph_window.show()
self.timer.start(16) self.timer.start(16)
def update_speed(self): def update_speed(self):
speed_map = {"0.25x": 64, "0.5x": 32, "1.0x": 16, "2.0x": 8} speed_map = {"0.25x": 64, "0.5x": 32, "1.0x": 16, "2.0x": 8}
ms = speed_map.get(self.speed_combo.currentText(), 16) ms = speed_map.get(self.speed_combo.currentText(), 16)
@@ -1792,7 +1886,7 @@ class HandValidationWindow(QWidget):
ret, frame = self.cap.read() ret, frame = self.cap.read()
if not ret: return if not ret: return
self.frame_counter += 1
f_idx = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES)) f_idx = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES))
self.slider.setValue(f_idx) self.slider.setValue(f_idx)
self.lbl_frame.setText(f"Frame: {f_idx}") self.lbl_frame.setText(f"Frame: {f_idx}")
@@ -1822,13 +1916,16 @@ class HandValidationWindow(QWidget):
qimg = QImage(rgb.data, w, h, 3 * w, QImage.Format_RGB888) qimg = QImage(rgb.data, w, h, 3 * w, QImage.Format_RGB888)
display_size = self.video_label.contentsRect().size() display_size = self.video_label.contentsRect().size()
mode = Qt.TransformationMode.SmoothTransformation if self.paused else Qt.TransformationMode.FastTransformation
self.video_label.setPixmap( self.video_label.setPixmap(
QPixmap.fromImage(qimg).scaled( QPixmap.fromImage(qimg).scaled(
display_size, display_size,
Qt.AspectRatioMode.KeepAspectRatio, Qt.AspectRatioMode.KeepAspectRatio,
Qt.TransformationMode.SmoothTransformation # Added for better quality mode
) )
) )
if self.csv_graph_window and self.frame_counter % 2 == 0:
self.csv_graph_window.update_plot(f_idx / fps)
def draw_skeleton(self, img, landmarks, w, h, color, h_key, prev_landmarks=None, fps=30): def draw_skeleton(self, img, landmarks, w, h, color, h_key, prev_landmarks=None, fps=30):
# --- 1. Draw Skeleton Connections --- # --- 1. Draw Skeleton Connections ---
@@ -1886,6 +1983,77 @@ class HandValidationWindow(QWidget):
cv2.arrowedLine(img, com_px, smooth_end_px, (255, 255, 255), 5, tipLength=0.3) cv2.arrowedLine(img, com_px, smooth_end_px, (255, 255, 255), 5, tipLength=0.3)
class CSVGraphWindow(QWidget):
def __init__(self, csv_path, boris_anchor=0, video_delay=0, window_sec=15.0):
super().__init__()
self.setWindowTitle("High-Speed Data Viewer")
self.resize(1200, 400)
self.main_layout = QVBoxLayout(self)
self.boris_anchor = boris_anchor
self.video_delay = video_delay
self.window_sec = window_sec
self.current_page = -1
# 1. Load and downsample for extreme speed
self.df = pd.read_csv(csv_path)
# Plotting every 4th point is visually identical but 4x faster
self.df = self.df.iloc[::4, :].reset_index(drop=True)
self.time_col = self.df.columns[0]
self.anno_col = self.df.columns[1]
self.channels = list(self.df.columns[2:])
# 2. Create PyQtGraph Plot
self.plot_widget = pg.PlotWidget()
self.plot_widget.setBackground('w') # White background
self.main_layout.addWidget(self.plot_widget)
# 3. Add Shaded Annotations (Boxes)
self._add_pg_annotations()
# 4. Plot Channels (PyQtGraph is very fast at this)
for i, ch in enumerate(self.channels):
# Using a single color or rotating colors
color = pg.intColor(i, hues=10, values=1, alpha=150)
self.plot_widget.plot(self.df[self.time_col], self.df[ch], pen=pg.mkPen(color, width=1))
# 5. Add Playhead (InfiniteLine is much faster than axvline)
self.playhead = pg.InfiniteLine(pos=0, angle=90, pen=pg.mkPen('r', width=2))
self.plot_widget.addItem(self.playhead)
self.plot_widget.setXRange(0, self.window_sec)
def _add_pg_annotations(self):
labels = self.df[self.anno_col].fillna("None").astype(str)
changes = labels != labels.shift()
indices = labels.index[changes].tolist() + [len(labels)]
for i in range(len(indices) - 1):
idx = indices[i]
label = labels.iloc[idx]
if label not in ["None", "0", "nan"]:
t1 = self.df[self.time_col].iloc[idx]
t2 = self.df[self.time_col].iloc[indices[i+1]-1]
# LinearRegionItem is great for shading
region = pg.LinearRegionItem(values=(t1, t2), brush=pg.mkBrush(0, 0, 255, 30), movable=False)
self.plot_widget.addItem(region)
def update_plot(self, video_time):
aligned_time = (video_time + self.video_delay - self.boris_anchor) + 5
# 1. Update playhead position (Instantly fast in PyQtGraph)
self.playhead.setPos(aligned_time)
# 2. Page flipping logic to avoid constant axis movement
page_number = int(aligned_time // self.window_sec)
if page_number != self.current_page:
new_start = page_number * self.window_sec
self.plot_widget.setXRange(new_start, new_start + self.window_sec, padding=0)
self.current_page = page_number
class HandDataInspector(QWidget): class HandDataInspector(QWidget):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@@ -2717,6 +2885,57 @@ class MainApplication(QMainWindow):
self.extract_frame_and_hands # Assuming this is a method available to MainApplication self.extract_frame_and_hands # Assuming this is a method available to MainApplication
) )
msg_box = QMessageBox()
msg_box.setWindowTitle("Action Required")
msg_box.setText("BORIS?")
msg_box.setStandardButtons(QMessageBox.Yes | QMessageBox.No)
msg_box.setDefaultButton(QMessageBox.Yes)
response = msg_box.exec()
if response == QMessageBox.Yes:
# 3. Open File Dialog
self.selected_boris_json_path, _ = QFileDialog.getOpenFileName(
None,
"Select File",
"",
"All Files (*);;Text Files (*.txt)"
)
try:
with open(self.selected_boris_json_path, 'r') as f:
data = json.load(f)
# 3. Extract keys from "observation"
observations = data.get("observations", {})
if not observations:
QMessageBox.warning(None, "Error", "No 'observation' key found or it is empty.")
return
keys = list(observations.keys())
# 4. Open Dropdown Dialog (QInputDialog)
selected_key, ok = QInputDialog.getItem(
None, "Select Key", "Pick an observation key:", keys, 0, False
)
# 5. Store the result
if ok and selected_key:
self.selected_boris_key = selected_key
print(f"Successfully stored: {self.selected_boris_key}")
QMessageBox.information(None, "Success", f"Stored: {self.selected_boris_key}")
except json.JSONDecodeError:
QMessageBox.critical(None, "Error", "Invalid JSON format.")
except Exception as e:
QMessageBox.critical(None, "Error", f"An error occurred: {e}")
else:
print("File selection cancelled.")
else:
# If 'No' or closed, do nothing
print("User declined.")
# 4. Connect Signals to Main Thread Slots # 4. Connect Signals to Main Thread Slots
self.worker_thread.observations_loaded.connect(self.on_individual_files_loaded) self.worker_thread.observations_loaded.connect(self.on_individual_files_loaded)
@@ -3528,16 +3747,20 @@ class MainApplication(QMainWindow):
# --- THREAD START --- # --- THREAD START ---
output_csv = f"{obs_id}_{cam_id}_processed.csv" output_csv = f"{obs_id}_{cam_id}_processed.csv"
full_video_path = file_info["path"] full_video_path = file_info["path"]
json_path = self.selected_boris_json_path
selected_obs_key = self.selected_boris_key
# ⚠️ Ensure ParticipantProcessor is defined to accept self.observations_root # ⚠️ Ensure ParticipantProcessor is defined to accept self.observations_root
processor = ParticipantProcessor2( processor = ParticipantProcessor2(
obs_id=obs_id, obs_id=obs_id,
selected_cam_id=cam_id, selected_cam_id=2,
selected_hand_idx=selected_hand_idx, selected_hand_idx=selected_hand_idx,
video_path=full_video_path, # CRITICAL PATH FIX video_path=full_video_path, # CRITICAL PATH FIX
hide_preview=False, # Assuming default for now hide_preview=False, # Assuming default for now
output_csv=output_csv, output_csv=output_csv,
output_dir=self.output_dir, output_dir=self.output_dir,
initial_wrists=self.camera_state[(obs_id, cam_id)].get("initial_wrists") initial_wrists=self.camera_state[(obs_id, cam_id)].get("initial_wrists"),
json_path=json_path,
boris_obs_key=selected_obs_key
) )
# Connect signals to the UI elements # Connect signals to the UI elements