import tensorflow as tf

# Model hyper-parameters
DECAY = .9
EPSILON = 1e-8


class Generator:
    """Generative model with the architectural specifications suited for artistic style transfer."""

    def __init__(self, is_training=True):
        self.training = is_training

    def build(self, img):
        """Constructs the generative network's layers. Normally called after initialization.
        
        Args:
            img: 4D tensor representation of image batch
        """

        self.padded = self._pad(img, 40)

        self.conv1 = self._conv_block(self.padded, maps_shape=[9, 9, 3, 32], stride=1, name='conv1')
        self.conv2 = self._conv_block(self.conv1, maps_shape=[2, 2, 32, 64], stride=2, name='conv2')
        self.conv3 = self._conv_block(self.conv2, maps_shape=[2, 2, 64, 128], stride=2, name='conv3')

        self.resid1 = self._residual_block(self.conv3, maps_shape=[3, 3, 128, 128], stride=1, name='resid1')
        self.resid2 = self._residual_block(self.resid1, maps_shape=[3, 3, 128, 128], stride=1, name='resid2')
        self.resid3 = self._residual_block(self.resid2, maps_shape=[3, 3, 128, 128], stride=1, name='resid3')
        self.resid4 = self._residual_block(self.resid3, maps_shape=[3, 3, 128, 128], stride=1, name='resid4')
        self.resid5 = self._residual_block(self.resid4, maps_shape=[3, 3, 128, 128], stride=1, name='resid5')

        self.conv4 = self._upsample_block(self.resid5, maps_shape=[2, 2, 64, 128], stride=2, name='conv4')
        self.conv5 = self._upsample_block(self.conv4, maps_shape=[2, 2, 32, 64], stride=2, name='conv5')
        self.conv6 = self._conv_block(self.conv5, maps_shape=[9, 9, 32, 3], stride=1, name='conv6', activation=None)

        self.output = tf.nn.sigmoid(self.conv6)

    @staticmethod
    def _get_weights(shape):
        """Returns a variable for weights with a specified filters shape.
        
        Args:
            shape: a list specifying the initialized weights shape
            
        Returns:
            weights: tf.Variable representing a set of weights with a normal distribution
        """

        init = tf.truncated_normal(shape, mean=0., stddev=.1)
        weights = tf.Variable(init, dtype=tf.float32)
        return weights

    @staticmethod
    def _instance_normalize(inputs):
        """Instance normalize inputs to reduce covariate shift and reduce dependency on input contrast to improve results.
        
        Args:
            inputs: 4D tensor representing image layer encodings
            
        Returns:
            maps: 4D tensor of batch normalized inputs
        """

        with tf.variable_scope('instance_normalization'):
            batch, height, width, channels = [_.value for _ in inputs.get_shape()]
            mu, sigma_sq = tf.nn.moments(inputs, [1, 2], keep_dims=True)

            shift = tf.Variable(tf.constant(.1, shape=[channels]))
            scale = tf.Variable(tf.ones([channels]))
            normalized = (inputs - mu) / (sigma_sq + EPSILON) ** .5
            maps = scale * normalized + shift
            return maps

    @staticmethod
    def _pad(inputs, size):
        """Pads input of the image so the output is the same dimensions even after strided convolution.
        
        Args:
            inputs: 4D tensor representing image layer encodings
            size: int specifying the pad size
            
        Returns:
            padded_inputs: 4D tensor of padded inputs
        """

        padded_inputs = tf.pad(inputs, [[0, 0], [size, size], [size, size], [0, 0]], "REFLECT")
        return padded_inputs

    @staticmethod
    def _batch_normalize(inputs, num_maps, is_training):
        """Batch normalize inputs to reduce covariate shift and improve the efficiency of training.
        
        Args:
            inputs: 4D tensor representing image layer encodings
            num_maps: int representing the number of input feature maps
            is_training: bool representing whether or not the model is 
                         being trained rather than being used for inference
            
        Returns:
            bn_inputs: 4D tensor of batch normalized inputs
        """

        with tf.variable_scope("batch_normalization"):
            # Trainable variables for scaling and offsetting our inputs
            scale = tf.Variable(tf.ones([num_maps], dtype=tf.float32))
            offset = tf.Variable(tf.zeros([num_maps], dtype=tf.float32))

            # Mean and variances related to our current batch
            batch_mean, batch_var = tf.nn.moments(inputs, [0, 1, 2])

            # Create an optimizer to maintain a 'moving average'
            ema = tf.train.ExponentialMovingAverage(decay=DECAY)

            def ema_retrieve():
                return ema.average(batch_mean), ema.average(batch_var)

            # If the net is being trained, update the average every training step
            def ema_update():
                ema_apply = ema.apply([batch_mean, batch_var])

                # Make sure to compute the new means and variances prior to returning their values
                with tf.control_dependencies([ema_apply]):
                    return tf.identity(batch_mean), tf.identity(batch_var)

            # Retrieve the means and variances and apply the BN transformation
            mean, var = tf.cond(tf.equal(is_training, True), ema_update, ema_retrieve)
            bn_inputs = tf.nn.batch_normalization(inputs, mean, var, offset, scale, EPSILON)

        return bn_inputs

    def _conv_block(self, inputs, maps_shape, stride, name,
                    norm=True, padding='SAME', activation=tf.nn.relu):
        """Convolve inputs and return their batch normalized tensor.
        
        Args:
            inputs: 4D tensor representing image layer encodings
            maps_shape: list representing the shape of the layer weights
            stride: int representing stride length
            name: string assigned as the tf op names
            norm: bool representing whether or not to normalize layer inputs
            padding: string representing padding type
            activation: tf.nn activation function
            
        Returns:
            maps: 4D tensor representing convolved feature maps
        """

        with tf.variable_scope(name):
            if name == 'output':
                activation = tf.nn.sigmoid

            filters = self._get_weights(maps_shape)
            filter_maps = tf.nn.conv2d(inputs, filters, [1, stride, stride, 1], padding=padding)
            num_out_maps = maps_shape[3]
            bias = tf.Variable(tf.constant(.1, shape=[num_out_maps]))
            filter_maps = tf.nn.bias_add(filter_maps, bias)

            if norm:
                filter_maps = self._instance_normalize(filter_maps)

            maps = activation(filter_maps) if activation else filter_maps
            return maps

    def _upsample_block(self, inputs, maps_shape, stride, name):
        """Upsamples inputs using transposed convolution.
        
        Args:
            inputs: 4D tensor representing image layer encodings
            maps_shape: list representing the shape of the layer weights
            stride: int representing stride length
            name: string assigned as the tf op names
            
        Returns:
            maps: 4D tensor representing upsampled feature maps
        """

        with tf.variable_scope(name):
            filters = self._get_weights(maps_shape)

            # Get dimensions to use for the upsample operator
            batch, height, width, channels = inputs.get_shape().as_list()
            out_height = height * stride
            out_width = width * stride
            out_size = maps_shape[2]
            out_shape = tf.stack([batch, out_height, out_width, out_size])
            stride = [1, stride, stride, 1]

            # Upsample and normalize the biased outputs
            upsample = tf.nn.conv2d_transpose(inputs, filters, output_shape=out_shape, strides=stride)
            bias = tf.Variable(tf.constant(.1, shape=[out_size]))
            upsample = tf.nn.bias_add(upsample, bias)
            bn_maps = self._instance_normalize(upsample)
            maps =  tf.nn.relu(bn_maps)

            return maps

    def _residual_block(self, inputs, maps_shape, stride, name):
        """Residual block comprised of two conv layers and aims to add long short-term memory to the network.
        
        Args:
            inputs: 4D tensor representing image layer encodings
            maps_shape: list representing the shape of the layer weights
            stride: int representing stride length
            name: string assigned as the tf op names
            
        Returns:
            maps: 4D tensor representing feature maps
        """

        with tf.variable_scope(name):
            conv1 = self._conv_block(inputs, maps_shape, stride=stride, padding='VALID', name='c1')
            conv2 = self._conv_block(conv1, maps_shape, stride=stride, padding='VALID', name='c2', activation=None)

            batch = inputs.get_shape().as_list()[0]
            patch_height, patch_width, num_filters = conv2.get_shape().as_list()[1:]
            out_shape = tf.stack([batch, patch_height, patch_width, num_filters])
            cropped_inputs = tf.slice(inputs, [0, 1, 1, 0], out_shape)
            maps = conv2 + cropped_inputs

            return maps