本文系作者何从庆投稿,作者公众号:AI算法之心(id:AIHeartForYou),欢迎关注,点击文末"阅读原文"可直达原文链接,也欢迎大家投稿,AI、NLP相关即可。
[导读]
文本分类是自然语言处理中一个很经典也很重要的问题,它的应用很广泛,在很多领域发挥着重要作用,例如情感分析、舆情分析以及新闻分类等。在前几篇文章中介绍了几种传统的文本分类方法,由于文本表征都是高度稀疏的,因此特征表达能力较差。此外,传统文本分类方法需要人工特征工程,这个过程比较耗时。随着深度学习的发展,基于CNN、RNN的方法可以有效的缓解词语特征表征能力弱,序列捕捉能力差等问题。本文将介绍一种深度文本分类方法——RNN via Attention,该方法常常作为文本分类重要的baseline。
注意:想快速学习RNN via Attention的代码,小伙伴们赶紧滑到最后,这里有你想要的代码哦
RNN via Attention结构
传统的文本分类方法,基本都是利用TFIDF提取词频以及词语间的N-gram信息作为特征,然后通过机器学习方法如逻辑回归、支持向量等作为分类器。前几篇介绍的TFIDF-LR、TFIDF-NBSVM都是传统文本分类方法。这些方法特征表达能力差,序列捕捉能力弱,很难深层次的表征文本信息。
随着深度学习的发展,常用CNN、RNN等模型端到端的解决文本分类问题。本文介绍的RNN via Attention 是最经典的深度文本分类方法之一。下面我来以通俗易懂的方法一一道来该模型的优点。
RNN(s)
对于文本数据,最重要的是如何捕捉到上下文信息。RNN主要解决序列数据的处理,比如文本、语音、视频等等。简单的来说,RNN主要是通过上一时刻的信息以及当前时刻的输入,确定当前时刻的信息。因此,RNN可以捕捉到序列信息,这与捕捉文本的上下文信息相得益彰。
传统的RNN也会存在许多问题,无法处理随着递归,权重指数级爆炸或消失的问题,难以捕捉长期时间关联等等。基于RNN的变体,如LSTM和GRU很好的缓解这个问题。但是呢,LSTM和GRU这些网络,在长序列信息记忆方面效果也不是很好,Colin Raffel等人基于RNN以及RNN的变体,提出了一种适用于RNN(s)的简化注意力模型,很好的解决序列长度的合成加法和乘法的长期记忆问题。
在本文中,我们使用了一种RNN的变体——LSTM捕捉文本的上下文信息。更多地,人类在阅读文本时会考虑到前后上下文信息,在本文中,我们使用了双向的LSTM来捕捉前后上下文信息,充分的捕捉文本的前后信息。
Attention机制
基于Attention机制的论文有很多,比如transformer的self-attention、Hiearchical Attention、Feed-Forward Attention等。Attention的原理是什么呢?简单地说,Attention机制最初出现在图像领域,由于人在观察物体时会更注重重要的部分,所以机器也当和人一样,注意物体或者文本更重要的部分。
本文使用了一种Feed-Forward Attention (下文简称Attention机制)来对lstm捕捉的上下文信息使用注意力机制。一般来说,对于序列数据模型(RNN、GRU、LSTM),使用最大池或者平均池化来对提取的上下文信息进行操作,很容易丢失掉重要的信息,最大池化提取的不是我们想要的信息,平均池化弱化了我们想要的信息,种种原因,Attention机制成为了最优秀的池化操作之一。
如上图,Attention机制首先将RNN(s)每个时刻的隐藏层输入到一个全连接层,然后产生一个概率向量(其实就是类似于softmax函数),然后用这个概率向量对每个隐藏层加权,最后相加得到最终的向量c。看起来可能很复杂,其实只要理解里面的原理,大家就会发现也就那么回事。大家如果对Attention机制的原理想深入研究的话,可以看一下参考里面的论文。
Attention机制代码实现
上面我也讲解了Attention机制的原理,俗话说,光说不练假把式。这里我也会提供一个keras版本的Attention代码,有兴趣的大家赶紧拿去试试吧。
class Attention(Layer):
def __init__(self, step_dim,
W_regularizer=None, b_regularizer=None,
W_constraint=None, b_constraint=None,
bias=True, **kwargs):
self.supports_masking = True
self.init = initializers.get('glorot_uniform')
self.W_regularizer = regularizers.get(W_regularizer)
self.b_regularizer = regularizers.get(b_regularizer)
self.W_constraint = constraints.get(W_constraint)
self.b_constraint = constraints.get(b_constraint)
self.bias = bias
self.step_dim = step_dim
self.features_dim = 0
super(Attention, self).__init__(**kwargs)
def build(self, input_shape):
assert len(input_shape) == 3
self.W = self.add_weight((input_shape[-1],),
initializer=self.init,
name='{}_W'.format(self.name),
regularizer=self.W_regularizer,
constraint=self.W_constraint)
self.features_dim = input_shape[-1]
if self.bias:
self.b = self.add_weight((input_shape[1],),
initializer='zero',
name='{}_b'.format(self.name),
regularizer=self.b_regularizer,
constraint=self.b_constraint)
else:
self.b = None
self.built = True
def compute_mask(self, input, input_mask=None):
return None
def call(self, x, mask=None):
features_dim = self.features_dim
step_dim = self.step_dim
eij = K.reshape(K.dot(K.reshape(x, (-1, features_dim)),
K.reshape(self.W, (features_dim, 1))), (-1, step_dim))
if self.bias:
eij += self.b
eij = K.tanh(eij)
a = K.exp(eij)
if mask is not None:
a *= K.cast(mask, K.floatx())
a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())
a = K.expand_dims(a)
weighted_input = x * a
return K.sum(weighted_input, axis=1)
def compute_output_shape(self, input_shape):
return input_shape[0], self.features_dim
RNN via Attention 实战
上面我也介绍了Attention机制的原理和代码,是不是蓄势待发,想要练练手!这里我也为大家提供了一个案件,来实战下我们RNN via Attention的模型。
下面我们就可以实现我们的RNN via Attention 模型了,我们用该模型来解决kaggle上面的 Toxic Comment Classification Challenge。
模型代码如下:
def RNN_Attention(maxlen):
inp = Input(shape=(maxlen,))
x = Embedding(max_features, embed_size, weights=[embedding_matrix])(inp)
x = Bidirectional(LSTM(50, return_sequences=True, dropout=0.1, recurrent_dropout=0.1))(x)
x = Attention(step_dim=100)(x)
x = Dense(50, activation="relu")(x)
x = Dropout(0.1)(x)
x = Dense(6, activation="sigmoid")(x)
model = Model(inputs=inp, outputs=x)
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()
return model
如果您需要有毒评论文本分类数据,可以关注"AI算法之心",后台回复 "ToxicComment"(建议复制)获取。
完整可运行代码可以在我的github找到,后期我会更新其他的深度文本分类算法代码:
https://github.com/hecongqing/TextClassification
参考资料:
[1]https://www.kaggle.com/takuok/bidirectional-lstm-and-attention-lb-0-043
[2]https://www.kaggle.com/jhoward/improved-lstm-baseline-glove-dropout
[3]Raffel C, Ellis D P W. Feed-forward networks with attention can solve some long-term memory problems[J]. arXiv preprint arXiv:1512.08756, 2015.