{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualization in TorchOpt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[
](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/2_Visualization.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In [PyTorch](https://pytorch.org), if the attribute `requires_grad` of a tensor is `True`, the computation graph will be created if we use the tensor to do any operations. The computation graph is implemented like link-list -- `Tensor`s are nodes and they are linked by their attribute `gran_fn`. [PyTorchViz](https://github.com/szagoruyko/pytorchviz) is a Python package that uses [Graphviz](https://graphviz.org) as a backend for plotting computation graphs. TorchOpt use PyTorchViz as the blueprint and provide more easy-to-use visualization functions on the premise of supporting all its functions."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's start with a simple multiplication computation graph. We declared the variable `x` with flag `requires_grad=True` and compute `y = 2 * x`. Then we visualize the computation graph of `y`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"image/svg+xml": "\n\n\n\n\n"
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from IPython.display import display\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"import torchopt\n",
"\n",
"\n",
"x = torch.tensor(1.0, requires_grad=True)\n",
"y = 2 * x\n",
"display(torchopt.visual.make_dot(y, params={'x': x, 'y': y}))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The figure shows `y` is connected by the multiplication edge. The gradient of `y` will flow through the multiplication backward function then accumulated on `x`. Note that we pass a dictionary for adding node labels."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then let's plot a neural network. Note that we can pass the generator returned by method `named_parameters` for adding node labels."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"image/svg+xml": "\n\n\n\n\n"
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"class Net(nn.Module):\n",
" def __init__(self, dim):\n",
" super().__init__()\n",
" self.fc = nn.Linear(dim, 1, bias=True)\n",
"\n",
" def forward(self, x):\n",
" return self.fc(x)\n",
"\n",
"\n",
"dim = 5\n",
"batch_size = 2\n",
"net = Net(dim)\n",
"xs = torch.ones((batch_size, dim))\n",
"ys = torch.ones((batch_size, 1))\n",
"pred = net(xs)\n",
"loss = F.mse_loss(pred, ys)\n",
"\n",
"display(torchopt.visual.make_dot(loss, params=(net.named_parameters(), {'loss': loss})))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The computation graph of meta-learning algorithms will be much more complex. Our visualization tool allows users take as input the extracted network state for better visualization."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"image/svg+xml": "\n\n\n\n\n"
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"class MetaNet(nn.Module):\n",
" def __init__(self, dim):\n",
" super().__init__()\n",
" self.fc = nn.Linear(dim, 1, bias=True)\n",
"\n",
" def forward(self, x, meta_param):\n",
" return self.fc(x) + meta_param\n",
"\n",
"\n",
"dim = 5\n",
"batch_size = 2\n",
"net = MetaNet(dim)\n",
"\n",
"xs = torch.ones((batch_size, dim))\n",
"ys = torch.ones((batch_size, 1))\n",
"\n",
"optimizer = torchopt.MetaSGD(net, lr=1e-3)\n",
"meta_param = torch.tensor(1.0, requires_grad=True)\n",
"\n",
"# Set enable_visual\n",
"net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n",
"\n",
"pred = net(xs, meta_param)\n",
"loss = F.mse_loss(pred, ys)\n",
"optimizer.step(loss)\n",
"\n",
"# Set enable_visual\n",
"net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')\n",
"\n",
"pred = net(xs, meta_param)\n",
"loss = F.mse_loss(pred, torch.ones_like(pred))\n",
"\n",
"# Draw computation graph\n",
"display(\n",
" torchopt.visual.make_dot(\n",
" loss, [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}]\n",
" )\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
},
"vscode": {
"interpreter": {
"hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}