NVIDIA 发布 Nemotron-Labs-TwoTower 开放权重扩散语言模型
阅读原文· marktechpost.comNVIDIA这个TwoTower把扩散解码接在已有的AR骨干上,几乎无损质量却让吞吐翻倍,并且开源可商用,对批量文本生成的团队是实在的加速工具。
NVIDIA 发布 Nemotron-Labs-TwoTower,基于冻结的自回归骨干 Nemotron-3-Nano-30B-A3B 的扩散语言模型。采用双塔架构:上下文塔冻结,降噪器塔训练,通过层对齐交叉注意力和状态播种协作。在 2×H100 上 BF16 评估,保留 98.7% 的 AR 基线质量,生成吞吐量提升 2.42 倍(γ=0.8,块大小 S=16)。降噪器在约 2.1T token 上训练,骨干使用 25T token 预训练。总参数约 60B,每 token 活跃参数约 3B/塔。支持扩散、模拟 AR 和 AR 三种解码模式。
NVIDIA 发布了 Nemotron-Labs-TwoTower,这是一个基于预训练自回归骨干网络构建的扩散语言模型。该模型以开放权重形式发布,遵循 NVIDIA Nemotron 开放模型许可协议。此次发布旨在解决文本生成中的吞吐量瓶颈问题。
自回归(AR)模型每次解码一个 token。这种串行过程限制了生成吞吐量。离散扩散语言模型则另辟蹊径。它们并行生成 token,并逐步迭代优化。
大多数扩散语言模型使用一个网络承担两项任务。它在每一步既表示干净 token,又对加噪 token 进行去噪。TwoTower 将这些任务分离到两个塔中。它保持了 AR 基线模型综合基准质量的 98.7%。同时,其实际生成吞吐量提升了 2.42 倍。
太长不看版
- TwoTower 将扩散过程拆分为一个冻结的 AR 上下文塔和一个经过训练的去噪器塔。
- 它以 2.42 倍的吞吐量(γ=0.8,S=16,2×H100)保留了 AR 质量的 98.7%。
- 去噪器在约 2.1T token 上训练;骨干网络使用了 25T token。
- 一个检查点支持扩散、模拟 AR 和 AR 解码三种模式。
Nemotron-Labs-TwoTower
TwoTower 是一种逐块自回归扩散模型。它基于 Nemotron-3-Nano-30B-A3B 实现,这是一个开放权重的混合骨干网络。该骨干网络交错使用了 Mamba-2、自注意力机制和混合专家(MoE)层。
每个塔有 52 层:23 层 Mamba-2、6 层自注意力、23 层 MoE。发布的检查点包含两个塔,总参数量约 600 亿。每个塔每 token 的活跃参数约 30 亿。MoE 使用 128 个可路由专家,其中 6 个激活,外加 2 个共享专家。
两个塔初始时均为同一骨干网络检查点的副本。仅训练去噪器塔,AR 上下文塔保持冻结。去噪器在约 2.1T token 上训练,而骨干网络的预训练数据量为 25T token。
两个塔如何工作
AR 上下文塔对提示词和已确定的 token 进行因果计算。它生成每层的 KV cache 和最终的 Mamba-2 状态,从而保留骨干网络的自回归能力。
扩散去噪器塔对噪声块进行优化。在块内部,它使用块内双向注意力。对于过去的干净块,它保持因果性。
这些塔逐层连接。去噪器层 i 对上下文塔层 i 进行交叉注意力。这种逐层对齐的交叉注意力使得能够多尺度地访问主干网络的表示。此前的方法仅广播最后一个隐藏状态。
另外两项去噪器的修改也很重要。Mamba-2 层从其初始状态源自上下文塔的 Mamba 状态。扩散时间步通过 adaLN-single 时间调节来调制每一层。这个 adaLN 模块只增加了约 150 万个参数。
生成是逐块进行的。每个块从 S 个 [MASK] 模型 token 开始。去噪器经过 T 步对其进行细化,然后确定该块。之后,上下文塔处理已确定的块以更新其缓存。
这解释了为什么多次去噪步骤仍然可以胜过单模型 token 解码。自回归解码每一步仅确定一个模型 token。TwoTower 在细化的早期阶段每步确定多个模型 token。
基准测试
评估在 2×H100 GPU 上使用 BF16 精度进行。默认运行点为置信度解掩码,阈值 γ=0.8,块大小 S=16。该表格将 AR 基线方法与 TwoTower 扩散解码进行了比较。
| 任务 | Nemotron-3-Nano-30B-A3B(自回归) | Nemotron-Labs-TwoTower(扩散) |
|---|---|---|
| MMLU(5-shot, 准确率) | 78.56 | 78.24 |
| MMLU-Pro(5-shot, 思维链精确匹配) | 62.59 | 60.93 |
| ARC-Challenge(25-shot, 标准化准确率) | 91.72 | 92.66 |
| WinoGrande(5-shot, 准确率) | 76.09 | 76.09 |
| RACE(0-shot, 准确率) | 88.90 | 88.90 |
| HumanEval(0-shot) | 79.27 | 75.58 |
| MBPP-Sanitized(3-shot) | 74.71 | 74.28 |
| GSM8K(8-shot, 准确率) | 92.49 | 90.14 |
| MATH-500(4-shot) | 84.40 | 80.60 |
| MMLU Global Lite(5-shot) | 73.97 | 73.94 |
| MGSM(8-shot, 平均准确率) | 80.80 | 80.40 |
| 质量保持 | 100% | 98.7% |
| 生成吞吐量(× 自回归) | 1.0× | 2.42× |
通用知识得分与自回归基线相差约一个点以内。代码和数学能力有适度下降。常识和多语言得分恢复或略有提升。降低 γ 值会使每步确定的模型 token 更多,从而提高吞吐量,但会降低质量。
运行方式:三种生成模式
该检查点提供了三种推理路径。完整的双塔扩散使用 2 个 GPU,在 BF16 精度下每个 GPU 约 59GB。仅自回归模式可在单个 80GB 的 GPU 上运行。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "nvidia/Nemotron-Labs-TwoTower-30B-A3B-Base-BF16"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, trust_remote_code=True,
)
# context tower -> GPU 0, denoiser tower -> GPU 1
model.place_towers_on_devices("cuda:0", "cuda:1")
model.eval()
prompt = "France is a country "
inputs = tokenizer(prompt, return_tensors="pt").to("cuda:0")
outputs = model.generate_mask_diffusion(
inputs["input_ids"], max_new_tokens=128,
block_size=16, steps_per_block=16, mask_token_id=3,
temperature=0.1, confidence_threshold=0.8,
eos_token_id=tokenizer.eos_token_id,
)
print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True))三个模式分别是 generate_mask_diffusion()、generate_mock_ar() 和 generate_ar()。掩码扩散每步最多提交 block_size 个 token。Mock-AR 和 AR 每步提交一个 token。
适用场景
最直接的用途是加速批量生成。数据团队在生成合成文本时,可以用少量质量下降换取吞吐量提升。在 γ=0.8 时,这种交换是 1.3% 的质量损失换来 2.42 倍速度。
第二种用途是调整质量与吞吐量的权衡。根据 NVIDIA 的论文,提高 γ 可以保留更多质量。降低 γ 则每步提交更多 token 以换取速度。
第三种用途是即插即用式适配。上下文塔保留其语言模型头,用于推测解码、验证或 AR 评分。团队可以基于同一个检查点运行 AR 和扩散。
优势与不足
优势:
- 采用 NVIDIA Nemotron 开放模型许可协议开源权重,可用于商业用途
- 在默认工作点下,以 2.42 倍吞吐量保留了 98.7% 的 AR 质量
- 单个检查点同时支持扩散、mock-AR 和 AR 解码
- 去噪器使用约 2.1T 个 token 训练,无需完整重新预训练
- 序列长度缓存内存规模与 AR 基线相当
不足:
- 完整的双塔扩散需要 2 张 GPU,每张约 59GB BF16 显存
- 代码和数学任务退化幅度大于通用知识(HumanEval 从 79.27 降至 75.58)
- 同时保留两个塔会增加固定的模型权重内存占用
- 发布的检查点是基础模型,尚未经过指令微调或对齐
- 吞吐量超过 3 倍时质量损失更大