# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads the WikiQuestions dataset.

An example consists of question, table. Additionally, we store the processed
columns which store the entries after performing number, date and other
preprocessing as done in the baseline.
columns, column names and processed columns are split into word and number
columns.
lookup answer (or matrix) is also split into number and word lookup matrix
Author: aneelakantan (Arvind Neelakantan)
"""
import math
import os
import re
import numpy as np
import unicodedata as ud
import tensorflow as tf

bad_number = -200000.0  #number that is added to a corrupted table entry in a number column

def is_nan_or_inf(number):
  return math.isnan(number) or math.isinf(number)

def strip_accents(s):
  u = unicode(s, "utf-8")
  u_new = ''.join(c for c in ud.normalize('NFKD', u) if ud.category(c) != 'Mn')
  return u_new.encode("utf-8")


def correct_unicode(string):
  string = strip_accents(string)
  string = re.sub("\xc2\xa0", " ", string).strip()
  string = re.sub("\xe2\x80\x93", "-", string).strip()
  #string = re.sub(ur'[\u0300-\u036F]', "", string)
  string = re.sub("‚", ",", string)
  string = re.sub("…", "...", string)
  #string = re.sub("[·・]", ".", string)
  string = re.sub("ˆ", "^", string)
  string = re.sub("˜", "~", string)
  string = re.sub("‹", "<", string)
  string = re.sub("›", ">", string)
  #string = re.sub("[‘’´`]", "'", string)
  #string = re.sub("[“”«»]", "\"", string)
  #string = re.sub("[•†‡]", "", string)
  #string = re.sub("[‐‑–—]", "-", string)
  string = re.sub(ur'[\u2E00-\uFFFF]', "", string)
  string = re.sub("\\s+", " ", string).strip()
  return string


def simple_normalize(string):
  string = correct_unicode(string)
  # Citations
  string = re.sub("\[(nb ?)?\d+\]", "", string)
  string = re.sub("\*+$", "", string)
  # Year in parenthesis
  string = re.sub("\(\d* ?-? ?\d*\)", "", string)
  string = re.sub("^\"(.*)\"$", "", string)
  return string


def full_normalize(string):
  #print "an: ", string
  string = simple_normalize(string)
  # Remove trailing info in brackets
  string = re.sub("\[[^\]]*\]", "", string)
  # Remove most unicode characters in other languages
  string = re.sub(ur'[\u007F-\uFFFF]', "", string.strip())
  # Remove trailing info in parenthesis
  string = re.sub("\([^)]*\)$", "", string.strip())
  string = final_normalize(string)
  # Get rid of question marks
  string = re.sub("\?", "", string).strip()
  # Get rid of trailing colons (usually occur in column titles)
  string = re.sub("\:$", " ", string).strip()
  # Get rid of slashes
  string = re.sub(r"/", " ", string).strip()
  string = re.sub(r"\\", " ", string).strip()
  # Replace colon, slash, and dash with space
  # Note: need better replacement for this when parsing time
  string = re.sub(r"\:", " ", string).strip()
  string = re.sub("/", " ", string).strip()
  string = re.sub("-", " ", string).strip()
  # Convert empty strings to UNK
  # Important to do this last or near last
  if not string:
    string = "UNK"
  return string

def final_normalize(string):
  # Remove leading and trailing whitespace
  string = re.sub("\\s+", " ", string).strip()
  # Convert entirely to lowercase
  string = string.lower()
  # Get rid of strangely escaped newline characters
  string = re.sub("\\\\n", " ", string).strip()
  # Get rid of quotation marks
  string = re.sub(r"\"", "", string).strip()
  string = re.sub(r"\'", "", string).strip()
  string = re.sub(r"`", "", string).strip()
  # Get rid of *
  string = re.sub("\*", "", string).strip()
  return string

def is_number(x):
  try:
    f = float(x)
    return not is_nan_or_inf(f)
  except ValueError:
    return False
  except TypeError:
    return False


class WikiExample(object):

  def __init__(self, id, question, answer, table_key):
    self.question_id = id
    self.question = question
    self.answer = answer
    self.table_key = table_key
    self.lookup_matrix = []
    self.is_bad_example = False
    self.is_word_lookup = False
    self.is_ambiguous_word_lookup = False
    self.is_number_lookup = False
    self.is_number_calc = False
    self.is_unknown_answer = False


class TableInfo(object):

  def __init__(self, word_columns, word_column_names, word_column_indices,
               number_columns, number_column_names, number_column_indices,
               processed_word_columns, processed_number_columns, orig_columns):
    self.word_columns = word_columns
    self.word_column_names = word_column_names
    self.word_column_indices = word_column_indices
    self.number_columns = number_columns
    self.number_column_names = number_column_names
    self.number_column_indices = number_column_indices
    self.processed_word_columns = processed_word_columns
    self.processed_number_columns = processed_number_columns
    self.orig_columns = orig_columns


class WikiQuestionLoader(object):

  def __init__(self, data_name, root_folder):
    self.root_folder = root_folder
    self.data_folder = os.path.join(self.root_folder, "data")
    self.examples = []
    self.data_name = data_name

  def num_questions(self):
    return len(self.examples)

  def load_qa(self):
    data_source = os.path.join(self.data_folder, self.data_name)
    f = tf.gfile.GFile(data_source, "r")
    id_regex = re.compile("\(id ([^\)]*)\)")
    for line in f:
      id_match = id_regex.search(line)
      id = id_match.group(1)
      self.examples.append(id)

  def load(self):
    self.load_qa()


def is_date(word):
  if (not (bool(re.search("[a-z0-9]", word, re.IGNORECASE)))):
    return False
  if (len(word) != 10):
    return False
  if (word[4] != "-"):
    return False
  if (word[7] != "-"):
    return False
  for i in range(len(word)):
    if (not (word[i] == "X" or word[i] == "x" or word[i] == "-" or re.search(
        "[0-9]", word[i]))):
      return False
  return True


class WikiQuestionGenerator(object):

  def __init__(self, train_name, dev_name, test_name, root_folder):
    self.train_name = train_name
    self.dev_name = dev_name
    self.test_name = test_name
    self.train_loader = WikiQuestionLoader(train_name, root_folder)
    self.dev_loader = WikiQuestionLoader(dev_name, root_folder)
    self.test_loader = WikiQuestionLoader(test_name, root_folder)
    self.bad_examples = 0
    self.root_folder = root_folder   
    self.data_folder = os.path.join(self.root_folder, "annotated/data")
    self.annotated_examples = {}
    self.annotated_tables = {}
    self.annotated_word_reject = {}
    self.annotated_word_reject["-lrb-"] = 1
    self.annotated_word_reject["-rrb-"] = 1
    self.annotated_word_reject["UNK"] = 1

  def is_money(self, word):
    if (not (bool(re.search("[a-z0-9]", word, re.IGNORECASE)))):
      return False
    for i in range(len(word)):
      if (not (word[i] == "E" or word[i] == "." or re.search("[0-9]",
                                                             word[i]))):
        return False
    return True

  def remove_consecutive(self, ner_tags, ner_values):
    for i in range(len(ner_tags)):
      if ((ner_tags[i] == "NUMBER" or ner_tags[i] == "MONEY" or
           ner_tags[i] == "PERCENT" or ner_tags[i] == "DATE") and
          i + 1 < len(ner_tags) and ner_tags[i] == ner_tags[i + 1] and
          ner_values[i] == ner_values[i + 1] and ner_values[i] != ""):
        word = ner_values[i]
        word = word.replace(">", "").replace("<", "").replace("=", "").replace(
            "%", "").replace("~", "").replace("$", "").replace("£", "").replace(
                "€", "")
        if (re.search("[A-Z]", word) and not (is_date(word)) and not (
            self.is_money(word))):
          ner_values[i] = "A"
        else:
          ner_values[i] = ","
    return ner_tags, ner_values

  def pre_process_sentence(self, tokens, ner_tags, ner_values):
    sentence = []
    tokens = tokens.split("|")
    ner_tags = ner_tags.split("|")
    ner_values = ner_values.split("|")
    ner_tags, ner_values = self.remove_consecutive(ner_tags, ner_values)
    #print "old: ", tokens
    for i in range(len(tokens)):
      word = tokens[i]
      if (ner_values[i] != "" and
          (ner_tags[i] == "NUMBER" or ner_tags[i] == "MONEY" or
           ner_tags[i] == "PERCENT" or ner_tags[i] == "DATE")):
        word = ner_values[i]
        word = word.replace(">", "").replace("<", "").replace("=", "").replace(
            "%", "").replace("~", "").replace("$", "").replace("£", "").replace(
                "€", "")
        if (re.search("[A-Z]", word) and not (is_date(word)) and not (
            self.is_money(word))):
          word = tokens[i]
        if (is_number(ner_values[i])):
          word = float(ner_values[i])
        elif (is_number(word)):
          word = float(word)
        if (tokens[i] == "score"):
          word = "score"
      if (is_number(word)):
        word = float(word)
      if (not (self.annotated_word_reject.has_key(word))):
        if (is_number(word) or is_date(word) or self.is_money(word)):
          sentence.append(word)
        else:
          word = full_normalize(word)
          if (not (self.annotated_word_reject.has_key(word)) and
              bool(re.search("[a-z0-9]", word, re.IGNORECASE))):
            m = re.search(",", word)
            sentence.append(word.replace(",", ""))
    if (len(sentence) == 0):
      sentence.append("UNK")
    return sentence

  def load_annotated_data(self, in_file):
    self.annotated_examples = {}
    self.annotated_tables = {}
    f = tf.gfile.GFile(in_file, "r")
    counter = 0
    for line in f:
      if (counter > 0):
        line = line.strip()
        (question_id, utterance, context, target_value, tokens, lemma_tokens,
         pos_tags, ner_tags, ner_values, target_canon) = line.split("\t")
        question = self.pre_process_sentence(tokens, ner_tags, ner_values)
        target_canon = target_canon.split("|")
        self.annotated_examples[question_id] = WikiExample(
            question_id, question, target_canon, context)
        self.annotated_tables[context] = []
      counter += 1
    print "Annotated examples loaded ", len(self.annotated_examples)
    f.close()

  def is_number_column(self, a):
    for w in a:
      if (len(w) != 1):
        return False
      if (not (is_number(w[0]))):
        return False
    return True

  def convert_table(self, table):
    answer = []
    for i in range(len(table)):
      temp = []
      for j in range(len(table[i])):
        temp.append(" ".join([str(w) for w in table[i][j]]))
      answer.append(temp)
    return answer

  def load_annotated_tables(self):
    for table in self.annotated_tables.keys():
      annotated_table = table.replace("csv", "annotated")
      orig_columns = []
      processed_columns = []
      f = tf.gfile.GFile(os.path.join(self.root_folder, annotated_table), "r")
      counter = 0
      for line in f:
        if (counter > 0):
          line = line.strip()
          line = line + "\t" * (13 - len(line.split("\t")))
          (row, col, read_id, content, tokens, lemma_tokens, pos_tags, ner_tags,
           ner_values, number, date, num2, read_list) = line.split("\t")
        counter += 1
      f.close()
      max_row = int(row)
      max_col = int(col)
      for i in range(max_col + 1):
        orig_columns.append([])
        processed_columns.append([])
        for j in range(max_row + 1):
          orig_columns[i].append(bad_number)
          processed_columns[i].append(bad_number)
      #print orig_columns
      f = tf.gfile.GFile(os.path.join(self.root_folder, annotated_table), "r")
      counter = 0
      column_names = []
      for line in f:
        if (counter > 0):
          line = line.strip()
          line = line + "\t" * (13 - len(line.split("\t")))
          (row, col, read_id, content, tokens, lemma_tokens, pos_tags, ner_tags,
           ner_values, number, date, num2, read_list) = line.split("\t")
          entry = self.pre_process_sentence(tokens, ner_tags, ner_values)
          if (row == "-1"):
            column_names.append(entry)
          else:
            orig_columns[int(col)][int(row)] = entry
            if (len(entry) == 1 and is_number(entry[0])):
              processed_columns[int(col)][int(row)] = float(entry[0])
            else:
              for single_entry in entry:
                if (is_number(single_entry)):
                  processed_columns[int(col)][int(row)] = float(single_entry)
                  break
              nt = ner_tags.split("|")
              nv = ner_values.split("|")
              for i_entry in range(len(tokens.split("|"))):
                if (nt[i_entry] == "DATE" and
                    is_number(nv[i_entry].replace("-", "").replace("X", ""))):
                  processed_columns[int(col)][int(row)] = float(nv[
                      i_entry].replace("-", "").replace("X", ""))
                  #processed_columns[int(col)][int(row)] =  float(nv[i_entry])
            if (len(entry) == 1 and (is_number(entry[0]) or is_date(entry[0]) or
                                     self.is_money(entry[0]))):
              if (len(entry) == 1 and not (is_number(entry[0])) and
                  is_date(entry[0])):
                entry[0] = entry[0].replace("X", "x")
        counter += 1
      word_columns = []
      processed_word_columns = []
      word_column_names = []
      word_column_indices = []
      number_columns = []
      processed_number_columns = []
      number_column_names = []
      number_column_indices = []
      for i in range(max_col + 1):
        if (self.is_number_column(orig_columns[i])):
          number_column_indices.append(i)
          number_column_names.append(column_names[i])
          temp = []
          for w in orig_columns[i]:
            if (is_number(w[0])):
              temp.append(w[0])
          number_columns.append(temp)
          processed_number_columns.append(processed_columns[i])
        else:
          word_column_indices.append(i)
          word_column_names.append(column_names[i])
          word_columns.append(orig_columns[i])
          processed_word_columns.append(processed_columns[i])
      table_info = TableInfo(
          word_columns, word_column_names, word_column_indices, number_columns,
          number_column_names, number_column_indices, processed_word_columns,
          processed_number_columns, orig_columns)
      self.annotated_tables[table] = table_info
      f.close()

  def answer_classification(self):
    lookup_questions = 0
    number_lookup_questions = 0
    word_lookup_questions = 0
    ambiguous_lookup_questions = 0
    number_questions = 0
    bad_questions = 0
    ice_bad_questions = 0
    tot = 0
    got = 0
    ice = {}
    with tf.gfile.GFile(
        self.root_folder + "/arvind-with-norms-2.tsv", mode="r") as f:
      lines = f.readlines()
      for line in lines:
        line = line.strip()
        if (not (self.annotated_examples.has_key(line.split("\t")[0]))):
          continue
        if (len(line.split("\t")) == 4):
          line = line + "\t" * (5 - len(line.split("\t")))
          if (not (is_number(line.split("\t")[2]))):
            ice_bad_questions += 1
        (example_id, ans_index, ans_raw, process_answer,
         matched_cells) = line.split("\t")
        if (ice.has_key(example_id)):
          ice[example_id].append(line.split("\t"))
        else:
          ice[example_id] = [line.split("\t")]
    for q_id in self.annotated_examples.keys():
      tot += 1
      example = self.annotated_examples[q_id]
      table_info = self.annotated_tables[example.table_key]
      # Figure out if the answer is numerical or lookup
      n_cols = len(table_info.orig_columns)
      n_rows = len(table_info.orig_columns[0])
      example.lookup_matrix = np.zeros((n_rows, n_cols))
      exact_matches = {}
      for (example_id, ans_index, ans_raw, process_answer,
           matched_cells) in ice[q_id]:
        for match_cell in matched_cells.split("|"):
          if (len(match_cell.split(",")) == 2):
            (row, col) = match_cell.split(",")
            row = int(row)
            col = int(col)
            if (row >= 0):
              exact_matches[ans_index] = 1
      answer_is_in_table = len(exact_matches) == len(example.answer)
      if (answer_is_in_table):
        for (example_id, ans_index, ans_raw, process_answer,
             matched_cells) in ice[q_id]:
          for match_cell in matched_cells.split("|"):
            if (len(match_cell.split(",")) == 2):
              (row, col) = match_cell.split(",")
              row = int(row)
              col = int(col)
              example.lookup_matrix[row, col] = float(ans_index) + 1.0
      example.lookup_number_answer = 0.0
      if (answer_is_in_table):
        lookup_questions += 1
        if len(example.answer) == 1 and is_number(example.answer[0]):
          example.number_answer = float(example.answer[0])
          number_lookup_questions += 1
          example.is_number_lookup = True
        else:
          #print "word lookup"
          example.calc_answer = example.number_answer = 0.0
          word_lookup_questions += 1
          example.is_word_lookup = True
      else:
        if (len(example.answer) == 1 and is_number(example.answer[0])):
          example.number_answer = example.answer[0]
          example.is_number_calc = True
        else:
          bad_questions += 1
          example.is_bad_example = True
          example.is_unknown_answer = True
      example.is_lookup = example.is_word_lookup or example.is_number_lookup
      if not example.is_word_lookup and not example.is_bad_example:
        number_questions += 1
        example.calc_answer = example.answer[0]
        example.lookup_number_answer = example.calc_answer
      # Split up the lookup matrix into word part and number part
      number_column_indices = table_info.number_column_indices
      word_column_indices = table_info.word_column_indices
      example.word_columns = table_info.word_columns
      example.number_columns = table_info.number_columns
      example.word_column_names = table_info.word_column_names
      example.processed_number_columns = table_info.processed_number_columns
      example.processed_word_columns = table_info.processed_word_columns
      example.number_column_names = table_info.number_column_names
      example.number_lookup_matrix = example.lookup_matrix[:,
                                                           number_column_indices]
      example.word_lookup_matrix = example.lookup_matrix[:, word_column_indices]

  def load(self):
    train_data = []
    dev_data = []
    test_data = []
    self.load_annotated_data(
        os.path.join(self.data_folder, "training.annotated"))
    self.load_annotated_tables()
    self.answer_classification()
    self.train_loader.load()
    self.dev_loader.load()
    for i in range(self.train_loader.num_questions()):
      example = self.train_loader.examples[i]
      example = self.annotated_examples[example]
      train_data.append(example)
    for i in range(self.dev_loader.num_questions()):
      example = self.dev_loader.examples[i]
      dev_data.append(self.annotated_examples[example])

    self.load_annotated_data(
        os.path.join(self.data_folder, "pristine-unseen-tables.annotated"))
    self.load_annotated_tables()
    self.answer_classification()
    self.test_loader.load()
    for i in range(self.test_loader.num_questions()):
      example = self.test_loader.examples[i]
      test_data.append(self.annotated_examples[example])
    return train_data, dev_data, test_data