{ "cells": [ { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true }, "source": [ "import torch\n", "import numpy as np\n", "import rioxarray as rxr\n", "import matplotlib.pyplot as plt\n", "from terratorch import FULL_MODEL_REGISTRY\n", "from terratorch.models.backbones.terramind.model.terramind_register import v1_pretraining_mean, v1_pretraining_std\n", "\n", "# Select device\n", "if torch.cuda.is_available():\n", " device = 'cuda' \n", "elif torch.backends.mps.is_available():\n", " device = 'mps'\n", "else:\n", " device = 'cpu'" ], "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Build model\n", "model = FULL_MODEL_REGISTRY.build('terramind_v1_tokenizer_s2l2a', pretrained=True)\n", "\n", "# For other modalities:\n", "# model = FULL_MODEL_REGISTRY.build('terramind_v1_tokenizer_s1rtc', pretrained=True)\n", "# model = FULL_MODEL_REGISTRY.build('terramind_v1_tokenizer_dem', pretrained=True)\n", "# model = FULL_MODEL_REGISTRY.build('terramind_v1_tokenizer_lulc', pretrained=True)\n", "# model = FULL_MODEL_REGISTRY.build('terramind_v1_tokenizer_ndvi', pretrained=True)\n", "\n", "model = model.to(device)\n", "model.eval()" ], "id": "72ebcbdd51a967ff", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Load an example (Replace S2L2A in the file paths for other modalities) \n", "examples = [\n", " '../examples/S2L2A/38D_378R_2_3.tif',\n", " '../examples/S2L2A/282D_485L_3_3.tif',\n", " '../examples/S2L2A/433D_629L_3_1.tif',\n", " '../examples/S2L2A/637U_59R_1_3.tif',\n", " '../examples/S2L2A/609U_541L_3_0.tif',\n", "]\n", "\n", "# Select example between 0 and 4\n", "data = rxr.open_rasterio(examples[1])\n", "# Conver to shape [B, C, 224, 224]\n", "data = torch.Tensor(data.values, device='cpu').unsqueeze(0)" ], "id": "f399d4fb83a5adfa", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Visualize S-2 L2A input as RGB\n", "rgb = data[0, [3,2,1]].clone().permute(1,2,0)\n", "rgb = (rgb / 2000).clip(0, 1) * 255\n", "rgb = rgb.cpu().numpy().round().astype(np.uint8)\n", "plt.imshow(rgb)\n", "plt.axis('off')\n", "plt.show()" ], "id": "93f674b55d92ab9", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Normalize input\n", "mean = torch.Tensor(v1_pretraining_mean['untok_sen2l2a@224'])\n", "std = torch.Tensor(v1_pretraining_std['untok_sen2l2a@224'])\n", "input = (data - mean[None, :, None, None]) / std[None, :, None, None]\n", "\n", "# See keys for other modalities:\n", "# v1_pretraining_mean.keys()" ], "id": "e7ac20d22ac706d9", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Run model with diffusion steps\n", "input = input.to(device)\n", "with torch.no_grad():\n", " # Encode & decode image\n", " reconstruction = model(input, timesteps=10)\n", "\n", " # Alternatively split the encoding and decoding process to analyze tokens \n", " # Encode image\n", " # _, _, tokens = model.encode(input)\n", " # Decode tokens\n", " # reconstruction = model.decode_tokens(tokens, verbose=True, timesteps=10)\n", "\n", "# Denormalize\n", "reconstruction = reconstruction.cpu()\n", "reconstruction = (reconstruction * std[None, :, None, None]) + mean[None, :, None, None]" ], "id": "7d9754c331e226cf", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", "\n", "# Visualize S-2 L2A input as RGB\n", "rgb = data[0, [3,2,1]].clone().permute(1,2,0)\n", "rgb = (rgb / 2000).clip(0, 1) * 255\n", "rgb = rgb.cpu().numpy().round().astype(np.uint8)\n", "ax[0].imshow(rgb)\n", "ax[0].axis('off')\n", "ax[0].set_title('Input')\n", "\n", "# Visualize S-2 L2A reconstruction as RGB\n", "rgb = reconstruction[0, [3,2,1]].clone().permute(1,2,0)\n", "rgb = (rgb / 2000).clip(0, 1) * 255\n", "rgb = rgb.cpu().numpy().round().astype(np.uint8)\n", "ax[1].imshow(rgb)\n", "ax[1].axis('off')\n", "ax[1].set_title('Reconstruction')\n", "\n", "plt.show()" ], "id": "e7b50dd37a5ae9a3", "outputs": [], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }