{ "cells": [ { "cell_type": "markdown", "id": "8850c832-3b54-4971-8ee0-2cd64b585ea8", "metadata": {}, "source": [ "# TorchOpt for Implicit Differentiation" ] }, { "cell_type": "markdown", "id": "2b547376", "metadata": {}, "source": [ "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/5_Implicit_Differentiation.ipynb)" ] }, { "cell_type": "markdown", "id": "8d7f9865-dc02-43d4-be90-da1160c4e4dd", "metadata": {}, "source": [ "By treating the solution $\\phi^{\\star}$ as an implicit function of $\\theta$, the idea of implicit differentiation is to directly get analytical best-response derivatives $\\partial \\phi^{\\star}(\\theta)/ \\partial \\theta$ by implicit function theorem. This is suitable for algorithms when the inner-level optimal solution is achieved ${\\left. \\frac{\\partial F (\\phi, \\theta)}{\\partial \\phi} \\right\\rvert}_{\\phi = \\phi^{\\star}} = 0$ or reaches some stationary conditions $F (\\phi^{\\star}, \\theta) = 0$, such as [iMAML](https://arxiv.org/abs/1909.04630) and [DEQ](https://arxiv.org/abs/1909.01377)." ] }, { "cell_type": "markdown", "id": "d7e4b9e1-115f-45ad-a9b3-ea338bcfe6dd", "metadata": {}, "source": [ "In this tutorial, we will introduce how TorchOpt can be used to conduct implicit differentiation." ] }, { "cell_type": "code", "execution_count": 1, "id": "8f13ae67-e328-409f-84a8-1fc425c03a66", "metadata": {}, "outputs": [], "source": [ "import functorch\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "import torchopt" ] }, { "attachments": {}, "cell_type": "markdown", "id": "0cdaac49-4b94-4900-9bb5-a39057ac8b21", "metadata": {}, "source": [ "## 1. Functional API\n", "\n", "The basic functional API is `torchopt.diff.implicit.custom_root`, which is used as the decorator for the forward process implicit gradient procedures. Users are required to implement the stationary conditions for the inner-loop process, which will be used as the input of custom_root decorator. We show the pseudo code in the following part.\n", "\n", "```python\n", "# Functional API for implicit gradient\n", "def stationary(params, meta_params, data):\n", " # stationary condition construction\n", " return stationary condition\n", "\n", "# Decorator that wraps the function\n", "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", "@torchopt.diff.implicit.custom_root(stationary, solve=linear_solver)\n", "def solve(params, meta_params, data):\n", " # Forward optimization process for params\n", " return optimal_params\n", "\n", "# Define params, meta_params and get data\n", "params, meta_prams, data = ..., ..., ...\n", "optimal_params = solve(params, meta_params, data)\n", "loss = outer_loss(optimal_params)\n", "\n", "meta_grads = torch.autograd.grad(loss, meta_params)\n", "```" ] }, { "cell_type": "markdown", "id": "dbef87df-2164-4f1d-8919-37a6fbdc5011", "metadata": {}, "source": [ "Here we use the example of [iMAML](https://arxiv.org/abs/1909.04630) as a real example. For iMAML, the inner-loop objective is described by the following equation.\n", "\n", "$$\n", "{\\mathcal{Alg}}^{\\star} \\left( \\boldsymbol{\\theta}, \\mathcal{D}_{i}^{\\text{tr}} \\right) = \\underset{\\phi'}{\\operatorname{\\arg \\min}} ~ G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\triangleq \\mathcal{L} \\left( \\boldsymbol{\\phi}', \\mathcal{D}_{i}^{\\text{tr}} \\right) + \\frac{\\lambda}{2} {\\left\\| \\boldsymbol{\\phi}' - \\boldsymbol{\\theta} \\right\\|}^{2}\n", "$$\n", "\n", "According to this function, we can define the forward function `inner_solver`, where we solve this equation based on sufficient gradient descents. For such inner-loop process, the optimality condition is that the gradient w.r.t inner-loop parameter is $0$.\n", "\n", "$$\n", "{\\left. \\nabla_{\\boldsymbol{\\phi}'} G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\right\\rvert}_{\\boldsymbol{\\phi}' = \\boldsymbol{\\phi}^{\\star}} = 0\n", "$$\n", "\n", "Thus we can define the optimality function by defining `imaml_objective` and make it first-order gradient w.r.t the inner-loop parameter as $0$. We achieve so by calling out `functorch.grad(imaml_objective, argnums=0)`. Finally, the forward function is decorated by the `@torchopt.diff.implicit.custom_root` decorator and the optimality condition we define." ] }, { "cell_type": "code", "execution_count": 2, "id": "8d623b2f-48ee-4df6-a2ce-cf306b4c9067", "metadata": {}, "outputs": [], "source": [ "# Inner-loop objective function\n", "# The optimality function: grad(imaml_objective)\n", "def imaml_objective(params, meta_params, data):\n", " x, y, fmodel = data\n", " y_pred = fmodel(params, x)\n", " regularization_loss = 0.0\n", " for p1, p2 in zip(params, meta_params):\n", " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", " loss = F.mse_loss(y_pred, y) + regularization_loss\n", " return loss\n", "\n", "\n", "# Optimality Condition is: the gradient w.r.t inner-loop optimal params is 0 (we achieve so by\n", "# specifying argnums=0 in functorch.grad) the argnums=1 specify which meta-parameter we want to\n", "# backpropogate, in this case we want to backpropogate to the initial parameters so we set it as 1.\n", "# You can also set argnums as (1, 2) if you want to backpropogate through multiple meta-parameters\n", "\n", "\n", "# Here we pass argnums=1 to the custom_root. That means we want to compute the gradient of\n", "# optimal_params w.r.t. the 1-indexed argument in inner_solver, i.e., params.\n", "# torchopt.linear_solve.solve_normal_cg specify that we use the conjugate gradient based linear solver\n", "@torchopt.diff.implicit.custom_root(\n", " functorch.grad(imaml_objective, argnums=0), # optimality function\n", " argnums=1,\n", " solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", ")\n", "def inner_solver(params, meta_params, data):\n", " # Initial functional optimizer based on TorchOpt\n", " x, y, fmodel = data\n", " optimizer = torchopt.sgd(lr=2e-2)\n", " opt_state = optimizer.init(params)\n", " with torch.enable_grad():\n", " # Temporarily enable gradient computation for conducting the optimization\n", " for i in range(100):\n", " pred = fmodel(params, x)\n", " loss = F.mse_loss(pred, y) # compute loss\n", "\n", " # Compute regularization loss\n", " regularization_loss = 0.0\n", " for p1, p2 in zip(params, meta_params):\n", " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", " final_loss = loss + regularization_loss\n", "\n", " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates\n", " params = torchopt.apply_updates(params, updates, inplace=True)\n", "\n", " optimal_params = params\n", " return optimal_params\n", "\n", "\n", "# torchopt.linear_solve.solve_inv specify that we use the Neumann Series inversion linear solver\n", "@torchopt.diff.implicit.custom_root(\n", " functorch.grad(imaml_objective, argnums=0), # optimality function\n", " argnums=1,\n", " solve=torchopt.linear_solve.solve_inv(ns=True, maxiter=100, alpha=0.1),\n", ")\n", "def inner_solver_inv_ns(params, meta_params, data):\n", " # Initial functional optimizer based on TorchOpt\n", " x, y, fmodel = data\n", " optimizer = torchopt.sgd(lr=2e-2)\n", " opt_state = optimizer.init(params)\n", " with torch.enable_grad():\n", " # Temporarily enable gradient computation for conducting the optimization\n", " for i in range(100):\n", " pred = fmodel(params, x)\n", " loss = F.mse_loss(pred, y) # compute loss\n", "\n", " # Compute regularization loss\n", " regularization_loss = 0.0\n", " for p1, p2 in zip(params, meta_params):\n", " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", " final_loss = loss + regularization_loss\n", "\n", " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates\n", " params = torchopt.apply_updates(params, updates, inplace=True)\n", "\n", " optimal_params = params\n", " return optimal_params" ] }, { "cell_type": "markdown", "id": "32a75c81-d479-4120-a73d-5b2b488358d0", "metadata": {}, "source": [ "In the next step, we consider a specific case for one layer neural network to fit the linear data." ] }, { "cell_type": "code", "execution_count": 3, "id": "fb95538b-1fd9-4ec8-9f57-6360bedc05b7", "metadata": {}, "outputs": [], "source": [ "torch.manual_seed(0)\n", "x = torch.randn(20, 4)\n", "w = torch.randn(4, 1)\n", "b = torch.randn(1)\n", "y = x @ w + b + 0.5 * torch.randn(20, 1)" ] }, { "cell_type": "markdown", "id": "eeb1823a-2231-4471-bb68-cce7724f2578", "metadata": {}, "source": [ "We instantiate an one layer neural network, where the weights and bias are initialized with constant." ] }, { "cell_type": "code", "execution_count": 4, "id": "d50a7bfe-ac69-4089-8cf8-3cbd69d6d4e7", "metadata": { "tags": [] }, "outputs": [], "source": [ "class Net(nn.Module):\n", " def __init__(self, dim):\n", " super().__init__()\n", " self.fc = nn.Linear(dim, 1, bias=True)\n", " nn.init.ones_(self.fc.weight)\n", " nn.init.zeros_(self.fc.bias)\n", "\n", " def forward(self, x):\n", " return self.fc(x)\n", "\n", "\n", "model = Net(4)\n", "fmodel, meta_params = functorch.make_functional(model)\n", "data = (x, y, fmodel)\n", "\n", "\n", "# Clone function for parameters\n", "def clone(params):\n", " cloned = []\n", " for item in params:\n", " if isinstance(item, torch.Tensor):\n", " cloned.append(item.clone().detach_().requires_grad_(True))\n", " else:\n", " cloned.append(item)\n", " return tuple(cloned)" ] }, { "cell_type": "markdown", "id": "065c36c4-89e2-4a63-8213-63db6ee3b08e", "metadata": {}, "source": [ "We take the forward process by calling out the forward function, then we pass the optimal params into the outer-loop loss function." ] }, { "cell_type": "code", "execution_count": 5, "id": "115e79c6-911f-4743-a2ed-e50a71c3a813", "metadata": { "tags": [] }, "outputs": [], "source": [ "optimal_params = inner_solver(clone(meta_params), meta_params, data)\n", "\n", "outer_loss = fmodel(optimal_params, x).mean()" ] }, { "cell_type": "markdown", "id": "e2812351-f635-496e-9732-c80831ac04a6", "metadata": {}, "source": [ "Finally, we can get the meta-gradient as shown below." ] }, { "cell_type": "code", "execution_count": 6, "id": "6bdcbe8d-2336-4f80-b124-eb43c5a2fc0a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" ] } ], "source": [ "torch.autograd.grad(outer_loss, meta_params)" ] }, { "cell_type": "markdown", "id": "926ae8bb", "metadata": {}, "source": [ "Also we can switch to the Neumann Series inversion linear solver." ] }, { "cell_type": "code", "execution_count": 7, "id": "43df0374", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" ] } ], "source": [ "optimal_params = inner_solver_inv_ns(clone(meta_params), meta_params, data)\n", "outer_loss = fmodel(optimal_params, x).mean()\n", "torch.autograd.grad(outer_loss, meta_params)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c92e67ea-b220-4a14-a1ea-4eb3c5f52b6b", "metadata": {}, "source": [ "## 2. OOP API\n", "\n", "The basic OOP class is the class `ImplicitMetaGradientModule`. We make the network as an `nn.Module` following a classical PyTorch style. Users need to define the stationary condition/objective function and the inner-loop solve function to enable implicit gradient computation. We show the pseudo code in the following part.\n", "\n", "```python\n", "from torchopt.nn import ImplicitMetaGradientModule\n", "\n", "# Inherited from the class ImplicitMetaGradientModule\n", "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", "class InnerNet(ImplicitMetaGradientModule, linear_solve=linear_solver):\n", " def __init__(self, meta_module):\n", " ...\n", "\n", " def forward(self, batch):\n", " # Forward process\n", " ...\n", "\n", " def optimality(self, batch, labels):\n", " # Stationary condition construction for calculating implicit gradient\n", " # NOTE: If this method is not implemented, it will be automatically derived from the\n", " # gradient of the `objective` function.\n", " ...\n", "\n", " def objective(self, batch, labels):\n", " # Define the inner-loop optimization objective\n", " # NOTE: This method is optional if method `optimality` is implemented.\n", " ...\n", "\n", " def solve(self, batch, labels):\n", " # Conduct the inner-loop optimization\n", " ...\n", " return self # optimized module\n", "\n", "# Get meta_params and data\n", "meta_params, data = ..., ...\n", "inner_net = InnerNet()\n", "\n", "# Solve for inner-loop process related to the meta-parameters\n", "optimal_inner_net = inner_net.solve(meta_params, *data)\n", "\n", "# Get outer-loss and solve for meta-gradient\n", "loss = outer_loss(optimal_inner_net)\n", "meta_grad = torch.autograd.grad(loss, meta_params)\n", "```" ] }, { "cell_type": "markdown", "id": "62fbe520-11d0-41ff-9b0a-c6508b1d01cf", "metadata": {}, "source": [ "The class `ImplicitMetaGradientModule` is to enable the gradient flow from `self.parameters()` to `self.meta_parameters()`. In `__init__` function, users need to define the inner parameters and meta-parameters. By default, `ImplicitMetaGradientModule` treats all tensors and modules from input as `self.meta_parameters()`, and all tensors and modules defined in the `__init__` are regarded as `self.parameters()`. Users can also register `self.parameters()` and `self.meta_parameters()` by calling `self.register_parameter()` and `self.register_meta_parameter()` respectively." ] }, { "cell_type": "code", "execution_count": 8, "id": "c3999684-f4d3-4bc0-86ab-a7e803b2fe80", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" ] } ], "source": [ "class InnerNet(\n", " torchopt.nn.ImplicitMetaGradientModule,\n", " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", "):\n", " def __init__(self, meta_net, n_inner_iter, reg_param):\n", " super().__init__()\n", " # Declaration of the meta-parameter\n", " self.meta_net = meta_net\n", " # Get a deepcopy, register inner-parameter\n", " self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True)\n", " self.n_inner_iter = n_inner_iter\n", " self.reg_param = reg_param\n", "\n", " def forward(self, x):\n", " return self.net(x)\n", "\n", " def objective(self, x, y):\n", " # We do not implement the optimality conditions, so it will be automatically derived from\n", " # the gradient of the `objective` function.\n", " y_pred = self(x)\n", " loss = F.mse_loss(y_pred, y)\n", " regularization_loss = 0\n", " for p1, p2 in zip(\n", " self.parameters(), # parameters of `self.net`\n", " self.meta_parameters(), # parameters of `self.meta_net`\n", " ):\n", " regularization_loss += (\n", " 0.5 * self.reg_param * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", " )\n", " return loss + regularization_loss\n", "\n", " def solve(self, x, y):\n", " params = tuple(self.parameters())\n", " inner_optim = torchopt.SGD(params, lr=2e-2)\n", " with torch.enable_grad():\n", " # Temporarily enable gradient computation for conducting the optimization\n", " for _ in range(self.n_inner_iter):\n", " loss = self.objective(x, y)\n", " inner_optim.zero_grad()\n", " # NOTE: The parameter inputs should be explicitly specified in `backward` function\n", " # as argument `inputs`. Otherwise, if not provided, the gradient is accumulated into\n", " # all the leaf Tensors (including the meta-parameters) that were used to compute the\n", " # objective output. Alternatively, please use `torch.autograd.grad` instead.\n", " loss.backward(inputs=params) # backward pass in inner-loop\n", " inner_optim.step() # update inner parameters\n", " return self\n", "\n", "\n", "# Initialize the meta-network\n", "meta_net = Net(4)\n", "inner_net = InnerNet(meta_net, 100, reg_param=1)\n", "\n", "# Solve for inner-loop\n", "optimal_inner_net = inner_net.solve(x, y)\n", "outer_loss = optimal_inner_net(x).mean()\n", "\n", "# Derive the meta-gradient\n", "torch.autograd.grad(outer_loss, meta_net.parameters())" ] }, { "cell_type": "markdown", "id": "2b69a5d6-b5e4-4f08-af0a-40afc2382b45", "metadata": {}, "source": [ "We also show an example on how to implement implicit gradient calculation when the inner-level optimal solution reaches some stationary conditions $F (\\phi^{\\star}, \\theta) = 0$, such as [DEQ](https://arxiv.org/abs/1909.01377), based on the OOP API. " ] }, { "cell_type": "code", "execution_count": 9, "id": "de87c308-d847-4491-9aa1-bc393e6dd1d8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "(\n", "│ tensor([[ 0.0272, 0.0031, -0.0156, -0.0238],\n", "│ │ [ 0.1004, 0.0113, -0.0573, -0.0878],\n", "│ │ [ 0.0666, 0.0075, -0.0380, -0.0583],\n", "│ │ [ 0.1446, 0.0163, -0.0826, -0.1265]]),\n", "│ tensor([0.0574, 0.2114, 0.1403, 0.3046])\n", ")\n" ] } ], "source": [ "class Net(nn.Module):\n", " def __init__(self, dim):\n", " super().__init__()\n", " self.fc = nn.Linear(dim, dim)\n", "\n", " def forward(self, x):\n", " return self.fc(x)\n", "\n", "\n", "class InnerNet(\n", " torchopt.nn.ImplicitMetaGradientModule,\n", " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", "):\n", " def __init__(self, meta_net, x0):\n", " super().__init__()\n", " # Register meta-parameter\n", " self.meta_net = meta_net\n", " # Declaration of the inner-parameter, register inner-parameter\n", " self.x = nn.Parameter(x0.clone().detach_(), requires_grad=True)\n", "\n", " def forward(self, x):\n", " return self.meta_net(x)\n", "\n", " def optimality(self):\n", " # Fixed-point condition\n", " return (self.x - self(self.x),)\n", "\n", " def solve(self):\n", " # Solving inner-loop fixed-point iteration\n", " # This is just an illustrating example for solving fixed-point iteration\n", " # one can use more advanced method to solve fixed-point iteration\n", " # such as anderson acceleration.\n", " for _ in range(10):\n", " self.x.copy_(self(self.x))\n", " return self\n", "\n", "\n", "# Initialize meta-network\n", "torch.manual_seed(0)\n", "meta_net = Net(4)\n", "x0 = torch.randn(1, 4)\n", "inner_net = InnerNet(meta_net, x0)\n", "\n", "# Solve for inner-loop\n", "optimal_inner_net = inner_net.solve()\n", "outer_loss = optimal_inner_net.x.mean()\n", "\n", "# Derive the meta-gradient\n", "torch.autograd.grad(outer_loss, meta_net.parameters())" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.15" }, "vscode": { "interpreter": { "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" } } }, "nbformat": 4, "nbformat_minor": 5 }