#!/usr/bin/env python from __future__ import print_function import os import sys import csv import numpy as np import math import random from collections import defaultdict sys.path.append( '%s/../prog_common' % os.path.dirname(os.path.realpath(__file__)) ) from cmd_args import cmd_args from prog_util import prod, MAX_VARS, rule_ranges, MAX_NUM_STATEMENTS, DECISION_DIM from prog_tree import Node from tree_walker import ProgramOnehotBuilder class ProgTreeDecoder(object): def __init__(self): self.full_var_set = set(range(MAX_VARS)) self.reset_state() def reset_state(self): self.num_statements = 0 self.return_used = False self.defined_vars = set([0]) def get_node(self, node, new_sym, pos): if node.is_created(): assert pos < len(node.children) ans = node.children[pos] assert ans.symbol == new_sym return ans return Node(new_sym, node) def rand_rule(self, node, sub_ranges = None): g_range = rule_ranges[node.symbol] idxes = np.arange(g_range[0], g_range[1]) if sub_ranges is not None: idxes = idxes[sub_ranges] assert len(idxes) if len(idxes) == 1: result = 0 else: result = self.walker.sample_index_with_mask(node, idxes) if sub_ranges is not None: new_idx = sub_ranges[result] else: new_idx = result if node.rule_used is not None: assert node.rule_used == new_idx else: node.rule_used = new_idx return node.rule_used def rand_att(self, node, candidates): if len(candidates) == 1: att_idx = candidates[0] else: att_idx = self.walker.sample_att(node, candidates) if not hasattr(node, 'var_id'): node.var_id = att_idx else: assert node.var_id == att_idx return att_idx def get_inh_attr(self, attr_dict, key): assert attr_dict is not None assert key in attr_dict return attr_dict[key] def tree_generator(self, node, inherit_atts = None): if node.symbol == 'stat_list': candidates = [0] self.num_statements += 1 if self.num_statements < MAX_NUM_STATEMENTS: candidates.append(1) rule = self.rand_rule(node, candidates) if rule == 1: # state_list -> stat_list stat_sep stat s_list = self.get_node(node, 'stat_list', 0) node.add_child(s_list) self.tree_generator(s_list, inherit_atts={'is_last' : False}) sep = self.get_node(node, '\';\'', 1) node.add_child(sep) s = self.get_node(node, 'stat', -1) node.add_child(s) if inherit_atts is not None: is_last = self.get_inh_attr(inherit_atts, 'is_last') else: is_last = True self.tree_generator(s, inherit_atts={'is_last' : is_last}) elif node.symbol == 'stat': rule = int(self.get_inh_attr(inherit_atts, 'is_last')) p = prod[node.symbol][rule] s = self.get_node(node, p[0], 0) node.add_child(s) self.tree_generator(s) elif node.symbol == 'assign_stat': rhs = self.get_node(node, 'rhs', 2) self.tree_generator(rhs, {'is_reuse' : True}) lhs = self.get_node(node, 'lhs', 0) self.tree_generator(lhs, {'is_reuse': False}) node.add_child(lhs) e = self.get_node(node, '\'=\'', 1) node.add_child(e) node.add_child(rhs) elif node.symbol == 'return_stat': self.return_used = True lhs = self.get_node(node, 'lhs', 2) self.tree_generator(lhs, {'is_reuse' : True}) r = self.get_node(node, '\'return\'', 0) node.add_child(r) q = self.get_node(node, '\':\'', 1) node.add_child(q) node.add_child(lhs) elif node.symbol == 'var': is_reuse = self.get_inh_attr(inherit_atts, 'is_reuse') if is_reuse: candidates = list(self.defined_vars) else: candidates = list(self.full_var_set - self.defined_vars) assert len(candidates) var_id = self.rand_att(node, candidates) c = self.get_node(node, '\'v\'', 0) node.add_child(c) i = self.get_node(node, 'var_id', 1) node.add_child(i) self.tree_generator(i, {'id' : var_id}) if not is_reuse: # create a new lhs self.defined_vars.add(var_id) elif node.symbol == 'var_id': idx = self.get_inh_attr(inherit_atts, 'id') i = self.get_node(node, '\'%d\'' % idx, 0) node.add_child(i) else: assert node.symbol in ['program', 'lhs', 'rhs', 'expr', 'unary_expr', 'binary_expr', 'unary_op', 'unary_func', 'binary_op', 'binary_op', 'operand', 'immediate_number', 'digit'] rule = self.rand_rule(node) p = prod[node.symbol][rule] for i in range(len(p)): c = self.get_node(node, p[i], i) if not p[i][0] == '\'': # non-terminal t = self.tree_generator(c, inherit_atts=inherit_atts) node.add_child(c) def decode(self, node, walker): self.walker = walker self.walker.reset() self.reset_state() self.tree_generator(node) def batch_make_att_masks(node_list, tree_decoder = None, walker = None, dtype=np.byte): if walker is None: walker = OnehotBuilder() if tree_decoder is None: tree_decoder = ProgramOnehotBuilder() true_binary = np.zeros((len(node_list), cmd_args.max_decode_steps, DECISION_DIM), dtype=dtype) rule_masks = np.zeros((len(node_list), cmd_args.max_decode_steps, DECISION_DIM), dtype=dtype) for i in range(len(node_list)): node = node_list[i] tree_decoder.decode(node, walker) true_binary[i, np.arange(walker.num_steps), walker.global_rule_used[:walker.num_steps]] = 1 true_binary[i, np.arange(walker.num_steps, cmd_args.max_decode_steps), -1] = 1 for j in range(walker.num_steps): rule_masks[i, j, walker.mask_list[j]] = 1 rule_masks[i, np.arange(walker.num_steps, cmd_args.max_decode_steps), -1] = 1.0 return true_binary, rule_masks if __name__ == '__main__': dec = ProgTreeDecoder()