import os import pickle import argparse import shutil import math import sys import model import ggtnn_train import ggtnn_graph_parse from ggtnn_graph_parse import MetadataList, PreppedStory from util import * def helper_trim(bucketed, desired_total): """Trim bucketed fairly so that it has desired_total things total""" cur_total = sum(len(b) for b in bucketed) keep_frac = desired_total/cur_total if keep_frac > 1.0: print("WARNING: Asked to trim to {} items, but was already only {} items. Keeping original length.".format(desired_total, cur_total)) return bucketed keep_amts = [math.floor(len(b) * keep_frac) for b in bucketed] tmp_total = sum(keep_amts) addtl_to_add = desired_total - tmp_total assert addtl_to_add >= 0 keep_amts = [x + (1 if i < addtl_to_add else 0) for i,x in enumerate(keep_amts)] assert sum(keep_amts) == desired_total trimmed_bucketed = [b[:amt] for b,amt in zip(bucketed, keep_amts)] return trimmed_bucketed def main(task_dir, output_format_str, state_width, process_repr_size, dynamic_nodes, mutable_nodes, wipe_node_state, direct_reference, propagate_intermediate, sequence_aggregate_repr, old_aggregate, train_with_graph, train_with_query, outputdir, num_updates, batch_size, learning_rate, dropout_keep, resume, resume_auto, visualize, visualize_snap, visualization_test, validation, validation_interval, evaluate_accuracy, check_mode, stop_at_accuracy, stop_at_loss, stop_at_overfitting, restrict_dataset, train_save_params, batch_adjust, set_exit_status, just_compile, autopickle, pickle_model, unpickle_model, interrupt_file): output_format = model.ModelOutputFormat[output_format_str] with open(os.path.join(task_dir,'metadata.p'),'rb') as f: metadata = pickle.load(f) with open(os.path.join(task_dir,'file_list.p'),'rb') as f: bucketed = pickle.load(f) bucketed = [[os.path.join(task_dir,x) for x in b] for b in bucketed] if restrict_dataset is not None: bucketed = helper_trim(bucketed, restrict_dataset) sentence_length, new_nodes_per_iter, bucket_sizes, wordlist, anslist, graph_node_list, graph_edge_list = metadata eff_anslist = ggtnn_train.get_effective_answer_words(anslist, output_format) if validation is None: validation_buckets = None validation_bucket_sizes = None else: with open(os.path.join(validation,'metadata.p'),'rb') as f: validation_metadata = pickle.load(f) with open(os.path.join(validation,'file_list.p'),'rb') as f: validation_buckets = pickle.load(f) validation_buckets = [[os.path.join(validation,x) for x in b] for b in validation_buckets] validation_bucket_sizes = validation_metadata[2] if direct_reference: word_node_mapping = {wi:ni for wi,word in enumerate(wordlist) for ni,node in enumerate(graph_node_list) if word == node} else: word_node_mapping = {} model_kwargs = dict(num_input_words=len(wordlist), num_output_words=len(eff_anslist), num_node_ids=len(graph_node_list), node_state_size=state_width, num_edge_types=len(graph_edge_list), input_repr_size=100, output_repr_size=100, propose_repr_size=process_repr_size, propagate_repr_size=process_repr_size, new_nodes_per_iter=new_nodes_per_iter, output_format=output_format, final_propagate=5, word_node_mapping=word_node_mapping, dynamic_nodes=dynamic_nodes, nodes_mutable=mutable_nodes, wipe_node_state=wipe_node_state, intermediate_propagate=(5 if propagate_intermediate else 0), sequence_representation=sequence_aggregate_repr, dropout_keep=dropout_keep, use_old_aggregate=old_aggregate, best_node_match_only=True, train_with_graph=train_with_graph, train_with_query=train_with_query, setup=True, check_mode=check_mode) model_kwargs = get_compatible_kwargs(model.Model, model_kwargs) if autopickle is not None: if not os.path.exists(autopickle): os.makedirs(autopickle) model_hash = object_hash(model_kwargs) model_filename = os.path.join(autopickle, "model_{}.p".format(model_hash)) print("Looking for cached model at {}".format(model_filename)) if os.path.isfile(model_filename): print("Loading model from cache") m, stored_kwargs = pickle.load(open(model_filename, 'rb')) assert model_kwargs == stored_kwargs, "Hash collision between models!\nCurrent: {}\nStored: {}".format(model_kwargs,stored_kwargs) else: print("Building model from scratch") m = model.Model(**model_kwargs) print("Saving model to cache") sys.setrecursionlimit(100000) pickle.dump((m,model_kwargs), open(model_filename,'wb'), protocol=pickle.HIGHEST_PROTOCOL) elif unpickle_model is not None: print("Unpickling model...") m = pickle.load(open(unpickle_model, 'rb')) else: m = model.Model(**model_kwargs) if pickle_model is not None: sys.setrecursionlimit(100000) print("Pickling model...") pickle.dump(m, open(pickle_model,'wb'), protocol=pickle.HIGHEST_PROTOCOL) if just_compile: return if learning_rate is not None: m.set_learning_rate(learning_rate) if not os.path.exists(outputdir): os.makedirs(outputdir) if resume_auto: result = find_recent_params(outputdir) if result is not None: start_idx, paramfile = result print("Automatically resuming from {} after iteration {}.".format(paramfile, start_idx)) resume = result else: print("Didn't find anything to resume. Starting from the beginning...") if resume is not None: start_idx, paramfile = resume start_idx = int(start_idx) load_params(m.params, open(paramfile, "rb") ) else: start_idx = 0 if visualize is not False: if visualize is True: source = bucketed else: bucket, story = visualize source = [[bucketed[bucket][story]]] print("Starting to visualize...") ggtnn_train.visualize(m, source, wordlist, eff_anslist, output_format, outputdir, snap=visualize_snap) print("Wrote visualization files to {}.".format(outputdir)) elif evaluate_accuracy: print("Evaluating accuracy...") acc = ggtnn_train.test_accuracy(m, bucketed, bucket_sizes, len(eff_anslist), output_format, batch_size, batch_adjust, (not train_with_query)) print("Obtained accuracy of {}".format(acc)) elif visualization_test: print("Starting visualization test...") ggtnn_train.visualize(m, bucketed, wordlist, eff_anslist, output_format, outputdir, debugmode=True) print("Wrote visualization files to {}.".format(outputdir)) else: print("Starting to train...") status = ggtnn_train.train(m, bucketed, bucket_sizes, len(eff_anslist), output_format, num_updates, outputdir, start_idx, batch_size, validation_buckets, validation_bucket_sizes, stop_at_accuracy, stop_at_loss, stop_at_overfitting, train_save_params, validation_interval, batch_adjust, interrupt_file) if set_exit_status: sys.exit(status.value) parser = argparse.ArgumentParser(description='Train a graph memory network model.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('task_dir', help="Parsed directory for the task to load") parser.add_argument('output_format_str', choices=[x.name for x in model.ModelOutputFormat], help="Output format for the task") parser.add_argument('state_width', type=int, help="Width of node state") parser.add_argument('--process-repr-size', type=int, default=50, help="Width of intermediate representations") parser.add_argument('--mutable-nodes', action="store_true", help="Make nodes mutable") parser.add_argument('--wipe-node-state', action="store_true", help="Wipe node state before the query") parser.add_argument('--direct-reference', action="store_true", help="Use direct reference for input, based on node names") parser.add_argument('--dynamic-nodes', action="store_true", help="Create nodes after each sentence. (Otherwise, create unique nodes at the beginning)") parser.add_argument('--propagate-intermediate', action="store_true", help="Run a propagation step after each sentence") parser.add_argument('--sequence-aggregate-repr', action="store_true", help="Compute the query aggregate representation from the sequence of graphs instead of just the last one") parser.add_argument('--old-aggregate', action="store_true", help="Use the old, incorrect aggregate function") parser.add_argument('--no-graph', dest='train_with_graph', action="store_false", help="Don't train using graph supervision") parser.add_argument('--no-query', dest='train_with_query', action="store_false", help="Don't train using query supervision") parser.add_argument('--outputdir', default="output", help="Directory to save output in") parser.add_argument('--num-updates', default="10000", type=int, help="How many iterations to train") parser.add_argument('--batch-size', default="10", type=int, help="Batch size to use") parser.add_argument('--learning-rate', type=float, default=None, help="Use this learning rate") parser.add_argument('--dropout-keep', default=1, type=float, help="Use dropout, with this keep chance") parser.add_argument('--restrict-dataset', metavar="NUM_STORIES", type=int, default=None, help="Restrict size of dataset to this") parser.add_argument('--save-params-interval', type=int, default=1000, dest="train_save_params", help="Save parameters after this many iterations") parser.add_argument('--final-params-only', action="store_const", const=None, dest="train_save_params", help="Don't save parameters while training, only at the end.") parser.add_argument('--validation', metavar="VALIDATION_DIR", default=None, help="Parsed directory of validation tasks") parser.add_argument('--validation-interval', type=int, default=1000, help="Check validation after this many iterations") parser.add_argument('--check-nan', dest="check_mode", action="store_const", const="nan", help="Check for NaN. Slows execution") parser.add_argument('--check-debug', dest="check_mode", action="store_const", const="debug", help="Debug mode. Slows execution") parser.add_argument('--visualize', nargs="?", const=True, default=False, metavar="BUCKET,STORY", type=lambda s:[int(x) for x in s.split(',')], help="Visualise current state instead of training. Optional parameter selects a particular story to visualize, and should be of the form bucketnum,index") parser.add_argument('--visualize-snap', action="store_true", help="In visualization mode, snap to best option at each timestep") parser.add_argument('--visualization-test', action="store_true", help="Like visualize, but use the correct graph instead of the model's graph") parser.add_argument('--evaluate-accuracy', action="store_true", help="Evaluate accuracy of model") parser.add_argument('--stop-at-accuracy', type=float, default=None, help="Stop training once it reaches this accuracy on validation set") parser.add_argument('--stop-at-loss', type=float, default=None, help="Stop training once it reaches this loss on validation set") parser.add_argument('--stop-at-overfitting', type=float, default=None, help="Stop training once validation loss is this many times higher than train loss") parser.add_argument('--batch-adjust', type=int, default=None, help="If set, ensure that size of edge matrix does not exceed this") parser.add_argument('--set-exit-status', action="store_true", help="Give info about training status in the exit status") parser.add_argument('--just-compile', action="store_true", help="Don't run the model, just compile it") parser.add_argument('--autopickle', metavar="PICKLEDIR", default=None, help="Automatically cache model in this directory") parser.add_argument('--pickle-model', metavar="MODELFILE", default=None, help="Save the compiled model to a file") parser.add_argument('--unpickle-model', metavar="MODELFILE", default=None, help="Load the model from a file instead of compiling it from scratch") parser.add_argument('--interrupt-file', default=None, help="Interrupt training if this file appears") resume_group = parser.add_mutually_exclusive_group() resume_group.add_argument('--resume', nargs=2, metavar=('TIMESTEP', 'PARAMFILE'), default=None, help='Where to restore from: timestep, and file to load') resume_group.add_argument('--resume-auto', action='store_true', help='Automatically restore from a previous run using output directory') if __name__ == '__main__': np.set_printoptions(linewidth=shutil.get_terminal_size((80, 20)).columns) args = vars(parser.parse_args()) main(**args)