import tensorflow as tf
from utils import logger
import ops


class Generator(object):
    def __init__(self, name, is_train, norm='batch', image_size=128):
        logger.info('Init Generator %s', name)
        self.name = name
        self._is_train = is_train
        self._norm = norm
        self._reuse = False
        self._image_size = image_size

    def __call__(self, input, z):
        with tf.variable_scope(self.name, reuse=self._reuse):
            batch_size = int(input.get_shape()[0])
            latent_dim = int(z.get_shape()[-1])
            num_filters = [64, 128, 256, 512, 512, 512, 512]
            if self._image_size == 256:
                num_filters.append(512)

            layers = []
            G = input
            z = tf.reshape(z, [batch_size, 1, 1, latent_dim])
            z = tf.tile(z, [1, self._image_size, self._image_size, 1])
            G = tf.concat([G, z], axis=3)
            for i, n in enumerate(num_filters):
                G = ops.conv_block(G, n, 'C{}_{}'.format(n, i), 4, 2, self._is_train,
                                self._reuse, norm=self._norm if i else None, activation='leaky')
                layers.append(G)

            layers.pop()
            num_filters.pop()
            num_filters.reverse()

            for i, n in enumerate(num_filters):
                G = ops.deconv_block(G, n, 'CD{}_{}'.format(n, i), 4, 2, self._is_train,
                                self._reuse, norm=self._norm, activation='relu')
                G = tf.concat([G, layers.pop()], axis=3)
            G = ops.deconv_block(G, 3, 'last_layer', 4, 2, self._is_train,
                               self._reuse, norm=None, activation='tanh')

            self._reuse = True
            self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
            return G