# SwinNet实战:如何用Swin Transformer提升RGB-D图像边缘检测精度(附代码)
在计算机视觉的诸多任务中,显著目标检测(Salient Object Detection, SOD)一直扮演着关键角色,它旨在从复杂背景中精准地识别并分割出最吸引人注意的前景目标。当我们将单一的RGB图像扩展到包含深度信息的RGB-D图像,甚至是热红外信息的RGB-T图像时,问题变得更加有趣,也更具挑战性。多模态数据带来了信息互补的巨大潜力——RGB提供丰富的纹理和颜色,深度图勾勒出清晰的空间结构,而热红外则揭示了温度差异。然而,如何有效地融合这些异构信息,并让模型“看见”并强化目标的精细边缘,是决定最终检测效果的核心。
近年来,Transformer架构以其强大的全局建模能力席卷了视觉领域,而Swin Transformer通过引入层级设计和窗口注意力机制,巧妙地平衡了计算效率与建模能力。SwinNet正是这一趋势下的杰出产物,它将Swin Transformer作为骨干网络,并创新性地融入了边缘感知机制,专门为RGB-D和RGB-T的显著目标检测任务量身打造。与那些只关注“是什么”的理论论文不同,本文将带你从零开始,深入代码层面,亲手搭建并优化一个SwinNet模型,解决实际项目中遇到的边缘模糊、模态对齐不准等痛点。无论你是希望在自己的数据集上复现SOTA结果,还是想借鉴其思想改进自己的多模态融合模型,这里都有你需要的实战细节。
## 1. 环境搭建与核心依赖解析
动手之前,一个稳定、兼容的环境是成功的基石。SwinNet的实现依赖于PyTorch生态和一些特定的视觉库,版本匹配至关重要。我曾在早期因为torchvision版本不兼容,导致预训练模型加载失败,白白浪费了半天时间排查。
首先,我们创建一个独立的Python虚拟环境。我强烈推荐使用`conda`,它能很好地管理复杂的依赖关系。
```bash
conda create -n swinnet python=3.8 -y
conda activate swinnet
```
接下来安装PyTorch。请根据你的CUDA版本前往[PyTorch官网](https://pytorch.org/get-started/locally/)获取准确的安装命令。例如,对于CUDA 11.3,可以这样安装:
```bash
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
```
> **注意**:Swin Transformer的官方实现通常对PyTorch版本有一定要求,1.8以上版本一般兼容性较好。务必确保你的CUDA驱动版本支持所选的PyTorch CUDA版本。
核心的模型实现和工具库,我们通过pip安装:
```bash
pip install opencv-python pillow matplotlib scikit-image timm
pip install einops
```
这里有几个关键包需要特别说明:
* **`timm` (PyTorch Image Models)**:这是一个宝藏库,提供了Swin Transformer在内的数百种预训练模型及其加载接口。我们将直接用它来加载Swin骨干网络。
* **`einops`**:它提供了极其优雅的张量操作语法(如`rearrange`, `reduce`),能让你的注意力机制、特征重塑等代码变得清晰易读,强烈推荐。
* **`scikit-image`**:用于数据增强和评估指标(如MAE、F-measure)的计算。
最后,我们需要获取SwinNet的源代码。通常论文作者会在GitHub上开源代码。我们可以克隆或下载其仓库:
```bash
git clone <SwinNet官方仓库地址>
cd SwinNet
```
如果官方仓库结构清晰,其`requirements.txt`文件会列出所有依赖。但根据我的经验,直接按照上述步骤安装核心库,再根据运行时的报错信息补充缺失的包,是更高效的方法。
## 2. 数据准备与预处理管道构建
模型再好,也离不开高质量的数据。RGB-D SOD领域有几个公认的基准数据集,如**NJUD**、**NLPR**、**STEREO**和**RGBD135**。我们的实战将以NJUD数据集为例。
### 2.1 数据集结构与解读
下载并解压NJUD数据集后,你通常会看到如下目录结构:
```
NJUD/
├── RGB/
│ ├── 0001.jpg
│ ├── 0002.jpg
│ └── ...
├── depth/
│ ├── 0001.png
│ ├── 0002.png
│ └── ...
└── GT/
├── 0001.png
├── 0002.png
└── ...
```
* **RGB/**:存放原始的彩色JPG图像。
* **depth/**:存放对应的深度图(通常为16位PNG)。深度图的像素值代表距离,值越大通常表示物体越远。**关键点在于,不同数据集的深度图存储方式和值范围可能不同**,有的可能是反转的(近处值大),有的可能经过了归一化。
* **GT/** (Ground Truth):存放二值化的显著性标注图,前景(显著目标)为白色(255),背景为黑色(0)。
> **提示**:在加载深度图时,务必使用`cv2.IMREAD_UNCHANGED`模式来保留原始的16位信息,然后根据数据集说明进行适当的归一化(例如,归一化到[0, 1]区间)。盲目地以8位方式读取会导致信息丢失。
### 2.2 自定义Dataset类与数据增强
我们将创建一个PyTorch的`Dataset`类来封装数据加载逻辑。这里面的技巧在于如何同步地对RGB图像、深度图和GT进行完全一致的空间变换。
```python
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
import os
from PIL import Image
import numpy as np
import torchvision.transforms as T
import torchvision.transforms.functional as TF
class RGBDSaliencyDataset(Dataset):
def __init__(self, rgb_dir, depth_dir, gt_dir, transform=None, img_size=224):
self.rgb_paths = sorted([os.path.join(rgb_dir, f) for f in os.listdir(rgb_dir) if f.endswith('.jpg')])
self.depth_paths = sorted([os.path.join(depth_dir, f) for f in os.listdir(depth_dir) if f.endswith('.png')])
self.gt_paths = sorted([os.path.join(gt_dir, f) for f in os.listdir(gt_dir) if f.endswith('.png')])
self.transform = transform
self.img_size = img_size
# 基础转换:调整大小并转为Tensor
self.to_tensor = T.Compose([
T.Resize((img_size, img_size)),
T.ToTensor(),
])
def __len__(self):
return len(self.rgb_paths)
def __getitem__(self, idx):
# 读取RGB图像
rgb_img = Image.open(self.rgb_paths[idx]).convert('RGB')
# 读取深度图(16位)
depth_img = cv2.imread(self.depth_paths[idx], cv2.IMREAD_UNCHANGED).astype(np.float32)
# 处理可能的无效值并归一化
if depth_img.max() > 0:
depth_img = depth_img / depth_img.max()
depth_img = Image.fromarray((depth_img * 255).astype(np.uint8)) # 转为PIL Image
# 读取GT
gt_img = Image.open(self.gt_paths[idx]).convert('L')
# 应用同步的数据增强(例如随机水平翻转)
if self.transform:
# 这里需要确保对三个图像应用完全相同的随机参数
seed = torch.randint(0, 2**32, (1,)).item()
torch.manual_seed(seed)
rgb_img = self.transform(rgb_img)
torch.manual_seed(seed)
depth_img = self.transform(depth_img)
torch.manual_seed(seed)
gt_img = self.transform(gt_img)
else:
# 如果没有随机增强,只进行确定性的Resize和ToTensor
rgb_img = self.to_tensor(rgb_img)
depth_img = self.to_tensor(depth_img)
gt_img = self.to_tensor(gt_img)
# 将GT二值化(阈值0.5)
gt_img = (gt_img > 0.5).float()
# 深度图在转为Tensor后是单通道,我们复制成三通道以适配Swin Transformer的输入(如果需要)
if depth_img.shape[0] == 1:
depth_img = depth_img.repeat(3, 1, 1)
return rgb_img, depth_img, gt_img
# 定义训练和验证的数据增强
train_transform = T.Compose([
T.RandomHorizontalFlip(p=0.5),
T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
T.RandomResizedCrop(224, scale=(0.8, 1.0)),
])
val_transform = None # 验证集只做Resize和ToTensor
# 创建数据加载器
train_dataset = RGBDSaliencyDataset('path/to/NJUD/RGB', 'path/to/NJUD/depth', 'path/to/NJUD/GT', train_transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
```
这个`Dataset`类的核心在于**同步随机种子**,确保对RGB、深度、GT的随机操作(如翻转、裁剪)是完全一致的,否则数据就“对不齐”了。颜色抖动`ColorJitter`通常只应用于RGB图像,因为深度图代表几何信息,不应进行颜色扰动。
## 3. SwinNet模型架构深度拆解与实现
现在,我们进入最核心的部分——构建SwinNet模型。我们将遵循论文的模块化思想,一步步实现。理解每个模块的设计动机,比单纯复制代码更重要。
### 3.1 双流Swin Transformer骨干网络
SwinNet使用两个独立的Swin Transformer(通常是Swin-B或Swin-L)作为编码器,分别处理RGB和深度模态。使用`timm`库可以轻松加载预训练权重。
```python
import torch.nn as nn
import timm
class TwoStreamSwinBackbone(nn.Module):
def __init__(self, model_name='swin_base_patch4_window7_224', pretrained=True):
super().__init__()
# 创建RGB流骨干
self.rgb_backbone = timm.create_model(model_name, pretrained=pretrained, features_only=True)
# 创建深度流骨干(通常结构与RGB流相同,但输入通道可能不同,这里我们复制权重)
self.depth_backbone = timm.create_model(model_name, pretrained=pretrained, features_only=True)
# 注意:深度图输入是单通道,但Swin Transformer预训练模型期望3通道输入。
# 一种常见做法是将深度图复制到3个通道,或者修改第一个卷积层的输入通道数并小心地初始化。
# 为简单起见,我们采用复制通道的方法。
self.depth_input_adapter = nn.Conv2d(1, 3, kernel_size=1) # 1x1卷积将1通道映射到3通道
# timm的features_only会返回一个列表,包含不同阶段的特征图
# 对于Swin-B,通常是4个阶段输出的特征,尺寸逐步下采样
def forward(self, rgb, depth):
# 适配深度图输入
depth = self.depth_input_adapter(depth)
# 提取多尺度特征
rgb_features = self.rgb_backbone(rgb) # 列表,例如 [B, C1, H/4, W/4], [B, C2, H/8, W/8], ...
depth_features = self.depth_backbone(depth)
return rgb_features, depth_features
```
这里有个**工程细节**:预训练的Swin Transformer是在ImageNet(RGB三通道)上训练的。直接将单通道深度图输入,即使复制成三通道,其分布也与自然图像相差甚远,可能影响特征提取。论文中有时会提及对深度流骨干进行**部分微调**或使用**特定的预处理**。更精细的做法是,只加载RGB流骨干的权重,深度流骨干随机初始化,并在训练中从头学习。
### 3.2 空间对齐与通道重校准模块(SACR)
这是SwinNet实现跨模态有效融合的第一个关键模块。其设计思想非常直观:
1. **空间对齐**:RGB和深度图中的显著物体应该在相同位置。该模块学习一个公共的空间注意力图,用来加权两种模态的特征,使它们在空间上对齐。
2. **通道重校准**:RGB和深度模态提供的信息重要性不同。该模块通过通道注意力,让网络自适应地强调每个模态中更有用的通道信息。
```python
import torch.nn.functional as F
class SpatialAlignmentChannelRecalibration(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
# 用于生成公共空间注意力图的卷积层
self.spatial_attention = nn.Sequential(
nn.Conv2d(channels * 2, channels // 2, kernel_size=3, padding=1),
nn.BatchNorm2d(channels // 2),
nn.ReLU(inplace=True),
nn.Conv2d(channels // 2, 1, kernel_size=3, padding=1),
nn.Sigmoid() # 输出0-1的注意力权重
)
# 用于RGB和深度模态各自的通道注意力(使用SE模块思想)
self.rgb_channel_attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(channels // 4, channels, kernel_size=1),
nn.Sigmoid()
)
self.depth_channel_attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(channels // 4, channels, kernel_size=1),
nn.Sigmoid()
)
def forward(self, rgb_feat, depth_feat):
# 输入特征形状: [B, C, H, W]
# 1. 生成公共空间注意力图
concat_feat = torch.cat([rgb_feat, depth_feat], dim=1)
spatial_att = self.spatial_attention(concat_feat) # [B, 1, H, W]
# 2. 空间对齐:用同一个注意力图加权两个模态的特征
rgb_aligned = rgb_feat * spatial_att
depth_aligned = depth_feat * spatial_att
# 3. 通道重校准:各自计算通道注意力并加权
rgb_channel_att = self.rgb_channel_attention(rgb_aligned)
depth_channel_att = self.depth_channel_attention(depth_aligned)
rgb_recalibrated = rgb_aligned * rgb_channel_att
depth_recalibrated = depth_aligned * depth_channel_att
# 4. 融合对齐和重校准后的特征(简单相加)
fused_feat = rgb_recalibrated + depth_recalibrated
return fused_feat
```
这个模块会被应用到编码器输出的**每一层特征**上,实现层内的跨模态融合。在实际代码中,你可能需要为不同通道数的特征层实例化多个SACR模块。
### 3.3 边缘感知模块
边缘模糊是显著目标检测的常见问题。SwinNet巧妙地利用深度图的浅层特征(富含边缘细节)来生成清晰的边缘线索。深度图对物体边界通常比RGB图更敏感。
```python
class EdgeAwareModule(nn.Module):
def __init__(self, in_channels, out_channels=64):
super().__init__()
# 使用深度流骨干最浅层的特征(例如第一阶段输出)作为输入
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
# 用于细化边缘特征的卷积块
self.refine = nn.Sequential(
nn.Conv2d(out_channels * 3, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
# 边缘特征的通道注意力
self.edge_channel_attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(out_channels, out_channels // 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels // 4, out_channels, kernel_size=1),
nn.Sigmoid()
)
def forward(self, shallow_depth_feat):
# shallow_depth_feat: [B, C, H, W]
# 通过1x1卷积和上采样生成多尺度边缘线索(假设输入是下采样后的,需要上采样回原图大小?)
# 实际上,我们通常在同一分辨率下操作。这里演示的是论文中提到的连接多个上采样结果的思想。
# 简化版:直接处理当前分辨率特征
feat1 = self.conv1(shallow_depth_feat)
feat2 = F.interpolate(self.conv2(shallow_depth_feat), scale_factor=2, mode='bilinear', align_corners=False)
feat2 = F.interpolate(feat2, size=feat1.shape[2:], mode='bilinear', align_corners=False) # 调整回原尺寸示例
feat3 = F.interpolate(self.conv3(shallow_depth_feat), scale_factor=0.5, mode='bilinear', align_corners=False)
feat3 = F.interpolate(feat3, size=feat1.shape[2:], mode='bilinear', align_corners=False)
# 连接多尺度特征
edge_feat = torch.cat([feat1, feat2, feat3], dim=1)
edge_feat = self.refine(edge_feat)
# 应用通道注意力,增强重要的边缘通道
edge_att = self.edge_channel_attention(edge_feat)
refined_edge_feat = edge_feat * edge_att + edge_feat # 残差连接
return refined_edge_feat
```
生成的`refined_edge_feat`将作为宝贵的先验信息,被送入解码器,指导最终显著性图的生成,确保目标轮廓锐利。
### 3.4 边缘引导解码器
解码器的任务是将编码器提取并融合的多层特征,逐步上采样并聚合,最终输出与输入同分辨率的显著性概率图。SwinNet的解码器是边缘引导的,意味着在每一步融合中,都融入了边缘特征。
```python
class EdgeGuidedDecoder(nn.Module):
def __init__(self, channel_list, edge_channels):
"""
channel_list: 列表,包含从深到浅各层融合特征的通道数,例如 [1024, 512, 256, 128]
edge_channels: 边缘特征的通道数
"""
super().__init__()
self.up_blocks = nn.ModuleList()
self.edge_fusion_blocks = nn.ModuleList()
# 构建从深层到浅层的上采样融合块
for i in range(len(channel_list)-1):
in_ch = channel_list[i] + channel_list[i+1] # 当前层特征 + 上一层上采样特征
out_ch = channel_list[i+1]
self.up_blocks.append(
nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
)
)
# 在每个融合步骤,将边缘特征也融合进来
self.edge_fusion_blocks.append(
nn.Conv2d(out_ch + edge_channels, out_ch, kernel_size=3, padding=1)
)
# 最终输出层,生成单通道显著性图
self.final_conv = nn.Sequential(
nn.Conv2d(channel_list[-1], 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 1, kernel_size=1),
nn.Sigmoid()
)
def forward(self, fused_features, edge_feat):
"""
fused_features: 列表,从深到浅的融合特征(来自SACR模块的输出)
edge_feat: 边缘特征
"""
x = fused_features[0] # 从最深层特征开始
for i, (up_block, edge_fusion) in enumerate(zip(self.up_blocks, self.edge_fusion_blocks)):
# 上采样当前特征
x_up = up_block(x)
# 与同分辨率的浅层融合特征拼接
x = torch.cat([x_up, fused_features[i+1]], dim=1)
# 将边缘特征也拼接到当前融合特征中(边缘特征需要调整到相同分辨率)
edge_resized = F.interpolate(edge_feat, size=x.shape[2:], mode='bilinear', align_corners=False)
x = torch.cat([x, edge_resized], dim=1)
# 通过卷积融合所有信息
x = edge_fusion(x)
# 最终上采样到输入图像大小并输出
saliency_map = self.final_conv(x)
saliency_map = F.interpolate(saliency_map, scale_factor=4, mode='bilinear', align_corners=False) # 假设最终上采样4倍
return saliency_map
```
至此,我们已经将SwinNet的核心模块拆解并实现了。将它们组装起来,就构成了完整的SwinNet模型。在训练时,我们使用混合损失函数,通常包括**二元交叉熵损失(BCE)**和**交并比损失(IoU Loss)**,来同时优化像素级分类和区域级重叠度。
## 4. 训练策略、调参技巧与结果优化
有了模型和数据,如何高效地训练并得到最优结果?这部分充满了实践智慧。
### 4.1 损失函数设计与权衡
单一的BCE损失容易导致预测图模糊。结合IoU或Dice损失可以更好地优化整体区域的一致性。
```python
class HybridLoss(nn.Module):
def __init__(self, bce_weight=1.0, iou_weight=0.5):
super().__init__()
self.bce_weight = bce_weight
self.iou_weight = iou_weight
self.bce_loss = nn.BCELoss()
def iou_loss(self, pred, target):
# pred, target: [B, 1, H, W], 值在0-1之间
intersection = (pred * target).sum(dim=(2,3))
union = pred.sum(dim=(2,3)) + target.sum(dim=(2,3)) - intersection
iou = (intersection + 1e-6) / (union + 1e-6)
return 1 - iou.mean()
def forward(self, pred, target):
bce = self.bce_loss(pred, target)
iou = self.iou_loss(pred, target)
total_loss = self.bce_weight * bce + self.iou_weight * iou
return total_loss, {'bce': bce.item(), 'iou': iou.item()}
```
你可以调整`bce_weight`和`iou_weight`来平衡两者。在我的实验中,`(1.0, 0.5)`是一个不错的起点。更高级的损失函数如**结构相似性损失(SSIM Loss)**或**边缘感知损失**也可以尝试加入,以进一步提升边缘质量。
### 4.2 优化器与学习率调度
对于Swin Transformer这类大模型,**AdamW优化器**配合**余弦退火学习率调度**是当前的主流选择。AdamW相比Adam加入了权重衰减的正则化,能更好地防止过拟合。
```python
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
model = SwinNet(...).cuda()
criterion = HybridLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6) # T_max为总epoch数
# 训练循环中
for epoch in range(epochs):
model.train()
for rgb, depth, gt in train_loader:
rgb, depth, gt = rgb.cuda(), depth.cuda(), gt.cuda()
optimizer.zero_grad()
pred = model(rgb, depth)
loss, loss_dict = criterion(pred, gt)
loss.backward()
optimizer.step()
scheduler.step() # 每个epoch后更新学习率
```
> **注意**:对于双流骨干,我们可能希望采用**差异化的学习率**。例如,对预训练的RGB骨干设置较低的学习率(如`1e-5`),对深度骨干和解码器部分设置较高的学习率(如`1e-4`)。这可以通过`optimizer`的参数分组来实现。
### 4.3 关键超参数调优经验
训练深度模型就像烹饪,火候和配料至关重要。以下是我在调参过程中总结的一些经验,以表格形式呈现,方便你快速参考:
| 超参数 | 推荐范围/值 | 调参影响与建议 |
| :--- | :--- | :--- |
| **初始学习率 (lr)** | 1e-4 到 5e-4 | 太大易震荡不收敛,太小收敛慢。AdamW下1e-4较稳健。可对骨干网络设置更低lr(如1e-5)。 |
| **批大小 (batch_size)** | 8, 16, 32 | 受GPU内存限制。较大的batch_size(如32)可能使训练更稳定,但可能降低模型泛化性。在内存允许下尝试。 |
| **输入图像尺寸** | 224x224, 384x384 | Swin Transformer预训练尺寸多为224。增大尺寸(如384)可能提升细节,但显著增加计算量,需调整窗口大小。 |
| **权重衰减 (weight_decay)** | 1e-4, 5e-4 | 防止过拟合的关键正则化。1e-4是常用起点。如果训练集小,可尝试增大。 |
| **损失函数权重** | BCE: 1.0, IoU: 0.5~1.0 | IoU权重越高,模型越关注区域整体性,可能牺牲一些细节。根据验证集指标(如F-measure, MAE)调整。 |
| **数据增强强度** | 适中 | 随机翻转、裁剪、颜色抖动有效。过度增强(如大尺度裁剪、强颜色扰动)可能破坏RGB-D对齐,反而有害。 |
| **训练周期数** | 40~100 | 使用早停(Early Stopping)策略,监控验证集损失或F-measure,连续多个epoch不提升则停止。 |
### 4.4 推理、可视化与性能评估
模型训练完成后,在测试集上进行推理并评估是检验成果的最后一步。我们不仅需要看数字指标,更要**可视化**结果,直观地分析模型在哪里做得好,哪里还有问题。
常用的评估指标包括:
* **MAE (Mean Absolute Error)**:预测图与GT逐像素绝对误差的平均值,值越小越好。
* **最大F-measure (max Fβ)**:准确率和召回率的加权调和平均,是综合性能的重要指标。
* **S-measure (Structure Measure)**:评估预测图的结构相似性。
* **E-measure (Enhanced Alignment Measure)**:综合考虑局部像素匹配和全局图像级匹配。
```python
def evaluate_model(model, test_loader, save_dir='./results'):
model.eval()
mae_total, f_score_total = 0.0, 0.0
os.makedirs(save_dir, exist_ok=True)
with torch.no_grad():
for idx, (rgb, depth, gt) in enumerate(test_loader):
rgb, depth = rgb.cuda(), depth.cuda()
pred = model(rgb, depth)
pred_np = pred.squeeze().cpu().numpy()
gt_np = gt.squeeze().cpu().numpy()
# 计算MAE
mae = np.mean(np.abs(pred_np - gt_np))
mae_total += mae
# 计算F-measure (需要二值化预测图,使用自适应阈值)
pred_bin = (pred_np > 0.5).astype(np.float32)
# ... 这里省略计算precision, recall, F-beta的代码 ...
# f_score = ...
# 可视化保存
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
axes[0].imshow(rgb[0].permute(1,2,0).cpu().numpy())
axes[0].set_title('RGB')
axes[1].imshow(depth[0,0].cpu().numpy(), cmap='gray')
axes[1].set_title('Depth')
axes[2].imshow(pred_np, cmap='jet')
axes[2].set_title('Prediction')
axes[3].imshow(gt_np, cmap='gray')
axes[3].set_title('Ground Truth')
plt.savefig(os.path.join(save_dir, f'result_{idx:04d}.png'))
plt.close()
avg_mae = mae_total / len(test_loader)
avg_f = f_score_total / len(test_loader)
print(f'Test Results - MAE: {avg_mae:.4f}, maxF: {avg_f:.4f}')
return avg_mae, avg_f
```
通过可视化,你可能会发现一些典型问题:**深度图质量差导致融合失效**、**复杂背景干扰**、**细小物体漏检**、**边缘毛刺**等。针对这些问题,可以回到数据预处理(如深度图增强)、模型结构(如加强边缘模块)或损失函数(如增加边缘损失)上进行针对性改进。
最后,如果你想将模型部署到实际应用中,还需要考虑**模型轻量化**(如使用更小的Swin-Tiny骨干)、**推理速度优化**(使用TensorRT或ONNX转换)以及**处理任意尺寸输入**(将模型改为全卷积,支持动态输入)等工程问题。这又是另一个充满挑战和乐趣的领域了。