import jax from brax import base from jax import numpy as jnp from .arm_envs import ArmEnvs """ Reach: Move end of arm to random goal. - Observation space: 13-dim obs + 3-dim goal. - Action space: 4-dim, each element in [-1, 1], corresponding to target angles for joints 1, 2, 4, 6. See _get_obs() and ArmEnvs._convert_action() for details. """ class ArmReach(ArmEnvs): def _get_xml_path(self): return "envs/assets/panda_reach.xml" @property def action_size(self) -> int: return 4 # Override default (actuator count) # See ArmEnvs._set_environment_attributes for descriptions of attributes def _set_environment_attributes(self): self.env_name = "arm_reach" self.episode_length = 100 self.goal_indices = jnp.array([7, 8, 9]) # End-effector position self.completion_goal_indices = jnp.array([7, 8, 9]) # Identical self.state_dim = 13 self.goal_reach_thresh = 0.1 self.arm_noise_scale = 0 self.goal_noise_scale = 0.2 def _get_initial_state(self, rng): target_q = self.sys.init_q[:7] arm_q_default = jnp.array( [1.571, 0.742, 0, -1.571, 0, 3.054, 1.449] ) # Start closer to the relevant area arm_q = arm_q_default + self.arm_noise_scale * jax.random.uniform( rng, [self.sys.q_size() - 7], minval=-1 ) q = jnp.concatenate([target_q] + [arm_q]) qd = jnp.zeros([self.sys.qd_size()]) return q, qd def _get_initial_goal(self, pipeline_state: base.State, rng): """ Generate goals in a box. x: [-0.2, 0.2], y: [0.3, 0.7], z: [0.1, 0.5] """ goal = jnp.array([0, 0.5, 0.3]) + self.goal_noise_scale * jax.random.uniform(rng, [3], minval=-1) return goal def _compute_goal_completion(self, obs, goal): # Goal occupancy: is the end of the arm close enough to the goal? eef_pos = obs[self.completion_goal_indices] goal_eef_pos = goal[:3] dist = jnp.linalg.norm(eef_pos - goal_eef_pos) success = jnp.array(dist < self.goal_reach_thresh, dtype=float) success_easy = jnp.array(dist < 0.3, dtype=float) success_hard = jnp.array(dist < 0.03, dtype=float) return success, success_easy, success_hard def _update_goal_visualization(self, pipeline_state: base.State, goal: jax.Array) -> base.State: updated_q = pipeline_state.q.at[:3].set(goal) # Only set the position, not orientation updated_pipeline_state = pipeline_state.replace(qpos=updated_q) return updated_pipeline_state def _get_obs(self, pipeline_state: base.State, goal: jax.Array, timestep) -> jax.Array: """ Observation space (13-dim) - q_subset (7-dim): joint angles - End of arm (6-dim): position and velocity Note q is 14-dim: 7-dim cube position/angle, 7-dim joint angles Goal space (3-dim): position of end of arm """ q_subset = pipeline_state.q[7:14] eef_index = 7 # Cube is 0, then links 1-7 are indices 1-7. The end-effector (eef) base is merged with link 7, so we say link 7 index = eef index. eef_x_pos = pipeline_state.x.pos[eef_index] eef_xd_vel = pipeline_state.xd.vel[eef_index] return jnp.concatenate([q_subset] + [eef_x_pos] + [eef_xd_vel] + [goal]) def _get_arm_angles(self, pipeline_state: base.State) -> jax.Array: q_indices = jnp.arange(7, 14) return pipeline_state.q[q_indices]