diff --git a/.gitignore b/.gitignore index d1e521b..2a532e8 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ data alignnet_model.pth +.venv/ diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..1445aee --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.10.14 diff --git a/alignit/config.py b/alignit/config.py index 15b8ff2..c564460 100644 --- a/alignit/config.py +++ b/alignit/config.py @@ -48,7 +48,7 @@ class ModelConfig: metadata={"help": "Path to save/load trained model"}, ) use_depth_input: bool = field( - default=True, metadata={"help": "Whether to use depth input for the model"} + default=False, metadata={"help": "Whether to use depth input for the model"} ) depth_hidden_dim: int = field( default=128, metadata={"help": "Output dimension of depth CNN"} @@ -92,6 +92,10 @@ class RecordConfig: dataset: DatasetConfig = field(default_factory=DatasetConfig) trajectory: TrajectoryConfig = field(default_factory=TrajectoryConfig) + robot_type: str = field( + default="sim", + metadata={"help": "Robot type: 'sim' for simulation or 'real' for real xArm robot"}, + ) episodes: int = field(default=10, metadata={"help": "Number of episodes to record"}) lin_tol_alignment: float = field( default=0.015, metadata={"help": "Linear tolerance for alignment servo"} @@ -164,7 +168,7 @@ class InferConfig: metadata={"help": "Number of iterations within tolerance before stopping"}, ) rotation_matrix_multiplier: int = field( - default=3, + default=2.0, metadata={ "help": "Number of times to multiply the rotation matrix of relative action in order to speed up convergence" }, diff --git a/alignit/infere.py b/alignit/infere.py index e533162..369ed06 100644 --- a/alignit/infere.py +++ b/alignit/infere.py @@ -1,5 +1,4 @@ import time - import torch import transforms3d as t3d import numpy as np @@ -10,8 +9,12 @@ from alignit.utils.zhou import sixd_se3 from alignit.utils.tfs import print_pose, are_tfs_close from alignit.robots.xarmsim import XarmSim -from alignit.robots.xarm import Xarm +Xarm = None +try: + from alignit.robots.xarm import Xarm +except ImportError: + pass @draccus.wrap() def main(cfg: InferConfig): @@ -19,6 +22,7 @@ def main(cfg: InferConfig): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") + # Initialize Model net = AlignNet( backbone_name=cfg.model.backbone, backbone_weights=cfg.model.backbone_weights, @@ -33,99 +37,146 @@ def main(cfg: InferConfig): net.to(device) net.eval() + # Initialize Robot robot = XarmSim() - - start_pose = t3d.affines.compose( - [0.23, 0, 0.25], t3d.euler.euler2mat(np.pi, 0, 0), [1, 1, 1] - ) - robot.servo_to_pose(start_pose, lin_tol=1e-2) - iteration = 0 - iterations_within_tolerance = 0 + + num_alignments = getattr(cfg, 'num_alignments', 5) ang_tol_rad = np.deg2rad(cfg.ang_tolerance) - try: - while True: - observation = robot.get_observation() - rgb_image = observation["rgb"].astype(np.float32) / 255.0 - depth_image = observation["depth"].astype(np.float32) - print( - "Min/Max depth,mean (raw):", - observation["depth"].min(), - observation["depth"].max(), - observation["depth"].mean(), - ) - print( - "Min/Max depth,mean (scaled):", - depth_image.min(), - depth_image.max(), - depth_image.mean(), - ) - rgb_image_tensor = ( - torch.from_numpy(np.array(rgb_image)) - .permute(2, 0, 1) # (H, W, C) -> (C, H, W) - .unsqueeze(0) - .to(device) - ) - - depth_image_tensor = ( - torch.from_numpy(np.array(depth_image)) - .unsqueeze(0) # Add channel dimension: (1, H, W) - .unsqueeze(0) # Add batch dimension: (1, 1, H, W) - .to(device) - ) - rgb_images_batch = rgb_image_tensor.unsqueeze(1) - depth_images_batch = depth_image_tensor.unsqueeze(1) - - with torch.no_grad(): - relative_action = net(rgb_images_batch, depth_images=depth_images_batch) - relative_action = relative_action.squeeze(0).cpu().numpy() - relative_action = sixd_se3(relative_action) - - if cfg.debug_output: - print_pose(relative_action) - - relative_action[:3, :3] = np.linalg.matrix_power( - relative_action[:3, :3], cfg.rotation_matrix_multiplier - ) - if are_tfs_close( - relative_action, lin_tol=cfg.lin_tolerance, ang_tol=ang_tol_rad - ): - iterations_within_tolerance += 1 - else: - iterations_within_tolerance = 0 + alignment_results = [] + + # Safety limit: Total attempts allowed before declaring a trial "Failed" + MAX_TOTAL_STEPS = 1000 + + print(f"\nRunning {num_alignments} alignment trials...\n") + + for alignment_trial in range(num_alignments): + print(f"\n{'='*60}") + print(f"Alignment Trial {alignment_trial + 1}/{num_alignments}") + print(f"{'='*60}") + + # 1. Randomize Start Pose + start_pose = t3d.affines.compose( + [np.random.uniform(0.15, 0.30), + np.random.uniform(-0.15, 0.15), + np.random.uniform(0.20, 0.35)], + t3d.euler.euler2mat(np.pi + np.random.uniform(-0.3, 0.3), + np.random.uniform(-0.3, 0.3), + np.random.uniform(-np.pi, np.pi)), + [1, 1, 1] + ) + robot.servo_to_pose(start_pose, lin_tol=1e-2, ang_tol=0.1) + + iteration = 0 + iterations_within_tolerance = 0 + trial_data = [] + + try: + while True: + # 2. Get Observation and Preprocess + observation = robot.get_observation() + rgb_np = observation["rgb"].astype(np.float32) / 255.0 + + if rgb_np.ndim == 2: + rgb_np = np.expand_dims(rgb_np, axis=-1) + if rgb_np.shape[-1] == 1: + rgb_np = np.repeat(rgb_np, 3, axis=-1) + + rgb_images_batch = ( + torch.from_numpy(rgb_np) + .permute(2, 0, 1) + .unsqueeze(0).unsqueeze(0) + .to(device) + ) - print(relative_action) - target_pose = robot.pose() @ relative_action - iteration += 1 - action = { - "pose": target_pose, - "gripper.pos": 1.0, - } - robot.send_action(action) - if iterations_within_tolerance >= cfg.max_iterations: - print(f"Reached maximum iterations ({cfg.max_iterations}) - stopping.") - print("Moving robot to final pose.") + # 3. Model Inference + with torch.no_grad(): + relative_action = net(rgb_images_batch) + + relative_action = relative_action.squeeze(0).cpu().numpy() + relative_action = sixd_se3(relative_action) + + # 4. Calculate error magnitude from ORIGINAL unscaled action + # This represents the actual residual error from the network + error_magnitude = np.linalg.norm(relative_action[:3, 3]) + + # 5. Check alignment based on ORIGINAL unscaled action + # This is the true convergence check - not affected by scaling + if are_tfs_close( + relative_action, lin_tol=cfg.lin_tolerance, ang_tol=ang_tol_rad + ): + iterations_within_tolerance += 1 + print(f"Step {iteration}: Within Tol ({iterations_within_tolerance}/{cfg.debouncing_count}) [error: {error_magnitude:.6f}]") + else: + # Reset if we move outside the tolerance zone + iterations_within_tolerance = 0 + print(f"Step {iteration}: Adjusting... [error: {error_magnitude:.6f}]") + + # 6. Scale action for robot movement (separate from convergence check) + # Error-magnitude based scaling: bigger errors → bigger movements, small errors → small movements + min_scale = 1.0 + error_scale = max(error_magnitude, min_scale) + + scaled_action = relative_action.copy() + scaled_action[:3, 3] *= error_scale # Scale translation + scaled_action[:3, :3] = np.linalg.matrix_power( + scaled_action[:3, :3], int(cfg.rotation_matrix_multiplier) + ) # Apply rotation power scaling + + # 7. Move Robot current_pose = robot.pose() - gripper_z_offset = np.array( - [ + target_pose = current_pose @ scaled_action + iteration += 1 + + robot.send_action({"pose": target_pose, "gripper.pos": 1.0}) + + # 8. Exit Conditions + + # SUCCESS: Remained close for the required number of iterations + if iterations_within_tolerance >= cfg.max_iterations: + print(f"✓ Converged after {iteration} total steps.") + + # Finalize: Move to height offset and close gripper + gripper_z_offset = np.array([ [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, cfg.manual_height], [0, 0, 0, 1], - ] - ) - offset_pose = current_pose @ gripper_z_offset - robot.servo_to_pose(pose=offset_pose) - robot.close_gripper() - robot.gripper_off() - - break - - time.sleep(10.0) - except KeyboardInterrupt: - print("\nExiting...") - + ]) + robot.servo_to_pose(pose=robot.pose() @ gripper_z_offset) + #robot.close_gripper() + #robot.gripper_off() + + alignment_results.append({ + "trial": alignment_trial + 1, + "success": True, + "iterations": iteration, + }) + break + + # FAILURE: Safety timeout reached + if iteration >= MAX_TOTAL_STEPS: + print(f"✗ Failed: Timeout reached ({MAX_TOTAL_STEPS} steps).") + alignment_results.append({ + "trial": alignment_trial + 1, + "success": False, + "iterations": iteration, + }) + break + + except KeyboardInterrupt: + print("\nTrial interrupted by user.") + break + + # Summary Statistics + print(f"\n{'='*60}") + print(f"INFERENCE SUMMARY") + print(f"{'='*60}") + successful = sum(1 for r in alignment_results if r["success"]) + print(f"Success Rate: {successful}/{len(alignment_results)} ({successful*100//max(1, len(alignment_results))}%)") + if alignment_results: + print(f"Avg Steps to Converge: {np.mean([r['iterations'] for r in alignment_results]):.1f}") + robot.disconnect() - if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/alignit/losses.py b/alignit/losses.py new file mode 100644 index 0000000..0264667 --- /dev/null +++ b/alignit/losses.py @@ -0,0 +1,21 @@ +import torch +import torch.nn as nn + +class InversePredictionWeightedLoss(nn.Module): + def __init__(self, epsilon: float = 1e-6): + super().__init__() + self.epsilon = epsilon + + def forward(self, pred, target): + pred_pos = pred[:, :3] + pred_rot = pred[:, 3:] + target_pos = target[:, :3] + target_rot = target[:, 3:] + + weights_pos = 1.0 / (torch.abs(pred_pos) + self.epsilon) + pos_loss = (weights_pos * (pred_pos - target_pos) ** 2).mean() + + rot_loss = torch.mean((pred_rot - target_rot) ** 2) + + loss = pos_loss + rot_loss + return loss \ No newline at end of file diff --git a/alignit/record.py b/alignit/record.py index 00e897b..d1ae283 100644 --- a/alignit/record.py +++ b/alignit/record.py @@ -21,6 +21,8 @@ from alignit.config import RecordConfig + + def generate_spiral_trajectory(start_pose, cfg): """Generate spiral trajectory using configuration parameters.""" trajectory = [] @@ -76,23 +78,21 @@ def generate_spiral_trajectory(start_pose, cfg): return trajectory - @draccus.wrap() def main(cfg: RecordConfig): """Record alignment dataset using configuration parameters.""" robot = XarmSim() + features = Features( { "images": Sequence(Image()), "action": Sequence(Value("float32")), - "depth": Sequence(Image()), } ) for episode in range(cfg.episodes): pose_start, pose_alignment_target = robot.reset() trajectory = generate_spiral_trajectory(pose_start, cfg.trajectory) - pose = robot.pose() frames = [] for pose in trajectory: robot.servo_to_pose( @@ -104,18 +104,15 @@ def main(cfg: RecordConfig): action_sixd = se3_sixd(action_pose) observation = robot.get_observation() - print(observation.keys()) frame = { "images": [observation["rgb"].copy()], "action": action_sixd, - "depth": [observation["depth"].copy()], } frames.append(frame) print(f"Episode {episode+1} completed with {len(frames)} frames.") episode_dataset = Dataset.from_list(frames, features=features) - # 2. Load existing dataset if available if os.path.exists(cfg.dataset.path): existing_dataset = load_from_disk(cfg.dataset.path) existing_dataset = existing_dataset.cast(features) @@ -123,17 +120,15 @@ def main(cfg: RecordConfig): else: combined_dataset = episode_dataset - # 3. Save to TEMPORARY location first (avoid self-overwrite) temp_path = f"{cfg.dataset.path}_temp" combined_dataset.save_to_disk(temp_path) - # 4. Atomic replacement (only after successful save) if os.path.exists(cfg.dataset.path): - shutil.rmtree(cfg.dataset.path) # Remove old version - shutil.move(temp_path, cfg.dataset.path) # Move new version into place + shutil.rmtree(cfg.dataset.path) + shutil.move(temp_path, cfg.dataset.path) robot.disconnect() if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/alignit/robots/xarm.py b/alignit/robots/xarm.py index 39c69fd..80f078c 100644 --- a/alignit/robots/xarm.py +++ b/alignit/robots/xarm.py @@ -2,9 +2,10 @@ import numpy as np import transforms3d as t3d -from lerobot.cameras.realsense import RealSenseCamera, RealSenseCameraConfig -from lerobot_xarm.xarm import Xarm as LeXarm -from lerobot_xarm.config import XarmConfig +from lerobot.cameras.realsense.camera_realsense import RealSenseCamera +from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig +from lerobot_robot_xarm.xarm import Xarm as LeXarm +from lerobot_robot_xarm.config_xarm import XarmConfig from alignit.robots.robot import Robot from alignit.utils.tfs import are_tfs_close @@ -37,13 +38,10 @@ def get_intrinsics(self): return self.camera.get_intrinsics() def get_observation(self): - rgb_image, depth_image, acquisition_time = self.camera.async_read() - depth_array_clipped = np.clip(np.array(depth_image), a_min=0, a_max=1000) - depth_image = np.array(depth_array_clipped) / 1000.0 + rgb_image = self.camera.async_read() return { "rgb": rgb_image, - "depth": depth_image, } def disconnect(self): diff --git a/alignit/robots/xarmsim/__init__.py b/alignit/robots/xarmsim/__init__.py index 7b58a59..dbe72d8 100644 --- a/alignit/robots/xarmsim/__init__.py +++ b/alignit/robots/xarmsim/__init__.py @@ -98,6 +98,14 @@ def reset(self): random_pos, t3d.euler.euler2mat(roll, pitch, yaw), [1, 1, 1] ) self._set_object_pose("pickup_object", pose) + # Lock the object joint to prevent it from falling/moving + self._lock_object_joint("pickup_object") + # Disable collisions for the pickup object so it doesn't get pushed + # by the robot during alignment trials. + try: + self._disable_object_collisions("pickup_object") + except Exception: + pass pose1 = self._get_object_pose() pose_start = pose1 @ t3d.affines.compose( [0, 0, -0.1], t3d.euler.euler2mat(0, 0, 0), [1, 1, 1] @@ -125,6 +133,60 @@ def _set_object_pose(self, object_name, pose_matrix): self.data.qvel[qvel_adr : qvel_adr + 6] = 0 mj.mj_forward(self.model, self.data) + def _lock_object_joint(self, object_name: str): + """Lock object joint to prevent any movement (translation or rotation).""" + try: + body_id = self.model.body(object_name).id + joint_id = self.model.body_jntadr[body_id] + + if joint_id >= 0: + qvel_adr = self.model.jnt_dofadr[joint_id] + # Lock all 6 DOFs (3 translation + 3 rotation for free joint) + self.data.qvel[qvel_adr : qvel_adr + 6] = 0 + + # Set very high damping on the joint to resist any motion + try: + dof_adr = self.model.jnt_dofadr[joint_id] + for i in range(6): + self.model.dof_damping[dof_adr + i] = 100.0 # Very high damping + except Exception: + pass + except Exception: + pass + + def _disable_object_collisions(self, object_name: str): + """Disable collisions AND gravity for object to keep it stationary. + + This prevents the object from being pushed or reacting to contacts + when the robot touches it during benchmarking. + """ + try: + body_id = self.model.body(object_name).id + except Exception: + return + + # 1. Disable collisions: set contact type and affinity to 0 + start = int(self.model.body_geomadr[body_id]) + count = int(self.model.body_geomnum[body_id]) + for i in range(start, start + count): + try: + self.model.geom_contype[i] = 0 + self.model.geom_conaffinity[i] = 0 + except Exception: + pass + + # 2. Disable gravity on object by setting mass to near-zero + try: + self.model.body_mass[body_id] = 0.001 + except Exception: + pass + + # 3. Set inertia to near-zero to prevent rotation from contacts + try: + self.model.body_inertia[body_id] = [0.0001, 0.0001, 0.0001] + except Exception: + pass + def close_gripper(self): self._set_gripper_position(self.gripper_close_pos) @@ -159,6 +221,9 @@ def send_action(self, action): self.data.ctrl[self.mujoco_actuator_ids] = target_joint_qpos_for_mujoco mj.mj_step(self.model, self.data) + + # Lock object joint every frame to prevent any motion + self._lock_object_joint("pickup_object") self.viewer.sync() return True @@ -191,15 +256,9 @@ def get_observation(self): name = mujoco.mj_id2name(self.model, mujoco.mjtObj.mjOBJ_CAMERA, i) self.renderer.update_scene(self.data, camera=name) image = self.renderer.render() - self.renderer.enable_depth_rendering() - self.renderer.update_scene(self.data, camera=name) - image_depth = self.renderer.render() - self.renderer.disable_depth_rendering() # TODO: Handle multiple cameras obs["rgb"] = image[:, :, ::-1] - obs["depth"] = image_depth - obs["depth"] = np.clip(obs["depth"], 0, 1) return obs diff --git a/alignit/train.py b/alignit/train.py index f1c2266..b59e817 100644 --- a/alignit/train.py +++ b/alignit/train.py @@ -1,7 +1,6 @@ import torch from torch.utils.data import DataLoader from torch.optim import Adam -from torch.nn import MSELoss from tqdm import tqdm from datasets import load_from_disk from torchvision import transforms @@ -10,15 +9,14 @@ from alignit.config import TrainConfig from alignit.models.alignnet import AlignNet +from alignit.losses import InversePredictionWeightedLoss def collate_fn(batch): images = [item["images"] for item in batch] - depth_images = [item.get("depth", None) for item in batch] actions = [item["action"] for item in batch] return { "images": images, - "depth_images": depth_images, "action": torch.tensor(actions, dtype=torch.float32), } @@ -50,14 +48,13 @@ def main(cfg: TrainConfig): ) optimizer = Adam(net.parameters(), lr=cfg.learning_rate) - criterion = MSELoss() + criterion = InversePredictionWeightedLoss(epsilon=0.01) net.train() for epoch in range(cfg.epochs): total_loss = 0 for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"): images = batch["images"] - depth_images_pil = batch["depth_images"] actions = batch["action"].to(device) batch_rgb_tensors = [] @@ -76,32 +73,8 @@ def main(cfg: TrainConfig): batch_rgb_tensors = torch.stack(batch_rgb_tensors, dim=0).to(device) - batch_depth_tensors = None - if cfg.model.use_depth_input: - batch_depth_tensors = [] - for depth_sequence in depth_images_pil: - if depth_sequence is None: - raise ValueError( - "Depth images expected but not found when use_depth_input=True" - ) - - depth_sequence_processed = [] - for d_img in depth_sequence: - depth_array = np.array(d_img) - depth_tensor = torch.from_numpy(depth_array).float() - depth_tensor = depth_tensor.unsqueeze(0) - depth_sequence_processed.append(depth_tensor) - - stacked_depth = torch.stack(depth_sequence_processed, dim=0) - batch_depth_tensors.append(stacked_depth) - - batch_depth_tensors = torch.stack(batch_depth_tensors, dim=0).to(device) - optimizer.zero_grad() - if cfg.model.use_depth_input: - outputs = net(batch_rgb_tensors, depth_images=batch_depth_tensors) - else: - outputs = net(batch_rgb_tensors) + outputs = net(batch_rgb_tensors) loss = criterion(outputs, actions) loss.backward() diff --git a/alignit/visualize.py b/alignit/visualize.py index 60d83b8..70b23f8 100644 --- a/alignit/visualize.py +++ b/alignit/visualize.py @@ -13,18 +13,16 @@ def visualize(cfg: VisualizeConfig): def get_data(index): item = dataset[index] image = item["images"][0] - depth = item["depth"][0] action_sixd = item["action"] action = sixd_se3(action_sixd) label = get_pose_str(action, degrees=True) - return image, depth, label + return image, label gr.Interface( fn=get_data, inputs=gr.Slider(0, len(dataset) - 1, step=1, label="Index", interactive=True), outputs=[ gr.Image(type="pil", label="Image"), - gr.Image(type="pil", label="Depth Image"), gr.Text(label="Label"), ], title="Dataset Image Viewer", diff --git a/benchmark/__init__.py b/benchmark/__init__.py new file mode 100644 index 0000000..46dd880 --- /dev/null +++ b/benchmark/__init__.py @@ -0,0 +1 @@ +"""Benchmark module for alignment evaluation.""" diff --git a/benchmark/metrics.py b/benchmark/metrics.py new file mode 100644 index 0000000..998ceee --- /dev/null +++ b/benchmark/metrics.py @@ -0,0 +1,116 @@ +"""Metrics utilities for benchmark evaluation.""" + +import numpy as np + + +def compute_translation_error(pose_actual, pose_target): + """ + Compute translation error in mm. + + Args: + pose_actual: 4x4 transformation matrix (actual pose) + pose_target: 4x4 transformation matrix (target pose) + + Returns: + Translation error in mm + """ + t_actual = pose_actual[:3, 3] + t_target = pose_target[:3, 3] + error_m = np.linalg.norm(t_actual - t_target) + return error_m * 1000 # Convert to mm + + +def compute_rotation_error(pose_actual, pose_target): + """ + Compute rotation error in degrees using angle-axis representation. + + Args: + pose_actual: 4x4 transformation matrix (actual pose) + pose_target: 4x4 transformation matrix (target pose) + + Returns: + Rotation error in degrees + """ + R_actual = pose_actual[:3, :3] + R_target = pose_target[:3, :3] + + # Compute relative rotation + R_rel = R_actual @ R_target.T + + # Convert to angle-axis and extract angle + rot = R.from_matrix(R_rel) + angle_rad = np.arccos(np.clip((np.trace(R_rel) - 1) / 2, -1, 1)) + angle_deg = np.degrees(angle_rad) + + return angle_deg + + +def compute_pose_error(pose_actual, pose_target, trans_weight=1.0, rot_weight=1.0): + """ + Compute combined pose error. + + Args: + pose_actual: 4x4 transformation matrix + pose_target: 4x4 transformation matrix + trans_weight: Weight for translation error (mm) + rot_weight: Weight for rotation error (degrees) + + Returns: + dict with 'translation', 'rotation', and 'combined' errors + """ + trans_err = compute_translation_error(pose_actual, pose_target) + rot_err = compute_rotation_error(pose_actual, pose_target) + + # Normalize and combine + combined = trans_weight * trans_err + rot_weight * rot_err + + return { + "translation_mm": trans_err, + "rotation_deg": rot_err, + "combined": combined, + } + + +def check_convergence(error, trans_tol_mm=5.0, rot_tol_deg=5.0): + """ + Check if pose error is within acceptable tolerance. + + Args: + error: dict from compute_pose_error() + trans_tol_mm: Translation tolerance in mm + rot_tol_deg: Rotation tolerance in degrees + + Returns: + bool: True if converged + """ + return (error["translation_mm"] <= trans_tol_mm and + error["rotation_deg"] <= rot_tol_deg) + + +def compute_statistics(errors): + """ + Compute statistics over a list of errors. + + Args: + errors: list of error dicts from compute_pose_error() + + Returns: + dict with mean, std, min, max for each error type + """ + trans_errors = [e["translation_mm"] for e in errors] + rot_errors = [e["rotation_deg"] for e in errors] + + return { + "translation": { + "mean_mm": np.mean(trans_errors), + "std_mm": np.std(trans_errors), + "min_mm": np.min(trans_errors), + "max_mm": np.max(trans_errors), + }, + "rotation": { + "mean_deg": np.mean(rot_errors), + "std_deg": np.std(rot_errors), + "min_deg": np.min(rot_errors), + "max_deg": np.max(rot_errors), + }, + } diff --git a/benchmark/run_benchmark.py b/benchmark/run_benchmark.py new file mode 100644 index 0000000..4bffb1d --- /dev/null +++ b/benchmark/run_benchmark.py @@ -0,0 +1,261 @@ +"""Benchmark runner for alignment in simulation.""" + +import json +import time +from datetime import datetime +from pathlib import Path + +import numpy as np +import torch +import transforms3d as t3d +import draccus + +from alignit.config import InferConfig +from alignit.models.alignnet import AlignNet +from alignit.utils.zhou import sixd_se3 +from alignit.utils.tfs import are_tfs_close +from alignit.robots.xarmsim import XarmSim +from benchmark.metrics import ( + compute_pose_error, + check_convergence, + compute_statistics, +) + + +@draccus.wrap() +def run_benchmark(cfg: InferConfig, num_trials: int = 5, max_iterations: int = 50): + """ + Run alignment benchmark in simulation. + + Args: + cfg: InferConfig with model settings + num_trials: Number of random alignment trials to run + max_iterations: Max iterations per trial + + Returns: + dict with benchmark results + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # Load model + net = AlignNet( + backbone_name=cfg.model.backbone, + backbone_weights=cfg.model.backbone_weights, + use_vector_input=cfg.model.use_vector_input, + fc_layers=cfg.model.fc_layers, + vector_hidden_dim=cfg.model.vector_hidden_dim, + output_dim=cfg.model.output_dim, + feature_agg=cfg.model.feature_agg, + use_depth_input=cfg.model.use_depth_input, + ) + net.load_state_dict(torch.load(cfg.model.path, map_location=device)) + net.to(device) + net.eval() + + robot = XarmSim() + + # Initialize results + results = { + "timestamp": datetime.now().isoformat(), + "model_path": cfg.model.path, + "num_trials": num_trials, + "max_iterations": max_iterations, + "trials": [], + "summary": {}, + } + + # Run benchmark trials + print(f"\n{'='*60}") + print(f"Running {num_trials} alignment trials...") + print(f"{'='*60}\n") + + all_final_errors = [] + convergence_count = 0 + + for trial_idx in range(num_trials): + print(f"\n[Trial {trial_idx + 1}/{num_trials}]") + + # Random start pose + start_pose = t3d.affines.compose( + [np.random.uniform(0.15, 0.30), + np.random.uniform(-0.15, 0.15), + np.random.uniform(0.20, 0.35)], + t3d.euler.euler2mat(np.pi + np.random.uniform(-0.3, 0.3), + np.random.uniform(-0.3, 0.3), + np.random.uniform(-np.pi, np.pi)), + [1, 1, 1] + ) + + # Random target pose (within reasonable bounds) + target_pose = t3d.affines.compose( + [np.random.uniform(0.15, 0.30), + np.random.uniform(-0.15, 0.15), + np.random.uniform(0.20, 0.35)], + t3d.euler.euler2mat(np.pi + np.random.uniform(-0.3, 0.3), + np.random.uniform(-0.3, 0.3), + np.random.uniform(-np.pi, np.pi)), + [1, 1, 1] + ) + + robot.servo_to_pose(start_pose, lin_tol=1e-2, ang_tol=0.1) + + trial_data = { + "start_pose": start_pose.tolist(), + "target_pose": target_pose.tolist(), + "iterations": [], + "converged": False, + } + + iteration = 0 + converged = False + + try: + while iteration < max_iterations: + observation = robot.get_observation() + rgb_np = observation["rgb"].astype(np.float32) / 255.0 + + # Ensure 3 channels + if rgb_np.ndim == 2: + rgb_np = np.expand_dims(rgb_np, axis=-1) + if rgb_np.shape[-1] == 1: + rgb_np = np.repeat(rgb_np, 3, axis=-1) + + rgb_images_batch = ( + torch.from_numpy(rgb_np) + .permute(2, 0, 1) + .unsqueeze(0) + .unsqueeze(0) + .to(device) + ) + + with torch.no_grad(): + relative_action = net(rgb_images_batch) + + relative_action = relative_action.squeeze(0).cpu().numpy() + relative_action = sixd_se3(relative_action) + + # Apply rotation multiplier for more stable convergence + relative_action[:3, :3] = np.linalg.matrix_power( + relative_action[:3, :3], cfg.rotation_matrix_multiplier + ) + + # Compute error before this step + current_pose = robot.pose() + error = compute_pose_error(current_pose, target_pose) + + # Check convergence + is_converged = check_convergence( + error, + trans_tol_mm=cfg.lin_tolerance * 1000, + rot_tol_deg=np.degrees(cfg.ang_tolerance) + ) + + trial_data["iterations"].append({ + "iteration": iteration, + "translation_error_mm": error["translation_mm"], + "rotation_error_deg": error["rotation_deg"], + "converged": is_converged, + }) + + print(f" Iter {iteration + 1:2d}: " + f"Trans={error['translation_mm']:6.2f}mm, " + f"Rot={error['rotation_deg']:6.2f}°, " + f"Converged={is_converged}") + + if is_converged: + converged = True + trial_data["converged"] = True + convergence_count += 1 + break + + # Execute action + target_pose_iter = robot.pose() @ relative_action + robot.servo_to_pose(pose=target_pose_iter, lin_tol=1e-3, ang_tol=1e-2) + iteration += 1 + + except Exception as e: + print(f" Error during trial: {e}") + + # Get final error + final_pose = robot.pose() + final_error = compute_pose_error(final_pose, target_pose) + trial_data["final_error"] = { + "translation_mm": final_error["translation_mm"], + "rotation_deg": final_error["rotation_deg"], + } + trial_data["num_iterations"] = iteration + 1 + + all_final_errors.append(final_error) + results["trials"].append(trial_data) + + print(f" Final: Trans={final_error['translation_mm']:.2f}mm, " + f"Rot={final_error['rotation_deg']:.2f}°, " + f"Iters={trial_data['num_iterations']}") + + # Compute summary statistics + stats = compute_statistics(all_final_errors) + results["summary"] = { + "convergence_rate": convergence_count / num_trials, + "convergence_count": convergence_count, + "translation": stats["translation"], + "rotation": stats["rotation"], + "avg_iterations": np.mean([t["num_iterations"] for t in results["trials"]]), + } + + robot.disconnect() + + return results + + +def save_results(results, output_dir: str = "./benchmark/results"): + """Save benchmark results to JSON.""" + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + result_file = output_path / f"benchmark_{timestamp}.json" + + with open(result_file, "w") as f: + json.dump(results, f, indent=2) + + return result_file + + +def print_summary(results): + """Print benchmark summary to console.""" + print(f"\n{'='*60}") + print(f"BENCHMARK SUMMARY") + print(f"{'='*60}\n") + + summary = results["summary"] + + print(f"Convergence Rate: {summary['convergence_rate']*100:.1f}% " + f"({summary['convergence_count']}/{results['num_trials']})") + print(f"Average Iterations: {summary['avg_iterations']:.1f}") + + print(f"\nTranslation Error (mm):") + print(f" Mean: {summary['translation']['mean_mm']:.2f} ± {summary['translation']['std_mm']:.2f}") + print(f" Range: [{summary['translation']['min_mm']:.2f}, {summary['translation']['max_mm']:.2f}]") + + print(f"\nRotation Error (degrees):") + print(f" Mean: {summary['rotation']['mean_deg']:.2f} ± {summary['rotation']['std_deg']:.2f}") + print(f" Range: [{summary['rotation']['min_deg']:.2f}, {summary['rotation']['max_deg']:.2f}]") + + print(f"\n{'='*60}\n") + + +if __name__ == "__main__": + import sys + + # Parse arguments + num_trials = int(sys.argv[1]) if len(sys.argv) > 1 else 5 + + # Run benchmark + results = run_benchmark(num_trials=num_trials, max_iterations=50) + + # Save and print results + result_file = save_results(results) + print(f"Results saved to: {result_file}") + + print_summary(results)