#!/usr/bin/env python from __future__ import print_function from past.builtins import range import os import sys import numpy as np import math import random from tqdm import tqdm sys.path.append('%s/../prog_common' % os.path.dirname(os.path.realpath(__file__))) from prog_util import DECISION_DIM from prog_tree import AnnotatedTree2ProgTree from cmd_args import cmd_args sys.path.append('%s/../prog_decoder' % os.path.dirname(os.path.realpath(__file__))) from prog_tree_decoder import ProgTreeDecoder, batch_make_att_masks from tree_walker import ProgramOnehotBuilder sys.path.append('%s/../cfg_parser' % os.path.dirname(os.path.realpath(__file__))) import cfg_parser as parser from joblib import Parallel, delayed import h5py import argparse cmd_opt = argparse.ArgumentParser(description='Argparser for data dump') cmd_opt.add_argument('-min_len', type=int, help='min # of statements') cmd_opt.add_argument('-max_len', type=int, help='max # of statements') cmd_opt.add_argument('-phase', type=str, help='train / test') args, _ = cmd_opt.parse_known_args() def process_chunk(program_list): grammar = parser.Grammar(cmd_args.grammar_file) cfg_tree_list = [] for program in program_list: ts = parser.parse(program, grammar) assert isinstance(ts, list) and len(ts) == 1 n = AnnotatedTree2ProgTree(ts[0]) cfg_tree_list.append(n) walker = ProgramOnehotBuilder() tree_decoder = ProgTreeDecoder() onehot, masks = batch_make_att_masks(cfg_tree_list, tree_decoder, walker, dtype=np.byte) return (onehot, masks) def run_job(L): chunk_size = 5000 list_binary = Parallel(n_jobs=cmd_args.data_gen_threads, verbose=50)( delayed(process_chunk)(L[start: start + chunk_size]) for start in range(0, len(L), chunk_size) ) all_onehot = np.zeros((len(L), cmd_args.max_decode_steps, DECISION_DIM), dtype=np.byte) all_masks = np.zeros((len(L), cmd_args.max_decode_steps, DECISION_DIM), dtype=np.byte) for start, b_pair in zip( range(0, len(L), chunk_size), list_binary ): all_onehot[start: start + chunk_size, :, :] = b_pair[0] all_masks[start: start + chunk_size, :, :] = b_pair[1] return all_onehot, all_masks if __name__ == '__main__': onehot_list = [] mask_list = [] for l in range(args.min_len, args.max_len + 1): if args.phase == 'train': fname = '%s/free_var_id-check_data-number-50000-nbstat-%d.txt' % (cmd_args.save_dir, l) else: fname = '%s/free_var_id-check_data-number-50000-nbstat-%d.test.txt' % (cmd_args.save_dir, l) program_list = [] with open(fname, 'r') as f: for row in tqdm(f): program = row.strip() program_list.append(program) onehot, mask = run_job(program_list) onehot_list.append(onehot) mask_list.append(mask) all_onehot = np.vstack(onehot_list) all_mask = np.vstack(mask_list) # shuffle training set idxes = range(all_onehot.shape[0]) random.shuffle(idxes) idxes = np.array(idxes, dtype=np.int32) all_onehot = all_onehot[idxes, :, :] all_mask = all_mask[idxes, :, :] print('num samples: ', len(idxes)) out_file = '%s/nbstate-%d-to-%d-%s.h5' % (cmd_args.save_dir, args.min_len, args.max_len, args.phase) h5f = h5py.File(out_file, 'w') h5f.create_dataset('x_%s' % args.phase, data=all_onehot) h5f.create_dataset('masks_%s' % args.phase, data=all_mask) h5f.close()