Recipe: Fine-tuning for Missingness¶
This recipe walks through fine-tuning a cxt model to handle missing or inaccessible genomic sites. The result is a model that encodes per-window missingness directly into its source tensor, rather than treating absent sites as carrying no information.
When to fine-tune for missingness¶
The pretrained broad and broad_w200 models assume every base in each
window is observable. When applied to empirical data with substantial
inaccessible regions (low-coverage masking, repetitive-element filters,
accessibility masks), they can produce biased TMRCA estimates in windows
where a large fraction of sites is missing.
Fine-tuning for missingness teaches the model what “no data” looks like at each window scale, so that it can distinguish genuine absence of mutations from unobserved sequence.
Typical use cases:
Whole-genome sequencing with accessibility masks (e.g. Ag1000G, 1000 Genomes strict-mask regions).
Reduced-representation or capture data where only a subset of bases is sequenced.
Any dataset where the callable fraction varies substantially across the genome.
Overview¶
The pipeline has four stages:
Obtain a bitmask – a per-base boolean array marking inaccessible positions.
Simulate training data – or reuse existing tree sequences from the stdpopsim catalog.
Preprocess with the bitmask – the
--bitmaskflag injects random missingness crops from your mask into each training example.Fine-tune – warm-start from the
broad_w200checkpoint and train on the missingness-augmented dataset.
After fine-tuning, the model can be used for inference with the
missingness_bitmask argument to cxt.translate().
Step 1: Prepare a bitmask¶
A bitmask is a 1-D boolean NumPy array where True means accessible
and False means inaccessible. Its length should be at least as long as
the chromosome you plan to analyse. During preprocessing, a random crop of
length sequence_length is drawn from the bitmask for each training
example.
From an Ag1000G-style accessibility file:
import numpy as np
acc = np.load("accessibility.npz")["access_2L"] # True = accessible
unaccessible = ~acc # True = inaccessible
np.savez_compressed("unaccessible_bitmask.npz", access_2L=acc)
From a BED file of callable regions:
import numpy as np
chrom_length = 49_364_325 # e.g. Ag chr2L
mask = np.zeros(chrom_length, dtype=bool)
with open("callable_regions.bed") as fh:
for line in fh:
chrom, start, end = line.split()[:3]
mask[int(start):int(end)] = True
np.savez_compressed("my_bitmask.npz", access_2L=mask)
The key in the .npz file should match the key referenced in your
preprocessing command (by default access_2L).
Step 2: Simulate or gather training data¶
Fine-tuning for missingness uses the same simulated tree sequences as the
broad_w200 model – typically the 13 high-\(N_e\) stdpopsim species.
If you have already simulated these (see Simulation), point
DATA_DIR_LP at the directory containing them.
If starting from scratch, generate a minimal set:
DATA_DIR=/path/to/training_data
SIM="python cxt/simulation_ts_only.py"
# A. gambiae (100 simulations)
${SIM} --num_processes 50 --num_samples 100 \
--data_dir ${DATA_DIR}/stdpopsim/v0.2/stdpopsim_anogam \
--scenario stdpopsim_anogam
# A. aegypti (300 simulations)
${SIM} --num_processes 50 --num_samples 300 \
--data_dir ${DATA_DIR}/stdpopsim/v0.2/stdpopsim_aedaeg \
--scenario stdpopsim_aedaeg
Add as many species as needed. The more diversity in the training set, the more robust the resulting model. See Simulation for the full list.
Step 3: Preprocess with the bitmask¶
Use python -m cxt.preprocess with the --bitmask flag. This does two
things for each training example:
Draws a random crop from the bitmask and deletes the inaccessible intervals from the tree sequence before computing the SFS.
Encodes per-window missingness fractions at each scale into the source tensor (channels
X[0,:,:,0]andX[1,:,:,0]).
python -m cxt.preprocess \
--base_dir ${DATA_DIR_LP} \
--out_subdir processed_small_window_missing_data \
--window_size 200 \
--sequence_length 100000 \
--num_pairs 200 \
--train_ratio 0.9 \
--global_seed 12345 \
--num_workers 75 \
--skip_existing \
--bitmask /path/to/my_bitmask.npz
This produces a dataset with the same structure as any other preprocessed
directory (train/ and test/ subdirectories with X.npy,
y.npy, pairs.npy, and meta.json), except that X now
contains the missingness channels.
Step 4: Fine-tune from broad_w200¶
Fine-tuning warm-starts from the broad_w200 checkpoint and continues
training on the missingness-augmented data at a reduced learning rate. The
w200_wmissing preset automatically sets mask_singletons=False, which
is required for missingness-aware models.
python -m cxt.train \
--model w200_wmissing \
--dataset-path /path/to/processed_small_window_missing_data \
--gpus 0 1 \
--epochs 2 \
--lr 3e-5 \
--checkpoint broad_w200
The --checkpoint argument accepts either a local path to a .ckpt
file or the name of a pretrained model type (e.g. broad_w200), in which
case the checkpoint is downloaded automatically.
Training produces a checkpoint under lightning_logs/ (or the directory
set by --log-dir).
Step 5: Run inference with the fine-tuned model¶
Load the checkpoint and pass missingness_bitmask to
cxt.translate():
import numpy as np
import cxt
model = cxt.load_model(
"/path/to/lightning_logs/version_X/checkpoints/epoch=1-step=944.ckpt",
device="cuda",
)
bitmask = np.load("my_bitmask.npz")["access_2L"]
unaccessible = ~bitmask
tmrca, index_map = cxt.translate(
"genotypes.vcf",
model,
pivot_pairs=[(0, 1), (2, 3)],
blocks=[(0, 100_000)],
devices=["cuda:0"],
missingness_bitmask=unaccessible,
mutation_rate=3.5e-9,
B=128,
build_workers=8,
)
The missingness_bitmask should be a boolean array of shape
(chromosome_length,) where True marks inaccessible positions.
Optional: Adapter for small sample sizes¶
If your empirical dataset has fewer samples than the 50 haplotypes the model was trained on, you can add an adapter on top of the missingness model.
Preprocess adapter data (10 samples, 20 pairs):
python -m cxt.preprocess \
--base_dir ${DATA_DIR_LP} \
--out_subdir processed_small_window_missing_data_n10 \
--window_size 200 \
--sequence_length 100000 \
--num_pairs 20 \
--simplify_first_n_samples 10 \
--train_ratio 0.9 \
--global_seed 12345 \
--num_workers 75 \
--skip_existing \
--bitmask /path/to/my_bitmask.npz
Train the adapter (two-stage: resume from a broad+adapter checkpoint
that already learned the 10 → 50 sample mapping):
python -m cxt.train \
--model w200_wmissing \
--adapter \
--adapter-samples 10 \
--resume-adapter /path/to/broad_adapter_epoch=2-step=792.ckpt \
--dataset-path /path/to/processed_small_window_missing_data_n10 \
--gpus 0 1 \
--epochs 10 \
--lr 3e-5
Inference with the adapter:
adapter_model = cxt.load_model(
"/path/to/w200_wmissing_adapter.ckpt",
device="cuda",
)
tmrca, index_map = cxt.translate(
ts_small,
adapter_model.backbone,
pivot_pairs=[(0, 1)],
blocks=[(0, 100_000)],
devices=["cuda:0"],
missingness_bitmask=unaccessible,
mutation_rate=3.5e-9,
adapter=adapter_model.adapter,
)
See Mosquito: 2L Inversion and RDL Sweep for a complete worked example using both the missingness model and its adapter on the Ag1000G dataset.