多头注意力(MHA)和分组查询注意力(GQA)的伪代码实现:
1. 多头注意力(Multi-Head Attention, MHA)
def MultiHeadAttention(Q, K, V, num_heads, d_model):
    batch_size, seq_len = Q.size(0), Q.size(1)
    d_k = d_model // num_heads  # 每个头的维度
    
    # 1. 线性投影并分割头
    Q = linear(Q).view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)
    K = linear(K).view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)
    V = linear(V).view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)
    
    # 2. 计算缩放点积注意力
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    attn = softmax(scores, dim=-1)
    context = torch.matmul(attn, V)
    
    # 3. 合并所有头并线性变换
    context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
    output = linear(context)
    return output关键步骤说明:
- 将输入 
Q/K/V分别投影到num_heads个独立的头。 - 每个头计算独立的注意力权重。
 - 合并所有头的输出并通过线性层得到最终结果。
 
2. 分组查询注意力(Grouped-Query Attention, GQA)
def GroupedQueryAttention(Q, K, V, num_heads, num_groups, d_model):
    batch_size, seq_len = Q.size(0), Q.size(1)
    assert num_heads % num_groups == 0
    heads_per_group = num_heads // num_groups
    d_k = d_model // num_heads
    
    # 1. 投影 Q 到 h 个头,K/V 到 g 个头
    Q_proj = linear(Q).view(batch_size, seq_len, num_heads, d_k)  # [B, L, h, d_k]
    K_proj = linear(K).view(batch_size, seq_len, num_groups, d_k) # [B, L, g, d_k]
    V_proj = linear(V).view(batch_size, seq_len, num_groups, d_k) # [B, L, g, d_k]
    
    # 2. 扩展 K/V 以匹配每个组内的头数
    K_proj = K_proj.unsqueeze(2).expand(-1, -1, heads_per_group, -1, -1)  # [B, L, k, g, d_k]
    V_proj = V_proj.unsqueeze(2).expand(-1, -1, heads_per_group, -1, -1)  # [B, L, k, g, d_k]
    K_proj = K_proj.reshape(batch_size, seq_len, num_heads, d_k)          # [B, L, h, d_k]
    V_proj = V_proj.reshape(batch_size, seq_len, num_heads, d_k)          # [B, L, h, d_k]
    
    # 3. 调整维度并计算注意力
    Q_proj = Q_proj.transpose(1, 2)  # [B, h, L, d_k]
    K_proj = K_proj.transpose(1, 2)  # [B, h, L, d_k]
    V_proj = V_proj.transpose(1, 2)  # [B, h, L, d_k]
    
    scores = torch.matmul(Q_proj, K_proj.transpose(-2, -1)) / math.sqrt(d_k)
    attn = softmax(scores, dim=-1)
    context = torch.matmul(attn, V_proj)
    
    # 4. 合并输出
    context = context.transpose(1, 2).reshape(batch_size, seq_len, d_model)
    output = linear(context)
    return output关键步骤说明:
- 将 
Q投影到num_heads个查询头,K/V投影到num_groups个键值头(num_groups < num_heads)。 - 对 
K/V进行扩展,使每个查询组共享相同的键值头(例如,组内k个查询头共享 1 个键值头)。 - 计算注意力时,组内的多个查询头复用同一组键值头,减少计算量。
 - 最终合并结果并通过线性层。
 
核心区别:
| 特性 | MHA | GQA | 
|---|---|---|
| 键值头数 | 与查询头数相同(num_heads) | 
少于查询头数(num_groups) | 
| 计算复杂度 | 较高(每个头独立计算) | 较低(组内共享键值头) | 
| 应用场景 | 标准 Transformer | 大模型推理优化(如 LLaMA-2) | 
通过分组共享键值头,GQA 在保持表达能力的同时显著提升了推理效率。