diff --git a/robosuite/examples/third_party_controller/mink_controller.py b/robosuite/examples/third_party_controller/mink_controller.py index 6d796ea6ee..39bcdef1a6 100644 --- a/robosuite/examples/third_party_controller/mink_controller.py +++ b/robosuite/examples/third_party_controller/mink_controller.py @@ -141,7 +141,7 @@ def __init__( robot_model: mujoco.MjModel, robot_joint_names: Optional[List[str]] = None, verbose: bool = False, - input_type: Literal["absolute", "relative", "relative_pose"] = "absolute", + input_type: Literal["absolute", "delta", "delta_pose"] = "absolute", input_ref_frame: Literal["world", "base", "eef"] = "world", input_rotation_repr: Literal["quat_wxyz", "axis_angle"] = "axis_angle", posture_weights: Dict[str, float] = None, @@ -315,17 +315,32 @@ def solve(self, input_action: np.ndarray) -> np.ndarray: input_quat_wxyz = input_ori if self.input_ref_frame == "base": - input_poses = np.zeros((len(self.site_ids), 4, 4)) - for i in range(len(self.site_ids)): - base_pos = self.configuration.data.body("robot0_base").xpos - base_ori = self.configuration.data.body("robot0_base").xmat.reshape(3, 3) + base_pos = self.configuration.data.body("robot0_base").xpos + base_ori = self.configuration.data.body("robot0_base").xmat.reshape(3, 3) + + if self.input_type == "absolute": + # For absolute poses, transform both position and orientation from base to world frame base_pose = T.make_pose(base_pos, base_ori) - input_pose = T.make_pose(input_pos[i], T.quat2mat(np.roll(input_quat_wxyz[i], -1))) - input_poses[i] = np.dot(base_pose, input_pose) - input_pos = input_poses[:, :3, 3] - input_quat_wxyz = np.array( - [np.roll(T.mat2quat(input_poses[i, :3, :3]), shift=1) for i in range(len(self.site_ids))] - ) + + # Transform each input pose to world frame + input_poses = [] + for pos, quat in zip(input_pos, input_quat_wxyz): + pose_in_base = T.make_pose(pos, T.quat2mat(np.roll(quat, -1))) + world_pose = base_pose @ pose_in_base + input_poses.append(world_pose) + + input_poses = np.array(input_poses) + input_pos = input_poses[:, :3, 3] + input_quat_wxyz = np.array([np.roll(T.mat2quat(pose[:3, :3]), 1) for pose in input_poses]) + + elif self.input_type == "delta": + # For deltas, only rotate the vectors from base to world frame + input_pos = np.array([base_ori @ pos for pos in input_pos]) + + # Transform rotation deltas using rotation matrices + input_quat_wxyz = np.array( + [np.roll(T.mat2quat(base_ori @ T.quat2mat(np.roll(quat, -1))), 1) for quat in input_quat_wxyz] + ) elif self.input_ref_frame == "eef": raise NotImplementedError("Input reference frame 'eef' not yet implemented") @@ -337,7 +352,7 @@ def solve(self, input_action: np.ndarray) -> np.ndarray: target_pos = input_pos + cur_pos target_quat_xyzw = np.array( [ - T.quat_multiply(T.mat2quat(cur_ori[i].reshape(3, 3)), np.roll(input_quat_wxyz[i], -1)) + T.quat_multiply(np.roll(input_quat_wxyz[i], -1), T.mat2quat(cur_ori[i].reshape(3, 3))) for i in range(len(self.site_ids)) ] ) @@ -489,9 +504,11 @@ def _init_joint_action_policy(self): self.joint_action_policy = IKSolverMink( model=self.sim.model._model, data=self.sim.data._data, - site_names=self.composite_controller_specific_config["ref_name"] - if "ref_name" in self.composite_controller_specific_config - else default_site_names, + site_names=( + self.composite_controller_specific_config["ref_name"] + if "ref_name" in self.composite_controller_specific_config + else default_site_names + ), robot_model=self.robot_model.mujoco_model, robot_joint_names=joint_names, input_type=self.composite_controller_specific_config.get("ik_input_type", "absolute"),