{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TorchOpt as Functional Optimizer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[
](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/1_Functional_Optimizer.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this tutorial, we will introduce how TorchOpt can be treated as functional optimizer to conduct normal optimization with functional programming style. We will also illustrate how to conduct differentiable optimization with functional programming in PyTorch."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Basic API\n",
"\n",
"In this first part, we will illustrate how TorchOpt can be used as a functional optimizer. We compare it with different API in [JAX](https://github.com/google/jax) and [PyTorch](https://pytorch.org) to help understand the similarity and dissimilarity. We use simple network, Adam optimizer and MSE loss objective."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from collections import OrderedDict\n",
"\n",
"import functorch\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import optax\n",
"import torch\n",
"import torch.autograd\n",
"import torch.nn as nn\n",
"\n",
"import torchopt\n",
"\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self, dim):\n",
" super().__init__()\n",
" self.fc = nn.Linear(dim, 1, bias=True)\n",
" nn.init.ones_(self.fc.weight)\n",
" nn.init.zeros_(self.fc.bias)\n",
"\n",
" def forward(self, x):\n",
" return self.fc(x)\n",
"\n",
"\n",
"def mse(inputs, targets):\n",
" return ((inputs - targets) ** 2).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1.1 Original JAX implementation\n",
"\n",
"The first example is JAX implementation coupled with [Optax](https://github.com/deepmind/optax), which belongs to functional programming style."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def origin_jax():\n",
" batch_size = 1\n",
" dim = 1\n",
" params = OrderedDict([('weight', jnp.ones((dim, 1))), ('bias', jnp.zeros((1,)))])\n",
"\n",
" def model(params, x):\n",
" return jnp.matmul(x, params['weight']) + params['bias']\n",
"\n",
" # Obtain the `opt_state` that contains statistics for the optimizer\n",
" learning_rate = 1.0\n",
" optimizer = optax.adam(learning_rate)\n",
" opt_state = optimizer.init(params)\n",
"\n",
" def compute_loss(params, x, y):\n",
" pred = model(params, x)\n",
" return mse(pred, y)\n",
"\n",
" xs = 2 * jnp.ones((batch_size, dim))\n",
" ys = jnp.ones((batch_size, 1))\n",
"\n",
" grads = jax.grad(compute_loss)(params, xs, ys)\n",
" updates, opt_state = optimizer.update(grads, opt_state)\n",
"\n",
" print('Parameters before update:', params)\n",
" params = optax.apply_updates(params, updates)\n",
" print('Parameters after update:', params)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameters before update:\n",
"OrderedDict([\n",
" ('weight', DeviceArray([[1.]], dtype=float32)),\n",
" ('bias', DeviceArray([0.], dtype=float32))\n",
"])\n",
"Parameters after update:\n",
"OrderedDict([\n",
" ('weight', DeviceArray([[6.735325e-06]], dtype=float32)),\n",
" ('bias', DeviceArray([-0.99999326], dtype=float32))\n",
"])\n"
]
}
],
"source": [
"origin_jax()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1.2 `functorch` with TorchOpt\n",
"\n",
"The second example is [`functorch`](https://pytorch.org/functorch) coupled with TorchOpt. It basically follows the same structure with the JAX example."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def interact_with_functorch():\n",
" batch_size = 1\n",
" dim = 1\n",
" net = Net(dim)\n",
" model, params = functorch.make_functional(net) # get the functional version of the model\n",
"\n",
" # Obtain the `opt_state` that contains statistics for the optimizer\n",
" learning_rate = 1.0\n",
" optimizer = torchopt.adam(learning_rate)\n",
" opt_state = optimizer.init(params)\n",
"\n",
" xs = 2 * torch.ones((batch_size, dim))\n",
" ys = torch.ones((batch_size, 1))\n",
"\n",
" pred = model(params, xs)\n",
" loss = mse(pred, ys)\n",
"\n",
" grads = torch.autograd.grad(loss, params)\n",
" updates, opt_state = optimizer.update(grads, opt_state)\n",
"\n",
" print('Parameters before update:', params)\n",
" params = torchopt.apply_updates(params, updates)\n",
" print('Parameters after update:', params)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameters before update:\n",
"(\n",
" Parameter containing: tensor([[1.]], requires_grad=True),\n",
" Parameter containing: tensor([0.], requires_grad=True)\n",
")\n",
"Parameters after update:\n",
"(\n",
" Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n",
" Parameter containing: tensor([-1.0000], requires_grad=True)\n",
")\n"
]
}
],
"source": [
"interact_with_functorch()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"TorchOpt also offers a wrapper `torchopt.FuncOptimizer` to make it easier to maintain the optimizer states."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def interact_with_functorch_with_wrapper():\n",
" batch_size = 1\n",
" dim = 1\n",
" net = Net(dim)\n",
" model, params = functorch.make_functional(net) # get the functional version of the model\n",
"\n",
" learning_rate = 1.0\n",
" optimizer = torchopt.FuncOptimizer(torchopt.adam(learning_rate))\n",
"\n",
" xs = 2 * torch.ones((batch_size, dim))\n",
" ys = torch.ones((batch_size, 1))\n",
"\n",
" pred = model(params, xs)\n",
" loss = mse(pred, ys)\n",
"\n",
" print('Parameters before update:', params)\n",
" params = optimizer.step(loss, params)\n",
" print('Parameters after update:', params)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameters before update:\n",
"(\n",
" Parameter containing: tensor([[1.]], requires_grad=True),\n",
" Parameter containing: tensor([0.], requires_grad=True)\n",
")\n",
"Parameters after update:\n",
"(\n",
" tensor([[6.6757e-06]], grad_fn=),\n",
" tensor([-1.0000], grad_fn=)\n",
")\n"
]
}
],
"source": [
"interact_with_functorch_with_wrapper()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1.3 Full TorchOpt\n",
"\n",
"`torchopt.Optimizer` is the base class for our PyTorch-like optimizer. Combined with the functional optimizer `torchopt.sgd` and `torchopt.adam`, we can define our high-level API `torchopt.SGD` and `torchopt.Adam`. The third example is to illustrate that TorchOpt can also directly replace `torch.optim` with exactly the same usage. Note the API difference happens between `torchopt.adam()` and `torchopt.Adam()`."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def full_torchopt():\n",
" batch_size = 1\n",
" dim = 1\n",
" net = Net(dim)\n",
"\n",
" learning_rate = 1.0\n",
" # High-level API\n",
" optim = torchopt.Adam(net.parameters(), lr=learning_rate)\n",
" # Low-level API\n",
" optim = torchopt.Optimizer(net.parameters(), torchopt.adam(lr=learning_rate))\n",
"\n",
" xs = 2 * torch.ones((batch_size, dim))\n",
" ys = torch.ones((batch_size, 1))\n",
"\n",
" pred = net(xs)\n",
" loss = mse(pred, ys)\n",
"\n",
" print('Parameters before update:', dict(net.named_parameters()))\n",
" optim.zero_grad()\n",
" loss.backward()\n",
" optim.step()\n",
" print('Parameters after update:', dict(net.named_parameters()))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameters before update:\n",
"{\n",
" 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n",
" 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n",
"}\n",
"Parameters after update:\n",
"{\n",
" 'fc.weight': Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n",
" 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n",
"}\n"
]
}
],
"source": [
"full_torchopt()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1.4 Original PyTorch\n",
"\n",
"The final example is to original PyTorch example with `torch.optim`."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def origin_torch():\n",
" batch_size = 1\n",
" dim = 1\n",
" net = Net(dim)\n",
"\n",
" learning_rate = 1.0\n",
" optim = torch.optim.Adam(net.parameters(), lr=learning_rate)\n",
"\n",
" xs = 2 * torch.ones((batch_size, dim))\n",
" ys = torch.ones((batch_size, 1))\n",
"\n",
" pred = net(xs)\n",
" loss = mse(pred, ys)\n",
"\n",
" print('Parameters before update:', dict(net.named_parameters()))\n",
" optim.zero_grad()\n",
" loss.backward()\n",
" optim.step()\n",
" print('Parameters after update:', dict(net.named_parameters()))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameters before update:\n",
"{\n",
" 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n",
" 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n",
"}\n",
"Parameters after update:\n",
"{\n",
" 'fc.weight': Parameter containing: tensor([[1.1921e-07]], requires_grad=True),\n",
" 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n",
"}\n"
]
}
],
"source": [
"origin_torch()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Differentiable Optimization with Functional Optimizer\n",
"\n",
"Coupled with functional optimizer, you can conduct differentiable optimization by setting the `inplace` flag as `False` in update and `apply_updates` function. (which might be helpful for meta-learning algorithm implementation with functional programming style). \n",
"\n",
"Note that `torchopt.SGD` and `torchopt.Adam` do not support differentiable optimization. Refer to the Meta-Optimizer notebook for PyTorch-like differentiable optimizers."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def differentiable():\n",
" batch_size = 1\n",
" dim = 1\n",
" net = Net(dim)\n",
" model, params = functorch.make_functional(net) # get the functional version of the model\n",
"\n",
" # Meta-parameter\n",
" meta_param = nn.Parameter(torch.ones(1))\n",
"\n",
" # SGD example\n",
" learning_rate = 1.0\n",
" optimizer = torchopt.sgd(learning_rate)\n",
" opt_state = optimizer.init(params)\n",
"\n",
" xs = torch.ones((batch_size, dim))\n",
" ys = torch.ones((batch_size, 1))\n",
"\n",
" pred = model(params, xs)\n",
" # Where meta_param is used\n",
" pred = pred + meta_param\n",
" loss = mse(pred, ys)\n",
"\n",
" grads = torch.autograd.grad(loss, params, create_graph=True)\n",
" updates, opt_state = optimizer.update(grads, opt_state, inplace=False)\n",
" # Update parameters with single step SGD update\n",
" params = torchopt.apply_updates(params, updates, inplace=False)\n",
"\n",
" pred = model(params, xs)\n",
" loss = mse(pred, ys)\n",
" loss.backward()\n",
"\n",
" print('Gradient for the meta-parameter:', meta_param.grad)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Gradient for the meta-parameter: tensor([32.])\n"
]
}
],
"source": [
"differentiable()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.1 Track the Gradient of Momentum\n",
"\n",
"Note that most modern optimizers involve momentum term in the gradient update (basically only SGD with `momentum = 0` does not involve). We provide an option for user to choose whether to also track the meta-gradient through momentum term. The default option is `moment_requires_grad=True`."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"optim = torchopt.adam(lr=1.0, moment_requires_grad=False)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"optim = torchopt.adam(lr=1.0, moment_requires_grad=True)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"optim = torchopt.sgd(lr=1.0, momentum=0.8, moment_requires_grad=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Accelerated Optimizer\n",
"\n",
"Users can use accelerated optimizer by setting the `use_accelerated_op` as `True`. Currently we only support the Adam optimizer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check whether the `accelerated_op` is available:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n"
]
}
],
"source": [
"torchopt.accelerated_op_available(torch.device('cpu'))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n"
]
}
],
"source": [
"torchopt.accelerated_op_available(torch.device('cuda'))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"net = Net(1).cuda()\n",
"optim = torchopt.Adam(net.parameters(), lr=1.0, use_accelerated_op=True)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"optim = torchopt.adam(lr=1.0, use_accelerated_op=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
},
"vscode": {
"interpreter": {
"hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}