import tensorflow as tf

from ..model_helper import TrainModel, EvalModel, InferModel
from . import taware_iterators
from .taware_model import TopicAwareSeq2SeqModel
from thred.util import vocab


def create_train_model(hparams, scope=None, num_workers=1, jobid=0, extra_args=None):
    """Create train graph, model, and iterator."""
    train_file = hparams.train_data

    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "train"):
        vocab_table = vocab.create_vocab_table(hparams.vocab_file)

        dataset = tf.data.TextLineDataset(train_file)
        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

        iterator = taware_iterators.get_iterator(
            dataset,
            vocab_table,
            batch_size=hparams.batch_size,
            num_buckets=hparams.num_buckets,
            topic_words_per_utterance=hparams.topic_words_per_utterance,
            src_max_len=hparams.src_max_len,
            tgt_max_len=hparams.tgt_max_len,
            skip_count=skip_count_placeholder,
            num_shards=num_workers,
            shard_index=jobid)

        # Note: One can set model_device_fn to
        # `tf.train.replica_device_setter(ps_tasks)` for distributed training.
        model_device_fn = None
        # if extra_args: model_device_fn = extra_args.model_device_fn
        with tf.device(model_device_fn):
            model = TopicAwareSeq2SeqModel(
                mode=tf.contrib.learn.ModeKeys.TRAIN,
                iterator=iterator,
                params=hparams,
                scope=scope)

    return TrainModel(
        graph=graph,
        model=model,
        iterator=iterator,
        skip_count_placeholder=skip_count_placeholder)


def create_eval_model(hparams, scope=None):
    """Create train graph, model, src/tgt file holders, and iterator."""
    vocab_file = hparams.vocab_file
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "eval"):
        vocab_table = vocab.create_vocab_table(vocab_file)
        eval_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)

        eval_dataset = tf.data.TextLineDataset(eval_file_placeholder)
        iterator = taware_iterators.get_iterator(
            eval_dataset,
            vocab_table,
            hparams.batch_size,
            num_buckets=hparams.num_buckets,
            topic_words_per_utterance=hparams.topic_words_per_utterance,
            src_max_len=hparams.src_max_len,
            tgt_max_len=hparams.tgt_max_len)
        model = TopicAwareSeq2SeqModel(
            mode=tf.contrib.learn.ModeKeys.EVAL,
            iterator=iterator,
            params=hparams,
            scope=scope,
            log_trainables=False)
    return EvalModel(
        graph=graph,
        model=model,
        eval_file_placeholder=eval_file_placeholder,
        iterator=iterator)


def create_infer_model(hparams, scope=None):
    """Create inference model."""
    graph = tf.Graph()
    vocab_file = hparams.vocab_file

    with graph.as_default(), tf.container(scope or "infer"):
        vocab_table = vocab.create_vocab_table(vocab_file)
        reverse_vocab_table = vocab.create_rev_vocab_table(vocab_file)

        src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

        src_dataset = tf.data.Dataset.from_tensor_slices(
            src_placeholder)
        iterator = taware_iterators.get_infer_iterator(
            src_dataset,
            vocab_table,
            batch_size=batch_size_placeholder,
            topic_words_per_utterance=hparams.topic_words_per_utterance,
            src_max_len=hparams.src_max_len)
        model = TopicAwareSeq2SeqModel(
            mode=tf.contrib.learn.ModeKeys.INFER,
            iterator=iterator,
            params=hparams,
            rev_vocab_table=reverse_vocab_table,
            scope=scope,
            log_trainables=False)
    return InferModel(
        graph=graph,
        model=model,
        src_placeholder=src_placeholder,
        batch_size_placeholder=batch_size_placeholder,
        iterator=iterator)