基于tf.keras,实现多标签分类CNN模型。
logs
文件夹,存放日志文件;训练完毕会出现models
文件夹,存放模型;configs.py
并进行修改,此为参数配置文件;utils/check_label_file.py
,确保标签文件中的图片真实可用;python run.py
,会根据配置文件configs.py
进行训练/测试/模型转换等。README.md
;1_learning_note
下的note;multi_label
下的trainer.py
里的__init__
函数,把整体模型串起来;run.py
文件,结合着看configs.py
。A_learning_notes
: README后,先查看本部分了解本项目大致结构;backbone
: 模型的骨干网络脚本;dataset
: 数据集构造脚本;
dataset_util.py
: 使用tf.image API进行图像数据增强,然后用tf.data进行数据集构建;file_util.py
: 以txt标签文件的形式,构造tf.data数据集用于训练;tfrecord_util.py
: 读取txt标签文件,写tfrecord,然后读取tfrecord为数据集用于训练;images
: 项目图片;logs
: 存放训练过程中的日志文件和tensorboard文件(当前可能不存在);models
: 存放训练好的模型文件(当前可能不存在);multi_label
: 多标签分类模型构建脚本;
classifier_loss.py
: 多标签分类的损失函数,包含多种损失函数:focal loss
、GHM
等;classifier_model.py
: 多标签分类模型,负责调用backbone
里的骨干网络和本脚本中的多标签head
组成整体模型;train.py
: 模型训练接口,集成模型构建/编译/训练/debug/预测、数据集构建等功能;utils
: 一些工具脚本;
generate_txt
: 扫描指定路径下的图片数据,生成训练、测试等label.txt(根据实际项目而定,当前可能不存在);check_label_file.py
: 在训练前检查训练集,确保标签文件中的图片真实可用;draw_tools.py
: 模型训练完进行测试时,绘制每个类别的混淆图;logger_callback.py
: 日志打印的keras回调函数;radam.py
: RAdam算法的tf.keras优化器实现;configs.py
: 配置文件;run.py
: 启动脚本;在多标签多分类模型基础上,添加功能:
loss函数改造:
label smoothing
: 标签平滑。focal loss
: 给每个样本的分类loss增加一个因子项,降低分类误差小的样本的影响,解决难易样本问题。
gradient harmonizing mechanism (GHM)
:
根据样本梯度密度曲线(这里的梯度是梯度范数,并且不是所有网络参数的梯度,而是最后一层的回传梯度),
取反得到梯度密度调和参数(和平衡多类别数据集一个意思,只不过这里不是按类别来平衡,而是按梯度区间来平衡),
再乘以梯度以调整梯度贡献曲线,从而降低高密度区域的梯度贡献比例,提升低密度区域的梯度贡献比例。
原论文insight: 对网络训练而言,梯度是最重要的东西,而网络训练不好,也是因为梯度没调节好。 focal loss认为前背景不平衡问题,本质为难易样本不平衡问题,从而调节样本的梯度贡献,一定程度上解决了背景问题。 作者认为,类别不平衡、难易样本不平衡,造成的本质驱动是梯度不平衡。 然后通过绘制训练好的模型在样本空间上的梯度分布曲线,发现小梯度和大梯度都是高密度区域, (作者认为小梯度对应易学习样本,大密度对应异常样本); 然后绘制正常loss和focal loss梯度贡献曲线,发现正常loss中,高密度区域的梯度贡献度很高, 而focal loss中,小梯度的高密度区域被因子项惩罚而降低梯度贡献度, 但大梯度的高密度区域的梯度贡献度依然很高。 作者认为focal loss平衡了一部分梯度贡献度,所以使得训练低密度的中间梯度的梯度贡献度影响提升, 提升了算法性能;同时,认为focal loss并没有从本质出发,所以还有残留问题(异常样本大梯度的高密度区域)。 然后提出了GHM,从梯度分布和梯度贡献角度出发,提升网络训练效果。
分离conv层的权重衰减项$\lambda{conv}$ 和 BN层gamma的权重衰减项$\lambda{gamma}$
TIPS:其他试过但基本无效的手段包括: 继续加大weight decay权重,BN层的gamma不加weight decay,BN层的beta加weight decay, 全连接层加dropout,focal loss,从Adam训练改为SGDM,加warmup。
[1] L2 Regularization versus Batch and Weight Normalization
[2] Towards Understanding Regularization in Batch Normalization
[3] Learning Efficient Convolutional Networks through Network Slimming
[4] Accurate, Large Minibatch SGD:Training ImageNet in 1 Hour
[5] Large Batch Training of Convolutional Networks
[6] On the Variance of the Adaptive Learning Rate and Beyond
[7] Rethinking the inception architecture for computer vision
[8] Gradient Harmonized Single-stage Detector
[9] Data-Driven Sparse Structure Selection for Deep Neural Networks
[10] Rethinking the Value of Network Pruning