【干货】基于Keras的注意力机制实战

【导读】近几年,注意力机制(Attention)大量地出现在自动翻译、信息检索等模型中。可以把Attention看成模型中的一个特征选择组件,特征选择一方面可以增强模型的效果,另一方面,我们可以通过计算出的特征的权重来计算结果与特征之间的某种关联。例如在自动翻译模型中,Attention可以计算出不同语种词之间的关系。本文一个简单的例子,来展示Attention是怎么在模型中起到特征选择作用的。


代码




导入相关库

#coding=utf-8
import numpy as np
from keras.models import *
from keras.layers import Input, Dense, merge
import matplotlib.pyplot as plt
import pandas as pd


数据生成函数

# 输入维度
input_dim = 32


# 生成数据,数据的的第attention_column个特征由label决定,
# 即
label只与数据的第attention_column个特征相关
def get_data(n, input_dim, attention_column=1):
   x = np.random.standard_normal(size=(n, input_dim))
   y = np.random.randint(low=0, high=2, size=(n, 1))
   x[:, attention_column] = y[:, 0]
   return x, y


模型定义函数

将输入进行一次变换后,计算出Attention权重,将输入乘上Attention权重,获得新的特征。


# Attention模型
def build_model():
   inputs = Input(shape=(input_dim,))

   # 计算Attention权重
   
attention_probs = Dense(input_dim, activation='softmax',
name='attention_vec')(inputs)
   # 根据Attention权重更新特征
   
attention_mul = merge([inputs, attention_probs],
output_shape=32,
name='attention_mul', mode='mul')

   # 预测标签
   
attention_mul = Dense(64)(attention_mul)
   output = Dense(1, activation='sigmoid')(attention_mul)
   model = Model(input=[inputs], output=output)
   attention_vec_model = Model(input=[inputs],
output=attention_probs)
   return model, attention_vec_model


主函数

if __name__ == '__main__':
   # 生成训练数据
   
N = 10000
   
inputs_1, outputs = get_data(N, input_dim)

   # 获取模型,以及用于计算Attention权重的子模型
   
m, attention_vec_model = build_model()
   m.compile(optimizer='adam', loss='binary_crossentropy',
metrics=['accuracy'])
   print(m.summary())

   # 训练
   
m.fit([inputs_1], outputs, epochs=20, batch_size=64,
validation_split=0.5)

   # 生成测试数据
   
testing_inputs_1, testing_outputs = get_data(1, input_dim)

   # 根据测试数据计算Attention权重
   
attention_vector = attention_vec_model.
   predict([testing_inputs_1])[0].flatten()
   print('attention =', attention_vector)

   # 绘图
pd.DataFrame(attention_vector, columns=['attention (%)'])
.plot(kind='bar', title='Attention Mechanism as a function of
input dimensions.'
)
    plt.show()


运行结果

代码中,attention_column为1,也就是说,label只与数据的第1个特征相关。从运行结果中可以看出,Attention权重成功地获取了这个信息。


参考链接

https://github.com/philipperemy/keras-attention-mechanism


更多教程资料请访问:人工智能知识资料全集

-END-

专 · 知

人工智能领域主题知识资料查看与加入专知人工智能服务群

【专知AI服务计划】专知AI知识技术服务会员群加入人工智能领域26个主题知识资料全集获取

[点击上面图片加入会员]

请PC登录www.zhuanzhi.ai或者点击阅读原文,注册登录专知,获取更多AI知识资料

请加专知小助手微信(扫一扫如下二维码添加),加入专知主题群(请备注主题类型:AI、NLP、CV、 KG等)交流~

关注专知公众号,获取人工智能的专业知识!

点击“阅读原文”,使用专知

展开全文
Top
微信扫码咨询专知VIP会员