1. COVID-19 CT图像分类项目概述
作为一名刚接触深度学习的开发者,我一直在寻找合适的练手项目来提升自己的实战能力。医学影像分类这个方向引起了我的兴趣,特别是COVID-19检测这个具有实际应用价值的场景。这个项目使用PyTorch框架和DenseNet模型,对COVID-19 CT图像进行二分类(阳性/阴性),非常适合想要入门医学图像分析的开发者。
项目最大的特点是:
- 使用真实的医学影像数据集(COVID-CT)
- 采用迁移学习技术,基于预训练的DenseNet模型
- 完整的项目流程:从数据加载到模型评估
- 包含了我在复现过程中遇到的各种坑和解决方案
2. 项目环境与技术栈
2.1 开发环境配置
这个项目需要以下环境配置:
# 基础环境 Python 3.8+ CUDA 11.7 (如需GPU加速) cuDNN 8.5+ # 核心依赖 torch==2.10.0+cu126 torchvision==0.25.0+cu126 torchxrayvision==1.4.0注意:安装PyTorch时,建议直接使用官方命令获取CUDA版本。我在实践中发现,使用某些镜像源可能会导致安装成CPU-only版本。
2.2 关键技术组件
| 技术领域 | 具体实现 |
|---|---|
| 深度学习框架 | PyTorch |
| 模型架构 | DenseNet-121 (预训练) |
| 数据处理 | TorchVision Transforms |
| 可视化工具 | TensorBoard, Matplotlib |
| 评估指标 | 准确率、混淆矩阵、分类报告 |
3. 数据集处理与分析
3.1 数据集介绍
我们使用的COVID-CT数据集包含741张CT扫描图像:
- COVID-19阳性:347张
- COVID-19阴性:394张
数据集按照7:2:1的比例划分:
- 训练集:423张
- 验证集:116张
- 测试集:202张
3.2 数据预处理流程
医学影像数据预处理是项目成功的关键。我们设计了两种不同的transform管道:
# 训练集数据增强 train_transformer = transforms.Compose([ transforms.Resize(256), transforms.RandomResizedCrop(240, scale=(0.5, 1.0)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 验证/测试集处理 val_transformer = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(240), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])这样设计的原因是:
- 训练时通过随机裁剪和翻转增加数据多样性
- 评估时使用确定性变换保证结果可复现
- 标准化使用ImageNet的均值和方差(因为使用预训练模型)
3.3 自定义数据集类
我们实现了CovidCTDataset类来加载数据:
class CovidCTDataset(Dataset): def __init__(self, root_dir, txt_COVID, txt_NonCOVID, transform=None): self.img_list = [] # 加载COVID阳性样本 covid_list = [[os.path.join(root_dir, 'CT_COVID', item), 0] for item in read_txt(txt_COVID)] # 加载COVID阴性样本 noncovid_list = [[os.path.join(root_dir, 'CT_NonCOVID', item), 1] for item in read_txt(txt_NonCOVID)] self.img_list = covid_list + noncovid_list self.transform = transform def __getitem__(self, idx): img_path, label = self.img_list[idx] image = Image.open(img_path).convert('RGB') if self.transform: image = self.transform(image) return {'img': image, 'label': label}这个设计允许我们:
- 灵活加载不同划分的数据集
- 保持图像和标签的对应关系
- 支持各种transform操作
4. 模型构建与训练
4.1 DenseNet模型选择
我们选择DenseNet-121作为基础模型,原因如下:
- 密集连接结构适合医学图像分析
- 预训练权重(在ImageNet上)提供良好的特征提取能力
- 模型深度适中,适合我们的数据规模
import torchxrayvision as xrv model = xrv.models.DenseNet(num_classes=2, in_channels=3).to(device)注意:这里使用torchxrayvision提供的医学预训练模型,比普通DenseNet更适合医疗图像分析。
4.2 训练流程实现
训练过程的关键组件:
# 损失函数 criterion = nn.CrossEntropyLoss() # 优化器 optimizer = optim.Adam(model.parameters(), lr=0.0003) # 训练循环 for epoch in range(100): model.train() for batch in train_loader: data, target = batch['img'].to(device), batch['label'].to(device) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() # 验证 model.eval() with torch.no_grad(): # 验证代码...训练中的关键点:
- 使用Adam优化器,学习率设为0.0003(经过实验确定)
- 每个epoch后在验证集上评估
- 使用TensorBoard记录训练过程
4.3 训练可视化
通过TensorBoard可以监控训练过程:
tensorboard --logdir=runs --port=6008从曲线可以看出:
- 训练损失稳步下降
- 验证准确率逐渐提升
- 没有出现过拟合现象
5. 模型评估与结果分析
5.1 测试集性能
在测试集上的评估结果:
测试集准确率: 0.8762 precision recall f1-score support COVID 0.86 0.89 0.87 95 NonCOVID 0.89 0.86 0.88 107 accuracy 0.88 202 macro avg 0.88 0.88 0.88 202 weighted avg 0.88 0.88 0.88 2025.2 混淆矩阵分析
从混淆矩阵可以看出:
- 模型对COVID阳性的识别率(召回率)为89%
- 对阴性的识别率为86%
- 没有明显的类别偏向性
5.3 性能优化建议
根据评估结果,可以考虑以下优化方向:
- 尝试更复杂的数据增强(如随机旋转、颜色抖动)
- 调整类别权重,处理数据不平衡问题
- 使用更大型的预训练模型(如DenseNet-169)
- 尝试不同的学习率调度策略
6. 常见问题与解决方案
6.1 PyTorch安装问题
问题描述:使用清华镜像源安装PyTorch时,可能会错误安装CPU版本。
解决方案:
# 推荐使用官方命令安装 pip install torch torchvision --index-url https://download.pytorch.org/whl/cu1176.2 Matplotlib兼容性问题
问题描述:PyCharm中Matplotlib图像无法正常显示。
解决方案:
- 打开PyCharm设置
- 进入"Tools" → "Python Scientific"
- 取消勾选"Show plots in tool window"
6.3 CUDA内存不足
问题描述:训练时出现CUDA out of memory错误。
解决方案:
- 减小batch size(本项目使用16)
- 使用梯度累积技巧
- 尝试混合精度训练
7. 关键代码细节解析
7.1 数据类型处理
在计算损失函数时,需要注意数据类型:
# target.long()的作用 loss = criterion(output, target.long()).long()确保标签是整数类型,虽然PyTorch通常会自动转换,但显式转换更安全。
7.2 设备转移问题
模型和数据需要转移到相同设备:
# 模型自动输出与输入相同的设备 output = model(data) # 自动在GPU上计算 loss = criterion(output, target) # target也需要在GPU上7.3 梯度与数值转换
训练过程中需要注意:
# .detach()与.item()的区别 loss_value = loss.detach().cpu().item().detach():切断梯度计算,但仍保持Tensor类型.item():转换为Python标量数值
8. 项目总结与扩展方向
通过这个项目,我深入学习了:
- PyTorch完整训练流程的实现
- 医学图像处理的特有方法
- 迁移学习在实际问题中的应用
- 模型评估与结果分析方法
项目后续可以扩展的方向:
- 部署为Web应用,提供在线检测服务
- 尝试3D CNN处理CT序列图像
- 加入临床数据(如患者年龄、症状)进行多模态分析
这个项目已经开源在GitHub: covid-ct-classification ,包含完整代码和预训练模型,欢迎大家一起改进和完善。