{
"cells": [
{
"cell_type": "markdown",
"id": "8850c832-3b54-4971-8ee0-2cd64b585ea8",
"metadata": {},
"source": [
"# TorchOpt for Zero-Order Differentiation"
]
},
{
"cell_type": "markdown",
"id": "2b547376",
"metadata": {},
"source": [
"[
](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/6_Zero_Order_Differentiation.ipynb)"
]
},
{
"cell_type": "markdown",
"id": "8d7f9865-dc02-43d4-be90-da1160c4e4dd",
"metadata": {},
"source": [
"When the inner-loop process is non-differentiable or one wants to eliminate the heavy computation burdens in the previous two modes (brought by Hessian), one can choose ZD. ZD typically gets gradients based on zero-order estimation, such as finite-difference, or Evolutionary Strategy.\n",
"\n",
"TorchOpt offers API for ES-based differentiation. Instead of optimizing the objective $f (\\boldsymbol{\\theta}): \\mathbb{R}^n \\to \\mathbb{R}$, ES optimizes a Gaussion smoothing objective defined as $\\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\mathbb{E}_{\\boldsymbol{z} \\sim \\mathcal{N}( 0, {I}_d )} [ f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) ]$, where $\\sigma$ denotes precision. The gradient of such objective is $\\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\mathcal{N}( 0, {I}_d )} [ f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) \\cdot \\boldsymbol{z} ]$. Refer to [ES-MAML](https://arxiv.org/pdf/1910.01215.pdf) for more details."
]
},
{
"cell_type": "markdown",
"id": "d7e4b9e1-115f-45ad-a9b3-ea338bcfe6dd",
"metadata": {},
"source": [
"In this tutorial, we will introduce how TorchOpt can be used to ES-based differentiation."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8f13ae67-e328-409f-84a8-1fc425c03a66",
"metadata": {},
"outputs": [],
"source": [
"import functorch\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"import torchopt"
]
},
{
"cell_type": "markdown",
"id": "0cdaac49-4b94-4900-9bb5-a39057ac8b21",
"metadata": {},
"source": [
"## 1. Functional API\n",
"\n",
"The basic functional API is `torchopt.diff.zero_order.zero_order`, which is used as the decorator for the forward process zero-order gradient procedures. Users are required to implement the noise sampling function, which will be used as the input of zero_order decorator. Here we show the specific meaning for each parameter used in the decorator.\n",
"\n",
"- `distribution` for noise sampling distribution. The distribution $\\lambda$ should be spherical symmetric and with a constant variance of $1$ for each element. I.e.:\n",
"\n",
" - Spherical symmetric: $\\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ \\boldsymbol{z} ] = \\boldsymbol{0}$.\n",
" - Constant variance of $1$ for each element: $\\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ {\\lvert z_i \\rvert}^2 ] = 1$.\n",
" - For example, the standard multi-dimensional normal distribution $\\mathcal{N} (\\boldsymbol{0}, \\boldsymbol{1})$.\n",
"\n",
"- `method` for different kind of algorithms, we support `'naive'` ([ES-RL](https://arxiv.org/abs/1703.03864)), `'forward'` ([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)), and `'antithetic'` ([antithetic](https://arxiv.org/abs/1803.07055)).\n",
"\n",
" $$\n",
" \\begin{align*}\n",
" \\text{naive} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) \\cdot \\boldsymbol{z} ] \\\\\n",
" \\text{forward} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{\\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ ( f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) - f (\\boldsymbol{\\theta}) ) \\cdot \\boldsymbol{z} ] \\\\\n",
" \\text{antithetic} \\qquad & \\nabla_{\\boldsymbol{\\theta}} \\tilde{f}_{\\sigma} (\\boldsymbol{\\theta}) = \\frac{1}{2 \\sigma} \\mathbb{E}_{\\boldsymbol{z} \\sim \\lambda} [ (f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) - f (\\boldsymbol{\\theta} + \\sigma \\, \\boldsymbol{z}) ) \\cdot \\boldsymbol{z} ]\n",
" \\end{align*}\n",
" $$\n",
"\n",
"- `argnums` specifies which parameter we want to trace the meta-gradient.\n",
"- `num_samples` specifies how many times we want to conduct the sampling.\n",
"- `sigma` is for precision. This is the scaling factor for the sampling distribution.\n",
"\n",
"We show the pseudo code in the following part.\n",
"\n",
"```python\n",
"# Functional API for zero-order differentiation\n",
"# 1. Customize the noise distribution via a distribution class\n",
"class Distribution:\n",
" def sample(self, sample_shape=torch.Size()):\n",
" # Sampling function for noise\n",
" # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.\n",
" ...\n",
" return noise_batch\n",
"\n",
"distribution = Distribution()\n",
"\n",
"# 2. Customize the noise distribution via a sampling function\n",
"def distribution(sample_shape=torch.Size()):\n",
" # Sampling function for noise\n",
" # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.\n",
" ...\n",
" return noise_batch\n",
"\n",
"# 3. Distribution can also be an instance of `torch.distributions.Distribution`, e.g., `torch.distributions.Normal(...)`\n",
"distribution = torch.distributions.Normal(loc=0, scale=1)\n",
"\n",
"# Decorator that wraps the function\n",
"@torchopt.diff.zero_order(distribution=distribution, method='naive', argnums=0, num_samples=100, sigma=0.01)\n",
"def forward(params, data):\n",
" # Forward optimization process for params\n",
" ...\n",
" return objective # the returned tensor should be a scalar tensor\n",
"\n",
"# Define params and get data\n",
"params, data = ..., ...\n",
"\n",
"# Forward pass\n",
"loss = forward(params, data)\n",
"# Backward pass using zero-order differentiation\n",
"grads = torch.autograd.grad(loss, params)\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "dbef87df-2164-4f1d-8919-37a6fbdc5011",
"metadata": {},
"source": [
"Here we use the example of a linear layer as an example, note that this is just an example to show linear layer can work with ES."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8d623b2f-48ee-4df6-a2ce-cf306b4c9067",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"001: tensor(0.0265, grad_fn=)\n",
"002: tensor(0.0243, grad_fn=)\n",
"003: tensor(0.0222, grad_fn=)\n",
"004: tensor(0.0202, grad_fn=)\n",
"005: tensor(0.0184, grad_fn=)\n",
"006: tensor(0.0170, grad_fn=)\n",
"007: tensor(0.0157, grad_fn=)\n",
"008: tensor(0.0146, grad_fn=)\n",
"009: tensor(0.0137, grad_fn=)\n",
"010: tensor(0.0130, grad_fn=)\n",
"011: tensor(0.0123, grad_fn=)\n",
"012: tensor(0.0118, grad_fn=)\n",
"013: tensor(0.0114, grad_fn=)\n",
"014: tensor(0.0111, grad_fn=)\n",
"015: tensor(0.0111, grad_fn=)\n",
"016: tensor(0.0111, grad_fn=)\n",
"017: tensor(0.0113, grad_fn=)\n",
"018: tensor(0.0115, grad_fn=)\n",
"019: tensor(0.0118, grad_fn=)\n",
"020: tensor(0.0120, grad_fn=)\n",
"021: tensor(0.0121, grad_fn=)\n",
"022: tensor(0.0121, grad_fn=)\n",
"023: tensor(0.0122, grad_fn=)\n",
"024: tensor(0.0122, grad_fn=)\n",
"025: tensor(0.0122, grad_fn=)\n"
]
}
],
"source": [
"torch.random.manual_seed(0)\n",
"\n",
"fmodel, params = functorch.make_functional(nn.Linear(32, 1))\n",
"x = torch.randn(64, 32) * 0.1\n",
"y = torch.randn(64, 1) * 0.1\n",
"distribution = torch.distributions.Normal(loc=0, scale=1)\n",
"\n",
"\n",
"@torchopt.diff.zero_order(\n",
" distribution=distribution, method='forward', argnums=0, num_samples=100, sigma=0.01\n",
")\n",
"def forward_process(params, fn, x, y):\n",
" y_pred = fn(params, x)\n",
" loss = F.mse_loss(y_pred, y)\n",
" return loss\n",
"\n",
"\n",
"optimizer = torchopt.adam(lr=0.01)\n",
"opt_state = optimizer.init(params) # init optimizer\n",
"\n",
"for i in range(25):\n",
" loss = forward_process(params, fmodel, x, y) # compute loss\n",
"\n",
" grads = torch.autograd.grad(loss, params) # compute gradients\n",
" updates, opt_state = optimizer.update(grads, opt_state) # get updates\n",
" params = torchopt.apply_updates(params, updates) # update network parameters\n",
"\n",
" print(f'{i + 1:03d}: {loss!r}')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "db723f6b",
"metadata": {},
"source": [
"## 2. OOP API\n",
"\n",
"The basic OOP API is the class `ZeroOrderGradientModule`. We make the network as an `nn.Module` following a classical PyTorch style. Users need to define the forward process zero-order gradient procedures `forward()` and a noise sampling function `sample()`. Here we show the specific meaning for each parameter used in the class.\n",
"\n",
"- `method` for different kind of algorithms, we support `'naive'` ([ES-RL](https://arxiv.org/abs/1703.03864)), `'forward'` ([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)), and `'antithetic'` ([antithetic](https://d1wqtxts1xzle7.cloudfront.net/75609515/coredp2011_1web-with-cover-page-v2.pdf?Expires=1670215467&Signature=RfP~mQhhhI7aGknwXbRBgSggFrKuNTPYdyUSdMmfTxOa62QoOJAm-Xhr3F1PLyjUQc2JVxmKIKGGuyYvyfCTpB31dfmMtuVQxZMWVF-SfErTN05SliC93yjA1x1g2kjhn8bkBFdQqGl~1RQSKnhj88BakgSeDNzyCxwbD5VgR89BXRs4YIK5RBIKYtgLhoyz5jar7wHS3TJhRzs3WNeTIAjAmLqJ068oGFZ0Jr7maGquTe3w~8LEEIprJ6cyCMc6b1UUJkmwjNq0RLTVbxgFjfi4Z9kyxyJB9IOS1J25OOON4jfwh5JlXS7MVskuONUyHJim1TQ8OwCraKlBsQLPQw__&Key-Pair-Id=APKAJLOHF5GGSLRBV4ZA)).\n",
"- `num_samples` specifies how many times we want to conduct the sampling.\n",
"- `sigma` is for precision. This is the scaling factor for the sampling distribution.\n",
"\n",
"We show the pseudo code in the following part.\n",
"\n",
"```python\n",
"from torchopt.nn import ZeroOrderGradientModule\n",
"\n",
"# Inherited from the class ZeroOrderGradientModule\n",
"# Optionally specify the `method` and/or `num_samples` and/or `sigma` used for sampling\n",
"class Net(ZeroOrderGradientModule, method='naive', num_samples=100, sigma=0.01):\n",
" def __init__(self, ...):\n",
" ...\n",
"\n",
" def forward(self, batch):\n",
" # Forward process\n",
" ...\n",
" return objective # the returned tensor should be a scalar tensor\n",
"\n",
" def sample(self, sample_shape=torch.Size()):\n",
" # Generate a batch of noise samples\n",
" # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.\n",
" ...\n",
" return noise_batch\n",
"\n",
"# Get model and data\n",
"net = Net(...)\n",
"data = ...\n",
"\n",
"# Forward pass\n",
"loss = Net(data)\n",
"# Backward pass using zero-order differentiation\n",
"grads = torch.autograd.grad(loss, net.parameters())\n",
"```"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "b53524f5",
"metadata": {},
"source": [
"Here we reimplement the functional API example above with the OOP API."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ecc5730c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"001: tensor(0.0201, grad_fn=)\n",
"002: tensor(0.0181, grad_fn=)\n",
"003: tensor(0.0167, grad_fn=)\n",
"004: tensor(0.0153, grad_fn=)\n",
"005: tensor(0.0142, grad_fn=)\n",
"006: tensor(0.0133, grad_fn=)\n",
"007: tensor(0.0125, grad_fn=)\n",
"008: tensor(0.0119, grad_fn=)\n",
"009: tensor(0.0116, grad_fn=)\n",
"010: tensor(0.0114, grad_fn=)\n",
"011: tensor(0.0112, grad_fn=)\n",
"012: tensor(0.0112, grad_fn=)\n",
"013: tensor(0.0113, grad_fn=)\n",
"014: tensor(0.0116, grad_fn=)\n",
"015: tensor(0.0118, grad_fn=)\n",
"016: tensor(0.0121, grad_fn=)\n",
"017: tensor(0.0123, grad_fn=)\n",
"018: tensor(0.0125, grad_fn=)\n",
"019: tensor(0.0127, grad_fn=)\n",
"020: tensor(0.0127, grad_fn=)\n",
"021: tensor(0.0125, grad_fn=)\n",
"022: tensor(0.0123, grad_fn=)\n",
"023: tensor(0.0120, grad_fn=)\n",
"024: tensor(0.0118, grad_fn=)\n",
"025: tensor(0.0117, grad_fn=)\n"
]
}
],
"source": [
"torch.random.manual_seed(0)\n",
"\n",
"\n",
"class Net(torchopt.nn.ZeroOrderGradientModule, method='forward', num_samples=100, sigma=0.01):\n",
" def __init__(self, dim):\n",
" super().__init__()\n",
" self.fc = nn.Linear(dim, 1)\n",
" self.distribution = torch.distributions.Normal(loc=0, scale=1)\n",
"\n",
" def forward(self, x, y):\n",
" y_pred = self.fc(x)\n",
" loss = F.mse_loss(y_pred, y)\n",
" return loss\n",
"\n",
" def sample(self, sample_shape=torch.Size()):\n",
" return self.distribution.sample(sample_shape)\n",
"\n",
"\n",
"x = torch.randn(64, 32) * 0.1\n",
"y = torch.randn(64, 1) * 0.1\n",
"net = Net(dim=32)\n",
"\n",
"\n",
"optimizer = torchopt.Adam(net.parameters(), lr=0.01)\n",
"\n",
"for i in range(25):\n",
" loss = net(x, y) # compute loss\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward() # backward pass\n",
" optimizer.step() # update network parameters\n",
"\n",
" print(f'{i + 1:03d}: {loss!r}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.15 ('torchopt')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
},
"vscode": {
"interpreter": {
"hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}