Star 历史趋势
数据来源: GitHub API · 生成自 Stargazers.cn
README.md

JaxGCRL

Installation | Quick Start | Environments | Baselines | Citation


Accelerating Goal-Conditioned RL Algorithms and Research

arXiv link: https://arxiv.org/abs/2408.11052

We provide blazingly fast goal-conditioned environments based on MJX and BRAX for quick experimentation with goal-conditioned self-supervised reinforcement learning.

  • Blazingly Fast Training - Train 10 million environment steps in 10 minutes on a single GPU, up to $22\times$ faster than prior implementations.
  • Comprehensive Benchmarking - Includes 10+ diverse environments and multiple pre-implemented baselines for out-of-the-box evaluation.
  • Modular Implementation - Designed for clarity and scalability, allowing for easy modification of algorithms.

Installation 📂

Editable Install (Recommended)

After cloning the repository, run one of the following commands.

With GPU on Linux:

pip install -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

[!NOTE]
Make sure you have the correct CUDA version installed, i.e. CUDA >= 12.3. You can check your CUDA version with nvcc --version command. If you have an older version, you can create a new conda environment with the correct version of CUDA and JaxGCRL package using the following command:

conda env create -f environment.yml

With CPU on Mac:

export SDKROOT="$(xcrun --show-sdk-path)" # may be needed to build brax dependencies pip install -e .

PyPI

The package is also available on PyPI:

pip install jaxgcrl -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Quick Start 🚀

To verify the installation, run a test experiment:

jaxgcrl crl --env ant

The jaxgcrl command is equivalent to invoking python run.py with the same arguments

[!NOTE]
If you haven't yet configured wandb, you may be prompted to log in.

See scripts/train.sh for an example config. A description of the available agents can be generated with jaxgcrl --help. Available configs can be listed with jaxgcrl {crl,ppo,sac,td3} --help. Common flags you may want to change include:

  • env=...: replace "ant" with any environment name. See jaxgcrl/utils/env.py for a list of available environments.
  • Removing --log_wandb: omits logging, if you don't want to use a wandb account.
  • --total_env_steps: shorter or longer runs.
  • --num_envs: based on how many environments your GPU memory allows.
  • --contrastive_loss_fn, --energy_fn, --h_dim, --n_hidden, etc.: algorithmic and architectural changes.

[!Note] We recommend using calculator by @riiswa for checking the correctness of hyperparameters:

Environment Interaction

Environments can be controlled with the reset and step functions. These methods return a state object, which is a dataclass containing the following fields:

state.pipeline_state: current, internal state of the environment
state.obs: current observation
state.done: flag indicating if the agent reached the goal
state.metrics: agent performance metrics
state.info: additional info

The following code demonstrates how to interact with the environment:

import jax from utils.env import create_env key = jax.random.PRNGKey(0) # Initialize the environment env = create_env('ant') # Use JIT compilation to make environment's reset and step functions execute faster jit_env_reset = jax.jit(env.reset) jit_env_step = jax.jit(env.step) NUM_STEPS = 1000 # Reset the environment and obtain the initial state state = jit_env_reset(key) # Simulate the environment for a fixed number of steps for _ in range(NUM_STEPS): # Generate a random action key, key_act = jax.random.split(key, 2) random_action = jax.random.uniform(key_act, shape=(8,), minval=-1, maxval=1) # Perform an environment step with the generated action state = jit_env_step(state, random_action)

Wandb support 📈

We strongly recommend using Wandb for tracking and visualizing results (Wandb support). Enable Wandb logging with the --log_wandb flag. The following flags are also available to organize experiments:

  • --project_name
  • --group_name
  • --exp_name

The --log_wandb flag logs metrics to Wandb. By default, metrics are logged to a CSV.

  1. Run example sweep:
wandb sweep --project example_sweep ./scripts/sweep.yml
  1. Then run wandb agent with :
wandb agent <previous_command_output>

We also render videos of the learned policies as wandb artifacts.

Environments 🌎

We currently support a variety of continuous control environments:

  • Locomotion: Half-Cheetah, Ant, Humanoid
  • Locomotion + task: AntMaze, AntBall (AntSoccer), AntPush, HumanoidMaze
  • Simple arm: Reacher, Pusher, Pusher 2-object
  • Manipulation: Reach, Grasp, Push (easy/hard), Binpick (easy/hard)
EnvironmentEnv nameCode
Reacherreacherlink
Half Cheetahcheetahlink
Pusherpusher_easy
pusher_hard
link
Antantlink
Ant Mazeant_u_maze
ant_big_maze
ant_hardest_maze
link
Ant Soccerant_balllink
Ant Pushant_pushlink
Humanoidhumanoidlink
Humanoid Mazehumanoid_u_maze
humanoid_big_maze
humanoid_hardest_maze
link
Arm Reacharm_reachlink
Arm Grasparm_grasplink
Arm Pusharm_push_easy
arm_push_hard
link
Arm Binpickarm_binpick_easy
arm_binpick_hard
link

To add new environments: add an XML to envs/assets, add a python environment file in envs, and register the environment name in utils.py.

Baselines 🤖

We currently support following algorithms:

AlgorithmHow to runCode
CRLpython run.py crl ...link
PPOpython run.py ppo ...link
SACpython run.py sac ...link
SAC + HERpython run.py sac ... --use_herlink
TD3python run.py td3 ...link
TD3 + HERpython run.py td3 ... --use_herlink

Code Structure 📝

The core structure of the codebase is as follows:


run.py: Takes the name of an agent and runs with the specified configs.
agents/
├── agents/
│   ├── crl/ 
│   │   ├── crl.py CRL algorithm 
│   │   ├── losses.py contrastive losses and energy functions
│   │   └── networks.py CRL network architectures
│   ├── ppo/ 
│   │   └── ppo.py PPO algorithm 
│   ├── sac/ 
│   │   ├── sac.py SAC algorithm
│   │   └── networks.py SAC network architectures
│   └── td3/ 
│       ├── td3.py TD3 algorithm
│       ├── losses.py TD3 loss functions
│       └── networks.py TD3 network architectures
├── utils/
│   ├── config.py Base run configs
│   ├── env.py Logic for rendering and environment initialization
│   ├── replay_buffer.py: Contains replay buffer, including logic for state, action, and goal sampling for training.
│   └── evaluator.py: Runs evaluation and collects metrics.
├── envs/
│   ├── ant.py, humanoid.py, ...: Most environments are here.
│   ├── assets: Contains XMLs for environments.
│   └── manipulation: Contains all manipulation environments.
└── scripts/train.sh: Modify to choose environment and hyperparameters.

The architecture can be adjusted in networks.py.

Contributing 🏗️

Help us build JaxGCRL into the best possible tool for the GCRL community. Reach out and start contributing or just add an Issue/PR!

  • Add Franka robot arm environments. [Done by SimpleGeometry]
  • Get around 70% success rate on Ant Big Maze task. [Done by RajGhugare19]
  • Add more complex versions of Ant Sokoban.
  • Integrate environments:
    • Overcooked
    • Hanabi
    • Rubik's cube
    • Sokoban

To run tests (make sure you have access to a GPU):

python -m pytest

Citing JaxGCRL 📜

If you use JaxGCRL in your work, please cite us as follows:
@inproceedings{bortkiewicz2025accelerating, author = {Bortkiewicz, Micha\l{} and Pa\l{}ucki, W\l{}adek and Myers, Vivek and Dziarmaga, Tadeusz and Arczewski, Tomasz and Kuci\'{n}ski, \L{}ukasz and Eysenbach, Benjamin}, booktitle = {{International Conference} on {Learning Representations}}, title = {{Accelerating Goal-Conditioned RL Algorithms} and {Research}}, url = {https://arxiv.org/pdf/2408.11052}, year = {2025}, }

Questions ❓

If you have any questions, comments, or suggestions, please reach out to Michał Bortkiewicz (michalbortkiewicz8@gmail.com).

See Also 🙌

There are a number of other libraries which inspired this work, we encourage you to take a look!

JAX-native algorithms:

  • Mava: JAX implementations of IPPO and MAPPO, two popular MARL algorithms.
  • PureJaxRL: JAX implementation of PPO, and demonstration of end-to-end JAX-based RL training.
  • Minimax: JAX implementations of autocurricula baselines for RL.
  • JaxIRL: JAX implementation of algorithms for inverse reinforcement learning.

JAX-native environments:

  • Gymnax: Implementations of classic RL tasks including classic control, bsuite and MinAtar.
  • Jumanji: A diverse set of environments ranging from simple games to NP-hard combinatorial problems.
  • Pgx: JAX implementations of classic board games, such as Chess, Go and Shogi.
  • Brax: A fully differentiable physics engine written in JAX, features continuous control tasks.
  • XLand-MiniGrid: Meta-RL gridworld environments inspired by XLand and MiniGrid.
  • Craftax: (Crafter + NetHack) in JAX.
  • JaxMARL: Multi-agent RL in Jax.

关于 About

Online Goal-Conditioned Reinforcement Learning in JAX. ICLR 2025 Spotlight.
goal-conditioned-rljaxreinforcement-learningunsupervised-reinforcement-learning

语言 Languages

Python97.6%
Jupyter Notebook2.1%
Shell0.3%

提交活跃度 Commit Activity

代码提交热力图
过去 52 周的开发活跃度
3
Total Commits
峰值: 1次/周
Less
More

核心贡献者 Contributors