# 图解Transformer中的Mask机制:为什么BERT和GPT用的不一样?
如果你刚开始接触Transformer,可能会被各种Mask搞得晕头转向。Padding Mask、Sequence Mask、Causal Mask、Cross-Attention Mask……这些名词听起来相似,但在BERT、GPT这些明星模型中,它们扮演的角色和实现方式却截然不同。更让人困惑的是,为什么同样是基于Transformer架构,BERT和GPT对Mask的处理方式会有如此大的差异?这背后其实隐藏着模型设计哲学的根本不同。
理解Mask机制,是深入理解现代大语言模型工作原理的关键一步。它不仅仅是代码里几行简单的逻辑判断,而是决定了模型如何“看见”和“理解”输入信息,是模型实现不同任务能力(如理解上下文或生成文本)的核心设计。今天,我们就通过图解和代码,把这些Mask的来龙去脉、设计逻辑和应用场景彻底讲清楚。
## 1. Mask的本质:控制信息的可见性
在深入具体Mask类型之前,我们首先要明白Mask在Transformer中到底在做什么。简单来说,**Mask是一种在注意力计算中,用于控制哪些位置的信息可以被“看到”或“关注”的机制**。它通过在注意力分数(Attention Scores)上施加一个非常大的负值(如 `-1e9`),使得这些位置在经过Softmax函数后权重趋近于零。
```python
import torch
import torch.nn.functional as F
# 模拟注意力分数计算
def attention_with_mask(Q, K, mask=None):
"""
Q: [batch_size, num_heads, seq_len, d_k]
K: [batch_size, num_heads, seq_len, d_k]
mask: [batch_size, 1, seq_len, seq_len] 或 [batch_size, seq_len, seq_len]
"""
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5) # 缩放点积
if mask is not None:
# 将mask中为0的位置(需要屏蔽)的分数设置为一个极小的负数
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1) # 在最后一个维度做softmax
return attn_weights
# 示例:一个3x3的注意力分数矩阵
scores = torch.tensor([[1.0, 0.5, 0.2],
[0.3, 1.2, 0.8],
[0.1, 0.4, 0.9]])
# 创建一个下三角mask(用于防止看到未来信息)
causal_mask = torch.tril(torch.ones(3, 3)).bool()
print("Causal Mask (下三角):")
print(causal_mask)
# tensor([[ True, False, False],
# [ True, True, False],
# [ True, True, True]])
# 应用mask后的注意力权重
masked_weights = attention_with_mask(scores.unsqueeze(0).unsqueeze(0),
scores.unsqueeze(0).unsqueeze(0),
mask=~causal_mask.unsqueeze(0).unsqueeze(0))
print("\n应用Causal Mask后的注意力权重:")
print(masked_weights.squeeze())
```
> **提示**:在PyTorch的实现中,通常使用 `masked_fill(mask == 0, -1e9)` 这样的操作。这里的 `-1e9` 是一个经验值,只要足够小,经过softmax后对应的概率就会趋近于0,从而实现“屏蔽”效果。
从功能上看,Mask主要解决两个核心问题:
1. **处理变长序列**:实际训练中,一个批次(batch)内的序列长度往往不一致,我们需要用特殊符号(如`[PAD]`)填充到相同长度。但在计算注意力时,这些填充位置不应该参与计算。
2. **防止信息泄露**:在自回归生成任务中(如GPT),模型在预测当前位置时,不应该“看到”未来的信息,否则就相当于作弊。
理解了Mask的基本作用,我们再来看看Transformer家族中不同架构是如何使用Mask的。
## 2. 纯编码器架构:BERT的Padding Mask
BERT(Bidirectional Encoder Representations from Transformers)采用的是**纯编码器(Encoder-Only)**架构。它的设计目标是**双向理解**整个输入序列的上下文信息。在预训练阶段,BERT主要使用**掩码语言模型(Masked Language Model, MLM)**任务,即随机遮盖输入序列中的一些token,让模型根据上下文来预测这些被遮盖的token。
对于BERT这样的编码器模型,其核心的Mask需求是**Padding Mask**。
### 2.1 Padding Mask的设计逻辑
在批量训练时,为了将不同长度的句子处理成相同的长度,我们会在较短的句子末尾添加特殊的填充token(如`[PAD]`)。这些填充token本身没有语义信息,在计算注意力时不应该让其他token关注它们,也不应该让它们关注其他token。
Padding Mask的生成逻辑非常直观:**识别并屏蔽所有填充token所在的位置**。
```python
def get_padding_mask(seq, pad_idx=0):
"""
生成Padding Mask。
seq: 输入序列的token id张量,形状为 [batch_size, seq_len]
pad_idx: 填充token的id(通常为0)
返回: mask张量,形状为 [batch_size, 1, 1, seq_len](便于广播)
"""
# seq != pad_idx 会生成一个布尔张量,True表示是真实token,False表示是填充token
# unsqueeze是为了增加维度,方便后续与注意力分数矩阵广播
mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
# 我们需要的是在注意力分数矩阵中,让填充位置对应的行和列都被屏蔽
# 因此返回的mask形状是 [batch_size, 1, 1, seq_len]
# 在计算注意力时,这个mask会广播到 [batch_size, num_heads, seq_len, seq_len]
return mask
# 示例:一个batch包含两个句子,最大长度设为5
batch_tokens = torch.tensor([
[101, 2023, 2003, 102, 0], # 句子1: [CLS] hello world [SEP] [PAD]
[101, 1045, 102, 0, 0] # 句子2: [CLS] hi [SEP] [PAD] [PAD]
])
# 假设pad_idx=0
padding_mask = get_padding_mask(batch_tokens, pad_idx=0)
print("Padding Mask (形状: {}):".format(padding_mask.shape))
print(padding_mask.squeeze()) # 去掉多余的维度方便查看
# 输出:
# tensor([[[[ True, True, True, True, False]]],
# [[[ True, True, True, False, False]]]])
```
### 2.2 Padding Mask在注意力中的效果
当我们将这个Mask应用到注意力计算时,效果如下图所示:
```
假设序列长度为5,其中最后两个位置是[PAD]
原始注意力分数矩阵(5x5):
[ [s11, s12, s13, s14, s15],
[s21, s22, s23, s24, s25],
[s31, s32, s33, s34, s35],
[s41, s42, s43, s44, s45],
[s51, s52, s53, s54, s55] ]
应用Padding Mask后(屏蔽第4、5行和列):
[ [s11, s12, s13, -inf, -inf],
[s21, s22, s23, -inf, -inf],
[s31, s32, s33, -inf, -inf],
[-inf, -inf, -inf, -inf, -inf],
[-inf, -inf, -inf, -inf, -inf] ]
```
这样,所有token(包括其他真实token和填充token本身)都不会关注填充位置,填充位置也不会关注任何其他位置。**BERT只需要Padding Mask,因为它需要同时看到整个序列的所有真实token,进行双向编码**。
在实际的BERT实现中,这个Mask会应用到每一个Transformer Encoder层的自注意力计算中。值得注意的是,在MLM任务中,被随机选择遮盖的token(如`[MASK]`)**并不是通过Mask机制实现的**,而是直接替换了输入embedding,模型仍然可以看到这个位置有一个token(只是不知道它是什么),并需要预测它。
## 3. 纯解码器架构:GPT的Sequence Mask(Causal Mask)
GPT(Generative Pre-trained Transformer)系列模型采用的是**纯解码器(Decoder-Only)**架构。与BERT的双向理解不同,GPT的设计目标是**自回归生成**,即根据已生成的内容预测下一个token。因此,GPT的核心限制是:**在生成当前位置时,只能看到当前位置及之前的信息,不能看到未来的信息**。
这种限制是通过**Sequence Mask(也称为Causal Mask或Look-ahead Mask)**来实现的。
### 3.1 Causal Mask的生成与原理
Causal Mask是一个**下三角矩阵**,矩阵的主对角线及以下位置为1(允许关注),以上位置为0(禁止关注)。
```python
def get_causal_mask(seq):
"""
生成Causal Mask(下三角矩阵)。
seq: 输入序列,形状为 [batch_size, seq_len]
返回: mask张量,形状为 [1, 1, seq_len, seq_len]
"""
seq_len = seq.size(1)
# 创建一个seq_len x seq_len的下三角矩阵(包含对角线)
# torch.tril: 返回矩阵的下三角部分,其余为0
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
# 增加batch和head维度,便于广播
return causal_mask.unsqueeze(0).unsqueeze(0)
# 示例:序列长度为4
causal_mask = get_causal_mask(torch.zeros(1, 4))
print("Causal Mask (下三角矩阵,4x4):")
print(causal_mask.squeeze())
# 输出:
# tensor([[ True, False, False, False],
# [ True, True, False, False],
# [ True, True, True, False],
# [ True, True, True, True]])
```
### 3.2 Causal Mask的工作方式
为了更直观地理解,我们来看一个序列 `[A, B, C, D]` 在GPT中是如何被处理的:
| 当前预测位置 | 可访问的上下文 | 不可见的未来信息 |
|--------------|----------------|------------------|
| A (位置0) | [A] | B, C, D |
| B (位置1) | [A, B] | C, D |
| C (位置2) | [A, B, C] | D |
| D (位置3) | [A, B, C, D] | 无
这种Mask确保了模型在训练时的行为与推理时完全一致:在推理时,模型确实是逐个token生成的,没有未来的信息可用。
在实际的GPT实现中,还需要将Causal Mask与Padding Mask结合使用:
```python
def get_gpt_style_mask(seq, pad_idx=0):
"""
GPT风格的Mask:Padding Mask与Causal Mask的结合。
"""
batch_size, seq_len = seq.shape
# 1. 生成Padding Mask: [batch_size, 1, 1, seq_len]
pad_mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
# 2. 生成Causal Mask: [1, 1, seq_len, seq_len]
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool().unsqueeze(0).unsqueeze(0)
# 3. 结合两种Mask:只有同时满足Padding Mask和Causal Mask的位置才不被屏蔽
# 即:必须是真实token(非PAD)且不能是未来位置
combined_mask = pad_mask & causal_mask
return combined_mask
# 示例
tokens = torch.tensor([[1, 2, 3, 0, 0], # 实际长度3,填充2个
[1, 2, 0, 0, 0]]) # 实际长度2,填充3个
gpt_mask = get_gpt_style_mask(tokens, pad_idx=0)
print("GPT风格Mask的形状:", gpt_mask.shape) # [2, 1, 5, 5]
# 查看第一个句子的mask矩阵(5x5)
print("\n第一个句子的Mask矩阵:")
print(gpt_mask[0, 0])
# 输出(True表示允许关注):
# tensor([[ True, False, False, False, False], # 位置0只能看自己
# [ True, True, False, False, False], # 位置1可以看0,1
# [ True, True, True, False, False], # 位置2可以看0,1,2
# [False, False, False, False, False], # 位置3是PAD,不能看任何位置
# [False, False, False, False, False]]) # 位置4是PAD,不能看任何位置
```
### 3.3 为什么GPT不需要Encoder?
这是一个常见的疑问。原始的Transformer论文是为机器翻译设计的,包含Encoder和Decoder。但GPT作为纯Decoder模型,为什么能工作得这么好?
关键在于**任务目标的差异**:
- **翻译任务**:需要先理解源语言句子(Encoder),再生成目标语言句子(Decoder)。
- **语言建模任务**:只需要根据上文生成下文,是纯粹的自回归生成任务。
GPT把语言建模任务形式化为:给定前n个token,预测第n+1个token。这完全符合Decoder的自回归特性。在预训练时,GPT使用大量文本,通过Causal Mask让每个位置只能看到前面的token,然后预测下一个token。这种设计让GPT成为了强大的文本生成模型。
## 4. 编码器-解码器架构:T5/BART的混合Mask策略
编码器-解码器(Encoder-Decoder)架构,也称为Seq2Seq架构,是原始Transformer论文的完整形态。这类模型(如T5、BART、原始Transformer)同时包含编码器和解码器两部分,适用于需要先理解输入再生成输出的任务,如翻译、摘要、问答等。
在这种架构中,**Mask的使用最为复杂**,涉及三种不同的Mask:
### 4.1 编码器端的Padding Mask
与BERT类似,编码器接收源序列,需要处理变长输入,因此使用Padding Mask。编码器的自注意力是**双向的**,每个位置都可以看到序列中的所有其他位置(除了填充位置)。
```python
# 编码器端的Padding Mask生成(与BERT相同)
def get_encoder_mask(src_seq, src_pad_idx):
"""为编码器生成Padding Mask"""
return (src_seq != src_pad_idx).unsqueeze(1).unsqueeze(2)
```
### 4.2 解码器端的自注意力Mask
解码器在自注意力层需要防止信息泄露,因此需要**Causal Mask**。但同时,解码器的输入也可能包含填充token,所以还需要**Padding Mask**。两者需要结合:
```python
def get_decoder_self_mask(trg_seq, trg_pad_idx):
"""为解码器自注意力生成Mask:Padding Mask + Causal Mask"""
batch_size, seq_len = trg_seq.shape
# Padding Mask
pad_mask = (trg_seq != trg_pad_idx).unsqueeze(1).unsqueeze(2) # [batch_size, 1, 1, seq_len]
# Causal Mask (下三角)
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool().unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
# 结合:只有非填充且不是未来位置才允许关注
return pad_mask & causal_mask
```
### 4.3 解码器端的交叉注意力Mask
这是编码器-解码器架构特有的部分。在解码器的第二个注意力层(交叉注意力),Query来自解码器,Key和Value来自编码器的输出。此时,解码器的每个位置都可以看到**编码器输出的所有位置**(当然,要排除编码器端的填充位置)。
```python
def get_cross_attention_mask(encoder_mask, decoder_mask):
"""
为解码器的交叉注意力生成Mask。
这里只需要考虑编码器端的Padding Mask,
因为解码器端的每个位置都可以关注编码器输出的所有有效位置。
"""
# encoder_mask: [batch_size, 1, 1, src_len]
# 我们需要将其扩展为 [batch_size, 1, tgt_len, src_len]
# 这样解码器的每个位置(tgt_len)都知道编码器哪些位置是有效的
batch_size = encoder_mask.size(0)
tgt_len = decoder_mask.size(-1) # 解码器序列长度
# 将编码器mask从 [batch_size, 1, 1, src_len] 扩展为 [batch_size, 1, tgt_len, src_len]
# 这样每个解码器位置都使用相同的编码器mask
cross_mask = encoder_mask.expand(-1, -1, tgt_len, -1)
return cross_mask
```
### 4.4 完整示例:T5/BART的Mask流程
让我们通过一个具体的翻译例子来理解整个过程。假设我们要将英文“I love you”翻译成法文“Je t'aime”。
**编码器端**:
- 输入:`["I", "love", "you", "[PAD]", "[PAD]"]`(填充到长度5)
- Padding Mask:`[1, 1, 1, 0, 0]`(1表示真实token,0表示填充)
- 编码器自注意力:每个token可以看到所有真实token(I, love, you)
**解码器端(训练时)**:
- 输入:`["[BOS]", "Je", "t'", "aime", "[PAD]"]`("[BOS]"是开始符号)
- 自注意力Mask:Causal + Padding
- 位置0([BOS]):只能看到自己
- 位置1("Je"):可以看到[BOS], Je
- 位置2("t'"):可以看到[BOS], Je, t'
- 位置3("aime"):可以看到[BOS], Je, t', aime
- 位置4([PAD]):不能看任何位置
- 交叉注意力:解码器的每个位置都可以看到编码器所有真实token(I, love, you)
这种设计确保了:
1. 解码器在生成每个法文单词时,只能看到已经生成的法文单词(自注意力)
2. 但可以看到整个英文句子的信息(交叉注意力)
3. 填充位置被正确屏蔽
## 5. 可视化对比:三种架构的Mask差异
为了更直观地理解,我们用一个3x3的注意力矩阵来可视化不同架构下的Mask模式。假设序列长度为3,其中最后一个位置是填充(PAD)。
### 5.1 BERT(纯编码器)的Mask模式
```
注意力矩阵(3x3):
行表示"关注者",列表示"被关注者"
真实token1 → [ 1, 1, 0 ] # 可以关注token1、token2,不能关注PAD
真实token2 → [ 1, 1, 0 ] # 可以关注token1、token2,不能关注PAD
填充token3 → [ 0, 0, 0 ] # 不能关注任何位置(包括自己)
```
**特点**:对称的Mask,填充位置完全被隔离。
### 5.2 GPT(纯解码器)的Mask模式
```
注意力矩阵(3x3):
假设所有位置都是真实token(无填充)
token1 → [ 1, 0, 0 ] # 只能关注自己(位置1)
token2 → [ 1, 1, 0 ] # 可以关注位置1和2
token3 → [ 1, 1, 1 ] # 可以关注所有位置(位置1、2、3)
```
如果考虑填充(假设位置3是PAD):
```
token1 → [ 1, 0, 0 ] # 只能关注自己
token2 → [ 1, 1, 0 ] # 可以关注位置1和2
PAD3 → [ 0, 0, 0 ] # 不能关注任何位置
```
**特点**:下三角模式,确保只能看到过去的信息。
### 5.3 T5/BART(编码器-解码器)的Mask模式
这里需要分开看两个注意力矩阵:
**解码器自注意力**(与GPT相同,下三角):
```
token1 → [ 1, 0, 0 ]
token2 → [ 1, 1, 0 ]
token3 → [ 1, 1, 1 ]
```
**解码器交叉注意力**(关注编码器输出):
假设编码器有2个真实token + 1个PAD
```
解码器token1 → [ 1, 1, 0 ] # 可以关注编码器的所有真实token
解码器token2 → [ 1, 1, 0 ] # 可以关注编码器的所有真实token
解码器token3 → [ 1, 1, 0 ] # 可以关注编码器的所有真实token
```
**特点**:自注意力是下三角的,交叉注意力是全连接的(仅排除编码器填充)。
### 5.4 三种架构的对比表格
| 特性 | BERT (Encoder-Only) | GPT (Decoder-Only) | T5/BART (Encoder-Decoder) |
|------|---------------------|-------------------|---------------------------|
| **主要任务** | 理解、分类、标注 | 文本生成 | 序列到序列(翻译、摘要等) |
| **注意力方向** | 双向 | 单向(仅左侧) | 解码器自注意力:单向;交叉注意力:双向 |
| **核心Mask** | Padding Mask | Padding Mask + Causal Mask | 编码器:Padding Mask;解码器:Padding+Causal Mask;交叉注意力:编码器Padding Mask |
| **信息流** | 全连接(除填充) | 仅左侧连接 | 自注意力:仅左侧;交叉注意力:全连接 |
| **典型应用** | 文本分类、NER、问答 | 对话、续写、代码生成 | 翻译、摘要、文本改写 |
## 6. 实际代码中的Mask实现细节
理解了理论后,我们来看看在实际的Transformer实现中,这些Mask是如何被创建和应用的。以下是一个简化的Transformer实现,展示了三种Mask的具体使用:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
"""简化版多头注意力,展示Mask的应用"""
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 线性变换并分割成多头
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
# 应用Mask(如果提供)
if mask is not None:
# mask形状应为 [batch_size, 1, 1, seq_len] 或 [batch_size, 1, seq_len, seq_len]
# 需要扩展到多头维度
mask = mask.unsqueeze(1) # 增加head维度
scores = scores.masked_fill(mask == 0, -1e9)
# Softmax得到注意力权重
attn_weights = F.softmax(scores, dim=-1)
# 应用注意力权重到Value
context = torch.matmul(attn_weights, V)
# 合并多头输出
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 最终线性变换
output = self.W_o(context)
return output, attn_weights
class TransformerEncoderLayer(nn.Module):
"""Transformer编码器层"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 自注意力(带残差连接和层归一化)
attn_output, _ = self.self_attn(x, x, x, mask)
x = x + self.dropout(attn_output)
x = self.norm1(x)
# 前馈网络(带残差连接和层归一化)
ff_output = self.feed_forward(x)
x = x + self.dropout(ff_output)
x = self.norm2(x)
return x
class TransformerDecoderLayer(nn.Module):
"""Transformer解码器层"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# 自注意力(需要causal mask)
self.self_attn = MultiHeadAttention(d_model, num_heads)
# 交叉注意力(关注编码器输出)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output, self_mask=None, cross_mask=None):
# 自注意力(带causal mask)
self_attn_output, _ = self.self_attn(x, x, x, self_mask)
x = x + self.dropout(self_attn_output)
x = self.norm1(x)
# 交叉注意力(关注编码器输出)
cross_attn_output, _ = self.cross_attn(x, encoder_output, encoder_output, cross_mask)
x = x + self.dropout(cross_attn_output)
x = self.norm2(x)
# 前馈网络
ff_output = self.feed_forward(x)
x = x + self.dropout(ff_output)
x = self.norm3(x)
return x
def create_masks(src, trg, src_pad_idx, trg_pad_idx):
"""
为编码器-解码器模型创建所有必要的mask
src: 源序列 [batch_size, src_len]
trg: 目标序列 [batch_size, trg_len]
返回: src_mask, trg_mask, memory_mask
"""
# 编码器mask:仅Padding Mask
src_mask = (src != src_pad_idx).unsqueeze(1).unsqueeze(2) # [batch_size, 1, 1, src_len]
# 解码器自注意力mask:Padding Mask + Causal Mask
trg_pad_mask = (trg != trg_pad_idx).unsqueeze(1).unsqueeze(2) # Padding部分
trg_len = trg.size(1)
causal_mask = torch.tril(torch.ones(trg_len, trg_len)).bool().unsqueeze(0).unsqueeze(0) # Causal部分
trg_mask = trg_pad_mask & causal_mask # 结合
# 交叉注意力mask:使用编码器的Padding Mask
# 扩展为 [batch_size, 1, trg_len, src_len]
memory_mask = src_mask.expand(-1, -1, trg_len, -1)
return src_mask, trg_mask, memory_mask
# 使用示例
batch_size = 2
src_len = 5
trg_len = 6
d_model = 512
num_heads = 8
# 模拟输入
src = torch.randint(1, 100, (batch_size, src_len))
trg = torch.randint(1, 100, (batch_size, trg_len))
# 创建mask
src_mask, trg_mask, memory_mask = create_masks(src, trg, src_pad_idx=0, trg_pad_idx=0)
print("源序列形状:", src.shape)
print("目标序列形状:", trg.shape)
print("\n编码器Mask形状:", src_mask.shape)
print("解码器自注意力Mask形状:", trg_mask.shape)
print("交叉注意力Mask形状:", memory_mask.shape)
# 验证Causal Mask的下三角特性
print("\n解码器Mask的第一个样本(展示Causal特性):")
print(trg_mask[0, 0, :, :].int())
```
这段代码展示了完整的Mask创建流程。在实际的Transformer实现中,这些Mask会被传递到对应的注意力层中,控制信息流动。
## 7. 现代大模型中的Mask变体与优化
随着模型规模的增长,Mask机制也在不断演进。以下是一些重要的变体和优化:
### 7.1 滑动窗口注意力(Sliding Window Attention)
为了处理超长序列,一些模型采用了滑动窗口注意力,每个token只能关注其前后固定窗口内的token。这可以看作是一种稀疏的Causal Mask。
```python
def create_sliding_window_mask(seq_len, window_size):
"""
创建滑动窗口Mask
seq_len: 序列长度
window_size: 窗口大小(每侧)
返回: [1, 1, seq_len, seq_len]的布尔张量
"""
mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
for i in range(seq_len):
start = max(0, i - window_size)
end = min(seq_len, i + window_size + 1)
mask[i, start:end] = True
return mask.unsqueeze(0).unsqueeze(0)
# 示例:序列长度8,窗口大小2
window_mask = create_sliding_window_mask(8, 2)
print("滑动窗口Mask (窗口大小=2):")
print(window_mask.squeeze().int())
```
### 7.2 前缀注意力(Prefix-LM)的Mask
一些模型(如UniLM)使用前缀注意力,其中部分序列是双向的(前缀),部分序列是单向的(生成部分)。这种Mask是Causal Mask的变体:
```
假设序列前3个token是前缀(双向),后2个是生成部分(单向)
Mask矩阵:
[1, 1, 1, 0, 0] # 前缀位置0:可以看到所有前缀
[1, 1, 1, 0, 0] # 前缀位置1:可以看到所有前缀
[1, 1, 1, 0, 0] # 前缀位置2:可以看到所有前缀
[1, 1, 1, 1, 0] # 生成位置3:可以看到所有前缀+已生成部分
[1, 1, 1, 1, 1] # 生成位置4:可以看到所有token
```
### 7.3 内存高效的Mask实现
对于极长序列,存储完整的注意力矩阵(seq_len × seq_len)可能内存不足。现代实现使用了一些优化技巧:
```python
def causal_attention_forward(Q, K, V):
"""
内存高效的Causal Attention实现
使用迭代方式,避免存储完整的注意力矩阵
"""
batch_size, num_heads, seq_len, d_k = Q.shape
output = torch.zeros_like(V)
# 逐位置计算
for i in range(seq_len):
# 只计算当前位置可以访问的部分
Q_i = Q[:, :, i:i+1, :] # [batch, heads, 1, d_k]
K_i = K[:, :, :i+1, :] # [batch, heads, i+1, d_k]
V_i = V[:, :, :i+1, :] # [batch, heads, i+1, d_v]
# 计算注意力分数(只关注前i+1个位置)
scores = torch.matmul(Q_i, K_i.transpose(-2, -1)) / (d_k ** 0.5)
attn_weights = F.softmax(scores, dim=-1)
# 应用注意力权重
output[:, :, i:i+1, :] = torch.matmul(attn_weights, V_i)
return output
```
### 7.4 FlashAttention优化
FlashAttention是一种IO感知的注意力算法,通过分块计算来减少GPU内存访问,显著加速注意力计算并减少内存使用。它特别适合处理长序列,但实现较为复杂,这里只展示其基本思想:
```python
# FlashAttention的伪代码示意
def flash_attention(Q, K, V, block_size=256):
"""
简化的FlashAttention思想展示
实际实现要复杂得多,涉及精细的GPU内存管理
"""
batch_size, num_heads, seq_len, d_k = Q.shape
output = torch.zeros_like(V)
# 将Q、K、V分块
num_blocks = (seq_len + block_size - 1) // block_size
for block_i in range(num_blocks):
# 处理一个块
start_i = block_i * block_size
end_i = min((block_i + 1) * block_size, seq_len)
# 加载当前块到快速内存(SRAM)
Q_block = Q[:, :, start_i:end_i, :]
# 与所有K块计算注意力
for block_j in range(num_blocks):
start_j = block_j * block_size
end_j = min((block_j + 1) * block_size, seq_len)
K_block = K[:, :, start_j:end_j, :]
V_block = V[:, :, start_j:end_j, :]
# 计算块间注意力
# ... 实际实现涉及复杂的重计算和内存管理
pass
return output
```
## 8. 选择正确的Mask:实践指南
在实际项目中,如何选择和使用正确的Mask?这里有一些实用建议:
### 8.1 根据任务选择架构
| 任务类型 | 推荐架构 | Mask类型 | 说明 |
|---------|---------|---------|------|
| **文本分类** | Encoder-Only (BERT) | Padding Mask | 需要理解整个句子的双向上下文 |
| **命名实体识别** | Encoder-Only (BERT) | Padding Mask | 每个token的分类需要全局上下文 |
| **文本生成** | Decoder-Only (GPT) | Padding + Causal Mask | 自回归生成,不能看到未来信息 |
| **机器翻译** | Encoder-Decoder (T5) | 三种Mask组合 | 需要理解源语言,生成目标语言 |
| **文本摘要** | Encoder-Decoder (BART) | 三种Mask组合 | 理解原文,生成摘要 |
| **问答系统** | Encoder-Only 或 Encoder-Decoder | 取决于具体设计 | 提取式问答用Encoder,生成式用Encoder-Decoder |
### 8.2 常见错误与调试技巧
**错误1:Mask形状不匹配**
```python
# 错误:Mask形状与注意力分数不匹配
scores = torch.matmul(Q, K.transpose(-2, -1)) # [batch, heads, seq_len, seq_len]
mask = (seq != pad_idx).unsqueeze(1) # [batch, 1, seq_len] - 错误!
scores = scores.masked_fill(mask == 0, -1e9) # 形状不匹配!
# 正确:需要unsqueeze两次
mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len]
```
**错误2:忘记结合Padding和Causal Mask**
```python
# 错误:只用了Causal Mask,忽略了Padding
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
# 如果序列有填充,模型会关注填充位置!
# 正确:结合两种Mask
pad_mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
combined_mask = pad_mask & causal_mask.unsqueeze(0).unsqueeze(0)
```
**调试技巧:可视化Mask矩阵**
```python
def visualize_mask(mask, title="Mask Matrix"):
"""可视化Mask矩阵"""
import matplotlib.pyplot as plt
mask_np = mask.squeeze().cpu().numpy()
plt.figure(figsize=(8, 6))
plt.imshow(mask_np, cmap='Blues', interpolation='nearest')
plt.colorbar()
plt.title(title)
plt.xlabel("Key Positions")
plt.ylabel("Query Positions")
# 添加文本标注
for i in range(mask_np.shape[0]):
for j in range(mask_np.shape[1]):
plt.text(j, i, '1' if mask_np[i, j] else '0',
ha='center', va='center', color='black' if mask_np[i, j] else 'white')
plt.tight_layout()
plt.show()
# 使用示例
seq = torch.tensor([[1, 2, 3, 0, 0]])
mask = get_gpt_style_mask(seq, pad_idx=0)
visualize_mask(mask[0], "GPT-style Mask")
```
### 8.3 性能优化建议
1. **预计算Mask**:对于固定长度的序列,可以预计算Mask并缓存,避免每次前向传播都重新计算。
2. **使用布尔张量**:PyTorch中布尔类型的Mask比浮点型更节省内存。
3. **注意广播机制**:合理设计Mask形状,利用广播机制减少内存使用。
4. **考虑稀疏注意力**:对于极长序列,考虑使用稀疏注意力模式(如滑动窗口、局部注意力)。
Mask机制是Transformer架构中看似简单实则精妙的设计。不同的Mask策略直接对应着不同的模型能力和应用场景。BERT的Padding Mask让它能够双向理解上下文,GPT的Causal Mask让它能够自回归生成文本,而Encoder-Decoder架构的混合Mask让它能够处理复杂的序列到序列任务。
理解这些差异不仅有助于我们正确使用现有的Transformer模型,也为设计新的模型架构提供了基础。在实际工作中,我经常发现Mask相关的bug是最难调试的之一,因为错误可能不会导致程序崩溃,而是表现为模型性能的微妙下降。掌握Mask的工作原理和实现细节,是成为Transformer专家的必经之路。