## 1. 打破常规:为什么图像处理需要Transformer?
大家好,我是老张,在AI和硬件领域摸爬滚打了十几年。今天想和大家聊聊一个彻底改变了计算机视觉格局的模型——Vision Transformer,也就是我们常说的ViT。如果你对深度学习有点了解,肯定知道卷积神经网络(CNN)在过去十年里几乎是图像处理的“标配”,从人脸识别到自动驾驶,无处不在。但2020年,谷歌的一篇论文像一颗重磅炸弹,提出了一个大胆的想法:**能不能完全不用卷积,只用Transformer来处理图像?**
我刚开始接触这个想法时,第一反应也是“这能行吗?”。毕竟Transformer是为自然语言处理(NLP)设计的,它处理的是像句子那样一维的序列数据。而图像是二维的,有空间结构,像素之间还有复杂的局部关系。CNN的卷积核天生就是为了捕捉这种局部特征(比如边缘、纹理)而生的,它的“局部感受野”设计非常符合我们对图像处理的直觉。
但Transformer的核心——**自注意力机制**——有一个CNN难以比拟的绝活:**全局建模能力**。一个卷积核每次只能看到图像的一小块区域(比如3x3),要想知道图像左上角和右下角的关系,信息需要经过很多层卷积才能传递过去。而自注意力机制在理论上,从第一层开始,任何一个“元素”就能直接和图像上所有其他“元素”进行交互和计算关联性。
这就好比你要理解一整篇文章的意思。CNN的方式是,先让你认单词(局部特征),再组合成短语,再看句子,最后理解段落和全文。而Transformer的方式是,直接把所有单词摆在你面前,让你同时去分析每个单词和文章中所有其他单词的关系,立刻把握全文的核心思想和长距离的指代关系。后者在理解整体上下文和复杂依赖关系上,潜力巨大。
ViT的诞生,就是要把Transformer这种强大的全局理解能力,直接“嫁接”到图像上。它要回答的问题是:如果我们放弃CNN那些为图像量身定做的“先验知识”(比如平移不变性、局部性),完全让数据驱动,让模型自己从海量数据中学到图像的一切规律,结果会怎样?事实证明,当数据量足够大时,这个“暴力”但直接的方法,效果惊人。
## 2. ViT的核心设计:把图像变成“句子”
那么,具体怎么把一张图片塞进为文本设计的Transformer里呢?这是ViT最精妙也最核心的一步。它的思路非常直观,可以概括为:**分块、拉平、视为词**。
### 2.1 图像分块:从像素网格到视觉单词
想象一下,你有一张标准的224x224像素的彩色图片。如果直接把每个像素当作一个“词”,那序列长度就是224*224=50176,这对于计算注意力来说是不可承受之重(计算量随序列长度平方增长)。
ViT的做法很聪明:它不关心单个像素,而是把图像划分成一个个大小相等的**图像块**。论文里常用的设置是把图片切成16x16像素的小块。对于224x224的图,横竖各切14刀,就得到了14x14=196个图像块。
每个图像块的大小是16x16x3=768(16像素高,16像素宽,3个颜色通道)。现在,这196个图像块,就成了我们处理的基本单元。你可以把它们想象成196个“视觉单词”,每个“单词”包含了16x16区域内的所有视觉信息。
这一步在代码里通常用一个步长等于核大小的卷积层高效实现:
```python
# 使用卷积层一步完成分块和线性投影
self.proj = nn.Conv2d(in_channels=3, out_channels=embed_dim, kernel_size=16, stride=16)
# 输入 [B, 3, 224, 224] -> 输出 [B, 768, 14, 14]
```
这个卷积核大小为16,步长也为16,意味着它不重叠地扫描整张图,每个16x16区域输出一个值(实际上是768维的向量),完美实现了分块操作。
### 2.2 线性投影与位置编码:为Transformer准备输入
分块之后,每个图像块实际上是一个三维的小数组(16, 16, 3)。我们需要把它转换成一维的向量,才能输入Transformer。这个过程叫**线性投影**,其实就是用一个全连接层把768维的像素值映射到一个新的维度D(论文中D=768)。
现在,我们有了196个长度为768的向量。但Transformer本身没有位置概念,打乱这些向量的顺序,它计算出的结果是一样的。这对于图像来说显然是灾难性的——猫的鼻子长在眼睛上面还是下面,意义完全不同。因此,我们必须加入**位置编码**。
ViT采用了一种简单直接的可学习位置编码。我们初始化一个可训练的参数矩阵,形状是 `[197, 768]`(为什么是197?稍后解释),其中每一行代表一个位置(从0到196)的编码向量。在输入Transformer之前,把这个位置编码向量直接加到对应的图像块向量上:
```python
# x 的形状是 [batch_size, 196, 768]
# pos_embed 的形状是 [1, 197, 768],在batch维度广播
x = x + self.pos_embed[:, 1:, :] # 注意,位置0预留给了一个特殊token
```
这个可学习的位置编码会在训练过程中,自己学会如何表征“上下左右”、“中心边缘”这些空间关系。我实测下来发现,学到的位置编码可视化后,确实能呈现出明显的二维空间结构,相邻位置编码相似,同行或同列的位置编码也有规律。
### 2.3 Class Token:图像的“句子主旨”
在NLP的BERT模型里,有一个 `[CLS]` token,用于汇聚整个句子的信息,做分类任务。ViT借鉴了这个天才的设计,引入了一个**可学习的分类token**。
我们在序列的最前面,额外添加一个特殊的向量,称为 `cls_token`。现在,输入序列的长度从196变成了197。这个 `cls_token` 会和其他196个图像块token一起,经过所有Transformer层的处理。在最后一层,我们只取出这个 `cls_token` 对应的输出向量,把它送入一个轻量的分类头(通常是MLP),得到最终的图像分类结果。
你可以把这个 `cls_token` 理解为一个“提问者”或“汇总者”。它随着数据流经整个网络,通过自注意力机制不断地从所有图像块那里收集信息,最终它自己身上就凝聚了整张图片的全局语义。这个设计避免了我们需要从196个输出中选哪一个来做分类的尴尬,非常优雅。
## 3. Transformer编码器:全局注意力的引擎
准备好输入序列后,就进入了ViT的主干——**Transformer编码器**。这部分和原始Transformer的编码器几乎一模一样,由多个相同的层堆叠而成(ViT-Base是12层)。每一层都包含两个核心子层:多头自注意力层和前馈神经网络层,并且每个子层都包裹着残差连接和层归一化。
### 3.1 多头自注意力机制:让每个像素块“纵观全局”
这是Transformer的灵魂,也是ViT获得全局建模能力的来源。它的工作原理可以用一个“信息检索”的类比来理解。
对于序列中的每一个token(比如某个图像块),自注意力层会帮它做三件事:
1. **生成查询**:这个token想问:“我关心什么样的信息?”
2. **生成键**:序列里所有token(包括自己)都亮出自己的“身份标签”,说:“我这里有这样的信息。”
3. **生成值**:每个token准备好自己实际要提供的“内容信息”。
然后,计算这个token的“查询”与所有token的“键”的相似度(通过点积),得到一个注意力权重分布。这个权重决定了在汇总信息时,应该从每个token的“值”那里取多少。最后,用这个权重对所有token的“值”进行加权求和,就得到了该token新的表示。
**多头**机制,就是并行地做多组这样的操作(比如12个头),每组使用不同的投影矩阵,相当于从12个不同的角度或“语义子空间”去计算注意力。最后把12个结果拼接起来,再投影回原来的维度。这大大增强了模型的表达能力。
在ViT中,这意味着**从第一层开始,图像左上角的一个小块,就能直接关注到右下角的小块**,并建立联系。这对于识别“一只鸟的喙和它的脚属于同一个物体”这样的长距离依赖至关重要。而CNN需要很多层卷积,这种信息才能慢慢传递过去。
代码层面,核心计算非常紧凑:
```python
# q, k, v 形状均为 [batch_size, num_heads, seq_len, head_dim]
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim) # 计算相似度
attn_weights = F.softmax(attn_weights, dim=-1) # 归一化为权重
attn_output = torch.matmul(attn_weights, v) # 加权求和
```
### 3.2 前馈网络与层归一化:非线性变换与稳定训练
自注意力层输出后,会经过一个**前馈网络**。它通常是一个简单的两层MLP,中间有一个非线性激活函数(如GELU)。这个子层的作用是给每个token的表示增加非线性变换能力,进行特征混合和升维再降维。
这里有一个关键点:自注意力层是进行 **token与token之间** 的交互(混合空间信息),而前馈网络是进行 **每个token内部特征之间** 的交互(混合通道信息)。两者分工明确。
**残差连接**和**层归一化**是训练深层Transformer的关键技术。每个子层的输入和输出会相加(残差连接),这能有效缓解梯度消失,让网络可以堆得很深。层归一化则被应用在子层之前(Pre-Norm,这是ViT采用的方式,不同于原始Transformer的Post-Norm),它对每个样本的所有特征维度进行归一化,稳定了训练过程,让学习率可以设得更大,收敛更快。
我自己的经验是,在搭建ViT时,这些细节(Pre-Norm vs Post-Norm,GELU vs ReLU,初始化方式)对最终效果的影响非常显著,有时候调好一个细节,准确率能提升一两个点。
## 4. ViT vs CNN:理念之争与实战差异
聊了这么多原理,大家最关心的肯定是:ViT和传统的CNN到底谁更强?在实际项目中该怎么选?我结合自己的实战经验,给大家做个对比分析。
| 特性 | 卷积神经网络 | Vision Transformer |
| :--- | :--- | :--- |
| **核心操作** | 卷积(局部滑动窗口) | 自注意力(全局交互) |
| **归纳偏置** | 强:平移不变性、局部性 | 弱:几乎没有图像-specific的先验 |
| **感受野** | 局部开始,随层数扩大 | **从第一层起就是全局** |
| **数据需求** | 相对较少,在小数据集上也能学好 | **极度依赖大规模数据** |
| **计算效率** | 高,局部计算,参数共享 | 序列长度平方级复杂度,对长序列慢 |
| **可解释性** | 特征图可视化,相对直观 | 注意力图可视化,显示全局关联 |
| **擅长任务** | 纹理、局部模式识别、实时检测 | 长距离依赖、全局上下文理解、大规模分类 |
**CNN的优势在于它的“高效”和“数据友好”**。卷积的局部性和权重共享,让它用较少的参数和计算量就能提取有效的层次化特征。而且,它的归纳偏置非常符合图像的物理规律,所以即使在ImageNet这种“中等”规模的数据集上(120万张图),也能取得非常好的效果。在计算资源有限、数据量不大的场景下,CNN依然是首选。
**ViT的优势在于它的“强大”和“可扩展性”**。它几乎没有对图像结构做任何假设,所有空间关系都需要从数据中学。这既是缺点也是优点。缺点是需要海量数据(比如谷歌用的JFT-3亿数据集)来“喂饱”它,否则很容易过拟合,效果不如CNN。但一旦用超大数据集预训练好,它的表现往往能超越CNN,尤其是在需要理解图像全局结构的任务上。而且,模型越大,数据越多,ViT的性能提升似乎没有明显的天花板,显示出极好的可扩展性。
在实际项目中,我的选择策略是:
- **如果任务类似ImageNet分类,数据量在百万级**:可以尝试ViT,但需要仔细调参和可能的数据增强。使用预训练模型进行微调是更稳妥的选择。
- **如果数据量只有几万甚至几千**:老老实实用CNN(如ResNet、EfficientNet)或者使用在ImageNet上预训练好的ViT进行微调。直接从头训练ViT大概率会翻车。
- **如果任务对全局上下文要求极高**:比如医学图像中分析整个器官的病变关联、卫星图像中分析地理要素的分布,即使数据量不大,也值得尝试用预训练的ViT,它的全局注意力机制可能带来惊喜。
- **如果对推理速度要求苛刻**:CNN或更轻量的混合模型(如MobileViT)仍是主流。
## 5. 动手实践:从零理解ViT的代码实现
光说不练假把式。下面我带大家走一遍ViT关键模块的PyTorch实现代码,我会加上详细的注释,确保你能看懂每一行在做什么。
首先,是**图像分块嵌入模块**,它负责把图片变成序列:
```python
import torch
import torch.nn as nn
class PatchEmbed(nn.Module):
""" 将2D图像分割为Patches并做线性投影 """
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.img_size = (img_size, img_size)
self.patch_size = (patch_size, patch_size)
# 计算patch数量: (224/16) * (224/16) = 14*14 = 196
self.num_patches = (img_size // patch_size) * (img_size // patch_size)
# 核心:用一个大卷积核、大步长的卷积层同时完成分块和线性投影
# kernel_size=16, stride=16 意味着不重叠地取每个16x16区域
# 输入通道3,输出通道embed_dim(768)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape # 输入形状 [Batch, 3, 224, 224]
# 确保输入尺寸正确
assert H == self.img_size[0] and W == self.img_size[1], f"输入尺寸({H},{W})与模型设定{self.img_size}不符"
# 分块投影: [B, 3, 224, 224] -> [B, 768, 14, 14]
# flatten(2): 将高和宽维度展平 -> [B, 768, 196]
# transpose(1,2): 交换维度 -> [B, 196, 768] (序列长度196,特征维度768)
x = self.proj(x).flatten(2).transpose(1, 2)
return x
```
接下来是ViT的**核心构建块**,即一个Transformer编码层:
```python
class TransformerBlock(nn.Module):
""" 一个完整的Transformer编码器层 """
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.):
super().__init__()
# 第一层:层归一化 + 多头自注意力 + 残差
self.norm1 = nn.LayerNorm(dim)
self.attn = MultiHeadAttention(dim, num_heads, qkv_bias, attn_drop_rate, drop_rate)
# 第二层:层归一化 + 前馈网络(MLP) + 残差
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop_rate)
# 随机深度丢弃(Stochastic Depth),一种正则化技术
self.drop_path = DropPath(drop_rate) if drop_rate > 0. else nn.Identity()
def forward(self, x):
# 注意Pre-Norm结构:先归一化,再进入子层
# 残差连接:x = x + 子层输出
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
```
其中用到的**多头自注意力**实现如下:
```python
class MultiHeadAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads # 每个头的维度,如768/12=64
self.scale = self.head_dim ** -0.5 # 缩放因子,稳定softmax梯度
# 用一个线性层同时生成Q, K, V的投影矩阵
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
# 将多个头的输出合并回原维度
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape # [batch_size, 序列长度197, 特征维度768]
# 生成QKV: [B, 197, 768] -> [B, 197, 768*3] -> 重塑并分头
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # 每个形状: [B, num_heads, N, head_dim]
# 计算注意力分数: Q * K^T / sqrt(d_k)
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, heads, N, N]
attn = attn.softmax(dim=-1) # 在最后一个维度(N)上做softmax
attn = self.attn_drop(attn)
# 应用注意力权重到V上
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # 合并多头
x = self.proj(x)
x = self.proj_drop(x)
return x
```
最后,我们把所有部分组装成完整的**Vision Transformer模型**:
```python
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12):
super().__init__()
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
# 可学习的分类token和位置编码
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) # +1 for cls_token
# 堆叠Transformer层
self.blocks = nn.ModuleList([
TransformerBlock(dim=embed_dim, num_heads=num_heads)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
# 分类头:通常就是一个线性层
self.head = nn.Linear(embed_dim, num_classes)
# 初始化参数
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
B = x.shape[0]
# 1. 分块嵌入
x = self.patch_embed(x) # [B, 196, 768]
# 2. 添加cls token
cls_tokens = self.cls_token.expand(B, -1, -1) # 从[1,1,768]扩展到[B,1,768]
x = torch.cat((cls_tokens, x), dim=1) # [B, 197, 768]
# 3. 添加位置编码
x = x + self.pos_embed
# 4. 通过Transformer编码器
for blk in self.blocks:
x = blk(x)
# 5. 取cls token的输出做分类
x = self.norm(x)
cls_output = x[:, 0] # 只取第一个token(cls_token)的输出
out = self.head(cls_output)
return out
```
你可以用几行代码实例化并测试这个模型:
```python
model = VisionTransformer(img_size=224, patch_size=16, num_classes=10)
dummy_input = torch.randn(4, 3, 224, 224) # 4张224x224的RGB图
output = model(dummy_input)
print(output.shape) # 应该输出: torch.Size([4, 10])
```
## 6. 超越分类:ViT的进化与未来
ViT最初是为图像分类设计的,但它的影响力远不止于此。研究者们很快将这种“分块+Transformer”的思想拓展到了计算机视觉的各个角落,催生了一系列变体和改进。
**目标检测**:DETR是第一个用Transformer做端到端目标检测的框架,它去掉了传统的锚框和非极大值抑制,将检测视为一个集合预测问题。后续的Deformable DETR、Swin Transformer for Detection等,都在速度和精度上做了进一步优化。
**图像分割**:SETR、Segmenter等模型将ViT作为编码器,配合各种解码器结构,在语义分割任务上取得了媲美甚至超越CNN的成绩。ViT的全局上下文信息对于理解像素所属的物体类别非常有帮助。
**底层视觉任务**:甚至在图像超分辨率、去噪、风格迁移等任务中,也出现了基于Transformer的模型,它们能更好地建模图像的长程依赖,恢复出更连贯的纹理和结构。
当然,ViT也有其明显的**挑战**。最大的问题就是**计算复杂度**。自注意力的计算量与序列长度的平方成正比。当图像分辨率很高时(比如1024x1024),序列长度会爆炸式增长。为了解决这个问题,社区提出了许多聪明的方案:
- **局部窗口注意力**:像Swin Transformer那样,将注意力计算限制在局部窗口内,再通过窗口移动来传递信息,将计算复杂度从平方降到了线性。
- **分层下采样结构**:同样在Swin中引入,像CNN一样构建特征金字塔,逐步减少序列长度,同时增加特征维度,兼顾效率和表达能力。
- **稀疏注意力**:只计算最重要的token对之间的注意力,比如根据键的相似度来筛选。
- **线性注意力**:通过核函数近似,将softmax注意力转化为线性计算。
在我个人看来,ViT的出现并不是要彻底取代CNN,而是为我们提供了另一种强大的工具。未来的趋势很可能是**融合**:吸收CNN在局部特征提取、平移等变性上的高效性,结合Transformer在全局建模上的强大能力。很多优秀的混合模型已经证明了这一点。
对于初学者,我的建议是,先扎实理解CNN,再深入学习Transformer和ViT。当你手里既有锤子又有螺丝刀时,面对不同的任务,你才能游刃有余地选择最合适的工具。ViT的代码实现虽然看起来比CNN复杂,但它的模块化程度很高,一旦理解了自注意力和分块嵌入这两个核心,整个架构就豁然开朗了。多动手写几遍代码,在小的数据集(比如CIFAR-10)上跑一跑,调整参数观察效果,是掌握它的最好方式。