042、训练技巧大揭秘:学习率调度、损失函数组合与梯度裁剪的调参心法
上周五晚上,我盯着训练日志里那条震荡得像心电图一样的PSNR曲线,差点把咖啡泼到键盘上。模型跑了72小时,损失值在0.023到0.047之间反复横跳,验证集上的视觉效果时好时坏——某些patch锐利得像刀锋,另一些却糊成一团。这是典型的训练不稳定症状,而罪魁祸首往往藏在三个地方:学习率在某个阶段开始失控,损失函数里某个分量在暗中捣鬼,或者梯度在某次反向传播中爆炸了。
今天这篇笔记,我就把这几年在超分训练中踩过的坑、试过的药方,原原本本摊开来讲。
学习率调度:别让模型在高原上原地踏步
超分模型的学习率调度,和分类任务完全是两码事。分类模型通常喜欢余弦退火或者阶梯式下降,但超分任务有个特点:前期需要快速收敛到合理的特征空间,后期需要极其精细的微调来恢复高频细节。
我最早犯的错误是直接用PyTorch的StepLR,每30个epoch把学习率乘以0.1。结果呢?模型在某个阶段突然“失忆”,之前学到的纹理细节全部丢失,PSNR直接跳水2个dB。后来才明白,超分模型的参数空间里,局部极小值非常密集,粗暴的阶梯下降会让模型直接跳出已经找到的好区域。
现在我的标配是余弦退火+预热(Cosine Annealing with Warm Restarts)。具体做法:前5个epoch用线性预热,从1e-6慢慢升到1e-4,让模型先“热身”而不是一上来就猛冲。然后使用余弦周期,周期长度设为总epoch数的1/3,每次重启时学习率不重置到初始值,而是乘以0.8的衰减因子。这样模型每次“重启”时,其实是在上一个周期找到的最优点附近继续搜索,既避免了陷入局部最优,又不会丢失已经学到的细节。
这里有个小技巧:预热阶段的初始学习率不要设成0。我试过从0开始线性增长,结果前两个epoch的损失几乎不下降,白白浪费计算资源。改成1e-6后,预热阶段就能看到损失稳步下降。
损失函数组合:L1、感知损失、GAN损失的配比实验
超分的损失函数设计,本质上是在“保真度”和“感知质量”之间找平衡。L1损失保证像素级准确,感知损失让特征空间对齐,GAN损失则负责生成逼真的纹理。但三者的配比,我调了整整两个月才找到门道。
先说结论:L1损失是地基,感知损失是框架,GAN损失是装修。地基不稳,后面全是空中楼阁。
我踩过最大的坑是过早引入GAN损失。当时为了追求视觉效果,在训练初期就把GAN损失的权重设为0.1,结果模型直接崩了——生成器被判别器压着打,输出全是噪声。后来改成两阶段训练:前50个epoch只用L1+感知损失(权重比10:1),等PSNR稳定在28dB以上再引入GAN损失,权重从0.01开始慢慢增加到0.05。
感知损失的选择也有讲究。VGG16的relu5_1层对高频细节敏感,但容易引入伪影;relu3_3层更关注整体结构。我的经验是:对于4倍超分,用relu4_2层效果最好;对于2倍超分,relu3_3更合适。别问我为什么,这是跑了20组对比实验试出来的。
还有一个容易被忽略的点:损失函数的数值范围。L1损失通常在0.01量级,感知损失在0.1量级,GAN损失可能在1量级。如果不做归一化,梯度更新会被某个分量主导。我习惯在每次迭代后打印各个损失分量的数值,确保它们在同一数量级。如果某个分量突然变大,说明训练出了问题,需要立即暂停检查。
梯度裁剪:不是越大越好,也不是越小越好
梯度裁剪这个操作,在超分任务里比分类任务敏感得多。分类模型梯度爆炸通常是因为标签错误或者网络太深,但超分模型的梯度爆炸往往发生在纹理密集的区域——比如草地、头发、布料这些高频细节丰富的地方。
我刚开始用torch.nn.utils.clip_grad_norm_,把max_norm设成1.0,结果训练速度慢得像蜗牛,因为大部分梯度都被裁剪了,模型几乎学不到东西。后来改成10.0,又发现某些batch的梯度值达到100以上,裁剪后梯度方向被严重扭曲,模型开始震荡。
最终找到的平衡点是5.0。但这个值不是固定的,需要根据模型规模和batch size动态调整。一个经验公式:max_norm = sqrt(模型参数量/1e6) * 2.0。对于EDSR这种40M参数的模型,max_norm大概在8.9左右;对于RCAN这种15M参数的模型,max_norm在5.5左右。
裁剪方式也有讲究。clip_grad_norm_按范数裁剪,适合大部分场景;clip_grad_value_按值裁剪,适合处理极端异常值。我的习惯是先用norm裁剪,如果训练日志里频繁出现“gradient norm exceeded”的警告,再换成value裁剪,把每个梯度值限制在[-1.0, 1.0]之间。
实战中的调参心法
说了这么多理论,最后分享几个我在实际项目中验证过的调参流程:
第一步:先跑一个mini版。用1/10的数据量、1/4的batch size,快速验证学习率和损失函数配比是否合理。如果mini版在50个epoch内PSNR能超过25dB,说明方向对了。
第二步:监控梯度统计。每100个iteration打印一次梯度的均值、方差和最大范数。如果梯度均值长期为负,说明学习率太大,模型在震荡;如果梯度方差突然增大,说明某个batch包含异常样本,需要检查数据增强策略。
第三步:可视化中间特征。别只看最终的PSNR和SSIM。把模型中间层的特征图可视化出来,看看哪些通道在激活,哪些通道是死的。如果某个通道全是0,说明该通道的梯度消失了,需要调整初始化或者增加残差连接。
第四步:保存检查点时要保存优化器状态。我吃过一次大亏:模型跑了200个epoch,服务器宕机,只保存了模型权重,优化器的动量状态全丢了。重新加载后,学习率调度器从0开始,模型直接崩了。现在我的保存逻辑是:每10个epoch保存一个完整检查点,包含模型权重、优化器状态、学习率调度器状态、当前epoch数。
最后一条个人经验:不要迷信论文里的超参数。那些在DIV2K上表现完美的参数,换到你的数据集上可能完全失效。我的做法是:把论文里的参数作为起点,然后做一组网格搜索——学习率从1e-5到1e-3,损失权重从0.01到1.0,梯度裁剪阈值从1.0到10.0。虽然耗时,但这是最可靠的方法。
调参这件事,本质上是在和模型的“脾气”打交道。每个模型都有自己的性格,有的激进(学习率稍大就震荡),有的保守(梯度裁剪太狠就学不动)。你需要做的,不是找到一个万能公式,而是学会读懂训练日志里那些数字背后的故事。当你能从PSNR曲线的抖动中判断出是学习率问题还是梯度问题,从损失分量的变化中看出是感知损失过强还是GAN损失太弱,你就真正入门了。