Training¶
All cxt model variants are trained using a single unified script
(python -m cxt.train) built on PyTorch Lightning. This page documents the
exact commands that reproduce every checkpoint shipped with the package.
Checkpoint dependency graph¶
The six released checkpoints form a dependency chain. Models lower in the graph are fine-tuned from checkpoints higher up:
narrow (from scratch on "processed_narrow" — constant only)
broad (from scratch on "processed")
├── broad+adapter (adapter on "processed_n10", frozen broad backbone)
│ └── w200_wmissing_adapter (--resume-adapter on "processed_small_window_missing_data_n10")
└── broad_w200 (fine-tuned on "processed_small_window")
└── w200_wmissing (fine-tuned on "processed_small_window_missing_data")
Training data summary¶
Which model trains on which dataset and whether it uses fine-tuning:
Model |
Fine-tuning |
Dataset |
Source scenarios |
Window |
Samples |
Pairs |
|---|---|---|---|---|---|---|
|
No (from scratch) |
|
constant only |
w2000 |
50 |
200 |
|
No (from scratch) |
|
all |
w2000 |
50 |
200 |
|
Yes ← |
|
13 high-Ne stdpopsim |
w200 |
50 |
200 |
|
Yes ← |
|
all |
w2000 |
10 |
20 |
|
Yes ← |
|
13 high-Ne stdpopsim |
w200 |
50 |
200 + bitmask |
|
Yes ← |
|
13 high-Ne stdpopsim |
w200 |
10 |
20 + bitmask |
CLI reference¶
python -m cxt.train \
--model <preset_name> \
--dataset-path <preprocessed_dir> \
--gpus <gpu_ids> \
--epochs <n> \
[--lr <learning_rate>] \
[--batch-size <bs>] \
[--grad-accum <steps>] \
[--checkpoint <path_or_model_type>] \
[--adapter] \
[--adapter-samples <n>] \
[--adapter-bottleneck <dim>] \
[--adapter-dropout <float>] \
[--resume-adapter <ckpt_path>] \
[--log-dir <directory>]
Key arguments:
--model: preset name (narrow,broad,broad_w200,w200_wmissing)--dataset-path: path to a preprocessed dataset (must containtrain/andtest/subdirectories)--gpus: GPU indices (e.g.0 1)--checkpoint: warm-start from checkpoint (for fine-tuning)--adapter: enable adapter training with frozen backbone--adapter-samples: adapter input dimension (sample count)--resume-adapter: resume adapter training from a fullLitAdapterDecodercheckpoint (loads both backbone and adapter weights, not just the backbone). Used for two-stage adapter training (e.g. trainingw200_wmissing_adapterfrombroad+adapter).--log-dir: root directory forlightning_logs/(default: current working directory). Use this to redirect checkpoints and TensorBoard logs to an isolated directory.
Reproducing all checkpoints¶
Below are the exact commands for each checkpoint. Replace paths with your actual directories.
1. narrow¶
Trained from scratch on the constant-\(N_e\) dataset only. 6 layers, 6 epochs.
Checkpoint: narrow_epoch=5-step=4692.ckpt
python -m cxt.train \
--model narrow \
--dataset-path /path/to/base_dataset/processed_narrow \
--gpus 0 1 \
--epochs 6
2. broad¶
Trained from scratch. 10 layers, 2 epochs. This is the main model and serves as the backbone for all downstream fine-tuning.
Checkpoint: broad_epoch=1-step=5280.ckpt
python -m cxt.train \
--model broad \
--dataset-path /path/to/processed \
--gpus 0 1 \
--epochs 2
3. broad_w200¶
Fine-tuned from the broad checkpoint on 200 bp window data. Uses a
reduced learning rate.
Checkpoint: broad_w200_epoch=1-step=944.ckpt
python -m cxt.train \
--model broad_w200 \
--dataset-path /path/to/processed_small_window \
--gpus 0 1 \
--epochs 2 \
--lr 3e-5 \
--checkpoint /path/to/broad_epoch=1-step=5280.ckpt
4. w200_wmissing¶
Fine-tuned from the broad_w200 checkpoint on data with encoded
missingness. The w200_wmissing preset automatically sets
mask_singletons=False.
Checkpoint: w200_wmissing_epoch=1-step=944.ckpt
python -m cxt.train \
--model w200_wmissing \
--dataset-path /path/to/processed_small_window_missing_data \
--gpus 0 1 \
--epochs 2 \
--lr 3e-5 \
--checkpoint /path/to/broad_w200_epoch=1-step=944.ckpt
5. broad+adapter¶
Lightweight adapter on top of a frozen broad backbone. Trained on
10-sample data. The adapter learns to map from 10 input samples to the
50-sample feature space expected by the backbone.
Checkpoint: broad_adapter_epoch=2-step=792.ckpt
python -m cxt.train \
--model broad \
--adapter \
--adapter-samples 10 \
--dataset-path /path/to/processed_n10 \
--gpus 0 1 \
--epochs 3 \
--checkpoint /path/to/broad_epoch=1-step=5280.ckpt
6. w200_wmissing_adapter¶
Two-stage adapter training: the broad+adapter checkpoint already
learned the 10→50 sample mapping on w2000 data. The second stage resumes
that adapter on w200 + bitmask data at low learning rate to adapt for
missingness. The --resume-adapter flag loads the full
LitAdapterDecoder checkpoint (backbone + adapter weights).
Checkpoint: w200_wmissing_adapter_epoch=9-step=480.ckpt
python -m cxt.train \
--model w200_wmissing \
--adapter \
--adapter-samples 10 \
--resume-adapter /path/to/broad_adapter_epoch=2-step=792.ckpt \
--dataset-path /path/to/processed_small_window_missing_data_n10 \
--gpus 0 1 \
--epochs 10 \
--lr 3e-5
End-to-end summary¶
For reference, here is the full pipeline from simulation to checkpoints:
┌─────────────────────────────────────────────────────────────┐
│ 1. SIMULATE (python cxt/simulation_ts_only.py) │
│ → base_dataset, ssd, idd, llm/*, stdpopsim/* │
│ │
│ 2. PREPROCESS (python -m cxt.preprocess) │
│ → processed_narrow (w2000, 50 samples, constant only)│
│ → processed (w2000, 50 samples, 200 pairs) │
│ → processed_n10 (w2000, 10 samples, 20 pairs) │
│ → processed_small_window (w200, 50 samples) │
│ → processed_small_window_missing_data (+ bitmask) │
│ → processed_small_window_missing_data_n10 (+ n10) │
│ │
│ 3. TRAIN (python -m cxt.train) │
│ narrow ← processed_narrow (constant only) │
│ broad ← processed │
│ broad_w200 ← processed_small_window + broad ckpt │
│ broad+adapter ← processed_n10 + broad ckpt │
│ w200_wmissing ← processed_sw_missing + broad_w200 │
│ w200_wmissing_adapter ← processed_sw_missing_n10 │
│ + broad+adapter (--resume-adapter)│
└─────────────────────────────────────────────────────────────┘
Training hyperparameters¶
Default training hyperparameters (from cxt.config.TrainingConfig):
Parameter |
Default |
|---|---|
Learning rate |
3e-4 |
Min learning rate |
3e-5 |
Warmup iterations |
10 |
LR decay iterations |
150,000 |
Batch size |
128 |
Gradient accumulation |
6 |
Weight decay |
0.1 |
Optimizer betas |
(0.9, 0.95) |
Precision |
bf16-mixed |
Strategy |
DDP |
Practical notes¶
Checkpoints are saved by PyTorch Lightning under
lightning_logs/in the directory specified by--log-dir(or the current working directory if not set). The--checkpointargument performs warm-start initialization for fine-tuning.Checkpoint cache:
cxt.load_model()looks for pretrained checkpoints in~/.cache/cxt/checkpoints/by default. Set theCXT_CHECKPOINT_CACHEenvironment variable to redirect this (useful for isolated reproduction runs that should not touch the global cache).Singleton masking is controlled by the model preset.
w200_wmissinghasmask_singletons=Falseautomatically.KV cache is disabled during training and enabled automatically at inference via
cxt.load_model().Multi-GPU training uses DDP.
--gpus 0 1selects GPUs by index.Adapter training freezes the backbone and only updates the adapter module plus selected layer-norm and last-N transformer blocks (controlled by
unfreeze_strategy="ln_lastN").Two-stage adapter training (used for
w200_wmissing_adapter): first train an adapter on w2000 data (broad+adapter), then resume with--resume-adapteron w200 + bitmask data. This transfers the learned sample-size mapping while adapting for missingness and smaller windows.