{ "cells": [ { "cell_type": "markdown", "id": "8850c832-3b54-4971-8ee0-2cd64b585ea8", "metadata": {}, "source": [ "# TorchOpt for Zero-Order Differentiation" ] }, { "cell_type": "markdown", "id": "2b547376", "metadata": {}, "source": [ "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/6_Zero_Order_Differentiation.ipynb)" ] }, { "cell_type": "markdown", "id": "8d7f9865-dc02-43d4-be90-da1160c4e4dd", "metadata": {}, "source": [ "When the inner-loop process is non-differentiable or one wants to eliminate the heavy computation burdens in the previous two modes (brought by Hessian), one can choose ZD. ZD typically gets gradients based on zero-order estimation, such as finite-difference, or Evolutionary Strategy.\n", "\n", "TorchOpt offers API for ES-based differentiation. Instead of optimizing the objective $f (\\boldsymbol{\\theta}): \\mathbb{R}^n \\to \\mathbb{R}$, ES optimizes a Gaussion smoothing objective defined as $\\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\mathbb{E}_{\\boldsymbol{z} \\sim \\mathcal{N}( 0, {I}_d )} [ f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) ]$, where $\\sigma$ denotes precision. The gradient of such objective is $\\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\mathcal{N}( 0, {I}_d )} [ f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) \\cdot \\boldsymbol{z} ]$. Refer to [ES-MAML](https://arxiv.org/pdf/1910.01215.pdf) for more details." ] }, { "cell_type": "markdown", "id": "d7e4b9e1-115f-45ad-a9b3-ea338bcfe6dd", "metadata": {}, "source": [ "In this tutorial, we will introduce how TorchOpt can be used to ES-based differentiation." ] }, { "cell_type": "code", "execution_count": 1, "id": "8f13ae67-e328-409f-84a8-1fc425c03a66", "metadata": {}, "outputs": [], "source": [ "import functorch\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "import torchopt" ] }, { "cell_type": "markdown", "id": "0cdaac49-4b94-4900-9bb5-a39057ac8b21", "metadata": {}, "source": [ "## 1. Functional API\n", "\n", "The basic functional API is `torchopt.diff.zero_order.zero_order`, which is used as the decorator for the forward process zero-order gradient procedures. Users are required to implement the noise sampling function, which will be used as the input of zero_order decorator. Here we show the specific meaning for each parameter used in the decorator.\n", "\n", "- `distribution` for noise sampling distribution. The distribution $\\lambda$ should be spherical symmetric and with a constant variance of $1$ for each element. I.e.:\n", "\n", " - Spherical symmetric: $\\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ \\boldsymbol{z} ] = \\boldsymbol{0}$.\n", " - Constant variance of $1$ for each element: $\\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ {\\lvert z_i \\rvert}^2 ] = 1$.\n", " - For example, the standard multi-dimensional normal distribution $\\mathcal{N} (\\boldsymbol{0}, \\boldsymbol{1})$.\n", "\n", "- `method` for different kind of algorithms, we support `'naive'` ([ES-RL](https://arxiv.org/abs/1703.03864)), `'forward'` ([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)), and `'antithetic'` ([antithetic](https://arxiv.org/abs/1803.07055)).\n", "\n", " $$\n", " \\begin{align*}\n", " \\text{naive} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) \\cdot \\boldsymbol{z} ] \\\\\n", " \\text{forward} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ ( f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) - f (\\boldsymbol{\\theta}) ) \\cdot \\boldsymbol{z} ] \\\\\n", " \\text{antithetic} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{2 \\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ (f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) - f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) ) \\cdot \\boldsymbol{z} ]\n", " \\end{align*}\n", " $$\n", "\n", "- `argnums` specifies which parameter we want to trace the meta-gradient.\n", "- `num_samples` specifies how many times we want to conduct the sampling.\n", "- `sigma` is for precision. This is the scaling factor for the sampling distribution.\n", "\n", "We show the pseudo code in the following part.\n", "\n", "```python\n", "# Functional API for zero-order differentiation\n", "# 1. Customize the noise distribution via a distribution class\n", "class Distribution:\n", " def sample(self, sample_shape=torch.Size()):\n", " # Sampling function for noise\n", " # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.\n", " ...\n", " return noise_batch\n", "\n", "distribution = Distribution()\n", "\n", "# 2. Customize the noise distribution via a sampling function\n", "def distribution(sample_shape=torch.Size()):\n", " # Sampling function for noise\n", " # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.\n", " ...\n", " return noise_batch\n", "\n", "# 3. Distribution can also be an instance of `torch.distributions.Distribution`, e.g., `torch.distributions.Normal(...)`\n", "distribution = torch.distributions.Normal(loc=0, scale=1)\n", "\n", "# Decorator that wraps the function\n", "@torchopt.diff.zero_order(distribution=distribution, method='naive', argnums=0, num_samples=100, sigma=0.01)\n", "def forward(params, data):\n", " # Forward optimization process for params\n", " ...\n", " return objective # the returned tensor should be a scalar tensor\n", "\n", "# Define params and get data\n", "params, data = ..., ...\n", "\n", "# Forward pass\n", "loss = forward(params, data)\n", "# Backward pass using zero-order differentiation\n", "grads = torch.autograd.grad(loss, params)\n", "```" ] }, { "cell_type": "markdown", "id": "dbef87df-2164-4f1d-8919-37a6fbdc5011", "metadata": {}, "source": [ "Here we use the example of a linear layer as an example, note that this is just an example to show linear layer can work with ES." ] }, { "cell_type": "code", "execution_count": 2, "id": "8d623b2f-48ee-4df6-a2ce-cf306b4c9067", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "001: tensor(0.0265, grad_fn=)\n", "002: tensor(0.0243, grad_fn=)\n", "003: tensor(0.0222, grad_fn=)\n", "004: tensor(0.0202, grad_fn=)\n", "005: tensor(0.0184, grad_fn=)\n", "006: tensor(0.0170, grad_fn=)\n", "007: tensor(0.0157, grad_fn=)\n", "008: tensor(0.0146, grad_fn=)\n", "009: tensor(0.0137, grad_fn=)\n", "010: tensor(0.0130, grad_fn=)\n", "011: tensor(0.0123, grad_fn=)\n", "012: tensor(0.0118, grad_fn=)\n", "013: tensor(0.0114, grad_fn=)\n", "014: tensor(0.0111, grad_fn=)\n", "015: tensor(0.0111, grad_fn=)\n", "016: tensor(0.0111, grad_fn=)\n", "017: tensor(0.0113, grad_fn=)\n", "018: tensor(0.0115, grad_fn=)\n", "019: tensor(0.0118, grad_fn=)\n", "020: tensor(0.0120, grad_fn=)\n", "021: tensor(0.0121, grad_fn=)\n", "022: tensor(0.0121, grad_fn=)\n", "023: tensor(0.0122, grad_fn=)\n", "024: tensor(0.0122, grad_fn=)\n", "025: tensor(0.0122, grad_fn=)\n" ] } ], "source": [ "torch.random.manual_seed(0)\n", "\n", "fmodel, params = functorch.make_functional(nn.Linear(32, 1))\n", "x = torch.randn(64, 32) * 0.1\n", "y = torch.randn(64, 1) * 0.1\n", "distribution = torch.distributions.Normal(loc=0, scale=1)\n", "\n", "\n", "@torchopt.diff.zero_order(\n", " distribution=distribution, method='forward', argnums=0, num_samples=100, sigma=0.01\n", ")\n", "def forward_process(params, fn, x, y):\n", " y_pred = fn(params, x)\n", " loss = F.mse_loss(y_pred, y)\n", " return loss\n", "\n", "\n", "optimizer = torchopt.adam(lr=0.01)\n", "opt_state = optimizer.init(params) # init optimizer\n", "\n", "for i in range(25):\n", " loss = forward_process(params, fmodel, x, y) # compute loss\n", "\n", " grads = torch.autograd.grad(loss, params) # compute gradients\n", " updates, opt_state = optimizer.update(grads, opt_state) # get updates\n", " params = torchopt.apply_updates(params, updates) # update network parameters\n", "\n", " print(f'{i + 1:03d}: {loss!r}')" ] }, { "attachments": {}, "cell_type": "markdown", "id": "db723f6b", "metadata": {}, "source": [ "## 2. OOP API\n", "\n", "The basic OOP API is the class `ZeroOrderGradientModule`. We make the network as an `nn.Module` following a classical PyTorch style. Users need to define the forward process zero-order gradient procedures `forward()` and a noise sampling function `sample()`. Here we show the specific meaning for each parameter used in the class.\n", "\n", "- `method` for different kind of algorithms, we support `'naive'` ([ES-RL](https://arxiv.org/abs/1703.03864)), `'forward'` ([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)), and `'antithetic'` ([antithetic](https://d1wqtxts1xzle7.cloudfront.net/75609515/coredp2011_1web-with-cover-page-v2.pdf?Expires=1670215467&Signature=RfP~mQhhhI7aGknwXbRBgSggFrKuNTPYdyUSdMmfTxOa62QoOJAm-Xhr3F1PLyjUQc2JVxmKIKGGuyYvyfCTpB31dfmMtuVQxZMWVF-SfErTN05SliC93yjA1x1g2kjhn8bkBFdQqGl~1RQSKnhj88BakgSeDNzyCxwbD5VgR89BXRs4YIK5RBIKYtgLhoyz5jar7wHS3TJhRzs3WNeTIAjAmLqJ068oGFZ0Jr7maGquTe3w~8LEEIprJ6cyCMc6b1UUJkmwjNq0RLTVbxgFjfi4Z9kyxyJB9IOS1J25OOON4jfwh5JlXS7MVskuONUyHJim1TQ8OwCraKlBsQLPQw__&Key-Pair-Id=APKAJLOHF5GGSLRBV4ZA)).\n", "- `num_samples` specifies how many times we want to conduct the sampling.\n", "- `sigma` is for precision. This is the scaling factor for the sampling distribution.\n", "\n", "We show the pseudo code in the following part.\n", "\n", "```python\n", "from torchopt.nn import ZeroOrderGradientModule\n", "\n", "# Inherited from the class ZeroOrderGradientModule\n", "# Optionally specify the `method` and/or `num_samples` and/or `sigma` used for sampling\n", "class Net(ZeroOrderGradientModule, method='naive', num_samples=100, sigma=0.01):\n", " def __init__(self, ...):\n", " ...\n", "\n", " def forward(self, batch):\n", " # Forward process\n", " ...\n", " return objective # the returned tensor should be a scalar tensor\n", "\n", " def sample(self, sample_shape=torch.Size()):\n", " # Generate a batch of noise samples\n", " # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.\n", " ...\n", " return noise_batch\n", "\n", "# Get model and data\n", "net = Net(...)\n", "data = ...\n", "\n", "# Forward pass\n", "loss = Net(data)\n", "# Backward pass using zero-order differentiation\n", "grads = torch.autograd.grad(loss, net.parameters())\n", "```" ] }, { "attachments": {}, "cell_type": "markdown", "id": "b53524f5", "metadata": {}, "source": [ "Here we reimplement the functional API example above with the OOP API." ] }, { "cell_type": "code", "execution_count": 3, "id": "ecc5730c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "001: tensor(0.0201, grad_fn=)\n", "002: tensor(0.0181, grad_fn=)\n", "003: tensor(0.0167, grad_fn=)\n", "004: tensor(0.0153, grad_fn=)\n", "005: tensor(0.0142, grad_fn=)\n", "006: tensor(0.0133, grad_fn=)\n", "007: tensor(0.0125, grad_fn=)\n", "008: tensor(0.0119, grad_fn=)\n", "009: tensor(0.0116, grad_fn=)\n", "010: tensor(0.0114, grad_fn=)\n", "011: tensor(0.0112, grad_fn=)\n", "012: tensor(0.0112, grad_fn=)\n", "013: tensor(0.0113, grad_fn=)\n", "014: tensor(0.0116, grad_fn=)\n", "015: tensor(0.0118, grad_fn=)\n", "016: tensor(0.0121, grad_fn=)\n", "017: tensor(0.0123, grad_fn=)\n", "018: tensor(0.0125, grad_fn=)\n", "019: tensor(0.0127, grad_fn=)\n", "020: tensor(0.0127, grad_fn=)\n", "021: tensor(0.0125, grad_fn=)\n", "022: tensor(0.0123, grad_fn=)\n", "023: tensor(0.0120, grad_fn=)\n", "024: tensor(0.0118, grad_fn=)\n", "025: tensor(0.0117, grad_fn=)\n" ] } ], "source": [ "torch.random.manual_seed(0)\n", "\n", "\n", "class Net(torchopt.nn.ZeroOrderGradientModule, method='forward', num_samples=100, sigma=0.01):\n", " def __init__(self, dim):\n", " super().__init__()\n", " self.fc = nn.Linear(dim, 1)\n", " self.distribution = torch.distributions.Normal(loc=0, scale=1)\n", "\n", " def forward(self, x, y):\n", " y_pred = self.fc(x)\n", " loss = F.mse_loss(y_pred, y)\n", " return loss\n", "\n", " def sample(self, sample_shape=torch.Size()):\n", " return self.distribution.sample(sample_shape)\n", "\n", "\n", "x = torch.randn(64, 32) * 0.1\n", "y = torch.randn(64, 1) * 0.1\n", "net = Net(dim=32)\n", "\n", "\n", "optimizer = torchopt.Adam(net.parameters(), lr=0.01)\n", "\n", "for i in range(25):\n", " loss = net(x, y) # compute loss\n", "\n", " optimizer.zero_grad()\n", " loss.backward() # backward pass\n", " optimizer.step() # update network parameters\n", "\n", " print(f'{i + 1:03d}: {loss!r}')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.15 ('torchopt')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.15" }, "vscode": { "interpreter": { "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" } } }, "nbformat": 4, "nbformat_minor": 5 }