{
"cells": [
{
"cell_type": "markdown",
"id": "0f0f79f6",
"metadata": {},
"source": [
"\n",
" \n",
"\n",
"\n",
"# AlphaGenome Variant Scoring Tutorial\n",
"\n",
"This notebook demonstrates how to score genetic variants using the AlphaGenome PyTorch model.\n",
"\n",
"## Prerequisites\n",
"\n",
"Before running this notebook, you need to:\n",
"\n",
"1. **Download required files** - See [variant_scoring/README.md](../../src/alphagenome_pytorch/variant_scoring/README.md) for detailed instructions:\n",
" - Model weights (`ag_pytorch_model.pt`)\n",
" - Track means (`track_means.pt`)\n",
" - Reference genome (hg38.fa with .fai index)\n",
" - Gene annotations (gencode.v46.annotation.parquet)\n",
" - PolyA annotations (gencode.v46.polyAs.linked.parquet) - optional\n",
" - Track metadata (track_metadata.parquet)\n",
"\n",
"2. **Set up annotation files** - Run the preprocessing scripts:\n",
" ```bash\n",
" # See README for download URLs and conversion commands\n",
" python scripts/convert_gtf_to_parquet.py --input gencode.v46.annotation.gtf --output gencode.v46.annotation.parquet\n",
" python scripts/preprocess_polya.py --metadata gencode.v46.metadata.PolyA_feature --gtf gencode.v46.annotation.parquet --output gencode.v46.polyAs.linked.parquet\n",
" ```\n",
"\n",
"3. **Update file paths** in the cells below to match your local setup."
]
},
{
"cell_type": "markdown",
"id": "8de9a680",
"metadata": {},
"source": [
"## 1. Import Required Modules"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f155e5ab",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from alphagenome_pytorch.model import AlphaGenome\n",
"from alphagenome_pytorch.config import DtypePolicy\n",
"from alphagenome_pytorch.variant_scoring import (\n",
" VariantScoringModel, Variant, Interval,\n",
" CenterMaskScorer, OutputType, AggregationType,\n",
" get_recommended_scorers,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "b61ff637",
"metadata": {},
"source": [
"## 2. Load AlphaGenome Model\n",
"\n",
"We load the model in **float32** precision for production use. This provides:\n",
"- Cleaner signals with less quantization noise\n",
"- Standard precision for downstream analysis\n",
"\n",
"**Note**: Use `compute_dtype=torch.bfloat16` only if you need exact parity with the JAX/API reference for testing or to fit the full model (such as 1MB context) on smaller hardware."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ad47f012",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✓ Model loaded on GPU\n"
]
}
],
"source": [
"TORCH_WEIGHTS_PATH = '../../ag_pytorch_model.pth\n",
"\n",
"\n",
"\n",
"\n",
"'\n",
"\n",
"\n",
"# Initialize model in float32 (production use)\n",
"model = AlphaGenome(num_organisms=2, dtype_policy=DtypePolicy.full_float32())\n",
"checkpoint = torch.load(TORCH_WEIGHTS_PATH, map_location='cpu')\n",
"model.load_state_dict(checkpoint['state_dict'])\n",
"\n",
"model.eval()\n",
"\n",
"# Move to GPU if available\n",
"if torch.cuda.is_available():\n",
" model.cuda()\n",
" print(\"✓ Model loaded on GPU\")\n",
"else:\n",
" print(\"✓ Model loaded on CPU\")"
]
},
{
"cell_type": "markdown",
"id": "cb688d74",
"metadata": {},
"source": [
"## 3. Create Variant Scoring Model\n",
"\n",
"The `VariantScoringModel` wraps the base model with variant scoring functionality.\n",
"It requires:\n",
"- **fasta_path**: Reference genome for sequence extraction\n",
"- **gtf_path**: Gene annotations for gene-based scorers\n",
"- **polya_path**: PolyA sites for PolyadenylationScorer (optional)\n",
"- **track_metadata**: Track names and metadata for tidy output"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "3521be18",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✓ Scoring model initialized\n"
]
}
],
"source": [
"# Update these paths to match your local setup\n",
"# See README for instructions on downloading and converting annotations\n",
"scoring_model = VariantScoringModel(\n",
" model,\n",
" fasta_path=\"../../data/annotations/hg38.fa\",\n",
" gtf_path=\"../../data/annotations/gencode.v46.annotation.parquet\",\n",
" polya_path=\"../../data/annotations/gencode.v46.polyAs.linked.parquet\", # Optional\n",
" default_organism=\"human\",\n",
")\n",
"\n",
"# Load track metadata for annotated results\n",
"# These can be extracted from the model using scripts/extract_track_metadata.py\n",
"scoring_model.load_all_metadata('../../track_metadata.parquet')\n",
"print(\"✓ Scoring model initialized\")"
]
},
{
"cell_type": "markdown",
"id": "c88081db",
"metadata": {},
"source": [
"## 4. Define Variant and Interval\n",
"\n",
"- **Variant**: Uses VCF-style 1-based coordinates (chrom:pos:ref>alt)\n",
"- **Interval**: Model requires 131,072bp (128KB) centered on the variant"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3c034745",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Variant: chr22:36201698:A>C\n",
"Interval: chr22:36136162-36267234\n",
"Interval width: 131,072bp\n"
]
}
],
"source": [
"# Define a test variant (chr22:36201698:A>C)\n",
"variant = Variant.from_str('chr22:36201698:A>C')\n",
"\n",
"# Create interval centered on variant\n",
"interval = Interval.centered_on('chr22', 36201698, width=\"100KB\")\n",
"\n",
"print(f\"Variant: {variant}\")\n",
"print(f\"Interval: {interval}\")\n",
"print(f\"Interval width: {interval.width:,}bp\")"
]
},
{
"cell_type": "markdown",
"id": "f104e3a8",
"metadata": {},
"source": [
"## 5. Score Variant with Recommended Scorers\n",
"\n",
"The `get_recommended_scorers()` function returns 19 pre-configured scorers that match the official AlphaGenome API:\n",
"\n",
"- **CenterMaskScorer** (12 configs): ATAC, DNase, CAGE, PRO-cap, ChIP-TF, ChIP-histone\n",
"- **GeneMaskLFCScorer** (1): RNA-seq log fold change\n",
"- **GeneMaskActiveScorer** (1): RNA-seq active allele\n",
"- **GeneMaskSplicingScorer** (2): Splice sites and usage\n",
"- **SpliceJunctionScorer** (1): Junction disruption\n",
"- **ContactMapScorer** (1): 3D chromatin contacts\n",
"- **PolyadenylationScorer** (1): Polyadenylation QTLs (human only)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "adf34858",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✓ Scored variant with 19 scorers\n"
]
}
],
"source": [
"# Score variant with all 19 recommended scorers\n",
"scores = scoring_model.score_variant(\n",
" interval=interval,\n",
" variant=variant,\n",
" scorers=get_recommended_scorers('human'),\n",
" organism='human',\n",
" to_cpu=True, # Move results to CPU for analysis\n",
")\n",
"\n",
"print(f\"✓ Scored variant with {len(scores)} scorers\")"
]
},
{
"cell_type": "markdown",
"id": "b3a72aa2",
"metadata": {},
"source": [
"## 6. Convert Results to Tidy DataFrame\n",
"\n",
"The `tidy_scores()` method converts results to a long-format DataFrame with:\n",
"- One row per track\n",
"- Variant and gene metadata\n",
"- Track names and metadata (biosample, assay type, etc.)\n",
"- Raw scores and quantile-normalized scores"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "62f387d8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Shape: (38780, 21)\n"
]
},
{
"data": {
"text/html": [
"
| \n", " | variant_id | \n", "scored_interval | \n", "gene_id | \n", "gene_name | \n", "gene_type | \n", "gene_strand | \n", "junction_Start | \n", "junction_End | \n", "output_type | \n", "variant_scorer | \n", "... | \n", "track_strand | \n", "Assay title | \n", "ontology_curie | \n", "biosample_name | \n", "biosample_type | \n", "transcription_factor | \n", "histone_mark | \n", "gtex_tissue | \n", "raw_score | \n", "track_index | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "chr22:36201698:A>C | \n", "chr22:36136162-36267234 | \n", "None | \n", "None | \n", "None | \n", "None | \n", "NaN | \n", "NaN | \n", "atac | \n", "CenterMaskScorer(output=atac, width=501, agg=d... | \n", "... | \n", ". | \n", "ATAC-seq | \n", "CL:0000084 | \n", "T-cell | \n", "primary_cell | \n", "None | \n", "None | \n", "None | \n", "-0.062569 | \n", "0 | \n", "
| 1 | \n", "chr22:36201698:A>C | \n", "chr22:36136162-36267234 | \n", "None | \n", "None | \n", "None | \n", "None | \n", "NaN | \n", "NaN | \n", "atac | \n", "CenterMaskScorer(output=atac, width=501, agg=d... | \n", "... | \n", ". | \n", "ATAC-seq | \n", "CL:0000100 | \n", "motor neuron | \n", "in_vitro_differentiated_cells | \n", "None | \n", "None | \n", "None | \n", "0.004393 | \n", "1 | \n", "
| 2 | \n", "chr22:36201698:A>C | \n", "chr22:36136162-36267234 | \n", "None | \n", "None | \n", "None | \n", "None | \n", "NaN | \n", "NaN | \n", "atac | \n", "CenterMaskScorer(output=atac, width=501, agg=d... | \n", "... | \n", ". | \n", "ATAC-seq | \n", "CL:0000236 | \n", "B cell | \n", "primary_cell | \n", "None | \n", "None | \n", "None | \n", "0.007725 | \n", "2 | \n", "
| 3 | \n", "chr22:36201698:A>C | \n", "chr22:36136162-36267234 | \n", "None | \n", "None | \n", "None | \n", "None | \n", "NaN | \n", "NaN | \n", "atac | \n", "CenterMaskScorer(output=atac, width=501, agg=d... | \n", "... | \n", ". | \n", "ATAC-seq | \n", "CL:0000623 | \n", "natural killer cell | \n", "primary_cell | \n", "None | \n", "None | \n", "None | \n", "-0.057389 | \n", "3 | \n", "
| 4 | \n", "chr22:36201698:A>C | \n", "chr22:36136162-36267234 | \n", "None | \n", "None | \n", "None | \n", "None | \n", "NaN | \n", "NaN | \n", "atac | \n", "CenterMaskScorer(output=atac, width=501, agg=d... | \n", "... | \n", ". | \n", "ATAC-seq | \n", "CL:0000624 | \n", "CD4-positive, alpha-beta T cell | \n", "primary_cell | \n", "None | \n", "None | \n", "None | \n", "-0.059652 | \n", "4 | \n", "
5 rows × 21 columns
\n", "| \n", " | variant_id | \n", "scored_interval | \n", "gene_id | \n", "gene_name | \n", "gene_type | \n", "gene_strand | \n", "junction_Start | \n", "junction_End | \n", "output_type | \n", "variant_scorer | \n", "... | \n", "track_strand | \n", "Assay title | \n", "ontology_curie | \n", "biosample_name | \n", "biosample_type | \n", "transcription_factor | \n", "histone_mark | \n", "gtex_tissue | \n", "raw_score | \n", "track_index | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "chr22:36201698:A>C | \n", "chr22:36136162-36267234 | \n", "None | \n", "None | \n", "None | \n", "None | \n", "None | \n", "None | \n", "atac | \n", "CenterMaskScorer(output=atac, width=2001, agg=... | \n", "... | \n", ". | \n", "ATAC-seq | \n", "CL:0000084 | \n", "T-cell | \n", "primary_cell | \n", "None | \n", "None | \n", "None | \n", "-0.039216 | \n", "0 | \n", "
| 1 | \n", "chr22:36201698:A>C | \n", "chr22:36136162-36267234 | \n", "None | \n", "None | \n", "None | \n", "None | \n", "None | \n", "None | \n", "atac | \n", "CenterMaskScorer(output=atac, width=2001, agg=... | \n", "... | \n", ". | \n", "ATAC-seq | \n", "CL:0000100 | \n", "motor neuron | \n", "in_vitro_differentiated_cells | \n", "None | \n", "None | \n", "None | \n", "-0.001569 | \n", "1 | \n", "
| 2 | \n", "chr22:36201698:A>C | \n", "chr22:36136162-36267234 | \n", "None | \n", "None | \n", "None | \n", "None | \n", "None | \n", "None | \n", "atac | \n", "CenterMaskScorer(output=atac, width=2001, agg=... | \n", "... | \n", ". | \n", "ATAC-seq | \n", "CL:0000236 | \n", "B cell | \n", "primary_cell | \n", "None | \n", "None | \n", "None | \n", "0.022250 | \n", "2 | \n", "
| 3 | \n", "chr22:36201698:A>C | \n", "chr22:36136162-36267234 | \n", "None | \n", "None | \n", "None | \n", "None | \n", "None | \n", "None | \n", "atac | \n", "CenterMaskScorer(output=atac, width=2001, agg=... | \n", "... | \n", ". | \n", "ATAC-seq | \n", "CL:0000623 | \n", "natural killer cell | \n", "primary_cell | \n", "None | \n", "None | \n", "None | \n", "-0.029534 | \n", "3 | \n", "
| 4 | \n", "chr22:36201698:A>C | \n", "chr22:36136162-36267234 | \n", "None | \n", "None | \n", "None | \n", "None | \n", "None | \n", "None | \n", "atac | \n", "CenterMaskScorer(output=atac, width=2001, agg=... | \n", "... | \n", ". | \n", "ATAC-seq | \n", "CL:0000624 | \n", "CD4-positive, alpha-beta T cell | \n", "primary_cell | \n", "None | \n", "None | \n", "None | \n", "-0.026243 | \n", "4 | \n", "
5 rows × 21 columns
\n", "| \n", " | gene_name | \n", "gene_id | \n", "raw_score | \n", "
|---|---|---|---|
| 0 | \n", "MTCO1P20 | \n", "ENSG00000233764 | \n", "0.005431 | \n", "
| 768 | \n", "MTCO2P20 | \n", "ENSG00000231576 | \n", "0.004852 | \n", "
| 1536 | \n", "MTATP6P20 | \n", "ENSG00000237948 | \n", "0.001218 | \n", "
| 2304 | \n", "MTCO3P20 | \n", "ENSG00000225557 | \n", "0.002390 | \n", "
| 3072 | \n", "MTCYBP34 | \n", "ENSG00000237129 | \n", "0.003161 | \n", "
| 3840 | \n", "MTND1P10 | \n", "ENSG00000229088 | \n", "-0.000389 | \n", "
| 4608 | \n", "ENSG00000288778 | \n", "ENSG00000288778 | \n", "0.040432 | \n", "
| 5376 | \n", "APOL1 | \n", "ENSG00000100342 | \n", "0.025417 | \n", "
| 6144 | \n", "ENSG00000279805 | \n", "ENSG00000279805 | \n", "0.001093 | \n", "
| 6912 | \n", "APOL3 | \n", "ENSG00000128284 | \n", "0.002691 | \n", "
| 7680 | \n", "APOL4 | \n", "ENSG00000100336 | \n", "0.002197 | \n", "
| 8448 | \n", "APOL2 | \n", "ENSG00000128335 | \n", "0.011930 | \n", "