import os
import time
import tensorflow as tf
from batcher import Batcher
import beam_search
import data
import cPickle as pk
import json
import pyrouge
import util
import logging
import numpy as np
import pdb

FLAGS = tf.app.flags.FLAGS

SECS_UNTIL_NEW_CKPT = 60  # max number of seconds before loading new checkpoint

class End2EndEvaluator(object):
  """Evaluate selector and rewriter"""

  def __init__(self, model, batcher, vocab):
    """Initialize decoder.

    Args:
      model: a Seq2SeqAttentionModel object.
      batcher: a Batcher object.
      vocab: Vocabulary object
    """
    self._model = model
    self._model.build_graph()
    self._batcher = batcher
    self._vocab = vocab
    self._saver = tf.train.Saver(max_to_keep=3) # we use this to load checkpoints for decoding
    self._sess = tf.Session(config=util.get_config())
    if FLAGS.mode == 'evalall':
      self.prepare_evaluate()

  def prepare_evaluate(self, ckpt_path=None):
    # Load an initial checkpoint to use for decoding
    if FLAGS.mode == 'evalall':
      if FLAGS.load_best_eval_model:
        tf.logging.info('Loading best eval checkpoint')
        ckpt_path = util.load_ckpt(self._saver, self._sess, ckpt_dir='eval'+FLAGS.eval_method)
      elif FLAGS.eval_ckpt_path:
        ckpt_path = util.load_ckpt(self._saver, self._sess, ckpt_path=FLAGS.eval_ckpt_path)
      else:
        tf.logging.info('Loading best train checkpoint')
        ckpt_path = util.load_ckpt(self._saver, self._sess)
    elif FLAGS.mode == 'eval':
      _ = util.load_ckpt(self._saver, self._sess, ckpt_path=ckpt_path) # load a new checkpoint

    if FLAGS.single_pass:
      # Make a descriptive decode directory name
      ckpt_name = "ckpt-" + ckpt_path.split('-')[-1] # this is something of the form "ckpt-123456"
      self._decode_dir = os.path.join(FLAGS.log_root, get_decode_dir_name(ckpt_name))
      tf.logging.info('Save evaluation results to '+ self._decode_dir)
      if os.path.exists(self._decode_dir):
        if FLAGS.mode == 'eval':
          return False  # The checkpoint has already been evaluated. Evaluate next one.
        else:
          raise Exception("single_pass decode directory %s should not already exist" % self._decode_dir)
    else: # Generic decode dir name
      self._decode_dir = os.path.join(FLAGS.log_root, "decode")

    # Make the decode dir if necessary
    if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir)

    if FLAGS.single_pass:
      # Make the dirs to contain output written in the correct format for pyrouge
      self._rouge_ref_dir = os.path.join(self._decode_dir, "reference")
      if not os.path.exists(self._rouge_ref_dir): os.mkdir(self._rouge_ref_dir)
      self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded")
      if not os.path.exists(self._rouge_dec_dir): os.mkdir(self._rouge_dec_dir)
      if FLAGS.save_vis:
        self._rouge_vis_dir = os.path.join(self._decode_dir, "visualize")
        if not os.path.exists(self._rouge_vis_dir): os.mkdir(self._rouge_vis_dir)
      if FLAGS.save_pkl:
        self._result_dir = os.path.join(self._decode_dir, "result")
        if not os.path.exists(self._result_dir): os.mkdir(self._result_dir)
    return True

  def evaluate(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    t0 = time.time()
    counter = 0

    while True:
      batch = self._batcher.next_batch()  # 1 example repeated across batch
      if batch is None: # finished decoding dataset in single_pass mode
        assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
        tf.logging.info("Decoder has finished reading dataset for single_pass.")
        tf.logging.info("Output has been saved in %s and %s. Starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_results, rouge_results_str = rouge_log(rouge_results_dict, self._decode_dir)
        t1 = time.time()
        tf.logging.info("evaluation time: %.3f min", (t1-t0)/60.0)
        return rouge_results, rouge_results_str

      if FLAGS.decode_method == 'greedy':
        output_ids = self._model.run_greedy_search(self._sess, batch)
        for i in range(FLAGS.batch_size):
          self.process_one_article(batch.original_articles_sents[i], batch.original_abstracts_sents[i], \
                                   batch.original_extracts_ids[i], output_ids[i], \
                                   batch.art_oovs[i], None, None, None, None, None, counter)
          counter += 1
      elif FLAGS.decode_method == 'beam':
        # Get sentence probabilities from selector
        selector_output = self._model._selector.run_eval_step(self._sess, batch, probs_only=True)
        sent_probs = selector_output['probs'][0].tolist()
        # Run beam search to get best Hypothesis
        best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)

        # Extract the output ids from the hypothesis and convert back to words
        output_ids = [int(t) for t in best_hyp.tokens[1:]]    # remove start token
        best_hyp.log_probs = best_hyp.log_probs[1:]   # remove start token probability
        self.process_one_article(batch.original_articles_sents[0], batch.original_abstracts_sents[0], \
                                 batch.original_extracts_ids[0], output_ids, batch.art_oovs[0], \
                                 best_hyp.attn_dists_norescale, best_hyp.attn_dists, \
                                 best_hyp.p_gens, best_hyp.log_probs, sent_probs, counter)
        counter += 1

 
  def process_one_article(self, original_article_sents, original_abstract_sents, \
                          original_selected_ids, output_ids, oovs, attn_dists_norescale, \
                          attn_dists, p_gens, log_probs, sent_probs, counter):
    # Remove the [STOP] token from decoded_words, if necessary
    decoded_words = data.outputids2words(output_ids, self._vocab, oovs)
    try:
      fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
      decoded_words = decoded_words[:fst_stop_idx]
    except ValueError:
      decoded_words = decoded_words
    decoded_output = ' '.join(decoded_words) # single string
    decoded_sents = data.words2sents(decoded_words)

    if FLAGS.single_pass:
      verbose = False if FLAGS.mode == 'eval' else True
      self.write_for_rouge(original_abstract_sents, decoded_sents, counter, verbose) # write ref summary and decoded summary to file, to eval with pyrouge later
      if FLAGS.decode_method == 'beam' and FLAGS.save_vis:
        sent_probs_per_word = []
        for sent_id, sent in enumerate(original_article_sents):
          sent_len = len(sent.split(' '))
          for _ in range(sent_len):
            if sent_id < FLAGS.max_art_len:
              sent_probs_per_word.append(sent_probs[sent_id])
            else:
              sent_probs_per_word.append(0)
        original_article = ' '.join(original_article_sents)
        original_abstract = ' '.join(original_abstract_sents)
        article_withunks = data.show_art_oovs(original_article, self._vocab) # string
        abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, oovs)
        self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, attn_dists_norescale, \
                               attn_dists, p_gens, log_probs, sent_probs_per_word, counter, verbose)
      if FLAGS.save_pkl:
        self.save_result(original_article_sents, original_abstract_sents, \
                         original_selected_ids, decoded_sents, counter, verbose)

  def save_result(self, article_sents, reference_sents, gt_ids, decoded_sents, index, verbose=False):
    """save the result in pickle format"""
    data = {'article': article_sents, 
            'reference': reference_sents, 
            'gt_ids': gt_ids, 
            'decoded': decoded_sents}
    output_fname = os.path.join(self._result_dir, 'result_%06d.pkl' % index)
    with open(output_fname, 'wb') as output_file:
      pk.dump(data, output_file)
    if verbose:
      tf.logging.info('Wrote result data to %s', output_fname)

  def write_for_rouge(self, reference_sents, decoded_sents, ex_index, verbose=False):
    """Write output to file in correct format for eval with pyrouge. This is called in single_pass mode.

    Args:
      reference_sents: list of strings
      decoded_words: list of strings
      ex_index: int, the index with which to label the files
    """
    # pyrouge calls a perl script that puts the data into HTML files.
    # Therefore we need to make our output HTML safe.
    decoded_sents = [make_html_safe(w) for w in decoded_sents]
    reference_sents = [make_html_safe(w) for w in reference_sents]

    # Write to file
    ref_file = os.path.join(self._rouge_ref_dir, "%06d_reference.txt" % ex_index)
    decoded_file = os.path.join(self._rouge_dec_dir, "%06d_decoded.txt" % ex_index)

    with open(ref_file, "w") as f:
      for idx,sent in enumerate(reference_sents):
        f.write(sent) if idx==len(reference_sents)-1 else f.write(sent+"\n")
    with open(decoded_file, "w") as f:
      for idx,sent in enumerate(decoded_sents):
        f.write(sent) if idx==len(decoded_sents)-1 else f.write(sent+"\n")

    if verbose:
      tf.logging.info("Wrote example %i to file" % ex_index)


  def write_for_attnvis(self, article, abstract, decoded_words, attn_dists_norescale, attn_dists, p_gens, log_probs, \
                        sent_probs, count=None, verbose=False):
    """Write some data to json file, which can be read into the in-browser attention visualizer tool:
      https://github.com/abisee/attn_vis

    Args:
      article: The original article string.
      abstract: The human (correct) abstract string.
      attn_dists: List of arrays; the attention distributions.
      decoded_words: List of strings; the words of the generated summary.
      p_gens: List of scalars; the p_gen values. If not running in pointer-generator mode, list of None.
    """
    article_lst = article.split() # list of words
    decoded_lst = decoded_words # list of decoded words
    to_write = {
        'article_lst': [make_html_safe(t) for t in article_lst],
        'decoded_lst': [make_html_safe(t) for t in decoded_lst],
        'abstract_str': make_html_safe(abstract),
        'attn_dists_norescale': attn_dists_norescale,
        'attn_dists': attn_dists,
        'probs': np.exp(log_probs).tolist(),
        'sent_probs': sent_probs
    }
    to_write['p_gens'] = p_gens
    if count != None:
      output_fname = os.path.join(self._rouge_vis_dir, 'attn_vis_data_%06d.json' % count)
    else:
      output_fname = os.path.join(self._decode_dir, 'attn_vis_data.json')
    with open(output_fname, 'w') as output_file:
      json.dump(to_write, output_file)
    if verbose:
      tf.logging.info('Wrote visualization data to %s', output_fname)

  def init_batcher(self):
    self._batcher = Batcher(FLAGS.data_path, self._vocab, self._model._hps, single_pass=FLAGS.single_pass)


def make_html_safe(s):
  """Replace any angled brackets in string s to avoid interfering with HTML attention visualizer."""
  s.replace("<", "&lt;")
  s.replace(">", "&gt;")
  return s


def rouge_eval(ref_dir, dec_dir):
  """Evaluate the files in ref_dir and dec_dir with pyrouge, returning results_dict"""
  r = pyrouge.Rouge155()
  r.model_filename_pattern = '#ID#_reference.txt'
  r.system_filename_pattern = '(\d+)_decoded.txt'
  r.model_dir = ref_dir
  r.system_dir = dec_dir
  logging.getLogger('global').setLevel(logging.WARNING) # silence pyrouge logging
  rouge_results = r.convert_and_evaluate()
  return r.output_to_dict(rouge_results)


def rouge_log(results_dict, dir_to_write):
  """Log ROUGE results to screen and write to file.

  Args:
    results_dict: the dictionary returned by pyrouge
    dir_to_write: the directory where we will write the results to"""
  rouge_results = {}
  log_str = ""
  for x in ["1","2","l"]:
    log_str += "\nROUGE-%s:\n" % x
    for y in ["f_score", "recall", "precision"]:
      key = "rouge_%s_%s" % (x,y)
      key_cb = key + "_cb"
      key_ce = key + "_ce"
      val = results_dict[key]
      val_cb = results_dict[key_cb]
      val_ce = results_dict[key_ce]
      if y == 'f_score':
        rouge_results[x] = val
      log_str += "%s: %.4f with confidence interval (%.4f, %.4f)\n" % (key, val, val_cb, val_ce)
  tf.logging.info(log_str) # log to screen
  results_file = os.path.join(dir_to_write, "ROUGE_results.txt")
  tf.logging.info("Writing final ROUGE results to %s...", results_file)
  with open(results_file, "w") as f:
    f.write(log_str)
  return rouge_results, log_str

def get_decode_dir_name(ckpt_name):
  """Make a descriptive name for the decode dir, including the name of the checkpoint we use to decode. This is called in single_pass mode."""

  if "train" in FLAGS.data_path: dataset = "train"
  elif "val" in FLAGS.data_path: dataset = "val"
  elif "test" in FLAGS.data_path: dataset = "test"
  else: raise ValueError("FLAGS.data_path %s should contain one of train, val or test" % (FLAGS.data_path))
  dirname = "decode_%s_%imaxenc_%ibeam_%imindec_%imaxdec" % (dataset, FLAGS.max_enc_steps, FLAGS.beam_size, FLAGS.min_dec_steps, FLAGS.max_dec_steps)
  if ckpt_name is not None:
    dirname += "_%s_%s" % (ckpt_name, FLAGS.decode_method)
  return dirname