极市导读
本文将带你感受einsum的“万能”,作者通过提供从基础到高级的einsum使用范例,展示了它是怎么做到既简洁又优雅地实现多种张量操作,并轻易解决维度匹配问题。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
C = torch.einsum( "ik,kj->ij",A,B)
import torch
A = torch.tensor([[ 1, 2],[ 3, 4.0]])
B = torch.tensor([[ 5, 6],[ 7, 8.0]])
C1 = A@B
print(C1)
C2 = torch.einsum( "ik,kj->ij",[A,B])
print(C2)
tensor([[ 19., 22.],
[ 43., 50.]])
tensor([[ 19., 22.],
[ 43., 50.]])
#例1,张量转置
A = torch.randn( 3, 4, 5)
#B = torch.permute(A,[0,2,1])
B = torch.einsum( "ijk->ikj",A)
print( "before:",A.shape)
print( "after:",B.shape)
before: torch.Size([ 3, 4, 5])
after: torch.Size([ 3, 5, 4])
#例2,取对角元
A = torch.randn( 5, 5)
#B = torch.diagonal(A)
B = torch.einsum( "ii->i",A)
print( "before:",A.shape)
print( "after:",B.shape)
before: torch.Size([ 5, 5])
after: torch.Size([ 5])
#例3,求和降维
A = torch.randn( 4, 5)
#B = torch.sum(A,1)
B = torch.einsum( "ij->i",A)
print( "before:",A.shape)
print( "after:",B.shape)
before: torch.Size([ 4, 5])
after: torch.Size([ 4])
#例4,哈达玛积
A = torch.randn( 5, 5)
B = torch.randn( 5, 5)
#C=A*B
C = torch.einsum( "ij,ij->ij",A,B)
print( "before:",A.shape, B.shape)
print( "after:",C.shape)
before: torch.Size([ 5, 5]) torch.Size([ 5, 5])
after: torch.Size([ 5, 5])
#例5,向量内积
A = torch.randn( 10)
B = torch.randn( 10)
#C=torch.dot(A,B)
C = torch.einsum( "i,i->",A,B)
print( "before:",A.shape, B.shape)
print( "after:",C.shape)
before: torch.Size([ 10]) torch.Size([ 10])
after: torch.Size([])
#例6,向量外积
A = torch.randn( 10)
B = torch.randn( 5)
#C = torch.outer(A,B)
C = torch.einsum( "i,j->ij",A,B)
print( "before:",A.shape, B.shape)
print( "after:",C.shape)
before: torch.Size([ 10]) torch.Size([ 5])
after: torch.Size([ 10, 5])
#例7,矩阵乘法
A = torch.randn( 5, 4)
B = torch.randn( 4, 6)
#C = torch.matmul(A,B)
C = torch.einsum( "ik,kj->ij",A,B)
print( "before:",A.shape, B.shape)
print( "after:",C.shape)
before: torch.Size([ 5, 4]) torch.Size([ 4, 6])
after: torch.Size([ 5, 6])
#例8,张量缩并
A = torch.randn( 3, 4, 5)
B = torch.randn( 4, 3, 6)
#C = torch.tensordot(A,B,dims=[(0,1),(1,0)])
C = torch.einsum( "ijk,jih->kh",A,B)
print( "before:",A.shape, B.shape)
print( "after:",C.shape)
before: torch.Size([ 3, 4, 5]) torch.Size([ 4, 3, 6])
after: torch.Size([ 5, 6])
#例9,bilinear注意力机制
#====不考虑batch维度====
q = torch.randn( 10) #query_features
k = torch.randn( 10) #key_features
W = torch.randn( 5, 10, 10) #out_features,query_features,key_features
b = torch.randn( 5) #out_features
#a = q@W@k.t()+b
a = torch.bilinear(q,k,W,b)
print( "a.shape:",a.shape)
#=====考虑batch维度====
Q = torch.randn( 8, 10) #batch_size,query_features
K = torch.randn( 8, 10) #batch_size,key_features
W = torch.randn( 5, 10, 10) #out_features,query_features,key_features
b = torch.randn( 5) #out_features
#A = torch.bilinear(Q,K,W,b)
A = torch.einsum( 'bq,oqk,bk->bo',Q,W,K) + b
print( "A.shape:",A.shape)
a.shape: torch.Size([ 5])
A.shape: torch.Size([ 8, 5])
#例10,scaled-dot-product注意力机制
#====不考虑batch维度====
q = torch.randn( 10) #query_features
k = torch.randn( 6, 10) #key_size, key_features
d_k = k.shape[ -1]
a = torch.softmax(q@k.t()/d_k, -1)
print( "a.shape=",a.shape )
#====考虑batch维度====
Q = torch.randn( 8, 10) #batch_size,query_features
K = torch.randn( 8, 6, 10) #batch_size,key_size,key_features
d_k = K.shape[ -1]
A = torch.softmax(torch.einsum( "in,ijn->ij",Q,K)/d_k, -1)
print( "A.shape=",A.shape )
a.shape= torch.Size([ 6])
A.shape= torch.Size([ 8, 6])
公众号后台回复“ECCV2022”获取论文分类资源下载~
“
点击阅读原文进入CV社区
收获更多技术干货