{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import dataclasses\n", "import math\n", "import warnings\n", "from typing import Callable\n", "import os\n", "\n", "import lovely_tensors\n", "import numpy as np\n", "import PIL.Image\n", "import torch\n", "import torch.nn.functional as F\n", "import torchvision.transforms as TVT\n", "import torchvision.transforms.functional as TVTF\n", "import tqdm\n", "from omegaconf import OmegaConf\n", "from torch import Tensor, nn\n", "from torchmetrics.classification import MulticlassJaccardIndex\n", "\n", "DINOv3_REPO_DIR = \"\" # Please add here the path to your DINOv3 repository" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Prepare datasets" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Please change the dataset `self.ds` to yours in the __init__ functions of both datasets.\n", "\n", "class ZeroShotSegmentationDataset(torch.utils.data.Dataset):\n", " CLASS_NAMES: tuple[str, ...]\n", " IGNORE_ZERO_LABEL: bool # If True, map label 0 to 255 so it's ignored, and shift all other labels by -1\n", " transform: Callable[[PIL.Image.Image], Tensor]\n", "\n", " def __init__(self, transform: Callable[[PIL.Image.Image], Tensor]) -> None:\n", " self.transform = transform\n", "\n", " def _mask_to_tensor(self, mask_pil: PIL.Image.Image) -> Tensor:\n", " mask = torch.from_numpy(np.array(mask_pil)).long()\n", " if self.IGNORE_ZERO_LABEL:\n", " mask = torch.where((mask == 0) | (mask == 255), 255, mask - 1)\n", " return mask\n", "\n", " def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:\n", " img, target = self.ds[idx]\n", " img = self.transform(img)\n", " target = self._mask_to_tensor(target)\n", " return img, target\n", "\n", " def __len__(self) -> int:\n", " return len(self.ds)\n", "\n", "\n", "class Cityscapes(ZeroShotSegmentationDataset):\n", " CLASS_NAMES = (\n", " \"road\",\n", " \"sidewalk\",\n", " \"building\",\n", " \"wall\",\n", " \"fence\",\n", " \"pole\",\n", " \"traffic light\",\n", " \"traffic sign\",\n", " \"vegetation\",\n", " \"terrain\",\n", " \"sky\",\n", " \"person\",\n", " \"rider\",\n", " \"car\",\n", " \"truck\",\n", " \"bus\",\n", " \"train\",\n", " \"motorcycle\",\n", " \"bicycle\",\n", " )\n", " IGNORE_ZERO_LABEL = False\n", "\n", " def __init__(self, transform: Callable[[PIL.Image.Image], Tensor]) -> None:\n", " super().__init__(transform)\n", " self.ds = None # Put here \"Cityscapes:split=VAL\" dataset\n", "\n", "\n", "class Ade20k(ZeroShotSegmentationDataset):\n", " CLASS_NAMES = (\n", " \"wall\",\n", " \"building\",\n", " \"sky\",\n", " \"floor\",\n", " \"tree\",\n", " \"ceiling\",\n", " \"road\",\n", " \"bed \",\n", " \"windowpane\",\n", " \"grass\",\n", " \"cabinet\",\n", " \"sidewalk\",\n", " \"person\",\n", " \"earth\",\n", " \"door\",\n", " \"table\",\n", " \"mountain\",\n", " \"plant\",\n", " \"curtain\",\n", " \"chair\",\n", " \"car\",\n", " \"water\",\n", " \"painting\",\n", " \"sofa\",\n", " \"shelf\",\n", " \"house\",\n", " \"sea\",\n", " \"mirror\",\n", " \"rug\",\n", " \"field\",\n", " \"armchair\",\n", " \"seat\",\n", " \"fence\",\n", " \"desk\",\n", " \"rock\",\n", " \"wardrobe\",\n", " \"lamp\",\n", " \"bathtub\",\n", " \"railing\",\n", " \"cushion\",\n", " \"base\",\n", " \"box\",\n", " \"column\",\n", " \"signboard\",\n", " \"chest of drawers\",\n", " \"counter\",\n", " \"sand\",\n", " \"sink\",\n", " \"skyscraper\",\n", " \"fireplace\",\n", " \"refrigerator\",\n", " \"grandstand\",\n", " \"path\",\n", " \"stairs\",\n", " \"runway\",\n", " \"case\",\n", " \"pool table\",\n", " \"pillow\",\n", " \"screen door\",\n", " \"stairway\",\n", " \"river\",\n", " \"bridge\",\n", " \"bookcase\",\n", " \"blind\",\n", " \"coffee table\",\n", " \"toilet\",\n", " \"flower\",\n", " \"book\",\n", " \"hill\",\n", " \"bench\",\n", " \"countertop\",\n", " \"stove\",\n", " \"palm\",\n", " \"kitchen island\",\n", " \"computer\",\n", " \"swivel chair\",\n", " \"boat\",\n", " \"bar\",\n", " \"arcade machine\",\n", " \"hovel\",\n", " \"bus\",\n", " \"towel\",\n", " \"light\",\n", " \"truck\",\n", " \"tower\",\n", " \"chandelier\",\n", " \"awning\",\n", " \"streetlight\",\n", " \"booth\",\n", " \"television receiver\",\n", " \"airplane\",\n", " \"dirt track\",\n", " \"apparel\",\n", " \"pole\",\n", " \"land\",\n", " \"bannister\",\n", " \"escalator\",\n", " \"ottoman\",\n", " \"bottle\",\n", " \"buffet\",\n", " \"poster\",\n", " \"stage\",\n", " \"van\",\n", " \"ship\",\n", " \"fountain\",\n", " \"conveyer belt\",\n", " \"canopy\",\n", " \"washer\",\n", " \"plaything\",\n", " \"swimming pool\",\n", " \"stool\",\n", " \"barrel\",\n", " \"basket\",\n", " \"waterfall\",\n", " \"tent\",\n", " \"bag\",\n", " \"minibike\",\n", " \"cradle\",\n", " \"oven\",\n", " \"ball\",\n", " \"food\",\n", " \"step\",\n", " \"tank\",\n", " \"trade name\",\n", " \"microwave\",\n", " \"pot\",\n", " \"animal\",\n", " \"bicycle\",\n", " \"lake\",\n", " \"dishwasher\",\n", " \"screen\",\n", " \"blanket\",\n", " \"sculpture\",\n", " \"hood\",\n", " \"sconce\",\n", " \"vase\",\n", " \"traffic light\",\n", " \"tray\",\n", " \"ashcan\",\n", " \"fan\",\n", " \"pier\",\n", " \"crt screen\",\n", " \"plate\",\n", " \"monitor\",\n", " \"bulletin board\",\n", " \"shower\",\n", " \"radiator\",\n", " \"glass\",\n", " \"clock\",\n", " \"flag\",\n", " )\n", " IGNORE_ZERO_LABEL = True\n", "\n", " def __init__(self, transform: Callable[[PIL.Image.Image], Tensor]) -> None:\n", " super().__init__(transform)\n", " self.ds = None # Put here \"ADE20KChallengeData2016:split=VAL\" dataset\n", "\n", "\n", "DATASETS: dict[str, type[ZeroShotSegmentationDataset]] = {\n", " \"cityscapes\": Cityscapes,\n", " \"ade20k\": Ade20k,\n", "}\n", "NORMALIZE_IMAGENET = TVT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Encode image function" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def encode_image(model, img: Tensor) -> tuple[Tensor, Tensor]:\n", " \"\"\"Extract image features from the backbone and the additional blocks.\"\"\"\n", " B, _, H, W = img.shape\n", " P = model.visual_model.backbone.patch_size # In the case of our DINOv3\n", " new_H = math.ceil(H / P) * P\n", " new_W = math.ceil(W / P) * P\n", "\n", " # Stretch image to a multiple of patch size\n", " if (H, W) != (new_H, new_W):\n", " img = F.interpolate(img, size=(new_H, new_W), mode=\"bicubic\", align_corners=False) # [B, 3, H', W']\n", "\n", " B, _, h_i, w_i = img.shape\n", "\n", " backbone_patches = None\n", " cls_tokens, _, patch_tokens = model.visual_model.get_class_and_patch_tokens(img)\n", " blocks_patches = (\n", " patch_tokens.reshape(B, h_i // P, w_i // P, -1).contiguous()\n", " ) # [1, h, w, D]\n", "\n", " return backbone_patches, blocks_patches\n", "\n", "\n", "class ShortSideResize(nn.Module):\n", " def __init__(self, size: int, interpolation: TVT.InterpolationMode) -> None:\n", " super().__init__()\n", " self.size = size\n", " self.interpolation = interpolation\n", "\n", " def forward(self, img: Tensor) -> Tensor:\n", " _, h, w = TVTF.get_dimensions(img)\n", " if (w <= h and w == self.size) or (h <= w and h == self.size):\n", " return img\n", " if w < h:\n", " new_w = self.size\n", " new_h = int(self.size * h / w)\n", " return TVTF.resize(img, [new_h, new_w], self.interpolation)\n", " else:\n", " new_h = self.size\n", " new_w = int(self.size * w / h)\n", " return TVTF.resize(img, [new_h, new_w], self.interpolation)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Functions for prediction in mode whole or sliding window" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def predict_whole(model, img: Tensor, text_features: Tensor) -> Tensor:\n", " # Extract image features from the additional blocks, ignore the backbone features\n", " _, H, W = img.shape\n", " _, blocks_feats = encode_image(model, img.unsqueeze(0)) # [1, h, w, D]\n", " _, h, w, _ = blocks_feats.shape\n", " blocks_feats = blocks_feats.squeeze(0) # [h, w, D]\n", "\n", " # Cosine similarity between patch features and text features (already normalized)\n", " blocks_feats = F.normalize(blocks_feats, p=2, dim=-1) # [h, w, D]\n", " cos = torch.einsum(\"cd,hwd->chw\", text_features, blocks_feats) # [num_classes, h, w]\n", "\n", " # Return low-res cosine similarities, they will be upsampled to the target resolution later\n", " return cos\n", "\n", "def predict_slide(model, img: Tensor, text_features: Tensor, side: int, stride: int) -> Tensor:\n", " # Iterate over overlapping windows, accumulate predictions at the image resolution\n", " _, H, W = img.shape\n", " num_classes, _ = text_features.shape\n", " probs = torch.zeros([num_classes, H, W], device=\"cuda\")\n", " counts = torch.zeros([H, W], device=\"cuda\")\n", " h_grids = max(H - side + stride - 1, 0) // stride + 1\n", " w_grids = max(W - side + stride - 1, 0) // stride + 1\n", " for i in range(h_grids):\n", " for j in range(w_grids):\n", " y1 = i * stride\n", " x1 = j * stride\n", " y2 = min(y1 + side, H)\n", " x2 = min(x1 + side, W)\n", " y1 = max(y2 - side, 0)\n", " x1 = max(x2 - side, 0)\n", "\n", " # Compute cosine similarities for this window, same logic as predict_whole\n", " img_window = img[:, y1:y2, x1:x2] # [3, H_win, W_win]\n", " cos = predict_whole(model, img_window, text_features) # [num_classes, h, w]\n", "\n", " # Upsample to the window resolution and accumulate \"probabilities\"\n", " # NOTE: they aren't real probabilities, just the result of applying softmax to cosine similarities\n", " cos = F.interpolate(\n", " cos.unsqueeze(0),\n", " size=img_window.shape[1:],\n", " mode=\"bilinear\",\n", " align_corners=False,\n", " ).squeeze(0) # [num_classes, H_win, W_win]\n", " probs[:, y1:y2, x1:x2] += cos.softmax(dim=0) # [num_classes, h, w]\n", " counts[y1:y2, x1:x2] += 1\n", " probs /= counts\n", "\n", " # Return \"probabilities\" at the img resolution, they will be upsampled to the target resolution later\n", " return probs # [num_classes, H, W]\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Prompt templates" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Reference: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb\n", "PROMPT_TEMPLATES = (\n", " \"a bad photo of a {0}.\",\n", " \"a photo of many {0}.\",\n", " \"a sculpture of a {0}.\",\n", " \"a photo of the hard to see {0}.\",\n", " \"a low resolution photo of the {0}.\",\n", " \"a rendering of a {0}.\",\n", " \"graffiti of a {0}.\",\n", " \"a bad photo of the {0}.\",\n", " \"a cropped photo of the {0}.\",\n", " \"a tattoo of a {0}.\",\n", " \"the embroidered {0}.\",\n", " \"a photo of a hard to see {0}.\",\n", " \"a bright photo of a {0}.\",\n", " \"a photo of a clean {0}.\",\n", " \"a photo of a dirty {0}.\",\n", " \"a dark photo of the {0}.\",\n", " \"a drawing of a {0}.\",\n", " \"a photo of my {0}.\",\n", " \"the plastic {0}.\",\n", " \"a photo of the cool {0}.\",\n", " \"a close-up photo of a {0}.\",\n", " \"a black and white photo of the {0}.\",\n", " \"a painting of the {0}.\",\n", " \"a painting of a {0}.\",\n", " \"a pixelated photo of the {0}.\",\n", " \"a sculpture of the {0}.\",\n", " \"a bright photo of the {0}.\",\n", " \"a cropped photo of a {0}.\",\n", " \"a plastic {0}.\",\n", " \"a photo of the dirty {0}.\",\n", " \"a jpeg corrupted photo of a {0}.\",\n", " \"a blurry photo of the {0}.\",\n", " \"a photo of the {0}.\",\n", " \"a good photo of the {0}.\",\n", " \"a rendering of the {0}.\",\n", " \"a {0} in a video game.\",\n", " \"a photo of one {0}.\",\n", " \"a doodle of a {0}.\",\n", " \"a close-up photo of the {0}.\",\n", " \"a photo of a {0}.\",\n", " \"the origami {0}.\",\n", " \"the {0} in a video game.\",\n", " \"a sketch of a {0}.\",\n", " \"a doodle of the {0}.\",\n", " \"a origami {0}.\",\n", " \"a low resolution photo of a {0}.\",\n", " \"the toy {0}.\",\n", " \"a rendition of the {0}.\",\n", " \"a photo of the clean {0}.\",\n", " \"a photo of a large {0}.\",\n", " \"a rendition of a {0}.\",\n", " \"a photo of a nice {0}.\",\n", " \"a photo of a weird {0}.\",\n", " \"a blurry photo of a {0}.\",\n", " \"a cartoon {0}.\",\n", " \"art of a {0}.\",\n", " \"a sketch of the {0}.\",\n", " \"a embroidered {0}.\",\n", " \"a pixelated photo of a {0}.\",\n", " \"itap of the {0}.\",\n", " \"a jpeg corrupted photo of the {0}.\",\n", " \"a good photo of a {0}.\",\n", " \"a plushie {0}.\",\n", " \"a photo of the nice {0}.\",\n", " \"a photo of the small {0}.\",\n", " \"a photo of the weird {0}.\",\n", " \"the cartoon {0}.\",\n", " \"art of the {0}.\",\n", " \"a drawing of the {0}.\",\n", " \"a photo of the large {0}.\",\n", " \"a black and white photo of a {0}.\",\n", " \"the plushie {0}.\",\n", " \"a dark photo of a {0}.\",\n", " \"itap of a {0}.\",\n", " \"graffiti of the {0}.\",\n", " \"a toy {0}.\",\n", " \"itap of my {0}.\",\n", " \"a photo of a cool {0}.\",\n", " \"a photo of a small {0}.\",\n", " \"a tattoo of the {0}.\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load model" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Load the model\n", "import sys\n", "sys.path.append(DINOv3_REPO_DIR)\n", "\n", "from dinov3.hub.dinotxt import dinov3_vitl16_dinotxt_tet1280d20h24l\n", "model, tokenizer = dinov3_vitl16_dinotxt_tet1280d20h24l()\n", "model.to(\"cuda\", non_blocking=True)\n", "model.eval()\n", "tokenizer = tokenizer.tokenize" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Configuration" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@dataclasses.dataclass\n", "class Configuration:\n", " dataset: str = \"cityscapes\" # cityscapes, ade20k\n", "\n", " mode: str = \"slide\" # whole (whole image), slide (sliding window inference)\n", " resize: int = 512 # Short side of the input images\n", "\n", " # Only used for mode=slide\n", " side: int = 384\n", " stride: int = 192\n", "\n", "# Local setup\n", "lovely_tensors.monkey_patch()\n", "warnings.filterwarnings(\"ignore\", message=\"xFormers\")\n", "cfg: Configuration = OmegaConf.to_object(\n", " OmegaConf.structured(Configuration),\n", ")\n", "print(f\"Configuration:\\n{OmegaConf.to_yaml(cfg)}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load dataset\n", "transform = TVT.Compose(\n", " [\n", " ShortSideResize(cfg.resize, TVT.InterpolationMode.BICUBIC),\n", " TVT.ToTensor(),\n", " NORMALIZE_IMAGENET,\n", " ]\n", ")\n", "dataset = DATASETS[cfg.dataset](transform)\n", "class_names = dataset.CLASS_NAMES\n", "print(f\"Dataset: {len(dataset)} images, {len(class_names)} classes\")\n", "dataloder = torch.utils.data.DataLoader(\n", " dataset,\n", " batch_size=None, # TODO Adapt\n", " num_workers=1,\n", " shuffle=False,\n", " pin_memory=True,\n", " multiprocessing_context=\"spawn\",\n", ")\n", "\n", "# Prepare text features: prompts x class names\n", "text_feats = []\n", "for class_name in tqdm.tqdm(class_names, desc=\"Class names\", unit=\"name\", ncols=0):\n", " text = [template.format(class_name) for template in PROMPT_TEMPLATES]\n", " tokens = tokenizer(text).to(\"cuda\", non_blocking=True)\n", " feats = model.encode_text(tokens) # [num_prompts, 2D]\n", " feats = feats[:, feats.shape[1] // 2 :] # The 1st half of the features corresponds to the CLS token, drop it\n", " feats = F.normalize(feats, p=2, dim=-1) # Normalize each text embedding\n", " feats = feats.mean(dim=0) # Average over all prompt embeddings per class\n", " feats = F.normalize(feats, p=2, dim=-1) # Normalize again\n", " text_feats.append(feats)\n", "text_feats = torch.stack(text_feats) # [num_classes, D]\n", "print(f\"Text features: {text_feats}\")\n", "\n", "# Loop over dataset, perform segmentation and compute metrics\n", "miou = MulticlassJaccardIndex(len(class_names), average=\"macro\", ignore_index=255).to(\"cuda\")\n", "for idx, (img, target) in enumerate(tqdm.tqdm(dataloder, desc=\"Segmentation\", unit=\"img\", ncols=0)):\n", " _, H, W = img.shape\n", " H_target, W_target = target.shape\n", " img = img.to(\"cuda\", non_blocking=True) # [3, H, W]\n", " target = target.to(\"cuda\", non_blocking=True) # [H_target, W_target]\n", " if idx == 0:\n", " tqdm.tqdm.write(f\"Image: {img}\")\n", " tqdm.tqdm.write(f\"Target: {target}\")\n", "\n", " if cfg.mode == \"whole\":\n", " pred = predict_whole(model, img, text_feats) # [num_classes, H, W]\n", " elif cfg.mode == \"slide\":\n", " pred = predict_slide(model, img, text_feats, cfg.side, cfg.stride) # [num_classes, H, W]\n", " else:\n", " raise ValueError(f\"Unknown mode {cfg.mode}\")\n", " if idx == 0:\n", " tqdm.tqdm.write(f\"Pred: {pred}\")\n", "\n", " # Interpolate to the target resolution and take argmax\n", " pred = F.interpolate(pred.unsqueeze(0), size=(H_target, W_target), mode=\"bilinear\", align_corners=False)\n", " pred = pred.squeeze(0).argmax(dim=0) # [H_target, W_target]\n", " miou.update(pred.unsqueeze(0), target.unsqueeze(0))\n", "\n", "# Compute metrics\n", "print(f\"Configuration {cfg}\")\n", "print(f\"Segmentation mIoU: {100 * miou.compute().item()}\")\n", "\n", "data = [['mIoU'], [100 * miou.compute().item()]]" ] } ], "metadata": { "kernelspec": { "display_name": "fairvit-py311-ptnightly-xformers-20250810", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.13" } }, "nbformat": 4, "nbformat_minor": 2 }