{ "cells": [ { "cell_type": "markdown", "id": "cell-00", "metadata": {}, "source": [ "# Finetuning AlphaGenome's encoder on MPRA data\n", "\n", "This notebook demos how to train an MLP regression head on a **frozen** AlphaGenome encoder for one LentiMPRA cell line (HepG2) from\n", "[Agarwal et al. 2025](https://www.nature.com/articles/s41586-024-08430-9).\n", "\n", "**Why encoder-only?**\n", "MPRA constructs are short (~200–300 bp). The CNN encoder alone handles arbitrary short sequences and produces\n", "contextual 1536-dimensional embeddings at 128 bp resolution. This approach leverages the rich sequence representations learned by large-scale generalist models while adapting them to specific regulatory tasks through task-specific prediction heads.\n", "\n", "**Steps:**\n", "1. Install dependencies\n", "2. Download HepG2 LentiMPRA data\n", "3. Define dataset, head, and training utilities\n", "4. Configure hyperparameters\n", "5. Load pretrained AlphaGenome backbone and freeze it\n", "6. Create `MPRAHead` (LayerNorm → flatten → MLP → scalar output)\n", "7. Build datasets and DataLoaders\n", "8. Train (encoder-only forward, only head params updated)\n", "9. Plot loss / Pearson curves and evaluate on test set" ] }, { "cell_type": "markdown", "id": "cell-01", "metadata": {}, "source": [ "## 1. Install dependencies" ] }, { "cell_type": "code", "execution_count": 1, "id": "cell-03", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PyTorch: 2.10.0+cu128\n", "CUDA available: True\n", "GPU: NVIDIA H100 80GB HBM3\n" ] } ], "source": [ "import torch\n", "\n", "print(f\"PyTorch: {torch.__version__}\")\n", "print(f\"CUDA available: {torch.cuda.is_available()}\")\n", "if torch.cuda.is_available(): print(f\"GPU: {torch.cuda.get_device_name(0)}\")" ] }, { "cell_type": "markdown", "id": "cell-04", "metadata": {}, "source": [ "## 2. Download HepG2 LentiMPRA data\n", "\n", "Training data (HepG2 LentiMPRA) is downloaded from\n", "[human_legnet](https://github.com/autosome-ru/human_legnet) (`datasets/original/HepG2.tsv`).\n", "The TSV contains one row per unique MPRA insert with columns:\n", "- `seq`: DNA sequence of the regulatory element\n", "- `fold`: cross-validation fold (1–10)\n", "- `rev`: strand (0 = forward; we keep forward only)\n", "- `mean_value`: mean log-scale activity score across replicates" ] }, { "cell_type": "code", "execution_count": 2, "id": "cell-05", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloaded to ./data/legnet_lentimpra/HepG2.tsv\n", "Total forward-strand samples : 122,926\n", "Training samples (folds 2-9) : 98,336\n" ] }, { "data": { "text/html": [ "
| \n", " | seq_id | \n", "seq | \n", "mean_value | \n", "fold | \n", "rev | \n", "
|---|---|---|---|---|---|
| 0 | \n", "DNasePeakNoPromoter1 | \n", "AGGACCGGATCAACTCCTAACCCTAACCCTAACCCTAACCCTAACC... | \n", "-0.675 | \n", "2 | \n", "0 | \n", "
| 1 | \n", "DNasePeakNoPromoter1_Reversed: | \n", "AGGACCGGATCAACTCGTTCTCCTCAGCACAGACCCGGAGAGCACC... | \n", "-0.274 | \n", "2 | \n", "0 | \n", "
| 2 | \n", "DNasePeakNoPromoter10 | \n", "AGGACCGGATCAACTTGTTTCTTAGGAAAGGCGGCCAACCCAGGGT... | \n", "0.908 | \n", "10 | \n", "0 | \n", "
| 3 | \n", "DNasePeakNoPromoter10_Reversed: | \n", "AGGACCGGATCAACTGGTTAGAGCTCAAAGGTCACTCCGATGACAC... | \n", "0.276 | \n", "10 | \n", "0 | \n", "
| 4 | \n", "DNasePeakNoPromoter100 | \n", "AGGACCGGATCAACTATTACTCACACAAGACACACATTGTCTGCCG... | \n", "1.295 | \n", "9 | \n", "0 | \n", "