Source code for cxt.config

from __future__ import annotations
from dataclasses import dataclass, field, replace
from copy import deepcopy


[docs] @dataclass class ModelConfig: """Unified configuration for all cxt model variants. Instead of separate NarrowModelConfig / BroadModelConfig classes with commented-out blocks, select a preset via ``PRESETS["broad"]`` etc. """ n_layer: int = 10 n_embd: int = 400 n_head: int = 4 output_dim: int = 326 # 324 discretization bins + 2 special tokens num_samples: int = 50 sample_scale_embd: int = 2 combined_dim: int = 1001 # source (500) + target (501) window_size: int = 2000 # base SFS window in bp mask_singletons: bool = True use_kv_cache: bool = False # enable only at inference time bias: bool = False dropout: float = 0.1 device: str = "cpu" batch_size: int = 1 # only matters when use_kv_cache=True
[docs] def for_inference(self, batch_size: int = 1, device: str = "cpu") -> ModelConfig: """Return a copy configured for autoregressive decoding.""" return replace(self, use_kv_cache=True, batch_size=batch_size, device=device)
[docs] def for_training(self, batch_size: int = 128, device: str = "cuda") -> ModelConfig: """Return a copy configured for training (no KV cache).""" return replace(self, use_kv_cache=False, batch_size=batch_size, device=device)
PRESETS: dict[str, ModelConfig] = { "narrow": ModelConfig(n_layer=6), "broad": ModelConfig(n_layer=10), "broad_w200": ModelConfig(n_layer=10, window_size=200), "residual": ModelConfig(n_layer=10), "w200_wmissing": ModelConfig(n_layer=10, window_size=200, mask_singletons=False), }
[docs] @dataclass class AdapterConfig: """Configuration for the sample-size adapter (IEAdapter).""" ie_in: int = 10 ie_out: int = 50 bottleneck: int = 32 dropout: float = 0.0 use_se: bool = True new_mask_index: int | None = 0 unfreeze_strategy: str = "ln_lastN" last_n: int = 2
[docs] @dataclass class TrainingConfig: """Hyperparameters for Lightning training.""" max_lr: float = 3e-4 min_lr: float = 3e-5 warmup_iters: int = 10 lr_decay_iters: int = 150_000 batch_size: int = 128 grad_accum_steps: int = 6 weight_decay: float = 0.1 betas: tuple[float, float] = (0.9, 0.95) num_workers: int = 16 prefetch_factor: int = 2
# --------------------------------------------------------------------------- # Backward compatibility aliases so old checkpoints can be unpickled. # Lightning saves the config class name; these let ``torch.load`` find them. # --------------------------------------------------------------------------- NarrowModelConfig = ModelConfig BroadModelConfig = ModelConfig TokenFreeDecoderConfig = ModelConfig