"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"metrics = pd.read_csv(f\"{trainer.logger.log_dir}/metrics.csv\")\n",
"\n",
"aggreg_metrics = []\n",
"agg_col = \"epoch\"\n",
"for i, dfg in metrics.groupby(agg_col):\n",
" agg = dict(dfg.mean())\n",
" agg[agg_col] = i\n",
" aggreg_metrics.append(agg)\n",
"\n",
"df_metrics = pd.DataFrame(aggreg_metrics)\n",
"df_metrics[[\"train_loss\", \"valid_loss\"]].plot(\n",
" grid=True, legend=True, xlabel='Epoch', ylabel='Loss')\n",
"df_metrics[[\"train_acc\", \"valid_acc\"]].plot(\n",
" grid=True, legend=True, xlabel='Epoch', ylabel='ACC')"
]
},
{
"cell_type": "markdown",
"id": "18304f53",
"metadata": {},
"source": [
"- The `trainer` automatically saves the model with the best validation accuracy automatically for us, we which we can load from the checkpoint via the `ckpt_path='best'` argument; below we use the `trainer` instance to evaluate the best model on the test set:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "8bff53a0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Restoring states from the checkpoint path at logs/my-model/version_36/checkpoints/epoch=8-step=74600.ckpt\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"Loaded model weights from checkpoint at logs/my-model/version_36/checkpoints/epoch=8-step=74600.ckpt\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "80900a8ff1054ce0abdc8a2cbcc7b646",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Testing: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃ Test metric ┃ DataLoader 0 ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ test_acc │ 0.9217909574508667 │\n",
"└───────────────────────────┴───────────────────────────┘\n",
"\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│\u001b[36m \u001b[0m\u001b[36m test_acc \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9217909574508667 \u001b[0m\u001b[35m \u001b[0m│\n",
"└───────────────────────────┴───────────────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[{'test_acc': 0.9217909574508667}]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.test(model=lightning_model, datamodule=data_module, ckpt_path='best')"
]
},
{
"cell_type": "markdown",
"id": "ebe513ab",
"metadata": {},
"source": [
"## Predicting labels of new data"
]
},
{
"cell_type": "markdown",
"id": "f0674611",
"metadata": {},
"source": [
"- You can use the `trainer.predict` method on a new `DataLoader` or `DataModule` to apply the model to new data.\n",
"- Alternatively, you can also manually load the best model from a checkpoint as shown below:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "99fb98a9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"logs/my-model/version_36/checkpoints/epoch=8-step=74600.ckpt\n"
]
}
],
"source": [
"path = trainer.checkpoint_callback.best_model_path\n",
"print(path)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "ab60d544",
"metadata": {},
"outputs": [],
"source": [
"lightning_model = LightningModel.load_from_checkpoint(\n",
" path, model=pytorch_model)\n",
"lightning_model.eval();"
]
},
{
"cell_type": "markdown",
"id": "eb52e61c",
"metadata": {},
"source": [
"- Note that our PyTorch model, which is passed to the Lightning model, requires input arguments. However, this is automatically being taken care of since we used `self.save_hyperparameters()` in our PyTorch model's `__init__` method.\n",
"- Now, below is an example applying the model manually. Here, pretend that the `test_dataloader` is a new data loader."
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "b544b139",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([7, 1, 3, 4, 9])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_dataloader = data_module.test_dataloader()\n",
"\n",
"all_true_labels = []\n",
"all_predicted_labels = []\n",
"for batch in test_dataloader:\n",
" features, labels = batch\n",
" \n",
" with torch.no_grad(): # since we don't need to backprop\n",
" logits = lightning_model(features)\n",
"\n",
" predicted_labels = torch.argmax(logits, dim=1)\n",
" all_predicted_labels.append(predicted_labels)\n",
" all_true_labels.append(labels)\n",
" \n",
"all_predicted_labels = torch.cat(all_predicted_labels)\n",
"all_true_labels = torch.cat(all_true_labels)\n",
"all_predicted_labels[:5]"
]
},
{
"cell_type": "markdown",
"id": "a9afe2e3",
"metadata": {},
"source": [
"Just as an internal check, if the model was loaded correctly, the test accuracy below should be identical to the test accuracy we saw earlier in the previous section."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "89cd4554",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9218 (92.18%)\n"
]
}
],
"source": [
"test_acc = torch.mean((all_predicted_labels == all_true_labels).float())\n",
"print(f'Test accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)')"
]
},
{
"cell_type": "markdown",
"id": "74a2323e",
"metadata": {},
"source": [
"## Single-image usage"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "331fa78f",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "694f1e4c",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"id": "62f39922",
"metadata": {},
"source": [
"- Assume we have a single image as shown below:"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "0882c122",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAPzklEQVR4nO3de4wVZZrH8d8jN7mMXKTBltFtmeAta9YZW1lhI24mi0pIZGIk+IdhExP8QxJGNK4OCRhjlGy84B/emIUMa2acjGGI93UMGWNGkomtQcAlCEvaGaRDN4IOKqhNP/tHl5MWu95qquqcKnm/n6Rzuuvpt+vx2D/q9Hmr6jV3F4BT32lVNwCgOQg7EAnCDkSCsAORIOxAJIY3c2eTJ0/2tra2Zu4SiEpnZ6cOHjxog9UKhd3MrpX0mKRhkv7L3VeHvr+trU0dHR1FdgkgoL29PbWW+2W8mQ2T9Lik6yRdLOkmM7s4788D0FhF/ma/QtIed9/r7l9J+q2k68tpC0DZioR9mqS/Dvh6X7LtW8xsiZl1mFlHT09Pgd0BKKJI2Ad7E+A75966+1p3b3f39paWlgK7A1BEkbDvk3TOgK9/KGl/sXYANEqRsL8taYaZnWdmIyUtkvRCOW0BKFvuqTd37zWzpZJeU//U23p3f7+0zgCUqtA8u7u/IumVknoB0ECcLgtEgrADkSDsQCQIOxAJwg5EgrADkSDsQCQIOxAJwg5EgrADkSDsQCQIOxAJwg5EgrADkSDsQCQIOxAJwg5EgrADkSDsQCQIOxAJwg5EoqlLNhfV3d2dWpswYUJw7PHjx4P1Y8eO5WlJkmQ26Aq5f3fGGWcE66edFue/uV1dXcF6a2trkzqJQ5y/ZUCECDsQCcIORIKwA5Eg7EAkCDsQCcIORKJW8+xff/11sH722Wen1rLmunt7e3P1VIYpU6YE61dddVWwvnz58mD9yiuvPOmemmXLli2ptdmzZzd032PGjEmtzZ8/Pzj2wQcfDNanT5+eq6cqFQq7mXVKOiLpuKRed28voykA5SvjyP6v7n6whJ8DoIH4mx2IRNGwu6Q/mNk7ZrZksG8wsyVm1mFmHT09PQV3ByCvomGf7e4/kXSdpNvM7DvvNLn7Wndvd/f2lpaWgrsDkFehsLv7/uSxW9ImSVeU0RSA8uUOu5mNNbMffPO5pLmSdpTVGIByFXk3fqqkTcn89nBJv3H3/ynSTF9fX7AeuiZ92bJlwbFz584N1seNGxesjxgxIrWWdX7Aa6+9Fqw/8cQTwXrW87Jx48ZgvUovvvhi7rFPP/10sD5q1Khg/eOPP06tPfTQQ8Gxq1atCtafeeaZYL2Ocofd3fdK+qcSewHQQEy9AZEg7EAkCDsQCcIORIKwA5Go1SWuRW6pnHWZ57x583L/7KKyLmE9/fTTg/U1a9YE6+6eWsu69LfRPvnkk9TaWWedFRy7ZMmgZ2A3xYoVK4L1devWBesjR44ss51ScGQHIkHYgUgQdiAShB2IBGEHIkHYgUgQdiASp8w8e9aSzHU2Z86cYH3lypXB+sGD6ff7rPruQG+88UZqbebMmc1r5AQLFiwI1u+4445gfdu2bcF6e3v9brTMkR2IBGEHIkHYgUgQdiAShB2IBGEHIkHYgUjUap69yLXXWbdzrrNp06YVGn/48OHUWqPn2Y8ePRqs79q1K7WWdZ1/I5133nnBemh5cEl67rnngnXm2QFUhrADkSDsQCQIOxAJwg5EgrADkSDsQCRqNc+edT17aB6+t7e37HaaZvjwYv8bqryWf8+ePcF66J72WUsuN1LWOR2LFi0K1rOWbF69enWh/TdC5pHdzNabWbeZ7RiwbZKZvW5mu5PHiY1tE0BRQ3kZ/ytJ156w7W5Jm919hqTNydcAaiwz7O7+pqRDJ2y+XtKG5PMNkhaU2xaAsuV9g26qu3dJUvI4Je0bzWyJmXWYWUdPT0/O3QEoquHvxrv7Wndvd/f2qm9+CMQsb9gPmFmrJCWP3eW1BKAR8ob9BUmLk88XS3q+nHYANErmBK+ZPSvpakmTzWyfpFWSVkv6nZndIukvkm5sZJPfCM3Df5/vG1/kfvlStdfyb9++PffYOq5h/o2FCxcG64888kiwnnX+wYwZM066p6Iyw+7uN6WUflpyLwAaiNNlgUgQdiAShB2IBGEHIkHYgUjU6hLX0OWQWfWil4lWqWjva9asSa2NHz8+ODbrFOaurq5gfevWrcF6yOjRo3OPbbTLLrssWD/33HOD9fvvvz9Y37BhQ7DeCBzZgUgQdiAShB2IBGEHIkHYgUgQdiAShB2IRK0mp7Mu1ezr60utZc0n11nRefZNmzal1rKWJs5aLrqtrS1Yz7qUc+LE9BsPL126NDi2Sln/T7Lm0RcvXhysh2413draGhybF0d2IBKEHYgEYQciQdiBSBB2IBKEHYgEYQci8b2aZw/5Ps+zDxs2rND4l19+ObU2a9asQj87y5w5c4L1Cy64ILU2derUsttpmqxbTd9+++3Bemie/bHHHsvVUxaO7EAkCDsQCcIORIKwA5Eg7EAkCDsQCcIORKJW8+xffvll7rFjxowpsZPmKjrP3tvbW1InJ2/Xrl3B+ty5c5vUSXONGjUqWF+5cmWwvnz58tTaqlWrgmMnTZoUrKfJPLKb2Xoz6zazHQO23WtmH5nZ1uRjXq69A2iaobyM/5WkawfZ/qi7X5p8vFJuWwDKlhl2d39T0qEm9AKggYq8QbfUzLYlL/NTbzRmZkvMrMPMOrLWFQPQOHnD/qSkH0m6VFKXpIfTvtHd17p7u7u3t7S05NwdgKJyhd3dD7j7cXfvk/RLSVeU2xaAsuUKu5kNvNftzyTtSPteAPWQOc9uZs9KulrSZDPbJ2mVpKvN7FJJLqlT0q1lNHPs2LHcY7/P8+ynnVbs3KZGzrNn3WOgu7s7WM+6b/2pKuu+8cuWLUutbdmyJTh2/vz5uXrKDLu73zTI5nW59gagMpwuC0SCsAORIOxAJAg7EAnCDkSiVpe4Hj16NPfY0aNHB+t79+4N1g8cOJB731mXqF5yySXBetGptyK34M7y6aefBuvuHqyfqlNvWdPEDzzwQO6f3agzTTmyA5Eg7EAkCDsQCcIORIKwA5Eg7EAkCDsQiVrNs3/xxRe5x15zzTXB+ocffpj7Zxc1fHj4ab7hhhsK/fyvvvqq0PiQorcSa21tzf6mCmSd07F+/fpgPetW0UeOHAnWH3/88dTa5ZdfHhybF0d2IBKEHYgEYQciQdiBSBB2IBKEHYgEYQciccrMs+/fvz9Yv/HGG4P1iy66KFgPXWOcdS39e++9F6xnzelm+eCDDwqND/noo48KjT/zzDNL6uS7sn5fnnrqqdTafffdFxz72WefBetLly4N1lesWBGsV7E6Ekd2IBKEHYgEYQciQdiBSBB2IBKEHYgEYQciUat59gsvvDBYnzVrVmqtr68vOHb79u3B+quvvhqsZ827FjFu3LhgfebMmcH6woULy2znW3bv3h2sjxw5MlgP/bcdOnQoOHbNmjXB+sMPPxysh+6nH1oyWZLuueeeYH3SpEnBeh1lHtnN7Bwz+6OZ7TSz981sWbJ9kpm9bma7k8eJjW8XQF5DeRnfK+kOd79I0j9Lus3MLpZ0t6TN7j5D0ubkawA1lRl2d+9y93eTz49I2ilpmqTrJW1Ivm2DpAUN6hFACU7qDToza5P0Y0l/ljTV3buk/n8QJE1JGbPEzDrMrKPo/cwA5DfksJvZOEkbJf3c3f821HHuvtbd2929vYqT/wH0G1LYzWyE+oP+a3f/fbL5gJm1JvVWSd2NaRFAGTKn3szMJK2TtNPdHxlQekHSYkmrk8fnizYzfvz4YP2tt94quovcQtM4n3/+eXBs1rTd1KlTg/URI0YE643U2dkZrGdNG956662ptXXr1gXHjhkzJli/6667gvXQ9NqECROCY09FQ5lnny3pZknbzWxrsu0X6g/578zsFkl/kRS+YBxApTLD7u5/kmQp5Z+W2w6ARuF0WSAShB2IBGEHIkHYgUgQdiAStbrEtc5Cc91Zc7bf5zndw4cPB+tZl6m+9NJLqbUnn3wyOPbmm28O1rNu4Y1v48gORIKwA5Eg7EAkCDsQCcIORIKwA5Eg7EAkmGdH0KOPPhqs33nnncH69OnTU2vDh/Pr10wc2YFIEHYgEoQdiARhByJB2IFIEHYgEoQdiAQTnQgaO3ZssH7++ec3qRMUxZEdiARhByJB2IFIEHYgEoQdiARhByJB2IFIZIbdzM4xsz+a2U4ze9/MliXb7zWzj8xsa/Ixr/HtAshrKCfV9Eq6w93fNbMfSHrHzF5Pao+6+0ONaw9AWYayPnuXpK7k8yNmtlPStEY3BqBcJ/U3u5m1SfqxpD8nm5aa2TYzW29mE1PGLDGzDjPr6OnpKdYtgNyGHHYzGydpo6Sfu/vfJD0p6UeSLlX/kf/hwca5+1p3b3f39paWluIdA8hlSGE3sxHqD/qv3f33kuTuB9z9uLv3SfqlpCsa1yaAoobybrxJWidpp7s/MmB764Bv+5mkHeW3B6AsQ3k3frakmyVtN7OtybZfSLrJzC6V5JI6Jd3agP4AlGQo78b/SZINUnql/HYANApn0AGRIOxAJAg7EAnCDkSCsAORIOxAJAg7EAnCDkSCsAORIOxAJAg7EAnCDkSCsAORIOxAJMzdm7czsx5JHw7YNFnSwaY1cHLq2ltd+5LoLa8ye/sHdx/0/m9NDft3dm7W4e7tlTUQUNfe6tqXRG95Nas3XsYDkSDsQCSqDvvaivcfUtfe6tqXRG95NaW3Sv9mB9A8VR/ZATQJYQciUUnYzexaM9tlZnvM7O4qekhjZp1mtj1Zhrqj4l7Wm1m3me0YsG2Smb1uZruTx0HX2Kuot1os4x1YZrzS567q5c+b/je7mQ2T9IGkf5O0T9Lbkm5y9/9taiMpzKxTUru7V34ChpldJekzSf/t7v+YbPtPSYfcfXXyD+VEd/+PmvR2r6TPql7GO1mtqHXgMuOSFkj6d1X43AX6WqgmPG9VHNmvkLTH3fe6+1eSfivp+gr6qD13f1PSoRM2Xy9pQ/L5BvX/sjRdSm+14O5d7v5u8vkRSd8sM17pcxfoqymqCPs0SX8d8PU+1Wu9d5f0BzN7x8yWVN3MIKa6e5fU/8sjaUrF/ZwocxnvZjphmfHaPHd5lj8vqoqwD7aUVJ3m/2a7+08kXSfptuTlKoZmSMt4N8sgy4zXQt7lz4uqIuz7JJ0z4OsfStpfQR+Dcvf9yWO3pE2q31LUB75ZQTd57K64n7+r0zLegy0zrho8d1Uuf15F2N+WNMPMzjOzkZIWSXqhgj6+w8zGJm+cyMzGSpqr+i1F/YKkxcnniyU9X2Ev31KXZbzTlhlXxc9d5cufu3vTPyTNU/878v8naUUVPaT0NV3Se8nH+1X3JulZ9b+s+1r9r4hukXSmpM2SdiePk2rU2zOStkvapv5gtVbU27+o/0/DbZK2Jh/zqn7uAn015XnjdFkgEpxBB0SCsAORIOxAJAg7EAnCDkSCsAORIOxAJP4fUwOrijzVq+YAAAAASUVORK5CYII=\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from PIL import Image\n",
"\n",
"\n",
"image = Image.open('./quickdraw-png_set1/binoculars/binoculars_093136.png')\n",
"plt.imshow(image, cmap='Greys')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "2625ac92",
"metadata": {},
"source": [
"- Note that we used a resize-transformation in the `DataModule` that rescaled the 28x28 images to size 32x32. We also have to apply the same transformation to any new image that we feed to the model:"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "64dafca9",
"metadata": {},
"outputs": [],
"source": [
"resize_transform = transforms.Compose(\n",
" [transforms.Resize((32, 32)),\n",
" transforms.ToTensor()])\n",
"\n",
"image_chw = resize_transform(image)"
]
},
{
"cell_type": "markdown",
"id": "003c8d20",
"metadata": {},
"source": [
"- Note that `ToTensor` returns the image in the CHW format. CHW refers to the dimensions and stands for channel, height, and width."
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "02845a79",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 32, 32])\n"
]
}
],
"source": [
"print(image_chw.shape)"
]
},
{
"cell_type": "markdown",
"id": "0ecfcca3",
"metadata": {},
"source": [
"- However, the PyTorch / PyTorch Lightning model expectes images in NCHW format, where N stands for the number of images (e.g., in a batch).\n",
"- We can add the additional channel dimension via `unsqueeze` as shown below:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "b442de8b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 1, 32, 32])\n"
]
}
],
"source": [
"image_nchw = image_chw.unsqueeze(0)\n",
"print(image_nchw.shape)"
]
},
{
"cell_type": "markdown",
"id": "25690363",
"metadata": {},
"source": [
"- Now that we have the image in the right format, we can feed it to our classifier:"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "db887578",
"metadata": {},
"outputs": [],
"source": [
"with torch.no_grad(): # since we don't need to backprop\n",
" logits = lightning_model(image_nchw)\n",
" probas = torch.softmax(logits, axis=1)\n",
" predicted_label = torch.argmax(probas)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "c8fdd29f",
"metadata": {},
"outputs": [],
"source": [
"label_dict = {\n",
" 0: \"lollipop\",\n",
" 1: \"binoculars\",\n",
" 2: \"mouse\",\n",
" 3: \"basket\",\n",
" 4: \"penguin\",\n",
" 5: \"washing machine\",\n",
" 6: \"canoe\",\n",
" 7: \"eyeglasses\",\n",
" 8: \"beach\",\n",
" 9: \"screwdriver\",\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "be912947",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicted label: binoculars\n",
"Class-membership probability 99.95%\n"
]
}
],
"source": [
"print(f'Predicted label: {label_dict[predicted_label.item()]}')\n",
"print(f'Class-membership probability {probas[0][predicted_label]*100:.2f}%')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}