Use Ctrl+P (or Cmd+P) to save as PDF. Back to paper
Standard attention computes three matrices — queries, keys, values — and materializes the full N×N attention matrix in GPU high-bandwidth memory (HBM). For a sequence of length N, that is O(N²) memory writes and reads. On a GPU, HBM is large but slow. SRAM is fast but tiny.
The bottleneck is not arithmetic. GPUs have enormous compute throughput. The bottleneck is data movement between HBM and SRAM.
Tri Dao et al. named this precisely: attention is IO-bound, not compute-bound.
Flash Attention (2022) does not approximate attention. It computes the exact same result. What it changes is the order of operations.
Instead of materializing the full N×N attention matrix in HBM, Flash Attention tiles the computation into blocks that fit in SRAM. It loads a block, computes on it, and writes only the final output — never writing the intermediate attention matrix to HBM at all.
The result:
No approximation. No architectural change. Faster because of fewer memory round-trips.
Flash Attention-2 (2023) kept the same IO insight and improved parallelism — reducing non-matmul FLOPs, repartitioning work across GPU thread blocks. The result: 1.7–3.0× faster than FA1, 3–10× faster than standard attention on A100 GPUs.
Flash Attention-3 (2024) targeted H100 hardware specifically — exploiting asynchrony between compute and memory ops, adding FP8 support with 2.6× lower numerical error than naive FP8 attention. Peak throughput: 840 TFLOPs/s, 85% of H100 theoretical maximum.
Each version is the same idea applied harder: stop treating memory traffic as an afterthought.
Every inference optimization stack — continuous batching, prefix caching, speculative decoding — assumes Flash Attention is already there. It is the floor, not a feature.
The broader lesson: when a system is slow, identify which resource is saturated before reaching for algorithmic complexity. In this case, the resource was HBM bandwidth. The fix was a better access pattern, not a better algorithm.
If you build on GPUs, understand your memory hierarchy. The gap between theoretical compute and realized throughput is almost always a memory traffic problem.
1. Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022. https://arxiv.org/abs/2205.14135
2. Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. ICLR 2024. https://arxiv.org/abs/2307.08691
3. Shah, J., Bikshandi, G., Thakkar, V., Ramani, P., Dao, T., & Zhang, Y. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision. NeurIPS 2024. https://tridao.me/publications/flash3/flash3.pdf
4. Dao-AILab. (2024). flash-attention GitHub repository. https://github.com/Dao-AILab/flash-attention