from abc import ABCMeta, abstractmethod
import tensorflow as tf
import prettytensor as pt
import zutils.tf_math_funcs as tmf
import zutils.pt_utils as ptu
import net_modules.auto_struct.utils as asu
from net_modules import keypoints_2d
import math
from zutils.py_utils import *
import zutils.tf_graph_utils as tgu

import collections


class Factory:
    __metaclass__ = ABCMeta

    def __init__(self, output_channels, options):
        """
        :param output_channels: output_channels for the encoding net
        """
        self.output_channels = output_channels
        self.options = options

        self.structure_as_final_class = True
        self.target_input_size = None

        self.stop_gradient_at_latent_for_structure = False

    def __call__(self, input_tensor, condition_tensor=None, extra_inputs=None):
        _, _, structure_latent, mos = self.structure_encode(
            input_tensor, condition_tensor=condition_tensor, extra_inputs=extra_inputs)
        latent_tensor = structure_latent
        assert self.output_channels == tmf.get_shape(latent_tensor)[1], \
            "wrong output_channels"
        return structure_latent, mos.extra_outputs

    def input_to_heatmap_overall(self, input_tensor, mos):
        # compute shared features
        overall_feature = mos(self.image2sharedfeature(input_tensor))
        # compute raw heatmap
        raw_heatmap = mos(self.image2heatmap(overall_feature))
        if "heatmap_extra" in mos.extra_outputs:
            heatmap_extra = mos.extra_outputs["heatmap_extra"]
        else:
            heatmap_extra = None
        raw_heatmap = mos(call_func_with_ignored_args(
            self.heatmap_postprocess, raw_heatmap, image_tensor=input_tensor, heatmap_extra=heatmap_extra))
        if "heatmap_extra" in mos.extra_outputs:
            heatmap_extra = mos.extra_outputs["heatmap_extra"]
        else:
            heatmap_extra = None
        # normalize heatmap
        heatmap = tf.nn.softmax(raw_heatmap)
        heatmap = mos(call_func_with_ignored_args(
            self.heatmap_postpostprocess, heatmap, image_tensor=input_tensor, heatmap_extra=heatmap_extra
        ))
        return heatmap, overall_feature

    def input_to_heatmap(self, input_tensor, mos, **kwargs):
        heatmap, _ = self.input_to_heatmap_overall(input_tensor, mos, **kwargs)
        return heatmap

    def structure_encode(self, input_tensor, condition_tensor=None, extra_inputs=None):
        if "freeze_encoded_structure" in self.options and rbool(self.options["freeze_encoded_structure"]):
            with pt.defaults_scope(phase=pt.Phase.test):
                return self.structure_encode_(input_tensor, condition_tensor, extra_inputs)
        else:
            return self.structure_encode_(input_tensor, condition_tensor, extra_inputs)

    def structure_encode_(self, input_tensor, condition_tensor=None, extra_inputs=None):
        """Create encoder network.
        """

        input_tensor = self.pad_input_tensor(input_tensor)

        # module output strip

        mos = asu.ModuleOutputStrip()
        mos.extra_outputs["discriminator_remark"] = dict(
            generator_aux_loss=[]
        )

        deterministic_collection = tgu.SubgraphCollectionSnapshots()
        deterministic_collection.sub_snapshot("_old")
        with tf.variable_scope("deterministic"), tf.variable_scope("structure"):

            # augment images (if needed)
            main_batch_size = tmf.get_shape(input_tensor)[0]

            input_tensor_x, aug_cache = mos(self.augment_images(input_tensor))
            network_predefined = ("network_predefined" in aug_cache) and aug_cache["network_predefined"]

        aug_cache["main_batch_size"] = main_batch_size
        mos.extra_outputs["aug_cache"] = aug_cache

        with tf.variable_scope("deterministic"):
            with tf.variable_scope("structure", reuse=True if network_predefined else None):
                heatmap, overall_feature = self.input_to_heatmap_overall(input_tensor_x, mos)
                structure_pack = mos(self.heatmap2structure(heatmap))

            with tf.variable_scope("structure"):
                # postprocess structure
                structure_param_x = mos(call_func_with_ignored_args(
                    self.heatmap2structure_poststep, structure_pack, image_tensor=input_tensor_x
                ))
                # clean up augmented data
                structure_param = mos(call_func_with_ignored_args(
                    self.cleanup_augmentation_structure, structure_param_x, aug_cache=aug_cache,
                    condition_tensor=condition_tensor
                ))

        with tf.variable_scope("deterministic"), tf.variable_scope("structure"):
            mos.extra_outputs["save"]["heatmap"] = heatmap[:main_batch_size]

            # entropy loss to encourage heatmap separation across different channels
            if "heatmap_separation_loss_weight" in self.options and \
                    rbool(self.options["heatmap_separation_loss_weight"]):
                total_heatmap_entropy = keypoints_2d.keypoint_map_depth_entropy_with_real_bg(heatmap)
                separation_loss = total_heatmap_entropy * self.options["heatmap_separation_loss_weight"]
                separation_loss.disp_name = "separation"
                tgu.add_to_aux_loss(separation_loss)

            # register structure_param for storing
            mos.extra_outputs["save"]["structure_param"] = structure_param
            mos.extra_outputs["for_decoder"]["structure_param"] = structure_param

            # structure_param matching
            if extra_inputs is not None and "structure_param" in extra_inputs and \
                    "structure_detection_loss_weight" in self.options and \
                    rbool(self.options["structure_detection_loss_weight"]):
                structure_param_dist = self.structure_param_distance(
                    extra_inputs["structure_param"], tf.stop_gradient(structure_param))
                structure_detection_loss = \
                    self.options["structure_detection_loss_weight"] * tf.reduce_mean(structure_param_dist, axis=0)
                structure_detection_loss.disp_name = "struct_detection"
                tgu.add_to_aux_loss(structure_detection_loss)
                mos.extra_outputs["discriminator_remark"]["generator_aux_loss"].append(structure_detection_loss)

        deterministic_collection.sub_snapshot("structure_deterministic")
        encoded_structure_vars = deterministic_collection["structure_deterministic"].get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES)
        if "freeze_encoded_structure" in self.options and rbool(self.options["freeze_encoded_structure"]):
            tgu.add_to_freeze_collection(encoded_structure_vars)
        if "encoded_structure_lr_mult" in self.options and rbool(self.options["encoded_structure_lr_mult"]):
            for v in encoded_structure_vars:
                v.lr_mult = self.options["encoded_structure_lr_mult"]

        with tf.variable_scope("variational"), tf.variable_scope("structure"):
            structure_latent = mos(self.structure2latent(structure_param))

        if self.structure_as_final_class:
            with tf.variable_scope("deterministic"):
                # use the main batch only
                heatmap = heatmap[:main_batch_size]
                overall_feature = overall_feature[:main_batch_size]

        return overall_feature, heatmap, structure_latent, mos

    def augment_images(self, image_tensor):
        return image_tensor

    def cleanup_augmentation_structure(self, structure_param, aug_cache, condition_tensor=None):
        return structure_param

    def image2sharedfeature(self, image_tensor):
        return image_tensor

    @abstractmethod
    def image2heatmap(self, image_tensor):
        return None

    def heatmap_postprocess(self, heatmap):
        return heatmap

    def heatmap_postpostprocess(self, heatmap):
        return heatmap

    def heatmap2structure_poststep(self, structure_pack):
        return structure_pack

    @abstractmethod
    def heatmap2structure(self, heatmap_tensor):
        return None

    def structure2latent(self, structure_tensor):
        # simply copy the structure as latent
        input_shape = tmf.get_shape(structure_tensor)
        latent_tensor = tf.reshape(structure_tensor, [input_shape[0], -1])
        return latent_tensor

    def structure_param2euclidean(self, structure_param):
        return structure_param

    def structure_param_distance(self, p1, p2):
        batch_size = tmf.get_shape(p1)[0]
        r1 = self.structure_param2euclidean(p1)
        r2 = self.structure_param2euclidean(p2)
        r1 = tf.reshape(r1, [batch_size, -1])
        r2 = tf.reshape(r2, [batch_size, -1])
        return tf.reduce_sum(tf.square(r2-r1), axis=1)

    def pad_input_tensor(self, input_tensor):

        if self.target_input_size is None:
            return input_tensor

        if (
            isinstance(self.target_input_size, collections.Iterable) and
            isinstance(self.target_input_size, collections.Sized)
        ):
            assert len(self.target_input_size) == 2, "wrong target_input_size"
            final_input_size = self.target_input_size
        else:
            final_input_size = [self.target_input_size] * 2

        init_input_size = tmf.get_shape(input_tensor)[1:3]

        assert math.isclose(final_input_size[0]/init_input_size[0], final_input_size[1]/init_input_size[1]), \
            "enlarge ratio should be the same (for the simplicity of other implementation)"

        assert final_input_size[0] >= init_input_size[0] and final_input_size[1] >= init_input_size[1], \
            "target input size should not be smaller the actual input size"

        if init_input_size[0] == final_input_size[0] and init_input_size[1] == final_input_size[1]:
            return input_tensor
        else:
            the_pad_y_begin = (final_input_size[0] - init_input_size[0]) // 2
            the_pad_x_begin = (final_input_size[1] - init_input_size[1]) // 2
            the_padding = [
                [0, 0],
                [the_pad_y_begin, final_input_size[0] - init_input_size[0] - the_pad_y_begin],
                [the_pad_x_begin, final_input_size[1] - init_input_size[1] - the_pad_x_begin],
                [0] * 2,
            ]
            paded_input_tensor = tmf.pad(
                tensor=input_tensor, paddings=the_padding, mode="MEAN_EDGE",
                geometric_axis=[1, 2]
            )
            return paded_input_tensor