基于Gist Token的简化稀疏注意力
阅读原文· arxiv.org简化稀疏注意力(SSA)无需改变架构,通过在序列中插入gist token并施加注意力掩码进行继续预训练,使模型将各分块关键信息压缩至gist token。推理时,查询仅与少量gist token打分,选择性展开top-k分块的原始token,避免全KV缓存带宽开销。在LongBench上,SSA在相同压缩比下优于压缩和推理时稀疏注意力基线;在检索增强生成中,经继续预训练后超过全注意力5.7个百分点,归因于选择性展开能集中关注相关分块并过滤噪声。分层变体H-SSA在对数线性解码复杂度下,在32倍压缩比时仍维持或提升精度。代码已开源。
Sparse attention can reduce the cost of long-context inference, but most variants introduce new architectural components. We introduce Simplified Sparse Attention (SSA), a simpler approach to sparse attention that requires no architectural changes. Concretely, we first perform continued pretraining on sequences interleaved with gist tokens. We optimize the standard next-token loss as usual, but the gist tokens use an attention mask to restrict what parts of the context the language model can attend to; this teaches the model to pack each chunk's important information into the gist tokens. At inference time, SSA scores chunks via attention between the current query and the small set of gist tokens, selectively unfolding the top-k chunks by reintroducing their corresponding raw tokens. Since the query is scored only against the gist tokens, we avoid the memory-bandwidth cost associated with naive scoring against the full KV cache, without requiring the auxiliary KV cache approach used by sparse attention methods. On LongBench, SSA consistently outperforms compression and inference-time sparse-attention baselines under the same compression ratio. More strikingly, in retrieval-augmented generation, SSA can even outperform full attention after continued pretraining by over 5.7 points. We attribute this to the ability of SSA's selective unfolding, which concentrates attention on the query-relevant chunks and effectively filters out noise. SSA further extends to a hierarchical gist-of-gist variant (H-SSA) that achieves log-linear decoding complexity while maintaining or improving accuracy at high compression ratios up to 32x. The code is available at https://github.com/yuzhenmao/simplified-sparse-attention/.