{ "cells": [ { "metadata": { "id": "_Z6db789D_hv" }, "cell_type": "code", "source": [ "# Copyright 2024 The AI Edge Quantizer Authors.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "# ==============================================================================" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "CDiNmYJqVwYN" }, "cell_type": "markdown", "source": [ "\n", " \n", " \n", "
\n", " Run in Google Colab\n", " \n", " View source on GitHub\n", "
" ] }, { "metadata": { "id": "HsWBIuE3EBCf" }, "cell_type": "code", "source": [ "# When running in google colab the pre-installed versions of some packages\n", "# might be incompatible with AI edge libraries.\n", "!pip uninstall -y jax jaxlib\n", "!pip install ai-edge-quantizer-nightly\n", "!pip install ai-edge-model-explorer\n", "!pip install ai-edge-litert-nightly" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "bD4ZDD9EEE9P" }, "cell_type": "markdown", "source": [ "## Install" ] }, { "metadata": { "id": "unHj1x85EHrA" }, "cell_type": "code", "source": [ "import logging\n", "import numpy as np\n", "\n", "import matplotlib.pylab as plt\n", "import pathlib\n", "import random\n", "import json\n", "\n", "import numpy as np\n", "import model_explorer\n", "\n", "from ai_edge_litert.interpreter import Interpreter\n", "import tensorflow as tf\n", "\n", "from ai_edge_quantizer import quantizer\n", "from ai_edge_quantizer import recipe\n", "from ai_edge_quantizer import qtyping\n", "from ai_edge_quantizer.utils import tfl_flatbuffer_utils" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "AAj-iHUtEQno" }, "cell_type": "markdown", "source": [ "## Create and train MNIST in Keras" ] }, { "metadata": { "id": "wpMcV_4-EN3A" }, "cell_type": "code", "source": [ "# Load MNIST dataset\n", "mnist = tf.keras.datasets.mnist\n", "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n", "\n", "# Normalize the input image so that each pixel value is between 0 to 1.\n", "train_images = train_images.astype(np.float32) / 255.0\n", "train_images = train_images.reshape([-1, 28, 28, 1])\n", "test_images = test_images.astype(np.float32) / 255.0\n", "test_images = test_images.reshape([-1, 28, 28, 1])\n", "\n", "num_classes = 10\n", "hidden_dim = 32\n", "model = tf.keras.Sequential()\n", "\n", "model.add(\n", " tf.keras.layers.Conv2D(\n", " hidden_dim//4,\n", " 3,\n", " activation=\"relu\",\n", " padding=\"same\",\n", " input_shape=(28, 28, 1),\n", " use_bias=True,\n", " )\n", ")\n", "model.add(tf.keras.layers.AveragePooling2D(pool_size=2))\n", "model.add(tf.keras.layers.Flatten())\n", "model.add(tf.keras.layers.Dense(hidden_dim, activation=\"relu\", use_bias=True))\n", "model.add(\n", " tf.keras.layers.Dense(num_classes, use_bias=False, activation=\"softmax\")\n", ")\n", "\n", "# Train the digit classification model.\n", "model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(\n", " from_logits=True),\n", " metrics=['accuracy'])\n", "model.fit(\n", " train_images,\n", " train_labels,\n", " epochs=5,\n", " validation_data=(test_images, test_labels)\n", ")" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "K8eYTpZYEdIq" }, "cell_type": "markdown", "source": [ "## Helper functions" ] }, { "metadata": { "id": "kl3l6Y77Eeqo" }, "cell_type": "code", "source": [ "def run_model(model_path, test_image_indices):\n", " global test_images\n", "\n", " # Initialize the interpreter.\n", " interpreter = Interpreter(model_path=str(model_path))\n", " interpreter.allocate_tensors()\n", "\n", " input_details = interpreter.get_input_details()[0]\n", " print(f\"input details: {input_details}\")\n", " output_details = interpreter.get_output_details()[0]\n", "\n", " predictions = np.zeros((len(test_image_indices),), dtype=int)\n", " for i, test_image_index in enumerate(test_image_indices):\n", " test_image = test_images[test_image_index]\n", "\n", " # Check if the input type is quantized, then rescale input data to int8.\n", " if input_details['dtype'] == np.int8:\n", " input_scale, input_zero_point = input_details[\"quantization\"]\n", " test_image = test_image / input_scale + input_zero_point\n", "\n", " test_image = np.expand_dims(test_image, axis=0).astype(input_details[\"dtype\"])\n", " interpreter.set_tensor(input_details[\"index\"], test_image)\n", " interpreter.invoke()\n", " output = interpreter.get_tensor(output_details[\"index\"])[0]\n", "\n", " predictions[i] = output.argmax()\n", "\n", " return predictions\n", "\n", "def test_model(model_path, test_image_index, model_type):\n", " global test_labels\n", "\n", " predictions = run_model(model_path, [test_image_index])\n", "\n", " plt.imshow(test_images[test_image_index])\n", " template = model_type + \" Model \\n True:{true}, Predicted:{predict}\"\n", " _ = plt.title(template.format(true= str(test_labels[test_image_index]), predict=str(predictions[0])))\n", " plt.grid(False)\n", "\n", "def evaluate_model(model_path, model_type):\n", " global test_images\n", " global test_labels\n", "\n", " test_image_indices = range(test_images.shape[0])\n", " predictions = run_model(model_path, test_image_indices)\n", "\n", " accuracy = (np.sum(test_labels== predictions) * 100) / len(test_images)\n", "\n", " print('%s model accuracy is %.4f%% (Number of test samples=%d)' % (\n", " model_type, accuracy, len(test_images)))" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "GBm6JKJaEpjL" }, "cell_type": "markdown", "source": [ "## Convert to flatbuffer and visualize float model" ] }, { "metadata": { "id": "gm4q-mAKJ8nq", "trusted": true, "tags": [ "parameters" ], "editable": true }, "cell_type": "code", "source": [ "#@title Parameter to visualize LiteRT model\n", "visualize_model = True" ], "outputs": [], "execution_count": 6 }, { "metadata": { "id": "ulLoMC5TE1Gk" }, "cell_type": "code", "source": [ "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", "litert_model = converter.convert()\n", "\n", "model_path = \"mnist_model.tflite\"\n", "with open(model_path, \"wb\") as f:\n", " f.write(litert_model)\n", "\n", "if visualize_model:\n", " model_explorer.visualize(model_path)" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "95sgQ9iX5pUL" }, "cell_type": "markdown", "source": [ "## Create a LiteRT model with dynamic quantization with AI Edge Quantizer" ] }, { "metadata": { "id": "ADFOmtTbFKcm" }, "cell_type": "code", "source": [ "dynamic_quant_mnist_model_path = \"mnist_model_quantized.tflite\"\n", "\n", "qt = quantizer.Quantizer(model_path, recipe.dynamic_wi8_afp32())\n", "quant_result = qt.quantize().export_model(dynamic_quant_mnist_model_path)\n", "\n", "if visualize_model:\n", " model_explorer.visualize(dynamic_quant_mnist_model_path)" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "m4dRQjwq6ZOn" }, "cell_type": "markdown", "source": [ "## Sanity check of float model on one image" ] }, { "metadata": { "id": "dByC87y8FTzn" }, "cell_type": "code", "source": [ "# Change this to test a different image.\n", "test_image_index = 1\n", "\n", "# Test the float model\n", "test_model(model_path, test_image_index, model_type=\"Float\")" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "oAc22QNM6gjo" }, "cell_type": "markdown", "source": [ "## Sanity check of LiteRT model with dynamic quantization" ] }, { "metadata": { "id": "3yiBEdrQGCx0" }, "cell_type": "code", "source": [ "test_model(dynamic_quant_mnist_model_path, test_image_index, model_type=\"Dynamic_wi8_afp32\")" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "BsSOEtN9GZn2" }, "cell_type": "markdown", "source": [ "## Evaluate the models on all images" ] }, { "metadata": { "id": "1iqELElvGclH" }, "cell_type": "code", "source": [ "# Evaluate the float model\n", "evaluate_model(model_path, model_type=\"Float\")\n", "\n", "# Evaluate the LiteRT model with dynamic quantization\n", "evaluate_model(dynamic_quant_mnist_model_path, model_type=\"Dynamic_wi8_afp32\")" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "ISKGYiAc7uUy" }, "cell_type": "markdown", "source": [ "## Compare size of flatbuffers" ] }, { "metadata": { "id": "RCg4CaxJ70Z7" }, "cell_type": "code", "source": [ "!ls -lh *.tflite" ], "outputs": [], "execution_count": null } ], "metadata": { "colab": { "provenance": [], "toc_visible": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }