"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"metrics = pd.read_csv(f\"{trainer.logger.log_dir}/metrics.csv\")\n",
"\n",
"aggreg_metrics = []\n",
"agg_col = \"epoch\"\n",
"for i, dfg in metrics.groupby(agg_col):\n",
" agg = dict(dfg.mean())\n",
" agg[agg_col] = i\n",
" aggreg_metrics.append(agg)\n",
"\n",
"df_metrics = pd.DataFrame(aggreg_metrics)\n",
"df_metrics[[\"train_loss\", \"valid_loss\"]].plot(\n",
" grid=True, legend=True, xlabel='Epoch', ylabel='Loss')\n",
"df_metrics[[\"train_acc\", \"valid_acc\"]].plot(\n",
" grid=True, legend=True, xlabel='Epoch', ylabel='ACC')"
]
},
{
"cell_type": "markdown",
"id": "2cbe0151",
"metadata": {},
"source": [
"- The `trainer` automatically saves the model with the best validation accuracy automatically for us, we which we can load from the checkpoint via the `ckpt_path='best'` argument; below we use the `trainer` instance to evaluate the best model on the test set:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "b726c1a0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Restoring states from the checkpoint path at logs/my-model/version_22/checkpoints/epoch=10-step=2353.ckpt\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"Loaded model weights from checkpoint at logs/my-model/version_22/checkpoints/epoch=10-step=2353.ckpt\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "02006ec6d039406f8c71dcda15169e82",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Testing: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃ Test metric ┃ DataLoader 0 ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ test_acc │ 0.9873999953269958 │\n",
"└───────────────────────────┴───────────────────────────┘\n",
"\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│\u001b[36m \u001b[0m\u001b[36m test_acc \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9873999953269958 \u001b[0m\u001b[35m \u001b[0m│\n",
"└───────────────────────────┴───────────────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[{'test_acc': 0.9873999953269958}]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.test(model=lightning_model, datamodule=data_module, ckpt_path='best')"
]
},
{
"cell_type": "markdown",
"id": "461b364a",
"metadata": {},
"source": [
"## Predicting labels of new data"
]
},
{
"cell_type": "markdown",
"id": "f32de2de",
"metadata": {},
"source": [
"- You can use the `trainer.predict` method on a new `DataLoader` or `DataModule` to apply the model to new data.\n",
"- Alternatively, you can also manually load the best model from a checkpoint as shown below:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "ba4570eb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"logs/my-model/version_22/checkpoints/epoch=10-step=2353.ckpt\n"
]
}
],
"source": [
"path = trainer.checkpoint_callback.best_model_path\n",
"print(path)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "8c1f7796",
"metadata": {},
"outputs": [],
"source": [
"lightning_model = LightningModel.load_from_checkpoint(\n",
" path, model=pytorch_model)\n",
"lightning_model.eval();"
]
},
{
"cell_type": "markdown",
"id": "4389e871",
"metadata": {},
"source": [
"- Note that our PyTorch model, which is passed to the Lightning model requires input arguments. However, this is automatically being taken care of since we used `self.save_hyperparameters()` in our PyTorch model's `__init__` method.\n",
"- Now, below is an example applying the model manually. Here, pretend that the `test_dataloader` is a new data loader."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "3ff7cd66",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([7, 2, 1, 0, 4])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_dataloader = data_module.test_dataloader()\n",
"\n",
"all_true_labels = []\n",
"all_predicted_labels = []\n",
"for batch in test_dataloader:\n",
" features, labels = batch\n",
" \n",
" with torch.no_grad(): # since we don't need to backprop\n",
" logits = lightning_model(features)\n",
" predicted_labels = torch.argmax(logits, dim=1)\n",
" all_predicted_labels.append(predicted_labels)\n",
" all_true_labels.append(labels)\n",
" \n",
"all_predicted_labels = torch.cat(all_predicted_labels)\n",
"all_true_labels = torch.cat(all_true_labels)\n",
"all_predicted_labels[:5]"
]
},
{
"cell_type": "markdown",
"id": "ff0a65db",
"metadata": {},
"source": [
"Just as an internal check, if the model was loaded correctly, the test accuracy below should be identical to the test accuracy we saw earlier in the previous section."
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "6c06ef45",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9874 (98.74%)\n"
]
}
],
"source": [
"test_acc = torch.mean((all_predicted_labels == all_true_labels).float())\n",
"print(f'Test accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)')"
]
},
{
"cell_type": "markdown",
"id": "4e80fe10",
"metadata": {},
"source": [
"## Single-image usage"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "335380c9",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "45998e5f",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"id": "ab0bdec5",
"metadata": {},
"source": [
"- Assume we have a single image as shown below:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "9fd24838",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAN8UlEQVR4nO3dUYxUdZbH8d+RHaKREUFa7Tgg7GiMZBNhrOAmbibquBPxQSRmzGAygunImEAcDA9L2BjQJ7JhZpiHzSSM4MBmFjIRjDzg7BiCUTQBCwMtLrqiYQdGAk0woSdRseHsQ182DXb9q6l7b93qPt9P0qmqe+rW/1jy61tV/9v1N3cXgLHvqqobANAehB0IgrADQRB2IAjCDgTxd+0cbMqUKT59+vR2DgmEcvToUZ0+fdqGq+UKu5k9JOk3ksZJesnd16TuP336dNXr9TxDAkio1WoNay2/jDezcZL+XdJcSTMlLTCzma0+HoBy5XnPPkfSEXf/zN3PSdoqaV4xbQEoWp6w3yLp2JDbx7NtlzCzxWZWN7N6X19fjuEA5JEn7MN9CPCtc2/dfb2719y91tXVlWM4AHnkCftxSVOH3P6epM/ztQOgLHnC/p6k281shpmNl/RTSTuKaQtA0VqeenP3ATNbKum/NDj1ttHdPyysMwCFyjXP7u47Je0sqBcAJeJ0WSAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQuZZsNrOjkvolnZc04O61IpoCULxcYc/c7+6nC3gcACXiZTwQRN6wu6Q/m9l+M1s83B3MbLGZ1c2s3tfXl3M4AK3KG/Z73f0HkuZKWmJmP7z8Du6+3t1r7l7r6urKORyAVuUKu7t/nl2ekvSqpDlFNAWgeC2H3cyuNbPvXrwu6ceSDhXVGIBi5fk0/iZJr5rZxcf5T3f/UyFdAShcy2F3988k3VVgLwBKxNQbEARhB4Ig7EAQhB0IgrADQRTxhzDI6d13303We3t7k/Wenp6Gtf7+/uS+b775ZrJ+6FB1p07cf//9yfqdd97Z8mPfcMMNyXo2pTymcGQHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSCYZ2+DtWvXJuurVq1K1r/66qtkffv27Q1rb7/9dnLfc+fOJetVeuGFF0p77McffzxZbzbP/uKLLybrt9122xX3VDaO7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQhLl72war1Wper9fbNl6RUs/Tnj17kvvOnTs3Wf/yyy9b6qkI48aNS9bvvvvuZP2jjz5K1pudI5AyMDCQrF+4cKHlx85r4sSJyfqZM2fa1MmlarWa6vX6sCcJcGQHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSD4e/bM119/nayvW7euYW3lypUFd3Opa665JllfsGBBw9r8+fOT+95zzz3JerPvVy/TkSNHkvV58+Yl683OAchj/PjxpT12WZoe2c1so5mdMrNDQ7ZNNrM3zOyT7HJSuW0CyGskL+N/L+mhy7atkLTL3W+XtCu7DaCDNQ27u78l6fJz/+ZJ2pRd3yTp0WLbAlC0Vj+gu8ndT0hSdnljozua2WIzq5tZva+vr8XhAORV+qfx7r7e3WvuXuvq6ip7OAANtBr2k2bWLUnZ5aniWgJQhlbDvkPSwuz6QkmvFdMOgLI0nWc3sy2S7pM0xcyOS1olaY2kP5pZj6S/SPpJmU0W4eOPP07WFy1alKzv27evwG4utXTp0mR9+fLlyfq0adOKbKdjHDx4MFkvcx69u7s7WX/nnXdKG7ssTcPu7o3O2PhRwb0AKBGnywJBEHYgCMIOBEHYgSAIOxBEmD9xfe6555L1PFNrzZb3ffbZZ5P1Zsv/Tpgw4Yp7Gg2OHTuWrD/zzDOljd3s/9m2bduS9VtvvbXIdtqCIzsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBBFmnv31119P1q+6Kv177+abb25YW716dXLfp59+Olkfy86dO9ew9sgjjyT3LXPZ47Vr1ybrzb5iezTiyA4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQYSZZ3/ppZeS9bNnzybrTz31VMPa9ddf30pLY8LAwECynvqa7N7e3qLbucTMmTMb1pp9x8BYxJEdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4IIM8/e09NTdQuj0vnz55P1/fv3J+sbNmwosp1L3HXXXcn67t27G9aafX/BWNT0v9jMNprZKTM7NGTbajP7q5kdyH4eLrdNAHmN5Nfb7yU9NMz2X7v7rOxnZ7FtASha07C7+1uSyvt+IABtkeeNy1Iz681e5k9qdCczW2xmdTOr9/X15RgOQB6thv23kr4vaZakE5J+2eiO7r7e3WvuXuvq6mpxOAB5tRR2dz/p7ufd/YKk30maU2xbAIrWUtjNrHvIzfmSDjW6L4DO0HSe3cy2SLpP0hQzOy5plaT7zGyWJJd0VNLPy2sRVdq6dWuy/uSTT5Y29owZM5L1LVu2JOsTJ04ssp1Rr2nY3X3BMJvLO1MCQCninUYEBEXYgSAIOxAEYQeCIOxAEGH+xBXD27t3b7K+ZMmSNnXybc2m/e644442dTI2cGQHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSCYZx/jvvnmm2R92bJlyXp/f3+B3Vzq+eefT9Znz55d2tgRcWQHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSCYZx/jXnnllWR93759pY7/2GOPNaytWLEiue+4ceOKbic0juxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EATz7GPc8uXLS338KVOmJOubN29uWLv66quLbgcJTY/sZjbVzHab2WEz+9DMfpFtn2xmb5jZJ9nlpPLbBdCqkbyMH5C03N3vlPSPkpaY2UxJKyTtcvfbJe3KbgPoUE3D7u4n3P397Hq/pMOSbpE0T9Km7G6bJD1aUo8ACnBFH9CZ2XRJsyXtlXSTu5+QBn8hSLqxwT6LzaxuZvW+vr6c7QJo1YjDbmYTJG2TtMzdz450P3df7+41d691dXW10iOAAowo7Gb2HQ0G/Q/uvj3bfNLMurN6t6RT5bQIoAhNp97MzCRtkHTY3X81pLRD0kJJa7LL10rpEE2tWbOmYe3kyZOljv3EE08k60yvdY6RzLPfK+lnkj4wswPZtpUaDPkfzaxH0l8k/aSUDgEUomnY3X2PJGtQ/lGx7QAoC6fLAkEQdiAIwg4EQdiBIAg7EAR/4joKfPHFF8n6unXrSht7/vz5yfratWtLGxvF4sgOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0Ewzz4KHDx4MFkv8+u+pk2blqyzrPLowZEdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Jgnn0U2LlzZ2VjP/jgg5WNjWJxZAeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIEayPvtUSZsl3SzpgqT17v4bM1st6WlJF/+YeqW7VzchPIYtWrQoWX/55Zcb1s6ePZvcd+nSpcn6Aw88kKxj9BjJSTUDkpa7+/tm9l1J+83sjaz2a3dnlQBgFBjJ+uwnJJ3Irveb2WFJt5TdGIBiXdF7djObLmm2pL3ZpqVm1mtmG81sUoN9FptZ3czqZX59EoC0EYfdzCZI2iZpmbuflfRbSd+XNEuDR/5fDrefu69395q717q6uvJ3DKAlIwq7mX1Hg0H/g7tvlyR3P+nu5939gqTfSZpTXpsA8moadjMzSRskHXb3Xw3Z3j3kbvMlHSq+PQBFGcmn8fdK+pmkD8zsQLZtpaQFZjZLkks6KunnJfQHSTNnzkzWP/3004a1gYGB5L6TJ09uqSeMPiP5NH6PJBumxJw6MIpwBh0QBGEHgiDsQBCEHQiCsANBEHYgCL5Kegy47rrrqm4BowBHdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0Iwty9fYOZ9Un63yGbpkg63bYGrkyn9tapfUn01qoie7vV3Yf9/re2hv1bg5vV3b1WWQMJndpbp/Yl0Vur2tUbL+OBIAg7EETVYV9f8fgpndpbp/Yl0Vur2tJbpe/ZAbRP1Ud2AG1C2IEgKgm7mT1kZh+b2REzW1FFD42Y2VEz+8DMDphZveJeNprZKTM7NGTbZDN7w8w+yS6HXWOvot5Wm9lfs+fugJk9XFFvU81st5kdNrMPzewX2fZKn7tEX2153tr+nt3Mxkn6H0n/LOm4pPckLXD3/25rIw2Y2VFJNXev/AQMM/uhpL9J2uzu/5Bt+zdJZ9x9TfaLcpK7/0uH9LZa0t+qXsY7W62oe+gy45IelbRIFT53ib4eVxuetyqO7HMkHXH3z9z9nKStkuZV0EfHc/e3JJ25bPM8SZuy65s0+I+l7Rr01hHc/YS7v59d75d0cZnxSp+7RF9tUUXYb5F0bMjt4+qs9d5d0p/NbL+ZLa66mWHc5O4npMF/PJJurLifyzVdxrudLltmvGOeu1aWP8+rirAPt5RUJ83/3evuP5A0V9KS7OUqRmZEy3i3yzDLjHeEVpc/z6uKsB+XNHXI7e9J+ryCPobl7p9nl6ckvarOW4r65MUVdLPLUxX38/86aRnv4ZYZVwc8d1Uuf15F2N+TdLuZzTCz8ZJ+KmlHBX18i5ldm31wIjO7VtKP1XlLUe+QtDC7vlDSaxX2colOWca70TLjqvi5q3z5c3dv+4+khzX4ifynkv61ih4a9PX3kg5mPx9W3ZukLRp8WfeNBl8R9Ui6QdIuSZ9kl5M7qLf/kPSBpF4NBqu7ot7+SYNvDXslHch+Hq76uUv01ZbnjdNlgSA4gw4IgrADQRB2IAjCDgRB2IEgCDsQBGEHgvg/av0i7OGhSbEAAAAASUVORK5CYII=\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from PIL import Image\n",
"\n",
"\n",
"image = Image.open('data/mnist_pngs/613.png')\n",
"plt.imshow(image, cmap='Greys')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "039753a3",
"metadata": {},
"source": [
"- Note that we used a resize-transformation in the `DataModule` that rescaled the 28x28 MNIST images to size 32x32. We also have to apply the same transformation to any new image that we feed to the model:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "a8944f63",
"metadata": {},
"outputs": [],
"source": [
"resize_transform = transforms.Compose(\n",
" [transforms.Resize((32, 32)),\n",
" transforms.ToTensor()])\n",
"\n",
"image_chw = resize_transform(image)"
]
},
{
"cell_type": "markdown",
"id": "0937a26c",
"metadata": {},
"source": [
"- Note that `ToTensor` returns the image in the CHW format. CHW refers to the dimensions and stands for channel, height, and width."
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "e4b2b30e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 32, 32])\n"
]
}
],
"source": [
"print(image_chw.shape)"
]
},
{
"cell_type": "markdown",
"id": "6a36612a",
"metadata": {},
"source": [
"- However, the PyTorch / PyTorch Lightning model expectes images in NCHW format, where N stands for the number of images (e.g., in a batch).\n",
"- We can add the additional channel dimension via `unsqueeze` as shown below:"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "740d681f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 1, 32, 32])\n"
]
}
],
"source": [
"image_nchw = image_chw.unsqueeze(0)\n",
"print(image_nchw.shape)"
]
},
{
"cell_type": "markdown",
"id": "774b7498",
"metadata": {},
"source": [
"- Now that we have the image in the right format, we can feed it to our classifier:"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "2be68b34",
"metadata": {},
"outputs": [],
"source": [
"with torch.no_grad(): # since we don't need to backprop\n",
" logits = lightning_model(image_nchw)\n",
" probas = torch.softmax(logits, axis=1)\n",
" predicted_label = torch.argmax(probas)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "10c37394",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicted label: 7\n",
"Class-membership probability 99.99%\n"
]
}
],
"source": [
"print(f'Predicted label: {predicted_label}')\n",
"print(f'Class-membership probability {probas[0][predicted_label]*100:.2f}%')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}