import tensorflow as tf
import numpy as np


class SegNet(object):
    """DeepLab model."""

    def __init__(self, batch_size=1,
                 num_classes=47,
                 lrn_rate=0.0001,
                 lr_decay_step=70000,
                 lrn_rate_end=0.00001,
                 weight_decay_rate=0.0001,
                 optimizer='adam',  # 'sgd' or 'mom' or 'adam'
                 images=tf.placeholder(tf.float32, [None, 750, 750, 3]),
                 labels=tf.placeholder(tf.int32),
                 ignore_class_bg=True,
                 mode='test',
                 is_intermediate=False):
        """SegNet constructor.

    Args:
      : Hyperparameters.
      images: Batches of images. [batch_size, image_size, image_size, 3]
      labels: Batches of labels. [batch_size, image_size, image_size]
    """
        self.images = images
        self.labels = labels
        self.H = tf.shape(self.images)[1]
        self.W = tf.shape(self.images)[2]
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.lrn_rate = lrn_rate
        self.lr_decay_step = lr_decay_step
        self.lrn_rate_end = lrn_rate_end
        self.weight_decay_rate = weight_decay_rate
        self.optimizer = optimizer
        self.ignore_class_bg = ignore_class_bg
        self.mode = mode
        self.is_intermediate = is_intermediate
        self._extra_train_ops = []

        with tf.variable_scope("SegNet"):
            self.build_graph()

    def build_graph(self):
        """Build a whole graph for the model."""
        self._build_model()
        if self.mode == 'train':
            self._build_train_op()

    def _build_model(self):
        x = self.images

        # encoders
        with tf.variable_scope('enc_1'):
            x = self.conv_bn_relu('conv1', x, 3, 64)
            x = self.conv_bn_relu('conv2', x, 3, 64)
            x, ind_1 = tf.nn.max_pool_with_argmax(x, [1, 2, 2, 1], [1, 2, 2, 1], "SAME")  # (N, 384, 384, 64)

        with tf.variable_scope('enc_2'):
            x = self.conv_bn_relu('conv1', x, 3, 128)
            x = self.conv_bn_relu('conv2', x, 3, 128)
            x, ind_2 = tf.nn.max_pool_with_argmax(x, [1, 2, 2, 1], [1, 2, 2, 1], "SAME")  # (N, 192, 192, 128)

        with tf.variable_scope('enc_3'):
            x = self.conv_bn_relu('conv1', x, 3, 256)
            x = self.conv_bn_relu('conv2', x, 3, 256)
            x = self.conv_bn_relu('conv3', x, 3, 256)
            x, ind_3 = tf.nn.max_pool_with_argmax(x, [1, 2, 2, 1], [1, 2, 2, 1], "SAME")  # (N, 96, 96, 256)

        with tf.variable_scope('enc_4'):
            x = self.conv_bn_relu('conv1', x, 3, 512)
            x = self.conv_bn_relu('conv2', x, 3, 512)
            x = self.conv_bn_relu('conv3', x, 3, 512)
            x, ind_4 = tf.nn.max_pool_with_argmax(x, [1, 2, 2, 1], [1, 2, 2, 1], "SAME")  # (N, 48, 48, 512)

        with tf.variable_scope('enc_5'):
            x = self.conv_bn_relu('conv1', x, 3, 512)
            x = self.conv_bn_relu('conv2', x, 3, 512)
            x = self.conv_bn_relu('conv3', x, 3, 512)
            x, ind_5 = tf.nn.max_pool_with_argmax(x, [1, 2, 2, 1], [1, 2, 2, 1], "SAME")  # (N, 24, 24, 512)

        # decoders
        with tf.variable_scope('dec_5'):
            x = self._unpool_2d(x, ind_5, out_size=[48, 48])
            x = self.conv_bn_relu('conv1', x, 3, 512)
            x = self.conv_bn_relu('conv2', x, 3, 512)
            x = self.conv_bn_relu('conv3', x, 3, 512)

        with tf.variable_scope('dec_4'):
            x = self._unpool_2d(x, ind_4, out_size=[96, 96])
            x = self.conv_bn_relu('conv1', x, 3, 512)
            x = self.conv_bn_relu('conv2', x, 3, 512)
            # x = self.conv_bn_relu('conv3', x, 3, 256)

        if self.is_intermediate:
            self.intermediate_feat = x
            return

        with tf.variable_scope('dec_3'):
            x = self._unpool_2d(x, ind_3, out_size=[188, 188])
            x = self.conv_bn_relu('conv1', x, 3, 256)
            x = self.conv_bn_relu('conv2', x, 3, 256)
            x = self.conv_bn_relu('conv3', x, 3, 128)

        with tf.variable_scope('dec_2'):
            x = self._unpool_2d(x, ind_2, out_size=[375, 375])
            x = self.conv_bn_relu('conv1', x, 3, 128)
            x = self.conv_bn_relu('conv2', x, 3, 64)

        with tf.variable_scope('dec_1'):
            x = self._unpool_2d(x, ind_1, out_size=[750, 750])
            x = self.conv_bn_relu('conv1', x, 3, 64)
            x = self.conv_bn_relu('conv2', x, 3, self.num_classes)

        logits_up = x

        # below is similar to Deeplab-v2

        self.logits_up = logits_up  # (N, H, W, num_classes)
        logits_flat = tf.reshape(self.logits_up, [-1, self.num_classes])
        pred = tf.nn.softmax(logits_flat)
        self.pred = tf.reshape(pred, tf.shape(self.logits_up))  # shape = [1, H, W, nClasses]

        pred_label = tf.argmax(self.pred, 3)  # shape = [1, H, W]
        pred_label = tf.expand_dims(pred_label, axis=3)
        self.pred_label = pred_label  # shape = [1, H, W, 1], contains [0, nClasses)

    def conv_bn_relu(self, name, input, ksize, out_size, stride=1):
        in_size = input.shape[3]
        rst = self._conv(name, input, ksize, in_size, out_size, self._stride_arr(stride))
        rst = tf.contrib.layers.batch_norm(rst)
        rst = tf.nn.relu(rst)
        return rst

    def _stride_arr(self, stride):
        """Map a stride scalar to the stride array for tf.nn.conv2d."""
        return [1, stride, stride, 1]

    def _conv(self, name, x, filter_size, in_filters, out_filters, strides):
        """Convolution."""
        with tf.variable_scope(name):
            n = filter_size * filter_size * out_filters
            w = tf.get_variable('DW', [filter_size, filter_size, in_filters, out_filters], tf.float32,
                                initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / n)))
            conv = tf.nn.conv2d(x, w, strides, padding='SAME')
            b = tf.get_variable('biases', [out_filters], initializer=tf.constant_initializer())
            return conv + b

    def _unpool_2d(self,
                   pool,
                   ind,
                   out_size,
                   scope='unpool_2d'):
        """Adds a 2D unpooling op.
        https://arxiv.org/abs/1505.04366
        Unpooling layer after max_pool_with_argmax.
           Args:
               pool:        max pooled output tensor
               ind:         argmax indices
               stride:      stride is the same as for the pool
           Return:
               unpool:    unpooling tensor
        """
        with tf.variable_scope(scope):
            input_shape = tf.shape(pool)
            output_shape = [input_shape[0], out_size[0], out_size[1], input_shape[3]]

            flat_input_size = tf.reduce_prod(input_shape)
            flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]]

            pool_ = tf.reshape(pool, [flat_input_size])
            batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), dtype=ind.dtype),
                                     shape=[input_shape[0], 1, 1, 1])
            b = tf.ones_like(ind) * batch_range
            b1 = tf.reshape(b, [flat_input_size, 1])
            ind_ = tf.reshape(ind, [flat_input_size, 1])
            ind_ = tf.concat([b1, ind_], 1)

            ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64))
            ret = tf.reshape(ret, output_shape)

            set_input_shape = pool.get_shape()
            set_output_shape = [set_input_shape[0], out_size[0], out_size[1],
                                set_input_shape[3]]
            ret.set_shape(set_output_shape)
            return ret

    def _build_train_op(self):
        """Build training specific ops for the graph."""

        logits_flatten = tf.reshape(self.logits_up, [-1, self.num_classes])
        pred_flatten = tf.reshape(self.pred, [-1, self.num_classes])

        labels_gt = self.labels

        if self.ignore_class_bg:
            # ignore background labels: 255
            gt_labels_flatten = tf.reshape(labels_gt, [-1, ])
            indices = tf.squeeze(tf.where(tf.less_equal(gt_labels_flatten, self.num_classes - 1)), 1)
            remain_logits = tf.gather(logits_flatten, indices)
            remain_pred = tf.gather(pred_flatten, indices)
            remain_labels = tf.gather(gt_labels_flatten, indices)
            xent = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=remain_logits, labels=remain_labels)
        else:
            xent = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits_up, labels=labels_gt)

        self.cls_loss = tf.reduce_mean(xent, name='xent')  # xent.shape=[nIgnoredBgPixels]
        self.cost = self.cls_loss + self._decay()
        tf.summary.scalar('cost', self.cost)

        self.global_step = tf.Variable(0, name='global_step', trainable=False)
        self.learning_rate = tf.train.polynomial_decay(self.lrn_rate,
                                                       self.global_step,
                                                       self.lr_decay_step,
                                                       end_learning_rate=self.lrn_rate_end,
                                                       power=0.9)
        tf.summary.scalar('learning rate', self.learning_rate)

        tvars = tf.trainable_variables()

        if self.optimizer == 'sgd':
            optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        elif self.optimizer == 'mom':
            optimizer = tf.train.MomentumOptimizer(self.learning_rate, 0.9)
        elif self.optimizer == 'adam':
            optimizer = tf.train.AdamOptimizer(self.learning_rate)
        else:
            raise NameError("Unknown optimizer type %s!" % self.optimizer)

        grads_and_vars = optimizer.compute_gradients(self.cost, var_list=tvars)
        var_lr_mult = {}
        for var in tvars:
            if var.op.name.find(r'fc_final_sketch46') > 0 and var.op.name.find(r'biases') > 0:
                var_lr_mult[var] = 20.
            elif var.op.name.find(r'fc_final_sketch46') > 0:
                var_lr_mult[var] = 10.
            else:
                var_lr_mult[var] = 1.
        grads_and_vars = [((g if var_lr_mult[v] == 1 else tf.multiply(var_lr_mult[v], g)), v)
                          for g, v in grads_and_vars]

        ## summary grads
        # for grad, grad_var in grads_and_vars:
        #     print('>>>', grad_var.op.name)
        #     if grad is None:
        #         print('None grad')
        #     # if grad is not None:
        #     #     tf.summary.histogram(grad_var.op.name + "/gradient", grad)

        apply_op = optimizer.apply_gradients(grads_and_vars,
                                             global_step=self.global_step, name='train_step')

        train_ops = [apply_op] + self._extra_train_ops
        self.train_step = tf.group(*train_ops)

    def _decay(self):
        """L2 weight decay loss."""
        costs = []
        for var in tf.trainable_variables():
            if var.op.name.find(r'DW') > 0:
                costs.append(tf.nn.l2_loss(var))
                # tf.histogram_summary(var.op.name, var)

        return tf.multiply(self.weight_decay_rate, tf.add_n(costs))