import tensorflow as tf
from tensorflow.contrib import layers
from tensorflow.contrib.cudnn_rnn import CudnnLSTM
from tensorflow.contrib.cudnn_rnn.python.layers.cudnn_rnn import CUDNN_RNN_BIDIRECTION
import os
from tasks.acp.data import RealData


class Parser:
    def parse(self, x, context, is_training):
        raise NotImplementedError()

    def restore(self):
        """
        Must return a tuple of (scope, restore_file_path).
        """
        raise NotImplementedError()


class NoOpParser(Parser):
    def restore(self):
        return None

    def parse(self, x, context, is_training):
        return x


class OptionalParser(Parser):
    def __init__(self, delegate: Parser, bs, seq_out, n_out, eos_idx):
        self.eos_idx = eos_idx
        self.n_out = n_out
        self.seq_out = seq_out
        self.bs = bs

        self.delegate = delegate

    def restore(self):
        return self.delegate.restore()

    def parse(self, x, context, is_training):
        parsed = self.delegate.parse(x, context, is_training)

        empty_answer = tf.constant(self.eos_idx, tf.int32, shape=(self.bs, self.seq_out))
        empty_answer = tf.one_hot(empty_answer, self.n_out)  # (bs, seq_out, n_out)

        logit_empty = layers.fully_connected(context, 1, activation_fn=None)  # (bs, 1)

        return parsed + tf.reshape(logit_empty, (self.bs, 1, 1)) * empty_answer


class AmountParser(Parser):
    """
    You should pre-train this parser to parse amounts otherwise it's hard to learn jointly.
    """
    seq_in = RealData.seq_in
    seq_out = RealData.seq_amount
    n_out = len(RealData.chars)
    scope = 'parse/amounts'

    def __init__(self, bs):
        os.makedirs("./snapshots/amounts", exist_ok=True)
        self.bs = bs

    def restore(self):
        return self.scope, "./snapshots/amounts/best"

    def parse(self, x, context, is_training):
        with tf.variable_scope(self.scope):
            # Input RNN

            in_rnn = CudnnLSTM(1, 128, direction=CUDNN_RNN_BIDIRECTION, name="in_rnn")
            h_in, _ = in_rnn(tf.transpose(x, [1, 0, 2]))
            h_in = tf.reshape(tf.transpose(h_in, [1, 0, 2]), (self.bs, self.seq_in, 1, 256))  # (bs, seq_in, 1, 128)

            # Output RNN
            out_input = tf.zeros((self.seq_out, self.bs, 1))  # consider teacher forcing.
            out_rnn = CudnnLSTM(1, 128, name="out_rnn")
            h_out, _ = out_rnn(out_input)
            h_out = tf.reshape(tf.transpose(h_out, [1, 0, 2]), (self.bs, 1, self.seq_out, 128))  # (bs, 1, seq_out, 128)

            # Bahdanau attention
            att = tf.nn.tanh(layers.fully_connected(h_out, 128, activation_fn=None) + layers.fully_connected(h_in, 128, activation_fn=None))
            att = layers.fully_connected(att, 1, activation_fn=None)  # (bs, seq_in, seq_out, 1)
            att = tf.nn.softmax(att, axis=1)  # (bs, seq_in, seq_out, 1)

            attended_h = tf.reduce_sum(att * h_in, axis=1)  # (bs, seq_out, 128)

            p_gen = layers.fully_connected(attended_h, 1, activation_fn=tf.nn.sigmoid)  # (bs, seq_out, 1)
            p_copy = (1 - p_gen)

            # Generate
            gen = layers.fully_connected(attended_h, self.n_out, activation_fn=None)  # (bs, seq_out, n_out)
            gen = tf.reshape(gen, (self.bs, self.seq_out, self.n_out))

            # Copy
            copy = tf.log(tf.reduce_sum(att * tf.reshape(x, (self.bs, self.seq_in, 1, self.n_out)), axis=1) + 1e-8)  # (bs, seq_out, n_out)

            output_logits = p_copy * copy + p_gen * gen
            return output_logits


class DateParser(Parser):

    """
    You should pre-train this parser to parse dates otherwise it's hard to learn jointly.
    """
    seq_out = RealData.seq_date
    n_out = len(RealData.chars)
    scope = 'parse/date'

    def __init__(self, bs):
        os.makedirs("./snapshots/dates", exist_ok=True)
        self.bs = bs

    def restore(self):
        return self.scope, "./snapshots/dates/best"

    def parse(self, x, context, is_training):
        with tf.variable_scope(self.scope):
            for i in range(4):
                x = tf.layers.conv1d(x, 128, 3, padding="same", activation=tf.nn.relu)  # (bs, 128, 128)
                x = tf.layers.max_pooling1d(x, 2, 2)  # (bs, 64-32-16-8, 128)
            x = tf.reduce_sum(x, axis=1)  # (bs, 128)

            x = tf.concat([x, context], axis=1)  # (bs, 256)
            for i in range(3):
                x = layers.fully_connected(x, 256)

            x = layers.dropout(x, is_training=is_training)
            x = layers.fully_connected(x, self.seq_out * self.n_out, activation_fn=None)
            return tf.reshape(x, (self.bs, self.seq_out, self.n_out))