{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# TorchOpt as Meta-Optimizer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/3_Meta_Optimizer.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial, we will show how to treat TorchOpt as a differentiable optimizer with traditional PyTorch optimization API. In addition, we also provide many other API for easy meta-learning algorithm implementations." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Basic API for Differentiable Optimizer\n", "\n", "`MetaOptimizer` is the main class for our differentiable optimizer. Combined with the functional optimizer `torchopt.sgd` and `torchopt.adam` mentioned in the tutorial 1, we can define our high-level API `torchopt.MetaSGD` and `torchopt.MetaAdam`. We will discuss how this combination happens with `torchopt.chain` in Section 3. Let us consider the problem below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Assume a tensor $x$ is a meta-parameter and $a$ is a normal parameters (such as network parameters). We have inner loss $\\mathcal{L}^{\\textrm{in}} = a_0 \\cdot x^2$ and we update $a$ use the gradient $\\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = x^2$ and $a_1 = a_0 - \\eta \\, \\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = a_0 - \\eta \\, x^2$. Then we compute the outer loss $\\mathcal{L}^{\\textrm{out}} = a_1 \\cdot x^2$. So the gradient of outer loss to $x$ would be:\n", "\n", "$$\n", "\\begin{split}\n", " \\frac{\\partial \\mathcal{L}^{\\textrm{out}}}{\\partial x}\n", " & = \\frac{\\partial (a_1 \\cdot x^2)}{\\partial x} \\\\\n", " & = \\frac{\\partial a_1}{\\partial x} \\cdot x^2 + a_1 \\cdot \\frac{\\partial (x^2)}{\\partial x} \\\\\n", " & = \\frac{\\partial (a_0 - \\eta \\, x^2)}{\\partial x} \\cdot x^2 + (a_0 - \\eta \\, x^2) \\cdot 2 x \\\\\n", " & = (- \\eta \\cdot 2 x) \\cdot x^2 + (a_0 - \\eta \\, x^2) \\cdot 2 x \\\\\n", " & = - 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x\n", "\\end{split}\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given the analytical solution above. Let's try to verify it with TorchOpt. Define the net work first." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from IPython.display import display\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "import torchopt\n", "\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.a = nn.Parameter(torch.tensor(1.0), requires_grad=True)\n", "\n", " def forward(self, x):\n", " return self.a * (x**2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we declare the network (parameterized by `a`) and the meta-parameter `x`. Do not forget to set flag `requires_grad=True` for `x`." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "net = Net()\n", "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we declare the meta-optimizer. Here we show two equivalent ways of defining the meta-optimizer. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Low-level API\n", "optim = torchopt.MetaOptimizer(net, torchopt.sgd(lr=1.0))\n", "\n", "# High-level API\n", "optim = torchopt.MetaSGD(net, lr=1.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The meta-optimizer takes the network as input and use method `step` to update the network (parameterized by `a`). Finally, we show how a bi-level process works." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x.grad = tensor(-28.)\n" ] } ], "source": [ "inner_loss = net(x)\n", "optim.step(inner_loss)\n", "\n", "outer_loss = net(x)\n", "outer_loss.backward()\n", "# x.grad = - 4 * lr * x^3 + 2 * a_0 * x\n", "# = - 4 * 1 * 2^3 + 2 * 1 * 2\n", "# = -32 + 4\n", "# = -28\n", "print(f'x.grad = {x.grad!r}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.1 Track the Gradient of Momentum\n", "\n", "Note that most modern optimizers involve moment term in the gradient update (basically only SGD with `momentum=0` does not involve). We provide an option for user to choose whether to also track the meta-gradient through moment term. The default option is `moment_requires_grad=True`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- When you do not track the meta-gradient through moment (`moment_requires_grad=False`)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140447553047184\n\nouter_loss\n()\n\n\n\n140447553041216\n\nMseLossBackward0\n\n\n\n140447553041216->140447553047184\n\n\n\n\n\n140447553042896\n\nMulBackward0\n\n\n\n140447553042896->140447553041216\n\n\n\n\n\n140447553019088\n\nAddBackward0\nstep1.a\n()\n\n\n\n140447553019088->140447553042896\n\n\n\n\n\n140447553041072\n\nAccumulateGrad\n\n\n\n140447553041072->140447553019088\n\n\n\n\n\n140447553043664\n\nMulBackward0\n\n\n\n140447553041072->140447553043664\n\n\n\n\n\n140447553045344\n\nstep0.a\n()\n\n\n\n140447553045344->140447553041072\n\n\n\n\n\n140447553041120\n\nMulBackward0\n\n\n\n140447553041120->140447553019088\n\n\n\n\n\n140447553043040\n\nDivBackward0\n\n\n\n140447553043040->140447553041120\n\n\n\n\n\n140447553043184\n\nDivBackward0\n\n\n\n140447553043184->140447553043040\n\n\n\n\n\n140447553043328\n\nAddBackward0\n\n\n\n140447553043328->140447553043184\n\n\n\n\n\n140447553043424\n\nMulBackward0\n\n\n\n140447553043424->140447553043328\n\n\n\n\n\n140447553043856\n\nAddcmulBackward0\n\n\n\n140447553043424->140447553043856\n\n\n\n\n\n140447553043424->140447553043856\n\n\n\n\n\n140447553043520\n\nMseLossBackwardBackward0\n\n\n\n140447553043520->140447553043424\n\n\n\n\n\n140447553043664->140447553043520\n\n\n\n\n\n140447553043472\n\nPowBackward0\n\n\n\n140447553043472->140447553043424\n\n\n\n\n\n140447553043472->140447553043664\n\n\n\n\n\n140447553043808\n\nAccumulateGrad\n\n\n\n140447553043808->140447553043472\n\n\n\n\n\n140447553041264\n\nPowBackward0\n\n\n\n140447553043808->140447553041264\n\n\n\n\n\n140447553045584\n\nx\n()\n\n\n\n140447553045584->140447553043808\n\n\n\n\n\n140447553043136\n\nAddBackward0\n\n\n\n140447553043136->140447553043040\n\n\n\n\n\n140447553043232\n\nSqrtBackward0\n\n\n\n140447553043232->140447553043136\n\n\n\n\n\n140447553043760\n\nAddBackward0\n\n\n\n140447553043760->140447553043232\n\n\n\n\n\n140447553043904\n\nDivBackward0\n\n\n\n140447553043904->140447553043760\n\n\n\n\n\n140447553043856->140447553043904\n\n\n\n\n\n140447553041264->140447553042896\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "net = Net()\n", "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", "y = torch.tensor(1.0)\n", "\n", "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=False)\n", "\n", "net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n", "inner_loss = F.mse_loss(net(x), y)\n", "optim.step(inner_loss)\n", "net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')\n", "\n", "outer_loss = F.mse_loss(net(x), y)\n", "display(\n", " torchopt.visual.make_dot(\n", " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", " )\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- When you track the meta-gradient through moment (`moment_requires_grad=True`, default for `torchopt.MetaAdam`)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140447553148704\n\nouter_loss\n()\n\n\n\n140447553041024\n\nMseLossBackward0\n\n\n\n140447553041024->140447553148704\n\n\n\n\n\n140447553043424\n\nMulBackward0\n\n\n\n140447553043424->140447553041024\n\n\n\n\n\n140450536407152\n\nAddBackward0\nstep1.a\n()\n\n\n\n140450536407152->140447553043424\n\n\n\n\n\n140447553041264\n\nAccumulateGrad\n\n\n\n140447553041264->140450536407152\n\n\n\n\n\n140447553019232\n\nMulBackward0\n\n\n\n140447553041264->140447553019232\n\n\n\n\n\n140447553148064\n\nstep0.a\n()\n\n\n\n140447553148064->140447553041264\n\n\n\n\n\n140447553041216\n\nMulBackward0\n\n\n\n140447553041216->140450536407152\n\n\n\n\n\n140447553041312\n\nDivBackward0\n\n\n\n140447553041312->140447553041216\n\n\n\n\n\n140447553041408\n\nDivBackward0\n\n\n\n140447553041408->140447553041312\n\n\n\n\n\n140447553043376\n\nAddBackward0\n\n\n\n140447553043376->140447553041408\n\n\n\n\n\n140447553041168\n\nMulBackward0\n\n\n\n140447553041168->140447553043376\n\n\n\n\n\n140447553042272\n\nAccumulateGrad\n\n\n\n140447553042272->140447553041168\n\n\n\n\n\n140450290826352\n\n()\n\n\n\n140450290826352->140447553042272\n\n\n\n\n\n140447553044432\n\nMulBackward0\n\n\n\n140447553044432->140447553043376\n\n\n\n\n\n140447553018320\n\nAddcmulBackward0\n\n\n\n140447553044432->140447553018320\n\n\n\n\n\n140447553044432->140447553018320\n\n\n\n\n\n140447553042080\n\nMseLossBackwardBackward0\n\n\n\n140447553042080->140447553044432\n\n\n\n\n\n140447553019232->140447553042080\n\n\n\n\n\n140447553019088\n\nPowBackward0\n\n\n\n140447553019088->140447553044432\n\n\n\n\n\n140447553019088->140447553019232\n\n\n\n\n\n140447553018464\n\nAccumulateGrad\n\n\n\n140447553018464->140447553019088\n\n\n\n\n\n140447553043328\n\nPowBackward0\n\n\n\n140447553018464->140447553043328\n\n\n\n\n\n140447553148144\n\nx\n()\n\n\n\n140447553148144->140447553018464\n\n\n\n\n\n140447553041456\n\nAddBackward0\n\n\n\n140447553041456->140447553041312\n\n\n\n\n\n140447553041360\n\nSqrtBackward0\n\n\n\n140447553041360->140447553041456\n\n\n\n\n\n140447553015920\n\nAddBackward0\n\n\n\n140447553015920->140447553041360\n\n\n\n\n\n140447553018560\n\nDivBackward0\n\n\n\n140447553018560->140447553015920\n\n\n\n\n\n140447553018320->140447553018560\n\n\n\n\n\n140447553018272\n\nMulBackward0\n\n\n\n140447553018272->140447553018320\n\n\n\n\n\n140447553018944\n\nAccumulateGrad\n\n\n\n140447553018944->140447553018272\n\n\n\n\n\n140450290824272\n\n()\n\n\n\n140450290824272->140447553018944\n\n\n\n\n\n140447553043328->140447553043424\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "net = Net()\n", "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", "y = torch.tensor(1.0)\n", "\n", "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True)\n", "\n", "net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n", "inner_loss = F.mse_loss(net(x), y)\n", "optim.step(inner_loss)\n", "net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')\n", "\n", "outer_loss = F.mse_loss(net(x), y)\n", "display(\n", " torchopt.visual.make_dot(\n", " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", " )\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that the additional moment terms are added into the computational graph when we set `moment_requires_grad=True`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Extract and Recover" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1 Basic API\n", "\n", "We observe that how to reinitialize the inner-loop parameter in a new bi-level process vary in different meta-learning algorithms. For instance, in algorithm like Model-Agnostic Meta-Learning (MAML) ([arXiv:1703.03400](https://arxiv.org/abs/1703.03400)), every time a new task comes, we need to reset the parameters to the initial ones. In other cases such as Meta-Gradient Reinforcement Learning (MGRL) ([arXiv:1805.09801](https://arxiv.org/abs/1805.09801)), the inner-loop network parameter just inherit previous updated parameter to continue the new bi-level process.\n", "\n", "We provide the `torchopt.extract_state_dict` and `torchopt.recover_state_dict` functions to extract and restore the state of network and optimizer. By default, the extracted state dictionary is a reference (this design is for accumulating gradient of multi-task batch training, MAML for example). You can also set `by='copy'` to extract the copy of the state dictionary or set `by='deepcopy'` to have a detached copy." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a = tensor(-1.0000, grad_fn=)\n", "a = tensor(-1.0000, grad_fn=)\n" ] } ], "source": [ "net = Net()\n", "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", "\n", "optim = torchopt.MetaAdam(net, lr=1.0)\n", "\n", "# Get the reference of state dictionary\n", "init_net_state = torchopt.extract_state_dict(net, by='reference')\n", "init_optim_state = torchopt.extract_state_dict(optim, by='reference')\n", "# If set `detach_buffers=True`, the parameters are referenced as references while buffers are detached copies\n", "init_net_state = torchopt.extract_state_dict(net, by='reference', detach_buffers=True)\n", "\n", "# Set `copy` to get the copy of the state dictionary\n", "init_net_state_copy = torchopt.extract_state_dict(net, by='copy')\n", "init_optim_state_copy = torchopt.extract_state_dict(optim, by='copy')\n", "\n", "# Set `deepcopy` to get the detached copy of state dictionary\n", "init_net_state_deepcopy = torchopt.extract_state_dict(net, by='deepcopy')\n", "init_optim_state_deepcopy = torchopt.extract_state_dict(optim, by='deepcopy')\n", "\n", "# Conduct 2 inner-loop optimization\n", "for i in range(2):\n", " inner_loss = net(x)\n", " optim.step(inner_loss)\n", "\n", "print(f'a = {net.a!r}')\n", "\n", "# Recover and reconduct 2 inner-loop optimization\n", "torchopt.recover_state_dict(net, init_net_state)\n", "torchopt.recover_state_dict(optim, init_optim_state)\n", "\n", "for i in range(2):\n", " inner_loss = net(x)\n", " optim.step(inner_loss)\n", "\n", "print(f'a = {net.a!r}') # the same result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 Multi-task Example with `extract_state_dict` and `recover_state_dict`\n", "\n", "Let's move to another more complex setting. Meta-Learning algorithms always fix network on several different tasks and accumulate outer loss of each task to the meta-gradient.\n", "\n", "Assume $x$ is a meta-parameter and $a$ is a normal parameter. We firstly update $a$ use inner loss $\\mathcal{L}_1^{\\textrm{in}} = a_0 \\cdot x^2$ to $a_1$. Then we use $a_1$ to compute the outer loss $\\mathcal{L}_1^{\\textrm{out}} = a_1 \\cdot x^2$ and backpropagate it. Then we use $a_0$ to compute the inner loss $\\mathcal{L}_2^{\\textrm{in}} = a_0 \\cdot x$ and update $a_0$ to $a_2 = a_0 - \\eta \\, \\frac{\\partial \\mathcal{L}_2^{\\textrm{in}}}{\\partial a_0} = a_0 - \\eta \\, x$. Then we compute outer loss $\\mathcal{L}_2^{\\textrm{out}} = a_2 \\cdot x$ and backpropagate it. So the accumulated meta-gradient would be:\n", "\n", "$$\n", "\\begin{split}\n", " \\frac{\\partial \\mathcal{L}_1^{\\textrm{out}}}{\\partial x} + \\frac{\\partial \\mathcal{L}_2^{\\textrm{out}}}{\\partial x}\n", " & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + \\frac{\\partial (a_2 \\cdot x)}{\\partial x} \\\\\n", " & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + (\\frac{\\partial a_2}{\\partial x} \\cdot x + a_2) \\\\\n", " & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + [\\frac{\\partial (a_0 - \\eta \\, x)}{\\partial x} \\cdot x + (a_0 - \\eta \\, x)] \\\\\n", " & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + [(- \\eta) \\cdot x + (a_0 - \\eta \\, x)] \\\\\n", " & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + (- 2 \\, \\eta \\, x + a_0)\n", "\\end{split}\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's define the network and variables first." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class Net2Tasks(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.a = nn.Parameter(torch.tensor(1.0), requires_grad=True)\n", "\n", " def task1(self, x):\n", " return self.a * x**2\n", "\n", " def task2(self, x):\n", " return self.a * x\n", "\n", "\n", "net = Net2Tasks()\n", "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", "\n", "optim = torchopt.MetaSGD(net, lr=1.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once we call `step` method of `MetaOptimizer`, the parameters of the network would be changed. We should use `torchopt.extract_state_dict` to extract state and use `torchopt.recover_state_dict` to recover the state. Note that if we use optimizers that have momentum buffers, we should also extract and recover them, vanilla SGD does not have momentum buffers so code `init_optim_state = torchopt.extract_state_dict(optim)` and `torchopt.recover_state_dict(optim, init_optim_state)` have no effect." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "init_optim_state = ((EmptyState(),),)\n", "Task 1: x.grad = tensor(-28.)\n", "Accumulated: x.grad = tensor(-31.)\n" ] } ], "source": [ "# Get the reference of state dictionary\n", "init_net_state = torchopt.extract_state_dict(net, by='reference')\n", "init_optim_state = torchopt.extract_state_dict(optim, by='reference')\n", "# The `state_dict` is empty for vanilla SGD optimizer\n", "print(f'init_optim_state = {init_optim_state!r}')\n", "\n", "inner_loss_1 = net.task1(x)\n", "optim.step(inner_loss_1)\n", "outer_loss_1 = net.task1(x)\n", "outer_loss_1.backward()\n", "print(f'Task 1: x.grad = {x.grad!r}')\n", "\n", "torchopt.recover_state_dict(net, init_net_state)\n", "torchopt.recover_state_dict(optim, init_optim_state)\n", "inner_loss_2 = net.task2(x)\n", "optim.step(inner_loss_2)\n", "outer_loss_2 = net.task2(x)\n", "outer_loss_2.backward()\n", "\n", "# `extract_state_dict`` extracts the reference so gradient accumulates\n", "# x.grad = (- 4 * lr * x^3 + 2 * a_0 * x) + (- 2 * lr * x + a_0)\n", "# = (- 4 * 1 * 2^3 + 2 * 1 * 2) + (- 2 * 1 * 2 + 1)\n", "# = -28 - 3\n", "# = -31\n", "print(f'Accumulated: x.grad = {x.grad!r}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Gradient Transformation in `MetaOptimizer`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also use some gradient normalization tricks in our `MetaOptimizer`. In fact `MetaOptimizer` decedents like `MetaSGD` are specializations of `MetaOptimizer`. Specifically, `MetaSGD(net, lr=1.)` is `MetaOptimizer(net, alias.sgd(lr=1., moment_requires_grad=True))`, where flag `moment_requires_grad=True` means the momentums are created with flag `requires_grad=True` so the momentums will also be the part of the computation graph.\n", "\n", "In the designing of TorchOpt, we treat these functions as derivations of `combine.chain`. So we can build our own chain like `combine.chain(clip.clip_grad_norm(max_norm=1.), sgd(lr=1., requires_grad=True))` to clip the gradient and update parameters using `sgd`.\n", "\n", "$$\n", "\\begin{aligned}\n", " \\frac{\\partial \\mathcal{L}^{\\textrm{out}}}{\\partial x}\n", " & = \\frac{\\partial (a_1 \\cdot x^2)}{\\partial x} \\\\\n", " & = \\frac{\\partial a_1}{\\partial x} \\cdot x^2 + a_1 \\cdot \\frac{\\partial (x^2)}{\\partial x} \\\\\n", " & = \\frac{\\partial (a_0 - \\eta \\, g)}{\\partial x} \\cdot x^2 + (a_0 - \\eta \\, g) \\cdot 2 x & \\qquad (g \\propto \\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = x^2, \\ {\\lVert g \\rVert}_2 \\le G_{\\max}) \\\\\n", " & = \\frac{\\partial (a_0 - \\eta \\, \\beta^{-1} \\, x^2)}{\\partial x} \\cdot x^2 + (a_0 - \\eta \\, \\beta^{-1} \\, x^2) \\cdot 2 x & \\qquad (g = \\beta^{-1} \\, x^2, \\ \\beta > 0, \\ {\\lVert g \\rVert}_2 \\le G_{\\max}) \\\\\n", " & = (- \\beta^{-1} \\, \\eta \\cdot 2 x) \\cdot x^2 + (a_0 - \\beta^{-1} \\, \\eta \\, x^2) \\cdot 2 x \\\\\n", " & = - 4 \\, \\beta^{-1} \\, \\eta \\, x^3 + 2 \\, a_0 \\, x\n", "\\end{aligned}\n", "$$" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x.grad = tensor(-12.0000)\n" ] } ], "source": [ "net = Net()\n", "x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n", "\n", "optim_impl = torchopt.combine.chain(\n", " torchopt.clip.clip_grad_norm(max_norm=2.0),\n", " torchopt.sgd(lr=1.0, moment_requires_grad=True),\n", ")\n", "optim = torchopt.MetaOptimizer(net, optim_impl)\n", "\n", "inner_loss = net(x)\n", "optim.step(inner_loss)\n", "\n", "outer_loss = net(x)\n", "outer_loss.backward()\n", "# Since `max_norm` is 2 and the gradient is x^2, so the scale = x^2 / 2 = 2^2 / 2 = 2\n", "# x.grad = - 4 * lr * x^3 / scale + 2 * a_0 * x\n", "# = - 4 * 1 * 2^3 / 2 + 2 * 1 * 2\n", "# = -16 + 4\n", "# = -12\n", "print(f'x.grad = {x.grad!r}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Learning Rate Scheduler\n", "\n", "TorchOpt also provides implementation of learning rate scheduler, which can be used as:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "functional_adam = torchopt.adam(\n", " lr=torchopt.schedule.linear_schedule(\n", " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", " )\n", ")\n", "\n", "adam = torchopt.Adam(\n", " net.parameters(),\n", " lr=torchopt.schedule.linear_schedule(\n", " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", " ),\n", ")\n", "\n", "meta_adam = torchopt.MetaAdam(\n", " net,\n", " lr=torchopt.schedule.linear_schedule(\n", " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", " ),\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Accelerated Optimizer\n", "\n", "Users can use accelerated optimizer by setting the `use_accelerated_op=True`. Currently we only support the Adam optimizer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check whether the `accelerated_op` is available:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n" ] } ], "source": [ "torchopt.accelerated_op_available(torch.device('cpu'))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n" ] } ], "source": [ "torchopt.accelerated_op_available(torch.device('cuda'))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n140450290825712\n\nouter_loss\n()\n\n\n\n140450533650240\n\nMseLossBackward0\n\n\n\n140450533650240->140450290825712\n\n\n\n\n\n140450533648560\n\nMulBackward0\n\n\n\n140450533648560->140450533650240\n\n\n\n\n\n140450533647456\n\nAddBackward0\nstep1.a\n()\n\n\n\n140450533647456->140450533648560\n\n\n\n\n\n140447435136640\n\nAccumulateGrad\n\n\n\n140447435136640->140450533647456\n\n\n\n\n\n140450533648416\n\nMulBackward0\n\n\n\n140447435136640->140450533648416\n\n\n\n\n\n140447435236512\n\nstep0.a\n()\n\n\n\n140447435236512->140447435136640\n\n\n\n\n\n140447435136688\n\nMulBackward0\n\n\n\n140447435136688->140450533647456\n\n\n\n\n\n140447554132144\n\nUpdatesOpBackward\n\n\n\n140447554132144->140447435136688\n\n\n\n\n\n140447554131664\n\nMuOpBackward\n\n\n\n140447554131664->140447554132144\n\n\n\n\n\n140447435134816\n\nMulBackward0\n\n\n\n140447435134816->140447554131664\n\n\n\n\n\n140447554131904\n\nNuOpBackward\n\n\n\n140447435134816->140447554131904\n\n\n\n\n\n140450533648992\n\nMseLossBackwardBackward0\n\n\n\n140450533648992->140447435134816\n\n\n\n\n\n140450533648416->140450533648992\n\n\n\n\n\n140450533646448\n\nPowBackward0\n\n\n\n140450533646448->140447435134816\n\n\n\n\n\n140450533646448->140450533648416\n\n\n\n\n\n140447553018176\n\nAccumulateGrad\n\n\n\n140447553018176->140450533646448\n\n\n\n\n\n140447435135536\n\nPowBackward0\n\n\n\n140447553018176->140447435135536\n\n\n\n\n\n140447553045424\n\nx\n()\n\n\n\n140447553045424->140447553018176\n\n\n\n\n\n140447435136592\n\nAccumulateGrad\n\n\n\n140447435136592->140447554131664\n\n\n\n\n\n140447552973856\n\n()\n\n\n\n140447552973856->140447554131664\n\n\n\n\n\n140447552973856->140447435136592\n\n\n\n\n\n140447553044544\n\n()\n\n\n\n140447553044544->140447554131664\n\n\n\n\n\n140447553044544->140447554131904\n\n\n\n\n\n140447554131904->140447554132144\n\n\n\n\n\n140450533648896\n\nAccumulateGrad\n\n\n\n140450533648896->140447554131904\n\n\n\n\n\n140447435236752\n\n()\n\n\n\n140447435236752->140447554131904\n\n\n\n\n\n140447435236752->140450533648896\n\n\n\n\n\n140447553045904\n\n()\n\n\n\n140447553045904->140447554132144\n\n\n\n\n\n140447435237152\n\n()\n\n\n\n140447435237152->140447554132144\n\n\n\n\n\n140447435237232\n\n()\n\n\n\n140447435237232->140447554132144\n\n\n\n\n\n140447435135536->140450533648560\n\n\n\n\n\n" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "net = Net().to(device='cuda')\n", "x = nn.Parameter(torch.tensor(2.0, device=torch.device('cuda')), requires_grad=True)\n", "y = torch.tensor(1.0, device=torch.device('cuda'))\n", "\n", "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True, use_accelerated_op=True)\n", "\n", "net_state_0 = torchopt.extract_state_dict(\n", " net, by='reference', enable_visual=True, visual_prefix='step0.'\n", ")\n", "inner_loss = F.mse_loss(net(x), y)\n", "optim.step(inner_loss)\n", "net_state_1 = torchopt.extract_state_dict(\n", " net, by='reference', enable_visual=True, visual_prefix='step1.'\n", ")\n", "\n", "outer_loss = F.mse_loss(net(x), y)\n", "display(\n", " torchopt.visual.make_dot(\n", " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", " )\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Known Issues\n", "\n", "Here we record some common issues faced by users when using the meta-optimizer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**1. Get `NaN` error when using `MetaAdam` or other meta-optimizers.**\n", "\n", "The `NaN` error is because of the numerical instability of the `Adam` in meta-learning. There exist an `sqrt` operation in `Adam`'s computation process. Backpropogating through the `Adam` operator introduces the second derivation of the `sqrt` operation, which is not numerical stable, i.e. ${\\left. \\frac{d^2 \\sqrt{x}}{{dx}^2} \\right\\rvert}_{x = 0} = \\texttt{NaN}$. You can also refer to issue [facebookresearch/higher#125](https://github.com/facebookresearch/higher/issues/125).\n", "\n", "For this problem, TorchOpt have two recommended solutions.\n", "\n", "* Put the `sqrt` operation into the whole equation, and compute the derivation of the output to the input manually. The second derivation of the `sqrt` operation will be eliminated. You can achieve this by setting the flag `use_accelerated_op=True`, you can follow the instructions in notebook [Functional Optimizer](1_Functional_Optimizer.ipynb) and Meta-Optimizer." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "inner_optim = torchopt.MetaAdam(net, lr=1.0, use_accelerated_op=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Register hook to the first-order gradients. During the backpropagation, the NaN gradients will be set to 0, which will have a similar effect to the first solution but much slower. " ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "impl = torchopt.chain(torchopt.hook.register_hook(torchopt.hook.zero_nan_hook), torchopt.adam(1e-1))\n", "inner_optim = torchopt.MetaOptimizer(net, impl)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**2. Get `Trying to backward through the graph a second time` error when conducting multiple meta-optimization.**\n", "\n", "Please refer to the tutorial notebook [Stop Gradient](4_Stop_Gradient.ipynb) for more guidance." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.15" }, "vscode": { "interpreter": { "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" } } }, "nbformat": 4, "nbformat_minor": 4 }