FlashAttention: A Novel Attention Algorithm with IO Awareness, Fast and Memory-Efficient
At the heart of the Transformer model is the self-attention mechanism (self-attention), which has both time and storage complexity over the length of the sequence of O(N2)O(N2) Level. As the size of Large Language Models (LLMs) continues to grow, equipping LLMs with longer contextual backgrounds presents a very significant challenge in terms of engineering implementation.
A team of researchers from Stanford University's Department of Computer Science and the State University of New York at Buffalo has published a new attention algorithm called FlashAttention, which not only runs 2-4 times faster than PyTorch's standard attention, but also requires 5-20 times less memory. FlashAttention2, Flash Decoding, and FlashDecoding will be released later with even more dramatic performance speedups.
Phantom Cube Basic Research and DevelopmentLarge Model Training Tool HAI-LLM FlashAttention has been adopted across the board to dramatically improve graphics card utilization and achieve excellent training performance. In this series of articles, we will talk about the technology behind FlashAttention and our practical experience.
paper address:https://arxiv.org/abs/2205.14135
Project Source Code:https://github.com/Dao-AILab/flash-attention
contexts
Traditional attention algorithms whose memory efficiency is O(N2)O(N2) The. Some past approaches to optimizing attention mechanisms have used approximations such as sparse approximations, low-rank approximations, and combinations thereof. While these methods can reduce the computation to linear or near-linear (O(N)O(N)), but they are overly concerned with reducing the number of floating-point operations (FLops) performed per second and tend to ignore the overhead from memory accesses (IO).
GPU FLOPS have been growing faster than memory throughput (TB/s) for years. In our practice of optimizing model training on A100 or equivalent graphics cards, we have found that memory throughput is the bottleneck that affects further training efficiency, and FLOPS and memory throughput need to be tightly coupled in order to fully improve training efficiency. This requires more detailed design at the software level.
As shown in the figure below:
The graph above shows the throughput and capacity of different tiers of memory for the CPU and GPU. As you can see memory is not a single component, it is layered in nature and the general rule is: the faster the memory, the more expensive it is and the smaller the capacity.
Take A100 as an example: A100 GPU has 40~80GB of High Bandwidth Memory (HBM) with a bandwidth of 1.5-2.0 TB/s, while every 108 stream processors has 192KB of SRAM, with an estimated bandwidth of 19TB/s. It can be seen that although the SRAM capacity is much smaller, the speedup is 10 times higher. It can be seen that although the SRAM capacity is much smaller, the speed is increased by 10 times, so how to utilize SRAM efficiently is the key to speed up the attention algorithm.
Standard Attention Algorithm
Let's first look at the computational logic behind the standard attention algorithm:
It can be seen that the standard attention algorithm essentially treats HBM load/store operations as 0-cost (it is not IO-aware).
The following figure shows the complete computational time consumption statistics for one Attention operator in the GPT-2 model:
It can be seen that masking, softmax and dropout operations take up a large amount of time, while matrix multiplication (Matmul), which mainly utilizes FLOPS, takes up only a fraction of the time. Therefore, the FlashAttention algorithm, which is optimized to be hardware IO-aware, is proposed to drastically reduce redundant HBM IO and leverage SRAM for computational acceleration.
FlashAttention
FlashAttention The idea is that since the standard attention algorithm has to write S back to the HBM and this step is only for reloading the computed Softmax, we can save it in SRAM and then write the final result back to the HBM when all the intermediate steps have been performed. as shown in the following figure:
You can see that FlashAttention fuses multiple operations together by loading from the HBM only once, performing the fused arithmetic operation, and then writing the result back to the HBM.The fusion operation employs the following two main techniques:
- Tiling: matrix chunking, computes the reduction of a Softmax function without accessing the entire input, used in both forward and backward propagation;
- Recomputation: time-for-space, recomputation without storing intermediate attention matrices, used only in backward propagation.
The complete pseudo-code is below:
1. Tiling block calculations
For limited SRAM capacity, theN2N2 The storage usage makes the sequence length (N) limited to a certain range, so we have to perform a matrix chunking computation. For matrix multiplication with point-by-point operations (scale, masking, dropout) chunking is relatively easy to implement, the main obstacle is the Softmax function, which needs to couple all the score columns together. For this reason the researcher used a trick: since Softmax is related to the attention KK columns are coupled, by introducing two additional statistics m(x),l(x)m(x),l(x) to decouple and realize the chunked computation. The details are as follows:
m(x):=maxi xi, f(x):=[exi-m(x)...exB-m(x)], l(x):=∑i f(x)i, softmax(x):=f(x)l(x)m(x):=maxi xi, f(x):=[exi−m(x)…exB−m(x)], l(x):=∑i f(x)i, softmax(x):=l(x)f(x)
For two vectors x(1),x(2)∈ RBx(1),x(2)∈ RBThe decoupled splicing vectors x=[x(1),x(2)]∈ R2Bx=[x(1),x(2)]∈ R2B The Softmax calculation of the
m(x)=m([x(1),x(2)])=max(x(1),x(2)), f(x)=[em(x(1))-m(x)f(x(1)) em(x(2))-m(x)f(x(2))]m(x)=m([x(1),x(2)])=max(x(1),x(2)), f(x)=[em(x(1))−m(x)f(x(1)) em(x(2))−m(x)f(x(2))] l(x)=l([x(1),x(2)])=em(x(1))-m(x)l(x(1)) +em(x(2))-m(x)l(x(2)), softmax(x)=f(x)l(x)l(x)=l([x(1),x(2)])=em(x(1))−m(x)l(x(1)) +em(x(2))−m(x)l(x(2)), softmax(x)=l(x)f(x)
Note that it is possible to compute Softmax for multiple blocks in parallel at the same time using GPU multithreading.To take full advantage of the hardware performance, the computation of multiple blocks is not serial, but parallel.
2. Recalculation
To avoid generating redundant HBM read/write counts, FlashAttention does not keep a large intermediate results matrix for backward passes.
In the standard attention implementation, backward passes to compute the gradients of Q,K,V require the NxN intermediate matrices S,P , which are not preserved. The trick used for the study is to recalculate and save the two statistics m(x),l(x)m(x),l(x)The attention matrix S,P is recomputed by chunking it quickly on the high-speed SRAM for backward passes. This approach is much faster than the standard method.
test
Compared to the standard attention algorithm, FlashAttention effectively reduces the I/O of the HBM with a significant reduction in runtime, although the GFLOPs increase due to the need for recomputation for backpropagation, as shown in the figure below on the left:
Meanwhile, from the right side of the above figure, we can also see that as the Block Size increases, the number of HBM accesses decreases, and the running time also decreases. When the Block Size exceeds 256, even though the number of HBM accesses is decreasing, the runtime does not decrease. This is when performance is limited by other factors, for example, computational constraints. It should also be noted that a larger Block Size may cause the memory required to perform a fusion operation to exceed the size of the SRAM.
Experiments were conducted on an A100 graphics card and the acceleration of FlashAttention is shown below:
Memory changes for:
It can be seen that combining dropout and masking at different sequence lengths has different degrees of acceleration; as the sequence length increases, FlashAttention has a continuous optimization effect on memory consumption.
summarize
The maximum sequence length of the inputs and outputs of most large language models is only 2K or 4K, essentially because the computational and spatial complexity of the self-attention block, the core component of the transformer, is O(N2)O(N2) The success of FlashAttention inspires us that deep learning model optimization and acceleration can be achieved through chunking, operator fusion, and recomputation techniques, which are of great relevance for AI industrial practice to move into the deep water.
- 本文作者: suopu
您可以转载、不违背作品原意地摘录及引用本技术博客的内容,但必须遵守以下条款: 署名 — 您应当署名原作者,但不得以任何方式暗示幻方为您背书,亦不会对幻方的权利造成任何负面影响。 非商业性使用 — 您不得将本技术博客内容用于商业目的。 禁止演绎 — 如果基于该内容改编、转换、或者再创作,您不得公开或分发被修改内容,该内容仅可供个人使用。







