import tensorflow as tf from tensorflow.contrib import layers from tensorflow.python.layers.core import Dense from data_generator_att import * START_TOKEN = 0 END_TOKEN = 1 UNK_TOKEN = 2 VOCAB = {'<S>': 0, '</S>': 1, '<UNK>': 2, '0': 3, '1': 4, '2': 5, '3': 6, '4': 7, '5': 8, '6': 9, '7': 10, '8': 11, '9': 12} VOCAB_SIZE = len(VOCAB) BATCH_SIZE = 32 RNN_UNITS = 256 TRAIN_STEP = 1000000 IMAGE_HEIGHT = 32 MAXIMUM__DECODE_ITERATIONS = 20 DISPLAY_STEPS = 100 LOGS_PATH = 'logs_path' CKPT_DIR = 'save_model' image = tf.placeholder(tf.float32, shape=(None, IMAGE_HEIGHT, None, 1), name='img_data') train_output = tf.placeholder(tf.int64, shape=[None, None], name='train_output') train_length = tf.placeholder(tf.int32, shape=[None], name='train_length') target_output = tf.placeholder(tf.int64, shape=[None, None], name='target_output') def encoder_net(_image, scope, reuse=None): with tf.variable_scope(scope, reuse=reuse): convolution1 = layers.conv2d(inputs=_image, num_outputs=64, kernel_size=[3, 3], padding='SAME', activation_fn=tf.nn.relu) pool1 = layers.max_pool2d(inputs=convolution1, kernel_size=[2, 2], stride=[2, 2]) convolution2 = layers.conv2d(inputs=pool1, num_outputs=128, kernel_size=[3, 3], padding='SAME', activation_fn=tf.nn.relu) pool2 = layers.max_pool2d(inputs=convolution2, kernel_size=[2, 2], stride=[2, 2]) convolution3 = layers.conv2d(inputs=pool2, num_outputs=256, kernel_size=[3, 3], padding='SAME', activation_fn=tf.nn.relu) convolution4 = layers.conv2d(inputs=convolution3, num_outputs=256, kernel_size=[3, 3], padding='SAME', activation_fn=tf.nn.relu) pool3 = layers.max_pool2d(inputs=convolution4, kernel_size=[2, 1], stride=[2, 1]) convolution5 = layers.conv2d(inputs=pool3, num_outputs=512, kernel_size=[3, 3], padding='SAME', activation_fn=tf.nn.relu) n1 = layers.batch_norm(convolution5) convolution6 = layers.conv2d(inputs=n1, num_outputs=512, kernel_size=[3, 3], padding='SAME', activation_fn=tf.nn.relu) n2 = layers.batch_norm(convolution6) pool4 = layers.max_pool2d(inputs=n2, kernel_size=[2, 1], stride=[2, 1]) convolution7 = layers.conv2d(inputs=pool4, num_outputs=512, kernel_size=[2, 2], padding='VALID', activation_fn=tf.nn.relu) cnn_out = tf.squeeze(convolution7, axis=1) cell = tf.contrib.rnn.GRUCell(num_units=RNN_UNITS) enc_outputs, enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw=cell, cell_bw=cell, inputs=cnn_out, dtype=tf.float32) encoder_outputs = tf.concat(enc_outputs, -1) return encoder_outputs def decode(helper, memory, scope, reuse=None): with tf.variable_scope(scope, reuse=reuse): attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units=RNN_UNITS, memory=memory) cell = tf.contrib.rnn.GRUCell(num_units=RNN_UNITS) attn_cell = tf.contrib.seq2seq.AttentionWrapper(cell, attention_mechanism, attention_layer_size=RNN_UNITS, output_attention=True) output_layer = Dense(units=VOCAB_SIZE) decoder = tf.contrib.seq2seq.BasicDecoder( cell=attn_cell, helper=helper, initial_state=attn_cell.zero_state(dtype=tf.float32, batch_size=BATCH_SIZE), output_layer=output_layer) outputs = tf.contrib.seq2seq.dynamic_decode( decoder=decoder, output_time_major=False, impute_finished=True, maximum_iterations=MAXIMUM__DECODE_ITERATIONS) return outputs def build_compute_graph(): train_output_embed = encoder_net(image, scope='encode_features') pred_output_embed = encoder_net(image, scope='encode_features', reuse=True) output_embed = layers.embed_sequence(train_output, vocab_size=VOCAB_SIZE, embed_dim=VOCAB_SIZE, scope='embed') embeddings = tf.Variable(tf.truncated_normal(shape=[VOCAB_SIZE, VOCAB_SIZE], stddev=0.1), name='decoder_embedding') start_tokens = tf.zeros([BATCH_SIZE], dtype=tf.int64) train_helper = tf.contrib.seq2seq.TrainingHelper(output_embed, train_length) pred_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( embeddings, start_tokens=tf.to_int32(start_tokens), end_token=1) train_outputs = decode(train_helper, train_output_embed, 'decode') #pred_outputs = decode(pred_helper, pred_output_embed, 'decode', reuse=True) pred_outputs = decode(pred_helper, train_output_embed, 'decode', reuse=True) train_decode_result = train_outputs[0].rnn_output[0, :-1, :] pred_decode_result = pred_outputs[0].rnn_output[0, :, :] mask = tf.cast(tf.sequence_mask(BATCH_SIZE * [train_length[0] - 1], train_length[0]), tf.float32) att_loss = tf.contrib.seq2seq.sequence_loss(train_outputs[0].rnn_output, target_output, weights=mask) loss = tf.reduce_mean(att_loss) train_one_step = tf.train.AdadeltaOptimizer().minimize(loss) return loss, train_one_step, train_decode_result, pred_decode_result def train_network(loss, train_one_step, train_decode_result, pred_decode_result): sess = tf.Session() sess.run(tf.global_variables_initializer()) # # First let's load meta graph and restore weights # saver = tf.train.import_meta_graph('./save_model/attention_digit_ocr.model-100.meta') # # saver.restore(sess, tf.train.latest_checkpoint('./save_model')) # tensorboard visualization with tf.name_scope('summaries'): tf.summary.scalar("cost", loss) summary_op = tf.summary.merge_all() writer = tf.summary.FileWriter(LOGS_PATH) # train with sess.as_default(): data_gen = name_training_data_generator(BATCH_SIZE) for i in range(TRAIN_STEP): input_data = data_gen.__next__() train_one_step.run(feed_dict={image: input_data['input'], train_output: input_data['train_output'], target_output: input_data['target_output'], train_length: input_data['train_length']}) if i % DISPLAY_STEPS == 0: summary_loss, loss_result = sess.run([summary_op, loss], feed_dict={image: input_data['input'], train_output: input_data['train_output'], target_output: input_data['target_output'], train_length: input_data['train_length']}) writer.add_summary(summary_loss, i) train_outputs_result = sess.run([train_decode_result], feed_dict={image: input_data['input'], train_output: input_data['train_output'], target_output: input_data['target_output'], train_length: input_data['train_length']}) pred_outputs_result = sess.run([pred_decode_result], feed_dict={image: input_data['input'], train_output: input_data['train_output'], target_output: input_data['target_output']}) print("Step:{}, loss:{}, train_decode:{}, predict_decode:{}, ground_truth:{}". format(i, loss_result, np.argmax(train_outputs_result[0], axis=1), np.argmax(pred_outputs_result[0], axis=1), input_data['target_output'][0])) # save model saver = tf.train.Saver() model_name = "attention_digit_ocr.model" if not os.path.exists(CKPT_DIR): os.makedirs(CKPT_DIR) saver.save(sess, os.path.join(CKPT_DIR, model_name), global_step=i) def main(): loss, train_one_step, train_decode_result, pred_decode_result = build_compute_graph() train_network(loss, train_one_step, train_decode_result, pred_decode_result) if __name__ == '__main__': main()