# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import numpy as np import json import copy import functools from tensorpack.utils import logger """ Layer definitions """ class LayerTypes(object): # Value regarding stop_gradients not exactly layer types STOP_GRADIENT_NONE = 0 STOP_GRADIENT_HARD = 1 STOP_GRADIENT_SOFT = 2 # NOT_EXIST and NOTHING is for layer info sequence. NOT_EXIST = 0 IDENTITY = 1 # vision RESIDUAL_LAYER = 2 RESIDUAL_BOTTLENECK_LAYER = 3 CONV_1 = 4 CONV_3 = 5 SEPARABLE_CONV_3 = 6 SEPARABLE_CONV_5 = 7 SEPARABLE_CONV_3_2 = 23 SEPARABLE_CONV_5_2 = 24 SEPARABLE_CONV_7_2 = 25 DILATED_CONV_3 = 30 DILATED_CONV_5 = 31 MAXPOOL_3x3 = 8 AVGPOOL_3x3 = 9 # merge MERGE_WITH_CAT = 10 MERGE_WITH_SUM = 11 MERGE_WITH_AVG = 12 MERGE_WITH_NOTHING = 13 MERGE_WITH_MUL = 21 MERGE_WITH_CAT_PROJ = 22 MERGE_WITH_SOFTMAX = 27 MERGE_WITH_WEIGHTED_SUM = 28 # Gates on hallucinations and the original path GATED_LAYER = 14 ANTI_GATED_LAYER = 17 NO_FORWARD_LAYER = 26 # MLP and general operations FullyConnected = 15 MLP_RESIDUAL_LAYER = 16 # RNN layers FC_TANH_MUL_GATE = 18 FC_RELU_MUL_GATE = 19 FC_SGMD_MUL_GATE = 20 FC_IDEN_MUL_GATE = 29 @staticmethod def num_layer_types(): return 32 # Please update me whenever a new layer is made @staticmethod def no_param_ops(): LT = LayerTypes return [LT.NOT_EXIST, LT.IDENTITY, LT.GATED_LAYER, LT.ANTI_GATED_LAYER, LT.NO_FORWARD_LAYER] @staticmethod def do_drop_path(op): LT = LayerTypes return (not LT.has_multi_inputs(op) and not op in LT.no_param_ops()) @staticmethod def sample_layer_type(valid_ops, prob=None): n = len(valid_ops) if prob is None: prob = [1. / n] * n else: assert len(prob) == n return valid_ops[int(np.nonzero(np.random.multinomial(1, prob))[0][0])] @staticmethod def has_multi_inputs(op): return op in [LayerTypes.MERGE_WITH_CAT, LayerTypes.MERGE_WITH_SUM, LayerTypes.MERGE_WITH_AVG, LayerTypes.MERGE_WITH_MUL, LayerTypes.MERGE_WITH_CAT_PROJ, LayerTypes.MERGE_WITH_SOFTMAX, LayerTypes.MERGE_WITH_WEIGHTED_SUM] class LayerInfo(dict): def __init__(self, layer_id, inputs=[], operations=[], down_sampling=0, stop_gradient=0, aux_weight=0, is_candidate=0, extra_dict=None): """ Args: layer_id: a unqiue non-negative int that is the id of the layer. When overflows happens, an assertion failure will be triggered Kwargs: inputs: a list of layer_ids that are the inpu of the layer operations: a list of LayerType.<operations> on each of the input down_sampling: int/bool. whether this layer will dod down sampling. stop_gradient: int or list of int. Each int is 0 or 1, representing whether a stop_gradient call is done first to each of the input. If the value is just an int, then the same value is applied to all inputs. aux_weight: float auxiliary weight on the layer. This is used to mark where the prediction happens. E.g., the final layer always should have positive weight. is_candidate: int. Candidate id. This represents which candidate this layer is from. candidate==0 means the real network, >0 means it is from some hallucination. """ def _assert_is_list(v): assert isinstance(v, list), v def _assert_is_int_or_bool(v): assert isinstance(v, int) or isinstance(v, bool) _assert_is_list(inputs) _assert_is_list(operations) _assert_is_int_or_bool(down_sampling) assert isinstance(aux_weight, float) or isinstance(aux_weight, int) if len(operations) != len(inputs) + 1: assert len(inputs) == len(operations) if layer_id is not None: assert layer_id >=0, 'Hack failed, we have negative layer idx meow' self['id'] = layer_id self['inputs'] = inputs self['operations'] = operations self['down_sampling'] = down_sampling self['stop_gradient'] = stop_gradient self['aux_weight'] = aux_weight self['is_candidate'] = is_candidate if isinstance(extra_dict, dict) and len(extra_dict) > 0: self['extra_dict'] = extra_dict @property def id(self): return self['id'] @id.setter def id(self, val): self['id'] = val @property def inputs(self): return self['inputs'] @inputs.setter def inputs(self, val): self['inputs'] = val @property def operations(self): return self['operations'] @operations.setter def operations(self, val): self['operations'] = val @property def down_sampling(self): return self['down_sampling'] @down_sampling.setter def down_sampling(self, val): self['down_sampling'] = val @property def stop_gradient(self): return self['stop_gradient'] @stop_gradient.setter def stop_gradient(self, val): self['stop_gradient'] = val @property def aux_weight(self): return self['aux_weight'] @aux_weight.setter def aux_weight(self, val): self['aux_weight'] = val @property def is_candidate(self): return self['is_candidate'] @is_candidate.setter def is_candidate(self, val): self['is_candidate'] = val @property def merge_op(self): if len(self['inputs']) == len(self['operations']): return LayerTypes.MERGE_WITH_NOTHING return self['operations'][-1] @merge_op.setter def merge_op(self, val): if len(self['inputs']) == len(self['operations']): self['operations'].appned(val) self['operations'][-1] = val @property def input_ops(self): return self['operations'][:len(self['inputs'])] @input_ops.setter def input_ops(self, val): curr = self.get('operations', []) if len(curr) <= 1: # there is no merge_op before self['operations'] = val else: self['operations'][:-1] = val @property def extra_dict(self): return self.get('extra_dict', None) @extra_dict.setter def extra_dict(self, val): self['extra_dict'] = val @staticmethod def from_str(ss): d = json.loads(ss) return LayerInfo.from_json_loads(d) def to_str(self): return json.dumps(self) @staticmethod def from_json_loads(d): d['layer_id'] = d.pop('id', None) return LayerInfo(**d) @staticmethod def is_input(info): return len(info.inputs) == 0 class LayerInfoList(list): DELIM = '=^_^=' LOG_DELIM = '\n' n_extra_dim = 4 ORIG_IDX_IN_MERGE_HALLU = 0 HALLU_IDX_IN_MERGE_HALLU = 1 def __init__(self, *args, **kwargs): super(LayerInfoList, self).__init__(*args, **kwargs) def to_str(self): return json.dumps(self) def num_inputs(self): n_inputs = 0 for info in self: if LayerInfo.is_input(info): n_inputs += 1 else: break return n_inputs @property def master(self): return self def is_end_merge_sum(self): LT = LayerTypes return self[-1].merge_op in [LT.MERGE_WITH_SUM, LT.MERGE_WITH_AVG] def is_end_merge_cat(self): LT = LayerTypes return self[-1].merge_op in [LT.MERGE_WITH_CAT, LT.MERGE_WITH_CAT_PROJ] def sample_cat_hallucinations(self, layer_ops, merge_ops, prob_at_layer=None, min_num_hallus=1, hallu_input_choice=None): """ prob_at_layer : probility of having input from a layer. None is translated to default, which sample a layer proportional to its ch_dim. The ch_dim is computed using self, as we assume the last op is cat, and the cat determines the ch_dim. """ assert self[-1].merge_op == LayerTypes.MERGE_WITH_CAT n_inputs = self.num_inputs() n_final_merge = len(self[-1].inputs) if prob_at_layer is None: prob_at_layer = np.ones(len(self) - 1) prob_at_layer[:n_inputs-1] = n_final_merge prob_at_layer[n_inputs-1] = n_final_merge * 1.5 prob_at_layer = prob_at_layer / np.sum(prob_at_layer) assert len(prob_at_layer) >= len(self) - 1 if len(prob_at_layer) > len(self) - 1: logger.warn("sample cell hallu cuts the prob_at_layer to len(info_list) - 1") prob_at_layer = prob_at_layer[:len(self)-1] # choose inputs n_hallu_inputs = 2 l_hallu = [] for _ in range(min_num_hallus): # replace == True : can connect multiple times to the same layer in_idxs = np.random.choice(list(range(len(prob_at_layer))), size=n_hallu_inputs, replace=False, p=prob_at_layer) in_ids = list(map(lambda idx : self[idx].id, in_idxs)) main_ops = list(map(int, np.random.choice(layer_ops, size=n_hallu_inputs))) merge_op = int(np.random.choice(merge_ops)) hallu = LayerInfo(layer_id=self[-1].id, inputs=in_ids, operations=main_ops + [merge_op]) l_hallu.append(hallu) return l_hallu def add_cat_hallucinations(self, l_h, final_merge_op=LayerTypes.MERGE_WITH_SUM, stop_gradient_val=1, hallu_gate_layer=LayerTypes.NO_FORWARD_LAYER): assert final_merge_op == LayerTypes.MERGE_WITH_SUM assert self[-1].merge_op == LayerTypes.MERGE_WITH_CAT l_info = self next_id = 0 max_candidate = 0 for info in l_info: next_id = max(next_id, info.id) max_candidate = max(max_candidate, info.is_candidate) max_candidate = 0 l_insert = [] final_merge_inputs = [] for h in l_h: next_id += 1 max_candidate += 1 info = copy.deepcopy(h) info.id = next_id info.is_candidate = max_candidate info.stop_gradient = stop_gradient_val l_insert.append(info) final_merge_inputs.append(info.id) next_id += 1 ops = [hallu_gate_layer] * len(final_merge_inputs) + [LayerTypes.MERGE_WITH_SUM] info = LayerInfo(layer_id=next_id, inputs=final_merge_inputs, operations=ops) l_insert.append(info) l_info[-1].inputs.append(info.id) l_info[-1].operations[-1:-1] = [LayerTypes.IDENTITY] l_info[-1:-1] = l_insert merge_pt = l_info[-1] assert not isinstance(merge_pt.stop_gradient, list), \ "Merge point has invalid stop_gradient: {}".format(merge_pt) assert not isinstance(merge_pt.down_sampling, list), \ "Merge point has invalid down_sampling: {}".format(merge_pt) return l_info def select_cat_hallucination(self, selected, contained=None, hallu_indices=None, l_fs_ops=None, l_fs_omega=None): assert self[-1].merge_op == LayerTypes.MERGE_WITH_CAT, self[-1].merge_op l_info = self if isinstance(selected, int): selected = [selected] if contained is None: contained = l_info.contained_hallucination() n_contained = len(contained) if n_contained == 0: return l_info if hallu_indices is None: hallu_indices = LayerInfoList.sorted_hallu_indices(contained) to_keep = [] for h_idx in reversed(hallu_indices): start, end = contained[h_idx] if not h_idx in selected: l_info[start:end] = [] else: to_keep.append(l_info[end-1].id) for x in l_info[start:end]: x.is_candidate = 0 x.stop_gradient = 0 n_to_keep = len(to_keep) old_id = l_info[-2].id assert old_id == l_info[-1].inputs[-1], '{} != {}'.format(old_id, l_info[-2].id) def del_merge_hallu(): del l_info[-2] def del_at_final_cat(): del l_info[-1].inputs[-1] del l_info[-1].operations[-2] if isinstance(l_info[-1].down_sampling, list): assert len(l_info[-1].down_sampling) == len(l_info[-1].inputs) + 1 del l_info[-1].down_sampling[-1] if isinstance(l_info[-1].stop_gradient, list): assert len(l_info[-1].stop_gradient) == len(l_info[-1].inputs) + 1 del l_info[-1].stop_gradient[-1] if n_to_keep > 1: for keep_id in to_keep: assert keep_id in l_info[-2].inputs, \ '{} not in {}'.format(keep_id, l_info[-2].inputs) l_info[-2].inputs = to_keep l_info[-2].operations = ([LayerTypes.GATED_LAYER] * len(to_keep) + [LayerTypes.MERGE_WITH_SUM]) elif n_to_keep == 0: del_merge_hallu() del_at_final_cat() elif n_to_keep == 1: new_id = to_keep[0] assert new_id in l_info[-2].inputs, \ '{} not in {}'.format(new_id, l_info[-2].inputs) del_merge_hallu() l_info[-1].inputs, l_info[-1].stop_gradient = \ LayerInfoList._rewire_inputs_post_add( { old_id : new_id }, l_info[-1].inputs, l_info[-1].stop_gradient) l_info[-1].operations[-2] = LayerTypes.GATED_LAYER return l_info def _id_to_idx(self): id_to_idx = dict() for idx, info in enumerate(self): id_to_idx[info.id] = idx return id_to_idx def _distance_to_idx(self, t_idx, id_to_idx, max_dist=None): max_dist = max_dist if max_dist is not None else t_idx + 1 l_dists = [max_dist + 1] * (t_idx + 1) is_fixed = [False] * (t_idx + 1) l_dists[-1] = 0 is_fixed[-1] = True queue = [t_idx] qidx = 0 while qidx < len(queue): idx = queue[qidx] dist = l_dists[idx] for in_id in self[idx].inputs: in_idx = id_to_idx[in_id] if not is_fixed[in_idx]: is_fixed[in_idx] = True l_dists[in_idx] = dist + 1 if l_dists[in_idx] < max_dist: queue.append(in_idx) qidx += 1 return l_dists def _is_end_of_cell(self, idx): """ Given an index (idx) in the current layer info list, determine whether it is an end of a cell in the starting of a macro search. This method should only be used by a master l_info, i.e., a macro search master, or a cell search root/master. This method assumes that the original layers have id smaller than the last id, which is the last layer of the original network as well. Finally, the original layer ids also convey the depth of layers. (Known bug: we only use the connection pattern and id to check this. Ideally one should know this from detecting repeatable patterns of directly using the original construction function. This is not a concern for now as the seed cells are simple). """ l_info = self info = l_info[idx] if info.id > l_info[-1].id: # the info is inserted after seed model return False # Check input connection patterns. def seeded_inputs(_info): info_id = _info.id return [ (info_id - in_id, in_op) \ for in_id, in_op in zip(_info.inputs, _info.operations) \ if in_id < info_id ] tar_id_ops = seeded_inputs(l_info[-1]) id_ops = seeded_inputs(info) return id_ops == tar_id_ops def _sample_output_locations( self, min_num_hallus, cell_based): """ Returns: A list of indices in the layer info list that will be x_out """ l_info = self n_layers = len(l_info) n_inputs = l_info.num_inputs() LT = LayerTypes if cell_based: l_x_out = [n_layers - 1] * min_num_hallus else: # compute the eligible layers. indices = [ idx for idx in range(n_inputs, n_layers) \ if l_info._is_end_of_cell(idx) ] #logger.info("The eligible ones are {}".format(indices)) n_eligible = len(indices) repeats = min_num_hallus // n_eligible remainder = np.random.choice( indices, min_num_hallus % n_eligible, replace=False ) l_x_out = indices * repeats + list(remainder) l_x_out.sort() return l_x_out def _sample_all_input_ids(self, l_x_out, cell_based): """ 1. For cell search, inputs are either in the same cell or have lower id. So all layers are possible inputs. 2. For macro search, we mimic the behavior, and limit the inputs to be in the same cell or the output of the two previous cells. Args: l_x_out: a list of indices in the layer info list that are to be x_out cell_based (bool) : whether the NAS is for cell or macro search. Returns: A list of list of layer id that are to be input ids of the x_out in l_x_out. """ l_info = self l_inputs = [] for out_idx in l_x_out: if cell_based: in_id_pool = [info.id for info in l_info[:out_idx]] else: # This is the same as cell based search case above, # if the cell has two inputs (prev and prev-prev cell output). n_origs = 2 in_same_cell = True in_id_pool = [] for idx in reversed(range(out_idx)): if l_info._is_end_of_cell(idx): in_same_cell = False n_origs -= 1 in_id_pool.append(l_info[idx].id) if n_origs == 0: break if in_same_cell: in_id_pool.append(l_info[idx].id) l_inputs.append(in_id_pool) return l_inputs # def _sample_input_ids(self, l_x_out, cell_based): # """ # Args: # l_x_out: a list of indices in the layer info list that are to be x_out # cell_based (bool) : whether the NAS is for cell or macro search. # Returns: # A list of list of layer id that are to be input ids of the x_out in l_x_out. # """ # l_info = self # n_inputs = l_info.num_inputs() # LT = LayerTypes # # compute the current intensive operations. # n_intensive_ops = 0 # for info in l_info[n_inputs:]: # for op in info.operations: # n_intensive_ops += int( # op in [ # LT.SEPARABLE_CONV_3_2, LT.SEPARABLE_CONV_3_2, # LT.SEPARABLE_CONV_7_2, LT.CONV_1, LT.CONV_3, # ] # ) # # n_ins_per_out is proportional to number of existing intensive # # operations, and is at least 2 (unless not possible) # n_ins_per_out = max( # 2, int(0.5 + float(n_intensive_ops) / len(l_x_out))) # l_inputs = [] # # 1. For cell search, inputs are either in the same # # cell or have lower id. So all layers are possible inputs. # # 2. For macro search, we mimic the behavior, and limit # # the inputs to be in the same cell or the output of # # the two previous cells. # max_orig_id = l_info[-1].id # for out_idx in l_x_out: # # Find the most recent two orig idx. # # These are the output of cells. # n_origs = 2 # in_same_cell = True # in_id_pool = [] # for in_idx in reversed(range(out_idx)): # infoid = l_info[in_idx].id # if infoid <= max_orig_id: # in_same_cell = False # n_origs -= 1 # in_id_pool.append(infoid) # if n_origs == 0: # break # if in_same_cell: # in_id_pool.append(infoid) # n_x_ins = min(n_ins_per_out, len(in_id_pool)) # inputs = [ # int(_x) for _x in np.random.choice( # in_id_pool, n_x_ins, replace=False) # ] # l_inputs.append(inputs) # return l_inputs def sample_sum_hallucinations( self, layer_ops, merge_ops, prob_at_layer=None, min_num_hallus=1, hallu_input_choice=None, cell_based=False): """ Sample hallus """ l_info = self LT = LayerTypes do_feat_sel = ( len(merge_ops) == 1 and merge_ops[0] == LT.MERGE_WITH_WEIGHTED_SUM) l_x_out = self._sample_output_locations( min_num_hallus, cell_based) l_inputs = self._sample_all_input_ids(l_x_out, cell_based) hallus = [] for out_idx, inputs in zip(l_x_out, l_inputs): out_id = l_info[out_idx].id if len(inputs) == 0: continue if not do_feat_sel: # sample operations main_ops = list(map(int, np.random.choice(layer_ops, len(inputs)))) merge_op = int(np.random.choice(merge_ops)) else: # feature selection will form all ops # op1,..., opk, op1, ..., opk ,.... main_ops = list(layer_ops) * len(inputs) merge_op = LT.MERGE_WITH_WEIGHTED_SUM # in1,..., in1, in2, ..., in2, ... inputs = [x for l_x in zip(*[inputs] * len(layer_ops)) for x in l_x] hallu = LayerInfo(layer_id=out_id, inputs=inputs, operations=main_ops + [ merge_op ]) hallus.append(hallu) return hallus def select_sum_hallucination(self, selected, contained=None, hallu_indices=None, l_fs_ops=None, l_fs_omega=None): l_info = self if isinstance(selected, int): selected = [selected] if contained is None: contained = l_info.contained_hallucination() n_contained = len(contained) if n_contained == 0: return l_info if hallu_indices is None: hallu_indices = LayerInfoList.sorted_hallu_indices(contained) for h_idx_idx in reversed(range(len(hallu_indices))): h_idx = hallu_indices[h_idx_idx] start, end = contained[h_idx] h_layer_id = l_info[end-1].id found = False # find the exact one user of this candidate. for idx in range(end, len(l_info)): if h_layer_id in l_info[idx].inputs: found = True break assert found, "Did not find id {}".format(h_layer_id) to_remove = ( (not h_idx in selected) or (l_info[end-1].merge_op == LayerTypes.MERGE_WITH_WEIGHTED_SUM and len(l_fs_ops[h_idx_idx]) == 0) ) if to_remove: l_info[idx] = LayerInfoList._remove_connection_from_id( l_info[idx], h_layer_id) l_info[start:end] = [] else: for x in l_info[start:end]: x.is_candidate = 0 x.stop_gradient = 0 # feature selection case: hallu_info = l_info[end-1] if hallu_info.merge_op == LayerTypes.MERGE_WITH_WEIGHTED_SUM: fs_indices = l_fs_ops[h_idx_idx] fs_omega = l_fs_omega[h_idx_idx] assert len(fs_indices) == len(fs_omega), \ 'Invalid feat select info i={} omega={}'.format( fs_indices, fs_omega) new_inputs = [ hallu_info.inputs[in_idx] for in_idx in fs_indices ] new_operations = [ hallu_info.operations[in_idx] for in_idx in fs_indices ] l_info[end-1].inputs = new_inputs l_info[end-1].operations = ( #new_operations + [LayerTypes.MERGE_WITH_WEIGHTED_SUM] new_operations + [LayerTypes.MERGE_WITH_CAT_PROJ] ) ed = hallu_info.extra_dict ed = dict() if ed is None else ed ed['ops_ids'] = list(map(int, fs_indices)) ed['fs_omega'] = list(map(float, fs_omega)) l_info[end-1].extra_dict = ed # fix future reference to hallu for in_idx, in_id in enumerate(l_info[idx].inputs): if in_id == h_layer_id: l_info[idx].operations[in_idx] = LayerTypes.GATED_LAYER return l_info def add_sum_hallucinations( self, l_h, final_merge_op=LayerTypes.MERGE_WITH_SUM, stop_gradient_val=1, hallu_gate_layer=LayerTypes.NO_FORWARD_LAYER): """ Add hallus """ l_info = self next_id = 0 max_candidate = 0 for info in l_info: next_id = max(next_id, info.id) max_candidate = max(max_candidate, info.is_candidate) last_o_l_info_idx = 0 for hallu in l_h: o_id = hallu.id n_layers = len(l_info) for idx in range(last_o_l_info_idx, n_layers): if l_info[idx].id == o_id: last_o_l_info_idx = idx break assert l_info[last_o_l_info_idx].id == o_id, \ "info {} does not find id {}; start at {}".format( l_info, o_id, last_o_l_info_idx) next_id += 1 max_candidate += 1 info = copy.deepcopy(hallu) info.id = next_id info.is_candidate = max_candidate info.stop_gradient = stop_gradient_val l_info[last_o_l_info_idx].inputs.append(info.id) l_info[last_o_l_info_idx].operations[-1:-1] = [hallu_gate_layer] merge_pt = l_info[last_o_l_info_idx] assert not isinstance(merge_pt.stop_gradient, list), \ "Merge point has invalid stop_gradient: {}".format(merge_pt) assert not isinstance(merge_pt.down_sampling, list), \ "Merge point has invalid down_sampling: {}".format(merge_pt) l_info[last_o_l_info_idx:last_o_l_info_idx] = [info] return l_info def contained_hallucination(self): """ Args: l_info (list of LayerInfo) : a model description that contains hallu Returns: hallu_idx_to_range : a dict mapping from candidate_id to (start, end) for candidate_id > 0. l_info[start:end] contains exactly all info such that info.is_candidate == candidate_id. """ l_info = self n_layers = len(l_info) prev_candidate = 0 hallu_idx_to_range = {} for idx, info in enumerate(l_info): candidate = info.is_candidate if candidate == 0: prev_candidate = 0 continue is_start = prev_candidate != candidate is_last = ((idx + 1 == n_layers) or (l_info[idx+1].is_candidate != candidate)) prev_candidate = candidate if is_start: start = idx if is_last: hallu_idx_to_range[candidate] = (start, idx+1) return hallu_idx_to_range @staticmethod def sorted_hallu_indices(hallu_idx_to_range): return sorted(hallu_idx_to_range, key=lambda i : hallu_idx_to_range[i][0]) @staticmethod def _create_info_merge(next_id, h_id, o_id, aux_weight, is_candidate, final_merge_op=LayerTypes.MERGE_WITH_SUM, hallu_gate_layer=LayerTypes.NO_FORWARD_LAYER): """ Form the LayerInfo for the merge operation between hallu of id h_id and the original tensor of id o_id (out_id). The new LayerInfo will have info.id == next_id. Return a list of layers used for merging Note any change to this function need to be mirrored in _finalize_info_merge """ inputs = [None] * 2 inputs[LayerInfoList.ORIG_IDX_IN_MERGE_HALLU] = o_id inputs[LayerInfoList.HALLU_IDX_IN_MERGE_HALLU] = h_id operations = [LayerTypes.IDENTITY] * 2 + [final_merge_op] operations[LayerInfoList.HALLU_IDX_IN_MERGE_HALLU] = hallu_gate_layer info = LayerInfo(next_id, inputs=inputs, operations=operations, aux_weight=aux_weight, is_candidate=is_candidate) return [info] @staticmethod def _finalize_info_merge(l_info, start, end): """ Given an l_info, and start and end idx of a hallu, update the params of l_info[start:end], so that they will become part of the model permenantly. Note that any change to this function need to be mirrored in _create_info_merge """ for x in l_info[start:end]: x.is_candidate = 0 x.stop_gradient = 0 anti_gate_idx = LayerInfoList.ORIG_IDX_IN_MERGE_HALLU l_info[end-1].operations[anti_gate_idx] = LayerTypes.ANTI_GATED_LAYER gate_idx = LayerInfoList.HALLU_IDX_IN_MERGE_HALLU l_info[end-1].operations[gate_idx] = LayerTypes.GATED_LAYER return l_info @staticmethod def _rewire_inputs_post_add(prev_id_to_new_id, inputs, stop_gradient): """ Given a dictionary mapping from old id to the new id, we swap the occurance of old_id in inputs to new id. Return: updated inputs, and stop_gradient """ if not isinstance(stop_gradient, list): stop_gradient = [stop_gradient] * len(inputs) for ini, inid in enumerate(inputs): if inid in prev_id_to_new_id: inputs[ini] = prev_id_to_new_id[inid] stop_gradient[ini] = 0 sg_ref = stop_gradient[0] for sg in stop_gradient: if sg != sg_ref: break if sg == sg_ref: stop_gradient = sg_ref return inputs, stop_gradient @staticmethod def _remove_connection_from_id(info, id_to_remove): if not id_to_remove in info.inputs: return info if isinstance(info.stop_gradient, list): assert len(info.stop_gradient) == len(info.inputs), \ "Invalid info {}".format(info) if isinstance(info.down_sampling, list): assert len(info.down_sampling) == len(info.inputs), \ "Invalid info {}".format(info) assert len(info.operations) == len(info.inputs) + 1, \ "Invalid info {}".format(info) idx = 0 while idx < len(info.inputs): if info.inputs[idx] == id_to_remove: del info.inputs[idx] del info.operations[idx] if isinstance(info.stop_gradient, list): del info.stop_gradient[idx] if isinstance(info.down_sampling, list): del info.down_sampling[idx] idx += 1 return info @staticmethod def from_str(ss, delim=DELIM): # deprecation protection if ss[0] == '{' and ss[-1] == '}' and delim in ss: return LayerInfoList(map(LayerInfo.from_str, ss.strip().split(delim))) l_dict = json.loads(ss) return LayerInfoList.from_json_loads(l_dict) @staticmethod def from_json_loads(l_dict): return LayerInfoList(map(LayerInfo.from_json_loads, l_dict)) @staticmethod def to_seq(l_info): """ Transform the layer info list into a sequence to be parsed by recurrent structures. """ n_layers = len(l_info) idx_to_id = list(map(lambda info : info.id, l_info)) id_to_idx = dict(map(lambda idx : (idx_to_id[idx], idx), range(n_layers))) seq = [] # format of appended info # 0 : merge type # 1 : is_candidate # 2 : stop_grad (currently always just 0 or 1) # 3 : down_sampling n_extra_dim = LayerInfoList.n_extra_dim for idx, info in enumerate(l_info): input_seq = [LayerTypes.NOT_EXIST] * (idx + n_extra_dim) n_inputs = len(info.inputs) for input_id, op in zip(info.inputs, info.operations[:n_inputs]): input_idx = id_to_idx[input_id] input_seq[input_idx] = op input_seq[idx] = info.merge_op input_seq[idx+1:] = [int(info.is_candidate > 0), int(info.stop_gradient), int(info.down_sampling), ] seq.extend(input_seq) return seq @staticmethod def seq_to_img_flag(seq, max_depth=128, make_batch=False): img = np.ones([max_depth, max_depth], dtype=int) * LayerTypes.NOT_EXIST flag = np.zeros([max_depth, LayerInfoList.n_extra_dim - 1], dtype=int) line_len = LayerInfoList.n_extra_dim start = 0 seq_len = len(seq) li = 0 while start < seq_len: end_img = start + li + 1 end_line = start + line_len img[li, :li+1] = seq[start:end_img] flag[li, :] = seq[end_img:end_line] start = end_line line_len += 1 li += 1 if make_batch: img = img.reshape([1, max_depth, max_depth]) flag = flag.reshape([1, max_depth, LayerInfoList.n_extra_dim - 1]) return img, flag @staticmethod def seq_to_hstr(seq, not_exist_str='--'): n_extra_dim = LayerInfoList.n_extra_dim line_len = n_extra_dim seq_len = len(seq) start = 0 ss = [] while start < seq_len: end = start + line_len ss.append(' '.join(map( lambda x : (not_exist_str if x == LayerTypes.NOT_EXIST else '{:02d}'.format(x)), seq[start:end]))) start = end line_len += 1 return '\n'.join(ss) @staticmethod def str_to_seq(ss): return LayerInfoList.to_seq(LayerInfoList.from_str(ss))