©PaperWeekly 原创 · 作者 | 孙裕道
单位 | 北京邮电大学博士生
研究方向 | GAN图像生成、情绪对抗样本生成
引言
Transformer 在 NLP 中大获成功,Vision Transformer 则将 Transformer 模型架构扩展到计算机视觉的领域中,并且它可以很好的地取代卷积操作,在不依赖卷积的情况下,依然可以在图像分类任务上达到很好的效果。卷积操作只能考虑到局部的特征信息,而 Transformer 中的注意力机制可以综合考量全局的特征信息。
SAGAN 在 GAN 的框架中利用自注意力机制来捕获图像特征的长距离依赖关系,使得合成的图像中考量了所有的图像特征信息。SAGAN 中自注意力机制的操作原理如上图所示。
AttnGAN 通过利用注意力机制来实现多阶段细颗粒度的文本到图像的生成,它可以通过关注自然语言中的一些重要单词来对图像的不同子区域进行合成。比如通过文本“一只鸟有黄色的羽毛和黑色的眼睛”来生成图像时,会对关键词“鸟”,“羽毛”,“眼睛”,“黄色”,“黑色”给予不同的生成权重,并根据这些关键词的引导在图像的不同的子区域中进行细节的丰富。AttnGAN 中注意力机制的操作原理如上图所示。
Vision Transformer
本节主要详细介绍 Vision Transformer 的工作原理,3.1 节是关于 Vision Transformer 的整体框架,3.2 节是关于 Transformer Encoder 的内部操作细节。对于 Transformer Encoder 中 Multi-Head Attention 的原理本文不会赘述,具体想了解的可以参考上一篇文章《矩阵视角下的 Transformer 详解(附代码)》中相关原理的介绍。
如果下图所示为 Vision Transformer 的整体框架以及相应的训练流程。
如下图所示分别为 Vision Transformer Encoder 模型结构图和原始 Transformer Encoder 的模型结构图。可以直观的发现 Vision Transformer Encoder 和 Transformer Encoder 都有层归一化,多头注意力机制,残差连接和线性变换这四个操作,只是在操作顺序有所不同。在以下的 Transformer 代码实例中,将以下两种 Encoder 网络结构都进行了实现,可以发现两种网络结构都可以进行很好的训练。
Vision Transformer 的代码示例如下所示。该代码是由上一篇《矩阵视角下的Transformer详解(附代码)》的代码的基础上改编而来。Vision Transformer 的作者的本意就是想让在 NLP 中的 Transformer 模型架构做尽可能少的修改可以直接迁移到 CV 中,所以以下程序尽可能保持作者的愿意,并在代码实现了两种Encoder 的网络结构,即 3.2 节图片所示的两个网络结构,一种是最原始的Encoder 网络结构,一种是 Vision Transformer。论文里的 Encoder 的网络结构。
import torch
import torch.nn as nn
import os
from einops import rearrange
from einops import repeat
from einops.layers.torch import Rearrange
def inputs_deal(inputs):
return inputs if isinstance(inputs, tuple) else(inputs, inputs)
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query):
N =query.shape[0]
value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]
# split embedding into self.heads pieces
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)
# queries shape: (N, query_len, heads, heads_dim)
# keys shape : (N, key_len, heads, heads_dim)
# energy shape: (N, heads, query_len, key_len)
attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim=3)
out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
# attention shape: (N, heads, query_len, key_len)
# values shape: (N, value_len, heads, heads_dim)
# (N, query_len, heads, head_dim)
out = self.fc_out(out)
return out
class TransformerBlock(nn.Module):
def __init__(self, embed_size, heads, dropout, forward_expansion):
super(TransformerBlock, self).__init__()
self.attention = SelfAttention(embed_size, heads)
self.norm = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion*embed_size),
nn.ReLU(),
nn.Linear(forward_expansion*embed_size, embed_size)
)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, x, type_mode):
if type_mode == 'original':
attention = self.attention(value, key, query)
x = self.dropout(self.norm(attention + x))
forward = self.feed_forward(x)
out = self.dropout(self.norm(forward + x))
return out
else:
attention = self.attention(self.norm(value), self.norm(key), self.norm(query))
x =self.dropout(attention + x)
forward = self.feed_forward(self.norm(x))
out = self.dropout(forward + x)
return out
class TransformerEncoder(nn.Module):
def __init__(
self,
embed_size,
num_layers,
heads,
forward_expansion,
dropout = 0,
type_mode = 'original'
):
super(TransformerEncoder, self).__init__()
self.embed_size = embed_size
self.type_mode = type_mode
self.Query_Key_Value = nn.Linear(embed_size, embed_size * 3, bias = False)
self.layers = nn.ModuleList(
[
TransformerBlock(
embed_size,
heads,
dropout=dropout,
forward_expansion=forward_expansion,
)
for _ in range(num_layers)]
)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
for layer in self.layers:
QKV_list = self.Query_Key_Value(x).chunk(3, dim = -1)
x = layer(QKV_list[0], QKV_list[1], QKV_list[2], x, self.type_mode)
return x
class VisionTransformer(nn.Module):
def __init__(self,
image_size,
patch_size,
num_classes,
embed_size,
num_layers,
heads,
mlp_dim,
pool = 'cls',
channels = 3,
dropout = 0,
emb_dropout = 0.1,
type_mode = 'vit'):
super(VisionTransformer, self).__init__()
img_h, img_w = inputs_deal(image_size)
patch_h, patch_w = inputs_deal(patch_size)
assert img_h % patch_h == 0 and img_w % patch_w == 0, 'Img dimensions can be divisible by the patch dimensions'
num_patches = (img_h // patch_h) * (img_w // patch_w)
patch_size = channels * patch_h * patch_w
self.patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_h, p2=patch_w),
nn.Linear(patch_size, embed_size, bias=False)
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_size))
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_size))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = TransformerEncoder(embed_size,
num_layers,
heads,
mlp_dim,
dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(embed_size),
nn.Linear(embed_size, num_classes)
)
def forward(self, img):
x = self.patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d ->b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
if __name__ == '__main__':
vit = VisionTransformer(
image_size = 256,
patch_size = 16,
num_classes = 10,
embed_size = 256,
num_layers = 6,
heads = 8,
mlp_dim = 512,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(3, 3, 256, 256)
pred = vit(img)
print(pred)
以下代码是利用 Vision Transformer 网络结构训练一个分类 mnist 数据集的主程序代码。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import VIT
import os
def train():
batch_size = 4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epoches = 20
mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= batch_size, shuffle=True)
mnist_model = VIT.VisionTransformer(
image_size = 28,
patch_size = 7,
num_classes = 10,
channels = 1,
embed_size = 512,
num_layers = 1,
heads = 2,
mlp_dim =1024,
dropout = 0,
emb_dropout = 0)
loss_fn = nn.CrossEntropyLoss()
mnist_model = mnist_model.to(device)
opitimizer = optim.Adam(mnist_model.parameters(), lr=0.00001)
mnist_model.train()
for epoch in range(epoches):
total_loss = 0
corrects = 0
num = 0
for batch_X, batch_Y in train_loader:
batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
opitimizer.zero_grad()
outputs = mnist_model(batch_X)
_, pred = torch.max(outputs.data, 1)
loss = loss_fn(outputs, batch_Y)
loss.backward()
opitimizer.step()
total_loss += loss.item()
corrects = torch.sum(pred == batch_Y.data)
num += batch_size
print(epoch, total_loss/float(num), corrects.item()/float(batch_size))
if __name__ == '__main__':
train()
训练的过程如下所示,可以发现损失函数可以稳定下降。但是训练一个 Vision Transformer 模型真的是很烧硬件,跟训练一个普通的 CNN 模型相比,训练一个 Vision Transformer 模型更加耗时耗力。
特别鸣谢
感谢 TCCI 天桥脑科学研究院对于 PaperWeekly 的支持。TCCI 关注大脑探知、大脑功能和大脑健康。
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧