# 图解Transformer并行计算:从矩阵乘法切分到Megatron架构设计
在构建千亿乃至万亿参数规模的大语言模型时,我们面临的核心矛盾是:单个GPU的显存容量与计算能力,与模型巨大的参数量及计算需求之间存在难以逾越的鸿沟。传统的单卡训练模式早已失效,分布式并行训练成为唯一可行的路径。然而,并行并非简单地将模型或数据“分而治之”,其背后是一套精密的数学切分策略与通信模式设计。今天,我们就来深入拆解NVIDIA Megatron-LM框架中张量并行的核心思想,用可视化的方式,理解如何将庞大的权重矩阵“庖丁解牛”,并分析不同通信原语(如AllReduce与All-Gather)的选择如何深刻影响训练效率。无论你是正在搭建自己的分布式训练框架,还是希望优化现有的大模型训练流水线,理解这些底层原理都至关重要。
## 1. 张量并行的基石:矩阵乘法的两种切分视角
要理解复杂的Transformer模型如何被拆分,我们必须回到最基础的运算单元:矩阵乘法。假设我们有一个前向传播操作 `Y = X @ W`,其中输入 `X` 的形状为 `[batch_size, sequence_length, hidden_size]`(简写为 `[b, s, h]`),权重 `W` 的形状为 `[h, h']`。当 `W` 过于庞大,无法存入单张GPU显存时,我们就需要对其进行切分。
### 1.1 按行切分(Row-wise Partitioning)
按行切分,即沿着权重矩阵 `W` 的行维度(对应输入特征维度 `h`)进行分割。假设我们有两张GPU(GPU0和GPU1),那么 `W` 被切分为 `W0` 和 `W1`,形状分别为 `[h/2, h']`。
**关键问题**:`X` 的形状是 `[b, s, h]`,其列维度 `h` 与 `W` 的行维度 `h` 需要对齐才能做矩阵乘法。现在 `W` 被按行切分了,`X` 该如何处理?
**解决方案**:将输入 `X` 按列切分(Column-wise Partitioning)。具体来说,将 `X` 也沿着最后一个维度 `h` 切成两半,得到 `X0` 和 `X1`,形状为 `[b, s, h/2]`。这样,每张GPU上的计算就变成了:
- GPU0: `Y0_part = X0 @ W0`
- GPU1: `Y1_part = X1 @ W1`
注意,此时 `Y0_part` 和 `Y1_part` 都是完整输出 `Y` 的一部分贡献。为了得到最终的 `Y`,我们需要进行一次跨GPU的求和操作。这个过程在通信上对应一次 **AllReduce** 操作(具体是Reduce-Scatter后接All-Gather,但整体效果是求和并同步结果)。
> **提示**:这里的“按行切分权重”导致了“按列切分输入”,并且需要一次**求和归并**通信。这是理解后续更复杂切分策略的基础模式。
### 1.2 按列切分(Column-wise Partitioning)
按列切分,即沿着权重矩阵 `W` 的列维度(输出特征维度 `h'`)进行分割。同样在两张GPU上,`W` 被切分为 `W0` 和 `W1`,形状分别为 `[h, h'/2]`。
此时,输入 `X` 的维度 `[b, s, h]` 与 `W0`、`W1` 的行维度 `h` 完全匹配,因此**无需对 `X` 进行切分**。每张GPU可以持有完整的 `X` 副本进行计算:
- GPU0: `Y0 = X @ W0`
- GPU1: `Y1 = X @ W1`
计算得到的 `Y0` 和 `Y1` 的形状是 `[b, s, h'/2]`。它们不再是部分和,而是最终输出 `Y` 在列维度上的两个连续切片。因此,要得到完整的 `Y`,我们需要将 `Y0` 和 `Y1` 在列维度上拼接(Concatenate)起来。这个操作在通信上对应一次 **All-Gather** 操作。
**两种切分方式的对比与通信开销**
| 切分方式 | 权重 `W` 切分维度 | 输入 `X` 处理方式 | 输出 `Y` 聚合方式 | 前向传播通信原语 |
| :--- | :--- | :--- | :--- | :--- |
| **按行切分** | 行 (h) | 按列切分 | 求和 (Sum) | **AllReduce** (求和) |
| **按列切分** | 列 (h') | 广播 (或保留完整) | 拼接 (Concat) | **All-Gather** |
从通信量角度分析,假设每个元素是4字节(float32),那么:
- **AllReduce** 的通信量约为 `2 * b * s * h'` 字节(两次通信阶段,每次传输 `b*s*h'` 个元素)。
- **All-Gather** 的通信量约为 `b * s * h'` 字节(每张GPU发送自己拥有的 `b*s*(h'/2)` 数据给所有其他GPU)。
在 `h'` 较大的情况下,All-Gather的通信量仅为AllReduce的一半。这为我们在设计并行策略时提供了重要的优化方向:**优先采用产生All-Gather通信的切分方式**。
## 2. Transformer MLP层的张量并行策略
Transformer中的多层感知机(MLP或FFN)通常由两个线性层和一个非线性激活函数(如GeLU)组成:`Z = Linear_B( GeLU( Linear_A(X) ) )`。其中,`Linear_A` 将维度从 `h` 投影到 `4h`(中间扩展层),`Linear_B` 再投影回 `h`。
如果我们简单粗暴地将整个MLP层视为一个黑盒进行并行,可能会引入不必要的通信。Megatron-LM的精妙之处在于,它**对两个线性层采用了不同的切分策略**。
### 2.1 组合切分策略
MLP层的标准计算流程如下:
1. `Y = X @ A` (A 形状: `[h, 4h]`)
2. `Y' = GeLU(Y)`
3. `Z = Y' @ B` (B 形状: `[4h, h]`)
Megatron-LM的策略是:
- **对第一个线性层 `A` 采用列切分**。即 `A = [A1, A2]`,每个部分形状为 `[h, 2h]`(假设2卡并行)。输入 `X` 被广播到所有GPU。
- GPU0: `Y1 = X @ A1`
- GPU1: `Y2 = X @ A2`
- 此时 `Y1` 和 `Y2` 是 `Y` 的列切片,形状为 `[b, s, 2h]`。
- **GeLU激活函数在各GPU上独立计算**。因为 `Y1` 和 `Y2` 已经是独立切片,GeLU(`Y1`) 和 GeLU(`Y2`) 可以并行计算,无需通信。
- **对第二个线性层 `B` 采用行切分**。由于 `B` 需要与 GeLU 的输出相乘,且 GeLU 的输出是按列切分的,为了匹配维度,`B` 必须按行切分:`B = [B1; B2]`,每个部分形状为 `[2h, h]`。
- GPU0: `Z1 = GeLU(Y1) @ B1`
- GPU1: `Z2 = GeLU(Y2) @ B2`
- **最终输出 `Z` 通过求和得到**:`Z = Z1 + Z2`。这需要一次 **AllReduce** 通信。
**为什么这样设计?**
核心在于**避免在非线性激活函数前后进行昂贵的通信**。如果我们对 `A` 采用行切分,那么计算 `Y1_part` 和 `Y2_part` 后,需要先进行一次AllReduce求和得到完整的 `Y`,然后才能进行GeLU计算。这引入了一次额外的、在激活函数之前的AllReduce通信。而采用列切分,GeLU可以立即在局部切片上计算,将通信推迟到了第二个线性层之后,且通信量不变。
### 2.2 MLP层的通信模式总结
让我们用流程图来清晰展示两卡并行下MLP层的前向(FWD)与反向传播(BWD)过程:
```python
# 前向传播 (FWD) 伪代码示意
def MLP_forward_TP(X):
# X 被广播到所有GPU
# 列切分 A
Y_local = X @ A_local # 本地计算,无通信
Y_act_local = GeLU(Y_local) # 本地激活,无通信
# 行切分 B
Z_local = Y_act_local @ B_local # 本地计算,无通信
# 聚合输出
Z = all_reduce_sum(Z_local) # 通信点:AllReduce
return Z
```
反向传播是前向传播的镜像。梯度 `∂L/∂Z` 被广播到各GPU,各卡独立计算对本地参数 `B_local` 和 `A_local` 的梯度。在计算需要传递给前一层的梯度 `∂L/∂X` 时,各GPU先计算出本地的 `∂L/∂X_local`,然后通过一次AllReduce求和得到完整的 `∂L/∂X`。
因此,**一个MLP层在前向和反向传播中,各需要一次AllReduce通信**。
## 3. Self-Attention层的并行化:多头注意力的天然优势
Self-Attention层比MLP层更复杂,但其**多头(Multi-Head)机制**恰好为张量并行提供了极其优雅的切分点。
### 3.1 标准多头注意力计算回顾
对于一个注意力头,计算如下:
`Attention(Q, K, V) = softmax( (Q @ K^T) / sqrt(d_k) ) @ V`
其中 `Q = X @ W_q`, `K = X @ W_k`, `V = X @ W_v`。
在多头注意力中,`W_q`、`W_k`、`W_v` 被沿着输出维度(头维度)切分成 `num_heads` 份。每个头独立计算注意力,最后将所有头的输出在特征维度拼接,再经过一个输出投影层 `W_o`。
### 3.2 张量并行下的注意力层切分
Megatron-LM利用了多头并行的天然特性:
1. **QKV投影的并行**:将 `W_q`、`W_k`、`W_v` 三个权重矩阵**按列切分**。每个GPU负责一部分注意力头对应的权重切片。输入 `X` 被广播到所有GPU。
- 例如,2卡并行,16个头:GPU0负责头0-7的 `W_q0`、`W_k0`、`W_v0`;GPU1负责头8-15的 `W_q1`、`W_k1`、`W_v1`。
- 每张GPU独立计算自己负责的那些头的 `Q`、`K`、`V` 和注意力输出。
2. **注意力输出拼接**:各GPU计算出的多头注意力输出是完整输出在头维度上的切片。这些切片需要拼接起来。
3. **输出投影层的并行**:拼接后的张量需要经过输出投影层 `W_o`。为了匹配输入(按头切分后的拼接结果),`W_o` 必须**按行切分**。每个GPU用本地的 `W_o_local` 对拼接后的完整中间结果进行计算(这里需要一次All-Gather来获得完整的中间结果吗?不,这里有更优方案)。
实际上,为了减少通信,Megatron-LM采用了一种融合策略:**它并不先拼接所有头的输出,而是让每个头的输出直接与 `W_o` 对应的行切片相乘,然后再对结果进行求和归并**。
具体流程如下(以2卡为例):
- **前向传播**:
1. 各GPU计算本地头的 `Q`, `K`, `V` 和注意力输出 `Attention_output_local`。
2. 计算 `Z_local = Attention_output_local @ W_o_local`(`W_o` 被行切分)。
3. 对所有的 `Z_local` 进行 **AllReduce 求和**,得到最终的输出 `Z`。
- **反向传播**:
1. 梯度 `∂L/∂Z` 被广播到各GPU。
2. 各GPU独立计算对本地 `W_o_local` 和本地注意力头参数 `W_qkv_local` 的梯度。
3. 计算需要传给前一层的梯度 `∂L/∂X` 时,各GPU先算本地梯度,然后通过一次 **AllReduce 求和** 得到完整的 `∂L/∂X`。
**通信分析**:可以看到,Self-Attention层的通信模式与MLP层惊人地一致:**前向一次AllReduce,反向一次AllReduce**。通信量同样为 `Φ = b * s * h`。
> **注意**:这种设计巧妙地将原本需要的“All-Gather(拼接注意力头输出)+ 矩阵乘”转换为了“本地矩阵乘 + AllReduce(求和)”。由于AllReduce的通信量在优化后与All-Gather同量级,但计算流程更融合,往往能获得更好的性能。
## 4. 词嵌入层与输出层的并行化挑战
Transformer的输入输出端涉及巨大的词表(Vocabulary),其嵌入矩阵 `E` 的维度为 `[vocab_size, hidden_size]`,其中 `vocab_size` 可达数万甚至数十万。将这个矩阵放在单卡上会消耗大量显存。
### 4.1 输入词嵌入(Input Embedding)的并行
策略是**沿词表维度(行)切分嵌入矩阵**。每张GPU只存储一部分词对应的嵌入向量。
- **前向传播(查找)**:给定一个输入token ID序列,每张GPU根据自己维护的词表切片,查找对应的嵌入向量。对于不属于自己词表范围的token,则返回零向量。
- **通信**:各GPU查找完成后,得到一个局部嵌入结果。这些结果是**互斥**的(每个token的嵌入向量只存在于一张GPU上)。为了获得完整的嵌入张量,需要执行一次 **All-Gather** 操作,将所有局部结果拼接起来。但更高效的做法是执行一次 **AllReduce(求和)**,因为零向量不影响求和结果。这样,每个token的嵌入向量会从持有它的GPU广播到所有GPU。
### 4.2 输出层(Output Embedding & Loss)的并行
输出层通常与输入层共享权重矩阵。在计算logits(即隐藏层与输出嵌入矩阵的点积)时,面临巨大挑战:`[b*s, h] @ [h, vocab_size]^T`,其中 `vocab_size` 极大。
- **朴素方法**:各GPU用本地的输出嵌入矩阵切片计算局部logits,然后通过一次 **All-Gather** 收集所有logits以进行全局的softmax和交叉熵计算。通信量为 `b * s * vocab_size`,开销巨大。
- **Megatron的优化方法**:
1. 各GPU计算本地logits `logits_local`。
2. 计算本地logits的指数和 `sum_exp_local = sum(exp(logits_local), dim=-1)`。
3. 对所有GPU的 `sum_exp_local` 进行 **AllReduce(求和)**,得到全局的指数分母 `sum_exp_global`。通信量仅为 `b * s`。
4. 在每张GPU上,计算本地词表范围内的概率:`p_local = exp(logits_local) / sum_exp_global`。
5. 计算本地词表范围内的交叉熵损失 `loss_local`。
6. 对所有GPU的 `loss_local` 进行 **AllReduce(求和)**,得到全局总损失。通信量仅为 `GPU数量`。
这种优化将通信量从 `O(vocab_size)` 降低到了 `O(1)`,对于大词表训练至关重要。
## 5. 混合并行实战:TP与DP的结合
在实际的超大规模训练中,纯张量并行(TP)受限于单台机器内的GPU数量(通常为8或16)。为了扩展到成千上万张GPU,必须结合数据并行(DP)。
### 5.1 典型的混合并行架构
业界常见的模式是:**机器内采用张量并行(TP),机器间采用数据并行(DP)**。
- **TP组**:一台机器内的所有GPU构成一个TP组,共同承载一个完整的模型副本。模型参数在组内被切分。
- **DP组**:不同机器上具有相同TP切分位置的GPU,构成一个DP组。每个DP组处理不同的数据批次,并定期同步梯度。
例如,有32台机器,每台8卡。可以设置每台机器为一个TP组(TP=8)。那么总共就有32个模型副本。所有机器上的第0号GPU构成一个DP组,第1号GPU构成另一个DP组,以此类推,共8个DP组,每个DP组包含32个GPU。
### 5.2 通信开销分析与设计抉择
为什么选择“机内TP,机间DP”?
1. **TP通信密集**:TP在每一层的前向和反向传播中都需要AllReduce通信(通信量正比于 `b*s*h`)。这种频繁的、层间的同步通信对延迟和带宽非常敏感。机器内GPU间通常通过NVLink或PCIe互联,带宽远高于机器间的网络(如InfiniBand或以太网)。将TP限制在机内,可以最大化利用高速互联,减少通信延迟。
2. **DP通信宽松**:DP的通信主要发生在每个训练步(step)结束时,对梯度进行AllReduce同步(通信量正比于参数量)。这是一次性的、步调一致的通信。虽然总数据量可能很大,但对延迟的敏感性低于TP的层间通信。此外,可以使用梯度压缩、异步更新等技术进一步优化跨机DP通信。
3. **计算与通信重叠**:在反向传播中,TP需要立即通信以计算传递给前一层的梯度,否则计算无法继续。而DP的梯度同步可以在计算图完成后进行,更容易与后续计算重叠。
下表对比了TP和DP的关键通信特性:
| 特性 | 张量并行 (TP) | 数据并行 (DP) |
| :--- | :--- | :--- |
| **通信发生点** | 每层的前向和反向传播中 | 每个训练步的末尾(梯度同步) |
| **通信模式** | AllReduce (求和) | AllReduce (梯度平均) |
| **通信量/层** | ~ `b * s * h` | ~ `参数数量` (每步一次) |
| **对延迟敏感性** | **高**(阻塞计算流) | 相对较低(可异步、可压缩) |
| **典型部署位置** | **机器内部**(高速互联) | 机器之间(网络互联) |
### 5.3 效率权衡与配置选择
在实际配置时,需要在TP并行度、DP并行度、批处理大小(batch size)和模型大小之间进行权衡:
- 增加TP并行度可以减少每张GPU的显存占用,但会增加机内通信开销,可能降低单卡效率。
- 增加DP并行度可以处理更大的全局批大小,加速训练,但会增加跨机通信开销和梯度同步时间。
- 全局批大小过大可能影响模型收敛性,需要配合学习率调整策略。
实践中,通常基于以下步骤进行配置:
1. **确定单卡内存上限**:根据模型参数量、激活值、优化器状态,估算承载模型所需的最小GPU数量(TP维度)。
2. **选择机内TP维度**:在满足步骤1的前提下,选择机器内GPU数量作为TP维度(如8),以充分利用高速互联。
3. **扩展DP维度**:使用更多的机器来增加DP组大小,直到达到目标的总GPU数量或全局批大小。
理解从最基础的矩阵乘法切分,到复杂的Transformer层并行策略,再到跨节点的混合并行设计,是构建和优化大规模AI训练系统的关键。Megatron-LM的设计向我们展示,高效的并行化并非简单的任务分发,而是对计算图、数据流和通信模式的深度协同设计。每一次通信操作的选择(AllReduce vs All-Gather),每一个切分点的确定,都直接影响到最终训练的速度和扩展效率。掌握这些原理,才能在实际工作中灵活应对不同的模型架构和硬件环境,设计出最适合的并行训练方案。