# -*- coding: utf-8 -*- #/usr/bin/python2 ''' By kyubyong park. kbpark.linguist@gmail.com. www.github.com/kyubyong/neural_japanese_transliterator ''' from __future__ import print_function import tensorflow as tf from tqdm import tqdm from data_load import get_batch, load_vocab, load_train_data from hyperparams import Hyperparams as hp from utils import shift_by_one from modules import embed from networks import encode, decode from data_load import * roma2idx, idx2roma, surf2idx, idx2surf = load_vocab() class Graph: def __init__(self, is_training=True): self.graph = tf.Graph() with self.graph.as_default(): if is_training: self.x, self.y, self.num_batch = get_batch() else: # Evaluation self.x = tf.placeholder(tf.int32, shape=(None, hp.max_len,)) self.y = tf.placeholder(tf.int32, shape=(None, hp.max_len,)) # Character Embedding for x self.enc = embed(self.x, len(roma2idx), hp.embed_size, scope="emb_x") # Encoder self.memory = encode(self.enc, is_training=is_training) # Character Embedding for decoder_inputs self.decoder_inputs = shift_by_one(self.y) self.dec = embed(self.decoder_inputs, len(surf2idx), hp.embed_size, scope="emb_decoder_inputs") # Decoder self.outputs = decode(self.dec, self.memory, len(surf2idx), is_training=is_training) # (N, T', hp.n_mels*hp.r) self.logprobs = tf.log(tf.nn.softmax(self.outputs)+1e-10) self.preds = tf.arg_max(self.outputs, dimension=-1) if is_training: self.loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.y, logits=self.outputs) self.istarget = tf.to_float(tf.not_equal(self.y, tf.zeros_like(self.y))) # masking self.mean_loss = tf.reduce_sum(self.loss * self.istarget) / (tf.reduce_sum(self.istarget)) # Training Scheme self.global_step = tf.Variable(0, name='global_step', trainable=False) self.optimizer = tf.train.AdamOptimizer(learning_rate=hp.lr) self.train_op = self.optimizer.minimize(self.mean_loss, global_step=self.global_step) # Summary tf.summary.scalar('mean_loss', self.mean_loss) self.merged = tf.summary.merge_all() def main(): g = Graph(); print("Training Graph loaded") with g.graph.as_default(): # Training sv = tf.train.Supervisor(logdir=hp.logdir, save_model_secs=0) with sv.managed_session() as sess: for epoch in range(1, hp.num_epochs+1): if sv.should_stop(): break for step in tqdm(range(g.num_batch), total=g.num_batch, ncols=70, leave=False, unit='b'): sess.run(g.train_op) # Write checkpoint files at every epoch gs = sess.run(g.global_step) sv.saver.save(sess, hp.logdir + '/model_epoch_%02d_gs_%d' % (epoch, gs)) if __name__ == '__main__': main() print("Done")