Star 历史趋势
数据来源: GitHub API · 生成自 Stargazers.cn
README.md

Haliax

Build Status Documentation Status License PyPI

Though you don’t seem to be much for listening, it’s best to be careful. If you managed to catch hold of even just a piece of my name, you’d have all manner of power over me.
— Patrick Rothfuss, The Name of the Wind

Haliax is a JAX library for building neural networks with named tensors, in the tradition of Alexander Rush's Tensor Considered Harmful. Named tensors improve the legibility and compositionality of tensor programs by using named axes instead of positional indices as typically used in NumPy, PyTorch, etc.

Despite the focus on legibility, Haliax is also fast, typically about as fast as "pure" JAX code. Haliax is also built to be scalable: it can support Fully-Sharded Data Parallelism (FSDP) and Tensor Parallelism with just a few lines of code. Haliax powers Levanter, our companion library for training large language models and other foundation models, with scale proven up to 70B parameters and up to TPU v4-2048.

Example: Attention

Here's a minimal attention module implementation in Haliax. For a more detailed introduction, please see the Haliax tutorial. (We use the excellent Equinox library for its module system and tree transformations.)

import equinox as eqx import jax import jax.numpy as jnp import haliax as hax import haliax.nn as hnn Pos = hax.Axis("position", 1024) # sequence length KPos = Pos.alias("key_position") Head = hax.Axis("head", 8) # number of attention heads Key = hax.Axis("key", 64) # key size Embed = hax.Axis("embed", 512) # embedding size # alternatively: #Pos, KPos, Head, Key, Embed = hax.make_axes(pos=1024, key_pos=1024, head=8, key=64, embed=512) def attention_scores(Key, KPos, query, key, mask): # how similar is each query to each key scores = hax.dot(query, key, axis=Key) / jnp.sqrt(Key.size) if mask is not None: scores -= 1E9 * (1.0 - mask) # convert to probabilities scores = haliax.nn.softmax(scores, KPos) return scores def attention(Key, KPos, query, key, value, mask): scores = attention_scores(Key, KPos, query, key, mask) answers = hax.dot(scores, value, axis=KPos) return answers # Causal Mask means that if pos >= key_pos, then pos can attend to key_pos causal_mask = hax.arange(Pos).broadcast_axis(KPos) >= hax.arange(KPos) class Attention(eqx.Module): proj_q: hnn.Linear # [Embed] -> [Head, Key] proj_k: hnn.Linear # [Embed] -> [Head, Key] proj_v: hnn.Linear # [Embed] -> [Head, Key] proj_answer: hnn.Linear # output projection from [Head, Key] -> [Embed] @staticmethod def init(Embed, Head, Key, *, key): k_q, k_k, k_v, k_ans = jax.random.split(key, 4) proj_q = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_q) proj_k = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_k) proj_v = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_v) proj_answer = hnn.Linear.init(In=(Head, Key), Out=Embed, key=k_ans) return Attention(proj_q, proj_k, proj_v, proj_answer) def __call__(self, x, mask=None): q = self.proj_q(x) # Rename "position" to "key_position" for self attention k = self.proj_k(x).rename({"position": "key_position"}) v = self.proj_v(x).rename({"position": "key_position"}) answers = attention(Key, KPos, q, k, v, causal_mask) x = self.proj_answer(answers) return x

Haliax was created by Stanford's Center for Research on Foundation Models (CRFM)'s research engineering team. You can find us in the #levanter channel on the unofficial Jax LLM Discord.

Documentation

Tutorials

These are some tutorials to get you started with Haliax. They are available as Colab notebooks:

API Reference

Haliax's API documentation is available at haliax.readthedocs.io.

Contributing

We welcome contributions! Please see CONTRIBUTING.md for more information. We also have a list of good first issues to help you get started. (If those don't appeal, don't hesitate to reach out to us on Discord!)

License

Haliax is licensed under the Apache License, Version 2.0. See LICENSE for the full license text.

关于 About

Named Tensors for Legible Deep Learning in JAX

语言 Languages

Python100.0%

提交活跃度 Commit Activity

代码提交热力图
过去 52 周的开发活跃度
105
Total Commits
峰值: 18次/周
Less
More

核心贡献者 Contributors