# lightning.pytorch==2.1.1 seed_everything: 42 trainer: accelerator: auto strategy: auto devices: auto num_nodes: 1 precision: 16-mixed logger: true callbacks: - class_path: RichProgressBar - class_path: LearningRateMonitor init_args: logging_interval: epoch max_epochs: 100 log_every_n_steps: 5 default_root_dir: output/terramind_base_multicrop/ data: class_path: terratorch.datamodules.MultiTemporalCropClassificationDataModule init_args: batch_size: 8 num_workers: 4 data_root: multi-temporal-crop-classification-subset expand_temporal_dimension: true use_metadata: false reduce_zero_label: true train_transform: - class_path: terratorch.datasets.transforms.FlattenTemporalIntoChannels - class_path: albumentations.D4 - class_path: albumentations.pytorch.transforms.ToTensorV2 - class_path: terratorch.datasets.transforms.UnflattenTemporalFromChannels init_args: n_timesteps: 3 model: class_path: terratorch.tasks.SemanticSegmentationTask init_args: model_factory: EncoderDecoderFactory model_args: backbone: terramind_v1_base backbone_pretrained: true backbone_modalities: - S2L2A backbone_bands: S2L2A: ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"] # Apply temporal wrapper (docs: https://terrastackai.github.io/terratorch/stable/guide/temporal_wrapper/) backbone_use_temporal: true backbone_temporal_pooling: concat # Defaults to "mean" which also supports flexible input lengths backbone_temporal_n_timestamps: 3 # Required for pooling = concat necks: - name: SelectIndices indices: [2, 5, 8, 11] # indices for terramind_v1_tiny, small, and base # indices: [5, 11, 17, 23] # large version - name: ReshapeTokensToImage remove_cls_token: False - name: LearnedInterpolateToPyramidal decoder: UNetDecoder decoder_channels: [512, 256, 128, 64] head_dropout: 0.1 num_classes: 13 loss: ce ignore_index: -1 freeze_backbone: false freeze_decoder: false class_names: ["Natural Vegetation", "Forest", "Corn", "Soybeans", "Wetlands", "Developed / Barren", "Open Water", "Winter Wheat", "Alfalfa", "Fallow / Idle Cropland", "Cotton", "Sorghum", "Other"] optimizer: class_path: torch.optim.AdamW init_args: lr: 2.e-5 lr_scheduler: class_path: ReduceLROnPlateau init_args: monitor: val/loss factor: 0.5 patience: 5