from typing import Tuple import jax from brax import base, math from brax.envs.base import PipelineEnv, State from brax.io import mjcf from etils import epath from jax import numpy as jnp # This is based on original Reacher environment from Brax # https://github.com/google/brax/blob/main/brax/envs/reacher.py class Reacher(PipelineEnv): def __init__(self, backend="generalized", dense_reward: bool = False, **kwargs): path = epath.resource_path("brax") / "envs/assets/reacher.xml" sys = mjcf.load(path) n_frames = 2 if backend in ["spring", "positional"]: sys = sys.tree_replace({"opt.timestep": 0.005}) sys = sys.replace(actuator=sys.actuator.replace(gear=jnp.array([25.0, 25.0]))) n_frames = 4 kwargs["n_frames"] = kwargs.get("n_frames", n_frames) super().__init__(sys=sys, backend=backend, **kwargs) self.dense_reward = dense_reward self.state_dim = 10 self.goal_indices = jnp.array([4, 5, 6]) self.goal_reach_thresh = 0.05 def reset(self, rng: jax.Array) -> State: rng, rng1, rng2 = jax.random.split(rng, 3) q = self.sys.init_q + jax.random.uniform(rng1, (self.sys.q_size(),), minval=-0.1, maxval=0.1) qd = jax.random.uniform(rng2, (self.sys.qd_size(),), minval=-0.005, maxval=0.005) # set the target q, qd _, target = self._random_target(rng) q = q.at[2:].set(target) qd = qd.at[2:].set(0) pipeline_state = self.pipeline_init(q, qd) obs = self._get_obs(pipeline_state) reward, done, zero = jnp.zeros(3) metrics = { "reward_dist": zero, "reward_ctrl": zero, "success": zero, "dist": zero, } state = State(pipeline_state, obs, reward, done, metrics) return state def step(self, state: State, action: jax.Array) -> State: pipeline_state = self.pipeline_step(state.pipeline_state, action) obs = self._get_obs(pipeline_state) target_pos = pipeline_state.x.pos[2] tip_pos = pipeline_state.x.take(1).do(base.Transform.create(pos=jnp.array([0.11, 0, 0]))).pos tip_to_target = target_pos - tip_pos dist = jnp.linalg.norm(tip_to_target) reward_dist = -math.safe_norm(tip_to_target) success = jnp.array(dist < self.goal_reach_thresh, dtype=float) if self.dense_reward: reward = reward_dist else: reward = success state.metrics.update(reward_dist=reward_dist, success=success, dist=dist) return state.replace(pipeline_state=pipeline_state, obs=obs, reward=reward) def _get_obs(self, pipeline_state: base.State) -> jax.Array: """Returns egocentric observation of target and arm body.""" theta = pipeline_state.q[:2] target_pos = pipeline_state.x.pos[2] tip_pos = pipeline_state.x.take(1).do(base.Transform.create(pos=jnp.array([0.11, 0, 0]))).pos tip_vel = base.Transform.create(pos=jnp.array([0.11, 0, 0])).do(pipeline_state.xd.take(1)).vel return jnp.concatenate( [ # state jnp.cos(theta), jnp.sin(theta), tip_pos, tip_vel, # target/goal target_pos, ] ) def _random_target(self, rng: jax.Array) -> Tuple[jax.Array, jax.Array]: """Returns a target location in a random circle slightly above xy plane.""" rng, rng1, rng2 = jax.random.split(rng, 3) dist = 0.2 * jax.random.uniform(rng1) ang = jnp.pi * 2.0 * jax.random.uniform(rng2) target_x = dist * jnp.cos(ang) target_y = dist * jnp.sin(ang) return rng, jnp.array([target_x, target_y])