API Reference

cxt (top-level)

Configuration

class cxt.config.AdapterConfig(ie_in: int = 10, ie_out: int = 50, bottleneck: int = 32, dropout: float = 0.0, use_se: bool = True, new_mask_index: int | None = 0, unfreeze_strategy: str = 'ln_lastN', last_n: int = 2)[source]

Configuration for the sample-size adapter (IEAdapter).

bottleneck: int = 32
dropout: float = 0.0
ie_in: int = 10
ie_out: int = 50
last_n: int = 2
new_mask_index: int | None = 0
unfreeze_strategy: str = 'ln_lastN'
use_se: bool = True
cxt.config.BroadModelConfig

alias of ModelConfig

class cxt.config.ModelConfig(n_layer: int = 10, n_embd: int = 400, n_head: int = 4, output_dim: int = 326, num_samples: int = 50, sample_scale_embd: int = 2, combined_dim: int = 1001, window_size: int = 2000, mask_singletons: bool = True, use_kv_cache: bool = False, bias: bool = False, dropout: float = 0.1, device: str = 'cpu', batch_size: int = 1)[source]

Unified configuration for all cxt model variants.

Instead of separate NarrowModelConfig / BroadModelConfig classes with commented-out blocks, select a preset via PRESETS["broad"] etc.

batch_size: int = 1
bias: bool = False
combined_dim: int = 1001
device: str = 'cpu'
dropout: float = 0.1
for_inference(batch_size: int = 1, device: str = 'cpu') ModelConfig[source]

Return a copy configured for autoregressive decoding.

for_training(batch_size: int = 128, device: str = 'cuda') ModelConfig[source]

Return a copy configured for training (no KV cache).

mask_singletons: bool = True
n_embd: int = 400
n_head: int = 4
n_layer: int = 10
num_samples: int = 50
output_dim: int = 326
sample_scale_embd: int = 2
use_kv_cache: bool = False
window_size: int = 2000
cxt.config.NarrowModelConfig

alias of ModelConfig

cxt.config.TokenFreeDecoderConfig

alias of ModelConfig

class cxt.config.TrainingConfig(max_lr: float = 0.0003, min_lr: float = 3e-05, warmup_iters: int = 10, lr_decay_iters: int = 150000, batch_size: int = 128, grad_accum_steps: int = 6, weight_decay: float = 0.1, betas: tuple[float, float] = (0.9, 0.95), num_workers: int = 16, prefetch_factor: int = 2)[source]

Hyperparameters for Lightning training.

batch_size: int = 128
betas: tuple[float, float] = (0.9, 0.95)
grad_accum_steps: int = 6
lr_decay_iters: int = 150000
max_lr: float = 0.0003
min_lr: float = 3e-05
num_workers: int = 16
prefetch_factor: int = 2
warmup_iters: int = 10
weight_decay: float = 0.1

Model Loading

Checkpoint download, caching, and model loading.

Consolidates the two setup_cxt_model() implementations from utils.py into a single load_model() factory that:

  • resolves the correct ModelConfig preset (no post-load hacking),

  • downloads checkpoints from GitHub LFS on first use, and

  • handles adapter wrapping internally.

cxt.checkpoint.get_checkpoint_path(model_type: str, force_local: str | None = None) Path[source]

Return local path to checkpoint, downloading if necessary.

cxt.checkpoint.load_model(model_type: str = 'broad', device: str = 'cpu', force_local: str | None = None)[source]

Load a pretrained cxt model ready for inference.

Parameters:
  • model_type (str) – One of: broad, broad+adapter, narrow, residual, broad_w200, w200_wmissing, w200_wmissing_adapter.

  • device (str) – Target device (model is loaded on CPU first, then moved).

  • force_local (str, optional) – Use this checkpoint path instead of downloading.

Returns:

model – Model in eval mode with KV caches allocated.

Return type:

TokenFreeDecoder or FrozenDecoderWithAdapter

Inference

Note

translate() auto-detects the input type (tree sequence, VCF path, or genotype matrix tuple) and dispatches to the appropriate backend. When mutation_rate is provided, a per-block stochastic bias correction is applied using cxt.correction.stochastic_diversity_bias_correction_v2().

SFS Computation

Bias Correction

Model Architecture

Training

Dataset

Simulation

Simulation functions live in two modules:

  • cxt.utils contains simulate_parameterized_tree_sequence, create_sawtooth_demography, sample_demography, sample_population_size, and DemographyStorage.

  • cxt.simulation_ts_only is the CLI entry point used by run_fresh.sh (invoked as python cxt/simulation_ts_only.py).

Preprocessing

Utilities