{ "cells": [ { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2026-01-08T10:37:04.347617Z", "start_time": "2026-01-08T10:36:53.909625Z" } }, "source": [ "import os\n", "import torch\n", "import numpy as np\n", "import rioxarray as rxr\n", "import matplotlib.pyplot as plt\n", "from terratorch import FULL_MODEL_REGISTRY\n", "from plotting_utils import plot_modality\n", "\n", "# Select device\n", "if torch.cuda.is_available():\n", " device = 'cuda' \n", "elif torch.backends.mps.is_available():\n", " device = 'mps'\n", "else:\n", " device = 'cpu'" ], "outputs": [], "execution_count": 1 }, { "metadata": { "ExecuteTime": { "end_time": "2026-01-06T16:12:30.608134Z", "start_time": "2026-01-06T16:12:30.571524Z" } }, "cell_type": "code", "source": [ "# Load input data\n", "examples = [\n", " '38D_378R_2_3.tif',\n", " '282D_485L_3_3.tif',\n", " '433D_629L_3_1.tif',\n", " '637U_59R_1_3.tif',\n", " '609U_541L_3_0.tif',\n", "]\n", "\n", "# Select example between 0 and 4\n", "file = examples[0]\n", "\n", "# Define modalities\n", "modalities = ['S2L2A', 'S1RTC', 'DEM', 'LULC', 'NDVI']\n", "data = {m: rxr.open_rasterio(f'../examples/{m}/{file}') for m in modalities}\n", "# Tensor with shape [B, C, 224, 224]\n", "data = {\n", " k: torch.Tensor(v.values, device='cpu').unsqueeze(0)\n", " for k, v in data.items()\n", "}" ], "id": "43b2c3eb41e450d8", "outputs": [], "execution_count": 2 }, { "metadata": {}, "cell_type": "code", "source": [ "# Run any-to-any generation (this can take a while without a GPU, consider reducing timesteps for faster inference)\n", "outputs = {}\n", "for m in modalities:\n", " print(f'Processing {m}')\n", " out_modalities = modalities[:]\n", " out_modalities.remove(m)\n", " \n", " # Init model\n", " model = FULL_MODEL_REGISTRY.build(\n", " 'terramind_v1_base_generate',\n", " modalities=[m],\n", " output_modalities=out_modalities,\n", " pretrained=True,\n", " standardize=True,\n", " )\n", " model = model.to(device)\n", " model.eval()\n", " \n", " input = data[m].clone().to(device)\n", " with torch.no_grad():\n", " generated = model(input, verbose=True, timesteps=10)\n", " outputs[m] = generated" ], "id": "f399d4fb83a5adfa", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Plot any-to-any generations\n", "n_mod = len(modalities)\n", "fig, axes = plt.subplots(nrows=n_mod, ncols=n_mod + 1, figsize=[12, 10])\n", "axes[0][0].set_title('Input')\n", "for i, m in enumerate(modalities):\n", " axes[0][i + 1].set_title(m)\n", "\n", "for (m, input), ax in zip(data.items(), axes):\n", " plot_modality(m, input, ax=ax[0])\n", " for a in ax:\n", " a.axis('off')\n", "\n", "for k, m_output in enumerate(outputs.values()):\n", " for m, out in m_output.items(): \n", " j = modalities.index(m) + 1\n", " plot_modality(m, out, ax=axes[k][j])\n", " \n", "plt.savefig(f'any_to_any_{os.path.basename(file)}.pdf')\n", "plt.show()" ], "id": "9ced943a8b005719", "outputs": [], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }