from __future__ import division from builtins import bytes import os import argparse import math import codecs import torch from table.Translator import Translator import table.IO import opts from itertools import takewhile, count try: from itertools import zip_longest except ImportError: from itertools import izip_longest as zip_longest from path import Path import glob import json from tqdm import tqdm from lib.dbengine import DBEngine from lib.query import Query import pdb parser = argparse.ArgumentParser(description='evaluate.py') opts.translate_opts(parser) opt = parser.parse_args() torch.cuda.set_device(opt.gpu) opt.anno = os.path.join( opt.anno_data_path, '{}.jsonl'.format(opt.split)) opt.source_file = os.path.join( opt.data_path, '{}.jsonl'.format(opt.split)) opt.db_file = os.path.join(opt.data_path, '{}.db'.format(opt.split)) opt.pre_word_vecs = os.path.join(opt.data_path, 'embedding') print('Evaluating model on the {} set.'.format(opt.split)) def main(): dummy_parser = argparse.ArgumentParser(description='train.py') opts.model_opts(dummy_parser) opts.train_opts(dummy_parser) dummy_opt = dummy_parser.parse_known_args([])[0] engine = DBEngine(opt.db_file) with codecs.open(opt.source_file, "r", "utf-8") as corpus_file: sql_list = [json.loads(line)['sql'] for line in corpus_file] js_list = table.IO.read_anno_json(opt.anno) prev_best = (None, None) for fn_model in glob.glob(opt.model_path): opt.model = fn_model translator = Translator(opt, dummy_opt.__dict__) data = table.IO.TableDataset(js_list, translator.fields, None, False) test_data = table.IO.OrderedIterator( dataset=data, device=opt.gpu, batch_size=opt.batch_size, train=False, sort=True, sort_within_batch=False) # inference if opt.beam_search: print('Using execution guidance for inference.') r_list = [] for batch in test_data: r_list += translator.translate(batch, js_list, sql_list) r_list.sort(key=lambda x: x.idx) assert len(r_list) == len(js_list), 'len(r_list) != len(js_list): {} != {}'.format( len(r_list), len(js_list)) # evaluation for pred, gold, sql_gold in zip(r_list, js_list, sql_list): pred.eval(gold, sql_gold, engine) print('Results:') for metric_name in ('all', 'exe'): c_correct = sum((x.correct[metric_name] for x in r_list)) print('{}: {} / {} = {:.2%}'.format(metric_name, c_correct, len(r_list), c_correct / len(r_list))) if metric_name == 'all' and (prev_best[0] is None or c_correct > prev_best[1]): prev_best = (fn_model, c_correct) if (opt.split == 'dev') and (prev_best[0] is not None): with codecs.open(os.path.join(opt.data_path, 'dev_best.txt'), 'w', encoding='utf-8') as f_out: f_out.write('{}\n'.format(prev_best[0])) if __name__ == "__main__": main()