tf.keras为我们提供了易用的TF API,其中keras.Model是最重要的API之一,它封装了模型的参数、结构等信息及训练、测试等过程。为了让用户能够更好地定制训练的过程,TF 2.2为该API引入了新的可扩展接口。
在TensorFlow开发者峰会2020(TF Dev Summit '20)中,相关人员介绍了TF 2.2为keras.Model引入的自定义训练过程接口train_step:
在之前的版本中,虽然tf.keras的keras.Model模型封装了模型的训练过程,但由于这种封装过于黑盒,使得许多开发者并不愿意使用keras.Model自带的训练功能,而选择显式地调用tf.GradientTape等来进行反向传播和参数更新。一般,开发者会定义如下的训练过程:``` def train_step(images, labels): with tf.GradientTape() as tape: logits = mnist_model(images, training=True)
tf.debugging.assert_equal(logits.shape, (32, 10))
loss_value = loss_object(labels, logits)
loss_history.append(loss_value.numpy().mean()) grads = tape.gradient(loss_value, mnist_model.trainable_variables) optimizer.apply_gradients(zip(grads, mnist_model.trainable_variables))
然后通过循环来手动调度训练过程:```
def train(epochs):
for epoch in range(epochs):
for(batch, (images, labels)) in enumerate(dataset):
train_step(images, labels)
print('Epoch {} finished'.format(epoch))
keras.Model自带了许多非常好用的功能,例如进度显示、基于回调的TensorBoard日志、基于回调的Early Stop等。一般需要使用keras.Model自带的训练机制才可以享受到这些便捷的功能,上面这种手动调用的方法虽然能够让开发者对训练过程有着完全的掌控,但也使得他们不能享受部分keras.Model自带的便捷功能。
TF 2.2在keras.Model类中直接引入了train_step方法,这样开发者只需要在继承keras.Model模型时用自定义的方法覆盖父类中train_step的方法,就可以自定义可控的训练过程,并使用keras.Model自带的调度机制来进行训练:
参考链接: