import tensorflow as tf

import VAE
import MLP


# TF 1.3 release the statistical distribution library tf.distributions,
# support for versions of TF before 1.3
try:
    distributions = tf.distributions
    kl_divergence = tf.distributions.kl_divergence
except:
    distributions = tf.contrib.distributions
    kl_divergence = tf.contrib.distributions.kl_divergence


class SCAN(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.img_encoder = VAE.Encoder()
        self.img_decoder = VAE.Decoder()
        self.sym_encoder = MLP.Encoder()
        self.sym_decoder = MLP.Decoder()

    def train(self):
        img, sym = self.read_data_sets()

        with tf.variable_scope("beta_VAE"):
            img_q_mu, img_q_sigma = self.img_encoder(img)
            img_z = distributions.Normal(img_q_mu, img_q_sigma)
            img_gen = self.img_decoder(img_z.sample(self.cfg.batch_size))

            img_reconstruct_error = tf.reduce_mean(img_gen)

            img_z_prior = distributions.Normal()
            KL_divergence = kl_divergence(img_z, img_z_prior)
            KL_divergence = self.cfg.beta_vae * KL_divergence

            loss = img_reconstruct_error - KL_divergence

        # train beta VAE
        optimizer = tf.train.AdamOptimizer(self.cfg.learning_rate)
        train_op = optimizer.minimize(loss)

        for step in range(self.cfg.epoch):
            self.sess.run(train_op)

        with tf.variable_scope("SCAN"):
            sym_q_mu, sym_q_sigma = self.sym_encoder(sym)
            sym_z = distributions.Normal(sym_q_mu, sym_q_sigma)
            self.sym_decoder(sym_z.sample(self.cfg.batch_size))

            sym_reconstruct_error = tf.reduce_mean()

            sym_z_prior = distributions.Normal()
            beta_KL_divergence = kl_divergence(sym_z, sym_z_prior)
            beta_KL_divergence = self.cfg.beta_scan * beta_KL_divergence

            lambda_KL_divergence = kl_divergence(img_z, sym_z)

            loss = sym_reconstruct_error - beta_KL_divergence
            loss -= self.cfg.lambda_scan * lambda_KL_divergence

        # train SCAN
        optimizer = tf.train.AdamOptimizer(self.cfg.learning_rate)
        train_op = optimizer.minimize(loss)

        for step in range(self.cfg.epoch):
            self.sess.run(train_op)

    def inference(self):
        pass

    def read_data_sets(self):
        """
        Returns:
            data queues of image and symbol.
        """
        img, sym = [], []

        return(img, sym)