GitHub 7.5k star量,各种视觉Transformer的PyTorch实现合集整理好了

2021 年 12 月 31 日 机器之心
机器之心报道
编辑:杜伟

这个项目登上了今天的GitHub Trending。


近一两年,Transformer 跨界 CV 任务不再是什么新鲜事了。

自 2020 年 10 月谷歌提出 Vision Transformer (ViT) 以来,各式各样视觉 Transformer 开始在图像合成、点云处理、视觉 - 语言建模等领域大显身手。

之后,在 PyTorch 中实现 Vision Transformer 成为了研究热点。GitHub 中也出现了很多优秀的项目,今天要介绍的就是其中之一。

该项目名为「vit-pytorch」, 它是一个 Vision Transformer 实现,展示了一种在 PyTorch 中仅使用单个 transformer 编码器来实现视觉分类 SOTA 结果的简单方法。

项目当前的 star 量已经达到了 7.5k,创建者为 Phil Wang,ta 在 GitHub 上有 147 个资源库。


项目地址:https://github.com/lucidrains/vit-pytorch

项目作者还提供了一段动图展示:



项目介绍

首先来看 Vision Transformer-PyTorch 的安装、使用、参数、蒸馏等步骤。

第一步是安装:

$ pip install vit-pytorch

第二步是使用:

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

第三步是所需参数,包括如下:

  • image_size:图像大小

  • patch_size:patch 数量

  • num_classes:分类类别的数量

  • dim:线性变换 nn.Linear(..., dim) 后输出张量的最后维

  • depth:Transformer 块的数量

  • heads:多头注意力层中头的数量

  • mlp_dim:MLP(前馈)层的维数

  • channels:图像通道的数量

  • dropout:Dropout rate

  • emb_dropout:嵌入 dropout rate

  • ……


最后是蒸馏,采用的流程出自 Facebook AI 和索邦大学的论文《Training data-efficient image transformers & distillation through attention》。

论文地址:https://arxiv.org/pdf/2012.12877.pdf

从 ResNet50(或任何教师网络)蒸馏到 vision transformer 的代码如下:

import torchfrom torchvision.models import resnet50from vit_pytorch.distill import DistillableViT, DistillWrapperteacher = resnet50(pretrained = True)
v = DistillableViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)
distiller = DistillWrapper(
    student = v,
    teacher = teacher,
    temperature = 3,           # temperature of distillationalpha = 0.5,               # trade between main loss and distillation losshard = False               # whether to use soft or hard distillation
)
img = torch.randn(2, 3, 256, 256)labels = torch.randint(0, 1000, (2,))
loss = distiller(img, labels)loss.backward()
# after lots of training above ...pred = v(img) # (2, 1000)

除了 Vision Transformer 之外,该项目还提供了 Deep ViT、CaiT、Token-to-Token ViT、PiT 等其他 ViT 变体模型的 PyTorch 实现。


对 ViT 模型 PyTorch 实现感兴趣的读者可以参阅原项目。


© THE END 

转载请联系本公众号获得授权

投稿或寻求报道:content@jiqizhixin.com

登录查看更多
0

相关内容

Transformer是谷歌发表的论文《Attention Is All You Need》提出一种完全基于Attention的翻译架构

知识荟萃

精品入门和进阶教程、论文和代码整理等

更多

查看相关VIP内容、论文、资讯等
TPAMI 2022|华为诺亚最新视觉Transformer综述
专知会员服务
55+阅读 · 2022年2月24日
【ICLR2022】序列生成的目标侧数据增强
专知会员服务
22+阅读 · 2022年2月14日
专知会员服务
63+阅读 · 2021年4月11日
Transformer替代CNN?8篇论文概述最新进展!
专知会员服务
75+阅读 · 2021年1月19日
专知会员服务
109+阅读 · 2020年3月12日
近期必读的7篇 CVPR 2019【视觉问答】相关论文和代码
专知会员服务
35+阅读 · 2020年1月10日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
【GitHub实战】Pytorch实现的小样本逼真的视频到视频转换
专知会员服务
35+阅读 · 2019年12月15日
用PyTorch做物体检测和追踪
AI研习社
12+阅读 · 2019年1月6日
一份超全的PyTorch资源列表(Github 2.2K星)
黑龙江大学自然语言处理实验室
25+阅读 · 2018年10月26日
Github 项目推荐 | YOLOv3 的最小化 PyTorch 实现
AI研习社
25+阅读 · 2018年5月31日
tensorflow项目学习路径
北京思腾合力科技有限公司
10+阅读 · 2017年11月23日
国家自然科学基金
1+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
1+阅读 · 2012年12月31日
国家自然科学基金
1+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
1+阅读 · 2011年12月31日
Arxiv
3+阅读 · 2022年4月19日
Arxiv
27+阅读 · 2021年11月11日
Arxiv
17+阅读 · 2021年3月29日
Meta-Transfer Learning for Zero-Shot Super-Resolution
Arxiv
43+阅读 · 2020年2月27日
Arxiv
12+阅读 · 2019年4月9日
Arxiv
27+阅读 · 2017年12月6日
VIP会员
相关VIP内容
TPAMI 2022|华为诺亚最新视觉Transformer综述
专知会员服务
55+阅读 · 2022年2月24日
【ICLR2022】序列生成的目标侧数据增强
专知会员服务
22+阅读 · 2022年2月14日
专知会员服务
63+阅读 · 2021年4月11日
Transformer替代CNN?8篇论文概述最新进展!
专知会员服务
75+阅读 · 2021年1月19日
专知会员服务
109+阅读 · 2020年3月12日
近期必读的7篇 CVPR 2019【视觉问答】相关论文和代码
专知会员服务
35+阅读 · 2020年1月10日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
【GitHub实战】Pytorch实现的小样本逼真的视频到视频转换
专知会员服务
35+阅读 · 2019年12月15日
相关资讯
用PyTorch做物体检测和追踪
AI研习社
12+阅读 · 2019年1月6日
一份超全的PyTorch资源列表(Github 2.2K星)
黑龙江大学自然语言处理实验室
25+阅读 · 2018年10月26日
Github 项目推荐 | YOLOv3 的最小化 PyTorch 实现
AI研习社
25+阅读 · 2018年5月31日
tensorflow项目学习路径
北京思腾合力科技有限公司
10+阅读 · 2017年11月23日
相关基金
国家自然科学基金
1+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
1+阅读 · 2012年12月31日
国家自然科学基金
1+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
1+阅读 · 2011年12月31日
相关论文
Top
微信扫码咨询专知VIP会员