{ "cells": [ { "cell_type": "markdown", "id": "de738fb5", "metadata": {}, "source": [ "# Segmentation Tracking with DINOv3\n", "\n", "This notebook demonstrates using DINOv3 for video segmentation tracking\n", "using a non-parametric method similar to\n", "[\"Space-time correspondence as a contrastive random walk\" (Jabri et al. 2020)](https://arxiv.org/abs/2006.14613).\n", "\n", "Given:\n", "- RGB video frames\n", "- Instance segmentation masks for the first frame\n", "\n", "We will extract patch features from each frame and use patch similarity\n", "to propagate the ground-truth labels to all frames." ] }, { "cell_type": "markdown", "id": "6ab28340", "metadata": {}, "source": [ "## Setup\n", "\n", "Let's start by loading some pre-requisites, setting up the environment and checking the DINOv3 repository location:\n", "- `local` if `DINOV3_LOCATION` environment variable was set to work with a local version of DINOv3 repository;\n", "- `github` if the code should be loaded via torch hub." ] }, { "cell_type": "code", "execution_count": 1, "id": "1d4d54c3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DINOv3 location set to /private/home/vkhalidov/fairvit_dinov3_release/app/dinov3_release/\n" ] } ], "source": [ "import datetime\n", "import functools\n", "import io\n", "import logging\n", "import math\n", "import os\n", "from pathlib import Path\n", "import tarfile\n", "import time\n", "import urllib\n", "\n", "import lovely_tensors\n", "import matplotlib.pyplot as plt\n", "import mediapy as mp\n", "import numpy as np\n", "from PIL import Image\n", "import torch\n", "import torch.nn.functional as F\n", "import torchvision.transforms as TVT\n", "import torchvision.transforms.functional as TVTF\n", "from torch import Tensor, nn\n", "from tqdm import tqdm\n", "\n", "DISPLAY_HEIGHT = 200\n", "lovely_tensors.monkey_patch()\n", "torch.set_grad_enabled(False)\n", "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\")\n", "\n", "DINOV3_GITHUB_LOCATION = \"facebookresearch/dinov3\"\n", "\n", "if os.getenv(\"DINOV3_LOCATION\") is not None:\n", " DINOV3_LOCATION = os.getenv(\"DINOV3_LOCATION\")\n", "else:\n", " DINOV3_LOCATION = DINOV3_GITHUB_LOCATION\n", "\n", "print(f\"DINOv3 location set to {DINOV3_LOCATION}\")" ] }, { "cell_type": "markdown", "id": "1a28e044", "metadata": {}, "source": [ "## Model\n", "\n", "We load the DINOv3 ViT-L model and get some attributes. Feel free to try other DINOv3 models as well!" ] }, { "cell_type": "code", "execution_count": 2, "id": "75cd313f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-08-13 14:13:40,010 - INFO - using base=100 for rope new\n", "2025-08-13 14:13:40,011 - INFO - using min_period=None for rope new\n", "2025-08-13 14:13:40,011 - INFO - using max_period=None for rope new\n", "2025-08-13 14:13:40,012 - INFO - using normalize_coords=separate for rope new\n", "2025-08-13 14:13:40,012 - INFO - using shift_coords=None for rope new\n", "2025-08-13 14:13:40,012 - INFO - using rescale_coords=2 for rope new\n", "2025-08-13 14:13:40,012 - INFO - using jitter_coords=None for rope new\n", "2025-08-13 14:13:40,013 - INFO - using dtype=fp32 for rope new\n", "2025-08-13 14:13:40,014 - INFO - using mlp layer as FFN\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Patch size: 16\n", "Embedding dimension: 1024\n", "Peak GPU memory: 1.1 GB\n" ] } ], "source": [ "# examples of available DINOv3 models:\n", "MODEL_DINOV3_VITS = \"dinov3_vits16\"\n", "MODEL_DINOV3_VITSP = \"dinov3_vits16plus\"\n", "MODEL_DINOV3_VITB = \"dinov3_vitb16\"\n", "MODEL_DINOV3_VITL = \"dinov3_vitl16\"\n", "MODEL_DINOV3_VITHP = \"dinov3_vith16plus\"\n", "MODEL_DINOV3_VIT7B = \"dinov3_vit7b16\"\n", "\n", "# we take DINOv3 ViT-L\n", "MODEL_NAME = MODEL_DINOV3_VITL\n", "\n", "model = torch.hub.load(\n", " repo_or_dir=DINOV3_LOCATION,\n", " model=MODEL_NAME,\n", " source=\"local\" if DINOV3_LOCATION != DINOV3_GITHUB_LOCATION else \"github\",\n", ")\n", "model.to(\"cuda\")\n", "model.eval()\n", "\n", "patch_size = model.patch_size\n", "embed_dim = model.embed_dim\n", "print(f\"Patch size: {patch_size}\")\n", "print(f\"Embedding dimension: {embed_dim}\")\n", "print(f\"Peak GPU memory: {torch.cuda.max_memory_allocated() / 2**30:.1f} GB\")" ] }, { "cell_type": "markdown", "id": "6e7881a8", "metadata": {}, "source": [ "We want to process one image at the time and get L2-normalized features.\n", "Here is a wrapper to do just that." ] }, { "cell_type": "code", "execution_count": 3, "id": "c85e1dc1", "metadata": {}, "outputs": [], "source": [ "@torch.compile(disable=True)\n", "def forward(\n", " model: nn.Module,\n", " img: Tensor, # [3, H, W] already normalized for the model\n", ") -> Tensor:\n", " feats = model.get_intermediate_layers(img.unsqueeze(0), n=1, reshape=True)[0] # [1, D, h, w]\n", " feats = feats.movedim(-3, -1) # [1, h, w, D]\n", " feats = F.normalize(feats, dim=-1, p=2)\n", " return feats.squeeze(0) # [h, w, D]" ] }, { "cell_type": "markdown", "id": "8a66e30f", "metadata": {}, "source": [ "## Data\n", "\n", "Here we load the video frames and the instance segmentation masks for the first frame." ] }, { "cell_type": "markdown", "id": "a27a2bbe", "metadata": {}, "source": [ "\n", "This notebook assumes that the video has already been processed to extract individual frames as `.jpg` images in the current directory.\n", "```txt\n", "000001.jpg\n", "000002.jpg\n", "...\n", "```\n", "\n", "To run this notebook on your own `.mp4` video, use `ffmpeg` to extract frames at 24 FPS:\n", "```bash\n", "INPUT_VIDEO=\"video.mp4\"\n", "OUTPUT_DIR=\".\"\n", "ffmpeg -hide_banner -i \"${INPUT_VIDEO}\" -qscale:v 2 -vf fps=24 -y \"${OUTPUT_DIR}/%06d.jpg\"\n", "```" ] }, { "cell_type": "code", "execution_count": 4, "id": "ac498bc7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of frames: 174\n", "Original size: width=1920, height=1440\n" ] } ], "source": [ "VIDEO_FRAMES_URI = \"https://dl.fbaipublicfiles.com/dinov3/notebooks/segmentation_tracking/video_frames.tar.gz\"\n", "\n", "def load_video_frames_from_remote_tar(tar_uri: str) -> list[Image.Image]:\n", " images = []\n", " indices = []\n", " with urllib.request.urlopen(tar_uri) as f:\n", " tar = tarfile.open(fileobj=io.BytesIO(f.read()))\n", " for member in tar.getmembers():\n", " index_str, _ = os.path.splitext(member.name)\n", " image_data = tar.extractfile(member)\n", " image = Image.open(image_data).convert(\"RGB\")\n", " images.append(image)\n", " indices.append(int(index_str))\n", " order = np.argsort(indices)\n", " return [images[i] for i in order]\n", "\n", "frames = load_video_frames_from_remote_tar(VIDEO_FRAMES_URI)\n", "num_frames = len(frames)\n", "print(f\"Number of frames: {num_frames}\")\n", "\n", "original_width, original_height = frames[0].size\n", "print(f\"Original size: width={original_width}, height={original_height}\")" ] }, { "cell_type": "markdown", "id": "06ac1ba9-9737-4e8f-901b-99d8baf50f52", "metadata": {}, "source": [ "Let's show four sample frames from the video:" ] }, { "cell_type": "code", "execution_count": 5, "id": "567b31fb-ec61-40a3-89d3-b4457a09c72d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
Frame 0
\n", "
\n", "
Frame 57
\n", "
\n", "
Frame 115
\n", "
\n", "
Frame 173
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "num_selected_frames = 4\n", "selected_frames = np.linspace(0, num_frames - 1, num_selected_frames, dtype=int)\n", "\n", "mp.show_images(\n", " [frames[int(i)] for i in selected_frames],\n", " titles=[f\"Frame {i}\" for i in selected_frames],\n", " height=DISPLAY_HEIGHT,\n", ")" ] }, { "cell_type": "markdown", "id": "a8de8294", "metadata": {}, "source": [ "This notebook assumes that instance segmentation masks for the first frame are stored in a `.png` file:\n", "- A value of `0` indicates the background.\n", "- Object instances are represented by progressive contiguous `uint8` indices starting from `1`.\n", "\n", "We created the example mask using [SAM 2](), but Paint would work just as well.\n", "\n", "A function to visualize the masks as RGB is also provided." ] }, { "cell_type": "code", "execution_count": 6, "id": "740a8fc6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mask size: [1440, 1920]\n", "Number of masks: 6\n" ] }, { "data": { "text/html": [ "
\n", "
\n", "
Frame
\n", "
\n", "
Mask
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def mask_to_rgb(mask: np.ndarray | Tensor, num_masks: int) -> np.ndarray:\n", " if isinstance(mask, Tensor):\n", " mask = mask.cpu().numpy()\n", "\n", " # Exclude background\n", " background = mask == 0\n", " mask = mask - 1\n", " num_masks = num_masks - 1\n", "\n", " # Choose palette\n", " if num_masks <= 10:\n", " mask_rgb = plt.get_cmap(\"tab10\")(mask)[..., :3]\n", " elif num_masks <= 20:\n", " mask_rgb = plt.get_cmap(\"tab20\")(mask)[..., :3]\n", " else:\n", " mask_rgb = plt.get_cmap(\"gist_rainbow\")(mask / (num_masks - 1))[..., :3]\n", "\n", " mask_rgb = (mask_rgb * 255).astype(np.uint8)\n", " mask_rgb[background, :] = 0\n", " return mask_rgb\n", "\n", "\n", "def load_image_from_url(url: str) -> Image:\n", " with urllib.request.urlopen(url) as f:\n", " return Image.open(f)\n", "\n", "\n", "first_mask_np = np.array(\n", " load_image_from_url(\n", " \"https://dl.fbaipublicfiles.com/dinov3/notebooks/segmentation_tracking/first_video_frame_mask.png\"\n", " )\n", ")\n", "\n", "mask_height, mask_width = first_mask_np.shape # Abbreviated at [H', W']\n", "print(f\"Mask size: {[mask_height, mask_width]}\")\n", "\n", "num_masks = int(first_mask_np.max() + 1) # Abbreviated as M\n", "print(f\"Number of masks: {num_masks}\")\n", "\n", "mp.show_images(\n", " [frames[0], mask_to_rgb(first_mask_np, num_masks)],\n", " titles=[\"Frame\", \"Mask\"],\n", " height=DISPLAY_HEIGHT,\n", ")" ] }, { "cell_type": "markdown", "id": "8b8cc9d6", "metadata": {}, "source": [ "Time for some math! Input frames need to be resized to match the desired forward resolution and the model patch size.\n", "\n", "The desired forward resolution refers to the _short side_ of the input.\n", "If the desired resolution is not a multiple of the patch size, we simply round it up.\n", "Then, we determine the _long side_ by maintaining the original aspect ratio as much as possible, but rounding up to the patch size as well.\n", "\n", "With the occasion, we'll also setup the torchvision transforms and test them out on the first frame." ] }, { "cell_type": "code", "execution_count": 7, "id": "3b2c3fe2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "First frame: tensor[3, 960, 1280] n=3686400 (14Mb) x∈[-2.118, 2.640] μ=0.186 σ=0.915 cuda:0\n" ] } ], "source": [ "class ResizeToMultiple(nn.Module):\n", " def __init__(self, short_side: int, multiple: int):\n", " super().__init__()\n", " self.short_side = short_side\n", " self.multiple = multiple\n", "\n", " def _round_up(self, side: float) -> int:\n", " return math.ceil(side / self.multiple) * self.multiple\n", "\n", " def forward(self, img):\n", " old_width, old_height = TVTF.get_image_size(img)\n", " if old_width > old_height:\n", " new_height = self._round_up(self.short_side)\n", " new_width = self._round_up(old_width * new_height / old_height)\n", " else:\n", " new_width = self._round_up(self.short_side)\n", " new_height = self._round_up(old_height * new_width / old_width)\n", " return TVTF.resize(img, [new_height, new_width], interpolation=TVT.InterpolationMode.BICUBIC)\n", "\n", "\n", "SHORT_SIDE = 960\n", "\n", "transform = TVT.Compose(\n", " [\n", " ResizeToMultiple(short_side=SHORT_SIDE, multiple=patch_size),\n", " TVT.ToTensor(),\n", " TVT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", " ]\n", ")\n", "first_frame = transform(frames[0]).to(\"cuda\")\n", "print(f\"First frame: {first_frame}\")\n", "\n", "_, frame_height, frame_width = first_frame.shape # Abbreviated as [H, W]\n", "feats_height, feats_width = frame_height // patch_size, frame_width // patch_size # Abbreviated as [h, w]" ] }, { "cell_type": "markdown", "id": "b123acb4", "metadata": {}, "source": [ "Label propagation happens at the output resolution of the model,\n", "so we downsample the ground-truth masks of the first frame and turn them into a one-hot probability map." ] }, { "cell_type": "code", "execution_count": 8, "id": "783a97c7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "First mask: tensor[60, 80] i64 n=4800 (38Kb) x∈[0, 5] μ=0.229 σ=0.855 cuda:0\n", "First probs: tensor[60, 80, 6] n=28800 (0.1Mb) x∈[0., 1.000] μ=0.167 σ=0.373 cuda:0\n" ] } ], "source": [ "first_mask = torch.from_numpy(first_mask_np).to(\"cuda\", dtype=torch.long) # [H', W']\n", "first_mask = F.interpolate(\n", " first_mask[None, None, :, :].float(), # [1, 1, H', W']\n", " (feats_height, feats_width),\n", " mode=\"nearest-exact\",\n", ")[0, 0].long() # [h, w]\n", "print(f\"First mask: {first_mask}\")\n", "\n", "first_probs = F.one_hot(first_mask, num_masks).float() # [h, w, M]\n", "print(f\"First probs: {first_probs}\")" ] }, { "cell_type": "markdown", "id": "29a8f33a", "metadata": {}, "source": [ "## How it works\n", "\n", "And now the fun part!\n", "\n", "Label propagation takes as input:\n", "- The features of the current frame, with shape `[h, w, D]`\n", "- The features of the `t` context frames, with shape `[t, h, w, D]`\n", "- The mask probabilities of `t` context frames, with shape `[t, h, w, M]`\n", "\n", "For each patch of the current frame:\n", "- We compute the cosine similarity with all context patches.\n", "- We restrict the focus to a local neighborhood and select the top-k most similar context patches.\n", "- We compute a weighted average of the mask probabilities of the selected patches to obtain a prediction for the mask probabilies of the current patch." ] }, { "cell_type": "code", "execution_count": 9, "id": "4f78d2e7", "metadata": {}, "outputs": [], "source": [ "@torch.compile(disable=True)\n", "def propagate(\n", " current_features: Tensor, # [h\", w\", D], where h=h\", w=w\", and \" stands for current\n", " context_features: Tensor, # [t, h, w, D]\n", " context_probs: Tensor, # [t, h, w, M]\n", " neighborhood_mask: Tensor, # [h\", w\", h, w]\n", " topk: int,\n", " temperature: float,\n", ") -> Tensor:\n", " t, h, w, M = context_probs.shape\n", "\n", " # Compute similarity current -> context\n", " dot = torch.einsum(\n", " \"ijd, tuvd -> ijtuv\",\n", " current_features, # [h\", w\", D]\n", " context_features, # [t, h, w, D]\n", " ) # [h\", w\", t, h, w]\n", "\n", " # Restrict focus to local neighborhood\n", " dot = torch.where(\n", " neighborhood_mask[:, :, None, :, :], # [h\", w\", 1, h, w]\n", " dot, # [h\", w\", t, h, w]\n", " -torch.inf,\n", " )\n", "\n", " # Select top-k patches inside the neighborhood\n", " dot = dot.flatten(2, -1).flatten(0, 1) # [h\"w\", thw]\n", " k_th_largest = torch.topk(dot, dim=1, k=topk).values # [h\"w\", k]\n", " dot = torch.where(\n", " dot >= k_th_largest[:, -1:], # [h\"w\", thw]\n", " dot, # [h\"w\", thw]\n", " -torch.inf,\n", " )\n", "\n", " # Propagate probabilities from context to current frame\n", " weights = F.softmax(dot / temperature, dim=1) # [h\"w\", thw]\n", " current_probs = torch.mm(\n", " weights, # [h\"w\", thw]\n", " context_probs.flatten(0, 2), # [thw, M]\n", " ) # [h\"w\", M]\n", "\n", " # Propagated probs should already sum to 1, but just in case\n", " current_probs = current_probs / current_probs.sum(dim=1, keepdim=True) # [h\"w\", M]\n", "\n", " return current_probs.unflatten(0, (h, w)) # [h\", w\", M]\n", "\n", "\n", "@functools.lru_cache()\n", "def make_neighborhood_mask(h: int, w: int, size: float, shape: str) -> Tensor:\n", " ij = torch.stack(\n", " torch.meshgrid(\n", " torch.arange(h, dtype=torch.float32, device=\"cuda\"),\n", " torch.arange(w, dtype=torch.float32, device=\"cuda\"),\n", " indexing=\"ij\",\n", " ),\n", " dim=-1,\n", " ) # [h, w, 2]\n", " if shape == \"circle\":\n", " ord = 2\n", " elif shape == \"square\":\n", " ord = torch.inf\n", " else:\n", " raise ValueError(f\"Invalid {shape=}\")\n", " norm = torch.linalg.vector_norm(\n", " ij[:, :, None, None, :] - ij[None, None, :, :, :], # [h\", w\", h, w, 2]\n", " ord=ord,\n", " dim=-1,\n", " ) # [h\", w\", h, w]\n", " mask = norm <= size # [h\", w\", h, w] bool, True inside, False outside\n", " return mask" ] }, { "cell_type": "markdown", "id": "d5e119ef", "metadata": {}, "source": [ "How does the neighborhood mask look like?" ] }, { "cell_type": "code", "execution_count": 10, "id": "ca3500c9", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
(3, 14)
\n", "
\n", "
(20, 25)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "neighborhood_mask = make_neighborhood_mask(feats_height, feats_width, size=12, shape=\"circle\")\n", "\n", "mp.show_images(\n", " {f\"{(i, j)}\": neighborhood_mask[i, j].cpu().numpy() for i, j in [[3, 14], [20, 25]]},\n", " height=DISPLAY_HEIGHT,\n", ")" ] }, { "cell_type": "markdown", "id": "63987fe5", "metadata": {}, "source": [ "To understand how it works, let's do it for one frame only.\n", "The \"context\" contains only the first frame and the \"current frame\" is the second one." ] }, { "cell_type": "code", "execution_count": 11, "id": "d1301ba3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "First feats: torch.Size([60, 80, 1024])\n", "Current feats: torch.Size([60, 80, 1024])\n", "Current probs: tensor[60, 80, 6] n=28800 (0.1Mb) x∈[0., 1.000] μ=0.167 σ=0.371 cuda:0\n" ] } ], "source": [ "torch._dynamo.maybe_mark_dynamic(first_frame, (1, 2))\n", "first_feats = forward(model, first_frame) # [h, w, D]\n", "print(f\"First feats: {first_feats.shape}\")\n", "\n", "frame_idx = 1\n", "current_frame_pil = frames[frame_idx]\n", "current_frame = transform(current_frame_pil).to(\"cuda\") # [3, H, W]\n", "torch._dynamo.maybe_mark_dynamic(current_frame, (1, 2))\n", "current_feats = forward(model, current_frame) # [h\", w\", D]\n", "print(f\"Current feats: {current_feats.shape}\")\n", "\n", "current_probs = propagate(\n", " current_feats, # [h\", w\", D]\n", " context_features=first_feats.unsqueeze(0), # [1, h, w, D]\n", " context_probs=first_probs.unsqueeze(0), # [1, h, w, M]\n", " neighborhood_mask=neighborhood_mask, # [h\", w\", h, w]\n", " topk=5,\n", " temperature=0.2,\n", ") # [h\", w\", M]\n", "print(f\"Current probs: {current_probs}\")" ] }, { "cell_type": "markdown", "id": "25c47c78", "metadata": {}, "source": [ "Then, we upsample the predicted probabilities and postprocess them.\n", "\n", "Finally, we visualize:\n", "- The first frame with its ground truth next to the second frame with the predicted masks.\n", "- The per-mask probabilties predicted for the second frame." ] }, { "cell_type": "code", "execution_count": 12, "id": "4818851d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
First frame
\n", "
\n", "
Second frame
\n", "
\n", "
\n", "
\n", "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "
\n", "
Mask 0
\n", "
\n", "
Mask 1
\n", "
\n", "
Mask 2
\n", "
\n", "
Mask 3
\n", "
\n", "
Mask 4
\n", "
\n", "
Mask 5
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def postprocess_probs(\n", " probs: Tensor, # [B, M, H', W']\n", ") -> Tensor:\n", " vmin = probs.flatten(2, 3).min(dim=2).values # [B, M]\n", " vmax = probs.flatten(2, 3).max(dim=2).values # [B, M]\n", " probs = (probs - vmin[:, :, None, None]) / (vmax[:, :, None, None] - vmin[:, :, None, None])\n", " probs = torch.nan_to_num(probs, nan=0)\n", " return probs # [B, M, H', W']\n", "\n", "\n", "p = current_probs.movedim(-1, -3).unsqueeze(0) # [1, M, h\", w\"]\n", "p = F.interpolate(p, size=(mask_height, mask_width), mode=\"nearest\") # [1, M, H', W']\n", "p = postprocess_probs(p).squeeze(0) # [M, H', W']\n", "current_pred_np = p.argmax(0).cpu().numpy() # [H', W']\n", "current_probs_np = p.cpu().numpy() # [M, H', W']\n", "del p\n", "\n", "mp.show_images(\n", " [\n", " frames[0],\n", " current_frame_pil,\n", " mask_to_rgb(first_mask_np, num_masks),\n", " mask_to_rgb(current_pred_np, num_masks),\n", " ],\n", " titles=[\"First frame\", \"Second frame\", \"\", \"\"],\n", " columns=2,\n", " height=DISPLAY_HEIGHT,\n", ")\n", "\n", "mp.show_images(current_probs_np, titles=[f\"Mask {i}\" for i in range(num_masks)], height=DISPLAY_HEIGHT)" ] }, { "cell_type": "markdown", "id": "90364134", "metadata": {}, "source": [ "## Process video" ] }, { "cell_type": "markdown", "id": "af808223", "metadata": {}, "source": [ "All clear? Time to do it for real!\n", "\n", "This time we will process all frames, using a queue of context frames and a queue context mask probabilities.\n", "The queues will contain a limited number of the most recent frames, determined by `max_context_length`.\n", "The first frame is always included in the context and doesn't need to go in the queue." ] }, { "cell_type": "markdown", "id": "6efef2c9", "metadata": {}, "source": [ "Let's define all hyperparameters in one place." ] }, { "cell_type": "code", "execution_count": 13, "id": "132ee2c0", "metadata": {}, "outputs": [], "source": [ "MAX_CONTEXT_LENGTH = 7\n", "NEIGHBORHOOD_SIZE = 12\n", "NEIGHBORHOOD_SHAPE = \"circle\"\n", "TOPK = 5\n", "TEMPERATURE = 0.2" ] }, { "cell_type": "markdown", "id": "5485cd7e", "metadata": {}, "source": [ "Let's go!\n", "\n", "The predicted mask probabilities and the masks, at the original mask resolution, will be stored in `mask_predictions` and `mask_probabilities`." ] }, { "cell_type": "code", "execution_count": 14, "id": "a587723e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Processing: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 173/173 [03:09<00:00, 1.09s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Processing time: 0:03:09\n", "Mask probabilities: tensor[174, 6, 1440, 1920] n=2886451200 (11Gb) x∈[0., 1.000] μ=0.167 σ=0.368\n", "Mask predictions: tensor[174, 1440, 1920] u8 n=481075200 (0.4Gb) x∈[0, 5] μ=0.268 σ=0.903\n" ] } ], "source": [ "mask_predictions = torch.zeros([num_frames, mask_height, mask_width], dtype=torch.uint8) # [T, H', W']\n", "mask_predictions[0, :, :] = torch.from_numpy(first_mask_np)\n", "\n", "mask_probabilities = torch.zeros([num_frames, num_masks, mask_height, mask_width]) # [T, M, H', W']\n", "mask_probabilities[0, :, :, :] = F.one_hot(torch.from_numpy(first_mask_np).long(), num_masks).movedim(-1, -3)\n", "\n", "features_queue: list[Tensor] = []\n", "probs_queue: list[Tensor] = []\n", "\n", "neighborhood_mask = make_neighborhood_mask(\n", " feats_height,\n", " feats_width,\n", " size=NEIGHBORHOOD_SIZE,\n", " shape=NEIGHBORHOOD_SHAPE,\n", ") # [h\", w\", h, w]\n", "\n", "start = time.perf_counter()\n", "for frame_idx in tqdm(range(1, num_frames), desc=\"Processing\"):\n", " # Extract features for the current frame\n", " current_frame_pil = frames[frame_idx]\n", " current_frame = transform(current_frame_pil).to(\"cuda\") # [3, H, W]\n", " torch._dynamo.maybe_mark_dynamic(current_frame, (1, 2))\n", " current_feats = forward(model, current_frame) # [h\", w\", D]\n", "\n", " # Prepare the context, marking the time and mask dimensions as dynamic for torch compile\n", " context_feats = torch.stack([first_feats, *features_queue], dim=0) # [1+len(queue), h, w, D]\n", " context_probs = torch.stack([first_probs, *probs_queue], dim=0) # [1+len(queue), h, w, M]\n", " torch._dynamo.maybe_mark_dynamic(context_feats, 0)\n", " torch._dynamo.maybe_mark_dynamic(context_probs, (0, 3))\n", "\n", " # Propagate segmentation probs from context frames\n", " current_probs = propagate(\n", " current_feats,\n", " context_feats,\n", " context_probs,\n", " neighborhood_mask,\n", " TOPK,\n", " TEMPERATURE,\n", " ) # [h\", w\", M]\n", "\n", " # Update queues with current features and probs\n", " features_queue.append(current_feats)\n", " probs_queue.append(current_probs)\n", " if len(features_queue) > MAX_CONTEXT_LENGTH:\n", " features_queue.pop(0)\n", " if len(probs_queue) > MAX_CONTEXT_LENGTH:\n", " probs_queue.pop(0)\n", "\n", " # Upsample and postprocess segmentation probs, argmax to obtain a prediction\n", " current_probs = F.interpolate(\n", " current_probs.movedim(-1, -3)[None, :, :, :],\n", " size=(mask_height, mask_width),\n", " mode=\"nearest\",\n", " ) # [1, M, H', W']\n", " current_probs = postprocess_probs(current_probs) # [1, M, H', W']\n", " current_probs = current_probs.squeeze(0)\n", " mask_probabilities[frame_idx, :, :, :] = current_probs\n", " pred = torch.argmax(current_probs, dim=0).to(dtype=torch.uint8) # [H', W']\n", " mask_predictions[frame_idx, :, :] = pred # [H', W']\n", "\n", "torch.cuda.synchronize()\n", "end = time.perf_counter()\n", "print(f\"Processing time: {datetime.timedelta(seconds=round(end - start))}\")\n", "print(f\"Mask probabilities: {mask_probabilities}\")\n", "print(f\"Mask predictions: {mask_predictions}\")" ] }, { "cell_type": "markdown", "id": "020a82cc", "metadata": {}, "source": [ "Let's visualize a few frames and a video of the result." ] }, { "cell_type": "code", "execution_count": 15, "id": "3c4432ff", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
Frame 0
\n", "
\n", "
Frame 57
\n", "
\n", "
Frame 115
\n", "
\n", "
Frame 173
\n", "
\n", "
\n", "
\n", "
\n", "
\n", "
\n", "
\n", "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "
\n", "
Input
\n", "
\n", "
Pred
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "
\n", "
Prob 0
\n", "
\n", "
Prob 1
\n", "
\n", "
Prob 2
\n", "
\n", "
Prob 3
\n", "
\n", "
Prob 4
\n", "
\n", "
Prob 5
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import mediapy as mp\n", "\n", "mp.show_images(\n", " [frames[i].convert(\"RGB\") for i in selected_frames]\n", " + [mask_to_rgb(mask_predictions[i], num_masks) for i in selected_frames],\n", " titles=[f\"Frame {i}\" for i in selected_frames] + [\"\"] * len(selected_frames),\n", " columns=len(selected_frames),\n", " height=DISPLAY_HEIGHT,\n", ")\n", "\n", "mp.show_videos(\n", " {\n", " \"Input\": [np.array(frame) for frame in frames],\n", " \"Pred\": mask_to_rgb(mask_predictions, num_masks),\n", " },\n", " height=DISPLAY_HEIGHT,\n", " fps=24,\n", ")\n", "mp.show_videos(\n", " {f\"Prob {i}\": mask_probabilities[:, i].numpy() for i in range(num_masks)},\n", " height=DISPLAY_HEIGHT,\n", " fps=24,\n", ")" ] }, { "cell_type": "markdown", "id": "4a5bc771", "metadata": {}, "source": [ "# Conclusion\n", "\n", "This notebook showed how to use DINOv3 for video segmentation tracking.\n", "It should be fairly straightforward to run it to your own video and masks.\n", "The notebook hyperparameters can also be adjusted to see the effect on the results." ] }, { "cell_type": "markdown", "id": "5f3adf2c", "metadata": {}, "source": [ "Let's discuss GPU memory usage:\n", "- The ViT-L model takes approximately 1.1 GB to load.\n", "- The similarity matrix `dot` inside `propagate()` can get pretty big,\n", " especially at higher resolution and longer context length.\n", "\n", "To reduce memory usage and increase speed, the notebook is already set up to work with `torch.compile()`.\n", "In particular, note the use of `torch._dynamo.maybe_mark_dynamic()` to mark a few dimensions as dynamic\n", "to avoid too many recompilations. To enable compilation, just set `disable` to `False` in the\n", "`@torch.compile()` decorators for the `propagate()` and `forward()` functions.\n", "\n", "If you are low on memory, it also possible to use a smaller model and\n", "reduce the forward resolution, but tracking results will look worse." ] }, { "cell_type": "code", "execution_count": 16, "id": "94d339b0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Peak GPU memory: 3.6 GB\n" ] } ], "source": [ "print(f\"Peak GPU memory: {torch.cuda.max_memory_allocated() / 2**30:.1f} GB\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 5 }