搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(二十一)

2022 年 1 月 14 日 极市平台
↑ 点击 蓝字  关注极市平台

作者丨科技猛兽
编辑丨极市平台

极市导读

 

本文详解了来自华为诺亚方舟实验室的工作-Pyramid TNT,本文通过引入两种方法来改进TNT(Transformer-in-Transformer)基线:1)金字塔架构,和 2)卷积stem,以创建新的PyramidTNT。>>加入极市CV技术交流群,走在计算机视觉的最前沿

本文目录

41 Pyramid TNT:使用金字塔结构改进的 TNT Baseline
(来自北京华为诺亚方舟实验室)
40.1 TNT 回顾
41.2 Pyramid TNT 原理分析
41.3 Pyramid TNT 代码解读

Transformer 是 Google 的团队在 2017 年提出的一种 NLP 经典模型,现在比较火热的 Bert 也是基于 Transformer。Transformer 模型使用了 Self-Attention 机制,不采用 RNN 的顺序结构,使得模型可以并行化训练,而且能够拥有全局信息。

Transformer in Transformer 针对 ViT 处理图片的方式:将输入图片划分成一个个块 (patch) ,然后针对将这些 patch 看成一个序列 (Sequence) 的不完美之处,提出了一种 TNT 架构,它不仅考虑 patch 之间的信息,还考虑每个 patch 的内部信息,使得 Transformer 模型分别对整体和局部信息进行建模,提升性能。

TNT 架构没有使用 PVT 提出的 Transformer 模型金字塔结构,而金字塔结构在大多数 Vision Transformer 和 MLP 模型上都被证明了有很好的建模性能,所以 Pyramid TNT 作为 TNT 的 Extended Version,进一步验证了金字塔结构对于 TNT Backbone 的作用。

41 Pyramid TNT:使用金字塔结构改进的 TNT Baseline

论文名称:PyramidTNT: Improved Transformer-in-Transformer Baselines with Pyramid Architecture

TNT 论文地址:

https://arxiv.org/pdf/2103.00112.pdf

Pyramid TNT 论文地址:

https://arxiv.org/pdf/2201.00978.pdf

40.1 TNT 回顾:

Vision Transformer 超详细解读 (原理分析+代码解读) (四)

Transformer 需要的是序列 (Sequence)的输入信号,而我们有的是 image 这种 2D 的输入信号,那直接把图片分块以后进行 Flatten 操作是一种很直觉的处理方式。但是,这种intuitive的方法能不能够完美地建模图像,因为我们缺少了一部分非常重要的信息,即:每个patch的内部信息

TNT 认为,每个输入的内部信息,即每个 patch 的内部信息,没有被 Transformer 所建模。是一个欠考虑的因素。所以 TNT 使得 Transformer 模型既建模那些不同 patch 之间的关系,也要建模每个 patch 内部的关系。

图1:Transformer in Transformer

第1步还是将输入图片划分成  个块 (patch):

式中  是每个块的大小。ViT,DeiT,IPT,SETR,ViT-FRCNN 到这里就把它们输入 Transformer了,TNT 为了更好地学习图片中 global 和 local 信息的关系,还要再进行一步。在 TNT 中,作者将 patch 视为表示图像的视觉 "sentence"。每个 patch 进一步分成  个子块,即一个 "sentence" 由一系列视觉 "words" 组成。

式中,  代表第  个视觉 "sentence" 的第  个视觉 "words",这一步其实是把每个 patch 通过PyTorch 的 unfold 操作划分成更小的 patch,之后把这些小 patch 通过线性投影展平,就得到了:

其中,  是第  个视觉 "words" 的 Embedding,  代表 Embedding dimension。

如下图1所示,输入是一个大 patch,输出的黄色大长条是这个 patch 展平以后的 sentence embedding,输出的彩色小长条是这个 patch 划分成更小的 patch 之后再展平以后的 word embedding。

图2:对patch进行unfold操作后得到更小的patch

图2的操作进行完之后就得到了大 patch 的 sentence embedding 以及小 patch 的 word embedding。接下来把它们送入Transformer的Block里面建模特征,如下图2所示。Transformer 是由很多 TNT Blocks 组成的,每个 TNT Block 包含2个 Transformer Block,分别是:

图3:Transformer in Transformer
  • Outer block 建模 sentence embedding 之间的 global relationship。
  • Inner block 建模 word embedding 之间的 local structure information。

这两种 Block 对应2个不同的数据流,其中 Outer block 的数据流在不同 patch 之间运行,而 Inner block 的数据流在每个 patch 内部运行。

Inner Transformer:

定义  ,我们把这个值传入 Inner Transformer  ,则有:

注意正常的 Transformer 的输入应该是  的张量,这里  代表 batch size,  代表序列长度,  代表hidden dimension。不考虑 batch size 这一维,就是一个  的矩阵,也可以看做是  个  维向量,那么对于 Inner Transformer  来讲,这里的  。也就是说,Inner Transformer 的输入是  个  维的向量。注意这里的  就是这  个向量的其中一个。所以 Inner Transformer 的第  个 layer 的输出就可以写为:

Inner Transformer  建模的是更细小的像素级别的 relationship,例如,在一张人脸中,属于眼睛的像素与眼睛的其他像素更相关,而与前额像素的 relationship 较少。

Outer Transformer:

Outer Transformer  就相当于是 ViT 中的 Transformer,它建模的是更答大的 patch 级别的 relationship,输入的 patch embedding 使用 ViT 类似的做法,添加  ,它们初始化为0。

定义  为第  个layer的第  个向量,则 Outer Transformer 的表达式为:

那么现在既有 Outer Transformer 的第  个 layer 的输出向量:

也有 Inner Transformer 的第  个 layer 的输出向量:

下面的问题是:要如何把它们结合起来,以融合 global 和 local 的信息呢?

作者采用的方式是:

式中,  代表 Flatten 操作,  代表权重。

通过这种方式,把第  个 layer 的第  个 sentence embedding 向量和第  个 word embedding 向量融合起来,即对应图2的结构。

总的来说,TNT Block 第  个 layer 的输入和输出可以表示为:

在 TNT Block 中,Inner Transformer 建模 word embedding 之间的 local structure information 之间的关系,而 Outer block 建模 sentence embedding 之间的 global relationship。通过将 TNT Block 堆叠  次,作者构建了 Transformer in Transformer。最后,使用一个分类头对图像进行分类。

位置编码:

位置编码的作用是让像素间保持空间位置关系,对于图像就是保持二维信息,它对于图像识别任务来讲很重要。具体来说,就需要对 sentence embedding 和 word embedding 分别设计一种位置编码。

  • sentence positional encoding:

作者这里使用的是可学习的1D位置编码:

式中,  是给 sentence embedding 使用的位置编码,它用来编码全局空间信息 (global spatial information)。

  • word positional encoding:

作者这里使用的是可学习的1D位置编码:

式中,  是给 word embedding 使用的位置编码,它们用来编码局部相对信息 (local relative information)。

40.2 Pyramid TNT 原理分析:

TNT 作为一种通用的视觉任务 Backbone,取得了优异的性能。Pyramid TNT 受到 Transformer 模型两种主流改进方法:金字塔架构 (PVT,Swin Transformer,CycleMLP 等等)卷积 stem (Convolutional Stem) 的启发,改进了 TNT 架构。

Pyramid TNT 将它们融入 TNT 中,金字塔架构 (Pyramid Structure) 用于提取多尺度信息,卷积 stem (Convolutional Stem) 用于改善图片分块的方法和使得训练过程更加稳定。此外,Pyramid TNT 还包括其他一些 trick 比如相对位置编码等。

图4:Pyramid TNT

Convolutional Stem

给定输入图片  ,ViT 的做法是通过一个  的卷积进行图片的分块操作。Early convolutions help transformers see better (NeurIPS 2021) 这篇论文发现:将这个卷积操作替换成几个连续的卷积操作能够使得 Transformer 模型获得更好的性能,且对优化器更加鲁棒,增加了优化稳定性。基于这个发现,作者也对 TNT 模型应用了 Convolutional Stem。

具体而言 Pyramid TNT 的 Convolutional Stem 是5个 3×3 卷积。对于 Outer Transformer,Convolutional Stem 将输入图片变成  ,式中  是 sentence embedding 维度。对于 Inner Transformer,Convolutional Stem 将输入图片变成  ,式中  是 word embedding 的维度。对于位置编码,sentence positional encoding 和 word positional encoding 被分别添加在了 sentence embedding 和 word embedding 上。

Pyramid Architecture

[原始的 TNT 网络]:

在原始 TNT 中,在每个 Block 中保持相同数量的 tokens,遵循 ViT 的设计方式。视觉 "sentence" 和视觉 "words" 的数量自下而上一直保持不变。

视觉 "sentence" 的特征图分辨率自下而上一直是 

视觉 "words" 的特征图分辨率自下而上一直是 

[Pyramid TNT 网络]:

在 Pyramid TNT 中,网络在每个 stage 中保持不同数量的 tokens,遵循 PVT 的设计方式。视觉 "sentence" 和视觉 "words" 的数量自下而上分阶段变化。

视觉 "words" 的特征图分辨率在4个 stage 中分别是: 

视觉 "sentence" 的特征图分辨率在4个 stage 中分别是: 

通过 Convolution Stem,把 224×224 的输入图片分成 8×8 的大 patch,一共是 28×28 个。所以 Outer Transformer 特征图的分辨率是:

通过 Convolution Stem,把 8×8 的大 patch 分成 2×2 的小 patch,一共是 4×4×28×28 个。所以 Inner Transformer 特征图的分辨率是:

不同 stage 之间通过一个  的卷积操作降低特征分辨率。注意不同的 stage 的 Outer Transformer,即视觉 "sentence" 的特征图分辨率  是大小一直变化的,而不同的 stage 的 Inner Transformer视觉 "words" 的特征图分辨率一直是 

实验结果

分类任务实验结果

数据集:ImageNet-1k (1,280,000 Training data, 50,000 validation data,1000 classes)

超参数设置:

图5:Pyramid TNT ImageNet 实验超参数设置

实验结果如下图5所示。与原始 TNT 相比,Pyramid TNT 获得了更好的性能。Pyramid TNT-S比 TNT-S 少 1.9B 计算量,精度提高了0.5%。作者还将 Pyramid TNT 与其他有代表性的 CNN、MLP 和基于 Transformer 的模型进行了比较。从结果中,我们可以看到 Pyramid TNT 是最先进的视觉 Backbone。

图6:ImageNet 分类任务实验结果

目标检测实验结果

数据集: COCO 2017 (118,000 Training data, 50,000 validation data)

对比的框架: RetinaNet,Mask R-CNN

超参数: Batch size=2,AdamW Optimizer,initial lr=1e-4,在第8和第11个 Epoch 分别乘以0.1,weight decay=0.05,"1x" schedule (12 epochs),输入图片 resize 成 (1333, 800)。

金字塔的四个阶段的空间分辨率被设置为:  。作者使用了  的转置卷积和 BN 和 GeLU 激活函数加上  的卷积和 BN 和 GeLU 激活函数,以产生  的分辨率的特征图。

图7:目标检测实验结果

在具有相似计算成本的 one-stage 和 two-stage 的检测器上,Pyramid-S 明显优于其他 Backbone。例如,基于 Pyramid-S 的 RetinaNet 达到了42.0  和 57.7  。这些结果表明,金字塔结构有助于捕获更好的全局信息。

实例分割实验结果

数据集: COCO 2017 (118,000 Training data, 50,000 validation data)

对比的框架: Mask R-CNN,Cascade Mask R-CNN

超参数: Batch size=16,AdamW Optimizer,initial lr=1e-4,在第27和第33个 Epoch 分别乘以0.1,weight decay=0.05,"3x" schedule,输入图片 resize 成 (1333, 800)。

图7:实例分割实验结果

Pyramid-S 在 Mask R-CNN 和Cascade Mask R-CNN 上可以获得比其他 Backbone 好得多的  和  ,显示出其更好的特征表示能力。例如,Pyramid-S 在 Mask R-CNN 上 Wave-MLP 高出0.9 的 

40.3 Pyramid TNT 代码解读:

代码来自:

https://github.com/huawei-noah/CV-Backbones/tree/master/tnt_pytorch

一些张量的维度的大小已经在代码中以注释的形式进行标注。

Convolutional Stem:

class Stem(nn.Module):
""" Image to Visual Word Embedding
Overlap: https://arxiv.org/pdf/2106.13797.pdf
"""
def __init__(self, img_size=224, in_chans=3, outer_dim=768, inner_dim=24):
super().__init__()
img_size = to_2tuple(img_size)
self.img_size = img_size
self.inner_dim = inner_dim
self.num_patches = img_size[0] // 8 * img_size[1] // 8
self.num_words = 16

self.common_conv = nn.Sequential(
nn.Conv2d(in_chans, inner_dim*2, 3, stride=2, padding=1),
nn.BatchNorm2d(inner_dim*2),
nn.ReLU(inplace=True),
)
self.inner_convs = nn.Sequential(
nn.Conv2d(inner_dim*2, inner_dim, 3, stride=1, padding=1),
nn.BatchNorm2d(inner_dim),
nn.ReLU(inplace=False),
)
self.outer_convs = nn.Sequential(
nn.Conv2d(inner_dim*2, inner_dim*4, 3, stride=2, padding=1),
nn.BatchNorm2d(inner_dim*4),
nn.ReLU(inplace=True),
nn.Conv2d(inner_dim*4, inner_dim*8, 3, stride=2, padding=1),
nn.BatchNorm2d(inner_dim*8),
nn.ReLU(inplace=True),
nn.Conv2d(inner_dim*8, outer_dim, 3, stride=1, padding=1),
nn.BatchNorm2d(outer_dim),
nn.ReLU(inplace=False),
)

self.unfold = nn.Unfold(kernel_size=4, padding=0, stride=4)

def forward(self, x):
B, C, H, W = x.shape
H_out, W_out = H // 8, W // 8
H_in, W_in = 4, 4
x = self.common_conv(x)
# inner_tokens
# inner_tokens: (B, inner_dim, H/2, W/2)
inner_tokens = self.inner_convs(x) # B, C, H, W
# inner_tokens: (B, H/8, W/8, inner_dim*4*4)
inner_tokens = self.unfold(inner_tokens).transpose(1, 2) # B, N, Ck2
# inner_tokens: (B, inner_dim, H/8*W/8, 4*4)
inner_tokens = inner_tokens.reshape(B * H_out * W_out, self.inner_dim, H_in*W_in).transpose(1, 2) # B*N, C, 4*4
# outer_tokens
# outer_tokens: (B, outer_dim, H/8, W/8)
outer_tokens = self.outer_convs(x) # B, C, H_out, W_out
# outer_tokens: (B, H/8*W/8, outer_dim)
outer_tokens = outer_tokens.permute(0, 2, 3, 1).reshape(B, H_out * W_out, -1)
return inner_tokens, outer_tokens, (H_out, W_out), (H_in, W_in)

注意 Convolution Stem 返回的 inner_tokens 和 outer_tokens 张量的维度:
outer_tokens: (B, H/8*W/8, outer_dim)
inner_tokens: (B, inner_dim, H/8*W/8, 4*4)
Convolution Stem 返回的 inner_tokens 和 outer_tokens 分别通过后面 Block 类的 Inner Attention 和 Outer Attention,二者输出的张量维度分别是:(B*H/8*W/8, 4*4, inner_dim) 和 (B, H/8*W/8, outer_dim)。之后,这两个张量再按照上式10中的方式融合在一起。
其实 Pyramid Transformer in Transformer 代码的核心是通过这个 Convolution Stem 分别得到两个不同维度的张量,一个输入 Outer Transformer Block,一个输入 Inner Transformer Block。这两个 Transformer Block 的输出再按照上式10中的方式融合在一起。

MLP:

class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x

Attention 类 (这里作者用了 PVT V2 的轻量 attention 类的实现):

class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5

self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.pool = nn.AvgPool2d(sr_ratio, stride=sr_ratio)
self.linear = nn.Linear(dim, dim)
self.norm = nn.LayerNorm(dim)

def forward(self, x, H, W, relative_pos=None):
B, N, C = x.shape
# q: (B, nH, N, C/nH)
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

if self.sr_ratio > 1:
# x_: (B, C, H, W)
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
# x_: (B, N/4, C)
x_ = self.pool(x_).reshape(B, C, -1).permute(0, 2, 1)
# x_: (B, N/4, C)
x_ = self.norm(self.linear(x_))
# x_: (2, B, nH, N/4, C/nH)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# k,v: (B, nH, N/4, C/nH)
k, v = kv[0], kv[1]

# attn: (B, nH, N, N/4)
attn = (q @ k.transpose(-2, -1)) * self.scale
if relative_pos is not None:
attn += relative_pos
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

# X: (B, N, C)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x

一个 Pyramid TNT Block 的实现:

class Block(nn.Module):
""" TNT Block
"""
def __init__(self, outer_dim, inner_dim, outer_head, inner_head, num_words, mlp_ratio=4.,
qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
norm_layer=nn.LayerNorm, se=0, sr_ratio=1):
super().__init__()
self.has_inner = inner_dim > 0
if self.has_inner:
# Inner
self.inner_norm1 = norm_layer(num_words * inner_dim)
self.inner_attn = Attention(
inner_dim, num_heads=inner_head, qkv_bias=qkv_bias,
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.inner_norm2 = norm_layer(num_words * inner_dim)
self.inner_mlp = Mlp(in_features=inner_dim, hidden_features=int(inner_dim * mlp_ratio),
out_features=inner_dim, act_layer=act_layer, drop=drop)

self.proj_norm1 = norm_layer(num_words * inner_dim)
self.proj = nn.Linear(num_words * inner_dim, outer_dim, bias=False)
self.proj_norm2 = norm_layer(outer_dim)
# Outer
self.outer_norm1 = norm_layer(outer_dim)
self.outer_attn = Attention(
outer_dim, num_heads=outer_head, qkv_bias=qkv_bias,
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.outer_norm2 = norm_layer(outer_dim)
self.outer_mlp = Mlp(in_features=outer_dim, hidden_features=int(outer_dim * mlp_ratio),
out_features=outer_dim, act_layer=act_layer, drop=drop)
# SE
self.se = se
self.se_layer = None
if self.se > 0:
self.se_layer = SE(outer_dim, 0.25)

def forward(self, x, outer_tokens, H_out, W_out, H_in, W_in, relative_pos):
# outer_tokens: (B, H/8*W/8, outer_dim)
B, N, C = outer_tokens.size()
if self.has_inner:
# x: (B*H/8*W/8, 4*4, inner_dim)
x = x + self.drop_path(self.inner_attn(self.inner_norm1(x.reshape(B, N, -1)).reshape(B*N, H_in*W_in, -1), H_in, W_in)) # B*N, k*k, c
# x: (B*H/8*W/8, 4*4, inner_dim)
x = x + self.drop_path(self.inner_mlp(self.inner_norm2(x.reshape(B, N, -1)).reshape(B*N, H_in*W_in, -1))) # B*N, k*k, c
# outer_tokens: (B, H/8*W/8, outer_dim)
outer_tokens = outer_tokens + self.proj_norm2(self.proj(self.proj_norm1(x.reshape(B, N, -1)))) # B, N, C
if self.se > 0:
outer_tokens = outer_tokens + self.drop_path(self.outer_attn(self.outer_norm1(outer_tokens), H_out, W_out, relative_pos))
tmp_ = self.outer_mlp(self.outer_norm2(outer_tokens))
outer_tokens = outer_tokens + self.drop_path(tmp_ + self.se_layer(tmp_))
else:
# outer_tokens: (B, H/8*W/8, outer_dim)
outer_tokens = outer_tokens + self.drop_path(self.outer_attn(self.outer_norm1(outer_tokens), H_out, W_out, relative_pos))
# outer_tokens: (B, H/8*W/8, outer_dim)
outer_tokens = outer_tokens + self.drop_path(self.outer_mlp(self.outer_norm2(outer_tokens)))
# x: (B*H/8*W/8, 4*4, inner_dim)
# outer_tokens: (B, H/8*W/8, outer_dim)
return x, outer_tokens

和 TNT 基本一致,不同之处是前向函数中还需要传入 H_out, W_out, H_in, W_in, relative_pos 这些参数,它们分别代表大 patch 和小 patch 的特征分辨率大小。

一个 Pyramid TNT Stage 的实现:

class Stage(nn.Module):
""" PyramidTNT stage
"""
def __init__(self, num_blocks, outer_dim, inner_dim, outer_head, inner_head, num_patches, num_words, mlp_ratio=4.,
qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
norm_layer=nn.LayerNorm, se=0, sr_ratio=1):
super().__init__()
blocks = []
drop_path = drop_path if isinstance(drop_path, list) else [drop_path] * num_blocks

for j in range(num_blocks):
if j == 0:
_inner_dim = inner_dim
elif j == 1 and num_blocks > 6:
_inner_dim = inner_dim
else:
_inner_dim = -1
blocks.append(Block(
outer_dim, _inner_dim, outer_head=outer_head, inner_head=inner_head,
num_words=num_words, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop,
attn_drop=attn_drop, drop_path=drop_path[j], act_layer=act_layer, norm_layer=norm_layer,
se=se, sr_ratio=sr_ratio))

self.blocks = nn.ModuleList(blocks)
self.relative_pos = nn.Parameter(torch.randn(
1, outer_head, num_patches, num_patches // sr_ratio // sr_ratio))

def forward(self, inner_tokens, outer_tokens, H_out, W_out, H_in, W_in):
for blk in self.blocks:
inner_tokens, outer_tokens = blk(inner_tokens, outer_tokens, H_out, W_out, H_in, W_in, self.relative_pos)
return inner_tokens, outer_tokens

不同的 stage 之间应有下采样的操作。"sentence" level 和 "word" level 的下采样分别通过下面的 SentenceAggregation 类和 WordAggregation 类来解决:

class SentenceAggregation(nn.Module):
""" Sentence Aggregation
"""
def __init__(self, dim_in, dim_out, stride=2, act_layer=nn.GELU):
super().__init__()
self.stride = stride
self.norm = nn.LayerNorm(dim_in)
self.conv = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=2*stride-1, padding=stride-1, stride=stride),
)

def forward(self, x, H, W):
B, N, C = x.shape # B, N, C
x = self.norm(x)
x = x.transpose(1, 2).reshape(B, C, H, W)
x = self.conv(x)
H, W = math.ceil(H / self.stride), math.ceil(W / self.stride)
x = x.reshape(B, -1, H * W).transpose(1, 2)
return x, H, W

class WordAggregation(nn.Module):
""" Word Aggregation
"""
def __init__(self, dim_in, dim_out, stride=2, act_layer=nn.GELU):
super().__init__()
self.stride = stride
self.dim_out = dim_out
self.norm = nn.LayerNorm(dim_in)
self.conv = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=2*stride-1, padding=stride-1, stride=stride),
)

def forward(self, x, H_out, W_out, H_in, W_in):
B_N, M, C = x.shape # B*N, M, C
x = self.norm(x)
x = x.reshape(-1, H_out, W_out, H_in, W_in, C)

# padding to fit (1333, 800) in detection.
pad_input = (H_out % 2 == 1) or (W_out % 2 == 1)
if pad_input:
x = F.pad(x.permute(0, 3, 4, 5, 1, 2), (0, W_out % 2, 0, H_out % 2))
x = x.permute(0, 4, 5, 1, 2, 3)
# patch merge
x1 = x[:, 0::2, 0::2, :, :, :] # B, H/2, W/2, H_in, W_in, C
x2 = x[:, 1::2, 0::2, :, :, :]
x3 = x[:, 0::2, 1::2, :, :, :]
x4 = x[:, 1::2, 1::2, :, :, :]
x = torch.cat([torch.cat([x1, x2], 3), torch.cat([x3, x4], 3)], 4) # B, H/2, W/2, 2*H_in, 2*W_in, C
x = x.reshape(-1, 2*H_in, 2*W_in, C).permute(0, 3, 1, 2) # B_N/4, C, 2*H_in, 2*W_in
x = self.conv(x) # B_N/4, C, H_in, W_in
x = x.reshape(-1, self.dim_out, M).transpose(1, 2)
return x

我们可以发现 "sentence" level 和 "word" level 的下采样都是通过一个卷积操作完成。

第1个 stage 结束后的下采样:
SentenceAggregation 输入维度:  输出维度: 
WordAggregation 输入维度: 输出维度: 

第2个 stage 结束后的下采样:
SentenceAggregation 输入维度:  输出维度: 
WordAggregation 输入维度: 输出维度: 

第3个 stage 结束后的下采样:
SentenceAggregation 输入维度:  输出维度: 
WordAggregation 输入维度: 输出维度: 

Pyramid TNT 整体模型架构:

class PyramidTNT(nn.Module):
""" PyramidTNT (Transformer in Transformer) for computer vision
"""
def __init__(self, configs=None, img_size=224, in_chans=3, num_classes=1000, mlp_ratio=4., qkv_bias=False,
qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, se=0):
super().__init__()
self.num_classes = num_classes
depths = configs['depths']
outer_dims = configs['outer_dims']
inner_dims = configs['inner_dims']
outer_heads = configs['outer_heads']
inner_heads = configs['inner_heads']
sr_ratios = [4, 2, 1, 1]
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
self.num_features = outer_dims[-1] # num_features for consistency with other models

self.patch_embed = Stem(
img_size=img_size, in_chans=in_chans, outer_dim=outer_dims[0], inner_dim=inner_dims[0])
num_patches = self.patch_embed.num_patches
num_words = self.patch_embed.num_words

self.outer_pos = nn.Parameter(torch.zeros(1, num_patches, outer_dims[0]))
self.inner_pos = nn.Parameter(torch.zeros(1, num_words, inner_dims[0]))
self.pos_drop = nn.Dropout(p=drop_rate)

depth = 0
self.word_merges = nn.ModuleList([])
self.sentence_merges = nn.ModuleList([])
self.stages = nn.ModuleList([])
for i in range(4):
if i > 0:
self.word_merges.append(WordAggregation(inner_dims[i-1], inner_dims[i], stride=2))
self.sentence_merges.append(SentenceAggregation(outer_dims[i-1], outer_dims[i], stride=2))
self.stages.append(Stage(depths[i], outer_dim=outer_dims[i], inner_dim=inner_dims[i],
outer_head=outer_heads[i], inner_head=inner_heads[i],
num_patches=num_patches // (2 ** i) // (2 ** i), num_words=num_words, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[depth:depth+depths[i]], norm_layer=norm_layer, se=se, sr_ratio=sr_ratios[i])
)
depth += depths[i]

self.norm = norm_layer(outer_dims[-1])

# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
# self.repr = nn.Linear(outer_dim, representation_size)
# self.repr_act = nn.Tanh()

# Classifier head
self.head = nn.Linear(outer_dims[-1], num_classes) if num_classes > 0 else nn.Identity()

trunc_normal_(self.outer_pos, std=.02)
trunc_normal_(self.inner_pos, std=.02)
self.apply(self._init_weights)

def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

@torch.jit.ignore
def no_weight_decay(self):
return {'outer_pos', 'inner_pos'}

def get_classifier(self):
return self.head

def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

def forward_features(self, x):
inner_tokens, outer_tokens, (H_out, W_out), (H_in, W_in) = self.patch_embed(x)
inner_tokens += self.inner_pos # B*N, 8*8, C
outer_tokens += self.pos_drop(self.outer_pos) # B, N, D

for i in range(4):
if i > 0:
inner_tokens = self.word_merges[i-1](inner_tokens, H_out, W_out, H_in, W_in)
outer_tokens, H_out, W_out = self.sentence_merges[i-1](outer_tokens, H_out, W_out)
inner_tokens, outer_tokens = self.stages[i](inner_tokens, outer_tokens, H_out, W_out, H_in, W_in)

outer_tokens = self.norm(outer_tokens)
return outer_tokens.mean(dim=1)

def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x

不同大小的 Pyramid TNT 配置信息:

@register_model
def ptnt_ti_patch16_192(pretrained=False, **kwargs):
outer_dim = 80
inner_dim = 5
outer_head = 2
inner_head = 1
configs = {
'depths': [2, 6, 3, 2],
'outer_dims': [outer_dim, outer_dim*2, outer_dim*4, outer_dim*4],
'inner_dims': [inner_dim, inner_dim*2, inner_dim*4, inner_dim*4],
'outer_heads': [outer_head, outer_head*2, outer_head*4, outer_head*4],
'inner_heads': [inner_head, inner_head*2, inner_head*4, inner_head*4],
}

model = PyramidTNT(configs=configs, img_size=192, qkv_bias=False, **kwargs)
model.default_cfg = default_cfgs['tnt_s_patch16_192']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
return model

小结

本文介绍了 Pyramid TNT 架构的原理和 PyTorch 代码实现。TNT 作为一种通用的视觉任务 Backbone,取得了优异的性能。Pyramid TNT 受到 Transformer 模型两种主流改进方法:金字塔架构 (PVT,Swin Transformer,CycleMLP 等等) 和卷积 stem (Convolutional Stem) 的启发,改进了 TNT 架构。Pyramid TNT 将它们融入 TNT 中,金字塔架构 (Pyramid Structure) 用于提取多尺度信息,卷积 stem (Convolutional Stem) 用于改善图片分块的方法和使得训练过程更加稳定。此外,Pyramid TNT 还包括其他一些 trick 比如相对位置编码等。

如果觉得有用,就请分享到朋友圈吧!

△点击卡片关注极市平台,获取 最新CV干货

公众号后台回复“transformer”获取最新Transformer综述论文下载~


极市干货
课程/比赛: 珠港澳人工智能算法大赛 保姆级零基础人工智能教程
算法trick 目标检测比赛中的tricks集锦 从39个kaggle竞赛中总结出来的图像分割的Tips和Tricks
技术综述: 一文弄懂各种loss function 工业图像异常检测最新研究总结(2019-2020)


极市平台签约作者#


科技猛兽

知乎:科技猛兽


清华大学自动化系19级硕士

研究领域:AI边缘计算 (Efficient AI with Tiny Resource):专注模型压缩,搜索,量化,加速,加法网络,以及它们与其他任务的结合,更好地服务于端侧设备。


作品精选

搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了
用Pytorch轻松实现28个视觉Transformer,开源库 timm 了解一下!(附代码解读)
轻量高效!清华智能计算实验室开源基于PyTorch的视频 (图片) 去模糊框架SimDeblur



投稿方式:
添加小编微信Fengcall(微信号:fengcall19),备注:姓名-投稿
△长按添加极市平台小编

觉得有用麻烦给个在看啦~   
登录查看更多
0

相关内容

Pyramid is a small, fast, down-to-earth Python web application development framework.
ICLR 2022 | BEIT论文解读:将MLM无监督预训练应用到CV领域
专知会员服务
32+阅读 · 2022年3月24日
【ICLR2022】Vision Transformer 模型工作机制的最新理论
专知会员服务
42+阅读 · 2022年2月19日
Transformer如何用于视频?最新「视频Transformer」2022综述
专知会员服务
75+阅读 · 2022年1月20日
专知会员服务
29+阅读 · 2021年7月30日
专知会员服务
65+阅读 · 2021年7月21日
【文本分类大综述:从浅层到深度学习,35页pdf】
专知会员服务
184+阅读 · 2020年8月6日
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
当可变形注意力机制引入Vision Transformer
极市平台
1+阅读 · 2022年1月23日
计算机视觉中的transformer模型创新思路总结
极市平台
0+阅读 · 2021年12月4日
国家自然科学基金
6+阅读 · 2017年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
1+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
国家自然科学基金
1+阅读 · 2009年12月31日
国家自然科学基金
1+阅读 · 2008年12月31日
Arxiv
3+阅读 · 2022年4月19日
Arxiv
0+阅读 · 2022年4月15日
Arxiv
32+阅读 · 2022年2月15日
Arxiv
102+阅读 · 2021年6月8日
Arxiv
19+阅读 · 2021年4月8日
Arxiv
17+阅读 · 2021年3月29日
Arxiv
19+阅读 · 2020年12月23日
VIP会员
相关VIP内容
ICLR 2022 | BEIT论文解读:将MLM无监督预训练应用到CV领域
专知会员服务
32+阅读 · 2022年3月24日
【ICLR2022】Vision Transformer 模型工作机制的最新理论
专知会员服务
42+阅读 · 2022年2月19日
Transformer如何用于视频?最新「视频Transformer」2022综述
专知会员服务
75+阅读 · 2022年1月20日
专知会员服务
29+阅读 · 2021年7月30日
专知会员服务
65+阅读 · 2021年7月21日
【文本分类大综述:从浅层到深度学习,35页pdf】
专知会员服务
184+阅读 · 2020年8月6日
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
相关基金
国家自然科学基金
6+阅读 · 2017年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
1+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
国家自然科学基金
1+阅读 · 2009年12月31日
国家自然科学基金
1+阅读 · 2008年12月31日
相关论文
Arxiv
3+阅读 · 2022年4月19日
Arxiv
0+阅读 · 2022年4月15日
Arxiv
32+阅读 · 2022年2月15日
Arxiv
102+阅读 · 2021年6月8日
Arxiv
19+阅读 · 2021年4月8日
Arxiv
17+阅读 · 2021年3月29日
Arxiv
19+阅读 · 2020年12月23日
Top
微信扫码咨询专知VIP会员