SparDA: Sparse Decoupled Attention for Efficient Long-Context LLM Inference

2026-06-03Computation and Language

Computation and LanguageMachine Learning
AI summary

The authors present SparDA, a new method to make large language models pay attention more efficiently over long texts. They add a special component called Forecast that predicts which memory parts will be needed next, allowing data to be prepared ahead of time and reducing delays. This approach lowers the overhead of selecting attention parts and speeds up the processing without increasing model size much. Tested on models with 8 billion parameters, SparDA matches or improves accuracy and makes both starting and continuing text generation faster. It also enables using bigger batches on a single GPU, greatly improving overall speed.

sparse attentionlarge language modelskey-value cachesequence lengthGPU prefetchingquery-key-value projectionsattention mechanismcompute bottleneckbatch sizedecode throughput
Authors
Yaosheng Fu, Guangxuan Xiao, Xin Dong, Song Han, Oreste Villa
Abstract
Sparse attention reduces compute and memory bandwidth for long-context LLM inference. However, two key challenges remain: (1) KV cache capacity still grows with sequence length, and offloading to CPU memory introduces a PCIe transfer bottleneck; (2) the sparse selection step itself retains $O(T^2)$ complexity and can dominate attention cost at long contexts. We propose SparDA, a decoupled sparse attention architecture that introduces a fourth per-layer projection, the Forecast, alongside Query, Key, and Value. The Forecast predicts the KV blocks needed by the next layer, enabling lookahead selection that overlaps CPU-to-GPU prefetch with current-layer execution. Because Forecast is decoupled from the attention query, our GQA implementation uses one Forecast head per GQA group, reducing selection overhead versus the original multi-head selector. SparDA adds $<$0.5% parameters and trains only the Forecast projections by matching the original selector's attention distribution. On two sparse-pretrained 8B models, SparDA matches or slightly improves accuracy and delivers up to 1.25$\times$ prefill speedup and 1.7$\times$ decode speedup over the sparse-attention offload baseline. By enabling larger feasible batch sizes on a single GPU, SparDA further reaches up to 5.3$\times$ higher decode throughput than the non-offload sparse baseline. Our source code is available at https://github.com/NVlabs/SparDA.