import numpy as np import os import pickle import model import random import ggtnn_graph_parse import convert_story import gzip from enum import Enum from ggtnn_graph_parse import MetadataList, PreppedStory from graceful_interrupt import GracefulInterruptHandler from pprint import pformat import util from train_exit_status import TrainExitStatus from functools import reduce BATCH_SIZE = 10 def convert_answer(answer, num_words, format_spec, maxlen): """ Convert an answer into an appropriate answer matrix given a ModelOutputFormat. num_words should be after processing with get_effective_answer_words, so that the last word is the "stop" word """ assert format_spec in model.ModelOutputFormat if format_spec == model.ModelOutputFormat.subset: ans_mat = np.zeros((1,num_words), np.float32) for word in answer: ans_mat[0, word] = 1.0 elif format_spec == model.ModelOutputFormat.category: ans_mat = np.zeros((1,num_words), np.float32) ans_mat[0,answer[0]] = 1.0 elif format_spec == model.ModelOutputFormat.sequence: ans_mat = np.zeros((maxlen+1,num_words), np.float32) for i,word in enumerate(answer+[num_words-1]*(maxlen+1-len(answer))): ans_mat[i, word] = 1.0 return ans_mat def get_effective_answer_words(answer_words, format_spec): """ If needed, modify answer_words using format spec to add padding chars """ if format_spec == model.ModelOutputFormat.sequence: return answer_words + ["<stop>"] else: return answer_words def sample_batch(matching_stories, batch_size, num_answer_words, format_spec): chosen_stories = [random.choice(matching_stories) for _ in range(batch_size)] return assemble_batch(chosen_stories, num_answer_words, format_spec) def assemble_batch(story_fns, num_answer_words, format_spec): stories = [] for sfn in story_fns: with gzip.open(sfn,'rb') as f: cvtd_story, _, _, _ = pickle.load(f) stories.append(cvtd_story) sents, graphs, queries, answers = zip(*stories) cvtd_sents = np.array(sents, np.int32) cvtd_queries = np.array(queries, np.int32) max_ans_len = max(len(a) for a in answers) cvtd_answers = np.stack([convert_answer(answer, num_answer_words, format_spec, max_ans_len) for answer in answers]) num_new_nodes, new_node_strengths, new_node_ids, next_edges = zip(*graphs) num_new_nodes = np.stack(num_new_nodes) new_node_strengths = np.stack(new_node_strengths) new_node_ids = np.stack(new_node_ids) next_edges = np.stack(next_edges) return cvtd_sents, cvtd_queries, cvtd_answers, num_new_nodes, new_node_strengths, new_node_ids, next_edges def assemble_correct_graphs(story_fns): correct_strengths, correct_ids, correct_edges = [], [], [] for sfn in story_fns: with gzip.open(sfn,'rb') as f: cvtd_story, _, _, _ = pickle.load(f) strengths, ids, _, edges = convert_story.convert(cvtd_story) correct_strengths.append(strengths) correct_ids.append(ids) correct_edges.append(edges) return tuple(np.concatenate(l,0) for l in (correct_strengths, correct_ids, correct_edges)) def visualize(m, story_buckets, wordlist, answerlist, output_format, outputdir, batch_size=1, seq_len=5, debugmode=False, snap=False): cur_bucket = random.choice(story_buckets) sampled_batch = sample_batch(cur_bucket, batch_size, len(answerlist), output_format) part_sampled_batch = sampled_batch[:3] with open(os.path.join(outputdir,'stories.txt'),'w') as f: ggtnn_graph_parse.print_batch(part_sampled_batch, wordlist, answerlist, file=f) with open(os.path.join(outputdir,'answer_list.txt'),'w') as f: f.write('\n'.join(answerlist) + '\n') if debugmode: args = sampled_batch fn = m.debug_test_fn else: args = part_sampled_batch[:2] + ((seq_len,) if output_format == model.ModelOutputFormat.sequence else ()) fn = m.snap_test_fn if snap else m.fuzzy_test_fn results = fn(*args) for i,result in enumerate(results): np.save(os.path.join(outputdir,'result_{}.npy'.format(i)), result) def test_accuracy(m, story_buckets, bucket_sizes, num_answer_words, format_spec, batch_size, batch_auto_adjust=None, test_graph=False): correct = 0 out_of = 0 for bucket, bucket_size in zip(story_buckets, bucket_sizes): cur_batch_size = adj_size(m, bucket_size, batch_size, batch_auto_adjust) for start_idx in range(0, len(bucket), cur_batch_size): stories = bucket[start_idx:start_idx+cur_batch_size] batch = assemble_batch(stories, num_answer_words, format_spec) answers = batch[2] args = batch[:2] + ((answers.shape[1],) if format_spec == model.ModelOutputFormat.sequence else ()) if test_graph: _, batch_close, _ = m.eval(*batch, with_accuracy=True) else: out_answers, out_strengths, out_ids, out_states, out_edges = m.snap_test_fn(*args) close = np.isclose(out_answers, answers) batch_close = np.all(close, (1,2)) print(batch_close) batch_correct = np.sum(batch_close).tolist() batch_out_of = len(stories) correct += batch_correct out_of += batch_out_of return correct/out_of def adj_size(m, cur_bucket_size, batch_size, batch_auto_adjust): if batch_auto_adjust is not None: # Adjust batch size for this bucket edge_size = (cur_bucket_size**3) * (m.new_nodes_per_iter**2) * m.num_edge_types if m.sequence_representation: # In sequence representation mode, we are doing stuff with all objects at the same time # so add a multiple of the edge size to get a nice bound edge_size = edge_size * 4 max_batch_size = batch_auto_adjust//edge_size return min(batch_size, max_batch_size) else: return batch_size def train(m, story_buckets, bucket_sizes, len_answers, output_format, num_updates, outputdir, start=0, batch_size=BATCH_SIZE, validation_buckets=None, validation_bucket_sizes=None, stop_at_accuracy=None, stop_at_loss=None, stop_at_overfitting=None, save_params=1000, validation_interval=1000, batch_auto_adjust=None, interrupt_file=None): with GracefulInterruptHandler() as interrupt_h: for i in range(start+1,num_updates+1): exit_with = None cur_bucket, cur_bucket_size = random.choice(list(zip(story_buckets, bucket_sizes))) cur_batch_size = adj_size(m, cur_bucket_size, batch_size, batch_auto_adjust) sampled_batch = sample_batch(cur_bucket, cur_batch_size, len_answers, output_format) loss, info = m.train(*sampled_batch) if np.any(np.isnan(loss)): print("Loss at timestep {} was nan! Aborting".format(i)) return TrainExitStatus.nan_loss # Don't bother saving with open(os.path.join(outputdir,'data.csv'),'a') as f: if i == 1: f.seek(0) f.truncate() keylist = "iter, loss, " + ", ".join(k for k,v in sorted(info.items())) + "\n" f.write(keylist) if validation_buckets is not None: with open(os.path.join(outputdir,'valid.csv'),'w') as f2: f2.write(keylist) f.write("{}, {},".format(i,loss) + ", ".join(str(v) for k,v in sorted(info.items())) + "\n") if i % 1 == 0: print("update {}: {}\n{}".format(i,loss,pformat(info))) if i % validation_interval == 0: if validation_buckets is not None: cur_bucket, cur_bucket_size = random.choice(list(zip(validation_buckets, validation_bucket_sizes))) cur_batch_size = adj_size(m, cur_bucket_size, batch_size, batch_auto_adjust) sampled_batch = sample_batch(cur_bucket, cur_batch_size, len_answers, output_format) valid_loss, valid_info = m.eval(*sampled_batch) print("validation at {}: {}\n{}".format(i,valid_loss,pformat(valid_info))) with open(os.path.join(outputdir,'valid.csv'),'a') as f: f.write("{}, {}, ".format(i,valid_loss) + ", ".join(str(v) for k,v in sorted(valid_info.items())) + "\n") valid_accuracy = test_accuracy(m, validation_buckets, validation_bucket_sizes, len_answers, output_format, batch_size, batch_auto_adjust, (not m.train_with_query)) print("Best-choice accuracy at {}: {}".format(i,valid_accuracy)) with open(os.path.join(outputdir,'valid_acc.csv'),'a') as f: f.write("{}, {}\n".format(i,valid_accuracy)) if stop_at_accuracy is not None and valid_accuracy >= stop_at_accuracy: print("Accuracy reached threshold! Stopping training") exit_with = TrainExitStatus.success if stop_at_loss is not None and valid_loss <= stop_at_loss: print("Loss reached threshold! Stopping training") exit_with = TrainExitStatus.success if stop_at_overfitting is not None and valid_loss/loss > stop_at_overfitting: print("Model appears to be overfitting! Stopping training") exit_with = TrainExitStatus.overfitting if exit_with is None and (interrupt_h.interrupted or (interrupt_file is not None and os.path.isfile(interrupt_file))): exit_with = TrainExitStatus.interrupted if (save_params is not None and i % save_params == 0) or (exit_with is not None) or (i==num_updates): util.save_params(m.params, open(os.path.join(outputdir, 'params{}.p'.format(i)), 'wb')) if exit_with is not None: return exit_with return TrainExitStatus.reached_update_limit