# CNN-LSTM-CTC-OCR # Copyright (C) 2017,2018 Jerod Weinman, Abyaya Lamsal, Benjamin Gafford # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. # mjsynth.py -- Suite of routines for processing MJSynth data stored # as Examples in in TFRecord files. import os import tensorflow as tf import numpy as np import pipeline def get_dataset( args ): """ Get a Dataset from TFRecord files. Parameters: base_dir : Directory containing the TFRecord files file_patterns : List of wildcard patterns for TFRecord files to read num_threads : Number of threads to use for reading and processing buffer_sz : Number of Examples to prefetch and buffer Returns: image : preprocessed image tf.float32 tensor of shape [32, ?, 1] (? = width) width : width (in pixels) of image tf.int32 tensor of shape [] labels : list of indices of characters mapping text->out_charset tf.int32 tensor of shape [?] (? = length+1) length : length of labels (sans -1 EOS token) tf.int32 tensor of shape [] text : ground truth string tf.string tensor of shape [] """ # Extract args [ base_dir, file_patterns, num_threads, buffer_sz ] = args[0:4] # Get filenames as list of tensors tensor_filenames = _get_filenames( base_dir, file_patterns ) # Get filenames into a dataset format ds_filenames = tf.data.Dataset.from_tensor_slices( tensor_filenames ) # Shuffle for some stochasticity ds_filenames = ds_filenames.shuffle( buffer_size=len( tensor_filenames ), reshuffle_each_iteration=True ) dataset = tf.data.TFRecordDataset( ds_filenames, num_parallel_reads=num_threads, buffer_size=buffer_sz ) return dataset def preprocess_fn( data ): """Parse the elements of the dataset""" feature_map = { 'image/encoded' : tf.FixedLenFeature( [], dtype=tf.string, default_value='' ), 'image/labels' : tf.VarLenFeature( dtype=tf.int64 ), 'image/width' : tf.FixedLenFeature( [1], dtype=tf.int64, default_value=1 ), 'image/filename' : tf.FixedLenFeature( [], dtype=tf.string, default_value='' ), 'text/string' : tf.FixedLenFeature( [], dtype=tf.string, default_value='' ), 'text/length' : tf.FixedLenFeature( [1], dtype=tf.int64, default_value=1 ) } features = tf.parse_single_example( data, feature_map ) # Initialize fields according to feature map # Convert to grayscale image = tf.image.decode_jpeg( features['image/encoded'], channels=1 ) width = tf.cast( features['image/width'], tf.int32 ) # for ctc_loss label = tf.serialize_sparse( features['image/labels'] ) # for batching length = features['text/length'] text = features['text/string'] image = preprocess_image( image ) return image, width, label, length, text def element_length_fn( image, image_width, label, label_seq_length, text ): return image_width def postbatch_fn( image, width, label, length, text ): """Post-batching, postprocessing: packs raw tensors into a dictionary for Dataset's iterator output""" # Batching is complete, so now we can re-sparsify our labels for ctc_loss label = tf.cast( tf.deserialize_many_sparse( label, tf.int64 ), tf.int32 ) # Format relevant features for estimator ingestion features = { "image" : image, "width" : width, "length" : length, "text" : text } return features, label def _get_filenames( base_dir, file_patterns=['*.tfrecord'] ): """Get a list of record files""" # List of lists ... data_files = [tf.gfile.Glob( os.path.join( base_dir, file_pattern ) ) for file_pattern in file_patterns] # flatten data_files = [data_file for sublist in data_files for data_file in sublist] return data_files def preprocess_image( image ): """Preprocess image: Rescale and fix image height""" # Rescale from uint8([0,255]) to float([-0.5,0.5]) image = pipeline.rescale_image( image ) # Pad with copy of first row to expand to 32 pixels height first_row = tf.slice( image, [0, 0, 0], [1, -1, -1] ) image = tf.concat( [first_row, image], 0 ) return image