点击上方“专知”关注获取更多AI知识!
基于DL4J的CNN、AutoEncoder、RNN、Word2Vec等模型的实现
MNIST由手写数字图片组成,包含0-9十种数字,常被用作测试机器学习算法性能的基准数据集。MNIST包含了一个有60000张图片的训练集和一个有10000张图片的测试集。深度学习在MNIST上可以达到99.7%的准确率。
Deeplearning4j中直接集成了MNIST数据集,例如可以直接用下面的代码加载训练集和测试集:
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed); DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
本教程使用具有1个隐藏层的MLP作为网络的结构,使用RELU作为隐藏层的激活函数,使用SOFTMAX作为输出层的激活函数。
从图中可以看出,网络具有输入层、隐藏层和输出层一共3层,但在代码编写时,会将该网络看作由2个层组成(2次变换):
Layer 0: 一个Dense Layer(全连接层),由输入层进行线性变换变为隐藏层,并使用RELU对变换结果进行激活。用公式表达形式为H = relu(XW_0 + b_0)
,其中:
X: 输入层,是形状为[batch_size, input_dim]的矩阵,矩阵的每行对应一个样本,每列对应一个特征(一个像素)
H: 隐藏层的输出,是形状为[batch_size, hidden_dim]的矩阵,矩阵的每行对应一个样本隐藏层的输出
relu: 使用RELU激活函数进行激活
W_0: 形状为[input_dim, hidden_dim]的矩阵,是全连接层线性变换的参数
b_0: 形状为[hidden_dim]的矩阵,是全连接层线性变换的参数(偏置)
Layer 1: 一个Dense Layer(全连接层),由隐藏层进行线性变换为输出层,并使用SOFTMAX对变换结果进行激活。用公式表达形式为:OUTPUT = softmax(HW_1 + b_1)
,其中:
OUTPUT: 输出层,是形状为[batch_size, output_dim]的矩阵,矩阵的每行对应一个样本,每列对应样本属于某类的概率。例如该例子中第0列表示输入手写数字为1的概率。
softmax: 使用SOFTMAX激活函数进行激活
W_1: 形状为[hidden_dim, output_dim]的矩阵,是全连接层线性变换的参数
b_1: 形状为[output_dim]的矩阵,是全连接层线性变换的参数(偏置)
神经网络的训练过程,即神经网络参数的调整过程。待参数能够很好地预测测试集中样本的类别(label),神经网络就训练成功了。
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;import org.slf4j.LoggerFactory;/** * 本示例使用Deeplearning4j构建了一个多层感知器(MLP)来进行手写数字(MNIST)的识别 * 该示例中的神经网络只有1个隐藏层 * * 输入层的维度是numRows*numColumns(图像像素行数*图像像素列数),即每个手写数字图像的像素数量(28*28) * 隐藏层的大小为1000,使用RELU作为激活函数 * 输出层为SOFTMAX层,用于表示输入图像属于每个分类的概率(概率总和为1) * */public class MLPMnistSingleLayerExample {
private static Logger log = LoggerFactory.getLogger(MLPMnistSingleLayerExample.class);
public static void main(String[] args) throws Exception {
//number of rows and columns in the input pictures final int numRows = 28;
final int numColumns = 28;
int outputNum = 10; // 手写字符类别的数量 int batchSize = 128; // batch大小,一个batch中的输入使用相同的神经网络参数 int rngSeed = 123; // 设置一个随机种子,使得每次跑程序获得的随机值相同 int numEpochs = 15; // 训练时每扫描一遍数据集算一个Epoch //Deeplearning4j内置的MNIST数据集 DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed); DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed); log.info("Build model...."); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(rngSeed) // 为模型设置随机种子 // 使用随机梯度下降作为优化算法 .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .iterations(1) .learningRate(0.006) // 设置学习速率 .updater(Updater.NESTEROVS) .regularization(true).l2(1e-4) //设置L2正则系数,设置L2正则可以降低过拟合的程度 .list() //开始构建MLP网络(多层感知器) .layer(0, new DenseLayer.Builder() //设置第一个Dense层 .nIn(numRows * numColumns) //输入为28*28 .nOut(1000) //输出为1000 .activation(Activation.RELU) //使用RELU激活 .weightInit(WeightInit.XAVIER) //设置初始化方法 .build()) .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) //设置第二个Dense层,OutputLayer也是Dense层 .nIn(1000) //输入为1000 .nOut(outputNum) //输出为10,即手写数字的类别数量 .activation(Activation.SOFTMAX) //使用SOFTMAX激活 .weightInit(WeightInit.XAVIER) .build()) .pretrain(false).backprop(true) //进行反向传播,不进行预训练 .build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); //每隔1个iteration就输出一次score model.setListeners(new ScoreIterationListener(1)); log.info("Train model....");
for( int i=0; i<numEpochs; i++ ){ model.fit(mnistTrain); } log.info("Evaluate model...."); Evaluation eval = new Evaluation(outputNum); //创建一个评价器 while(mnistTest.hasNext()){ DataSet next = mnistTest.next(); INDArray output = model.output(next.getFeatureMatrix()); //模型的预测结果 eval.eval(next.getLabels(), output); //根据真实的结果和模型的预测结果对模型进行评价 } log.info(eval.stats()); log.info("****************Example finished********************"); } }
运行代码,输出如下:
明天请继续关注“DeepLearning4j”教程。
完整系列搜索查看,请PC登录
www.zhuanzhi.ai, 搜索“DeepLearning4j”即可得。
对DeepLearning4j教程感兴趣的同学,欢迎进入我们的专知DeepLearning4j主题群一起交流、学习、讨论,扫一扫如下群二维码即可进入:
群满,请扫描小助手,加入进群~
了解使用专知-获取更多AI知识!
阅读更多专知干货:
【干货】RL-GAN For NLP: 强化学习在生成对抗网络文本生成中扮演的角色
欢迎转发分享到微信群和朋友圈!
获取更多关于机器学习以及人工智能知识资料,请访问www.zhuanzhi.ai, 或者点击阅读原文,即可得到!
-END-
欢迎使用专知
专知,一个新的认知方式!目前聚焦在人工智能领域为AI从业者提供专业可信的知识分发服务, 包括主题定制、主题链路、搜索发现等服务,帮你又好又快找到所需知识。
使用方法>>访问www.zhuanzhi.ai, 或点击文章下方“阅读原文”即可访问专知
中国科学院自动化研究所专知团队
@2017 专知
专 · 知
关注我们的公众号,获取最新关于专知以及人工智能的资讯、技术、算法、深度干货等内容。扫一扫下方关注我们的微信公众号。
点击“阅读原文”,使用专知!