jax-mps
A JAX backend for Apple Silicon using MLX, enabling GPU-accelerated JAX computations on Mac.
[!NOTE] Our CI currently only validates that the project compiles because GitHub's hosted runners don't have access to Apple GPUs. If you have a Mac (e.g., a Mac Mini) that could serve as a self-hosted GitHub Actions runner for this project, please open an issue — it would let us run the full test suite on every PR and help us move much faster.
Example
jax-mps achieves a ~3.7x speed-up over the CPU backend when training a simple ResNet18 model on CIFAR-10 using an M4 MacBook Air.
$ JAX_PLATFORMS=cpu uv run examples/resnet/main.py --steps=30
loss = 0.029: 100%|██████████| 30/30 [01:41<00:00, 3.37s/it]
Final training loss: 0.029
Time per step (second half): 3.437
$ JAX_PLATFORMS=mps uv run examples/resnet/main.py --steps=30
WARNING:...:jax._src.xla_bridge:905: Platform 'mps' is experimental and not all JAX functionality may be correctly supported!
loss = 0.029: 100%|██████████| 30/30 [00:27<00:00, 1.07it/s]
Final training loss: 0.029
Time per step (second half): 0.928Installation
jax-mps requires macOS on Apple Silicon and Python 3.13. Install it with pip:
pip install jax-mpsThe plugin registers itself with JAX automatically and is enabled by default. Set JAX_PLATFORMS=mps to select the MPS backend explicitly.
jax-mps is built against the StableHLO bytecode format matching jaxlib 0.10.x. Using a different jaxlib version will likely cause deserialization failures at JIT compile time. See Version Pinning for details.
Architecture
This project implements a PJRT plugin that uses MLX to execute JAX programs on Apple Silicon GPUs. The evaluation proceeds in several stages:
- The JAX program is lowered to StableHLO, a set of high-level operations for machine learning applications.
- The plugin parses the StableHLO representation and maps operations to MLX equivalents. Compiled programs are cached to avoid re-parsing on repeated invocations.
- The MLX operations are executed on the GPU and results are returned to the caller.
Building
- Install build tools and build and install LLVM/MLIR & StableHLO. This is a one-time setup and takes about 30 minutes. See the
setup_deps.shscript for further options, such as forced re-installation, installation location, etc. The script pins LLVM and StableHLO to specific commits matching jaxlib 0.10.0 for bytecode compatibility (see the section on Version Pinning) for details.
$ brew install cmake ninja
$ ./scripts/setup_deps.sh- Build the plugin and install it as a Python package. This step should be fast, and MUST be repeated for all changes to C++ files.
$ uv pip install -e .Version Pinning
The script pins LLVM and StableHLO to specific commits matching jaxlib 0.10.0 for bytecode compatibility. To update these versions for a different jaxlib release, trace the dependency chain:
# 1. Find XLA commit used by jaxlib
curl -s https://raw.githubusercontent.com/jax-ml/jax/jax-v0.10.0/third_party/xla/revision.bzl
# → XLA_COMMIT = "b6f37ab7..."
# 2. Find LLVM commit used by that XLA version
curl -s https://raw.githubusercontent.com/openxla/xla/<XLA_COMMIT>/third_party/llvm/workspace.bzl
# → LLVM_COMMIT = "815edc3f..."
# 3. Find StableHLO commit used by that XLA version
curl -s https://raw.githubusercontent.com/openxla/xla/<XLA_COMMIT>/third_party/stablehlo/workspace.bzl
# → STABLEHLO_COMMIT = "3a8886de..."Then update the XLA_COMMIT, LLVM_COMMIT, and STABLEHLO_COMMIT variables at the top of scripts/setup_deps_llvm.sh.
Project Structure
jax-mps/
├── CMakeLists.txt
├── src/
│ ├── jax_plugins/mps/ # Python JAX plugin
│ ├── pjrt_plugin/ # C++ PJRT implementation
│ │ ├── pjrt_api.cc # PJRT C API entry point
│ │ ├── mps_client.h/mm # Metal client management
│ │ ├── mlx_executable.h/mm # StableHLO compilation & MLX execution
│ │ └── ops/ # Operation registry
│ └── proto/ # Protobuf definitions
└── tests/
How It Works
PJRT Plugin
PJRT (Portable JAX Runtime) is JAX's abstraction for hardware backends. The plugin implements:
PJRT_Client_Create- Initialize Metal devicePJRT_Client_Compile- Parse StableHLO and build MLX operation graphPJRT_Client_BufferFromHostBuffer- Transfer data to GPUPJRT_LoadedExecutable_Execute- Run computation on GPU
MLX Execution
StableHLO operations are mapped to MLX equivalents, e.g.:
stablehlo.add→mlx::core::add()stablehlo.dot_general→mlx::core::matmul()stablehlo.convolution→mlx::core::conv_general()stablehlo.reduce→mlx::core::sum/max/min/prod()
Footnotes
-
Measured against JAX's upstream test suite. Excluded are tests exercising capabilities the MPS backend fundamentally cannot provide on a single Apple-Silicon device: float64 (not supported on Metal) and dtypes with no MLX element type at all (the sub-byte/8-bit-float family — int4/uint4 and float8/float4), tests requiring a backend other than
mps, and collective/multi-device/sharding ops (MPS is single-device). Kernel-authoring suites that target CUDA/TPU (Pallas, Mosaic), multiprocess tests, and JAX's own documentation-coverage check are also excluded. Run withuv run python scripts/run_jax_tests.py. ↩