{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "authorship_tag": "ABX9TyPclsvowuYysU2IoxV9gyoD" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "# 集合通信实践\n", "\n", "介绍:在分布式训练或推理中,集合通信(Collective Communication)是一项必备操作。理解常见集合通信的操作原理,是实践或优化分布式任务的基础。本练习将通过PyTorch的集合通信库,带大家了解常见操作的基本用法与原理。\n", "\n", "\n", "相关文章:[分布式训练/推理基础:集合通信原理与实践](https://zhuanlan.zhihu.com/p/2006011081177457311)\n", "\n", "Author: kaiyuan\n", "\n", "Email: kaiyuanxie@yeah.net" ], "metadata": { "id": "bU3jAMiS66lD" } }, { "cell_type": "markdown", "source": [ "# 1 用例说明\n", "\n", "我们采用pytorch的通信库API来进行实践,API的相关介绍参考:[ Distributed communication](https://docs.pytorch.org/docs/stable/distributed.html)\n", "\n", "建议用PyTorch的镜像运行用例,常用的镜像: nvcr.io/nvidia/pytorch:xxxx\n", "\n", "```\n", "docker pull nvcr.io/nvidia/pytorch:26.01-py3\n", "```\n", "\n", "启动示例:\n", "\n", "```\n", "docker run -itd --rm --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \\\n", "-v /data/nfs_87:/data/nfs_87 \\\n", "--name pytorch-dev nvcr.io/nvidia/pytorch:26.01-py3 bash\n", "```\n", "\n", "进入容器:\n", "\n", "```\n", "docker exec -it pytorch-dev bash\n", "```\n", "\n", "测试机器信息:\n", "- NVIDIA A100-SXM4-80GB x 8\n", "- NVIDIA-SMI 570.172.08\n", "- Driver Version: 570.172.08\n", "- CUDA Version: 13.1\n", "\n", "在示例中,通过设置 world_size 参数来控制使用的GPU数量,此处设定world_size=4。运行该用例时,可能会出现如下告警信息,这是资源释放过程中产生的问题,不影响演示效果。\n", "\n", "```\n", "[rank0]:[W214 09:22:26.727076495 ProcessGroupNCCL.cpp:1565] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())\n", "```" ], "metadata": { "id": "60Rudfjt8Tr9" } }, { "cell_type": "markdown", "source": [ "# 2 聚合(Gather)\n", "\n", "\n", "## 2.1 Row Gather" ], "metadata": { "id": "aWgMBs5A_Vf1" } }, { "cell_type": "code", "source": [ "import os\n", "import torch\n", "import torch.distributed as dist\n", "import torch.multiprocessing as mp\n", "\n", "os.environ['MASTER_ADDR'] = 'localhost'\n", "os.environ['MASTER_PORT'] = '62115'\n", "\n", "def example(rank, world_size):\n", " # 初始化进程组\n", " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", " torch.cuda.set_device(rank)\n", "\n", " # 本地张量:形状[1, 5],数值为 rank\n", " tensor = torch.ones(1, 5, device=rank) * rank\n", "\n", " # 只在目标进程(rank 0)创建接收列表\n", " if rank == 0:\n", " gather_list = [torch.empty_like(tensor) for _ in range(world_size)]\n", " else:\n", " gather_list = None # 非目标进程不需要接收列表\n", "\n", " # 执行 gather 操作,所有进程将 tensor 发送给 rank 0\n", " dist.gather(tensor, gather_list=gather_list, dst=0)\n", "\n", " # 在rank 0上处理收集到的数据\n", " if rank == 0:\n", " # 将列表沿dim=0拼接,得到[world_size, 5]\n", " gathered_tensor = torch.cat(gather_list, dim=0)\n", " print(f\"Rank {rank} gathered tensor shape: {gathered_tensor.shape}\")\n", " print(\"Gathered tensor:\\n\", gathered_tensor)\n", " else:\n", " print(f\"Rank {rank} has sent its data, no local copy of gathered tensor.\")\n", "\n", "def main():\n", " world_size = 4\n", " mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-c8aQSoj8TIw", "outputId": "0f7eec81-0b1b-40f1-e300-f101e25d8c7b" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Rank 3 has sent its data, no local copy of gathered tensor.\n", "Rank 1 has sent its data, no local copy of gathered tensor.\n", "Rank 2 has sent its data, no local copy of gathered tensor.\n", "Rank 0 gathered tensor shape: torch.Size([4, 5])\n", "Gathered tensor:\n", " tensor([[0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1.],\n", " [2., 2., 2., 2., 2.],\n", " [3., 3., 3., 3., 3.]], device='cuda:0')\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "## 2.2 Column Gather" ], "metadata": { "id": "cYecYOEMOrKr" } }, { "cell_type": "code", "source": [ "import os\n", "import torch\n", "import torch.distributed as dist\n", "import torch.multiprocessing as mp\n", "\n", "os.environ['MASTER_ADDR'] = 'localhost'\n", "os.environ['MASTER_PORT'] = '62115'\n", "\n", "def example(rank, world_size):\n", " # 初始化进程组,使用NCCL后端\n", " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", " torch.cuda.set_device(rank)\n", "\n", " # 每个进程生成自己的本地张量:形状[5, 2],值 = rank*10 + 列索引\n", " cols_per_rank = 2\n", " local_tensor = torch.zeros(5, cols_per_rank, device=rank, dtype=torch.float)\n", " for col in range(cols_per_rank):\n", " local_tensor[:, col] = rank * 10 + col # 第0列全是 rank*10,第1列全是 rank*10+1\n", "\n", " print(f\"Rank {rank} local tensor (shape {local_tensor.shape}):\\n{local_tensor.cpu().numpy()}\")\n", "\n", " # 只在目标进程(rank 0)上准备接收列表\n", " if rank == 0:\n", " # 接收列表包含 world_size 个空张量,每个形状与 local_tensor 相同,位于 rank 0 的设备上\n", " gather_list = [torch.empty_like(local_tensor, device=0) for _ in range(world_size)]\n", " else:\n", " gather_list = None # 非目标进程不需要接收列表\n", "\n", " # 执行 gather 操作:所有进程将local_tensor发送给rank 0\n", " dist.gather(local_tensor, gather_list=gather_list, dst=0)\n", "\n", " # 在rank 0上处理收集到的数据\n", " if rank == 0:\n", " # 沿列维度(dim=1)拼接,得到形状[5, world_size * 2]的大张量\n", " gathered = torch.cat(gather_list, dim=1)\n", " print(f\"\\nRank {0} final gathered tensor (shape {gathered.shape}):\\n{gathered.cpu().numpy()}\")\n", " else:\n", " print(f\"Rank {rank} has sent its data, no local copy of gathered tensor.\\n\")\n", "\n", "def main():\n", " world_size = 4\n", " mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "aVAePK9tOrdG", "outputId": "4293ff21-21c3-40b4-80f0-5934115ac328" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Rank 1 local tensor (shape torch.Size([5, 2])):\n", "[[10. 11.]\n", " [10. 11.]\n", " [10. 11.]\n", " [10. 11.]\n", " [10. 11.]]\n", "Rank 3 local tensor (shape torch.Size([5, 2])):\n", "[[30. 31.]\n", " [30. 31.]\n", " [30. 31.]\n", " [30. 31.]\n", " [30. 31.]]\n", "Rank 2 local tensor (shape torch.Size([5, 2])):\n", "[[20. 21.]\n", " [20. 21.]\n", " [20. 21.]\n", " [20. 21.]\n", " [20. 21.]]\n", "Rank 0 local tensor (shape torch.Size([5, 2])):\n", "[[0. 1.]\n", " [0. 1.]\n", " [0. 1.]\n", " [0. 1.]\n", " [0. 1.]]\n", "Rank 1 has sent its data, no local copy of gathered tensor.\n", "\n", "Rank 2 has sent its data, no local copy of gathered tensor.\n", "\n", "Rank 3 has sent its data, no local copy of gathered tensor.\n", "\n", "\n", "Rank 0 final gathered tensor (shape torch.Size([5, 8])):\n", "[[ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]]\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "# 3 全聚合(All Gather)\n", "\n", "## 3.1 Row allgather" ], "metadata": { "id": "C879tQ-G_dM_" } }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "CNnELG9X3kpI", "outputId": "00511ac5-89fa-4ab3-b870-082c4843bbce" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "rank 1 after cat: torch.Size([4, 5])\n", "rank 0 after cat: torch.Size([4, 5])\n", "rank 3 after cat: torch.Size([4, 5])\n", "rank 2 after cat: torch.Size([4, 5])\n", "Rank 1 gathered tensor:\n", "Rank 0 gathered tensor:\n", "Rank 3 gathered tensor:\n", "Rank 2 gathered tensor:\n", " tensor([[0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1.],\n", " [2., 2., 2., 2., 2.],\n", " [3., 3., 3., 3., 3.]], device='cuda:1')\n", " tensor([[0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1.],\n", " [2., 2., 2., 2., 2.],\n", " [3., 3., 3., 3., 3.]], device='cuda:0')\n", " tensor([[0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1.],\n", " [2., 2., 2., 2., 2.],\n", " [3., 3., 3., 3., 3.]], device='cuda:2')\n", " tensor([[0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1.],\n", " [2., 2., 2., 2., 2.],\n", " [3., 3., 3., 3., 3.]], device='cuda:3')\n", "\")\n", "\n" ] } ], "source": [ "import os\n", "import time\n", "import torch\n", "import torch.distributed as dist\n", "import torch.multiprocessing as mp\n", "\n", "os.environ['MASTER_ADDR'] = 'localhost'\n", "os.environ['MASTER_PORT'] = '62115'\n", "\n", "def example(rank, world_size):\n", " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", " torch.cuda.set_device(rank)\n", "\n", " # 本地张量形状[1, 5],放在当前GPU\n", " tensor_shard = torch.ones(1, 5, device=rank) * rank\n", "\n", " # 创建接收列表:每个元素形状为[1, 5]\n", " tensor_gather_list = [torch.empty_like(tensor_shard) for _ in range(world_size)]\n", "\n", " # 执行 all_gather\n", " dist.all_gather(tensor_gather_list, tensor_shard)\n", "\n", " # 方法1:使用torch.cat沿dim=0拼接,得到 [world_size, 5]\n", " gathered_tensor = torch.cat(tensor_gather_list, dim=0) # 形状 (world_size, 5)\n", " print(f\"rank {rank} after cat: {gathered_tensor.shape}\")\n", "\n", " # 方法2:如果使用torch.stack,会得到[world_size, 1, 5]\n", " # stacked = torch.stack(tensor_gather_list, dim=0) # (world_size, 1, 5)\n", " # gathered_tensor = stacked.squeeze(1) # (world_size, 5)\n", "\n", " time.sleep(1)\n", " # 验证结果:rank 0 打印聚合后的张量\n", " print(f\"Rank {rank} gathered tensor:\\n\", gathered_tensor)\n", "\n", "def main():\n", " world_size = 4 # 假设 4 个进程\n", " mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ] }, { "cell_type": "markdown", "source": [ "## 3.2 Column allgather" ], "metadata": { "id": "sNoIZxHTNhY0" } }, { "cell_type": "code", "source": [ "import os\n", "import torch\n", "import torch.distributed as dist\n", "import torch.multiprocessing as mp\n", "\n", "os.environ['MASTER_ADDR'] = 'localhost'\n", "os.environ['MASTER_PORT'] = '62115'\n", "\n", "def example(rank, world_size):\n", " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", " torch.cuda.set_device(rank)\n", "\n", " # 每个进程生成一个[5, 2]的张量,值为rank*10 + 列索引(便于观察)\n", " cols_per_rank = 2\n", " tensor = torch.zeros(5, cols_per_rank, device=rank, dtype=torch.float)\n", " for j in range(cols_per_rank):\n", " tensor[:, j] = rank * 10 + j\n", " print(f\"Rank {rank} local tensor shape {tensor.shape}:\\n{tensor.cpu().numpy()}\")\n", "\n", " # 准备接收列表:每个元素形状与本地张量相同\n", " gather_list = [torch.empty_like(tensor) for _ in range(world_size)]\n", "\n", " # 执行all_gather(收集到列表)\n", " dist.all_gather(gather_list, tensor)\n", "\n", " # 沿列维度(dim=1)拼接所有张量\n", " gathered_tensor = torch.cat(gather_list, dim=1) # 形状[5, world_size * cols_per_rank] = [5, 8]\n", " print(f\"Rank {rank} after column-wise all_gather, shape {gathered_tensor.shape}:\\n{gathered_tensor.cpu().numpy()}\")\n", "\n", "def main():\n", " world_size = 4\n", " mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "oGSFN_neNkq2", "outputId": "2975b720-f79c-4dc7-abe3-2f855b12fdf9" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Rank 1 local tensor shape torch.Size([5, 2]):\n", "[[10. 11.]\n", " [10. 11.]\n", " [10. 11.]\n", " [10. 11.]\n", " [10. 11.]]\n", "Rank 3 local tensor shape torch.Size([5, 2]):\n", "[[30. 31.]\n", " [30. 31.]\n", " [30. 31.]\n", " [30. 31.]\n", " [30. 31.]]\n", "Rank 2 local tensor shape torch.Size([5, 2]):\n", "[[20. 21.]\n", " [20. 21.]\n", " [20. 21.]\n", " [20. 21.]\n", " [20. 21.]]\n", "Rank 0 local tensor shape torch.Size([5, 2]):\n", "[[0. 1.]\n", " [0. 1.]\n", " [0. 1.]\n", " [0. 1.]\n", " [0. 1.]]\n", "Rank 3 after column-wise all_gather, shape torch.Size([5, 8]):\n", "[[ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]]\n", "Rank 0 after column-wise all_gather, shape torch.Size([5, 8]):\n", "[[ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]]\n", "Rank 1 after column-wise all_gather, shape torch.Size([5, 8]):\n", "[[ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]]\n", "Rank 2 after column-wise all_gather, shape torch.Size([5, 8]):\n", "[[ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]\n", " [ 0. 1. 10. 11. 20. 21. 30. 31.]]\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "# 4 规约(Reduce)" ], "metadata": { "id": "VzXDWep8_ng_" } }, { "cell_type": "code", "source": [ "import os\n", "import torch\n", "import torch.distributed as dist\n", "import torch.multiprocessing as mp\n", "\n", "os.environ['MASTER_ADDR'] = 'localhost'\n", "os.environ['MASTER_PORT'] = '62115'\n", "\n", "def example(rank, world_size):\n", " # 初始化进程组\n", " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", " torch.cuda.set_device(rank)\n", "\n", " # 本地张量:形状[1, 5],数值为rank\n", " tensor = torch.ones(1, 5, device=rank) * rank\n", "\n", " print(f\"Rank {rank} before reduce: {tensor.cpu().tolist()}\")\n", "\n", " # 执行reduce操作,将所有进程的tensor求和,结果存储到rank 0的tensor中\n", " # 注意:reduce 后,非目标进程的tensor内容可能不再有效\n", " dist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM)\n", "\n", " # 在 rank 0 上打印规约结果\n", " if rank == 0:\n", " print(f\"Rank {rank} after reduce (sum): {tensor.cpu().tolist()}\")\n", " else:\n", " # 非目标进程的tensor内容未定义,但为了演示,打印其当前值\n", " print(f\"Rank {rank} after reduce (tensor content undefined): {tensor.cpu().tolist()}\")\n", "\n", "def main():\n", " world_size = 4\n", " mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "M5nK9Hwg_0JB", "outputId": "4c630841-82ad-4f10-b3a2-b9eb53471f8e" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Rank 2 before reduce: [[2.0, 2.0, 2.0, 2.0, 2.0]]\n", "Rank 3 before reduce: [[3.0, 3.0, 3.0, 3.0, 3.0]]\n", "Rank 1 before reduce: [[1.0, 1.0, 1.0, 1.0, 1.0]]\n", "Rank 0 before reduce: [[0.0, 0.0, 0.0, 0.0, 0.0]]\n", "Rank 1 after reduce (tensor content undefined): [[1.0, 1.0, 1.0, 1.0, 1.0]]\n", "Rank 2 after reduce (tensor content undefined): [[2.0, 2.0, 2.0, 2.0, 2.0]]\n", "Rank 0 after reduce (sum): [[6.0, 6.0, 6.0, 6.0, 6.0]]\n", "Rank 3 after reduce (tensor content undefined): [[3.0, 3.0, 3.0, 3.0, 3.0]]\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "# 5 全规约(all reduce)" ], "metadata": { "id": "6DbKnHxl_0bv" } }, { "cell_type": "code", "source": [ "import os\n", "import torch\n", "import torch.distributed as dist\n", "import torch.multiprocessing as mp\n", "\n", "os.environ['MASTER_ADDR'] = 'localhost'\n", "os.environ['MASTER_PORT'] = '62115'\n", "\n", "def example(rank, world_size):\n", " # 初始化进程组\n", " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", " torch.cuda.set_device(rank)\n", "\n", " # 本地张量:形状[1, 5],数值为 rank\n", " tensor = torch.ones(1, 5, device=rank) * rank\n", "\n", " print(f\"Rank {rank} before all_reduce: {tensor.cpu().tolist()}\")\n", "\n", " # 执行all_reduce操作,求和并广播到所有进程\n", " dist.all_reduce(tensor, op=dist.ReduceOp.SUM)\n", "\n", " print(f\"Rank {rank} after all_reduce (sum): {tensor.cpu().tolist()}\")\n", "\n", "def main():\n", " world_size = 4\n", " mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "pEmiEug2_504", "outputId": "2ce1e53b-87b5-4bc0-fa3c-d1a4c20e22c1" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Rank 2 before all_reduce: [[2.0, 2.0, 2.0, 2.0, 2.0]]\n", "Rank 3 before all_reduce: [[3.0, 3.0, 3.0, 3.0, 3.0]]\n", "Rank 1 before all_reduce: [[1.0, 1.0, 1.0, 1.0, 1.0]]\n", "Rank 0 before all_reduce: [[0.0, 0.0, 0.0, 0.0, 0.0]]\n", "Rank 1 after all_reduce (sum): [[6.0, 6.0, 6.0, 6.0, 6.0]]\n", "Rank 2 after all_reduce (sum): [[6.0, 6.0, 6.0, 6.0, 6.0]]\n", "Rank 3 after all_reduce (sum): [[6.0, 6.0, 6.0, 6.0, 6.0]]\n", "Rank 0 after all_reduce (sum): [[6.0, 6.0, 6.0, 6.0, 6.0]]\n", "\n" ] } ] }, { "cell_type": "code", "source": [ "# 增加维度观测输出结果:\n", "import os\n", "import torch\n", "import torch.distributed as dist\n", "import torch.multiprocessing as mp\n", "\n", "os.environ['MASTER_ADDR'] = 'localhost'\n", "os.environ['MASTER_PORT'] = '62115'\n", "\n", "def example(rank, world_size):\n", " # 初始化进程组\n", " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", " torch.cuda.set_device(rank)\n", "\n", " # 本地张量:形状[2, 5],所有元素值为 rank\n", " tensor = torch.full((2, 5), rank, dtype=torch.float, device=rank)\n", "\n", " print(f\"Rank {rank} before all_reduce:\\n{tensor.cpu().numpy()}\")\n", "\n", " # 执行all_reduce求和\n", " dist.all_reduce(tensor, op=dist.ReduceOp.SUM)\n", "\n", " print(f\"Rank {rank} after all_reduce (sum):\\n{tensor.cpu().numpy()}\")\n", "\n", "def main():\n", " world_size = 4\n", " mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "BF8gv_VrFs1V", "outputId": "901cebeb-3be3-45f8-cbc5-7e179b942c7b" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Rank 2 before all_reduce:\n", "[[2. 2. 2. 2. 2.]\n", " [2. 2. 2. 2. 2.]]\n", "Rank 1 before all_reduce:\n", "[[1. 1. 1. 1. 1.]\n", " [1. 1. 1. 1. 1.]]\n", "Rank 3 before all_reduce:\n", "[[3. 3. 3. 3. 3.]\n", " [3. 3. 3. 3. 3.]]\n", "Rank 0 before all_reduce:\n", "[[0. 0. 0. 0. 0.]\n", " [0. 0. 0. 0. 0.]]\n", "Rank 0 after all_reduce (sum):\n", "[[6. 6. 6. 6. 6.]\n", " [6. 6. 6. 6. 6.]]\n", "Rank 2 after all_reduce (sum):\n", "[[6. 6. 6. 6. 6.]\n", " [6. 6. 6. 6. 6.]]\n", "Rank 1 after all_reduce (sum):\n", "[[6. 6. 6. 6. 6.]\n", " [6. 6. 6. 6. 6.]]\n", "Rank 3 after all_reduce (sum):\n", "[[6. 6. 6. 6. 6.]\n", " [6. 6. 6. 6. 6.]]\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "# 6 分发(scatter)" ], "metadata": { "id": "lF8zcIzo_6GF" } }, { "cell_type": "code", "source": [ "import os\n", "import torch\n", "import torch.distributed as dist\n", "import torch.multiprocessing as mp\n", "\n", "os.environ['MASTER_ADDR'] = 'localhost'\n", "os.environ['MASTER_PORT'] = '62115'\n", "\n", "def example(rank, world_size):\n", " # 初始化进程组\n", " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", " torch.cuda.set_device(rank)\n", "\n", " # 接收缓冲区:每个进程得到一行,形状[5]\n", " recv_tensor = torch.empty(5, device=rank, dtype=torch.float)\n", "\n", " if rank == 0:\n", " # 源进程:创建一个形状[world_size, 5]的大张量,每一行填充行号\n", " data_to_scatter = torch.arange(world_size, device=0).float().unsqueeze(1).repeat(1, 5)\n", " print(f\"Rank {rank}: Original data to scatter:\\n{data_to_scatter.cpu().numpy()}\")\n", "\n", " # 将大张量拆分为列表,每个元素是形状[5]的张量(对应每一行)\n", " scatter_list = [data_to_scatter[i] for i in range(world_size)] # 每个元素形状[5]\n", " print(f\"Rank {rank}: Scatter list shapes: {[t.shape for t in scatter_list]}\")\n", " else:\n", " scatter_list = None\n", "\n", " # 执行scatter操作\n", " dist.scatter(recv_tensor, scatter_list=scatter_list, src=0)\n", "\n", " print(f\"Rank {rank} received tensor of shape {recv_tensor.shape}:\\n{recv_tensor.cpu().numpy()}\")\n", "\n", "def main():\n", " world_size = 4\n", " mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "m2vL7569_6O-", "outputId": "547f5a42-c575-49fc-bad4-2dd8cb1f71e0" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Rank 0: Original data to scatter:\n", "[[0. 0. 0. 0. 0.]\n", " [1. 1. 1. 1. 1.]\n", " [2. 2. 2. 2. 2.]\n", " [3. 3. 3. 3. 3.]]\n", "Rank 0: Scatter list shapes: [torch.Size([5]), torch.Size([5]), torch.Size([5]), torch.Size([5])]\n", "Rank 0 received tensor of shape torch.Size([5]):\n", "[0. 0. 0. 0. 0.]\n", "Rank 1 received tensor of shape torch.Size([5]):\n", "[1. 1. 1. 1. 1.]\n", "Rank 3 received tensor of shape torch.Size([5]):\n", "[3. 3. 3. 3. 3.]\n", "Rank 2 received tensor of shape torch.Size([5]):\n", "[2. 2. 2. 2. 2.]\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "# 7 规约分发(Reduce Scatter)\n", "\n", "输入[world_size, 5],沿着第0维切分。每个rank拿到[1, 5] 数据。\n", "\n", "\n", "## 7.1 Row reduce scatter(使用dist.reduce_scatter接口)" ], "metadata": { "id": "mRuKlAymAClJ" } }, { "cell_type": "code", "source": [ "import os\n", "import torch\n", "import torch.distributed as dist\n", "import torch.multiprocessing as mp\n", "\n", "os.environ['MASTER_ADDR'] = 'localhost'\n", "os.environ['MASTER_PORT'] = '62115'\n", "\n", "def example(rank, world_size):\n", " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", " torch.cuda.set_device(rank)\n", "\n", " # 原始数据:形状[world_size, 5],第i行的值为rank*10 + i\n", " data = torch.zeros(world_size, 5, device=rank, dtype=torch.float)\n", " for i in range(world_size):\n", " data[i] = rank * 10 + i\n", "\n", " print(f\"Rank {rank} original data (shape {data.shape}):\\n{data.cpu().numpy()}\")\n", "\n", " # 将data按第一维拆分成列表,每个元素保持二维形状[1, 5](即每个分片是一行并保留维度)\n", " input_list = [data[i].unsqueeze(0) for i in range(world_size)] # 列表长度=world_size,每个元素形状[1, 5]\n", "\n", " # 输出张量也设为二维[1, 5],用于接收规约后属于当前 rank 的分片\n", " output = torch.empty(1, 5, device=rank, dtype=torch.float)\n", "\n", " # 执行 reduce_scatter:所有进程的input_list中对应本进程的分片会被规约(求和)后存入output\n", " dist.reduce_scatter(output, input_list, op=dist.ReduceOp.SUM)\n", "\n", " print(f\"Rank {rank} after reduce_scatter (sum), output shape {output.shape}:\\n{output.cpu().numpy()}\")\n", "\n", "def main():\n", " world_size = 4\n", " mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "DlQ3CaNnANN8", "outputId": "75ba085f-3006-4f5b-9827-58c6e1c68beb" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Rank 2 original data (shape torch.Size([4, 5])):\n", "[[20. 20. 20. 20. 20.]\n", " [21. 21. 21. 21. 21.]\n", " [22. 22. 22. 22. 22.]\n", " [23. 23. 23. 23. 23.]]\n", "Rank 3 original data (shape torch.Size([4, 5])):\n", "[[30. 30. 30. 30. 30.]\n", " [31. 31. 31. 31. 31.]\n", " [32. 32. 32. 32. 32.]\n", " [33. 33. 33. 33. 33.]]\n", "Rank 1 original data (shape torch.Size([4, 5])):\n", "[[10. 10. 10. 10. 10.]\n", " [11. 11. 11. 11. 11.]\n", " [12. 12. 12. 12. 12.]\n", " [13. 13. 13. 13. 13.]]\n", "Rank 0 original data (shape torch.Size([4, 5])):\n", "[[0. 0. 0. 0. 0.]\n", " [1. 1. 1. 1. 1.]\n", " [2. 2. 2. 2. 2.]\n", " [3. 3. 3. 3. 3.]]\n", "Rank 0 after reduce_scatter (sum), output shape torch.Size([1, 5]):\n", "[[60. 60. 60. 60. 60.]]\n", "Rank 1 after reduce_scatter (sum), output shape torch.Size([1, 5]):\n", "[[64. 64. 64. 64. 64.]]\n", "Rank 2 after reduce_scatter (sum), output shape torch.Size([1, 5]):\n", "[[68. 68. 68. 68. 68.]]\n", "Rank 3 after reduce_scatter (sum), output shape torch.Size([1, 5]):\n", "[[72. 72. 72. 72. 72.]]\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "## 7.2 Row reduce scatter(使用reduce_scatter_tensor接口)" ], "metadata": { "id": "yl8JO7AjK2dd" } }, { "cell_type": "code", "source": [ "import os\n", "import torch\n", "import torch.distributed as dist\n", "import torch.multiprocessing as mp\n", "\n", "os.environ['MASTER_ADDR'] = 'localhost'\n", "os.environ['MASTER_PORT'] = '62115'\n", "\n", "def example(rank, world_size):\n", " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", " torch.cuda.set_device(rank)\n", "\n", " # --- 构造输入大张量 ---\n", " # 每个进程有一个形状为[world_size, 5]的大张量\n", " # 这个张量相当于之前input_list中所有元素的拼接\n", " input_tensor = torch.zeros(world_size, 5, device=rank, dtype=torch.float)\n", " for i in range(world_size):\n", " input_tensor[i] = rank * 10 + i\n", " print(f\"Rank {rank} input_tensor (shape {input_tensor.shape}):\\n{input_tensor.cpu().numpy()}\")\n", "\n", " # --- 准备输出张量 ---\n", " # 输出张量的形状是[1, 5],即每个进程最终得到的结果块\n", " # 注意:因为沿第0维切分,输入是 [world_size, 5],每个进程得到的块大小就是 [1, 5]\n", " output = torch.empty(1, 5, device=rank, dtype=torch.float)\n", "\n", " # --- 执行 reduce_scatter_tensor ---\n", " # PyTorch的reduce_scatter_tensor API会自动将input_tensor沿dim=0切分成world_size块\n", " # 并将对应rank的块进行规约(求和)后存入output\n", " dist.reduce_scatter_tensor(output, input_tensor, op=dist.ReduceOp.SUM)\n", "\n", " print(f\"Rank {rank} after reduce_scatter_tensor (sum), output shape {output.shape}:\\n{output.cpu().numpy()}\")\n", "\n", "def main():\n", " world_size = 4\n", " mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gWYs-10ILD2o", "outputId": "2d485e43-7d5a-4356-9567-71ed71fea981" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Rank 1 input_tensor (shape torch.Size([4, 5])):\n", "[[10. 10. 10. 10. 10.]\n", " [11. 11. 11. 11. 11.]\n", " [12. 12. 12. 12. 12.]\n", " [13. 13. 13. 13. 13.]]\n", "Rank 2 input_tensor (shape torch.Size([4, 5])):\n", "[[20. 20. 20. 20. 20.]\n", " [21. 21. 21. 21. 21.]\n", " [22. 22. 22. 22. 22.]\n", " [23. 23. 23. 23. 23.]]\n", "Rank 0 input_tensor (shape torch.Size([4, 5])):\n", "[[0. 0. 0. 0. 0.]\n", " [1. 1. 1. 1. 1.]\n", " [2. 2. 2. 2. 2.]\n", " [3. 3. 3. 3. 3.]]\n", "Rank 3 input_tensor (shape torch.Size([4, 5])):\n", "[[30. 30. 30. 30. 30.]\n", " [31. 31. 31. 31. 31.]\n", " [32. 32. 32. 32. 32.]\n", " [33. 33. 33. 33. 33.]]\n", "Rank 3 after reduce_scatter_tensor (sum), output shape torch.Size([1, 5]):\n", "[[72. 72. 72. 72. 72.]]\n", "Rank 0 after reduce_scatter_tensor (sum), output shape torch.Size([1, 5]):\n", "[[60. 60. 60. 60. 60.]]\n", "Rank 1 after reduce_scatter_tensor (sum), output shape torch.Size([1, 5]):\n", "[[64. 64. 64. 64. 64.]]\n", "Rank 2 after reduce_scatter_tensor (sum), output shape torch.Size([1, 5]):\n", "[[68. 68. 68. 68. 68.]]\n", "\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "## 7.3 Column reduce scatter" ], "metadata": { "id": "L05YEs8OLYJI" } }, { "cell_type": "code", "source": [ "import os\n", "import torch\n", "import torch.distributed as dist\n", "import torch.multiprocessing as mp\n", "\n", "os.environ['MASTER_ADDR'] = 'localhost'\n", "os.environ['MASTER_PORT'] = '62115'\n", "\n", "def example(rank, world_size):\n", " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", " torch.cuda.set_device(rank)\n", "\n", " # 原始数据:形状[5, world_size],每一列的值设为rank*10 + j\n", " data = torch.zeros(5, world_size, device=rank, dtype=torch.float)\n", " for j in range(world_size):\n", " data[:, j] = rank * 10 + j\n", " print(f\"Rank {rank} original data (shape {data.shape}):\\n{data.cpu().numpy()}\")\n", "\n", " # 沿维度1(列)切分,每个分片形状[5, 1]\n", " # 使用torch.split 沿dim=1切分成world_size个[5,1]的块\n", " input_list = list(torch.split(data, 1, dim=1)) # 列表长度=world_size\n", "\n", " # 输出张量:接收规约后属于本进程的列,形状也是[5, 1]\n", " output = torch.empty(5, 1, device=rank, dtype=torch.float)\n", "\n", " # 执行reduce_scatter:所有进程的input_list中对应本进程的分片(即第 rank 列)被求和后存入output\n", " dist.reduce_scatter(output, input_list, op=dist.ReduceOp.SUM)\n", "\n", " print(f\"Rank {rank} after reduce_scatter (sum), output shape {output.shape}:\\n{output.cpu().numpy()}\")\n", "\n", "def main():\n", " world_size = 4\n", " mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "pQFnDRM-K2pS", "outputId": "2b3e7688-af9b-4ef8-849a-150c855e91d6" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Rank 3 original data (shape torch.Size([5, 4])):\n", "[[30. 31. 32. 33.]\n", " [30. 31. 32. 33.]\n", " [30. 31. 32. 33.]\n", " [30. 31. 32. 33.]\n", " [30. 31. 32. 33.]]\n", "Rank 1 original data (shape torch.Size([5, 4])):\n", "[[10. 11. 12. 13.]\n", " [10. 11. 12. 13.]\n", " [10. 11. 12. 13.]\n", " [10. 11. 12. 13.]\n", " [10. 11. 12. 13.]]\n", "Rank 2 original data (shape torch.Size([5, 4])):\n", "[[20. 21. 22. 23.]\n", " [20. 21. 22. 23.]\n", " [20. 21. 22. 23.]\n", " [20. 21. 22. 23.]\n", " [20. 21. 22. 23.]]\n", "Rank 0 original data (shape torch.Size([5, 4])):\n", "[[0. 1. 2. 3.]\n", " [0. 1. 2. 3.]\n", " [0. 1. 2. 3.]\n", " [0. 1. 2. 3.]\n", " [0. 1. 2. 3.]]\n", "Rank 2 after reduce_scatter (sum), output shape torch.Size([5, 1]):\n", "[[68.]\n", " [68.]\n", " [68.]\n", " [68.]\n", " [68.]]\n", "Rank 1 after reduce_scatter (sum), output shape torch.Size([5, 1]):\n", "[[64.]\n", " [64.]\n", " [64.]\n", " [64.]\n", " [64.]]Rank 0 after reduce_scatter (sum), output shape torch.Size([5, 1]):\n", "[[60.]\n", " [60.]\n", " [60.]\n", " [60.]\n", " [60.]]\n", "\n", "Rank 3 after reduce_scatter (sum), output shape torch.Size([5, 1]):\n", "[[72.]\n", " [72.]\n", " [72.]\n", " [72.]\n", " [72.]]\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "# 8 多对多(all to all)\n", "\n", "## 8.1 Row alltoall" ], "metadata": { "id": "zSo_EPKeAMB4" } }, { "cell_type": "code", "source": [ "import os\n", "import torch\n", "import torch.distributed as dist\n", "import torch.multiprocessing as mp\n", "\n", "os.environ['MASTER_ADDR'] = 'localhost'\n", "os.environ['MASTER_PORT'] = '62115'\n", "\n", "def example(rank, world_size):\n", " # 初始化进程组\n", " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", " torch.cuda.set_device(rank)\n", "\n", " # 每个元素的大小(这里每个数据块是[2]的张量,所以每个元素实际是标量,但这里用块大小表示)\n", " element_size = 2 # 每个块包含的元素个数\n", "\n", " # --- 构造发送数据 ---\n", " # 发送缓冲区形状:[world_size, element_size]\n", " send_tensor = torch.zeros(world_size, element_size, device=rank, dtype=torch.float)\n", " for dst in range(world_size):\n", " # 将要发送给进程dst的块填充为rank*10 + dst\n", " send_tensor[dst] = rank * 10 + dst\n", " print(f\"Rank {rank} send_tensor (before alltoall):\\n{send_tensor.cpu().numpy()}\")\n", "\n", " # --- 准备接收缓冲区 ---\n", " # 接收缓冲区形状同样为[world_size, element_size]\n", " recv_tensor = torch.empty(world_size, element_size, device=rank, dtype=torch.float)\n", "\n", " # --- 执行 All-to-All ---\n", " # 使用all_to_all_single,它会自动根据张量的第0维切分(因为传入的是单个张量,不是列表)\n", " # 参数说明:\n", " # output: 接收缓冲区\n", " # input: 发送缓冲区\n", " # 其他参数(如 output_split_sizes, input_split_sizes)可省略,默认均匀切分\n", " dist.all_to_all_single(recv_tensor, send_tensor)\n", "\n", " print(f\"\\nRank {rank} recv_tensor (after alltoall):\\n{recv_tensor.cpu().numpy()}\")\n", "\n", "def main():\n", " world_size = 4\n", " mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4uYLySZlACuw", "outputId": "271f66c9-6c30-484e-ecc4-ea77008ccdb2" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Rank 3 send_tensor (before alltoall):\n", "[[30. 30.]\n", " [31. 31.]\n", " [32. 32.]\n", " [33. 33.]]\n", "Rank 1 send_tensor (before alltoall):\n", "[[10. 10.]\n", " [11. 11.]\n", " [12. 12.]\n", " [13. 13.]]\n", "Rank 0 send_tensor (before alltoall):\n", "[[0. 0.]\n", " [1. 1.]\n", " [2. 2.]\n", " [3. 3.]]\n", "Rank 2 send_tensor (before alltoall):\n", "[[20. 20.]\n", " [21. 21.]\n", " [22. 22.]\n", " [23. 23.]]\n", "\n", "Rank 0 recv_tensor (after alltoall):\n", "[[ 0. 0.]\n", " [10. 10.]\n", " [20. 20.]\n", " [30. 30.]]\n", "Rank 1 recv_tensor (after alltoall):\n", "[[ 1. 1.]\n", " [11. 11.]\n", " [21. 21.]\n", " [31. 31.]]\n", "\n", "\n", "Rank 2 recv_tensor (after alltoall):\n", "[[ 2. 2.]\n", " [12. 12.]\n", " [22. 22.]\n", " [32. 32.]]\n", "\n", "Rank 3 recv_tensor (after alltoall):\n", "[[ 3. 3.]\n", " [13. 13.]\n", " [23. 23.]\n", " [33. 33.]]\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "## 8.2 Column alltoall" ], "metadata": { "id": "dp-WiqN8QBRU" } }, { "cell_type": "code", "source": [ "import os\n", "import torch\n", "import torch.distributed as dist\n", "import torch.multiprocessing as mp\n", "\n", "os.environ['MASTER_ADDR'] = 'localhost'\n", "os.environ['MASTER_PORT'] = '62115'\n", "\n", "def example(rank, world_size):\n", " # 初始化进程组\n", " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", " torch.cuda.set_device(rank)\n", "\n", " # 本地张量:形状[5, world_size],每一列的值设为 rank*10 + 本地列索引\n", " local_tensor = torch.zeros(5, world_size, device=rank, dtype=torch.float)\n", " for col in range(world_size):\n", " local_tensor[:, col] = rank * 10 + col\n", " print(f\"Rank {rank} local tensor (shape {local_tensor.shape}):\\n{local_tensor.cpu().numpy()}\")\n", "\n", " # 沿列切分成world_size个块,每个块形状[5, 1],并确保连续\n", " # 方法1:使用切片 + contiguous\n", " input_list = [local_tensor[:, col:col+1].contiguous() for col in range(world_size)]\n", " # 方法2:使用 torch.chunk + contiguous\n", " # chunks = torch.chunk(local_tensor, world_size, dim=1)\n", " # input_list = [chunk.contiguous() for chunk in chunks]\n", "\n", " # 准备输出列表,每个元素形状[5, 1](新建张量默认连续)\n", " output_list = [torch.empty(5, 1, device=rank, dtype=torch.float) for _ in range(world_size)]\n", "\n", " # 执行 all_to_all(列表版本)\n", " dist.all_to_all(output_list, input_list)\n", "\n", " # 沿列维度拼接输出列表,得到最终的[5, world_size] 张量\n", " result = torch.cat(output_list, dim=1)\n", "\n", " print(f\"Rank {rank} after column-wise all_to_all (shape {result.shape}):\\n{result.cpu().numpy()}\")\n", "\n", "def main():\n", " world_size = 4\n", " mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9tDTB-eYQX6U", "outputId": "281d10c9-9827-4d54-a276-611031439e29" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Rank 3 local tensor (shape torch.Size([5, 4])):\n", "[[30. 31. 32. 33.]\n", " [30. 31. 32. 33.]\n", " [30. 31. 32. 33.]\n", " [30. 31. 32. 33.]\n", " [30. 31. 32. 33.]]\n", "Rank 2 local tensor (shape torch.Size([5, 4])):\n", "[[20. 21. 22. 23.]\n", " [20. 21. 22. 23.]\n", " [20. 21. 22. 23.]\n", " [20. 21. 22. 23.]\n", " [20. 21. 22. 23.]]\n", "Rank 0 local tensor (shape torch.Size([5, 4])):\n", "[[0. 1. 2. 3.]\n", " [0. 1. 2. 3.]\n", " [0. 1. 2. 3.]\n", " [0. 1. 2. 3.]\n", " [0. 1. 2. 3.]]\n", "Rank 1 local tensor (shape torch.Size([5, 4])):\n", "[[10. 11. 12. 13.]\n", " [10. 11. 12. 13.]\n", " [10. 11. 12. 13.]\n", " [10. 11. 12. 13.]\n", " [10. 11. 12. 13.]]\n", "Rank 0 after column-wise all_to_all (shape torch.Size([5, 4])):\n", "[[ 0. 10. 20. 30.]\n", " [ 0. 10. 20. 30.]\n", " [ 0. 10. 20. 30.]\n", " [ 0. 10. 20. 30.]\n", " [ 0. 10. 20. 30.]]\n", "Rank 3 after column-wise all_to_all (shape torch.Size([5, 4])):\n", "[[ 3. 13. 23. 33.]\n", " [ 3. 13. 23. 33.]\n", " [ 3. 13. 23. 33.]\n", " [ 3. 13. 23. 33.]\n", " [ 3. 13. 23. 33.]]\n", "Rank 1 after column-wise all_to_all (shape torch.Size([5, 4])):\n", "[[ 1. 11. 21. 31.]\n", " [ 1. 11. 21. 31.]\n", " [ 1. 11. 21. 31.]\n", " [ 1. 11. 21. 31.]\n", " [ 1. 11. 21. 31.]]\n", "Rank 2 after column-wise all_to_all (shape torch.Size([5, 4])):\n", "[[ 2. 12. 22. 32.]\n", " [ 2. 12. 22. 32.]\n", " [ 2. 12. 22. 32.]\n", " [ 2. 12. 22. 32.]\n", " [ 2. 12. 22. 32.]]\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "# 9 广播(broadcast)" ], "metadata": { "id": "-QAQc2okyX4o" } }, { "cell_type": "code", "source": [ "import os\n", "import torch\n", "import torch.distributed as dist\n", "import torch.multiprocessing as mp\n", "\n", "os.environ['MASTER_ADDR'] = 'localhost'\n", "os.environ['MASTER_PORT'] = '62115'\n", "\n", "def example(rank, world_size):\n", " # 初始化进程组\n", " dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n", " torch.cuda.set_device(rank)\n", "\n", " # 定义张量形状\n", " shape = (5, 2)\n", "\n", " if rank == 0:\n", " # 源进程:创建一个全为1的张量\n", " tensor = torch.ones(shape, device=rank, dtype=torch.float)\n", " print(f\"Rank {rank} before broadcast:\\n{tensor.cpu().numpy()}\")\n", " else:\n", " # 非源进程:创建一个空的张量(或零张量),形状与源进程一致\n", " tensor = torch.zeros(shape, device=rank, dtype=torch.float)\n", " print(f\"Rank {rank} before broadcast (initial zeros):\\n{tensor.cpu().numpy()}\")\n", "\n", " # 执行广播:所有进程从rank 0接收数据\n", " dist.broadcast(tensor, src=0)\n", "\n", " # 打印广播后的结果\n", " print(f\"Rank {rank} after broadcast:\\n{tensor.cpu().numpy()}\")\n", "\n", "def main():\n", " world_size = 4\n", " mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "GLnlkSURyYCt", "outputId": "7d974f46-fa18-496d-9d40-c8712731682e" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Rank 1 before broadcast (initial zeros):\n", "[[0. 0.]\n", " [0. 0.]\n", " [0. 0.]\n", " [0. 0.]\n", " [0. 0.]]\n", "Rank 2 before broadcast (initial zeros):\n", "[[0. 0.]\n", " [0. 0.]\n", " [0. 0.]\n", " [0. 0.]\n", " [0. 0.]]\n", "Rank 0 before broadcast:\n", "[[1. 1.]\n", " [1. 1.]\n", " [1. 1.]\n", " [1. 1.]\n", " [1. 1.]]\n", "Rank 3 before broadcast (initial zeros):\n", "[[0. 0.]\n", " [0. 0.]\n", " [0. 0.]\n", " [0. 0.]\n", " [0. 0.]]\n", "Rank 0 after broadcast:\n", "[[1. 1.]\n", " [1. 1.]\n", " [1. 1.]\n", " [1. 1.]\n", " [1. 1.]]\n", "Rank 1 after broadcast:\n", "[[1. 1.]\n", " [1. 1.]\n", " [1. 1.]\n", " [1. 1.]\n", " [1. 1.]]\n", "Rank 3 after broadcast:\n", "[[1. 1.]\n", " [1. 1.]\n", " [1. 1.]\n", " [1. 1.]\n", " [1. 1.]]\n", "Rank 2 after broadcast:\n", "[[1. 1.]\n", " [1. 1.]\n", " [1. 1.]\n", " [1. 1.]\n", " [1. 1.]]\n", "\n" ] } ] } ] }