Star 历史趋势
数据来源: GitHub API · 生成自 Stargazers.cn
README.md

spmd_types

A type system for distributed (SPMD) tensor computations in PyTorch. This package provides two type systems:

  • Local SPMD types, which allow you to use Megatron-style differentiable collectives in a safe way by tracking whether or not your backward gradients are pending reduction or not.

  • Global SPMD types, a DTensor-like abstraction for writing code that has the same semantics whether run on a single device or in a distributed fashion, but with explicit communication operations so you are never guessing when a redistribute occurs.

In both cases, the SPMD types makes it possible for you to check that your code computes correct gradients (local SPMD) or gives equivalent results across different parallelizations (global SPMD), without having to actually run a full E2E distributed training run to check for loss matching.

The goal of this package is to provide a flexible type system that can typecheck realistic training code. We have used local SPMD types to typecheck a realistic pretraining codebase, and global SPMD types is actively under construction!

Installation

pip install spmd_types

Quick start

import torch
import torch.distributed as dist
import spmd_types as spmd
import spmd_types.checker
from torch.distributed.device_mesh import init_device_mesh

# Set up a fake process group (no GPUs needed)
dist.init_process_group(backend="fake", rank=0, world_size=8)
mesh = init_device_mesh("cpu", (2, 4), mesh_dim_names=("dp", "tp"))
dp = mesh.get_group("dp")
tp = mesh.get_group("tp")

with spmd.set_current_mesh(mesh), spmd.checker.typecheck():
    x = torch.randn(4)
    spmd.assert_type(x, {dp: spmd.R, tp: spmd.P})       # R on dp, partial on tp
    y = spmd.all_reduce(x, tp, src=spmd.P, dst=spmd.R)  # sum across tp ranks
    spmd.assert_type(y, {dp: spmd.R, tp: spmd.R})       # now replicated everywhere
    z = torch.mul(y, y)                                 # type inference: R * R -> R
    spmd.assert_type(z, {dp: spmd.R, tp: spmd.R})

dist.destroy_process_group()

Documentation

See Local SPMD types for a hands-on guide on porting Megatron-derived training frameworks, including the Megatron-to-spmd_types function mapping table and advice on Invariant vs Replicate.

See Design for the full type system specification, including local vs global SPMD modes, collective signatures with diagrams, forward-backward pairs, expert mode, cross-mesh compatibility, and partition spec redistribute.

License

BSD 3-Clause License. See CONTRIBUTING.md for how to contribute.

关于 About

This module defines a type system for distributed training code, based off of JAX's sharding in types, but adapted for the PyTorch ecosystem.

语言 Languages

Python100.0%

提交活跃度 Commit Activity

代码提交热力图
过去 52 周的开发活跃度
43
Total Commits
峰值: 13次/周
Less
More

核心贡献者 Contributors