Flash-GMM:面向可扩展软聚类的内存高效内核
阅读原文· arxiv.orgFlash-GMM 是一个基于 Triton 的融合内核,可在单次 GPU pass 中高效计算大规模高斯混合模型(GMM)。它无需在 GPU 内存中实例化完整责任矩阵,相比现有实现实现 20 倍加速,并支持在单设备上训练比之前大 100 倍以上的数据集。将 Flash-GMM 集成到 IVF 粗量化器中用于近似最近邻搜索(ANN)后,软 GMM 聚类可替代 k-means,利用 GMM 责任矩阵将边界向量分配到多个簇。该方法达到固定召回目标时所需距离计算减少 1.7 倍,或在同等计算成本下召回@10 提升 2–12。该内核已作为开源项目发布。
We present Flash-GMM, a fused Triton kernel for efficient computation of Gaussian Mixture Models (GMMs) over large-scale data in a single GPU pass. By eliminating the need to materialize the full responsibility matrix in GPU memory, Flash-GMM achieves a 20times speedup over existing implementations and enables training on datasets more than 100times larger than previously feasible on one device. To demonstrate its impact, we integrate Flash-GMM into the IVF coarse quantizer for approximate nearest-neighbor (ANN) search. We show that soft GMM clustering is now a viable drop-in replacement for k-means, and that GMM responsibilities can be leveraged to assign border vectors to multiple clusters. Our approach reaches fixed recall targets with up to 1.7times fewer distance computations, or equivalently, yields +2--12 recall@10 at matched computational cost. We release the kernel as an open-source project.