{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "66951a8c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch : 1.10.1\n",
"pytorch_lightning: 1.6.0.dev0\n",
"torchmetrics : 0.6.2\n",
"matplotlib : 3.3.4\n",
"coral_pytorch : 1.2.0\n",
"\n"
]
}
],
"source": [
"%load_ext watermark\n",
"%watermark -p torch,pytorch_lightning,torchmetrics,matplotlib,coral_pytorch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "60a1d929",
"metadata": {},
"outputs": [],
"source": [
"%load_ext pycodestyle_magic\n",
"%flake8_on --ignore W291,W293,E703"
]
},
{
"cell_type": "markdown",
"id": "6dc0efd6",
"metadata": {},
"source": [
"
\n",
"\n",
"# Cross entropy baseline model for ordinal regression and deep learning -- cement strength dataset"
]
},
{
"cell_type": "markdown",
"id": "47f22067",
"metadata": {},
"source": [
"This is a regular cross entropy classifier as a baseline for comparison with ordinal regression methods."
]
},
{
"cell_type": "markdown",
"id": "6c528557",
"metadata": {},
"source": [
"## General settings and hyperparameters"
]
},
{
"cell_type": "markdown",
"id": "2527be84",
"metadata": {},
"source": [
"- Here, we specify some general hyperparameter values and general settings\n",
"- Note that for small datatsets, it is not necessary and better not to use multiple workers as it can sometimes cause issues with too many open files in PyTorch"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "37d461ef",
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE = 128\n",
"NUM_EPOCHS = 200\n",
"LEARNING_RATE = 0.005\n",
"NUM_WORKERS = 0\n",
"\n",
"DATA_BASEPATH = \".\""
]
},
{
"cell_type": "markdown",
"id": "ddfe5704",
"metadata": {},
"source": [
"## Implementing a `MultiLayerPerceptron` using PyTorch Lightning's `LightningModule`"
]
},
{
"cell_type": "markdown",
"id": "829a6664",
"metadata": {},
"source": [
"- In this section, we set up the main model architecture using the `LightningModule` from PyTorch Lightning.\n",
"- We start with defining our `MultiLayerPerceptron` model in pure PyTorch, and then we use it in the `LightningModule` to get all the extra benefits that PyTorch Lightning provides."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c03faadc",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"# Regular PyTorch Module\n",
"class MultiLayerPerceptron(torch.nn.Module):\n",
" def __init__(self, input_size, hidden_units, num_classes):\n",
" super().__init__()\n",
"\n",
" # num_classes is used by the corn loss function\n",
" self.num_classes = num_classes\n",
" \n",
" # Initialize MLP layers\n",
" all_layers = []\n",
" for hidden_unit in hidden_units:\n",
" layer = torch.nn.Linear(input_size, hidden_unit)\n",
" all_layers.append(layer)\n",
" all_layers.append(torch.nn.ReLU())\n",
" input_size = hidden_unit\n",
"\n",
" output_layer = torch.nn.Linear(hidden_units[-1], num_classes)\n",
" \n",
" all_layers.append(output_layer)\n",
" self.model = torch.nn.Sequential(*all_layers)\n",
" \n",
" def forward(self, x):\n",
" x = self.model(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "8ff4cacb",
"metadata": {},
"source": [
"- In our `LightningModule` we use loggers to track mean absolute errors for both the training and validation set during training; this allows us to select the best model based on validation set performance later."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "0c14a021",
"metadata": {},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"import torchmetrics\n",
"\n",
"\n",
"# LightningModule that receives a PyTorch model as input\n",
"class LightningMLP(pl.LightningModule):\n",
" def __init__(self, model, learning_rate):\n",
" super().__init__()\n",
"\n",
" self.learning_rate = learning_rate\n",
" # The inherited PyTorch module\n",
" self.model = model\n",
"\n",
" # Save settings and hyperparameters to the log directory\n",
" # but skip the model parameters\n",
" self.save_hyperparameters(ignore=['model'])\n",
"\n",
" # Set up attributes for computing the MAE\n",
" self.train_mae = torchmetrics.MeanAbsoluteError()\n",
" self.valid_mae = torchmetrics.MeanAbsoluteError()\n",
" self.test_mae = torchmetrics.MeanAbsoluteError()\n",
" \n",
" # Defining the forward method is only necessary \n",
" # if you want to use a Trainer's .predict() method (optional)\n",
" def forward(self, x):\n",
" return self.model(x)\n",
" \n",
" # A common forward step to compute the loss and labels\n",
" # this is used for training, validation, and testing below\n",
" def _shared_step(self, batch):\n",
" features, true_labels = batch\n",
" logits = self(features)\n",
" loss = torch.nn.functional.cross_entropy(logits, true_labels)\n",
" predicted_labels = torch.argmax(logits, dim=1)\n",
"\n",
" return loss, true_labels, predicted_labels\n",
"\n",
" def training_step(self, batch, batch_idx):\n",
" loss, true_labels, predicted_labels = self._shared_step(batch)\n",
" self.log(\"train_loss\", loss)\n",
" self.train_mae(predicted_labels, true_labels)\n",
" self.log(\"train_mae\", self.train_mae, on_epoch=True, on_step=False)\n",
" return loss # this is passed to the optimzer for training\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
" loss, true_labels, predicted_labels = self._shared_step(batch)\n",
" self.log(\"valid_loss\", loss)\n",
" self.valid_mae(predicted_labels, true_labels)\n",
" self.log(\"valid_mae\", self.valid_mae,\n",
" on_epoch=True, on_step=False, prog_bar=True)\n",
"\n",
" def test_step(self, batch, batch_idx):\n",
" loss, true_labels, predicted_labels = self._shared_step(batch)\n",
" self.test_mae(predicted_labels, true_labels)\n",
" self.log(\"test_mae\", self.test_mae, on_epoch=True, on_step=False)\n",
"\n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" return optimizer"
]
},
{
"cell_type": "markdown",
"id": "6ec9d753",
"metadata": {},
"source": [
"## Setting up the dataset"
]
},
{
"cell_type": "markdown",
"id": "5621120c",
"metadata": {},
"source": [
"- In this section, we are going to set up our dataset.\n",
"- We start by downloading and taking a look at the Cement dataset:"
]
},
{
"cell_type": "markdown",
"id": "5311a22d",
"metadata": {},
"source": [
"### Inspecting the dataset"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "27c9ba77",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
| \n", " | response | \n", "V1 | \n", "V2 | \n", "V3 | \n", "V4 | \n", "V5 | \n", "V6 | \n", "V7 | \n", "V8 | \n", "
|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "4 | \n", "540.0 | \n", "0.0 | \n", "0.0 | \n", "162.0 | \n", "2.5 | \n", "1040.0 | \n", "676.0 | \n", "28 | \n", "
| 1 | \n", "4 | \n", "540.0 | \n", "0.0 | \n", "0.0 | \n", "162.0 | \n", "2.5 | \n", "1055.0 | \n", "676.0 | \n", "28 | \n", "
| 2 | \n", "2 | \n", "332.5 | \n", "142.5 | \n", "0.0 | \n", "228.0 | \n", "0.0 | \n", "932.0 | \n", "594.0 | \n", "270 | \n", "
| 3 | \n", "2 | \n", "332.5 | \n", "142.5 | \n", "0.0 | \n", "228.0 | \n", "0.0 | \n", "932.0 | \n", "594.0 | \n", "365 | \n", "
| 4 | \n", "2 | \n", "198.6 | \n", "132.4 | \n", "0.0 | \n", "192.0 | \n", "0.0 | \n", "978.4 | \n", "825.5 | \n", "360 | \n", "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃ Test metric ┃ DataLoader 0 ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ test_mae │ 0.28999999165534973 │\n",
"└───────────────────────────┴───────────────────────────┘\n",
"\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│\u001b[36m \u001b[0m\u001b[36m test_mae \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.28999999165534973 \u001b[0m\u001b[35m \u001b[0m│\n",
"└───────────────────────────┴───────────────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[{'test_mae': 0.28999999165534973}]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.test(model=lightning_model, datamodule=data_module, ckpt_path='best')"
]
},
{
"cell_type": "markdown",
"id": "ec7d93a1",
"metadata": {},
"source": [
"- The MAE of our model is quite good, especially compared to the 1.03 MAE baseline earlier."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}