{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "authorship_tag": "ABX9TyP6tsWQlvoBbf7JGb5htXAM" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "# Chunked Prefill & Flash Decoding计算演示\n", "\n", "Author: kaiyuan\n", "\n", "Email: kyxie@zju.edu.cn" ], "metadata": { "id": "zK12wHz9C4cH" } }, { "cell_type": "markdown", "source": [ "## 1 Chunked Prefill计算\n", "\n", "构造一个ChunkedPrefill运算与基础attention运算的对比" ], "metadata": { "id": "kp-2yDz2DGWa" } }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "feP3rGguBVpK" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from typing import Optional, Tuple, List\n", "\n", "\n", "class CausalChunkedPrefill(nn.Module):\n", " \"\"\"\n", " 流式 + 因果的Chunked Prefill实现\n", " 专为自回归LLM(如GPT、LLaMA)的推理优化\n", " \"\"\"\n", "\n", " def __init__(self, d_model: int, n_heads: int, chunk_size: int = 512):\n", " super().__init__()\n", " self.d_model = d_model\n", " self.n_heads = n_heads\n", " self.chunk_size = chunk_size\n", " self.head_dim = d_model // n_heads\n", "\n", " assert d_model % n_heads == 0, \"d_model must be divisible by n_heads\"\n", "\n", " # QKV投影层\n", " self.q_proj = nn.Linear(d_model, d_model)\n", " self.k_proj = nn.Linear(d_model, d_model)\n", " self.v_proj = nn.Linear(d_model, d_model)\n", " self.out_proj = nn.Linear(d_model, d_model)\n", "\n", " def _split_heads(self, x: torch.Tensor) -> torch.Tensor:\n", " \"\"\"将张量分割成多头\"\"\"\n", " batch_size, seq_len, _ = x.shape\n", " return x.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)\n", "\n", " def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:\n", " \"\"\"将多头合并\"\"\"\n", " batch_size, n_heads, seq_len, head_dim = x.shape\n", " return x.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)\n", "\n", " def prefill_standard(self, x: torch.Tensor) -> torch.Tensor:\n", " \"\"\"\n", " 标准注意力(不分块)- 用于验证正确性\n", " 因果注意力:每个位置只能看到之前的位置\n", " \"\"\"\n", " batch_size, seq_len, _ = x.shape\n", "\n", " # 计算QKV\n", " q = self.q_proj(x)\n", " k = self.k_proj(x)\n", " v = self.v_proj(x)\n", "\n", " # 分割多头\n", " q = self._split_heads(q) # [batch, n_heads, seq_len, head_dim]\n", " k = self._split_heads(k)\n", " v = self._split_heads(v)\n", "\n", " # 计算注意力分数\n", " scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)\n", "\n", " # 应用因果掩码(下三角矩阵)\n", " mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device))\n", " mask = mask.view(1, 1, seq_len, seq_len)\n", " scores = scores.masked_fill(mask == 0, float('-inf'))\n", "\n", " # softmax\n", " attn_weights = F.softmax(scores, dim=-1)\n", "\n", " # 注意力输出\n", " attn_output = torch.matmul(attn_weights, v)\n", "\n", " # 合并多头\n", " output = self._merge_heads(attn_output)\n", " output = self.out_proj(output)\n", "\n", " return output\n", "\n", " def prefill_chunked(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[List[torch.Tensor], List[torch.Tensor]]]:\n", " \"\"\"\n", " 分块预填充(流式 + 因果)\n", "\n", " Args:\n", " x: 输入序列 [batch, seq_len, d_model]\n", "\n", " Returns:\n", " output: 注意力输出 [batch, seq_len, d_model]\n", " kv_cache: KV缓存 (K列表, V列表)\n", " \"\"\"\n", " batch_size, seq_len, _ = x.shape\n", "\n", " # 计算总chunk数\n", " n_chunks = (seq_len + self.chunk_size - 1) // self.chunk_size\n", "\n", " # 初始化KV缓存(存储每个chunk的K和V)\n", " k_cache = [] # 每个元素: [batch, n_heads, chunk_size, head_dim]\n", " v_cache = [] # 每个元素: [batch, n_heads, chunk_size, head_dim]\n", "\n", " # 存储每个chunk的输出\n", " outputs = []\n", "\n", " print(f\"分块预填充: 序列长度={seq_len}, 分块大小={self.chunk_size}, 分块数={n_chunks}\")\n", "\n", " for chunk_idx in range(n_chunks):\n", " # 当前chunk的起始和结束位置\n", " start = chunk_idx * self.chunk_size\n", " end = min((chunk_idx + 1) * self.chunk_size, seq_len)\n", " chunk_len = end - start\n", "\n", " # 获取当前chunk\n", " chunk = x[:, start:end, :]\n", "\n", " # 计算当前chunk的QKV\n", " q = self.q_proj(chunk)\n", " k = self.k_proj(chunk)\n", " v = self.v_proj(chunk)\n", "\n", " # 分割多头\n", " q = self._split_heads(q) # [batch, n_heads, chunk_len, head_dim]\n", " k = self._split_heads(k)\n", " v = self._split_heads(v)\n", "\n", " # 将当前chunk的K和V添加到缓存\n", " k_cache.append(k)\n", " v_cache.append(v)\n", "\n", " # 当前累计的KV总长度\n", " total_kv_len = sum(k.shape[2] for k in k_cache)\n", "\n", " # 拼接当前所有可用的K和V(因果:只能看到当前和之前的chunk)\n", " k_all = torch.cat(k_cache, dim=2) # [batch, n_heads, total_kv_len, head_dim]\n", " v_all = torch.cat(v_cache, dim=2)\n", "\n", " # 计算注意力分数\n", " scores = torch.matmul(q, k_all.transpose(-2, -1)) / (self.head_dim ** 0.5)\n", "\n", " # 创建因果掩码\n", " # 注意:我们需要确保当前chunk内的Q也不能看到同一chunk内未来的K\n", " # 所以需要构建一个 [chunk_len, total_kv_len] 的掩码\n", "\n", " # 方法1:构建完整的掩码矩阵\n", " q_positions = torch.arange(chunk_len, device=x.device).unsqueeze(1) + start\n", " kv_positions = []\n", " for i, k_chunk in enumerate(k_cache):\n", " kv_start = i * self.chunk_size\n", " kv_len = k_chunk.shape[2]\n", " kv_positions.extend(range(kv_start, kv_start + kv_len))\n", " kv_positions = torch.tensor(kv_positions, device=x.device).unsqueeze(0)\n", "\n", " # Q位置只能看到小于等于它的KV位置\n", " mask = q_positions >= kv_positions # [chunk_len, total_kv_len]\n", " mask = mask.view(1, 1, chunk_len, total_kv_len)\n", "\n", " # 应用掩码\n", " scores = scores.masked_fill(~mask, float('-inf'))\n", "\n", " # softmax\n", " attn_weights = F.softmax(scores, dim=-1)\n", "\n", " # 注意力输出\n", " attn_output = torch.matmul(attn_weights, v_all)\n", "\n", " # 合并多头\n", " output_chunk = self._merge_heads(attn_output)\n", " output_chunk = self.out_proj(output_chunk)\n", "\n", " outputs.append(output_chunk)\n", "\n", " print(f\" 处理chunk {chunk_idx+1}/{n_chunks}: \"\n", " f\"位置 {start}:{end}, \"\n", " f\"KV缓存长度={total_kv_len}\")\n", "\n", " # 拼接所有chunk的输出\n", " output = torch.cat(outputs, dim=1)\n", "\n", " return output, (k_cache, v_cache)\n", "\n", " def decode_step(self,\n", " x: torch.Tensor,\n", " kv_cache: Tuple[List[torch.Tensor], List[torch.Tensor]]\n", " ) -> Tuple[torch.Tensor, Tuple[List[torch.Tensor], List[torch.Tensor]]]:\n", " \"\"\"\n", " 解码步骤:处理单个token(使用KV缓存)\n", "\n", " Args:\n", " x: 当前token [batch, 1, d_model]\n", " kv_cache: KV缓存 (K列表, V列表)\n", "\n", " Returns:\n", " output: 当前token的输出 [batch, 1, d_model]\n", " updated_kv_cache: 更新后的KV缓存\n", " \"\"\"\n", " k_cache, v_cache = kv_cache\n", "\n", " # 计算当前token的QKV\n", " q = self.q_proj(x)\n", " k = self.k_proj(x)\n", " v = self.v_proj(x)\n", "\n", " # 分割多头\n", " q = self._split_heads(q) # [batch, n_heads, 1, head_dim]\n", " k = self._split_heads(k)\n", " v = self._split_heads(v)\n", "\n", " # 添加到缓存\n", " k_cache.append(k)\n", " v_cache.append(v)\n", "\n", " # 拼接所有K和V\n", " k_all = torch.cat(k_cache, dim=2)\n", " v_all = torch.cat(v_cache, dim=2)\n", "\n", " # 计算注意力(因果掩码自动满足,因为只关注最后一个位置)\n", " scores = torch.matmul(q, k_all.transpose(-2, -1)) / (self.head_dim ** 0.5)\n", "\n", " # softmax\n", " attn_weights = F.softmax(scores, dim=-1)\n", "\n", " # 注意力输出\n", " attn_output = torch.matmul(attn_weights, v_all)\n", "\n", " # 合并多头\n", " output = self._merge_heads(attn_output)\n", " output = self.out_proj(output)\n", "\n", " return output, (k_cache, v_cache)" ] }, { "cell_type": "markdown", "source": [ "构造一个流式计算类:" ], "metadata": { "id": "Q2xL994xMEAt" } }, { "cell_type": "code", "source": [ "class StreamingLLMAttention:\n", " \"\"\"\n", " 流式LLM注意力层\n", " \"\"\"\n", "\n", " def __init__(self, d_model: int, n_heads: int, chunk_size: int = 512):\n", " self.attn = CausalChunkedPrefill(d_model, n_heads, chunk_size)\n", " self.chunk_size = chunk_size\n", " self.kv_cache = None\n", "\n", " def prefill(self, prompt: torch.Tensor) -> torch.Tensor:\n", " \"\"\"\n", " 预填充阶段:处理用户输入的prompt\n", "\n", " Args:\n", " prompt: 用户输入的prompt [batch, prompt_len, d_model]\n", "\n", " Returns:\n", " 注意力输出\n", " \"\"\"\n", " output, self.kv_cache = self.attn.prefill_chunked(prompt)\n", " return output\n", "\n", " def generate_token(self, token_emb: torch.Tensor) -> torch.Tensor:\n", " \"\"\"\n", " 生成一个token\n", "\n", " Args:\n", " token_emb: 当前token的embedding [batch, 1, d_model]\n", "\n", " Returns:\n", " 当前token的输出\n", " \"\"\"\n", " if self.kv_cache is None:\n", " raise ValueError(\"请先调用prefill初始化KV缓存\")\n", "\n", " output, self.kv_cache = self.attn.decode_step(token_emb, self.kv_cache)\n", " return output\n", "\n", " def reset_cache(self):\n", " \"\"\"重置KV缓存\"\"\"\n", " self.kv_cache = None" ], "metadata": { "id": "lAm2uhpkMDFk" }, "execution_count": 2, "outputs": [] }, { "cell_type": "markdown", "source": [ "设计一个测试函数:" ], "metadata": { "id": "7MqqGIB_DYt8" } }, { "cell_type": "code", "source": [ "def test_causal_chunked_prefill():\n", " \"\"\"测试流式 + 因果的分块预填充\"\"\"\n", "\n", " torch.manual_seed(42)\n", "\n", " print(\"=\" * 70)\n", " print(\"流式 + 因果分块预填充测试\")\n", " print(\"=\" * 70)\n", "\n", " # 测试配置\n", " batch_size = 2\n", " seq_len = 9 # 测试用短序列\n", " d_model = 64\n", " n_heads = 4\n", " chunk_size = 3\n", "\n", " print(f\"配置:\")\n", " print(f\" batch_size={batch_size}, seq_len={seq_len}\")\n", " print(f\" d_model={d_model}, n_heads={n_heads}\")\n", " print(f\" chunk_size={chunk_size}\")\n", "\n", " # 创建模型\n", " model = CausalChunkedPrefill(\n", " d_model=d_model,\n", " n_heads=n_heads,\n", " chunk_size=chunk_size\n", " )\n", "\n", " # 创建随机输入\n", " x = torch.randn(batch_size, seq_len, d_model)\n", "\n", " print(f\"\\n输入形状: {x.shape}\")\n", "\n", " print(\"\\n1. 计算标准注意力(不分块)...\")\n", " with torch.no_grad():\n", " output_standard = model.prefill_standard(x)\n", " print(f\" 标准注意力输出形状: {output_standard.shape}\")\n", "\n", " print(\"\\n2. 计算分块预填充...\")\n", " with torch.no_grad():\n", " output_chunked, kv_cache = model.prefill_chunked(x)\n", " print(f\" 分块预填充输出形状: {output_chunked.shape}\")\n", " print(\"\\n3. 比较两种方法的输出...\")\n", " diff = torch.abs(output_standard - output_chunked)\n", " max_diff = diff.max().item()\n", " mean_diff = diff.mean().item()\n", "\n", " print(f\" 最大差异: {max_diff:.10f}\")\n", " print(f\" 平均差异: {mean_diff:.10f}\")\n", "\n", " tolerance = 1e-5\n", " if max_diff < tolerance:\n", " print(f\" ✓ 测试通过!差异在容忍范围内 (< {tolerance})\")\n", " else:\n", " print(f\" ✗ 测试失败!差异超出容忍范围\")\n", " print(f\"\\n 调试信息(第一个样本的前3个位置):\")\n", " for pos in range(min(3, seq_len)):\n", " print(f\" 位置 {pos}:\")\n", " print(f\" 标准: {output_standard[0, pos, :5].detach().numpy().round(4)}\")\n", " print(f\" 分块: {output_chunked[0, pos, :5].detach().numpy().round(4)}\")\n", " print(f\" 差异: {diff[0, pos, :5].detach().numpy().round(8)}\")\n", "\n", " print(\"\\n4. 测试解码步骤(生成后续token)...\")\n", "\n", " # 生成一个测试token\n", " test_token = torch.randn(batch_size, 1, d_model)\n", "\n", " # 使用KV缓存解码\n", " output_decode, updated_kv_cache = model.decode_step(test_token, kv_cache)\n", "\n", " print(f\" 解码输出形状: {output_decode.shape}\")\n", " print(f\" 更新后KV缓存长度: {sum(k.shape[2] for k in updated_kv_cache[0])}\")\n", "\n", " # 验证解码的正确性\n", " x_with_new = torch.cat([x, test_token], dim=1)\n", "\n", " # 用标准注意力计算完整结果\n", " output_full_new = model.prefill_standard(x_with_new)\n", "\n", " # 取最后一个位置的输出(对应新token)\n", " output_full_last = output_full_new[:, -1:, :]\n", "\n", " # 比较\n", " diff_decode = torch.abs(output_decode - output_full_last).max().item()\n", " print(f\" 解码vs标准差异: {diff_decode:.10f}\")\n", "\n", " if diff_decode < tolerance:\n", " print(f\" ✓ 解码步骤正确\")\n", " else:\n", " print(f\" ✗ 解码步骤有误\")\n", "\n", " return max_diff < tolerance and diff_decode < tolerance\n", "\n", "\n", "def test_streaming_api():\n", " \"\"\"测试流式API\"\"\"\n", " torch.manual_seed(42)\n", "\n", " print(\"\\n\" + \"=\" * 70)\n", " print(\"流式API测试\")\n", " print(\"=\" * 70)\n", "\n", " # 创建流式LLM注意力\n", " stream_attn = StreamingLLMAttention(\n", " d_model=64,\n", " n_heads=4,\n", " chunk_size=3\n", " )\n", "\n", " # 模拟一个prompt\n", " prompt_len = 9\n", " prompt = torch.randn(1, prompt_len, 64)\n", "\n", " print(f\"1. 预填充阶段: 处理{prompt_len}个token的prompt\")\n", " output = stream_attn.prefill(prompt)\n", " print(f\" 输出形状: {output.shape}\")\n", "\n", " print(f\"\\n2. 生成阶段: 生成3个token\")\n", " for i in range(3):\n", " # 模拟一个token的embedding(实际中从embedding层获取)\n", " token_emb = torch.randn(1, 1, 64)\n", "\n", " output_token = stream_attn.generate_token(token_emb)\n", " print(f\" 生成token {i+1}: 输出形状 {output_token.shape}\")\n", "\n", " print(f\"\\n3. 重置缓存\")\n", " stream_attn.reset_cache()\n", " print(f\" KV缓存已重置\")\n", "\n", "\n", "def benchmark_performance():\n", " \"\"\"性能基准测试\"\"\"\n", "\n", " import time\n", "\n", " torch.manual_seed(42)\n", "\n", " print(\"\\n\" + \"=\" * 70)\n", " print(\"性能基准测试\")\n", " print(\"=\" * 70)\n", "\n", " # 测试长序列\n", " d_model = 1024\n", " n_heads = 16\n", " chunk_size = 512\n", "\n", " # 创建模型\n", " model = CausalChunkedPrefill(\n", " d_model=d_model,\n", " n_heads=n_heads,\n", " chunk_size=chunk_size\n", " )\n", "\n", " # 测试不同序列长度\n", " test_cases = [\n", " {\"seq_len\": 512, \"desc\": \"短序列(一个chunk)\"},\n", " {\"seq_len\": 2048, \"desc\": \"中等序列(4个chunk)\"},\n", " {\"seq_len\": 8192, \"desc\": \"长序列(16个chunk)\"},\n", " ]\n", "\n", " for test_case in test_cases:\n", " seq_len = test_case[\"seq_len\"]\n", "\n", " print(f\"\\n测试: {test_case['desc']} (seq_len={seq_len})\")\n", "\n", " x = torch.randn(1, seq_len, d_model)\n", "\n", " # 标准注意力(可能OOM)\n", " try:\n", " torch.cuda.empty_cache() if torch.cuda.is_available() else None\n", "\n", " start = time.time()\n", " with torch.no_grad():\n", " output_std = model.prefill_standard(x)\n", " time_std = time.time() - start\n", "\n", " mem_std = torch.cuda.max_memory_allocated() if torch.cuda.is_available() else 0\n", "\n", " print(f\" 标准注意力: {time_std:.3f}s, \"\n", " f\"内存: {mem_std/1024**2:.1f}MB\" if torch.cuda.is_available() else f\"{time_std:.3f}s\")\n", " except RuntimeError as e:\n", " print(f\" 标准注意力 OOM: {e}\")\n", " time_std = float('inf')\n", " mem_std = float('inf')\n", "\n", " # 分块注意力\n", " try:\n", " torch.cuda.empty_cache() if torch.cuda.is_available() else None\n", "\n", " start = time.time()\n", " with torch.no_grad():\n", " output_chunk, _ = model.prefill_chunked(x)\n", " time_chunk = time.time() - start\n", "\n", " mem_chunk = torch.cuda.max_memory_allocated() if torch.cuda.is_available() else 0\n", "\n", " print(f\" 分块注意力: {time_chunk:.3f}s, \"\n", " f\"内存: {mem_chunk/1024**2:.1f}MB\" if torch.cuda.is_available() else f\"{time_chunk:.3f}s\")\n", "\n", " if time_std != float('inf'):\n", " speedup = time_std / time_chunk if time_chunk > 0 else 0\n", " mem_reduction = (mem_std - mem_chunk) / mem_std * 100 if mem_std > 0 else 0\n", " print(f\" 加速: {speedup:.1f}x, 内存减少: {mem_reduction:.1f}%\")\n", " except RuntimeError as e:\n", " print(f\" 分块注意力 OOM: {e}\")\n", "\n", "\n", "print(\"流式 + 因果分块预填充实现\")\n", "\n", "# 运行基本测试\n", "test_passed = test_causal_chunked_prefill()\n", "\n", "if test_passed:\n", " # 测试流式API\n", " test_streaming_api()\n", "\n", " # 性能测试\n", " if torch.cuda.is_available():\n", " benchmark_performance()\n", " else:\n", " print(\"\\n注意: CUDA不可用,跳过性能测试\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "WxQEF9bIH-PD", "outputId": "45dcb53f-5edd-459e-cba1-66bf750f9d96" }, "execution_count": 3, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "流式 + 因果分块预填充实现\n", "======================================================================\n", "流式 + 因果分块预填充测试\n", "======================================================================\n", "配置:\n", " batch_size=2, seq_len=9\n", " d_model=64, n_heads=4\n", " chunk_size=3\n", "\n", "输入形状: torch.Size([2, 9, 64])\n", "\n", "1. 计算标准注意力(不分块)...\n", " 标准注意力输出形状: torch.Size([2, 9, 64])\n", "\n", "2. 计算分块预填充...\n", "分块预填充: 序列长度=9, 分块大小=3, 分块数=3\n", " 处理chunk 1/3: 位置 0:3, KV缓存长度=3\n", " 处理chunk 2/3: 位置 3:6, KV缓存长度=6\n", " 处理chunk 3/3: 位置 6:9, KV缓存长度=9\n", " 分块预填充输出形状: torch.Size([2, 9, 64])\n", "\n", "3. 比较两种方法的输出...\n", " 最大差异: 0.0000001192\n", " 平均差异: 0.0000000094\n", " ✓ 测试通过!差异在容忍范围内 (< 1e-05)\n", "\n", "4. 测试解码步骤(生成后续token)...\n", " 解码输出形状: torch.Size([2, 1, 64])\n", " 更新后KV缓存长度: 10\n", " 解码vs标准差异: 0.0000000745\n", " ✓ 解码步骤正确\n", "\n", "======================================================================\n", "流式API测试\n", "======================================================================\n", "1. 预填充阶段: 处理9个token的prompt\n", "分块预填充: 序列长度=9, 分块大小=3, 分块数=3\n", " 处理chunk 1/3: 位置 0:3, KV缓存长度=3\n", " 处理chunk 2/3: 位置 3:6, KV缓存长度=6\n", " 处理chunk 3/3: 位置 6:9, KV缓存长度=9\n", " 输出形状: torch.Size([1, 9, 64])\n", "\n", "2. 生成阶段: 生成3个token\n", " 生成token 1: 输出形状 torch.Size([1, 1, 64])\n", " 生成token 2: 输出形状 torch.Size([1, 1, 64])\n", " 生成token 3: 输出形状 torch.Size([1, 1, 64])\n", "\n", "3. 重置缓存\n", " KV缓存已重置\n", "\n", "注意: CUDA不可用,跳过性能测试\n" ] } ] }, { "cell_type": "markdown", "source": [ "## 2 Flash Decoding计算演示" ], "metadata": { "id": "JpZiXMdIDeg4" } }, { "cell_type": "markdown", "source": [ "## 2.1 方式一:保存max和block_sum_exp值" ], "metadata": { "id": "Reb6W3XcxBtP" } }, { "cell_type": "code", "source": [ "import torch\n", "import torch.nn.functional as F\n", "import math\n", "import time\n", "\n", "\n", "class FlashDecodingDemo:\n", " \"\"\"Flash-Decoding注意力计算演示\"\"\"\n", "\n", " def __init__(self, d_model: int = 64, num_heads: int = 8):\n", " self.d_model = d_model\n", " self.num_heads = num_heads\n", " self.head_dim = d_model // num_heads\n", "\n", " def traditional_attention(self, q, k, v):\n", " \"\"\"传统连续注意力计算\"\"\"\n", " scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)\n", " attention_weights = F.softmax(scores, dim=-1)\n", " output = torch.matmul(attention_weights, v)\n", " return output, attention_weights\n", "\n", " def flash_decoding_attention(self, q, k, v, block_size=32, tiling_mode='distributed'):\n", " if tiling_mode == 'distributed':\n", " return self.flash_decoding_distributed_tiling(q, k, v,\n", " tile_size_kv=block_size)\n", " else:\n", " return self.flash_decoding_without_cp(q, k, v, block_size)\n", "\n", " def flash_decoding_without_cp(self, q, k, v, block_size=32):\n", " \"\"\"\n", " 分块的FA\n", "\n", " \"\"\"\n", " batch_size, num_heads, seq_len_q, _ = q.shape\n", " seq_len_kv = k.shape[2]\n", " num_blocks = (seq_len_kv + block_size - 1) // block_size\n", "\n", " # 初始化累积变量\n", " # 累积的加权和\n", " numerator = torch.zeros(batch_size, num_heads, seq_len_q, self.head_dim,\n", " device=q.device, dtype=q.dtype)\n", " # 累积的归一化因子\n", " d_prime = torch.zeros(batch_size, num_heads, seq_len_q, 1,\n", " device=q.device, dtype=q.dtype)\n", "\n", " # 用于数值稳定性的全局最大值(初始设为很小的数)\n", " global_max = torch.full((batch_size, num_heads, seq_len_q, 1),\n", " -float('inf'),\n", " device=q.device, dtype=q.dtype)\n", "\n", " # 分块处理\n", " for block_idx in range(num_blocks):\n", " start_idx = block_idx * block_size\n", " end_idx = min(start_idx + block_size, seq_len_kv)\n", "\n", " k_block = k[:, :, start_idx:end_idx, :]\n", " v_block = v[:, :, start_idx:end_idx, :]\n", "\n", " # 计算当前块的注意力分数\n", " scores_block = torch.matmul(q, k_block.transpose(-2, -1)) / math.sqrt(self.head_dim)\n", "\n", " # 当前块的最大值\n", " block_max = scores_block.max(dim=-1, keepdim=True).values\n", "\n", " # 更新全局最大值\n", " # 我们需要比较每个位置(每个query)在所有块中的最大值\n", " new_global_max = torch.maximum(global_max, block_max)\n", "\n", " # 调整之前累积的权重(基于新的全局最大值)\n", " # 当全局最大值更新时,需要重新调整之前累积的权重\n", " if block_idx > 0:\n", " # 将之前累积的权重调整到新的尺度\n", " adjustment_factor = torch.exp(global_max - new_global_max)\n", " numerator = numerator * adjustment_factor\n", " d_prime = d_prime * adjustment_factor\n", "\n", " # 更新全局最大值\n", " global_max = new_global_max\n", "\n", " # 计算当前块的指数权重(减去全局最大值以保持数值稳定)\n", " exp_scores = torch.exp(scores_block - global_max)\n", " block_sum_exp = exp_scores.sum(dim=-1, keepdim=True)\n", "\n", " # 累积加权和\n", " numerator = numerator + torch.matmul(exp_scores, v_block)\n", " d_prime = d_prime + block_sum_exp\n", "\n", " # 最终归一化\n", " final_output = numerator / d_prime\n", "\n", " # 为了验证,也计算完整的注意力权重\n", " full_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)\n", " full_attention_weights = F.softmax(full_scores, dim=-1)\n", "\n", " return final_output, full_attention_weights\n", "\n", " def flash_decoding_distributed_tiling(self, q, k, v,\n", " tile_size_kv: int = 256,\n", " num_streams: int = 5):\n", " \"\"\"\n", " 使用数组明确模拟多个计算流(stream)的并行\n", " 每个流有自己独立的累加器数组\n", " \"\"\"\n", " batch_size, num_heads, seq_len_q, head_dim = q.shape\n", " seq_len_kv = k.shape[2]\n", " num_tiles = (seq_len_kv + tile_size_kv - 1) // tile_size_kv\n", "\n", " print(f\"\\n分布式数组实现: {num_streams}个计算流\")\n", " print(f\"每个流有自己的O、M、L数组\")\n", "\n", " # 创建流数组:每个流有独立的(O, M, L)\n", " stream_O = [] # 加权和数组\n", " stream_M = [] # 最大值数组\n", " stream_L = [] # exp和数组\n", "\n", " for stream_id in range(num_streams):\n", " # 每个流初始化自己的累加器\n", " O_stream = torch.zeros_like(q)\n", " M_stream = torch.full((batch_size, num_heads, seq_len_q, 1),\n", " -float('inf'), device=q.device, dtype=q.dtype)\n", " L_stream = torch.zeros_like(M_stream)\n", "\n", " stream_O.append(O_stream)\n", " stream_M.append(M_stream)\n", " stream_L.append(L_stream)\n", "\n", " # 模拟流并行处理tile\n", " print(f\"并行处理{num_tiles}个tile...\")\n", "\n", " for tile_idx in range(num_tiles):\n", " # 确定处理这个tile的流\n", " stream_id = tile_idx % num_streams\n", "\n", " # 获取当前tile\n", " start_idx = tile_idx * tile_size_kv\n", " end_idx = min(start_idx + tile_size_kv, seq_len_kv)\n", "\n", " k_tile = k[:, :, start_idx:end_idx, :]\n", " v_tile = v[:, :, start_idx:end_idx, :]\n", "\n", " # 当前流处理(只能访问自己的数组)\n", " O_curr = stream_O[stream_id]\n", " M_curr = stream_M[stream_id]\n", " L_curr = stream_L[stream_id]\n", "\n", " # 计算当前tile\n", " S_tile = torch.matmul(q, k_tile.transpose(-2, -1)) / math.sqrt(head_dim)\n", " m_tile = S_tile.max(dim=-1, keepdim=True).values\n", "\n", " # 更新当前流的统计量\n", " new_M = torch.maximum(M_curr, m_tile)\n", "\n", " if not torch.allclose(M_curr, new_M):\n", " scale = torch.exp(M_curr - new_M)\n", " O_curr = O_curr * scale\n", " L_curr = L_curr * scale\n", "\n", " exp_tile = torch.exp(S_tile - new_M)\n", " l_tile = exp_tile.sum(dim=-1, keepdim=True)\n", "\n", " # 更新当前流的数组\n", " stream_O[stream_id] = O_curr + torch.matmul(exp_tile, v_tile)\n", " stream_L[stream_id] = L_curr + l_tile\n", " stream_M[stream_id] = new_M\n", "\n", "\n", " # 归约所有流的结果\n", " final_output = self.reduce_stream_arrays(stream_O, stream_M, stream_L)\n", "\n", " # 为了验证,也计算完整的注意力权重\n", " full_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)\n", " full_attention_weights = F.softmax(full_scores, dim=-1)\n", "\n", " return final_output, full_attention_weights\n", "\n", " def reduce_stream_arrays(self, stream_O, stream_M, stream_L):\n", " \"\"\"归约多个流的数组结果\"\"\"\n", " num_streams = len(stream_O)\n", " if num_streams == 0:\n", " return torch.zeros_like(stream_O[0])\n", "\n", " # 使用树形归约算法\n", " # 第一轮:相邻流两两归约\n", " current_O = stream_O.copy()\n", " current_M = stream_M.copy()\n", " current_L = stream_L.copy()\n", "\n", " remaining = num_streams\n", " step = 1\n", "\n", " while remaining > 1:\n", " next_O = []\n", " next_M = []\n", " next_L = []\n", "\n", " # 每两个流归约为一个\n", " for i in range(0, remaining, 2):\n", " if i + 1 < remaining:\n", " # 归约流i和流i+1\n", " O1, M1, L1 = current_O[i], current_M[i], current_L[i]\n", " O2, M2, L2 = current_O[i+1], current_M[i+1], current_L[i+1]\n", "\n", " # 合并\n", " new_M = torch.maximum(M1, M2)\n", "\n", " # 调整第一个流\n", " if not torch.allclose(M1, new_M):\n", " scale1 = torch.exp(M1 - new_M)\n", " O1 = O1 * scale1\n", " L1 = L1 * scale1\n", "\n", " # 调整第二个流\n", " scale2 = torch.exp(M2 - new_M)\n", " O2 = O2 * scale2\n", " L2 = L2 * scale2\n", "\n", " # 合并结果\n", " merged_O = O1 + O2\n", " merged_L = L1 + L2\n", "\n", " next_O.append(merged_O)\n", " next_M.append(new_M)\n", " next_L.append(merged_L)\n", "\n", " else:\n", " # 奇数个流时,最后一个流直接进入下一轮\n", " next_O.append(current_O[i])\n", " next_M.append(current_M[i])\n", " next_L.append(current_L[i])\n", "\n", " current_O = next_O\n", " current_M = next_M\n", " current_L = next_L\n", " remaining = len(current_O)\n", " step += 1\n", "\n", " # 最终归一化\n", " final_output = current_O[0] / current_L[0]\n", " print(f\"归约完成,最终输出形状: {final_output.shape}\")\n", " return final_output\n", "\n", " def flash_decoding_attention_simple(self, q, k, v, block_size=32):\n", " \"\"\"\n", " 简化版本Flash-Decoding实现,包含两个循环。\n", " 需要保存每个块的max值、block_sum_exp值。\n", " 特点:理解直观。\n", " \"\"\"\n", " batch_size, num_heads, seq_len_q, _ = q.shape\n", " seq_len_kv = k.shape[2]\n", " num_blocks = (seq_len_kv + block_size - 1) // block_size\n", "\n", " # 存储每个块的中间结果\n", " block_outputs = []\n", " block_max_vals = []\n", " block_sum_exps = []\n", "\n", " # 第一步:计算每个块的局部结果\n", " for block_idx in range(num_blocks):\n", " start_idx = block_idx * block_size\n", " end_idx = min(start_idx + block_size, seq_len_kv)\n", "\n", " k_block = k[:, :, start_idx:end_idx, :]\n", " v_block = v[:, :, start_idx:end_idx, :]\n", "\n", " # 计算当前块注意力分数\n", " scores_block = torch.matmul(q, k_block.transpose(-2, -1)) / math.sqrt(self.head_dim)\n", " block_max = scores_block.max(dim=-1, keepdim=True).values\n", " exp_scores = torch.exp(scores_block - block_max)\n", " block_sum_exp = exp_scores.sum(dim=-1, keepdim=True)\n", "\n", " # 存储中间结果\n", " block_outputs.append(torch.matmul(exp_scores, v_block))\n", " block_max_vals.append(block_max)\n", " block_sum_exps.append(block_sum_exp)\n", "\n", " # 第二步:合并所有块的结果\n", " # 找到全局最大值\n", " all_max_vals = torch.stack(block_max_vals, dim=0) # [num_blocks, ...]\n", " global_max = all_max_vals.max(dim=0).values # 在每个query位置取最大值\n", "\n", " # 合并归一化因子\n", " total_sum_exp = torch.zeros_like(block_sum_exps[0])\n", " for i in range(num_blocks):\n", " total_sum_exp += block_sum_exps[i] * torch.exp(block_max_vals[i] - global_max)\n", "\n", " # 合并输出\n", " final_output = torch.zeros_like(block_outputs[0])\n", " for i in range(num_blocks):\n", " # 将每个块的贡献调整到全局尺度\n", " weight = torch.exp(block_max_vals[i] - global_max)\n", " final_output += block_outputs[i] * weight\n", "\n", " # 最终归一化\n", " final_output = final_output / total_sum_exp\n", "\n", " # 计算完整注意力权重用于验证\n", " full_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)\n", " full_attention_weights = F.softmax(full_scores, dim=-1)\n", "\n", " return final_output, full_attention_weights\n", "\n", " def verify_with_tolerance(self, batch_size=2, seq_len_q=1, seq_len_kv=1024):\n", " \"\"\"更严格的验证,包含容差检查\"\"\"\n", "\n", " # 生成随机测试数据\n", " torch.manual_seed(42) # 固定随机种子以便复现\n", " q = torch.randn(batch_size, self.num_heads, seq_len_q, self.head_dim)\n", " k = torch.randn(batch_size, self.num_heads, seq_len_kv, self.head_dim)\n", " v = torch.randn(batch_size, self.num_heads, seq_len_kv, self.head_dim)\n", "\n", " print(\"=\" * 70)\n", " print(\"Flash-Decoding 正确性验证\")\n", " print(\"=\" * 70)\n", "\n", " # 传统方法计算\n", " traditional_output, traditional_weights = self.traditional_attention(q, k, v)\n", "\n", " # 尝试不同block_size\n", " block_sizes = [16, 32, 64, 128, 256]\n", "\n", " for block_size in block_sizes:\n", " print(f\"\\n使用block_size={block_size}:\")\n", "\n", " # 一般方法\n", " flash_output1, flash_weights1 = self.flash_decoding_attention(\n", " q, k, v, block_size\n", " )\n", "\n", " # 简化方法\n", " flash_output2, flash_weights2 = self.flash_decoding_attention_simple(\n", " q, k, v, block_size\n", " )\n", "\n", " # 比较结果\n", " diff1 = torch.abs(traditional_output - flash_output1).max().item()\n", " diff2 = torch.abs(traditional_output - flash_output2).max().item()\n", "\n", " # 相对误差\n", " rel_error1 = diff1 / (torch.abs(traditional_output).max().item() + 1e-10)\n", " rel_error2 = diff2 / (torch.abs(traditional_output).max().item() + 1e-10)\n", "\n", " # 检查是否在容差范围内\n", " tolerance = 1e-4\n", " is_correct1 = diff1 < tolerance\n", " is_correct2 = diff2 < tolerance\n", "\n", " print(f\" 一般方法 - 最大绝对误差: {diff1:.2e}, 相对误差: {rel_error1:.2e}, 正确: {is_correct1}\")\n", " print(f\" 简化方法 - 最大绝对误差: {diff2:.2e}, 相对误差: {rel_error2:.2e}, 正确: {is_correct2}\")\n", "\n", " # 如果两种方法都正确,还可以比较它们之间的一致性\n", " if is_correct1 and is_correct2:\n", " method_diff = torch.abs(flash_output1 - flash_output2).max().item()\n", " print(f\" 两种方法间差异: {method_diff:.2e}\")\n", "\n", " return True\n", "\n", " def analyze_numerical_stability(self):\n", " \"\"\"数值稳定性分析\"\"\"\n", "\n", " print(\"\\n\" + \"=\" * 70)\n", " print(\"数值稳定性分析\")\n", " print(\"=\" * 70)\n", "\n", " # 测试不同范围的数值\n", " test_cases = [\n", " (\"小数值范围\", (-1.0, 1.0)),\n", " (\"中等数值范围\", (-10.0, 10.0)),\n", " (\"大数值范围\", (-50.0, 50.0)),\n", " ]\n", "\n", " for name, (min_val, max_val) in test_cases:\n", " print(f\"\\n{name} [{min_val}, {max_val}]:\")\n", "\n", " # 生成特定范围的测试数据\n", " q = torch.rand(1, self.num_heads, 1, self.head_dim) * (max_val - min_val) + min_val\n", " k = torch.rand(1, self.num_heads, 1024, self.head_dim) * (max_val - min_val) + min_val\n", " v = torch.rand(1, self.num_heads, 1024, self.head_dim) * (max_val - min_val) + min_val\n", "\n", " # 传统方法\n", " traditional_output, _ = self.traditional_attention(q, k, v)\n", "\n", " # Flash-Decoding方法\n", " flash_output, _ = self.flash_decoding_attention(q, k, v, block_size=64)\n", "\n", " # 计算误差\n", " diff = torch.abs(traditional_output - flash_output).max().item()\n", "\n", " # 检查是否出现NaN或Inf\n", " has_nan = torch.isnan(flash_output).any().item()\n", " has_inf = torch.isinf(flash_output).any().item()\n", "\n", " print(f\" 最大绝对误差: {diff:.2e}\")\n", " print(f\" 包含NaN: {has_nan}, 包含Inf: {has_inf}\")\n", "\n", " return True\n" ], "metadata": { "id": "VHRh-CuJ-RPy" }, "execution_count": 4, "outputs": [] }, { "cell_type": "code", "source": [ "# 创建演示实例\n", "demo = FlashDecodingDemo(d_model=512, num_heads=8)\n", "\n", "# 1. 验证正确性(使用更严格的验证)\n", "demo.verify_with_tolerance(\n", " batch_size=2,\n", " seq_len_q=1,\n", " seq_len_kv=1024\n", ")\n", "\n", "# 2. 数值稳定性分析\n", "demo.analyze_numerical_stability()\n", "\n", "# 3. 性能对比演示\n", "print(\"\\n\" + \"=\" * 70)\n", "print(\"性能对比演示(小batch_size,长序列)\")\n", "print(\"=\" * 70)\n", "\n", "# 模拟长序列推理场景\n", "seq_lengths = [1024, 4096, 16384, 32768]\n", "\n", "for seq_len in seq_lengths:\n", " print(f\"\\n序列长度: {seq_len}\")\n", "\n", " # 生成测试数据\n", " q = torch.randn(1, 8, 1, 64) # batch=1,单token查询\n", " k = torch.randn(1, 8, seq_len, 64)\n", " v = torch.randn(1, 8, seq_len, 64)\n", "\n", " # 传统方法时间\n", " start = time.time()\n", " traditional_output, _ = demo.traditional_attention(q, k, v)\n", " torch.cuda.synchronize() if torch.cuda.is_available() else None\n", " traditional_time = (time.time() - start) * 1000\n", "\n", " # Flash-Decoding时间\n", " start = time.time()\n", " flash_output, _ = demo.flash_decoding_attention(q, k, v, block_size=256)\n", " torch.cuda.synchronize() if torch.cuda.is_available() else None\n", " flash_time = (time.time() - start) * 1000\n", "\n", " # 验证一致性\n", " diff = torch.abs(traditional_output - flash_output).max().item()\n", "\n", " print(f\" 传统方法: {traditional_time:.2f}ms\")\n", " print(f\" Flash-Decoding: {flash_time:.2f}ms\")\n", " print(f\" 加速比: {traditional_time / flash_time:.2f}x\")\n", " print(f\" 输出差异: {diff:.2e}\")\n", " print(f\" 结果一致: {diff < 1e-4}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "vcjqTTvdMqX7", "outputId": "684e2a68-8213-4bac-f607-0ca6671ce13c" }, "execution_count": 5, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "======================================================================\n", "Flash-Decoding 正确性验证\n", "======================================================================\n", "\n", "使用block_size=16:\n", "\n", "分布式数组实现: 5个计算流\n", "每个流有自己的O、M、L数组\n", "并行处理64个tile...\n", "归约完成,最终输出形状: torch.Size([2, 8, 1, 64])\n", " 一般方法 - 最大绝对误差: 1.27e-07, 相对误差: 8.22e-07, 正确: True\n", " 简化方法 - 最大绝对误差: 1.49e-07, 相对误差: 9.67e-07, 正确: True\n", " 两种方法间差异: 6.71e-08\n", "\n", "使用block_size=32:\n", "\n", "分布式数组实现: 5个计算流\n", "每个流有自己的O、M、L数组\n", "并行处理32个tile...\n", "归约完成,最终输出形状: torch.Size([2, 8, 1, 64])\n", " 一般方法 - 最大绝对误差: 1.42e-07, 相对误差: 9.19e-07, 正确: True\n", " 简化方法 - 最大绝对误差: 1.49e-07, 相对误差: 9.67e-07, 正确: True\n", " 两种方法间差异: 4.47e-08\n", "\n", "使用block_size=64:\n", "\n", "分布式数组实现: 5个计算流\n", "每个流有自己的O、M、L数组\n", "并行处理16个tile...\n", "归约完成,最终输出形状: torch.Size([2, 8, 1, 64])\n", " 一般方法 - 最大绝对误差: 1.27e-07, 相对误差: 8.22e-07, 正确: True\n", " 简化方法 - 最大绝对误差: 1.56e-07, 相对误差: 1.02e-06, 正确: True\n", " 两种方法间差异: 4.47e-08\n", "\n", "使用block_size=128:\n", "\n", "分布式数组实现: 5个计算流\n", "每个流有自己的O、M、L数组\n", "并行处理8个tile...\n", "归约完成,最终输出形状: torch.Size([2, 8, 1, 64])\n", " 一般方法 - 最大绝对误差: 1.27e-07, 相对误差: 8.22e-07, 正确: True\n", " 简化方法 - 最大绝对误差: 1.27e-07, 相对误差: 8.22e-07, 正确: True\n", " 两种方法间差异: 4.47e-08\n", "\n", "使用block_size=256:\n", "\n", "分布式数组实现: 5个计算流\n", "每个流有自己的O、M、L数组\n", "并行处理4个tile...\n", "归约完成,最终输出形状: torch.Size([2, 8, 1, 64])\n", " 一般方法 - 最大绝对误差: 1.56e-07, 相对误差: 1.02e-06, 正确: True\n", " 简化方法 - 最大绝对误差: 1.42e-07, 相对误差: 9.19e-07, 正确: True\n", " 两种方法间差异: 2.24e-08\n", "\n", "======================================================================\n", "数值稳定性分析\n", "======================================================================\n", "\n", "小数值范围 [-1.0, 1.0]:\n", "\n", "分布式数组实现: 5个计算流\n", "每个流有自己的O、M、L数组\n", "并行处理16个tile...\n", "归约完成,最终输出形状: torch.Size([1, 8, 1, 64])\n", " 最大绝对误差: 5.22e-08\n", " 包含NaN: False, 包含Inf: False\n", "\n", "中等数值范围 [-10.0, 10.0]:\n", "\n", "分布式数组实现: 5个计算流\n", "每个流有自己的O、M、L数组\n", "并行处理16个tile...\n", "归约完成,最终输出形状: torch.Size([1, 8, 1, 64])\n", " 最大绝对误差: 1.91e-06\n", " 包含NaN: False, 包含Inf: False\n", "\n", "大数值范围 [-50.0, 50.0]:\n", "\n", "分布式数组实现: 5个计算流\n", "每个流有自己的O、M、L数组\n", "并行处理16个tile...\n", "归约完成,最终输出形状: torch.Size([1, 8, 1, 64])\n", " 最大绝对误差: 0.00e+00\n", " 包含NaN: False, 包含Inf: False\n", "\n", "======================================================================\n", "性能对比演示(小batch_size,长序列)\n", "======================================================================\n", "\n", "序列长度: 1024\n", "\n", "分布式数组实现: 5个计算流\n", "每个流有自己的O、M、L数组\n", "并行处理4个tile...\n", "归约完成,最终输出形状: torch.Size([1, 8, 1, 64])\n", " 传统方法: 0.96ms\n", " Flash-Decoding: 4.01ms\n", " 加速比: 0.24x\n", " 输出差异: 1.86e-07\n", " 结果一致: True\n", "\n", "序列长度: 4096\n", "\n", "分布式数组实现: 5个计算流\n", "每个流有自己的O、M、L数组\n", "并行处理16个tile...\n", "归约完成,最终输出形状: torch.Size([1, 8, 1, 64])\n", " 传统方法: 5.04ms\n", " Flash-Decoding: 22.43ms\n", " 加速比: 0.22x\n", " 输出差异: 1.19e-07\n", " 结果一致: True\n", "\n", "序列长度: 16384\n", "\n", "分布式数组实现: 5个计算流\n", "每个流有自己的O、M、L数组\n", "并行处理64个tile...\n", "归约完成,最终输出形状: torch.Size([1, 8, 1, 64])\n", " 传统方法: 43.96ms\n", " Flash-Decoding: 97.64ms\n", " 加速比: 0.45x\n", " 输出差异: 1.23e-07\n", " 结果一致: True\n", "\n", "序列长度: 32768\n", "\n", "分布式数组实现: 5个计算流\n", "每个流有自己的O、M、L数组\n", "并行处理128个tile...\n", "归约完成,最终输出形状: torch.Size([1, 8, 1, 64])\n", " 传统方法: 57.21ms\n", " Flash-Decoding: 204.57ms\n", " 加速比: 0.28x\n", " 输出差异: 1.31e-07\n", " 结果一致: True\n" ] } ] }, { "cell_type": "markdown", "source": [ "## 2.2 方式二:保存log-sum-exp" ], "metadata": { "id": "jN-owb68xOC2" } }, { "cell_type": "code", "source": [ "import torch\n", "import torch.nn.functional as F\n", "import math\n", "\n", "class FinalFlashDecodingTiling:\n", " \"\"\"\n", " 最终版Flash-Decoding Tiling实现\n", " 仅存储O和S,使用两步合并算法\n", " \"\"\"\n", "\n", " def __init__(self, d_model: int = 512, num_heads: int = 8):\n", " self.d_model = d_model\n", " self.num_heads = num_heads\n", " self.head_dim = d_model // num_heads\n", "\n", " def traditional_attention(self, q, k, v):\n", " \"\"\"基准:传统连续注意力计算\"\"\"\n", " scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)\n", " attention_weights = F.softmax(scores, dim=-1)\n", " output = torch.matmul(attention_weights, v)\n", " return output\n", "\n", " def compute_stream_output(self, q, k_tile, v_tile):\n", " \"\"\"\n", " 计算单个tile的流输出\n", " 返回: (O_i, S_i) 其中 S_i = m_i + log(l_i)\n", " \"\"\"\n", " # 计算当前tile的注意力分数\n", " S_tile = torch.matmul(q, k_tile.transpose(-2, -1)) / math.sqrt(self.head_dim)\n", "\n", " # 计算m_i和l_i\n", " m_i = S_tile.max(dim=-1, keepdim=True).values\n", " exp_tile = torch.exp(S_tile - m_i)\n", " l_i = exp_tile.sum(dim=-1, keepdim=True)\n", "\n", " # 计算加权和 O_i(已经是归一化的)\n", " O_i = torch.matmul(exp_tile, v_tile) / l_i\n", "\n", " # 计算 S_i = m_i + log(l_i)\n", " log_l_i = torch.log(l_i + 1e-12)\n", " S_i = m_i + log_l_i\n", "\n", " return O_i, S_i\n", "\n", " def merge_streams_two_step(self, streams_data):\n", " \"\"\"\n", " 两步合并算法:\n", " 1. 迭代计算全局 S_global\n", " 2. 用 S_global 修正每个流的输出贡献\n", " \"\"\"\n", " if not streams_data:\n", " return None\n", "\n", " # 提取所有流的S_i\n", " S_list = [S_i for _, S_i in streams_data]\n", "\n", " # 步骤1: 迭代计算全局 S_global (S_lst)\n", " S_global = S_list[0].clone()\n", "\n", " for i in range(1, len(S_list)):\n", " S_i = S_list[i]\n", " S_max = torch.maximum(S_global, S_i)\n", " S_min = torch.minimum(S_global, S_i)\n", " # 使用log(1+exp(x))的稳定计算\n", " log_term = torch.log1p(torch.exp(S_min - S_max))\n", " S_global = S_max + log_term\n", "\n", " # 步骤2: 修正每个流的输出贡献\n", " O_global = torch.zeros_like(streams_data[0][0])\n", "\n", " for O_i, S_i in streams_data:\n", " # 计算该流对全局的贡献权重\n", " weight = torch.exp(S_i - S_global)\n", " # 累加加权贡献\n", " O_global += O_i * weight\n", "\n", " return O_global\n", "\n", " def flash_decoding_with_lse(self, q, k, v,\n", " tile_size_kv: int = 256,\n", " num_streams: int = 4):\n", " \"\"\"\n", " Flash-Decoding 仅存储O和S\n", " \"\"\"\n", " batch_size, num_heads, seq_len_q, head_dim = q.shape\n", " seq_len_kv = k.shape[2]\n", " num_tiles = (seq_len_kv + tile_size_kv - 1) // tile_size_kv\n", "\n", " print(f\"使用两步合并算法Flash-Decoding: {num_streams}个流\")\n", "\n", " # 初始化流数组\n", " streams_data = []\n", "\n", " for stream_id in range(num_streams):\n", " # 每个流存储(O_i, S_i)\n", " O_stream = torch.zeros_like(q)\n", " S_stream = torch.full((batch_size, num_heads, seq_len_q, 1),\n", " -float('inf'), device=q.device, dtype=q.dtype)\n", " streams_data.append((O_stream, S_stream))\n", "\n", " # 处理每个tile\n", " print(f\"处理{num_tiles}个tile...\")\n", "\n", " for tile_idx in range(num_tiles):\n", " stream_id = tile_idx % num_streams\n", "\n", " start_idx = tile_idx * tile_size_kv\n", " end_idx = min(start_idx + tile_size_kv, seq_len_kv)\n", "\n", " k_tile = k[:, :, start_idx:end_idx, :]\n", " v_tile = v[:, :, start_idx:end_idx, :]\n", "\n", " # 计算当前tile的输出\n", " O_i, S_i = self.compute_stream_output(q, k_tile, v_tile)\n", "\n", " # 获取当前流的累加器\n", " O_acc, S_acc = streams_data[stream_id]\n", "\n", " # 合并当前tile结果到流累加器\n", " if torch.all(S_acc == -float('inf')):\n", " streams_data[stream_id] = (O_i, S_i)\n", " else:\n", " # 使用两步法合并当前tile到流累加器\n", " # 先计算合并后的S\n", " S_max = torch.maximum(S_acc, S_i)\n", " S_min = torch.minimum(S_acc, S_i)\n", " log_term = torch.log1p(torch.exp(S_min - S_max))\n", " S_merged = S_max + log_term\n", "\n", " # 修正两个部分的贡献\n", " weight_acc = torch.exp(S_acc - S_merged)\n", " weight_i = torch.exp(S_i - S_merged)\n", " O_merged = O_acc * weight_acc + O_i * weight_i\n", "\n", " streams_data[stream_id] = (O_merged, S_merged)\n", "\n", " print(f\"所有tile处理完成,开始归约所有流...\")\n", " # 归约所有流的结果\n", " O_final = self.merge_streams_two_step(streams_data)\n", "\n", " return O_final\n", "\n", " def verify_correctness(self, seq_len_kv: int = 2048):\n", " \"\"\"验证实现的正确性\"\"\"\n", "\n", " torch.manual_seed(42)\n", " batch_size = 2\n", " seq_len_q = 1\n", "\n", " q = torch.randn(batch_size, self.num_heads, seq_len_q, self.head_dim)\n", " k = torch.randn(batch_size, self.num_heads, seq_len_kv, self.head_dim)\n", " v = torch.randn(batch_size, self.num_heads, seq_len_kv, self.head_dim)\n", "\n", " print(\"=\" * 80)\n", " print(\"基于lse的Flash-Decoding\")\n", " print(\"=\" * 80)\n", "\n", " # 基准测试:传统方法\n", " print(f\"\\n1. 传统注意力计算...\")\n", " baseline = self.traditional_attention(q, k, v)\n", "\n", " # 基于lse的Flash-Decoding\n", " print(f\"\\n2. 基于lse的Flash-Decoding...\")\n", " output = self.flash_decoding_with_lse(q, k, v)\n", "\n", " # 验证正确性\n", " diff = torch.abs(baseline - output).max().item()\n", " rel_error = diff / torch.abs(baseline).max().item()\n", "\n", " print(f\"\\n验证结果:\")\n", " print(f\" 最大绝对误差: {diff:.2e}\")\n", " print(f\" 相对误差: {rel_error:.2e}\")\n", " print(f\" 合并算法是否正确: {diff < 1e-4}\")\n", "\n", " # 数学正确性验证\n", " print(f\"\\n3. 数学正确性验证(小规模测试)...\")\n", "\n", " # 创建一个小测试\n", " torch.manual_seed(123)\n", " q_test = torch.randn(1, 2, 1, 4)\n", " k_test = torch.randn(1, 2, 8, 4)\n", " v_test = torch.randn(1, 2, 8, 4)\n", "\n", " baseline_test = self.traditional_attention(q_test, k_test, v_test)\n", " output_test = self.flash_decoding_with_lse(q_test, k_test, v_test,\n", " tile_size_kv=4, num_streams=2)\n", "\n", " diff_test = torch.abs(baseline_test - output_test).max().item()\n", " print(f\" 小规模测试最大绝对误差: {diff_test:.2e}\")\n", " print(f\" 小规模测试是否正确: {diff_test < 1e-4}\")\n", "\n", " return {\n", " 'baseline': baseline,\n", " 'output': output,\n", " 'error': diff,\n", " 'correct': diff < 1e-4\n", " }\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo = FinalFlashDecodingTiling(d_model=512, num_heads=8)\n", "\n", " # 验证最终实现的正确性\n", " print(\"基于lse的Flash-Decoding验证\")\n", " results = demo.verify_correctness(seq_len_kv=2048)\n", "\n", " if results['correct']:\n", " print(\"\\n✅ 算法现在可以正确合并结果。\")\n", " else:\n", " print(f\"\\n❌ 仍然存在问题,误差: {results['error']:.2e}\")\n", "\n", " # 性能测试\n", " print(\"\\n\" + \"=\" * 80)\n", " print(\"性能测试\")\n", " print(\"=\" * 80)\n", "\n", " import time\n", " torch.manual_seed(42)\n", " batch_size = 2\n", " seq_len_q = 1\n", " seq_len_kv = 8192\n", "\n", " q = torch.randn(batch_size, 8, seq_len_q, 64)\n", " k = torch.randn(batch_size, 8, seq_len_kv, 64)\n", " v = torch.randn(batch_size, 8, seq_len_kv, 64)\n", "\n", " # 传统方法\n", " start = time.time()\n", " baseline = demo.traditional_attention(q, k, v)\n", " torch.cuda.synchronize() if torch.cuda.is_available() else None\n", " trad_time = time.time() - start\n", "\n", " # 优化方法\n", " start = time.time()\n", " output = demo.flash_decoding_with_lse(q, k, v, tile_size_kv=256)\n", " torch.cuda.synchronize() if torch.cuda.is_available() else None\n", " opt_time = time.time() - start\n", "\n", " diff = torch.abs(baseline - output).max().item()\n", "\n", " print(f\"序列长度: {seq_len_kv}\")\n", " print(f\"传统方法时间: {trad_time:.4f}s\")\n", " print(f\"优化方法时间: {opt_time:.4f}s\")\n", " print(f\"加速比: {trad_time/opt_time:.2f}x\")\n", " print(f\"误差: {diff:.2e}\")\n", " print(f\"是否一致: {diff < 1e-4}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3HNSKcq4PklO", "outputId": "b3fa12ce-ffe3-46aa-b0ab-4ac91fb69599" }, "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "基于lse的Flash-Decoding验证\n", "================================================================================\n", "基于lse的Flash-Decoding\n", "================================================================================\n", "\n", "1. 传统注意力计算...\n", "\n", "2. 基于lse的Flash-Decoding...\n", "使用两步合并算法Flash-Decoding: 4个流\n", "处理8个tile...\n", "所有tile处理完成,开始归约所有流...\n", "\n", "验证结果:\n", " 最大绝对误差: 2.53e-07\n", " 相对误差: 1.68e-06\n", " 合并算法是否正确: True\n", "\n", "3. 数学正确性验证(小规模测试)...\n", "使用两步合并算法Flash-Decoding: 2个流\n", "处理2个tile...\n", "所有tile处理完成,开始归约所有流...\n", " 小规模测试最大绝对误差: 1.19e-07\n", " 小规模测试是否正确: True\n", "\n", "✅ 算法现在可以正确合并结果。\n", "\n", "================================================================================\n", "性能测试\n", "================================================================================\n", "使用两步合并算法Flash-Decoding: 4个流\n", "处理32个tile...\n", "所有tile处理完成,开始归约所有流...\n", "序列长度: 8192\n", "传统方法时间: 0.0163s\n", "优化方法时间: 0.0453s\n", "加速比: 0.36x\n", "误差: 1.60e-07\n", "是否一致: True\n" ] } ] }, { "cell_type": "markdown", "source": [ "## 2.3 方式二的等价性证明" ], "metadata": { "id": "XTzvsRibVh74" } }, { "cell_type": "code", "source": [ "def mathematical_equivalence_proof():\n", " \"\"\"严格证明两步合并算法与传统方法的等价性\"\"\"\n", "\n", " print(\"\\n\" + \"=\" * 80)\n", " print(\"数学等价性证明\")\n", " print(\"=\" * 80)\n", "\n", " print(\"传统注意力计算:\")\n", " print(\" 设全局有N个注意力分数\")\n", " print(\" 全局最大值 M = max(score_j), j=1..N\")\n", " print(\" 全局指数和 L = Σ_j exp(score_j - M)\")\n", " print(\" 注意力输出 = Σ_j [exp(score_j - M) * v_j] / L\")\n", " print()\n", "\n", " print(\"Flash-Decoding分块计算:\")\n", " print(\" 将N个分数分成K个块,每个块i有:\")\n", " print(\" m_i = 块内最大值\")\n", " print(\" l_i = Σ_{j∈块i} exp(score_j - m_i)\")\n", " print(\" O_i = Σ_{j∈块i} [exp(score_j - m_i) * v_j] / l_i\")\n", " print(\" S_i = m_i + log(l_i)\")\n", " print()\n", "\n", " print(\"传统合并算法:\")\n", " print(\" 1. 找到全局最大值 M_global = max(m_i)\")\n", " print(\" 2. 调整每个块的贡献:\")\n", " print(\" 调整后l_i' = l_i × exp(m_i - M_global)\")\n", " print(\" 调整后O_i' = O_i × exp(m_i - M_global)\")\n", " print(\" 3. 合并:\")\n", " print(\" L_global = Σ_i l_i'\")\n", " print(\" O_global = Σ_i O_i' / L_global\")\n", " print()\n", "\n", " print(\"两步合并算法:\")\n", " print(\" 1. 计算 S_global = log(Σ_i exp(S_i))\")\n", " print(\" 2. 计算 O_global = Σ_i [O_i × exp(S_i - S_global)]\")\n", " print()\n", "\n", " print(\"证明等价性:\")\n", " print(\" 步骤1:证明 exp(S_i) = l_i × exp(m_i)\")\n", " print(\" 因为 S_i = m_i + log(l_i)\")\n", " print(\" 所以 exp(S_i) = exp(m_i + log(l_i)) = l_i × exp(m_i)\")\n", " print()\n", "\n", " print(\" 步骤2:证明 exp(S_global) = Σ_i [l_i × exp(m_i)]\")\n", " print(\" 因为 S_global = log(Σ_i exp(S_i))\")\n", " print(\" 所以 exp(S_global) = Σ_i exp(S_i) = Σ_i [l_i × exp(m_i)]\")\n", " print()\n", "\n", " print(\" 步骤3:证明 exp(S_i - S_global) = [l_i × exp(m_i)] / Σ_j [l_j × exp(m_j)]\")\n", " print(\" exp(S_i - S_global) = exp(S_i) / exp(S_global)\")\n", " print(\" = [l_i × exp(m_i)] / Σ_j [l_j × exp(m_j)]\")\n", " print()\n", "\n", " print(\" 步骤4:证明 O_i × exp(S_i - S_global) = [O_i × l_i × exp(m_i)] / Σ_j [l_j × exp(m_j)]\")\n", " print(\" 这显然成立\")\n", " print()\n", "\n", " print(\" 步骤5:证明 Σ_i [O_i × exp(S_i - S_global)] 等于传统合并结果\")\n", " print(\" 传统合并:O_global = Σ_i [O_i × l_i × exp(m_i - M_global)] / Σ_i [l_i × exp(m_i - M_global)]\")\n", " print(\" = Σ_i [O_i × l_i × exp(m_i)] / Σ_i [l_i × exp(m_i)] (乘以exp(M_global))\")\n", " print(\" = Σ_i [O_i × exp(S_i)] / Σ_i [exp(S_i)]\")\n", " print(\" = Σ_i [O_i × exp(S_i - S_global)] (由步骤3)\")\n", " print()\n", "\n", "mathematical_equivalence_proof()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "c3EWO6gp7qJX", "outputId": "53b76aa8-9d21-4e72-c710-72b3ec333e90" }, "execution_count": 7, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "================================================================================\n", "数学等价性证明\n", "================================================================================\n", "传统注意力计算:\n", " 设全局有N个注意力分数\n", " 全局最大值 M = max(score_j), j=1..N\n", " 全局指数和 L = Σ_j exp(score_j - M)\n", " 注意力输出 = Σ_j [exp(score_j - M) * v_j] / L\n", "\n", "Flash-Decoding分块计算:\n", " 将N个分数分成K个块,每个块i有:\n", " m_i = 块内最大值\n", " l_i = Σ_{j∈块i} exp(score_j - m_i)\n", " O_i = Σ_{j∈块i} [exp(score_j - m_i) * v_j] / l_i\n", " S_i = m_i + log(l_i)\n", "\n", "传统合并算法:\n", " 1. 找到全局最大值 M_global = max(m_i)\n", " 2. 调整每个块的贡献:\n", " 调整后l_i' = l_i × exp(m_i - M_global)\n", " 调整后O_i' = O_i × exp(m_i - M_global)\n", " 3. 合并:\n", " L_global = Σ_i l_i'\n", " O_global = Σ_i O_i' / L_global\n", "\n", "两步合并算法:\n", " 1. 计算 S_global = log(Σ_i exp(S_i))\n", " 2. 计算 O_global = Σ_i [O_i × exp(S_i - S_global)]\n", "\n", "证明等价性:\n", " 步骤1:证明 exp(S_i) = l_i × exp(m_i)\n", " 因为 S_i = m_i + log(l_i)\n", " 所以 exp(S_i) = exp(m_i + log(l_i)) = l_i × exp(m_i)\n", "\n", " 步骤2:证明 exp(S_global) = Σ_i [l_i × exp(m_i)]\n", " 因为 S_global = log(Σ_i exp(S_i))\n", " 所以 exp(S_global) = Σ_i exp(S_i) = Σ_i [l_i × exp(m_i)]\n", "\n", " 步骤3:证明 exp(S_i - S_global) = [l_i × exp(m_i)] / Σ_j [l_j × exp(m_j)]\n", " exp(S_i - S_global) = exp(S_i) / exp(S_global)\n", " = [l_i × exp(m_i)] / Σ_j [l_j × exp(m_j)]\n", "\n", " 步骤4:证明 O_i × exp(S_i - S_global) = [O_i × l_i × exp(m_i)] / Σ_j [l_j × exp(m_j)]\n", " 这显然成立\n", "\n", " 步骤5:证明 Σ_i [O_i × exp(S_i - S_global)] 等于传统合并结果\n", " 传统合并:O_global = Σ_i [O_i × l_i × exp(m_i - M_global)] / Σ_i [l_i × exp(m_i - M_global)]\n", " = Σ_i [O_i × l_i × exp(m_i)] / Σ_i [l_i × exp(m_i)] (乘以exp(M_global))\n", " = Σ_i [O_i × exp(S_i)] / Σ_i [exp(S_i)]\n", " = Σ_i [O_i × exp(S_i - S_global)] (由步骤3)\n", "\n" ] } ] } ] }