# -*- coding: utf-8 -*- from __future__ import print_function import codecs import os from train import Graph from params import Params as pm from data_loader import load_data, load_vocab import tensorflow as tf import numpy as np from nltk.translate.bleu_score import corpus_bleu def eval(): g = Graph(is_training = False) print("MSG : Graph loaded!") X, Sources, Targets = load_data('test') en2idx, idx2en = load_vocab('en.vocab.tsv') de2idx, idx2de = load_vocab('de.vocab.tsv') with g.graph.as_default(): sv = tf.train.Supervisor() with sv.managed_session(config = tf.ConfigProto(allow_soft_placement = True)) as sess: # load pre-train model sv.saver.restore(sess, tf.train.latest_checkpoint(pm.checkpoint)) print("MSG : Restore Model!") mname = open(pm.checkpoint + '/checkpoint', 'r').read().split('"')[1] if not os.path.exists('Results'): os.mkdir('Results') with codecs.open("Results/" + mname, 'w', 'utf-8') as f: list_of_refs, predict = [], [] # Get a batch for i in range(len(X) // pm.batch_size): x = X[i * pm.batch_size: (i + 1) * pm.batch_size] sources = Sources[i * pm.batch_size: (i + 1) * pm.batch_size] targets = Targets[i * pm.batch_size: (i + 1) * pm.batch_size] # Autoregressive inference preds = np.zeros((pm.batch_size, pm.maxlen), dtype = np.int32) for j in range(pm.maxlen): _preds = sess.run(g.preds, feed_dict = {g.inpt: x, g.outpt: preds}) preds[:, j] = _preds[:, j] for source, target, pred in zip(sources, targets, preds): got = " ".join(idx2de[idx] for idx in pred).split("<EOS>")[0].strip() f.write("- Source: {}\n".format(source)) f.write("- Ground Truth: {}\n".format(target)) f.write("- Predict: {}\n\n".format(got)) f.flush() # Bleu Score ref = target.split() prediction = got.split() if len(ref) > pm.word_limit_lower and len(prediction) > pm.word_limit_lower: list_of_refs.append([ref]) predict.append(prediction) score = corpus_bleu(list_of_refs, predict) f.write("Bleu Score = " + str(100 * score)) if __name__ == '__main__': eval() print("MSG : Done!")