Named Outputs ============= AlphaGenome predicts thousands of genomic tracks — chromatin accessibility, transcription, histone modifications, TF binding, and more. The named outputs API lets you work with these tracks by biological meaning (tissue, assay, ontology) rather than raw channel indices. .. contents:: On this page :local: :depth: 1 Setup ----- .. code-block:: python from alphagenome_pytorch import AlphaGenome from alphagenome_pytorch.named_outputs import TrackMetadataCatalog model = AlphaGenome.from_pretrained("weights.pth", device="cuda") catalog = TrackMetadataCatalog.load_builtin() model.set_track_metadata_catalog(catalog) out = model.predict(dna_onehot, organism_index=0, named_outputs=True) ``out`` is a ``NamedOutputs`` object. Each output head (``out.atac``, ``out.rna_seq``, etc.) is a ``NamedOutputHead`` that holds predictions at one or more resolutions. Index by resolution to get a ``NamedTrackTensor`` — a tensor bundled with per-channel metadata: .. code-block:: python out.atac # NamedOutputHead (all resolutions) out.atac[128] # NamedTrackTensor at 128bp resolution out.atac[128].tensor # the raw torch.Tensor out.atac[128].tracks # tuple of TrackMetadata (one per channel) Output types and track counts ----------------------------- .. list-table:: :header-rows: 1 :widths: 18 10 12 12 12 * - Output - Resolutions - Human tracks - Mouse tracks - Raw dimension * - ``atac`` - 1bp, 128bp - 167 - 18 - 256 * - ``dnase`` - 1bp, 128bp - 305 - 67 - 384 * - ``procap`` - 1bp, 128bp - 12 - --- - 128 * - ``cage`` - 1bp, 128bp - 546 - 188 - 640 * - ``rna_seq`` - 1bp, 128bp - 667 - 173 - 768 * - ``chip_tf`` - 128bp - 1617 - 127 - 1664 * - ``chip_histone`` - 128bp - 1116 - 183 - 1152 * - ``contact_maps`` - 128bp - 28 - 8 - 28 Human/Mouse columns show the number of real (non-padding) tracks. The "Raw dimension" is the full tensor channel count — both organisms share the same dimensions, with padding filling the gap. Named outputs strip padding by default (see :ref:`padding-tracks`). Track metadata fields --------------------- Each track carries metadata that you can filter on: .. list-table:: :header-rows: 1 :widths: 25 75 * - Field - Description * - ``track_name`` - Human-readable identifier (e.g. ``"CL:0000084 ATAC-seq"``) * - ``strand`` - ``"+"``, ``"-"``, or ``"."`` (unstranded) * - ``ontology_curie`` - Cell/tissue ontology term (e.g. ``"UBERON:0002107"``) * - ``biosample_name`` - Sample name (e.g. ``"liver"``, ``"HeLa"``) * - ``biosample_type`` - ``"tissue"``, ``"cell_line"``, ``"primary_cell"``, etc. * - ``biosample_life_stage`` - ``"adult"``, ``"embryo"``, etc. * - ``assay_title`` - Assay description (e.g. ``"total RNA-seq"``) * - ``gtex_tissue`` - GTEx tissue name (e.g. ``"Artery_Aorta"``) * - ``histone_mark`` - Histone modification (e.g. ``"H3K27ac"``) * - ``transcription_factor`` - TF name (e.g. ``"CTCF"``) * - ``data_source`` - Data origin (e.g. ``"encode"``, ``"gtex"``) * - ``genetically_modified`` - Whether the sample was genetically modified * - ``nonzero_mean`` - Track mean (used for normalization) Access fields directly on ``TrackMetadata`` objects: .. code-block:: python track = out.atac[128].tracks[0] track.ontology_curie # Direct attribute access track.get('biosample_type') # Safe access (returns None if missing) track.has('genetically_modified') # True if field exists and is not None track.to_dict() # Serialize to plain dict Filtering tracks ---------------- The core method is ``.select()``, available on ``NamedTrackTensor``, ``NamedOutputHead``, and ``NamedOutputs``. It returns a new object with only the matching tracks — both the tensor and metadata are sliced together, so they stay in sync. By metadata field ^^^^^^^^^^^^^^^^^ .. code-block:: python # By assay total_rna = out.rna_seq[1].select(assay_title='total RNA-seq') total_rna.tensor # already sliced to matching channels total_rna.tracks # metadata in sync # By histone mark h3k27ac = out.chip_histone[128].select(histone_mark='H3K27ac') # By tissue aorta = out.splice_junctions[1].select(gtex_tissue='Artery_Aorta') # By ontology hepg2_rna = out.rna_seq[1].select(ontology_curie='EFO:0001187') liver_rna = out.rna_seq[1].select(ontology_curie='UBERON:0002107') Multiple conditions ^^^^^^^^^^^^^^^^^^^ .. code-block:: python # All kwargs are AND-ed together ctcf_unmodified = out.chip_tf[128].select( transcription_factor='CTCF', genetically_modified=None, # field=None matches missing/null values ) "Any of" matching ^^^^^^^^^^^^^^^^^ .. code-block:: python # Pass a list for OR logic within one field ctcf_or_foxa1 = out.chip_tf[128].select( transcription_factor=['CTCF', 'FOXA1'] ) Custom predicate ^^^^^^^^^^^^^^^^ .. code-block:: python # When kwargs aren't expressive enough, use a predicate function liver_related = out.rna_seq[128].select( predicate=lambda t: 'liver' in (t.get('biosample_name') or '').lower() ) Strand filtering ^^^^^^^^^^^^^^^^ .. code-block:: python out.rna_seq[1].select(strand='+') # positive out.rna_seq[1].select(strand='-') # negative out.rna_seq[1].select(strand='.') # unstranded out.rna_seq[1].select(strand=['+', '-']) # stranded (either) out.rna_seq[1].select(strand=['-', '.']) # non-positive out.rna_seq[1].select(strand=['+', '.']) # non-negative Masks and indices ----------------- For loss computation or manual tensor slicing, you can get boolean masks or integer indices without creating a new ``NamedTrackTensor``: .. code-block:: python # Boolean mask — useful for element-wise loss masking mask = out.rna_seq[128].mask(ontology_curie='UBERON:0002107') loss = ((preds - targets) ** 2 * mask).mean() # Integer indices — useful for gather/index_select indices = out.chip_tf[128].indices(transcription_factor='CTCF') selected = preds[..., indices] Resolution-independent queries ------------------------------ Track metadata (names, ontology, biosample, strand, etc.) is the same at all resolutions — only the tensor's sequence dimension differs. You can query metadata on the ``NamedOutputHead`` without choosing a resolution: .. code-block:: python head = out.rna_seq head.num_tracks # same at 1bp and 128bp head.to_dataframe() # pandas DataFrame, no resolution needed head.indices(strand='+') head.mask(strand='+') Filter at the head level to apply to all resolutions at once: .. code-block:: python plus_strand = out.rna_seq.select(strand='+') plus_strand[1].tensor # 1bp, already filtered plus_strand[128].tensor # 128bp, already filtered Both orderings produce identical results: .. code-block:: python # These are equivalent out.rna_seq.select(strand='+')[128].tensor out.rna_seq[128].select(strand='+').tensor Cross-head filtering -------------------- Filter across all heads and resolutions at once: .. code-block:: python tissue_tracks = out.select(biosample_type='tissue') for (output_name, resolution), ntt in tissue_tracks.items(): print(f"{output_name}@{resolution}bp: {ntt.num_tracks} tracks") Variant effect scoring ---------------------- Arithmetic operators (``+``, ``-``, ``*``, ``/``, ``abs``, negation) preserve metadata from the left operand: .. code-block:: python ref = model.predict(ref_onehot, organism_index=0, named_outputs=True) alt = model.predict(alt_onehot, organism_index=0, named_outputs=True) # Filter and diff in one expression chip_diff = ( alt.chip_histone[128].select(strand=['-', '.']) - ref.chip_histone[128].select(strand=['-', '.']) ) chip_diff.tensor # the difference tensor chip_diff.tracks # metadata preserved from alt .. _padding-tracks: Padding tracks -------------- The raw model tensors include padding channels so that both organisms share the same tensor dimensions. For example, the ATAC head always outputs 256 channels, but only 167 correspond to real human experiments — the remaining 89 are padding placeholders. **Named outputs strip padding by default:** .. code-block:: python out = model.predict(dna_onehot, organism_index=0, named_outputs=True) out.atac[128].num_tracks # 167 (real tracks only) This matches the behavior of the official JAX AlphaGenome API, which also strips padding before exposing metadata to users. Keeping padding (for training) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ During training you often need the full tensor shape for loss computation with a mask rather than slicing out channels: .. code-block:: python out = model.predict( dna_onehot, organism_index=0, named_outputs=True, include_padding=True, ) out.atac[128].num_tracks # 256 (includes padding) # Boolean mask: True = real track, False = padding mask = out.atac[128].padding_mask() loss = ((preds - targets) ** 2 * mask).mean() # Resolution-independent mask mask = out.atac.padding_mask() Stripping padding after the fact ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. code-block:: python out.strip_padding() # All heads → new NamedOutputs out.atac.strip_padding() # One head → new NamedOutputHead out.atac[128].strip_padding() # One tensor → new NamedTrackTensor Checking individual tracks ^^^^^^^^^^^^^^^^^^^^^^^^^^ .. code-block:: python for track in out.atac[128].tracks: print(track.track_name, track.is_padding) Loading and building metadata catalogs -------------------------------------- The built-in catalog ships with the package and contains metadata extracted from the official JAX AlphaGenome checkpoint: .. code-block:: python from alphagenome_pytorch.named_outputs import TrackMetadataCatalog catalog = TrackMetadataCatalog.load_builtin('human') # or 'mouse' # Inspect what's available catalog.outputs(organism=0) # ['atac', 'cage', 'chip_histone', ...] catalog.organisms # [0] You can also load from files or build programmatically: .. code-block:: python # From parquet / CSV / TSV catalog = TrackMetadataCatalog.from_file("my_metadata.parquet") # From a pandas DataFrame import pandas as pd df = pd.DataFrame({ 'track_name': ['sample_A', 'sample_B'], 'output_type': ['atac', 'atac'], 'organism': [0, 0], 'biosample_type': ['tissue', 'cell_line'], }) catalog = TrackMetadataCatalog.from_dataframe(df) # Programmatically from alphagenome_pytorch.named_outputs import TrackMetadata catalog = TrackMetadataCatalog() catalog.add_tracks( "rna_seq", [ TrackMetadata(0, "rna_seq", 0, "UBERON:0000948 total RNA-seq", extras={"strand": "+", "assay_title": "total RNA-seq"}), TrackMetadata(1, "rna_seq", 0, "UBERON:0000948 total RNA-seq", extras={"strand": "-", "assay_title": "total RNA-seq"}), ], organism=0, ) Exporting to pandas ^^^^^^^^^^^^^^^^^^^ .. code-block:: python df = out.rna_seq[128].to_dataframe() # One head, one resolution df = out.rna_seq.to_dataframe() # One head (resolution-independent) Allow empty results ^^^^^^^^^^^^^^^^^^^ By default, ``.select()`` raises ``ValueError`` if no tracks match. Pass ``allow_empty=True`` to get an empty result instead: .. code-block:: python result = out.atac[128].select(biosample_name='nonexistent', allow_empty=True) result.num_tracks # 0 .. _jax-comparison: Comparison with JAX AlphaGenome ------------------------------- This section is for users migrating from the JAX ``alphagenome`` / ``alphagenome_research`` packages. Loading metadata ^^^^^^^^^^^^^^^^ .. list-table:: :header-rows: 1 :widths: 50 50 * - JAX - PyTorch * - .. code-block:: python model = dna_model.create_from_kaggle('all_folds') metadata = model.output_metadata( dna_model.Organism.HOMO_SAPIENS ) - .. code-block:: python catalog = TrackMetadataCatalog.load_builtin('human') model.set_track_metadata_catalog(catalog) Making predictions ^^^^^^^^^^^^^^^^^^ .. list-table:: :header-rows: 1 :widths: 50 50 * - JAX - PyTorch * - .. code-block:: python predictions = model.predict_interval( interval, requested_outputs={ dna_model.OutputType.RNA_SEQ, }, ontology_terms=['EFO:0001187'], ) predictions.rna_seq.metadata - .. code-block:: python out = model.predict( dna_onehot, organism_index=0, named_outputs=True, ) out.rna_seq[1].to_dataframe() Filtering ^^^^^^^^^ .. list-table:: :header-rows: 1 :widths: 50 50 * - JAX - PyTorch * - .. code-block:: python # Boolean mask with pandas predictions.rna_seq.filter_tracks( (predictions.rna_seq.metadata[ 'Assay title' ] == 'total RNA-seq').values ) # Multiple conditions predictions.chip_tf.filter_tracks( ( (metadata['transcription_factor'] == 'CTCF') & (metadata[ 'genetically_modified' ].isnull()) ).values ) - .. code-block:: python # Keyword arguments out.rna_seq[1].select( assay_title='total RNA-seq' ) # Multiple conditions out.chip_tf[128].select( transcription_factor='CTCF', genetically_modified=None, ) Strand filtering ^^^^^^^^^^^^^^^^ .. list-table:: :header-rows: 1 :widths: 50 50 * - JAX - PyTorch * - .. code-block:: python # 6 dedicated methods predictions.rna_seq \ .filter_to_positive_strand() predictions.rna_seq \ .filter_to_nonpositive_strand() predictions.splice_junctions \ .filter_to_strand('+') - .. code-block:: python # All via select() out.rna_seq[1].select(strand='+') out.rna_seq[1].select( strand=['-', '.'] ) Padding ^^^^^^^ .. list-table:: :header-rows: 1 :widths: 50 50 * - JAX - PyTorch * - .. code-block:: python # Manual boolean mask padding = metadata.padding mask = ~padding[OutputType.ATAC] # Or via create_track_masks() masks = metadata_lib \ .create_track_masks( metadata, requested_outputs={...}, requested_ontologies=None, ) - .. code-block:: python # Stripped by default out.atac[128].num_tracks # 167 # Or keep + mask out = model.predict( dna, 0, named_outputs=True, include_padding=True, ) mask = out.atac.padding_mask() track.is_padding # per track Feature comparison table ^^^^^^^^^^^^^^^^^^^^^^^^ .. list-table:: :header-rows: 1 :widths: 30 30 30 * - Feature - JAX - PyTorch * - Load metadata - ``model.output_metadata()`` - ``TrackMetadataCatalog.load_builtin()`` * - Filter by field - ``.filter_tracks(bool_mask)`` - ``.select(**criteria)`` * - Filter null/missing - ``metadata['col'].isnull()`` - ``.select(field=None)`` * - Access metadata field - ``metadata['field']`` - ``track.field`` * - Safe field access - --- - ``track.get('field', default)`` * - Check field exists - --- - ``track.has('field')`` * - Strand filtering - ``.filter_to_strand('+')`` - ``.select(strand='+')`` * - Tissue filtering - ``.filter_by_tissue(...)`` - ``.select(gtex_tissue=...)`` * - Get indices - manual numpy - ``.indices()`` * - Get boolean mask - manual numpy - ``.mask()`` * - Padding detection - ``name.str.lower() == 'padding'`` - ``track.is_padding`` * - Strip padding - ``metadata[~metadata['padding']]`` - auto / ``.strip_padding()`` * - Padding mask - ``~metadata.padding[type]`` - ``.padding_mask()`` * - To DataFrame - ``.metadata`` property - ``.to_dataframe()`` * - Arithmetic - direct on objects - direct on objects * - Cross-head filtering - --- - ``.select(**criteria)`` * - Allow empty results - --- - ``.select(allow_empty=True)`` * - Load from DataFrame - --- - ``.from_dataframe(df)`` Design notes ^^^^^^^^^^^^ **Why metadata and tensor are bundled** (``NamedTrackTensor``): After any ``.select()`` call, the returned tensor and metadata are guaranteed to be aligned — no manual index tracking needed. **Why metadata is resolution-independent** (``NamedOutputHead``): Track metadata doesn't change between 1bp and 128bp — only the sequence dimension differs. ``NamedOutputHead`` lets you query metadata and filter without choosing a resolution first.