TensorFlow实例: 手写汉字识别

2017 年 11 月 10 日 机器学习研究会

MNIST手写数字数据集通常做为深度学习的练习数据集,这个数据集恐怕早已经被大家玩坏了。识别手写汉字要把识别英文、数字难上很多。首先,英文字符的分类少,总共10+26*2;而中文总共50,000多汉字,常用的就有3000多。其次,汉字有书法,每个人书写风格多样。



本文目标是利用TensorFlow做一个简单的图像分类器,在比较大的数据集上,尽可能高效地做图像相关处理,从Train,Validation到Inference,是一个比较基本的Example, 从一个基本的任务学习如果在TensorFlow下做高效地图像读取,基本的图像处理,整个项目很简单,但其中有一些trick,在实际项目当中有很大的好处, 比如绝对不要一次读入所有的 的数据到内存(尽管在Mnist这类级别的例子上经常出现)…




最开始看到是这篇blog里面的TensorFlow练习22: 手写汉字识别

http://link.zhihu.com/?target=http%3A//blog.topspeedsnail.com/archives/10897


但是这篇文章只用了140训练与测试,试了下代码 很快,但是当扩展到所有的时,发现32g的内存都不够用,这才注意到原文中都是用numpy,会先把所有的数据放入到内存,但这个不必须的,无论在MXNet还是TensorFlow中都是不必 须的,MXNet使用的是DataIter,会在程序运行的过程中异步读取数据,TensorFlow也是这样的,TensorFlow封装了高级的api,用来做数据的读取,比如TFRecord,还有就是从filenames中读取, 来异步读取文件,然后做shuffle batch,再feed到模型的Graph中来做模型参数的更新。具体在tf如何做数据的读取可以看看reading data in tensorflow


http://link.zhihu.com/?target=https%3A//www.tensorflow.org/how_tos/reading_data/





我会拿到所有的数据集来做训练与测试,算作是对斗大的熊猫上面那篇文章的一个扩展。


Batch Generate

数据集来自于中科院自动化研究所,感谢分享精神!!!具体下载:

wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zip

wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip


解压后发现是一些gnt文件,然后用了斗大的熊猫里面的代码,将所有文件都转化为对应label目录下的所有png的图片。(注意在HWDB1.1trn_gnt.zip解压后是alz文件,需要再次解压 我在mac没有找到合适的工具,windows上有alz的解压工具)。


import os

import numpy as np

import struct

from PIL import Image



data_dir = '../data'

train_data_dir = os.path.join(data_dir, 'HWDB1.1trn_gnt')

test_data_dir = os.path.join(data_dir, 'HWDB1.1tst_gnt')



def read_from_gnt_dir(gnt_dir=train_data_dir):

    def one_file(f):

        header_size = 10

        while True:

            header = np.fromfile(f, dtype='uint8', count=header_size)

            if not header.size: break

            sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)

            tagcode = header[5] + (header[4]<<8)

            width = header[6] + (header[7]<<8)

            height = header[8] + (header[9]<<8)

            if header_size + width*height != sample_size:

                break

            image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))

            yield image, tagcode

    for file_name in os.listdir(gnt_dir):

        if file_name.endswith('.gnt'):

            file_path = os.path.join(gnt_dir, file_name)

            with open(file_path, 'rb') as f:

                for image, tagcode in one_file(f):

                    yield image, tagcode

char_set = set()

for _, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir):

    tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')

    char_set.add(tagcode_unicode)

char_list = list(char_set)

char_dict = dict(zip(sorted(char_list), range(len(char_list))))

print len(char_dict)

import pickle

f = open('char_dict', 'wb')

pickle.dump(char_dict, f)

f.close()

train_counter = 0

test_counter = 0

for image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir):

    tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')

    im = Image.fromarray(image)

    dir_name = '../data/train/' + '%0.5d'%char_dict[tagcode_unicode]

    if not os.path.exists(dir_name):

        os.mkdir(dir_name)

    im.convert('RGB').save(dir_name+'/' + str(train_counter) + '.png')

    train_counter += 1

for image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir):

    tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')

    im = Image.fromarray(image)

    dir_name = '../data/test/' + '%0.5d'%char_dict[tagcode_unicode]

    if not os.path.exists(dir_name):

        os.mkdir(dir_name)

    im.convert('RGB').save(dir_name+'/' + str(test_counter) + '.png')

    test_counter += 1


处理好的数据,放到了云盘,大家可以直接在我的云盘来下载处理好的数据集HWDB1. 这里说明下,char_dict是汉字和对应的数字label的记录。


http://link.zhihu.com/?target=https%3A//pan.baidu.com/s/1o84jIrg


得到数据集后,就要考虑如何读取了,一次用numpy读入内存在很多小数据集上是可以行的,但是在稍微大点的数据集上内存就成了瓶颈,但是不要害怕,TensorFlow有自己的方法:


def batch_data(file_labels,sess, batch_size=128):

    image_list = [file_label[0] for file_label in file_labels]

    label_list = [int(file_label[1]) for file_label in file_labels]

    print 'tag2 {0}'.format(len(image_list))

    images_tensor = tf.convert_to_tensor(image_list, dtype=tf.string)

    labels_tensor = tf.convert_to_tensor(label_list, dtype=tf.int64)

    input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor])


    labels = input_queue[1]

    images_content = tf.read_file(input_queue[0])

    # images = tf.image.decode_png(images_content, channels=1)

    images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32)

    # images = images / 256

    images =  pre_process(images)

    # print images.get_shape()

    # one hot

    labels = tf.one_hot(labels, 3755)

    image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,min_after_dequeue=10000)

    # print 'image_batch', image_batch.get_shape()


    coord = tf.train.Coordinator()

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    return image_batch, label_batch, coord, threads


简单介绍下,首先你需要得到所有的图像的path和对应的label的列表,利用tf.convert_to_tensor转换为对应的tensor, 利用tf.train.slice_input_producer将image_list ,label_list做一个slice处理,然后做图像的读取、预处理,以及label的one_hot表示,然后就是传到tf.train.shuffle_batch产生一个个shuffle batch,这些就可以feed到你的 模型。 slice_input_producer和shuffle_batch这类操作内部都是基于queue,是一种异步的处理方式,会在设备中开辟一段空间用作cache,不同的进程会分别一直往cache中塞数据 和取数据,保证内存或显存的占用以及每一个mini-batch不需要等待,直接可以从cache中获取。





Data Augmentation

由于图像场景不复杂,只是做了一些基本的处理,包括图像翻转,改变下亮度等等,这些在TensorFlow里面有现成的api,所以尽量使用TensorFlow来做相关的处理:


def pre_process(images):

    if FLAGS.random_flip_up_down:

        images = tf.image.random_flip_up_down(images)

    if FLAGS.random_flip_left_right:

        images = tf.image.random_flip_left_right(images)

    if FLAGS.random_brightness:

        images = tf.image.random_brightness(images, max_delta=0.3)

    if FLAGS.random_contrast:

        images = tf.image.random_contrast(images, 0.8, 1.2)

    new_size = tf.constant([FLAGS.image_size,FLAGS.image_size], dtype=tf.int32)

    images = tf.image.resize_images(images, new_size)

    return images



Build Graph

这里很简单的构造了一个两个卷积+一个全连接层的网络,没有做什么更深的设计,感觉意义不大,设计了一个dict,用来返回后面要用的所有op,还有就是为了方便再训练中查看loss和accuracy, 没有什么特别的,很容易理解, labels 为None时 方便做inference。


def network(images, labels=None):

    endpoints = {}

    conv_1 = slim.conv2d(images, 32, [3,3],1, padding='SAME')

    max_pool_1 = slim.max_pool2d(conv_1, [2,2],[2,2], padding='SAME')

    conv_2 = slim.conv2d(max_pool_1, 64, [3,3],padding='SAME')

    max_pool_2 = slim.max_pool2d(conv_2, [2,2],[2,2], padding='SAME')

    flatten = slim.flatten(max_pool_2)

    out = slim.fully_connected(flatten,3755, activation_fn=None)

    global_step = tf.Variable(initial_value=0)

    if labels is not None:

        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(out, labels))

        train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss, global_step=global_step)

        accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(out, 1), tf.argmax(labels, 1)), tf.float32))

        tf.summary.scalar('loss', loss)

        tf.summary.scalar('accuracy', accuracy)

        merged_summary_op = tf.summary.merge_all()

    output_score = tf.nn.softmax(out)

    predict_val_top3, predict_index_top3 = tf.nn.top_k(output_score, k=3)


    endpoints['global_step'] = global_step

    if labels is not None:

        endpoints['labels'] = labels

        endpoints['train_op'] = train_op

        endpoints['loss'] = loss

        endpoints['accuracy'] = accuracy

        endpoints['merged_summary_op'] = merged_summary_op

    endpoints['output_score'] = output_score

    endpoints['predict_val_top3'] = predict_val_top3

    endpoints['predict_index_top3'] = predict_index_top3

    return endpoints


Train

train函数包括从已有checkpoint中restore,得到step,快速恢复训练过程,训练主要是每一次得到mini-batch,更新参数,每隔eval_steps后做一次train batch的eval,每隔save_steps 后保存一次checkpoint。


转自:大数据挖掘DT数据分析


完整内容请点击“阅读原文”

登录查看更多
8

相关内容

汉字识别指通过扫描图像识别汉字的技术。 单个汉字识读或辨识请至: 生僻字
【2020新书】实战R语言4,323页pdf
专知会员服务
98+阅读 · 2020年7月1日
干净的数据:数据清洗入门与实践,204页pdf
专知会员服务
160+阅读 · 2020年5月14日
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
【课程】伯克利2019全栈深度学习课程(附下载)
专知会员服务
54+阅读 · 2019年10月29日
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
160+阅读 · 2019年10月28日
CNN与RNN中文文本分类-基于TensorFlow 实现
七月在线实验室
13+阅读 · 2018年10月30日
干货——图像分类(下)
计算机视觉战队
14+阅读 · 2018年8月28日
[机器学习] 用KNN识别MNIST手写字符实战
机器学习和数学
4+阅读 · 2018年5月13日
手把手教你搭建caffe及手写数字识别
七月在线实验室
12+阅读 · 2017年11月22日
tensorflow项目学习路径
数据挖掘入门与实战
22+阅读 · 2017年11月19日
tensorflow LSTM + CTC实现端到端OCR
机器学习研究会
26+阅读 · 2017年11月16日
深度学习入门篇--手把手教你用 TensorFlow 训练模型
全球人工智能
4+阅读 · 2017年10月21日
Bidirectional Attention for SQL Generation
Arxiv
4+阅读 · 2018年6月21日
Arxiv
6+阅读 · 2018年2月6日
Arxiv
25+阅读 · 2017年12月6日
VIP会员
相关VIP内容
【2020新书】实战R语言4,323页pdf
专知会员服务
98+阅读 · 2020年7月1日
干净的数据:数据清洗入门与实践,204页pdf
专知会员服务
160+阅读 · 2020年5月14日
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
【课程】伯克利2019全栈深度学习课程(附下载)
专知会员服务
54+阅读 · 2019年10月29日
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
160+阅读 · 2019年10月28日
相关资讯
CNN与RNN中文文本分类-基于TensorFlow 实现
七月在线实验室
13+阅读 · 2018年10月30日
干货——图像分类(下)
计算机视觉战队
14+阅读 · 2018年8月28日
[机器学习] 用KNN识别MNIST手写字符实战
机器学习和数学
4+阅读 · 2018年5月13日
手把手教你搭建caffe及手写数字识别
七月在线实验室
12+阅读 · 2017年11月22日
tensorflow项目学习路径
数据挖掘入门与实战
22+阅读 · 2017年11月19日
tensorflow LSTM + CTC实现端到端OCR
机器学习研究会
26+阅读 · 2017年11月16日
深度学习入门篇--手把手教你用 TensorFlow 训练模型
全球人工智能
4+阅读 · 2017年10月21日
相关论文
Bidirectional Attention for SQL Generation
Arxiv
4+阅读 · 2018年6月21日
Arxiv
6+阅读 · 2018年2月6日
Arxiv
25+阅读 · 2017年12月6日
Top
微信扫码咨询专知VIP会员