FlashInfer开源近1400个TRT-LLM-Gen高性能GPU内核,针对LLM推理优化。以W4A16量化GEMM为例,采用INT4权重与BF16激活,通过3级流水线及Warp专精化(加载、反量化、MMA、Epilogue)提升并行效率。因INT4反量化需CUDA核心处理寄存器,MMA被迫使用TS模式而非TMEM,导致SMEM带宽瓶颈。方案借鉴Cursor设计,通过流水线隐藏CUDA与Tensor Core计算差距,缓解吞吐量损失。
Curious what's in the PR of almost 1400 kernels?
Here we walk through a simple batched GEMM kernel 🟠 Tile size: M128, N16, K256 🟠W4A16: matrix A is INT4 with BF16 scaling factor for every 32 elements, matrix B is BF16 🟠3 pipeline stages 🟠1 CTA MMA 🟠Static scheduler
This warp specialized kernel has the following warp roles: 🟠Load A 🟠Load A scaling factor (SF) 🟠Load B 🟠Cast A: Dequantize INT4 to BF16. Waits on Load A and Load A SF 🟠MMA: Performs matmul. Waits on Cast A and Load B 🟠Epilogue: Performs activation computation. Waits on MMA
An interesting thing about this kernel is that its MMA uses TS mode due to matrix A dequantization requires CUDA cores, which work on registers instead of TMEM.