Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
data
alignnet_model.pth

.venv/
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.10.14
8 changes: 6 additions & 2 deletions alignit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ModelConfig:
metadata={"help": "Path to save/load trained model"},
)
use_depth_input: bool = field(
default=True, metadata={"help": "Whether to use depth input for the model"}
default=False, metadata={"help": "Whether to use depth input for the model"}
)
depth_hidden_dim: int = field(
default=128, metadata={"help": "Output dimension of depth CNN"}
Expand Down Expand Up @@ -92,6 +92,10 @@ class RecordConfig:

dataset: DatasetConfig = field(default_factory=DatasetConfig)
trajectory: TrajectoryConfig = field(default_factory=TrajectoryConfig)
robot_type: str = field(
default="sim",
metadata={"help": "Robot type: 'sim' for simulation or 'real' for real xArm robot"},
)
episodes: int = field(default=10, metadata={"help": "Number of episodes to record"})
lin_tol_alignment: float = field(
default=0.015, metadata={"help": "Linear tolerance for alignment servo"}
Expand Down Expand Up @@ -164,7 +168,7 @@ class InferConfig:
metadata={"help": "Number of iterations within tolerance before stopping"},
)
rotation_matrix_multiplier: int = field(
default=3,
default=2.0,
metadata={
"help": "Number of times to multiply the rotation matrix of relative action in order to speed up convergence"
},
Expand Down
225 changes: 138 additions & 87 deletions alignit/infere.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import time

import torch
import transforms3d as t3d
import numpy as np
Expand All @@ -10,15 +9,20 @@
from alignit.utils.zhou import sixd_se3
from alignit.utils.tfs import print_pose, are_tfs_close
from alignit.robots.xarmsim import XarmSim
from alignit.robots.xarm import Xarm

Xarm = None
try:
from alignit.robots.xarm import Xarm
except ImportError:
pass

@draccus.wrap()
def main(cfg: InferConfig):
"""Run inference/alignment using configuration parameters."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize Model
net = AlignNet(
backbone_name=cfg.model.backbone,
backbone_weights=cfg.model.backbone_weights,
Expand All @@ -33,99 +37,146 @@ def main(cfg: InferConfig):
net.to(device)
net.eval()

# Initialize Robot
robot = XarmSim()

start_pose = t3d.affines.compose(
[0.23, 0, 0.25], t3d.euler.euler2mat(np.pi, 0, 0), [1, 1, 1]
)
robot.servo_to_pose(start_pose, lin_tol=1e-2)
iteration = 0
iterations_within_tolerance = 0

num_alignments = getattr(cfg, 'num_alignments', 5)
ang_tol_rad = np.deg2rad(cfg.ang_tolerance)
try:
while True:
observation = robot.get_observation()
rgb_image = observation["rgb"].astype(np.float32) / 255.0
depth_image = observation["depth"].astype(np.float32)
print(
"Min/Max depth,mean (raw):",
observation["depth"].min(),
observation["depth"].max(),
observation["depth"].mean(),
)
print(
"Min/Max depth,mean (scaled):",
depth_image.min(),
depth_image.max(),
depth_image.mean(),
)
rgb_image_tensor = (
torch.from_numpy(np.array(rgb_image))
.permute(2, 0, 1) # (H, W, C) -> (C, H, W)
.unsqueeze(0)
.to(device)
)

depth_image_tensor = (
torch.from_numpy(np.array(depth_image))
.unsqueeze(0) # Add channel dimension: (1, H, W)
.unsqueeze(0) # Add batch dimension: (1, 1, H, W)
.to(device)
)
rgb_images_batch = rgb_image_tensor.unsqueeze(1)
depth_images_batch = depth_image_tensor.unsqueeze(1)

with torch.no_grad():
relative_action = net(rgb_images_batch, depth_images=depth_images_batch)
relative_action = relative_action.squeeze(0).cpu().numpy()
relative_action = sixd_se3(relative_action)

if cfg.debug_output:
print_pose(relative_action)

relative_action[:3, :3] = np.linalg.matrix_power(
relative_action[:3, :3], cfg.rotation_matrix_multiplier
)
if are_tfs_close(
relative_action, lin_tol=cfg.lin_tolerance, ang_tol=ang_tol_rad
):
iterations_within_tolerance += 1
else:
iterations_within_tolerance = 0
alignment_results = []

# Safety limit: Total attempts allowed before declaring a trial "Failed"
MAX_TOTAL_STEPS = 1000

print(f"\nRunning {num_alignments} alignment trials...\n")

for alignment_trial in range(num_alignments):
print(f"\n{'='*60}")
print(f"Alignment Trial {alignment_trial + 1}/{num_alignments}")
print(f"{'='*60}")

# 1. Randomize Start Pose
start_pose = t3d.affines.compose(
[np.random.uniform(0.15, 0.30),
np.random.uniform(-0.15, 0.15),
np.random.uniform(0.20, 0.35)],
t3d.euler.euler2mat(np.pi + np.random.uniform(-0.3, 0.3),
np.random.uniform(-0.3, 0.3),
np.random.uniform(-np.pi, np.pi)),
[1, 1, 1]
)
robot.servo_to_pose(start_pose, lin_tol=1e-2, ang_tol=0.1)

iteration = 0
iterations_within_tolerance = 0
trial_data = []

try:
while True:
# 2. Get Observation and Preprocess
observation = robot.get_observation()
rgb_np = observation["rgb"].astype(np.float32) / 255.0

if rgb_np.ndim == 2:
rgb_np = np.expand_dims(rgb_np, axis=-1)
if rgb_np.shape[-1] == 1:
rgb_np = np.repeat(rgb_np, 3, axis=-1)

rgb_images_batch = (
torch.from_numpy(rgb_np)
.permute(2, 0, 1)
.unsqueeze(0).unsqueeze(0)
.to(device)
)

print(relative_action)
target_pose = robot.pose() @ relative_action
iteration += 1
action = {
"pose": target_pose,
"gripper.pos": 1.0,
}
robot.send_action(action)
if iterations_within_tolerance >= cfg.max_iterations:
print(f"Reached maximum iterations ({cfg.max_iterations}) - stopping.")
print("Moving robot to final pose.")
# 3. Model Inference
with torch.no_grad():
relative_action = net(rgb_images_batch)

relative_action = relative_action.squeeze(0).cpu().numpy()
relative_action = sixd_se3(relative_action)

# 4. Calculate error magnitude from ORIGINAL unscaled action
# This represents the actual residual error from the network
error_magnitude = np.linalg.norm(relative_action[:3, 3])

# 5. Check alignment based on ORIGINAL unscaled action
# This is the true convergence check - not affected by scaling
if are_tfs_close(
relative_action, lin_tol=cfg.lin_tolerance, ang_tol=ang_tol_rad
):
iterations_within_tolerance += 1
print(f"Step {iteration}: Within Tol ({iterations_within_tolerance}/{cfg.debouncing_count}) [error: {error_magnitude:.6f}]")
else:
# Reset if we move outside the tolerance zone
iterations_within_tolerance = 0
print(f"Step {iteration}: Adjusting... [error: {error_magnitude:.6f}]")

# 6. Scale action for robot movement (separate from convergence check)
# Error-magnitude based scaling: bigger errors → bigger movements, small errors → small movements
min_scale = 1.0
error_scale = max(error_magnitude, min_scale)

scaled_action = relative_action.copy()
scaled_action[:3, 3] *= error_scale # Scale translation
scaled_action[:3, :3] = np.linalg.matrix_power(
scaled_action[:3, :3], int(cfg.rotation_matrix_multiplier)
) # Apply rotation power scaling

# 7. Move Robot
current_pose = robot.pose()
gripper_z_offset = np.array(
[
target_pose = current_pose @ scaled_action
iteration += 1

robot.send_action({"pose": target_pose, "gripper.pos": 1.0})

# 8. Exit Conditions

# SUCCESS: Remained close for the required number of iterations
if iterations_within_tolerance >= cfg.max_iterations:
print(f"✓ Converged after {iteration} total steps.")

# Finalize: Move to height offset and close gripper
gripper_z_offset = np.array([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, cfg.manual_height],
[0, 0, 0, 1],
]
)
offset_pose = current_pose @ gripper_z_offset
robot.servo_to_pose(pose=offset_pose)
robot.close_gripper()
robot.gripper_off()

break

time.sleep(10.0)
except KeyboardInterrupt:
print("\nExiting...")

])
robot.servo_to_pose(pose=robot.pose() @ gripper_z_offset)
#robot.close_gripper()
#robot.gripper_off()

alignment_results.append({
"trial": alignment_trial + 1,
"success": True,
"iterations": iteration,
})
break

# FAILURE: Safety timeout reached
if iteration >= MAX_TOTAL_STEPS:
print(f"✗ Failed: Timeout reached ({MAX_TOTAL_STEPS} steps).")
alignment_results.append({
"trial": alignment_trial + 1,
"success": False,
"iterations": iteration,
})
break

except KeyboardInterrupt:
print("\nTrial interrupted by user.")
break

# Summary Statistics
print(f"\n{'='*60}")
print(f"INFERENCE SUMMARY")
print(f"{'='*60}")
successful = sum(1 for r in alignment_results if r["success"])
print(f"Success Rate: {successful}/{len(alignment_results)} ({successful*100//max(1, len(alignment_results))}%)")
if alignment_results:
print(f"Avg Steps to Converge: {np.mean([r['iterations'] for r in alignment_results]):.1f}")

robot.disconnect()


if __name__ == "__main__":
main()
main()
21 changes: 21 additions & 0 deletions alignit/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
import torch.nn as nn

class InversePredictionWeightedLoss(nn.Module):
def __init__(self, epsilon: float = 1e-6):
super().__init__()
self.epsilon = epsilon

def forward(self, pred, target):
pred_pos = pred[:, :3]
pred_rot = pred[:, 3:]
target_pos = target[:, :3]
target_rot = target[:, 3:]

weights_pos = 1.0 / (torch.abs(pred_pos) + self.epsilon)
pos_loss = (weights_pos * (pred_pos - target_pos) ** 2).mean()

rot_loss = torch.mean((pred_rot - target_rot) ** 2)

loss = pos_loss + rot_loss
return loss
17 changes: 6 additions & 11 deletions alignit/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from alignit.config import RecordConfig




def generate_spiral_trajectory(start_pose, cfg):
"""Generate spiral trajectory using configuration parameters."""
trajectory = []
Expand Down Expand Up @@ -76,23 +78,21 @@ def generate_spiral_trajectory(start_pose, cfg):

return trajectory


@draccus.wrap()
def main(cfg: RecordConfig):
"""Record alignment dataset using configuration parameters."""
robot = XarmSim()

features = Features(
{
"images": Sequence(Image()),
"action": Sequence(Value("float32")),
"depth": Sequence(Image()),
}
)

for episode in range(cfg.episodes):
pose_start, pose_alignment_target = robot.reset()
trajectory = generate_spiral_trajectory(pose_start, cfg.trajectory)
pose = robot.pose()
frames = []
for pose in trajectory:
robot.servo_to_pose(
Expand All @@ -104,36 +104,31 @@ def main(cfg: RecordConfig):
action_sixd = se3_sixd(action_pose)

observation = robot.get_observation()
print(observation.keys())
frame = {
"images": [observation["rgb"].copy()],
"action": action_sixd,
"depth": [observation["depth"].copy()],
}
frames.append(frame)
print(f"Episode {episode+1} completed with {len(frames)} frames.")

episode_dataset = Dataset.from_list(frames, features=features)

# 2. Load existing dataset if available
if os.path.exists(cfg.dataset.path):
existing_dataset = load_from_disk(cfg.dataset.path)
existing_dataset = existing_dataset.cast(features)
combined_dataset = concatenate_datasets([existing_dataset, episode_dataset])
else:
combined_dataset = episode_dataset

# 3. Save to TEMPORARY location first (avoid self-overwrite)
temp_path = f"{cfg.dataset.path}_temp"
combined_dataset.save_to_disk(temp_path)

# 4. Atomic replacement (only after successful save)
if os.path.exists(cfg.dataset.path):
shutil.rmtree(cfg.dataset.path) # Remove old version
shutil.move(temp_path, cfg.dataset.path) # Move new version into place
shutil.rmtree(cfg.dataset.path)
shutil.move(temp_path, cfg.dataset.path)

robot.disconnect()


if __name__ == "__main__":
main()
main()
Loading
Loading