{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "authorship_tag": "ABX9TyN3ys/a92DbgzGMLLd+GOHO" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "# Nano vLLM框架配套练习\n", "\n", "\n", "相关文章:\n", "\n", "* [推理框架极简入门:用Nano-vLLM搭建知识体系](https://zhuanlan.zhihu.com/p/2008285806222132143)\n", "\n", "* [Nano vLLM架构介绍(英文)](https://github.com/CalvinXKY/nano-vllm/blob/main/docs/structures.md)\n", "\n", "\n", "Author: kaiyuan\n", "\n", "Email: kaiyuanxie@yeah.net" ], "metadata": { "id": "cehUvCUoyZDE" } }, { "cell_type": "markdown", "source": [ "# 1 请求处理主要流程\n", "\n", "\n", "## 1.1 请求序列的编码" ], "metadata": { "id": "9701xFrMGuZz" } }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "hDguFhMxor-p" }, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen3-0.6B\", use_fast=True)\n" ] }, { "cell_type": "code", "source": [ "token_ids = tokenizer.encode(\"hello\")\n", "print(token_ids)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qtWEGA46IxWD", "outputId": "4205188f-5380-47fe-fc50-12d905e89308" }, "execution_count": 15, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[14990]\n" ] } ] }, { "cell_type": "code", "source": [ "token_ids = tokenizer.encode(\"InfraTech\")\n", "print(token_ids)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "uU_zJrOnFVrt", "outputId": "c556e32d-ba9b-413c-f94b-70ad048ce263" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[19433, 956, 34097]\n" ] } ] }, { "cell_type": "code", "source": [ "token_ids = tokenizer.encode(\"Hi, I'm kaiyuan\")\n", "positions = list(range(len(token_ids)))\n", "print(token_ids)\n", "print(positions)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "G7PQ-ROXG4zF", "outputId": "3a5a3a30-efd7-4594-e2e8-18527408c5fe" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[13048, 11, 358, 2776, 595, 2143, 88, 10386]\n", "[0, 1, 2, 3, 4, 5, 6, 7]\n" ] } ] }, { "cell_type": "code", "source": [ "token_ids = tokenizer.encode(\"Do you subscribe InfraTech?\")\n", "positions = list(range(len(token_ids)))\n", "print(token_ids)\n", "print(positions)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "x1nfFeyAEqNX", "outputId": "a1b6c61d-82f6-4da3-d9ac-ec2a9553a0b8" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[5404, 498, 17963, 14921, 956, 34097, 30]\n", "[0, 1, 2, 3, 4, 5, 6]\n" ] } ] }, { "cell_type": "markdown", "source": [ "## 1.2 物理显存slot位置计算" ], "metadata": { "id": "rn7b1E4OpLv5" } }, { "cell_type": "code", "source": [ "def get_slots(block_ids, block_size, seq_len):\n", " slots = []\n", " for block_id in block_ids[:-1]:\n", " start = block_id * block_size\n", " end = start + block_size\n", " slots.extend(list(range(start, end)))\n", " # 最后一个block\n", " start = block_ids[-1] * block_size\n", " end = start + (seq_len - (len(block_ids)-1) * block_size)\n", " slots.extend(list(range(start, end)))\n", " return slots" ], "metadata": { "id": "991RT_IHmDQD" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "token_ids = tokenizer.encode(\"Hi, I'm kaiyuan\")\n", "positions = list(range(len(token_ids)))\n", "# 分配block 12、13给该请求, block size:4\n", "slots = get_slots([12, 13], 4 , len(token_ids))\n", "print(f\"token_ids {token_ids}\")\n", "print(f\"positions {positions}\")\n", "print(f\"slots {slots}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "DFr1sImUocU1", "outputId": "dd105bb2-279e-4529-d494-a84932867ddd" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "token_ids [13048, 11, 358, 2776, 595, 2143, 88, 10386]\n", "positions [0, 1, 2, 3, 4, 5, 6, 7]\n", "slots [48, 49, 50, 51, 52, 53, 54, 55]\n" ] } ] }, { "cell_type": "code", "source": [ "token_ids = tokenizer.encode(\"Do you subscribe InfraTech?\")\n", "positions = list(range(len(token_ids)))\n", "# 分配block 20、25给该请求,block size:4\n", "slots = get_slots([20, 25], 4 , len(token_ids))\n", "print(f\"token_ids {token_ids}\")\n", "print(f\"positions {positions}\")\n", "print(f\"slots {slots}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "w5X0BqJZo6j0", "outputId": "add37fb6-e21c-4b49-a019-b9bf8101942b" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "token_ids [5404, 498, 17963, 14921, 956, 34097, 30]\n", "positions [0, 1, 2, 3, 4, 5, 6]\n", "slots [80, 81, 82, 83, 100, 101, 102]\n" ] } ] }, { "cell_type": "code", "source": [ "token_ids = tokenizer.encode(\"Hello, what can I do for you.\")\n", "print(f\"token_ids {token_ids}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "UJIEi2sO6kAH", "outputId": "137c35c8-d7cf-4d8b-a404-2ef862545705" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "token_ids [9707, 11, 1128, 646, 358, 653, 369, 498, 13]\n" ] } ] }, { "cell_type": "markdown", "source": [ "## 1.3 Token ids解码" ], "metadata": { "id": "xZczGy5Gzkh1" } }, { "cell_type": "code", "source": [ "ans = tokenizer.decode([9707, 11, 1128, 646, 358, 653, 369, 498, 13])\n", "print(f\"Answer: {ans}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8XGF0JXb17Ep", "outputId": "363f11a3-6692-403c-f83f-477c1d186081" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Answer: Hello, what can I do for you.\n" ] } ] }, { "cell_type": "code", "source": [ "ans = tokenizer.decode([9454, 11, 315, 3308, 13])\n", "print(f\"Answer: {ans}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "s5WRis-BzkzT", "outputId": "209db8f2-3971-495c-9e53-2d9d3c6093d4" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Answer: Yes, of course.\n" ] } ] }, { "cell_type": "markdown", "source": [ "# 2 关键类与函数\n", "\n", "\n", "\n" ], "metadata": { "id": "TRRenS2u2H_0" } }, { "cell_type": "markdown", "source": [ "## 2.1 原子类与函数\n", "\n", "- config: 引擎配置文件\n", "- SamplingParams: 采样参数\n", "- Sequence:承载单个请求的相关信息" ], "metadata": { "id": "i-FSHaMc6-3n" } }, { "cell_type": "code", "source": [ "from dataclasses import dataclass\n", "from copy import copy\n", "from enum import Enum, auto\n", "from itertools import count\n", "\n", "@dataclass\n", "class Config:\n", " model: str = \"dummy\"\n", " max_num_batched_tokens: int = 16384\n", " max_num_seqs: int = 512\n", " max_model_len: int = 4096\n", " gpu_memory_utilization: float = 0.9\n", " tensor_parallel_size: int = 1\n", " enforce_eager: bool = False\n", " eos: int = -1\n", " kvcache_block_size: int = 256\n", " num_kvcache_blocks: int = -1\n", "\n", " def __post_init__(self):\n", " assert self.kvcache_block_size % 256 == 0\n", " assert 1 <= self.tensor_parallel_size <= 8\n", " assert self.max_num_batched_tokens >= self.max_model_len\n", "\n", "@dataclass\n", "class SamplingParams:\n", " temperature: float = 1.0\n", " max_tokens: int = 64\n", " ignore_eos: bool = False\n", "\n", "class SequenceStatus(Enum):\n", " WAITING = auto()\n", " RUNNING = auto()\n", " FINISHED = auto()\n", "\n", "\n", "class Sequence:\n", " block_size = 4 # 方便演示,从默认256修改4\n", " counter = count()\n", "\n", " def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):\n", " self.seq_id = next(Sequence.counter)\n", " self.status = SequenceStatus.WAITING\n", " self.token_ids = copy(token_ids)\n", " self.last_token = token_ids[-1]\n", " self.num_tokens = len(self.token_ids)\n", " self.num_prompt_tokens = len(token_ids)\n", " self.num_cached_tokens = 0\n", " self.block_table = []\n", " self.temperature = sampling_params.temperature\n", " self.max_tokens = sampling_params.max_tokens\n", " self.ignore_eos = sampling_params.ignore_eos\n", "\n", " def __len__(self):\n", " return self.num_tokens\n", "\n", " def __getitem__(self, key):\n", " return self.token_ids[key]\n", "\n", " @property\n", " def is_finished(self):\n", " return self.status == SequenceStatus.FINISHED\n", "\n", " @property\n", " def num_completion_tokens(self):\n", " return self.num_tokens - self.num_prompt_tokens\n", "\n", " @property\n", " def prompt_token_ids(self):\n", " return self.token_ids[:self.num_prompt_tokens]\n", "\n", " @property\n", " def completion_token_ids(self):\n", " return self.token_ids[self.num_prompt_tokens:]\n", "\n", " @property\n", " def num_cached_blocks(self):\n", " return self.num_cached_tokens // self.block_size\n", "\n", " @property\n", " def num_blocks(self):\n", " return (self.num_tokens + self.block_size - 1) // self.block_size\n", "\n", " @property\n", " def last_block_num_tokens(self):\n", " return self.num_tokens - (self.num_blocks - 1) * self.block_size\n", "\n", " def block(self, i):\n", " assert 0 <= i < self.num_blocks\n", " return self.token_ids[i*self.block_size: (i+1)*self.block_size]\n", "\n", " def append_token(self, token_id: int):\n", " self.token_ids.append(token_id)\n", " self.last_token = token_id\n", " self.num_tokens += 1\n", "\n", " def __getstate__(self):\n", " return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,\n", " self.token_ids if self.num_completion_tokens == 0 else self.last_token)\n", "\n", " def __setstate__(self, state):\n", " self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]\n", " if self.num_completion_tokens == 0:\n", " self.token_ids = state[-1]\n", " else:\n", " self.last_token = state[-1]\n" ], "metadata": { "id": "yc8jn4SJ7Sh7" }, "execution_count": 16, "outputs": [] }, { "cell_type": "markdown", "source": [ "## 2.2 KV cache blocks数量计算\n", "\n", "KV cache blocks数量计算的代码位于:[nanovllm/engine/model_runner.py](https://github.com/CalvinXKY/nano-vllm/blob/main/nanovllm/engine/model_runner.py)的allocate_kv_cache()函数中,通过该函数确认blocks数量上限\n" ], "metadata": { "id": "UjET1t_z634e" } }, { "cell_type": "code", "source": [ "import torch\n", "from transformers import AutoConfig\n", "\n", "model_id = \"Qwen/Qwen3-0.6B\"\n", "config = AutoConfig.from_pretrained(model_id)\n", "# 打印出整个配置内容\n", "print(config)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "LwCdlWSS_r_i", "outputId": "e086a816-942e-46a2-859c-dce47a9e753c" }, "execution_count": 17, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Qwen3Config {\n", " \"architectures\": [\n", " \"Qwen3ForCausalLM\"\n", " ],\n", " \"attention_bias\": false,\n", " \"attention_dropout\": 0.0,\n", " \"bos_token_id\": 151643,\n", " \"dtype\": \"bfloat16\",\n", " \"eos_token_id\": 151645,\n", " \"head_dim\": 128,\n", " \"hidden_act\": \"silu\",\n", " \"hidden_size\": 1024,\n", " \"initializer_range\": 0.02,\n", " \"intermediate_size\": 3072,\n", " \"layer_types\": [\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\",\n", " \"full_attention\"\n", " ],\n", " \"max_position_embeddings\": 40960,\n", " \"max_window_layers\": 28,\n", " \"model_type\": \"qwen3\",\n", " \"num_attention_heads\": 16,\n", " \"num_hidden_layers\": 28,\n", " \"num_key_value_heads\": 8,\n", " \"pad_token_id\": null,\n", " \"rms_norm_eps\": 1e-06,\n", " \"rope_parameters\": {\n", " \"rope_theta\": 1000000,\n", " \"rope_type\": \"default\"\n", " },\n", " \"sliding_window\": null,\n", " \"tie_word_embeddings\": true,\n", " \"transformers_version\": \"5.0.0\",\n", " \"use_cache\": true,\n", " \"use_sliding_window\": false,\n", " \"vocab_size\": 151936\n", "}\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "blocks数量计算演示:" ], "metadata": { "id": "JXRgD68K8Yq5" } }, { "cell_type": "code", "source": [ "# 数据类型,bf16 大小为2\n", "fp_16_size = torch.empty(1, dtype=torch.bfloat16).itemsize\n", "\n", "# 用户参数配置:\n", "# block_size,定义block大小。 显存使用系数:gpu_memory_utilization\n", "block_size = 4\n", "gpu_memory_utilization = 0.95\n", "\n", "# 从Qwen模型config里面获取数据\n", "num_hidden_layers = config.num_hidden_layers\n", "head_dim = config.head_dim\n", "num_kv_heads = config.num_key_value_heads\n", "\n", "# GPU当前状态\n", "total = 80 * (1024 ** 3)\n", "used = 10 * (1024 ** 3)\n", "peak = 40 * (1024 ** 3)\n", "current = 5 * (1024 ** 3)\n", "\n", "# 一个block需要的数据大小\n", "block_bytes = 2 * num_hidden_layers * block_size * num_kv_heads * head_dim * fp_16_size\n", "\n", "# blocks总数量\n", "num_kvcache_blocks = int(total * gpu_memory_utilization - used - peak + current) // block_bytes\n", "\n", "# KV cache数量:\n", "kv_cache = 2 * num_hidden_layers * num_kvcache_blocks * block_size * num_kv_heads * head_dim\n", "\n", "print(f\"num_kvcache_blocks: {num_kvcache_blocks}\")\n", "print(f\"kv cache size: {kv_cache}\")\n", "print(f\"kv cache mem: {kv_cache * fp_16_size / (1024 ** 3):.4} GB\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "diHKTVn2-3at", "outputId": "ca2f437b-f372-4e60-e92c-c4ad18bd4497" }, "execution_count": 18, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "num_kvcache_blocks: 72557\n", "kv cache size: 16642834432\n", "kv cache mem: 31.0 GB\n" ] } ] }, { "cell_type": "markdown", "source": [ "## 2.3 BlockManager的运行示例\n", "\n", "### BlockManager定义" ], "metadata": { "id": "Ltoiya-28isj" } }, { "cell_type": "code", "source": [ "import xxhash\n", "from collections import deque\n", "import numpy as np\n", "\n", "class Block:\n", "\n", " def __init__(self, block_id):\n", " self.block_id = block_id\n", " self.ref_count = 0\n", " self.hash = -1\n", " self.token_ids = []\n", "\n", " def update(self, hash: int, token_ids: list[int]):\n", " self.hash = hash\n", " self.token_ids = token_ids\n", "\n", " def reset(self):\n", " self.ref_count = 1\n", " self.hash = -1\n", " self.token_ids = []\n", "\n", "\n", "class BlockManager:\n", "\n", " def __init__(self, num_blocks: int, block_size: int):\n", " self.block_size = block_size\n", " self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]\n", " self.hash_to_block_id: dict[int, int] = dict()\n", " self.free_block_ids: deque[int] = deque(range(num_blocks))\n", " self.used_block_ids: set[int] = set()\n", "\n", " @classmethod\n", " def compute_hash(cls, token_ids: list[int], prefix: int = -1):\n", " h = xxhash.xxh64()\n", " if prefix != -1:\n", " h.update(prefix.to_bytes(8, \"little\"))\n", " h.update(np.array(token_ids).tobytes())\n", " return h.intdigest()\n", "\n", " def _allocate_block(self, block_id: int) -> Block:\n", " block = self.blocks[block_id]\n", " assert block.ref_count == 0\n", " block.reset()\n", " self.free_block_ids.remove(block_id)\n", " self.used_block_ids.add(block_id)\n", " return self.blocks[block_id]\n", "\n", " def _deallocate_block(self, block_id: int) -> Block:\n", " assert self.blocks[block_id].ref_count == 0\n", " self.used_block_ids.remove(block_id)\n", " self.free_block_ids.append(block_id)\n", "\n", " def can_allocate(self, seq: Sequence) -> bool:\n", " return len(self.free_block_ids) >= seq.num_blocks\n", "\n", " def allocate(self, seq: Sequence):\n", " assert not seq.block_table\n", " h = -1\n", " cache_miss = False\n", " for i in range(seq.num_blocks):\n", " token_ids = seq.block(i)\n", " h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1\n", " block_id = self.hash_to_block_id.get(h, -1)\n", " if block_id == -1 or self.blocks[block_id].token_ids != token_ids:\n", " cache_miss = True\n", " if cache_miss:\n", " block_id = self.free_block_ids[0]\n", " block = self._allocate_block(block_id)\n", " else:\n", " seq.num_cached_tokens += self.block_size\n", " if block_id in self.used_block_ids:\n", " block = self.blocks[block_id]\n", " block.ref_count += 1\n", " else:\n", " block = self._allocate_block(block_id)\n", " if h != -1:\n", " block.update(h, token_ids)\n", " self.hash_to_block_id[h] = block_id\n", " seq.block_table.append(block_id)\n", "\n", " def deallocate(self, seq: Sequence):\n", " for block_id in reversed(seq.block_table):\n", " block = self.blocks[block_id]\n", " block.ref_count -= 1\n", " if block.ref_count == 0:\n", " self._deallocate_block(block_id)\n", " seq.num_cached_tokens = 0\n", " seq.block_table.clear()\n", "\n", " def can_append(self, seq: Sequence) -> bool:\n", " return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)\n", "\n", " def may_append(self, seq: Sequence):\n", " block_table = seq.block_table\n", " last_block = self.blocks[block_table[-1]]\n", " if len(seq) % self.block_size == 1:\n", " assert last_block.hash != -1\n", " block_id = self.free_block_ids[0]\n", " self._allocate_block(block_id)\n", " block_table.append(block_id)\n", " elif len(seq) % self.block_size == 0:\n", " assert last_block.hash == -1\n", " token_ids = seq.block(seq.num_blocks-1)\n", " prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1\n", " h = self.compute_hash(token_ids, prefix)\n", " last_block.update(h, token_ids)\n", " self.hash_to_block_id[h] = last_block.block_id\n", " else:\n", " assert last_block.hash == -1" ], "metadata": { "id": "_txkOALm8h9C" }, "execution_count": 19, "outputs": [] }, { "cell_type": "code", "source": [ "# 定义一个block_manager状态打印函数\n", "def print_block_manager_info(block_manager):\n", " print(f\"block_manager.blocks size: {len(block_manager.blocks)}\")\n", " print(f\"blocks list: {[block.block_id for block in block_manager.blocks]}\")\n", " print(f\"free_block_ids: {block_manager.free_block_ids}\")\n", " print(f\"used_block_ids: {block_manager.used_block_ids}\")" ], "metadata": { "id": "k-MdKnUOBIV3" }, "execution_count": 20, "outputs": [] }, { "cell_type": "markdown", "source": [ "### BlockManager的blocks管理逻辑演示\n", "\n", "注意:以下代码单步仅能运行一次,重复运行得从BlockManager构建开始。" ], "metadata": { "id": "CU47Lqsl9yBr" } }, { "cell_type": "code", "source": [ "# 创建BlockManager\n", "num_kvcache_blocks = 10\n", "kvcache_block_size = 4\n", "block_manager= BlockManager(num_kvcache_blocks, kvcache_block_size)\n", "print_block_manager_info(block_manager)\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wBmixZvt9yM1", "outputId": "32f6a432-3902-4c8d-b79f-41d9de891d11" }, "execution_count": 21, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "block_manager.blocks size: 10\n", "blocks list: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n", "free_block_ids: deque([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])\n", "used_block_ids: set()\n" ] } ] }, { "cell_type": "code", "source": [ "sampling_params = SamplingParams(temperature=0.6, max_tokens=256)\n", "# 增加请求\n", "token_ids = tokenizer.encode(\"hi, I'm kaiyuan\")\n", "seq_0 = Sequence(token_ids, sampling_params)\n", "\n", "token_ids = tokenizer.encode(\"Do you subscribe InfraTech?\")\n", "seq_1 = Sequence(token_ids, sampling_params)\n", "\n", "# 为请求申请block:\n", "block_manager.allocate(seq_0)\n", "block_manager.allocate(seq_1)\n", "print_block_manager_info(block_manager)\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "nVv4QvBkAFEb", "outputId": "0b7c94b5-5269-4c39-d714-8b41134e0702" }, "execution_count": 22, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "block_manager.blocks size: 10\n", "blocks list: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n", "free_block_ids: deque([4, 5, 6, 7, 8, 9])\n", "used_block_ids: {0, 1, 2, 3}\n" ] } ] }, { "cell_type": "code", "source": [ "# 删除请求,能够看到释放的blocks2和3到了队尾\n", "block_manager.deallocate(seq_0)\n", "print_block_manager_info(block_manager)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "eL_CBmgAMYVI", "outputId": "fc00a17c-23f0-4638-a498-86c24c6088ff" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "block_manager.blocks size: 10\n", "blocks list: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n", "free_block_ids: deque([4, 5, 6, 7, 8, 9, 1, 0])\n", "used_block_ids: {2, 3}\n" ] } ] }, { "cell_type": "code", "source": [ "# 再增加一个请求,会使用队首的blocks:\n", "\n", "token_ids = tokenizer.encode(\"to add new blocks.\")\n", "seq_2 = Sequence(token_ids, sampling_params)\n", "block_manager.allocate(seq_2)\n", "print_block_manager_info(block_manager)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "sM9O-FWtMyy0", "outputId": "6df3d12c-c60b-4266-8bf4-b8d174b32e73" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "block_manager.blocks size: 10\n", "blocks list: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n", "free_block_ids: deque([6, 7, 8, 9, 1, 0])\n", "used_block_ids: {2, 3, 4, 5}\n" ] } ] }, { "cell_type": "markdown", "source": [ "prefix cache匹配示例" ], "metadata": { "id": "-s1e92oyOPnX" } }, { "cell_type": "code", "source": [ "# 重复字符串输入,有相同prefix数据,复用0、1 blocks数据:\n", "token_ids = tokenizer.encode(\"hi, I'm kaiyuan\")\n", "seq_3 = Sequence(token_ids, sampling_params)\n", "block_manager.allocate(seq_3)\n", "print_block_manager_info(block_manager)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "D1RFTVXtNwoZ", "outputId": "016c2fc1-b847-4002-8849-b77cfad5d8af" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "block_manager.blocks size: 10\n", "blocks list: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n", "free_block_ids: deque([6, 7, 8, 9])\n", "used_block_ids: {0, 1, 2, 3, 4, 5}\n" ] } ] }, { "cell_type": "markdown", "source": [ "## 3 GPU运算的关键内容\n", "\n" ], "metadata": { "id": "yPKLIqTC1Sc0" } }, { "cell_type": "markdown", "source": [ "## 3.1 包含KV cache的Attention计算\n", "\n", "Attention计算使用flash_attn库的flash_attn_varlen_func和flash_attn_with_kvcache函数\n", "\n", "在[flash-attention/flash_attn/flash_attn_interface.py](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py)中找到FA计算所需的数据要求:\n", "\n", "- Q/K/V数据格式:(total_q, nheads, headdim),即计算decode时Q的格式为(batch_size, seqlen, nheads, headdim)\n", "- KV cache数据格式:(num_blocks, page_block_size, nheads_k, headdim)\n", "- cu_seqlens_q/cu_seqlens_k数据格式:(batch_size + 1,),表示batch中各序列长度的累积值\n", "- max_seqlen_q/k:Q/K 在 batch 中最长的序列长度\n", "- cache_seqlens:cache 的序列长度\n", "- block_tables数据格式:(batch_size, max_num_blocks_per_seq),用于索引KV cache的block数据\n", "\n", "\n", "计算步骤:\n", "\n", "- 创建请求转换数据\n", "- 分配逻辑存储空间\n", "- 分配物理存储空间\n", "- 构建FA计算数据格式\n", "- 完成FA计算\n", "\n" ], "metadata": { "id": "DN3lUaBR5Qpi" } }, { "cell_type": "code", "source": [ "sampling_params = SamplingParams(temperature=0.6, max_tokens=256)\n", "# 增加请求\n", "token_ids = tokenizer.encode(\"hi, I'm kaiyuan\")\n", "seq_0 = Sequence(token_ids, sampling_params)\n", "\n", "token_ids = tokenizer.encode(\"Do you subscribe InfraTech?\")\n", "seq_1 = Sequence(token_ids, sampling_params)\n", "seqs = [seq_0, seq_1]\n", "\n", "num_kvcache_blocks = 10\n", "kvcache_block_size = 4\n", "block_manager= BlockManager(num_kvcache_blocks, kvcache_block_size)\n", "\n", "# 计算seq的block的最大长度,并为seq分配逻辑存储空间:\n", "max_len = 0\n", "for seq in seqs:\n", " max_len = max(len(seq.block_table), max_len)\n", " block_manager.allocate(seq)\n", "print_block_manager_info(block_manager)\n", "\n", "# 创建物理映射表:\n", "\n", "block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]\n", "\n", "print(f\"物理映射表:\\nblock_tables: {block_tables}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dV4mCyuJ1rCB", "outputId": "9cf78ec0-a5b9-4e42-86c2-5d4ccf525084" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "block_manager.blocks size: 10\n", "blocks list: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n", "free_block_ids: deque([4, 5, 6, 7, 8, 9])\n", "used_block_ids: {0, 1, 2, 3}\n", "物理映射表:\n", "block_tables: [[0, 1], [2, 3]]\n" ] } ] }, { "cell_type": "code", "source": [ "# 补充模型配置参数:\n", "num_hidden_layers = 1\n", "num_kv_heads = 8\n", "head_dim = 32\n", "\n", "# 创建KV cache的物理存储空间:\n", "kv_cache = torch.empty(2, num_hidden_layers, num_kvcache_blocks, kvcache_block_size, num_kv_heads, head_dim)\n", "print(f\"KV cache shape: {kv_cache.shape}\")\n", "\n", "# K cache与V cache数据分离:\n", "layer_id = 0 # 假设为第0层\n", "k_cache = kv_cache[0, layer_id]\n", "v_cache = kv_cache[1, layer_id]\n", "\n", "print(f\"K cache shape:{k_cache.shape}\")\n", "print(f\"V cache shape:{v_cache.shape}\")\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NaZ7087N4snp", "outputId": "724d6d60-b39f-4243-84a5-52351cd5fb65" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "KV cache shape: torch.Size([2, 1, 10, 4, 8, 32])\n", "K cache shape:torch.Size([10, 4, 8, 32])\n", "V cache shape:torch.Size([10, 4, 8, 32])\n" ] } ] }, { "cell_type": "code", "source": [ "# 构建Attention计算所需要的数据:\n", "cu_seqlens_q = [0]\n", "cu_seqlens_k = [0]\n", "max_seqlen_q = 0\n", "max_seqlen_k = 0\n", "total_seqlen = 0\n", "for seq in seqs:\n", " seqlen = len(seq)\n", " seqlen_q = seqlen - seq.num_cached_tokens\n", " seqlen_k = seqlen\n", " total_seqlen += seqlen\n", " cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)\n", " cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)\n", " max_seqlen_q = max(seqlen_q, max_seqlen_q)\n", " max_seqlen_k = max(seqlen_k, max_seqlen_k)\n", "\n", "print(f\"total_seqlen: {total_seqlen}\")\n", "print(f\"cu_seqlens_q: {cu_seqlens_q}\")\n", "print(f\"cu_seqlens_k: {cu_seqlens_k}\")\n", "print(f\"max_seqlen_q: {max_seqlen_q}\")\n", "print(f\"max_seqlen_k: {max_seqlen_k}\")\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "rpQg8gl066GL", "outputId": "80257448-233b-4628-97ef-097b974d2627" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "total_seqlen: 15\n", "cu_seqlens_q: [0, 8, 15]\n", "cu_seqlens_k: [0, 8, 15]\n", "max_seqlen_q: 8\n", "max_seqlen_k: 8\n" ] } ] }, { "cell_type": "markdown", "source": [ "注: 如下计算使用了flash_atten,需要用到GPU! 且环境支持flash_attn运行。\n", "\n", "可以使用docker镜像,如:\n", "\n", "```\n", "docker pull nvcr.io/nvidia/sglang:26.01-py3\n", "```\n" ], "metadata": { "id": "FCWaJxEfJ7Dp" } }, { "cell_type": "markdown", "source": [ "\n", "### prefill计算演示" ], "metadata": { "id": "_EOHeVt3-ipc" } }, { "cell_type": "code", "source": [ "# 检查 CUDA 是否可用\n", "assert torch.cuda.is_available(), \"CUDA 不可用,本示例需要 GPU\"" ], "metadata": { "id": "ZNQBPw7OIVOz" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache\n", "\n", "# q k v数据构造\n", "num_q_heads = num_kv_heads\n", "q = torch.randn(total_seqlen, num_q_heads, head_dim, device='cuda', dtype=torch.float16)\n", "k = torch.randn(total_seqlen, num_kv_heads, head_dim, device='cuda', dtype=torch.float16)\n", "v = torch.randn(total_seqlen, num_kv_heads, head_dim, device='cuda', dtype=torch.float16)\n", "\n", "\n", "cu_seqlens_q_cuda = torch.tensor(cu_seqlens_q, dtype=torch.int32, device='cuda')\n", "cu_seqlens_k_cuda = torch.tensor(cu_seqlens_k, dtype=torch.int32, device='cuda')\n", "\n", "\n", "# FA计算:\n", "o = flash_attn_varlen_func(q, k, v,\n", " max_seqlen_q=max_seqlen_q, cu_seqlens_q=cu_seqlens_q_cuda,\n", " max_seqlen_k=max_seqlen_k, cu_seqlens_k=cu_seqlens_k_cuda,\n", " softmax_scale=head_dim ** -0.5, causal=True, block_table=block_tables)" ], "metadata": { "id": "Hhlf6SzF_1Hx" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### decode计算演示\n" ], "metadata": { "id": "CJp4c8IJBLG6" } }, { "cell_type": "code", "source": [ "q = torch.randn(len(seqs), num_q_heads, head_dim, device='cuda', dtype=torch.float16)\n", "\n", "# 创建GPU版本的 KV cache的物理存储空间:\n", "kv_cache = torch.empty(2, num_hidden_layers, num_kvcache_blocks, kvcache_block_size, num_kv_heads, head_dim, device='cuda', dtype=torch.float16)\n", "print(f\"KV cache shape: {kv_cache.shape()}\")\n", "\n", "# K cache与V cache数据分离:\n", "layer_id = 0 # 假设为第0层\n", "k_cache = kv_cache[0, layer_id]\n", "v_cache = kv_cache[1, layer_id]\n", "\n", "# context的长度= 历史kv cache长度 + 新增seq长度\n", "context_lens = []\n", "for seq in seqs:\n", " context_lens.append(len(seq))\n", "\n", "# FA 计算:\n", "# q需要转为:(batch_size, seqlen, nheads, headdim)格式,seqlen = 1\n", "o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,\n", " cache_seqlens=context_lens,\n", " block_table=block_tables,\n", " softmax_scale=head_dim ** -0.5, causal=True)" ], "metadata": { "id": "EYKCsN-3BltK" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## 3.2 CUDA Graph的使用" ], "metadata": { "id": "d4f6v3jELsm7" } }, { "cell_type": "code", "source": [ "# nano-vLLM中graph batchsize的支持最大值为512\n", "max_num_seqs = 1024\n", "max_bs = min(max_num_seqs, 512)\n", "graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))\n", "print(graph_bs)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "i0L_b2vANf8a", "outputId": "0af904b0-3c47-40a8-cac4-44429d0127f6" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[1, 2, 4, 8, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496, 512]\n" ] } ] }, { "cell_type": "markdown", "source": [ "构建一个cuda graph捕获演示示例:\n" ], "metadata": { "id": "795JK3CgJWOL" } }, { "cell_type": "code", "source": [ "import torch\n", "import torch.nn as nn\n", "import random\n", "\n", "\n", "# -------------------- 定义模型 --------------------\n", "class SimpleModel(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.fc1 = nn.Linear(32, 64) # 输入特征 32 -> 64\n", " self.fc2 = nn.Linear(64, 32) # 64 -> 32\n", " self.relu = nn.ReLU()\n", "\n", " def forward(self, x):\n", " # x 形状: [bs, 8, 32]\n", " x = self.fc1(x) # [bs, 8, 64]\n", " x = self.relu(x)\n", " x = self.fc2(x) # [bs, 8, 32]\n", " return x\n", "\n", "model = SimpleModel().to(device)\n", "model.eval() # 推理模式,关闭 dropout/batchnorm 等随机行为\n", "\n", "# -------------------- 为不同 batch size 捕获图 --------------------\n", "batch_sizes = [1, 2, 4, 8, 16]\n", "graph_pool = {} # 字典:bs -> (graph, static_input, static_output)\n", "\n", "# 预热 CUDA 上下文,避免捕获时包含自动调优开销\n", "warmup_input = torch.randn(8, 8, 32, device=device)\n", "for _ in range(3):\n", " _ = model(warmup_input)\n", "torch.cuda.synchronize()\n", "\n", "for bs in batch_sizes:\n", " # 创建固定的输入、输出占位张量(图会记住它们的地址)\n", " static_input = torch.randn(bs, 8, 32, device=device)\n", " static_output = torch.empty_like(static_input) # 形状与输入相同\n", "\n", " # 开始捕获\n", " graph = torch.cuda.CUDAGraph()\n", " with torch.cuda.graph(graph):\n", " # 在此上下文中执行的所有 CUDA 操作都会被捕获\n", " static_output = model(static_input)\n", "\n", " # 将图及相关张量保存到池中\n", " graph_pool[bs] = (graph, static_input, static_output)\n", "\n", "print(f\"已为 batch sizes {batch_sizes} 捕获图完成。\")\n", "\n", "# -------------------- 模拟多次推理,随机选择 batch size --------------------\n", "num_iterations = 10\n", "for i in range(num_iterations):\n", " # 随机选择一个 batch size\n", " bs = random.choice(batch_sizes)\n", " graph, static_input, static_output = graph_pool[bs]\n", "\n", " # 生成新的随机输入数据\n", " new_input = torch.randn(bs, 8, 32, device=device)\n", "\n", " # 将新数据复制到图使用的静态输入张量中(in-place 操作,不改变地址)\n", " static_input.copy_(new_input)\n", "\n", " # 重放图\n", " graph.replay()\n", "\n", " # 此时 static_output 已经更新为对应新输入的计算结果\n", " # 可以取出结果用于后续处理,例如与普通前向结果对比验证\n", " with torch.no_grad():\n", " expected_output = model(new_input)\n", "\n", " # 验证结果是否一致(允许微小误差)\n", " if torch.allclose(static_output, expected_output, atol=1e-5):\n", " print(f\"迭代 {i+1}: bs={bs} 图重放结果与普通前向一致\")\n", " else:\n", " print(f\"迭代 {i+1}: bs={bs} 结果不一致!\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "kFRJNPNvJq_G", "outputId": "85ae1124-e4d5-459f-8ac8-19ce31a26d7d" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "已为 batch sizes [1, 2, 4, 8, 16] 捕获图完成。\n", "迭代 1: bs=4 图重放结果与普通前向一致\n", "迭代 2: bs=2 图重放结果与普通前向一致\n", "迭代 3: bs=4 图重放结果与普通前向一致\n", "迭代 4: bs=1 图重放结果与普通前向一致\n", "迭代 5: bs=4 图重放结果与普通前向一致\n", "迭代 6: bs=1 图重放结果与普通前向一致\n", "迭代 7: bs=2 图重放结果与普通前向一致\n", "迭代 8: bs=4 图重放结果与普通前向一致\n", "迭代 9: bs=4 图重放结果与普通前向一致\n", "迭代 10: bs=8 图重放结果与普通前向一致\n", "\n" ] } ] } ] }