timm库Vision Transformer实战:从模型加载到特征提取的完整流程

# 从零上手timm库Vision Transformer:模型加载、特征提取与迁移学习实战 如果你最近在计算机视觉项目中尝试过Transformer架构,大概率会听说过`timm`这个宝藏库。作为一个长期在图像分类、目标检测任务中摸爬滚打的开发者,我第一次接触`timm`时,最直接的感受是:它把那些复杂的模型调用变得像搭积木一样简单。特别是对于Vision Transformer(ViT)这类相对“年轻”但效果惊人的模型,`timm`提供了一个极其友好的入口,让你不用从零开始复现论文,就能快速验证想法、部署原型。 这篇文章不是一篇泛泛而谈的概述,而是聚焦于**实际操作**。我会带你走完从安装环境、加载预训练ViT模型,到提取中间特征、进行迁移学习的完整流程。过程中会穿插我实际项目中踩过的坑、总结的技巧,以及一些容易被忽略但至关重要的细节。无论你是想快速在自己的数据集上微调ViT,还是需要提取图像特征用于下游任务,这里的内容都能给你直接的参考。 ## 1. 环境搭建与timm库初探 在开始任何代码之前,确保你的环境是正确配置的。我推荐使用Python 3.8或更高版本,以及PyTorch 1.7以上。`timm`库的安装非常简单: ```bash pip install timm # 或者,如果你想安装最新开发版 pip install git+https://github.com/rwightman/pytorch-image-models ``` 安装完成后,第一件事不是急着写代码,而是先看看`timm`这个“武器库”里到底有哪些“武器”。`timm`支持数百个预训练模型,从经典的ResNet到最新的Swin Transformer、ConvNeXt,应有尽有。查看所有可用模型列表: ```python import timm # 列出所有可用的模型架构 model_names = timm.list_models() print(f"模型总数: {len(model_names)}") # 只看ViT相关模型 vit_models = timm.list_models('*vit*') print(f"ViT系列模型: {vit_models[:10]}") # 只打印前10个 ``` 你会看到一个很长的列表。对于ViT,常见的命名模式如 `vit_tiny_patch16_224`、`vit_base_patch16_224`、`vit_large_patch16_224`等。这些名字其实包含了关键信息: - `tiny/base/large`: 模型规模 - `patch16`: 图像被分割成的块大小(patch size) - `224`: 模型训练时输入的图像分辨率 > 注意:`timm`中的模型权重默认从Hugging Face Hub或作者指定的URL下载。第一次加载某个模型时,如果本地没有缓存,会自动下载。下载速度取决于网络,模型文件从几十MB到几百MB不等,请确保有足够的磁盘空间和稳定的网络连接。 ## 2. 加载预训练ViT模型:理解关键参数 加载一个预训练的ViT模型,核心函数是 `timm.create_model()`。这个函数看似简单,但几个参数的选择会直接影响后续所有操作。让我们从一个最基本的例子开始: ```python import torch import timm # 加载一个基础的ViT模型 model = timm.create_model('vit_base_patch16_224', pretrained=True) model.eval() # 切换到评估模式 # 打印模型结构概览 print(model) ``` 执行这段代码,如果这是你第一次使用`vit_base_patch16_224`,会看到下载进度条。下载完成后,模型就被加载到内存中了。但这里有个问题:这个模型是在ImageNet-1k上预训练的,输出是1000个类别的概率。如果你的任务不是ImageNet分类呢?这就引出了最重要的参数之一:`num_classes`。 ### 2.1 `num_classes`参数:迁移学习的钥匙 `num_classes`参数决定了模型分类头(head)的输出维度。默认情况下,对于在ImageNet上预训练的模型,`num_classes=1000`。但当你进行迁移学习时,比如你的任务只有10个类别,就需要修改它。 ```python # 场景1:直接用于ImageNet分类(或类别数相同的任务) model_imagenet = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=1000) # 场景2:用于10分类任务(迁移学习) model_10cls = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10) # 检查分类头的差异 print(f"原始模型分类头: {model_imagenet.head}") print(f"10分类模型分类头: {model_10cls.head}") ``` 当你设置 `num_classes` 为一个非默认值时,`timm` 实际上做了两件事: 1. 加载预训练的主干网络(backbone)权重 2. **重新初始化**分类头(最后一个线性层),使其输出维度匹配你指定的类别数 这意味着,在 `model_10cls` 中,除了分类头之外的所有层都保留了ImageNet上学到的知识,而分类头是随机初始化的,需要你在自己的数据上重新训练。 ### 2.2 `pretrained` 与 `pretrained_cfg` 的微妙关系 有时候你会发现,即使设置了 `pretrained=True`,模型也没有加载预训练权重。这通常是因为模型名称与预训练配置不匹配。`timm` 的模型系统其实相当灵活: ```python # 明确指定预训练配置 model = timm.create_model( 'vit_base_patch16_224', pretrained=True, pretrained_cfg='vit_base_patch16_224' # 明确指定配置 ) # 查看可用的预训练配置 model.pretrained_cfg ``` 在实际项目中,我习惯先检查模型是否有对应的预训练权重: ```python import timm model_name = 'vit_base_patch16_224' # 检查该模型是否有默认的预训练权重 has_pretrained = timm.models.is_model_pretrained(model_name) print(f"{model_name} 是否有预训练权重: {has_pretrained}") # 获取该模型的默认预训练配置 default_cfg = timm.models.get_pretrained_cfg(model_name) print(f"默认配置: {default_cfg}") ``` 这个检查步骤能避免很多“为什么我的模型效果这么差”的困惑——有时候你以为加载了预训练权重,实际上并没有。 ## 3. 深入ViT前向传播:特征提取的两种方式 理解ViT的前向传播过程,对于正确提取特征至关重要。与CNN不同,ViT的输出结构有其特殊性。让我们先看看标准的分类前向传播: ```python import torch import timm # 加载模型 model = timm.create_model('vit_base_patch16_224', pretrained=True) model.eval() # 创建一个模拟输入(batch_size=2, 3通道, 224x224分辨率) dummy_input = torch.randn(2, 3, 224, 224) # 标准前向传播(得到分类结果) with torch.no_grad(): output = model(dummy_input) print(f"分类输出形状: {output.shape}") # torch.Size([2, 1000]) ``` 这得到了最终的分类概率。但很多时候,我们需要的是**分类之前的特征**,比如: - 用于特征可视化 - 作为其他任务的输入(如检索、聚类) - 进行特征相似度计算 ### 3.1 使用 `forward_features` 方法 `timm` 为大多数视觉Transformer模型提供了 `forward_features` 方法,专门用于提取分类前的特征: ```python # 提取分类前的特征 with torch.no_grad(): features = model.forward_features(dummy_input) print(f"特征形状: {features.shape}") # 对于ViT,通常是 torch.Size([2, 197, 768]) ``` 这里有个关键点:对于标准的ViT,`forward_features` 返回的是 **所有token的特征**,包括: - 1个class token(CLS token) - 196个图像patch token(对于224x224输入,patch size为16时,有(224/16)²=196个patch) 所以输出形状是 `[batch_size, 197, hidden_dim]`。其中 `hidden_dim` 对于 `vit_base` 是768。 ### 3.2 提取CLS token特征 在ViT中,通常用CLS token的特征作为整个图像的表示。你可以这样提取: ```python # 方法1:从forward_features的输出中提取CLS token with torch.no_grad(): all_features = model.forward_features(dummy_input) cls_features = all_features[:, 0, :] # 取第一个token(CLS token) print(f"CLS特征形状: {cls_features.shape}") # torch.Size([2, 768]) # 方法2:直接使用模型输出(但需要理解模型内部处理) # 对于分类任务,ViT默认使用CLS token进行分类 # 所以标准forward的输出就是基于CLS token的 ``` ### 3.3 一个常见的陷阱:`num_classes=0` 的妙用 如果你只需要特征,完全不需要分类头,可以在创建模型时设置 `num_classes=0`: ```python # 创建没有分类头的ViT(只用于特征提取) feature_extractor = timm.create_model( 'vit_base_patch16_224', pretrained=True, num_classes=0 # 关键参数! ) # 现在模型没有分类头,forward直接返回特征 with torch.no_grad(): features = feature_extractor(dummy_input) print(f"无分类头时的输出形状: {features.shape}") # torch.Size([2, 197, 768]) ``` 这种方式创建的模型,其 `forward` 方法实际上等同于 `forward_features`。这在纯特征提取场景下非常方便,因为: 1. 模型更轻量(少了分类头的参数) 2. 前向传播稍微快一点 3. 代码更简洁,不需要显式调用 `forward_features` > 提示:当你设置 `num_classes=0` 时,`timm` 会移除模型的分类头,但保留全局平均池化(如果存在)或CLS token提取的逻辑。对于ViT,它仍然会返回CLS token的特征。 ## 4. 实战:构建完整的ViT特征提取流水线 现在我们把所有知识整合起来,构建一个实用的特征提取流水线。这个流水线需要处理: 1. 图像预处理(与模型训练时一致) 2. 批量处理 3. 特征提取与保存 4. 错误处理与日志 ### 4.1 图像预处理标准化 ViT模型对输入图像的预处理有特定要求。幸运的是,`timm` 提供了便捷的数据配置: ```python import torch import timm from PIL import Image import torchvision.transforms as T # 获取模型的默认预处理配置 model_name = 'vit_base_patch16_224' model = timm.create_model(model_name, pretrained=True, num_classes=0) model.eval() # 获取数据配置 data_config = timm.data.resolve_model_data_config(model) print(f"数据配置: {data_config}") # 创建预处理管道 transform = timm.data.create_transform( **data_config, is_training=False # 推理时的预处理 ) # 或者手动创建(了解每一步在做什么) manual_transform = T.Compose([ T.Resize((256, 256)), # 先缩放到稍大尺寸 T.CenterCrop(224), # 中心裁剪到模型输入尺寸 T.ToTensor(), # 转为Tensor T.Normalize(mean=[0.485, 0.456, 0.406], # ImageNet统计量 std=[0.229, 0.224, 0.225]) ]) # 测试预处理 img = Image.open('example.jpg').convert('RGB') input_tensor = transform(img).unsqueeze(0) # 增加batch维度 print(f"预处理后形状: {input_tensor.shape}") ``` ### 4.2 批量特征提取与性能优化 处理大量图像时,我们需要考虑效率和内存使用: ```python import torch from torch.utils.data import DataLoader, Dataset from tqdm import tqdm import numpy as np class ImageDataset(Dataset): """自定义图像数据集""" def __init__(self, image_paths, transform): self.image_paths = image_paths self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img = Image.open(self.image_paths[idx]).convert('RGB') return self.transform(img), self.image_paths[idx] def extract_features_batch(model, dataloader, device='cuda', use_cls_token=True): """ 批量提取特征 参数: model: 特征提取模型 dataloader: 数据加载器 device: 计算设备 use_cls_token: 是否只使用CLS token特征 """ model.to(device) model.eval() all_features = [] all_paths = [] with torch.no_grad(): for batch_imgs, batch_paths in tqdm(dataloader, desc="提取特征"): batch_imgs = batch_imgs.to(device) # 提取特征 features = model(batch_imgs) # 假设model是num_classes=0的 if use_cls_token: # 只取CLS token features = features[:, 0, :] else: # 使用所有token的平均(替代方案) features = features.mean(dim=1) # 移到CPU并转为numpy features = features.cpu().numpy() all_features.append(features) all_paths.extend(batch_paths) # 合并所有批次的特征 all_features = np.vstack(all_features) return all_features, all_paths # 使用示例 image_paths = ['img1.jpg', 'img2.jpg', 'img3.jpg'] # 你的图像路径列表 dataset = ImageDataset(image_paths, transform) dataloader = DataLoader(dataset, batch_size=32, shuffle=False) # 创建特征提取模型 feature_model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0) # 提取特征 features, paths = extract_features_batch(feature_model, dataloader, device='cuda') print(f"提取了 {len(features)} 张图像的特征") print(f"特征维度: {features.shape}") ``` ### 4.3 特征保存与加载 提取的特征通常需要保存供后续使用: ```python import h5py import pickle from pathlib import Path def save_features_h5(features, paths, output_path): """使用HDF5格式保存特征(支持大型数据集)""" with h5py.File(output_path, 'w') as f: # 保存特征数组 f.create_dataset('features', data=features) # 保存路径(需要特殊处理字符串) dt = h5py.special_dtype(vlen=str) path_ds = f.create_dataset('paths', (len(paths),), dtype=dt) for i, path in enumerate(paths): path_ds[i] = path print(f"特征已保存到 {output_path}") def save_features_pickle(features, paths, output_path): """使用pickle保存特征和元数据""" data = { 'features': features, 'paths': paths, 'model_name': 'vit_base_patch16_224', 'extraction_time': '2024-01-01' } with open(output_path, 'wb') as f: pickle.dump(data, f) print(f"特征已保存到 {output_path}") # 保存示例 output_dir = Path('./extracted_features') output_dir.mkdir(exist_ok=True) # 使用HDF5格式(推荐用于大型数据集) save_features_h5(features, paths, output_dir / 'features.h5') # 或者使用pickle(适合小型数据集) save_features_pickle(features, paths, output_dir / 'features.pkl') ``` ## 5. 高级技巧与常见问题排查 在实际使用中,你可能会遇到各种问题。这里分享一些我积累的经验和解决方案。 ### 5.1 处理不同尺寸的输入 ViT模型通常要求固定尺寸的输入(如224x224),但实际图像可能尺寸各异。有几种处理策略: ```python import torch import timm from torchvision import transforms # 方法1:直接resize到固定尺寸(可能变形) transform_fixed = transforms.Compose([ transforms.Resize((224, 224)), # 强制变形 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 方法2:保持宽高比,填充到正方形 def pad_to_square(img): """将图像填充为正方形""" w, h = img.size max_side = max(w, h) # 创建新图像(白色背景) new_img = Image.new('RGB', (max_side, max_side), (255, 255, 255)) # 粘贴原图到中心 left = (max_side - w) // 2 top = (max_side - h) // 2 new_img.paste(img, (left, top)) return new_img # 方法3:使用timm的灵活预处理 model = timm.create_model('vit_base_patch16_224', pretrained=True) # 有些ViT变体支持动态输入尺寸,但需要检查模型是否支持 ``` ### 5.2 内存优化与计算效率 处理高分辨率图像或大批量数据时,内存可能成为瓶颈: ```python # 技巧1:使用混合精度推理 from torch.cuda.amp import autocast model = timm.create_model('vit_base_patch16_224', pretrained=True).cuda() model.eval() def extract_features_amp(model, images): """使用自动混合精度进行特征提取""" with torch.no_grad(): with autocast(): features = model.forward_features(images) return features # 技巧2:梯度检查点(用于非常大的模型) model_with_checkpoint = timm.create_model( 'vit_large_patch16_224', pretrained=True, num_classes=0, features_only=False ) # 启用梯度检查点(训练时更有用,推理时影响不大) # 注意:这会增加计算时间,但减少内存使用 for block in model_with_checkpoint.blocks: block.grad_checkpointing = True # 技巧3:分块处理超大图像 def extract_features_chunked(model, large_image, chunk_size=512): """ 分块处理超大图像,然后合并特征 适用于ViT但需要小心位置编码的影响 """ # 这种方法比较复杂,需要根据具体任务设计 # 通常更好的做法是直接resize到模型支持的尺寸 pass ``` ### 5.3 特征可视化与解释 理解ViT学到了什么,特征可视化是很好的工具: ```python import numpy as np import matplotlib.pyplot as plt from sklearn.decomposition import PCA from sklearn.manifold import TSNE def visualize_features_2d(features, labels, method='pca', title="特征可视化"): """ 将高维特征降维到2D进行可视化 参数: features: 特征矩阵 [n_samples, n_features] labels: 样本标签 method: 'pca' 或 'tsne' """ if method == 'pca': reducer = PCA(n_components=2) elif method == 'tsne': reducer = TSNE(n_components=2, random_state=42) else: raise ValueError(f"不支持的降维方法: {method}") # 降维 features_2d = reducer.fit_transform(features) # 可视化 plt.figure(figsize=(10, 8)) scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], c=labels, cmap='tab10', alpha=0.6) plt.colorbar(scatter) plt.title(f"{title} - {method.upper()}") plt.xlabel("Component 1") plt.ylabel("Component 2") plt.tight_layout() plt.show() # 使用示例 # 假设我们有一些特征和对应的类别标签 # features = ... # 形状 [n_samples, 768] # labels = ... # 形状 [n_samples] # visualize_features_2d(features, labels, method='tsne') ``` ### 5.4 常见错误与解决方案 我在使用timm和ViT时遇到过的一些典型问题: ```python # 错误1:形状不匹配 try: # 错误的输入形状 wrong_input = torch.randn(1, 224, 224, 3) # HWC格式,但PyTorch需要CHW output = model(wrong_input) except Exception as e: print(f"错误: {e}") # 解决方案:确保输入是 [batch, channels, height, width] correct_input = wrong_input.permute(0, 3, 1, 2) # 错误2:模型不在eval模式 model.train() # 训练模式 # 这会启用dropout和batch norm的训练行为 # 解决方案:推理前切换到eval模式 model.eval() # 错误3:忘记禁用梯度 # 在特征提取时,我们不需要计算梯度 torch.set_grad_enabled(False) # 全局禁用 # 或者使用torch.no_grad()上下文管理器 # 错误4:内存不足 # 解决方案:减小batch size或使用梯度累积 batch_size = 4 # 根据你的GPU内存调整 ``` ## 6. 迁移学习实战:在自定义数据集上微调ViT 最后,我们来看一个完整的迁移学习示例。假设我们有一个10分类的数据集: ```python import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset from torchvision import transforms import timm from pathlib import Path from PIL import Image import numpy as np class CustomDataset(Dataset): def __init__(self, data_dir, transform=None): self.data_dir = Path(data_dir) self.transform = transform # 假设数据组织为:data_dir/class_name/*.jpg self.classes = sorted([d.name for d in self.data_dir.iterdir() if d.is_dir()]) self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} self.samples = [] for class_name in self.classes: class_dir = self.data_dir / class_name for img_path in class_dir.glob('*.jpg'): self.samples.append((img_path, self.class_to_idx[class_name])) def __len__(self): return len(self.samples) def __getitem__(self, idx): img_path, label = self.samples[idx] img = Image.open(img_path).convert('RGB') if self.transform: img = self.transform(img) return img, label def train_epoch(model, dataloader, criterion, optimizer, device): """训练一个epoch""" model.train() total_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(dataloader): inputs, targets = inputs.to(device), targets.to(device) # 前向传播 outputs = model(inputs) loss = criterion(outputs, targets) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() # 统计 total_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() avg_loss = total_loss / len(dataloader) accuracy = 100. * correct / total return avg_loss, accuracy def validate(model, dataloader, criterion, device): """验证""" model.eval() total_loss = 0 correct = 0 total = 0 with torch.no_grad(): for inputs, targets in dataloader: inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs) loss = criterion(outputs, targets) total_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() avg_loss = total_loss / len(dataloader) accuracy = 100. * correct / total return avg_loss, accuracy def main(): # 超参数 num_classes = 10 batch_size = 32 num_epochs = 10 learning_rate = 1e-4 # 设备 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 数据预处理 transform_train = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) transform_val = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 数据集和数据加载器 train_dataset = CustomDataset('./data/train', transform=transform_train) val_dataset = CustomDataset('./data/val', transform=transform_val) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) # 创建模型 - 关键步骤! model = timm.create_model( 'vit_base_patch16_224', pretrained=True, num_classes=num_classes # 修改分类头 ).to(device) # 只训练分类头(可选:先冻结主干,只训练分类头) # for param in model.parameters(): # param.requires_grad = False # for param in model.head.parameters(): # param.requires_grad = True # 损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.AdamW(model.parameters(), lr=learning_rate) # 训练循环 for epoch in range(num_epochs): train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device) val_loss, val_acc = validate(model, val_loader, criterion, device) print(f'Epoch {epoch+1}/{num_epochs}:') print(f' Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%') print(f' Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%') # 保存模型 torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': num_epochs, 'val_accuracy': val_acc }, 'vit_finetuned.pth') if __name__ == '__main__': main() ``` 这个示例展示了完整的迁移学习流程。有几个关键点值得注意: 1. **数据预处理**:训练和验证使用不同的预处理,训练时增加数据增强 2. **模型创建**:通过 `num_classes` 参数自动修改分类头 3. **训练策略**:可以先冻结主干网络,只训练分类头,然后再解冻全部微调 4. **优化器选择**:对于Transformer,AdamW通常比SGD效果更好 在实际项目中,你可能还需要添加学习率调度、早停、模型检查点保存等功能。但上面的代码已经提供了一个坚实的起点。 ## 7. 性能对比与模型选择 `timm` 提供了多种ViT变体,如何选择适合你任务的模型?这里有一个简单的对比: | 模型名称 | 参数量 | 输入尺寸 | ImageNet Top-1 Acc | 特征维度 | 适用场景 | |---------|--------|----------|-------------------|----------|----------| | `vit_tiny_patch16_224` | 5.7M | 224x224 | 75.5% | 192 | 移动端、快速原型 | | `vit_small_patch16_224` | 22.1M | 224x224 | 81.4% | 384 | 平衡型,通用任务 | | `vit_base_patch16_224` | 86.6M | 224x224 | 84.5% | 768 | 大多数研究项目 | | `vit_large_patch16_224` | 304.3M | 224x224 | 85.8% | 1024 | 高性能需求 | | `vit_base_patch16_384` | 86.6M | 384x384 | 86.0% | 768 | 高分辨率图像 | 选择模型时需要考虑: 1. **计算资源**:大模型需要更多GPU内存和计算时间 2. **数据量**:小数据集可能更适合小模型,避免过拟合 3. **输入尺寸**:高分辨率输入通常能提升性能,但计算成本更高 4. **部署环境**:移动端或边缘设备需要轻量级模型 我个人的经验是,对于大多数任务,`vit_base_patch16_224` 是一个不错的起点。如果效果不够好,可以尝试更大的模型或更高分辨率。如果速度是首要考虑,`vit_small` 或 `vit_tiny` 可能更合适。 ## 8. 实际项目中的注意事项 在真实的生产环境中使用ViT和timm时,还有一些额外的考虑: **多GPU训练**:如果你的数据或模型很大,可能需要多GPU训练。`timm` 与PyTorch的 `DistributedDataParallel` 兼容良好: ```python import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP def setup(rank, world_size): """设置分布式训练""" dist.init_process_group("nccl", rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group() def train_distributed(rank, world_size): """分布式训练函数""" setup(rank, world_size) # 每个进程创建自己的模型 model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10).to(rank) # 包装为DDP模型 model = DDP(model, device_ids=[rank]) # ... 训练代码 ... cleanup() ``` **模型导出**:如果需要将模型部署到生产环境,可能需要导出为其他格式: ```python # 导出为TorchScript model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10) model.eval() # 跟踪模式(对于动态控制流有限的模型) traced_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224)) traced_model.save("vit_traced.pt") # 或者脚本模式(更通用) scripted_model = torch.jit.script(model) scripted_model.save("vit_scripted.pt") ``` **特征缓存**:如果多次使用相同的特征,考虑缓存结果: ```python import hashlib import pickle from pathlib import Path class CachedFeatureExtractor: """带缓存的特征提取器""" def __init__(self, model, cache_dir="./feature_cache"): self.model = model self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(exist_ok=True) def _get_cache_key(self, image_path, params): """生成缓存键""" # 基于图像路径和提取参数生成唯一键 content = f"{image_path}_{params}" return hashlib.md5(content.encode()).hexdigest() def extract(self, image_path, use_cls_token=True): """提取特征,使用缓存""" cache_key = self._get_cache_key(image_path, use_cls_token) cache_file = self.cache_dir / f"{cache_key}.pkl" # 检查缓存 if cache_file.exists(): with open(cache_file, 'rb') as f: return pickle.load(f) # 提取特征 img = Image.open(image_path).convert('RGB') # ... 预处理和特征提取 ... # 保存到缓存 with open(cache_file, 'wb') as f: pickle.dump(features, f) return features ``` 这些实战经验来自我在多个计算机视觉项目中的积累。每个项目都有其特殊性,但掌握这些核心概念和技巧,能让你在面对新任务时更加从容。ViT和timm的组合确实强大,但真正的价值在于如何将它们灵活地应用到你的具体问题中。

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

Python内容推荐

Swin Transformer实战:timm中的 Swin Transformer实现图像分类(多GPU)。

Swin Transformer实战:timm中的 Swin Transformer实现图像分类(多GPU)。

本例提取了植物幼苗数据集中的部分数据做数据集,数据集共有12种类别,演示如何使用timm版本的Swin Transformer图像分类模型实现分类任务已经对验证集得分的统计,本文实现了多个GPU并行训练。 通过本文你和学到: ...

PyTorch中利用timm库实现28个视觉Transformer模型并进行代码解析

PyTorch中利用timm库实现28个视觉Transformer模型并进行代码解析

文中首先概述了timm库的特点和优势,接着深入剖析了vision_transformer.py文件中的关键代码片段,包括模型配置、嵌入层、注意力机制、残差连接、位置编码等重要组件的具体实现方法。此外,还探讨了如何自定义新的...

《DeepSeek原理与项目实战:大模型部署、微调与应用开发(752页)》.pdf

《DeepSeek原理与项目实战:大模型部署、微调与应用开发(752页)》.pdf

内容概要:《DeepSeek原理与项目实战:大模型部署、微调与应用开发》系统介绍了基于Transformer架构的DeepSeek大模型核心技术及其在实际开发中的应用。全书分为三大部分,共12章。第一部分深入解析了Transformer与...

Swin Transformer v2实战:使用Swin Transformer v2实现图像分类

Swin Transformer v2实战:使用Swin Transformer v2实现图像分类

Swin Transformer v2解决了大型视觉模型训练和应用中的三个主要问题,包括训练不稳定性、预训练和微调之间的分辨率差距以及对标记数据的渴望。 最新更改: 重新适配了timm,并将更换了huggingface的国内链接。 链接...

大模型监测 这段代码使用了 Hugging Face 的 Vision Transformer (ViT) 模型,完成一个图像分类任务 以下是主要功能:

模型加载:加载预训练的 ViT 模型 goo

大模型监测 这段代码使用了 Hugging Face 的 Vision Transformer (ViT) 模型,完成一个图像分类任务 以下是主要功能: 模型加载:加载预训练的 ViT 模型 goo

之后,代码加载了预训练的 Vision Transformer 模型,并将模型转移到计算设备上。定义了优化器和损失函数,这是构建训练循环的基础。在训练循环中,代码遍历每个批次的数据,执行前向传播、计算损失、反向传播以及...

全流程实战:神经网络入门到私有AI平台落地完整版-1.8G课程网盘链接提取码下载 .txt

全流程实战:神经网络入门到私有AI平台落地完整版-1.8G课程网盘链接提取码下载 .txt

20 小工人建造 Transformer 分类器:环境配置到代码讲解全流程拆解.mp4 21 小工人精讲超参数:模型变强的秘密调优手册.mp4 22 小工人讲运行逻辑:数据加载、词表、模型初始化一步步走通.mp4 23 小工人开工!...

【从0到1搞懂大模型】transformer详解:架构及代码实践-transformer完整代码(7)

【从0到1搞懂大模型】transformer详解:架构及代码实践-transformer完整代码(7)

为了帮助理解和应用,有的资料会提供Transformer的完整代码实现,从基础数据处理到模型训练和预测的整个流程。 在实践中,Transformer模型已经演变成很多变种,如BERT、GPT系列、Transformer-XL等,它们在各种NLP...

timm-0.6.7.tar.gz

timm-0.6.7.tar.gz

timm库的模型集覆盖了从基础的卷积神经网络到复杂的自注意力机制的网络架构。它的模型包括但不限于ResNet、ResNeXt、DenseNet、EfficientNet、Vision Transformer等经典架构。这些模型都能够解决各种视觉任务,而且...

Transformer与大模型实战

Transformer与大模型实战

本书《Transformer与大模型实战》深入探讨了当前自然语言处理(NLP)领域的核心技术和实战应用,特别聚焦于Transformer架构、BERT以及GPT系列模型的原理和实践。本书不仅详细解读了Transformer模型的基础架构,而且...

vision-transformer实战总结:非常简单的VIT入门教程,一定不要错过

vision-transformer实战总结:非常简单的VIT入门教程,一定不要错过

本例提取了植物幼苗数据集中的部分数据做数据集,数据集共有12种类别,演示如何使用pytorch版本的VIT图像分类模型实现分类任务。 通过本文你和学到: 1、如何构建VIT模型? 2、如何生成数据集? 3、如何使用Cutout...

医疗图像分割-基于Pyramid-Vision-Transformer算法实现医疗息肉分割-优质项目实战.zip

医疗图像分割-基于Pyramid-Vision-Transformer算法实现医疗息肉分割-优质项目实战.zip

本项目通过实战案例详细介绍了基于Pyramid-Vision-Transformer算法的医疗息肉图像分割技术,不仅深入探讨了算法本身,而且涉及了从数据预处理到模型训练、评估等整个流程,旨在为医疗图像处理领域的研究者和工程师...

ViT-基于MNIST手写数字识别数据集训练Vision-Transformer模型-简单易上手-优质项目实战.zip

ViT-基于MNIST手写数字识别数据集训练Vision-Transformer模型-简单易上手-优质项目实战.zip

Vision-Transformer(ViT)模型是一种基于Transformer架构的深度学习模型,最初被设计用于处理自然语言处理(NLP)任务,但其设计理念同样适用于图像处理领域。本项目的核心在于使用ViT模型在MNIST手写数字识别数据...

vision transformer预训练

vision transformer预训练

总的来说,"vision transformer预训练"通过各种自监督策略,如对比学习、像素级别的重建任务和掩码自编码,使得Transformer模型能在有限的数据下学习到丰富的视觉特征,并在计算机视觉任务中展现出强大的性能。...

深度学习基于Vision Transformer与Star Block的图像分类模型设计:增强特征提取与分类性能

深度学习基于Vision Transformer与Star Block的图像分类模型设计:增强特征提取与分类性能

内容概要:本文介绍了一种改进的视觉Transformer模型(ViT),通过引入自定义的Star_Block模块增强其性能。Star_Block模块由中心分支和多个并行分支组成,采用卷积神经网络(CNN)技术处理图像特征。具体来说,中心...

BERT基础教程:Transformer大模型实战.pdf

BERT基础教程:Transformer大模型实战.pdf

关注有更多资源,私免费的得

基于TCN-Transformer结构的时间序列预测模型:共享特征提取与多场景应用优化模型,基于TCN-Transformer实现时间序列预测 
模型采用共享TCN结构,用于提取Encoder Emb

基于TCN-Transformer结构的时间序列预测模型:共享特征提取与多场景应用优化模型,基于TCN-Transformer实现时间序列预测 模型采用共享TCN结构,用于提取Encoder Emb

基于TCN-Transformer结构的时间序列预测模型:共享特征提取与多场景应用优化模型,基于TCN-Transformer实现时间序列预测。 模型采用共享TCN结构,用于提取Encoder Embedding和Decoder Embedding 的因果特征,在尽可能...

万得多模态题目万得多模态题目万得多模态题目

万得多模态题目万得多模态题目万得多模态题目

本资源摘要信息将详细介绍多模态学习和视觉 Transformer 模型,包括 Vision Transformer(ViT)模型和 Contrastive Language-Image Pre-training(CLIP)模型。 一、多模态学习 多模态学习是一种机器学习技术,...

深度学习-Transformer实战系列课程

深度学习-Transformer实战系列课程

Transformer模型是深度学习领域中的一个重大突破,由Google在2017年提出的《Attention is All You Need》论文中首次介绍。Transformer模型以其创新性的注意力机制(Attention Mechanism)取代了传统的序列依赖模型,...

timm-1.0.15.tar.gz

timm-1.0.15.tar.gz

它之所以受到欢迎,主要是因为其涵盖了从简单的卷积神经网络(CNN)到复杂的视觉变换器(Vision Transformer)在内的多种网络结构。这些模型被广泛应用于计算机视觉领域,如图像分类、目标检测、语义分割等任务。 ...

Vision Transformer详解[可运行源码]

Vision Transformer详解[可运行源码]

Vision Transformer(ViT)是一种新兴的深度学习架构,它将Transformer模型引入到图像处理领域,并取得了引人注目的效果。ViT的基本思路是将图像切分成多个块,每个块都可以看作是一个序列中的token,然后通过...

最新推荐最新推荐

recommend-type

虚化高斯模糊-下载即用.zip

打开链接下载源码: https://pan.quark.cn/s/4397e18c5cb7 ShapeBlurView 库是一个高斯模糊(毛玻璃效果)蒙层库。 简书地址 https://www.jianshu.com/p/442759a3ccf1 不知大家做需求的时候是否有这样的效果要求: 需求示例 大家熟悉的Android常用图片加载库,比如Glide 可以对图片进行毛玻璃效果的加载(实现不展开说了) 但是都是对整个要加载的图片进行高斯模糊效果,对应局部这种比较难处理,这个库就能实现这样的效果。 当然,你对整个图片盖一层,也能达到Glide高斯模糊加载的效果。 先看看效果: 效果示例 效果示例 [comment]: <> (效果示例效果示例) 网上有其他大神开源的库,但都有些美中不足。 此库支持矩形、圆形、椭圆;边框、边框自定义颜色、自定义边框粗细;矩形时支持切圆角 并且可以支持对4个角分别切圆角。 ----- 使用步骤 1、在添加maven地址的地方添加: 2、在需要使用的gradle文件添加依赖: appcompat:*根据你自己的版本添加 使用说明 (1)Xml布局文件中引用 默认效果代码如上,当然width、height根据需求而定 (2)可用属性 ``RealtimeBlurView`库,感谢:RealtimeBlurView 项目库如有不足和错误的地方,欢迎大家讨论指正! 觉得不错的话,感谢Star下!
recommend-type

移除 Windows PE

源码直接下载地址: https://pan.quark.cn/s/1dbc338528b4 Uninstall_Statistics ================= 统计 应用 自身被 卸载 Android statistics application is uninstalled 参考自这篇blog http://www.cnblogs.com/zealotrouge/p/3157126.html http://www.cnblogs.com/zealotrouge/p/3159772.html
recommend-type

YOLO算法道路场景扫描车与汽车目标检测数据集-6655张-标注类别为汽车-扫描车.zip

1. YOLO目标检测数据集, 适用于YOLOV5、yolov7,yolov8, yolov11, yolov13, yolo26等系列算法,含标签,已标注好,可以直接用来训练; 2. 内置data.yaml数据集配置文件,已经划分好了训练集、验证集等; 3. 数据集和模型具体情况可参考https://blog.csdn.net/zhiqingAI/article/details/161091291?spm=1011.2415.3001.5331 , 和 https://blog.csdn.net/zhiqingAI/article/details/124230743?spm=1001.2014.3001.5502
recommend-type

12306火车站三字码表

下载代码方式:https://pan.quark.cn/s/48abaf2fae86 12306火车站对应的三字代码表,在2020年4月28日获取,编码格式为utf-8。该资料是用于达成python爬取票务系统余票时url生成的必要条件。
recommend-type

UPS维护记录-下载即用.zip

打开链接下载源码: https://pan.quark.cn/s/5e13cc87aca0 【不间断电源(Uninterruptible Power Supply,简称UPS)的维护档案】是数据中心管理过程中的核心环节,其根本目的在于保障供电设备的持续稳定运作,避免因电力供应波动所引发的系统停运或信息遗失。以下列举了关于UPS维护的核心要点:1. **不间断电源的功能**:- UPS是一种能持续供应稳定电能的装置,当外部电源中断时能够即时切换至电池供电模式,确保关键设备不受干扰,尤其对于电力要求较高的IT基础设施,例如服务器及网络设备等,其作用尤为关键。2. **维护作业**:- **清洁除尘**:UPS内部积聚的灰尘可能阻碍散热系统,进而降低运行效能,甚至诱发电路短路,因此周期性清理内部尘埃具有必要性。 - **静电防护**:静电可能对电子部件造成损害,在执行维护任务时,必须采取防静电措施,如佩戴防静电腕带,并维持适宜的空气湿度。 - **电池放电检测**:通过实施放电操作可以评估电池性能,明确其能否在必要时提供充足的能量支持。放电持续时长以及放电前后电池组的电压水平是判定电池健康状态的重要依据。 - **电池充电检测**:充电环节同样关键,通过监测充电时长和电压变化情况,能够评估充电效能和电池的充电状况。3. **维护成效与建议**:- 记录维护后的设备运行状况,如电池组的电压稳定性、充放电效率,以及是否存在异常温度升高等问题,是结果部分应详细记载的内容。 - 基于维护成果提出改进措施,如更换老旧电池单元,优化充电方案,改善通风条件,或增加维护检查的频率。4. **维护频率**:UPS的维护通常按照季度、半年或年度执行,具体频率需依据设备的使用条件和负载情况确定...
recommend-type

学生成绩管理系统C++课程设计与实践

资源摘要信息:"学生成绩信息管理系统-C++(1).doc" 1. 系统需求分析与设计 在进行学生成绩信息管理系统开发前,首先需要进行系统需求分析,这是确定系统开发目标与范围的过程。需求分析应包括数据需求和功能需求两个方面。 - 数据需求分析: - 学生成绩信息:需要收集学生的姓名、学号、课程成绩等数据。 - 数据类型和长度:明确每个数据项的数据类型(如字符串、整型等)和长度,例如学号可能是字符串类型且长度为一定值。 - 描述:详细描述每个数据项的意义,以确保系统能够准确处理。 - 功能需求分析: - 列出功能列表:用户界面应提供清晰的操作指引,列出所有可用功能。 - 查询学生成绩:系统应能通过学号或姓名查询学生的成绩信息。 - 增加学生成绩信息:允许用户添加未保存的学生成绩信息。 - 删除学生成绩信息:能够通过学号或姓名删除已经保存的成绩信息。 - 修改学生成绩信息:通过学号或姓名修改已有的成绩记录。 - 退出程序:提供安全退出程序的选项,并确保所有修改都已保存。 2. 系统设计 系统设计阶段主要完成内存数据结构设计、数据文件设计、代码设计、输入输出设计、用户界面设计和处理过程设计。 - 内存数据结构设计: - 使用链表结构组织内存中的数据,便于动态增删查改操作。 - 数据文件设计: - 选择文本文件存储数据,便于查看和编辑。 - 代码设计: - 根据功能需求,编写相应的函数和模块。 - 输入输出设计: - 设计简洁明了的输入输出提示信息和操作流程。 - 用户界面设计: - 用户界面应为字符界面,方便在命令行环境下使用。 - 处理过程设计: - 设计数据处理流程,确保每个操作都有明确的处理逻辑。 3. 系统实现与测试 实现阶段需要根据设计阶段的成果编写程序代码,并进行系统测试。 - 程序编写: - 完成系统设计中所有功能的程序代码编写。 - 系统测试: - 设计测试用例,通过测试用例上机测试系统。 - 记录测试方法和测试结果,确保系统稳定可靠。 4. 设计报告撰写 最后,根据系统开发的各个阶段,撰写详细的设计报告。 - 系统描述:包括问题说明、数据需求和功能需求。 - 系统设计:详细记录内存数据结构设计、数据文件设计、代码设计、输入/输出设计、用户界面设计、处理过程设计。 - 系统测试:包括测试用例描述、测试方法和测试结果。 - 设计特点、不足、收获和体会:反思整个开发过程,总结经验和教训。 时间安排: - 第19周(7月12日至7月16日)完成项目。 - 7月9日8:00到计算机学院实验中心(三楼)提交程序和课程设计报告。 指导教师和系主任(或责任教师)需要在文档上签名确认。 系统需求分析: - 使用表格记录系统需求分析的结果,包括数据项、数据类型、数据长度和描述。 - 分析数据项如学生成绩信息、状态器、链表节点等,确定其属性和行为。 以上就是文档中提到的学生成绩信息管理系统开发的关键知识点。开发此类系统需要熟练掌握C++编程基础,了解面向对象的程序设计思想,以及熟悉文件操作和链表等数据结构的应用。此外,良好的软件开发流程意识、测试意识和文档撰写能力也是必不可少的。
recommend-type

别再手动拖拽了!用Lumerical脚本批量创建FDTD仿真结构(附完整代码)

# 告别低效建模:Lumerical脚本自动化实战指南 在光子学仿真领域,时间就是科研生命线。当同行还在GUI界面里反复点击菜单时,你已经用脚本批量生成了20组参数化结构——这不是未来场景,而是每位FDTD用户都应该掌握的基础生产力革命。本文将彻底改变你与Lumerical的交互方式,从手动拖拽的农耕时代,跃迁到自动化建模的工业文明。 ## 1. 为何脚本建模是必然选择 2019年Nature Photonics的一项研究显示,科研工作者在仿真工具上平均浪费37%的时间在重复性操作上。对于需要参数扫描的纳米光学结构设计,这个数字可能更高。手动创建10个不同尺寸的纳米柱阵列意味着: -
recommend-type

Java邮件解析任务中,如何安全高效地提取HTML邮件内容并避免硬编码、资源泄漏和类型转换异常?

<think>我们被要求优化一段Java代码,该代码用于处理邮件(特别是来自特定发件人的构建通知邮件)。代码的主要问题包括: 1. 重复获取邮件内容:在检查MIME类型后,多次调用`msg[i].getContent()`,这可能导致性能问题或流关闭异常。 2. 类型转换问题:直接将邮件内容转换为`Multipart`而不进行类型检查,可能引发`ClassCastException`。 3. 代码结构问题:逻辑嵌套过深,可读性差,且存在重复代码(如插入邮件详情的操作在两个地方都有)。 4. 硬编码和魔法值:例如在解析HTML表格时使用了硬编码的索引(如list3.get(10)),这容易因邮件
recommend-type

RH公司应收账款管理优化策略研究

资源摘要信息:"本文针对RH公司的应收账款管理问题进行了深入研究,并提出了改进策略。文章首先分析了应收账款在企业管理中的重要性,指出其对于提高企业竞争力、扩大销售和充分利用生产能力的作用。然后,以RH公司为例,探讨了公司应收账款管理的现状,并识别出合同管理、客户信用调查等方面的不足。在此基础上,文章提出了一系列改善措施,包括完善信用政策、改进业务流程、加强信用调查和提高账款回收力度。特别强调了建立专门的应收账款回收部门和流程的重要性,并建议在实际应用过程中进行持续优化。同时,文章也意识到企业面临复杂多变的内外部环境,因此提出的策略需要根据具体情况调整和优化。 针对财务管理领域的专业学生和从业者,本文提供了一个关于应收账款管理问题的案例研究,具有实际指导意义。文章还探讨了信用管理和征信体系在应收账款管理中的作用,强调了它们对于提升企业信用风险控制和市场竞争能力的重要性。通过对比国内外企业在应收账款管理上的差异,文章总结了适合中国企业实际环境的应收账款管理方法和策略。" 根据提供的文件内容,以下是详细的知识点: 1. 应收账款管理的重要性:应收账款作为企业的一项重要资产,其有效管理关系到企业的现金流、财务健康以及市场竞争力。不良的应收账款管理会导致资金链断裂、坏账损失增加等问题,严重影响企业的正常运营和长远发展。 2. 应收账款的信用风险:在信用交易日益频繁的商业环境中,企业必须对客户信用进行评估,以便采取合理的信用政策,降低信用风险。 3. 合同管理的薄弱环节:合同是应收账款管理的法律基础,严格的合同管理能够保障企业权益,减少因合同问题导致的应收账款风险。 4. 客户信用调查:了解客户的信用状况对于预测和控制应收账款风险至关重要。企业需要建立有效的客户信用调查机制,识别和筛选信用良好的客户。 5. 应收账款回收策略:企业应建立有效的账款回收机制,包括定期的账款跟进、逾期账款的催收等。同时,建立专门的应收账款回收部门可以提升回收效率。 6. 应收账款管理流程优化:通过改进企业内部管理流程,如简化审批流程、提高工作效率等措施,能够提升应收账款的管理效率。 7. 应收账款管理策略的调整和优化:由于企业的内外部环境复杂多变,因此制定的管理策略需要根据实际情况进行动态调整和持续优化。 8. 信用管理和征信体系的作用:建立和完善企业内部信用管理体系和征信体系,有助于企业更好地控制信用风险,并在市场竞争中占据有利地位。 9. 对比国内外应收账款管理实践:通过研究国内外企业在应收账款管理上的不同做法和经验,可以借鉴先进的管理理念和方法,提升国内企业的应收账款管理水平。 综上所述,本文深入探讨了应收账款管理的多个方面,为RH公司乃至其他同类型企业提供了应收账款管理的改进方向和策略,对于财务管理专业的教育和实践都具有重要的参考价值。
recommend-type

新手别慌!用BingPi-M2开发板带你5分钟搞懂Tina Linux SDK目录结构

# 新手别慌!用BingPi-M2开发板带你5分钟搞懂Tina Linux SDK目录结构 第一次拿到BingPi-M2开发板时,面对Tina Linux SDK里密密麻麻的文件夹,我完全不知道从哪下手。就像走进一个陌生的大仓库,每个货架上都堆满了工具和零件,却找不到操作手册。这种困惑持续了整整两天,直到我意识到——理解目录结构比死记硬背每个文件更重要。 ## 1. 为什么SDK目录结构如此重要 想象你正在组装一台复杂的模型飞机。如果所有零件都混在一个箱子里,你需要花大量时间寻找每个螺丝和面板。但如果有分门别类的隔层,标注着"机身部件"、"电子设备"、"紧固件",组装效率会成倍提升。Ti