FA1
首先把老生常谈的 scale dot-product attention formula 拿出来:
$$ Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V $$
| 存储类型 | 全称 | 能否读写 | 典型延迟 | 典型用途 | 是否进入 GPU 性能模型 |
|---|---|---|---|---|---|
| SRAM | Static RAM | ✅ 读写 | 1–30 cycles | 寄存器、LDS、Shared、L1 | ✅ 核心 |
| DRAM | Dynamic RAM | ✅ 读写 | 300–1000 cycles | 外部显存 | ✅ 核心 |
| HBM | High Bandwidth DRAM | ✅ 读写 | 300–800 cycles | GPU 显存 | ✅ 核心(DRAM 的一种) |
| SROM / ROM | Static Read-Only Memory | ❌ 只读 | 快 | 存微码、固件 | ❌ 不讨论 |
这个经典的 FA 图,左边就是多级显存的访存效率,右边的图说明了以 GPT-2为例,pyTorch 原生实现的 Attention 和 FlashAttention 进行 Kernel Fusion后的时间对比。
而中间的图说明了 FA 的计算架构。最关键的一点是,对于 Softmax 后生成的注意力分数进行 Kernel fuse,不去显式地实现注意力分数。这样在显存中就不会出现完整的分数的传输。
而在反向传播中,需要注意力分数的时候就进行重算,来避免当 sequence length 很大时的 $O(n^2)$ 的数据传输。
首先说明Naive 实现:
最朴素的 GPU 实现通常是三步、多个 CUDA kernel 或 GEMM:
- Kernel/GEMM1: 计算
S = Q Kᵀ
- 用
cublasGemmEx做一个(N×d) × (d×N) → (N×N) - 把
S存到 HBM(显存)里
- Kernel2: 对 S 行做 softmax(含减 max + exp + 归一化)
- 从 HBM 读
S - 做 softmax
- 把
P写回 HBM
- Kernel/GEMM3: 计算
O = P V
- 再用 GEMM:
(N×N) × (N×d) → (N×d)
对于多个 Kernel 可以想到的第一个思路就是 fuse,把 “S 计算 + softmax” 融合成一个大 kernel:块状加载 Q、K 到 shared memory,直接算 logits,立刻 softmax,减少一次写 S 的开销。
而 FA1 实现思路:完全不要显式存整个 S/P 矩阵到 HBM,改成“分块算 attention + 在线 softmax + 累积输出”。
最主要的是 Softmax 的实现,他每次要看整行的数据来进行归一化,所以要用递推公式保证数值稳定和正确:
3.2 在线 softmax(Online Softmax):支持“分块 / 多次看到同一行”的 softmax
普通 softmax 对一行 ( s_i ) 的公式:
$$ p_{i,j} = \frac{e^{s_{i,j}}}{\sum_j e^{s_{i,j}}} $$
但在 FA1 中,一行 ( s_i ) 的全部元素 ( s_{i,*} ) 是通过多个 K block 才“分段看到”的,
所以要用一个递推公式来保持 数值稳定性与结果正确性。
对每一行 ( i ),维护三个量:
- ( m_i ):当前见过的所有 logits 的最大值
- ( l_i ):当前见过的所有 logits 的 $\sum_j e^{s_{i,j} - m_i}$ 。
- ( acc_i ):加权和 $\sum_j p_{i,j} \cdot v_j$ 。的累积
当看到一个新的 block 的 logits $s_{i,j}^{(\text{block})}$ 时:
1)计算新的最大值:
$$ m_i^{new} = \max \left( m_i^{old}, \max_j s_{i,j}^{(\text{block})} \right) $$
2)对旧的累积进行重标定(rescale):
$$ l_i^{new}= e^{m_i^{old} - m_i^{new}} \cdot l_i^{old}+ \sum_j e^{ s_{i,j}^{(\text{block})} - m_i^{new} } $$
3)输出加权和同样进行重标定并加上新块贡献:
$$ acc_i^{new}= e^{m_i^{old} - m_i^{new}} \cdot acc_i^{old}+ \sum_j e^{ s_{i,j}^{(\text{block})} - m_i^{new} } \cdot v_j $$
循环完所有 block 之后:
$$ O_i = \frac{acc_i}{l_i} $$
等价于下面的公式:
上述即 FA1 的原理实现,下面用 Triton 进行实现:
下面的代码是在串行合并,所以只有一个 pre_max -> cur_max 的更新,如果一旦变成并行分块,就会出现多个分块来进行分割的结构。
# 计算公式
# S = Q * (K.transpose(-1, -2))
# P = softmax(S)
# O = P * V
import numpy as np
N = 4
d = 2
# 这里的 N 就是 seq_len,以一行作为示例
S = np.random.random(size=(1, N))
V = np.random.random(size=(N, d))
def tiled_softmax_then_matmul(S, V):
# 分数
acc = np.zeros(shape=(1, d)) # 到目前为止输出的加权和
pre_max = float("-inf") # 到目前为止的最大值
pre_sum = 0 # 到目前为止 logits 的指数和
for i in range(N): # 每个token,KV的列维度,为了简洁,这里把Q的行维度设为了1,因此没有了内循环
s_i = S[:,i] # 每列S
cur_max = max(pre_max, s_i) # 当前分块和之前分块一起的最大值
# 将之间的 调整因子调整为减去 cur_sum
pre_sum *= (np.exp(pre_max - cur_max)) # L10
# 当前分块和之前分块一起的指数和
cur_sum = pre_sum + np.exp(s_i - cur_max)
# 当前分块的softmax结果
score = np.exp(s_i - cur_max) / cur_sum # 到目前为止,当前块的 score 已经计算完毕,但是之前的 logits 还是以 pre_sum 作为底来计算的,所以还需要调整
scale = pre_sum / cur_sum # 因为上一个分块的结果是基于当时的softmax中间sum组成的分母(presum),现在这个分块又得到了新的中间sum(cursum),所以需要更新:对上一个分块的结果acc做一个scale,保证结果的正确性
acc *= scale # 进行更新之前的 logits
acc += score * V[i,] # scale后的中间结果加上当前分块的P * V = O
# 更新
pre_max = cur_max
pre_sum = cur_sum
return acc关于 fa1,有一些优化可以提出来:
- 比如Q的循环可以放在外面,消除每次KV外循环都要去访问Q的开销;
- 每次外循环都需要去用presum/cursum去rescale中间结果以得到最终结果,这些计算可以通过算法消除;--->>> 这里我的理解就是可以将这个缩放调整因子放到GEMM 中来增加 tensorcore 的flops;
补充,PCIe 和 NVlink 带宽:
| 互联方式 | 单向理论带宽 | 现实可达 |
|---|---|---|
| PCIe 3.0 x16 | ~16 GB/s | ~12–13 GB/s |
| PCIe 4.0 x16 | ~32 GB/s | ~25–30 GB/s |
| PCIe 5.0 x16 | ~64 GB/s | ~50–60 GB/s |
| NVLink V2(V100) | ~50 GB/s/链路 | ~45 GB/s |
| NVLink V3(A100) | ~100 GB/s/链路 | ~90 GB/s |
| NVLink V4(H100) | ~150 GB/s/链路 | ~140 GB/s |
| GPU HBM 本地带宽 | 1–3 TB/s | ✅ 实测可达 |
FA2
之前我们理解到,在 fa1 中,kv 的列维度为外循环,qo 的行维度为内循环;
以A100为例,其FP16/BF16 tensorcore的理论峰值吞吐量为312 TFLOPS,但FP32非矩阵乘法在CUDA core上仅有19.5 TFLOPS.
这里引入 diag,左乘就是对矩阵每行都进行相同的处理(矩阵的常规处理居然忘了)
上图中FA1 其实隐含了 Tile,但是没有显式提出来。
fa2 的话,为了尽可能减少非矩阵乘的计算量,做了改进:省略了每次的分母的更新,在上图中的第二点,才是真正分块实现时候的表示,FA1 加入了调整因子;FA2 不需要对分母调整,只需要对分子调整即可,注意这里的 diag 也没有了 ^{-1},而 FA2 在最后再进行全局的 $diag(l)^{-1}$ .
上边说到,将Q 和 KV 的循环顺序进行修改能够得到性能提升,我有点不解,在这里解释:其实 Q 和 KV 循环调换我没理解透,但是 /l 来避免全局的同步点会有很大提升。
对seqlen维度充分并行:这一点主要考虑到batchsize*numheads小于SM个数的情况下,无法打满SM算力,此时seqlen一般都很大,需要对seqlen维度充分并行。主要的实现就是在于FlashAttention-2将Q移到了外循环,KV移到了内循环,由于改进了算法使得warps之间不再需要相互通信去处理Q,所以外循环可以放在不同的block上。这个交换的优化方法是由Triton的作者提出并实现的。如下图,左图需要各个warp之间做reduce才能算出一行的结果,右图则不需要,它每个warp都可以独立计算出一整块softmax和一整块O结果。