一文详解Vision Transformer(附代码)

2022 年 1 月 19 日 PaperWeekly


©PaperWeekly 原创 · 作者 | 孙裕道

单位 | 北京邮电大学博士生

研究方向 | GAN图像生成、情绪对抗样本生成



引言

Transformer 在 NLP 中大获成功,Vision Transformer 则将 Transformer 模型架构扩展到计算机视觉的领域中,并且它可以很好的地取代卷积操作,在不依赖卷积的情况下,依然可以在图像分类任务上达到很好的效果。卷积操作只能考虑到局部的特征信息,而 Transformer 中的注意力机制可以综合考量全局的特征信息。

Vision Transformer 尽力做到在不改变 Transformer 中 Encoder 架构的前提下,直接将其从 NLP 领域迁移到计算机视觉领域中,目的是让原始的 Transformer 模型开箱即用。如果想要了解 Transformer 原理详细的介绍可以看我的上一篇文章 《矩阵视角下的 Transformer 详解(附代码)》


注意力机制应用
在正式详细介绍 Vision Transformer 之前,先介绍两个注意力机制在计算机视觉中应用的例子。Vision Transformer 并不是第一个将注意力机制应用到计算机视觉的领域中去的,其中 SAGAN 和 AttnGAN 就早已经在 GAN 的框架中引入了注意力机制,并且它们大大提高了图像生成的质量。
2.1 Self-Attention GAN

SAGAN 在 GAN 的框架中利用自注意力机制来捕获图像特征的长距离依赖关系,使得合成的图像中考量了所有的图像特征信息。SAGAN 中自注意力机制的操作原理如上图所示。

给定一个 3 通道的输入特征图 ,其中 。将 分别输入到三个不同的 的卷积层中,并生成 query 特征图 ,key 特征图 和 value 特征图 。生成 具体的计算过程为,给定三个卷积核 ,并用这三个卷积核分别与 做卷积运算得到 ,即:

其中 表示卷积运算符号。同理生成 的计算过程与 的计算过程类似。然后再利用 进行注意力分数的计算得到矩阵 ,其中矩阵 的元素 的计算公式为:

再对矩阵 利用 softmax 函数进行注意力分布的计算得到注意力分布矩阵 ,其中矩阵 的元素 的计算公式为:

最后利用注意力分布矩阵 和value特征图 得到最后的输出 ,即:

2.2 AttnGAN

AttnGAN 通过利用注意力机制来实现多阶段细颗粒度的文本到图像的生成,它可以通过关注自然语言中的一些重要单词来对图像的不同子区域进行合成。比如通过文本“一只鸟有黄色的羽毛和黑色的眼睛”来生成图像时,会对关键词“鸟”,“羽毛”,“眼睛”,“黄色”,“黑色”给予不同的生成权重,并根据这些关键词的引导在图像的不同的子区域中进行细节的丰富。AttnGAN 中注意力机制的操作原理如上图所示。

给定输入图像特征向量 和词特征向量 ,其中 。首先利用矩阵 进行线性变换将词特征空间 的向量转换成图像特征空间 的向量,则有:

然后再利用转换后的词特征 与图像特征 进行注意力分数的计算得到注意力分数矩阵 ,其中的分量 的计算公式为:

再对矩阵 利用 函数进行注意力分布的计算得到注意力分布矩阵 ,其中矩阵 的元素 的计算公式为:

最后利用注意力分布矩阵 和图像特征 得到最后的输出 ,即:



Vision Transformer

本节主要详细介绍 Vision Transformer 的工作原理,3.1 节是关于 Vision Transformer 的整体框架,3.2 节是关于 Transformer Encoder 的内部操作细节。对于 Transformer Encoder 中 Multi-Head Attention 的原理本文不会赘述,具体想了解的可以参考上一篇文章《矩阵视角下的 Transformer 详解(附代码)》中相关原理的介绍。

不难发现,不管是自然语言处理中的 Transformer,还是计算机视觉中图像生成的 SAGAN,以及文本生成图像的 AttnGAN,它们核心模块中注意力机制的主要目的就是求出注意力分布。
3.1 Vision Transformer 整体框架

如果下图所示为 Vision Transformer 的整体框架以及相应的训练流程。

  • 给定一张图片 ,并将它分割成 9 个 patch 分别为 。然后再将这个 9 个 patch 拉平,则有
  • 利用矩阵 将拉平后的向量 经过线性变换得到图像编码向量 ,具体的计算公式为:
  • 然后将图像编码向量 和类编码向量 分别与对应的位置编进行加和得到输入编码向量,则有:

  • 接着将输入编码向量输入到 Vision Transformer Encoder 中得到对应的输出
  • 最后将类编码向量 输入全连接神经网络中 MLP 得到类别预测向量 ,并与真实类别向量 计算交叉熵损失得到损失值 loss,利用优化算法更新模型的权重参数。
注意事项:
看到这里可能会有一个疑问为什么预测类别的时候只用到了类别编码向量 ,Vision Transformer Encoder 其它的输出为什么没有输入到 MLP 中?为了回答这个问题,我们令函数 为 Vision Transformer Encoder},则类编码向量 可以表示为:

由上公式可以发现,类编码向量 是属于高层特征,其实它综合了所有的图像编码信息,所以可以用它来进行分类,这个可以类比在卷积神经网络中最后的类别输出向量其实就是一层层卷积得到的高层特征。


3.2 Transformer Encoder操作原理

如下图所示分别为 Vision Transformer Encoder 模型结构图和原始 Transformer Encoder 的模型结构图。可以直观的发现 Vision Transformer Encoder 和 Transformer Encoder 都有层归一化,多头注意力机制,残差连接和线性变换这四个操作,只是在操作顺序有所不同。在以下的 Transformer 代码实例中,将以下两种 Encoder 网络结构都进行了实现,可以发现两种网络结构都可以进行很好的训练。

下图左半部分 Vision Transformer Encoder 具体的操作流程为:
  • 给定输入编码矩阵 ,首先将其进行层归一化得到
  • 利用矩阵 进行线性变换得到矩阵 。具体的计算过程为:

    再将这三个矩阵输入到 Multi-Head Attention( 该原理参考 《矩阵视角下的 Transformer 详解(附代码)》 )中得到矩阵 ,将最原始的输入矩阵 进行残差计算得到
  • 进行第二次层归一化得到 ,然后再将 输入到全连接神经网络中进行线性变换得到 。最后将 进行残差操作得到该 Block 的输出; 。一个 Encoder 可以将 个 Block 进行堆叠,最后得到的输出为




程序代码

Vision Transformer 的代码示例如下所示。该代码是由上一篇《矩阵视角下的Transformer详解(附代码)》的代码的基础上改编而来。Vision Transformer 的作者的本意就是想让在 NLP 中的 Transformer 模型架构做尽可能少的修改可以直接迁移到 CV 中,所以以下程序尽可能保持作者的愿意,并在代码实现了两种Encoder 的网络结构,即 3.2 节图片所示的两个网络结构,一种是最原始的Encoder 网络结构,一种是 Vision Transformer。论文里的 Encoder 的网络结构。

这里需要注意的是,Vision Transformer 里并能没有 Decoder 模块,所以不需要计算 Encoder 和 Decoder 的交叉注意力分布,这就进一步给 Vision Transformer 的编程带来了简便。Vision Transformer的开源代码的网址为:
https://github.com/lucidrains/vit-pytorch/tree/main/vit_pytorch


import torch
import torch.nn as nn
import os
from einops import rearrange
from einops import repeat
from einops.layers.torch import Rearrange

def inputs_deal(inputs):
    return inputs if isinstance(inputs, tuple) else(inputs, inputs)

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query):
        N =query.shape[0]
        value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]

        # split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)
        # queries shape: (N, query_len, heads, heads_dim)
        # keys shape : (N, key_len, heads, heads_dim)
        # energy shape: (N, heads, query_len, key_len)

        attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim=3)

        out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # (N, query_len, heads, head_dim)

        out = self.fc_out(out)
        return out


class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)


    def forward(self, value, key, query, x, type_mode):
        if type_mode == 'original':
            attention = self.attention(value, key, query)
            x = self.dropout(self.norm(attention + x))
            forward = self.feed_forward(x)
            out = self.dropout(self.norm(forward + x))
            return out
        else:
            attention = self.attention(self.norm(value), self.norm(key), self.norm(query))
            x =self.dropout(attention + x)
            forward = self.feed_forward(self.norm(x))
            out = self.dropout(forward + x)
            return out

class TransformerEncoder(nn.Module):
    def __init__(
            self,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout = 0,
            type_mode = 'original'
        )
:

        super(TransformerEncoder, self).__init__()
        self.embed_size = embed_size
        self.type_mode = type_mode
        self.Query_Key_Value = nn.Linear(embed_size, embed_size * 3, bias = False)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,
                    )
                for _ in range(num_layers)]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        for layer in self.layers:
            QKV_list = self.Query_Key_Value(x).chunk(3, dim = -1)
            x = layer(QKV_list[0], QKV_list[1], QKV_list[2], x, self.type_mode)
        return x

class VisionTransformer(nn.Module):
    def __init__(self,
                image_size,
                patch_size,
                num_classes,
                embed_size,
                num_layers,
                heads,
                mlp_dim,
                pool = 'cls',
                channels = 3,
                dropout = 0,
                emb_dropout = 0.1,
                type_mode = 'vit')
:

        super(VisionTransformer, self).__init__()
        img_h, img_w = inputs_deal(image_size)
        patch_h, patch_w = inputs_deal(patch_size)

        assert img_h % patch_h == 0 and img_w % patch_w == 0'Img dimensions can be divisible by the patch dimensions'

        num_patches = (img_h // patch_h) * (img_w // patch_w)

        patch_size = channels * patch_h * patch_w

        self.patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_h, p2=patch_w),
            nn.Linear(patch_size, embed_size, bias=False)
        )


        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_size))
        self.cls_token = nn.Parameter(torch.randn(11, embed_size))
        self.dropout = nn.Dropout(emb_dropout)



        self.transformer = TransformerEncoder(embed_size,
                                    num_layers,
                                    heads,
                                    mlp_dim,
                                    dropout)
        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_size),
            nn.Linear(embed_size, num_classes)
        )

    def forward(self, img):
        x = self.patch_embedding(img)
        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, '() n d ->b n d', b = b)
        x = torch.cat((cls_tokens, x), dim = 1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        x = self.transformer(x)
        x = x.mean(dim = 1if self.pool == 'mean' else x[:, 0]
        x = self.to_latent(x)
        return self.mlp_head(x)


if __name__ == '__main__':
    vit = VisionTransformer(
            image_size = 256,
            patch_size = 16,
            num_classes = 10,
            embed_size = 256,
            num_layers = 6,
            heads = 8,
            mlp_dim = 512,
            dropout = 0.1,
            emb_dropout = 0.1
        )
    img = torch.randn(33256256)
    pred = vit(img)
    print(pred)

以下代码是利用 Vision Transformer 网络结构训练一个分类 mnist 数据集的主程序代码。

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import VIT
import os


def train():
    batch_size = 4
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    epoches = 20
    mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= batch_size, shuffle=True)
    mnist_model = VIT.VisionTransformer(
        image_size = 28,
        patch_size = 7,
        num_classes = 10,
        channels = 1,
        embed_size = 512,
        num_layers = 1,
        heads = 2,
        mlp_dim =1024,
        dropout = 0,
        emb_dropout = 0)
    loss_fn = nn.CrossEntropyLoss()
    mnist_model = mnist_model.to(device)
    opitimizer = optim.Adam(mnist_model.parameters(), lr=0.00001)
    mnist_model.train()
    for epoch in range(epoches):
        total_loss = 0
        corrects = 0
        num = 0
        for batch_X, batch_Y in train_loader:
            batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
            opitimizer.zero_grad()
            outputs = mnist_model(batch_X)
            _, pred = torch.max(outputs.data, 1)
            loss = loss_fn(outputs, batch_Y)
            loss.backward()
            opitimizer.step()
            total_loss += loss.item()
            corrects = torch.sum(pred == batch_Y.data)
            num += batch_size
            print(epoch, total_loss/float(num), corrects.item()/float(batch_size))

if __name__ == '__main__':
    train()

训练的过程如下所示,可以发现损失函数可以稳定下降。但是训练一个 Vision Transformer 模型真的是很烧硬件,跟训练一个普通的 CNN 模型相比,训练一个 Vision Transformer 模型更加耗时耗力。



特别鸣谢

感谢 TCCI 天桥脑科学研究院对于 PaperWeekly 的支持。TCCI 关注大脑探知、大脑功能和大脑健康。



更多阅读




#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编




🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧



·

登录查看更多
4

相关内容

Transformer是谷歌发表的论文《Attention Is All You Need》提出一种完全基于Attention的翻译架构

知识荟萃

精品入门和进阶教程、论文和代码整理等

更多

查看相关VIP内容、论文、资讯等
专知会员服务
22+阅读 · 2021年9月20日
专知会员服务
29+阅读 · 2021年7月30日
专知会员服务
60+阅读 · 2021年2月16日
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
Transformer性能优化:运算和显存
PaperWeekly
1+阅读 · 2022年3月29日
当可变形注意力机制引入Vision Transformer
极市平台
1+阅读 · 2022年1月23日
矩阵视角下的Transformer详解(附代码)
PaperWeekly
3+阅读 · 2022年1月4日
用Attention玩转CV,一文总览自注意力语义分割进展
入门 | 什么是自注意力机制?
机器之心
17+阅读 · 2018年8月19日
Attention is All You Need | 每周一起读
PaperWeekly
10+阅读 · 2017年6月28日
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2014年12月31日
国家自然科学基金
8+阅读 · 2013年12月31日
国家自然科学基金
3+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
Arxiv
2+阅读 · 2022年4月19日
Arxiv
3+阅读 · 2022年4月19日
Arxiv
1+阅读 · 2022年4月18日
Arxiv
1+阅读 · 2022年4月15日
Arxiv
17+阅读 · 2021年3月29日
Heterogeneous Graph Transformer
Arxiv
27+阅读 · 2020年3月3日
Arxiv
27+阅读 · 2017年12月6日
VIP会员
相关VIP内容
专知会员服务
22+阅读 · 2021年9月20日
专知会员服务
29+阅读 · 2021年7月30日
专知会员服务
60+阅读 · 2021年2月16日
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
相关资讯
Transformer性能优化:运算和显存
PaperWeekly
1+阅读 · 2022年3月29日
当可变形注意力机制引入Vision Transformer
极市平台
1+阅读 · 2022年1月23日
矩阵视角下的Transformer详解(附代码)
PaperWeekly
3+阅读 · 2022年1月4日
用Attention玩转CV,一文总览自注意力语义分割进展
入门 | 什么是自注意力机制?
机器之心
17+阅读 · 2018年8月19日
Attention is All You Need | 每周一起读
PaperWeekly
10+阅读 · 2017年6月28日
相关基金
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2014年12月31日
国家自然科学基金
8+阅读 · 2013年12月31日
国家自然科学基金
3+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
相关论文
Arxiv
2+阅读 · 2022年4月19日
Arxiv
3+阅读 · 2022年4月19日
Arxiv
1+阅读 · 2022年4月18日
Arxiv
1+阅读 · 2022年4月15日
Arxiv
17+阅读 · 2021年3月29日
Heterogeneous Graph Transformer
Arxiv
27+阅读 · 2020年3月3日
Arxiv
27+阅读 · 2017年12月6日
Top
微信扫码咨询专知VIP会员