{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# `torchopt.stop_gradient` in Meta-Learning"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/4_Stop_Gradient.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this tutorial, we will illustrate the usage of `torchopt.stop_gradient` with a meta-learning example. We use `torchopt.visual` to help us visualize what is going on in automatic differentiation. Firstly, we define a simple network and the objective function for inner- and outer- optimization."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import display\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"import torchopt\n",
"\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self, dim):\n",
" super().__init__()\n",
" self.fc = nn.Linear(dim, 1, bias=True)\n",
"\n",
" def forward(self, x):\n",
" return self.fc(x)\n",
"\n",
"\n",
"loss_fn = F.mse_loss"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define the input `x` and output `y`. `y` will be served as the regression target in the following code."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 64\n",
"dim = 16\n",
"\n",
"x = torch.randn((batch_size, dim))\n",
"y = torch.zeros((batch_size, 1))\n",
"net = Net(dim)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us define the meta-parameter, we use `MetaSGD` as the inner-loop optimizer and `Adam` as the outer-loop optimizer. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"meta_parameter = nn.Parameter(torch.tensor(1.0), requires_grad=True)\n",
"\n",
"optim = torchopt.MetaSGD(net, lr=1e-1)\n",
"meta_optim = torch.optim.Adam([meta_parameter], lr=1e-1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define the inner-loop optimization and visualize the inner-loop forward gradient flow."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"inner loss: 0.3472\n",
"\n"
]
},
{
"data": {
"image/svg+xml": "\n\n\n\n\n"
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"init_net_state = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')\n",
"\n",
"# inner loss\n",
"inner_loss = loss_fn(net(x), y)\n",
"\n",
"print(f'inner loss: {inner_loss:.4f}')\n",
"display(torchopt.visual.make_dot(inner_loss, params=(init_net_state, {'inner_loss': inner_loss})))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Conduct inner-loop optimization with `MetaSGD`, here the meta-parameter is served as a factor controlling the scale of inner-loop loss."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# inner-step optimization\n",
"loss = inner_loss * meta_parameter\n",
"optim.step(loss)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We compute the outer loss and draw the full computation graph of the first bi-level process. In this graph, three main parts are included.\n",
"\n",
"- Inner-loop: forward process and inner-loss calculation\n",
"- Inner-loop optimization: `MetaSGD` optimization step given inner-loss\n",
"- Outer-loop: forward process and outer-loss calculation"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"outer loss: 0.2039\n",
"\n"
]
},
{
"data": {
"image/svg+xml": "\n\n\n\n\n"
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Extract `state_dict`` for updated network\n",
"one_step_net_state = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')\n",
"one_step_optim_state = torchopt.extract_state_dict(optim)\n",
"\n",
"# Calculate outer loss\n",
"outer_loss = loss_fn(net(x), y)\n",
"print(f'outer loss: {outer_loss:.4f}')\n",
"display(\n",
" torchopt.visual.make_dot(\n",
" outer_loss,\n",
" params=(\n",
" init_net_state,\n",
" one_step_net_state,\n",
" {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n",
" ),\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we backward the loss to conduct outer-loop meta-optimization."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"meta_parameter.grad = tensor(-0.1205)\n",
"meta_parameter = Parameter containing:\n",
"tensor(1.1000, requires_grad=True)\n"
]
}
],
"source": [
"meta_optim.zero_grad()\n",
"outer_loss.backward()\n",
"print(f'meta_parameter.grad = {meta_parameter.grad!r}')\n",
"meta_optim.step()\n",
"print(f'meta_parameter = {meta_parameter!r}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have already conducted one bi-level optimization and optimize our meta-parameters. When you want to conduct the second bi-level optimization, you need to be careful whether you need to use the `stop_gradient` function. For example, if your new inner-loop parameters directly inherits previous inner-loop parameters (which is a common strategy in many meta-learning algorithms like Meta-Gradient Reinforcement Learning (MGRL) ([arXiv:1805.09801](https://arxiv.org/abs/1805.09801))), you might need `stop_gradient` function."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In general, the backpropagation only frees saved tensors (often used as auxiliary data for computing the gradient) but the computation graph remains. Once the outer iteration is finished, if you want to use any intermediate network parameters produced by the inner loop for the next bi-level iteration, you should detach them from the computation graph.\n",
"\n",
"There are two main reasons:\n",
"\n",
"- The network parameters are still connected to the previous computation graph (`.grad_fn` is not `None`). If later the gradient backpropagate to these parameters, the PyTorch backward engine will try to backpropagate through the previous computation graph. This will raise a `RuntimeError`: Trying to backward through the graph a second time...\n",
"- If we do not detach the computation graph, the computation graph connected to these parameters can not be freed by GC (Garbage Collector) until these parameters are collected by GC."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let us see what will happen if we do not use the `stop_gradient` function before we conduct the second bi-level process."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"image/svg+xml": "\n\n\n\n\n"
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"
╭─────────────────────────────────────── Traceback (most recent call last) ───────────────────────────────────────╮\n",
"│/tmp/ipykernel_3962266/4178930003.py:21 in <module>│\n",
"││\n",
"│[Errno 2] No such file or directory: '/tmp/ipykernel_3962266/4178930003.py'│\n",
"││\n",
"│/home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/torch/_tensor.py:487 in backward│\n",
"││\n",
"│ 484 │ │ │ │ create_graph=create_graph, │\n",
"│ 485 │ │ │ │ inputs=inputs, │\n",
"│ 486 │ │ │ ) │\n",
"│❱ 487 │ │ torch.autograd.backward( │\n",
"│ 488 │ │ │ self, gradient, retain_graph, create_graph, inputs=inputs │\n",
"│ 489 │ │ ) │\n",
"│ 490 │\n",
"││\n",
"│╭───────────────────────── locals ──────────────────────────╮│\n",
"││ create_graph = False││\n",
"││ gradient = None││\n",
"││ inputs = None││\n",
"││ retain_graph = None││\n",
"││ self = tensor(0.1203, grad_fn=<MseLossBackward0>)││\n",
"│╰───────────────────────────────────────────────────────────╯│\n",
"││\n",
"│/home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/torch/autograd/__init__.py:197 in backward│\n",
"││\n",
"│194 │ # The reason we repeat same the comment below is that│\n",
"│195 │ # some Python versions print out the first line of a multi-line function│\n",
"│196 │ # calls in the traceback and some print out the last line│\n",
"│❱ 197 │ Variable._execution_engine.run_backward( # Calls into the C++ engine to run the ba│\n",
"│198 │ │ tensors, grad_tensors_, retain_graph, create_graph, inputs, │\n",
"│199 │ │ allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to r│\n",
"│200 │\n",
"││\n",
"│╭──────────────────────────── locals ────────────────────────────╮│\n",
"││ create_graph = False││\n",
"││ grad_tensors = None││\n",
"││ grad_tensors_ = (tensor(1.),)││\n",
"││ grad_variables = None││\n",
"││ inputs = ()││\n",
"││ retain_graph = False││\n",
"││ tensors = (tensor(0.1203, grad_fn=<MseLossBackward0>),)││\n",
"│╰────────────────────────────────────────────────────────────────╯│\n",
"╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
"RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have \n",
"already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad().\n",
"Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved \n",
"tensors after calling backward.\n",
"
\n"
],
"text/plain": [
"\u001b[31m╭─\u001b[0m\u001b[31m────────────────────────────────────── \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m ──────────────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2;33m/tmp/ipykernel_3962266/\u001b[0m\u001b[1;33m4178930003.py\u001b[0m:\u001b[94m21\u001b[0m in \u001b[92m\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[3;31m[Errno 2] No such file or directory: '/tmp/ipykernel_3962266/4178930003.py'\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2;33m/home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/torch/\u001b[0m\u001b[1;33m_tensor.py\u001b[0m:\u001b[94m487\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m 484 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mcreate_graph=create_graph, \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m 485 \u001b[0m\u001b[2m│ │ │ │ \u001b[0minputs=inputs, \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m 486 \u001b[0m\u001b[2m│ │ │ \u001b[0m) \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 487 \u001b[2m│ │ \u001b[0mtorch.autograd.backward( \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m 488 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m, gradient, retain_graph, create_graph, inputs=inputs \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m 489 \u001b[0m\u001b[2m│ │ \u001b[0m) \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m 490 \u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[33m╭─\u001b[0m\u001b[33m──────────────────────── locals ─────────────────────────\u001b[0m\u001b[33m─╮\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m create_graph = \u001b[94mFalse\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m gradient = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m inputs = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m retain_graph = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m self = \u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[94m0.1203\u001b[0m, \u001b[33mgrad_fn\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mMseLossBackward0\u001b[0m\u001b[1m>\u001b[0m\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[33m╰───────────────────────────────────────────────────────────╯\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2;33m/home/PanXuehai/Miniconda3/envs/torchopt/lib/python3.9/site-packages/torch/autograd/\u001b[0m\u001b[1;33m__init__.py\u001b[0m:\u001b[94m197\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m194 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# The reason we repeat same the comment below is that\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m195 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# some Python versions print out the first line of a multi-line function\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m196 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# calls in the traceback and some print out the last line\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m197 \u001b[2m│ \u001b[0mVariable._execution_engine.run_backward( \u001b[2m# Calls into the C++ engine to run the ba\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m198 \u001b[0m\u001b[2m│ │ \u001b[0mtensors, grad_tensors_, retain_graph, create_graph, inputs, \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m199 \u001b[0m\u001b[2m│ │ \u001b[0mallow_unreachable=\u001b[94mTrue\u001b[0m, accumulate_grad=\u001b[94mTrue\u001b[0m) \u001b[2m# Calls into the C++ engine to r\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m200 \u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[33m╭─\u001b[0m\u001b[33m─────────────────────────── locals ───────────────────────────\u001b[0m\u001b[33m─╮\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m create_graph = \u001b[94mFalse\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m grad_tensors = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m grad_tensors_ = \u001b[1m(\u001b[0m\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[94m1\u001b[0m.\u001b[1m)\u001b[0m,\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m grad_variables = \u001b[94mNone\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m inputs = \u001b[1m(\u001b[0m\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m retain_graph = \u001b[94mFalse\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[33m│\u001b[0m tensors = \u001b[1m(\u001b[0m\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[94m0.1203\u001b[0m, \u001b[33mgrad_fn\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mMseLossBackward0\u001b[0m\u001b[1m>\u001b[0m\u001b[1m)\u001b[0m,\u001b[1m)\u001b[0m \u001b[33m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[33m╰────────────────────────────────────────────────────────────────╯\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n",
"\u001b[1;91mRuntimeError: \u001b[0mTrying to backward through the graph a second time \u001b[1m(\u001b[0mor directly access saved tensors after they have \n",
"already been freed\u001b[1m)\u001b[0m. Saved intermediate values of the graph are freed when you call \u001b[1;35m.backward\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m or \u001b[1;35mautograd.grad\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m.\n",
"Specify \u001b[33mretain_graph\u001b[0m=\u001b[3;92mTrue\u001b[0m if you need to backward through the graph a second time or if you need to access saved \n",
"tensors after calling backward.\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Inner update with attached computation graph\n",
"inner_loss = loss_fn(net(x), y)\n",
"loss = inner_loss * meta_parameter\n",
"optim.step(loss)\n",
"\n",
"# Outer forward process\n",
"outer_loss = loss_fn(net(x), y)\n",
"display(\n",
" torchopt.visual.make_dot(\n",
" outer_loss,\n",
" params=(\n",
" init_net_state,\n",
" one_step_net_state,\n",
" {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n",
" ),\n",
" )\n",
")\n",
"\n",
"# Outer update\n",
"meta_optim.zero_grad()\n",
"outer_loss.backward()\n",
"meta_optim.step()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From the graph we can see, directly conducting the second bi-level process links the graph of first and second bi-level process together. We should manually stop gradient with `torchopt.stop_gradient`. `torchopt.stop_gradient` will detach the node of gradient graph and make it become a leaf node. It allows the input of network, optimizer, or state dictionary and the gradient operation happens in an in-place manner.\n",
"\n",
"Let's use `recover_state_dict` to come back to one-step updated states."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# Reset to previous one-step updated states\n",
"torchopt.recover_state_dict(net, one_step_net_state)\n",
"torchopt.recover_state_dict(optim, one_step_optim_state)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And finally, Let's conduct the stop-gradient operation before the second meta-optimization step. "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"meta_parameter.grad = tensor(-0.0635)\n",
"meta_parameter = Parameter containing:\n",
"tensor(1.1940, requires_grad=True)\n",
"\n"
]
},
{
"data": {
"image/svg+xml": "\n\n\n\n\n"
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Stop gradient and make them become the leaf node\n",
"torchopt.stop_gradient(net)\n",
"torchopt.stop_gradient(optim)\n",
"one_step_net_state_detached = torchopt.extract_state_dict(\n",
" net, enable_visual=True, visual_prefix='step1.detached.'\n",
")\n",
"\n",
"# Inner update\n",
"inner_loss = loss_fn(net(x), y)\n",
"loss = inner_loss * meta_parameter\n",
"optim.step(loss)\n",
"\n",
"# Outer update\n",
"outer_loss = loss_fn(net(x), y)\n",
"meta_optim.zero_grad()\n",
"outer_loss.backward()\n",
"print(f'meta_parameter.grad = {meta_parameter.grad!r}')\n",
"meta_optim.step()\n",
"print(f'meta_parameter = {meta_parameter!r}')\n",
"\n",
"display(\n",
" torchopt.visual.make_dot(\n",
" outer_loss,\n",
" params=(\n",
" one_step_net_state_detached,\n",
" {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n",
" ),\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The gradient graph is the same with the first meta-optimization's gradient graph and we successfully conduct the second bi-level process."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
},
"vscode": {
"interpreter": {
"hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}