\n",
"
\n",
"
\n",
" \n",
"\n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"
\n"
]
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"# 定义训练函数:\n",
"def train(train_loader, net, epochs=5, total_iterations_limit=None):\n",
" cross_el = nn.CrossEntropyLoss()\n",
" optimizer = torch.optim.Adam(net.parameters(), lr=0.001)\n",
"\n",
" total_iterations = 0\n",
"\n",
" for epoch in range(epochs):\n",
" net.train()\n",
"\n",
" loss_sum = 0\n",
" num_iterations = 0\n",
"\n",
" data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')\n",
" if total_iterations_limit is not None:\n",
" data_iterator.total = total_iterations_limit\n",
" for data in data_iterator:\n",
" num_iterations += 1\n",
" total_iterations += 1\n",
" x, y = data\n",
" x = x.to(device)\n",
" y = y.to(device)\n",
" optimizer.zero_grad()\n",
" output = net(x.view(-1, 28*28))\n",
" loss = cross_el(output, y)\n",
" loss_sum += loss.item()\n",
" avg_loss = loss_sum / num_iterations\n",
" data_iterator.set_postfix(loss=avg_loss)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" if total_iterations_limit is not None and total_iterations >= total_iterations_limit:\n",
" return"
],
"metadata": {
"id": "QKLuOJe92iFF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# 下载MNIST手写体数字识别的数据\n",
"transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n",
"\n",
"# 加载手写体数据:\n",
"mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) # 训练集\n",
"train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)\n",
"mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) # 测试集\n",
"\n",
"# 去掉数字‘1'的数据,模型对‘1'的识别率存在问题\n",
"exclude_indices = torch.tensor([False if x == 1 else True for x in mnist_trainset.targets])\n",
"mnist_trainset.data = mnist_trainset.data[exclude_indices]\n",
"mnist_trainset.targets = mnist_trainset.targets[exclude_indices]\n",
"\n",
"# 训练模型:\n",
"train(train_loader, net, epochs=1, total_iterations_limit=2000)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WCgiS0a6-2E8",
"outputId": "4f2dc110-4e5b-4104-b166-558ec5c2fb3f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 9.91M/9.91M [00:00<00:00, 38.7MB/s]\n",
"100%|██████████| 28.9k/28.9k [00:00<00:00, 1.18MB/s]\n",
"100%|██████████| 1.65M/1.65M [00:00<00:00, 6.65MB/s]\n",
"100%|██████████| 4.54k/4.54k [00:00<00:00, 3.36MB/s]\n",
"Epoch 1: 100%|█████████▉| 1999/2000 [00:44<00:00, 45.14it/s, loss=0.327]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# 加载测试数据,观测结果:\n",
"test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)\n",
"def test(model=net):\n",
" total = 0\n",
" wrong_counts = [0 for i in range(10)]\n",
" correct_counts = [0 for i in range(10)]\n",
"\n",
" with torch.no_grad():\n",
" for data in tqdm(test_loader, desc='Testing'):\n",
" x, y = data\n",
" x = x.to(device)\n",
" y = y.to(device)\n",
" output = model(x.view(-1, 784))\n",
" for idx, i in enumerate(output):\n",
" if torch.argmax(i) == y[idx]:\n",
" correct_counts[y[idx]] +=1\n",
" else:\n",
" wrong_counts[y[idx]] +=1\n",
" total +=1\n",
" result_str = \"\"\n",
" for i in range(len(wrong_counts)):\n",
" result_str += f'The wrong counts of digit {i}: {wrong_counts[i]}\\n'\n",
" print(f'\\nAccuracy: {round(sum(correct_counts)/total, 3)}\\n{result_str}')\n",
" return [x / (x+y) for x, y in zip(correct_counts, wrong_counts)]\n",
"\n",
"test()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Zm31BDpU2nMG",
"outputId": "b9ad7bb2-2b69-4e7d-9d3c-a32dfb1f495f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Testing: 100%|██████████| 1000/1000 [00:03<00:00, 288.46it/s]"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"Accuracy: 0.84\n",
"The wrong counts of digit 0: 12\n",
"The wrong counts of digit 1: 1135\n",
"The wrong counts of digit 2: 81\n",
"The wrong counts of digit 3: 35\n",
"The wrong counts of digit 4: 47\n",
"The wrong counts of digit 5: 61\n",
"The wrong counts of digit 6: 31\n",
"The wrong counts of digit 7: 73\n",
"The wrong counts of digit 8: 50\n",
"The wrong counts of digit 9: 75\n",
"\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[0.9877551020408163,\n",
" 0.0,\n",
" 0.9215116279069767,\n",
" 0.9653465346534653,\n",
" 0.9521384928716904,\n",
" 0.9316143497757847,\n",
" 0.9676409185803758,\n",
" 0.9289883268482491,\n",
" 0.9486652977412731,\n",
" 0.9256689791873142]"
]
},
"metadata": {},
"execution_count": 10
}
]
},
{
"cell_type": "markdown",
"source": [
"## 2.2 LoRA微调"
],
"metadata": {
"id": "ZFKtrH8aXw9d"
}
},
{
"cell_type": "code",
"source": [
"# 定义LoRA对权重修改修改:\n",
"class LoRAParametrization(nn.Module):\n",
" def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):\n",
" super().__init__()\n",
" # 低秩矩阵的定义:\n",
" self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))\n",
" self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))\n",
" nn.init.normal_(self.lora_A, mean=0, std=1)\n",
"\n",
" # 参考论文:https://arxiv.org/pdf/2106.09685 4.1节 设置一个比例系数:\n",
" self.scale = alpha / rank\n",
" # LoRA开关:\n",
" self.enabled = True\n",
"\n",
" def forward(self, original_weights):\n",
" if self.enabled:\n",
" # Return W + (B*A)*scale\n",
" return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale\n",
" else:\n",
" return original_weights"
],
"metadata": {
"id": "_RdVOz-z9GTC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# 可视化LoRA层\n",
"lora_param = LoRAParametrization(1000, 784)\n",
"example_input = torch.randn(1000, 784)\n",
"trace_model(\n",
" lora_param,\n",
" example_input,\n",
" collapse_modules_after_depth=3,\n",
" show_non_gradient_nodes=False,\n",
" forced_module_tracing_depth=None,\n",
" height=500\n",
")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 517
},
"id": "PGYoC3W6Lwsx",
"outputId": "bd269da2-3951-485b-e0eb-688cf106fd83"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"