{ "cells": [ { "metadata": {}, "cell_type": "markdown", "source": [ "# HLS Burn Scars\n", "\n", "## Task: Adapt the notebook yourself to perform fine-tuning with TerraMind on HLS Burn Scars.\n", "\n", "You find several TODOs in this notebook.\n", "\n", "Use the dataset description (https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars) and the TerraMind docs (https://terrastackai.github.io/terratorch/stable/guide/terramind/) to solve the TODOs." ], "id": "7b00130721ffbf74" }, { "metadata": {}, "cell_type": "markdown", "source": [ "# Setup\n", "\n", "In colab: \n", "1. Go to \"Runtime\" -> \"Change runtime type\" -> Select \"T4 GPU\"\n", "2. Install TerraTorch" ], "id": "b4bacc318390456b" }, { "metadata": {}, "cell_type": "code", "source": [ "!pip install \"terratorch>=1.2.4\"\n", "!pip install gdown tensorboard \"setuptools<81\"" ], "id": "W_4z81Fn9RET", "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "2e8c1961-311b-49e0-b6ea-2867b44cb47a", "metadata": { "id": "2e8c1961-311b-49e0-b6ea-2867b44cb47a" }, "source": [ "import os\n", "import torch\n", "import gdown\n", "import terratorch\n", "import albumentations\n", "import numpy as np\n", "import lightning.pytorch as pl\n", "import matplotlib.pyplot as plt\n", "from pathlib import Path\n", "from terratorch.datamodules import GenericNonGeoSegmentationDataModule\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")" ], "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": "3. Download the dataset from Google Drive", "id": "917b65b8e7cd7d65" }, { "cell_type": "code", "source": [ "# This version is an adaptation from https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars (same data, updated folder structure) with splits from https://github.com/IBM/peft-geofm/tree/main/datasets_splits/burn_scars. You can also download the original data from Hugging Face.\n", "\n", "if not os.path.isfile('hls_burn_scars.tar.gz'):\n", " gdown.download(\"https://drive.google.com/uc?id=1yFDNlGqGPxkc9lh9l1O70TuejXAQYYtC\")\n", "\n", "if not os.path.isdir('hls_burn_scars/'):\n", " !tar -xzf hls_burn_scars.tar.gz" ], "metadata": { "id": "dw5-9A4A4OmI", "collapsed": true }, "id": "dw5-9A4A4OmI", "outputs": [], "execution_count": null }, { "cell_type": "markdown", "id": "35ba4d58-8ff6-4f9c-bfb1-a70376f80494", "metadata": { "id": "35ba4d58-8ff6-4f9c-bfb1-a70376f80494" }, "source": [ "## HLS Burn Scars Dataset\n", "\n", "Lets start with analysing the dataset" ] }, { "cell_type": "code", "id": "e3854bdb-17a4-43c8-bfa8-822b44fd59c3", "metadata": { "id": "e3854bdb-17a4-43c8-bfa8-822b44fd59c3" }, "source": [ "dataset_path = Path(\"hls_burn_scars\")\n", "!ls \"hls_burn_scars\"" ], "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": "!ls \"hls_burn_scars/splits/\" | head", "id": "2b52842a5b6ae51f", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": "!ls \"hls_burn_scars/data/\" | head", "id": "a894497e94b8c649", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": "!head \"hls_burn_scars/splits/train.txt\"", "id": "beb9471f65722f7f", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "import rioxarray as rxr\n", "sample = rxr.open_rasterio('hls_burn_scars/data/subsetted_512x512_HLS.S30.T10SDH.2020248.v1.4_merged.tif')\n", "sample" ], "id": "886a65a53367015f", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "TerraTorch provides generic data modules that work directly with PyTorch Lightning. For single-modal datasets we recommend the `GenericNonGeoSegmentationDataModule`.\n", "Check the docs for the parameters.\n", "\n", "Data module: https://terrastackai.github.io/terratorch/stable/package/generic_datamodules/#terratorch.datamodules.generic_pixel_wise_data_module.GenericNonGeoSegmentationDataModule\n", "\n", "Dataset: https://terrastackai.github.io/terratorch/stable/package/generic_datasets/#terratorch.datasets.generic_pixel_wise_dataset.GenericNonGeoSegmentationDataset" ], "id": "a2f22dc984ead544" }, { "metadata": { "id": "735803b1-a4bf-427f-a1e6-5ac755af33fc" }, "cell_type": "code", "source": [ "datamodule = terratorch.datamodules.GenericNonGeoSegmentationDataModule(\n", " batch_size=8,\n", " num_workers=2,\n", " num_classes=2,\n", " rgb_indices=[2, 1, 0],\n", "\n", " # TODO: Define the data and label paths\n", " train_data_root=None,\n", " train_label_data_root=None,\n", " val_data_root=None,\n", " val_label_data_root=None,\n", " test_data_root=None,\n", " test_label_data_root=None,\n", "\n", " # TODO: Define the split files\n", " train_split=None,\n", " val_split=None,\n", " test_split=None,\n", "\n", " # TODO: Define suffixs\n", " image_grep=None,\n", " label_grep=None,\n", "\n", " # TODO: Update the standardization values. They need to be the same length as the data.\n", " # You can define a constant_scale that applies a multiplicator in case the data does not align with the the standardization values.\n", " constant_scale=None,\n", " means=None,\n", " stds=None,\n", " # fyi TerraMind pretraining values (assuming data in range 0-10000)\n", " # S2L2A means: [1390.458, 1503.317, 1718.197, 1853.910, 2199.100, 2779.975, 2987.011, 3083.234, 3132.220, 3162.988, 2424.884, 1857.648]\n", " # S2L2A stds: [2106.761, 2141.107, 2038.973, 2134.138, 2085.321, 1889.926, 1820.257, 1871.918, 1753.829, 1797.379, 1434.261, 1334.311]\n", "\n", " # albumentations supports shared transformations and can handle multimodal inputs.\n", " train_transform=[\n", " albumentations.D4(), # Random flips and rotation\n", " albumentations.pytorch.transforms.ToTensorV2(),\n", " ],\n", " val_transform=None, # Fallback to ToTensor\n", " test_transform=None,\n", "\n", " no_label_replace=-1, # Replace NaN labels. defaults to -1 which is ignored in the loss and metrics.\n", " no_data_replace=0, # Replace NaN data\n", " check_stackability=True, # Ideally leave it True if you check a dataset the first time. Afterward, False to be more efficient\n", ")\n", "\n", "# Setup train and val datasets\n", "datamodule.setup(\"fit\")" ], "id": "735803b1-a4bf-427f-a1e6-5ac755af33fc", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# checking datasets validation split size\n", "val_dataset = datamodule.val_dataset\n", "len(val_dataset)" ], "id": "b7062ddc-a3b7-4378-898c-41abcdf2ee3b", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Helper function for plotting both modalities\n", "def plot_sample(sample):\n", " data = sample['image'].cpu().numpy()\n", " mask = sample['mask'].cpu().numpy()\n", "\n", " # Scaling data.\n", " if data.mean() < 1:\n", " data = data * 10000\n", " data = (data.clip(0, 2000) / 2000) * 255\n", " rgb = data[[2, 1, 0]].astype(np.uint8).transpose(1,2,0)\n", "\n", " fig, ax = plt.subplots(1, 3, figsize=(12, 4))\n", " ax[0].imshow(rgb)\n", " ax[0].set_title('Image')\n", " ax[0].axis('off')\n", " ax[1].imshow(mask, vmin=-1, vmax=1, interpolation='nearest')\n", " ax[1].set_title('Mask')\n", " ax[1].axis('off')\n", " ax[2].imshow(rgb)\n", " ax[2].imshow(mask, vmin=-1, vmax=1, interpolation='nearest', alpha=0.5)\n", " ax[2].set_title('Mask on Image')\n", " ax[2].axis('off')\n", " fig.tight_layout()\n", " plt.show()" ], "id": "6b1262546b1a8e68", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Plotting a few samples\n", "plot_sample(val_dataset[0])\n", "plot_sample(val_dataset[1])\n", "plot_sample(val_dataset[2])\n", "\n", "# The GenericMultiModalDataModule has an integrated plotting function.\n", "# We use a custom one because the 0-2000 scaling looks a bit nicer than the min-max scaling of the integrated function.\n", "\n", "# val_dataset.plot(val_dataset[0])\n", "# plt.show()" ], "id": "3a1da2ad-a797-4f4a-ad1a-cd10f9addb01", "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "ede1c1c6-9f60-4510-a2da-572c55d03f79", "metadata": { "id": "ede1c1c6-9f60-4510-a2da-572c55d03f79" }, "source": [ "# checking datasets testing split size\n", "datamodule.setup(\"test\")\n", "test_dataset = datamodule.test_dataset\n", "len(test_dataset)" ], "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "# Fine-tune TerraMind via PyTorch Lightning\n", "\n", "With TerraTorch, we can use standard Lightning components for the fine-tuning.\n", "These include callbacks and the trainer class.\n", "TerraTorch provides EO-specific tasks that define the training and validation steps.\n", "In this case, we are using the `SemanticSegmentationTask`.\n", "We refer to the [TerraTorch paper](https://arxiv.org/abs/2503.20563) for a detailed explanation of the TerraTorch tasks." ], "id": "654a30ddef8ed5a" }, { "cell_type": "code", "id": "ae69d39a-857a-4392-b058-0f4b518edf6e", "metadata": { "scrolled": true, "id": "ae69d39a-857a-4392-b058-0f4b518edf6e" }, "source": [ "pl.seed_everything(0)\n", "\n", "# By default, TerraTorch saves the model with the best validation loss. You can overwrite this by defining a custom ModelCheckpoint, e.g., saving the model with the highest validation mIoU. \n", "# TODO Optionally adjust the checkpoint\n", "checkpoint_callback = pl.callbacks.ModelCheckpoint(\n", " dirpath=None,\n", " mode=\"max\",\n", " monitor=\"val/mIoU\", # Variable to monitor\n", " filename=\"best-mIoU\",\n", " save_weights_only=True,\n", ")\n", "\n", "# Lightning Trainer\n", "trainer = pl.Trainer(\n", " accelerator=\"auto\",\n", " strategy=\"auto\",\n", " devices=1, # Deactivate multi-gpu because it often fails in notebooks\n", " precision=\"16-mixed\", # Speed up training with half precision, delete for full precision training.\n", " num_nodes=1,\n", " logger=True, # Uses TensorBoard by default\n", " max_epochs=3, # For demos\n", " log_every_n_steps=1,\n", " callbacks=[checkpoint_callback],\n", " # TODO Define output dir\n", " default_root_dir=None,\n", ")\n", "\n", "# Segmentation mask that build the model and handles training and validation steps. \n", "model = terratorch.tasks.SemanticSegmentationTask(\n", " model_factory=\"EncoderDecoderFactory\", # Combines a backbone with necks, the decoder, and a head\n", " model_args={\n", " # TerraMind backbone\n", " \"backbone\": \"terramind_v1_small\",\n", " \"backbone_pretrained\": True,\n", " # TODO Select the modality. Check the docs for details https://terrastackai.github.io/terratorch/stable/guide/terramind/#subset-of-input-bands\n", " \"backbone_modalities\": [\"TODO\"],\n", " # TODO define the input bands. This is only needed because you need to select a subset of the pre-training bands for Burn Scars.\n", " # Check the names in the \"List of pre-trained bands\" in the docs.\n", " \"backbone_bands\": {\"TODO\": [\"TODO\"]},\n", " \n", " # Necks \n", " \"necks\": [\n", " {\n", " \"name\": \"SelectIndices\",\n", " \"indices\": [2, 5, 8, 11] # indices for terramind_v1_base\n", " # \"indices\": [5, 11, 17, 23] # indices for terramind_v1_large\n", " },\n", " {\"name\": \"ReshapeTokensToImage\",\n", " \"remove_cls_token\": False}, # TerraMind is trained without CLS token, which neads to be specified.\n", " {\"name\": \"LearnedInterpolateToPyramidal\"} # Some decoders like UNet or UperNet expect hierarchical features. Therefore, we need to learn a upsampling for the intermediate embedding layers when using a ViT like TerraMind.\n", " ],\n", " \n", " # Decoder\n", " \"decoder\": \"UNetDecoder\",\n", " \"decoder_channels\": [256, 128, 64, 32],\n", " # Warning for Mac users: The UNetDecoder can lead to some failures because of the batch norms when training on MPS.\n", " # Use a FCN decoder instead or install TerraTorch from main/v1.2.5 for a fix.\n", " \n", " # Head\n", " \"head_dropout\": 0.1,\n", " \"num_classes\": 2,\n", " },\n", " \n", " loss=\"dice\", # We recommend dice for binary tasks and ce for tasks with multiple classes. \n", " optimizer=\"AdamW\",\n", " lr=2e-5, # The optimal learning rate varies between datasets, we recommend testing different once between 1e-5 and 1e-4. You can perform hyperparameter optimization using terratorch-iterate. \n", " scheduler='ReduceLROnPlateau', # optionally define a learning rate scheduler and pass hparams\n", " scheduler_hparams={\n", " 'factor': 0.5, # This \"reduce LR on plateau\" scheduler multiplies the lr by when the val loss did not improve for epochs\n", " 'patience': 5\n", " },\n", " ignore_index=-1,\n", " freeze_backbone=True, # Only used to speed up fine-tuning in this demo, we highly recommend fine-tuning the backbone for the best performance. \n", " freeze_decoder=False, # Should be false in most cases as the decoder is randomly initialized.\n", " plot_on_val=True, # Plot predictions during validation steps \n", " class_names=[\"Others\", \"Burned\"], # optionally define class names\n", " class_weights=[0.5, 0.5], # optionally define class weights for in-balanced datasets\n", ")" ], "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Before starting the fine-tuning, you can start the tensorboard with:\n", "%load_ext tensorboard\n", "%tensorboard --logdir output" ], "id": "ca03ce8977006bb0", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Training\n", "trainer.fit(model, datamodule=datamodule)" ], "id": "ff284062edfce308", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": "After fine-tuning, we can evaluate the model on the test set:", "id": "3c1bebdb7370a174" }, { "cell_type": "code", "id": "35a77263-5308-4781-a17f-a35e62ca1875", "metadata": { "scrolled": true, "id": "35a77263-5308-4781-a17f-a35e62ca1875" }, "source": [ "# Let's test the fine-tuned model\n", "best_ckpt_path = \"output/terramind_small_burnscars/checkpoints/best-mIoU.ckpt\"\n", "trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path)\n", "\n", "# Note: This demo only trains for 3 epochs by default, which does not result in good test metrics." ], "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "1e015fe0-88ee-46cf-b972-f8cb9d361536", "metadata": { "id": "1e015fe0-88ee-46cf-b972-f8cb9d361536", "outputId": "c7c06228-e634-4608-cb0e-617d003fea46", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 } }, "source": [ "# Now we can use the model for predictions and plotting\n", "model = terratorch.tasks.SemanticSegmentationTask.load_from_checkpoint(\n", " best_ckpt_path,\n", " model_factory=model.hparams.model_factory,\n", " model_args=model.hparams.model_args,\n", ")\n", "model.eval()\n", "\n", "test_loader = datamodule.test_dataloader()\n", "with torch.no_grad():\n", " batch = next(iter(test_loader))\n", " image = batch[\"image\"].copy()\n", " batch = datamodule.aug(batch)\n", " input = batch[\"image\"]\n", " for mod, value in input.items():\n", " input[mod] = value.to(model.device)\n", " masks = batch[\"mask\"].numpy()\n", "\n", " with torch.no_grad():\n", " outputs = model(input)\n", " \n", " preds = torch.argmax(outputs.output, dim=1).cpu().numpy()\n", "\n", "for i in range(5):\n", " sample = {\n", " \"image\": image[\"S2L2A\"][i].cpu(),\n", " \"mask\": batch[\"mask\"][i],\n", " \"prediction\": preds[i],\n", " }\n", " plot_sample(sample)\n", " # test_dataset.plot(sample)\n", " # plt.show()\n", " \n", "# Note: This demo only trains for 5 epochs by default, which does not result in good predictions." ], "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "# Burnscars config\n", "\n", "If you are struggling with this task, you can check this burn scars config for guidance.\n", "\n", "Please note, that this config uses the generic segmentation dataset instead of the generic multimodal dataset. The idea is similar, but the details are a bit different (e.g. dicts of strings)." ], "id": "510f3843c7fc348a" }, { "metadata": {}, "cell_type": "code", "source": [ "# Download config\n", "!wget https://raw.githubusercontent.com/IBM/terramind/refs/heads/main/configs/terramind_v1_base_burnscars.yaml" ], "id": "672761a04336db8b", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Check the config\n", "!cat terramind_v1_base_burnscars.yaml" ], "id": "e1cd8ede424eff55", "outputs": [], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "name": "python3", "language": "python" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" }, "colab": { "provenance": [] } }, "nbformat": 4, "nbformat_minor": 5 }