【导读】本文翻译自TensorFlow官网新出的教程《Text generation using a RNN with eager execution》,该教程介绍如何使用TensorFlow Eager(动态图)和RNN来学习生成莎士比亚的作品。模型可以根据已有的字符序列来预测序列的下一个字符,以达到文本生成的效果。
简介
教程包含了用Tensorflow Eager(动态图)和tf.keras实现的可执行代码,下面是代码运行的示例结果:
QUEENE:
I had thought thou hadst a Roman; for the oracle,
Thus by All bids the man against the word,
Which are so weak of care, by old care done;
Your children were in your holy love,
And the precipitation through the bleeding throne.
BISHOP OF ELY:
Marry, and will, my lord, to weep in such a one were prettiest;
Yet now I was adopted heir
Of the world's lamentable day,
To watch the next way with his father with his face?
ESCALUS:
The cause why then we are all resolved more sons.
VOLUMNIA:
O, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, it is no sin it should be dead,
And love and pale as any will to that word.
QUEEN ELIZABETH:
But how long have I heard the soul for this world,
And show his hands of life be proved to stand.
PETRUCHIO:
I say he look'd on, if I must be content
To stay him from the fatal of our country's bliss.
His lordship pluck'd from this sentence then for prey,
And then let us twain, being the moon,
were she such a case as fills m
虽然生成的句子中有一部分看起来比较符合语法,大多数生成的句子还是没有什么意义的。这个模型并没有考虑词的意义,而是考虑:
该模型是基于字符的,模型并不知道如何用字符拼写单词,甚至不知道词是文本的组成单元。
文本的结构很像戏剧,和训练集类似,文本块往往以说话者的大写名字开始。
模型中序列长度为100的序列上进行训练,但它有能力去生成更长的序列。
简介
导入相关库:
import tensorflow as tf
tf.enable_eager_execution()
import numpy as np
import os
import time
下载数据集:
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
读取数据:
text = open(path_to_file).read()
print ('Length of text: {} characters'.format(len(text)))
查看数据:
print(text[:1000])
First Citizen:
Before we proceed any further, hear me speak.
All:
Speak, speak.
First Citizen:
You are all resolved rather to die than to famish?
All:
Resolved. resolved.
First Citizen:
First, you know Caius Marcius is chief enemy to the people.
All:
We know't, we know't.
First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?
All:
No more talking on't; let it be done: away, away!
Second Citizen:
One word, good citizens.
First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.
构建字符集合:
# The unique characters in the file
vocab = sorted(set(text))
print ('{} unique characters'.format(len(vocab)))
65 unique characters
文本处理
构建字符索引:
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)
text_as_int = np.array([char2idx[c] for c in text])
for char,_ in zip(char2idx, range(20)):
print('{:6s} ---> {:4d}'.format(repr(char), char2idx[char]))
'j' ---> 48
'f' ---> 44
'R' ---> 30
':' ---> 10
'W' ---> 35
';' ---> 11
'o' ---> 53
'b' ---> 40
'K' ---> 23
'L' ---> 24
'O' ---> 27
'h' ---> 46
'm' ---> 51
'u' ---> 59
'H' ---> 20
'z' ---> 64
'!' ---> 2
'S' ---> 31
'N' ---> 26
'Z' ---> 38
预测任务:
给定一个字符或一个字符序列,希望预测下一个最可能出现的字符。在训练时,我们的模型输入seq_length个字符,输出seq_length个字符。例如seq_length为4,我们的文本你是Hello,那么输入是Hell,输出是ello。
首先获得多个长度为seq_length的文本:
seq_length = 100
chunks = tf.data.Dataset.from_tensor_slices(text_as_int).batch(seq_length+1, drop_remainder=True)
for item in chunks.take(5):
print(repr(''.join(idx2char[item.numpy()])))
'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '
'are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you k'
"now Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us ki"
"ll him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be d"
'one: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citi'
从文本中提取输入和目标:
def split_input_target(chunk):
input_text = chunk[:-1]
target_text = chunk[1:]
return input_text, target_text
dataset = chunks.map(split_input_target)
for input_example, target_example in dataset.take(1):
print ('Input data: ', repr(''.join(idx2char[input_example.numpy()])))
print ('Target data:', repr(''.join(idx2char[target_example.numpy()])))
Input data: 'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'
Target data: 'irst Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '
利用tf.data来划分batch,并进行shuffle:
# batch大小
BATCH_SIZE = 64
# shuffle缓存大小
BUFFER_SIZE = 10000
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
模型
模型类,使用了tf.keras的Embedding和GRU层:
class Model(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, units):
super(Model, self).__init__()
self.units = units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
if tf.test.is_gpu_available():
self.gru = tf.keras.layers.CuDNNGRU(self.units,
return_sequences=True,
recurrent_initializer='glorot_uniform',
stateful=True)
else:
self.gru = tf.keras.layers.GRU(self.units,
return_sequences=True,
recurrent_activation='sigmoid',
recurrent_initializer='glorot_uniform',
stateful=True)
self.fc = tf.keras.layers.Dense(vocab_size)
def call(self, x):
embedding = self.embedding(x)
# output at every time step
# output shape == (batch_size, seq_length, hidden_size)
output = self.gru(embedding)
# The dense layer will output predictions for every time_steps(seq_length)
# output shape after the dense layer == (seq_length * batch_size, vocab_size)
prediction = self.fc(output)
# states will be used to pass at every step to the model while training
return prediction
实例化模型、优化器和损失函数:
# 字典大小
vocab_size = len(vocab)
# 字符向量维度
embedding_dim = 256
# RNN单元维度
units = 1024
model = Model(vocab_size, embedding_dim, units)
optimizer = tf.train.AdamOptimizer()
def loss_function(real, preds):
return tf.losses.sparse_softmax_cross_entropy(labels=real, logits=preds)
训练模型:
model.build(tf.TensorShape([BATCH_SIZE, seq_length]))
model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) multiple 16640
_________________________________________________________________
gru (GRU) multiple 3935232
_________________________________________________________________
dense (Dense) multiple 66625
=================================================================
Total params: 4,018,497
Trainable params: 4,018,497
Non-trainable params: 0
_________________________________________________________________
# 保存checkpoints的目录
checkpoint_dir = './training_checkpoints'
# Checkpoint文件名
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
简单训练几次:
EPOCHS = 5
# 循环训练
for epoch in range(EPOCHS):
start = time.time()
# 在每轮epoch开始时,初始化状态
hidden = model.reset_states()
for (batch, (inp, target)) in enumerate(dataset):
with tf.GradientTape() as tape:
predictions = model(inp)
loss = loss_function(target, predictions)
grads = tape.gradient(loss, model.variables)
optimizer.apply_gradients(zip(grads, model.variables))
if batch % 100 == 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
batch,
loss))
# 每5个epoch保存一下模型
if (epoch + 1) % 5 == 0:
model.save_weights(checkpoint_prefix)
print('Epoch {} Loss {:.4f}'.format(epoch + 1, loss))
print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
Epoch 3 Batch 0 Loss 1.7645
Epoch 3 Batch 100 Loss 1.6853
Epoch 3 Loss 1.6164
Time taken for 1 epoch 610.0756878852844 sec
Epoch 4 Batch 0 Loss 1.6491
Epoch 4 Batch 100 Loss 1.5350
Epoch 4 Loss 1.5071
Time taken for 1 epoch 609.8330454826355 sec
Epoch 5 Batch 0 Loss 1.4715
Epoch 5 Batch 100 Loss 1.4685
Epoch 5 Loss 1.4042
Time taken for 1 epoch 608.6753587722778 sec
保存模型:
model.save_weights(checkpoint_prefix)
模型载入:
如果需要载入保存的模型,可以使用下面的代码:
model = Model(vocab_size, embedding_dim, units)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([1, None]))
使用训练的模型来进行文本生成:
# 生成文本
# 生成字符的数量
num_generate = 1000
# 起始字符
start_string = 'Q'
input_eval = [char2idx[s] for s in start_string]
input_eval = tf.expand_dims(input_eval, 0)
# 用于存储结果的空字符串
text_generated = []
# 较低的temperature值会产出一些预料之中的文本.
# 较高的temperature值会产出一些出乎预料的文本.
temperature = 1.0
# 这里batch大小为1
model.reset_states()
for i in range(num_generate):
predictions = model(input_eval)
# 移除batch维
predictions = tf.squeeze(predictions, 0)
# 用多项式分布来预测生成的词
predictions = predictions / temperature
predicted_id = tf.multinomial(predictions, num_samples=1)[-1, 0].numpy()
# 将生成的词和上一次的隐藏状态作为模型的下一个输入
input_eval = tf.expand_dims([predicted_id], 0)
text_generated.append(idx2char[predicted_id])
print(start_string + ''.join(text_generated))
运行结果:
QULERBY:
If a body.
But I would me your lood.
Steak ungrace and as this only in the ploaduse,
his they, much you amed on't.
RSCALIO:
Hearn' thousand as your well, and obepional.
ANTONIO:
Can wathach this wam a discure that braichal heep itspose,
Teparmate confoim it: never knor sheep, so litter
Plarence? He,
But thou sunds a parmon servection:
Occh Rom o'ld him sir;
madish yim,
I'll surm let as hand upherity
Shepherd:
Why do I sering their stumble; the thank emo'st yied
Baunted unpluction; the main, sir, What's a meanulainst
Even worship tebomn slatued of his name,
Manisholed shorks you go?
BUCKINGHAM:
We look thus then impare'd least itsiby drumes,
That I, what!
Nurset, fell beshee that which I will
to the near-Volshing upon this aguin against fless
Is done untlein with is the neck,
Thands he shall fear'ds; let me love at officed:
Where else to her awticions, as you hall, my lord.
KING RICHARD II:
I will been another one our accuser less
Tiold, methought to the presench of consiar
参考资料:
https://www.tensorflow.org/tutorials/sequences/text_generation
https://github.com/tensorflow/docs/blob/master/site/en/tutorials/sequences/text_generation.ipynb
-END-
专 · 知
人工智能领域26个主题知识资料全集获取与加入专知人工智能服务群: 欢迎微信扫一扫加入专知人工智能知识星球群,获取专业知识教程视频资料和与专家交流咨询!
请PC登录www.zhuanzhi.ai或者点击阅读原文,注册登录专知,获取更多AI知识资料!
请加专知小助手微信(扫一扫如下二维码添加),加入专知主题群(请备注主题类型:AI、NLP、CV、 KG等)交流~
请关注专知公众号,获取人工智能的专业知识!
点击“阅读原文”,使用专知