Self-Attention GAN 中的 self-attention 机制

2019 年 3 月 6 日 PaperWeekly


作者丨尹相楠

学校丨里昂中央理工博士在读

研究方向丨人脸识别、对抗生成网络




Self Attention GAN 用到了很多新的技术。最大的亮点当然是 self-attention 机制,该机制是 Non-local Neural Networks [1] 这篇文章提出的。其作用是能够更好地学习到全局特征之间的依赖关系。因为传统的 GAN 模型很容易学习到纹理特征:如皮毛,天空,草地等,不容易学习到特定的结构和几何特征,例如狗有四条腿,既不能多也不能少。 


除此之外,文章还用到了 Spectral Normalization for GANs [2] 提出的谱归一化。谱归一化的解释见本人这篇文章:详解GAN的谱归一化(Spectral Normalization)


但是,该文代码中的谱归一化和原始的谱归一化运用方式略有差别: 


1. 原始的谱归一化基于 W-GAN 的理论,只用在 Discriminator 中,用以约束 Discriminator 函数为 1-Lipschitz 连续。而在 Self-Attention GAN 中,Spectral Normalization 同时出现在了 Discriminator 和 Generator 中,用于使梯度更稳定。除了生成器和判别器的最后一层外,每个卷积/反卷积单元都会上一个 SpectralNorm。 


2. 当把谱归一化用在 Generator 上时,同时还保留了 BatchNorm。Discriminator 上则没有 BatchNorm,只有 SpectralNorm。 


3. 谱归一化用在 Discriminator 上时最后一层不加 Spectral Norm。 


最后,self-attention GAN 还用到了 cGANs With Projection Discriminator 提出的 conditional normalizationprojection in the discriminator。这两个技术我还没有来得及看,而且 PyTorch 版本的 self-attention GAN 代码中也没有实现,就先不管它们了。


本文主要说的是 self-attention 这部分内容。


 图1. Self-Attention


Self-Attention


在卷积神经网络中,每个卷积核的尺寸都是很有限的(基本上不会大于 5),因此每次卷积操作只能覆盖像素点周围很小一块邻域。


对于距离较远的特征,例如狗有四条腿这类特征,就不容易捕获到了(也不是完全捕获不到,因为多层的卷积、池化操作会把 feature map 的高和宽变得越来越小,越靠后的层,其卷积核覆盖的区域映射回原图对应的面积越大。但总而言之,毕竟还得需要经过多层映射,不够直接)。


Self-Attention 通过直接计算图像中任意两个像素点之间的关系,一步到位地获取图像的全局几何特征。 


论文中的公式不够直观,我们直接看文章的 PyTorch 的代码,核心部分为 sagan_models.py:


class Self_Attn(nn.Module):
    """ Self attention Layer"""
    def __init__(self,in_dim,activation):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation

        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1#
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """

        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1# B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N) 
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)

        out = self.gamma*out + x
        return out,attention


构造函数中定义了三个 1 × 1 的卷积核,分别被命名为 query_conv key_conv value_conv


为啥命名为这三个名字呢?这和作者给它们赋予的含义有关。query 意为查询,我们希望输入一个像素点,查询(计算)到 feature map 上所有像素点对这一点的影响。而 key 代表字典中的键,相当于所查询的数据库。query 和 key 都是输入的 feature map,可以看成把 feature map 复制了两份,一份作为 query 一份作为 key。 


需要用一个什么样的函数,才能针对 query 的 feature map 中的某一个位置,计算出 key 的 feature map 中所有位置对它的影响呢?作者认为这个函数应该是可以通过“学习”得到的。那么,自然而然就想到要对这两个 feature map 分别做卷积核为 1 × 1 的卷积了,因为卷积核的权重是可以学习得到的。 


至于 value_conv ,可以看成对原 feature map 多加了一层卷积映射,这样可以学习到的参数就更多了,否则 query_conv key_conv 的参数太少,按代码中只有 in_dims × in_dims//8 个。 


接下来逐行研究 forward 函数:


proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1)


这行代码先对输入的 feature map 卷积了一次,相当于对 query feature map 做了一次投影,所以叫做 proj_query。由于是 1 × 1 的卷积,所以不改变 feature map 的长和宽。feature map 的每个通道为如 (1) 所示的矩阵,矩阵共有 N 个元素(像素)。



然后重新改变了输出的维度,变成:


 (m_batchsize,-1,width*height) 


batch size 保持不变,width 和 height 融合到一起,把如 (1) 所示二维的 feature map 每个 channel 拉成一个长度为 N 的向量。


因此,如果 m_batchsize 取 1,即单独观察一个样本,该操作的结果是得到一个矩阵,矩阵的的行数为 query_conv 卷积输出的 channel 的数目 C( in_dim//8 ),列数为 feature map 像素数 N。


然后作者又通过 .permute(0, 2, 1) 转置了矩阵,矩阵的行数变成了 feature map 的像素数 N,列数变成了通道数 C。因此矩阵维度为 N × C 。该矩阵每行代表一个像素位置上所有通道的值,每列代表某个通道中所有的像素值。


 图2. proj_query 的维度


proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height)


这行代码和上一行类似,只不过取消了转置操作。得到的矩阵行数为通道数 C,列数为像素数 N,即矩阵维度为 C × N。该矩阵每行代表一个通道中所有的像素值,每列代表一个像素位置上所有通道的值。


 图3. proj_key的维度


energy =  torch.bmm(proj_query,proj_key)


这行代码中, torch.bmm 的意思是 batch matrix multiplication。就是说把相同 batch size 的两组 matrix 一一对应地做矩阵乘法,最后得到同样 batchsize 的新矩阵。


若 batch size=1,就是普通的矩阵乘法。已知 proj_query 维度是 N × C, proj_key 的维度是 C × N,因此 energy 的维度是 N × N:


 图4. energy的维度


energy 是 attention 的核心,其中第 i 行 j 列的元素,是由 proj_query 第 i 行,和 proj_key 第 j 列通过向量点乘得到的。而 proj_query 第 i 行表示的是 feature map 上第 i 个像素位置上所有通道的值,也就是第 i 个像素位置的所有信息,而 proj_key 第 j 列表示的是 feature map 上第 j 个像素位置上的所有通道值,也就是第 j 个像素位置的所有信息。


这俩相乘,可以看成是第 j 个像素对第 i 个像素的影响。即,energy 中第 i 行 j 列的元素值,表示第 j 个像素点对第 i 个像素点的影响。


attention = self.softmax(energy)


这里 sofmax 是构造函数中定义的,为按“行”归一化。这个操作之后的矩阵,各行元素之和为 1。这也比较好理解,因为 energy 中第 i 行元素,代表 feature map 中所有位置的像素对第 i 个像素的影响,而这个影响被解释为权重,故加起来应该是 1,故应对其按行归一化。attention 的维度也是 N × N。


proj_value = self.value_conv(x).view(m_batchsize,-1,width*height)


上面的代码中,先对原 feature map 作一次卷积映射,然后把得到的新 feature map 改变形状,维度变为 C × N ,其中 C 为通道数(注意和上面计算 proj_query   proj_key 的 C 不同,上面的 C 为 feature map 通道数的 1/8,这里的 C 与 feature map 通道数相同),N 为 feature map 的像素数。


 图5. proj_value的维度


out = torch.bmm(proj_value,attention.permute(0,2,1) )
out = out.view(m_batchsize,C,width,height)


然后,再把 proj_value (C × N)矩阵同  attention 矩阵的转置(N × N)相乘,得到 out (C × N)。之所以转置,是因为 attention 中每行的和为 1,其意义是权重,需要转置后变为每列的和为 1,施加于 proj_value 的行上,作为该行的加权平均。 proj_value 第 i 行代表第 i 个通道所有的像素值, attention 第 j 列,代表所有像素施加到第 j 个像素的影响。


因此, out 中第 i 行包含了输出的第 i 个通道中的所有像素,第 j 列表示所有像素中的第 j 个像素,合起来也就是: out 中的第 i 行第 j 列的元素,表示被 attention 加权之后的 feature map 的第 i 个通道的第 j 个像素的像素值。再改变一下形状, out 就恢复了 channel×width×height 的结构。


 图6. out的维度


out = self.gamma*out + x


最后一行代码,借鉴了残差神经网络(residual neural networks)的操作, gamma 是一个参数,表示整体施加了 attention 之后的 feature map 的权重,需要通过反向传播更新。而 x 就是输入的 feature map。


在初始阶段, gamma 为 0,该 attention 模块直接返回输入的 feature map,之后随着学习,该 attention 模块逐渐学习到了将 attention 加权过的 feature map 加在原始的 feature map 上,从而强调了需要施加注意力的部分 feature map。


总结


可以把 self attention 看成是 feature map 和它自身的转置相乘,让任意两个位置的像素直接发生关系,这样就可以学习到任意两个像素之间的依赖关系,从而得到全局特征了。看论文时会被它复杂的符号迷惑,但是一看代码就发现其实是很 naive 的操作。


参考文献


[1] Xiaolong Wang, Ross Girshick, Abhinav Gupta, Kaiming He, Non-local Neural Networks, CVPR 2018.

[2] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida, Spectral Normalization for Generative Adversarial Networks, ICLR 2018.




点击以下标题查看更多往期内容: 





#投 稿 通 道#

 让你的论文被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢? 答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得技术干货。我们的目的只有一个,让知识真正流动起来。


📝 来稿标准:

• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向) 

• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接 

• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志


📬 投稿邮箱:

• 投稿邮箱:hr@paperweekly.site 

• 所有文章配图,请单独在附件中发送 

• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通




🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧



关于PaperWeekly


PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。


▽ 点击 | 阅读原文 | 获取最新论文推荐

登录查看更多
12

相关内容

最新《生成式对抗网络》简介,25页ppt
专知会员服务
168+阅读 · 2020年6月28日
【ICML2020】持续图神经网络,Continuous Graph Neural Networks
专知会员服务
146+阅读 · 2020年6月28日
专知会员服务
107+阅读 · 2020年5月21日
【CVPR2020】MSG-GAN:用于稳定图像合成的多尺度梯度GAN
专知会员服务
26+阅读 · 2020年4月6日
近期必读的5篇AI顶会CVPR 2020 GNN (图神经网络) 相关论文
专知会员服务
78+阅读 · 2020年3月3日
必读的10篇 CVPR 2019【生成对抗网络】相关论文和代码
专知会员服务
31+阅读 · 2020年1月10日
生成对抗网络也需要注意力机制
机器之心
5+阅读 · 2019年3月15日
详解GAN的谱归一化(Spectral Normalization)
PaperWeekly
11+阅读 · 2019年2月13日
旧照片着色修复神器!自注意力GAN效果惊艳
GAN生成式对抗网络
7+阅读 · 2018年11月3日
自适应注意力机制在Image Caption中的应用
PaperWeekly
10+阅读 · 2018年5月10日
干货 | NLP中的self-attention【自-注意力】机制
机器学习算法与Python学习
12+阅读 · 2018年4月11日
TensorFlow seq2seq中的Attention机制(续)
深度学习每日摘要
15+阅读 · 2017年11月16日
GAN猫的脸
机械鸡
11+阅读 · 2017年7月8日
Arxiv
4+阅读 · 2019年12月2日
Self-Attention Graph Pooling
Arxiv
5+阅读 · 2019年4月17日
Arxiv
6+阅读 · 2019年4月8日
Arxiv
8+阅读 · 2018年5月21日
VIP会员
相关VIP内容
相关资讯
生成对抗网络也需要注意力机制
机器之心
5+阅读 · 2019年3月15日
详解GAN的谱归一化(Spectral Normalization)
PaperWeekly
11+阅读 · 2019年2月13日
旧照片着色修复神器!自注意力GAN效果惊艳
GAN生成式对抗网络
7+阅读 · 2018年11月3日
自适应注意力机制在Image Caption中的应用
PaperWeekly
10+阅读 · 2018年5月10日
干货 | NLP中的self-attention【自-注意力】机制
机器学习算法与Python学习
12+阅读 · 2018年4月11日
TensorFlow seq2seq中的Attention机制(续)
深度学习每日摘要
15+阅读 · 2017年11月16日
GAN猫的脸
机械鸡
11+阅读 · 2017年7月8日
Top
微信扫码咨询专知VIP会员