Visualizing how FlashInfer accelerates LLM inference
FlashInfer introduces the concept of attention states, which fully characterize the attention between a query and a set of key/value pairs. Each attention state consists of two components:
The log-sum-exp (LSE) of pre-softmax attention scores
The weighted sum of value vectors using the softmax of scores
The key insight of FlashInfer is that attention states can be merged efficiently. Given two attention states corresponding to different subsets of KV pairs, we can compute the attention state for their union:
This merge operator (⊕) is commutative and associative, allowing flexible computation strategies.
This animation shows how FlashInfer computes attention for a query over 4 KV pairs by partitioning the work and merging results.
When multiple sequences share a common prefix (e.g., same prompt), compute the attention state for the shared part once, then merge with each sequence's unique suffix.
Up to 30x speedup in long-context scenarios
Partition long KV sequences across multiple processing units, compute partial attention states in parallel, then merge the results.
Improves GPU utilization for memory-constrained scenarios