{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TorchOpt as Meta-Optimizer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[
](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/3_Meta_Optimizer.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this tutorial, we will show how to treat TorchOpt as a differentiable optimizer with traditional PyTorch optimization API. In addition, we also provide many other API for easy meta-learning algorithm implementations."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Basic API for Differentiable Optimizer\n",
"\n",
"`MetaOptimizer` is the main class for our differentiable optimizer. Combined with the functional optimizer `torchopt.sgd` and `torchopt.adam` mentioned in the tutorial 1, we can define our high-level API `torchopt.MetaSGD` and `torchopt.MetaAdam`. We will discuss how this combination happens with `torchopt.chain` in Section 3. Let us consider the problem below."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Assume a tensor $x$ is a meta-parameter and $a$ is a normal parameters (such as network parameters). We have inner loss $\\mathcal{L}^{\\textrm{in}} = a_0 \\cdot x^2$ and we update $a$ use the gradient $\\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = x^2$ and $a_1 = a_0 - \\eta \\, \\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = a_0 - \\eta \\, x^2$. Then we compute the outer loss $\\mathcal{L}^{\\textrm{out}} = a_1 \\cdot x^2$. So the gradient of outer loss to $x$ would be:\n",
"\n",
"$$\n",
"\\begin{split}\n",
" \\frac{\\partial \\mathcal{L}^{\\textrm{out}}}{\\partial x}\n",
" & = \\frac{\\partial (a_1 \\cdot x^2)}{\\partial x} \\\\\n",
" & = \\frac{\\partial a_1}{\\partial x} \\cdot x^2 + a_1 \\cdot \\frac{\\partial (x^2)}{\\partial x} \\\\\n",
" & = \\frac{\\partial (a_0 - \\eta \\, x^2)}{\\partial x} \\cdot x^2 + (a_0 - \\eta \\, x^2) \\cdot 2 x \\\\\n",
" & = (- \\eta \\cdot 2 x) \\cdot x^2 + (a_0 - \\eta \\, x^2) \\cdot 2 x \\\\\n",
" & = - 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x\n",
"\\end{split}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Given the analytical solution above. Let's try to verify it with TorchOpt. Define the net work first."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import display\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"import torchopt\n",
"\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.a = nn.Parameter(torch.tensor(1.0), requires_grad=True)\n",
"\n",
" def forward(self, x):\n",
" return self.a * (x**2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we declare the network (parameterized by `a`) and the meta-parameter `x`. Do not forget to set flag `requires_grad=True` for `x`."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"net = Net()\n",
"x = nn.Parameter(torch.tensor(2.0), requires_grad=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we declare the meta-optimizer. Here we show two equivalent ways of defining the meta-optimizer. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Low-level API\n",
"optim = torchopt.MetaOptimizer(net, torchopt.sgd(lr=1.0))\n",
"\n",
"# High-level API\n",
"optim = torchopt.MetaSGD(net, lr=1.0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The meta-optimizer takes the network as input and use method `step` to update the network (parameterized by `a`). Finally, we show how a bi-level process works."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x.grad = tensor(-28.)\n"
]
}
],
"source": [
"inner_loss = net(x)\n",
"optim.step(inner_loss)\n",
"\n",
"outer_loss = net(x)\n",
"outer_loss.backward()\n",
"# x.grad = - 4 * lr * x^3 + 2 * a_0 * x\n",
"# = - 4 * 1 * 2^3 + 2 * 1 * 2\n",
"# = -32 + 4\n",
"# = -28\n",
"print(f'x.grad = {x.grad!r}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1.1 Track the Gradient of Momentum\n",
"\n",
"Note that most modern optimizers involve moment term in the gradient update (basically only SGD with `momentum=0` does not involve). We provide an option for user to choose whether to also track the meta-gradient through moment term. The default option is `moment_requires_grad=True`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- When you do not track the meta-gradient through moment (`moment_requires_grad=False`)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"image/svg+xml": "\n\n\n\n\n"
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"net = Net()\n",
"x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n",
"y = torch.tensor(1.0)\n",
"\n",
"optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=False)\n",
"\n",
"net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n",
"inner_loss = F.mse_loss(net(x), y)\n",
"optim.step(inner_loss)\n",
"net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')\n",
"\n",
"outer_loss = F.mse_loss(net(x), y)\n",
"display(\n",
" torchopt.visual.make_dot(\n",
" outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- When you track the meta-gradient through moment (`moment_requires_grad=True`, default for `torchopt.MetaAdam`)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"image/svg+xml": "\n\n\n\n\n"
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"net = Net()\n",
"x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n",
"y = torch.tensor(1.0)\n",
"\n",
"optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True)\n",
"\n",
"net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n",
"inner_loss = F.mse_loss(net(x), y)\n",
"optim.step(inner_loss)\n",
"net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')\n",
"\n",
"outer_loss = F.mse_loss(net(x), y)\n",
"display(\n",
" torchopt.visual.make_dot(\n",
" outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the additional moment terms are added into the computational graph when we set `moment_requires_grad=True`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Extract and Recover"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.1 Basic API\n",
"\n",
"We observe that how to reinitialize the inner-loop parameter in a new bi-level process vary in different meta-learning algorithms. For instance, in algorithm like Model-Agnostic Meta-Learning (MAML) ([arXiv:1703.03400](https://arxiv.org/abs/1703.03400)), every time a new task comes, we need to reset the parameters to the initial ones. In other cases such as Meta-Gradient Reinforcement Learning (MGRL) ([arXiv:1805.09801](https://arxiv.org/abs/1805.09801)), the inner-loop network parameter just inherit previous updated parameter to continue the new bi-level process.\n",
"\n",
"We provide the `torchopt.extract_state_dict` and `torchopt.recover_state_dict` functions to extract and restore the state of network and optimizer. By default, the extracted state dictionary is a reference (this design is for accumulating gradient of multi-task batch training, MAML for example). You can also set `by='copy'` to extract the copy of the state dictionary or set `by='deepcopy'` to have a detached copy."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"a = tensor(-1.0000, grad_fn=)\n",
"a = tensor(-1.0000, grad_fn=)\n"
]
}
],
"source": [
"net = Net()\n",
"x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n",
"\n",
"optim = torchopt.MetaAdam(net, lr=1.0)\n",
"\n",
"# Get the reference of state dictionary\n",
"init_net_state = torchopt.extract_state_dict(net, by='reference')\n",
"init_optim_state = torchopt.extract_state_dict(optim, by='reference')\n",
"# If set `detach_buffers=True`, the parameters are referenced as references while buffers are detached copies\n",
"init_net_state = torchopt.extract_state_dict(net, by='reference', detach_buffers=True)\n",
"\n",
"# Set `copy` to get the copy of the state dictionary\n",
"init_net_state_copy = torchopt.extract_state_dict(net, by='copy')\n",
"init_optim_state_copy = torchopt.extract_state_dict(optim, by='copy')\n",
"\n",
"# Set `deepcopy` to get the detached copy of state dictionary\n",
"init_net_state_deepcopy = torchopt.extract_state_dict(net, by='deepcopy')\n",
"init_optim_state_deepcopy = torchopt.extract_state_dict(optim, by='deepcopy')\n",
"\n",
"# Conduct 2 inner-loop optimization\n",
"for i in range(2):\n",
" inner_loss = net(x)\n",
" optim.step(inner_loss)\n",
"\n",
"print(f'a = {net.a!r}')\n",
"\n",
"# Recover and reconduct 2 inner-loop optimization\n",
"torchopt.recover_state_dict(net, init_net_state)\n",
"torchopt.recover_state_dict(optim, init_optim_state)\n",
"\n",
"for i in range(2):\n",
" inner_loss = net(x)\n",
" optim.step(inner_loss)\n",
"\n",
"print(f'a = {net.a!r}') # the same result"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.2 Multi-task Example with `extract_state_dict` and `recover_state_dict`\n",
"\n",
"Let's move to another more complex setting. Meta-Learning algorithms always fix network on several different tasks and accumulate outer loss of each task to the meta-gradient.\n",
"\n",
"Assume $x$ is a meta-parameter and $a$ is a normal parameter. We firstly update $a$ use inner loss $\\mathcal{L}_1^{\\textrm{in}} = a_0 \\cdot x^2$ to $a_1$. Then we use $a_1$ to compute the outer loss $\\mathcal{L}_1^{\\textrm{out}} = a_1 \\cdot x^2$ and backpropagate it. Then we use $a_0$ to compute the inner loss $\\mathcal{L}_2^{\\textrm{in}} = a_0 \\cdot x$ and update $a_0$ to $a_2 = a_0 - \\eta \\, \\frac{\\partial \\mathcal{L}_2^{\\textrm{in}}}{\\partial a_0} = a_0 - \\eta \\, x$. Then we compute outer loss $\\mathcal{L}_2^{\\textrm{out}} = a_2 \\cdot x$ and backpropagate it. So the accumulated meta-gradient would be:\n",
"\n",
"$$\n",
"\\begin{split}\n",
" \\frac{\\partial \\mathcal{L}_1^{\\textrm{out}}}{\\partial x} + \\frac{\\partial \\mathcal{L}_2^{\\textrm{out}}}{\\partial x}\n",
" & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + \\frac{\\partial (a_2 \\cdot x)}{\\partial x} \\\\\n",
" & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + (\\frac{\\partial a_2}{\\partial x} \\cdot x + a_2) \\\\\n",
" & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + [\\frac{\\partial (a_0 - \\eta \\, x)}{\\partial x} \\cdot x + (a_0 - \\eta \\, x)] \\\\\n",
" & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + [(- \\eta) \\cdot x + (a_0 - \\eta \\, x)] \\\\\n",
" & = (- 4 \\, \\eta \\, x^3 + 2 \\, a_0 \\, x) + (- 2 \\, \\eta \\, x + a_0)\n",
"\\end{split}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's define the network and variables first."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class Net2Tasks(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.a = nn.Parameter(torch.tensor(1.0), requires_grad=True)\n",
"\n",
" def task1(self, x):\n",
" return self.a * x**2\n",
"\n",
" def task2(self, x):\n",
" return self.a * x\n",
"\n",
"\n",
"net = Net2Tasks()\n",
"x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n",
"\n",
"optim = torchopt.MetaSGD(net, lr=1.0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Once we call `step` method of `MetaOptimizer`, the parameters of the network would be changed. We should use `torchopt.extract_state_dict` to extract state and use `torchopt.recover_state_dict` to recover the state. Note that if we use optimizers that have momentum buffers, we should also extract and recover them, vanilla SGD does not have momentum buffers so code `init_optim_state = torchopt.extract_state_dict(optim)` and `torchopt.recover_state_dict(optim, init_optim_state)` have no effect."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"init_optim_state = ((EmptyState(),),)\n",
"Task 1: x.grad = tensor(-28.)\n",
"Accumulated: x.grad = tensor(-31.)\n"
]
}
],
"source": [
"# Get the reference of state dictionary\n",
"init_net_state = torchopt.extract_state_dict(net, by='reference')\n",
"init_optim_state = torchopt.extract_state_dict(optim, by='reference')\n",
"# The `state_dict` is empty for vanilla SGD optimizer\n",
"print(f'init_optim_state = {init_optim_state!r}')\n",
"\n",
"inner_loss_1 = net.task1(x)\n",
"optim.step(inner_loss_1)\n",
"outer_loss_1 = net.task1(x)\n",
"outer_loss_1.backward()\n",
"print(f'Task 1: x.grad = {x.grad!r}')\n",
"\n",
"torchopt.recover_state_dict(net, init_net_state)\n",
"torchopt.recover_state_dict(optim, init_optim_state)\n",
"inner_loss_2 = net.task2(x)\n",
"optim.step(inner_loss_2)\n",
"outer_loss_2 = net.task2(x)\n",
"outer_loss_2.backward()\n",
"\n",
"# `extract_state_dict`` extracts the reference so gradient accumulates\n",
"# x.grad = (- 4 * lr * x^3 + 2 * a_0 * x) + (- 2 * lr * x + a_0)\n",
"# = (- 4 * 1 * 2^3 + 2 * 1 * 2) + (- 2 * 1 * 2 + 1)\n",
"# = -28 - 3\n",
"# = -31\n",
"print(f'Accumulated: x.grad = {x.grad!r}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Gradient Transformation in `MetaOptimizer`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also use some gradient normalization tricks in our `MetaOptimizer`. In fact `MetaOptimizer` decedents like `MetaSGD` are specializations of `MetaOptimizer`. Specifically, `MetaSGD(net, lr=1.)` is `MetaOptimizer(net, alias.sgd(lr=1., moment_requires_grad=True))`, where flag `moment_requires_grad=True` means the momentums are created with flag `requires_grad=True` so the momentums will also be the part of the computation graph.\n",
"\n",
"In the designing of TorchOpt, we treat these functions as derivations of `combine.chain`. So we can build our own chain like `combine.chain(clip.clip_grad_norm(max_norm=1.), sgd(lr=1., requires_grad=True))` to clip the gradient and update parameters using `sgd`.\n",
"\n",
"$$\n",
"\\begin{aligned}\n",
" \\frac{\\partial \\mathcal{L}^{\\textrm{out}}}{\\partial x}\n",
" & = \\frac{\\partial (a_1 \\cdot x^2)}{\\partial x} \\\\\n",
" & = \\frac{\\partial a_1}{\\partial x} \\cdot x^2 + a_1 \\cdot \\frac{\\partial (x^2)}{\\partial x} \\\\\n",
" & = \\frac{\\partial (a_0 - \\eta \\, g)}{\\partial x} \\cdot x^2 + (a_0 - \\eta \\, g) \\cdot 2 x & \\qquad (g \\propto \\frac{\\partial \\mathcal{L}^{\\textrm{in}}}{\\partial a_0} = x^2, \\ {\\lVert g \\rVert}_2 \\le G_{\\max}) \\\\\n",
" & = \\frac{\\partial (a_0 - \\eta \\, \\beta^{-1} \\, x^2)}{\\partial x} \\cdot x^2 + (a_0 - \\eta \\, \\beta^{-1} \\, x^2) \\cdot 2 x & \\qquad (g = \\beta^{-1} \\, x^2, \\ \\beta > 0, \\ {\\lVert g \\rVert}_2 \\le G_{\\max}) \\\\\n",
" & = (- \\beta^{-1} \\, \\eta \\cdot 2 x) \\cdot x^2 + (a_0 - \\beta^{-1} \\, \\eta \\, x^2) \\cdot 2 x \\\\\n",
" & = - 4 \\, \\beta^{-1} \\, \\eta \\, x^3 + 2 \\, a_0 \\, x\n",
"\\end{aligned}\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x.grad = tensor(-12.0000)\n"
]
}
],
"source": [
"net = Net()\n",
"x = nn.Parameter(torch.tensor(2.0), requires_grad=True)\n",
"\n",
"optim_impl = torchopt.combine.chain(\n",
" torchopt.clip.clip_grad_norm(max_norm=2.0),\n",
" torchopt.sgd(lr=1.0, moment_requires_grad=True),\n",
")\n",
"optim = torchopt.MetaOptimizer(net, optim_impl)\n",
"\n",
"inner_loss = net(x)\n",
"optim.step(inner_loss)\n",
"\n",
"outer_loss = net(x)\n",
"outer_loss.backward()\n",
"# Since `max_norm` is 2 and the gradient is x^2, so the scale = x^2 / 2 = 2^2 / 2 = 2\n",
"# x.grad = - 4 * lr * x^3 / scale + 2 * a_0 * x\n",
"# = - 4 * 1 * 2^3 / 2 + 2 * 1 * 2\n",
"# = -16 + 4\n",
"# = -12\n",
"print(f'x.grad = {x.grad!r}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Learning Rate Scheduler\n",
"\n",
"TorchOpt also provides implementation of learning rate scheduler, which can be used as:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"functional_adam = torchopt.adam(\n",
" lr=torchopt.schedule.linear_schedule(\n",
" init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n",
" )\n",
")\n",
"\n",
"adam = torchopt.Adam(\n",
" net.parameters(),\n",
" lr=torchopt.schedule.linear_schedule(\n",
" init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n",
" ),\n",
")\n",
"\n",
"meta_adam = torchopt.MetaAdam(\n",
" net,\n",
" lr=torchopt.schedule.linear_schedule(\n",
" init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n",
" ),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Accelerated Optimizer\n",
"\n",
"Users can use accelerated optimizer by setting the `use_accelerated_op=True`. Currently we only support the Adam optimizer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check whether the `accelerated_op` is available:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n"
]
}
],
"source": [
"torchopt.accelerated_op_available(torch.device('cpu'))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n"
]
}
],
"source": [
"torchopt.accelerated_op_available(torch.device('cuda'))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"image/svg+xml": "\n\n\n\n\n"
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"net = Net().to(device='cuda')\n",
"x = nn.Parameter(torch.tensor(2.0, device=torch.device('cuda')), requires_grad=True)\n",
"y = torch.tensor(1.0, device=torch.device('cuda'))\n",
"\n",
"optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True, use_accelerated_op=True)\n",
"\n",
"net_state_0 = torchopt.extract_state_dict(\n",
" net, by='reference', enable_visual=True, visual_prefix='step0.'\n",
")\n",
"inner_loss = F.mse_loss(net(x), y)\n",
"optim.step(inner_loss)\n",
"net_state_1 = torchopt.extract_state_dict(\n",
" net, by='reference', enable_visual=True, visual_prefix='step1.'\n",
")\n",
"\n",
"outer_loss = F.mse_loss(net(x), y)\n",
"display(\n",
" torchopt.visual.make_dot(\n",
" outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Known Issues\n",
"\n",
"Here we record some common issues faced by users when using the meta-optimizer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**1. Get `NaN` error when using `MetaAdam` or other meta-optimizers.**\n",
"\n",
"The `NaN` error is because of the numerical instability of the `Adam` in meta-learning. There exist an `sqrt` operation in `Adam`'s computation process. Backpropogating through the `Adam` operator introduces the second derivation of the `sqrt` operation, which is not numerical stable, i.e. ${\\left. \\frac{d^2 \\sqrt{x}}{{dx}^2} \\right\\rvert}_{x = 0} = \\texttt{NaN}$. You can also refer to issue [facebookresearch/higher#125](https://github.com/facebookresearch/higher/issues/125).\n",
"\n",
"For this problem, TorchOpt have two recommended solutions.\n",
"\n",
"* Put the `sqrt` operation into the whole equation, and compute the derivation of the output to the input manually. The second derivation of the `sqrt` operation will be eliminated. You can achieve this by setting the flag `use_accelerated_op=True`, you can follow the instructions in notebook [Functional Optimizer](1_Functional_Optimizer.ipynb) and Meta-Optimizer."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"inner_optim = torchopt.MetaAdam(net, lr=1.0, use_accelerated_op=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* Register hook to the first-order gradients. During the backpropagation, the NaN gradients will be set to 0, which will have a similar effect to the first solution but much slower. "
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"impl = torchopt.chain(torchopt.hook.register_hook(torchopt.hook.zero_nan_hook), torchopt.adam(1e-1))\n",
"inner_optim = torchopt.MetaOptimizer(net, impl)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**2. Get `Trying to backward through the graph a second time` error when conducting multiple meta-optimization.**\n",
"\n",
"Please refer to the tutorial notebook [Stop Gradient](4_Stop_Gradient.ipynb) for more guidance."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
},
"vscode": {
"interpreter": {
"hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}