{ "cells": [ { "cell_type": "markdown", "id": "561ef1d6-e8fc-431f-9c58-4c862b2813ec", "metadata": {}, "source": [ "# PyTorch DataLoader State and Nested Iterations" ] }, { "cell_type": "code", "execution_count": 1, "id": "37d57e04-facc-4543-9939-c724a57ce9c6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'2.1.0'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "torch.__version__" ] }, { "cell_type": "markdown", "id": "9fb9f618-f274-4f76-951f-40508cb85c1d", "metadata": {}, "source": [ "Iterating over a dataloader in a separate function will not affect its state in the main training loop. In PyTorch, a DataLoader is typically an iterable that can be iterated over multiple times independently. Each iteration over the DataLoader starts from the beginning and goes through the dataset in a fresh sequence (if shuffle is true, the sequence will be different each time).\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "1813f041-0366-4be3-890f-99c1a9c9d831", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "main loop: 1\n", "nested loop: 1\n", "nested loop: 2\n", "nested loop: 3\n", "nested loop: 4\n", "nested loop: 5\n", "main loop: 2\n", "nested loop: 1\n", "nested loop: 2\n", "nested loop: 3\n", "nested loop: 4\n", "nested loop: 5\n", "main loop: 3\n", "nested loop: 1\n", "nested loop: 2\n", "nested loop: 3\n", "nested loop: 4\n", "nested loop: 5\n", "main loop: 4\n", "nested loop: 1\n", "nested loop: 2\n", "nested loop: 3\n", "nested loop: 4\n", "nested loop: 5\n", "main loop: 5\n", "nested loop: 1\n", "nested loop: 2\n", "nested loop: 3\n", "nested loop: 4\n", "nested loop: 5\n", "main loop: 6\n", "nested loop: 1\n", "nested loop: 2\n", "nested loop: 3\n", "nested loop: 4\n", "nested loop: 5\n", "main loop: 7\n", "nested loop: 1\n", "nested loop: 2\n", "nested loop: 3\n", "nested loop: 4\n", "nested loop: 5\n", "main loop: 8\n", "nested loop: 1\n", "nested loop: 2\n", "nested loop: 3\n", "nested loop: 4\n", "nested loop: 5\n", "main loop: 9\n", "nested loop: 1\n", "nested loop: 2\n", "nested loop: 3\n", "nested loop: 4\n", "nested loop: 5\n", "main loop: 10\n", "nested loop: 1\n", "nested loop: 2\n", "nested loop: 3\n", "nested loop: 4\n", "nested loop: 5\n" ] } ], "source": [ "from torch.utils.data import Dataset, DataLoader\n", "\n", "# Custom Dataset class\n", "class IntegerDataset(Dataset):\n", " def __init__(self, start, end):\n", " self.data = list(range(start, end + 1))\n", "\n", " def __len__(self):\n", " return len(self.data)\n", "\n", " def __getitem__(self, idx):\n", " return self.data[idx]\n", "\n", "# Create a Dataset for integers 1 to 10\n", "integer_dataset = IntegerDataset(1, 10)\n", "\n", "# Create a DataLoader\n", "integer_loader = DataLoader(integer_dataset, batch_size=1, shuffle=False)\n", "\n", "# A function to estimate the loss based on a subset of training examples\n", "def calc_loss(data_loader, iters):\n", " for j in integer_loader:\n", " print(\"nested loop:\", j.item())\n", " if j >= iters: \n", " break\n", "\n", "# Example: Iterate over the DataLoader\n", "for i in integer_loader:\n", " print(\"main loop:\", i.item())\n", " calc_loss(integer_loader, iters=5)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }