{ "cells": [ { "metadata": {}, "cell_type": "markdown", "source": [ "# Setup\n", "\n", "In colab: \n", "1. Go to \"Runtime\" -> \"Change runtime type\" -> Select \"T4 GPU\"\n", "2. Install TerraTorch" ], "id": "b4bacc318390456b" }, { "metadata": {}, "cell_type": "code", "source": [ "!pip install \"terratorch>=1.2.4\"\n", "!pip install gdown tensorboard \"setuptools<81\"" ], "id": "W_4z81Fn9RET", "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "2e8c1961-311b-49e0-b6ea-2867b44cb47a", "metadata": { "id": "2e8c1961-311b-49e0-b6ea-2867b44cb47a" }, "source": [ "import os\n", "import torch\n", "import gdown\n", "import terratorch\n", "import albumentations\n", "import numpy as np\n", "import lightning.pytorch as pl\n", "import matplotlib.pyplot as plt\n", "from pathlib import Path\n", "from terratorch.datamodules import GenericNonGeoSegmentationDataModule\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")" ], "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": "3. Download the dataset from Google Drive", "id": "917b65b8e7cd7d65" }, { "cell_type": "code", "source": [ "if not os.path.isfile(\"sen1floods11_v1.1.tar.gz\"):\n", " gdown.download(\"https://drive.google.com/uc?id=1lRw3X7oFNq_WyzBO6uyUJijyTuYm23VS\")\n", " !tar -xzf sen1floods11_v1.1.tar.gz" ], "metadata": { "id": "dw5-9A4A4OmI", "collapsed": true }, "id": "dw5-9A4A4OmI", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "# Fine-tuning via CLI\n", "\n", "Fine-tune TerraMind with only a config file!\n", "\n", "Locally, run the fine-tuning command in your terminal rather than in this notebook.\n", "(For Macbook users: Error with MPS and some macOS version, fix https://github.com/terrastackai/terratorch/issues/859)" ], "id": "187cbd6d24d36731" }, { "metadata": {}, "cell_type": "code", "source": [ "# Download config\n", "!wget https://raw.githubusercontent.com/IBM/terramind/refs/heads/main/configs/terramind_v1_tiny_sen1floods11.yaml" ], "id": "2e5c0d2f1ae45631", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Check the config\n", "!cat terramind_v1_tiny_sen1floods11.yaml" ], "id": "763bfc03a97fbbe3", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Run fine-tuning\n", "!terratorch fit -c terramind_v1_tiny_sen1floods11.yaml" ], "id": "7a1f41e5385e658e", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": "Let's explore the different settings of the config in a bit more detail ...", "id": "8481672d20df131" }, { "cell_type": "markdown", "id": "35ba4d58-8ff6-4f9c-bfb1-a70376f80494", "metadata": { "id": "35ba4d58-8ff6-4f9c-bfb1-a70376f80494" }, "source": [ "## Sen1Floods11 Dataset\n", "\n", "Lets start with analysing the dataset" ] }, { "cell_type": "code", "id": "e3854bdb-17a4-43c8-bfa8-822b44fd59c3", "metadata": { "id": "e3854bdb-17a4-43c8-bfa8-822b44fd59c3" }, "source": [ "dataset_path = Path(\"sen1floods11_v1.1\")\n", "!ls \"sen1floods11_v1.1/data\"" ], "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": "!ls \"sen1floods11_v1.1/data/S2L1CHand/\" | head", "id": "87d91245594c607d", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "TerraTorch provides generic data modules that work directly with PyTorch Lightning.\n", "\n", "Sen1Floods11 is a multimodal dataset that provides Sentinel-2 L2A and Sentinel-1 GRD data. \n", "Therefore, we are using the `GenericMultiModalDataModule`. \n", "This module is similar to the `GenericNonGeoSegmentationDataModule`, which is used for standard segmentation tasks.\n", "However, the data roots, `img_grep` are other settings are provided as dict to account for the multimodal inputs. You find all settings in the [documentation](https://terrastackai.github.io/terratorch/stable/generic_datamodules/).\n", "In a Lightning config, the data module is defined with the `data` key." ], "id": "a2f22dc984ead544" }, { "cell_type": "code", "id": "735803b1-a4bf-427f-a1e6-5ac755af33fc", "metadata": { "id": "735803b1-a4bf-427f-a1e6-5ac755af33fc" }, "source": [ "datamodule = terratorch.datamodules.GenericMultiModalDataModule(\n", " task=\"segmentation\",\n", " batch_size=8,\n", " num_workers=2,\n", " num_classes=2,\n", "\n", " # Define your input modalities. The names must match the keys in the following dicts.\n", " modalities=[\"S2L1C\", \"S1GRD\"],\n", " rgb_indices={\n", " \"S2L1C\": [3, 2, 1],\n", " \"S1GRD\": [0, 1, 0],\n", " },\n", "\n", " # Define data paths as dicts using the modality names as keys.\n", " train_data_root={\n", " \"S2L1C\": dataset_path / \"data/S2L1CHand\",\n", " \"S1GRD\": dataset_path / \"data/S1GRDHand\",\n", " },\n", " train_label_data_root=dataset_path / \"data/LabelHand\",\n", " val_data_root={\n", " \"S2L1C\": dataset_path / \"data/S2L1CHand\",\n", " \"S1GRD\": dataset_path / \"data/S1GRDHand\",\n", " },\n", " val_label_data_root=dataset_path / \"data/LabelHand\",\n", " test_data_root={\n", " \"S2L1C\": dataset_path / \"data/S2L1CHand\",\n", " \"S1GRD\": dataset_path / \"data/S1GRDHand\",\n", " },\n", " test_label_data_root=dataset_path / \"data/LabelHand\",\n", "\n", " # Define split files because all samples are saved in the same folder.\n", " train_split=dataset_path / \"splits/flood_train_data.txt\",\n", " val_split=dataset_path / \"splits/flood_valid_data.txt\",\n", " test_split=dataset_path / \"splits/flood_test_data.txt\",\n", " \n", " # Define suffix, again using dicts.\n", " image_grep={\n", " \"S2L1C\": \"*_S2Hand.tif\",\n", " \"S1GRD\": \"*_S1Hand.tif\",\n", " },\n", " label_grep=\"*_LabelHand.tif\",\n", " \n", " # With TerraTorch, you can select a subset of the dataset bands as model inputs by providing dataset_bands (all bands in the data) and output_bands (selected bands). This setting is optional for all modalities and needs to be provided as dicts.\n", " # Here is an example for with S-1 GRD. You could change the output to [\"VV\"] to only train on the first band. Note that means and stds must be aligned with the output_bands (equal length of values). \n", " dataset_bands={\n", " \"S1GRD\": [\"VV\", \"VH\"]\n", " },\n", " output_bands={\n", " \"S1GRD\": [\"VV\", \"VH\"]\n", " },\n", "\n", " # Define standardization values. We use the pre-training values here and providing the additional modalities is not a problem, which makes it simple to experiment with different modality combinations. Alternatively, use the dataset statistics that you can generate using `terratorch compute_statistics -c config.yaml` (requires concat_bands: true for this multimodal datamodule).\n", " means={\n", " \"S2L1C\": [2357.089, 2137.385, 2018.788, 2082.986, 2295.651, 2854.537, 3122.849, 3040.560, 3306.481, 1473.847, 506.070, 2472.825, 1838.929],\n", " \"S2L2A\": [1390.458, 1503.317, 1718.197, 1853.910, 2199.100, 2779.975, 2987.011, 3083.234, 3132.220, 3162.988, 2424.884, 1857.648],\n", " \"S1GRD\": [-12.599, -20.293],\n", " \"S1RTC\": [-10.93, -17.329],\n", " \"RGB\": [87.271, 80.931, 66.667],\n", " \"DEM\": [670.665]\n", " },\n", " stds={\n", " \"S2L1C\": [1624.683, 1675.806, 1557.708, 1833.702, 1823.738, 1733.977, 1732.131, 1679.732, 1727.26, 1024.687, 442.165, 1331.411, 1160.419],\n", " \"S2L2A\": [2106.761, 2141.107, 2038.973, 2134.138, 2085.321, 1889.926, 1820.257, 1871.918, 1753.829, 1797.379, 1434.261, 1334.311],\n", " \"S1GRD\": [5.195, 5.890],\n", " \"S1RTC\": [4.391, 4.459],\n", " \"RGB\": [58.767, 47.663, 42.631],\n", " \"DEM\": [951.272],\n", " },\n", " \n", " # albumentations supports shared transformations and can handle multimodal inputs. \n", " train_transform=[\n", " albumentations.D4(), # Random flips and rotation\n", " albumentations.pytorch.transforms.ToTensorV2(),\n", " ],\n", " val_transform=None, # Using ToTensorV2() by default if not provided\n", " test_transform=None,\n", " \n", " no_label_replace=-1, # Replace NaN labels. defaults to -1 which is ignored in the loss and metrics.\n", " no_data_replace=0, # Replace NaN data\n", " check_stackability=False, # Set to True if you are uncertain if all samples have the same input shape\n", ")\n", "\n", "# Setup train and val datasets\n", "datamodule.setup(\"fit\")" ], "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# checking datasets validation split size\n", "val_dataset = datamodule.val_dataset\n", "len(val_dataset)" ], "id": "b7062ddc-a3b7-4378-898c-41abcdf2ee3b", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# The GenericMultiModalDataModule has an integrated plotting function with min/max scaling:\n", "val_dataset.plot(val_dataset[2])\n", "val_dataset.plot(val_dataset[8])\n", "val_dataset.plot(val_dataset[11])\n", "plt.show()" ], "id": "3a1da2ad-a797-4f4a-ad1a-cd10f9addb01", "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "ede1c1c6-9f60-4510-a2da-572c55d03f79", "metadata": { "id": "ede1c1c6-9f60-4510-a2da-572c55d03f79" }, "source": [ "# checking datasets testing split size\n", "datamodule.setup(\"test\")\n", "test_dataset = datamodule.test_dataset\n", "len(test_dataset)" ], "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "# TerraTorch model registry\n", "\n", "TerraTorch includes its own backbone registry with many EO FMs. It also includes meta registries for all model components that include other sources like timm image models or SMP decoders." ], "id": "cf0453502fb0bf62" }, { "metadata": {}, "cell_type": "code", "source": "from terratorch.registry import BACKBONE_REGISTRY, TERRATORCH_BACKBONE_REGISTRY, TERRATORCH_DECODER_REGISTRY", "id": "d970183baaea88cd", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Print all TerraMind v1 backbones. \n", "[backbone \n", " for backbone in TERRATORCH_BACKBONE_REGISTRY\n", " if 'terramind_v1' in backbone]\n", "# TiM models are using the Thinking-in-Modalities approach, see our paper for details." ], "id": "f4109f8f262cc5f6", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Available decoders. We use the UNetDecoder in this example.\n", "list(TERRATORCH_DECODER_REGISTRY)" ], "id": "9a51fdde4e1e5d1a", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Use the backbone registry to load a PyTorch model for a custom pipeline. The pre-trained weights are automatically downloaded with pretrained=True.\n", "model = BACKBONE_REGISTRY.build(\n", " \"terramind_v1_small\",\n", " modalities=[\"S2L1C\", \"S1GRD\"],\n", " pretrained=True,\n", ")\n", "\n", "# You find more information about the settings in the TerraMind docs: https://terrastackai.github.io/terratorch/stable/guide/terramind/" ], "id": "56bc7fa971e02793", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": "model", "id": "9fcb50e133f20cd7", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "# Fine-tune TerraMind via PyTorch Lightning\n", "\n", "With TerraTorch, we can use standard Lightning components for the fine-tuning.\n", "These include callbacks and the trainer class.\n", "TerraTorch provides EO-specific tasks that define the training and validation steps.\n", "In this case, we are using the `SemanticSegmentationTask`.\n", "We refer to the [TerraTorch paper](https://arxiv.org/abs/2503.20563) for a detailed explanation of the TerraTorch tasks." ], "id": "654a30ddef8ed5a" }, { "cell_type": "code", "id": "ae69d39a-857a-4392-b058-0f4b518edf6e", "metadata": { "scrolled": true, "id": "ae69d39a-857a-4392-b058-0f4b518edf6e" }, "source": [ "pl.seed_everything(0)\n", "\n", "# By default, TerraTorch saves the model with the best validation loss. You can overwrite this by defining a custom ModelCheckpoint, e.g., saving the model with the highest validation mIoU. \n", "checkpoint_callback = pl.callbacks.ModelCheckpoint(\n", " dirpath=\"output/terramind_small_sen1floods11/checkpoints/\",\n", " mode=\"max\",\n", " monitor=\"val/mIoU\", # Variable to monitor\n", " filename=\"best-mIoU\",\n", " save_weights_only=True,\n", ")\n", "\n", "# Lightning Trainer\n", "trainer = pl.Trainer(\n", " accelerator=\"auto\",\n", " strategy=\"auto\",\n", " devices=1, # Deactivate multi-gpu because it often fails in notebooks\n", " precision=\"16-mixed\", # Speed up training with half precision, delete for full precision training.\n", " num_nodes=1,\n", " logger=True, # Uses TensorBoard by default\n", " max_epochs=3, # For demos\n", " log_every_n_steps=1,\n", " callbacks=[checkpoint_callback],\n", " default_root_dir=\"output/terramind_base_sen1floods11/\",\n", ")\n", "\n", "# Segmentation mask that build the model and handles training and validation steps. \n", "model = terratorch.tasks.SemanticSegmentationTask(\n", " model_factory=\"EncoderDecoderFactory\", # Combines a backbone with necks, the decoder, and a head\n", " model_args={\n", " # TerraMind backbone\n", " \"backbone\": \"terramind_v1_small\",\n", " \"backbone_pretrained\": True,\n", " \"backbone_modalities\": [\"S2L1C\", \"S1GRD\"],\n", " # Optionally, define the input bands. This is only needed if you select a subset of the pre-training bands, as explained above.\n", " # \"backbone_bands\": {\"S1GRD\": [\"VV\"]},\n", " \n", " # Necks \n", " \"necks\": [\n", " {\n", " \"name\": \"SelectIndices\",\n", " \"indices\": [2, 5, 8, 11] # indices for terramind_v1_base\n", " # \"indices\": [5, 11, 17, 23] # indices for terramind_v1_large\n", " },\n", " {\"name\": \"ReshapeTokensToImage\",\n", " \"remove_cls_token\": False}, # TerraMind is trained without CLS token, which neads to be specified.\n", " {\"name\": \"LearnedInterpolateToPyramidal\"} # Some decoders like UNet or UperNet expect hierarchical features. Therefore, we need to learn a upsampling for the intermediate embedding layers when using a ViT like TerraMind.\n", " ],\n", " \n", " # Decoder\n", " \"decoder\": \"UNetDecoder\",\n", " \"decoder_channels\": [256, 128, 64, 32],\n", " # Warning for Mac users: The UNetDecoder can lead to some failures because of the batch norms when training on MPS.\n", " # Use a FCN decoder instead or install TerraTorch from main/v1.2.5 for a fix.\n", " \n", " # Head\n", " \"head_dropout\": 0.1,\n", " \"num_classes\": 2,\n", " },\n", " \n", " loss=\"dice\", # We recommend dice for binary tasks and ce for tasks with multiple classes. \n", " optimizer=\"AdamW\",\n", " lr=2e-5, # The optimal learning rate varies between datasets, we recommend testing different once between 1e-5 and 1e-4. You can perform hyperparameter optimization using terratorch-iterate. \n", " scheduler='ReduceLROnPlateau', # optionally define a learning rate scheduler and pass hparams\n", " scheduler_hparams={\n", " 'factor': 0.5, # This \"reduce LR on plateau\" scheduler multiplies the lr by when the val loss did not improve for epochs\n", " 'patience': 5\n", " },\n", " ignore_index=-1,\n", " freeze_backbone=True, # Only used to speed up fine-tuning in this demo, we highly recommend fine-tuning the backbone for the best performance.\n", " freeze_decoder=False, # Should be false in most cases as the decoder is randomly initialized.\n", " plot_on_val=True, # Plot predictions during validation steps\n", " class_names=[\"Others\", \"Water\"], # optionally define class names\n", " class_weights=[0.3, 0.7], # optionally define class weights for in-balanced datasets\n", ")" ], "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Before starting the fine-tuning, you can start the tensorboard with:\n", "%load_ext tensorboard\n", "%tensorboard --logdir output" ], "id": "ca03ce8977006bb0", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Training\n", "trainer.fit(model, datamodule=datamodule)" ], "id": "ff284062edfce308", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": "After fine-tuning, we can evaluate the model on the test set:", "id": "3c1bebdb7370a174" }, { "cell_type": "code", "id": "35a77263-5308-4781-a17f-a35e62ca1875", "metadata": { "scrolled": true, "id": "35a77263-5308-4781-a17f-a35e62ca1875" }, "source": [ "# Let's test the fine-tuned model\n", "best_ckpt_path = \"output/terramind_small_sen1floods11/checkpoints/best-mIoU.ckpt\"\n", "trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path)\n", "\n", "# Note: This demo only trains for 5 epochs by default, which does not result in good test metrics." ], "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Now we can use the model for predictions and plotting\n", "model = terratorch.tasks.SemanticSegmentationTask.load_from_checkpoint(\n", " best_ckpt_path,\n", " model_factory=model.hparams.model_factory,\n", " model_args=model.hparams.model_args,\n", ")\n", "model.eval()\n", "\n", "test_loader = datamodule.test_dataloader()\n", "with torch.no_grad():\n", " batch = next(iter(test_loader))\n", " image = batch[\"image\"].copy()\n", " batch = datamodule.aug(batch)\n", " input = batch[\"image\"]\n", " for mod, value in input.items():\n", " input[mod] = value.to(model.device)\n", " masks = batch[\"mask\"].numpy()\n", "\n", " with torch.no_grad():\n", " outputs = model(input)\n", "\n", " preds = torch.argmax(outputs.output, dim=1).cpu().numpy()\n", "\n", "for i in range(5):\n", " sample = {\n", " \"S2L1C\": image[\"S2L1C\"][i].cpu(),\n", " \"S1GRD\": image[\"S1GRD\"][i].cpu(),\n", " \"mask\": batch[\"mask\"][i],\n", " \"prediction\": preds[i],\n", " }\n", " test_dataset.plot(sample)\n", " plt.show()\n", "\n", "# Note: This demo only trains for 5 epochs by default, which does not result in good predictions." ], "id": "abef6a74e8130cc", "outputs": [], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "name": "python3", "language": "python" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" }, "colab": { "provenance": [] } }, "nbformat": 4, "nbformat_minor": 5 }