GroupSoftmax cross entropy loss function is implemented for training with multiple different benchmark datasets. We trained a 83 classes detection model by using COCO and CCTSDB.
在工业界的实际生产环境中,经常会面临如下四个问题:
为了解决上述四个问题,我们提出了 GroupSoftmax 交叉熵损失函数,如下图所示,与 Softmax 交叉熵损失函数相比, GroupSoftmax 交叉熵损失函数允许类别 K 和类别 j 发生合并,形成一个新的组合类别 g ,当训练样本 y 的真实标签为组合类别 g 时,也能够计算出类别 K 和类别 j 的对应梯度,完成网络权重更新。理论上, GroupSoftmax 交叉熵损失函数能够兼容任意数量、任意标注标准的多个数据集联合训练。
从公式也可以看出, GroupSoftmax 损失函数是 Softmax 损失函数的一种推广,一种更复杂也更加灵活的表达,可以自由的发生类别合并。当群组类别 g 中只包含 m 单独一个类别时, GroupSoftmax 损失函数退化为 Softmax 损失函数。为了验证 GroupSoftmax 交叉熵损失函数的有效性,我们利用 GroupSoftmax 交叉熵损失函数在COCO和CCTSDB数据集上进行了联合训练,得到了一个83类检测器。有趣的是,模型不仅有83类检测效果,在coco_minival2014测试集上的表现比原来80类检测器反而会好一些。也就是说我们利用了一个与COCO无关的CCTSDB数据集,在相同参数下,Faster RCNN算法的检测效果由原来的38.6提高到了39.3,提高了0.7个点。我们同时训练了一个83类Trident*模型,6个epoch训练周期在coco_minival2014测试集上mAP指标为44.0,进一步验证了 GroupSoftmax 交叉熵损失函数的有效性。从理论上而言, GroupSoftmax 交叉熵损失函数能够支持任意标注标准的数据集进行联合训练。
Model | Backbone | Head | GroupSoftmax | Num Classes | Train Schedule | FP16 | AP | AP50 | AP75 | APs | APm | APl | Link |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Faster-SyncBN | R101v2-C4 | C5-256ROI | no | 80 | 1X | no | 38.6 | - | - | - | - | - | model |
Faster-SyncBN | R101v2-C4 | C5-256ROI | yes | 83 | 1X | yes | 39.3 | 59.9 | 42.3 | 21.0 | 44.1 | 53.3 | - |
Trident* | R101v2-C4 | C5-128ROI | yes | 83 | 1X | yes | 44.0 | 64.9 | 48.4 | 29.0 | 47.8 | 57.6 | - |
GroupSoftmax 用法参考groupsoftmax_faster_r101v2c4_c5_256roi_syncbn_1x.py中的GroupParam设置,下面举例说明用法。
假设有3个数据集:数据集A、数据集B、数据集C,它们的标注细节情况如下:
考虑上述的3个数据集,我们能够定义一个最精细的6类检测任务,类别分别为 {行人、公交车、电动车、自行车、牛、狗} ,想要通过上述的3个数据集联合训练得到一个6分类的检测模型,应该修改RPN网络和Head网络中的分类数量,以及对应的group_id信息。对于某个数据集中未标注的类别,group_id为0,意味着等同于背景类,本质是某个类别未标注可以理解为将某个类别标注为背景。网络修改细节如下:
rpnvx = np.array([0, 1, 2, 3], dtype=np.float32) # rpn 4 classes
rpnva = np.array([0, 1, 2, 0], dtype=np.float32) # rpn group_id of DATASET A
rpnvb = np.array([0, 1, 0, 3], dtype=np.float32) # rpn group_id of DATASET B
rpnvc = np.array([0, 1, 1, 1], dtype=np.float32) # rpn group_id of DATASET C
boxvx = np.array([0, 1, 2, 3, 4, 5, 6], dtype=np.float32) # box 7 classes
boxva = np.array([0, 1, 2, 3, 3, 5, 0], dtype=np.float32) # box group_id of DATASET A
boxvb = np.array([0, 1, 2, 3, 4, 0, 6], dtype=np.float32) # box group_id of DATASET B
boxvc = np.array([0, 1, 1, 1, 1, 1, 1], dtype=np.float32) # box group_id of DATASET C
def gtclass2rpn(gtclass):
gtclass[gtclass < 5] = 1 #前景1
gtclass[gtclass == 5] = 2 #前景2
gtclass[gtclass == 6] = 3 #前景3
return gtclass
如果CCTSDB中同时标注了person类别和3类交通标注,也即CCTSDB和COCO中都标注了person类别,则应该做出改动的地方有三个:
RPN的分类任务应该由3分类修改为4分类,因为此时有4种情况,分别为: { 背景、前景1、前景2、前景3 } ,其中 前景1 为COCO和CCTSDB中都标注了的person, 前景2 为COCO中的其他79个类别, 前景3 为CSTSDB中的3类交通标志。所以应该将GroupParam中的rpnvx信息,由原来的:
rpnv0 = np.array([0, 1, 2], dtype=np.float32) # rpn 3 classes
rpnv1 = np.array([0, 1, 0], dtype=np.float32) # COCO benchmark
rpnv2 = np.array([0, 0, 2], dtype=np.float32) # CCTSDB benchmark
修改为:
rpnv0 = np.array([0, 1, 2, 3], dtype=np.float32) # rpn 4 classes
rpnv1 = np.array([0, 1, 2, 0], dtype=np.float32) # COCO benchmark
rpnv2 = np.array([0, 1, 0, 3], dtype=np.float32) # CCTSDB benchmark
前景1在COCO和CCTSDB中都进行标注了。而前景3在COCO中未标注,所以前景3在COCO的样本中会与背景类组成一个group组类别。而前景2在CCTSDB中未标注,所以前景2在CCTSDB的样本中会与背景类组成一个group组类别。本质是某个类别未标注可以理解为将某个类别标注为背景。
修改RPN网络的类别映射方法gtclass2rpn,原来是3分类网络,类别映射为: {0 ==> 0, [1,2,...,80] ==> 1, [81,82,83] ==> 2} 。现在是4分类网络,类别映射应该修改为: {0 ==> 0, 1 ==> 1, [2,3,...,80] ==> 2, [81,82,83] ==> 3} 。对应的gtclass2rpn方法应该由原来的:
def gtclass2rpn(gtclass):
class_gap = 80
gtclass[gtclass > class_gap] = -1
gtclass[gtclass > 0] = 1
gtclass[gtclass < 0] = 2
return gtclass
修改为:
def gtclass2rpn(gtclass):
class_gap = 80
gtclass[gtclass > class_gap] = -1
gtclass[gtclass == 1] = 1
gtclass[gtclass > 1] = 2
gtclass[gtclass < 0] = 3
return gtclass
在Head网络中,跟原来唯一的区别是CCTSDB中标注了person类别,因为person类别在COCO中类别id为1,所以对应的,在CCTSDB的boxv2中,id为1的类别也由原来的0修改为1,因为已经标注了,不再与背景类0组成一个group进行训练。修改之后如下:
# box 83 classes
boxv0 = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, \
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, \
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, \
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, \
61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, \
76, 77, 78, 79, 80, 81, 82, 83], dtype=np.float32)
#COCO benchmark
boxv1 = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, \
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, \
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, \
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, \
61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, \
76, 77, 78, 79, 80, 0, 0, 0 ], dtype=np.float32)
#CCTSDB benchmark
boxv2 = np.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 81, 82, 83], dtype=np.float32)
SimpleDet contains a lot of C++ operators not in MXNet offical repo, so one has to build MXNet from scratch. Please refer to INSTALL.md more details
SimpleDet requires groundtruth annotation organized as following format
[
{
"gt_class": (nBox, ),
"gt_bbox": (nBox, 4),
"flipped": bool,
"h": int,
"w": int,
"image_url": str,
"im_id": int,
# this fields are generated on the fly during test
"rec_id": int,
"resize_h": int,
"resize_w": int,
...
},
...
]
Especially, for experimenting on coco datatet, one can organize coco data in
data/
coco/
annotations/
instances_train2014.json
instances_valminusminival2014.json
instances_minival2014.json
image_info_test-dev2017.json
images/
train2014
val2014
test2017
and run the helper script to generate roidb
python3 utils/generate_roidb.py --dataset coco --dataset-split train2014
python3 utils/generate_roidb.py --dataset coco --dataset-split valminusminival2014
python3 utils/generate_roidb.py --dataset coco --dataset-split minival2014
python3 utils/generate_roidb.py --dataset coco --dataset-split test-dev2017
cd $SIMPLEDET_DIR
git clone https://github.com/RogerChern/mxnext
# train
python3 detection_train.py --config config/detection_config.py
# test
python3 detection_test.py --config config/detection_config.py
Please refer to MODEL_ZOO.md for available models
detection_train.py
detection_test.py
config/
detection_config.py
core/
detection_input.py
detection_metric.py
detection_module.py
models/
FPN/
tridentnet/
maskrcnn/
cascade_rcnn/
retinanet/
mxnext/
symbol/
builder.py
Everything is configurable from the config file, all the changes should be out of source.
One experiment is a directory in experiments folder with the same name as the config file.
E.g. r50_fixbn_1x.py is the name of a config file
config/
r50_fixbn_1x.py
experiments/
r50_fixbn_1x/
checkpoint.params
log.txt
coco_minival2014_result.json
The models
directory contains SOTA models implemented in SimpletDet.
Simpledet supports many popular detection methods and here we take Faster-RCNN as a typical example to show how a detector is built.
DetectionAugmentation
.
Norm2DImage
and Resize2DImageBbox
.AnchorTarget2D
, which generates anchors and corresponding anchor targets for training RPN.FasterRcnn
. The key components are listed as follow:
Backbone
provides interfaces to build backbone networks, e.g. ResNet and ResNext.Neck
provides interfaces to build complementary feature extraction layers for backbone networks, e.g. FPNNeck
builds Top-down pathway for Feature Pyramid Network.RpnHead
aims to build classification and regression layers to generate proposal outputs for RPN. Meanwhile, it also provides interplace to generate sampled proposals for the subsequent R-CNN.RoiExtractor
extracts features for each roi (proposal) based on the R-CNN features generated by Backbone
and Neck
.BboxHead
builds the R-CNN layers for proposal refinement.The flexibility of simpledet framework makes it easy to build different detectors. We take TridentNet as an example to demonstrate how to build a custom detector simply based on the Faster-RCNN framework.
DetectionAugmentation
.
TridentAnchorTarget2D
is implemented to generate anchors for multiple branches and filter anchors for scale-aware training scheme.Backbone
according to the descriptions in the paper. We also provide a TridentRpnHead
to generate filtered proposals in RPN to implement the scale-aware scheme. Other components are shared the same with original Faster-RCNN.Please refer to DISTRIBUTED.md
Yuntao Chen, Chenxia Han, Yanghao Li, Zehao Huang, Yi Jiang, Naiyan Wang
This project is release under the Apache 2.0 license for non-commercial usage. For commercial usage, please contact us for another license.
If you find our project helpful, please consider cite our tech report.
@article{chen2019simpledet,
title={SimpleDet: A Simple and Versatile Distributed Framework for Object Detection and Instance Recognition},
author={Chen, Yuntao and and Han, Chenxia and Li, Yanghao and Huang, Zehao and Jiang, Yi and Wang, Naiyan and Zhang, Zhaoxiang},
journal={arXiv preprint arXiv:1903.05831},
year={2019}
}