"""Generate language using XLNet"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import numpy as np
import os
import re

from tqdm import tqdm
import absl.logging as _logging  # pylint: disable=unused-import
import tensorflow as tf
import sentencepiece as spm

import model_utils
from data_utils import CLS_ID, special_symbols, EOD_ID

import xlnet
from prepro_utils import preprocess_text, encode_ids



EOP_ID = special_symbols["<eop>"]

parser = argparse.ArgumentParser()
# Model
parser.add_argument("--model_config_path", default=None,
                    help="Model config path.", type=str)
parser.add_argument("--clamp_len", default=-1,
                    help="Clamp length", type=int)
parser.add_argument("--same_length", default=False,
                    help="Same length attention", action='store_true')

# Data and memory
parser.add_argument("--batch_size", default=1, help='batch size', type=int)
parser.add_argument("--max_mem_length", default=128,
                    help="Max sequence length for cached hidden states"
                    " which each predicted token is conditioned upon"
                    ". Directly increases the memory requirement", type=int)
parser.add_argument("--uncased", default=False,
                    help="Use uncased inputs or not.", action='store_true')

# I/O paths
parser.add_argument("--init_checkpoint", default=None,
                    help="checkpoint path for initializing the model. "
                    "Could be a pretrained model or a finetuned model.")
parser.add_argument("--spiece_model_file", default="",
                    help="Sentence Piece model path.")
parser.add_argument("--input_file", default="",
                    help="File containing prompts separated by empty new line "
                    "for conditional sampling")

# prediction
parser.add_argument("--num_samples", default=1,
                    help="Number of samples to predict per instance", type=int)
parser.add_argument(
    "--interactive",
    default=False,
    help="Flag for interactive prediction through command line",
    action='store_true')
parser.add_argument(
    "--unconditional",
    default=False,
    help="Prints samples wihtout any prompt",
    action='store_true')
parser.add_argument(
    "--top_p",
    default=0,
    help="Top-p coverage to use. Set 0 to use top_k sampling",
    type=float)
parser.add_argument(
    "--top_k",
    default=40,
    help="Top-k sampling strategy parameter. Use only when top-p is zero. Set"
    "-1 to use all the samples",
    type=int)
parser.add_argument("--temperature", default=1,
                    help="Scaling factor for logits", type=int)
parser.add_argument("--num_toks_pred", default=1024,
                    help="Number of tokens to predict", type=int)
parser.add_argument("--bidirectional_eachstep", help="Compute bidirectional"
                    "attention every step. Consumes a lot of time but better results",
                    action='store_true')

FLAGS = parser.parse_args()


def _create_mask(qlen, mlen):
    """Simple bi-directional attention mask. Attend
    to all token in sequence and memory"""
    klen = qlen + mlen
    return tf.zeros((qlen, klen))

def get_preprocessor(examples, tokenize_fn, pad_ids):
    """
    Input:
    examples: [List[str]] input texts
    tokenize_fn: [function] encodes text into IDs
    Output:
    tf input features
    """
    def generator():
        for example in examples:
            tokens = tokenize_fn(example)
            yield pad_ids + tokens

    return generator


def get_input_dataset(preprocessor):
    """Returns tf.data.Dataset for input"""
    batch_size = FLAGS.batch_size
    max_mem_length = FLAGS.max_mem_length

    def mask(ids):
        example = {'input_k': ids}
        input_k = example['input_k'][-max_mem_length:]
        seq_len = tf.shape(input_k)[0]
        input_mask = tf.tile(
            tf.convert_to_tensor(
                [0],
                dtype=tf.float32),
            [seq_len])
        pad_len = tf.maximum(0, max_mem_length - seq_len)
        pad_tensor = tf.concat([[[pad_len]], [[0]]], axis=-1)
        input_k = tf.pad(input_k, pad_tensor, constant_values=0)
        input_mask = tf.pad(input_mask, pad_tensor, constant_values=1)
        example['input_mask'] = input_mask
        example['input_k'] = input_k
        example['seg_id'] = tf.convert_to_tensor([0] * max_mem_length)
        return example

    dataset = tf.data.Dataset.from_generator(preprocessor,
                                             output_types=tf.int32)
    dataset = dataset.map(mask)

    dataset = dataset.batch(batch_size,
                            drop_remainder=False)
    dataset.prefetch(1)
    return dataset



def inputs_and_mask(latest_tokens, batch_size):
    """Computes inputs and masks for prediction loop.
    A dummy token ([CLS]) is appended at the at of the previous
    tokens

    Input:
    latest_tokens: Tensor [batch_size,1] or None
            If None then last dimension is 1 in the returned tensors

    output:
    input_k: [batch_size,2] latest_tokens with a dummy
            token appened at the end of the sequence
    seg_id: [batch_size,2]
    attn_masks: [batch_size,2,2]
    input_q: [batch_size,2]
            masks the tokens to predict. In this case the last token
    """
    input_k = tf.tile([[CLS_ID]], [batch_size, 1])
    seg_id = tf.tile([[0]], [batch_size, 1])
    input_q = tf.tile([[1]], [batch_size, 1])

    if latest_tokens is not None:
        input_k = tf.concat([latest_tokens, input_k], axis=-1)
        seg_id = tf.tile(seg_id, [1, 2])
        input_q_0 = tf.tile([[0]], [batch_size, 1])
        input_q = tf.concat([input_q_0, input_q], axis=-1)
        target_mapping = tf.tile(tf.constant(
            [[[0], [1]]], dtype=tf.float32), [1, 1, batch_size])
        attn_masks = tf.convert_to_tensor([[0, 1], [0, 1]], dtype=tf.float32)
    else:
        attn_masks = tf.convert_to_tensor([[1]], dtype=tf.float32)
        target_mapping = tf.tile(tf.constant(
            [[[1]]], dtype=tf.float32), [1, 1, batch_size])

    attn_masks = tf.tile(attn_masks[None, :, :], [batch_size, 1, 1])
    input_q = tf.cast(input_q, tf.float32)

    return input_k, seg_id, attn_masks, input_q, target_mapping


def get_logits(xlnet_model, xlnet_config):
    """Builds the graph for calculating the final logits"""
    lookup_table = xlnet_model.get_embedding_table()
    tie_weight = True

    with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
        initializer = xlnet_model.get_initializer()
        hidden = xlnet_model.get_sequence_output()[-1:, :, :]
        n_token = xlnet_config.n_token
        d_model = xlnet_config.d_model

        with tf.variable_scope('lm_loss'):
            if tie_weight:
                assert lookup_table is not None, \
                    'lookup_table cannot be None for tie_weight'
                softmax_w = lookup_table
            else:
                softmax_w = tf.get_variable(
                    'weight', [
                        n_token, d_model], dtype=hidden.dtype, initializer=initializer)

            softmax_b = tf.get_variable('bias', [n_token], dtype=hidden.dtype,
                                        initializer=tf.zeros_initializer())

            logits = tf.einsum('ibd,nd->ibn', hidden, softmax_w) + softmax_b

    return logits


def sampling_strategy():
    """Based on flags return either top_k or
    top_p strategy."""
    if FLAGS.top_p != 0:
        return 'top_p'

    return 'top_k'


def sample_token(logits):
    """
    Inputs:
    logits: tf.Tensor([batch_size,len,num_tokens])
    Outpus:
    samples: tf.Tensor([batch_size,len,1])
    """
    # credits: https://github.com/nshepperd/gpt-2

    logits /= FLAGS.temperature

    batch_size = tf.shape(logits)[0]
    seq_len = tf.shape(logits)[1]
    num_toks = tf.shape(logits)[2]

    if sampling_strategy() == 'top_p':
        logits_sorted = tf.sort(logits,
                                direction="DESCENDING",
                                axis=-1)
        probs = tf.nn.softmax(logits_sorted, axis=-1)
        cum_probs = tf.math.cumsum(probs,
                                   axis=-1,
                                   exclusive=True)
        logits_masked = tf.where(cum_probs < FLAGS.top_p,
                                 logits_sorted,
                                 tf.ones_like(logits_sorted) * 100)
        min_logits = tf.reduce_min(logits_masked, axis=-1)

        logits_masked = tf.where(logits < min_logits,
                                 tf.ones_like(logits) * -1e10,
                                 logits)

    elif sampling_strategy() == "top_k":
        if FLAGS.top_k != 0:
            values, _ = tf.nn.top_k(logits, k=FLAGS.top_k)
            min_values = values[:, :, -1:]
            logits_masked = tf.where(
                logits < min_values,
                tf.ones_like(logits, dtype=logits.dtype) * -1e10,
                logits,
            )
    else:
        raise NotImplementedError("Invalid sampling strategy")

    logits_masked = tf.reshape(logits_masked, (-1, num_toks))

    samples = tf.random.categorical(logits_masked,
                                    num_samples=1,
                                    dtype=tf.int32)

    probs = tf.nn.softmax(tf.reshape(logits, (-1, num_toks)), axis=-1)
    confidences = tf.gather_nd(params=probs, batch_dims=1, indices=samples)

    return tf.reshape(samples, (batch_size, seq_len, 1)),\
        tf.reshape(confidences, (batch_size, seq_len, 1))

def prediction_graph_memory():
    """Gets features and
    return predicted tokens)
    features: Dict[str:tf.train.features] Contains following features:
              input_k
              seg_id
              input_mask
    """

    features = {
        "input_k": tf.placeholder(tf.int32, (None, None)),
        "seg_id": tf.placeholder(tf.int32, (None, None)),
        "input_mask": tf.placeholder(tf.float32, (None, None))
    }

    # Building prediction graph
    # Transforming features for batch channel on last axis
    inp = tf.transpose(features["input_k"], [1, 0])
    seg_id = tf.transpose(features["seg_id"], [1, 0])
    inp_mask = tf.transpose(features["input_mask"], [1, 0])

    # Model config
    xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path)
    run_config = xlnet.create_run_config(False, True, FLAGS)
    run_config.mem_len = FLAGS.max_mem_length

    perm_mask = _create_mask(tf.shape(inp)[0], 0)[:, :, None]
    # Getting the hidden states for the prompts
    xlnet_model = xlnet.XLNetModel(
        xlnet_config=xlnet_config,
        run_config=run_config,
        input_ids=inp,
        seg_ids=seg_id,
        input_mask=inp_mask,
        perm_mask=perm_mask)

    # getting memory
    mems = xlnet_model.get_new_memory()

    latest_tokens = None
    prev_tokens = None
    prev_confs = None
    batch_size = tf.shape(mems[0])[1]

    def cond(*_):
        """Dummy condition since we stop based on iteration"""
        return True

    def body(mems, latest_tokens, mem_mask, prev_tokens, prev_confs):
        """The main body of sampling loop.
        mem: cache memory--calculated hidden states
             of previous tokens
        latest_tokens: latest sampled tokens
        mem_mask: masking for setting previous memory zero. Used for padding
        prev_tokens: all the previous tokens including latest_tokens
        prev_confs: confidences of respective tokens in prev_tokens
        """

        # get dummy input token and permutation mask
        input_k, seg_id, perm_mask, input_q, target_mapping = \
            inputs_and_mask(latest_tokens,
                            batch_size)

        input_k = tf.transpose(input_k, (1, 0))
        input_q = tf.transpose(input_q, (1, 0))
        seg_id = tf.transpose(seg_id, (1, 0))
        perm_mask = tf.transpose(perm_mask, (1, 2, 0))
        # Set the hidden state of the padded tokens to be zero[
        for i, mem in enumerate(mems):
            mems[i] = (1 - mem_mask[:, :, None]) * mems[i]
        # Get logits
        xlnet_model = xlnet.XLNetModel(
            xlnet_config=xlnet_config,
            run_config=run_config,
            input_ids=input_k,
            seg_ids=seg_id,
            perm_mask=perm_mask,
            mems=mems,
            input_mask=None,
            inp_q=input_q,
            target_mapping=target_mapping)

        logits = get_logits(xlnet_model, xlnet_config)

        # Getting new memory
        new_mems = xlnet_model.get_new_memory()

        # sample a token
        logits = tf.transpose(logits, (1, 0, 2))
        sampled_tokens, confs = sample_token(logits)
        sampled_tokens = sampled_tokens[:, -1, :]  # Last token
        confs = confs[:, -1, :]  # Last token
        prev_tokens = sampled_tokens if prev_tokens is None \
            else tf.concat([prev_tokens, sampled_tokens], axis=1)
        prev_confs = confs if prev_confs is None \
            else tf.concat([prev_confs, confs], axis=1)
        # Cache the memory of the the last latest_tokens
        if latest_tokens is not None:
            merged_mems = []

            for i, mem in enumerate(mems):
                merged_mems.append(
                    tf.concat([mems[i][1:], new_mems[i][-2:-1]], axis=0))
            mem_mask = tf.concat(
                [mem_mask[1:], tf.zeros_like(mem_mask[:1])], axis=0)
            return [
                merged_mems,
                sampled_tokens,
                mem_mask,
                prev_tokens,
                prev_confs]

        return [mems, sampled_tokens, mem_mask, prev_tokens, prev_confs]

    mems, latest_tokens, mem_mask, prev_tokens, prev_confs = body(
        mems, latest_tokens, inp_mask, prev_tokens, prev_confs)

    args = tf.while_loop(
        cond=cond,
        body=body,
        maximum_iterations=FLAGS.num_toks_pred - 1,
        loop_vars=[mems, latest_tokens, mem_mask, prev_tokens, prev_confs],
        shape_invariants=[
            [tf.TensorShape([None, None, None]) for _ in range(len(mems))],
            tf.TensorShape([None, None]),
            tf.TensorShape([None, None]),
            tf.TensorShape([None, None]),
            tf.TensorShape([None, None])
        ]
    )

    predicted_tokens, predicted_confs = args[-2:]
    return (predicted_tokens, predicted_confs), features

def prediction_graph_no_memory():
    """Builds graphs and returns prediction and input features.
    Output:
    predictions: Tuple(Tensors) Currently returns sampled tokens and confidences
    features: Dict[str:tf.train.features] Contains following features:
              input_k
              seg_id
              input_mask
    """

    features = {
        "input_k": tf.placeholder(tf.int32, (None, None)),
        "seg_id": tf.placeholder(tf.int32, (None, None)),
        "input_mask": tf.placeholder(tf.float32, (None, None))
    }

    # Building prediction graph
    # Transforming features for batch channel on last axis
    inp = tf.transpose(features["input_k"], [1, 0])
    seg_id = tf.transpose(features["seg_id"], [1, 0])
    inp_mask = tf.transpose(features["input_mask"], [1, 0])

    # Model config
    xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path)
    run_config = xlnet.create_run_config(False, True, FLAGS)
    run_config.mem_len = FLAGS.max_mem_length

    perm_mask = _create_mask(tf.shape(inp)[0], 0)[:, :, None]
    # Getting the hidden states for the prompts

    prev_tokens = None
    prev_conf = None
    # target mapping
    seq_len = tf.shape(inp)[0]
    batch_size = tf.shape(inp)[-1]
    target_mapping = tf.concat(
        [tf.zeros((1, seq_len - 1, batch_size)), tf.ones((1, 1, batch_size))], axis=1)

    def cond(*_):
        """Dummy condition since we stop based on iteration"""
        return True

    def recalc(inp, inp_mask, seg_id, perm_mask):
        """Augment the inputs for the new token. Appends 1 row or columns accordingly"""
        input_q = tf.zeros_like(inp, dtype=tf.float32)
        inp = tf.pad(inp, tf.convert_to_tensor(
            [[0, 1], [0, 0]]), constant_values=0)
        inp_mask = tf.pad(inp_mask, tf.convert_to_tensor(
            [[0, 1], [0, 0]]), constant_values=0)
        seg_id = tf.pad(seg_id, tf.convert_to_tensor(
            [[0, 1], [0, 0]]), constant_values=0)
        col = tf.ones(tf.shape(perm_mask)[0:1], dtype=tf.float32)
        perm_mask = tf.concat([perm_mask, col[:, None, None]], axis=1)
        row = tf.concat([tf.zeros(tf.shape(perm_mask)[1:2] - 1, dtype=tf.float32),
                         tf.ones([1], dtype=tf.float32)], axis=0)
        perm_mask = tf.concat([perm_mask, row[None, :, None]], axis=0)
        input_q = tf.pad(input_q, tf.convert_to_tensor(
            [[0, 1], [0, 0]]), constant_values=1)

        return inp[1:], inp_mask[1:], perm_mask[1:, 1:], input_q[1:], seg_id[1:]

    def body(inp, inp_mask, seg_id, perm_mask, prev_tokens, prev_conf):
        """The main body of sampling loop.
        inp: input ids
        inp_mask: input masks for paddings, etc.
        seg_id: segment id. Zeros here.
        perm_mask: permutation mask to pass to transformer
        prev_tokens: all the previous tokens including latest_tokens
        prev_conf: confidences of respective tokens in prev_tokens
        """

        # get dummy input token and permutation mask
        input_k, input_mask, perm_mask, input_q, seg_id = recalc(
            inp, inp_mask, seg_id, perm_mask)
        # Get logits
        xlnet_model = xlnet.XLNetModel(
            xlnet_config=xlnet_config,
            run_config=run_config,
            input_ids=input_k,
            seg_ids=seg_id,
            input_mask=inp_mask,
            perm_mask=perm_mask,
            inp_q=input_q,
            target_mapping=target_mapping)

        logits = get_logits(xlnet_model, xlnet_config)

        # sample a token
        logits = tf.transpose(logits, (1, 0, 2))
        sampled_tokens, confidences = sample_token(logits)
        sampled_tokens = sampled_tokens[:, -1, :]  # Last token
        confidences = confidences[:, -1, :]
        prev_tokens = sampled_tokens if prev_tokens is None \
            else tf.concat([prev_tokens, sampled_tokens], axis=1)
        prev_conf = confidences if prev_conf is None \
            else tf.concat([prev_conf, confidences], axis=1)

        input_k = tf.concat(
            [input_k[:-1], tf.transpose(sampled_tokens, (1, 0))], axis=0)
        perm_mask = _create_mask(tf.shape(input_k)[0], 0)[:, :, None]
        return [input_k, input_mask, seg_id, perm_mask, prev_tokens, prev_conf]

    inp, inp_mask, seg_id, perm_mask, prev_tokens, prev_conf = body(
        inp, inp_mask, seg_id, perm_mask, prev_tokens, prev_conf)
    args = tf.while_loop(
        cond=cond,
        body=body,
        maximum_iterations=FLAGS.num_toks_pred - 1,
        loop_vars=[inp, inp_mask, seg_id, perm_mask, prev_tokens, prev_conf],
        shape_invariants=[
            tf.TensorShape([None, None]),
            tf.TensorShape([None, None]),
            tf.TensorShape([None, None]),
            tf.TensorShape([None, None, None]),
            tf.TensorShape([None, None]),
            tf.TensorShape([None, None]),
        ]
    )
    predicted_tokens, predicted_confs = args[-2:]
    return (predicted_tokens, predicted_confs), features


def main():
    """Main function routine"""

    tf.logging.set_verbosity(tf.logging.INFO)

    # Text encoding
    sp = spm.SentencePieceProcessor()
    sp.Load(FLAGS.spiece_model_file)

    def tokenize_fn(text):
        text = preprocess_text(text, lower=FLAGS.uncased)
        return encode_ids(sp, text)

    # Temporary fix for context problem.
    pad_txt = """In 1991, the remains of Russian Tsar Nicholas II and his family
                (except for Alexei and Maria) are discovered.
                The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
                remainder of the story. 1883 Western Siberia,
                a young Grigori Rasputin is asked by his father and a group of men to perform magic.
                Rasputin has a vision and denounces one of the men as a horse thief. Although his
                father initially slaps him for making such an accusation, Rasputin watches as the
                man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
                the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
                 with people, even a bishop, begging for his blessing. """
    pad_ids = tokenize_fn(pad_txt)
    pad_ids.append(EOD_ID)

    to_special_symbol = {v:k for k,v in special_symbols.items()}
    def parse_ids(toks):
        """Uses sentencepiece to conver to text. Subsitute
        EOP_ID and EOD_ID with new lines, and rest with their names"""
        start = 0
        sent = ""
        for i in range(len(toks)):
          if toks[i] in to_special_symbol:
            if start<i:
              sent+=sp.decode_ids(toks[start:i])
            if toks[i] in [EOD_ID,EOP_ID]:
                replace_by = "\n\n"
            else:
                replace_by = to_special_symbol[toks[i]]

            sent+=f" {replace_by} "
            start=i+1
        if start<len(toks):
          sent+=sp.decode_ids(toks[start:])

        return sent

    if not FLAGS.bidirectional_eachstep:
        prediction_graph = prediction_graph_memory
    else:
        prediction_graph = prediction_graph_no_memory

    predictions, features = prediction_graph()
    gpu_options = tf.GPUOptions(allow_growth=True)
    model_utils.init_from_checkpoint(FLAGS, global_vars=False)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          gpu_options=gpu_options)) as sess:
        sess.run(tf.global_variables_initializer())

        def predict(examples):
            """Given a list of texts in examples
            return the result"""
            preprocessor = get_preprocessor(examples,
                                            tokenize_fn, pad_ids)
            dataset = get_input_dataset(preprocessor)
            example = dataset.make_one_shot_iterator().get_next()

            num_examples = len(examples)
            num_batches = int(np.ceil(num_examples / FLAGS.batch_size))

            for _ in tqdm(range(num_batches)):
                inputs = sess.run(example)
                output, conf = sess.run(
                    predictions, feed_dict={
                        features[k]: v for k, v in inputs.items()})
                for _output,_conf in zip(output,conf):
                    yield _output,_conf

        if FLAGS.unconditional or FLAGS.interactive:
            tf.logging.info("Interactive flag received."
                            " Ignoring input files if any.")
            while True:
                if FLAGS.unconditional:
                    text = ""
                else:
                    text = input("----PROMPT----\n")
                outputs = predict([text] * FLAGS.num_samples)
                for i, (output,_) in enumerate(outputs):
                    out = parse_ids(output.tolist())
                    print("======SAMPLE {}======".format(i))
                    print(out)
                    print("=====================")
                if FLAGS.unconditional:
                    break
        else:
            assert FLAGS.input_file!="", "Please provide either an"\
            " input file or set interactive flag for command line input"
            assert os.path.exists(FLAGS.input_file), FLAGS.input_file+\
            " does not exists"

            with open(FLAGS.input_file) as f:
                texts = []
                text = ""
                for line in f:
                    if line.strip()=="":
                        if text!="":
                            # Removing the last <eop> of prompt
                            # since it is not desired
                            if text.endswith("<eop>"):
                                text=text[:-5]
                            texts.extend([text]*FLAGS.num_samples)
                            text=""
                        continue
                    text+=re.sub(r'\n','<eop>',line)
                if text!="":
                    texts.extend([text]*FLAGS.num_samples)

            tf.logging.info("Got %s lines in the input file",
                            len(texts)//FLAGS.num_samples)
            tf.logging.info("Sampling each line %s times",FLAGS.num_samples)

            outputs = iter(predict(texts))
            with open(os.path.join(FLAGS.input_file+".xlnet"),'w') as f:
                for i in range(0,len(texts),FLAGS.num_samples):
                    f.write("\n======Example {}=================\n".format(i))
                    f.write(texts[i])
                    for j in range(FLAGS.num_samples):
                        output,_ = next(outputs)
                        out = parse_ids(output.tolist())
                        f.write("\n======Example {} SAMPLE {}======\n".format(i,j))
                        f.write(out)
                        f.write("\n==================================\n")



if __name__ == "__main__":

    # Fixed flags
    FLAGS.use_tpu = False
    FLAGS.use_bfloat16 = False
    FLAGS.dropout = 0
    FLAGS.dropatt = 0
    FLAGS.init = "normal"
    FLAGS.init_std = 0.02
    FLAGS.init_range = 0.1
    print("Args: {}".format(vars(FLAGS)))
    main()