From 60735b0d381dab11d430630d5ebe25b4d5b97874 Mon Sep 17 00:00:00 2001 From: tyler Date: Fri, 13 Mar 2026 21:41:47 -0700 Subject: [PATCH] improvements --- .gitignore | 3 +- main.py | 251 ++++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 239 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index e3c9c22..de44da5 100644 --- a/.gitignore +++ b/.gitignore @@ -174,4 +174,5 @@ cython_debug/ # PyPI configuration file .pypirc -sparks_*/ \ No newline at end of file +sparks_*/ +*.pth \ No newline at end of file diff --git a/main.py b/main.py index 102a1c3..116d63f 100644 --- a/main.py +++ b/main.py @@ -45,12 +45,15 @@ import matplotlib.pyplot as plt from PySide6.QtWidgets import ( QApplication, QWidget, QMessageBox, QVBoxLayout, QHBoxLayout, QTextEdit, QScrollArea, QComboBox, QGridLayout, QPushButton, QMainWindow, QFileDialog, QLabel, QLineEdit, QFrame, QSizePolicy, QGroupBox, QDialog, QListView, - QMenu, QProgressBar, QCheckBox, QSlider, QTabWidget, QTreeWidget, QTreeWidgetItem, QHeaderView + QMenu, QProgressBar, QCheckBox, QSlider, QTabWidget, QTreeWidget, QTreeWidgetItem, QHeaderView, QInputDialog ) from PySide6.QtCore import QThread, Signal, Qt, QTimer, QEvent, QSize, QPoint from PySide6.QtGui import QAction, QKeySequence, QIcon, QIntValidator, QDoubleValidator, QPixmap, QStandardItemModel, QStandardItem, QImage from PySide6.QtSvgWidgets import QSvgWidget # needed to show svgs when app is not frozen +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.figure import Figure +import pyqtgraph as pg CURRENT_VERSION = "1.0.0" @@ -89,13 +92,21 @@ class TrainModelThread(QThread): # Load CSVs X_list, y_list, feature_cols = [], [], None + # for path in self.csv_paths: + # df = pd.read_csv(path) + # if feature_cols is None: + # feature_cols = [c for c in df.columns if c.startswith("lm")] + # df[feature_cols] = df[feature_cols].ffill().bfill() + # X_list.append(df[feature_cols].values) + # y_list.append(df[["reach_active", "reach_before_contact"]].values) + for path in self.csv_paths: df = pd.read_csv(path) if feature_cols is None: - feature_cols = [c for c in df.columns if c.startswith("lm")] + feature_cols = [c for c in df.columns if c.startswith("h0")] df[feature_cols] = df[feature_cols].ffill().bfill() X_list.append(df[feature_cols].values) - y_list.append(df[["reach_active", "reach_before_contact"]].values) + y_list.append(df[["current_event"]].values) X = np.concatenate(X_list, axis=0) y = np.concatenate(y_list, axis=0) @@ -117,7 +128,7 @@ class TrainModelThread(QThread): y_windows = np.array(y_windows) class_weights = [] - for i in range(2): + for i in range(1): cw = compute_class_weight('balanced', classes=np.array([0,1]), y=y_windows[:, i]) class_weights.append(cw[1]/cw[0]) pos_weight = torch.tensor(class_weights, dtype=torch.float32) @@ -126,7 +137,7 @@ class TrainModelThread(QThread): # LSTM class WindowLSTM(nn.Module): - def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2): + def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=1): super().__init__() self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True, bidirectional=bidirectional) @@ -1460,7 +1471,7 @@ class ParticipantProcessor2(QThread): finished_processing = Signal(str, str) # obs_id, cam_id def __init__(self, obs_id, selected_cam_id, selected_hand_idx, - video_path, output_csv, output_dir, initial_wrists, **kwargs): + video_path, output_csv, output_dir, initial_wrists, json_path=None, boris_obs_key=None, **kwargs): super().__init__() self.obs_id = obs_id self.cam_id = selected_cam_id @@ -1468,6 +1479,8 @@ class ParticipantProcessor2(QThread): self.video_path = video_path self.output_dir = output_dir self.output_csv = output_csv + self.json_path = json_path + self.boris_obs_key = boris_obs_key self.is_running = True # Convert initial_wrists (list) to initial_centroids (dict) for the tracker @@ -1532,7 +1545,29 @@ class ParticipantProcessor2(QThread): save_path = os.path.join(self.output_dir, self.output_csv) - header = ['frame'] + event_list = [] + camera_offset = 0.0 + + if self.json_path and self.boris_obs_key: + try: + with open(self.json_path, 'r') as f: + full_data = json.load(f) + # Use the specific key selected from the dropdown + obs_data = full_data.get("observations", {}).get(self.boris_obs_key, {}) + + # Get events: [time, ?, label, description, ...] + event_list = obs_data.get("events", []) + # Sort events by time to ensure lookup works + event_list.sort(key=lambda x: x[0]) + + # Get camera offset + media_info = obs_data.get("media_info", {}) + offsets = media_info.get("offset", {}) + camera_offset = float(offsets.get(self.cam_id, 0.0)) + except Exception as e: + print(f"Error loading JSON events: {e}") + + header = ['frame', 'time_sec', 'current_event'] if self.selected_hand_idx == 99: # Dual hand columns: h0_x0, h0_y0 ... h1_x20, h1_y20 for h_prefix in ['h0', 'h1']: @@ -1543,6 +1578,11 @@ class ParticipantProcessor2(QThread): for i in range(21): header.extend([f'x{i}', f'y{i}']) + + current_ev_idx = 0 + reach_active = 0 # This will be written as 1 or 0 + in_reach_phase = False + with open(save_path, 'w', newline='') as f: writer = csv.writer(f) writer.writerow(header) @@ -1559,6 +1599,30 @@ class ParticipantProcessor2(QThread): results = detector.detect(mp_image) frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) + current_time = frame_idx / fps + current_event_label = "None" + adjusted_time = current_time + camera_offset + + while (current_ev_idx < len(event_list) and + event_list[current_ev_idx][0] <= adjusted_time): + + event_label = str(event_list[current_ev_idx][2]).strip() + + # 1. Start on the FIRST "Reach" + if event_label == "Reach" and not in_reach_phase: + in_reach_phase = True + reach_active = 1 + + # 2. Stop on "End" + elif event_label == "End": + in_reach_phase = False + reach_active = 0 + + # Note: If event_label is the second "Reach", we do nothing. + # It hits this 'else' (implicitly) and we just move the index forward. + + current_ev_idx += 1 + # current_mapping stores {hand_id: landmark_list} for this frame current_mapping = {} @@ -1603,7 +1667,7 @@ class ParticipantProcessor2(QThread): last_known_centroids[hand_id] = self.get_centroid(landmarks) # --- STEP 4: WRITE TO CSV --- - row = [frame_idx] + row = [frame_idx, round(current_time, 4), reach_active] if self.selected_hand_idx == 99: # DUAL MODE: We expect Hand 0 and Hand 1 @@ -1708,7 +1772,8 @@ class HandValidationWindow(QWidget): self.inspector = HandDataInspector() self.inspector.show() - + self.frame_counter = 0 + # Logic state self.timer = QTimer() self.timer.timeout.connect(self.update_frame) @@ -1721,11 +1786,18 @@ class HandValidationWindow(QWidget): 'h0': deque(maxlen=30), 'h1': deque(maxlen=30) } + self.csv_graph_window = None def load_files(self): video_path, _ = QFileDialog.getOpenFileName(self, "Select Video", "", "Videos (*.mp4 *.avi)") csv_path, _ = QFileDialog.getOpenFileName(self, "Select CSV", "", "CSV Files (*.csv)") + + # --- fNIRS CSV (52 channels, synced plot) --- + csv_52_path, _ = QFileDialog.getOpenFileName(self, "Select 52-Channel CSV", "", "CSV Files (*.csv)") + # --- JSON (alignment info / time_shift) --- + json_path, _ = QFileDialog.getOpenFileName(self, "Select JSON", "", "JSON Files (*.json)") + if video_path and csv_path: self.cap = cv2.VideoCapture(video_path) self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) @@ -1733,8 +1805,30 @@ class HandValidationWindow(QWidget): # Load the dual-hand data self.hand_data = self.parse_dual_hand_csv(csv_path) + + self.csv_52 = pd.read_csv(csv_52_path) + self.csv_52_time_col = self.csv_52.columns[0] + + with open(json_path, "r") as f: + self.json_data = json.load(f) + + # Extract values dynamically but hardcode the '2' key + boris_anchor = self.json_data.get("boris_anchor", {}).get("time", 0.0) # Accessing video 2 delay specifically + v2_delay = self.json_data.get("videos", {}).get("2", {}).get("delay", 0.0) + + print(f"Syncing with Video 2: Shift {boris_anchor}, Delay {v2_delay}") + + # Initialize the window with these values + self.csv_graph_window = CSVGraphWindow( + csv_52_path, + boris_anchor=boris_anchor, + video_delay=v2_delay + ) + self.csv_graph_window.show() + self.timer.start(16) + def update_speed(self): speed_map = {"0.25x": 64, "0.5x": 32, "1.0x": 16, "2.0x": 8} ms = speed_map.get(self.speed_combo.currentText(), 16) @@ -1792,7 +1886,7 @@ class HandValidationWindow(QWidget): ret, frame = self.cap.read() if not ret: return - + self.frame_counter += 1 f_idx = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES)) self.slider.setValue(f_idx) self.lbl_frame.setText(f"Frame: {f_idx}") @@ -1821,14 +1915,17 @@ class HandValidationWindow(QWidget): rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) qimg = QImage(rgb.data, w, h, 3 * w, QImage.Format_RGB888) display_size = self.video_label.contentsRect().size() - + + mode = Qt.TransformationMode.SmoothTransformation if self.paused else Qt.TransformationMode.FastTransformation self.video_label.setPixmap( QPixmap.fromImage(qimg).scaled( display_size, Qt.AspectRatioMode.KeepAspectRatio, - Qt.TransformationMode.SmoothTransformation # Added for better quality + mode ) ) + if self.csv_graph_window and self.frame_counter % 2 == 0: + self.csv_graph_window.update_plot(f_idx / fps) def draw_skeleton(self, img, landmarks, w, h, color, h_key, prev_landmarks=None, fps=30): # --- 1. Draw Skeleton Connections --- @@ -1886,6 +1983,77 @@ class HandValidationWindow(QWidget): cv2.arrowedLine(img, com_px, smooth_end_px, (255, 255, 255), 5, tipLength=0.3) +class CSVGraphWindow(QWidget): + def __init__(self, csv_path, boris_anchor=0, video_delay=0, window_sec=15.0): + super().__init__() + self.setWindowTitle("High-Speed Data Viewer") + self.resize(1200, 400) + + self.main_layout = QVBoxLayout(self) + self.boris_anchor = boris_anchor + self.video_delay = video_delay + self.window_sec = window_sec + self.current_page = -1 + + # 1. Load and downsample for extreme speed + self.df = pd.read_csv(csv_path) + # Plotting every 4th point is visually identical but 4x faster + self.df = self.df.iloc[::4, :].reset_index(drop=True) + + self.time_col = self.df.columns[0] + self.anno_col = self.df.columns[1] + self.channels = list(self.df.columns[2:]) + + # 2. Create PyQtGraph Plot + self.plot_widget = pg.PlotWidget() + self.plot_widget.setBackground('w') # White background + self.main_layout.addWidget(self.plot_widget) + + # 3. Add Shaded Annotations (Boxes) + self._add_pg_annotations() + + # 4. Plot Channels (PyQtGraph is very fast at this) + for i, ch in enumerate(self.channels): + # Using a single color or rotating colors + color = pg.intColor(i, hues=10, values=1, alpha=150) + self.plot_widget.plot(self.df[self.time_col], self.df[ch], pen=pg.mkPen(color, width=1)) + + # 5. Add Playhead (InfiniteLine is much faster than axvline) + self.playhead = pg.InfiniteLine(pos=0, angle=90, pen=pg.mkPen('r', width=2)) + self.plot_widget.addItem(self.playhead) + + self.plot_widget.setXRange(0, self.window_sec) + + def _add_pg_annotations(self): + labels = self.df[self.anno_col].fillna("None").astype(str) + changes = labels != labels.shift() + indices = labels.index[changes].tolist() + [len(labels)] + + for i in range(len(indices) - 1): + idx = indices[i] + label = labels.iloc[idx] + if label not in ["None", "0", "nan"]: + t1 = self.df[self.time_col].iloc[idx] + t2 = self.df[self.time_col].iloc[indices[i+1]-1] + # LinearRegionItem is great for shading + region = pg.LinearRegionItem(values=(t1, t2), brush=pg.mkBrush(0, 0, 255, 30), movable=False) + self.plot_widget.addItem(region) + + def update_plot(self, video_time): + aligned_time = (video_time + self.video_delay - self.boris_anchor) + 5 + + # 1. Update playhead position (Instantly fast in PyQtGraph) + self.playhead.setPos(aligned_time) + + # 2. Page flipping logic to avoid constant axis movement + page_number = int(aligned_time // self.window_sec) + if page_number != self.current_page: + new_start = page_number * self.window_sec + self.plot_widget.setXRange(new_start, new_start + self.window_sec, padding=0) + self.current_page = page_number + + + class HandDataInspector(QWidget): def __init__(self): super().__init__() @@ -2716,6 +2884,57 @@ class MainApplication(QMainWindow): file_path, self.extract_frame_and_hands # Assuming this is a method available to MainApplication ) + + msg_box = QMessageBox() + msg_box.setWindowTitle("Action Required") + msg_box.setText("BORIS?") + msg_box.setStandardButtons(QMessageBox.Yes | QMessageBox.No) + msg_box.setDefaultButton(QMessageBox.Yes) + + response = msg_box.exec() + + if response == QMessageBox.Yes: + # 3. Open File Dialog + self.selected_boris_json_path, _ = QFileDialog.getOpenFileName( + None, + "Select File", + "", + "All Files (*);;Text Files (*.txt)" + ) + + try: + with open(self.selected_boris_json_path, 'r') as f: + data = json.load(f) + + # 3. Extract keys from "observation" + observations = data.get("observations", {}) + if not observations: + QMessageBox.warning(None, "Error", "No 'observation' key found or it is empty.") + return + + keys = list(observations.keys()) + + # 4. Open Dropdown Dialog (QInputDialog) + selected_key, ok = QInputDialog.getItem( + None, "Select Key", "Pick an observation key:", keys, 0, False + ) + + # 5. Store the result + if ok and selected_key: + self.selected_boris_key = selected_key + print(f"Successfully stored: {self.selected_boris_key}") + QMessageBox.information(None, "Success", f"Stored: {self.selected_boris_key}") + + except json.JSONDecodeError: + QMessageBox.critical(None, "Error", "Invalid JSON format.") + except Exception as e: + QMessageBox.critical(None, "Error", f"An error occurred: {e}") + + else: + print("File selection cancelled.") + else: + # If 'No' or closed, do nothing + print("User declined.") # 4. Connect Signals to Main Thread Slots self.worker_thread.observations_loaded.connect(self.on_individual_files_loaded) @@ -3528,16 +3747,20 @@ class MainApplication(QMainWindow): # --- THREAD START --- output_csv = f"{obs_id}_{cam_id}_processed.csv" full_video_path = file_info["path"] + json_path = self.selected_boris_json_path + selected_obs_key = self.selected_boris_key # ⚠️ Ensure ParticipantProcessor is defined to accept self.observations_root processor = ParticipantProcessor2( obs_id=obs_id, - selected_cam_id=cam_id, + selected_cam_id=2, selected_hand_idx=selected_hand_idx, video_path=full_video_path, # CRITICAL PATH FIX hide_preview=False, # Assuming default for now output_csv=output_csv, output_dir=self.output_dir, - initial_wrists=self.camera_state[(obs_id, cam_id)].get("initial_wrists") + initial_wrists=self.camera_state[(obs_id, cam_id)].get("initial_wrists"), + json_path=json_path, + boris_obs_key=selected_obs_key ) # Connect signals to the UI elements