#!/usr/bin/env python3 """ OpenMythos pretraining on FineWeb-Edu with FSDP + AdamW. Single GPU: python training/3b_fine_web_edu.py Multi-GPU: torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") training/3b_fine_web_edu.py """ import os import math import time import torch import torch.nn as nn import torch.distributed as dist from loguru import logger from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, FullStateDictConfig, StateDictType, ) from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.utils.data import IterableDataset, DataLoader, get_worker_info from contextlib import nullcontext from datasets import load_dataset from open_mythos import OpenMythos from open_mythos.main import TransformerBlock, RecurrentBlock from open_mythos.variants import mythos_3b from open_mythos.tokenizer import MythosTokenizer # --------------------------------------------------------------------------- # Dataset # --------------------------------------------------------------------------- class FineWebEduDataset(IterableDataset): """ Streaming FineWeb-Edu loader yielding fixed-length (input, target) pairs. FineWeb-Edu is trillions of tokens, so `streaming=True` pulls shards on demand instead of materializing to disk. Sharding is two-dimensional — `world_size` ranks × `num_workers` DataLoader workers per rank — and each `(rank, worker_id)` deterministically owns one shard of the global stream. That gives disjoint coverage without any cross-process coordination. Streaming datasets are not seekable, so a resumed run re-enters its shard from the beginning. Acceptable at pretraining scale: the chance of re-playing the same tokens before the run ends is negligible versus the cost of a true resumable loader. """ def __init__(self, encoding, seq_len: int, subset: str, rank: int, world_size: int): """ Args: encoding -- tokenizer exposing `.encode(str) -> list[int]` seq_len -- context length; every yielded pair has this many tokens subset -- FineWeb-Edu config name (e.g. "sample-10BT", "default") rank -- global rank of this process within the distributed job world_size -- total number of distributed processes """ self.encoding = encoding self.seq_len = seq_len self.subset = subset self.rank = rank self.world_size = world_size def __iter__(self): """ Yield `(input_ids, target_ids)` tensors of length `seq_len` forever. Inputs and targets are shifted by one for next-token prediction — `target[i] == input[i + 1]`. Documents are concatenated into a rolling buffer and sliced into fixed-length chunks, packing short docs together and splitting long ones. This keeps every step at the same shape, which under FSDP avoids recompute from variable-length inputs and removes the need for a pad-aware attention mask. """ worker = get_worker_info() num_workers = worker.num_workers if worker else 1 worker_id = worker.id if worker else 0 total_shards = self.world_size * num_workers shard_index = self.rank * num_workers + worker_id ds = load_dataset( "HuggingFaceFW/fineweb-edu", name=self.subset, split="train", streaming=True, ).shard(num_shards=total_shards, index=shard_index) buf = [] for sample in ds: buf.extend(self.encoding.encode(sample["text"])) while len(buf) >= self.seq_len + 1: chunk = buf[: self.seq_len + 1] buf = buf[self.seq_len + 1 :] yield ( torch.tensor(chunk[:-1], dtype=torch.long), torch.tensor(chunk[1:], dtype=torch.long), ) # --------------------------------------------------------------------------- # LR schedule: linear warmup → cosine decay # --------------------------------------------------------------------------- def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float: """ Linear warmup → half-cosine decay to `min_lr`. Standard language-model pretraining schedule. The warmup phase prevents Adam's second-moment estimate from collapsing to a huge LR in the first few steps when gradients are noisy. The cosine tail lets the model make small, increasingly conservative updates near the end of training rather than crashing to `min_lr` at a fixed step. Behavior by region: step < warmup → linear ramp 0 → max_lr warmup ≤ step < total → cosine decay max_lr → min_lr step ≥ total → clamped at min_lr (safety for off-by-one step counters at the end of training) Args: step -- current global optimizer step (0-indexed) warmup -- number of warmup steps before cosine decay begins total -- step at which the cosine reaches `min_lr` max_lr -- peak learning rate reached at the end of warmup min_lr -- floor learning rate at and after `total` steps Returns: Scalar learning rate for this step. """ if step < warmup: return max_lr * step / warmup if step >= total: return min_lr decay = (step - warmup) / (total - warmup) return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay)) # --------------------------------------------------------------------------- # Checkpointing # --------------------------------------------------------------------------- def _list_ckpts(ckpt_dir: str) -> list[str]: """ Return checkpoint paths in `ckpt_dir` sorted oldest → newest. Relies on the zero-padded `step_{0000000}.pt` filename convention so lexicographic sort matches chronological order. Changing the filename format elsewhere without updating the pad width would silently break both `keep_last` pruning and resume-latest on startup, since both pick the last element of this list. Args: ckpt_dir -- directory to scan; missing directory returns [] Returns: Sorted list of absolute paths to matching checkpoint files. """ if not os.path.isdir(ckpt_dir): return [] return sorted( os.path.join(ckpt_dir, f) for f in os.listdir(ckpt_dir) if f.startswith("step_") and f.endswith(".pt") ) def save_checkpoint( model, optimizer, step: int, cfg, vocab_size: int, ckpt_dir: str, ddp: bool, master: bool, keep_last: int = 3, ) -> None: """ Gather full model + optimizer state, write atomically, prune old files. Under FSDP both states are collected inside a single FULL_STATE_DICT context so the optim-state tensors bind to fully-unsharded parameters; mixing contexts between model and optimizer has caused silent divergence on resume in past torch versions. The temp-file + os.replace write means a kill mid-save leaves the previous checkpoint intact instead of a truncated .pt file. Non-master ranks participate in the FSDP gather (otherwise the collective would hang) but exit before touching disk. Args: model -- FSDP-wrapped (ddp=True) or raw (ddp=False) model optimizer -- the optimizer whose state should round-trip with the model step -- global step number; encoded zero-padded into the filename cfg -- model config object; saved so downstream eval can reconstruct the model without re-importing the variant vocab_size -- tokenizer vocab size at train time; saved for sanity-check on load against a (possibly updated) tokenizer ckpt_dir -- directory to write into; created if missing ddp -- True if FSDP path; False for single-GPU / CPU master -- whether this rank writes to disk (rank 0 only) keep_last -- number of most-recent checkpoints to retain; older ones are unlinked after a successful write Returns: None. Writes to disk as a side effect on master rank. """ if ddp: with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True), ): model_state = model.state_dict() optim_state = FSDP.optim_state_dict(model, optimizer) else: model_state = model.state_dict() optim_state = optimizer.state_dict() if not master: return os.makedirs(ckpt_dir, exist_ok=True) final_path = os.path.join(ckpt_dir, f"step_{step:07d}.pt") tmp_path = final_path + ".tmp" torch.save( { "step": step, "model": model_state, "optimizer": optim_state, "cfg": cfg, "vocab_size": vocab_size, }, tmp_path, ) os.replace(tmp_path, final_path) for old in _list_ckpts(ckpt_dir)[:-keep_last]: try: os.remove(old) except OSError as exc: logger.warning(f"Failed to prune old checkpoint {old}: {exc}") logger.success(f"Checkpoint saved → {final_path}") def load_checkpoint(model, optimizer, path: str, ddp: bool) -> int: """ Restore model + optimizer from disk, returning the step to resume at. Every rank reads the file (`rank0_only=False` on load) so FSDP has access to the full state on each rank — the complement to the `rank0_only=True` save path. Must mirror save's single-context pattern; splitting the model and optimizer loads across two `state_dict_type` blocks has historically produced optimizer state bound to the wrong shard shapes. `weights_only=False` is required because the checkpoint contains the pickled `cfg` dataclass — flip to `weights_only=True` only if you separate config out. Args: model -- same FSDP-wrapped or raw model used during save optimizer -- freshly constructed optimizer to be filled in-place path -- absolute path to a `step_{N:07d}.pt` file produced by `save_checkpoint` ddp -- whether the model is FSDP-wrapped; must match the save run Returns: The step number the checkpoint was taken at; the caller advances the training loop from this value. """ ckpt = torch.load(path, map_location="cpu", weights_only=False) if ddp: with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=False), ): model.load_state_dict(ckpt["model"]) optim_state = FSDP.optim_state_dict_to_load( model=model, optim=optimizer, optim_state_dict=ckpt["optimizer"], ) optimizer.load_state_dict(optim_state) else: model.load_state_dict(ckpt["model"]) optimizer.load_state_dict(ckpt["optimizer"]) return int(ckpt["step"]) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): """ End-to-end pretraining entry point. Order matters: distributed init must run before any CUDA allocation, the tokenizer must exist before the model is built (vocab_size flows into cfg), and FSDP must wrap the model before the optimizer is constructed (FSDP re-flattens parameters, so an optimizer built on the unwrapped model would track stale param objects). Resume then loads state into the already-constructed optimizer in-place. Lifecycle: 1. Initialize torch.distributed (NCCL) if launched under torchrun. 2. Build tokenizer → derive vocab_size. 3. Construct OpenMythos with the 3B variant config. 4. Wrap in FSDP with FULL_SHARD + bf16/fp16 mixed precision (multi-GPU) or move to device + autocast (single-GPU). 5. Build fused AdamW on (possibly sharded) parameters. 6. Resume from the latest checkpoint in `ckpt_dir` if one exists. 7. Stream FineWeb-Edu through grad-accumulation microbatches with cosine LR schedule, per-step logging, and periodic checkpoints. 8. Write a final checkpoint if the last save wasn't aligned to `ckpt_every`, then barrier + tear down the process group. All hyperparameters are literal constants in this function by design — pretraining runs are long-lived and each run pins exact settings; a CLI/config layer is deliberately avoided to keep the file self-auditable. """ # ------------------------------------------------------------------ # Distributed init # ------------------------------------------------------------------ ddp = int(os.environ.get("RANK", -1)) != -1 if ddp: dist.init_process_group("nccl") rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) device = f"cuda:{local_rank}" torch.cuda.set_device(device) else: rank = local_rank = 0 world_size = 1 device = "cuda" if torch.cuda.is_available() else "cpu" master = rank == 0 if master: logger.info( f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}" ) # ------------------------------------------------------------------ # Tokenizer # ------------------------------------------------------------------ encoding = MythosTokenizer() vocab_size = encoding.vocab_size if master: logger.info(f"Tokenizer: gpt-oss-20b | Vocab size: {vocab_size:,}") # ------------------------------------------------------------------ # Hyperparameters # ------------------------------------------------------------------ seq_len = 2048 micro_batch = 4 target_tokens = 30_000_000_000 grad_accum = max(1, 256 // (world_size * micro_batch)) global_batch_tok = world_size * micro_batch * grad_accum * seq_len total_steps = target_tokens // global_batch_tok warmup_steps = 2000 lr = 3e-4 wd = 0.1 log_every = 10 ckpt_every = 1000 ckpt_dir = "checkpoints" dataset_subset = "sample-10BT" # → sample-100BT or "default" for full run if master: logger.info( f"seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | " f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}" ) # ------------------------------------------------------------------ # Model # ------------------------------------------------------------------ cfg = mythos_3b() cfg.vocab_size = vocab_size cfg.max_seq_len = seq_len bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported() amp_dtype = torch.bfloat16 if bf16_ok else torch.float16 model = OpenMythos(cfg) if ddp: mp_policy = MixedPrecision( param_dtype=amp_dtype, reduce_dtype=amp_dtype, buffer_dtype=amp_dtype, ) wrap_policy = ModuleWrapPolicy({TransformerBlock, RecurrentBlock}) model = FSDP( model, sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=mp_policy, auto_wrap_policy=wrap_policy, device_id=local_rank, ) else: model = model.to(device) amp_ctx = ( torch.amp.autocast(device_type="cuda", dtype=amp_dtype) if "cuda" in device else nullcontext() ) # FSDP handles its own mixed precision; only need autocast for single-GPU amp_ctx = nullcontext() if ddp else amp_ctx # type: ignore[possibly-undefined] if master: n_params = sum(p.numel() for p in model.parameters()) logger.info(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}") # ------------------------------------------------------------------ # Optimizer # ------------------------------------------------------------------ optimizer = torch.optim.AdamW( model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True ) # ------------------------------------------------------------------ # Resume from latest checkpoint (if any) # ------------------------------------------------------------------ # Streaming datasets are not resumable by position, so re-iterating from # the beginning is accepted — at pretraining scale the loss of dataset # position is negligible vs. the cost of discarded training steps. start_step = 0 existing_ckpts = _list_ckpts(ckpt_dir) if existing_ckpts: latest = existing_ckpts[-1] if master: logger.info(f"Resuming from checkpoint: {latest}") start_step = load_checkpoint(model, optimizer, latest, ddp) if master: logger.success(f"Resumed at step {start_step}") # ------------------------------------------------------------------ # Dataset + DataLoader # ------------------------------------------------------------------ dataset = FineWebEduDataset(encoding, seq_len, dataset_subset, rank, world_size) loader = DataLoader(dataset, batch_size=micro_batch, num_workers=4, pin_memory=True) # ------------------------------------------------------------------ # Training loop # ------------------------------------------------------------------ if master: os.makedirs(ckpt_dir, exist_ok=True) model.train() data_iter = iter(loader) t0 = time.perf_counter() step = start_step while step < total_steps: cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1) for g in optimizer.param_groups: g["lr"] = cur_lr optimizer.zero_grad() loss_accum = 0.0 for micro_step in range(grad_accum): try: x, y = next(data_iter) except StopIteration: data_iter = iter(loader) x, y = next(data_iter) x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True) y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True) sync = ( nullcontext() if (not ddp or micro_step == grad_accum - 1) else model.no_sync() ) with sync, amp_ctx: logits = model(x) loss = nn.functional.cross_entropy( logits.view(-1, vocab_size), y.view(-1) ) loss = loss / grad_accum loss.backward() loss_accum += loss.item() # FSDP shards parameters, so `nn.utils.clip_grad_norm_` would clip # against each rank's local norm and miss the cross-shard gather. # FSDP.clip_grad_norm_ computes the true global norm and returns it. if ddp: grad_norm = model.clip_grad_norm_(1.0) else: grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() step += 1 if master and step % log_every == 0: dt = time.perf_counter() - t0 tok_per_sec = global_batch_tok * log_every / dt tokens_seen = step * global_batch_tok logger.info( f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} " f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} " f"| {tok_per_sec / 1e6:.2f}M tok/s " f"| {tokens_seen / 1e9:.1f}B tokens seen" ) t0 = time.perf_counter() if step % ckpt_every == 0: save_checkpoint( model, optimizer, step, cfg, vocab_size, ckpt_dir, ddp, master ) # Final checkpoint — total_steps may not be divisible by ckpt_every, so # without this the tail of the run is lost if the schedule doesn't align. if step > start_step and step % ckpt_every != 0: save_checkpoint(model, optimizer, step, cfg, vocab_size, ckpt_dir, ddp, master) if ddp: # Barrier so no rank exits while another is still finishing its # checkpoint gather — avoids NCCL "process group destroyed" noise. dist.barrier() dist.destroy_process_group() if master: logger.success("Training complete.") if __name__ == "__main__": main()