{ "cells": [ { "metadata": {}, "cell_type": "markdown", "source": [ "# Large-tile generation\n", "\n", "TerraMind was pre-trained on small patches of 224x224 pixels. Passing larger inputs to the model works as long as it is a multiple of 16x16 pixel. However, this is outside the training scope and can lead to worse generation results or OOM errors. \n", "This example performs generation of a larger tile using the `tiled_inference` function provided by TerraTorch." ], "id": "7b6872e1a26474b6" }, { "metadata": {}, "cell_type": "markdown", "source": [ "### Colab Setup (not required locally)\n", "1. Go to \"Runtime\" -> \"Change runtime type\" -> Select \"T4 GPU\"\n", "2. Install TerraTorch" ], "id": "a8bc734f1c8d70a5" }, { "metadata": {}, "cell_type": "code", "source": [ "# Install required packages in Colab\n", "!pip install git+https://github.com/IBM/terratorch.git\n", "!pip install rioxarray matplotlib\n", "!wget https://raw.githubusercontent.com/IBM/terramind/refs/heads/main/notebooks/plotting_utils.py" ], "id": "397d74d24c98fbc0", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "import os\n", "import torch\n", "import numpy as np\n", "import rioxarray as rxr\n", "import matplotlib.pyplot as plt\n", "from huggingface_hub import hf_hub_download\n", "from plotting_utils import plot_s2, plot_modality\n", "from terratorch.registry import FULL_MODEL_REGISTRY\n", "from terratorch.tasks.tiled_inference import tiled_inference\n", "\n", "# Select device\n", "if torch.cuda.is_available():\n", " device = 'cuda' \n", "elif torch.backends.mps.is_available():\n", " device = 'mps'\n", "else:\n", " device = 'cpu'" ], "id": "initial_id", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Download Santiago large-scale example from Hugging Face (2000x2000 pixel)\n", "if not os.path.isfile('examples/S2L2A/Santiago.tif'):\n", " hf_hub_download(repo_id='ibm-esa-geospatial/Examples', filename='S2L2A/Santiago.tif', repo_type='dataset', local_dir='examples/')" ], "id": "82a6e86fa59dfbc8", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Download Singapore large-scale example from Hugging Face (2000x2000 pixel)\n", "if not os.path.isfile('examples/S2L2A/Singapore_2025-01-09.tif'):\n", " hf_hub_download(repo_id='ibm-esa-geospatial/Examples', filename='S2L2A/Singapore_2025-01-09.tif', repo_type='dataset', local_dir='examples/')" ], "id": "db30e699361d4b76", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Load Singapore tile\n", "# data = rxr.open_rasterio('examples/S2L2A/Singapore_2025-01-09.tif').values\n", "\n", "# Load Santiago tile (alternative input)\n", "data = rxr.open_rasterio('examples/S2L2A/Santiago.tif').values\n", "\n", "# Optionally reduce image size to speed up inference\n", "data = data[:, 500:1500]\n", "\n", "# Build input tensor and add batch dimension\n", "input = torch.tensor(data, dtype=torch.float, device=device).unsqueeze(0)\n", "\n", "# Display the input\n", "fig, ax = plt.subplots(1, 1, figsize=(12, 12))\n", "plot_s2(data, ax=ax)\n", "plt.show()" ], "id": "bf5a1497474ca1a1", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Define the output from S2L2A, S1GRD, S1RTC, DEM, LULC, and NDVI\n", "output_modalities = ['S1GRD', 'LULC']\n", "\n", "# Build model\n", "model = FULL_MODEL_REGISTRY.build(\n", " 'terramind_v1_base_generate',\n", " modalities=['S2L2A'],\n", " output_modalities=output_modalities,\n", " pretrained=True,\n", " standardize=True,\n", " timesteps=10, # Number of diffusion steps\n", ")\n", "\n", "model = model.to(device)\n", "model.eval()" ], "id": "1753c1ef7ce847ae", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "# Run inference\n", "\n", "`tiled_inference` can be used with any model as long as the output of the passed `forward` function returns a tensor. We therefore wrap the model and stack the generated modalities. " ], "id": "c059321429f3bd41" }, { "metadata": {}, "cell_type": "code", "source": [ "# Define model forward for tiled inference\n", "def model_forward(x):\n", " # Run chained generation for all output modalities \n", " generated = model(x)\n", " # TerraTorch tiled inference expects a tensor output from model forward. We concatenate all generations along channel dimension. \n", " out = torch.concat([generated[m] for m in output_modalities], dim=1)\n", " return out \n", "\n", "pred = tiled_inference(model_forward, input, crop=256, stride=192, batch_size=16, verbose=True)\n", "pred = pred.squeeze(0) # Remove batch dim\n", "\n", "print(pred.shape)" ], "id": "b3a48ca53500b0d1", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Unstack output modalities after tiled_inference\n", "\n", "# Get start index and number of channels\n", "num_channels = {'S2L2A': 12, 'S1GRD': 2, 'S1RTC': 2, 'DEM': 1, 'LULC': 10, 'NDVI': 1}\n", "num_channels = {m: num_channels[m] for m in output_modalities}\n", "start_idx = np.cumsum([0] + list(num_channels.values()))\n", "\n", "# Split up the stacked bands into each modality\n", "generated = {m: pred[i:i+c].cpu() for m, i, c in zip(output_modalities, start_idx, num_channels.values())}\n", "\n", "if 'LULC' in generated.keys():\n", " # Get LULC classes\n", " generated['LULC'] = generated['LULC'].argmax(dim=0)\n", "\n", "print(generated.keys())" ], "id": "ce8d977622b0d5d", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "# Analyze the generations\n", "\n", "Let's have a look at the generations. You can notice that all modalities capture the general features quite well but are often wrongly scaled. While S1 and NDVI values are closer to the mean values than the ground trough, DEM generations clearly shows the smaller patches of the tiled inference. The model fails to estimate the general elevation of each single patch. Because of the tiled inference, some generations can look a bit patchy, e.g., S1 below clouds as the model estimates the shoreline." ], "id": "20156eb28dd7da2a" }, { "metadata": {}, "cell_type": "code", "source": [ "# Plot generations\n", "n_plots = len(generated) + 1\n", "fig, ax = plt.subplots(1, n_plots, figsize=(5 * n_plots, 5))\n", "\n", "plot_s2(input, ax=ax[0])\n", "ax[0].set_title('Input')\n", "\n", "for i, (mod, value) in enumerate(generated.items()):\n", " plot_modality(mod, value, ax=ax[i + 1])\n", "\n", " ax[i+1].set_title('generated ' + mod)\n", " \n", "plt.show()" ], "id": "c189d05839f52c15", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "# Compare generations in split view\n", "\n", "We use leafmap for the interactive visualisation of a generation." ], "id": "f5fd2e2a2f1b787f" }, { "metadata": {}, "cell_type": "code", "source": [ "# Install leafmap\n", "!pip install leafmap" ], "id": "4fafb26038e21eb9", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "import leafmap\n", "from plotting_utils import s2_to_rgb, s1_to_rgb, dem_to_rgb, ndvi_to_rgb, lulc_to_rgb" ], "id": "1783546fe1feb55c", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Plot comparison using leafmap \n", "leafmap.image_comparison(\n", " s2_to_rgb(input),\n", " s1_to_rgb(generated['S1GRD']),\n", " # dem_to_rgb(generated['DEM']),\n", " # ndvi_to_rgb(generated['NDVI']),\n", " # lulc_to_rgb(generated['LULC']),\n", " label1=\"S-2 L2A\",\n", " label2=\"Generated data\",\n", " starting_position=50,\n", " out_html=\"terramind_generation.html\",\n", ")" ], "id": "8bf7d7c4131a0058", "outputs": [], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }