©作者 | CW不要無聊的風格
研究方向 | 目标检测、大规模预训练模型
Aided by the rapid gains in hardware, models today can easily overfit one million images and begin to demand hundreds of millions of—often publicly inaccessible—labeled images.
progress of autoencoding methods in vision lags behind NLP.
We ask: what makes masked autoencoding different between vision and language?
Driven by this analysis, we present a simple, effective, and scalable form of a masked autoencoder (MAE) for visual representation learning.
With a vanilla ViT-Huge model, we achieve 87.8% accuracy when finetuned on ImageNet-1K. This outperforms all previous results that use only ImageNet-1K data.
Computing the loss only on masked patches differs from traditional denoising autoencoders that compute the loss on all pixels. This choice is purely result-driven:
computing the loss on all pixels leads to a slight decrease in accuracy (e.g., ~0.5%).
class MAE(nn.Module):
def __init__(
self, encoder, decoder_dim,
mask_ratio=0.75, decoder_depth=1,
num_decoder_heads=8, decoder_dim_per_head=64
):
super().__init__()
assert 0. < mask_ratio < 1., f'mask ratio must be kept between 0 and 1, got: {mask_ratio}'
# Encoder(这里 CW 用 ViT 实现)
self.encoder = encoder
self.patch_h, self.patch_w = encoder.patch_h, encoder.patch_w
# 由于原生的 ViT 有 cls_token,因此其 position embedding 的倒数第2个维度是:
# 实际划分的 patch 数量加上 1个 cls_token
num_patches_plus_cls_token, encoder_dim = encoder.pos_embed.shape[-2:]
# Input channels of encoder patch embedding: patch size**2 x 3
# 这个用作预测头部的输出通道,从而能够对 patch 中的所有像素值进行预测
num_pixels_per_patch = encoder.patch_embed.weight.size(1)
# Encoder-Decoder:Encoder 输出的维度可能和 Decoder 要求的输入维度不一致,因此需要转换
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
# Mask token
# 社会提倡这个比例最好是 75%
self.mask_ratio = mask_ratio
# mask token 的实质:1个可学习的共享向量
self.mask_embed = nn.Parameter(torch.randn(decoder_dim))
# Decoder:实质就是多层堆叠的 Transformer
self.decoder = Transformer(
decoder_dim,
decoder_dim * 4,
depth=decoder_depth,
num_heads=num_decoder_heads,
dim_per_head=decoder_dim_per_head,
)
# 在 Decoder 中用作对 mask tokens 的 position embedding
# Filter out cls_token 注意第1个维度去掉 cls_token
self.decoder_pos_embed = nn.Embedding(num_patches_plus_cls_token - 1, decoder_dim)
# Prediction head 输出的维度数等于1个 patch 的像素值数量
self.head = nn.Linear(decoder_dim, num_pixels_per_patch)
num_patches = (h // self.patch_h) * (w // self.patch_w)
# (b, c=3, h, w)->(b, n_patches, patch_size**2 * c)
patches = x.view(
b, c,
h // self.patch_h, self.patch_h,
w // self.patch_w, self.patch_w
).permute(0, 2, 4, 3, 5, 1).reshape(b, num_patches, -1)
# 根据 mask 比例计算需要 mask 掉的 patch 数量
# num_patches = (h // self.patch_h) * (w // self.patch_w)
num_masked = int(self.mask_ratio * num_patches)
# Shuffle:生成对应 patch 的随机索引
# torch.rand() 服从均匀分布(normal distribution)
# torch.rand() 只是生成随机数,argsort() 是为了获得成索引
# (b, n_patches)
shuffle_indices = torch.rand(b, num_patches, device=device).argsort()
# mask 和 unmasked patches 对应的索引
mask_ind, unmask_ind = shuffle_indices[:, :num_masked], shuffle_indices[:, num_masked:]
# 对应 batch 维度的索引:(b,1)
batch_ind = torch.arange(b, device=device).unsqueeze(-1)
# 利用先前生成的索引对 patches 进行采样,分为 mask 和 unmasked 两组
mask_patches, unmask_patches = patches[batch_ind, mask_ind], patches[batch_ind, unmask_ind]
# 将 patches 通过 emebdding 转换成 tokens
unmask_tokens = self.encoder.patch_embed(unmask_patches)
# 为 tokens 加入 position embeddings
# 注意这里索引加1是因为索引0对应 ViT 的 cls_token
unmask_tokens += self.encoder.pos_embed.repeat(b, 1, 1)[batch_ind, unmask_ind + 1]
# 真正的编码过程
encoded_tokens = self.encoder.transformer(unmask_tokens)
# 对编码后的 tokens 维度进行转换,从而符合 Decoder 要求的输入维度
enc_to_dec_tokens = self.enc_to_dec(encoded_tokens)
# 由于 mask token 实质上只有1个,因此要对其进行扩展,从而和 masked patches 一一对应
# (decoder_dim)->(b, n_masked, decoder_dim)
mask_tokens = self.mask_embed[None, None, :].repeat(b, num_masked, 1)
# 为 mask tokens 加入位置信息
mask_tokens += self.decoder_pos_embed(mask_ind)
# 将 mask tokens 与 编码后的 tokens 拼接起来
# (b, n_patches, decoder_dim)
concat_tokens = torch.cat([mask_tokens, enc_to_dec_tokens], dim=1)
# Un-shuffle:恢复原先 patches 的次序
dec_input_tokens = torch.empty_like(concat_tokens, device=device)
dec_input_tokens[batch_ind, shuffle_indices] = concat_tokens
# 将全量 tokens 喂给 Decoder 解码
decoded_tokens = self.decoder(dec_input_tokens)
# 取出解码后的 mask tokens
dec_mask_tokens = decoded_tokens[batch_ind, mask_ind, :]
# 预测 masked patches 的像素值
# (b, n_masked, n_pixels_per_patch=patch_size**2 x c)
pred_mask_pixel_values = self.head(dec_mask_tokens)
# loss 计算
loss = F.mse_loss(pred_mask_pixel_values, mask_patches)
@torch.no_grad
def predict(self, x):
self.eval()
device = x.device
b, c, h, w = x.shape
'''i. Patch partition'''
num_patches = (h // self.patch_h) * (w // self.patch_w)
# (b, c=3, h, w)->(b, n_patches, patch_size**2*c)
patches = x.view(
b, c,
h // self.patch_h, self.patch_h,
w // self.patch_w, self.patch_w
).permute(0, 2, 4, 3, 5, 1).reshape(b, num_patches, -1)
'''ii. Divide into masked & un-masked groups'''
num_masked = int(self.mask_ratio * num_patches)
# Shuffle
# (b, n_patches)
shuffle_indices = torch.rand(b, num_patches, device=device).argsort()
mask_ind, unmask_ind = shuffle_indices[:, :num_masked], shuffle_indices[:, num_masked:]
# (b, 1)
batch_ind = torch.arange(b, device=device).unsqueeze(-1)
mask_patches, unmask_patches = patches[batch_ind, mask_ind], patches[batch_ind, unmask_ind]
'''iii. Encode'''
unmask_tokens = self.encoder.patch_embed(unmask_patches)
# Add position embeddings
unmask_tokens += self.encoder.pos_embed.repeat(b, 1, 1)[batch_ind, unmask_ind + 1]
encoded_tokens = self.encoder.transformer(unmask_tokens)
'''iv. Decode'''
enc_to_dec_tokens = self.enc_to_dec(encoded_tokens)
# (decoder_dim)->(b, n_masked, decoder_dim)
mask_tokens = self.mask_embed[None, None, :].repeat(b, num_masked, 1)
# Add position embeddings
mask_tokens += self.decoder_pos_embed(mask_ind)
# (b, n_patches, decoder_dim)
concat_tokens = torch.cat([mask_tokens, enc_to_dec_tokens], dim=1)
# dec_input_tokens = concat_tokens
dec_input_tokens = torch.empty_like(concat_tokens, device=device)
# Un-shuffle
dec_input_tokens[batch_ind, shuffle_indices] = concat_tokens
decoded_tokens = self.decoder(dec_input_tokens)
'''v. Mask pixel Prediction'''
dec_mask_tokens = decoded_tokens[batch_ind, mask_ind, :]
# (b, n_masked, n_pixels_per_patch=patch_size**2 x c)
pred_mask_pixel_values = self.head(dec_mask_tokens)
# 比较下预测值和真实值
mse_per_patch = (pred_mask_pixel_values - mask_patches).abs().mean(dim=-1)
mse_all_patches = mse_per_patch.mean()
print(f'mse per (masked)patch: {mse_per_patch} mse all (masked)patches: {mse_all_patches} total {num_masked} masked patches')
print(f'all close: {torch.allclose(pred_mask_pixel_values, mask_patches, rtol=1e-1, atol=1e-1)}')
'''vi. Reconstruction'''
recons_patches = patches.detach()
# Un-shuffle (b, n_patches, patch_size**2 * c)
recons_patches[batch_ind, mask_ind] = pred_mask_pixel_values
# 模型重建的效果图
# Reshape back to image
# (b, n_patches, patch_size**2 * c)->(b, c, h, w)
recons_img = recons_patches.view(
b, h // self.patch_h, w // self.patch_w,
self.patch_h, self.patch_w, c
).permute(0, 5, 1, 3, 2, 4).reshape(b, c, h, w)
mask_patches = torch.randn_like(mask_patches, device=mask_patches.device)
# mask 效果图
patches[batch_ind, mask_ind] = mask_patches
patches_to_img = patches.view(
b, h // self.patch_h, w // self.patch_w,
self.patch_h, self.patch_w, c
).permute(0, 5, 1, 3, 2, 4).reshape(b, c, h, w)
return recons_img, patches_to_img
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 读入图像并缩放到适合模型输入的尺寸
from PIL import Image
img_raw = Image.open(os.path.join(BASE_DIR, 'mountain.jpg'))
h, w = img_raw.height, img_raw.width
ratio = h / w
print(f"image hxw: {h} x {w} mode: {img_raw.mode}")
img_size, patch_size = (224, 224), (16, 16)
img = img_raw.resize(img_size)
rh, rw = img.height, img.width
print(f'resized image hxw: {rh} x {rw} mode: {img.mode}')
img.save(os.path.join(BASE_DIR, 'resized_mountain.jpg'))
# 将图像转换成张量
from torchvision.transforms import ToTensor, ToPILImage
img_ts = ToTensor()(img).unsqueeze(0).to(device)
print(f"input tensor shape: {img_ts.shape} dtype: {img_ts.dtype} device: {img_ts.device}")
# 实例化模型并加载训练好的权重
encoder = ViT(img_size, patch_size, dim=512, mlp_dim=1024, dim_per_head=64)
decoder_dim = 512
mae = MAE(encoder, decoder_dim, decoder_depth=6)
weight = torch.load(os.path.join(BASE_DIR, 'mae.pth'), map_location='cpu')
mae.to(device)
# 推理
# 模型重建的效果图,mask 效果图
recons_img_ts, masked_img_ts = mae.predict(img_ts)
recons_img_ts, masked_img_ts = recons_img_ts.cpu().squeeze(0), masked_img_ts.cpu().squeeze(0)
# 将结果保存下来以便和原图比较
recons_img = ToPILImage()(recons_img_ts)
recons_img.save(os.path.join(BASE_DIR, 'recons_mountain.jpg'))
masked_img = ToPILImage()(masked_img_ts)
masked_img.save(os.path.join(BASE_DIR, 'masked_mountain.jpg'))
import torch
import torch.nn as nn
def to_pair(t):
return t if isinstance(t, tuple) else (t, t)
class PreNorm(nn.Module):
def __init__(self, dim, net):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.net = net
def forward(self, x, **kwargs):
return self.net(self.norm(x), **kwargs)
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads=8, dim_per_head=64, dropout=0.):
super().__init__()
self.num_heads = num_heads
self.scale = dim_per_head ** -0.5
inner_dim = dim_per_head * num_heads
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.attend = nn.Softmax(dim=-1)
project_out = not (num_heads == 1 and dim_per_head == dim)
self.out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, l, d = x.shape
'''i. QKV projection'''
# (b,l,dim_all_heads x 3)
qkv = self.to_qkv(x)
# (3,b,num_heads,l,dim_per_head)
qkv = qkv.view(b, l, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4).contiguous()
# 3 x (1,b,num_heads,l,dim_per_head)
q, k, v = qkv.chunk(3)
q, k, v = q.squeeze(0), k.squeeze(0), v.squeeze(0)
'''ii. Attention computation'''
attn = self.attend(
torch.matmul(q, k.transpose(-1, -2)) * self.scale
)
'''iii. Put attention on Value & reshape'''
# (b,num_heads,l,dim_per_head)
z = torch.matmul(attn, v)
# (b,num_heads,l,dim_per_head)->(b,l,num_heads,dim_per_head)->(b,l,dim_all_heads)
z = z.transpose(1, 2).reshape(b, l, -1)
# assert z.size(-1) == q.size(-1) * self.num_heads
'''iv. Project out'''
# (b,l,dim_all_heads)->(b,l,dim)
out = self.out(z)
# assert out.size(-1) == d
return out
class FFN(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(p=dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(p=dropout)
)
def forward(self, x):
return self.net(x)
class Transformer(nn.Module):
def __init__(self, dim, mlp_dim, depth=6, num_heads=8, dim_per_head=64, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, SelfAttention(dim, num_heads=num_heads, dim_per_head=dim_per_head, dropout=dropout)),
PreNorm(dim, FFN(dim, mlp_dim, dropout=dropout))
]))
def forward(self, x):
for norm_attn, norm_ffn in self.layers:
x = x + norm_attn(x)
x = x + norm_ffn(x)
return x
class ViT(nn.Module):
def __init__(
self, image_size, patch_size,
num_classes=1000, dim=1024, depth=6, num_heads=8, mlp_dim=2048,
pool='cls', channels=3, dim_per_head=64, dropout=0., embed_dropout=0.
):
super().__init__()
img_h, img_w = to_pair(image_size)
self.patch_h, self.patch_w = to_pair(patch_size)
assert not img_h % self.patch_h and not img_w % self.patch_w, \
f'Image dimensions ({img_h},{img_w}) must be divisible by the patch size ({self.patch_h},{self.patch_w}).'
num_patches = (img_h // self.patch_h) * (img_w // self.patch_w)
assert pool in {'cls', 'mean'}, f'pool type must be either cls (cls token) or mean (mean pooling), got: {pool}'
patch_dim = channels * self.patch_h * self.patch_w
self.patch_embed = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# Add 1 for cls_token
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.dropout = nn.Dropout(p=embed_dropout)
self.transformer = Transformer(
dim, mlp_dim, depth=depth, num_heads=num_heads,
dim_per_head=dim_per_head, dropout=dropout
)
self.pool = pool
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, x):
b, c, img_h, img_w = x.shape
assert not img_h % self.patch_h and not img_w % self.patch_w, \
f'Input image dimensions ({img_h},{img_w}) must be divisible by the patch size ({self.patch_h},{self.patch_w}).'
'''i. Patch partition'''
num_patches = (img_h // self.patch_h) * (img_w // self.patch_w)
# (b,c,h,w)->(b,n_patches,patch_h*patch_w*c)
patches = x.view(
b, c,
img_h // self.patch_h, self.patch_h,
img_w // self.patch_w, self.patch_w
).permute(0, 2, 4, 3, 5, 1).reshape(b, num_patches, -1)
'''ii. Patch embedding'''
# (b,n_patches,dim)
tokens = self.patch_embed(patches)
# (b,n_patches+1,dim)
tokens = torch.cat([self.cls_token.repeat(b, 1, 1), tokens], dim=1)
tokens += self.pos_embed[:, :(num_patches + 1)]
tokens = self.dropout(tokens)
'''iii. Transformer Encoding'''
enc_tokens = self.transformer(tokens)
'''iv. Pooling'''
# (b,dim)
pooled = enc_tokens[:, 0] if self.pool == 'cls' else enc_tokens.mean(dim=1)
'''v. Classification'''
# (b,n_classes)
logits = self.mlp_head(pooled)
return logits
特别鸣谢
感谢 TCCI 天桥脑科学研究院对于 PaperWeekly 的支持。TCCI 关注大脑探知、大脑功能和大脑健康。
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧