improvements
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -174,4 +174,5 @@ cython_debug/
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
sparks_*/
|
||||
sparks_*/
|
||||
*.pth
|
||||
251
main.py
251
main.py
@@ -45,12 +45,15 @@ import matplotlib.pyplot as plt
|
||||
from PySide6.QtWidgets import (
|
||||
QApplication, QWidget, QMessageBox, QVBoxLayout, QHBoxLayout, QTextEdit, QScrollArea, QComboBox, QGridLayout,
|
||||
QPushButton, QMainWindow, QFileDialog, QLabel, QLineEdit, QFrame, QSizePolicy, QGroupBox, QDialog, QListView,
|
||||
QMenu, QProgressBar, QCheckBox, QSlider, QTabWidget, QTreeWidget, QTreeWidgetItem, QHeaderView
|
||||
QMenu, QProgressBar, QCheckBox, QSlider, QTabWidget, QTreeWidget, QTreeWidgetItem, QHeaderView, QInputDialog
|
||||
)
|
||||
from PySide6.QtCore import QThread, Signal, Qt, QTimer, QEvent, QSize, QPoint
|
||||
from PySide6.QtGui import QAction, QKeySequence, QIcon, QIntValidator, QDoubleValidator, QPixmap, QStandardItemModel, QStandardItem, QImage
|
||||
from PySide6.QtSvgWidgets import QSvgWidget # needed to show svgs when app is not frozen
|
||||
|
||||
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
||||
from matplotlib.figure import Figure
|
||||
import pyqtgraph as pg
|
||||
|
||||
CURRENT_VERSION = "1.0.0"
|
||||
|
||||
@@ -89,13 +92,21 @@ class TrainModelThread(QThread):
|
||||
|
||||
# Load CSVs
|
||||
X_list, y_list, feature_cols = [], [], None
|
||||
# for path in self.csv_paths:
|
||||
# df = pd.read_csv(path)
|
||||
# if feature_cols is None:
|
||||
# feature_cols = [c for c in df.columns if c.startswith("lm")]
|
||||
# df[feature_cols] = df[feature_cols].ffill().bfill()
|
||||
# X_list.append(df[feature_cols].values)
|
||||
# y_list.append(df[["reach_active", "reach_before_contact"]].values)
|
||||
|
||||
for path in self.csv_paths:
|
||||
df = pd.read_csv(path)
|
||||
if feature_cols is None:
|
||||
feature_cols = [c for c in df.columns if c.startswith("lm")]
|
||||
feature_cols = [c for c in df.columns if c.startswith("h0")]
|
||||
df[feature_cols] = df[feature_cols].ffill().bfill()
|
||||
X_list.append(df[feature_cols].values)
|
||||
y_list.append(df[["reach_active", "reach_before_contact"]].values)
|
||||
y_list.append(df[["current_event"]].values)
|
||||
|
||||
X = np.concatenate(X_list, axis=0)
|
||||
y = np.concatenate(y_list, axis=0)
|
||||
@@ -117,7 +128,7 @@ class TrainModelThread(QThread):
|
||||
y_windows = np.array(y_windows)
|
||||
|
||||
class_weights = []
|
||||
for i in range(2):
|
||||
for i in range(1):
|
||||
cw = compute_class_weight('balanced', classes=np.array([0,1]), y=y_windows[:, i])
|
||||
class_weights.append(cw[1]/cw[0])
|
||||
pos_weight = torch.tensor(class_weights, dtype=torch.float32)
|
||||
@@ -126,7 +137,7 @@ class TrainModelThread(QThread):
|
||||
|
||||
# LSTM
|
||||
class WindowLSTM(nn.Module):
|
||||
def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=2):
|
||||
def __init__(self, input_size, hidden_size=64, bidirectional=False, output_size=1):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True,
|
||||
bidirectional=bidirectional)
|
||||
@@ -1460,7 +1471,7 @@ class ParticipantProcessor2(QThread):
|
||||
finished_processing = Signal(str, str) # obs_id, cam_id
|
||||
|
||||
def __init__(self, obs_id, selected_cam_id, selected_hand_idx,
|
||||
video_path, output_csv, output_dir, initial_wrists, **kwargs):
|
||||
video_path, output_csv, output_dir, initial_wrists, json_path=None, boris_obs_key=None, **kwargs):
|
||||
super().__init__()
|
||||
self.obs_id = obs_id
|
||||
self.cam_id = selected_cam_id
|
||||
@@ -1468,6 +1479,8 @@ class ParticipantProcessor2(QThread):
|
||||
self.video_path = video_path
|
||||
self.output_dir = output_dir
|
||||
self.output_csv = output_csv
|
||||
self.json_path = json_path
|
||||
self.boris_obs_key = boris_obs_key
|
||||
self.is_running = True
|
||||
|
||||
# Convert initial_wrists (list) to initial_centroids (dict) for the tracker
|
||||
@@ -1532,7 +1545,29 @@ class ParticipantProcessor2(QThread):
|
||||
|
||||
save_path = os.path.join(self.output_dir, self.output_csv)
|
||||
|
||||
header = ['frame']
|
||||
event_list = []
|
||||
camera_offset = 0.0
|
||||
|
||||
if self.json_path and self.boris_obs_key:
|
||||
try:
|
||||
with open(self.json_path, 'r') as f:
|
||||
full_data = json.load(f)
|
||||
# Use the specific key selected from the dropdown
|
||||
obs_data = full_data.get("observations", {}).get(self.boris_obs_key, {})
|
||||
|
||||
# Get events: [time, ?, label, description, ...]
|
||||
event_list = obs_data.get("events", [])
|
||||
# Sort events by time to ensure lookup works
|
||||
event_list.sort(key=lambda x: x[0])
|
||||
|
||||
# Get camera offset
|
||||
media_info = obs_data.get("media_info", {})
|
||||
offsets = media_info.get("offset", {})
|
||||
camera_offset = float(offsets.get(self.cam_id, 0.0))
|
||||
except Exception as e:
|
||||
print(f"Error loading JSON events: {e}")
|
||||
|
||||
header = ['frame', 'time_sec', 'current_event']
|
||||
if self.selected_hand_idx == 99:
|
||||
# Dual hand columns: h0_x0, h0_y0 ... h1_x20, h1_y20
|
||||
for h_prefix in ['h0', 'h1']:
|
||||
@@ -1543,6 +1578,11 @@ class ParticipantProcessor2(QThread):
|
||||
for i in range(21):
|
||||
header.extend([f'x{i}', f'y{i}'])
|
||||
|
||||
|
||||
current_ev_idx = 0
|
||||
reach_active = 0 # This will be written as 1 or 0
|
||||
in_reach_phase = False
|
||||
|
||||
with open(save_path, 'w', newline='') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(header)
|
||||
@@ -1559,6 +1599,30 @@ class ParticipantProcessor2(QThread):
|
||||
results = detector.detect(mp_image)
|
||||
frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
|
||||
|
||||
current_time = frame_idx / fps
|
||||
current_event_label = "None"
|
||||
adjusted_time = current_time + camera_offset
|
||||
|
||||
while (current_ev_idx < len(event_list) and
|
||||
event_list[current_ev_idx][0] <= adjusted_time):
|
||||
|
||||
event_label = str(event_list[current_ev_idx][2]).strip()
|
||||
|
||||
# 1. Start on the FIRST "Reach"
|
||||
if event_label == "Reach" and not in_reach_phase:
|
||||
in_reach_phase = True
|
||||
reach_active = 1
|
||||
|
||||
# 2. Stop on "End"
|
||||
elif event_label == "End":
|
||||
in_reach_phase = False
|
||||
reach_active = 0
|
||||
|
||||
# Note: If event_label is the second "Reach", we do nothing.
|
||||
# It hits this 'else' (implicitly) and we just move the index forward.
|
||||
|
||||
current_ev_idx += 1
|
||||
|
||||
# current_mapping stores {hand_id: landmark_list} for this frame
|
||||
current_mapping = {}
|
||||
|
||||
@@ -1603,7 +1667,7 @@ class ParticipantProcessor2(QThread):
|
||||
last_known_centroids[hand_id] = self.get_centroid(landmarks)
|
||||
|
||||
# --- STEP 4: WRITE TO CSV ---
|
||||
row = [frame_idx]
|
||||
row = [frame_idx, round(current_time, 4), reach_active]
|
||||
|
||||
if self.selected_hand_idx == 99:
|
||||
# DUAL MODE: We expect Hand 0 and Hand 1
|
||||
@@ -1708,7 +1772,8 @@ class HandValidationWindow(QWidget):
|
||||
|
||||
self.inspector = HandDataInspector()
|
||||
self.inspector.show()
|
||||
|
||||
self.frame_counter = 0
|
||||
|
||||
# Logic state
|
||||
self.timer = QTimer()
|
||||
self.timer.timeout.connect(self.update_frame)
|
||||
@@ -1721,11 +1786,18 @@ class HandValidationWindow(QWidget):
|
||||
'h0': deque(maxlen=30),
|
||||
'h1': deque(maxlen=30)
|
||||
}
|
||||
self.csv_graph_window = None
|
||||
|
||||
def load_files(self):
|
||||
video_path, _ = QFileDialog.getOpenFileName(self, "Select Video", "", "Videos (*.mp4 *.avi)")
|
||||
csv_path, _ = QFileDialog.getOpenFileName(self, "Select CSV", "", "CSV Files (*.csv)")
|
||||
|
||||
# --- fNIRS CSV (52 channels, synced plot) ---
|
||||
csv_52_path, _ = QFileDialog.getOpenFileName(self, "Select 52-Channel CSV", "", "CSV Files (*.csv)")
|
||||
|
||||
# --- JSON (alignment info / time_shift) ---
|
||||
json_path, _ = QFileDialog.getOpenFileName(self, "Select JSON", "", "JSON Files (*.json)")
|
||||
|
||||
if video_path and csv_path:
|
||||
self.cap = cv2.VideoCapture(video_path)
|
||||
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
@@ -1733,8 +1805,30 @@ class HandValidationWindow(QWidget):
|
||||
|
||||
# Load the dual-hand data
|
||||
self.hand_data = self.parse_dual_hand_csv(csv_path)
|
||||
|
||||
self.csv_52 = pd.read_csv(csv_52_path)
|
||||
self.csv_52_time_col = self.csv_52.columns[0]
|
||||
|
||||
with open(json_path, "r") as f:
|
||||
self.json_data = json.load(f)
|
||||
|
||||
# Extract values dynamically but hardcode the '2' key
|
||||
boris_anchor = self.json_data.get("boris_anchor", {}).get("time", 0.0) # Accessing video 2 delay specifically
|
||||
v2_delay = self.json_data.get("videos", {}).get("2", {}).get("delay", 0.0)
|
||||
|
||||
print(f"Syncing with Video 2: Shift {boris_anchor}, Delay {v2_delay}")
|
||||
|
||||
# Initialize the window with these values
|
||||
self.csv_graph_window = CSVGraphWindow(
|
||||
csv_52_path,
|
||||
boris_anchor=boris_anchor,
|
||||
video_delay=v2_delay
|
||||
)
|
||||
self.csv_graph_window.show()
|
||||
|
||||
self.timer.start(16)
|
||||
|
||||
|
||||
def update_speed(self):
|
||||
speed_map = {"0.25x": 64, "0.5x": 32, "1.0x": 16, "2.0x": 8}
|
||||
ms = speed_map.get(self.speed_combo.currentText(), 16)
|
||||
@@ -1792,7 +1886,7 @@ class HandValidationWindow(QWidget):
|
||||
ret, frame = self.cap.read()
|
||||
|
||||
if not ret: return
|
||||
|
||||
self.frame_counter += 1
|
||||
f_idx = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES))
|
||||
self.slider.setValue(f_idx)
|
||||
self.lbl_frame.setText(f"Frame: {f_idx}")
|
||||
@@ -1821,14 +1915,17 @@ class HandValidationWindow(QWidget):
|
||||
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
qimg = QImage(rgb.data, w, h, 3 * w, QImage.Format_RGB888)
|
||||
display_size = self.video_label.contentsRect().size()
|
||||
|
||||
|
||||
mode = Qt.TransformationMode.SmoothTransformation if self.paused else Qt.TransformationMode.FastTransformation
|
||||
self.video_label.setPixmap(
|
||||
QPixmap.fromImage(qimg).scaled(
|
||||
display_size,
|
||||
Qt.AspectRatioMode.KeepAspectRatio,
|
||||
Qt.TransformationMode.SmoothTransformation # Added for better quality
|
||||
mode
|
||||
)
|
||||
)
|
||||
if self.csv_graph_window and self.frame_counter % 2 == 0:
|
||||
self.csv_graph_window.update_plot(f_idx / fps)
|
||||
|
||||
def draw_skeleton(self, img, landmarks, w, h, color, h_key, prev_landmarks=None, fps=30):
|
||||
# --- 1. Draw Skeleton Connections ---
|
||||
@@ -1886,6 +1983,77 @@ class HandValidationWindow(QWidget):
|
||||
cv2.arrowedLine(img, com_px, smooth_end_px, (255, 255, 255), 5, tipLength=0.3)
|
||||
|
||||
|
||||
class CSVGraphWindow(QWidget):
|
||||
def __init__(self, csv_path, boris_anchor=0, video_delay=0, window_sec=15.0):
|
||||
super().__init__()
|
||||
self.setWindowTitle("High-Speed Data Viewer")
|
||||
self.resize(1200, 400)
|
||||
|
||||
self.main_layout = QVBoxLayout(self)
|
||||
self.boris_anchor = boris_anchor
|
||||
self.video_delay = video_delay
|
||||
self.window_sec = window_sec
|
||||
self.current_page = -1
|
||||
|
||||
# 1. Load and downsample for extreme speed
|
||||
self.df = pd.read_csv(csv_path)
|
||||
# Plotting every 4th point is visually identical but 4x faster
|
||||
self.df = self.df.iloc[::4, :].reset_index(drop=True)
|
||||
|
||||
self.time_col = self.df.columns[0]
|
||||
self.anno_col = self.df.columns[1]
|
||||
self.channels = list(self.df.columns[2:])
|
||||
|
||||
# 2. Create PyQtGraph Plot
|
||||
self.plot_widget = pg.PlotWidget()
|
||||
self.plot_widget.setBackground('w') # White background
|
||||
self.main_layout.addWidget(self.plot_widget)
|
||||
|
||||
# 3. Add Shaded Annotations (Boxes)
|
||||
self._add_pg_annotations()
|
||||
|
||||
# 4. Plot Channels (PyQtGraph is very fast at this)
|
||||
for i, ch in enumerate(self.channels):
|
||||
# Using a single color or rotating colors
|
||||
color = pg.intColor(i, hues=10, values=1, alpha=150)
|
||||
self.plot_widget.plot(self.df[self.time_col], self.df[ch], pen=pg.mkPen(color, width=1))
|
||||
|
||||
# 5. Add Playhead (InfiniteLine is much faster than axvline)
|
||||
self.playhead = pg.InfiniteLine(pos=0, angle=90, pen=pg.mkPen('r', width=2))
|
||||
self.plot_widget.addItem(self.playhead)
|
||||
|
||||
self.plot_widget.setXRange(0, self.window_sec)
|
||||
|
||||
def _add_pg_annotations(self):
|
||||
labels = self.df[self.anno_col].fillna("None").astype(str)
|
||||
changes = labels != labels.shift()
|
||||
indices = labels.index[changes].tolist() + [len(labels)]
|
||||
|
||||
for i in range(len(indices) - 1):
|
||||
idx = indices[i]
|
||||
label = labels.iloc[idx]
|
||||
if label not in ["None", "0", "nan"]:
|
||||
t1 = self.df[self.time_col].iloc[idx]
|
||||
t2 = self.df[self.time_col].iloc[indices[i+1]-1]
|
||||
# LinearRegionItem is great for shading
|
||||
region = pg.LinearRegionItem(values=(t1, t2), brush=pg.mkBrush(0, 0, 255, 30), movable=False)
|
||||
self.plot_widget.addItem(region)
|
||||
|
||||
def update_plot(self, video_time):
|
||||
aligned_time = (video_time + self.video_delay - self.boris_anchor) + 5
|
||||
|
||||
# 1. Update playhead position (Instantly fast in PyQtGraph)
|
||||
self.playhead.setPos(aligned_time)
|
||||
|
||||
# 2. Page flipping logic to avoid constant axis movement
|
||||
page_number = int(aligned_time // self.window_sec)
|
||||
if page_number != self.current_page:
|
||||
new_start = page_number * self.window_sec
|
||||
self.plot_widget.setXRange(new_start, new_start + self.window_sec, padding=0)
|
||||
self.current_page = page_number
|
||||
|
||||
|
||||
|
||||
class HandDataInspector(QWidget):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -2716,6 +2884,57 @@ class MainApplication(QMainWindow):
|
||||
file_path,
|
||||
self.extract_frame_and_hands # Assuming this is a method available to MainApplication
|
||||
)
|
||||
|
||||
msg_box = QMessageBox()
|
||||
msg_box.setWindowTitle("Action Required")
|
||||
msg_box.setText("BORIS?")
|
||||
msg_box.setStandardButtons(QMessageBox.Yes | QMessageBox.No)
|
||||
msg_box.setDefaultButton(QMessageBox.Yes)
|
||||
|
||||
response = msg_box.exec()
|
||||
|
||||
if response == QMessageBox.Yes:
|
||||
# 3. Open File Dialog
|
||||
self.selected_boris_json_path, _ = QFileDialog.getOpenFileName(
|
||||
None,
|
||||
"Select File",
|
||||
"",
|
||||
"All Files (*);;Text Files (*.txt)"
|
||||
)
|
||||
|
||||
try:
|
||||
with open(self.selected_boris_json_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 3. Extract keys from "observation"
|
||||
observations = data.get("observations", {})
|
||||
if not observations:
|
||||
QMessageBox.warning(None, "Error", "No 'observation' key found or it is empty.")
|
||||
return
|
||||
|
||||
keys = list(observations.keys())
|
||||
|
||||
# 4. Open Dropdown Dialog (QInputDialog)
|
||||
selected_key, ok = QInputDialog.getItem(
|
||||
None, "Select Key", "Pick an observation key:", keys, 0, False
|
||||
)
|
||||
|
||||
# 5. Store the result
|
||||
if ok and selected_key:
|
||||
self.selected_boris_key = selected_key
|
||||
print(f"Successfully stored: {self.selected_boris_key}")
|
||||
QMessageBox.information(None, "Success", f"Stored: {self.selected_boris_key}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
QMessageBox.critical(None, "Error", "Invalid JSON format.")
|
||||
except Exception as e:
|
||||
QMessageBox.critical(None, "Error", f"An error occurred: {e}")
|
||||
|
||||
else:
|
||||
print("File selection cancelled.")
|
||||
else:
|
||||
# If 'No' or closed, do nothing
|
||||
print("User declined.")
|
||||
|
||||
# 4. Connect Signals to Main Thread Slots
|
||||
self.worker_thread.observations_loaded.connect(self.on_individual_files_loaded)
|
||||
@@ -3528,16 +3747,20 @@ class MainApplication(QMainWindow):
|
||||
# --- THREAD START ---
|
||||
output_csv = f"{obs_id}_{cam_id}_processed.csv"
|
||||
full_video_path = file_info["path"]
|
||||
json_path = self.selected_boris_json_path
|
||||
selected_obs_key = self.selected_boris_key
|
||||
# ⚠️ Ensure ParticipantProcessor is defined to accept self.observations_root
|
||||
processor = ParticipantProcessor2(
|
||||
obs_id=obs_id,
|
||||
selected_cam_id=cam_id,
|
||||
selected_cam_id=2,
|
||||
selected_hand_idx=selected_hand_idx,
|
||||
video_path=full_video_path, # CRITICAL PATH FIX
|
||||
hide_preview=False, # Assuming default for now
|
||||
output_csv=output_csv,
|
||||
output_dir=self.output_dir,
|
||||
initial_wrists=self.camera_state[(obs_id, cam_id)].get("initial_wrists")
|
||||
initial_wrists=self.camera_state[(obs_id, cam_id)].get("initial_wrists"),
|
||||
json_path=json_path,
|
||||
boris_obs_key=selected_obs_key
|
||||
)
|
||||
|
||||
# Connect signals to the UI elements
|
||||
|
||||
Reference in New Issue
Block a user