原文 · 未翻译
In this tutorial, we explore OpenMythos by building an advanced recurrent-depth transformer workflow that runs end-to-end in Google Colab. We create both MLA and GQA model variants, compare their parameter counts, and check the stability of the recurrent injection matrix through its spectral radius. We then move from simple forward and generation tests into a synthetic compositional reasoning task, where the model learns to predict the sum of digit chains modulo a fixed value. Through this setup, we study how recurrent loops enable a single model to reuse its parameters for deeper computation.
Copy CodeCopiedUse a different Browser
import subprocess, sys def pip(*args): subprocess.run([sys.executable, "-m", "pip", "install", "-q", *args], check=False) try: import open_mythos # noqa: F401 except Exception: pip("open-mythos") try: import open_mythos # noqa: F401 except Exception: pip("git+https://github.com/kyegomez/OpenMythos.git") import math, random, time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader import matplotlib.pyplot as plt from open_mythos.main import OpenMythos, MythosConfig SEED = 42 random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED) print(f"Device: {device} | Torch: {torch.__version__}")
We install OpenMythos and fall back to the GitHub source if installing from PyPI fails. We import the required Python, PyTorch, NumPy, and plotting libraries for model building, training, and visualization. We also set a fixed random seed and use CUDA when available, so the tutorial runs efficiently in Colab.
Copy CodeCopiedUse a different Browser
def build_model(attn_type: str = "mla", max_loop_iters: int = 8) -> tuple: """Build a small OpenMythos model. Two attention variants supported. MLA — Multi-Latent Attention (compressed KV cache, DeepSeek-V2 style) GQA — Grouped-Query Attention (fewer KV heads than Q heads) """ base = dict( vocab_size = 64, dim = 128, n_heads = 4, max_seq_len = 32, max_loop_iters = max_loop_iters, prelude_layers = 1, coda_layers = 1, n_experts = 4, n_shared_experts = 1, n_experts_per_tok= 2, expert_dim = 64, lora_rank = 8, attn_type = attn_type, ) if attn_type == "gqa": cfg = MythosConfig(base, n_kv_heads=4, kv_lora_rank=32, q_lora_rank=32, qk_rope_head_dim=16, qk_nope_head_dim=16, v_head_dim=16, ) model = OpenMythos(cfg).to(device) return model, cfg model_mla, cfg_mla = build_model("mla") model_gqa, cfg_gqa = build_model("gqa") def n_params(m): return sum(p.numel() for p in m.parameters()) print(f"\n[MLA] params: {n_params(model_mla):>10,}") print(f"[GQA] params: {n_params(model_gqa):>10,}") def spectral_radius(model): A = model.recurrent.injection.get_A().detach().cpu() if A.dim() == 1: rho = A.abs().max().item() else: rho = torch.linalg.eigvals(A.float()).abs().max().item() return rho print(f"\nρ(A) MLA: {spectral_radius(model_mla):.4f} (must be = DIGIT_BASE] prompt = torch.tensor([toks], device=device) with torch.no_grad(): gen = model.generate(prompt, max_new_tokens=1, n_loops=8) predicted = gen[0, -1].item() print(f"\nDemo: digits={digits}, target=({'+'.join(map(str, digits))}) % {M} = {sum(digits)%M}") print(f" true token={true_tok} (digit {true_tok-DIGIT_BASE}) | " f"predicted token={predicted} (digit {predicted-DIGIT_BASE if predicted>=DIGIT_BASE else '?'})") print("\nDone. Key takeaway: at inference, increasing n_loops trades compute for") print("reasoning depth on the same fixed-parameter model — that's the RDT premise.")