推理不再為 logits「搬磚」:FlashSampling 讓解碼提速 19%

導讀:大型語言模型每生成一個詞彙,都必須先計算數萬個詞彙的機率分佈(logits),從中進行採樣——這批中介資料耗費大量記憶體頻寬卻隨即丟棄。普林斯頓大學聯手德國慕尼黑大學(LMU Munich)提出 FlashSampling 技術,將採樣步驟直接「融入」矩陣乘法運算中,讓 logits 張量不再顯式出現,解碼速度最高提升19%,且在數學上與原始採樣完全等價。

問題癥結:每次採樣都在進行「無效搬磚」

LLM 解碼的每一步都包含三道工序:

  1. 透過矩陣乘法計算出全詞表的 logits(詞表大小常達 10 萬以上);
  2. 利用 Softmax 轉換為機率分佈;
  3. 進行採樣選取一個詞彙。

這三步中,logits 張量猶如「一次性過客」——計算完畢即刻失去用途,卻佔用了大量高頻寬記憶體(HBM)的讀寫資源。隨著模型詞表擴大、批量(batch size)增加,這塊開銷已變得越來越不可忽視。

FlashSampling:將採樣融入矩陣乘法,logits 永不落地

核心洞察:採樣的本質是獲取加權隨機最大值,這可以透過 Gumbel-Max 技巧改寫——為每個詞彙的分數加上一個 Gumbel 噪聲,取其 argmax 即等價於機率採樣。

FlashSampling 將此噪聲注入步驟直接嵌入矩陣乘法的 epilogue(尾處理階段),讓採樣與矩陣乘法在同一個 CUDA kernel 中完成:

  • logits 張量從未寫入 HBM,僅在晶片上快取(SRAM)中即取即用、用後即棄;
  • 僅需傳輸一個 token 索引,而非整個 logits 向量;
  • 支援 top-k、nucleus sampling 等常用策略,無需修改模型架構。

數據說話:H100 上最高提速 19%,1200B 超大模型同樣受益

模型規模加速幅度場景
1.7B+19%單 GPU 小批量
70B~+8%多 GPU 推理
1200B顯著提升超大規模集群

在多 GPU 場景中,FlashSampling 還減少了 AllReduce 的通訊量——原本需同步完整的 logits,現在只需同步降維後的資訊,通訊開銷大幅下降。

該技術已整合至 vLLM 框架,經數學題生成等實際任務驗證無品質損耗,原始碼已在 GitHub 開源(https://github.com/FlashSampling/FlashSampling)。

為何值得關注?

當前大型模型推理優化主要集中在KV Cache 壓縮(如 Google TurboQuant、NVIDIA KVTC)和投机解碼(Speculative Decoding)。FlashSampling 指出了一個此前被忽視的「隱形開銷」——採樣本身的記憶體代價。

這是一次無損耗的純系統優化:不改變模型權重、不犧牲輸出品質、不需要額外硬體,即插即用。對於已在生產環境運行數千張 GPU 的推理集群而言,哪怕只有 10% 的吞吐量提升,都意味著直接節省真實的運算成本。

採樣從來不是推理瓶頸的主角,但 FlashSampling 提醒我們:大型模型系統優化的空間,往往藏在那些「理所當然」的環節裡


來源:arXiv:2603.15854,普林斯頓大學 / 慕尼黑大學 | 2026-03-25


分享網址
AINews·AI 新聞聚合平台
© 2026 AINews. All rights reserved.