FlashInfer: Attention States & Recursive Merge

Visualizing how FlashInfer accelerates LLM inference

Key Innovation: Attention States

FlashInfer introduces the concept of attention states, which fully characterize the attention between a query and a set of key/value pairs. Each attention state consists of two components:

Generalized Score (s)

s(I) = log(∑i∈I exp(si))

The log-sum-exp (LSE) of pre-softmax attention scores

Generalized Value (v)

v(I) = ∑i∈I softmax(si)vi

The weighted sum of value vectors using the softmax of scores

Recursive Merge Operator

The key insight of FlashInfer is that attention states can be merged efficiently. Given two attention states corresponding to different subsets of KV pairs, we can compute the attention state for their union:

[v(I∪J), s(I∪J)] = [v(I), s(I)] ⊕ [v(J), s(J)]

This merge operator (⊕) is commutative and associative, allowing flexible computation strategies.

Interactive Visualization

This animation shows how FlashInfer computes attention for a query over 4 KV pairs by partitioning the work and merging results.

Applications

Shared-Prefix Batch Decoding

When multiple sequences share a common prefix (e.g., same prompt), compute the attention state for the shared part once, then merge with each sequence's unique suffix.

Up to 30x speedup in long-context scenarios

KV Sequence Parallelism

Partition long KV sequences across multiple processing units, compute partial attention states in parallel, then merge the results.

Improves GPU utilization for memory-constrained scenarios