Inference No Longer Wastes Cycles on Logits: FlashSampling Accelerates Decoding by 19%

Editor's Note: Every time a large language model generates a token, it first calculates the probability distribution (logits) for tens of thousands of words and then samples from them. This intermediate data consumes massive memory bandwidth only to be immediately discarded. Researchers from Princeton University and LMU Munich have proposed FlashSampling, a method that integrates sampling directly into matrix multiplication. By eliminating the need to materialize the logits tensor, decoding speeds are boosted by up to 19%, while remaining mathematically equivalent to original sampling methods.

The Core Problem: Wasted Effort in Every Sampling Step

Each step of LLM decoding involves three distinct stages:

  1. Matrix multiplication to calculate logits for the entire vocabulary (often exceeding 100,000 tokens);
  2. Softmax conversion to derive probabilities;
  3. Sampling to select a single token.

In this trio of operations, the logits tensor is a "one-time visitor"—calculated and immediately rendered useless, yet it occupies significant High Bandwidth Memory (HBM) read/write cycles. As model vocabularies expand and batch sizes grow, this overhead becomes increasingly impossible to ignore.

FlashSampling: Integrating Sampling into MatMul, Logits Never Hit Memory

The core insight is that the essence of sampling is selecting a weighted random maximum. This can be reformulated using the Gumbel-Max trick—adding Gumbel noise to each word's score and taking the argmax is mathematically equivalent to probability sampling.

FlashSampling embeds this noise injection directly into the epilogue (tail processing stage) of matrix multiplication, allowing sampling and matmul to be completed within a single CUDA kernel:

  • The logits tensor is never written to HBM; it is used and discarded immediately within on-chip cache (SRAM);
  • Only a single token index needs to be transmitted, rather than the entire logits vector;
  • It supports common strategies like top-k and nucleus sampling without requiring any model modifications.

Data Speaks: Up to 19% Speedup on H100, Benefiting Even 1.2T Parameter Models

Model ScaleAccelerationScenario
1.7B+19%Single GPU, Small Batch
70B~+8%Multi-GPU Inference
1200BSignificant ImprovementUltra-large Scale Clusters

In multi-GPU scenarios, FlashSampling also reduces AllReduce communication volume. Whereas previously complete logits needed synchronization, now only dimensionality-reduced information requires syncing, drastically cutting communication overhead.

The method has been integrated into vLLM. Validation on practical tasks, such as generating solutions to math problems, shows no loss in quality. The code is open-sourced on GitHub at https://github.com/FlashSampling/FlashSampling.

Why This Matters

Current optimizations for large model inference primarily focus on KV Cache compression (e.g., Google's TurboQuant, NVIDIA's KVTC) and Speculative Decoding. FlashSampling highlights a previously overlooked "hidden cost": the memory price of the sampling process itself.

This is a purely systemic optimization with zero loss: it does not alter model weights, sacrifice output quality, or require additional hardware; it is plug-and-play. For inference clusters already running thousands of GPUs in production, even a 10% throughput increase translates directly to saved computational costs.

Sampling has never been the protagonist of inference bottlenecks, but FlashSampling reminds us that the room for optimization in large model systems often hides in those steps we take for granted.


Source: arXiv:2603.15854, Princeton University / LMU Munich | 2026-03-25


分享網址
AINews·AI 新聞聚合平台
© 2026 AINews. All rights reserved.