初学者的 Keras:实现卷积神经网络

2019 年 9 月 8 日 Python程序员


Keras是一个简单易用但功能强大的 Python 深度学习库。在这篇文章中,我们将用 Keras 构建一个简单的卷积神经网络(CNN),并训练它来解实际问题。


这篇文章适用于完全初学 Keras 的人,但假设有 CNNs 的基本背景知识。我对卷积神经网络的介绍涵盖了你在这篇文章中需要知道的一切(以及更多内容),如果需要,请先阅读。

 

我们现在就开始!

想要代码吗?完整的源代码在末尾。


问题:MNIST数字分类


我们将处理一个经典的计算机视觉入门问题:MNIST 手写数字分类。很简单:给定一个图像,将其分类为一个数字。

 MNIST 数据集中的样本图像

 

MNIST 数据集中的每个图像都是 28x28,包含一个居中的灰度数字。我们的 CNN 将获取一个图像并输出 10 个可能的类中的一个(每个数字一个)。


1. 安装


我假设你已经有了一个基本的 Python 安装(可能是这样)。让我们先下载一些我们需要的包:

注意:我们需要安装 tensorflow ,因为我们要在 TensorFlow 后端上运行 Keras(即 TensorFlow 将装备 Keras )。

现在你应该能够导入这些包并浏览 MNIST 数据集:

2. 准备数据


在开始之前,我们将把图像像素值从 [0,255] 规范化为 [-0.5,0.5] 以使网络更容易训练(使用较小的中心值通常会得到更好的结果)。我们还将把每个图像从(28,28)改为(28,28,1),因为 Keras 需要第三维度。

我们准备好开始构建我们的 CNN 了!


3. 构建模型


每个 Keras 模型要么使用表示层的线性堆栈的 Sequential 类构建,要么使用更可定制的功能 Model 类。我们将使用更简单的Sequential 模型,因为我们的 CNN 将是一个层的线性堆栈。


我们首先实例化一个 Sequential 模型:

Sequential 构造函数接受一个 Keras Layers 数组。我们将为 CNN 使用三种类型的层:卷积层、最大池层和 Softmax 层。

这是我们在我的 CNN 简介中使用的 CNN 设置。如果你对这三种层的任何一种都不满意的话,请阅读这篇文章。

  • num_filters, filter_size 和 pool_size 是设置 CNN 超参数的自解释变量。

  • 任何 Sequential 模型中的第一层都必须指定输入 input_shape,因此我们在 Conv2D上执行此操作。一旦指定了此输入形状,Keras 将自动推断后续层的输入形状。

  • Softmax 输出层有 10 个节点,每个类一个。


4. 编译模型


在开始培训之前,我们需要配置训练过程。我们在编译过程中确定了3个关键因素:

  • 优化器。我们将坚持用一个非常好的默认设置:Adam 基于梯度的优化器。Keras 还有许多其他优化器,你也可以查看。

  • 损失函数。因为我们使用的是 SoftMax 输出层,所以我们将使用交叉熵损失。Keras 区分 binary_crossentropy (2类)和 categorical_crossentropy(>2 类),因此我们将使用后者。查看所有的 Keras 损失函数.

  • 度量列表。因为这是一个分类问题,所以我们只会有关于准确度度量的 Keras 报告。


下面是编译的样子:

走起!

5. 训练模型


在 Keras 中训练模型实际上只包括调用 fit() 和指定一些参数。有很多可能的参数,但我们只提供这些:

  • 训练数据(图像和标签),通常分别称为 X 和 Y。

  • 训练的 epoch 数(整个数据集的迭代次数)。

  • 验证数据(或测试数据),在训练期间用于根据以前从未见过的数据定期测量网络性能。


有一件事我们必须小心:Keras 期望训练目标是 10 维向量,因为我们的 Softmax 输出层中有 10 个节点。现在,我们的 train_labels 和 test_labels 数组包含表示每个图像的类的单个整数:

很方便,Keras 有一个实用的方法来解决这个确切的问题:to_categorical。它将整数类数组转换为一个独热向量数组。例如,2 将变为[0, 0, 1, 0, 0, 0, 0, 0, 0, 0](它是从零索引)。


这就是它的样子:

我们现在可以把所有的东西放在一起训练我们的网络:

在完整 MNIST 数据集上运行该代码可以得到如下结果:

我们用这个简单的 CNN 达到了 97.4% 的测试精度!


6. 使用模型


既然我们有了一个有效的、经过训练的模型,让我们来使用它。我们要做的第一件事是将它保存到磁盘上,这样我们就可以随时加载它:

通过重建模型并加载保存的权重,我们现在可以在任何需要的时候重新加载经过训练的模型:

使用经过训练的模型进行预测很容易:我们将输入数组传递给 predict(),它返回一个输出数组。请记住,我们网络的输出是 10 个概率(因为 softmax),所以我们将使用 np.argmax() 将这些转换为实际数字。


8. 扩展


我们还可以做更多的实验并改进我们的网络 - 在这个官方的 Keras MNIST CNN 例子中,他们在 12 个 epochs 后达到了99.25% 的测试精度。你可以对我们的 CNN 进行修改的一些例子包括:


网络深度


如果我们添加或删除 Convolutional 层会发生什么?这将如何影响训练和/或模型的最终性能?


Dropout


如果我们尝试添加通常用来防止过拟合的 Dropout 层会发生什么?


全连接层


如果我们在 Convolutional 输出和最终的 Softmax 层之间添加全连接层会发生什么?这是 CNNs 中用于计算机视觉的常见操作。


卷积参数


如果我们使用 Conv2D 参数会发生什么?例如:


结语


你已经用 Keras 实现了你的第一个 CNN!我们通过首个简单网络获得了 97.4% 的测试精度。我在下面再加一次完整的源代码供你参考。


你可能感兴趣的扩展阅读包括:

  • 官方的 Keras 入门指南。

  • 我关于导出训练 CNNs 的反向传播算法的文章。

  • Keras 示例集锦.

  • 更多关于神经网络的文章.


谢谢你的阅读!完整的源代码如下。


完整代码

           

英文原文:https://victorzhou.com/blog/keras-cnn-tutorial/ 
译者:青书
登录查看更多
24

相关内容

一份简明有趣的Python学习教程,42页pdf
专知会员服务
76+阅读 · 2020年6月22日
Sklearn 与 TensorFlow 机器学习实用指南,385页pdf
专知会员服务
129+阅读 · 2020年3月15日
《深度学习》圣经花书的数学推导、原理与Python代码实现
《强化学习—使用 Open AI、TensorFlow和Keras实现》174页pdf
专知会员服务
136+阅读 · 2020年3月1日
【模型泛化教程】标签平滑与Keras, TensorFlow,和深度学习
专知会员服务
20+阅读 · 2019年12月31日
基于TensorFlow和Keras的图像识别
Python程序员
16+阅读 · 2019年6月24日
直白介绍卷积神经网络(CNN)
算法与数学之美
13+阅读 · 2019年1月23日
干货 | 用 Keras 实现图书推荐系统
AI科技评论
11+阅读 · 2018年12月15日
卷积神经网络概述及Python实现
云栖社区
4+阅读 · 2018年9月1日
【干货】使用Pytorch实现卷积神经网络
专知
13+阅读 · 2018年5月12日
基于Keras进行迁移学习
论智
12+阅读 · 2018年5月6日
一个小例子带你轻松Keras图像分类入门
云栖社区
4+阅读 · 2018年1月24日
TensorFlow实现神经网络入门篇
机器学习研究会
10+阅读 · 2017年11月19日
深度学习实战(二)——基于Keras 的深度学习
乐享数据DataScientists
15+阅读 · 2017年7月13日
Single-frame Regularization for Temporally Stable CNNs
Adaptive Neural Trees
Arxiv
4+阅读 · 2018年12月10日
Arxiv
12+阅读 · 2018年1月28日
Arxiv
6+阅读 · 2018年1月11日
VIP会员
相关VIP内容
相关资讯
基于TensorFlow和Keras的图像识别
Python程序员
16+阅读 · 2019年6月24日
直白介绍卷积神经网络(CNN)
算法与数学之美
13+阅读 · 2019年1月23日
干货 | 用 Keras 实现图书推荐系统
AI科技评论
11+阅读 · 2018年12月15日
卷积神经网络概述及Python实现
云栖社区
4+阅读 · 2018年9月1日
【干货】使用Pytorch实现卷积神经网络
专知
13+阅读 · 2018年5月12日
基于Keras进行迁移学习
论智
12+阅读 · 2018年5月6日
一个小例子带你轻松Keras图像分类入门
云栖社区
4+阅读 · 2018年1月24日
TensorFlow实现神经网络入门篇
机器学习研究会
10+阅读 · 2017年11月19日
深度学习实战(二)——基于Keras 的深度学习
乐享数据DataScientists
15+阅读 · 2017年7月13日
Top
微信扫码咨询专知VIP会员