{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Visualization in TorchOpt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/2_Visualization.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In [PyTorch](https://pytorch.org), if the attribute `requires_grad` of a tensor is `True`, the computation graph will be created if we use the tensor to do any operations. The computation graph is implemented like link-list -- `Tensor`s are nodes and they are linked by their attribute `gran_fn`. [PyTorchViz](https://github.com/szagoruyko/pytorchviz) is a Python package that uses [Graphviz](https://graphviz.org) as a backend for plotting computation graphs. TorchOpt use PyTorchViz as the blueprint and provide more easy-to-use visualization functions on the premise of supporting all its functions." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start with a simple multiplication computation graph. We declared the variable `x` with flag `requires_grad=True` and compute `y = 2 * x`. Then we visualize the computation graph of `y`." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140534064715952\n\ny\n()\n\n\n\n140534064838304\n\nMulBackward0\n\n\n\n140534064838304->140534064715952\n\n\n\n\n\n140534064837776\n\nAccumulateGrad\n\n\n\n140534064837776->140534064838304\n\n\n\n\n\n140534064714832\n\nx\n()\n\n\n\n140534064714832->140534064837776\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from IPython.display import display\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "import torchopt\n", "\n", "\n", "x = torch.tensor(1.0, requires_grad=True)\n", "y = 2 * x\n", "display(torchopt.visual.make_dot(y, params={'x': x, 'y': y}))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The figure shows `y` is connected by the multiplication edge. The gradient of `y` will flow through the multiplication backward function then accumulated on `x`. Note that we pass a dictionary for adding node labels." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then let's plot a neural network. Note that we can pass the generator returned by method `named_parameters` for adding node labels." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140534659780336\n\nloss\n()\n\n\n\n140531595570768\n\nMseLossBackward0\n\n\n\n140531595570768->140534659780336\n\n\n\n\n\n140531595570576\n\nAddmmBackward0\n\n\n\n140531595570576->140531595570768\n\n\n\n\n\n140531595570528\n\nAccumulateGrad\n\n\n\n140531595570528->140531595570576\n\n\n\n\n\n140531595583632\n\nfc.bias\n(1)\n\n\n\n140531595583632->140531595570528\n\n\n\n\n\n140531595571104\n\nTBackward0\n\n\n\n140531595571104->140531595570576\n\n\n\n\n\n140531595570432\n\nAccumulateGrad\n\n\n\n140531595570432->140531595571104\n\n\n\n\n\n140531595582816\n\nfc.weight\n(1, 5)\n\n\n\n140531595582816->140531595570432\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "class Net(nn.Module):\n", " def __init__(self, dim):\n", " super().__init__()\n", " self.fc = nn.Linear(dim, 1, bias=True)\n", "\n", " def forward(self, x):\n", " return self.fc(x)\n", "\n", "\n", "dim = 5\n", "batch_size = 2\n", "net = Net(dim)\n", "xs = torch.ones((batch_size, dim))\n", "ys = torch.ones((batch_size, 1))\n", "pred = net(xs)\n", "loss = F.mse_loss(pred, ys)\n", "\n", "display(torchopt.visual.make_dot(loss, params=(net.named_parameters(), {'loss': loss})))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The computation graph of meta-learning algorithms will be much more complex. Our visualization tool allows users take as input the extracted network state for better visualization." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140531595614064\n\nloss\n()\n\n\n\n140531595567168\n\nMseLossBackward0\n\n\n\n140531595567168->140531595614064\n\n\n\n\n\n140531595569232\n\nAddBackward0\n\n\n\n140531595569232->140531595567168\n\n\n\n\n\n140531595568800\n\nAddmmBackward0\n\n\n\n140531595568800->140531595569232\n\n\n\n\n\n140534660247264\n\nAddBackward0\nstep1.fc.bias\n(1)\n\n\n\n140534660247264->140531595568800\n\n\n\n\n\n140534553595376\n\nAccumulateGrad\n\n\n\n140534553595376->140534660247264\n\n\n\n\n\n140534553592832\n\nAddmmBackward0\n\n\n\n140534553595376->140534553592832\n\n\n\n\n\n140534064448352\n\nstep0.fc.bias\n(1)\n\n\n\n140534064448352->140534553595376\n\n\n\n\n\n140534553595616\n\nMulBackward0\n\n\n\n140534553595616->140534660247264\n\n\n\n\n\n140534553594848\n\nViewBackward0\n\n\n\n140534553594848->140534553595616\n\n\n\n\n\n140534553594992\n\nSumBackward1\n\n\n\n140534553594992->140534553594848\n\n\n\n\n\n140534553594800\n\nMseLossBackwardBackward0\n\n\n\n140534553594800->140534553594992\n\n\n\n\n\n140531595617904\n\nTBackward0\n\n\n\n140534553594800->140531595617904\n\n\n\n\n\n140534553593072\n\nAddBackward0\n\n\n\n140534553593072->140534553594800\n\n\n\n\n\n140534553592832->140534553593072\n\n\n\n\n\n140534553593456\n\nTBackward0\n\n\n\n140534553593456->140534553592832\n\n\n\n\n\n140534553593888\n\nAccumulateGrad\n\n\n\n140534553593888->140534553593456\n\n\n\n\n\n140531595572368\n\nAddBackward0\nstep1.fc.weight\n(1, 5)\n\n\n\n140534553593888->140531595572368\n\n\n\n\n\n140531595612944\n\nstep0.fc.weight\n(1, 5)\n\n\n\n140531595612944->140534553593888\n\n\n\n\n\n140531595567888\n\nAccumulateGrad\n\n\n\n140531595567888->140531595569232\n\n\n\n\n\n140531595567888->140534553593072\n\n\n\n\n\n140531595613184\n\nmeta_param\n()\n\n\n\n140531595613184->140531595567888\n\n\n\n\n\n140534553594272\n\nTBackward0\n\n\n\n140534553594272->140531595568800\n\n\n\n\n\n140531595572368->140534553594272\n\n\n\n\n\n140534553593504\n\nMulBackward0\n\n\n\n140534553593504->140531595572368\n\n\n\n\n\n140534553592976\n\nTBackward0\n\n\n\n140534553592976->140534553593504\n\n\n\n\n\n140534553593216\n\nTBackward0\n\n\n\n140534553593216->140534553592976\n\n\n\n\n\n140534553593552\n\nMmBackward0\n\n\n\n140534553593552->140534553593216\n\n\n\n\n\n140531595617904->140534553593552\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "class MetaNet(nn.Module):\n", " def __init__(self, dim):\n", " super().__init__()\n", " self.fc = nn.Linear(dim, 1, bias=True)\n", "\n", " def forward(self, x, meta_param):\n", " return self.fc(x) + meta_param\n", "\n", "\n", "dim = 5\n", "batch_size = 2\n", "net = MetaNet(dim)\n", "\n", "xs = torch.ones((batch_size, dim))\n", "ys = torch.ones((batch_size, 1))\n", "\n", "optimizer = torchopt.MetaSGD(net, lr=1e-3)\n", "meta_param = torch.tensor(1.0, requires_grad=True)\n", "\n", "# Set enable_visual\n", "net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n", "\n", "pred = net(xs, meta_param)\n", "loss = F.mse_loss(pred, ys)\n", "optimizer.step(loss)\n", "\n", "# Set enable_visual\n", "net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')\n", "\n", "pred = net(xs, meta_param)\n", "loss = F.mse_loss(pred, torch.ones_like(pred))\n", "\n", "# Draw computation graph\n", "display(\n", " torchopt.visual.make_dot(\n", " loss, [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}]\n", " )\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.15" }, "vscode": { "interpreter": { "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" } } }, "nbformat": 4, "nbformat_minor": 4 }