from collections import defaultdict, deque



import os
import pickle
import numpy as np
import torch
import torch.nn.functional as F

from torch import nn
# from torch.tensor import ones
from models.cnn_layers import CNN_LAYER_CREATION_FUNCTIONS, initialize_layers_weights, get_cnn_layer_with_names
from scipy.special import expit, logit
from typing import List

from models.shared_base import *
from utils import get_logger, get_variable, keydefaultdict

logger = get_logger()


def node_to_key(node):
    idx, jdx, _type = node
    if isinstance(_type, str):
        return f'{idx}-{jdx}-{_type}'
    else:
        return f'{idx}-{jdx}-{_type.__name__}'


def dag_to_keys(dag):
    return [node_to_key(node) for node in dag]


class Architecture:
    """Represents some hyperparameters of the architecture requested.
    final_filter_size is the number of filters of the cell before the output layer.
    Each reduction filter doubles the number of filters (as it halves the width and height)
    There are num_modules modules stacked together.
    Each module except for the final one is made up of num_repeat_normal normal Cells followed by a reduction cell.
    The final layer doesn't have the reduction cell.
    """

    def __init__(self, final_filter_size, num_repeat_normal, num_modules):
        self.final_filter_size = final_filter_size
        self.num_repeat_normal = num_repeat_normal
        self.num_modules = num_modules


class CNN(SharedModel):
    """Represents a Meta-Convolutional network made up of Meta-Convolutional Cells.
       Paths through the cells can be selected and moved to the gpu for training and evaluation.

       Adapted from online code. need intense modification.

    """

    def __init__(self, args, corpus):
        """
                 # input_channels, height, width, output_classes, gpu, num_cell_blocks,
                 # architecture=Architecture(final_filter_size=768 // 2, num_repeat_normal=6, num_modules=3)):

        :param args: arguments
        :param corpus: dataset
        """
        super(CNN, self).__init__(args)
        self.args = args
        self.corpus = corpus

        architecture = Architecture(final_filter_size=args.cnn_final_filter_size,
                                    num_repeat_normal=args.cnn_num_repeat_normal,
                                    num_modules=args.cnn_num_modules)
        
        input_channels = args.cnn_input_channels
        self.height = args.cnn_height
        self.width = args.cnn_width
        self.output_classes = args.output_classes
        self.architecture = architecture

        self.output_height = self.height
        self.output_width = self.width
        self.num_cell_blocks = args.num_blocks

        self.cells = nn.Sequential()
        self.reduce_cells = nn.Sequential()
        self.normal_cells = nn.Sequential()

        self.gpu = torch.device("cuda:0") if args.num_gpu > 0 else torch.device('cpu')
        self.cpu_device = torch.device("cpu")

        self.dag_variables_dict = {}
        self.reducing_dag_variables_dict = {}

        last_input_info = _CNNCell.InputInfo(input_channels=input_channels, input_width=self.width)
        current_input_info = _CNNCell.InputInfo(input_channels=input_channels, input_width=self.width)

        # count connections
        temp_cell = _CNNCell(input_infos=[last_input_info, current_input_info],
                             output_channels=architecture.final_filter_size,
                             output_width=self.output_width, reducing=False, dag_vars=None,
                             num_cell_blocks=self.num_cell_blocks)

        self.all_connections = list(temp_cell.connections.keys()) # as all possible connections.

        self.dag_variables = torch.ones(len(self.all_connections), requires_grad=True, device=self.gpu)
        self.reducing_dag_variables = torch.ones(len(self.all_connections), requires_grad=True, device=self.gpu)

        for i, key in enumerate(self.all_connections):
            self.dag_variables_dict[key] = self.dag_variables[i]
            self.reducing_dag_variables_dict[key] = self.reducing_dag_variables[i]

        cells = [('normal', architecture.final_filter_size)] * architecture.num_repeat_normal
        current_filter_size = architecture.final_filter_size
        for module in range(architecture.num_modules - 1):
            cells.append(('reducing', current_filter_size))
            current_filter_size //= 2
            cells.extend([('normal', current_filter_size)] * architecture.num_repeat_normal)

        cells.reverse()

        for i, (type, num_filters) in enumerate(cells):
            if type == 'reducing':
                self.output_height /= 2
                self.output_width /= 2
                reducing = True
            else:
                reducing = False
                assert (type == 'normal')

            dag_vars = self.dag_variables_dict if reducing == False else self.reducing_dag_variables_dict
            name = f'{i}-{type}-{num_filters}'
            a_cell = _CNNCell(input_infos=[last_input_info, current_input_info],
                                           output_channels=num_filters, output_width=self.output_width,
                                           reducing=reducing, dag_vars=dag_vars, num_cell_blocks=self.num_cell_blocks,
                                           args=self.args)
            self.cells.add_module(name, a_cell)

            # Registering for the WPL later.
            if reducing:
                self.reduce_cells.add_module(name, a_cell)
            else:
                self.normal_cells.add_module(name, a_cell)

            last_input_info, current_input_info = current_input_info, _CNNCell.InputInfo(input_channels=num_filters,
                                                                                         input_width=self.output_width)

        if self.output_classes:
            self.conv_output_size = self.output_height * self.output_width * self.architecture.final_filter_size
            self.out_layer = nn.Linear(self.conv_output_size, self.output_classes)
            torch.nn.init.kaiming_normal_(self.out_layer.weight, mode='fan_out', nonlinearity='relu')
            torch.nn.init.constant_(self.out_layer.bias, 0)
            self.out_layer.to(self.gpu)

        parent_counts = [0] * (2 + self.num_cell_blocks)

        for idx, jdx, _type in self.all_connections:
            parent_counts[jdx] += 1

        probs = np.array(list(2 / parent_counts[jdx] for idx, jdx, _type in self.all_connections))
        self.dags_logits = (logit(probs), logit(probs))

        self.target_ave_prob = np.mean(probs)
        self.cell_dags = ([], [])

        self.ignore_module_keys = ['cell', 'out_layer']
        self.wpl_monitored_modules = self.cells._modules
        self.init_wpl_weights()

    def forward(self, inputs,
                dag,
                is_train=True,
                hidden=None
                ):
        """
        :param cell_dags: (normal_cell_dag, reduction_cell_dag)
        :param inputs: [last_input, current_input]
        :param hidden: don't care. legacy for RNN.
        """

        cell_dag, reducing_cell_dag = dag or self.cell_dags
        # cell_dag, reducing_cell_dag = dag   # support the dynamic dags.

        is_train = is_train and self.args.mode in ['train'] # add here for behaviors differs from train and test.

        last_input, current_input = inputs, inputs

        for cell in self.cells:
            if cell.reducing:
                dag = reducing_cell_dag
            else:
                dag = cell_dag
            output, extra_out = cell(dag, last_input, current_input)
            last_input, current_input = current_input, output

        x = output.view(-1, self.conv_output_size)
        x = self.out_layer(x)
        return x, extra_out

    def get_f(self, name):
        """ Get the cell structure """
        name = name.lower()
        # return f
        raise NotImplementedError

    def get_num_cell_parameters(self, dag):
        """
        Returns the parameters of the path through the Meta-network given by the dag.
        :param dag: a list of [normal_dag, reduce_dag]
        return parameters.
        """
        dag, reducing_dag = dag
        params = []
        for cell in self.cells:
            if cell.reducing:
                d = reducing_dag
            else:
                d = dag
            params.extend(cell.get_parameters(d))
        # return params
        raise NotImplementedError

    def get_parameters(self, dags):
        """ return the parameter of given dags """
        dag, reducing_dag = dags
        params = []
        for cell in self.cells:
            if cell.reducing:
                d = reducing_dag
            else:
                d = dag
            params.extend(cell.get_parameters(d))
        return params

    def reset_parameters(self):
        """ reset all parameters ? """
        params = self.get_parameters(self.cell_dags)
        raise NotImplementedError('reset not implemented')

    def update_dag_logits(self, gradient_dicts, weight_decay, max_grad=0.1):
        """
        Updates the probabilities of each path being selected using the given gradients.
        """
        dag_probs = tuple(expit(logit) for logit in self.dags_logits)
        current_average_dag_probs = tuple(np.mean(prob) for prob in dag_probs)

        for i, key in enumerate(self.all_connections):
            for grad_dict, current_average_dag_prob, dag_logits in zip(gradient_dicts, current_average_dag_probs,
                                                                       self.dags_logits):
                if key in grad_dict:
                    grad = grad_dict[key] - weight_decay * (
                            current_average_dag_prob - self.target_ave_prob)  # *expit(dag_logits[i])
                    deriv = sigmoid_derivitive(dag_logits[i])
                    logit_grad = grad * deriv
                    dag_logits[i] += np.clip(logit_grad, -max_grad, max_grad)

    def get_dags_probs(self):
        """Returns the current probability of each path being selected.
        Each index corresponds to the connection in self.all_connections
        """
        return tuple(expit(logits) for logits in self.dags_logits)

    def __to_device(self, device, cell_dags):
        cell_dag, reducing_cell_dag = cell_dags
        for cell in self.cells:
            if cell.reducing:
                cell.to_device(device, reducing_cell_dag)
            else:
                cell.to_device(device, cell_dag)

    def set_dags(self, new_cell_dags=([], [])):
        """
        Sets the current active path. Moves other variables to the cpu to save gpu memory.

        :param new_cell_dags: (normal_cell_dag, reduction_cell_dag)
        """
        new_cell_dags = tuple(list(sorted(cell_dag)) for cell_dag in new_cell_dags)

        set_cell_dags = [set(cell_dag) for cell_dag in new_cell_dags]
        last_set_cell_dags = [set(cell_dag) for cell_dag in self.cell_dags]

        cell_dags_to_cpu = [last_set_cell_dag - set_cell_dag
                            for last_set_cell_dag, set_cell_dag in zip(last_set_cell_dags, set_cell_dags)]
        cell_dags_to_gpu = [set_cell_dag - last_set_cell_dag
                            for last_set_cell_dag, set_cell_dag in zip(last_set_cell_dags, set_cell_dags)]

        self.__to_device(self.cpu_device, cell_dags_to_cpu)
        self.__to_device(self.gpu, cell_dags_to_gpu)
        self.cell_dags = new_cell_dags

    # doing this is very important for grouping all the cells and unified the process.
    # maybe can move this to outer cells.
    # def init_wpl_weights(self):
    #     """
    #     Init for WPL operations.
    #
    #     NOTE: only take care of all the weights in self._modules, and others.
    #     for self parameters and operations, please override later.
    #
    #     :return:
    #     """
    #     for cell in self.cells:
    #         if isinstance(cell, WPLModule):
    #             cell.init_wpl_weights()
    #
    # def set_fisher_zero(self):
    #     for cell in self.cells:
    #         if isinstance(cell, WPLModule):
    #             cell.set_fisher_zero()
    #
    # def update_optimal_weights(self):
    #     """ Update the weights with optimal """
    #     for cell in self.cells:
    #         if isinstance(cell, WPLModule):
    #             cell.update_optimal_weights()

    def update_fisher(self, dags):
        """ logic is different here, for dags, update all the cells registered. """
        normal, reduce = dags
        for cell in self.cells:
            if cell.reducing:
                d = reduce
            else:
                d = normal
            cell.update_fisher(d)

    def compute_weight_plastic_loss_with_update_fisher(self, dags):
        loss = 0
        normal, reduce = dags
        for cell in self.cells:
            if cell.reducing:
                d = reduce
            else:
                d = normal
            loss += cell.compute_weight_plastic_loss_with_update_fisher(d)
        return loss


# Represents a Meta-Convolutional cell. It generates a possible forward connection between
# every layer except between the input layers of every type in CNN_LAYER_CREATION_FUNCTIONS
# Any path can then be chose to run and train with
class _CNNCell(WPLModule):

    class InputInfo:
        def __init__(self, input_channels, input_width):
            self.input_channels = input_channels
            self.input_width = input_width

    def __init__(self, input_infos: List[InputInfo],
                 output_channels, output_width,
                 reducing, dag_vars, num_cell_blocks,
                 args=None):
        super().__init__(args)

        self.input_infos = input_infos
        self.num_inputs = len(self.input_infos)
        self.num_cell_blocks = num_cell_blocks
        num_outputs = self.num_inputs + num_cell_blocks
        self.output_channels = output_channels
        self.output_width = output_width
        self.reducing = reducing
        self.dag_vars = dag_vars

        self.connections = dict()
        # self._connections = nn.ModuleList()

        for idx in range(num_outputs - 1):
            for jdx in range(max(idx + 1, self.num_inputs), num_outputs):
                for _type, type_name in get_cnn_layer_with_names():
                    if idx < self.num_inputs:
                        input_info = self.input_infos[idx]
                        if input_info.input_width != output_width:
                            assert (input_info.input_width / 2 == output_width)
                            stride = 2
                        else:
                            stride = 1
                        in_planes = input_info.input_channels

                    else:
                        stride = 1
                        in_planes = output_channels

                    out_planes = output_channels
                    try:
                        self.connections[(idx, jdx, type_name)] = _type(in_planes=in_planes, out_planes=out_planes,
                                                                        stride=stride)
                    except RuntimeError as e:
                        logger.error(f'Identity Matching error {e}')

                    initialize_layers_weights(self.connections[(idx, jdx, type_name)])
                    self.add_module(node_to_key((idx, jdx, type_name)), self.connections[(idx, jdx, type_name)])

        self.init_wpl_weights()

    def forward(self, dag, *inputs):
        """
        Define the actual CELL of one CNN structure.

        :param dag:
        :param inputs:
        :return:
            output: whatever output this mean
            extra_out: dict{string_keys}: to output additional variable/Tensors for regularization.
        """
        assert (len(inputs) == self.num_inputs)
        inputs = list(inputs)
        inputs = inputs + self.num_cell_blocks * [None]
        outputs = [0] * (self.num_inputs + self.num_cell_blocks)
        num_inputs = [0] * (self.num_inputs + self.num_cell_blocks)
        inputs_relu = [None] * (self.num_inputs + self.num_cell_blocks)

        for source, target, _type in dag:
            key = (source, target, _type)
            conn = self.connections[key]

            if inputs[source] is None:
                outputs[source] /= num_inputs[source]
                inputs[source] = outputs[source]
            layer_input = inputs[source]
            if hasattr(conn, 'input_relu') and conn.input_relu:
                if inputs_relu[source] is None:
                    inputs_relu[source] = torch.nn.functional.relu(layer_input)
                layer_input = inputs_relu[source]

            val = conn(layer_input) * self.dag_vars[key]
            outputs[target] += val
            num_inputs[target] += self.dag_vars[key]

        outputs[-1] /= num_inputs[-1]
        output = outputs[-1]
        raw_output = output

        extra_out = {'dropped': None,
                     'hiddens': None,
                     'raw': raw_output}

        return output, extra_out

    def to_device(self, device, dag):
        """Moves the parameters on the specified path to the device"""
        for source, target, type_name in dag:
            self.connections[(source, target, type_name)].to(device)

    def get_parameters(self, dag):
        """Returns the parameters of the path through the Cell given by the dag."""
        params = []
        for key in dag:
            params.extend(self.connections[key].parameters())
        return params

    def update_fisher(self, dag):
        """ a single dag"""
        super(_CNNCell, self).update_fisher(dag_to_keys(dag))

    def compute_weight_plastic_loss_with_update_fisher(self, dag):
        return super(_CNNCell, self).compute_weight_plastic_loss_with_update_fisher(dag_to_keys(dag))


def sigmoid_derivitive(x):
    """Returns the derivitive of a sigmoid function at x"""
    return expit(x) * (1.0 - expit(x))