Flash Attention

Much progress in AI over the past few years has been fueled by the transformer architecture. Transformers are the closest thing we have right now to machine learnable programs. They can be trained to generate images, text, videos, audio, video games, or even raw byte sequences, you name it.

Behind the transformer, powering many of these applications, there are two key operations which make up 99% of the FLOPs: attention and feed-forward layers. These are conceptually very simple, although very compute intensive. In this post, my goal is to expand on the attention operation and how to efficiently implement it using an algorithm called Flash Attention.

Attention layers are responsible for mixing information between words (or tokens). Without attention, each token wouldn’t know about all other tokens in the sequence. The input to an attention layer is a set of queries $Q \in \R^{N \times d}$, keys $K \in \R^{N \times d}$, and values $V \in \R^{N \times d}$, where $N$ is the number of tokens in the sequence and $d$ is the dimension of each token. The output is a set of attended values $O \in \R^{N \times d}$.

The attention operation is defined as: $O = \text{softmax}(\frac{QK^T}{\sqrt{d}})V$. The output is a convex combination of the values, where the weights are given by the softmax of the dot product between queries and keys, also called the attention score.

While the naive implementation of this formula is straight-forward, it is very memory inefficient. That’s where Flash Attention comes in.

The trick of Flash Attention is to merge the three operations query-key matmul, softmax, and value matmul into a single operation. That means we don’t need to transfer these matrices from high-bandwith memory (HBM) to SRAM and back three times (slow), but do it only once which speeds up the computation significantly on GPUs.

Here’s a simplified implementation of Flash Attention in Python using numpy.

 1import numpy as np
 2
 3
 4def flash_attention(Q, K, V):
 5    # K, V should have shape [batch, sequence_in, head_dim]
 6    # Q should have shape [batch, sequence_out, head_dim]
 7    O = np.zeros(Q.shape[:-2] + (Q.shape[-2], V.shape[-1]))
 8    batch_size, query_seq_len, query_dim = Q.shape
 9    softmax_scale = np.sqrt(query_dim)
10    blk_q = 64
11    blk_kv = 64
12    for i in range(0, query_seq_len, blk_q):
13        # Allocate intermediate results:
14        O_i = np.zeros((batch_size, blk_q, V.shape[-1]))
15        l = np.zeros((batch_size, blk_q, 1))
16        m = None
17        # Load Q_i from HBM memory
18        # The outputs for that query will be computed
19        Q_i = Q[:, i : i + blk_q, :] / softmax_scale
20        for j in range(0, K.shape[-2], blk_kv):
21            # Load K_j and V_j from HBM memory
22            # [batch, blk_kv, head_dim]
23            K_j = K[:, j : j + blk_kv, :]
24            V_j = V[:, j : j + blk_kv, :]
25            # Attention scores for block Q_i,K_j
26            # [batch, blk_q, blk_kv]
27            S_ij = Q_i @ np.swapaxes(K_j, -1, -2)
28            # Softmax over current block but don't normalize yet
29            m_ij = np.max(S_ij, axis=2, keepdims=True)
30            P_ij = np.exp(S_ij - m_ij)
31            # Carry forward factor w_ij
32            # [batch, blk_q, 1]
33            w_ij = np.exp(m - m_ij) if j > 0 else 0.0
34            m = m_ij
35            # Output normalization factor l
36            # [batch, blk_q, 1]
37            l = w_ij * l + np.sum(P_ij, axis=2, keepdims=True)
38            # Outputs for block Q_i,K_j
39            # [batch, blk_q, head_dim]
40            O_i = w_ij * O_i + P_ij @ V_j
41        O[:, i : i + blk_q, :] = O_i / l
42
43    return O

Since the query, key and value matrices don’t fit into fast SRAM as a whole, we need to split the computation into multiple blocks. The query-key matrix multiplication can easily be split by computing $S_{ij} = Q_i K_j^T / \sqrt{d}$ where $i$ and $j$ denote chunks of $Q$ and $K$ at the sequence dimension. That gives us the unnormalised attention scores $S_{ij}$ for each block.

Now the softmax. As a reminder, softmax is defined as $A = \frac{\exp(S)}{\sum_j \exp(S_{j})}$. However, for numerical stability, we usually compute it as $A = \frac{\exp(S - \max(S))}{\sum_j \exp(S - \max(S))}$. The problem is that both max and sum depend on the full matrix $S$, so we can’t compute them independently for each block. The way Flash Attention deals with this, is by first computing only $\exp(S_i)$ for each block and keeping track of the sum and max. Only after all computation is done, we divide by the overall sum $l = \sum_j \exp(S_{ij})$.

$$ O_i = \frac{\sum_j\exp(S_{ij})V_{j}}{\sum_j \exp(S_{ij})} $$

See how both sums are over $j$? That means we can compute them independently without needing to know the full matrix $S$.

Let’s check that the results match other implementations.

 1import torch
 2
 3def ref_attention_1(Q, K, V):
 4    S = (Q @ np.swapaxes(K, -1, -2)) / np.sqrt(Q.shape[-1])
 5    S = np.exp(S)
 6    S = S/np.sum(S, -1)[..., None]
 7    return S @ V
 8
 9def ref_attention_2(Q, K, V):
10    return torch.nn.functional.scaled_dot_product_attention(
11        query=torch.from_numpy(Q),
12        key=torch.from_numpy(K),
13        value=torch.from_numpy(V)
14    ).numpy()
15
16
17Q = np.random.uniform(size=(4, 4096, 32))
18K = np.random.uniform(size=(4, 4096, 32))
19V = np.random.uniform(size=(4, 4096, 32))
20
21ours = flash_attention(Q, K, V)
22np.testing.assert_allclose(ref_attention_1(Q, K, V), ours)
23np.testing.assert_allclose(ref_attention_2(Q, K, V), ours)

I hope that this short deep dive into Flash Attention was helpful! If you have any questions, feel free to reach out to me on Twitter.