{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "authorship_tag": "ABX9TyNfhX/ftUmKowz5v0shV9Eh" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "# 量化计算的基本认识\n", "\n", "介绍:在大模型推理场景中,量化技术的应用及其优势,包括降低显存占用、减少计算量以及优化数据传输开销。文章将围绕以下两个基础内容展开:量化误差的产生机制与量化计算的基本过程。\n", "\n", "相关文章:[大模型推理量化(Quantiztion)基础速览](https://zhuanlan.zhihu.com/p/2005335401469083798)\n", "\n", "Author: kaiyuan\n", "\n", "Email: kyxie@zju.edu.cn" ], "metadata": { "id": "Jy_VJEWOtE6B" } }, { "cell_type": "markdown", "source": [ "## 1 量化误差的产生机制\n", "\n", "演示并对比INT8和FP8 (E4M3)两种量化格式的量化误差。定义了量化和反量化函数,并使用一组测试数据来计算和展示不同量化方案下的原始值、量化值、反量化值以及量化误差。\n", "\n", "INT8 量化误差分析:\n", "\n", "- 对于小数值(如 0.001, 0.123, 1.234),INT8 量化能较好地保留精度,误差相对较小。\n", "- 当原始值接近或超出INT8的表示范围(-128到127乘以缩放因子0.1,即-12.8到12.7)时,量化误差显著增大。例如,原始值127.9、255.5、-300.0、448.0、-448.0,由于被截断到INT8的最大或最小值(12.7或-12.8),导致反量化后的值与原始值偏差巨大,误差非常高。\n", "\n", "FP8 (E4M3) 量化误差分析:\n", "- 对于小数值,FP8 (E4M3)也表现出相似的小误差。\n", "- FP8 (E4M3)的表示范围相对INT8更大(最大448乘以缩放因子0.1,即44.8)。因此,对于像127.9这样的值,它虽然也被截断,但截断发生在更大的范围,其反量化误差(83.1)比INT8(115.2)要小。\n", "- 对于超出FP8 (E4M3)表示范围的值(例如255.5,反量化上限是44.8),同样会发生溢出截断,导致较大的误差。\n", "\n", "两种量化方法在处理其有效表示范围内的值时,都能提供一定的精度。但当原始值超出其各自的表示范围时,都会发生严重的量化误差,尤其是在截断到最大或最小值时。FP8(E4M3)由于其指数部分的特性,通常能表示更大的动态范围,因此在某些情况下可以比定点INT8减少溢出带来的误差,但其精度粒度可能不如INT8均匀。" ], "metadata": { "id": "5K9o57pYJ3Sd" } }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "BBxUiV3CtAOU", "outputId": "d47fdb89-5e61-434d-bbd5-992263741420" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "原始值\tINT8量化值\tINT8反量化值\tINT8误差\n", "0.00100\t0\t0.00000\t0.00100\n", "0.12300\t1\t0.10000\t0.02300\n", "1.23400\t12\t1.20000\t0.03400\n", "127.90000\t127\t12.70000\t115.20000\n", "255.50000\t127\t12.70000\t242.80000\n", "-300.00000\t-128\t-12.80000\t287.20000\n", "448.00000\t127\t12.70000\t435.30000\n", "-448.00000\t-128\t-12.80000\t435.20000\n", "\n", "原始值\tFP8 (E4M3)量化值\tFP8 (E4M3)反量化值\tFP8 (E4M3)误差\n", "0.00100\t0.0\t0.00000\t0.00100\n", "0.12300\t1.0\t0.10000\t0.02300\n", "1.23400\t12.0\t1.20000\t0.03400\n", "127.90000\t448.0\t44.80000\t83.10000\n", "255.50000\t448.0\t44.80000\t210.70000\n", "-300.00000\t-448.0\t-44.80000\t255.20000\n", "448.00000\t448.0\t44.80000\t403.20000\n", "-448.00000\t-448.0\t-44.80000\t403.20000\n" ] } ], "source": [ "# -*- coding: gbk -*-\n", "import numpy as np\n", "\n", "# 定义量化和反量化函数\n", "def quantize_int8(value, scale_factor):\n", " \"\"\"\n", " INT8量化:将浮点数映射到INT8范围(-128到127)。\n", " \"\"\"\n", " quantized_value = int(np.round(value / scale_factor))\n", " quantized_value = np.clip(quantized_value, -128, 127) # 限制在INT8范围内\n", " return quantized_value\n", "\n", "def dequantize_int8(quantized_value, scale_factor):\n", " \"\"\"\n", " INT8反量化:将量化后的整数还原为浮点数。\n", " \"\"\"\n", " return quantized_value * scale_factor\n", "\n", "def quantize_fp8_e4m3(value, scale_factor):\n", " \"\"\"\n", " FP8 (E4M3)量化:将浮点数映射到FP8 (E4M3)格式。\n", " \"\"\"\n", " # 将值转换为无符号浮点数\n", " unscaled_value = value / scale_factor\n", " # 模拟FP8 (E4M3)的表示范围\n", " if abs(unscaled_value) > 448:\n", " return np.sign(unscaled_value) * 448 # 溢出处理\n", " return np.round(unscaled_value)\n", "\n", "def dequantize_fp8_e4m3(quantized_value, scale_factor):\n", " \"\"\"\n", " FP8 (E4M3)反量化:将量化后的值还原为浮点数。\n", " \"\"\"\n", " return quantized_value * scale_factor\n", "\n", "# 测试数据\n", "values = [\n", " 0.001, # 非常接近零的小数\n", " 0.123, # 小数\n", " 1.234, # 小数\n", " 127.9, # 接近 INT8 上限\n", " 255.5, # 超出 INT8 范围\n", " -300.0, # 负数且超出 INT8 范围\n", " 448.0, # FP8 的最大值\n", " -448.0 # FP8 的最小值\n", "]\n", "\n", "# 缩放因子\n", "scale_factor = 0.1\n", "\n", "# 打印标题\n", "print(\"原始值\\tINT8量化值\\tINT8反量化值\\tINT8误差\")\n", "for value in values:\n", " # INT8量化与反量化\n", " int8_quantized = quantize_int8(value, scale_factor)\n", " int8_dequantized = dequantize_int8(int8_quantized, scale_factor)\n", " int8_error = abs(value - int8_dequantized)\n", " print(f\"{value:.5f}\\t{int8_quantized}\\t{int8_dequantized:.5f}\\t{int8_error:.5f}\")\n", "\n", "print(\"\\n原始值\\tFP8 (E4M3)量化值\\tFP8 (E4M3)反量化值\\tFP8 (E4M3)误差\")\n", "for value in values:\n", " # FP8 (E4M3)量化与反量化\n", " fp8_quantized = quantize_fp8_e4m3(value, scale_factor)\n", " fp8_dequantized = dequantize_fp8_e4m3(fp8_quantized, scale_factor)\n", " fp8_error = abs(value - fp8_dequantized)\n", " print(f\"{value:.5f}\\t{fp8_quantized}\\t{fp8_dequantized:.5f}\\t{fp8_error:.5f}\")" ] }, { "cell_type": "markdown", "source": [ "## 2 量化计算过程演示\n", "\n", "矩阵计算: c = a * b 其中a,b为int类型,b为fp31。\n", "\n", "采用Trition库完成运算,本例需要使用GPU,建议采用docker容器。\n", "```\n", "docker pull nvcr.io/nvidia/sglang:26.01-py3\n", "```\n", "\n", "测试机器信息:\n", "- NVIDIA A100-SXM4-80GB\n", "- NVIDIA-SMI 570.172.08\n", "- Driver Version: 570.172.08\n", "- CUDA Version: 13.1\n", "\n", "\n", "说明:\n", "\n", "1. **数据类型流动**:\n", " * **输入**:`torch.int8`。\n", " * **计算**:`tl.dot(a, b)` 内部是 `INT8 * INT8 -> INT32`。我们立即 `.to(tl.float32)` 并累加到FP32累加器 `acc` 中。\n", " * **输出**:FP32,乘以缩放因子后直接存储。\n", "\n", "2. **为什么中间用INT32,不直接在FP32里乘?**\n", " * 虽然可以直接把INT8 load成FP32再相乘,但会浪费Tensor Core的INT8算力。标准做法是用`tl.dot`输出INT32再转FP32,兼顾精度与速度。\n", "\n", "3. **量化缩放**:\n", " * 这是INT8推理的标准流程:A_int8 * W_int8 = Y_int32,然后 `Y_fp32 = Y_int32 * scale_a * scale_b`。\n", " * 代码中`scale_a_ptr`和`scale_b_ptr`是标量指针,演示per-tensor缩放。扩展到per-token/per-channel只需改为向量加载。" ], "metadata": { "id": "hEOHCN_SIxvN" } }, { "cell_type": "code", "source": [ "# -*- coding: gbk -*-\n", "import torch\n", "import triton\n", "import triton.language as tl\n", "import numpy as np\n", "\n", "# ------------------------------------------------------------\n", "# 高性能分块版本:INT8 输入,FP32 输出(带量化缩放)\n", "# 每个线程块负责计算一个 [BLOCK_M, BLOCK_N] 的输出分块\n", "# ------------------------------------------------------------\n", "@triton.jit\n", "def int8_gemm_tiled_kernel(\n", " # 指针\n", " a_ptr, b_ptr, c_ptr,\n", " scale_a_ptr, scale_b_ptr, # 每张量缩放因子(FP32)\n", " M, N, K,\n", " stride_am, stride_ak,\n", " stride_bk, stride_bn,\n", " stride_cm, stride_cn,\n", " # 元参数:分块大小\n", " BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n", "):\n", " # 线程块在输出矩阵中的位置\n", " pid_m = tl.program_id(0)\n", " pid_n = tl.program_id(1)\n", "\n", " # 该块负责的行的范围\n", " offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n", " offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n", " offs_k = tl.arange(0, BLOCK_K)\n", "\n", " # 累加器:FP32!直接以FP32累加,避免后续类型转换开销\n", " acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n", "\n", " # 遍历 K 维度,每次处理 BLOCK_K 个元素\n", " for k in range(0, K, BLOCK_K):\n", " # ---- 1. 创建 A 的分块指针 (INT8) ----\n", " a_ptrs = a_ptr + (offs_m[:, None] * stride_am + (k + offs_k[None, :]) * stride_ak)\n", " # ---- 2. 创建 B 的分块指针 (INT8) ----\n", " b_ptrs = b_ptr + ((k + offs_k[:, None]) * stride_bk + offs_n[None, :] * stride_bn)\n", "\n", " # 加载 INT8 数据,自动提升为 INT32 供 tl.dot 使用\n", " a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & ((k + offs_k[None, :]) < K), other=0)\n", " b = tl.load(b_ptrs, mask=((k + offs_k[:, None]) < K) & (offs_n[None, :] < N), other=0)\n", "\n", " # ---- 3. 核心矩阵乘:INT8 * INT8 -> INT32 ----\n", " # tl.dot 要求输入至少是 INT16,这里 INT8 会自动扩展,输出 INT32\n", " acc += tl.dot(a, b).to(tl.float32) # 关键:INT32 转 FP32 后累加\n", "\n", " # ---- 4. 量化反量化:应用缩放因子,得到最终 FP32 输出 ----\n", " scale_a = tl.load(scale_a_ptr) # per-tensor 激活缩放\n", " scale_b = tl.load(scale_b_ptr) # per-tensor 权重缩放\n", " c = acc * (scale_a * scale_b) # 公式:C_fp32 = (A_int8 * B_int8) * (scale_a * scale_b)\n", "\n", " # ---- 5. 存储 FP32 结果 ----\n", " c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn\n", " tl.store(c_ptrs, c, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))\n", "\n", "\n", "def int8_gemm_fp32_output(a_int8, b_int8, scale_a=1.0, scale_b=1.0):\n", " \"\"\"\n", " INT8 矩阵乘,FP32 输出\n", " - a_int8, b_int8: shape (M, K), (K, N),torch.int8,CUDA\n", " - scale_a, scale_b: 每张量缩放因子,FP32,用于反量化\n", " - 返回: torch.float32 矩阵,值为 (a_int8 * b_int8) * (scale_a * scale_b)\n", " \"\"\"\n", " assert a_int8.is_cuda and b_int8.is_cuda\n", " assert a_int8.dtype == torch.int8 and b_int8.dtype == torch.int8\n", " M, K = a_int8.shape\n", " K_, N = b_int8.shape\n", " assert K == K_\n", "\n", " # 输出 FP32\n", " c_fp32 = torch.empty((M, N), device=a_int8.device, dtype=torch.float32)\n", "\n", " # 缩放因子作为标量张量传入内核\n", " scale_a_t = torch.tensor(scale_a, device=a_int8.device, dtype=torch.float32)\n", " scale_b_t = torch.tensor(scale_b, device=a_int8.device, dtype=torch.float32)\n", "\n", " # 分块大小(可调)\n", " BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32\n", " grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))\n", "\n", " int8_gemm_tiled_kernel[grid](\n", " a_int8, b_int8, c_fp32,\n", " scale_a_t, scale_b_t,\n", " M, N, K,\n", " a_int8.stride(0), a_int8.stride(1),\n", " b_int8.stride(0), b_int8.stride(1),\n", " c_fp32.stride(0), c_fp32.stride(1),\n", " BLOCK_M, BLOCK_N, BLOCK_K,\n", " )\n", " return c_fp32\n", "\n", "\n", "# ------------------------------------------------------------\n", "# 测试验证\n", "# ------------------------------------------------------------\n", "def test_int8_gemm_fp32():\n", " torch.manual_seed(42)\n", " M, N, K = 128, 128, 64 # 稍微放大,让分块效果更明显\n", "\n", " # 随机 INT8 矩阵(范围 -128~127)\n", " a = torch.randint(-128, 127, (M, K), device='cuda', dtype=torch.int8)\n", " b = torch.randint(-128, 127, (K, N), device='cuda', dtype=torch.int8)\n", "\n", " # 随机缩放因子(模拟量化反量化)\n", " scale_a = np.random.uniform(0.01, 0.1)\n", " scale_b = np.random.uniform(0.01, 0.1)\n", "\n", " # Triton FP32 输出\n", " c_triton = int8_gemm_fp32_output(a, b, scale_a, scale_b)\n", "\n", " # NumPy 参考:INT8 -> FP32 -> 乘缩放\n", " a_np = a.cpu().numpy().astype(np.float32)\n", " b_np = b.cpu().numpy().astype(np.float32)\n", " c_np = (a_np @ b_np) * (scale_a * scale_b)\n", "\n", " # 误差容忍度(INT8 量化本身有舍入误差)\n", " torch.testing.assert_close(c_triton.cpu(), torch.from_numpy(c_np), rtol=1e-2, atol=1e-2)\n", " print(\"测试通过:INT8 输入,FP32 输出,带量化缩放\")\n", "\n", "if __name__ == \"__main__\":\n", " test_int8_gemm_fp32()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2i5NxjeiJuCl", "outputId": "5d88027e-5135-4145-9e10-215b73d9493f" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "测试通过:INT8 输入,FP32 输出,带量化缩放\n" ] } ] } ] }