Introduction
Flash Attention is a fast and memory-efficient implementation of exact attention for transformer models. Developed at Princeton and Tri Dao's lab, it rethinks the attention computation to minimize data movement between GPU high-bandwidth memory and on-chip SRAM, achieving significant speedups without any approximation.
What Flash Attention Does
- Computes exact scaled dot-product attention 2-4x faster than PyTorch native attention
- Reduces memory usage from quadratic to linear in sequence length via tiling
- Supports causal masking, variable-length sequences, and multi-query/grouped-query attention
- Provides fused kernels for the forward and backward pass in training
- Enables training with much longer context windows on the same hardware
Architecture Overview
Flash Attention tiles the Q, K, V matrices into blocks that fit in GPU SRAM and computes attention incrementally using the online softmax trick. By never materializing the full N x N attention matrix in HBM, it reduces memory IO by an order of magnitude. Flash Attention 2 further optimizes parallelism across sequence length and attention heads, and Flash Attention 3 adds asynchronous pipelining on Hopper GPUs.
Self-Hosting & Configuration
- Install via pip with
pip install flash-attn --no-build-isolation(requires CUDA 11.6+ and a compatible GPU) - Supported on NVIDIA Ampere (A100), Ada (RTX 4090), and Hopper (H100) architectures
- Drop-in replacement for
torch.nn.functional.scaled_dot_product_attention - Hugging Face Transformers integrates Flash Attention via
attn_implementation="flash_attention_2" - Build from source for custom CUDA architectures or development
Key Features
- Exact computation with no approximation or accuracy loss
- IO-aware tiling eliminates the quadratic memory bottleneck
- Fused backward pass kernels for efficient training
- Supports head dimensions up to 256 and FP16/BF16 datatypes
- Widely adopted as default attention in major LLM training frameworks
Comparison with Similar Tools
- PyTorch SDPA — built-in scaled dot-product attention; uses Flash Attention as one backend
- xFormers — Meta library with memory-efficient attention; Flash Attention often faster for standard cases
- FlashInfer — optimized for inference serving with PagedAttention; complementary to Flash Attention
- Triton kernels — custom attention in Triton language; more flexible but typically slower
- Ring Attention — distributes attention across devices for very long sequences; orthogonal optimization
FAQ
Q: Does Flash Attention change model outputs? A: No. It computes exact attention. Outputs match standard attention up to floating-point rounding.
Q: Which GPUs are supported? A: NVIDIA Ampere (SM 80), Ada Lovelace (SM 89), and Hopper (SM 90) architectures. Older GPUs like V100 are not supported.
Q: Can I use it for inference only? A: Yes. Flash Attention speeds up both training and inference. Many serving frameworks use it by default.
Q: How do I enable it in Hugging Face Transformers?
A: Pass attn_implementation="flash_attention_2" when loading a model with AutoModelForCausalLM.from_pretrained().