撩一发深度文本分类之RNN via Attention

2019 年 1 月 27 日 AINLP

本文系作者何从庆投稿,作者公众号:AI算法之心(id:AIHeartForYou),欢迎关注,点击文末"阅读原文"可直达原文链接,也欢迎大家投稿,AI、NLP相关即可。




[导读] 

文本分类是自然语言处理中一个很经典也很重要的问题,它的应用很广泛,在很多领域发挥着重要作用,例如情感分析、舆情分析以及新闻分类等。在前几篇文章中介绍了几种传统的文本分类方法,由于文本表征都是高度稀疏的,因此特征表达能力较差。此外,传统文本分类方法需要人工特征工程,这个过程比较耗时。随着深度学习的发展,基于CNN、RNN的方法可以有效的缓解词语特征表征能力弱序列捕捉能力差等问题。本文将介绍一种深度文本分类方法——RNN via Attention,该方法常常作为文本分类重要的baseline


注意:想快速学习RNN via Attention的代码,小伙伴们赶紧滑到最后,这里有你想要的代码哦



RNN via Attention结构

传统的文本分类方法,基本都是利用TFIDF提取词频以及词语间的N-gram信息作为特征,然后通过机器学习方法如逻辑回归、支持向量等作为分类器。前几篇介绍的TFIDF-LR、TFIDF-NBSVM都是传统文本分类方法。这些方法特征表达能力差,序列捕捉能力弱,很难深层次的表征文本信息

随着深度学习的发展,常用CNN、RNN等模型端到端的解决文本分类问题。本文介绍的RNN via Attention 是最经典的深度文本分类方法之一。下面我来以通俗易懂的方法一一道来该模型的优点。




RNN(s)

对于文本数据,最重要的是如何捕捉到上下文信息。RNN主要解决序列数据的处理,比如文本、语音、视频等等。简单的来说,RNN主要是通过上一时刻的信息以及当前时刻的输入,确定当前时刻的信息。因此,RNN可以捕捉到序列信息,这与捕捉文本的上下文信息相得益彰

传统的RNN也会存在许多问题,无法处理随着递归,权重指数级爆炸或消失的问题难以捕捉长期时间关联等等。基于RNN的变体,如LSTM和GRU很好的缓解这个问题。但是呢,LSTM和GRU这些网络,在长序列信息记忆方面效果也不是很好,Colin Raffel等人基于RNN以及RNN的变体,提出了一种适用于RNN(s)的简化注意力模型很好的解决序列长度的合成加法和乘法的长期记忆问题

在本文中,我们使用了一种RNN的变体——LSTM捕捉文本的上下文信息。更多地,人类在阅读文本时会考虑到前后上下文信息,在本文中,我们使用了双向的LSTM来捕捉前后上下文信息,充分的捕捉文本的前后信息。


Attention机制


基于Attention机制的论文有很多,比如transformer的self-attention、Hiearchical Attention、Feed-Forward Attention等。Attention的原理是什么呢?简单地说,Attention机制最初出现在图像领域,由于人在观察物体时会更注重重要的部分,所以机器也当和人一样,注意物体或者文本更重要的部分。

本文使用了一种Feed-Forward Attention (下文简称Attention机制)来对lstm捕捉的上下文信息使用注意力机制。一般来说,对于序列数据模型(RNN、GRU、LSTM),使用最大池或者平均池化来对提取的上下文信息进行操作,很容易丢失掉重要的信息,最大池化提取的不是我们想要的信息,平均池化弱化了我们想要的信息,种种原因,Attention机制成为了最优秀的池化操作之一




如上图,Attention机制首先将RNN(s)每个时刻的隐藏层输入到一个全连接层,然后产生一个概率向量(其实就是类似于softmax函数),然后用这个概率向量对每个隐藏层加权,最后相加得到最终的向量c。看起来可能很复杂,其实只要理解里面的原理,大家就会发现也就那么回事。大家如果对Attention机制的原理想深入研究的话,可以看一下参考里面的论文。


Attention机制代码实现


上面我也讲解了Attention机制的原理,俗话说,光说不练假把式。这里我也会提供一个keras版本的Attention代码,有兴趣的大家赶紧拿去试试吧。


class Attention(Layer):
    def __init__(self, step_dim,
                 W_regularizer=None, b_regularizer=None,
                 W_constraint=None, b_constraint=None,
                 bias=True, **kwargs)
:

        self.supports_masking = True
        self.init = initializers.get('glorot_uniform')
        self.W_regularizer = regularizers.get(W_regularizer)
        self.b_regularizer = regularizers.get(b_regularizer)
        self.W_constraint = constraints.get(W_constraint)
        self.b_constraint = constraints.get(b_constraint)
        self.bias = bias
        self.step_dim = step_dim
        self.features_dim = 0
        super(Attention, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape) == 3
        self.W = self.add_weight((input_shape[-1],),
                                 initializer=self.init,
                                 name='{}_W'.format(self.name),
                                 regularizer=self.W_regularizer,
                                 constraint=self.W_constraint)
        self.features_dim = input_shape[-1]
        if self.bias:
            self.b = self.add_weight((input_shape[1],),
                                     initializer='zero',
                                     name='{}_b'.format(self.name),
                                     regularizer=self.b_regularizer,
                                     constraint=self.b_constraint)
        else:
            self.b = None
        self.built = True

    def compute_mask(self, input, input_mask=None):
        return None

    def call(self, x, mask=None):
        features_dim = self.features_dim
        step_dim = self.step_dim

        eij = K.reshape(K.dot(K.reshape(x, (-1, features_dim)),
                        K.reshape(self.W, (features_dim, 1))), (-1, step_dim))

        if self.bias:
            eij += self.b
        eij = K.tanh(eij)
        a = K.exp(eij)
        if mask is not None:
            a *= K.cast(mask, K.floatx())
        a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())
        a = K.expand_dims(a)
        weighted_input = x * a
        return K.sum(weighted_input, axis=1)

    def compute_output_shape(self, input_shape):
        return input_shape[0],  self.features_dim


RNN via Attention 实战

上面我也介绍了Attention机制的原理和代码,是不是蓄势待发,想要练练手!这里我也为大家提供了一个案件,来实战下我们RNN via Attention的模型。

下面我们就可以实现我们的RNN via Attention 模型了,我们用该模型来解决kaggle上面的 Toxic Comment Classification Challenge

模型代码如下:

def RNN_Attention(maxlen):
    inp = Input(shape=(maxlen,))
    x = Embedding(max_features, embed_size, weights=[embedding_matrix])(inp)
    x = Bidirectional(LSTM(50, return_sequences=True, dropout=0.1, recurrent_dropout=0.1))(x)
    x = Attention(step_dim=100)(x)
    x = Dense(50, activation="relu")(x)
    x = Dropout(0.1)(x)
    x = Dense(6, activation="sigmoid")(x)
    model = Model(inputs=inp, outputs=x)
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    model.summary()
    return model


如果您需要有毒评论文本分类数据,可以关注"AI算法之心",后台回复 "ToxicComment"(建议复制)获取。


完整可运行代码可以在我的github找到,后期我会更新其他的深度文本分类算法代码:

https://github.com/hecongqing/TextClassification


参考资料:

[1]https://www.kaggle.com/takuok/bidirectional-lstm-and-attention-lb-0-043

[2]https://www.kaggle.com/jhoward/improved-lstm-baseline-glove-dropout

[3]Raffel C, Ellis D P W. Feed-forward networks with attention can solve some long-term memory problems[J]. arXiv preprint arXiv:1512.08756, 2015.


作者:何从庆,湖南大学计算机硕士,主要研究方向: 机器学习与法律智能。 

Github 主页:https://github.com/hecongqing

微信公众号: AI算法之心 , 欢迎关注:







登录查看更多
7

相关内容

RNN:循环神经网络,是深度学习的一种模型。
【ICML2020-西电】用于语言生成的递归层次主题引导RNN
专知会员服务
21+阅读 · 2020年6月30日
最新《深度多模态数据分析》综述论文,26页pdf
专知会员服务
298+阅读 · 2020年6月16日
基于多头注意力胶囊网络的文本分类模型
专知会员服务
76+阅读 · 2020年5月24日
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
【MIT深度学习课程】深度序列建模,Deep Sequence Modeling
专知会员服务
77+阅读 · 2020年2月3日
NLP基础任务:文本分类近年发展汇总,68页超详细解析
专知会员服务
57+阅读 · 2020年1月3日
注意力机制介绍,Attention Mechanism
专知会员服务
168+阅读 · 2019年10月13日
深度学习的下一步:Transformer和注意力机制
云头条
56+阅读 · 2019年9月14日
误差反向传播——RNN
统计学习与视觉计算组
18+阅读 · 2018年9月6日
关于序列建模,是时候抛弃RNN和LSTM了
数盟
7+阅读 · 2018年4月20日
深度学习在文本分类中的应用
AI研习社
13+阅读 · 2018年1月7日
RNN在自然语言处理中的应用及其PyTorch实现
机器学习研究会
4+阅读 · 2017年12月3日
RNN在自然语言处理中的应用及其PyTorch实现 | 赠书
人工智能头条
6+阅读 · 2017年11月28日
完全图解RNN、RNN变体、Seq2Seq、Attention机制
AI研习社
12+阅读 · 2017年9月5日
干货|完全图解RNN、RNN变体、Seq2Seq、Attention机制
机器学习研究会
11+阅读 · 2017年8月5日
RNN | RNN实践指南(2)
KingsGarden
19+阅读 · 2017年5月4日
RNN | RNN实践指南(1)
KingsGarden
21+阅读 · 2017年4月4日
Arxiv
21+阅读 · 2019年8月21日
Arxiv
8+阅读 · 2018年1月25日
VIP会员
相关VIP内容
【ICML2020-西电】用于语言生成的递归层次主题引导RNN
专知会员服务
21+阅读 · 2020年6月30日
最新《深度多模态数据分析》综述论文,26页pdf
专知会员服务
298+阅读 · 2020年6月16日
基于多头注意力胶囊网络的文本分类模型
专知会员服务
76+阅读 · 2020年5月24日
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
【MIT深度学习课程】深度序列建模,Deep Sequence Modeling
专知会员服务
77+阅读 · 2020年2月3日
NLP基础任务:文本分类近年发展汇总,68页超详细解析
专知会员服务
57+阅读 · 2020年1月3日
注意力机制介绍,Attention Mechanism
专知会员服务
168+阅读 · 2019年10月13日
相关资讯
深度学习的下一步:Transformer和注意力机制
云头条
56+阅读 · 2019年9月14日
误差反向传播——RNN
统计学习与视觉计算组
18+阅读 · 2018年9月6日
关于序列建模,是时候抛弃RNN和LSTM了
数盟
7+阅读 · 2018年4月20日
深度学习在文本分类中的应用
AI研习社
13+阅读 · 2018年1月7日
RNN在自然语言处理中的应用及其PyTorch实现
机器学习研究会
4+阅读 · 2017年12月3日
RNN在自然语言处理中的应用及其PyTorch实现 | 赠书
人工智能头条
6+阅读 · 2017年11月28日
完全图解RNN、RNN变体、Seq2Seq、Attention机制
AI研习社
12+阅读 · 2017年9月5日
干货|完全图解RNN、RNN变体、Seq2Seq、Attention机制
机器学习研究会
11+阅读 · 2017年8月5日
RNN | RNN实践指南(2)
KingsGarden
19+阅读 · 2017年5月4日
RNN | RNN实践指南(1)
KingsGarden
21+阅读 · 2017年4月4日
Top
微信扫码咨询专知VIP会员