这个项目登上了今天的GitHub Trending。
$ pip install vit-pytorch
import torch
from vit_pytorch import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
image_size:图像大小
patch_size:patch 数量
num_classes:分类类别的数量
dim:线性变换 nn.Linear(..., dim) 后输出张量的最后维
depth:Transformer 块的数量
heads:多头注意力层中头的数量
mlp_dim:MLP(前馈)层的维数
channels:图像通道的数量
dropout:Dropout rate
emb_dropout:嵌入 dropout rate
……
import torchfrom torchvision.models import resnet50from vit_pytorch.distill import DistillableViT, DistillWrapperteacher = resnet50(pretrained = True)
v = DistillableViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
distiller = DistillWrapper(
student = v,
teacher = teacher,
temperature = 3, # temperature of distillationalpha = 0.5, # trade between main loss and distillation losshard = False # whether to use soft or hard distillation
)
img = torch.randn(2, 3, 256, 256)labels = torch.randint(0, 1000, (2,))
loss = distiller(img, labels)loss.backward()
# after lots of training above ...pred = v(img) # (2, 1000)
© THE END
转载请联系本公众号获得授权
投稿或寻求报道:content@jiqizhixin.com