# -*- coding: utf-8 -*- import codecs import os import numpy as np import tensorflow as tf from nltk.translate.bleu_score import corpus_bleu from Transformer.transformer import Transformer from Transformer.corpora.data_loader import Data_helper from Transformer.config.hyperparams import Hyperparams as pm from tqdm import tqdm from collections import Counter class Transformer_interface(object): def __init__(self): if os.path.exists(pm.DECODER_VOCAB): data_helper = Data_helper() # Load vocabulary self.de2idx, self.idx2de = data_helper.load_vocab(pm.DECODER_VOCAB) self.en2idx, self.idx2en = data_helper.load_vocab(pm.ENCODER_VOCAB) else: self.build_vocabulary(pm.source_train, pm.DECODER_VOCAB) self.build_vocabulary(pm.target_train, pm.ENCODER_VOCAB) def train(self): # Construct model model = Transformer() print("Graph loaded") init = tf.global_variables_initializer() config = tf.ConfigProto() config.gpu_options.allow_growth = True # Start training sv = tf.train.Supervisor(logdir=pm.logdir, save_model_secs=0, init_op=init) saver = sv.saver with sv.managed_session(config=config) as sess: for epoch in range(1, pm.num_epochs + 1): if sv.should_stop(): break for _ in tqdm(range(model.num_batch), total=model.num_batch, ncols=70, leave=False, unit='b'): sess.run(model.optimizer) gs = sess.run(model.global_step) saver.save(sess, pm.logdir + '/model_epoch_{}_global_step_{}'.format(epoch, gs)) print("MSG : Done for training!") def evaluate(self): # Load graph model = Transformer(trainable=False) print("Graph loaded") # Load data X, Sources, Targets = model.data_helper.load_test_datasets() # Start testing sv = tf.train.Supervisor() saver = sv.saver with sv.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: saver.restore(sess, tf.train.latest_checkpoint(pm.logdir)) print("Restored!") # Load Model mname = codecs.open(pm.logdir + '/checkpoint', 'r', encoding='utf-8').read().split('"')[1] # Inference if not os.path.exists('results'): os.mkdir('results') with codecs.open("results/" + mname, "w", encoding="utf-8") as f: list_of_refs, hypothesis = [], [] num_batch = len(X) // pm.batch_size for i in range(num_batch): # Get mini batches x = X[i * pm.batch_size: (i + 1) * pm.batch_size] sources = Sources[i * pm.batch_size: (i + 1) * pm.batch_size] targets = Targets[i * pm.batch_size: (i + 1) * pm.batch_size] # Auto-regressive inference preds = np.zeros((pm.batch_size, pm.maxlen), dtype=np.int32) for j in range(pm.maxlen): pred = sess.run(model.predicts, {model.x: x, model.y: preds}) preds[:, j] = pred[:, j] for source, target, pred in zip(sources, targets, preds): res = " ".join(self.idx2en[idx] for idx in pred).split("<EOS>")[0].strip() f.write("- source: {}\n".format(source)) f.write("- ground truth: {}\n".format(target)) f.write("- predict: {}\n\n".format(res)) f.flush() # Bleu Score ref = target.split() predicts = res.split() if len(ref) > pm.min_word_count and len(predicts) > pm.min_word_count: list_of_refs.append([ref]) hypothesis.append(predicts) score = corpus_bleu(list_of_refs, hypothesis) f.write("Bleu Score = {}".format(100 * score)) print("MSG : Done for testing!") def build_vocabulary(self, path, fname): files = codecs.open(path, 'r', encoding='utf-8').read() words = files.split() wordcount = Counter(words) if not os.path.exists('vocabulary'): os.mkdir('vocabulary') with codecs.open(fname, 'w', encoding='utf-8') as f: f.write("{}\t1000000000\n{}\t1000000000\n{}\t1000000000\n{}\t1000000000\n".format("<PAD>", "<UNK>", "<SOS>", "<EOS>")) for word, count in wordcount.most_common(len(wordcount)): f.write("{}\t{}\n".format(word, count)) if __name__ == '__main__': interface = Transformer_interface() interface.train() # interface.evaluate()