{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "b1739f6e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch : 1.12.1\n", "lightning : 2022.9.8\n", "torchmetrics: 0.9.3\n", "matplotlib : 3.5.2\n", "\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -p torch,lightning,torchmetrics,matplotlib" ] }, { "cell_type": "markdown", "id": "d198b6a3", "metadata": {}, "source": [ "# Class distance weighted cross-entropy loss for ordinal regression and deep learning -- cement strength dataset" ] }, { "cell_type": "markdown", "id": "e0ffe6ee", "metadata": {}, "source": [ "Implementation of a method for ordinal regression by Polat et al 2022 [1].\n", "\n", "**Paper reference:**\n", "\n", "- [1] G Polat, I Ergenc, HT Kani, YO Alahdab, O Atug, A Temizel. \"[Class Distance Weighted Cross-Entropy Loss for Ulcerative Colitis Severity Estimation](https://arxiv.org/abs/2202.05167).\" arXiv preprint arXiv:1612.00775 (2022)." ] }, { "cell_type": "markdown", "id": "654568ec", "metadata": {}, "source": [ "**Note:**\n", " \n", "To keep the notation lean and minimal, this notebook only contains \"Class Distance Weighted Cross-Entropy Loss \"-specific comments. For more comments on the PyTorch Lightning use, please see the cross-entropy baseline notebook [baseline-light_cement.ipynb](./baseline-light_cement.ipynb)." ] }, { "cell_type": "markdown", "id": "13449160", "metadata": {}, "source": [ "## General settings and hyperparameters" ] }, { "cell_type": "code", "execution_count": 2, "id": "1acc2499", "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE = 128\n", "NUM_EPOCHS = 200\n", "LEARNING_RATE = 0.005\n", "NUM_WORKERS = 0\n", "\n", "DATA_BASEPATH = \".\"" ] }, { "cell_type": "markdown", "id": "fa5d9470-1ffa-499e-8ef4-ba4c97bb8f3c", "metadata": {}, "source": [ "## 2 - Implementing the class distance weighted cross-entropy loss" ] }, { "cell_type": "markdown", "id": "f6179955-6513-4be7-9384-87a8b056b858", "metadata": {}, "source": [ "According to the paper, the loss is described as follows:\n", "\n", "$$\\mathbf{C D WC E}=-\\sum_{i=0}^{N-1} \\log (1-\\hat{y}) \\times|i-c|^{\\text {power }}$$\n", "\n", "where\n", "\n", "- $N$: the number of class labels\n", "- $\\hat{y}$: the predicted scores\n", "- $c$: ground-truth class\n", "- power: a hyperparameter term that determines the strength of the cost coefficient" ] }, { "cell_type": "code", "execution_count": 3, "id": "55306c8a-d3e1-4555-bd48-47f85fd384c2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.3792, 0.3104, 0.3104],\n", " [0.3072, 0.4147, 0.2780],\n", " [0.4263, 0.2248, 0.3490],\n", " [0.2668, 0.2978, 0.4354]])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn.functional as F\n", "\n", "targets = torch.tensor([0, 2, 1, 2])\n", "\n", "logits = torch.tensor( [[-0.3, -0.5, -0.5], # each row is 1 training example\n", " [-0.4, -0.1, -0.5],\n", " [-0.3, -0.94, -0.5],\n", " [-0.99, -0.88, -0.5]])\n", "\n", "probas = F.softmax(logits, dim=1)\n", "probas" ] }, { "cell_type": "code", "execution_count": 4, "id": "a8628712-7e04-46b9-a06e-64221e7fe313", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([12.2654, 12.2824, 0.9848, 10.2828])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def cdw_ce_loss_naive1(probas, targets, power=5):\n", " \n", " loss = torch.zeros(probas.shape[0])\n", " for example in range(probas.shape[0]):\n", " for i in range(probas.shape[1]):\n", " loss[example] += -torch.log(1-probas[example, i]) * torch.abs(i - targets[example])**power\n", " \n", " return loss\n", " \n", "cdw_ce_loss_naive1(probas, targets)" ] }, { "cell_type": "code", "execution_count": 5, "id": "4dfa0d27-972d-4239-8a88-3864858131f0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "147 µs ± 2.64 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%timeit cdw_ce_loss_naive1(probas, targets)" ] }, { "cell_type": "code", "execution_count": 6, "id": "2ed813da-4787-4d8e-898a-e12e832a25f0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([12.2654, 12.2824, 0.9848, 10.2828])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def cdw_ce_loss_naive2(probas, targets, power=5):\n", " \n", " loss = 0.\n", " for i in range(probas.shape[1]):\n", " loss += (-torch.log(1-probas[:, i]) * torch.abs(i - targets)**power)\n", " \n", " return loss\n", " \n", "cdw_ce_loss_naive2(probas, targets)" ] }, { "cell_type": "code", "execution_count": 7, "id": "893da285-61ed-4be9-9abd-f78abb37cfcc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "33.6 µs ± 82.6 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%timeit cdw_ce_loss_naive2(probas, targets)" ] }, { "cell_type": "code", "execution_count": 8, "id": "217c1285-8535-402c-aa75-1cc91abe758c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([12.2654, 12.2824, 0.9848, 10.2828])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def cdw_ce_loss_naive3(probas, targets, power=5):\n", " \n", " labels = torch.arange(probas.shape[1]).repeat(probas.shape[0], 1)\n", " loss = (-torch.log(1-probas) * torch.abs(labels - targets.reshape(probas.shape[0], 1))**power).sum(dim=1)\n", " \n", " return loss\n", " \n", "cdw_ce_loss_naive3(probas, targets)" ] }, { "cell_type": "code", "execution_count": 9, "id": "218b4093-e98d-4b2a-990d-992a9af5a9ef", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "14.6 µs ± 61.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit cdw_ce_loss_naive3(probas, targets)" ] }, { "cell_type": "code", "execution_count": 10, "id": "3da4f855-1463-4953-9a7c-e77ee6cc78df", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(8.9538)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def cdw_ce_loss(logits, targets, power=5, reduction=\"mean\"):\n", " \n", " probas = torch.softmax(logits, dim=1)\n", " labels = torch.arange(probas.shape[1]).repeat(probas.shape[0], 1).type_as(logits)\n", " loss = (-torch.log(1-probas) * torch.abs(labels - targets.type_as(labels).\\\n", " reshape(probas.shape[0], 1))**power).sum(dim=1)\n", " \n", " if reduction == \"none\":\n", " return loss\n", " elif reduction == \"sum\":\n", " return loss.sum()\n", " elif reduction == \"mean\":\n", " return loss.mean() \n", " else:\n", " raise ValueError(\"reduction must be 'none', 'sum', or 'mean'\") \n", "\n", "cdw_ce_loss(logits, targets)" ] }, { "cell_type": "markdown", "id": "debfdfd1", "metadata": {}, "source": [ "## Implementing a `MultiLayerPerceptron` using PyTorch Lightning's `LightningModule`" ] }, { "cell_type": "code", "execution_count": 11, "id": "1922e731", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "\n", "class MultiLayerPerceptron(torch.nn.Module):\n", " def __init__(self, input_size, hidden_units, num_classes):\n", " super().__init__()\n", "\n", " self.num_classes = num_classes\n", " \n", " all_layers = []\n", " for hidden_unit in hidden_units:\n", " layer = torch.nn.Linear(input_size, hidden_unit)\n", " all_layers.append(layer)\n", " all_layers.append(torch.nn.ReLU())\n", " input_size = hidden_unit\n", "\n", " output_layer = torch.nn.Linear(hidden_units[-1], num_classes)\n", " \n", " all_layers.append(output_layer)\n", " self.model = torch.nn.Sequential(*all_layers) \n", " \n", " def forward(self, x):\n", " x = self.model(x)\n", " return x" ] }, { "cell_type": "markdown", "id": "6562578a", "metadata": {}, "source": [ "And then, we will use them below in the `LightningModule`. They are only used in the `_shared_step` method as indicated:" ] }, { "cell_type": "code", "execution_count": 12, "id": "1381f777", "metadata": {}, "outputs": [], "source": [ "import lightning as L\n", "import torchmetrics\n", "\n", "\n", "class LightningMLP(L.LightningModule):\n", " def __init__(self, model, learning_rate):\n", " super().__init__()\n", "\n", " self.learning_rate = learning_rate\n", " self.model = model\n", "\n", " self.save_hyperparameters(ignore=['model'])\n", "\n", " self.train_mae = torchmetrics.MeanAbsoluteError()\n", " self.valid_mae = torchmetrics.MeanAbsoluteError()\n", " self.test_mae = torchmetrics.MeanAbsoluteError()\n", " \n", " def forward(self, x):\n", " return self.model(x)\n", "\n", " def _shared_step(self, batch):\n", " features, true_labels = batch\n", " logits = self(features)\n", "\n", " loss = cdw_ce_loss(logits, true_labels)\n", " predicted_labels = torch.argmax(logits, dim=1)\n", " \n", " return loss, true_labels, predicted_labels\n", "\n", " def training_step(self, batch, batch_idx):\n", " loss, true_labels, predicted_labels = self._shared_step(batch)\n", " self.log(\"train_loss\", loss)\n", " self.train_mae(predicted_labels.float(), true_labels.float())\n", " self.log(\"train_mae\", self.train_mae, on_epoch=True, on_step=False)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " loss, true_labels, predicted_labels = self._shared_step(batch)\n", " self.log(\"valid_loss\", loss)\n", " self.valid_mae(predicted_labels.float(), true_labels.float())\n", " self.log(\"valid_mae\", self.valid_mae,\n", " on_epoch=True, on_step=False, prog_bar=True)\n", "\n", " def test_step(self, batch, batch_idx):\n", " loss, true_labels, predicted_labels = self._shared_step(batch)\n", " self.test_mae(predicted_labels.float(), true_labels.float())\n", " self.log(\"test_mae\", self.test_mae, on_epoch=True, on_step=False)\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", " return optimizer" ] }, { "cell_type": "markdown", "id": "0ae78a8e", "metadata": {}, "source": [ "---\n", "\n", "# Note: There Are No Changes Compared To The Baseline Below\n", "\n", "---" ] }, { "cell_type": "markdown", "id": "39ac693d", "metadata": {}, "source": [ "## Setting up the dataset" ] }, { "cell_type": "markdown", "id": "b070e95b", "metadata": {}, "source": [ "### Inspecting the dataset" ] }, { "cell_type": "code", "execution_count": 13, "id": "a26a6486", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | response | \n", "V1 | \n", "V2 | \n", "V3 | \n", "V4 | \n", "V5 | \n", "V6 | \n", "V7 | \n", "V8 | \n", "
|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "4 | \n", "540.0 | \n", "0.0 | \n", "0.0 | \n", "162.0 | \n", "2.5 | \n", "1040.0 | \n", "676.0 | \n", "28 | \n", "
| 1 | \n", "4 | \n", "540.0 | \n", "0.0 | \n", "0.0 | \n", "162.0 | \n", "2.5 | \n", "1055.0 | \n", "676.0 | \n", "28 | \n", "
| 2 | \n", "2 | \n", "332.5 | \n", "142.5 | \n", "0.0 | \n", "228.0 | \n", "0.0 | \n", "932.0 | \n", "594.0 | \n", "270 | \n", "
| 3 | \n", "2 | \n", "332.5 | \n", "142.5 | \n", "0.0 | \n", "228.0 | \n", "0.0 | \n", "932.0 | \n", "594.0 | \n", "365 | \n", "
| 4 | \n", "2 | \n", "198.6 | \n", "132.4 | \n", "0.0 | \n", "192.0 | \n", "0.0 | \n", "978.4 | \n", "825.5 | \n", "360 | \n", "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃ Test metric ┃ DataLoader 0 ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ test_mae │ 0.5199999809265137 │\n",
"└───────────────────────────┴───────────────────────────┘\n",
"\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│\u001b[36m \u001b[0m\u001b[36m test_mae \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5199999809265137 \u001b[0m\u001b[35m \u001b[0m│\n",
"└───────────────────────────┴───────────────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[{'test_mae': 0.5199999809265137}]"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.test(model=lightning_model, datamodule=data_module, ckpt_path='best')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}