improvements

This commit is contained in:
2026-03-13 21:41:47 -07:00
parent 95f68e0c12
commit 60735b0d38
2 changed files with 239 additions and 15 deletions

3
.gitignore vendored
View File

@@ -174,4 +174,5 @@ cython_debug/
# PyPI configuration file
.pypirc
sparks_*/
sparks_*/
*.pth

251
main.py
View File

@@ -45,12 +45,15 @@ import matplotlib.pyplot as plt
from PySide6.QtWidgets import (
QApplication, QWidget, QMessageBox, QVBoxLayout, QHBoxLayout, QTextEdit, QScrollArea, QComboBox, QGridLayout,
QPushButton, QMainWindow, QFileDialog, QLabel, QLineEdit, QFrame, QSizePolicy, QGroupBox, QDialog, QListView,
QMenu, QProgressBar, QCheckBox, QSlider, QTabWidget, QTreeWidget, QTreeWidgetItem, QHeaderView
QMenu, QProgressBar, QCheckBox, QSlider, QTabWidget, QTreeWidget, QTreeWidgetItem, QHeaderView, QInputDialog
)
from PySide6.QtCore import QThread, Signal, Qt, QTimer, QEvent, QSize, QPoint
from PySide6.QtGui import QAction, QKeySequence, QIcon, QIntValidator, QDoubleValidator, QPixmap, QStandardItemModel, QStandardItem, QImage
from PySide6.QtSvgWidgets import QSvgWidget # needed to show svgs when app is not frozen
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
import pyqtgraph as pg
CURRENT_VERSION = "1.0.0"
@@ -89,13 +92,21 @@ class TrainModelThread(QThread):
# Load CSVs
X_list, y_list, feature_cols = [], [], None
# for path in self.csv_paths:
# df = pd.read_csv(path)
# if feature_cols is None:
# feature_cols = [c for c in df.columns if c.startswith("lm")]
# df[feature_cols] = df[feature_cols].ffill().bfill()
# X_list.append(df[feature_cols].values)
# y_list.append(df[["reach_active", "reach_before_contact"]].values)
for path in self.csv_paths:
df = pd.read_csv(path)
if feature_cols is None:
feature_cols = [c for c in df.columns if c.startswith("lm")]
feature_cols = [c for c in df.columns if c.startswith("h0")]
df[feature_cols] = df[feature_cols].ffill().bfill()
X_list.append(df[feature_cols].values)
y_list.append(df[["reach_active", "reach_before_contact"]].values)
y_list.append(df[["current_event"]].values)
X = np.concatenate(X_list, axis=0)
y = np.concatenate(y_list, axis=0)
@@ -117,7 +128,7 @@ class TrainModelThread(QThread):
y_windows = np.array(y_windows)
class_weights = []
for i in range(2):
for i in range(1):
cw = compute_class_weight('balanced', classes=np.array([0,1]), y=y_windows[:, i])
class_weights.append(cw[1]/cw[0])
pos_weight = torch.tensor(class_weights, dtype=torch.float32)
@@ -126,7 +137,7 @@ class TrainModelThread(QThread):
# LSTM
class WindowLSTM(nn.Module):
def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2):
def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=1):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True,
bidirectional=bidirectional)
@@ -1460,7 +1471,7 @@ class ParticipantProcessor2(QThread):
finished_processing = Signal(str, str) # obs_id, cam_id
def __init__(self, obs_id, selected_cam_id, selected_hand_idx,
video_path, output_csv, output_dir, initial_wrists, **kwargs):
video_path, output_csv, output_dir, initial_wrists, json_path=None, boris_obs_key=None, **kwargs):
super().__init__()
self.obs_id = obs_id
self.cam_id = selected_cam_id
@@ -1468,6 +1479,8 @@ class ParticipantProcessor2(QThread):
self.video_path = video_path
self.output_dir = output_dir
self.output_csv = output_csv
self.json_path = json_path
self.boris_obs_key = boris_obs_key
self.is_running = True
# Convert initial_wrists (list) to initial_centroids (dict) for the tracker
@@ -1532,7 +1545,29 @@ class ParticipantProcessor2(QThread):
save_path = os.path.join(self.output_dir, self.output_csv)
header = ['frame']
event_list = []
camera_offset = 0.0
if self.json_path and self.boris_obs_key:
try:
with open(self.json_path, 'r') as f:
full_data = json.load(f)
# Use the specific key selected from the dropdown
obs_data = full_data.get("observations", {}).get(self.boris_obs_key, {})
# Get events: [time, ?, label, description, ...]
event_list = obs_data.get("events", [])
# Sort events by time to ensure lookup works
event_list.sort(key=lambda x: x[0])
# Get camera offset
media_info = obs_data.get("media_info", {})
offsets = media_info.get("offset", {})
camera_offset = float(offsets.get(self.cam_id, 0.0))
except Exception as e:
print(f"Error loading JSON events: {e}")
header = ['frame', 'time_sec', 'current_event']
if self.selected_hand_idx == 99:
# Dual hand columns: h0_x0, h0_y0 ... h1_x20, h1_y20
for h_prefix in ['h0', 'h1']:
@@ -1543,6 +1578,11 @@ class ParticipantProcessor2(QThread):
for i in range(21):
header.extend([f'x{i}', f'y{i}'])
current_ev_idx = 0
reach_active = 0 # This will be written as 1 or 0
in_reach_phase = False
with open(save_path, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(header)
@@ -1559,6 +1599,30 @@ class ParticipantProcessor2(QThread):
results = detector.detect(mp_image)
frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
current_time = frame_idx / fps
current_event_label = "None"
adjusted_time = current_time + camera_offset
while (current_ev_idx < len(event_list) and
event_list[current_ev_idx][0] <= adjusted_time):
event_label = str(event_list[current_ev_idx][2]).strip()
# 1. Start on the FIRST "Reach"
if event_label == "Reach" and not in_reach_phase:
in_reach_phase = True
reach_active = 1
# 2. Stop on "End"
elif event_label == "End":
in_reach_phase = False
reach_active = 0
# Note: If event_label is the second "Reach", we do nothing.
# It hits this 'else' (implicitly) and we just move the index forward.
current_ev_idx += 1
# current_mapping stores {hand_id: landmark_list} for this frame
current_mapping = {}
@@ -1603,7 +1667,7 @@ class ParticipantProcessor2(QThread):
last_known_centroids[hand_id] = self.get_centroid(landmarks)
# --- STEP 4: WRITE TO CSV ---
row = [frame_idx]
row = [frame_idx, round(current_time, 4), reach_active]
if self.selected_hand_idx == 99:
# DUAL MODE: We expect Hand 0 and Hand 1
@@ -1708,7 +1772,8 @@ class HandValidationWindow(QWidget):
self.inspector = HandDataInspector()
self.inspector.show()
self.frame_counter = 0
# Logic state
self.timer = QTimer()
self.timer.timeout.connect(self.update_frame)
@@ -1721,11 +1786,18 @@ class HandValidationWindow(QWidget):
'h0': deque(maxlen=30),
'h1': deque(maxlen=30)
}
self.csv_graph_window = None
def load_files(self):
video_path, _ = QFileDialog.getOpenFileName(self, "Select Video", "", "Videos (*.mp4 *.avi)")
csv_path, _ = QFileDialog.getOpenFileName(self, "Select CSV", "", "CSV Files (*.csv)")
# --- fNIRS CSV (52 channels, synced plot) ---
csv_52_path, _ = QFileDialog.getOpenFileName(self, "Select 52-Channel CSV", "", "CSV Files (*.csv)")
# --- JSON (alignment info / time_shift) ---
json_path, _ = QFileDialog.getOpenFileName(self, "Select JSON", "", "JSON Files (*.json)")
if video_path and csv_path:
self.cap = cv2.VideoCapture(video_path)
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
@@ -1733,8 +1805,30 @@ class HandValidationWindow(QWidget):
# Load the dual-hand data
self.hand_data = self.parse_dual_hand_csv(csv_path)
self.csv_52 = pd.read_csv(csv_52_path)
self.csv_52_time_col = self.csv_52.columns[0]
with open(json_path, "r") as f:
self.json_data = json.load(f)
# Extract values dynamically but hardcode the '2' key
boris_anchor = self.json_data.get("boris_anchor", {}).get("time", 0.0) # Accessing video 2 delay specifically
v2_delay = self.json_data.get("videos", {}).get("2", {}).get("delay", 0.0)
print(f"Syncing with Video 2: Shift {boris_anchor}, Delay {v2_delay}")
# Initialize the window with these values
self.csv_graph_window = CSVGraphWindow(
csv_52_path,
boris_anchor=boris_anchor,
video_delay=v2_delay
)
self.csv_graph_window.show()
self.timer.start(16)
def update_speed(self):
speed_map = {"0.25x": 64, "0.5x": 32, "1.0x": 16, "2.0x": 8}
ms = speed_map.get(self.speed_combo.currentText(), 16)
@@ -1792,7 +1886,7 @@ class HandValidationWindow(QWidget):
ret, frame = self.cap.read()
if not ret: return
self.frame_counter += 1
f_idx = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES))
self.slider.setValue(f_idx)
self.lbl_frame.setText(f"Frame: {f_idx}")
@@ -1821,14 +1915,17 @@ class HandValidationWindow(QWidget):
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
qimg = QImage(rgb.data, w, h, 3 * w, QImage.Format_RGB888)
display_size = self.video_label.contentsRect().size()
mode = Qt.TransformationMode.SmoothTransformation if self.paused else Qt.TransformationMode.FastTransformation
self.video_label.setPixmap(
QPixmap.fromImage(qimg).scaled(
display_size,
Qt.AspectRatioMode.KeepAspectRatio,
Qt.TransformationMode.SmoothTransformation # Added for better quality
mode
)
)
if self.csv_graph_window and self.frame_counter % 2 == 0:
self.csv_graph_window.update_plot(f_idx / fps)
def draw_skeleton(self, img, landmarks, w, h, color, h_key, prev_landmarks=None, fps=30):
# --- 1. Draw Skeleton Connections ---
@@ -1886,6 +1983,77 @@ class HandValidationWindow(QWidget):
cv2.arrowedLine(img, com_px, smooth_end_px, (255, 255, 255), 5, tipLength=0.3)
class CSVGraphWindow(QWidget):
def __init__(self, csv_path, boris_anchor=0, video_delay=0, window_sec=15.0):
super().__init__()
self.setWindowTitle("High-Speed Data Viewer")
self.resize(1200, 400)
self.main_layout = QVBoxLayout(self)
self.boris_anchor = boris_anchor
self.video_delay = video_delay
self.window_sec = window_sec
self.current_page = -1
# 1. Load and downsample for extreme speed
self.df = pd.read_csv(csv_path)
# Plotting every 4th point is visually identical but 4x faster
self.df = self.df.iloc[::4, :].reset_index(drop=True)
self.time_col = self.df.columns[0]
self.anno_col = self.df.columns[1]
self.channels = list(self.df.columns[2:])
# 2. Create PyQtGraph Plot
self.plot_widget = pg.PlotWidget()
self.plot_widget.setBackground('w') # White background
self.main_layout.addWidget(self.plot_widget)
# 3. Add Shaded Annotations (Boxes)
self._add_pg_annotations()
# 4. Plot Channels (PyQtGraph is very fast at this)
for i, ch in enumerate(self.channels):
# Using a single color or rotating colors
color = pg.intColor(i, hues=10, values=1, alpha=150)
self.plot_widget.plot(self.df[self.time_col], self.df[ch], pen=pg.mkPen(color, width=1))
# 5. Add Playhead (InfiniteLine is much faster than axvline)
self.playhead = pg.InfiniteLine(pos=0, angle=90, pen=pg.mkPen('r', width=2))
self.plot_widget.addItem(self.playhead)
self.plot_widget.setXRange(0, self.window_sec)
def _add_pg_annotations(self):
labels = self.df[self.anno_col].fillna("None").astype(str)
changes = labels != labels.shift()
indices = labels.index[changes].tolist() + [len(labels)]
for i in range(len(indices) - 1):
idx = indices[i]
label = labels.iloc[idx]
if label not in ["None", "0", "nan"]:
t1 = self.df[self.time_col].iloc[idx]
t2 = self.df[self.time_col].iloc[indices[i+1]-1]
# LinearRegionItem is great for shading
region = pg.LinearRegionItem(values=(t1, t2), brush=pg.mkBrush(0, 0, 255, 30), movable=False)
self.plot_widget.addItem(region)
def update_plot(self, video_time):
aligned_time = (video_time + self.video_delay - self.boris_anchor) + 5
# 1. Update playhead position (Instantly fast in PyQtGraph)
self.playhead.setPos(aligned_time)
# 2. Page flipping logic to avoid constant axis movement
page_number = int(aligned_time // self.window_sec)
if page_number != self.current_page:
new_start = page_number * self.window_sec
self.plot_widget.setXRange(new_start, new_start + self.window_sec, padding=0)
self.current_page = page_number
class HandDataInspector(QWidget):
def __init__(self):
super().__init__()
@@ -2716,6 +2884,57 @@ class MainApplication(QMainWindow):
file_path,
self.extract_frame_and_hands # Assuming this is a method available to MainApplication
)
msg_box = QMessageBox()
msg_box.setWindowTitle("Action Required")
msg_box.setText("BORIS?")
msg_box.setStandardButtons(QMessageBox.Yes | QMessageBox.No)
msg_box.setDefaultButton(QMessageBox.Yes)
response = msg_box.exec()
if response == QMessageBox.Yes:
# 3. Open File Dialog
self.selected_boris_json_path, _ = QFileDialog.getOpenFileName(
None,
"Select File",
"",
"All Files (*);;Text Files (*.txt)"
)
try:
with open(self.selected_boris_json_path, 'r') as f:
data = json.load(f)
# 3. Extract keys from "observation"
observations = data.get("observations", {})
if not observations:
QMessageBox.warning(None, "Error", "No 'observation' key found or it is empty.")
return
keys = list(observations.keys())
# 4. Open Dropdown Dialog (QInputDialog)
selected_key, ok = QInputDialog.getItem(
None, "Select Key", "Pick an observation key:", keys, 0, False
)
# 5. Store the result
if ok and selected_key:
self.selected_boris_key = selected_key
print(f"Successfully stored: {self.selected_boris_key}")
QMessageBox.information(None, "Success", f"Stored: {self.selected_boris_key}")
except json.JSONDecodeError:
QMessageBox.critical(None, "Error", "Invalid JSON format.")
except Exception as e:
QMessageBox.critical(None, "Error", f"An error occurred: {e}")
else:
print("File selection cancelled.")
else:
# If 'No' or closed, do nothing
print("User declined.")
# 4. Connect Signals to Main Thread Slots
self.worker_thread.observations_loaded.connect(self.on_individual_files_loaded)
@@ -3528,16 +3747,20 @@ class MainApplication(QMainWindow):
# --- THREAD START ---
output_csv = f"{obs_id}_{cam_id}_processed.csv"
full_video_path = file_info["path"]
json_path = self.selected_boris_json_path
selected_obs_key = self.selected_boris_key
# ⚠️ Ensure ParticipantProcessor is defined to accept self.observations_root
processor = ParticipantProcessor2(
obs_id=obs_id,
selected_cam_id=cam_id,
selected_cam_id=2,
selected_hand_idx=selected_hand_idx,
video_path=full_video_path, # CRITICAL PATH FIX
hide_preview=False, # Assuming default for now
output_csv=output_csv,
output_dir=self.output_dir,
initial_wrists=self.camera_state[(obs_id, cam_id)].get("initial_wrists")
initial_wrists=self.camera_state[(obs_id, cam_id)].get("initial_wrists"),
json_path=json_path,
boris_obs_key=selected_obs_key
)
# Connect signals to the UI elements