UNet/UNet++实战:从零构建多类别分割数据管道与模型训练 1. 多类别分割任务入门指南第一次接触图像分割任务时我完全被那些专业术语搞晕了。简单来说多类别分割就是让计算机识别图片中不同类别的物体并用不同颜色标记出来。比如在医疗影像中我们可能需要同时识别肝脏、肾脏和脾脏在工业质检中可能要区分产品表面的划痕、凹陷和污渍。UNet和UNet是处理这类任务的明星模型。它们最大的优势在于能够很好地捕捉图像中的细节特征这对分割任务至关重要。我刚开始用UNet做细胞分割时发现它比传统方法准确率高出一大截从此就爱上了这个架构。要完成一个完整的分割项目我们需要走完这几个关键步骤准备数据→制作标签→构建模型→训练调优→测试验证。听起来简单但每个环节都有不少坑等着你。下面我就把自己踩过的坑和总结的经验分享给大家。2. 数据准备与标注实战2.1 数据收集与整理数据是模型的食物喂什么数据决定了模型能学到什么。我建议至少准备1000张以上的图片尺寸最好保持一致。如果是医疗影像256x256或512x512都是常用尺寸工业质检可能要求更高分辨率。文件目录建议这样组织dataset/ ├── images/ # 原始图像 ├── masks/ # 标注图像 ├── test/ # 测试集 └── checkpoints/ # 模型保存位置2.2 标注工具使用技巧Labelme是我最常用的标注工具它支持多边形标注特别适合不规则形状。安装很简单pip install labelme labelme # 启动图形界面标注时要注意几点每个类别使用不同的标签名尽量贴近物体边缘标注复杂物体可以用多个多边形组合保存为JSON格式它会记录所有标注点的坐标标注完成后你会得到一堆.json文件每个对应一张图片的标注信息。这些文件需要转换成模型能理解的mask图像。3. 标签制作与数据处理3.1 JSON转Mask实战这是最容易出错的一步。我们需要把JSON中的多边形信息转换成单通道的灰度图其中每个像素值代表类别索引。比如背景 0类别1 1类别2 2...import cv2 import numpy as np import json # 类别定义 categories [背景, 圆形, 矩形] # 加载原图获取尺寸 img cv2.imread(image.png) height, width img.shape[:2] # 创建空白mask mask np.zeros((height, width), dtypenp.uint8) # 处理每个标注区域 with open(image.json) as f: label_data json.load(f) for shape in label_data[shapes]: label shape[label] points np.array(shape[points], dtypenp.int32) cv2.fillPoly(mask, [points], categories.index(label)) # 保存为PNG格式 cv2.imwrite(mask.png, mask)3.2 数据增强技巧数据量不足时增强是救命稻草。我常用的增强包括随机旋转-30°到30°水平/垂直翻转亮度对比度调整高斯噪声弹性变形使用albumentations库可以轻松实现import albumentations as A transform A.Compose([ A.Rotate(limit30, p0.5), A.HorizontalFlip(p0.5), A.RandomBrightnessContrast(p0.2), A.GaussNoise(var_limit(10,50), p0.3) ]) augmented transform(imageimg, maskmask) aug_img augmented[image] aug_mask augmented[mask]4. UNet/UNet模型搭建4.1 基础UNet实现UNet的结构像是一个对称的沙漏先下采样提取特征再上采样恢复尺寸。核心是中间的跳跃连接能把浅层细节和深层语义信息结合起来。用PyTorch实现基础UNetimport torch import torch.nn as nn class DoubleConv(nn.Module): (卷积 BN ReLU) * 2 def __init__(self, in_ch, out_ch): super().__init__() self.double_conv nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding1), nn.BatchNorm2d(out_ch), nn.ReLU(inplaceTrue), nn.Conv2d(out_ch, out_ch, 3, padding1), nn.BatchNorm2d(out_ch), nn.ReLU(inplaceTrue) ) def forward(self, x): return self.double_conv(x) class UNet(nn.Module): def __init__(self, n_channels, n_classes): super().__init__() # 下采样路径 self.inc DoubleConv(n_channels, 64) self.down1 Down(64, 128) self.down2 Down(128, 256) self.down3 Down(256, 512) self.down4 Down(512, 1024) # 上采样路径 self.up1 Up(1024, 512) self.up2 Up(512, 256) self.up3 Up(256, 128) self.up4 Up(128, 64) self.outc nn.Conv2d(64, n_classes, 1) def forward(self, x): x1 self.inc(x) x2 self.down1(x1) x3 self.down2(x2) x4 self.down3(x3) x5 self.down4(x4) x self.up1(x5, x4) x self.up2(x, x3) x self.up3(x, x2) x self.up4(x, x1) logits self.outc(x) return logits4.2 UNet改进方案UNet在UNet基础上增加了密集跳跃连接让不同深度的特征能够更好地融合。这就像是在原有道路上修建了多条高架桥信息流通更顺畅了。关键改进点增加了子网络之间的密集连接采用深度监督机制支持模型剪枝实现UNet的核心模块class UNetPlusPlus(nn.Module): def __init__(self, n_channels, n_classes, deep_supervisionFalse): super().__init__() self.deep_supervision deep_supervision # 编码器部分 self.conv0_0 VGGBlock(n_channels, 64) self.conv1_0 VGGBlock(64, 128) self.conv2_0 VGGBlock(128, 256) self.conv3_0 VGGBlock(256, 512) self.conv4_0 VGGBlock(512, 1024) # 解码器部分 self.conv0_1 VGGBlock(64128, 64) self.conv1_1 VGGBlock(128256, 128) self.conv2_1 VGGBlock(256512, 256) self.conv3_1 VGGBlock(5121024, 512) self.conv0_2 VGGBlock(64*2128, 64) self.conv1_2 VGGBlock(128*2256, 128) self.conv2_2 VGGBlock(256*2512, 256) self.conv0_3 VGGBlock(64*3128, 64) self.conv1_3 VGGBlock(128*3256, 128) self.conv0_4 VGGBlock(64*4128, 64) # 输出层 self.final nn.Conv2d(64, n_classes, kernel_size1) if deep_supervision: self.ds_final1 nn.Conv2d(64, n_classes, kernel_size1) self.ds_final2 nn.Conv2d(64, n_classes, kernel_size1) self.ds_final3 nn.Conv2d(64, n_classes, kernel_size1)5. 模型训练与调优5.1 损失函数选择多类别分割常用的损失函数有交叉熵损失简单直接但对类别不平衡敏感Dice损失适合小目标分割Lovász-Softmax基于IOU的损失函数我推荐结合使用交叉熵和Dice损失class DiceBCELoss(nn.Module): def __init__(self, weightNone, size_averageTrue): super().__init__() def forward(self, inputs, targets, smooth1): # 交叉熵部分 bce F.binary_cross_entropy_with_logits(inputs, targets) # Dice系数部分 inputs torch.sigmoid(inputs) intersection (inputs * targets).sum() dice (2.*intersection smooth)/(inputs.sum() targets.sum() smooth) return bce (1 - dice)5.2 训练技巧学习率策略使用余弦退火配合热重启optimizer torch.optim.Adam(model.parameters(), lr1e-4) scheduler torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_010, T_mult1, eta_min1e-5)早停机制当验证集损失连续5个epoch不下降时停止训练混合精度训练可以节省显存并加速训练scaler torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): outputs model(inputs) loss criterion(outputs, masks) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()类别权重对样本少的类别给予更高权重class_weights torch.tensor([0.1, 1.0, 2.0]) # 背景、类别1、类别2 criterion nn.CrossEntropyLoss(weightclass_weights)6. 模型评估与推理6.1 评估指标常用的分割评估指标IOU交并比预测区域与真实区域的重叠度Dice系数类似IOU但对小目标更敏感像素准确率整体分类准确率计算IOU的代码def iou_score(output, target): output torch.sigmoid(output) 0.5 target target 0.5 intersection (output target).float().sum() union (output | target).float().sum() return (intersection 1e-6) / (union 1e-6)6.2 推理部署训练完成后可以使用以下代码进行单张图片推理def predict(model, image_path, save_path): # 加载图像 img cv2.imread(image_path) img cv2.cvtColor(img, cv2.COLOR_BGR2RGB) original_size img.shape[:2] # 预处理 img cv2.resize(img, (256, 256)) img img / 255.0 img torch.from_numpy(img).permute(2,0,1).float() img img.unsqueeze(0).to(device) # 推理 model.eval() with torch.no_grad(): output model(img) # 后处理 output F.softmax(output, dim1) pred torch.argmax(output, dim1).squeeze().cpu().numpy() pred cv2.resize(pred, (original_size[1], original_size[0]), interpolationcv2.INTER_NEAREST) # 可视化保存 colored_mask np.zeros((*pred.shape, 3), dtypenp.uint8) colored_mask[pred 1] [255, 0, 0] # 类别1红色 colored_mask[pred 2] [0, 0, 255] # 类别2蓝色 cv2.imwrite(save_path, colored_mask)7. 实际项目中的经验分享在医疗影像分割项目中我发现这几个技巧特别有用预处理很重要CT/MRI图像建议先做窗宽窗位调整再用CLAHE增强对比度处理类别不平衡对小目标使用OHEM在线难例挖掘策略模型集成训练3-5个不同初始化的模型取预测结果的平均值测试时增强对测试图像做多种变换旋转、翻转将预测结果平均工业质检项目中需要注意使用高分辨率图像时可以先裁剪再处理对于微小缺陷可以放大局部区域再输入网络后处理时使用形态学操作去除噪声最后给初学者的建议是先从简单的UNet开始跑通整个流程后再尝试UNet等复杂模型。记得保存每个实验的配置和结果方便后期分析比较。我在实际项目中遇到过模型突然性能下降的情况后来发现是数据增强过度导致的所以任何改动都要谨慎评估。