einsum is all you needed!

2022 年 7 月 27 日 极市平台
↑ 点击 蓝字  关注极市平台

作者丨吃货本货@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/542625230
编辑丨极市平台

极市导读

 

本文将带你感受einsum的“万能”,作者通过提供从基础到高级的einsum使用范例,展示了它是怎么做到既简洁又优雅地实现多种张量操作,并轻易解决维度匹配问题。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

引入

如果问pytorch中最强大的一个数学函数是什么?

我会说是torch.einsum:爱因斯坦求和函数。它几乎是一个"万能函数":能实现超过一万种功能的函数。
不仅如此,和其它pytorch中的函数一样,torch.einsum是支持求导和反向传播的,并且计算效率非常高。
einsum 提供了一套既简洁又优雅的规则,可实现包括但不限于:内积,外积,矩阵乘法,转置和张量收缩(tensor contraction)等张量操作,熟练掌握 einsum 可以很方便的实现复杂的张量操作,而且不容易出错。
尤其是在一些包括batch维度的高阶张量的相关计算中,若使用普通的矩阵乘法、求和、转置等算子来实现很容易出现维度匹配等问题,但换成einsum则会特别简单。
套用一句深度学习paper标题当中非常时髦的话术, einsum is all you needed !
本文源码路径:
https://github.com/lyhue1991/eat_pytorch_in_20_days/blob/master/4-2,%E5%BC%A0%E9%87%8F%E7%9A%84%E6%95%B0%E5%AD%A6%E8%BF%90%E7%AE%97.md

一,einsum规则原理

顾名思义,einsum这个函数的思想起源于家喻户晓的小爱同学:爱因斯坦。
很久很久以前,小爱同学在捣鼓广义相对论。广义相对论表述各种物理量用的都是张量。比如描述时空有一个四维时空度规张量,描述电磁场有一个电磁张量,描述运动的有能量动量张量。
在理论物理学家中,小爱同学的数学基础不算特别好,在捣鼓这些张量的时候,他遇到了一个比较头疼的问题:公式太长太复杂了。有没有什么办法让这些张量运算公式稍微显得对人类友好一些呢,能不能减少一些那种扭曲的求和符号呢?
小爱发现,求和导致维度收缩,因此求和符号操作的指标总是只出现在公式的一边。 例如在我们熟悉的矩阵乘法中
k这个下标被求和了,求和导致了这个维度的消失,所以它只出现在右边而不出现在左边。 这种只出现在张量公式的一边的下标被称之为哑指标,反之为自由指标。
小爱同学脑瓜子滴溜一转,反正这种只出现在一边的哑指标一定是被求和求掉的,干脆把对应的求和符号省略得了。
这就是爱因斯坦求和约定:
只出现在公式一边的指标叫做哑指标,针对哑指标的求和符号可以省略。
公式立刻清爽了很多。
这个公式表达的含义如下:
C这个张量的第i行第 列由 这个张量的第 行第 列和 这个张量的第 行第j列相乘,这样得到的是 一个三维张量 ,其元素为 ,然后对 在维度 上求和得到。
公式展现形式中除了省去了求和符号,还省去了乘法符号(代数通识)。
借鉴爱因斯坦求和约定表达张量运算的清爽整洁,numpy、tensorflow和 torch等库中都引入了 einsum这个函数。
上述矩阵乘法可以被einsum这个函数表述成
  
  
    
C = torch.einsum( "ik,kj->ij",A,B)
这个函数的规则原理非常简洁,3句话说明白。
  • 1,用元素计算公式来表达张量运算。
  • 2,只出现在元素计算公式箭头左边的指标叫做哑指标。
  • 3,省略元素计算公式中对哑指标的求和符号。
  
  
    
import torch 

A = torch.tensor([[ 1, 2],[ 3, 4.0]])
B = torch.tensor([[ 5, 6],[ 7, 8.0]])

C1 = A@B
print(C1)

C2 = torch.einsum( "ik,kj->ij",[A,B])
print(C2)

tensor([[ 19.22.],
        [ 43.50.]])
tensor([[ 19.22.],
        [ 43.50.]])

二,einsum基础范例

einsum这个函数的精髓实际上是第一条: 用元素计算公式来表达张量运算。
而绝大部分张量运算都可以用元素计算公式很方便地来表达,这也是它为什么会那么神通广大。

例1,张量转置

  
  
    
#例1,张量转置
A = torch.randn( 3, 4, 5)

#B = torch.permute(A,[0,2,1])
B = torch.einsum( "ijk->ikj",A) 

print( "before:",A.shape)
print( "after:",B.shape)

before: torch.Size([ 345])
after: torch.Size([ 354])

例2,取对角元

  
  
    
#例2,取对角元
A = torch.randn( 5, 5)
#B = torch.diagonal(A)
B = torch.einsum( "ii->i",A)
print( "before:",A.shape)
print( "after:",B.shape)

before: torch.Size([ 55])
after: torch.Size([ 5])

例3,求和降维

  
  
    
#例3,求和降维
A = torch.randn( 4, 5)
#B = torch.sum(A,1)
B = torch.einsum( "ij->i",A)
print( "before:",A.shape)
print( "after:",B.shape)

before: torch.Size([ 45])
after: torch.Size([ 4])

例4,哈达玛积

  
  
    
#例4,哈达玛积
A = torch.randn( 5, 5)
B = torch.randn( 5, 5)
#C=A*B
C = torch.einsum( "ij,ij->ij",A,B)
print( "before:",A.shape, B.shape)
print( "after:",C.shape)

before: torch.Size([ 55]) torch.Size([ 55])
after: torch.Size([ 55])

例5,向量内积

  
  
    
#例5,向量内积
A = torch.randn( 10)
B = torch.randn( 10)
#C=torch.dot(A,B)
C = torch.einsum( "i,i->",A,B)
print( "before:",A.shape, B.shape)
print( "after:",C.shape)

before: torch.Size([ 10]) torch.Size([ 10])
after: torch.Size([])

例6,向量外积

  
  
    
#例6,向量外积
A = torch.randn( 10)
B = torch.randn( 5)
#C = torch.outer(A,B)
C = torch.einsum( "i,j->ij",A,B)
print( "before:",A.shape, B.shape)
print( "after:",C.shape)

before: torch.Size([ 10]) torch.Size([ 5])
after: torch.Size([ 105])

例7,矩阵乘法

  
  
    
#例7,矩阵乘法
A = torch.randn( 5, 4)
B = torch.randn( 4, 6)
#C = torch.matmul(A,B)
C = torch.einsum( "ik,kj->ij",A,B)
print( "before:",A.shape, B.shape)
print( "after:",C.shape)

before: torch.Size([ 54]) torch.Size([ 46])
after: torch.Size([ 56])

例8,张量缩并

  
  
    
#例8,张量缩并
A = torch.randn( 3, 4, 5)
B = torch.randn( 4, 3, 6)
#C = torch.tensordot(A,B,dims=[(0,1),(1,0)])
C = torch.einsum( "ijk,jih->kh",A,B)
print( "before:",A.shape, B.shape)
print( "after:",C.shape)

before: torch.Size([ 345]) torch.Size([ 436])
after: torch.Size([ 56])

三,einsum高级范例

einsum可用于超过两个张量的计算。

例9,bilinear注意力机制

例如:双线性变换。这是向量内积的一种扩展,一种常用的注意力机制实现方式
不考虑batch维度时,双线性变换的公式如下:
考虑batch维度时,无法用矩阵乘法表示,可以用元素计算公式表达如下:
  
  
    
#例9,bilinear注意力机制

#====不考虑batch维度====
q = torch.randn( 10#query_features
k = torch.randn( 10#key_features
W = torch.randn( 5, 10, 10#out_features,query_features,key_features
b = torch.randn( 5#out_features

#a = q@W@k.t()+b  
a = torch.bilinear(q,k,W,b)
print( "a.shape:",a.shape)


#=====考虑batch维度====
Q = torch.randn( 8, 10)     #batch_size,query_features
K = torch.randn( 8, 10)     #batch_size,key_features
W = torch.randn( 5, 10, 10#out_features,query_features,key_features
b = torch.randn( 5)        #out_features

#A = torch.bilinear(Q,K,W,b)
A = torch.einsum( 'bq,oqk,bk->bo',Q,W,K) + b
print( "A.shape:",A.shape)


a.shape: torch.Size([ 5])
A.shape: torch.Size([ 85])

例10,scaled-dot-product注意力机制

我们也可以用einsum来实现更常见的scaled-dot-product 形式的 Attention.
不考虑batch维度时, scaled-dot-product形式的Attention用矩阵乘法公式表示如下:
考虑batch维度时,无法用矩阵乘法表示,可以用元素计算公式表达如下:
  
  
    
#例10,scaled-dot-product注意力机制

#====不考虑batch维度====
q = torch.randn( 10)   #query_features
k = torch.randn( 6, 10#key_size, key_features

d_k = k.shape[ -1]
a = torch.softmax(q@k.t()/d_k, -1

print( "a.shape=",a.shape )

#====考虑batch维度====
Q = torch.randn( 8, 10)   #batch_size,query_features
K = torch.randn( 8, 6, 10#batch_size,key_size,key_features

d_k = K.shape[ -1]
A = torch.softmax(torch.einsum( "in,ijn->ij",Q,K)/d_k, -1

print( "A.shape=",A.shape )


a.shape= torch.Size([ 6])
A.shape= torch.Size([ 86])


公众号后台回复“ECCV2022”获取论文分类资源下载~

△点击卡片关注极市平台,获取 最新CV干货

极市干货
算法项目: CV工业项目落地实战 目标检测算法上新!(年均分成5万)
实操教程 Pytorch - 弹性训练原理分析《CUDA C 编程指南》导读
极视角动态: 极视角作为重点项目入选「2022青岛十大资本青睐企业」榜单! 极视角发布EQP激励计划,招募优质算法团队展开多维度生态合作! 极市AI校园大使招募


点击阅读原文进入CV社区

收获更多技术干货

登录查看更多
1

相关内容

基于Lua语言的深度学习框架 github.com/torch
【干货书】Python科学编程,451页pdf
专知会员服务
127+阅读 · 2021年6月27日
【知识图谱@EMNLP2020】Knowledge Graphs in NLP @ EMNLP 2020
专知会员服务
42+阅读 · 2020年11月22日
【经典书】算法C语言实现,Algorithms in C. 672页pdf
专知会员服务
81+阅读 · 2020年8月13日
【干货书】R语言书: 编程和统计的第一课程,
专知会员服务
111+阅读 · 2020年5月9日
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
163+阅读 · 2019年10月28日
半精度(FP16)调试血泪总结
CVer
4+阅读 · 2022年5月31日
一文读懂 Pytorch 中的 Tensor View 机制
极市平台
0+阅读 · 2022年1月30日
谈谈自动微分(Automatic Differentiation)
PaperWeekly
1+阅读 · 2022年1月3日
Tensorrt踩坑日记 | python、pytorch 转 onnx 推理加速
极市平台
15+阅读 · 2021年12月24日
PyTorch 对类别张量进行 one-hot 编码
极市平台
0+阅读 · 2021年12月18日
嘿,同学!田厂对你很好奇!
微软招聘
0+阅读 · 2021年11月5日
YOLOv3:An Incremental Improvement 全文翻译
极市平台
12+阅读 · 2018年3月28日
一文读懂「Attention is All You Need」| 附代码实现
PaperWeekly
37+阅读 · 2018年1月10日
论文共读 | Attention is All You Need
黑龙江大学自然语言处理实验室
14+阅读 · 2017年9月7日
国家自然科学基金
0+阅读 · 2017年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
2+阅读 · 2014年12月31日
国家自然科学基金
5+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
1+阅读 · 2013年12月31日
国家自然科学基金
3+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
Arxiv
0+阅读 · 2022年10月2日
Arxiv
0+阅读 · 2022年9月30日
Arxiv
0+阅读 · 2022年9月30日
Arxiv
0+阅读 · 2022年9月28日
Arxiv
27+阅读 · 2017年12月6日
VIP会员
相关资讯
半精度(FP16)调试血泪总结
CVer
4+阅读 · 2022年5月31日
一文读懂 Pytorch 中的 Tensor View 机制
极市平台
0+阅读 · 2022年1月30日
谈谈自动微分(Automatic Differentiation)
PaperWeekly
1+阅读 · 2022年1月3日
Tensorrt踩坑日记 | python、pytorch 转 onnx 推理加速
极市平台
15+阅读 · 2021年12月24日
PyTorch 对类别张量进行 one-hot 编码
极市平台
0+阅读 · 2021年12月18日
嘿,同学!田厂对你很好奇!
微软招聘
0+阅读 · 2021年11月5日
YOLOv3:An Incremental Improvement 全文翻译
极市平台
12+阅读 · 2018年3月28日
一文读懂「Attention is All You Need」| 附代码实现
PaperWeekly
37+阅读 · 2018年1月10日
论文共读 | Attention is All You Need
黑龙江大学自然语言处理实验室
14+阅读 · 2017年9月7日
相关基金
国家自然科学基金
0+阅读 · 2017年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
2+阅读 · 2014年12月31日
国家自然科学基金
5+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
1+阅读 · 2013年12月31日
国家自然科学基金
3+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
相关论文
Arxiv
0+阅读 · 2022年10月2日
Arxiv
0+阅读 · 2022年9月30日
Arxiv
0+阅读 · 2022年9月30日
Arxiv
0+阅读 · 2022年9月28日
Arxiv
27+阅读 · 2017年12月6日
Top
微信扫码咨询专知VIP会员