Star 历史趋势
数据来源: GitHub API · 生成自 Stargazers.cn
README.md

jax-mps

GitHub Action Badge PyPI JAX tests1

A JAX backend for Apple Silicon using MLX, enabling GPU-accelerated JAX computations on Mac.

[!NOTE] Our CI currently only validates that the project compiles because GitHub's hosted runners don't have access to Apple GPUs. If you have a Mac (e.g., a Mac Mini) that could serve as a self-hosted GitHub Actions runner for this project, please open an issue — it would let us run the full test suite on every PR and help us move much faster.

Example

jax-mps achieves a ~3.7x speed-up over the CPU backend when training a simple ResNet18 model on CIFAR-10 using an M4 MacBook Air.

$ JAX_PLATFORMS=cpu uv run examples/resnet/main.py --steps=30
loss = 0.029: 100%|██████████| 30/30 [01:41<00:00,  3.37s/it]
Final training loss: 0.029
Time per step (second half): 3.437

$ JAX_PLATFORMS=mps uv run examples/resnet/main.py --steps=30
WARNING:...:jax._src.xla_bridge:905: Platform 'mps' is experimental and not all JAX functionality may be correctly supported!
loss = 0.029: 100%|██████████| 30/30 [00:27<00:00,  1.07it/s]
Final training loss: 0.029
Time per step (second half): 0.928

Installation

jax-mps requires macOS on Apple Silicon and Python 3.13. Install it with pip:

pip install jax-mps

The plugin registers itself with JAX automatically and is enabled by default. Set JAX_PLATFORMS=mps to select the MPS backend explicitly.

jax-mps is built against the StableHLO bytecode format matching jaxlib 0.10.x. Using a different jaxlib version will likely cause deserialization failures at JIT compile time. See Version Pinning for details.

Architecture

This project implements a PJRT plugin that uses MLX to execute JAX programs on Apple Silicon GPUs. The evaluation proceeds in several stages:

  1. The JAX program is lowered to StableHLO, a set of high-level operations for machine learning applications.
  2. The plugin parses the StableHLO representation and maps operations to MLX equivalents. Compiled programs are cached to avoid re-parsing on repeated invocations.
  3. The MLX operations are executed on the GPU and results are returned to the caller.

Building

  1. Install build tools and build and install LLVM/MLIR & StableHLO. This is a one-time setup and takes about 30 minutes. See the setup_deps.sh script for further options, such as forced re-installation, installation location, etc. The script pins LLVM and StableHLO to specific commits matching jaxlib 0.10.0 for bytecode compatibility (see the section on Version Pinning) for details.
$ brew install cmake ninja
$ ./scripts/setup_deps.sh
  1. Build the plugin and install it as a Python package. This step should be fast, and MUST be repeated for all changes to C++ files.
$ uv pip install -e .

Version Pinning

The script pins LLVM and StableHLO to specific commits matching jaxlib 0.10.0 for bytecode compatibility. To update these versions for a different jaxlib release, trace the dependency chain:

# 1. Find XLA commit used by jaxlib
curl -s https://raw.githubusercontent.com/jax-ml/jax/jax-v0.10.0/third_party/xla/revision.bzl
# → XLA_COMMIT = "b6f37ab7..."

# 2. Find LLVM commit used by that XLA version
curl -s https://raw.githubusercontent.com/openxla/xla/<XLA_COMMIT>/third_party/llvm/workspace.bzl
# → LLVM_COMMIT = "815edc3f..."

# 3. Find StableHLO commit used by that XLA version
curl -s https://raw.githubusercontent.com/openxla/xla/<XLA_COMMIT>/third_party/stablehlo/workspace.bzl
# → STABLEHLO_COMMIT = "3a8886de..."

Then update the XLA_COMMIT, LLVM_COMMIT, and STABLEHLO_COMMIT variables at the top of scripts/setup_deps_llvm.sh.

Project Structure

jax-mps/
├── CMakeLists.txt
├── src/
│   ├── jax_plugins/mps/         # Python JAX plugin
│   ├── pjrt_plugin/             # C++ PJRT implementation
│   │   ├── pjrt_api.cc          # PJRT C API entry point
│   │   ├── mps_client.h/mm      # Metal client management
│   │   ├── mlx_executable.h/mm  # StableHLO compilation & MLX execution
│   │   └── ops/                 # Operation registry
│   └── proto/                   # Protobuf definitions
└── tests/

How It Works

PJRT Plugin

PJRT (Portable JAX Runtime) is JAX's abstraction for hardware backends. The plugin implements:

  • PJRT_Client_Create - Initialize Metal device
  • PJRT_Client_Compile - Parse StableHLO and build MLX operation graph
  • PJRT_Client_BufferFromHostBuffer - Transfer data to GPU
  • PJRT_LoadedExecutable_Execute - Run computation on GPU

MLX Execution

StableHLO operations are mapped to MLX equivalents, e.g.:

  • stablehlo.addmlx::core::add()
  • stablehlo.dot_generalmlx::core::matmul()
  • stablehlo.convolutionmlx::core::conv_general()
  • stablehlo.reducemlx::core::sum/max/min/prod()

Footnotes

  1. Measured against JAX's upstream test suite. Excluded are tests exercising capabilities the MPS backend fundamentally cannot provide on a single Apple-Silicon device: float64 (not supported on Metal) and dtypes with no MLX element type at all (the sub-byte/8-bit-float family — int4/uint4 and float8/float4), tests requiring a backend other than mps, and collective/multi-device/sharding ops (MPS is single-device). Kernel-authoring suites that target CUDA/TPU (Pallas, Mosaic), multiprocess tests, and JAX's own documentation-coverage check are also excluded. Run with uv run python scripts/run_jax_tests.py.

关于 About

A JAX backend for Apple Metal Performance Shaders (MPS), enabling GPU-accelerated JAX computations on Apple Silicon.
apple-siliconjaxmetal-performance-shaders

语言 Languages

C++54.2%
Python41.4%
Shell2.7%
CMake1.4%
C0.3%

提交活跃度 Commit Activity

代码提交热力图
过去 52 周的开发活跃度
407
Total Commits
峰值: 119次/周
Less
More

核心贡献者 Contributors