Star 历史趋势
数据来源: GitHub API · 生成自 Stargazers.cn
README.md

🔺 Gated DeltaNet-2: Decoupling Erase and Write in Linear Attention

Official PyTorch implementation of Gated DeltaNet-2: Decoupling Erase and Write in Linear Attention.

Star on GitHub

Ali Hatamizadeh, Yejin Choi, and Jan Kautz.

🌟 Why Gated DeltaNet-2?

Linear attention compresses an unbounded KV cache into a fixed-size recurrent state. The hard part is not just what to forget, but how to edit this compressed memory without scrambling existing associations. Prior delta-rule models (Gated DeltaNet, Kimi Delta Attention) tie erasing and writing to a single scalar gate — even though they act on different axes of the state.

Gated DeltaNet-2 decouples these two roles:

  • ✂️ Channel-wise Erase Gate b_t — selects which key-side coordinates of the decayed state are read and removed
  • ✍️ Channel-wise Write Gate w_t — selects which value-side coordinates of the new content are committed
  • 🌀 Channel-wise Decay — inherited from KDA for fine-grained global forgetting
  • 🔁 Strict Generalization — recovers KDA when both gates collapse to the same scalar, and Gated DeltaNet when the decay also collapses
  • Hardware-efficient Training — fast-weight WY chunkwise algorithm with gate-aware backward, fused in Triton

📐 The Gated Delta Rule-2

Given an erase gate b_t ∈ [0,1]^{d_k}, a write gate w_t ∈ [0,1]^{d_v}, and channel-wise decay D_t = Diag(α_t), the recurrent state evolves as:

S_t = (I − k_t (b_t ⊙ k_t)ᵀ) D_t S_{t−1}  +  k_t (w_t ⊙ v_t)ᵀ

Compared with KDA, the right factor of the rank-one erase becomes channel-selective on the key axis, and the write term becomes channel-selective on the value axis. The two decisions no longer share a single scalar.

📊 Results

We train all models at 1.3B parameters on 100B tokens of FineWeb-Edu, matched in recurrent state size, and compare against Mamba-2, Gated DeltaNet, KDA, and Mamba-3 (SISO and MIMO).

Language Modeling and Commonsense Reasoning

Gated DeltaNet-2 achieves the best average across both recurrent-only and hybrid settings:

ModelWiki ppl ↓LMB ppl ↓LMB acc ↑Avg. acc ↑
Recurrent
Mamba-216.7912.3845.2451.82
Gated DeltaNet16.4011.8949.6252.07
KDA16.8111.6848.1352.28
Mamba-3 (MIMO)16.4511.6647.8252.39
Gated DeltaNet-215.9011.4148.0953.11
Hybrid (+ SWA)
Transformer19.2213.7248.3250.86
Gated DeltaNet16.0010.8248.7152.25
KDA16.0110.6649.2152.68
Mamba-3 (MIMO)15.8110.9249.8252.72
Gated DeltaNet-215.6210.4350.9053.97

Long-context Retrieval (RULER)

Gated DeltaNet-2 is strongest where memory editing matters most — particularly the interference-heavy multi-key needle-in-a-haystack settings:

ModelS-NIAH-2 @4KS-NIAH-3 @2KMK-NIAH-1 @4K
Recurrent
Gated DeltaNet87.254.227.8
KDA89.063.228.0
Mamba-3 (MIMO)64.272.418.0
Gated DeltaNet-293.089.837.8
Hybrid
Gated DeltaNet57.391.244.8
KDA56.093.440.4
Mamba-3 (MIMO)53.098.446.6
Gated DeltaNet-257.999.048.0

Real-world Retrieval

Across SWDE, SQuAD, FDA, TriviaQA, NQ, and DROP, Gated DeltaNet-2 leads the recurrent and hybrid frontier:

SettingMamba-2GDNKDAMamba-3 (MIMO)GDN-2
Recurrent avg.26.8428.0928.6728.3529.88
Hybrid avg.39.7439.1140.1440.1142.28

Throughput

Gated DeltaNet-2 retains near-flat scaling with sequence length on a single H100 (training, hybrid 1.3B), with only a small constant overhead over KDA for the added channel-wise gates.

🔧 What's New in the Update Rule

MethodDecayEraseWrite
Mamba-2scalarscalar
Gated DeltaNetscalarscalar β_tscalar β_t
KDAchannel-wisescalar β_tscalar β_t
Gated DeltaNet-2channel-wisechannel-wise b_tchannel-wise w_t

Ablations confirm both gates contribute, with the erase gate b_t accounting for most of the gain — consistent with its role in selectively protecting or revising key-side associations in the recurrent state.

📢 Latest Updates

  • 05/21/2026: 🔥 Code Release: Train your own Gated DeltaNet-2 on FineWeb-Edu
  • Watch this space for more exciting updates!

🚀 Getting Started

Training Your Model

Launch your training with our streamlined command:

python ../pretrain.py \ --train_data_dir ${TRAIN_DATA} \ --val_data_dir ${VALIDATION_DATA} \ --output_root ${SAVE_DIR} \ --exp_name ${NAME} \ --model_name ${MODEL} \ --train_config ${CONFIG} \ --eval_iters ${EVAL_ITERS} \ --learning_rate ${LR} \ --micro_batch_size ${MICRO_BATCH_SIZE}

💡 Pro Tip: Add --interactive_job --debug for interactive debugging sessions!

Default Recipe

We train 1.3B-parameter models on 100B tokens of FineWeb-Edu with:

  • AdamW, peak LR 4e-4, weight decay 0.1, gradient clip 1.0
  • Cosine schedule with 1B-token warmup
  • Global batch size 0.5M tokens, sequence length 4K
  • Hybrid models use a 2K sliding-window attention size
  • 16 heads, d_k = d_v = 128, matched recurrent state size against Mamba-2/3 baselines

📜 License

Copyright © 2026, NVIDIA Corporation. All rights reserved.

Licensed under the NVIDIA Source Code License-NC. See LICENSE for details.

🙏 Acknowledgements

Built on the shoulders of giants:

📖 Citation

If you find this work useful, please consider citing:

@article{hatamizadeh2026gated, title={Gated DeltaNet-2: Decoupling Erase and Write in Linear Attention}, author={Hatamizadeh, Ali and Choi, Yejin and Kautz, Jan}, journal={arXiv preprint arXiv:2605.22791}, year={2026} }

⭐ Support Us

If you find this work useful, please consider:

  • Starring the repository
  • Citing our paper
  • Contributing to the codebase

Join us in pushing the boundaries of linear attention! 🚀

Star History

Stargazers repo roster for @NVlabs/GatedDeltaNet-2

Star History Chart

关于 About

Official PyTorch Implementation of Gated DeltaNet-2: Decoupling Erase and Write in Linear Attention

语言 Languages

Python99.3%
Shell0.4%
Dockerfile0.3%

提交活跃度 Commit Activity

代码提交热力图
过去 52 周的开发活跃度
6
Total Commits
峰值: 4次/周
Less
More

核心贡献者 Contributors