极市导读
本文将带你感受einsum的“万能”,作者通过提供从基础到高级的einsum使用范例,展示了它是怎么做到既简洁又优雅地实现多种张量操作,并轻易解决维度匹配问题。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
C = torch.einsum(
"ik,kj->ij",A,B)
import torch
A = torch.tensor([[
1,
2],[
3,
4.0]])
B = torch.tensor([[
5,
6],[
7,
8.0]])
C1 = A@B
print(C1)
C2 = torch.einsum(
"ik,kj->ij",[A,B])
print(C2)
tensor([[
19.,
22.],
[
43.,
50.]])
tensor([[
19.,
22.],
[
43.,
50.]])
#例1,张量转置
A = torch.randn(
3,
4,
5)
#B = torch.permute(A,[0,2,1])
B = torch.einsum(
"ijk->ikj",A)
print(
"before:",A.shape)
print(
"after:",B.shape)
before: torch.Size([
3,
4,
5])
after: torch.Size([
3,
5,
4])
#例2,取对角元
A = torch.randn(
5,
5)
#B = torch.diagonal(A)
B = torch.einsum(
"ii->i",A)
print(
"before:",A.shape)
print(
"after:",B.shape)
before: torch.Size([
5,
5])
after: torch.Size([
5])
#例3,求和降维
A = torch.randn(
4,
5)
#B = torch.sum(A,1)
B = torch.einsum(
"ij->i",A)
print(
"before:",A.shape)
print(
"after:",B.shape)
before: torch.Size([
4,
5])
after: torch.Size([
4])
#例4,哈达玛积
A = torch.randn(
5,
5)
B = torch.randn(
5,
5)
#C=A*B
C = torch.einsum(
"ij,ij->ij",A,B)
print(
"before:",A.shape, B.shape)
print(
"after:",C.shape)
before: torch.Size([
5,
5]) torch.Size([
5,
5])
after: torch.Size([
5,
5])
#例5,向量内积
A = torch.randn(
10)
B = torch.randn(
10)
#C=torch.dot(A,B)
C = torch.einsum(
"i,i->",A,B)
print(
"before:",A.shape, B.shape)
print(
"after:",C.shape)
before: torch.Size([
10]) torch.Size([
10])
after: torch.Size([])
#例6,向量外积
A = torch.randn(
10)
B = torch.randn(
5)
#C = torch.outer(A,B)
C = torch.einsum(
"i,j->ij",A,B)
print(
"before:",A.shape, B.shape)
print(
"after:",C.shape)
before: torch.Size([
10]) torch.Size([
5])
after: torch.Size([
10,
5])
#例7,矩阵乘法
A = torch.randn(
5,
4)
B = torch.randn(
4,
6)
#C = torch.matmul(A,B)
C = torch.einsum(
"ik,kj->ij",A,B)
print(
"before:",A.shape, B.shape)
print(
"after:",C.shape)
before: torch.Size([
5,
4]) torch.Size([
4,
6])
after: torch.Size([
5,
6])
#例8,张量缩并
A = torch.randn(
3,
4,
5)
B = torch.randn(
4,
3,
6)
#C = torch.tensordot(A,B,dims=[(0,1),(1,0)])
C = torch.einsum(
"ijk,jih->kh",A,B)
print(
"before:",A.shape, B.shape)
print(
"after:",C.shape)
before: torch.Size([
3,
4,
5]) torch.Size([
4,
3,
6])
after: torch.Size([
5,
6])
#例9,bilinear注意力机制
#====不考虑batch维度====
q = torch.randn(
10)
#query_features
k = torch.randn(
10)
#key_features
W = torch.randn(
5,
10,
10)
#out_features,query_features,key_features
b = torch.randn(
5)
#out_features
#a = q@W@k.t()+b
a = torch.bilinear(q,k,W,b)
print(
"a.shape:",a.shape)
#=====考虑batch维度====
Q = torch.randn(
8,
10)
#batch_size,query_features
K = torch.randn(
8,
10)
#batch_size,key_features
W = torch.randn(
5,
10,
10)
#out_features,query_features,key_features
b = torch.randn(
5)
#out_features
#A = torch.bilinear(Q,K,W,b)
A = torch.einsum(
'bq,oqk,bk->bo',Q,W,K) + b
print(
"A.shape:",A.shape)
a.shape: torch.Size([
5])
A.shape: torch.Size([
8,
5])
#例10,scaled-dot-product注意力机制
#====不考虑batch维度====
q = torch.randn(
10)
#query_features
k = torch.randn(
6,
10)
#key_size, key_features
d_k = k.shape[
-1]
a = torch.softmax(q@k.t()/d_k,
-1)
print(
"a.shape=",a.shape )
#====考虑batch维度====
Q = torch.randn(
8,
10)
#batch_size,query_features
K = torch.randn(
8,
6,
10)
#batch_size,key_size,key_features
d_k = K.shape[
-1]
A = torch.softmax(torch.einsum(
"in,ijn->ij",Q,K)/d_k,
-1)
print(
"A.shape=",A.shape )
a.shape= torch.Size([
6])
A.shape= torch.Size([
8,
6])
公众号后台回复“ECCV2022”获取论文分类资源下载~
“
点击阅读原文进入CV社区
收获更多技术干货