选自Github
作者:王小龙等
机器之心编译
参与:李泽南
最近,卡耐基梅隆大学(CMU)的王小龙等人发表的论文《A-Fast-RCNN: Hard Positive Generation via Adversary for Object Detection》引起了很多人的关注。该研究将对抗学习的思路应用在图像识别问题中,通过对抗网络生成遮挡和变形图片样本来训练检测网络,取得了不错的效果。该论文已被 CVPR2017 大会接收。
论文链接:http://www.cs.cmu.edu/~xiaolonw/papers/CVPR2017_Adversarial_Det.pdf
Github:https://github.com/xiaolonw/adversarial-frcnn
论文:A-Fast-RCNN: Hard Positive Generation via Adversary for Object Detection
摘要
如何确定物体探测器能够应对被遮蔽、不同角度或变形的图像?我们目前的解决方法是使用数据驱动的策略,收集一个巨大的数据集——覆盖所有条件下物体的样子,并希望通过模型训练能够让分类器学会把它们识别为同一个物体。但是数据集真的能够覆盖所有的情况吗?我们认为像分类、遮蔽与变形这样的特性也符合长尾理论。一些遮蔽与变形非常罕见,几乎永远不会发生,而我们希望训练出的模型是能够应付所有情况的。在本论文中,我们提出了一种新的解决方案。我们提出了一种对抗网络,可以自我生成遮蔽与变形例子。对抗的目标是生成物体探测器难以识别的例子。在我们的架构中,原识别器与它的对手共同进行学习。实验证明,我们的方法与 Fast-RCNN 相比,在 VOC07 上的 mAP 上的升幅为 2.3%,在 VOC2012 物体识别挑战中的 mAP 升幅为 2.6%。我们同时发布了本研究的代码。
图 1:在论文中,我们提出了使用对抗网络来生成带有遮挡和变形的例子,从而让物体探测器难以进行分类。随着探测器的性能逐渐提升,对抗网络产生的图片质量也在提升。通过这种对抗策略,神经网络识别物体的准确性得到了进一步提升。
图 2:该方法的 ASDN 网络架构以及如何与 Fast RCNN 结合的示意图。我们的 ASDN 网络使用输入图片加入 RoI 池化层中得到的补丁。ASDN 网络预测遮挡/极高光蒙版,然后将其用于丢弃特征值,并传递到 Fast-RCNN 分类塔。
图 3:(a)模型预训练——寻找难度最高的遮挡用于训练 ASDN 网络。(b)ASDN 网络生成的遮挡蒙版事例,黑色区域在通过 FRCN 管道时被遮挡。
图 4:ASDN 与 ASTN 网络组合架构示意。首先创建遮挡蒙版,随后旋转路径以产生用于训练的例子。
表格 1:VOC 识别测试的平均精度,FRCN 指使用我们训练方式的 FRCN 成绩。
该研究的 Caffe 实现:A-Fast-RCNN: Hard Positive Generation via Adversary for Object Detection
介绍
本实现是 Caffe 版本的 A-Fast-RCNN。尽管我们在论文中的初始实现是在 Torch 上进行的。但 Caffe 的版本更加简单、快速和易于使用。我们发布了用 Adversarial Spatial Dropout Network 训练 A-Fast-RCNN 的训练数据的代码。
许可
本代码是在 MIT License 之下发布的(请参阅 LICENSE 文件获取详细信息)。
引用
如果你认为本内容对你的研究有帮助,可以进行引用:
@inproceedings{WangCVPR17afrcnn,
Author = {Xiaolong Wang and Abhinav Shrivastava and Abhinav Gupta},
Title = {A-Fast-RCNN: Hard Positive Generation via Adversary for Object Detection},
Booktitle = {Conference on Computer Vision and Pattern Recognition ({CVPR})},
Year = {2017}
}
免责声明
本实现是建立在 OHEM 代码的一个 fork 上的,后者又建立在 Faster R-CNN Python 代码和 Fast R-CNN 之上。请在使用时选择相应的研究论文加以引用。
OHEM:https://github.com/abhi2610/ohem
Faster R-CNN Python:https://github.com/rbgirshick/py-faster-rcnn
Fast R-CNN:https://github.com/rbgirshick/fast-rcnn
结果
注意:研究中记录的结果基于 VGG16 网络。
安装
请遵循 VOC 数据下载和安装规范,这方面与 Faster R-CNN Python 一样。
使用
想要运行代码,请输入:
./train.sh
它包括三个阶段的训练:
./experiments/scripts/fast_rcnn_std.sh [GPU_ID] VGG16 pascal_voc
这曾被用来进行标准 Fast-RCNN 一万次迭代的训练,你或许需要下载模型和 log。
模型:http://suo.im/2cgwYG
Log:http://suo.im/39gkhf
./experiments/scripts/fast_rcnn_adv_pretrain.sh [GPU_ID] VGG16 pascal_voc
在对抗网络的预训练阶段,可能会需要下载模型和 log:
模型:http://suo.im/2cgwYG
Log:http://suo.im/1TSiRh
./copy_model.h
用于复制上述两个模型的权重,用于初始化联合模型。
./experiments/scripts/fast_rcnn_adv.sh [GPU_ID] VGG16 pascal_voc
用于 detector 联合训练对抗网络,在这一步中你可能会需要下载模型和 log:
模型:http://suo.im/25uFFX
Log:http://suo.im/2UTbnC
本文为机器之心编译,转载请联系本公众号获得授权。
✄------------------------------------------------
加入机器之心(全职记者/实习生):hr@jiqizhixin.com
投稿或寻求报道:editor@jiqizhixin.com
广告&商务合作:bd@jiqizhixin.com