import math
import tensorflow as tf
import tensorflow.contrib.slim as slim

from ops import snconv2d, snlinear
from models.model_base import Model

_EPS = 1e-5

class DiscriminatorBuilder(Model):
    def __init__(self, config):
        self.config=config
        self.ndf_base = self.config["model_params"]["ndf_base"]
        self.num_extra_layers = self.config["model_params"]["d_extra_layers"]
        self.macro_patch_size = self.config["data_params"]["macro_patch_size"]

        self.update_collection = "D_update_collection"

    def _d_residual_block(self, x, out_ch, idx, is_training, resize=True, is_head=False):
        update_collection = self._get_update_collection(is_training)
        with tf.variable_scope("d_resblock_"+str(idx), reuse=tf.AUTO_REUSE):
            h = x
            if not is_head:
                h = tf.nn.relu(h)
            h = snconv2d(h, out_ch, name='d_resblock_conv_1', update_collection=update_collection)
            h = tf.nn.relu(h)
            h = snconv2d(h, out_ch, name='d_resblock_conv_2', update_collection=update_collection)
            if resize:
                h = slim.avg_pool2d(h, [2, 2])

            # Short cut
            s = x
            if resize:
                s = slim.avg_pool2d(s, [2, 2])
            s = snconv2d(s, out_ch, k_h=1, k_w=1, name='d_resblock_conv_sc', update_collection=update_collection)
            return h + s
    
            
    def forward(self, x, y=None, is_training=True):
        valid_sizes = {8, 16, 32, 64, 128, 256, 512}
        assert (self.macro_patch_size[0] in valid_sizes and self.macro_patch_size[1] in valid_sizes), \
            "I haven't test your macro patch size: {}".format(self.macro_patch_size)

        update_collection = self._get_update_collection(is_training)
        print(" [Build] Discriminator ; is_training: {}".format(is_training))
        
        with tf.variable_scope("D_discriminator", reuse=tf.AUTO_REUSE):

            num_resize_layers = int(math.log(min(self.macro_patch_size), 2) - 1)
            num_total_layers  = num_resize_layers + self.num_extra_layers
            basic_layers = [2, 4, 8, 8]
            if num_total_layers>len(basic_layers):
                num_replicate_layers = num_total_layers - len(basic_layers)
                ndf_mult_list = [1, ] * num_replicate_layers + basic_layers
            else:
                ndf_mult_list = basic_layers[-num_total_layers:]
                ndf_mult_list[0] = 1
            print("\t ndf_mult_list = {}".format(ndf_mult_list))

            # Stack extra layers without resize first
            h = x
            for idx, ndf_mult in enumerate(ndf_mult_list):
                n_ch = self.ndf_base * ndf_mult
                # Head is fixed and goes first
                if idx==0:
                    is_head, resize, is_extra = True, True, False
                # Extra layers before standard layers
                elif idx<=self.num_extra_layers:
                    is_head, resize, is_extra = False, False, True
                # Last standard layer has no resize
                elif idx==len(ndf_mult_list)-1:
                    is_head, resize, is_extra = False, False, False
                # Standard layers
                else:
                    is_head, resize, is_extra = False, True, False
                
                h = self._d_residual_block(h, n_ch, idx=idx, is_training=is_training, resize=resize, is_head=is_head)
                print("\t DResBlock: id={}, out_shape={}, resize={}, is_extra={}"
                    .format(idx, h.shape.as_list(), resize, is_extra))

            h = tf.nn.relu(h)
            h = tf.reduce_sum(h, axis=[1,2]) # Global pooling
            last_feature_map = h
            adv_out = snlinear(h, 1, 'main_steam_out', update_collection=update_collection)

            # Projection Discriminator
            if y is not None:
                h_num_ch = self.ndf_base*ndf_mult_list[-1]
                y_emb = snlinear(y, h_num_ch, 'y_emb', update_collection=update_collection)
                proj_out = tf.reduce_sum(y_emb*h, axis=1, keepdims=True)
            else:
                proj_out = 0

            out = adv_out + proj_out
            
            return out, last_feature_map