{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "A100", "machine_shape": "hm" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "source": [ "## Multimodal RAG Example using Qwen3-VL Models\n", "\n", "This example demonstrates a complete multimodal RAG pipeline that can:\n", "\n", "- Process PDF documents into images\n", "- Create embeddings for both text queries and document images to use to search and retrieve relevant pages\n", "- Rerank results for better accuracy\n", "- Generate answers using Qwen3-VL\n", "\n", "**Note:** These models are large. Make sure you have sufficient GPU memory. You need to download the model repositories that include the scripts folder. You can install flash attention if your GPU allows, and quantize the models using bitsandbytes." ], "metadata": { "id": "1Doz8nKeWU4Q" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VE3jgbfnPQjs" }, "outputs": [], "source": [ "!pip install -q pdf2image>=1.16.0 qwen-vl-utils flash-attn bitsandbytes\n" ] }, { "cell_type": "code", "source": [ "!sudo apt-get install poppler-utils" ], "metadata": { "id": "UgV8BebePbl-", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "35795eae-9b7b-4288-bab7-8b68a18263dc" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Reading package lists... Done\n", "Building dependency tree... Done\n", "Reading state information... Done\n", "The following NEW packages will be installed:\n", " poppler-utils\n", "0 upgraded, 1 newly installed, 0 to remove and 41 not upgraded.\n", "Need to get 186 kB of archives.\n", "After this operation, 697 kB of additional disk space will be used.\n", "Get:1 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 poppler-utils amd64 22.02.0-2ubuntu0.12 [186 kB]\n", "Fetched 186 kB in 0s (372 kB/s)\n", "debconf: unable to initialize frontend: Dialog\n", "debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 78, <> line 1.)\n", "debconf: falling back to frontend: Readline\n", "debconf: unable to initialize frontend: Readline\n", "debconf: (This frontend requires a controlling tty.)\n", "debconf: falling back to frontend: Teletype\n", "dpkg-preconfigure: unable to re-open stdin: \n", "Selecting previously unselected package poppler-utils.\n", "(Reading database ... 121689 files and directories currently installed.)\n", "Preparing to unpack .../poppler-utils_22.02.0-2ubuntu0.12_amd64.deb ...\n", "Unpacking poppler-utils (22.02.0-2ubuntu0.12) ...\n", "Setting up poppler-utils (22.02.0-2ubuntu0.12) ...\n", "Processing triggers for man-db (2.10.2-1) ...\n" ] } ] }, { "cell_type": "markdown", "source": [ "Download scripts that required to wrap Qwen3-VL models." ], "metadata": { "id": "WQbwhYg8WNnb" } }, { "cell_type": "code", "source": [ "!wget https://huggingface.co/Qwen/Qwen3-VL-Embedding-8B/resolve/main/scripts/qwen3_vl_embedding.py\n", "!wget https://huggingface.co/Qwen/Qwen3-VL-Reranker-8B/resolve/main/scripts/qwen3_vl_reranker.py\n", "!mkdir scripts\n", "!mv qwen3_vl_embedding.py scripts/\n", "!mv qwen3_vl_reranker.py scripts/" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "_nuKol8KU4j8", "outputId": "33e2c24b-597d-4c16-af32-b2708d3651cd" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "--2026-01-08 16:18:02-- https://huggingface.co/Qwen/Qwen3-VL-Embedding-8B/resolve/main/scripts/qwen3_vl_embedding.py\n", "Resolving huggingface.co (huggingface.co)... 18.239.50.80, 18.239.50.49, 18.239.50.16, ...\n", "Connecting to huggingface.co (huggingface.co)|18.239.50.80|:443... connected.\n", "HTTP request sent, awaiting response... 307 Temporary Redirect\n", "Location: /api/resolve-cache/models/Qwen/Qwen3-VL-Embedding-8B/93dc85c0aaf1836e8ed35366f06e70625138817b/scripts%2Fqwen3_vl_embedding.py?%2FQwen%2FQwen3-VL-Embedding-8B%2Fresolve%2Fmain%2Fscripts%2Fqwen3_vl_embedding.py=&etag=%2236d45865735be96a1278a21c132ff640e2ae68ca%22 [following]\n", "--2026-01-08 16:18:02-- https://huggingface.co/api/resolve-cache/models/Qwen/Qwen3-VL-Embedding-8B/93dc85c0aaf1836e8ed35366f06e70625138817b/scripts%2Fqwen3_vl_embedding.py?%2FQwen%2FQwen3-VL-Embedding-8B%2Fresolve%2Fmain%2Fscripts%2Fqwen3_vl_embedding.py=&etag=%2236d45865735be96a1278a21c132ff640e2ae68ca%22\n", "Reusing existing connection to huggingface.co:443.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 13258 (13K) [text/plain]\n", "Saving to: ‘qwen3_vl_embedding.py’\n", "\n", "\rqwen3_vl_embedding. 0%[ ] 0 --.-KB/s \rqwen3_vl_embedding. 100%[===================>] 12.95K --.-KB/s in 0s \n", "\n", "2026-01-08 16:18:02 (382 MB/s) - ‘qwen3_vl_embedding.py’ saved [13258/13258]\n", "\n", "--2026-01-08 16:18:02-- https://huggingface.co/Qwen/Qwen3-VL-Reranker-8B/resolve/main/scripts/qwen3_vl_reranker.py\n", "Resolving huggingface.co (huggingface.co)... 18.239.50.80, 18.239.50.49, 18.239.50.103, ...\n", "Connecting to huggingface.co (huggingface.co)|18.239.50.80|:443... connected.\n", "HTTP request sent, awaiting response... 307 Temporary Redirect\n", "Location: /api/resolve-cache/models/Qwen/Qwen3-VL-Reranker-8B/fb0130df40a5cb3a458acb86ceeb2ef30b05900d/scripts%2Fqwen3_vl_reranker.py?%2FQwen%2FQwen3-VL-Reranker-8B%2Fresolve%2Fmain%2Fscripts%2Fqwen3_vl_reranker.py=&etag=%22db7c53f9ad0971a41065053be5ebc432a978a993%22 [following]\n", "--2026-01-08 16:18:03-- https://huggingface.co/api/resolve-cache/models/Qwen/Qwen3-VL-Reranker-8B/fb0130df40a5cb3a458acb86ceeb2ef30b05900d/scripts%2Fqwen3_vl_reranker.py?%2FQwen%2FQwen3-VL-Reranker-8B%2Fresolve%2Fmain%2Fscripts%2Fqwen3_vl_reranker.py=&etag=%22db7c53f9ad0971a41065053be5ebc432a978a993%22\n", "Reusing existing connection to huggingface.co:443.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 10873 (11K) [text/plain]\n", "Saving to: ‘qwen3_vl_reranker.py’\n", "\n", "qwen3_vl_reranker.p 100%[===================>] 10.62K --.-KB/s in 0s \n", "\n", "2026-01-08 16:18:03 (193 MB/s) - ‘qwen3_vl_reranker.py’ saved [10873/10873]\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "Some utils to convert PDF to images." ], "metadata": { "id": "T6SmEpvZcemw" } }, { "cell_type": "code", "source": [ "from pdf2image import convert_from_path\n", "import requests\n", "\n", "def download_pdf(url, save_path=\"document.pdf\"):\n", " response = requests.get(url)\n", " with open(save_path, 'wb') as f:\n", " f.write(response.content)\n", " print(f\"PDF saved to {save_path}\")\n", " return save_path\n", "\n", "def pdf_to_images(pdf_path):\n", " images = convert_from_path(pdf_path)\n", " return images" ], "metadata": { "id": "Xhu1f65TUfx4" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We'll use a PDF about climate change and write queries related to this." ], "metadata": { "id": "Xgtah5CQWgeV" } }, { "cell_type": "code", "source": [ "pdf_url = \"https://climate.ec.europa.eu/system/files/2018-06/youth_magazine_en.pdf\"\n", "pdf_path = download_pdf(pdf_url)\n", "document_images = pdf_to_images(pdf_path)\n", "\n", "document_images = document_images[4:10]\n", "\n", "queries = [\n", " {\"text\": \"How much did the world temperature change so far?\"},\n", " {\"text\": \"What are the main causes of climate change?\"},\n", "]" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "y8cbXMhRUliA", "outputId": "8c7ce0f5-80b5-41ca-cd7d-dba7c49775cd" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "PDF saved to document.pdf\n" ] } ] }, { "cell_type": "markdown", "source": [ "We can now embed the documents and queries and get similarity scores." ], "metadata": { "id": "DilxeuU8hITd" } }, { "cell_type": "code", "source": [ "import numpy as np\n", "import torch\n", "from io import BytesIO\n", "from PIL import Image\n", "\n", "from scripts.qwen3_vl_embedding import Qwen3VLEmbedder\n", "from transformers import Qwen3VLForConditionalGeneration, AutoProcessor\n", "\n", "embedder = Qwen3VLEmbedder(\"Qwen/Qwen3-VL-Embedding-2B\")\n", "\n", "document_inputs = []\n", "for idx, img in enumerate(document_images):\n", " img_path = f\"temp_page_{idx}.png\"\n", " img.save(img_path)\n", " document_inputs.append({\"image\": img_path})\n", "\n", "document_embeddings = embedder.process(document_inputs)\n", "query_embeddings = embedder.process(queries)\n", "print(f\"Query embeddings shape: {query_embeddings.shape}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FRsaxlTZUOGs", "outputId": "b501bfa7-9319-4ce3-a2ae-5025faa9e74e" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Query embeddings shape: torch.Size([2, 2048])\n" ] } ] }, { "cell_type": "markdown", "source": [ "Now we will get the similarities." ], "metadata": { "id": "P852Uf7ehMmF" } }, { "cell_type": "code", "source": [ "def retrieve_top_k(query_embedding, document_embeddings, k=3):\n", " \"\"\"Retrieve top-k most similar documents for a query\"\"\"\n", " if torch.is_tensor(query_embedding):\n", " query_embedding = query_embedding.cpu().float().numpy()\n", " if torch.is_tensor(document_embeddings):\n", " document_embeddings = document_embeddings.cpu().float().numpy()\n", " similarity_scores = query_embedding @ document_embeddings.T\n", " top_k_indices = np.argsort(similarity_scores)[-k:][::-1]\n", " top_k_scores = similarity_scores[top_k_indices]\n", " return top_k_indices, top_k_scores\n", "\n", "\n", "for query_idx, query in enumerate(queries):\n", " print(f\"\\nQuery {query_idx + 1}: {query['text']}\")\n", " print(\"-\" * 60)\n", "\n", " top_indices, top_scores = retrieve_top_k(\n", " query_embeddings[query_idx],\n", " document_embeddings,\n", " k=3\n", " )\n", "\n", " print(f\"Top 3 pages (by similarity):\")\n", " for rank, (page_idx, score) in enumerate(zip(top_indices, top_scores), 1):\n", " print(f\" {rank}. Page {page_idx + 1} (score: {score:.4f})\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "JgqYF3mWU1G4", "outputId": "a7783806-0044-4120-c471-e6cf32ae5a9e" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Query 1: How much did the world temperature change so far?\n", "------------------------------------------------------------\n", "Top 3 pages (by similarity):\n", " 1. Page 1 (score: 0.4605)\n", " 2. Page 5 (score: 0.4408)\n", " 3. Page 2 (score: 0.4360)\n", "\n", "Query 2: What are the main causes of climate change?\n", "------------------------------------------------------------\n", "Top 3 pages (by similarity):\n", " 1. Page 1 (score: 0.4903)\n", " 2. Page 2 (score: 0.4770)\n", " 3. Page 5 (score: 0.4555)\n" ] } ] }, { "cell_type": "markdown", "source": [ "We can remove embedder to save memory." ], "metadata": { "id": "-ukF0iGUhTuF" } }, { "cell_type": "code", "source": [ "del embedder" ], "metadata": { "id": "18H6qH5HhewE" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Let's rerank results." ], "metadata": { "id": "3PPYyAzshWsg" } }, { "cell_type": "code", "source": [ "from scripts.qwen3_vl_reranker import Qwen3VLReranker\n", "\n", "reranker = Qwen3VLReranker(\"Qwen/Qwen3-VL-Reranker-2B\")\n", "\n", "# Enable FA2 if your GPU allows\n", "# reranker = Qwen3VLReranker(\n", "# model_name_or_path=reranker_model_name,\n", "# torch_dtype=torch.bfloat16,\n", "# attn_implementation=\"flash_attention_2\"\n", "# )\n", "\n", "query_for_reranking = queries[0]\n", "top_indices, _ = retrieve_top_k(query_embeddings[0], document_embeddings, k=3)\n", "\n", "print(f\"\\nReranking results for: {query_for_reranking['text']}\")\n", "\n", "reranker_inputs = {\n", " \"instruction\": \"Retrieve pages relevant to the user's query about climate change.\",\n", " \"query\": query_for_reranking,\n", " \"documents\": [{\"image\": f\"temp_page_{idx}.png\"} for idx in top_indices],\n", " \"fps\": 1.0\n", "}\n", "\n", "reranker_scores = reranker.process(reranker_inputs)\n", "\n", "for rank, (page_idx, score) in enumerate(zip(top_indices, reranker_scores), 1):\n", " print(f\" {rank}. Page {page_idx + 1} (reranker score: {score:.4f})\")\n", "\n", "best_page_idx = top_indices[np.argmax(reranker_scores)]\n", "print(f\"\\nBest page after reranking: Page {best_page_idx + 1}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "kcZwFxBXaOpi", "outputId": "2bb9db6b-722a-4fdd-d5c4-2b964f641989" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", "Reranking results for: How much did the world temperature change so far?\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.12/dist-packages/transformers/tokenization_utils_base.py:2914: UserWarning: `max_length` is ignored when `padding`=`True` and there is no truncation strategy. To pad to max length, use `padding='max_length'`.\n", " warnings.warn(\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ " 1. Page 1 (reranker score: 0.5188)\n", " 2. Page 5 (reranker score: 0.4743)\n", " 3. Page 2 (reranker score: 0.6626)\n", "\n", "Best page after reranking: Page 2\n" ] } ] }, { "cell_type": "markdown", "source": [ "We are done with reranker, we can remove it to save memory." ], "metadata": { "id": "5dNf6C_uhErk" } }, { "cell_type": "code", "source": [ "del reranker" ], "metadata": { "id": "o86242ZmhDBV" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We will now initialize the Qwen3-VL model, pass our best ranked page as well as our text prompt to let Qwen3-VL generate the answer." ], "metadata": { "id": "XKHsnSj1hhaQ" } }, { "cell_type": "code", "source": [ "vlm_model = Qwen3VLForConditionalGeneration.from_pretrained(\n", " \"Qwen/Qwen3-VL-2B-Instruct\",\n", " dtype=\"auto\",\n", " device_map=\"auto\"\n", ").to(\"cuda\")\n", "# Enable FA2 for better performance\n", "# vlm_model = Qwen3VLForConditionalGeneration.from_pretrained(\n", "# \"Qwen/Qwen3-VL-2B-Instruct\",\n", "# dtype=torch.bfloat16,\n", "# attn_implementation=\"flash_attention_2\",\n", "# device_map=\"auto\",\n", "# )\n", "\n", "processor = AutoProcessor.from_pretrained(\"Qwen/Qwen3-VL-2B-Instruct\")\n", "\n", "print(f\"Generating answer for: {query_for_reranking['text']}\")\n", "print(f\"Using retrieved page: {best_page_idx + 1}\")\n", "\n", "messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\n", " \"type\": \"image\",\n", " \"image\": f\"temp_page_{best_page_idx}.png\",\n", " },\n", " {\n", " \"type\": \"text\",\n", " \"text\": f\"Based on this page from a climate change document, please answer the following question: {query_for_reranking['text']}\"\n", " },\n", " ],\n", " }\n", "]\n", "\n", "inputs = processor.apply_chat_template(\n", " messages,\n", " tokenize=True,\n", " add_generation_prompt=True,\n", " return_dict=True,\n", " return_tensors=\"pt\"\n", ")\n", "inputs = inputs.to(vlm_model.device)\n", "\n", "generated_ids = vlm_model.generate(**inputs, max_new_tokens=256)\n", "generated_ids_trimmed = [\n", " out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)\n", "]\n", "output_text = processor.batch_decode(\n", " generated_ids_trimmed,\n", " skip_special_tokens=True,\n", " clean_up_tokenization_spaces=False\n", ")\n", "\n", "print(f\"\\nQuery: {query_for_reranking['text']}\")\n", "print(f\"\\nAnswer:\\n{output_text[0]}\")\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "WYPnnSgRVlya", "outputId": "f715dfb7-be64-4bfc-b80a-ea735dbde855" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Generating answer for: How much did the world temperature change so far?\n", "Using retrieved page: 2\n", "\n", "Query: How much did the world temperature change so far?\n", "\n", "Answer:\n", "Based on the information provided in the document, the world temperature has changed by approximately 1.1°C since the late 19th century.\n" ] } ] }, { "cell_type": "markdown", "source": [ "The answer is correct!" ], "metadata": { "id": "EMk8V_-veJPY" } } ] }