谷歌正式开源 Hinton 胶囊理论代码,即刻用 TensorFlow 实现吧

2018 年 2 月 1 日 AI研习社 思颖

雷锋网(公众号:雷锋网) AI 研习社消息,相信大家对于「深度学习教父」Geoffery Hinton 在去年年底发表的胶囊网络还记忆犹新,在论文 Dynamic Routing between Capsules 中,Hinton 团队提出了一种全新的网络结构。为了避免网络结构的杂乱无章,他们提出把关注同一个类别或者同一个属性的神经元打包集合在一起,好像胶囊一样。在神经网络工作时,这些胶囊间的通路形成稀疏激活的树状结构(整个树中只有部分路径上的胶囊被激活)。这样一来,Capsule 也就具有更好的解释性。

在实验结果上,CapsNet 在数字识别和健壮性上都取得了不错的效果。详情可以

日前,该论文的第一作者 Sara Sabour 在 GitHub 上公布了论文代码,大家可以马上动手实践起来。雷锋网 AI 研习社将教程编译整理如下:终于盼来了Hinton的Capsule新论文,它能开启深度神经网络的新时代吗?

  所需配置:

  • TensorFlow(点击 http://www.tensorflow.org 进行安装或升级)

  • NumPy (详情点击 http://www.numpy.org/ )

  • GPU

  执行 test 程序,来验证安装是否正确,诸如:

python layers_test.py

  快速 MNIST 测试:

下载并提取 MNIST tfrecord 到 $DATA_DIR/ 下:

https://storage.googleapis.com/capsule_toronto/mnist_data.tar.gz

下载并提取 MNIST 模型 checkpoint 到 $CKPT_DIR 下:

https://storage.googleapis.com/capsule_toronto/mnist_checkpoints.tar.gz

python experiment.py --data_dir=$DATA_DIR/mnist_data/ --train=false \

--summary_dir=/tmp/ --checkpoint=$CKPT_DIR/mnist_checkpoint/model.ckpt-1

  快速 CIFAR10 ensemble 测试:

下载并提取 cifar10 二进制文件到 $DATA_DIR/ 下:

https://www.cs.toronto.edu/~kriz/cifar.html

下载并提取 cifar10 模型 checkpoint 到 $CKPT_DIR 下:

https://storage.googleapis.com/capsule_toronto/cifar_checkpoints.tar.gz

将目录($DATA_DIR)作为 data_dir 来传递:

python experiment.py --data_dir=$DATA_DIR --train=false --dataset=cifar10 \

--hparams_override=num_prime_capsules=64,padding=SAME,leaky=true,remake=false \

--summary_dir=/tmp/ --checkpoint=$CKPT_DIR/cifar/cifar{}/model.ckpt-600000 \

--num_trials=7

   CIFAR10 训练指令:

python experiment.py --data_dir=$DATA_DIR --dataset=cifar10 --max_steps=600000\

--hparams_override=num_prime_capsules=64,padding=SAME,leaky=true,remake=false \

--summary_dir=/tmp/

  MNIST full 训练指令:

  • 也可以执行--validate=true as well 在训练-测试集上训练

  • 执行 --num_gpus=NUM_GPUS 在多块GPU上训练

python experiment.py --data_dir=$DATA_DIR/mnist_data/ --max_steps=300000\

--summary_dir=/tmp/attempt0/

   MNIST baseline 训练指令:

python experiment.py --data_dir=$DATA_DIR/mnist_data/ --max_steps=300000\

--summary_dir=/tmp/attempt1/ --model=baseline

To test on validation during training of the above model:

训练如上模型时,在验证集上进行测试(记住,在训练过程中会持续执行指令):

  • 在训练时执行 --validate=true 也一样

  • 可能需要两块 GPU,一块用于训练集,一块用于验证集

  • 如果所有的测试都在一台机器上,你需要对训练集、验证集的测试中限制 RAM 消耗。如果不这样,TensorFlow 会在一开始占用所有的 RAM,这样就不能执行其他工作了

python experiment.py --data_dir=$DATA_DIR/mnist_data/ --max_steps=300000\

--summary_dir=/tmp/attempt0/ --train=false --validate=true

大家可以通过 --num_targets=2 和 --data_dir=$DATA_DIR/multitest_6shifted_mnist.tfrecords@10 在 MultiMNIST 上进行测试或训练,生成 multiMNIST/MNIST 记录的代码在 input_data/mnist/mnist_shift.py 目录下。

  multiMNIST 测试代码:

python mnist_shift.py --data_dir=$DATA_DIR/mnist_data/ --split=test --shift=6 

--pad=4 --num_pairs=1000 --max_shard=100000 --multi_targets=true

可以通过 --shift=6 --pad=6 来构造 affNIST expanded_mnist

论文地址:https://arxiv.org/pdf/1710.09829.pdf 

GitHub 地址:https://github.com/Sarasra/models/tree/master/research/capsules

雷锋网 AI 研习社编译整理。


附重塑AI的胶囊网络论文解读:


NLP 工程师入门实践班:基于深度学习的自然语言处理

三大模块,五大应用,手把手快速入门 NLP

海外博士讲师,丰富项目经验

算法 + 实践,搭配典型行业应用

随到随学,专业社群,讲师在线答疑

▼▼▼





新人福利





关注 AI 研习社(okweiwu),回复  1  领取

【超过 1000G 神经网络 / AI / 大数据,教程,论文】



如何看待 Hinton 那篇备受关注的Capsules论文?

登录查看更多
4

相关内容

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。
专知会员服务
109+阅读 · 2020年3月12日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
163+阅读 · 2019年10月28日
Github 项目推荐 | 用 Pytorch 实现的 Capsule Network
AI研习社
22+阅读 · 2018年3月7日
万众期待:Hinton团队开源CapsNet源码
专知
6+阅读 · 2018年2月1日
放弃深度学习 ,Hinton提出Capsule计划
德先生
3+阅读 · 2018年1月2日
tensorflow LSTM + CTC实现端到端OCR
机器学习研究会
26+阅读 · 2017年11月16日
重磅!Geoffrey Hinton提出capsule 概念,推翻反向传播!
人工智能学家
7+阅读 · 2017年9月17日
TResNet: High Performance GPU-Dedicated Architecture
Arxiv
8+阅读 · 2020年3月30日
Arxiv
6+阅读 · 2018年4月23日
Arxiv
3+阅读 · 2018年3月2日
Arxiv
10+阅读 · 2018年2月17日
Arxiv
5+阅读 · 2018年1月16日
VIP会员
相关资讯
Top
微信扫码咨询专知VIP会员