MiniMax 发布 MSA 稀疏注意力方法,开源推理内核并推出 MiniMax-M3 模型
MiniMax 把长上下文注意力从 O(N) 压到固定每查询 2048 token,还同时开源高效内核与生产模型,对做长上下文 agent 的团队是即时可用的方法,遗憾是只限 SM100 GPU。
MiniMax 发布 MSA(MiniMax Sparse Attention),一种构建在 Grouped Query Attention 上的稀疏注意力方法。它将注意力分解为索引分支与主分支:索引分支以块粒度(默认 128 token)为每个 GQA 组选择 16 个 token 块(固定预算 2048 个键值 token),主分支仅在这些块上执行精确 softmax 注意力。MSA 在 109B 参数 MoE 模型上训练,开源了面向 NVIDIA SM100 GPU 的推理内核 `fmha_sm100`(MIT 许可,支持 BF16/FP8/NVFP4/FP4),并发布生产模型 MiniMax-M3。MSA-PT 在 MMLU、GSM8K、HumanEval、RULER-8K、RULER-32K 上分别达 67.2、77.7、64.0、84.2、77.5,与全注意力基线持平。128K 上下文下,其 exp-free Top-k 选择比 `torch.topk` 快 5.1 倍。
MiniMax 发布了 MSA(MiniMax 稀疏注意力),这是一种直接构建在分组查询注意力(GQA)之上的稀疏注意力方法。它针对一个瓶颈:在长上下文中,softmax 注意力的二次方成本。MiniMax 研究团队在一个拥有 1090 亿参数的混合专家模型中对其进行了测试,该模型使用原生多模态数据训练。他们还开源了一个推理内核,并发布了一个生产模型 MiniMax-M3。
什么是 MSA(MiniMax 稀疏注意力)
MSA(MiniMax 稀疏注意力)将注意力分解为两个阶段:索引分支(Index Branch)和主分支(Main Branch)。索引分支决定每个查询应该读取哪些键值块。主分支随后仅对这些块执行精确的 softmax 注意力。
选择是在块粒度上进行的,而不是按词元。默认块大小为 Bk = 128 个词元。每个查询和 GQA 组保留 k = 16 个块。这将每个查询的预算固定为 kBk = 2,048 个键值词元。
这两种成本结构不同。密集 GQA 注意力的每个查询缩放为 O(N),即完整上下文。MSA 缩放为 O(kBk),随着 N 增长而保持不变。因此,随着上下文长度的增加,计算差距会扩大。
选择在每个 GQA 组内部共享,但在组之间独立。一个键值头服务于多个查询头,它们共享一个块集合。不同的组可以关注不同的长距离区域。
两个分支如何工作
索引分支仅向标准 GQA 层添加两个投影矩阵。它为每个 GQA 组定义一个索引查询头和一个共享的索引键头。它对可见的键词元进行评分,然后将这些分数通过最大池化聚合到块级别。
然后,一个 Top-k 算子为每个查询和组选择得分最高的块。包含查询的本地块始终被包括在内。这可以防止选择器丢弃查询的邻近区域。
主分支从选定的块中收集因果可见的词元。它仅对这些词元应用缩放点积 softmax 注意力。每个查询头保留自己的查询投影,但共享该组的块集合。
报告中的一张可视化图表展示了学习型索引器(learned indexer)的选择结果。注意力头集中于局部对角线和第一个块,其余预算留给少数长距离带状区域。


MSA 是如何训练的
Top-k 选择是不可微的,因此语言建模损失无法训练索引投影(index projections)。MSA 通过 KL 对齐损失解决了这个问题。该损失使索引分支(Index Branch)分布与主分支(Main Branch)注意力模式相匹配。教师(teacher)是主分支对所选 token 的组平均(group-averaged)分布。
三种机制稳定了稀疏训练。梯度分离(Gradient Detach)对索引分支输入应用 stop-gradient,将 KL 损失限制在索引投影上,而不影响骨干网络(backbone)。如果没有它,较大的 KL 系数会导致梯度尖峰和损失发散。
索引器预热(Indexer Warmup)在最初若干迭代中让两个分支都运行全注意力。索引器在控制路由之前通过 KL 损失进行学习。强制局部块(Forced Local Block)为邻近上下文保留一个槽位。
消融实验塑造了最终方案。早期变体为索引分支添加了一个带有自身输出的值头(value head)。一旦使用预热,该值头就不再必要。最终设计出于效率考虑将其移除。
MSA 支持两条训练路径。MSA-PT 在 400 亿 token 的索引器预热后从头开始训练。MSA-CPT 则转换一个已在 2.6 万亿 token 上训练好的密集 GQA 检查点,然后继续训练 4000 亿 token,其中包括 400 亿 token 的预热。
内核协同设计
理论上的稀疏性如果没有匹配的 GPU 路径,就无法转化为速度提升。MSA 将算法与两种内核设计思路相结合。
第一种是无指数(exp-free)的 Top-k 选择。Softmax 保留了顺序,因此对原始分数进行排序即可得出相同的索引。该内核在选择之前跳过了 max、exp 和 sum 步骤。在 128K 上下文长度且 k = 16 的条件下,其速度比 torch.topk 快 5.1 倍,同时比 TileLang 的 radix-select 内核快 3.7 倍。
第二种是带查询收集的 KV-外部稀疏注意力。与遍历查询相比,遍历 KV 块会提高算术强度。该内核将 ⌈128/G⌉ 个查询位置打包成一个 128×128 的分数 MMA。一个两阶段的前向传播将注意力计算和合并步骤分散到各个 CTA 上。
这个开源内核 `fmha_sm100` 面向 NVIDIA SM100 GPU。它提供密集 FlashAttention 和稀疏 Top-k 内核,采用 MIT 许可证发布。支持 BF16、FP8、NVFP4 和 FP4 精度。
MSA 与其他稀疏方法的对比
研究团队将 MSA 与四种原生训练的稀疏设计进行了对比。
下表总结了它所描述的差异。
| 方法 | 骨干架构 | 选择粒度 | 索引器/选择信号 |
|---|---|---|---|
| MSA | GQA | 块级(B_k = 128),按 GQA 组进行 Top-k 选择 | KL 对齐损失 |
| NSA | MQA / MHA | 压缩块 + 选中块 + 滑动窗口 | 原生(端到端)训练 |
| InfLLM-V2 | 密集↔稀疏可切换 | 无参数块选择 + 滑动窗口 | 无参数(无训练的索引器) |
| MoBA | GQA | 非常大的 KV 块(块平均键) | 仅 LM 梯度 |
| DSA | MLA(MQA 模式) | Token 级;所有头共享单个 Top-k | ReLU 闪电索引器 |
MSA 的独特之处在于将按 GQA 组共享的 Top-k 与块级选择结合起来。这保持了 KV 读操作的连续性,同时让每个组拥有自己的检索。
质量方面也表现良好。两种稀疏模型总体上与 Full-Attention 基线保持了相当的竞争力。
下表显示了 3T token 预算下的代表性结果。
| 基准 | Full(全注意力) | MSA-PT | MSA-CPT |
|---|---|---|---|
| MMLU | 67.0 | 67.2 | 66.8 |
| GSM8K | 76.2 | 77.7 | 73.7 |
| HumanEval | 61.0 | 64.0 | 57.9 |
| RULER-8K | 79.8 | 84.2 | 77.2 |
| RULER-32K | 75.0 | 77.5 | 75.7 |
| VideoMME | 41.11 | 45.48 | 39.65 |
在长上下文扩展后,MSA-CPT 在 HELMET-128K 和 RULER-128K 上仍接近 Full(全注意力)的水平。每个查询仍只关注 2,048 个键值 token。
解释器演示
包含示例的使用场景
MSA 针对的是上下文长度成为部署瓶颈的工作负载。
- 长时推理智能体:一个跨越数百个推理和动作步骤的智能体会累积大量记录。密集注意力对该历史进行二次方增长的计算。无论长度如何,MSA 都将每个查询的预算保持在 2,048 个 token。
- 仓库级代码推理:一个加载完整仓库的编码智能体可能超过数十万 token。索引器将每个查询路由到少数相关代码块。无关文件保持在所选集合之外。
- 持久记忆:长期运行的助手会持续积累对话状态。MSA 每次读取固定大小的最相关块切片。随着记忆增长,解码成本大致保持平稳。
- 长视频理解:该模型原生支持多模态,并在图像和视频数据上训练。MSA-PT 在多个视频基准(包括 VideoMME 和 TemporalBench)上取得了三次运行中的最高分。稀疏选择可扩展到长视觉 token 序列。
运行内核
最快的方式是使用 Hugging Face kernels 库。
# pip install -U kernels
from kernels import get_kernel
kernel_module = get_kernel("MiniMaxAI/msa", version=0)
sparse_atten_func = kernel_module.sparse_atten_func
sparse_atten_func(...)该仓库还直接展示了规划器、索引器和注意力调用。
import torch
from fmha_sm100 import fmha_sm100, fmha_sm100_plan, sparse_topk_select
page_size, topk = 128, 16
# Dense proxy pass: per-block max score from a cheap Q slice.
proxy_plan = fmha_sm100_plan(
qo_lens, kv_lens, proxy_q.shape[1],
num_kv_heads=1, page_size=page_size, output_maxscore=True,
)
_, max_score = fmha_sm100(
proxy_q, proxy_k_pages, proxy_v_pages, proxy_plan,
kv_indices=kv_indices, output_o=False, output_maxscore=True,
)
# Block scores -> selected KV block indexes.
kv_block_indexes = sparse_topk_select(
max_score.contiguous(), topk, num_valid_pages=num_pages,
)
# Sparse attention over the selected blocks.
sparse_plan = fmha_sm100_plan(
qo_lens, kv_lens, q.shape[1],
num_kv_heads=k_pages.shape[1], page_size=page_size, kv_block_num=topk,
)
out, _ = fmha_sm100(
q, k_pages, v_pages, sparse_plan,
kv_indices=kv_indices, kv_block_indexes=kv_block_indexes,
)这些是该仓库的官方使用示例。输入是调用方准备好的分页键值张量。首次运行时会 JIT 编译索引器,这可能需要几分钟。要求是 SM100 GPU、CUDA Toolkit 和 Python 3.10 或更高版本。
优势与不足
优势
- 在报告的设置下,1M 上下文时每个 token 的注意力计算量下降 28.4 倍。
- 在 H800 上,1M 上下文时实测的端到端加速比达到预填充 14.2 倍、解码 7.6 倍。
- 该设计仅向标准 GQA 层添加了两个投影矩阵。
- 它同时支持从头训练和从稠密 checkpoint 转换。
- 推理内核已通过 MIT 许可证发布。
不足与待解决问题
- 发布的内核针对 NVIDIA SM100;其他架构需要单独处理。
- 在某些子任务上,与全注意力相比仍存在残余的长上下文检索差距。
- 报告的加速比假设特定的头配置和 H800 环境。
- KL 损失相对于普通稠密层增加了训练时的复杂度。
- 结果来自 MiniMax 自身的评估套件,而非第三方复现。