{ "cells": [ { "cell_type": "markdown", "id": "de738fb5", "metadata": {}, "source": [ "# Segmentation Tracking with DINOv3\n", "\n", "This notebook demonstrates using DINOv3 for video segmentation tracking\n", "using a non-parametric method similar to\n", "[\"Space-time correspondence as a contrastive random walk\" (Jabri et al. 2020)](https://arxiv.org/abs/2006.14613).\n", "\n", "Given:\n", "- RGB video frames\n", "- Instance segmentation masks for the first frame\n", "\n", "We will extract patch features from each frame and use patch similarity\n", "to propagate the ground-truth labels to all frames." ] }, { "cell_type": "markdown", "id": "6ab28340", "metadata": {}, "source": [ "## Setup\n", "\n", "Let's start by loading some pre-requisites, setting up the environment and checking the DINOv3 repository location:\n", "- `local` if `DINOV3_LOCATION` environment variable was set to work with a local version of DINOv3 repository;\n", "- `github` if the code should be loaded via torch hub." ] }, { "cell_type": "code", "execution_count": 1, "id": "1d4d54c3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DINOv3 location set to /private/home/vkhalidov/fairvit_dinov3_release/app/dinov3_release/\n" ] } ], "source": [ "import datetime\n", "import functools\n", "import io\n", "import logging\n", "import math\n", "import os\n", "from pathlib import Path\n", "import tarfile\n", "import time\n", "import urllib\n", "\n", "import lovely_tensors\n", "import matplotlib.pyplot as plt\n", "import mediapy as mp\n", "import numpy as np\n", "from PIL import Image\n", "import torch\n", "import torch.nn.functional as F\n", "import torchvision.transforms as TVT\n", "import torchvision.transforms.functional as TVTF\n", "from torch import Tensor, nn\n", "from tqdm import tqdm\n", "\n", "DISPLAY_HEIGHT = 200\n", "lovely_tensors.monkey_patch()\n", "torch.set_grad_enabled(False)\n", "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\")\n", "\n", "DINOV3_GITHUB_LOCATION = \"facebookresearch/dinov3\"\n", "\n", "if os.getenv(\"DINOV3_LOCATION\") is not None:\n", " DINOV3_LOCATION = os.getenv(\"DINOV3_LOCATION\")\n", "else:\n", " DINOV3_LOCATION = DINOV3_GITHUB_LOCATION\n", "\n", "print(f\"DINOv3 location set to {DINOV3_LOCATION}\")" ] }, { "cell_type": "markdown", "id": "1a28e044", "metadata": {}, "source": [ "## Model\n", "\n", "We load the DINOv3 ViT-L model and get some attributes. Feel free to try other DINOv3 models as well!" ] }, { "cell_type": "code", "execution_count": 2, "id": "75cd313f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-08-13 14:13:40,010 - INFO - using base=100 for rope new\n", "2025-08-13 14:13:40,011 - INFO - using min_period=None for rope new\n", "2025-08-13 14:13:40,011 - INFO - using max_period=None for rope new\n", "2025-08-13 14:13:40,012 - INFO - using normalize_coords=separate for rope new\n", "2025-08-13 14:13:40,012 - INFO - using shift_coords=None for rope new\n", "2025-08-13 14:13:40,012 - INFO - using rescale_coords=2 for rope new\n", "2025-08-13 14:13:40,012 - INFO - using jitter_coords=None for rope new\n", "2025-08-13 14:13:40,013 - INFO - using dtype=fp32 for rope new\n", "2025-08-13 14:13:40,014 - INFO - using mlp layer as FFN\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Patch size: 16\n", "Embedding dimension: 1024\n", "Peak GPU memory: 1.1 GB\n" ] } ], "source": [ "# examples of available DINOv3 models:\n", "MODEL_DINOV3_VITS = \"dinov3_vits16\"\n", "MODEL_DINOV3_VITSP = \"dinov3_vits16plus\"\n", "MODEL_DINOV3_VITB = \"dinov3_vitb16\"\n", "MODEL_DINOV3_VITL = \"dinov3_vitl16\"\n", "MODEL_DINOV3_VITHP = \"dinov3_vith16plus\"\n", "MODEL_DINOV3_VIT7B = \"dinov3_vit7b16\"\n", "\n", "# we take DINOv3 ViT-L\n", "MODEL_NAME = MODEL_DINOV3_VITL\n", "\n", "model = torch.hub.load(\n", " repo_or_dir=DINOV3_LOCATION,\n", " model=MODEL_NAME,\n", " source=\"local\" if DINOV3_LOCATION != DINOV3_GITHUB_LOCATION else \"github\",\n", ")\n", "model.to(\"cuda\")\n", "model.eval()\n", "\n", "patch_size = model.patch_size\n", "embed_dim = model.embed_dim\n", "print(f\"Patch size: {patch_size}\")\n", "print(f\"Embedding dimension: {embed_dim}\")\n", "print(f\"Peak GPU memory: {torch.cuda.max_memory_allocated() / 2**30:.1f} GB\")" ] }, { "cell_type": "markdown", "id": "6e7881a8", "metadata": {}, "source": [ "We want to process one image at the time and get L2-normalized features.\n", "Here is a wrapper to do just that." ] }, { "cell_type": "code", "execution_count": 3, "id": "c85e1dc1", "metadata": {}, "outputs": [], "source": [ "@torch.compile(disable=True)\n", "def forward(\n", " model: nn.Module,\n", " img: Tensor, # [3, H, W] already normalized for the model\n", ") -> Tensor:\n", " feats = model.get_intermediate_layers(img.unsqueeze(0), n=1, reshape=True)[0] # [1, D, h, w]\n", " feats = feats.movedim(-3, -1) # [1, h, w, D]\n", " feats = F.normalize(feats, dim=-1, p=2)\n", " return feats.squeeze(0) # [h, w, D]" ] }, { "cell_type": "markdown", "id": "8a66e30f", "metadata": {}, "source": [ "## Data\n", "\n", "Here we load the video frames and the instance segmentation masks for the first frame." ] }, { "cell_type": "markdown", "id": "a27a2bbe", "metadata": {}, "source": [ "\n", "This notebook assumes that the video has already been processed to extract individual frames as `.jpg` images in the current directory.\n", "```txt\n", "000001.jpg\n", "000002.jpg\n", "...\n", "```\n", "\n", "To run this notebook on your own `.mp4` video, use `ffmpeg` to extract frames at 24 FPS:\n", "```bash\n", "INPUT_VIDEO=\"video.mp4\"\n", "OUTPUT_DIR=\".\"\n", "ffmpeg -hide_banner -i \"${INPUT_VIDEO}\" -qscale:v 2 -vf fps=24 -y \"${OUTPUT_DIR}/%06d.jpg\"\n", "```" ] }, { "cell_type": "code", "execution_count": 4, "id": "ac498bc7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of frames: 174\n", "Original size: width=1920, height=1440\n" ] } ], "source": [ "VIDEO_FRAMES_URI = \"https://dl.fbaipublicfiles.com/dinov3/notebooks/segmentation_tracking/video_frames.tar.gz\"\n", "\n", "def load_video_frames_from_remote_tar(tar_uri: str) -> list[Image.Image]:\n", " images = []\n", " indices = []\n", " with urllib.request.urlopen(tar_uri) as f:\n", " tar = tarfile.open(fileobj=io.BytesIO(f.read()))\n", " for member in tar.getmembers():\n", " index_str, _ = os.path.splitext(member.name)\n", " image_data = tar.extractfile(member)\n", " image = Image.open(image_data).convert(\"RGB\")\n", " images.append(image)\n", " indices.append(int(index_str))\n", " order = np.argsort(indices)\n", " return [images[i] for i in order]\n", "\n", "frames = load_video_frames_from_remote_tar(VIDEO_FRAMES_URI)\n", "num_frames = len(frames)\n", "print(f\"Number of frames: {num_frames}\")\n", "\n", "original_width, original_height = frames[0].size\n", "print(f\"Original size: width={original_width}, height={original_height}\")" ] }, { "cell_type": "markdown", "id": "06ac1ba9-9737-4e8f-901b-99d8baf50f52", "metadata": {}, "source": [ "Let's show four sample frames from the video:" ] }, { "cell_type": "code", "execution_count": 5, "id": "567b31fb-ec61-40a3-89d3-b4457a09c72d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n",
" \n",
" Frame 0 | \n",
" \n",
" Frame 57 | \n",
" \n",
" Frame 115 | \n",
" \n",
" Frame 173 |
\n",
" \n",
" Frame | \n",
" \n",
" Mask |
\n",
" \n",
" (3, 14) | \n",
" \n",
" (20, 25) |
\n",
" \n",
" First frame | \n",
" \n",
" Second frame |
\n",
" \n",
" | \n",
" \n",
" |
\n",
" \n",
" Mask 0 | \n",
" \n",
" Mask 1 | \n",
" \n",
" Mask 2 | \n",
" \n",
" Mask 3 | \n",
" \n",
" Mask 4 | \n",
" \n",
" Mask 5 |
\n",
" \n",
" Frame 0 | \n",
" \n",
" Frame 57 | \n",
" \n",
" Frame 115 | \n",
" \n",
" Frame 173 |
\n",
" \n",
" | \n",
" \n",
" | \n",
" \n",
" | \n",
" \n",
" |
\n",
" \n",
" Input | \n",
" \n",
" Pred |
\n",
" \n",
" Prob 0 | \n",
" \n",
" Prob 1 | \n",
" \n",
" Prob 2 | \n",
" \n",
" Prob 3 | \n",
" \n",
" Prob 4 | \n",
" \n",
" Prob 5 |