import metric
import tfutil as tfu

import numpy as np
import tensorflow as tf


class RCAN:

    def __init__(self,
                 sess,                                     # tensorflow session
                 batch_size=16,                            # batch size
                 n_channel=3,                              # number of image channel, 3 for RGB, 1 for gray-scale
                 img_scaling_factor=4,                     # image scale factor to up
                 lr_img_size=(48, 48),                     # input patch image size for LR
                 hr_img_size=(192, 192),                   # input patch image size for HR
                 n_res_blocks=20,                          # number of residual block
                 n_res_groups=10,                          # number of residual group
                 res_scale=1,                              # scaling factor of res block
                 n_filters=64,                             # number of conv2d filter size
                 kernel_size=3,                            # number of conv2d kernel size
                 activation='relu',                        # activation function
                 use_bn=False,                             # using batch_norm or not
                 reduction=16,                             # reduction rate at CA layer
                 # rgb_mean=(114.2430, 111.4502, 103.0450),  # RGB mean, for DIV2K DataSet
                 # rgb_std=(69.6606, 66.0210, 72.1786),      # RGB std, for DIV2K DataSet
                 rgb_mean=(0.4480, 0.4371, 0.4041),        # RGB mean, for DIV2K DataSet
                 rgb_std=(0.2732, 0.2589, 0.2831),         # RGB std, for DIV2K DataSet
                 optimizer='adam',                         # name of optimizer
                 lr=1e-4,                                  # learning rate
                 lr_decay=.5,                              # learning rate decay factor
                 lr_decay_step=2e5,                        # learning rate decay step
                 momentum=.9,                              # SGD momentum value
                 beta1=.9,                                 # Adam beta1 value
                 beta2=.999,                               # Adam beta2 value
                 opt_eps=1e-8,                             # Adam epsilon value
                 eps=1.1e-5,                               # epsilon
                 tf_log="./model/",                        # path saved tensor summary / model
                 n_gpu=1,                                  # number of GPU
                 ):
        self.sess = sess
        self.batch_size = batch_size
        self.n_channel = n_channel
        self.img_scale = img_scaling_factor
        self.lr_img_size = lr_img_size + (self.n_channel,)
        self.hr_img_size = hr_img_size + (self.n_channel,)

        self.n_res_blocks = n_res_blocks
        self.n_res_groups = n_res_groups
        self.res_scale = res_scale

        self.n_filters = n_filters
        self.kernel_size = kernel_size
        self.activation = activation
        self.use_bn = use_bn
        self.reduction = reduction

        self.rgb_mean = tf.constant(rgb_mean, dtype=tf.float32)
        self.rgb_std = tf.constant(rgb_std, dtype=tf.float32)

        self.optimizer = optimizer
        self.lr = lr
        self.lr_decay = lr_decay
        self.lr_decay_step = lr_decay_step
        self.momentum = momentum
        self.beta1 = beta1
        self.beta2 = beta2
        self.opt_eps = opt_eps

        self._eps = eps

        self.tf_log = tf_log

        self.n_gpu = n_gpu

        self.act = None
        self.opt = None
        self.train_op = None
        self.loss = None
        self.output = None

        self.psnr = None
        self.ssim = None

        self.saver = None
        self.best_saver = None
        self.merged = None
        self.writer = None

        self.global_step = tf.Variable(0, trainable=False, name='global_step')

        # tensor placeholder for input
        self.x_lr = tf.placeholder(tf.float32, shape=(None,) + self.lr_img_size, name='x-lr-img')
        self.x_hr = tf.placeholder(tf.float32, shape=(None,) + self.hr_img_size, name='x-hr-img')

        self.lr = tf.placeholder(tf.float32, name='learning_rate')
        # self.is_train = tf.placeholder(tf.bool, name='is_train')

        # setting stuffs
        self.setup()

        # build a network
        self.build_model()

    def setup(self):
        # Activation Function Setting
        if self.activation == 'relu':
            self.act = tf.nn.relu
        elif self.activation == 'leaky_relu':
            self.act = tf.nn.leaky_relu
        elif self.activation == 'elu':
            self.act = tf.nn.elu
        else:
            raise NotImplementedError("[-] Not supported activation function {}".format(self.activation))

        # Optimizer
        if self.optimizer == 'adam':
            self.opt = tf.train.AdamOptimizer(learning_rate=self.lr,
                                              beta1=self.beta1, beta2=self.beta2, epsilon=self.opt_eps)
        elif self.optimizer == 'sgd':  # sgd + m with nestrov
            self.opt = tf.train.MomentumOptimizer(learning_rate=self.lr, momentum=self.momentum, use_nesterov=True)
        else:
            raise NotImplementedError("[-] Not supported optimizer {}".format(self.optimizer))

    def image_processing(self, x, sign, name):
        with tf.variable_scope(name):
            r, g, b = tf.split(x, num_or_size_splits=3, axis=-1)

            # normalize pixel with pre-calculated value based on DIV2K DataSet
            rgb = tf.concat([(r + sign * self.rgb_mean[0]),
                             (g + sign * self.rgb_mean[1]),
                             (b + sign * self.rgb_mean[2])], axis=-1)
            return rgb

    def channel_attention(self, x, f, reduction, name):
        """
        Channel Attention (CA) Layer
        :param x: input layer
        :param f: conv2d filter size
        :param reduction: conv2d filter reduction rate
        :param name: scope name
        :return: output layer
        """
        with tf.variable_scope("CA-%s" % name):
            skip_conn = tf.identity(x, name='identity')

            x = tfu.adaptive_global_average_pool_2d(x)

            x = tfu.conv2d(x, f=f // reduction, k=1, name="conv2d-1")
            x = self.act(x)

            x = tfu.conv2d(x, f=f, k=1, name="conv2d-2")
            x = tf.nn.sigmoid(x)
            return tf.multiply(skip_conn, x)

    def residual_channel_attention_block(self, x, f, kernel_size, reduction, use_bn, name):
        with tf.variable_scope("RCAB-%s" % name):
            skip_conn = tf.identity(x, name='identity')

            x = tfu.conv2d(x, f=f, k=kernel_size, name="conv2d-1")
            x = tf.layers.BatchNormalization(epsilon=self._eps, name="bn-1")(x) if use_bn else x
            x = self.act(x)

            x = tfu.conv2d(x, f=f, k=kernel_size, name="conv2d-2")
            x = tf.layers.BatchNormalization(epsilon=self._eps, name="bn-2")(x) if use_bn else x

            x = self.channel_attention(x, f, reduction, name="RCAB-%s" % name)
            return self.res_scale * x + skip_conn  # tf.math.add(self.res_scale * x, skip_conn)

    def residual_group(self, x, f, kernel_size, reduction, use_bn, name):
        with tf.variable_scope("RG-%s" % name):
            skip_conn = tf.identity(x, name='identity')

            for i in range(self.n_res_blocks):
                x = self.residual_channel_attention_block(x, f, kernel_size, reduction, use_bn, name=str(i))

            x = tfu.conv2d(x, f=f, k=kernel_size, name='rg-conv-1')
            return x + skip_conn  # tf.math.add(x, skip_conn)

    def up_scaling(self, x, f, scale_factor, name):
        """
        :param x: image
        :param f: conv2d filter
        :param scale_factor: scale factor
        :param name: scope name
        :return:
        """
        with tf.variable_scope(name):
            if scale_factor == 3:
                x = tfu.conv2d(x, f * 9, k=1, name='conv2d-image_scaling-0')
                x = tfu.pixel_shuffle(x, 3)
            elif scale_factor & (scale_factor - 1) == 0:  # is it 2^n?
                log_scale_factor = int(np.log2(scale_factor))
                for i in range(log_scale_factor):
                    x = tfu.conv2d(x, f * 4, k=1, name='conv2d-image_scaling-%d' % i)
                    x = tfu.pixel_shuffle(x, 2)
            else:
                raise NotImplementedError("[-] Not supported scaling factor (%d)" % scale_factor)
            return x

    def residual_channel_attention_network(self, x, f, kernel_size, reduction, use_bn, scale):
        with tf.variable_scope("Residual_Channel_Attention_Network"):
            x = self.image_processing(x, sign=-1, name='pre-processing')

            # 1. head
            head = tfu.conv2d(x, f=f, k=kernel_size, name="conv2d-head")

            # 2. body
            x = head
            for i in range(self.n_res_groups):
                x = self.residual_group(x, f, kernel_size, reduction, use_bn, name=str(i))

            body = tfu.conv2d(x, f=f, k=kernel_size, name="conv2d-body")
            body += head  # tf.math.add(body, head)

            # 3. tail
            x = self.up_scaling(body, f, scale, name='up-scaling')
            tail = tfu.conv2d(x, f=self.n_channel, k=kernel_size, name="conv2d-tail")  # (-1, 384, 384, 3)

            x = self.image_processing(tail, sign=1, name='post-processing')
            return x

    def build_model(self):
        # RCAN model
        self.output = self.residual_channel_attention_network(x=self.x_lr,
                                                              f=self.n_filters,
                                                              kernel_size=self.kernel_size,
                                                              reduction=self.reduction,
                                                              use_bn=self.use_bn,
                                                              scale=self.img_scale,
                                                              )
        self.output = tf.clip_by_value(self.output * 255., 0., 255.)

        # l1 loss
        self.loss = tf.reduce_mean(tf.abs(self.output - self.x_hr))

        self.train_op = self.opt.minimize(self.loss, global_step=self.global_step)

        # metrics
        self.psnr = tf.reduce_mean(metric.psnr(self.output, self.x_hr, m_val=1))
        self.ssim = tf.reduce_mean(metric.ssim(self.output, self.x_hr, m_val=1))

        # summaries
        tf.summary.image('lr', self.x_lr, max_outputs=self.batch_size)
        tf.summary.image('hr', self.x_hr, max_outputs=self.batch_size)
        tf.summary.image('generated-hr', self.output, max_outputs=self.batch_size)

        tf.summary.scalar("loss/l1_loss", self.loss)
        tf.summary.scalar("metric/psnr", self.psnr)
        tf.summary.scalar("metric/ssim", self.ssim)
        tf.summary.scalar("misc/lr", self.lr)

        # merge summary
        self.merged = tf.summary.merge_all()

        # model saver
        self.saver = tf.train.Saver(max_to_keep=1)
        self.best_saver = tf.train.Saver(max_to_keep=1)
        self.writer = tf.summary.FileWriter(self.tf_log, self.sess.graph)