{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "authorship_tag": "ABX9TyPZS2W5wUHH6xSFJCGwAPPy" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "# CUDA Graph的原理演示\n", "\n", "通过PyTorch提供的CUDA API来创建图编译与运算,通过了解其基本操作流程来理解其在推理中应用场景。\n", "\n", "相关文章:\n", "\n", "- [vLLM为什么没在prefill阶段支持cuda graph?](https://www.zhihu.com/question/7987565201/answer/2012589977544991690)\n", "\n", "\n", "Author: kaiyuan\n", "\n", "Email: kyxie@zju.edu.cn" ], "metadata": { "id": "W44GOFaqlxje" } }, { "cell_type": "markdown", "source": [ "镜像拉取:\n", "```\n", "docker pull nvcr.io/nvidia/sglang:26.01-py3\n", "```\n", "版本介绍:[Link](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/sglang?version=26.01-py3)\n", "\n", "容器创建示例:\n", "```\n", "docker run -itd --rm --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \\\n", "-v /data/nfs/kaiyuan:/data/nfs/kaiyuan \\\n", "--name sglang-dev nvcr.io/nvidia/sglang:26.01-py3 bash\n", "```\n", "\n", "本例测试机器信息:\n", "- NVIDIA A100-SXM4-80GB\n", "- NVIDIA-SMI 570.172.08\n", "- Driver Version: 570.172.08\n", "- CUDA Version: 13.1\n" ], "metadata": { "id": "VlaP3eComuQl" } }, { "cell_type": "markdown", "source": [ "## 1 图捕获与重放\n", "\n", "构建一个简单model,完成图捕获与重放。\n", "\n", "- 图捕获:每个 batch size 对应一个独立的 CUDAGraph,因为图固定了张量的形状和内存地址。\n", "- 捕获时使用固定的 static_input和static_output,它们的存储地址被记录在图内;\n", "\n", "- 重放前通过 .copy_() 将新数据写入static_input,保证内存地址不变;\n", "\n", "- 重放后结果直接出现在static_output中。" ], "metadata": { "id": "ayUtxmPynCz7" } }, { "cell_type": "code", "source": [ "import torch\n", "import torch.nn as nn\n", "import random\n", "\n", "# 检查 CUDA 是否可用\n", "assert torch.cuda.is_available(), \"CUDA不可用!本示例需要GPU\"\n", "device = torch.device('cuda')\n", "\n", "# -------------------- 定义模型 --------------------\n", "class SimpleModel(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.fc1 = nn.Linear(32, 64) # 输入特征 32 -> 64\n", " self.fc2 = nn.Linear(64, 32) # 64 -> 32\n", " self.relu = nn.ReLU()\n", "\n", " def forward(self, x):\n", " # x 形状: [bs, 8, 32]\n", " x = self.fc1(x) # [bs, 8, 64]\n", " x = self.relu(x)\n", " x = self.fc2(x) # [bs, 8, 32]\n", " return x\n", "\n", "model = SimpleModel().to(device)\n", "model.eval() # 推理模式,关闭dropout/batchnorm等随机行为\n", "\n", "# -------------------- 为不同batch size捕获图 --------------------\n", "batch_sizes = [1, 2, 4, 8, 16]\n", "graph_pool = {} # 字典:bs -> (graph, static_input, static_output)\n", "\n", "# 预热 CUDA 上下文,避免捕获时包含自动调优开销\n", "warmup_input = torch.randn(8, 8, 32, device=device)\n", "for _ in range(3):\n", " _ = model(warmup_input)\n", "torch.cuda.synchronize()\n", "\n", "for bs in batch_sizes:\n", " # 创建固定的输入、输出占位张量(图会记住它们的地址)\n", " static_input = torch.randn(bs, 8, 32, device=device)\n", " static_output = torch.empty_like(static_input) # 形状与输入相同\n", "\n", " # 开始捕获\n", " graph = torch.cuda.CUDAGraph()\n", " with torch.cuda.graph(graph):\n", " # 在此上下文中执行的所有 CUDA 操作都会被捕获\n", " static_output = model(static_input)\n", "\n", " # 将图及相关张量保存到池中\n", " graph_pool[bs] = (graph, static_input, static_output)\n", "\n", "print(f\"已为 batch sizes {batch_sizes} 捕获图完成。\")\n", "\n", "# -------------------- 模拟多次推理,随机选择batch size --------------------\n", "num_iterations = 10\n", "for i in range(num_iterations):\n", " # 随机选择一个batch size\n", " bs = random.choice(batch_sizes)\n", " graph, static_input, static_output = graph_pool[bs]\n", "\n", " # 生成新的随机输入数据\n", " new_input = torch.randn(bs, 8, 32, device=device)\n", "\n", " # 将新数据复制到图使用的静态输入张量中(in-place 操作,不改变地址)\n", " static_input.copy_(new_input)\n", "\n", " # 重放图\n", " graph.replay()\n", "\n", " # 此时 static_output 已经更新为对应新输入的计算结果\n", " # 可以取出结果用于后续处理,例如与普通前向结果对比验证\n", " with torch.no_grad():\n", " expected_output = model(new_input)\n", "\n", " # 验证结果是否一致(允许微小误差)\n", " if torch.allclose(static_output, expected_output, atol=1e-5):\n", " print(f\"迭代 {i+1}: bs={bs} 结果与一致\")\n", " else:\n", " print(f\"迭代 {i+1}: bs={bs} 结果不一致!\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "jb0IPPGMoFxk", "outputId": "5be94b53-11e2-4341-9faf-94ae5bed35c9" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "已为 batch sizes [1, 2, 4, 8, 16] 捕获图完成。\n", "迭代 1: bs=4 结果与一致\n", "迭代 2: bs=2 结果与一致\n", "迭代 3: bs=4 结果与一致\n", "迭代 4: bs=1 结果与一致\n", "迭代 5: bs=4 结果与一致\n", "迭代 6: bs=1 结果与一致\n", "迭代 7: bs=2 结果与一致\n", "迭代 8: bs=4 结果与一致\n", "迭代 9: bs=4 结果与一致\n", "迭代 10: bs=8 结果与一致\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "## 2 部分图编译\n", "\n", "此示例展示如何将CUDA Graph应用于模型的一部分,在保持其余部分灵活性的同时,对频繁执行的子图进行加速。\n", "\n", "模型ThreePartModel包含三个线性层(fc_a, fc_b, fc_c),每层后跟 ReLU(除了最后一层)。为简化,我们将fc_b单独捕获,而将ReLU放在图外处理(也可将ReLU包含在图内,但需确保其不会改变张量地址,通常in-place操作也可捕获,但为清晰起见,本例将激活函数放在图外)。\n", "\n", "部分捕获:\n", "\n", "- 捕获时只针对 model.fc_b(static_input_b),不包含前后模块。\n", "\n", "- 静态输入static_input_b的形状与模块A的输出一致,即[bs, 8, 64]\n", "\n", "- 静态输出static_output_b的形状与模块B的输出一致,即[bs, 8, 128]\n" ], "metadata": { "id": "0x7b9reLo-qD" } }, { "cell_type": "code", "source": [ "import torch\n", "import torch.nn as nn\n", "import random\n", "\n", "# 检查 CUDA 可用性\n", "assert torch.cuda.is_available(), \"CUDA 不可用\"\n", "device = torch.device('cuda')\n", "\n", "# -------------------- 定义包含三个子模块的模型 --------------------\n", "class ThreePartModel(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " # 模块 A:线性 32 -> 64\n", " self.fc_a = nn.Linear(32, 64)\n", " # 模块 B:线性 64 -> 128\n", " self.fc_b = nn.Linear(64, 128)\n", " # 模块 C:线性 128 -> 32\n", " self.fc_c = nn.Linear(128, 32)\n", " self.relu = nn.ReLU()\n", "\n", " def forward(self, x):\n", " # x shape: [bs, seq_len=8, 32]\n", " x = self.relu(self.fc_a(x)) # [bs, 8, 64]\n", " x = self.relu(self.fc_b(x)) # [bs, 8, 128]\n", " x = self.fc_c(x) # [bs, 8, 32]\n", " return x\n", "\n", "model = ThreePartModel().to(device)\n", "model.eval()\n", "\n", "# -------------------- 准备模块 B 的图池 --------------------\n", "batch_sizes = [1, 2, 4, 8, 16]\n", "graph_pool = {} # bs -> (graph, static_input, static_output)\n", "\n", "# 预热CUDA内核,避免捕获时包含自动调优过程\n", "with torch.no_grad():\n", " dummy = torch.randn(8, 8, 32, device=device)\n", " for _ in range(3):\n", " _ = model(dummy)\n", "torch.cuda.synchronize()\n", "\n", "# 为每个batch size单独捕获模块B的图\n", "for bs in batch_sizes:\n", " # 准备模块B的静态输入和输出\n", " static_input_b = torch.randn(bs, 8, 64, device=device) # 与模块A的输出形状一致\n", " static_output_b = torch.empty(bs, 8, 128, device=device) # 模块B的输出形状\n", "\n", " # 捕获图\n", " graph = torch.cuda.CUDAGraph()\n", " with torch.cuda.graph(graph):\n", " # 在此上下文中只执行模块B\n", " static_output_b = model.fc_b(static_input_b) # 注意:fc_b后没有relu,我们单独处理\n", " # 这里static_input_b和static_output_b的地址被记录\n", " # 将图和相关张量保存到池中\n", " graph_pool[bs] = (graph, static_input_b, static_output_b)\n", "\n", "print(f\"已为模块 B (batch sizes {batch_sizes}) 捕获图完成。\")\n", "\n", "# -------------------- 模拟多次推理,仅模块B使用图重放 --------------------\n", "num_iterations = 10\n", "for i in range(num_iterations):\n", " # 随机选择一个batch size\n", " bs = random.choice(batch_sizes)\n", " graph, static_in_b, static_out_b = graph_pool[bs]\n", "\n", " # 生成新的随机输入数据 (整个模型的输入)\n", " new_input = torch.randn(bs, 8, 32, device=device)\n", "\n", " # ----- 正常执行模块 A -----\n", " with torch.no_grad():\n", " out_a = model.relu(model.fc_a(new_input)) # [bs, 8, 64]\n", "\n", " # ----- 使用图重放模块 B -----\n", " # 将 out_a 复制到模块 B 的静态输入张量中(in-place 保持地址不变)\n", " static_in_b.copy_(out_a)\n", " # 重放图,此时 static_out_b 得到模块 B 的输出(注意图中没有 relu,需手动添加)\n", " graph.replay()\n", " # 对模块 B 的输出应用激活函数(因为图中未包含)\n", " out_b = model.relu(static_out_b) # [bs, 8, 128]\n", "\n", " # ----- 正常执行模块 C -----\n", " with torch.no_grad():\n", " final_out = model.fc_c(out_b) # [bs, 8, 32]\n", "\n", " # 验证:与完整模型普通前向的结果比较\n", " with torch.no_grad():\n", " expected = model(new_input)\n", "\n", " if torch.allclose(final_out, expected, atol=1e-5):\n", " print(f\"迭代 {i+1}: bs={bs} 部分图重放结果与完整模型一致\")\n", " else:\n", " print(f\"迭代 {i+1}: bs={bs} 结果不一致!\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "TiFbEOaPofYy", "outputId": "2162ea27-0b3c-4e69-b49d-e13112f03091" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "已为模块 B (batch sizes [1, 2, 4, 8, 16]) 捕获图完成。\n", "迭代 1: bs=2 部分图重放结果与完整模型一致\n", "迭代 2: bs=16 部分图重放结果与完整模型一致\n", "迭代 3: bs=8 部分图重放结果与完整模型一致\n", "迭代 4: bs=1 部分图重放结果与完整模型一致\n", "迭代 5: bs=8 部分图重放结果与完整模型一致\n", "迭代 6: bs=2 部分图重放结果与完整模型一致\n", "迭代 7: bs=1 部分图重放结果与完整模型一致\n", "迭代 8: bs=16 部分图重放结果与完整模型一致\n", "迭代 9: bs=4 部分图重放结果与完整模型一致\n", "迭代 10: bs=2 部分图重放结果与完整模型一致\n", "\n" ] } ] } ] }