FlashAttention-2 is a new method for attention computation in Transformers that outperforms the existing methods on both speed and accuracy, especially for longer sequences. It makes the training of GPT-style models up to 2x faster and the inference of attention-based models up to 3x faster on average, reaching a peak speed of 230 TFLOPS on A100 GPU.
FlashAttention-2 is an improved version of FlashAttention, which came out in May 2022. It has better parallelism, more efficient work partitioning, and less non-matmul FLOPs (non-matrix multiplication operations).
FlashAttention-2 allows us to train larger and more powerful models in shorter time.

The problem
Language models such as GPT-4 (32k), MPT (65k), and Claude (100k) can process longer data sequences. As a consequence, they need to use Transformers with longer context lengths.
However, the self-attention mechanism in Transformers (which enables the Transformers to focus on the most relevant parts of the input sequence) is the main bottleneck that prevents them from handling longer sequences. The time and memory complexity of self-attention increases quadratically with the sequence length.
The existing solutions, such as sparse attention, low-rank attention, and kernel approximation have been shown to be effective in some cases, but they do not fully address the problem.
FlashAttention
FlashAttention significantly enhances previous attention mechanisms and improves the scalability of Transformers for long sequences.
It uses tiling (dividing the attention matrix into smaller tiles) and recomputation (reusing the results of attention computation for multiple tokens) to reduce the IO complexity of attention computation (2-4 times) and to lower the memory usage from quadratic to linear in sequence length.
The IO complexity of attention computation is the number of times that the attention matrix needs to be loaded into and out of memory.
FlashAttention has also been extended to block-sparse attention, which can further speed up the attention computation (5-9 times).
It has been widely used by many organizations and research labs.
The figure below shows how FlashAttention works to make attention faster and use less memory.
The left part shows how FlashAttention breaks the data into smaller parts, called tiles, and calculates the attention for each tile separately.
- The data is represented by three matrices: Q (query), K (key), and V (value).
- FlashAttention does not calculate the attention scores for the whole data at once because that would be slow and use a lot of memory. Instead, it splits Q, K, and V into tiles, and calculates the attention scores for each tile separately.
- The data is stored in the GPU memory. GPU has different levels, such as HBM (High-Bandwidth Memory), which is bigger but slower and SRAM (Static Random-Access Memory), which is smaller but faster. FlashAttention tries to use SRAM as much as possible and HBM as little as possible.
The right part shows how much faster FlashAttention is compared to PyTorch attention. FlashAttention is 7.6 times faster than PyTorch attention on GPT-2.

Below we see a diagram of FlashAttention forward pass. It works by dividing the attention computation into blocks. This avoids having to read and write from HBM (High-Bandwidth Memory), which is slow and expensive.

Some of the inefficiencies of FlashAttention are:
- It can only reach 25-40% of the maximum possible speed: up to 124 TFLOPS on A100 GPU.
- It does not parallelize the attention computation across different thread blocks, which limits the occupancy and efficiency of the GPU.
- It does not distribute the work between GPU warps (a group of threads that execute the same instruction in parallel on a GPU) within each thread block, which causes unnecessary shared memory.
FlashAttention-2
The improved version, FlashAttention-2, fixes some of these inefficiencies, by using:
Better Parallelism: FlashAttention-2 parallelizes the attention computation across different thread blocks on the GPU, even for a single head, to increase occupancy and utilization.
Better work partitioning: FlashAttention-2 uses a new partitioning scheme that distributes the work between warps (a group of 32 threads working together) within each thread block to reduce communication through shared memory.
This enables it to reach 50-73% of the A100 GPU’s peak speed. This is a big improvement over FlashAttention, which reached only 25-40% of the peak speed.
In the picture below we can see the work partitioning between different warps in the forward pass for FlashAttention and FlashAttention-2.
- For each block, FlashAttention uses the “sliced-K” scheme. It splits K and V across 4 warps while keeping Q accessible by all warps.
- FlashAttention-2 divides Q across 4 warps while keeping K and V accessible by all warps. This splitting in FlashAttention-2 reduces the amount of synchronization and communication between different warps, resulting in less shared memory reads/writes.

Fewer non-matmul FLOPs: FlashAttention-2 does fewer calculations that are not matmul (the multiplication of two matrices). On A100 GPU, a matmul calculation can be 16 times faster than a non-matmul calculation for FP16/BF16 data type, as the theoretical peak speed for matmul is 312 TFLOPS, while the theoretical peak speed for non-matmul is 19.5 TFLOPS.
FlashAttention-2 also reduces the number of non-matmul FLOPs by tweaking the algorithm to avoid unnecessary computations.
Empirical validation
The empirical validation shows that FlashAttention-2 is a significant improvement over FlashAttention in terms of speed and performance. The picture below shows the speed of different attention methods on an A100 80GB SXM4 GPU for different settings.
We can see that FlashAttention-2 is about 2x faster than FlashAttention in terms of FLOPS (floating-point operations per second), which is a theoretical metric. Compared to a standard attention method in PyTorch, FlashAttention-2 can be up to 9x faster in terms of FLOPS.
However, these speedups may vary depending on the hardware configuration, data type, and sequence length.

FlashAttention-2 can achieve up to 230 TFLOPS on A100 GPUs and up to 335 TFLOPS on H100 GPUs for FP16/BF16 data type. These are the theoretical peak speeds for matmul operations, which are faster than non-matmul operations.
The end-to-end training speed of GPT-style models using FlashAttention-2 reaches up to 225 TFLOPS per A100 GPU, with 72% model FLOPs utilization for FP16/BF16 data type.

Conclusion
FlashAttention-2 is a new algorithm that accelerates attention and reduces its memory usage.
It applies tiling to load blocks of inputs from GPU memory to cache and compute attention with that block. It also leverages better parallelism and work partitioning to enhance the efficiency of the computation.
With FlashAttention-2, we can double the speed of FlashAttention and train models with 16k context length at the same cost as training a 8k context model before.
Learn more:
- Research paper: “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning” (on arXiv)
- Research paper: “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness” (on arXiv)
- Project page