tensorflow LSTM + CTC实现端到端OCR

2017 年 11 月 16 日 机器学习研究会

最近在做OCR相关的东西,关于OCR真的是有悠久了历史了,最开始用tesseract,然而效果总是不理想,其中字符分割真的是个博大精深的问题,那么多年那么多算法,然而应用到实际总是有诸多问题。比如说非等间距字体的分割,汉字的分割,有光照阴影的图片的字体分割等等,针对特定的问题,特定的算法能有不错的效果,但也仅限于特定问题,很难有一些通用的结果。于是看了Xlvector的博客之后,发现可以端到端来实现OCR,他是基于mxnet的,于是我想把它转到tensorflow这个框架来,顺便还能熟悉一下这个框架。更加细节的实现方法见另一篇 http://ilovin.me/2017-04-23/tensorflow-lstm-ctc-input-output/


生成数据

利用captcha来生成验证码,具体生成验证码的代码,

在公众号 datadw 里 回复 OCR   即可获取。

共生成4-6位包含数字和英文大小写的训练图片128000张和测试图片400张。命名规则就是num_label.png,生成的图片如下图



关于生成数据,再多说一点,可以像Xlvector那样一边生成一边训练,这样样本是无穷的,效果更好。但是实际应用中有限样本的情况还是更多的。

载入数据

两种载入数据方式

pipeline

最开始想通过一个tf.train.string_input_producer来读入所有的文件名,然后以pipline的方式读入,但是由于标签的是不定长的,想通过正则来生成label,一开始是想用py_func来实现,后来发现传入string会有问题,所以最后还是选择生成tf.record文件,关于不定长问题,把比较短的标签在后面补零(0是blank的便签,就是说自己的类别中不能出现0这个类),然后读出每个batch后,再把0去掉。

一次性载入

我这里给一个目录,然后遍历里面所有的文件,等到训练的时候,每一个epoch循环把文件的index给手动shuffle一下,然后就可以每次截取出一个batch来用作输入了

class DataIterator:
   def __init__(self, data_dir):
       self.image_names = []
self.image = []
self.labels=[]
for root, sub_folder, file_list in os.walk(data_dir):
           for file_path in file_list:
               image_name = os.path.join(root,file_path)
self.image_names.append(image_name)
im = cv2.imread(image_name,0).astype(np.float32)/255.
               im = cv2.resize(im,(image_width,image_height))
# transpose to (160*60) and the step shall be 160
               # in this way, each row is a feature vector
               im = im.swapaxes(0,1)
self.image.append(np.array(im))
#image is named as ./<folder>/00000_abcd.png
               code = image_name.split('/')[2].split('_')[1].split('.')[0]
code = [SPACE_INDEX if code == SPACE_TOKEN else maps[c] for c in list(code)]
self.labels.append(code)
print(image_name,' ',code)
@property
   def size(self):
       return len(self.labels)
def input_index_generate_batch(self,index=None):
       if index:
           image_batch=[self.image[i] for i in index]
label_batch=[self.labels[i] for i in index]
else:
           # get the whole data as input
           image_batch=self.image
label_batch=self.labels
def get_input_lens(sequences):
           lengths = np.asarray([len(s) for s in sequences], dtype=np.int64)
return sequences,lengths
batch_inputs,batch_seq_len = get_input_lens(np.array(image_batch))
batch_labels = sparse_tuple_from_label(label_batch)
return batch_inputs,batch_seq_len,batch_labels


需要注意的是tensorflow lstm输入格式的问题,其label tensor应该是稀疏矩阵,所以读取图片和label之后,还要进行一些处理,具体可以看代码

在公众号 datadw 里 回复 OCR   即可获取。

关于载入图片,发现12.8w张图一次读进内存,内存也就涨了5G,如果训练数据加大,还是加一个pipeline来读比较好。

网络结构

然后是网络结构

graph = tf.Graph()
with graph.as_default():
   inputs = tf.placeholder(tf.float32, [None, None, num_features])
labels = tf.sparse_placeholder(tf.int32)
seq_len = tf.placeholder(tf.int32, [None])
# Stacking rnn cells
   stack = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.LSTMCell(FLAGS.num_hidden,state_is_tuple=True) for i in range(FLAGS.num_layers)] , state_is_tuple=True)
# The second output is the last state and we will no use that
   outputs, _ = tf.nn.dynamic_rnn(stack, inputs, seq_len, dtype=tf.float32)
shape = tf.shape(inputs)
batch_s, max_timesteps = shape[0], shape[1]
# Reshaping to apply the same weights over the timesteps
   outputs = tf.reshape(outputs, [-1, FLAGS.num_hidden])
# Truncated normal with mean 0 and stdev=0.1
   W = tf.Variable(tf.truncated_normal([FLAGS.num_hidden,
num_classes],
stddev=0.1),name='W')
b = tf.Variable(tf.constant(0., shape=[num_classes],name='b'))
# Doing the affine projection
   logits = tf.matmul(outputs, W) + b
# Reshaping back to the original shape
   logits = tf.reshape(logits, [batch_s, -1, num_classes])
# Time major
   logits = tf.transpose(logits, (1, 0, 2))
global_step = tf.Variable(0,trainable=False)
loss = tf.nn.ctc_loss(labels=labels,inputs=logits, sequence_length=seq_len)
cost = tf.reduce_mean(loss)
#optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
   #        momentum=FLAGS.momentum).minimize(cost,global_step=global_step)
   optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.initial_learning_rate,
beta1=FLAGS.beta1,beta2=FLAGS.beta2).minimize(loss,global_step=global_step)
# Option 2: tf.contrib.ctc.ctc_beam_search_decoder
   # (it's slower but you'll get better results)
   #decoded, log_prob = tf.nn.ctc_greedy_decoder(logits, seq_len,merge_repeated=False)
   decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len,merge_repeated=False)
# Inaccuracy: label error rate
   lerr = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0], tf.int32), labels))


这里我参考了stackoverflow的一篇帖子写的,

https://stackoverflow.com/questions/38059247/using-tensorflows-connectionist-temporal-classification-ctc-implementation

根据tensorflow 1.0.1的版本做了微调,使用了Adam作为optimizer。

需要注意的是ctc_beam_search_decoder是非常耗时的.


转自: 大数据挖掘DT数据分析


完整内容请点击“阅读原文”

登录查看更多
26

相关内容

OCR (Optical Character Recognition,光学字符识别)是指电子设备(例如扫描仪或数码相机)检查纸上打印的字符,通过检测暗、亮的模式确定其形状,然后用字符识别方法将形状翻译成计算机文字的过程;即,针对印刷体字符,采用光学的方式将纸质文档中的文字转换成为黑白点阵的图像文件,并通过识别软件将图像中的文字转换成文本格式,供文字处理软件进一步编辑加工的技术。
一份循环神经网络RNNs简明教程,37页ppt
专知会员服务
173+阅读 · 2020年5月6日
Transformer文本分类代码
专知会员服务
117+阅读 · 2020年2月3日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
《动手学深度学习》(Dive into Deep Learning)PyTorch实现
专知会员服务
120+阅读 · 2019年12月31日
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
165+阅读 · 2019年10月28日
CNN与RNN中文文本分类-基于TensorFlow 实现
七月在线实验室
13+阅读 · 2018年10月30日
[机器学习] 用KNN识别MNIST手写字符实战
机器学习和数学
4+阅读 · 2018年5月13日
CNN图像风格迁移的原理及TensorFlow实现
数据挖掘入门与实战
5+阅读 · 2018年4月18日
深度学习CTPN+CRNN模型实现图片内文字的定位与识别(OCR)
数据挖掘入门与实战
16+阅读 · 2017年11月25日
利用 TensorFlow 实现排序和搜索算法
机器学习研究会
5+阅读 · 2017年11月23日
tensorflow项目学习路径
数据挖掘入门与实战
22+阅读 · 2017年11月19日
TensorFlow实例: 手写汉字识别
机器学习研究会
8+阅读 · 2017年11月10日
推荐|caffe-orc主流ocr算法:CNN+BLSTM+CTC架构实现!
全球人工智能
19+阅读 · 2017年10月29日
用python和Tesseract实现光学字符识别(OCR)
Python程序员
7+阅读 · 2017年7月18日
A Sketch-Based System for Semantic Parsing
Arxiv
4+阅读 · 2019年9月12日
Arxiv
8+阅读 · 2018年5月1日
Arxiv
5+阅读 · 2018年5月1日
Arxiv
8+阅读 · 2018年1月19日
VIP会员
相关VIP内容
一份循环神经网络RNNs简明教程,37页ppt
专知会员服务
173+阅读 · 2020年5月6日
Transformer文本分类代码
专知会员服务
117+阅读 · 2020年2月3日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
《动手学深度学习》(Dive into Deep Learning)PyTorch实现
专知会员服务
120+阅读 · 2019年12月31日
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
165+阅读 · 2019年10月28日
相关资讯
CNN与RNN中文文本分类-基于TensorFlow 实现
七月在线实验室
13+阅读 · 2018年10月30日
[机器学习] 用KNN识别MNIST手写字符实战
机器学习和数学
4+阅读 · 2018年5月13日
CNN图像风格迁移的原理及TensorFlow实现
数据挖掘入门与实战
5+阅读 · 2018年4月18日
深度学习CTPN+CRNN模型实现图片内文字的定位与识别(OCR)
数据挖掘入门与实战
16+阅读 · 2017年11月25日
利用 TensorFlow 实现排序和搜索算法
机器学习研究会
5+阅读 · 2017年11月23日
tensorflow项目学习路径
数据挖掘入门与实战
22+阅读 · 2017年11月19日
TensorFlow实例: 手写汉字识别
机器学习研究会
8+阅读 · 2017年11月10日
推荐|caffe-orc主流ocr算法:CNN+BLSTM+CTC架构实现!
全球人工智能
19+阅读 · 2017年10月29日
用python和Tesseract实现光学字符识别(OCR)
Python程序员
7+阅读 · 2017年7月18日
Top
微信扫码咨询专知VIP会员