# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import numpy as np
import json
import copy
import functools

from tensorpack.utils import logger
Layer definitions

class LayerTypes(object):
    # Value regarding stop_gradients not exactly layer types

    # NOT_EXIST and NOTHING is for layer info sequence.
    NOT_EXIST = 0
    IDENTITY = 1

    # vision
    CONV_1 = 4
    CONV_3 = 5
    SEPARABLE_CONV_3_2 = 23
    SEPARABLE_CONV_5_2 = 24
    SEPARABLE_CONV_7_2 = 25
    DILATED_CONV_3 = 30
    DILATED_CONV_5 = 31

    MAXPOOL_3x3 = 8
    AVGPOOL_3x3 = 9

    # merge

    # Gates on hallucinations and the original path
    GATED_LAYER = 14

    # MLP and general operations
    FullyConnected = 15

    # RNN layers

    def num_layer_types():
        return 32 # Please update me whenever a new layer is made

    def no_param_ops():
        LT = LayerTypes
        return [LT.NOT_EXIST, LT.IDENTITY,

    def do_drop_path(op):
        LT = LayerTypes
        return (not LT.has_multi_inputs(op) and not op in LT.no_param_ops())

    def sample_layer_type(valid_ops, prob=None):
        n = len(valid_ops)
        if prob is None:
            prob = [1. / n] * n
            assert len(prob) == n
        return valid_ops[int(np.nonzero(np.random.multinomial(1, prob))[0][0])]

    def has_multi_inputs(op):
        return op in [LayerTypes.MERGE_WITH_CAT,
            LayerTypes.MERGE_WITH_SUM, LayerTypes.MERGE_WITH_AVG,
            LayerTypes.MERGE_WITH_MUL, LayerTypes.MERGE_WITH_CAT_PROJ,

class LayerInfo(dict):

    def __init__(self, layer_id, inputs=[], operations=[],
            down_sampling=0, stop_gradient=0, aux_weight=0,
            is_candidate=0, extra_dict=None):
            layer_id: a unqiue non-negative int that is the id of the layer.
                When overflows happens, an assertion failure will be triggered

            inputs: a list of layer_ids that are the inpu of the layer
            operations: a list of LayerType.<operations> on each of the input
            down_sampling: int/bool. whether this layer will dod down sampling.
            stop_gradient: int or list of int. Each int is 0 or 1, representing whether
                a stop_gradient call is done first to each of the input. If the value is
                just an int, then the same value is applied to all inputs.
            aux_weight: float auxiliary weight on the layer. This is used to mark where the
                prediction happens. E.g., the final layer always should have positive weight.
            is_candidate: int. Candidate id. This represents which candidate this layer is from.
                candidate==0 means the real network, >0 means it is from some hallucination.
        def _assert_is_list(v):
            assert isinstance(v, list), v
        def _assert_is_int_or_bool(v):
            assert isinstance(v, int) or isinstance(v, bool)
        assert isinstance(aux_weight, float) or isinstance(aux_weight, int)
        if len(operations) != len(inputs) + 1:
            assert len(inputs) == len(operations)
        if layer_id is not None:
            assert layer_id >=0, 'Hack failed, we have negative layer idx meow'

        self['id'] = layer_id
        self['inputs'] = inputs
        self['operations'] = operations
        self['down_sampling'] = down_sampling
        self['stop_gradient'] = stop_gradient
        self['aux_weight'] = aux_weight
        self['is_candidate'] = is_candidate
        if isinstance(extra_dict, dict) and len(extra_dict) > 0:
            self['extra_dict'] = extra_dict

    def id(self):
        return self['id']

    def id(self, val):
        self['id'] = val

    def inputs(self):
        return self['inputs']

    def inputs(self, val):
        self['inputs'] = val

    def operations(self):
        return self['operations']

    def operations(self, val):
        self['operations'] = val

    def down_sampling(self):
        return self['down_sampling']

    def down_sampling(self, val):
        self['down_sampling'] = val

    def stop_gradient(self):
        return self['stop_gradient']

    def stop_gradient(self, val):
        self['stop_gradient'] = val

    def aux_weight(self):
        return self['aux_weight']

    def aux_weight(self, val):
        self['aux_weight'] = val

    def is_candidate(self):
        return self['is_candidate']

    def is_candidate(self, val):
        self['is_candidate'] = val

    def merge_op(self):
        if len(self['inputs']) == len(self['operations']):
            return LayerTypes.MERGE_WITH_NOTHING
        return self['operations'][-1]

    def merge_op(self, val):
        if len(self['inputs']) == len(self['operations']):
        self['operations'][-1] = val

    def input_ops(self):
        return self['operations'][:len(self['inputs'])]

    def input_ops(self, val):
        curr = self.get('operations', [])
        if len(curr) <= 1:
            # there is no merge_op before
            self['operations'] = val
            self['operations'][:-1] = val

    def extra_dict(self):
        return self.get('extra_dict', None)

    def extra_dict(self, val):
        self['extra_dict'] = val

    def from_str(ss):
        d = json.loads(ss)
        return LayerInfo.from_json_loads(d)

    def to_str(self):
        return json.dumps(self)

    def from_json_loads(d):
        d['layer_id'] = d.pop('id', None)
        return LayerInfo(**d)

    def is_input(info):
        return len(info.inputs) == 0

class LayerInfoList(list):
    DELIM = '=^_^='
    LOG_DELIM = '\n'
    n_extra_dim = 4


    def __init__(self, *args, **kwargs):
        super(LayerInfoList, self).__init__(*args, **kwargs)

    def to_str(self):
        return json.dumps(self)

    def num_inputs(self):
        n_inputs = 0
        for info in self:
            if LayerInfo.is_input(info):
                n_inputs += 1
        return n_inputs

    def master(self):
        return self

    def is_end_merge_sum(self):
        LT = LayerTypes
        return self[-1].merge_op in [LT.MERGE_WITH_SUM, LT.MERGE_WITH_AVG]

    def is_end_merge_cat(self):
        LT = LayerTypes
        return self[-1].merge_op in [LT.MERGE_WITH_CAT, LT.MERGE_WITH_CAT_PROJ]

    def sample_cat_hallucinations(self, layer_ops, merge_ops,
        prob_at_layer=None, min_num_hallus=1, hallu_input_choice=None):
        prob_at_layer : probility of having input from a layer. None is translated
            to default, which sample a layer proportional to its ch_dim. The ch_dim
            is computed using self, as we assume the last op is cat, and the cat
            determines the ch_dim.

        assert self[-1].merge_op == LayerTypes.MERGE_WITH_CAT
        n_inputs = self.num_inputs()
        n_final_merge = len(self[-1].inputs)

        if prob_at_layer is None:
            prob_at_layer = np.ones(len(self) - 1)
            prob_at_layer[:n_inputs-1] = n_final_merge
            prob_at_layer[n_inputs-1] = n_final_merge * 1.5
            prob_at_layer = prob_at_layer / np.sum(prob_at_layer)
        assert len(prob_at_layer) >= len(self) - 1
        if len(prob_at_layer) > len(self) - 1:
            logger.warn("sample cell hallu cuts the prob_at_layer to len(info_list) - 1")
            prob_at_layer = prob_at_layer[:len(self)-1]

        # choose inputs
        n_hallu_inputs = 2
        l_hallu = []
        for _ in range(min_num_hallus):
            # replace == True : can connect multiple times to the same layer
            in_idxs = np.random.choice(list(range(len(prob_at_layer))),
                size=n_hallu_inputs, replace=False, p=prob_at_layer)
            in_ids = list(map(lambda idx : self[idx].id, in_idxs))
            main_ops = list(map(int, np.random.choice(layer_ops, size=n_hallu_inputs)))
            merge_op = int(np.random.choice(merge_ops))
            hallu = LayerInfo(layer_id=self[-1].id, inputs=in_ids,
                operations=main_ops + [merge_op])
        return l_hallu

    def add_cat_hallucinations(self, l_h,
        assert final_merge_op == LayerTypes.MERGE_WITH_SUM
        assert self[-1].merge_op == LayerTypes.MERGE_WITH_CAT
        l_info = self
        next_id = 0
        max_candidate = 0
        for info in l_info:
            next_id = max(next_id, info.id)
            max_candidate = max(max_candidate, info.is_candidate)
        max_candidate = 0
        l_insert = []
        final_merge_inputs = []
        for h in l_h:
            next_id += 1
            max_candidate += 1
            info = copy.deepcopy(h)
            info.id = next_id
            info.is_candidate = max_candidate
            info.stop_gradient = stop_gradient_val
        next_id += 1
        ops = [hallu_gate_layer] * len(final_merge_inputs) + [LayerTypes.MERGE_WITH_SUM]
        info = LayerInfo(layer_id=next_id, inputs=final_merge_inputs,
        l_info[-1].operations[-1:-1] = [LayerTypes.IDENTITY]
        l_info[-1:-1] = l_insert
        merge_pt = l_info[-1]
        assert not isinstance(merge_pt.stop_gradient, list), \
            "Merge point has invalid stop_gradient: {}".format(merge_pt)
        assert not isinstance(merge_pt.down_sampling, list), \
            "Merge point has invalid down_sampling: {}".format(merge_pt)
        return l_info

    def select_cat_hallucination(self,
            selected, contained=None, hallu_indices=None,
            l_fs_ops=None, l_fs_omega=None):
        assert self[-1].merge_op == LayerTypes.MERGE_WITH_CAT, self[-1].merge_op
        l_info = self
        if isinstance(selected, int):
            selected = [selected]
        if contained is None:
            contained = l_info.contained_hallucination()
        n_contained = len(contained)
        if n_contained == 0:
            return l_info
        if hallu_indices is None:
            hallu_indices = LayerInfoList.sorted_hallu_indices(contained)
        to_keep = []
        for h_idx in reversed(hallu_indices):
            start, end = contained[h_idx]
            if not h_idx in selected:
                l_info[start:end] = []
                for x in l_info[start:end]:
                    x.is_candidate = 0
                    x.stop_gradient = 0

        n_to_keep = len(to_keep)
        old_id = l_info[-2].id
        assert old_id == l_info[-1].inputs[-1], '{} != {}'.format(old_id, l_info[-2].id)

        def del_merge_hallu():
            del l_info[-2]

        def del_at_final_cat():
            del l_info[-1].inputs[-1]
            del l_info[-1].operations[-2]
            if isinstance(l_info[-1].down_sampling, list):
                assert len(l_info[-1].down_sampling) == len(l_info[-1].inputs) + 1
                del l_info[-1].down_sampling[-1]
            if isinstance(l_info[-1].stop_gradient, list):
                assert len(l_info[-1].stop_gradient) == len(l_info[-1].inputs) + 1
                del l_info[-1].stop_gradient[-1]

        if n_to_keep > 1:
            for keep_id in to_keep:
                assert keep_id in l_info[-2].inputs, \
                    '{} not in {}'.format(keep_id, l_info[-2].inputs)
            l_info[-2].inputs = to_keep
            l_info[-2].operations = ([LayerTypes.GATED_LAYER] * len(to_keep) +

        elif n_to_keep == 0:

        elif n_to_keep == 1:
            new_id = to_keep[0]
            assert new_id in l_info[-2].inputs, \
                '{} not in {}'.format(new_id, l_info[-2].inputs)
            l_info[-1].inputs, l_info[-1].stop_gradient = \
                    { old_id : new_id }, l_info[-1].inputs, l_info[-1].stop_gradient)
            l_info[-1].operations[-2] = LayerTypes.GATED_LAYER
        return l_info

    def _id_to_idx(self):
        id_to_idx = dict()
        for idx, info in enumerate(self):
            id_to_idx[info.id] = idx
        return id_to_idx

    def _distance_to_idx(self, t_idx, id_to_idx, max_dist=None):
        max_dist = max_dist if max_dist is not None else t_idx + 1
        l_dists = [max_dist + 1] * (t_idx + 1)
        is_fixed = [False] * (t_idx + 1)
        l_dists[-1] = 0
        is_fixed[-1] = True
        queue = [t_idx]
        qidx = 0
        while qidx < len(queue):
            idx = queue[qidx]
            dist = l_dists[idx]
            for in_id in self[idx].inputs:
                in_idx = id_to_idx[in_id]
                if not is_fixed[in_idx]:
                    is_fixed[in_idx] = True
                    l_dists[in_idx] = dist + 1
                    if l_dists[in_idx] < max_dist:
            qidx += 1
        return l_dists

    def _is_end_of_cell(self, idx):
        Given an index (idx) in the current layer info list,
        determine whether it is an end of a cell in the
        starting of a macro search.

        This method should only be used by a master l_info,
        i.e., a macro search master, or a cell search root/master.

        This method assumes that the original layers have id smaller than
        the last id, which is the last layer of the original
        network as well. Finally, the original layer ids also
        convey the depth of layers.

        (Known bug: we only use the connection pattern and id
        to check this. Ideally one should know this from
        detecting repeatable patterns of directly using
        the original construction function. This is not a
        concern for now as the seed cells are simple).
        l_info = self
        info = l_info[idx]
        if info.id > l_info[-1].id:
            # the info is inserted after seed model
            return False

        # Check input connection patterns.

        def seeded_inputs(_info):
            info_id = _info.id
            return [
                (info_id - in_id, in_op) \
                    for in_id, in_op in zip(_info.inputs, _info.operations) \
                    if in_id < info_id
        tar_id_ops = seeded_inputs(l_info[-1])
        id_ops = seeded_inputs(info)
        return id_ops == tar_id_ops

    def _sample_output_locations(
            self, min_num_hallus, cell_based):
        A list of indices in the layer info list that will be x_out
        l_info = self
        n_layers = len(l_info)
        n_inputs = l_info.num_inputs()
        LT = LayerTypes

        if cell_based:
            l_x_out = [n_layers - 1] * min_num_hallus
            # compute the eligible layers.
            indices = [
                idx for idx in range(n_inputs, n_layers) \
                    if l_info._is_end_of_cell(idx)
            #logger.info("The eligible ones are {}".format(indices))
            n_eligible = len(indices)
            repeats = min_num_hallus // n_eligible
            remainder = np.random.choice(
                indices, min_num_hallus % n_eligible, replace=False
            l_x_out = indices * repeats + list(remainder)
        return l_x_out

    def _sample_all_input_ids(self, l_x_out, cell_based):
        1. For cell search, inputs are either in the same
        cell or have lower id. So all layers are possible inputs.
        2. For macro search, we mimic the behavior, and limit
        the inputs to be in the same cell or the output of
        the two previous cells.

        l_x_out: a list of indices in the layer info list that are to be x_out
        cell_based (bool) : whether the NAS is for cell or macro search.

        A list of list of layer id that are to be input ids of the x_out in l_x_out.
        l_info = self
        l_inputs = []
        for out_idx in l_x_out:
            if cell_based:
                in_id_pool = [info.id for info in l_info[:out_idx]]
                # This is the same as cell based search case above,
                # if the cell has two inputs (prev and prev-prev cell output).
                n_origs = 2
                in_same_cell = True
                in_id_pool = []
                for idx in reversed(range(out_idx)):
                    if l_info._is_end_of_cell(idx):
                        in_same_cell = False
                        n_origs -= 1
                        if n_origs == 0:
                    if in_same_cell:
        return l_inputs

    # def _sample_input_ids(self, l_x_out, cell_based):
    #     """
    #     Args:
    #     l_x_out: a list of indices in the layer info list that are to be x_out
    #     cell_based (bool) : whether the NAS is for cell or macro search.

    #     Returns:
    #     A list of list of layer id that are to be input ids of the x_out in l_x_out.
    #     """
    #     l_info = self
    #     n_inputs = l_info.num_inputs()
    #     LT = LayerTypes

    #     # compute the current intensive operations.
    #     n_intensive_ops = 0
    #     for info in l_info[n_inputs:]:
    #         for op in info.operations:
    #             n_intensive_ops += int(
    #                 op in [
    #                     LT.SEPARABLE_CONV_3_2, LT.SEPARABLE_CONV_3_2,
    #                     LT.SEPARABLE_CONV_7_2, LT.CONV_1, LT.CONV_3,
    #                 ]
    #             )
    #     # n_ins_per_out is proportional to number of existing intensive
    #     # operations, and is at least 2 (unless not possible)
    #     n_ins_per_out = max(
    #         2, int(0.5 + float(n_intensive_ops) / len(l_x_out)))
    #     l_inputs = []

    #     # 1. For cell search, inputs are either in the same
    #     # cell or have lower id. So all layers are possible inputs.
    #     # 2. For macro search, we mimic the behavior, and limit
    #     # the inputs to be in the same cell or the output of
    #     # the two previous cells.
    #     max_orig_id = l_info[-1].id
    #     for out_idx in l_x_out:
    #         # Find the most recent two orig idx.
    #         # These are the output of cells.
    #         n_origs = 2
    #         in_same_cell = True
    #         in_id_pool = []
    #         for in_idx in reversed(range(out_idx)):
    #             infoid = l_info[in_idx].id
    #             if infoid <= max_orig_id:
    #                 in_same_cell = False
    #                 n_origs -= 1
    #                 in_id_pool.append(infoid)
    #                 if n_origs == 0:
    #                     break
    #             if in_same_cell:
    #                 in_id_pool.append(infoid)
    #         n_x_ins = min(n_ins_per_out, len(in_id_pool))
    #         inputs = [
    #             int(_x) for _x in np.random.choice(
    #                 in_id_pool, n_x_ins, replace=False)
    #         ]
    #         l_inputs.append(inputs)
    #     return l_inputs

    def sample_sum_hallucinations(
            self, layer_ops, merge_ops, prob_at_layer=None,
            min_num_hallus=1, hallu_input_choice=None,
        Sample hallus
        l_info = self
        LT = LayerTypes
        do_feat_sel = (
            len(merge_ops) == 1 and
            merge_ops[0] == LT.MERGE_WITH_WEIGHTED_SUM)
        l_x_out = self._sample_output_locations(
            min_num_hallus, cell_based)
        l_inputs = self._sample_all_input_ids(l_x_out, cell_based)
        hallus = []
        for out_idx, inputs in zip(l_x_out, l_inputs):
            out_id = l_info[out_idx].id
            if len(inputs) == 0:
            if not do_feat_sel:
                # sample operations
                main_ops = list(map(int, np.random.choice(layer_ops, len(inputs))))
                merge_op = int(np.random.choice(merge_ops))
                # feature selection will form all ops
                # op1,..., opk, op1, ..., opk ,....
                main_ops = list(layer_ops) * len(inputs)
                merge_op = LT.MERGE_WITH_WEIGHTED_SUM
                # in1,..., in1, in2, ..., in2, ...
                inputs = [x for l_x in zip(*[inputs] * len(layer_ops)) for x in l_x]
            hallu = LayerInfo(layer_id=out_id, inputs=inputs,
                operations=main_ops + [ merge_op ])
        return hallus

    def select_sum_hallucination(self,
            selected, contained=None, hallu_indices=None,
            l_fs_ops=None, l_fs_omega=None):
        l_info = self
        if isinstance(selected, int):
            selected = [selected]
        if contained is None:
            contained = l_info.contained_hallucination()
        n_contained = len(contained)
        if n_contained == 0:
            return l_info
        if hallu_indices is None:
            hallu_indices = LayerInfoList.sorted_hallu_indices(contained)

        for h_idx_idx in reversed(range(len(hallu_indices))):
            h_idx = hallu_indices[h_idx_idx]
            start, end = contained[h_idx]
            h_layer_id = l_info[end-1].id
            found = False
            # find the exact one user of this candidate.
            for idx in range(end, len(l_info)):
                if h_layer_id in l_info[idx].inputs:
                    found = True
            assert found, "Did not find id {}".format(h_layer_id)
            to_remove = (
                (not h_idx in selected) or
                (l_info[end-1].merge_op == LayerTypes.MERGE_WITH_WEIGHTED_SUM and
                 len(l_fs_ops[h_idx_idx]) == 0)
            if to_remove:
                l_info[idx] = LayerInfoList._remove_connection_from_id(
                    l_info[idx], h_layer_id)
                l_info[start:end] = []
                for x in l_info[start:end]:
                    x.is_candidate = 0
                    x.stop_gradient = 0
                # feature selection case:
                hallu_info = l_info[end-1]
                if hallu_info.merge_op == LayerTypes.MERGE_WITH_WEIGHTED_SUM:
                    fs_indices = l_fs_ops[h_idx_idx]
                    fs_omega = l_fs_omega[h_idx_idx]
                    assert len(fs_indices) == len(fs_omega), \
                        'Invalid feat select info i={} omega={}'.format(
                            fs_indices, fs_omega)
                    new_inputs = [
                        hallu_info.inputs[in_idx] for in_idx in fs_indices
                    new_operations = [
                        hallu_info.operations[in_idx] for in_idx in fs_indices
                    l_info[end-1].inputs = new_inputs
                    l_info[end-1].operations = (
                        #new_operations + [LayerTypes.MERGE_WITH_WEIGHTED_SUM]
                        new_operations + [LayerTypes.MERGE_WITH_CAT_PROJ]
                    ed = hallu_info.extra_dict
                    ed = dict() if ed is None else ed
                    ed['ops_ids'] = list(map(int, fs_indices))
                    ed['fs_omega'] = list(map(float, fs_omega))
                    l_info[end-1].extra_dict = ed
                # fix future reference to hallu
                for in_idx, in_id in enumerate(l_info[idx].inputs):
                    if in_id == h_layer_id:
                        l_info[idx].operations[in_idx] = LayerTypes.GATED_LAYER
        return l_info

    def add_sum_hallucinations(
            self, l_h,
        Add hallus
        l_info = self
        next_id = 0
        max_candidate = 0
        for info in l_info:
            next_id = max(next_id, info.id)
            max_candidate = max(max_candidate, info.is_candidate)

        last_o_l_info_idx = 0
        for hallu in l_h:
            o_id = hallu.id
            n_layers = len(l_info)
            for idx in range(last_o_l_info_idx, n_layers):
                if l_info[idx].id == o_id:
                    last_o_l_info_idx = idx
            assert l_info[last_o_l_info_idx].id == o_id, \
                "info {} does not find id {}; start at {}".format(
                    l_info, o_id, last_o_l_info_idx)

            next_id += 1
            max_candidate += 1

            info = copy.deepcopy(hallu)
            info.id = next_id
            info.is_candidate = max_candidate
            info.stop_gradient = stop_gradient_val

            l_info[last_o_l_info_idx].operations[-1:-1] = [hallu_gate_layer]
            merge_pt = l_info[last_o_l_info_idx]
            assert not isinstance(merge_pt.stop_gradient, list), \
                "Merge point has invalid stop_gradient: {}".format(merge_pt)
            assert not isinstance(merge_pt.down_sampling, list), \
                "Merge point has invalid down_sampling: {}".format(merge_pt)
            l_info[last_o_l_info_idx:last_o_l_info_idx] = [info]
        return l_info

    def contained_hallucination(self):
        l_info (list of LayerInfo) : a model description that contains hallu

        hallu_idx_to_range : a dict mapping from candidate_id to (start, end) for candidate_id > 0.
            l_info[start:end] contains exactly all info such that info.is_candidate == candidate_id.
        l_info = self
        n_layers = len(l_info)
        prev_candidate = 0
        hallu_idx_to_range = {}
        for idx, info in enumerate(l_info):
            candidate = info.is_candidate
            if candidate == 0:
                prev_candidate = 0

            is_start = prev_candidate != candidate
            is_last = ((idx + 1 == n_layers) or (l_info[idx+1].is_candidate != candidate))
            prev_candidate = candidate

            if is_start:
                start = idx
            if is_last:
                hallu_idx_to_range[candidate] = (start, idx+1)
        return hallu_idx_to_range

    def sorted_hallu_indices(hallu_idx_to_range):
        return sorted(hallu_idx_to_range, key=lambda i : hallu_idx_to_range[i][0])

    def _create_info_merge(next_id, h_id, o_id, aux_weight, is_candidate,
        Form the LayerInfo for the merge operation between hallu of id h_id and the original
        tensor of id o_id (out_id). The new LayerInfo will have info.id == next_id.
        Return a list of layers used for merging
        Note any change to this function need to be mirrored in _finalize_info_merge
        inputs = [None] * 2
        inputs[LayerInfoList.ORIG_IDX_IN_MERGE_HALLU] = o_id
        inputs[LayerInfoList.HALLU_IDX_IN_MERGE_HALLU] = h_id
        operations = [LayerTypes.IDENTITY] * 2 + [final_merge_op]
        operations[LayerInfoList.HALLU_IDX_IN_MERGE_HALLU] = hallu_gate_layer
        info = LayerInfo(next_id, inputs=inputs, operations=operations,
            aux_weight=aux_weight, is_candidate=is_candidate)
        return [info]

    def _finalize_info_merge(l_info, start, end):
        Given an l_info, and start and end idx of a hallu, update the params of l_info[start:end],
        so that they will become part of the model permenantly.
        Note that any change to this function need to be mirrored in _create_info_merge
        for x in l_info[start:end]:
            x.is_candidate = 0
            x.stop_gradient = 0
        anti_gate_idx = LayerInfoList.ORIG_IDX_IN_MERGE_HALLU
        l_info[end-1].operations[anti_gate_idx] = LayerTypes.ANTI_GATED_LAYER
        gate_idx = LayerInfoList.HALLU_IDX_IN_MERGE_HALLU
        l_info[end-1].operations[gate_idx] = LayerTypes.GATED_LAYER
        return l_info

    def _rewire_inputs_post_add(prev_id_to_new_id, inputs, stop_gradient):
            Given a dictionary mapping from old id to the new id,
            we swap the occurance of old_id in inputs to new id.
            Return: updated inputs, and stop_gradient
        if not isinstance(stop_gradient, list):
            stop_gradient = [stop_gradient] * len(inputs)
        for ini, inid in enumerate(inputs):
            if inid in prev_id_to_new_id:
                inputs[ini] = prev_id_to_new_id[inid]
                stop_gradient[ini] = 0
        sg_ref = stop_gradient[0]
        for sg in stop_gradient:
            if sg != sg_ref:
        if sg == sg_ref:
            stop_gradient = sg_ref
        return inputs, stop_gradient

    def _remove_connection_from_id(info, id_to_remove):
        if not id_to_remove in info.inputs:
            return info
        if isinstance(info.stop_gradient, list):
            assert len(info.stop_gradient) == len(info.inputs), \
                "Invalid info {}".format(info)
        if isinstance(info.down_sampling, list):
            assert len(info.down_sampling) == len(info.inputs), \
                "Invalid info {}".format(info)
        assert len(info.operations) == len(info.inputs) + 1, \
            "Invalid info {}".format(info)

        idx = 0
        while idx < len(info.inputs):
            if info.inputs[idx] == id_to_remove:
                del info.inputs[idx]
                del info.operations[idx]
                if isinstance(info.stop_gradient, list):
                    del info.stop_gradient[idx]
                if isinstance(info.down_sampling, list):
                    del info.down_sampling[idx]
            idx += 1
        return info

    def from_str(ss, delim=DELIM):
        # deprecation protection
        if ss[0] == '{' and ss[-1] == '}' and delim in ss:
            return LayerInfoList(map(LayerInfo.from_str, ss.strip().split(delim)))
        l_dict = json.loads(ss)
        return LayerInfoList.from_json_loads(l_dict)

    def from_json_loads(l_dict):
        return LayerInfoList(map(LayerInfo.from_json_loads, l_dict))

    def to_seq(l_info):
        Transform the layer info list into a sequence to be parsed by recurrent structures.
        n_layers = len(l_info)
        idx_to_id = list(map(lambda info : info.id, l_info))
        id_to_idx = dict(map(lambda idx : (idx_to_id[idx], idx), range(n_layers)))
        seq = []
        # format of appended info
        # 0 : merge type
        # 1 : is_candidate
        # 2 : stop_grad (currently always just 0 or 1)
        # 3 : down_sampling
        n_extra_dim = LayerInfoList.n_extra_dim
        for idx, info in enumerate(l_info):
            input_seq = [LayerTypes.NOT_EXIST] * (idx + n_extra_dim)
            n_inputs = len(info.inputs)
            for input_id, op in zip(info.inputs, info.operations[:n_inputs]):
                input_idx = id_to_idx[input_id]
                input_seq[input_idx] = op
            input_seq[idx] = info.merge_op
            input_seq[idx+1:] = [int(info.is_candidate > 0),
        return seq

    def seq_to_img_flag(seq, max_depth=128, make_batch=False):
        img = np.ones([max_depth, max_depth], dtype=int) * LayerTypes.NOT_EXIST
        flag = np.zeros([max_depth, LayerInfoList.n_extra_dim - 1], dtype=int)
        line_len = LayerInfoList.n_extra_dim
        start = 0
        seq_len = len(seq)
        li = 0
        while start < seq_len:
            end_img = start + li + 1
            end_line = start + line_len
            img[li, :li+1] = seq[start:end_img]
            flag[li, :] = seq[end_img:end_line]
            start = end_line
            line_len += 1
            li += 1
        if make_batch:
            img = img.reshape([1, max_depth, max_depth])
            flag = flag.reshape([1, max_depth, LayerInfoList.n_extra_dim - 1])
        return img, flag

    def seq_to_hstr(seq, not_exist_str='--'):
        n_extra_dim = LayerInfoList.n_extra_dim
        line_len = n_extra_dim
        seq_len = len(seq)
        start = 0
        ss = []
        while start < seq_len:
            end = start + line_len
            ss.append(' '.join(map(
                lambda x : (not_exist_str if x == LayerTypes.NOT_EXIST
                    else '{:02d}'.format(x)),
            start = end
            line_len += 1
        return '\n'.join(ss)

    def str_to_seq(ss):
        return LayerInfoList.to_seq(LayerInfoList.from_str(ss))