import jax from brax import base, math from brax.envs.base import PipelineEnv, State from brax.io import mjcf from etils import epath from jax import numpy as jnp # This is based on original Pusher environment from Brax # https://github.com/google/brax/blob/main/brax/envs/pusher.py class Pusher(PipelineEnv): def __init__(self, backend="generalized", kind="easy", dense_reward: bool = False, **kwargs): path = epath.resource_path("brax") / "envs/assets/pusher.xml" sys = mjcf.load(path) n_frames = 5 if backend in ["spring", "positional"]: sys = sys.tree_replace({"opt.timestep": 0.001}) sys = sys.replace(actuator=sys.actuator.replace(gear=jnp.array([20.0] * sys.act_size()))) n_frames = 50 kwargs["n_frames"] = kwargs.get("n_frames", n_frames) super().__init__(sys=sys, backend=backend, **kwargs) # The tips_arm body gets fused with r_wrist_roll_link, so we use the parent # r_wrist_flex_link for tips_arm_idx. self._tips_arm_idx = self.sys.link_names.index("r_wrist_flex_link") self._object_idx = self.sys.link_names.index("object") self._goal_idx = self.sys.link_names.index("goal") self.kind = kind self.dense_reward = dense_reward self.state_dim = 20 self.goal_indices = jnp.array([10, 11, 12]) self.goal_reach_thresh = 0.1 def reset(self, rng: jax.Array) -> State: qpos = self.sys.init_q rng, rng1, rng2, rng3, rng4 = jax.random.split(rng, 5) # randomly orient the object cylinder_pos = jnp.concatenate( [ jax.random.uniform(rng, (1,), minval=-0.3, maxval=-1e-6), jax.random.uniform(rng1, (1,), minval=-0.2, maxval=0.2), ] ) # randomly place the goal depending on env kind if self.kind == "hard": goal_pos = jnp.concatenate( [ jax.random.uniform(rng2, (1,), minval=-0.65, maxval=0.35), jax.random.uniform(rng3, (1,), minval=-0.55, maxval=0.45), ] ) elif self.kind == "easy": goal_pos = jnp.concatenate( [ jax.random.uniform(rng2, (1,), minval=-0.3, maxval=-1e-6) - 0.25, jax.random.uniform(rng3, (1,), minval=-0.2, maxval=0.2), ] ) # constrain minimum distance of object to goal norm = math.safe_norm(cylinder_pos - goal_pos) scale = jnp.where(norm < 0.17, 0.17 / norm, 1.0) cylinder_pos *= scale qpos = qpos.at[-4:].set(jnp.concatenate([cylinder_pos, goal_pos])) qvel = jax.random.uniform(rng4, (self.sys.qd_size(),), minval=-0.005, maxval=0.005) qvel = qvel.at[-4:].set(0.0) pipeline_state = self.pipeline_init(qpos, qvel) obs = self._get_obs(pipeline_state) reward, done, zero = jnp.zeros(3) metrics = { "reward_dist": zero, "reward_ctrl": zero, "reward_near": zero, "success": zero, "success_hard": zero, } state = State(pipeline_state, obs, reward, done, metrics) return state def step(self, state: State, action: jax.Array) -> State: assert state.pipeline_state is not None x_i = state.pipeline_state.x.vmap().do(base.Transform.create(pos=self.sys.link.inertia.transform.pos)) vec_1 = x_i.pos[self._object_idx] - x_i.pos[self._tips_arm_idx] vec_2 = x_i.pos[self._object_idx] - x_i.pos[self._goal_idx] obj_to_goal_dist = math.safe_norm(vec_2) reward_near = -math.safe_norm(vec_1) reward_dist = -obj_to_goal_dist reward_ctrl = -jnp.square(action).sum() pipeline_state = self.pipeline_step(state.pipeline_state, action) obs = self._get_obs(pipeline_state) success = jnp.array(obj_to_goal_dist < self.goal_reach_thresh, dtype=float) if self.dense_reward: reward = reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near else: reward = success state.metrics.update( reward_near=reward_near, reward_dist=reward_dist, reward_ctrl=reward_ctrl, success=success, success_hard=jnp.array(obj_to_goal_dist < 0.05, dtype=float), ) return state.replace(pipeline_state=pipeline_state, obs=obs, reward=reward) def _get_obs(self, pipeline_state: base.State) -> jax.Array: """Observes pusher body position and velocities.""" x_i = pipeline_state.x.vmap().do(base.Transform.create(pos=self.sys.link.inertia.transform.pos)) return jnp.concatenate( [ # state pipeline_state.q[:7], # Rotations of arm joints [7, ] x_i.pos[self._tips_arm_idx], # Arm tip position [3, ] x_i.pos[self._object_idx], # Movable object position [3, ] pipeline_state.qd[:7], # Rotational velocities of arm joints [7, ] # goal x_i.pos[self._goal_idx], # This is the position we want the object to end up in [3, ] ] ) # This is debug env for pusher. # The goal here is the same as in Reacher: to get arm to given position. class PusherReacher(PipelineEnv): def __init__(self, backend="generalized", **kwargs): path = epath.resource_path("brax") / "envs/assets/pusher.xml" sys = mjcf.load(path) n_frames = 5 if backend in ["spring", "positional"]: sys = sys.tree_replace({"opt.timestep": 0.001}) sys = sys.replace(actuator=sys.actuator.replace(gear=jnp.array([20.0] * sys.act_size()))) n_frames = 50 kwargs["n_frames"] = kwargs.get("n_frames", n_frames) super().__init__(sys=sys, backend=backend, **kwargs) # The tips_arm body gets fused with r_wrist_roll_link, so we use the parent # r_wrist_flex_link for tips_arm_idx. self._tips_arm_idx = self.sys.link_names.index("r_wrist_flex_link") self._object_idx = self.sys.link_names.index("object") self._goal_idx = self.sys.link_names.index("goal") self.state_dim = 17 self.goal_indices = jnp.array([14, 15, 16]) def reset(self, rng: jax.Array) -> State: qpos = self.sys.init_q rng, rng1, rng2, rng3, rng4 = jax.random.split(rng, 5) # randomly orient the object cylinder_pos = jnp.concatenate( [ jnp.array([1.0]), jnp.array([1.0]), ] ) # randomly place the goal depending on env kind goal_pos = jnp.concatenate( [ jax.random.uniform(rng2, (1,), minval=-0.3, maxval=-1e-6), jax.random.uniform(rng3, (1,), minval=-0.2, maxval=0.2), ] ) # constrain minimum distance of object to goal norm = math.safe_norm(cylinder_pos - goal_pos) scale = jnp.where(norm < 0.17, 0.17 / norm, 1.0) cylinder_pos *= scale qpos = qpos.at[-4:].set(jnp.concatenate([cylinder_pos, goal_pos])) qvel = jax.random.uniform(rng4, (self.sys.qd_size(),), minval=-0.005, maxval=0.005) qvel = qvel.at[-4:].set(0.0) pipeline_state = self.pipeline_init(qpos, qvel) obs = self._get_obs(pipeline_state) reward, done, zero = jnp.zeros(3) metrics = { "reward_dist": zero, "reward_ctrl": zero, "reward_near": zero, "success": zero, "success_hard": zero, } state = State(pipeline_state, obs, reward, done, metrics) return state def step(self, state: State, action: jax.Array) -> State: assert state.pipeline_state is not None x_i = state.pipeline_state.x.vmap().do(base.Transform.create(pos=self.sys.link.inertia.transform.pos)) vec_1 = x_i.pos[self._object_idx] - x_i.pos[self._tips_arm_idx] vec_2 = x_i.pos[self._object_idx] - x_i.pos[self._goal_idx] arm_to_goal_dist = math.safe_norm(x_i.pos[self._goal_idx] - x_i.pos[self._tips_arm_idx]) reward_dist = -arm_to_goal_dist reward = reward_dist pipeline_state = self.pipeline_step(state.pipeline_state, action) obs = self._get_obs(pipeline_state) state.metrics.update( reward_near=0.0, reward_dist=reward_dist, reward_ctrl=0.0, success=jnp.array(arm_to_goal_dist < 0.1, dtype=float), success_hard=jnp.array(arm_to_goal_dist < 0.05, dtype=float), ) return state.replace(pipeline_state=pipeline_state, obs=obs, reward=reward) def _get_obs(self, pipeline_state: base.State) -> jax.Array: """Observes pusher body position and velocities.""" x_i = pipeline_state.x.vmap().do(base.Transform.create(pos=self.sys.link.inertia.transform.pos)) return jnp.concatenate( [ # state pipeline_state.q[:7], # Rotations of arm joints [7, ] pipeline_state.qd[:7], # Rotational velocities of arm joints [7, ] x_i.pos[self._tips_arm_idx], # Arm tip position [3, ] # goal x_i.pos[self._goal_idx], ] )