Implementing Flash Attention
Programming LLMI previously built a large language model from scratch as well as a trainable tokenizer to more deeply understand modern LLM architecture. Naturally, the next step was to pretrain a small version of this model on my local computer. I was shocked at how quickly I ran out of memory, even for small models with short context lengths. After a little investigation, I determined that our multi-head attention module was responsible for significantly more memory usage than its modern counterparts. This eventually led me to flash attention, which I had certainly read about but had never implemented. I had previously thought that this was just a boring engineering tweak on top of attention. Although the algorithm is motivated by hardware constraints, I learned that the ideas and mathematical underpinnings were surprisingly interesting.
Following the ethos of my local large language model project, my only option was to rederive the algorithm to obtain a deep understanding before implementing it from scratch. Before we can derive the flash attention algorithm, we first need to understand the issues that it aims to address in the standard scaled dot product attention.
Motivation
In our naive implementation, within each head we calculate
$$ \text{Attention}(Q, K, V) = \text{softmax} \Bigg(\frac{Q K^T}{\sqrt{d}}\Bigg) V $$When we calculate $S = Q K^T$, we are calculating an $N \times N$ matrix, where $N$ is the sequence length. After applying softmax row-wise on $S$, we still have an $N \times N$ matrix, and then multiplying by $V$ produces our $N \times d$ output. So the memory usage is clearly $\mathcal{O}(N^2)$ in the standard formulation, which produces the infamous quadratic scaling of sequence length in both compute and memory usage that has inspired many attention alternatives.
Now to add in the specific hardware considerations, the other component in understanding performance is memory bandwidth. When we calculate $S$, it will be written to and then read from high-bandwidth memory (HBM), which acts like the global GPU memory reserve. This means we also have $\mathcal{O}(N^2)$ in memory bandwidth traffic. Now modern GPUs are generally compute-rich but bandwidth-starved, especially at the scale of operations for a language model small enough to fit in a consumer GPU. So there is opportunity to speed up the attention calculation if we can trade down memory bandwidth operations for additional compute operations. Specifically, we want to maximize the reuse of data in SRAM (shared memory), which behaves much like a manually managed L1 cache if you are more familiar with CPU hardware design. Therefore, the challenge is to implement attention in a way that the tensors remain small enough to reside in SRAM.
With this in mind, the goal of flash attention is to avoid materializing the full $N \times N$ matrix and instead decompose the calculation into smaller sub-problems. By computing smaller pieces of the full problem, we can avoid storing the full matrix, which will reduce the memory requirements of pretraining. Additionally, there are additional opportunities to cleverly fuse together operations to avoid storing intermediate results, reducing memory bandwidth. So in short, the two gains we hope to achieve with flash attention are:
- Reducing memory bandwidth by better utilizing SRAM to avoid repeated global memory reads and writes.
- Reducing memory usage by avoiding the storage of full $N \times N$ matrices when we know the final output is of dimension $N \times d$.
Finally, we should note that flash attention is not approximating attention. Rather, it is a hardware-aware algorithm aimed at improving the memory and memory bandwidth scaling without impacting the underlying mathematical result. So adding flash attention offers pure upside for training, minus some additional implementation complexity.
Tile-Based Computation
To help motivate the math derivations required to implement flash attention, let’s start with the big picture behind the algorithm. In the naive attention that we already implemented, the algorithm for calculating attention looks something like
$$ S = Q K^T $$ $$ P = s(S) $$ $$ O = P V $$within a scaling factor, omitted for clarity, where we have used the shorthand $s = \text{softmax}$. As we already discussed, this leads to $\mathcal{O}(N^2)$ memory usage and memory bandwidth traffic. The idea of flash attention is to break down the computation into blocks, which are eventually fused into the full attention result. Decomposing the calculation into blocks allows everything to stay in SRAM by minimizing the number of HBM reads and writes, improving the memory bandwidth bottleneck. Additionally, the total memory usage drops from $\mathcal{O}(N^2)$ to $\mathcal{O}(N \cdot d)$ through fusing intermediate results. At a high level, the algorithm looks something like
for block_idx in block_indices:
Q_block, K_block, V_block = form_blocks(Q, K, V, block_idx)
S_block = calculate_block_product(Q_block, K_block)
partial_softmax = calculate_partial_softmax(S_block)
softmax_stats = update_softmax_stats(partial_softmax, softmax_stats)
attention_contribution = calculate_contribution(
partial_softmax, softmax_stats, V_block)
The main challenge here comes from the softmax calculation, which requires global information not found in a single tile. We encounter a similar challenge when we try to assemble the attention contributions into a single attention score computed across all of the tiles. We can frame both of these as mathematical problems and derive the appropriate update rules that allow us to properly weigh the contribution from each tile to build up to the final answer even without all of the global information available.
Streaming Softmax
Let’s consider a single row of a matrix $x \in \mathbb{R}^N$. We can compute the standard softmax according to
$$ s(x)_j = \frac{e^{x_j}}{\sum_{k = 1}^{N} e^{x_k}} $$or
$$ s(x)_j = \frac{e^{x_j}}{Z} $$after defining a partition function $Z \equiv \sum_k e^{x_k}$. However, this version of softmax is prone to numerical instability. So instead, we want to define $m \equiv \max_{k} x_k$ and calculate
$$ s(x)_j = \frac{e^{x_j - m}}{\sum_{k = 1}^{N} e^{x_k - m}}, $$which is numerically stable and robust to overflows. However, this does make the update rule more complicated, since $m$ needs to be calculated globally across an entire row. Let’s see how we can do this. But first, let’s introduce one final definition, $\ell \equiv \sum_k e^{x_k - m}$, which allows us to write our partition function as
$$ Z = \sum_k e^{x_k} $$ $$ Z = e^m \sum_k e^{x_k - m} $$ $$ Z = e^m \ell $$Let’s now think about how we would perform this calculate when $x$ is split into two chunks. That is, imagine we have
$$ x = \left[\begin{array}{c|c} x_A & x_B \end{array}\right] $$and that we have already computed the statistics over the sub-vector $x_A$: $m_A = \max_{k \in A} x_k$ and $\ell_A = \sum_{k \in A}e^{x_k - m_A}$. So we can already write
$$ Z_A = e^{m_A} \ell_A $$We want to write something like $Z = Z_A + Z_B$, but this won’t be true in general unless $m_A = m_B$. Both partition functions will use the same maximum $m = \max\{m_A, m_B\}$. Using this definition, we can write
$$ e^{m_A} \ell_A = e^{m_A + m - m} \ell_A $$ $$ e^{m_A} \ell_A = e^m e^{m_A - m} \ell_A $$and analogously for $Z_B = e^{m_B} \ell_B$. Thus, the correct partition function expressed in terms of $m_A$, $m_B$, and $m$ is
$$ Z = e^m \Big(e^{m_A - m} \ell_A + e^{m_B - m} \ell_B\Big) $$This provides an update rule if we have previously calculated statistics over $x_A$ and now want to update the softmax for the contribution from $x_B$:
$$ m_{\text{new}} = \max(m_{\text{old}}, m_{\text{tile}}) $$ $$ \ell_{\text{new}} = e^{m_{\text{old}} - m_{\text{new}}} \ell_{\text{old}} + e^{m_{\text{tile}} - m_{\text{new}}} \ell_{\text{tile}} $$Streaming Attention Contribution
We just derived the rules required to calculate a streaming version of softmax across multiple tiles. But the full attention equation is given by
$$ A(Q, K, V) = \Bigg(\frac{Q K^T}{\sqrt{d}}\Bigg) V $$To avoid storing any intermediate results, we also need to derive how to apply a tile-level calculation of softmax to a tile subset of the full $V$ tensor. Deriving an update rule for
$$ O = s(S) V $$in terms of the partial softmax will allow us to keep the entire calculation in SRAM. To begin the derivation, we will start with a very similar setup to before. We can write
$$ O = \sum_k s(x)_k V_k $$for some vector $x \in \mathbb{R}^N$. Returning to a partition function, we can write this as
$$ O = \frac{1}{Z} \sum_k e^{x_k} V_k $$Now, we will split up the vector $x$ into two subsets, $A$ and $B$, again. This allows us to write the numerator in the above equation as
$$ N = \sum_{k \in A} e^{x_k} V_k + \sum_{k \in B} e^{x_k} V_k $$If we define $O_A \equiv \sum_{k \in A} e^{x_k - m_A} V_k$, we can instead write
$$ N = e^{m_A} O_A + e^{m_B} O_B $$ $$ N = e^m \Big( e^{m_A - m} O_A + e^{m_B - m} O_B \Big) $$So if we define our updated output as
$$ O = e^{m_A - m} O_A + e^{m_B - m} O_B, $$We arrive at a similar update rule for $O$ based on the contribution from each tile
$$ O_{\text{new}} = e^{m_{\text{old}} - m_{\text{new}}} O_A + e^{m_{\text{tile}} - m_{\text{new}}} O_{\text{tile}} $$This now provides the final update equation that allows us to calculate attention entirely from individual blocks, noting that we will require a final normalization by $\ell$ once all tiles have contributed to $O$.
Implementation
My implementation of flash attention using the two update rules that we derived can be found in my repo. The implementation should look quite familiar if you have followed the math up until this point, since it is mostly a direct translation of math into pytorch code. Just as before, I was very explicit about statically typing tensor dimensions as well as documenting any changes to these dimensions to make the algorithm easy to read. The implementation also adds dropout and causal masking, since these were features of the original implementation that we are supposed to be improving.
If you bought into the argument of the merits of flash attention at the start of the article, you might think that we are capturing the full benefits in this implementation. However, even without a deep understanding of GPU hardware, you will find some obvious red flags. For both the outer loop, where we create a single block of the $Q$ tensor that can be reused and stored in SRAM, and the inner loop, where we also block $K$ and $V$, we are using simple python for loops. This is always a bad idea when dealing with linear algebra if the goal is performance, and even moreso in deep learning where we interrupt GPU calculations to go back to the CPU. Since I was interested in understanding and illustrating the algorithm, my focus was on carefully tracking the logic of the algorithm through improved readability at the expense of performance.
I did want to understand how successful my implementation was in achieving the original goals. Remember that I started down this rabbit hole because I was alarmed at the amount of VRAM required to train even small networks. So maybe we are comfortable punting on the actual performance, but we did too much work to not improve upon the training memory requirements, right? At least, that was my hope. I was also curious to quantify exactly how slow our implementation performed relative to the original multi-headed attention. So I created a simple profile script that looks at latency in both the forward and forward + backward passes as well as the peak memory used during each. The results actually tell a fairly compelling story: while our flash attention is slower than molasses, it does dramatically reduce the memory requirements:
Hybrid Flash Attention
While this side quest has become a bigger undertaking than I was prepared to commit to, we have not quite achieved our true goal. I pursued this path to make training a (small) large language model possible on my meager consumer hardware. The constant OOM errors should be remedied based on the profiler results, but I don’t think I have made much meaningful progress on improving training feasibility if I have increased the latency by an order of magnitude along the way. Training a single model for a few days is much more palatable than training for an entire month. So let’s return to the implementation to try to address the root cause of the slowdown: those pesky for loops.
Initially, I briefly considered an extreme thought: if we are interested in performance, I could just write the kernel in either CUDA or Triton. I do have some experience programming in CUDA, although that was many years ago. However, as I learned more about what would be required to use CUDA or take the simpler path with Triton, I quickly realized that the return on my significant time investment would be poor. If we were going to make this algorithm work, we would need to continue to rely on pure pytorch to perform our calculation, which can then benefit from the standard auto-grad calculation.
If we take a close loop at the loops that we performed, we notice that both of these loops don’t actually look super necessary. While they are convenient from the perspective of a developer, they carry significant overhead, especially when we are dealing with small enough matrices where the memory bandwidth can be comparable to the computation time. The outer loop seems pretty challenging to avoid in our implementation, but the inner loop can be re-expressed in terms of linear algebra if we are willing to make some concessions.
After significant experimentation, I found that a hybrid approach to adding flash attention struck the best balance between algorithm fidelity, inference time, and memory usage. I was able to eliminate the inner block by skipping our fancy online softmax normalization trick. In the hybrid implementation, we bypass the streaming softmax accumulation entirely, instead accumulating the softmax only at the level of individual query blocks. We are effectively trading a bit of the memory savings for faster computation, since eliminating the second for loop means that we are launching fewer, larger kernels. The details of the implementation can be found in the repo.
While it saddened me to throw away our careful derivations that closely followed the original paper to fall back on something resembling the original attention calculation, the hybrid approach that I developed brings serious upgrades when we turn to profiling.
Algorithmic Performance Understanding
It is worth looking into the algorithm design a bit to understand the trade-offs. The main difference is that the inner loop has been eliminated. Instead, for each query block, all of the causally keys and values are sliced and then scored in a single einsum call. This replaces the online softmax normalization trick that we developed in favor of computing the softmax statistics in a single pass. This is actually much simpler, since there is no longer incremental accumulation.
Speed
Let’s look more deeply into our algorithm and profile results to understand how exactly the algorithm runs on our hardware. The first interesting result comes from looking at the forward pass speed-up. Recall that the standard multi-head attention will write the full attention score matrix to HBM and then re-read it multiple times. Tracing through the algorithm, we see:
- This matrix is written once by $\texttt{matmul}(Q, K^T)$.
- It is read + written by calling $\texttt{masked_fill_}$.
- It is read + written by $\texttt{softmax}$.
- It is finally read again to calculate $\texttt{matmul}(\texttt{attn_weights}, V)$.
The full attention matrix has dimensions $B \times N_H \times T \times T$, so using the profiler default variables ($B$ = 4, $N_H$ = 12, $T$ = 2048), we are dealing with about 800 MB with 32-bit floats. So we are streaming many gigabytes through HBM, which becomes much more of the bottleneck than the FLOPs.
Now, we can compare this to the hybrid implementation we developed. We immediately face the familiar downside of the overhead associated with our for loop, and this dominates the latency for small sequence lengths. However, this implementation greatly improves upon the memory bandwidth by avoiding materializing the full attention matrix. The largest tensor per block is instead given by $B \cdot N_H \times B_q \times q_{end}$, which never grows beyond $48 \times 64 \times 2048$, or about 25 MB. This is small enough to reside in the L2 cache of our GPU. As a result of this, we have roughly half the total reads and writes from HBM compared to the standard multi-head attention implementation. So the fixed cost of the python for loop dominates for small sequences, but as the sequences grow in size, the improvements to the memory bandwidth become significant.
We observe that the speed up in the forward + backward follows a similar pattern, although calling .backward() now adds Python graph-traversal overhead on top of the inference overhead. For this reason, we do not see that the speed ever favors our hybrid algorithm, although we are paying a much smaller speed cost than before, reaching over 90% of the original MHA speed.
Memory
At a high level, we see the expected big wins in memory usage. In the standard MHA implementation, we need to store the attention weights of size $B \times N_H \times T \times T$ as already mentioned. In either version of flash attention, we realize significant savings by never materializing this full object. We are able to preserve nearly all of the memory savings from the original flash attention implementation with tensors in our size range of interest. Digging more deeply, though, we find some interesting changes in comparing the memory reduction factors produced from the vanilla flash attention and our hybrid approach.
First, the forward pass memory reduction factor drops from ~11 to ~9 for the forward pass in adopting our hybrid approach. In the vanilla flash implementation, the query and key dimensions are both tiled, so that the largest tensor stored has shape $B \cdot N_H \times B_q \times B_k$ independent of the sequence length. However, the hybrid flash attention slices across all causally-visible keys at once, so the shape of the largest object grows with the sequence length to a maximum size of $B \cdot N_H \times B_q \times T$. Therefore, we would expect the vanilla implementation to beat our hybrid implementation by a factor of $B_k / T$.
The situation gets more interesting when we also consider the backward pass. In the joint forward + backward pass evaluation, the hybrid implementation actually improves upon the memory reduction factor of ~2 to reach a factor of ~2.75. Remember that autograd needs to save the full set of tensors needed to recompute the gradients. Both implementations produce the same final attention weight shape, but they introduce differing amounts of overhead related to backpropogation specific to their implementations.
In the vanilla flash attention, we require gradients of the transposed key copies K_jT as well as the chain of intermediate tensors required to calculate the streaming statistics. The hybrid implementation actually avoids both of these sources of overhead. Both K_causal and V_causal are simple slices of the original tensors, so there is no reshaping or allocation that adds an additional autograd node. And obviously there is no streaming dependency chain without the streaming statistics at all. In this way, the hybrid approach actually further reduces the memory demand while training, which was the specific motivation behind this entire exploration.