"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"metrics = pd.read_csv(f\"{trainer.logger.log_dir}/metrics.csv\")\n",
"\n",
"aggreg_metrics = []\n",
"agg_col = \"epoch\"\n",
"for i, dfg in metrics.groupby(agg_col):\n",
" agg = dict(dfg.mean())\n",
" agg[agg_col] = i\n",
" aggreg_metrics.append(agg)\n",
"\n",
"df_metrics = pd.DataFrame(aggreg_metrics)\n",
"df_metrics[[\"train_loss\", \"valid_loss\"]].plot(\n",
" grid=True, legend=True, xlabel='Epoch', ylabel='Loss')\n",
"df_metrics[[\"train_acc\", \"valid_acc\"]].plot(\n",
" grid=True, legend=True, xlabel='Epoch', ylabel='ACC')"
]
},
{
"cell_type": "markdown",
"id": "00811a53",
"metadata": {},
"source": [
"- The `trainer` automatically saves the model with the best validation accuracy automatically for us, we which we can load from the checkpoint via the `ckpt_path='best'` argument; below we use the `trainer` instance to evaluate the best model on the test set:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "2f972950",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Restoring states from the checkpoint path at logs/my-model/version_29/checkpoints/epoch=7-step=2807.ckpt\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"Loaded model weights from checkpoint at logs/my-model/version_29/checkpoints/epoch=7-step=2807.ckpt\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4070de9968ea4d72b73fbcaec6a3a1c8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Testing: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃ Test metric ┃ DataLoader 0 ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ test_acc │ 0.6643999814987183 │\n",
"└───────────────────────────┴───────────────────────────┘\n",
"\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│\u001b[36m \u001b[0m\u001b[36m test_acc \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.6643999814987183 \u001b[0m\u001b[35m \u001b[0m│\n",
"└───────────────────────────┴───────────────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[{'test_acc': 0.6643999814987183}]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.test(model=lightning_model, datamodule=data_module, ckpt_path='best')"
]
},
{
"cell_type": "markdown",
"id": "77f0f212",
"metadata": {},
"source": [
"## Predicting labels of new data"
]
},
{
"cell_type": "markdown",
"id": "9f11137f",
"metadata": {},
"source": [
"- You can use the `trainer.predict` method on a new `DataLoader` or `DataModule` to apply the model to new data.\n",
"- Alternatively, you can also manually load the best model from a checkpoint as shown below:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "473d9139",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"logs/my-model/version_29/checkpoints/epoch=7-step=2807.ckpt\n"
]
}
],
"source": [
"path = trainer.checkpoint_callback.best_model_path\n",
"print(path)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "7ca9ba4c",
"metadata": {},
"outputs": [],
"source": [
"lightning_model = LightningModel.load_from_checkpoint(\n",
" path, model=pytorch_model)\n",
"lightning_model.eval();"
]
},
{
"cell_type": "markdown",
"id": "7c1db4af",
"metadata": {},
"source": [
"- Note that our PyTorch model, which is passed to the Lightning model requires input arguments. However, this is automatically being taken care of since we used `self.save_hyperparameters()` in our PyTorch model's `__init__` method.\n",
"- Now, below is an example applying the model manually. Here, pretend that the `test_dataloader` is a new data loader."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "e1c35e59",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([3, 1, 0, 0, 4])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_dataloader = data_module.test_dataloader()\n",
"\n",
"all_true_labels = []\n",
"all_predicted_labels = []\n",
"for batch in test_dataloader:\n",
" features, labels = batch\n",
" \n",
" with torch.no_grad(): # since we don't need to backprop\n",
" logits = lightning_model(features)\n",
" \n",
" predicted_labels = torch.argmax(logits, dim=1)\n",
" all_predicted_labels.append(predicted_labels)\n",
" all_true_labels.append(labels)\n",
" \n",
"all_predicted_labels = torch.cat(all_predicted_labels)\n",
"all_true_labels = torch.cat(all_true_labels)\n",
"all_predicted_labels[:5]"
]
},
{
"cell_type": "markdown",
"id": "715bf727",
"metadata": {},
"source": [
"Just as an internal check, if the model was loaded correctly, the test accuracy below should be identical to the test accuracy we saw earlier in the previous section."
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "03d5a220",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.6644 (66.44%)\n"
]
}
],
"source": [
"test_acc = torch.mean((all_predicted_labels == all_true_labels).float())\n",
"print(f'Test accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)')"
]
},
{
"cell_type": "markdown",
"id": "97a69f77",
"metadata": {},
"source": [
"## Single-image usage"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "f15ef928",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "e1029db9",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"id": "5a293e6a",
"metadata": {},
"source": [
"- Assume we have a single image as shown below:"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "36032a1a",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAaT0lEQVR4nO2dbYykVZXH/6fe+r2nZ6Znppt5BWQFFhHYFo0YV0AMa0yQDxpN1vCBOGYjyZq4H1g2WdnsF3ezavywIQ7CihsXJassZENUlrhLVJi1wZEZBGEYB2aYpnvep9+7qp6zH+phM+A9p7qf6nqq9f5/Saer7ql7n/PcqlNP1f3XOVdUFYSQP3wKnXaAEJIPDHZCIoHBTkgkMNgJiQQGOyGRwGAnJBJKrXQWkZsBfB1AEcA3VfXL3uN7+tfpuo0jrRySREpBHKPYRsuiyCY5O4dCxiHtYzm2xDjWmZNvYG76bLBr5mAXkSKAfwZwE4CjAH4hIo+q6q+tPus2juDP//qeoC2L3p/jvJMgq/vB0Avoni77WMWCbRPDliRVx5O6aSmViqYtsSIQQNGxiWEqFO3zml9Kgu33/v1fmH1aebauBXBQVQ+p6hKA7wK4pYXxCCFtpJVg3wrgyHn3j6ZthJA1SCvBHvrQ9TsfSERkt4iMi8j43MyZFg5HCGmFVoL9KIDt593fBuDY2x+kqntUdUxVx3r7h1o4HCGkFVoJ9l8AuERELhSRCoBPAXh0ddwihKw2mVfjVbUmIncA+BEa0tv9qvq82weKJAmvIiLDanzi9BFXIyFrkoL9fBaSJdOmdeM1BUAK4dXz7or90jdfowDEWVWvL9k+NsSrMOVi2CZeTNQMNcHp05LOrqqPAXislTEIIfnAX9AREgkMdkIigcFOSCQw2AmJBAY7IZHQ0mr8ShF4WUgZxnPkNXe8VS6ymVXma0exz9WWHD0fxcrgaFiN8ewe9bptnJlbtDs6gxaM7BqdsaUw7xroJbt4mVl1sZNr1Ei88TPswsa6NxfOcISQPyAY7IREAoOdkEhgsBMSCQx2QiIh19V4VXs1000+WLEhOytfX/Z993Ddz7iqntUX2w3bjySxV5its8taSixBZcXHAuyST54jqvYcqlM7K+u5WfXwPOVCzKPZXvDKTkgkMNgJiQQGOyGRwGAnJBIY7IREAoOdkEjIVXoDgHrdkGsSL6nFknFqTh8bdd7jjPwCd1RxjrZmdq3xZBxPhso2JMxEGLdPVrKMavfxtobK9fl0BlQrsYaJMIQQBjshkcBgJyQSGOyERAKDnZBIYLATEgktSW8ichjANBq719dUdcztoAlQN2qJqV0TTIz3JIWTndTMD/NgngRojJpRumoHpiuejLPqB8v/vPNizZyX4Yjn32ro7Ner6olVGIcQ0kb4MZ6QSGg12BXAj0XkGRHZvRoOEULaQ6sf469T1WMishnA4yLyoqo+ef4D0jeB3QAwsH5Ti4cjhGSlpSu7qh5L/08BeBjAtYHH7FHVMVUd6+kbbOVwhJAWyBzsItInIgNv3gbwEQAHVssxQsjq0srH+C0AHk4LEpYA/Juq/tDrUBBBbyV8SBX7fWd+ISzXJWq7X3C2JqqUshVR1ELYx6qzbZFmLBzpZdJ5mpeZ9ORu45Qtl0udFEE/e3DluB7muJ1XO7bsyovMwa6qhwC8exV9IYS0EUpvhEQCg52QSGCwExIJDHZCIoHBTkgk5FpwslqtYeLoyaCta6Df7FfpDbspTvKao66hz5D/AKCeVE1bYsiDtbpX+NJ+P80qr2WhLXKS42IbtuGzj+Xu25afH9nJ4uTKZ5hXdkIigcFOSCQw2AmJBAY7IZHAYCckEnJdjT97dgY//NHPg7aBkc1mv6v/5LJg+5ahdWafcslequ+t2O9xSWLXwpuvhVdN1UmesVbwgWbbRq3yMrIzXNF5yy84i75VZ6k7cZSSbGRd+reMXoKP97xk0xm8MU1Lxi27LHhlJyQSGOyERAKDnZBIYLATEgkMdkIigcFOSCTkKr0VpIieSjjhpbph2Oy378DrwfarL7Xfq3ZssXWLEiqmreDoHTMLC8H2uWk7eaYq9rE8zajibG3laTJVo/hbwdny6optto/DA2XTtvfQvGk7u2RImE5tQE+eKqgjYbqZMEYf5zLn1+TLRpZkI18CDFOv2zIwr+yERAKDnZBIYLATEgkMdkIigcFOSCQw2AmJhKbSm4jcD+BjAKZU9Yq0bQOA7wHYBeAwgE+q6ulmY5W7SrhgZ1him90+avY7/rN9wfaDh5bMPts2v8O0dastNVUK9pgna2Hb0qItvc1VbVvdkVYqXrac8xZtSX2enFRbmjNt5Zr9Elmct+W8hXnj3BJbrisXbdmo4tiKXtqegaeuJYldU7DmSVsZJEDP6I1nvnQc35czS98CcPPb2u4E8ISqXgLgifQ+IWQN0zTY0/3WT72t+RYAD6S3HwDw8dV1ixCy2mT9zr5FVScAIP1vV54ghKwJ2r5AJyK7RWRcRMari7PtPhwhxCBrsE+KyCgApP+nrAeq6h5VHVPVsXJXX8bDEUJaJWuwPwrgtvT2bQAeWR13CCHtYjnS24MAPgRgWESOAvgSgC8DeEhEbgfwGoBPLOdg5bJgy9busHGn/bV/4OBgsP3FQxNmnwlHytt1kWnC/PykaUuMTK6l+UWzz0ZnW6u+ii1dzUzbcthi1e5XM5QXLdpy49ZN201bX2L7UZuz56o6F5YcNw/aL7mNg7aPvRUni7GYoZijk1WYOAVE604lzaSercpmwavquUK6Sl62ZxNU9dOG6casDhFC8oe/oCMkEhjshEQCg52QSGCwExIJDHZCIiHXgpP9PV247spdQdvBbkOSA/CuD18TbN//tYfNPs8+86JpG9tlZ8RtHrLlk6KG95ZbcuSYy3fY0ttW51j7Xzxr2iamzti2k29PY2gwddb+9eLiJbZ0ODdzzrTJ/LRp2zm8Mdi+bXOX2acidoagGvvsNcUsOOnIdU6GnZe9lniynCPniZHGWCrZ4WklxBUdFY9XdkIigcFOSCQw2AmJBAY7IZHAYCckEhjshERCrtJbuQCMlMPyxK/P2RLPdTdfHWyvLcyYfR778S9N295nD5u2W2+wM8CSmZPB9st3DJl9dvTZktfJqVdN22tHjpi2har9tB078ptge6k3nDkIAL/cv8+0FcSWRC/a9kemrb+/Jzyel4fmKl7edcnWm8TIKFMn602d8TwTHDmvULD3zDP1vIKxXx6A7p5e4zj2PPHKTkgkMNgJiQQGOyGRwGAnJBIY7IREQq6r8dPnTuN/nviPoO2AXmD2e//YO4Pt733vFWafAy/Z9en+66lx01Yv2skY1enwSvemPrt22tTwJtP20sHweAAwldgrsYnaK+QnToYTaIZL4dVxAFhCeGUXAHbutJOGeo0VYQBIauF5rItzXs61pw5n2yW1l8iLRpKJOKvWiZPs4m2jpc52Xl6dOWvIWs0+51rdUhns4/DKTkgkMNgJiQQGOyGRwGAnJBIY7IREAoOdkEhYzvZP9wP4GIApVb0ibbsbwGcBHE8fdpeqPtZsrHK5jG2jW4K2sydsmeGnj/wk2F49YW9NdOLIcdO2fcvFpu3Y67Z8smUg3E8Kb5h9zi3ZddrWbdhm2haXjH2cAMzN23O1Y9uFwfZSl52IMTJi+zGyJfx8AYCIPVdVS3qzlU2UKraEWSvY52yLeUBSDdfXE+c6p+rZ7HOem3O2ylqy6/xZFIv2mS3MhpPAvFp3y7myfwvAzYH2r6nqVelf00AnhHSWpsGuqk8CCJcsJYT83tDKd/Y7ROQ5EblfRNavmkeEkLaQNdjvAXAxgKsATAD4ivVAEdktIuMiMu59pyGEtJdMwa6qk6paV9UEwL0ArnUeu0dVx1R1rLfX/i01IaS9ZAp2ERk97+6tAA6sjjuEkHaxHOntQQAfAjAsIkcBfAnAh0TkKjSKZx0G8LnlHGxgcBAfuOnGoG3HkdNmv2OnwjLDuovDWwwBwC0fu8y0bd+2w7R1OdlhPZVwRlEFC2afilNzTZxsrRmvIJtTz0zqRm21gl1zrdvZesvamggAnMQxND70/S6Jka0FAKfP2jLl6VlHs1Nbptw0PBBs73FkPkdRRLVq+3H2rL1l1+SkLc8eMeoNnj5tx4Qle3rSYNNgV9VPB5rva9aPELK24C/oCIkEBjshkcBgJyQSGOyERAKDnZBIyLXg5PFTJ/CNB/8laKvN2FLT8cnwtkvV2pLZ58brrzdtN133RdPW199n2qzCgAVnGpcW7PNKnMqGw2V7TK9YIgzpxZLCGjbbDy/zCs4WSta5FZyCk8eNYpkAMHnMlqG6nYy+CzaFC5n2dneZfaD266qnx5Zmt22zswevfNe7TNszzzwbbH/66b1mn6HhoWB7uWxLiryyExIJDHZCIoHBTkgkMNgJiQQGOyGRwGAnJBJyld6KIhiqhGWS4no71329IXfUnQyqotqn9vRTYakDAEZHRk3b4GA4g2pmftbsM7RhnWnr6w2PBwAldbLePBnNaPcKLC4u2ll7ExPHTFu3I1/19feH+3TZ0tWO7fZ+fyMXbDZtpYIt55WMl0jizCEcedCTKU+fseXBroo9V8PDw8H2G264wexzzHheyiX7dc8rOyGRwGAnJBIY7IREAoOdkEhgsBMSCbmuxvf39OD9V747aEucemxLxhY+VbXrgc3Ozpu2n/3sEdPWXRk0bevWjQTbe/qHzD4XXvQO07Z5s11Dr9epC+clp1im/n77vJ5+yk64+OZ931jxsQBgaCisQoyM2ivu7xl7j2kbGQnPPQCsGxwybYMDxnl7l7mCU4TOYXrarqGnziZVVtLQ9Pw5s8/xE+HV+Jqx7RbAKzsh0cBgJyQSGOyERAKDnZBIYLATEgkMdkIiYTnbP20H8G0AI2gUHdujql8XkQ0AvgdgFxpbQH1SVe1MAADVxVlMvvxU0ObkF5hJC+rs01Mu23XJ1lnF5ADU5o+btlPzYblDKkNmn9ePvmjaikV7+r0tmfwNMsNzMjsbli8B4MirdrLLmdNTps1KxgAAMZJJrrzSrsXm1WmrOUlPG4a3OH6E2/cZdd8A4PSJcM1DAFBnO6963d6Gqlp1ZDlj+6qlJbsWnrXlVVJvTXqrAfiiql4G4H0APi8ilwO4E8ATqnoJgCfS+4SQNUrTYFfVCVV9Nr09DeAFAFsB3ALggfRhDwD4eJt8JISsAiv6zi4iuwBcDWAvgC2qOgE03hAA2AnHhJCOs+xgF5F+AN8H8AVVtX/H97v9dovIuIiMz87Z3xsJIe1lWcEuImU0Av07qvqDtHlSREZT+yiA4EqOqu5R1TFVHevrdQrzE0LaStNgFxFBYz/2F1T1q+eZHgVwW3r7NgB2dgkhpOMsJ+vtOgCfAbBfRPalbXcB+DKAh0TkdgCvAfhEs4E0qaE2fypoSxKnJhjC+kmxaG91kyzZUo01HgAU7CFRLoQlr1lni6f1W4ZMmzq10xbm7Lp2p06eMG1zc2GJ58Rxu8+rjvTW1W1fD3btsLc7mp0OZx2emrL9+PX+/abNk9fKFftlvLAQrq93+LfPmX1ee8W2ea+d/j5767CiIy13GVs2FYv23Fs7h6kj/zUNdlX9KewzvLFZf0LI2oC/oCMkEhjshEQCg52QSGCwExIJDHZCIiHXgpNLSQFHZsLyVbHkbOFjZIeVYetkRUfW8go2YtHONNIk/N44t2jLMaq2hCbOOS86GU/1ui31iRhz5ZzzwZdeMm3ObkLYOLTBtF32zkuD7V7W2Cu/ed60zcycNW2HXjpg2gpG2tupyTfMPoM9dnHO4yfsrMiS2tfODeucrb6MApcF51KcpSYmr+yERAKDnZBIYLATEgkMdkIigcFOSCQw2AmJhFylt1q9gKnpsPRWcHSG7q6wxFZ2dKGSI2t50luhaGcuFQphGaen35YA6xrOugKAgiPVVCq2nCdi91MNz8nsjF2A848vvcy01et2wZGBXvu8+3vDPg4O2hJUd0+P7Udiz+ORl+1suR6jOGevc6xSt+3jUN3WvPr77ddO/4Btq1TC89hltAMwK2l6hVZ5ZSckEhjshEQCg52QSGCwExIJDHZCIiHX1fh6XXHmXHhV1Vshn5kLb2lTLjmr2cYKPgB0Vewqt7399rZLZSP7YL5urxSXnTpiIxvsumqjo3YZ/npir5BX58JVvrdvtOfj8l3Dpq2x41eYYslOyKkn4Vpo4tRwE2dbrkLG65LAmCuZM/skTh23+kK4hiIAVAY2mbba3Ixpq86Gz3veUH8AoGAketVrdgIVr+yERAKDnZBIYLATEgkMdkIigcFOSCQw2AmJhKbSm4hsB/BtACNo6DB7VPXrInI3gM8CeLMo112q+pg3VrVWw/ETp4M2LxHGSkBxysyh5Eh5JS+BxhmzXA776I9n2yYnbfnn0GG7RhoKtrxi1TPrcvwQb+4dOaxcCEuiAFAwJDvruWyGOBKgh/W6KjmSqDgJSlJYZ9rOnrV9LBRXvoOxJ0Vaplrd9mE5OnsNwBdV9VkRGQDwjIg8ntq+pqr/tIwxCCEdZjl7vU0AmEhvT4vICwC2ttsxQsjqsqLv7CKyC8DVAPamTXeIyHMicr+IrF9t5wghq8eyg11E+gF8H8AXVPUcgHsAXAzgKjSu/F8x+u0WkXERGa/X7O94hJD2sqxgF5EyGoH+HVX9AQCo6qSq1lU1AXAvgGtDfVV1j6qOqepYsWRX0SCEtJemwS6NJcH7ALygql89r330vIfdCsDeloMQ0nGWsxp/HYDPANgvIvvStrsAfFpErgKgAA4D+FyzgQSAJGFpqO5kGlWTcHZV40NFGNUM++MAUCcry8Tp4mV5QSecQe1zk4Kz/VMh/JRKwa655mUcevJPV9GReYzLSME9lmlC2asp6Eif1rmVivbBPPnVlSkdm1cv0fLRymwDgJIhAy9UW5DeVPWnCL+cXU2dELK24C/oCIkEBjshkcBgJyQSGOyERAKDnZBIyLXg5ODgAD5y058GbfW6LSctLoWzvGpV+xd5tZo9Xq3myHxVr1/YlhjSIADUnSykuiFDNvp5Y9rnXTfOre7MVdWZj8Txo6a2NLRoDFlftMdLEkdKxbxtc2TWxLJ50qzjh+tjRrnXwpM9LUV3ZtbZbqxFfwghvycw2AmJBAY7IZHAYCckEhjshEQCg52QSMhVeuvuruDSS3euuJ+V3KZqSxOKbLKWOpKdJcm44zlyjDpZUklin5sn/1gyYK1uF6msOkVFEk86tBU7U3L0ZM+6I2HWvP3XnH6WdOh0QVKzn7Na1ZFts86j4YzXp2qc1xvHfmv24ZWdkEhgsBMSCQx2QiKBwU5IJDDYCYkEBjshkZCr9KZJgurc7Mr7Wdk/7l5YjiznZJuJI9mJ8dYo4mRdOcdS8Qo2eoUvnXOT8FNqtTeMGbLGYBcPBdwanJn8cFTWJlKkseec2Bl7ns3DmytPJrb6qXdeRp/xZ35u9uGVnZBIYLATEgkMdkIigcFOSCQw2AmJhKar8SLSDeBJAF3p4/9dVb8kIhsAfA/ALjS2f/qkqp72xioUiujvHwzavIQRqwadt1rpJcJ4K/Ve2a/EzMhx3jOdlXrxVmgznptm2CrLW0X28FatrTn2juTqD95z5tTCKxoSSpb6bk1MKHpjGttypc4YzStXmwpFZ3sq24P/ZxHADar6bjS2Z75ZRN4H4E4AT6jqJQCeSO8TQtYoTYNdG8ykd8vpnwK4BcADafsDAD7eDgcJIavDcvdnL6Y7uE4BeFxV9wLYotrYhjT9v7ltXhJCWmZZwa6qdVW9CsA2ANeKyBXLPYCI7BaRcREZn56ZzugmIaRVVrQar6pnAPw3gJsBTIrIKACk/6eMPntUdUxVxwb6B1rzlhCSmabBLiKbRGQovd0D4MMAXgTwKIDb0ofdBuCRNvlICFkFlpMIMwrgAREpovHm8JCq/qeIPAXgIRG5HcBrAD7RbCDVBIu1RcPqJQoYBcO8BA5HuvJqxrkYx3OTRdwBHaunUblSWdjm5dUU3KSbbFgjerJhlvEaRmdMw+SN58mU7tS7uq1du84a0pfejOu042DTYFfV5wBcHWg/CeDGZv0JIWsD/oKOkEhgsBMSCQx2QiKBwU5IJDDYCYkE8bLNVv1gIscBvJreHQZwIreD29CPt0I/3srvmx87VXVTyJBrsL/lwCLjqjrWkYPTD/oRoR/8GE9IJDDYCYmETgb7ng4e+3zox1uhH2/lD8aPjn1nJ4TkCz/GExIJHQl2EblZRH4jIgdFpGO160TksIjsF5F9IjKe43HvF5EpETlwXtsGEXlcRF5O/6/vkB93i8jr6ZzsE5GP5uDHdhH5iYi8ICLPi8hfpu25zonjR65zIiLdIvK/IvKr1I+/S9tbmw9VzfUPQBHAKwAuAlAB8CsAl+ftR+rLYQDDHTjuBwFcA+DAeW3/CODO9PadAP6hQ37cDeCvcp6PUQDXpLcHALwE4PK858TxI9c5QSMDtz+9XQawF8D7Wp2PTlzZrwVwUFUPqeoSgO+iUbwyGlT1SQCn3tacewFPw4/cUdUJVX02vT0N4AUAW5HznDh+5Io2WPUir50I9q0Ajpx3/yg6MKEpCuDHIvKMiOzukA9vspYKeN4hIs+lH/Pb/nXifERkFxr1Ezpa1PRtfgA5z0k7irx2IthD5Tc6JQlcp6rXAPgzAJ8XkQ92yI+1xD0ALkZjj4AJAF/J68Ai0g/g+wC+oKrn8jruMvzIfU60hSKvFp0I9qMAtp93fxuAYx3wA6p6LP0/BeBhNL5idIplFfBsN6o6mb7QEgD3Iqc5EZEyGgH2HVX9Qdqc+5yE/OjUnKTHPoMVFnm16ESw/wLAJSJyoYhUAHwKjeKVuSIifSIy8OZtAB8BcMDv1VbWRAHPN19MKbcihzmRRrG1+wC8oKpfPc+U65xYfuQ9J20r8prXCuPbVhs/isZK5ysA/qZDPlyEhhLwKwDP5+kHgAfR+DhYReOTzu0ANqKxjdbL6f8NHfLjXwHsB/Bc+uIazcGPD6DxVe45APvSv4/mPSeOH7nOCYArAfwyPd4BAH+btrc0H/wFHSGRwF/QERIJDHZCIoHBTkgkMNgJiQQGOyGRwGAnJBIY7IREAoOdkEj4P0w6IIvRV/IIAAAAAElFTkSuQmCC\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from PIL import Image\n",
"\n",
"\n",
"image = Image.open('data/cifar10_pngs/90_airplane.png')\n",
"plt.imshow(image, cmap='Greys')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "df4829b7",
"metadata": {},
"source": [
"- Note that we have to use the same image transformation that we used earlier in the `DataModule`. \n",
"- While we didn't apply any image augmentation, we could use the `to_tensor` function from the torchvision library; however, as a general template that provides flexibility for more complex transformation chains, let's use the `Compose` class for this:"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "d990431b",
"metadata": {},
"outputs": [],
"source": [
"transform = transforms.Compose([transforms.ToTensor()])\n",
"\n",
"image_chw = resize_transform(image)"
]
},
{
"cell_type": "markdown",
"id": "811de2fd",
"metadata": {},
"source": [
"- Note that `ToTensor` returns the image in the CHW format. CHW refers to the dimensions and stands for channel, height, and width."
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "14785180",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([3, 32, 32])\n"
]
}
],
"source": [
"print(image_chw.shape)"
]
},
{
"cell_type": "markdown",
"id": "d17591e9",
"metadata": {},
"source": [
"- However, the PyTorch / PyTorch Lightning model expectes images in NCHW format, where N stands for the number of images (e.g., in a batch).\n",
"- We can add the additional channel dimension via `unsqueeze` as shown below:"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "b4422205",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 3, 32, 32])\n"
]
}
],
"source": [
"image_nchw = image_chw.unsqueeze(0)\n",
"print(image_nchw.shape)"
]
},
{
"cell_type": "markdown",
"id": "cdd09c4c",
"metadata": {},
"source": [
"- Now that we have the image in the right format, we can feed it to our classifier:"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "6e76be24",
"metadata": {},
"outputs": [],
"source": [
"with torch.no_grad(): # since we don't need to backprop\n",
" logits = lightning_model(image_nchw)\n",
" probas = torch.softmax(logits, axis=1)\n",
" predicted_label = torch.argmax(probas)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "88b1576d",
"metadata": {},
"outputs": [],
"source": [
"int_to_str = {\n",
" 0: 'airplane',\n",
" 1: 'automobile',\n",
" 2: 'bird',\n",
" 3: 'cat',\n",
" 4: 'deer',\n",
" 5: 'dog',\n",
" 6: 'frog',\n",
" 7: 'horse',\n",
" 8: 'ship',\n",
" 9: 'truck'}"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "ceabf906",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicted label: airplane\n",
"Class-membership probability 98.94%\n"
]
}
],
"source": [
"print(f'Predicted label: {int_to_str[predicted_label.item()]}')\n",
"print(f'Class-membership probability {probas[0][predicted_label]*100:.2f}%')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}