点击上方“专知”关注获取更多AI知识!
【导读】10月26日,深度学习元老Hinton的NIPS2017 Capsule论文《Dynamic Routing Between Capsules》终于在arxiv上发表。今天相关关于这篇论文的TensorFlow\Pytorch\Keras实现相继开源出来,让我们来看下。
论文地址:https://arxiv.org/pdf/1710.09829.pdf
摘要:Capsule 是一组神经元,其活动向量(activity vector)表示特定实体类型的实例化参数,如对象或对象部分。我们使用活动向量的长度表征实体存在的概率,向量方向表示实例化参数。同一水平的活跃 capsule 通过变换矩阵对更高级别的 capsule 的实例化参数进行预测。当多个预测相同时,更高级别的 capsule 变得活跃。我们展示了判别式训练的多层 capsule 系统在 MNIST 数据集上达到了最好的性能效果,比识别高度重叠数字的卷积网络的性能优越很多。为了达到这些结果,我们使用迭代的路由协议机制:较低级别的 capsule 偏向于将输出发送至高级别的 capsule,有了来自低级别 capsule 的预测,高级别 capsule 的活动向量具备较大的标量积。
Python 3
PyTorch
TorchVision
TorchNet
TQDM
Visdom
第一步 在capsule_network.py
文件中设置训练epochs,batch size等
BATCH_SIZE = 100NUM_CLASSES = 10NUM_EPOCHS = 30NUM_ROUTING_ITERATIONS = 3
Step 2 开始训练. 如果本地文件夹中没有MNIST数据集,将运行脚本自动下载到本地. 确保 PyTorch可视化工具Visdom正在运行。
$ sudo python3 -m visdom.server & python3 capsule_network.py
经过30个epoche的训练手写体数字的识别率达到99.48%. 从下图的训练进度和损失图的趋势来看,这一识别率可以被进一步的提高 。
采用了PyTorch中默认的Adam梯度优化参数并没有用到动态学习率的调整。 batch size 使用100个样本的时候,在雷蛇GTX 1050 GPU上每个Epochs 用时3分钟。
扩展到除MNIST以外的其他数据集。
主要借鉴了以下两个 TensorFlow 和 Keras 的实现:
Keras implementation by @XifengGuo
TensorFlow implementation by @naturomics
Many thanks to @InnerPeace-Wu for a discussion on the dynamic routing procedure outlined in the paper.
Python
NumPy
Tensorflow (I'm using 1.3.0, not yet tested for older version)
tqdm (for displaying training progress info)
scipy (for saving image)
*第一步 * 用git命令下载代码到本地.
$ git clone https://github.com/naturomics/CapsNet-Tensorflow.git
$ cd CapsNet-Tensorflow
第二部 下载MNIST数据集(http://yann.lecun.com/exdb/mnist/), 移动并解压到data/mnist
文件夹.(当你用复制wget
命令到你的终端是注意渠道花括号里的反斜杠)
$ mkdir -p data/mnist
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/{train-images-idx3-ubyte.gz,train-labels-idx1-ubyte.gz,t10k-images-idx3-ubyte.gz,t10k-labels-idx1-ubyte.gz}
$ gunzip data/mnist/*.gz
第三部 开始训练:
$ pip install tqdm # install it if you haven't installed yet
$ python train.py
tqdm包并不是必须的,只是为了可视化训练过程。如果你不想要在train.py
中将循环for in step ...
改成 ``for step in range(num_batch)就行了。
$ python eval.py --is_training False
错误的运行结果(Details in Issues #8):
training loss
test acc
Epoch | 49 | 51 |
---|---|---|
test acc | 94.69 | 94.71 |
Results after fixing Issues #8:
关于capsule的一点见解
一种新的神经单元(输入向量输出向量,而不是标量)
常规算法类似于Attention机制
总之是一项很有潜力的工作,有很多工作可以在之上开展
完成MNIST的实现Finish the MNIST version of capsNet (progress:90%)
在其他数据集上验证capsNet
调整模型结构
一篇新的投稿在ICLR2018上的后续论文(https://openreview.net/pdf?id=HJWLfGWRb) about capsules(submitted to ICLR 2018),
Keras
matplotlib
第一步 安装 Keras:
$ pip install keras
第二步 用 git
命令下载代码到本地.
$ git clone https://github.com/xifengguo/CapsNet-Keras.git
$ cd CapsNet-Keras
第三步 训练:
$ python capsulenet.py
一次迭代训练(default 3).
$ python capsulenet.py --num_routing 1
其他参数包括想 batch_size, epochs, lam_recon, shift_fraction, save_dir
可以以同样的方式使用。 具体可以参考 capsulenet.py
假设你已经有了用上面命令训练好的模型,训练模型将被保存在 result/trained_model.h5
. 现在只需要使用下面的命令来得到测试结果。
$ python capsulenet.py --is_training 0 --weights result/trained_model.h5
将会输出测试结果并显示出重构后的图片。测试数据使用的和验证集一样 ,同样也可以很方便的在新数据上验证,至于要按照你的需要修改下代码就行了。
如果你的电脑没有GPU来训练模型,你可以从https://pan.baidu.com/s/1hsF2bvY下载预先训练好的训练模型
主要结果
运行 python capsulenet.py
: epoch=1 代表训练一个epoch 后的结果 在保存的日志文件中,epoch从0开始。
Epoch | 1 | 5 | 10 | 15 | 20 |
---|---|---|---|---|---|
train_acc | 90.65 | 98.95 | 99.36 | 99.63 | 99.75 |
vali_acc | 98.51 | 99.30 | 99.34 | 99.49 | 99.59 |
损失和准确度:
一次常规迭代后的结果
运行 python CapsNet.py --num_routing 1
Epoch | 1 | 5 | 10 | 15 | 20 |
---|---|---|---|---|---|
train_acc | 89.64 | 99.02 | 99.42 | 99.66 | 99.73 |
vali_acc | 98.55 | 99.33 | 99.43 | 99.57 | 99.58 |
每个 epoch 在单卡GTX 1070 GPU上大概需要110s
注释: 训练任然是欠拟合的,欢迎在你自己的机器上验证。学习率decay还没有经过调试, 我只是试了一次,你可以接续微调。
测试结果
运行 python capsulenet.py --is_training 0 --weights result/trained_model.h5
模型结构:
Kaggle (this version as self-contained notebook):
MNIST Dataset running on the standard MNIST and predicting for test data
MNIST Fashion running on the more challenging Fashion images.
TensorFlow:
naturomics/CapsNet-Tensorflow
Very good implementation. I referred to this repository in my code.
InnerPeace-Wu/CapsNet-tensorflow
I referred to the use of tf.scan when optimizing my CapsuleLayer.
LaoDar/tf_CapsNet_simple
PyTorch:
nishnik/CapsNet-PyTorch
timomernick/pytorch-capsule
gram-ai/capsule-networks
andreaazzini/capsnet.pytorch
leftthomas/CapsNet
MXNet:
AaronLeong/CapsNet_Mxnet
Lasagne (Theano):
DeniskaMazur/CapsNet-Lasagne
Chainer:
soskek/dynamic_routing_between_capsules
https://github.com/gram-ai/capsule-networks
https://github.com/naturomics/CapsNet-Tensorflow
https://github.com/XifengGuo/CapsNet-Keras
特别提示:
请关注专知公众号(扫一扫最下面专知二维码,或者点击上方蓝色专知),后台回复“MLDL” 就可以获取机器学习&深度学习知识资料大全集的pdf下载链接~~
特别提示:
专知,为人工智能从业者提供专业可信的AI知识分发服务;请登录www.zhuanzhi.ai或者点击阅读原文,顶端搜索“机器学习” 主题,直接获取查看获得关于机器学习更多的知识资料,包括链路荟萃动态资讯精华文章等资料,帮助你更好获取机器学习知识!如下图所示。
更多专知荟萃知识资料全集获取,请查看:
【专知荟萃01】深度学习知识资料大全集(入门/进阶/论文/代码/数据/综述/领域专家等)(附pdf下载)
【专知荟萃02】自然语言处理NLP知识资料大全集(入门/进阶/论文/Toolkit/数据/综述/专家等)(附pdf下载)
【专知荟萃03】知识图谱KG知识资料全集(入门/进阶/论文/代码/数据/综述/专家等)(附pdf下载)
专知荟萃04】自动问答QA知识资料全集(入门/进阶/论文/代码/数据/综述/专家等)(附pdf下载)
【教程实战】Google DeepMind David Silver《深度强化学习》公开课教程学习笔记以及实战代码完整版
【GAN货】生成对抗网络知识资料全集(论文/代码/教程/视频/文章等)
【干货】Google GAN之父Ian Goodfellow ICCV2017演讲:解读生成对抗网络的原理与应用
【AlphaGoZero核心技术】深度强化学习知识资料全集(论文/代码/教程/视频/文章等)
欢迎转发到你的微信群和朋友圈,分享专业AI知识!
获取更多关于机器学习以及人工智能知识资料,请访问www.zhuanzhi.ai, 或者点击阅读原文,即可得到!
-END-
欢迎使用专知
专知,一个新的认知方式!目前聚焦在人工智能领域为AI从业者提供专业可信的知识分发服务, 包括主题定制、主题链路、搜索发现等服务,帮你又好又快找到所需知识。
使用方法>>访问www.zhuanzhi.ai, 或点击文章下方“阅读原文”即可访问专知
中国科学院自动化研究所专知团队
@2017 专知
专 · 知
关注我们的公众号,获取最新关于专知以及人工智能的资讯、技术、算法、深度干货等内容。扫一扫下方关注我们的微信公众号。
点击“阅读原文”,使用专知!