"""
Module for dataset creation.
Usage:
python3 dataset.py --command create_dataset --input data/train-v1.1.json --output data/train.bin,data/validation.bin --split 0.8,0.2
python3 dataset.py --command create_vocab --input data/train-v1.1.json --output data/vocab

Modified ver. of https://github.com/tensorflow/models/blob/master/textsum/data.py
"""
import glob
import json
import struct
from random import shuffle

import tensorflow as tf
from tensorflow.core.example import example_pb2
from spacy.en import English

nlp = English()

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('command', 'create_dataset',
                           'Either create_vocab or create_dataset.'
                           'Specify FLAGS.in_directories accordingly.')
tf.app.flags.DEFINE_string('input', '', 'path to input data')
tf.app.flags.DEFINE_string('output', '', 'comma separated paths to files')
tf.app.flags.DEFINE_string('split', '', 'comma separated fractions of training/validation')

# special tokens
PARAGRAPH_START = '<p>'
PARAGRAPH_END = '</p>'
SENTENCE_START = '<s>'
SENTENCE_END = '</s>'
UNKNOWN_TOKEN = '<unk>'
PAD_TOKEN = '<pad>'

# special tokens for OOVs
WORD_BEGIN = '<b>'
WORD_CONTINUE = '<c>'
WORD_END = '<e>'


class Vocab:
  """Vocab class for mapping token and ids."""

  def __init__(self, file_path, max_size):
    self._token_to_id = {}
    self._id_to_token = {}
    self._size = 0

    with open(file_path, 'rt', encoding='utf-8') as f:
      for line in f:
        tokens = line.split()
        # take care of white spaces
        if len(tokens) == 1:
          count = tokens[0]
          idx = line.index(count)
          t = line[:idx-1]
          tokens = (t, count)
        if len(tokens) != 2:
          continue
        # duplicates
        if tokens[0] in self._token_to_id:
          continue
        self._size += 1
        if self._size > max_size:
          tf.logging.warn('Warning! Too many tokens: >%d\n' % max_size)
          break
        self._token_to_id[tokens[0]] = self._size
        self._id_to_token[self._size] = tokens[0]

  def __len__(self):
    return self._size

  def tokenToId(self, token):
    if token not in self._token_to_id:
      tf.logging.warn('id not found for token: %s\n' % token)
      return self._token_to_id[UNKNOWN_TOKEN]
    return self._token_to_id[token]

  def idToToken(self, _id):
    if _id not in self._id_to_token:
      tf.logging.warn('token not found for id: %d\n' % _id)
      return UNKNOWN_TOKEN
    return self._id_to_token[_id]

def create_vocab(input_file, output_file, max_size=200000):
  """Generates vocab from input_file.
     Args:
        input_file: input file path
        output_file: output file path
        max_size: size of Vocabulary
  """
  from collections import Counter
  counter = Counter()

  with open(input_file, 'r', encoding='utf-8') as data_file:
    parsed_file = json.load(data_file)
    data = parsed_file['data']

    for datum in data:
      for paragraph in datum['paragraphs']:
        context = nlp(paragraph['context'].lower())
        counter.update(context.text)
        counter.update(map(lambda c: c.text, context))
        for qas in paragraph['qas']:
          question = nlp(qas['question'].lower())
          counter.update(question.text)
          counter.update(map(lambda c: c.text, question))

  with open(output_file, 'wt') as f:
    # reserve for special tokens
    f.write('<s> 0\n')
    f.write('</s> 0\n')
    f.write('<unk> 0\n')
    f.write('<pad> 0\n')
    f.write('<b> 0\n')
    f.write('<c> 0\n')
    f.write('<e> 0\n')
    for token, count in counter.most_common(max_size-7):
      f.write(token + ' ' + str(count) + '\n')

def create_dataset(input_file, output_files, split_fractions):
  """Generates train/validation files from input_file.
     Args:
        input_file: input file path
        output_file: output file path
        split_fractions: train/validation split fractions
  """
  import struct
  from random import shuffle
  from nltk.tokenize import sent_tokenize
  from tensorflow.core.example import example_pb2

  with open(input_file, 'r') as data_file:
    parsed_file = json.load(data_file)
    data = parsed_file['data']
    len_data = len(data)
    indices = [int(len_data*(1-split)) for split in split_fractions]
    indices.insert(0, 0)

    # shuffle data by topic
    shuffle(data)

    for i in range(1, len(indices)):
      subset = data[indices[i-1]:indices[i]]
      with open(output_files[i-1], 'wb') as writer:
        for datum in subset:
          for paragraph in datum['paragraphs']:
            context = nlp(paragraph['context']).text

            sentences = sent_tokenize(context)
            context = '<p>' + ' '.join(['<s>' + sentence + '</s>' for sentence in sentences]) + '</p>'
            context = context.encode('utf-8')

            qas = paragraph['qas']
            for qa in qas:
              question = nlp(qa['question']).text
              answer = nlp(qa['answers'][0]['text']).text # just select best one

              sentences = sent_tokenize(question)
              question = '<p>' + ' '.join(['<s>' + sentence + '</s>' for sentence in sentences]) + '</p>'
              question = question.encode('utf-8')

              sentences = sent_tokenize(answer)
              answer = '<p>' + ' '.join(['<s>' + sentence + '</s>' for sentence in sentences]) + '</p>'
              answer = answer.encode('utf-8')

              tf_example = example_pb2.Example()
              tf_example.features.feature['context'].bytes_list.value.extend([context])
              tf_example.features.feature['question'].bytes_list.value.extend([question])
              tf_example.features.feature['answer'].bytes_list.value.extend([answer])
              tf_example_str = tf_example.SerializeToString()
              str_len = len(tf_example_str)
              writer.write(struct.pack('q', str_len))
              writer.write(struct.pack('%ds' % str_len, tf_example_str))

def snippet_gen(text, start_tok, end_tok, inclusive=False):
  """Generates consecutive snippets between start and end tokens.
     Args:
        text: a string
        start_tok: a string denoting the start of snippets
        end_tok: a string denoting the end of snippets
        inclusive: Whether include the tokens in the returned snippets.
    Yields:
        String snippets
  """
  cur = 0
  while True:
    try:
      start_p = text.index(start_tok, cur)
      end_p = text.index(end_tok, start_p + 1)
      cur = end_p + len(end_tok)
      if inclusive:
        yield text[start_p:cur]
      else:
        yield text[start_p+len(start_tok):end_p]
    except ValueError as e:
      raise StopIteration('no more snippets in text: %s' % e)

def to_sentences(paragraph, include_token=False):
  """Takes tokens of a paragraph and returns list of sentences.
     Args:
        paragraph: string, text of paragraph
        include_token: Whether include the sentence separation tokens result.
     Returns:
        List of sentence strings.
  """
  if not isinstance(paragraph, str):
    paragraph = paragraph.decode('utf-8')
  s_gen = snippet_gen(paragraph, SENTENCE_START, SENTENCE_END, include_token)
  return [s for s in s_gen]

def pad(ids, pad_id, length):
  """Pad or trim list to len length.
     Args:
        ids: list of ints to pad
        pad_id: what to pad with
        length: length to pad or trim to
     Returns:
        ids trimmed or padded with pad_id
  """
  assert pad_id is not None
  assert length is not None

  if len(ids) < length:
    a = [pad_id] * (length - len(ids))
    return ids + a
  else:
    return ids[:length]

def tokens_to_ids(text, vocab, pad_len=None, pad_id=None):
  """Get ids corresponding to tokens in text.
  Assumes tokens separated by space.
  Args:
    text: a string
    vocab: TextVocabularyFile object
    pad_len: int, length to pad to
    pad_id: int, token id for pad symbol
  Returns:
    A list of ints representing token ids.
  """
  ids = []
  b = vocab.tokenToId(WORD_BEGIN)
  c = vocab.tokenToId(WORD_CONTINUE)
  e = vocab.tokenToId(WORD_END)
  unk = vocab.tokenToId(UNKNOWN_TOKEN)
  token_iterator = map(lambda x: x.text, nlp(text.lower()))
  for token in token_iterator:
    _id = vocab.tokenToId(token)
    if _id == unk: # w is OOV
      ids.append(b)
      for character in token:
        ids.append(c)
        ids.append(vocab.tokenToId(character))
      ids.append(e)
    else: # w is present in vocab
      ids.append(_id)
  if pad_len is not None:
    return pad(ids, pad_id, pad_len)
  return ids

def ids_to_tokens(ids_list, vocab):
  """Get tokens from ids.
     Args:
        ids_list: list of int32
        vocab: TextVocabulary object
     Returns:
        List of tokens corresponding to ids.
  """
  assert isinstance(ids_list, list), '%s  is not a list' % ids_list
  answer = []
  tmp = ''
  # iterate throught each id and recover any OOVs
  for _id in ids_list:
    token = vocab.idToToken(_id)
    if token == PAD_TOKEN:
      token = ''
    if token == WORD_BEGIN:
      tmp += token
    elif token == WORD_END:
      tmp = ''.join(tmp.split(WORD_CONTINUE))
      answer.append(tmp[1:])
      tmp = ''
    elif len(tmp) > 0:
      tmp += token
    else:
      answer.append(token)
  return answer

def tf_Examples(data_path, num_epochs=None):
  """Generates tf.Examples from path of data files.
    Binary data format: <length><blob>. <length> represents the byte size
    of <blob>. <blob> is serialized tf.Example proto. The tf.Example contains
    the tokenized article text and summary.
  Args:
    data_path: path to tf.Example data files.
    num_epochs: Number of times to go through the data. None means infinite.
  Yields:
    Deserialized tf.Example.
  If there are multiple files specified, they accessed in a random order.
  """
  epoch = 0
  while True:
    if num_epochs is not None and epoch >= num_epochs:
      break
    filelist = glob.glob(data_path)
    assert filelist, 'Empty filelist.'
    shuffle(filelist)
    for f in filelist:
      reader = open(f, 'rb')
      while True:
        len_bytes = reader.read(8)
        if not len_bytes: break
        str_len = struct.unpack('q', len_bytes)[0]
        example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
        yield example_pb2.Example.FromString(example_str)

    epoch += 1

def main(unused_argv):
  assert FLAGS.command and FLAGS.input and FLAGS.output
  output_files = FLAGS.output.split(',')
  input_file = FLAGS.input

  if FLAGS.command == 'create_dataset':
    assert FLAGS.split

    split_fractions = [float(s) for s in FLAGS.split.split(',')]

    assert len(output_files) == len(split_fractions)

    create_dataset(input_file, output_files, split_fractions)

  elif FLAGS.command == 'create_vocab':
    assert len(output_files) == 1

    create_vocab(input_file, output_files[0])


if __name__ == '__main__':
  tf.app.run()