{ "cells": [ { "cell_type": "markdown", "id": "b4bacc318390456b", "metadata": {}, "source": [ "# Setup\n", "1. In colab: Go to \"Runtime\" -> \"Change runtime type\" -> Select \"T4 GPU\"\n", "2. Install TerraTorch" ] }, { "cell_type": "code", "id": "W_4z81Fn9RET", "metadata": { "scrolled": true }, "source": [ "!pip install \"terratorch>=1.2.4\"\n", "!pip install gdown tensorboard \"setuptools<81\"" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "2e8c1961-311b-49e0-b6ea-2867b44cb47a", "metadata": { "id": "2e8c1961-311b-49e0-b6ea-2867b44cb47a" }, "source": [ "import os\n", "import sys\n", "import torch\n", "import gdown\n", "import terratorch\n", "import albumentations\n", "import lightning.pytorch as pl\n", "import matplotlib.pyplot as plt\n", "from terratorch.datamodules import MultiTemporalCropClassificationDataModule\n", "import warnings\n", "warnings.filterwarnings('ignore')" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "id": "917b65b8e7cd7d65", "metadata": {}, "source": [ "3. Download the dataset from Google Drive" ] }, { "cell_type": "code", "id": "3t-YKKUztjXn", "metadata": { "id": "3t-YKKUztjXn", "scrolled": true }, "source": [ "# Download a random subset for demos (~1 GB)\n", "\n", "if not os.path.isdir('multi-temporal-crop-classification-subset/'):\n", " if not os.path.isfile('multi-temporal-crop-classification-subset.tar.gz'):\n", " gdown.download(\"https://drive.google.com/uc?id=1SycflNslu47yfMg2i_z8FqYkhZQv7JQM\")\n", " !tar -xzf multi-temporal-crop-classification-subset.tar.gz\n", "\n", "dataset_path = \"multi-temporal-crop-classification-subset\"" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "id": "35ba4d58-8ff6-4f9c-bfb1-a70376f80494", "metadata": { "id": "35ba4d58-8ff6-4f9c-bfb1-a70376f80494" }, "source": [ "## Multi-temporal Crop Dataset\n", "\n", "Lets start with analyzing the dataset\n" ] }, { "cell_type": "code", "id": "e3854bdb-17a4-43c8-bfa8-822b44fd59c3", "metadata": { "id": "e3854bdb-17a4-43c8-bfa8-822b44fd59c3" }, "source": [ "!ls \"{dataset_path}\"" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "ddd7d83440895e87", "metadata": {}, "source": [ "# Each merged sample includes the stacked bands of three time steps\n", "!ls \"{dataset_path}/training_chips\" | head" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "735803b1-a4bf-427f-a1e6-5ac755af33fc", "metadata": { "id": "735803b1-a4bf-427f-a1e6-5ac755af33fc" }, "source": [ "# Adjusted dataset class for this dataset (general dataset could be used as well)\n", "datamodule = MultiTemporalCropClassificationDataModule(\n", " batch_size=8,\n", " num_workers=2,\n", " data_root=dataset_path,\n", " train_transform=[\n", " terratorch.datasets.transforms.FlattenTemporalIntoChannels(), # Required for temporal data\n", " albumentations.D4(), # Random flips and rotation\n", " albumentations.pytorch.transforms.ToTensorV2(),\n", " terratorch.datasets.transforms.UnflattenTemporalFromChannels(n_timesteps=3),\n", " ],\n", " val_transform=None, # Using ToTensor() by default\n", " test_transform=None,\n", " expand_temporal_dimension=True,\n", " use_metadata=False, # The crop dataset has metadata for location and time\n", " reduce_zero_label=True,\n", ")\n", "\n", "# Setup train and val datasets\n", "datamodule.setup(\"fit\")" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "a87ed3b7-f7dc-486d-ac59-cd781a070925", "metadata": { "id": "a87ed3b7-f7dc-486d-ac59-cd781a070925" }, "source": [ "# checking for the dataset means and stds\n", "datamodule.means, datamodule.stds" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "08644e71-d82f-426c-b0c1-79026fccb578", "metadata": { "id": "08644e71-d82f-426c-b0c1-79026fccb578" }, "source": [ "# checking datasets train split size\n", "train_dataset = datamodule.train_dataset\n", "len(train_dataset)" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "88b86821-3481-4d92-bdba-246568c66c48", "metadata": { "id": "88b86821-3481-4d92-bdba-246568c66c48" }, "source": [ "# checking datasets available bands\n", "train_dataset.all_band_names" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "9264de41-ab16-43cc-9ea2-ee51b0969624", "metadata": { "id": "9264de41-ab16-43cc-9ea2-ee51b0969624" }, "source": [ "# checking datasets classes\n", "train_dataset.class_names" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "3a1da2ad-a797-4f4a-ad1a-cd10f9addb01", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "3a1da2ad-a797-4f4a-ad1a-cd10f9addb01", "outputId": "9c948b7c-e02b-4980-a142-b36bcb51a8e4" }, "source": [ "# plotting a few samples\n", "for i in range(5):\n", " train_dataset.plot(train_dataset[i])" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "b7062ddc-a3b7-4378-898c-41abcdf2ee3b", "metadata": { "id": "b7062ddc-a3b7-4378-898c-41abcdf2ee3b" }, "source": [ "# checking datasets validation split size\n", "val_dataset = datamodule.val_dataset\n", "len(val_dataset)" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "ede1c1c6-9f60-4510-a2da-572c55d03f79", "metadata": { "id": "ede1c1c6-9f60-4510-a2da-572c55d03f79" }, "source": [ "# checking datasets testing split size\n", "datamodule.setup(\"test\")\n", "test_dataset = datamodule.test_dataset\n", "len(test_dataset)" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "id": "4072e2f849c0df2d", "metadata": {}, "source": [ "# Fine-tune TerraMind via PyTorch Lightning\n", "\n", "With TerraTorch, we can use standard Lightning components for the fine-tuning.\n", "These include callbacks and the trainer class.\n", "TerraTorch provides EO-specific tasks that define the training and validation steps.\n", "In this case, we are using the `SemanticSegmentationTask`.\n", "We refer to the [TerraTorch paper](https://arxiv.org/abs/2503.20563) for a detailed explanation of the TerraTorch tasks.\n", "\n", "## Temporal Wrapper\n", "\n", "TerraMind does not support multi-temporal inputs natively. Therefore, we use the temporal wrapper that applies the encoder on each image and merges the latents before the decoder in a mid-fusion fashion. More details: https://terrastackai.github.io/terratorch/stable/guide/temporal_wrapper/" ] }, { "cell_type": "code", "id": "ae69d39a-857a-4392-b058-0f4b518edf6e", "metadata": { "id": "ae69d39a-857a-4392-b058-0f4b518edf6e", "scrolled": true }, "source": [ "pl.seed_everything(0)\n", "\n", "checkpoint_callback = pl.callbacks.ModelCheckpoint(\n", " dirpath=\"output/terramind_base_multicrop/checkpoints/\",\n", " mode=\"min\",\n", " monitor=\"val/loss\",\n", " filename=\"best-loss\",\n", " save_weights_only=True,\n", ")\n", "\n", "# Lightning Trainer\n", "trainer = pl.Trainer(\n", " accelerator=\"auto\",\n", " strategy=\"auto\",\n", " devices=1, # Lightning multi-gpu often fails in notebooks\n", " precision='16-mixed', # Speed up training\n", " num_nodes=1,\n", " logger=True, # Uses TensorBoard by default\n", " max_epochs=3, # For demos\n", " log_every_n_steps=1,\n", " callbacks=[checkpoint_callback],\n", " default_root_dir=\"output/terramind_base_multicrop\",\n", ")\n", "\n", "# Model\n", "model = terratorch.tasks.SemanticSegmentationTask(\n", " model_factory=\"EncoderDecoderFactory\",\n", " model_args={\n", " # TerraMind backbone\n", " \"backbone\": \"terramind_v1_small\",\n", " \"backbone_pretrained\": True,\n", " \"backbone_modalities\": [\"S2L2A\"],\n", " \"backbone_bands\": {\"S2L2A\": [\"BLUE\", \"GREEN\", \"RED\", \"NIR_NARROW\", \"SWIR_1\", \"SWIR_2\"]},\n", "\n", " # Apply temporal wrapper (params are passed with prefix backbone_temporal)\n", " \"backbone_use_temporal\": True,\n", " \"backbone_temporal_pooling\": \"concat\", # Defaults to \"mean\" which also supports flexible input lengths\n", " \"backbone_temporal_n_timestamps\": 3, # Required for pooling = concat\n", " \n", " # Necks \n", " \"necks\": [\n", " {\n", " \"name\": \"SelectIndices\",\n", " \"indices\": [2, 5, 8, 11] # indices for terramind_v1_tiny, small, and base\n", " # \"indices\": [5, 11, 17, 23] # indices for terramind_v1_large\n", " },\n", " {\n", " \"name\": \"ReshapeTokensToImage\",\n", " \"remove_cls_token\": False,\n", " },\n", " {\"name\": \"LearnedInterpolateToPyramidal\"}, \n", " ],\n", " \n", " # Decoder\n", " \"decoder\": \"UNetDecoder\",\n", " \"decoder_channels\": [512, 256, 128, 64],\n", " # Warning for Mac users: The UNetDecoder can lead to some failures because of the batch norms when training on MPS.\n", " # Use a FCN decoder instead or install TerraTorch from main/v1.2.5 for a fix.\n", " \n", " # Head\n", " \"head_dropout\": 0.1,\n", " \"num_classes\": 13,\n", " },\n", " \n", " loss=\"ce\",\n", " lr=1e-4, # The optimal learning rate varies between datasets, we recommend testing different once between 1e-5 and 1e-4. You can perform hyperparameter optimization using terratorch-iterate.\n", " optimizer=\"AdamW\",\n", " ignore_index=-1,\n", " freeze_backbone=True, # Speeds up fine-tuning\n", " freeze_decoder=False,\n", " plot_on_val=True,\n", " class_names=[\"Natural Vegetation\", \"Forest\", \"Corn\", \"Soybeans\", \"Wetlands\", \"Developed / Barren\", \"Open Water\", \"Winter Wheat\", \"Alfalfa\", \"Fallow / Idle Cropland\", \"Cotton\", \"Sorghum\", \"Other\"],\n", ")" ], "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Before starting the fine-tuning, you can start the tensorboard with:\n", "%load_ext tensorboard\n", "%tensorboard --logdir output" ], "id": "b88b33e20b52c23b", "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "27fee1e72be7349", "metadata": {}, "source": [ "# Training\n", "trainer.fit(model, datamodule=datamodule)" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "388aa3db0dc07460", "metadata": {}, "source": [ "# Let's test the fine-tuned model\n", "best_ckpt_path = \"output/terramind_base_multicrop/checkpoints/best-loss.ckpt\"\n", "trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path)\n", "\n", "# Note: This demo only trains for a few epochs by default, which does not result in good test metrics." ], "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Now we can use the model for predictions and plotting\n", "model = terratorch.tasks.SemanticSegmentationTask.load_from_checkpoint(\n", " best_ckpt_path,\n", " model_factory=model.hparams.model_factory,\n", " model_args=model.hparams.model_args,\n", ")\n", "\n", "test_loader = datamodule.test_dataloader()\n", "with torch.no_grad():\n", " batch = next(iter(test_loader))\n", " images = batch[\"image\"]\n", " images = images.to(model.device)\n", " masks = batch[\"mask\"].numpy()\n", "\n", " with torch.no_grad():\n", " outputs = model(images)\n", "\n", " preds = torch.argmax(outputs.output, dim=1).cpu().numpy()\n", "\n", "for i in range(5):\n", " sample = {\n", " \"image\": batch[\"image\"][i].cpu(),\n", " \"mask\": batch[\"mask\"][i],\n", " \"prediction\": preds[i],\n", " }\n", " test_dataset.plot(sample)\n", " plt.show()\n", "\n", "# Note: This demo only trains for 5 epochs by default, which does not result in good predictions." ], "id": "e303087bc658b83f", "outputs": [], "execution_count": null }, { "cell_type": "markdown", "id": "a0c88e2d5ab78020", "metadata": {}, "source": [ "# Fine-tuning via CLI\n", "\n", "Locally, run the fine-tuning command in your terminal rather than in this notebook.\n", "\n", "In Colab, you want to restart the session to free up GPU memory and set `freeze_backbone: true` to avoid OOM errors." ] }, { "cell_type": "code", "id": "bdbf05ebc81b9998", "metadata": {}, "source": [ "# Download config\n", "!wget https://raw.githubusercontent.com/IBM/terramind/refs/heads/main/configs/terramind_v1_base_multitemporal_crop.yaml" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "id": "12d19f74-e35a-4087-b7ce-c798b60d5173", "metadata": {}, "source": [ "# Run fine-tuning\n", "!terratorch fit -c terramind_v1_base_multitemporal_crop.yaml" ], "outputs": [], "execution_count": null } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3.11", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }