使用自定义标注的CT图像微调nnUNet,需要遵循特定的数据处理流程和训练配置。以下是完整的Python代码实现方法。
### **1. 数据准备与格式转换**
首先需要将自定义的CT数据转换为nnUNet要求的格式。假设您的CT图像为DICOM或NIfTI格式,标注为NIfTI格式的二进制掩码。
```python
import os
import nibabel as nib
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from nnunet.paths import nnUNet_raw_data, nnUNet_preprocessed
# 定义任务参数
task_id = 501 # 自定义任务ID,建议大于500避免与内置任务冲突
task_name = "CustomLiverCT"
num_classes = 2 # 背景 + 肝脏
# 创建任务目录结构
task_folder = join(nnUNet_raw_data, f"Task{task_id:03d}_{task_name}")
maybe_mkdir_p(task_folder)
maybe_mkdir_p(join(task_folder, "imagesTr"))
maybe_mkdir_p(join(task_folder, "labelsTr"))
maybe_mkdir_p(join(task_folder, "imagesTs"))
maybe_mkdir_p(join(task_folder, "labelsTs"))
def convert_to_nnunet_format(source_img_path, source_label_path, case_id):
"""
将单例数据转换为nnUNet格式
"""
# 读取图像和标签
img = nib.load(source_img_path)
label = nib.load(source_label_path)
# 确保图像和标签尺寸一致
assert img.shape == label.shape, f"图像和标签尺寸不匹配: {img.shape} vs {label.shape}"
# 保存为nnUNet格式
img_nifti = nib.Nifti1Image(img.get_fdata(), img.affine, img.header)
label_nifti = nib.Nifti1Image(label.get_fdata(), img.affine, img.header)
# 文件名格式: case_XXXX_0000.nii.gz (0000表示模态)
nib.save(img_nifti, join(task_folder, "imagesTr", f"{case_id}_0000.nii.gz"))
nib.save(label_nifti, join(task_folder, "labelsTr", f"{case_id}.nii.gz"))
return case_id
# 示例:转换多个病例
cases = [
{"img": "patient1_img.nii.gz", "label": "patient1_label.nii.gz", "id": "patient_001"},
{"img": "patient2_img.nii.gz", "label": "patient2_label.nii.gz", "id": "patient_002"},
# 添加更多病例...
]
for case in cases:
convert_to_nnunet_format(case["img"], case["label"], case["id"])
```
### **2. 创建dataset.json配置文件**
nnUNet需要dataset.json文件来描述数据集属性,这是关键步骤[ref_6]。
```python
import json
dataset_info = {
"name": task_name,
"description": "Custom liver CT segmentation dataset with manual annotations",
"reference": "Internal hospital data",
"licence": "CC BY-NC-SA 4.0",
"release": "1.0",
"tensorImageSize": "3D", # 对于CT通常是3D
"modality": {
"0": "CT"
},
"labels": {
"0": "background",
"1": "liver"
},
"numTraining": len(cases),
"numTest": 0, # 如果没有测试集,设为0
"training": [
{
"image": f"./imagesTr/{case['id']}_0000.nii.gz",
"label": f"./labelsTr/{case['id']}.nii.gz"
}
for case in cases
],
"test": []
}
# 保存配置文件
with open(join(task_folder, "dataset.json"), "w") as f:
json.dump(dataset_info, f, indent=4)
print(f"数据集配置已保存到: {join(task_folder, 'dataset.json')}")
```
### **3. 数据预处理**
nnUNet会自动执行预处理,但我们可以自定义预处理参数。
```python
from nnunet.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner3D_v21
from nnunet.training.model_restore import load_model_and_checkpoint_files
from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
# 设置环境变量(关键步骤)
os.environ['nnUNet_raw_data_base'] = "/path/to/nnUNet_raw_data"
os.environ['nnUNet_preprocessed'] = "/path/to/nnUNet_preprocessed"
os.environ['RESULTS_FOLDER'] = "/path/to/nnUNet_results"
# 运行实验规划器
planner = ExperimentPlanner3D_v21(task_folder, task_folder, "3d_fullres")
planner.plan_experiment()
planner.run_preprocessing(num_threads=8) # 根据CPU核心数调整
```
### **4. 微调训练代码**
使用预训练模型进行微调,这是减少标注工作量的关键[ref_5]。
```python
from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
from nnunet.utilities.task_name_id_conversion import convert_id_to_task_name
import torch
class CustomTrainer(nnUNetTrainerV2):
def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None,
batch_dice=True, stage=None, unpack_data=True, deterministic=True,
fp16=False, freeze_encoder=False):
super().__init__(plans_file, fold, output_folder, dataset_directory,
batch_dice, stage, unpack_data, deterministic, fp16)
self.freeze_encoder = freeze_encoder # 是否冻结编码器
def initialize_network(self):
"""
初始化网络并加载预训练权重
"""
super().initialize_network()
# 加载预训练模型(例如在肝脏CT上预训练的模型)
pretrained_model_path = "/path/to/pretrained_model/model_best.model"
if os.path.exists(pretrained_model_path):
print(f"加载预训练权重: {pretrained_model_path}")
checkpoint = torch.load(pretrained_model_path, map_location=torch.device('cpu'))
# 加载权重(跳过不匹配的层)
model_dict = self.network.state_dict()
pretrained_dict = {k: v for k, v in checkpoint['state_dict'].items()
if k in model_dict and model_dict[k].shape == v.shape}
model_dict.update(pretrained_dict)
self.network.load_state_dict(model_dict)
# 如果冻结编码器
if self.freeze_encoder:
for name, param in self.network.named_parameters():
if 'encoder' in name or 'conv_blocks' in name:
param.requires_grad = False
print("编码器层已冻结,只训练解码器")
return self.network
# 训练配置参数
def train_custom_model():
task_name = convert_id_to_task_name(task_id)
fold = 0 # 使用第0折进行训练
# 创建训练器实例
trainer = CustomTrainer(
plans_file=join(nnUNet_preprocessed, task_name, "nnUNetPlansv2.1_plans_3D.pkl"),
fold=fold,
output_folder=join(os.environ['RESULTS_FOLDER'], "nnUNet", "3d_fullres", task_name),
dataset_directory=join(nnUNet_preprocessed, task_name),
batch_dice=True,
stage=None,
unpack_data=True,
deterministic=True,
fp16=True, # 混合精度训练节省显存[ref_2]
freeze_encoder=True # 微调时通常冻结编码器
)
# 设置训练参数
trainer.num_epochs = 500 # 微调时epoch数可以减少
trainer.initial_lr = 1e-4 # 微调时学习率调小
trainer.batch_size = 2 # 根据GPU显存调整
# 启用深度监督(有助于小数据集训练)
trainer.deep_supervision = True
# 开始训练
trainer.train()
return trainer
# 执行训练
if __name__ == "__main__":
trainer = train_custom_model()
```
### **5. 多厂商数据适配技巧**
如果您的CT数据来自不同厂商,需要特殊处理以减少厂商差异影响[ref_1]。
```python
import SimpleITK as sitk
from scipy import ndimage
class MultiVendorCTProcessor:
"""处理多厂商CT数据的预处理类"""
@staticmethod
def hu_value_calibration(image, target_hu_range=(-200, 300)):
"""
HU值校准:将不同厂商的CT值映射到统一范围
"""
# 获取图像统计信息
mean_val = np.mean(image)
std_val = np.std(image)
# 线性归一化到目标范围
calibrated = np.clip(image, mean_val - 3*std_val, mean_val + 3*std_val)
calibrated = (calibrated - calibrated.min()) / (calibrated.max() - calibrated.min())
calibrated = calibrated * (target_hu_range[1] - target_hu_range[0]) + target_hu_range[0]
return calibrated
@staticmethod
def adaptive_resampling(image, target_spacing=(1.0, 1.0, 1.0)):
"""
自适应重采样:处理不同厂商的扫描参数差异
"""
original_spacing = image.GetSpacing()
original_size = image.GetSize()
# 计算新的尺寸
new_size = [
int(round(original_size[0] * original_spacing[0] / target_spacing[0])),
int(round(original_size[1] * original_spacing[1] / target_spacing[1])),
int(round(original_size[2] * original_spacing[2] / target_spacing[2]))
]
# 执行重采样
resampler = sitk.ResampleImageFilter()
resampler.SetSize(new_size)
resampler.SetOutputSpacing(target_spacing)
resampler.SetInterpolator(sitk.sitkLinear)
return resampler.Execute(image)
@staticmethod
def vendor_specific_augmentation(image, label, vendor_id):
"""
厂商感知的数据增强:根据不同厂商特性应用不同的增强策略
"""
from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d
augmentation_params = {
'vendor_A': {'rotation_range': (-15, 15), 'scale_range': (0.9, 1.1)},
'vendor_B': {'rotation_range': (-10, 10), 'scale_range': (0.95, 1.05)},
'vendor_C': {'rotation_range': (-20, 20), 'scale_range': (0.85, 1.15)}
}
params = augmentation_params.get(vendor_id, augmentation_params['vendor_A'])
# 应用旋转
rotation_angle = np.random.uniform(*params['rotation_range'])
image = ndimage.rotate(image, rotation_angle, axes=(1, 2), reshape=False, order=1)
label = ndimage.rotate(label, rotation_angle, axes=(1, 2), reshape=False, order=0)
# 应用缩放
scale_factor = np.random.uniform(*params['scale_range'])
new_shape = tuple(int(dim * scale_factor) for dim in image.shape)
return image, label
```
### **6. 推理与后处理**
训练完成后,使用模型进行预测并应用后处理。
```python
from nnunet.inference.predict import predict_from_folder
from nnunet.postprocessing.connected_components import determine_postprocessing
def run_inference(input_folder, output_folder, model_folder):
"""
运行推理预测
"""
# 设置参数
folds = [0] # 使用哪些折的模型进行集成
save_npz = False
num_threads_preprocessing = 6
num_threads_nifti_save = 2
# 执行预测
predict_from_folder(
model=model_folder,
input_folder=input_folder,
output_folder=output_folder,
folds=folds,
save_npz=save_npz,
num_threads_preprocessing=num_threads_preprocessing,
num_threads_nifti_save=num_threads_nifti_save,
lowres_segmentations=None,
part_id=0,
num_parts=1,
tta=False, # 测试时增强,可设为True提升精度但增加计算量
mixed_precision=True,
overwrite_existing=True,
mode='normal',
overwrite_all_in_gpu=None,
step_size=0.5
)
# 后处理(去除小连通区域)
determine_postprocessing(
output_folder,
join(model_folder, "postprocessing.json"),
threshold=0.5,
num_processes=8
)
print(f"推理完成,结果保存在: {output_folder}")
# 使用示例
model_folder = join(os.environ['RESULTS_FOLDER'], "nnUNet", "3d_fullres",
f"Task{task_id:03d}_{task_name}", "fold_0")
input_folder = "/path/to/test/images"
output_folder = "/path/to/predictions"
run_inference(input_folder, output_folder, model_folder)
```
### **7. 2D CT切片处理**
如果您的CT数据以2D切片形式存在,需要转换为伪3D数据[ref_2]。
```python
import cv2
from PIL import Image
class CT2DToPseudo3DConverter:
"""将2D CT切片转换为伪3D数据"""
@staticmethod
def convert_2d_slices_to_nifti(slice_folder, output_nifti_path, slice_order='filename'):
"""
将2D切片堆叠为3D NIfTI文件
"""
# 读取所有切片并排序
slice_files = sorted([f for f in os.listdir(slice_folder) if f.endswith(('.png', '.jpg', '.tif'))])
slices = []
for i, slice_file in enumerate(slice_files):
slice_path = join(slice_folder, slice_file)
# 读取图像
if slice_file.endswith('.dcm'):
# DICOM处理
import pydicom
ds = pydicom.dcmread(slice_path)
slice_data = ds.pixel_array
else:
# 普通图像格式
img = cv2.imread(slice_path, cv2.IMREAD_GRAYSCALE)
slice_data = np.array(img)
slices.append(slice_data)
# 堆叠为3D数组
volume = np.stack(slices, axis=-1) # 假设切片是沿着z轴
# 创建NIfTI图像
affine = np.eye(4)
nifti_img = nib.Nifti1Image(volume, affine)
# 保存
nib.save(nifti_img, output_nifti_path)
print(f"已保存伪3D NIfTI文件: {output_nifti_path}, 形状: {volume.shape}")
return volume
@staticmethod
def create_pseudo_3d_dataset(image_slice_folders, label_slice_folders, output_dir):
"""
批量创建伪3D数据集
"""
for i, (img_folder, label_folder) in enumerate(zip(image_slice_folders, label_slice_folders)):
case_id = f"case_{i:04d}"
# 转换图像
img_nifti = join(output_dir, "imagesTr", f"{case_id}_0000.nii.gz")
CT2DToPseudo3DConverter.convert_2d_slices_to_nifti(img_folder, img_nifti)
# 转换标签
label_nifti = join(output_dir, "labelsTr", f"{case_id}.nii.gz")
CT2DToPseudo3DConverter.convert_2d_slices_to_nifti(label_folder, label_nifti)
```
### **8. 关键注意事项**
在实际微调过程中,需要注意以下几点:
| 注意事项 | 解决方案 | 代码示例 |
|---------|---------|---------|
| **数据量不足** | 使用数据增强和迁移学习 | `trainer.deep_supervision = True` |
| **显存限制** | 使用混合精度训练和梯度累积 | `fp16=True`, 调整`batch_size` |
| **类别不平衡** | 使用Dice损失和样本加权 | `batch_dice=True` |
| **过拟合风险** | 早停、权重衰减、Dropout | `trainer.num_epochs`适当减少 |
| **多厂商数据** | HU值校准和自适应预处理 | `MultiVendorCTProcessor`类 |
### **9. 完整微调流程封装**
最后,将整个流程封装为可执行的脚本:
```python
def complete_finetuning_pipeline():
"""完整的微调流程"""
# 1. 数据准备
print("步骤1: 数据格式转换...")
prepare_data()
# 2. 数据预处理
print("步骤2: 数据预处理...")
run_preprocessing()
# 3. 模型微调
print("步骤3: 模型微调训练...")
trainer = train_custom_model()
# 4. 模型评估
print("步骤4: 模型评估...")
evaluate_model(trainer)
# 5. 推理测试
print("步骤5: 在新数据上推理...")
run_inference()
print("微调流程完成!")
if __name__ == "__main__":
# 设置任务参数
task_id = 501
task_name = "MyLiverCT"
# 执行完整流程
complete_finetuning_pipeline()
```
通过以上代码,您可以完成从数据准备到模型微调的全流程。关键点包括:正确格式化数据、配置`dataset.json`、加载预训练权重、调整训练参数以适应小数据集,以及处理多厂商CT数据的特殊需求。在实际应用中,建议根据具体数据特性和计算资源调整训练参数。