基于Tensorflow Estimators的文本分类系列之二

2018 年 4 月 19 日 全球人工智能

高薪招聘兼职AI讲师和AI助教!

来源:AI前线

作者:徐鹏


输入函数

Estimator框架使用输入函数从模型本身分割数据管道。无论您的数据是在.csv文件中,还是在pandas.DataFrame中,无论它是否进入内存,都可以使用多种帮助器方法来创建它们。在我们的例子中,我们可以为相关数据集和测试集使用Dataset.from_tensor_slices。

x_len_train = np.array([min(len(x),sentence_size)for x in x_train_variable])

x \ _len \ _test = np.array([min(len(x),sentence_size)for x in x_test_variable])

def parser(x,length,y):

   特征= {“x”:x,“len”:长度}

   返回特征,y

def train_input_fn():

   dataset = tf.data.Dataset.from_tensor_slices(\

        (x_train,x_len_train,y_train))

   dataset = dataset.shuffle(buffer_size = len(x_train_variable))

   dataset = dataset.batch(100)dataset = dataset.map(解析器)

   dataset = dataset.repeat()

   iterator = dataset.make_one_shot_iterator()

   返回iterator.get_next()

def eval_input_fn():

   dataset = tf.data.Dataset.from_tensor_slices(

        (x_test,x_len_test,y_test))

   dataset = dataset.batch(100)dataset = dataset.map(解析器)

   iterator = dataset.make_one_shot_iterator() 

   返回iterator.get_next()

我们对训练数据进行混洗,并没有预先定义我们想要训练的历元的数量,而我们只需要一个测试数据的信号出现时间进行评估。我们还添加了一个额外的“len”的键,该键可以捕获原始未加垫片序列的长度,我们将在稍后使用。

基线建立

对于机器学习来说,一在开始就尝试基本基线的项目的英文一个很好的做法。越简单越好,因为拥有简单而强大的基线是通过增加额外复杂性来准确理解我们在性能方面获得多少的关键。很简单的解决方案可以满足我们的要求

考虑到这一点,让我们先尝试一下最简单的文本分类模型。这将是一个稀疏的线性模型,它为每个令牌赋予权重,并将所有结果相加,而不管顺序如何。由于这个模型并不关心句子中单词的顺序,所以我们通常将其称为一袋词方法。我们来看看如何使用估计来实现这个模型。

我们首先定义用作分类器输入的特征列。正如我们在第2部分中看到的,categorical_column_with_identity是此预处理文本输入的正确选择。如果我们提供原始文本标记,其他feature_columns可以为我们做很多预处理。我们现在可以使用预制的LinearClassifier。

column = tf.feature_column.categorical_column_with_identity('x',vocab_size)

classifier = tf.estimator.LinearClassifier(feature_columns = [column],model_dir = os.path.join(model_dir,'bow_sparse'))

最后,我们创建一个简单的函数来训练分类器,并创建一个精确回忆曲线。由于我们的目的并不是为了最大限度地发挥其性能,我们只能训练我们的模型25,000步。

def train_and_evaluate(分类器): 

    classifier.train(input_fn = train_input_fn,steps = 25000) 

    eval_results = classifier.evaluate(input_fn = eval_input_fn) 

    predictions = np.array([p ['logistic'] [0] \

        for class in classifier.predict(input_fn = eval_input_fn)]) 

    tf.reset_default_graph() 

    #除了分类器写入的摘要之外,还要添加一个PR摘要

    pr = summary_lib.pr_curve(

            'precision_recall', 

             预测=预测,

             labels = y_test.astype(bool),num_thresholds = 21) 

    用tf.Session()作为sess: 

        writer = tf.summary.FileWriter(os.path.join(

                classifier.model_dir,'eval'),sess.graph) 

        writer.add_summary(sess.run(pr),global_step = 0) 

        writer.close() 

列车\ _and \ _evaluate(分类器)

选择简单模型的好处之一是它更易于解释。模型越复杂,检查就越困难,而且越容易像黑匣子一样工作。在这个例子中,我们可以从我们模型的上一个检查点加载权重,并查看哪些令牌与绝对值最大的权重对应。结果看起来像我们所期望的。

#用张量加载张量 

weights = classifier.get_variable_value(

            '线性/ linear_model / X /权重')。弄平() 

#找到绝对值最大的权重 

极端= np.concatenate(

        (sorted_indexes [-8:],sorted_indexes [:8])) 

#word_inverted_index是一个从索引映射回标记的字典 

extreme_weights =排序( 

        [(权重[i],word_inverted_index [i - index_offset])\

        因为我在极端]])

#创建绘图 

y_pos = np.arange(len(extreme_weights)) 

plt.bar(y_pos,[extreme_weights中的pair [pair [0]]), 

    align ='center',alpha = 0.5)plt.xticks(y_pos,[pair [1] 

    对extreme_weights],旋转= 45,哈='右') 

plt.ylabel( '体重') 

plt.title('最重要的令牌') 

plt.show()

一些重要的标志

正如我们所看到的,“清爽”这种最积极的标记显然与积极的情绪相关,而具有较大负面代价的标记却无法引起负面情绪。一个简单而强大的修改,可以做的改进这个模型是通过他们的TF-IDF分数来加权记号。

原文链接:http //ruder.io/text-classification-tensorflow-estimators/

- 加入人工智能学院系统学习 -

点击“ 阅读原文 ”查看详情

登录查看更多
1

相关内容

数据集,又称为资料集、数据集合或资料集合,是一种由数据所组成的集合。
Data set(或dataset)是一个数据的集合,通常以表格形式出现。每一列代表一个特定变量。每一行都对应于某一成员的数据集的问题。它列出的价值观为每一个变量,如身高和体重的一个物体或价值的随机数。每个数值被称为数据资料。对应于行数,该数据集的数据可能包括一个或多个成员。
Sklearn 与 TensorFlow 机器学习实用指南,385页pdf
专知会员服务
129+阅读 · 2020年3月15日
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
TensorFlow Lite指南实战《TensorFlow Lite A primer》,附48页PPT
专知会员服务
69+阅读 · 2020年1月17日
必读的7篇IJCAI 2019【图神经网络(GNN)】相关论文-Part2
专知会员服务
60+阅读 · 2020年1月10日
TF Boys必看!一文搞懂TensorFlow 2.0新架构!
引力空间站
18+阅读 · 2019年1月16日
基于 Keras 用深度学习预测时间序列
R语言中文社区
23+阅读 · 2018年7月27日
教程 | 用TensorFlow Estimator实现文本分类
机器之心
4+阅读 · 2018年5月17日
【干货】基于Keras的注意力机制实战
专知
59+阅读 · 2018年5月4日
业界|基于Tensorflow Estimators的文本分类系列之一
全球人工智能
3+阅读 · 2018年4月19日
keras系列︱深度学习五款常用的已训练模型
数据挖掘入门与实战
10+阅读 · 2018年3月27日
发布TensorFlow 1.4
谷歌开发者
7+阅读 · 2017年11月23日
Tensorflow 文本分类-Python深度学习
Python程序员
12+阅读 · 2017年11月22日
深度学习实战(二)——基于Keras 的深度学习
乐享数据DataScientists
15+阅读 · 2017年7月13日
Factor Graph Attention
Arxiv
6+阅读 · 2019年4月11日
A Probe into Understanding GAN and VAE models
Arxiv
9+阅读 · 2018年12月13日
Arxiv
5+阅读 · 2018年4月13日
Arxiv
4+阅读 · 2016年12月29日
VIP会员
相关资讯
TF Boys必看!一文搞懂TensorFlow 2.0新架构!
引力空间站
18+阅读 · 2019年1月16日
基于 Keras 用深度学习预测时间序列
R语言中文社区
23+阅读 · 2018年7月27日
教程 | 用TensorFlow Estimator实现文本分类
机器之心
4+阅读 · 2018年5月17日
【干货】基于Keras的注意力机制实战
专知
59+阅读 · 2018年5月4日
业界|基于Tensorflow Estimators的文本分类系列之一
全球人工智能
3+阅读 · 2018年4月19日
keras系列︱深度学习五款常用的已训练模型
数据挖掘入门与实战
10+阅读 · 2018年3月27日
发布TensorFlow 1.4
谷歌开发者
7+阅读 · 2017年11月23日
Tensorflow 文本分类-Python深度学习
Python程序员
12+阅读 · 2017年11月22日
深度学习实战(二)——基于Keras 的深度学习
乐享数据DataScientists
15+阅读 · 2017年7月13日
Top
微信扫码咨询专知VIP会员