{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "authorship_tag": "ABX9TyP0isUG44zCWb69T8obf8aU" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "# 手撕SGLang KV cache逻辑:理解Radix Attention原理\n", "# Build Radix KV Cache Manager from Scratch\n", "\n", "Author: kaiyuan\n", "\n", "Email: kyxie@zju.edu.cn\n" ], "metadata": { "id": "JrEmpnkRvZyw" } }, { "cell_type": "markdown", "source": [ "# 1 物理内存的分配与使用\n", "\n", "函数定义:\n", "* MHA page的计算\n", "* pages数量的计算\n", "* page table的创建 格式:{请求索引,物理显存位置list}" ], "metadata": { "id": "1hpT5GDVxSxj" } }, { "cell_type": "code", "source": [ "import torch\n", "from dataclasses import dataclass\n", "\n", "def get_mha_cache_per_page(head_dim, num_kv_heads, num_layers, page_size=1, tp_size=1, dtype_size=2):\n", " \"\"\"\n", " :param head_dim: dim大小\n", " :param num_kv_heads: 头数\n", " :param num_layers: 模型层数\n", " :param page_size: 页大小,默认1\n", " :param tp_size: 并行策略的TP大小, 默认1为没有TP切分\n", " :param dtype_size: 数据大小,torch.float32 4字节,torch.float16/torch.bfloat16: 2字节\n", " :return:\n", " \"\"\"\n", " size = 2 * num_kv_heads * head_dim * num_layers * page_size / tp_size * dtype_size\n", " return size\n", "\n", "\n", "def get_num_pages_for_kv_cache(available_memory, cache_per_page):\n", " num_pages = int(available_memory // cache_per_page)\n", " kv_size = num_pages * cache_per_page / (1024**3)\n", " print(f\"Allocating {num_pages} pages for KV cache, K + V = {kv_size:.3f}GB\")\n", " return num_pages\n", "\n", "\n", "def create_page_table(max_running_req, max_seq_len):\n", " page_table = torch.zeros((max_running_req, max_seq_len), dtype=torch.int32)\n", " return page_table" ], "metadata": { "id": "nI-BFcTzvjHY" }, "execution_count": 1, "outputs": [] }, { "cell_type": "markdown", "source": [ "函数的调用测试:" ], "metadata": { "id": "CC1ZVHxVps8j" } }, { "cell_type": "code", "source": [ "# 模型参数定义:\n", "head_dim = 256\n", "num_kv_heads = 8\n", "num_layers = 1\n", "\n", "# 可用显存大小\n", "available_memory = 4 # 4G\n", "print(f\"可用显存大小{available_memory} GB\")\n", "\n", "cache_per_page = get_mha_cache_per_page(head_dim, num_kv_heads, num_layers)\n", "num_pages = get_num_pages_for_kv_cache(available_memory * (1024**3), cache_per_page)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9_KRMvT0HLl6", "outputId": "efa210e5-9b68-4635-bddc-a250485d42e9" }, "execution_count": 2, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "可用显存大小4 GB\n", "Allocating 524288 pages for KV cache, K + V = 4.000GB\n" ] } ] }, { "cell_type": "code", "source": [ "# 构建测试\n", "\n", "# 定义一个简单请求格式:\n", "@dataclass\n", "class SimpleRequest:\n", " uid: int\n", " table_idx: int\n", " len: int\n", "\n", "\n", "\n", "def demo():\n", " # 设定一个 num_pages大小用于演示\n", " num_pages = 50\n", "\n", " # 物理显存模拟:\n", " free_slots = torch.arange(num_pages, dtype=torch.int32)\n", "\n", " # 构建page_table:\n", " page_table = create_page_table(5, 10)\n", "\n", " print(\"当前空余slots:\")\n", " print(free_slots)\n", "\n", " # 定义显存的申请与释放函数:\n", " def allocate(req):\n", " nonlocal free_slots\n", " page_table[req.table_idx][:req.len] = free_slots[:req.len]\n", " free_slots = free_slots[req.len:]\n", "\n", " def free(req):\n", " nonlocal free_slots\n", " free_slots = torch.cat([free_slots, page_table[req.table_idx][:req.len]])\n", "\n", " req_0 = SimpleRequest(0, 0, 7)\n", " allocate(req_0)\n", " print(\"请求0 slots使用情况:\")\n", " print(page_table[req_0.table_idx][:req_0.len])\n", "\n", " req_1 = SimpleRequest(1, 1, 7)\n", " allocate(req_1)\n", " print(\"请求1 slots使用情况:\")\n", " print(page_table[req_1.table_idx][:req_1.len])\n", " print(\"=\"* 80)\n", "\n", " print(\"当前空余slots:\")\n", " print(free_slots)\n", " print(\"=\"* 80)\n", " free(req_0)\n", " print(\"释放请求0后空余slots:\")\n", " print(free_slots)\n", "\n", "demo()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "EpWTxemzT5oS", "outputId": "ce423f4e-93a6-4bdd-f813-f9e0bbb5ebe9" }, "execution_count": 3, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "当前空余slots:\n", "tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,\n", " 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n", " 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49],\n", " dtype=torch.int32)\n", "请求0 slots使用情况:\n", "tensor([0, 1, 2, 3, 4, 5, 6], dtype=torch.int32)\n", "请求1 slots使用情况:\n", "tensor([ 7, 8, 9, 10, 11, 12, 13], dtype=torch.int32)\n", "================================================================================\n", "当前空余slots:\n", "tensor([14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,\n", " 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49],\n", " dtype=torch.int32)\n", "================================================================================\n", "释放请求0后空余slots:\n", "tensor([14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,\n", " 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,\n", " 0, 1, 2, 3, 4, 5, 6], dtype=torch.int32)\n" ] } ] }, { "cell_type": "markdown", "source": [ "# 2 Radix Tree\n", "\n", "# 2.1 数据结构定义" ], "metadata": { "id": "nk8EY_x3ypdg" } }, { "cell_type": "code", "source": [ "from __future__ import annotations # using for RadixTreeNode\n", "import time\n", "\n", "# 定义一个辅助函数:\n", "def find_first_diff_pos(tensor1: torch.Tensor, tensor2: torch.Tensor) -> int:\n", " \"\"\"\n", " 比较两个一维PyTorch tensor的值,返回第一个不同值的位置\n", " Returns:\n", " int: 第一个不同值的索引位置。如果完全相同返回-1\n", " \"\"\"\n", " if tensor1.dim() != 1 or tensor2.dim() != 1:\n", " raise ValueError(\"两个tensor都必须是一维的\")\n", "\n", " len1, len2 = len(tensor1), len(tensor2)\n", "\n", " # 遍历两个tensor直到较短的那个结束\n", " for i in range(min(len1, len2)):\n", " if tensor1[i] != tensor2[i]:\n", " return i\n", "\n", " return min(len1, len2)\n", "\n", "\n", "class RadixTreeNode:\n", " counter: int = 0\n", " def __init__(self, tic: int | None = None) -> None:\n", " self.children: Dict[int, RadixTreeNode] = {}\n", " self._parent: RadixTreeNode | None = None\n", " self.ref_count: int = 0\n", " self.uuid = RadixTreeNode.counter\n", " RadixTreeNode.counter += 1\n", " self.timestamp = tic or time.monotonic_ns()\n", "\n", " self._token_ids: torch.Tensor\n", " self._slots: torch.Tensor\n", " self._length: int\n", "\n", " def set_ids_slots(self, token_ids: torch.Tensor, slots: torch.Tensor) -> None:\n", " assert len(token_ids) == len(slots)\n", " self._token_ids = token_ids\n", " self._slots = slots\n", " self._length = len(token_ids)\n", "\n", " def set_parent(self, parent: RadixTreeNode) -> None:\n", " self._parent = parent\n", " parent.children[int(self._token_ids[0].item())] = self\n", "\n", " @property\n", " def length(self) -> int:\n", " return self._length\n", "\n", " @property\n", " def parent(self) -> RadixTreeNode:\n", " assert self._parent is not None\n", " return self._parent\n", "\n", " @property\n", " def slots(self) -> torch.Tensor:\n", " return self._slots\n", "\n", " def is_root(self) -> bool:\n", " return self._parent is None\n", "\n", " def is_leaf(self) -> bool:\n", " return len(self.children) == 0\n", "\n", " def get_match_len(self, input_ids: torch.Tensor) -> int:\n", " return find_first_diff_pos(self._token_ids, input_ids)\n", "\n", " def _split_at(self, pos: int) -> RadixTreeNode:\n", " assert 0 < pos < self.length\n", " parent = self.parent\n", "\n", " new_node = RadixTreeNode(self.timestamp)\n", " new_node.set_ids_slots(self._token_ids[:pos], self._slots[:pos])\n", " new_node.set_parent(parent)\n", " new_node.ref_count = self.ref_count\n", "\n", " self.set_ids_slots(self._token_ids[pos:], self._slots[pos:])\n", " self.set_parent(new_node)\n", "\n", " return new_node\n", "\n", " def __lt__(self, other: RadixTreeNode) -> bool:\n", " return self.timestamp < other.timestamp\n", "\n", " def __repr__(self) -> str:\n", " return f\"RadixTreeNode(uuid={self.uuid}, tokens={self._token_ids.tolist()}, slots={self._slots.tolist()})\"" ], "metadata": { "id": "rZcEigOfVLwa" }, "execution_count": 4, "outputs": [] }, { "cell_type": "markdown", "source": [ "## 2.2 定义可视化打印函数" ], "metadata": { "id": "lsJKNfHOy1MD" } }, { "cell_type": "code", "source": [ "def print_radix_tree(root: RadixTreeNode, max_depth: int = 10,\n", " show_ref_count: bool = True, show_timestamp: bool = False) -> None:\n", " \"\"\"\n", " 以树形结构显示RadixTreeNode的形状\n", "\n", " Args:\n", " root: 根节点\n", " max_depth: 最大显示深度\n", " show_ref_count: 是否显示引用计数\n", " show_timestamp: 是否显示时间戳\n", " \"\"\"\n", "\n", " def _print_node(node: RadixTreeNode, depth: int, prefix: str, is_last: bool = True) -> None:\n", " \"\"\"递归打印节点及其子节点\"\"\"\n", " if depth > max_depth:\n", " return\n", "\n", " connector = \"└── \" if is_last else \"├── \"\n", " node_info = f\"uuid={node.uuid}\"\n", " if hasattr(node, '_token_ids') and node._token_ids is not None:\n", " token_str = str(node._token_ids.tolist())[:30] # 限制长度\n", " node_info += f\", tokens={token_str}\"\n", " if hasattr(node, '_slots') and node._slots is not None:\n", " slot_str = str(node._slots.tolist())[:30] # 限制长度\n", " node_info += f\", slots={slot_str}\"\n", " if show_ref_count:\n", " node_info += f\", ref={node.ref_count}\"\n", "\n", " if show_timestamp:\n", " node_info += f\", ts={node.timestamp}\"\n", "\n", " if node.is_leaf():\n", " node_info += \" [L]\"\n", " elif node.is_root():\n", " node_info += \" [R]\"\n", "\n", " print(f\"{prefix}{connector}{node_info}\")\n", " new_prefix = prefix + (\" \" if is_last else \"│ \")\n", "\n", " # 递归打印子节点\n", " child_count = len(node.children)\n", " for i, (key, child_node) in enumerate(sorted(node.children.items())):\n", " is_last_child = (i == child_count - 1)\n", " _print_node(child_node, depth + 1, new_prefix, is_last_child)\n", "\n", " print(\"\\n\" + \"=\"*80)\n", " print(\"RADIX TREE STRUCTURE\")\n", " print(\"=\"*80)\n", "\n", " if root.is_root():\n", " print(\"Root Node:\")\n", "\n", " _print_node(root, 0, \"\")\n", "\n" ], "metadata": { "id": "Wz8bSpN-OalE" }, "execution_count": 5, "outputs": [] }, { "cell_type": "markdown", "source": [ "## 2.3 数据的增、删、查操作\n", "\n", "### 2.3.1 查找与插入操作" ], "metadata": { "id": "_HSGPeElzZnR" } }, { "cell_type": "code", "source": [ "from typing import Tuple\n", "\n", "# 查找匹配前缀,有相同前缀,则分裂节点。\n", "def walk(input_ids: torch.Tensor) -> Tuple[RadixTreeNode, int]:\n", " \"\"\"\n", " 返回值 node:匹配到的前缀node。\n", " prefix_len: 匹配长度\n", " \"\"\"\n", " node = root_node\n", " prefix_len = 0\n", " indice_len = len(input_ids)\n", " while prefix_len < indice_len:\n", " this_id = int(req_1_ids[prefix_len].item())\n", " if this_id not in node.children:\n", " return node, prefix_len\n", "\n", " node = node.children[this_id]\n", " match_len = node.get_match_len(input_ids[prefix_len:])\n", " prefix_len += match_len\n", "\n", " if match_len != node.length:\n", " node = node._split_at(match_len)\n", " return node, prefix_len\n", " return node, prefix_len\n", "\n", "\n", "RadixTreeNode.counter = 0 # 从0开始计数\n", "root_node = RadixTreeNode()\n", "root_node.ref_count = 1\n", "node_0 = RadixTreeNode()\n", "req_0_ids = torch.tensor([1, 3, 6, 7, 9, 77])\n", "req_0_slots = torch.tensor([0, 1, 2, 3, 4, 7])\n", "\n", "req_1_ids = torch.tensor([1, 3, 6, 7, 87, 66])\n", "req_1_slots = torch.tensor([0, 1, 2, 3, 5, 6])\n", "\n", "# 创建节点0\n", "node_0.set_ids_slots(req_0_ids, req_0_slots)\n", "node_0.set_parent(root_node)\n", "\n", "# 打印插入node_0的状态:\n", "print_radix_tree(root_node)\n", "\n", "# 查找,并触发分裂操作\n", "node, prefix_len = walk(req_1_ids)\n", "\n", "# 增加一个处理句柄,记录request引用过的prefix\n", "cache_handle = node\n", "\n", "# 创建节点1\n", "new_node = RadixTreeNode()\n", "new_node.set_ids_slots(req_1_ids[prefix_len:], req_1_slots[prefix_len:].clone())\n", "new_node.set_parent(cache_handle)\n", "\n", "print_radix_tree(root_node)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "lTLzuqSSUCw7", "outputId": "19990b8d-3da6-411c-b591-f48b0d0b9a78" }, "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "================================================================================\n", "RADIX TREE STRUCTURE\n", "================================================================================\n", "Root Node:\n", "└── uuid=0, ref=1 [R]\n", " └── uuid=1, tokens=[1, 3, 6, 7, 9, 77], slots=[0, 1, 2, 3, 4, 7], ref=0 [L]\n", "\n", "================================================================================\n", "RADIX TREE STRUCTURE\n", "================================================================================\n", "Root Node:\n", "└── uuid=0, ref=1 [R]\n", " └── uuid=2, tokens=[1, 3, 6, 7], slots=[0, 1, 2, 3], ref=0\n", " ├── uuid=1, tokens=[9, 77], slots=[4, 7], ref=0 [L]\n", " └── uuid=3, tokens=[87, 66], slots=[5, 6], ref=0 [L]\n" ] } ] }, { "cell_type": "markdown", "source": [ "### 2.3.2 引用计数增加与减少" ], "metadata": { "id": "R-N-wHWazlvW" } }, { "cell_type": "code", "source": [ "# 前缀node被引用了,增加父节点的引用计数。\n", "node = cache_handle\n", "while not node.is_root():\n", " node.ref_count += 1\n", " node = node.parent\n", "\n", "print_radix_tree(root_node)\n", "\n", "# 请求被释放时,同时清除该请求产生的引用计数:\n", "node = cache_handle\n", "while not node.is_root():\n", " node.ref_count -= 1\n", " node = node.parent\n", "print_radix_tree(root_node)\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "DO7ZiUZzOlcv", "outputId": "29aee943-acf8-494a-a576-f7fb1343fcd7" }, "execution_count": 7, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "================================================================================\n", "RADIX TREE STRUCTURE\n", "================================================================================\n", "Root Node:\n", "└── uuid=0, ref=1 [R]\n", " └── uuid=2, tokens=[1, 3, 6, 7], slots=[0, 1, 2, 3], ref=1\n", " ├── uuid=1, tokens=[9, 77], slots=[4, 7], ref=0 [L]\n", " └── uuid=3, tokens=[87, 66], slots=[5, 6], ref=0 [L]\n", "\n", "================================================================================\n", "RADIX TREE STRUCTURE\n", "================================================================================\n", "Root Node:\n", "└── uuid=0, ref=1 [R]\n", " └── uuid=2, tokens=[1, 3, 6, 7], slots=[0, 1, 2, 3], ref=0\n", " ├── uuid=1, tokens=[9, 77], slots=[4, 7], ref=0 [L]\n", " └── uuid=3, tokens=[87, 66], slots=[5, 6], ref=0 [L]\n" ] } ] }, { "cell_type": "markdown", "source": [ "### 2.3.3 淘汰数据&释放空间\n", "\n", "当有新请求需要KV cache,从radix tree中释放空闲空间\n", "\n", "LRU策略:node里面定义了比较函数__lt__,采用heapq进行元素排序,堆顶元素是最小的。\n", "\n", "相当于leave_nodes.sort(key=lambda x: x.timestamp, reverse=True)\n", "\n", "**注意:** 下面代码可以运行两次,超过两次需要重新构造root_node数据" ], "metadata": { "id": "ZthPI3pmdK6e" } }, { "cell_type": "code", "source": [ "import heapq\n", "\n", "# 需要3个tokens空间\n", "target_size = 3\n", "\n", "# 查找叶子结点:\n", "nodes = [root_node]\n", "leave_nodes = []\n", "\n", "while len(nodes) > 0:\n", " node = nodes.pop()\n", " if node.is_leaf():\n", " if node.ref_count == 0:\n", " leave_nodes.append(node)\n", " else:\n", " for child in node.children.values():\n", " nodes.append(child)\n", "\n", "leave_nodes\n", "heapq.heapify(leave_nodes)\n", "evicted_indices = []\n", "evicted_size = 0\n", "\n", "print_radix_tree(root_node)\n", "\n", "# 删除空闲叶子结点,直到满足taget_size\n", "while evicted_size < target_size:\n", " assert (\n", " leave_nodes\n", " ), f\"Cannot evict enough cache, need {target_size}, only {evicted_size} evicted\"\n", " node = heapq.heappop(leave_nodes)\n", " assert node.ref_count == 0 and node.is_leaf() and not node.is_root()\n", " evicted_size += node.length\n", " evicted_indices.append(node.slots)\n", " parent = node.parent\n", " del parent.children[int(node._token_ids[0].item())]\n", " print_radix_tree(root_node)\n", " print(f\"Node: {node.uuid} is evicted\")\n", " if parent.is_leaf() and parent.ref_count == 0:\n", " heapq.heappush(leave_nodes, parent)\n", "\n", "free_slots = torch.cat(evicted_indices)\n", "\n", "print()\n", "print(f\"free slots: {free_slots}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wrX9V5TeT7pX", "outputId": "8d7368bd-30f7-4c32-b327-4f0b1a8d8602" }, "execution_count": 8, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "================================================================================\n", "RADIX TREE STRUCTURE\n", "================================================================================\n", "Root Node:\n", "└── uuid=0, ref=1 [R]\n", " └── uuid=2, tokens=[1, 3, 6, 7], slots=[0, 1, 2, 3], ref=0\n", " ├── uuid=1, tokens=[9, 77], slots=[4, 7], ref=0 [L]\n", " └── uuid=3, tokens=[87, 66], slots=[5, 6], ref=0 [L]\n", "\n", "================================================================================\n", "RADIX TREE STRUCTURE\n", "================================================================================\n", "Root Node:\n", "└── uuid=0, ref=1 [R]\n", " └── uuid=2, tokens=[1, 3, 6, 7], slots=[0, 1, 2, 3], ref=0\n", " └── uuid=3, tokens=[87, 66], slots=[5, 6], ref=0 [L]\n", "Node: 1 is evicted\n", "\n", "================================================================================\n", "RADIX TREE STRUCTURE\n", "================================================================================\n", "Root Node:\n", "└── uuid=0, ref=1 [R]\n", " └── uuid=2, tokens=[1, 3, 6, 7], slots=[0, 1, 2, 3], ref=0 [L]\n", "Node: 3 is evicted\n", "\n", "free slots: tensor([4, 7, 5, 6])\n" ] } ] }, { "cell_type": "markdown", "source": [ "# 3 KV manager的实现\n", "\n", "## 3.1 辅助函数定义:" ], "metadata": { "id": "s8aoqpG8FbLI" } }, { "cell_type": "code", "source": [ "from abc import ABC, abstractmethod\n", "from dataclasses import dataclass\n", "from typing import List, NamedTuple, Tuple, Dict\n", "import time\n", "import heapq\n", "\n", "\n", "# 请求体定义:\n", "@dataclass(eq=False)\n", "class Request:\n", " uid: int\n", " input_ids: torch.Tensor # cpu tensor\n", " table_idx: int\n", " cached_len: int\n", " output_len: int\n", " cache_handle: BaseCacheHandle = None\n", " max_tokens: int = 1024\n", "\n", " @property\n", " def input_len(self) -> int:\n", " return len(self.input_ids)\n", "\n", "# 计算整体尺寸大小:\n", "class SizeInfo(NamedTuple):\n", " evictable_size: int\n", " protected_size: int\n", " @property\n", " def total_size(self) -> int:\n", " return self.evictable_size + self.protected_size" ], "metadata": { "id": "xpRYQOOgs-6R" }, "execution_count": 9, "outputs": [] }, { "cell_type": "markdown", "source": [ "## 3.2 定义基础类\n", "\n", "CacheHandle 的作用:保存cache的长度,以及记录用于prefix的子节点" ], "metadata": { "id": "B8UAEMGDFvaT" } }, { "cell_type": "code", "source": [ "\n", "@dataclass(frozen=True)\n", "class BaseCacheHandle(ABC):\n", " cached_len: int\n", "\n", "@dataclass(frozen=True)\n", "class RadixCacheHandle(BaseCacheHandle):\n", " node: RadixTreeNode\n", "\n", "class BaseCacheManager(ABC):\n", " @abstractmethod\n", " def match_prefix(self, input_ids: torch.Tensor) -> Tuple[BaseCacheHandle, torch.Tensor]:\n", " \"\"\"\n", " Match prefix and return the indices of the matched prefix in the cache.\n", " This operation will not modify the cache.\n", " The returned indices is only safe to use when the handle is locked.\n", "\n", " Args:\n", " input_ids (torch.Tensor): The input ids to match. Shape: (seq_len,)\n", " Returns:\n", " handle (BaseCacheHandle): The handle to the matched prefix.\n", " indices (torch.Tensor): The indices of the longest-matched prefix in the cache.\n", " \"\"\"\n", "\n", " @abstractmethod\n", " def lock_handle(self, handle: BaseCacheHandle, unlock: bool = False) -> None:\n", " \"\"\"\n", " Lock or unlock a cache handle.\n", " This operation will not modify the cache, but change the size info only.\n", " When a handle is locked, it cannot be evicted.\n", " Handles must be locked before the previously-returned tensor of `match_prefix` is used.\n", " Otherwise it may be evicted by calling evict.\n", "\n", " Args:\n", " handle (BaseCacheHandle): The cache handle to lock or unlock.\n", " unlock (bool): Whether to unlock the handle. Defaults to False.\n", " \"\"\"\n", "\n", " @abstractmethod\n", " def insert_prefix(self, input_ids: torch.Tensor, indices: torch.Tensor) -> int:\n", " \"\"\"\n", " Insert a new prefix into the cache.\n", " This operation will modify the cache.\n", " Args:\n", " input_ids (torch.Tensor): The input ids to insert. Shape: (seq_len,)\n", " indices (torch.Tensor): The indices to store the new prefix. Shape: (seq_len,)\n", "\n", " Returns:\n", " int: The length of prefix that is already in the cache. This part is not\n", " inserted, so the caller should free these indices.\n", " \"\"\"\n", "\n", " @abstractmethod\n", " def evict(self, size: int) -> torch.Tensor:\n", " \"\"\"\n", " Evict some prefixes from the cache to free up space.\n", " This operation will modify the cache.\n", " Note that evict 0 is always safe and does nothing.\n", " Note that the actual evict size may be larger than the requested size.\n", " Args:\n", " size (int): The size to evict.\n", "\n", " Returns:\n", " torch.Tensor: The indices evicted. Shape: (evict_size,)\n", " Raises:\n", " RuntimeError: If the requested size is larger than the evictable size.\n", " \"\"\"\n", "\n", " @abstractmethod\n", " def reset(self) -> None:\n", " \"\"\"Reset the cache manager and the underlying cache.\"\"\"\n", "\n", " @property\n", " @abstractmethod\n", " def size_info(self) -> SizeInfo:\n", " \"\"\"Get the size information of the cache.\"\"\"" ], "metadata": { "id": "Ojomok3IFvwQ" }, "execution_count": 10, "outputs": [] }, { "cell_type": "markdown", "source": [ "## 3.3 CacheManager实现\n", "\n", "* RadixCacheManager 实现radix cahce管理\n", "* CacheManager 外层接口,存储的free slots记录所有可用物理显存" ], "metadata": { "id": "fYQbYK7tFwRY" } }, { "cell_type": "code", "source": [ "class RadixCacheManager(BaseCacheManager):\n", " def __init__(self, device: torch.device):\n", " self.device = device\n", " self.empty_tensor = torch.empty(0, dtype=torch.int32, device=device)\n", " super().__init__()\n", " self.root_node = RadixTreeNode()\n", " self.root_node.ref_count = 1 # root is always protected\n", " self.evictable_size = 0\n", " self.protected_size = 0\n", "\n", " def lock_handle(self, handle: BaseCacheHandle, unlock: bool = False) -> None:\n", " if handle is None:\n", " return\n", "\n", " node = handle.node\n", " if unlock:\n", " while not node.is_root():\n", " node.ref_count -= 1\n", " assert node.ref_count >= 0\n", " if node.ref_count == 0:\n", " self.evictable_size += node.length\n", " self.protected_size -= node.length\n", " node = node.parent\n", " else:\n", " while not node.is_root():\n", " if node.ref_count == 0:\n", " self.evictable_size -= node.length\n", " self.protected_size += node.length\n", " node.ref_count += 1\n", " node = node.parent\n", "\n", " def match_prefix(self, input_ids: torch.Tensor) -> Tuple[RadixCacheHandle, torch.Tensor]:\n", " node, prefix_len = self._walk(input_ids)\n", " if prefix_len == 0:\n", " assert node.is_root() and node is self.root_node and prefix_len == 0\n", " return RadixCacheHandle(prefix_len, node), self.empty_tensor\n", " slots_list: List[torch.Tensor] = []\n", " matched_node = node\n", " while not node.is_root():\n", " slots_list.append(node.slots)\n", " node = node.parent\n", " slots_list.reverse()\n", " return RadixCacheHandle(prefix_len, matched_node), torch.cat(slots_list)\n", "\n", " def insert_prefix(self, input_ids: torch.Tensor, indices: torch.Tensor) -> int:\n", " node, prefix_len = self._walk(input_ids)\n", " assert prefix_len <= len(input_ids)\n", " if prefix_len < len(input_ids):\n", " new_node = RadixTreeNode()\n", " new_node.set_ids_slots(input_ids[prefix_len:], indices[prefix_len:].clone())\n", " new_node.set_parent(node)\n", " self.evictable_size += new_node.length\n", " return prefix_len\n", "\n", " def _walk(self, input_ids: torch.Tensor) -> Tuple[RadixTreeNode, int]:\n", " prefix_len = 0\n", " indice_len = len(input_ids)\n", " node = self.root_node\n", " tic = time.monotonic_ns()\n", "\n", " while prefix_len < indice_len:\n", " this_id = int(input_ids[prefix_len].item())\n", " if this_id not in node.children:\n", " return node, prefix_len\n", "\n", " node = node.children[this_id]\n", "\n", " # NOTE: at least 1 char is matched, so match_len >= 1\n", " match_len = node.get_match_len(input_ids[prefix_len:])\n", " prefix_len += match_len\n", "\n", " # need to split the node if not fully matched\n", " if match_len != node.length:\n", " node = node._split_at(match_len)\n", " return node, prefix_len\n", "\n", " # update timestamp for accessed node\n", " node.timestamp = tic\n", "\n", " return node, prefix_len\n", "\n", " def evict(self, size: int) -> torch.Tensor:\n", " if size == 0:\n", " return self.empty_tensor\n", " assert (\n", " size <= self.evictable_size\n", " ), f\"Cannot evict {size}, only {self.evictable_size} is evictable\"\n", "\n", " leave_nodes = self._collect_leave_nodes_for_evict()\n", "\n", " heapq.heapify(leave_nodes)\n", " evicted_indices: List[torch.Tensor] = []\n", " evicted_size = 0\n", "\n", " while evicted_size < size:\n", " assert (\n", " leave_nodes\n", " ), f\"Cannot evict enough cache, need {size}, only {evicted_size} evicted\"\n", " node = heapq.heappop(leave_nodes)\n", " assert node.ref_count == 0 and node.is_leaf() and not node.is_root()\n", " evicted_size += node.length\n", " evicted_indices.append(node.slots)\n", " self.evictable_size -= node.length\n", " parent = node.parent\n", " del parent.children[int(node._token_ids[0].item())]\n", " # NOTE: root is always protected, so won't be evicted\n", " if parent.is_leaf() and parent.ref_count == 0:\n", " heapq.heappush(leave_nodes, parent)\n", "\n", " return torch.cat(evicted_indices)\n", "\n", " def _collect_leave_nodes_for_evict(self) -> List[RadixTreeNode]:\n", " nodes: List[RadixTreeNode] = [self.root_node]\n", " leave_nodes: List[RadixTreeNode] = []\n", "\n", " while len(nodes) > 0:\n", " node = nodes.pop()\n", " if node.is_leaf():\n", " if node.ref_count == 0:\n", " leave_nodes.append(node)\n", " else:\n", " for child in node.children.values():\n", " nodes.append(child)\n", "\n", " return leave_nodes\n", "\n", " def reset(self) -> None:\n", " raise NotImplementedError(\"RadixManager.reset is not implemented\")\n", "\n", " @property\n", " def size_info(self) -> SizeInfo:\n", " return SizeInfo(\n", " evictable_size=self.evictable_size,\n", " protected_size=self.protected_size,\n", " )\n", "\n", "\n", "\n", "class CacheManager:\n", " def __init__(self, device: torch.device, num_pages: int):\n", " self._free_slots = torch.arange(num_pages, dtype=torch.int32, device=device)\n", " self.device = device\n", " self.manager = RadixCacheManager(device=device)\n", " self.num_pages = num_pages\n", "\n", " def _free(self, indices: torch.Tensor) -> None:\n", " if len(indices) > 0:\n", " self._free_slots = torch.cat([self._free_slots, indices])\n", "\n", " def match_req(self, req: Request):\n", " input_len = req.input_len\n", " assert input_len > 0, \"Input length must be greater than 0.\"\n", " return self.manager.match_prefix(req.input_ids[: input_len - 1])\n", "\n", " @property\n", " def available_size(self) -> int:\n", " return self.manager.size_info.evictable_size + len(self._free_slots)\n", "\n", " def lock(self, handle: BaseCacheHandle) -> None:\n", " self.manager.lock_handle(handle, unlock=False)\n", "\n", " def unlock(self, handle: BaseCacheHandle) -> None:\n", " self.manager.lock_handle(handle, unlock=True)\n", "\n", " def allocate(self, needed_len: int) -> torch.Tensor:\n", " if needed_len <= (free_len := len(self._free_slots)):\n", " allocated = self._free_slots[:needed_len]\n", " self._free_slots = self._free_slots[needed_len:]\n", " return allocated\n", "\n", " # NOTE: len(evicted) + free_len >= needed_len\n", " evicted = self.manager.evict(needed_len - free_len)\n", " merged = torch.cat([self._free_slots, evicted])\n", " assert len(merged) >= needed_len, \"Eviction did not free enough space.\"\n", "\n", " allocated = merged[:needed_len]\n", " self._free_slots = merged[needed_len:]\n", " return allocated\n", "\n", " def free_and_cache_finished_req(\n", " self,\n", " old_handle: BaseCacheHandle,\n", " input_ids: torch.Tensor,\n", " indices: torch.Tensor,\n", " ) -> None:\n", " in_cache_len = self.manager.insert_prefix(input_ids, indices)\n", " self._free(indices[old_handle.cached_len : in_cache_len])\n", " self.unlock(old_handle)\n", "\n" ], "metadata": { "id": "EVhywASxn6jN" }, "execution_count": 11, "outputs": [] }, { "cell_type": "markdown", "source": [ "## 3.4 测试\n", "\n", "**步骤:**\n", "\n", "* 定义KV cache的可用显存空间大小,并计算cache_per_page、num_pages的数值;\n", "* 初始化page_table、 cache_manager;\n", "* 创建请求0,为该请求申请显存空间,随后模拟请 0释放的操作;\n", "* 创建请求1,先在 RadixTree 中匹配可复用的前缀缓存,再为请求1申请剩余所需空间;\n", "* 模拟请求1释放过程;" ], "metadata": { "id": "pdwCvp2NbrWC" } }, { "cell_type": "code", "source": [ "available_memory = 2000 * 1024 # 单位Byte\n", "cache_per_page = get_mha_cache_per_page(head_dim, num_kv_heads, num_layers)\n", "num_pages = get_num_pages_for_kv_cache(available_memory, cache_per_page)\n", "\n", "# 创建映射表、内存管理器\n", "page_table = create_page_table(5, 50)\n", "\n", "RadixTreeNode.counter = 0 # 设置节点计数从0开始\n", "cache_manager = CacheManager(torch.device(\"cpu\"), num_pages)\n", "print(f\"Init cache manager. Total cache available_size: {cache_manager.available_size} slots\")\n", "\n", "# 创建请求0\n", "req_0 = Request(uid=0, input_ids=torch.tensor([1, 3, 6, 7, 9, 77]), table_idx=0, cached_len=0, output_len=20)\n", "\n", "\n", "# 请求0申请空间\n", "allocated_slots = cache_manager.allocate(req_0.input_len)\n", "print(f\"After request0 allocated. Total cache available_size: {cache_manager.available_size} slots\")\n", "page_table[req_0.table_idx][:req_0.input_len] = allocated_slots[:req_0.input_len]\n", "\n", "# 请求0释放\n", "cache_manager.free_and_cache_finished_req(RadixCacheHandle(0, cache_manager.manager.root_node), req_0.input_ids, page_table[req_0.table_idx][:req_0.input_len])\n", "print(f\"After request0 free. Total cache available_size: {cache_manager.available_size} slots\")\n", "\n", "print_radix_tree(cache_manager.manager.root_node)\n", "\n", "# 创建请求1\n", "req_1 = Request(uid=0, input_ids=torch.tensor([1, 3, 6, 7, 87, 66]), table_idx=1, cached_len=0, output_len=20)\n", "\n", "# 请求1进行前缀匹配\n", "handle, match_indices = cache_manager.match_req(req_1)\n", "cached_len = handle.cached_len\n", "cache_manager.lock(handle)\n", "\n", "# 请求1占用显存\n", "page_table[req_1.table_idx][:cached_len].copy_(match_indices)\n", "\n", "extend_len = req_1.input_len - cached_len\n", "allocated_slots = cache_manager.allocate(extend_len)\n", "page_table[req_1.table_idx][cached_len:req_1.input_len] = allocated_slots[:extend_len]\n", "print_radix_tree(cache_manager.manager.root_node)\n", "\n", "# 请求1释放\n", "cache_manager.free_and_cache_finished_req(handle, req_1.input_ids, page_table[req_1.table_idx][:req_1.input_len])\n", "\n", "print_radix_tree(cache_manager.manager.root_node)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fd4fy0EvrKjO", "outputId": "57b6c964-8f51-482e-98e4-7baff7b36bf7" }, "execution_count": 12, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Allocating 250 pages for KV cache, K + V = 0.002GB\n", "Init cache manager. Total cache available_size: 250 slots\n", "After request0 allocated. Total cache available_size: 244 slots\n", "After request0 free. Total cache available_size: 250 slots\n", "\n", "================================================================================\n", "RADIX TREE STRUCTURE\n", "================================================================================\n", "Root Node:\n", "└── uuid=0, ref=1 [R]\n", " └── uuid=1, tokens=[1, 3, 6, 7, 9, 77], slots=[0, 1, 2, 3, 4, 5], ref=0 [L]\n", "\n", "================================================================================\n", "RADIX TREE STRUCTURE\n", "================================================================================\n", "Root Node:\n", "└── uuid=0, ref=1 [R]\n", " └── uuid=2, tokens=[1, 3, 6, 7], slots=[0, 1, 2, 3], ref=1\n", " └── uuid=1, tokens=[9, 77], slots=[4, 5], ref=0 [L]\n", "\n", "================================================================================\n", "RADIX TREE STRUCTURE\n", "================================================================================\n", "Root Node:\n", "└── uuid=0, ref=1 [R]\n", " └── uuid=2, tokens=[1, 3, 6, 7], slots=[0, 1, 2, 3], ref=0\n", " ├── uuid=1, tokens=[9, 77], slots=[4, 5], ref=0 [L]\n", " └── uuid=3, tokens=[87, 66], slots=[6, 7], ref=0 [L]\n" ] } ] } ] }