{ "cells": [ { "metadata": { "id": "ijPhcVzPXQDU" }, "cell_type": "markdown", "source": [ "Demonstrate selective quantization capabilities of AI Edge Quantizer.\n" ] }, { "metadata": { "id": "M5RmrWpQYQwS" }, "cell_type": "code", "source": [ "# Copyright 2024 The AI Edge Quantizer Authors.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "# ==============================================================================" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "hMsLPH02YR2R" }, "cell_type": "markdown", "source": [ "\n", " \n", " \n", "
\n", " Run in Google Colab\n", " \n", " View source on GitHub\n", "
" ] }, { "metadata": { "id": "0RqZd4zYZdbS" }, "cell_type": "code", "source": [ "# When running in google colab the pre-installed versions of some packages\n", "# might be incompatible with AI edge libraries.\n", "!pip uninstall -y tensorflow jax jaxlib torch torchvision torchao\n", "!pip install ai-edge-litert-nightly\n", "!pip install ai-edge-model-explorer\n", "!pip install ai-edge-quantizer-nightly\n", "!pip install litert-torch-nightly torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu\n", "!pip install pillow requests matplotlib" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "L16meLkvXgZk" }, "cell_type": "code", "source": [ "import os\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from PIL import Image\n", "import skimage\n", "import tensorflow as tf\n", "import ai_edge_quantizer\n", "import model_explorer\n", "\n", "from ai_edge_litert.interpreter import Interpreter" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "Q_i0HmTC_UMe" }, "cell_type": "code", "source": [ "# @title Preprocess/postprocess utilities (unrelated to quantization) { display-mode: \"form\" }\n", "\n", "MODEL_INPUT_HW = (1024, 1024)\n", "\n", "def make_channels_first(image):\n", " image = tf.transpose(image, [2, 0, 1])\n", " image = np.expand_dims(image, axis=0)\n", " return image\n", "\n", "def preprocess_image(file_path):\n", " image = skimage.io.imread(file_path)\n", " image = tf.image.resize(image, MODEL_INPUT_HW).numpy().astype(np.float32)\n", " image = image / 255.0\n", " return make_channels_first(image)\n", "\n", "def preprocess_image_litert_torch(test_image_path):\n", " image = Image.open(test_image_path)\n", " test_image = np.array(image.resize(MODEL_INPUT_HW, Image.Resampling.BILINEAR))\n", " test_image = np.expand_dims(test_image, axis=0).astype(np.float32)\n", " return test_image\n", "\n", "def run_segmentation(image, model_path):\n", " \"\"\"Get segmentation mask of the image.\"\"\"\n", " interpreter = Interpreter(model_path=model_path)\n", " interpreter.allocate_tensors()\n", "\n", " input_details = interpreter.get_input_details()[0]\n", " interpreter.set_tensor(input_details['index'], image)\n", " interpreter.invoke()\n", "\n", " output_details = interpreter.get_output_details()\n", " output_index = 0\n", " outputs = []\n", " for detail in output_details:\n", " outputs.append(interpreter.get_tensor(detail['index']))\n", " mask = tf.squeeze(outputs[output_index])\n", " # Min-max normalization.\n", " tf_min = np.min(mask)\n", " tf_max = np.max(mask)\n", " mask = (mask - tf_min) / (tf_max - tf_min)\n", " # Scale [0, 1] -> [0, 255].\n", " mask = (mask * 255)\n", " return mask\n", "\n", "\n", "def draw_segmentation(image, float_mask, quant_mask, info):\n", " _, ax = plt.subplots(1, 3, figsize=(15, 10))\n", "\n", " ax[0].imshow(np.array(image))\n", " ax[1].imshow(np.array(float_mask), cmap='gray')\n", " ax[2].imshow(np.array(quant_mask), cmap='gray')\n", "\n", " ax[1].set_title('Image')\n", " ax[1].set_title('Float Mask')\n", " ax[2].set_title('Quant Mask: {}'.format(info))\n", "\n", " plt.show()\n", "\n", "def save_model(model_content, save_path):\n", " with open(save_path, 'wb') as f:\n", " f.write(model_content)\n", "\n" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "J7PH4Um5EFkF" }, "cell_type": "code", "source": [ "!curl -H 'Accept: application/vnd.github.v3.raw' -O -L https://api.github.com/repos/google-ai-edge/ai-edge-quantizer/contents/colabs/test_data/input_image.jpg\n", "\n", "IMAGE_PATH = 'input_image.jpg'\n", "\n", "image = Image.open(IMAGE_PATH)\n", "test_image = preprocess_image_litert_torch(IMAGE_PATH)" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "GHsfVmAPu_8H" }, "cell_type": "markdown", "source": [ "# Getting LiteRT model From Pytorch.\n", "\n", "Our first step is to convert a PyTorch model to a float LiteRT model (which will be the input to AI Edge Quantizer)." ] }, { "metadata": { "id": "KJmCL_h63AMO" }, "cell_type": "code", "source": [ "%cd /content\n", "!rm -rf DIS sample_data\n", "\n", "!git clone https://github.com/xuebinqin/DIS.git\n", "%cd DIS/IS-Net/\n", "\n", "!curl -o ./model.tar.gz -L https://www.kaggle.com/api/v1/models/paulruiz/dis/pyTorch/8-17-22/1/download\n", "!tar -xvf 'model.tar.gz'" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "0hZkC1aB3TiZ" }, "cell_type": "code", "source": [ "import torch\n", "from models import ISNetDIS\n", "\n", "pytorch_model_filename = 'isnet-general-use.pth'\n", "pt_model = ISNetDIS()\n", "pt_model.load_state_dict(\n", " torch.load(pytorch_model_filename, map_location=torch.device('cpu'))\n", ")\n", "\n", "import torch\n", "from torch import nn\n", "from torchvision.transforms.functional import normalize\n", "\n", "\n", "class ImageSegmentationModelWrapper(nn.Module):\n", "\n", " RESCALING_FACTOR = 255.0\n", " MEAN = 0.5\n", " STD = 1.0\n", "\n", " def __init__(self, pt_model):\n", " super().__init__()\n", " self.model = pt_model\n", "\n", " def forward(self, image: torch.Tensor):\n", " # BHWC -> BCHW.\n", " image = image.permute(0, 3, 1, 2)\n", "\n", " # Rescale [0, 255] -> [0, 1].\n", " image = image / self.RESCALING_FACTOR\n", "\n", " # Normalize.\n", " image = (image - self.MEAN) / self.STD\n", "\n", " # Get result.\n", " result = self.model(image)[0][0]\n", "\n", " # BHWC -> BCHW.\n", " result = result.permute(0, 2, 3, 1)\n", "\n", " return result\n", "\n", "\n", "wrapped_pt_model = ImageSegmentationModelWrapper(pt_model).eval()" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "ipvXrW9NWsvy" }, "cell_type": "code", "source": [ "# @title Convert torch model to LiteRT using AI Edge Torch\n", "\n", "import litert_torch\n", "\n", "sample_args = (torch.rand((1, *MODEL_INPUT_HW, 3)),)\n", "edge_model = litert_torch.convert(wrapped_pt_model, sample_args)\n", "edge_model.export('model/isnet_float.tflite')" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "QcdAXUn_DeaO" }, "cell_type": "markdown", "source": [ "# AI Edge Quantizer" ] }, { "metadata": { "id": "3S7VLgswL4ig" }, "cell_type": "markdown", "source": [ "To use the `Quantizer`, we need to provide\n", "* the float .tflite model.\n", "* quantization recipe (i.e., apply quantization algorithm X on Operator Y with configuration Z).\n", "\n", "\n", "\n", "\n" ] }, { "metadata": { "id": "bcLgjIPExInv" }, "cell_type": "markdown", "source": [ "### Quantizing model with dynamic quantization\n", "\n", "\n", "The following example will showcase how to get a model with dynamic quantization with AI Edge Quantizer." ] }, { "metadata": { "id": "97HBnymXCaCA" }, "cell_type": "code", "source": [ "from ai_edge_quantizer import recipe\n", "\n", "quantizer = ai_edge_quantizer.Quantizer(float_model='model/isnet_float.tflite')\n", "quantizer.load_quantization_recipe(recipe=recipe.dynamic_wi8_afp32())\n", "\n", "\n", "quantization_result = quantizer.quantize()\n", "quantization_result.export_model('model/isnet_dynamic_wi8_afp32.tflite')" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "sCc1sgc6uzma" }, "cell_type": "markdown", "source": [ "`quantization_result` has two components\n", "\n", "> Add blockquote\n", "\n", "\n", "* quantized LiteRT model (in bytearray) and\n", "* the corresponding quantization recipe\n", "\n", "Let's take a look at what in this recipe" ] }, { "metadata": { "id": "3sMvpH5lFYV8" }, "cell_type": "code", "source": [ "quantization_result.recipe" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "3h8-JbIjNUeU" }, "cell_type": "markdown", "source": [ "Here the recipe means: apply the naive min/max uniform algorithm (`min_max_uniform_quantize`) for all ops supported by the AI Edge Quantizer (indicated by `*`) under layers satisfying regex `.*` (i.e., all layers). We want the weights of these ops to be quantized as int8, symmetric, channel_wise, and we want to execute the ops in `Integer` mode.\n" ] }, { "metadata": { "id": "RUufFeuaN_oN" }, "cell_type": "markdown", "source": [ "Now let try running both the float model and the newly quantized model and see how they compare." ] }, { "metadata": { "id": "o_m7OdW0OFXe" }, "cell_type": "code", "source": [ "quantized_mask = run_segmentation(test_image, 'model/isnet_dynamic_wi8_afp32.tflite')\n", "float_mask = run_segmentation(test_image, 'model/isnet_float.tflite')\n", "draw_segmentation(image, float_mask, quantized_mask, 'Dynamic_wi8_afp32')" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "OwexImxzm4CJ" }, "cell_type": "markdown", "source": [ "# Debug through Model Explorer (visualization)\n", "\n", "Now we know that Float execution give us better quality result, with a larger model size. Dynamic quantization gives a smaller model size but the quality can be worse.\n", "\n", "Let's try to understand where dynamic quantization is introducing precision loss to see if we can do better.\n", "\n", "The following code will generate a tensor-by-tensor comparison result between the dynamic quantized model and original float model.\n", "\n" ] }, { "metadata": { "editable": true, "id": "tyWj44-cfPkj", "tags": [ "parameters" ], "trusted": true }, "cell_type": "code", "source": [ "#@title Parameter to visualize LiteRT model\n", "visualize_model = True" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "8yyKy0PIHhrJ" }, "cell_type": "markdown", "source": [ "*Note*: `quantizer.validate` is a memory intensive operation. This might result in the out of memory errors in the cell below." ] }, { "metadata": { "id": "bNe85EY7_t6V" }, "cell_type": "code", "source": [ "comparison_result = quantizer.validate(\n", " test_data={'serving_default': [{'args_0': test_image}]},\n", " error_metrics='median_diff_ratio',\n", " use_xnnpack=False,\n", " num_threads=1,\n", ").save('', 'dynamic')" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "9lN3dE7qqIBo" }, "cell_type": "code", "source": [ "if visualize_model:\n", " model_explorer.visualize_from_config(\n", " model_explorer.config()\n", " .add_model_from_path('model/isnet_dynamic_wi8_afp32.tflite')\n", " .add_node_data_from_path('dynamic_comparison_result_me_input.json')\n", " )" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "CnB9x6b4RWJ_" }, "cell_type": "markdown", "source": [ "Using Model Explorer, we find that the errors come from the last few layers ('RSU6_stage2d', 'RSU7_stage1d', 'Conv2d_side1'). Lets try not quantize them." ] }, { "metadata": { "id": "u79g4Dxqewmm" }, "cell_type": "markdown", "source": [ "# Selective Dynamic Quantization\n", "\n", "Here we'll override the original `dynamic_wi8_afp32` recipe to skip the three scopes that produce inaccurate results. Notice that for each scope, the newly added rule always take precedence." ] }, { "metadata": { "id": "ZID1qD7i2Tn7" }, "cell_type": "code", "source": [ "scopes = ['RSU6', 'RSU7', 'Conv2d_side1']\n", "for scope in scopes:\n", " quantizer.update_quantization_recipe(\n", " regex=scope,\n", " operation_name='CONV_2D',\n", " algorithm_key='no_quantize',\n", " )\n", "quantizer.get_quantization_recipe()" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "Ng9euL6jLqfJ" }, "cell_type": "code", "source": [ "quantizer.quantize().export_model('model/isnet_selective_dynamic_wi8_afp32.tflite')\n", "quantized_mask = run_segmentation(\n", " test_image, 'model/isnet_selective_dynamic_wi8_afp32.tflite'\n", ")\n", "draw_segmentation(image, float_mask, quantized_mask, 'Selective Dynamic')" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "qt_4vjw9P9qO" }, "cell_type": "code", "source": [ "!ls -lh model/*.tflite" ], "outputs": [], "execution_count": null } ], "metadata": { "colab": { "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }