{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Dlv8N4uWtXcN" }, "source": [ "\n", "\n", "\n", "\n", "\n", "
\n", "\n", "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka
\n", "
Code repository: https://github.com/rasbt/LLMs-from-scratch\n", "
汉化的库: https://github.com/GoatCsu/CN-LLMs-from-scratch.git\n", "
\n", "
\n", "\n", "
\n" ] }, { "cell_type": "markdown", "metadata": { "id": "V6BXGeEJ_s-8" }, "source": [ "# 理解PyTorch缓冲区作用" ] }, { "cell_type": "markdown", "metadata": { "id": "aQt9Ob1Y_8EH" }, "source": [ "本质上,PyTorch缓冲区是与PyTorch模块或模型相关联的张量属性,与参数类似.\n", "\n", "但不同于参数的是,缓冲区在训练过程中不会被更新。\n", "\n", "在处理GPU计算时,PyTorch缓冲区尤其重要,因为它们需要与模型的参数一起在设备间传输(如从CPU到GPU)。与参数不同,缓冲区不需要计算梯度,但仍需位于正确的设备上,以确保计算的准确性。\n", "\n", "在第三章中,我们通过`self.register_buffer`使用了PyTorch缓冲区,书中对此仅做了简要介绍。由于其概念和作用并不十分直观,本代码笔记本提供了更为详细的解释和实操示例。" ] }, { "cell_type": "markdown", "metadata": { "id": "dAwGo_gYLY45" }, "source": [ "## 无缓存区" ] }, { "cell_type": "markdown", "metadata": { "id": "0qBQC9IPAJVZ" }, "source": [ "假设我们有以下代码,基于第三章的代码,并已修改以排除缓冲区。该代码实现了LLM中使用的因果自注意力机制:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "7wx-_rokAN04" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "class CausalAttentionWithoutBuffers(nn.Module):\n", "\n", " def __init__(self, d_in, d_out, context_length,\n", " dropout, qkv_bias=False):\n", " super().__init__()\n", " self.d_out = d_out\n", " self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n", " self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n", " self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n", " self.dropout = nn.Dropout(dropout)\n", " self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n", "\n", " def forward(self, x):\n", " b, num_tokens, d_in = x.shape\n", " keys = self.W_key(x)\n", " queries = self.W_query(x)\n", " values = self.W_value(x)\n", "\n", " attn_scores = queries @ keys.transpose(1, 2)\n", " attn_scores.masked_fill_(\n", " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)\n", " attn_weights = torch.softmax(\n", " attn_scores / keys.shape[-1]**0.5, dim=-1\n", " )\n", " attn_weights = self.dropout(attn_weights)\n", "\n", " context_vec = attn_weights @ values\n", " return context_vec" ] }, { "cell_type": "markdown", "metadata": { "id": "nNrK-wLaNSi7" }, "source": [ "我们可以按照如下形式初始化模型并在在测试样例上运行" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "e1MZiIsPA0Py", "outputId": "ce1407c6-c082-4755-b8ad-d9adcc9f153a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[[-0.4519, 0.2216],\n", " [-0.5874, 0.0058],\n", " [-0.6300, -0.0632],\n", " [-0.5675, -0.0843],\n", " [-0.5526, -0.0981],\n", " [-0.5299, -0.1081]],\n", "\n", " [[-0.4519, 0.2216],\n", " [-0.5874, 0.0058],\n", " [-0.6300, -0.0632],\n", " [-0.5675, -0.0843],\n", " [-0.5526, -0.0981],\n", " [-0.5299, -0.1081]]])\n" ] } ], "source": [ "torch.manual_seed(123)\n", "\n", "inputs = torch.tensor(\n", " [[0.43, 0.15, 0.89], # Your (x^1)\n", " [0.55, 0.87, 0.66], # journey (x^2)\n", " [0.57, 0.85, 0.64], # starts (x^3)\n", " [0.22, 0.58, 0.33], # with (x^4)\n", " [0.77, 0.25, 0.10], # one (x^5)\n", " [0.05, 0.80, 0.55]] # step (x^6)\n", ")\n", "\n", "batch = torch.stack((inputs, inputs), dim=0)\n", "context_length = batch.shape[1]\n", "d_in = inputs.shape[1]\n", "d_out = 2\n", "\n", "ca_without_buffer = CausalAttentionWithoutBuffers(d_in, d_out, context_length, 0.0)\n", "\n", "with torch.no_grad():\n", " context_vecs = ca_without_buffer(batch)\n", "\n", "print(context_vecs)" ] }, { "cell_type": "markdown", "metadata": { "id": "7_hqz6AgCCc1" }, "source": [ "到目前为止,一切都运行良好。\n", "\n", "然而,在训练LLM时,我们通常使用GPU来加速这一过程。因此,我们将把`CausalAttentionWithoutBuffers`模块转移到GPU设备上。\n", "\n", "这需要在配备GPU的环境中运行代码。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "PYwn44HWCPJS", "outputId": "d7236e0c-2a43-4770-ccc1-03c9d5d11421" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Machine has GPU: True\n" ] } ], "source": [ "print(\"Machine has GPU:\", torch.cuda.is_available())\n", "\n", "batch = batch.to(\"cuda\")\n", "ca_without_buffer.to(\"cuda\");" ] }, { "cell_type": "markdown", "metadata": { "id": "4_lMki2_CoIR" }, "source": [ "再一次运行代码" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 338 }, "id": "KE9iLcjGC1V1", "outputId": "ab6921c7-d7dd-44ea-9b92-1911037e3dcc" }, "outputs": [ { "ename": "RuntimeError", "evalue": "expected self and mask to be on the same device, but got mask on cpu and self on cuda:0", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mcontext_vecs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mca_without_buffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext_vecs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1532\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1533\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1534\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1539\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1540\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1542\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1543\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mattn_scores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mqueries\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mkeys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m attn_scores.masked_fill_(\n\u001b[0m\u001b[1;32m 24\u001b[0m self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)\n\u001b[1;32m 25\u001b[0m attn_weights = torch.softmax(\n", "\u001b[0;31mRuntimeError\u001b[0m: expected self and mask to be on the same device, but got mask on cpu and self on cuda:0" ] } ], "source": [ "with torch.no_grad():\n", " context_vecs = ca_without_buffer(batch)\n", "\n", "print(context_vecs)" ] }, { "cell_type": "markdown", "metadata": { "id": "I7V26PLrC2gk" }, "source": [ "运行代码时出现了错误。发生了什么呢?\n", "看起来我们尝试在GPU上的张量和CPU上的张量之间进行矩阵乘法。\n", "但我们已经将这些模块移到了GPU上!\n", "\n", "让我们再检查一下某些张量的设备位置:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "vvYDPBRIDHfU", "outputId": "4b9703a8-7035-4a2d-8643-c64d37b7abd2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "W_query.device: cuda:0\n", "mask.device: cpu\n" ] } ], "source": [ "print(\"W_query.device:\", ca_without_buffer.W_query.weight.device)\n", "print(\"mask.device:\", ca_without_buffer.mask.device)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "d11nX-FFOJ3C", "outputId": "1e92b0e8-dbc6-41f9-e88f-5d06e0726050" }, "outputs": [ { "data": { "text/plain": [ "torch.Tensor" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(ca_without_buffer.mask)" ] }, { "cell_type": "markdown", "metadata": { "id": "Ojay-KY-DL5M" }, "source": [ "如我们所见,`mask`没有被移到GPU上。原因是它不像权重(例如`W_query.weight`)那样是PyTorch的参数。\n", "\n", "因此,我们需要通过`.to(\"cuda\")`手动将其移到GPU上:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QYirQ63zDYsW", "outputId": "304628ac-bc4c-49c2-a0e1-ecf9385ddcd9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mask.device: cuda:0\n" ] } ], "source": [ "ca_without_buffer.mask = ca_without_buffer.mask.to(\"cuda\")\n", "print(\"mask.device:\", ca_without_buffer.mask.device)" ] }, { "cell_type": "markdown", "metadata": { "id": "4OoTqzkpDfAm" }, "source": [ "再一次运行代码" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "WfF0yBZODdAZ", "outputId": "291cfb54-86e6-45f9-99d1-fa145319f379" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[[-0.4519, 0.2216],\n", " [-0.5874, 0.0058],\n", " [-0.6300, -0.0632],\n", " [-0.5675, -0.0843],\n", " [-0.5526, -0.0981],\n", " [-0.5299, -0.1081]],\n", "\n", " [[-0.4519, 0.2216],\n", " [-0.5874, 0.0058],\n", " [-0.6300, -0.0632],\n", " [-0.5675, -0.0843],\n", " [-0.5526, -0.0981],\n", " [-0.5299, -0.1081]]], device='cuda:0')\n" ] } ], "source": [ "with torch.no_grad():\n", " context_vecs = ca_without_buffer(batch)\n", "\n", "print(context_vecs)" ] }, { "cell_type": "markdown", "metadata": { "id": "oUrVgWuuD7UE" }, "source": [ "这次,它成功了!\n", "\n", "然而,记得将每个张量手动移到GPU可能会很繁琐。正如我们将在下一节中看到的,使用`register_buffer`将`mask`注册为缓冲区会更为简便。" ] }, { "cell_type": "markdown", "metadata": { "id": "StS2wUrBLeuW" }, "source": [ "## 有了缓冲区的运行" ] }, { "cell_type": "markdown", "metadata": { "id": "nEqD2NFzPO6l" }, "source": [ "现在,让我们修改因果注意力类,将因果`mask`注册为缓冲区:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "ndsYj3Zf6N8U" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "class CausalAttentionWithBuffer(nn.Module):\n", "\n", " def __init__(self, d_in, d_out, context_length,\n", " dropout, qkv_bias=False):\n", " super().__init__()\n", " self.d_out = d_out\n", " self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n", " self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n", " self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n", " self.dropout = nn.Dropout(dropout)\n", " # Old:\n", " # self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n", "\n", " # New:\n", " self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n", "\n", " def forward(self, x):\n", " b, num_tokens, d_in = x.shape\n", " keys = self.W_key(x)\n", " queries = self.W_query(x)\n", " values = self.W_value(x)\n", "\n", " attn_scores = queries @ keys.transpose(1, 2)\n", " attn_scores.masked_fill_(\n", " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)\n", " attn_weights = torch.softmax(\n", " attn_scores / keys.shape[-1]**0.5, dim=-1\n", " )\n", " attn_weights = self.dropout(attn_weights)\n", "\n", " context_vec = attn_weights @ values\n", " return context_vec" ] }, { "cell_type": "markdown", "metadata": { "id": "_AL1X6y3Eb7S" }, "source": [ "十分方便的是,如果我们将模块移到GPU,`mask`也会自动被放置到GPU上:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8_VCxEa76j00", "outputId": "4d1af501-5a9e-46aa-b1ac-63bf0c68e02a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "W_query.device: cuda:0\n", "mask.device: cuda:0\n" ] } ], "source": [ "ca_with_buffer = CausalAttentionWithBuffer(d_in, d_out, context_length, 0.0)\n", "ca_with_buffer.to(\"cuda\")\n", "\n", "print(\"W_query.device:\", ca_with_buffer.W_query.weight.device)\n", "print(\"mask.device:\", ca_with_buffer.mask.device)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "TBWvKlMe7bbB", "outputId": "e43bf8ab-3fb9-417e-d087-560858332d86" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[[0.4772, 0.1063],\n", " [0.5891, 0.3257],\n", " [0.6202, 0.3860],\n", " [0.5478, 0.3589],\n", " [0.5321, 0.3428],\n", " [0.5077, 0.3493]],\n", "\n", " [[0.4772, 0.1063],\n", " [0.5891, 0.3257],\n", " [0.6202, 0.3860],\n", " [0.5478, 0.3589],\n", " [0.5321, 0.3428],\n", " [0.5077, 0.3493]]], device='cuda:0')\n" ] } ], "source": [ "with torch.no_grad():\n", " context_vecs = ca_with_buffer(batch)\n", "\n", "print(context_vecs)" ] }, { "cell_type": "markdown", "metadata": { "id": "xvOTh4NNPjef" }, "source": [ "As we can see above, registering a tensor as a buffer can make our lives a lot easier: We don't have to remember to move tensors to a target device like a GPU manually." ] }, { "cell_type": "markdown", "metadata": { "id": "Q-5YYKmJte3h" }, "source": [ "## 缓冲区与`state_dict`" ] }, { "cell_type": "markdown", "metadata": { "id": "YIHHawPbtjfp" }, "source": [ "- PyTorch缓冲区相较于普通张量的另一个优点是,它们会被包含在模型的`state_dict`中。\n", "- 例如,考虑没有缓冲区的因果注意力对象的`state_dict`:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "c217juzqtxsS", "outputId": "dbae3c3d-f4f8-4c70-a64f-90906561d8d9" }, "outputs": [ { "data": { "text/plain": [ "OrderedDict([('W_query.weight',\n", " tensor([[-0.2354, 0.0191, -0.2867],\n", " [ 0.2177, -0.4919, 0.4232]], device='cuda:0')),\n", " ('W_key.weight',\n", " tensor([[-0.4196, -0.4590, -0.3648],\n", " [ 0.2615, -0.2133, 0.2161]], device='cuda:0')),\n", " ('W_value.weight',\n", " tensor([[-0.4900, -0.3503, -0.2120],\n", " [-0.1135, -0.4404, 0.3780]], device='cuda:0'))])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ca_without_buffer.state_dict()" ] }, { "cell_type": "markdown", "metadata": { "id": "NdmZuPaqt6aO" }, "source": [ "- 上面的`state_dict`中没有包含`mask`。\n", "- 然而,由于将其注册为缓冲区,下面的`state_dict`中包含了`mask`。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "uGIGQAwPt1Pl", "outputId": "00f9bc44-63f9-4ebc-87ea-d4b8cafd81c1" }, "outputs": [ { "data": { "text/plain": [ "OrderedDict([('mask',\n", " tensor([[0., 1., 1., 1., 1., 1.],\n", " [0., 0., 1., 1., 1., 1.],\n", " [0., 0., 0., 1., 1., 1.],\n", " [0., 0., 0., 0., 1., 1.],\n", " [0., 0., 0., 0., 0., 1.],\n", " [0., 0., 0., 0., 0., 0.]], device='cuda:0')),\n", " ('W_query.weight',\n", " tensor([[-0.1362, 0.1853, 0.4083],\n", " [ 0.1076, 0.1579, 0.5573]], device='cuda:0')),\n", " ('W_key.weight',\n", " tensor([[-0.2604, 0.1829, -0.2569],\n", " [ 0.4126, 0.4611, -0.5323]], device='cuda:0')),\n", " ('W_value.weight',\n", " tensor([[ 0.4929, 0.2757, 0.2516],\n", " [ 0.2377, 0.4800, -0.0762]], device='cuda:0'))])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ca_with_buffer.state_dict()" ] }, { "cell_type": "markdown", "metadata": { "id": "ACC-a1Hnt4Zv" }, "source": [ "- `state_dict`在保存和加载训练好的PyTorch模型时非常有用,例如。\n", "- 在这个特定的情况下,保存和加载`mask`可能并不是特别有用,因为它在训练过程中保持不变;因此,出于演示目的,我们假设它被修改了,将所有的`1`改为`2`:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "RLm1Sw0cuhvy", "outputId": "4b2cc70f-1709-44e4-aa17-4e01353b86f8" }, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 2., 2., 2., 2., 2.],\n", " [0., 0., 2., 2., 2., 2.],\n", " [0., 0., 0., 2., 2., 2.],\n", " [0., 0., 0., 0., 2., 2.],\n", " [0., 0., 0., 0., 0., 2.],\n", " [0., 0., 0., 0., 0., 0.]], device='cuda:0')" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ca_with_buffer.mask[ca_with_buffer.mask == 1.] = 2.\n", "ca_with_buffer.mask" ] }, { "cell_type": "markdown", "metadata": { "id": "BIkGgGqqvp4S" }, "source": [ "- 然后,如果我们保存并加载模型,可以看到`mask`已经恢复为修改后的值。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "e8g0QHUhuVBw", "outputId": "cc7ee348-7f94-4117-e5cc-e0e01a94e906" }, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 2., 2., 2., 2., 2.],\n", " [0., 0., 2., 2., 2., 2.],\n", " [0., 0., 0., 2., 2., 2.],\n", " [0., 0., 0., 0., 2., 2.],\n", " [0., 0., 0., 0., 0., 2.],\n", " [0., 0., 0., 0., 0., 0.]])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.save(ca_with_buffer.state_dict(), \"model.pth\")\n", "\n", "new_ca_with_buffer = CausalAttentionWithBuffer(d_in, d_out, context_length, 0.0)\n", "new_ca_with_buffer.load_state_dict(torch.load(\"model.pth\"))\n", "\n", "new_ca_with_buffer.mask" ] }, { "cell_type": "markdown", "metadata": { "id": "0pPaJk7bvBD7" }, "source": [ "- 如果我们不使用缓冲区,情况就不一样了:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "D03w8vDyvBRS", "outputId": "28071601-120c-42da-b327-bb293793839f" }, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 1., 1., 1., 1., 1.],\n", " [0., 0., 1., 1., 1., 1.],\n", " [0., 0., 0., 1., 1., 1.],\n", " [0., 0., 0., 0., 1., 1.],\n", " [0., 0., 0., 0., 0., 1.],\n", " [0., 0., 0., 0., 0., 0.]])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ca_without_buffer.mask[ca_without_buffer.mask == 1.] = 2.\n", "\n", "torch.save(ca_without_buffer.state_dict(), \"model.pth\")\n", "\n", "new_ca_without_buffer = CausalAttentionWithoutBuffers(d_in, d_out, context_length, 0.0)\n", "new_ca_without_buffer.load_state_dict(torch.load(\"model.pth\"))\n", "\n", "new_ca_without_buffer.mask" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "L4", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.6" } }, "nbformat": 4, "nbformat_minor": 4 }