import sys
import os
import time
import pickle
from collections import Counter
import numpy as np
from utils import list_rules, print_rules


class Experiment():
    """
    This class handles all experiments related activties, 
    including training, testing, early stop, and visualize
    results, such as get attentions and get rules. 

    Args:
        sess: a TensorFlow session 
        saver: a TensorFlow saver
        option: an Option object that contains hyper parameters
        learner: an inductive learner that can  
                 update its parameters and perform inference.
        data: a Data object that can be used to obtain 
              num_batch_train/valid/test,
              next_train/valid/test,
              and a parser for get rules.
    """
    
    def __init__(self, sess, saver, option, learner, data):
        self.sess = sess
        self.saver = saver
        self.option = option
        self.learner = learner
        self.data = data
        # helpers
        self.msg_with_time = lambda msg: \
                "%s Time elapsed %0.2f hrs (%0.1f mins)" \
                % (msg, (time.time() - self.start) / 3600., 
                        (time.time() - self.start) / 60.)

        self.start = time.time()
        self.epoch = 0
        self.best_valid_loss = np.inf
        self.best_valid_in_top = 0.
        self.train_stats = []
        self.valid_stats = []
        self.test_stats = []
        self.early_stopped = False
        self.log_file = open(os.path.join(self.option.this_expsdir, "log.txt"), "w")


    def one_epoch(self, mode, num_batch, next_fn):
        epoch_loss = []
        epoch_in_top = []
        for batch in xrange(num_batch):
            if (batch+1) % max(1, (num_batch / self.option.print_per_batch)) == 0:
                sys.stdout.write("%d/%d\t" % (batch+1, num_batch))
                sys.stdout.flush()
            
            (qq, hh, tt), mdb = next_fn()
            if mode == "train":
                run_fn = self.learner.update
            else:
                run_fn = self.learner.predict
            loss, in_top = run_fn(self.sess,
                                  qq, 
                                  hh, 
                                  tt, 
                                  mdb) 
            epoch_loss += list(loss)
            epoch_in_top += list(in_top)
                                    
        msg = self.msg_with_time(
                "Epoch %d mode %s Loss %0.4f In top %0.4f." 
                % (self.epoch+1, mode, np.mean(epoch_loss), np.mean(epoch_in_top)))
        print(msg)
        self.log_file.write(msg + "\n")
        return epoch_loss, epoch_in_top

    def one_epoch_train(self):
        if self.epoch > 0 and self.option.resplit:
            self.data.train_resplit(self.option.no_link_percent)
        loss, in_top = self.one_epoch("train", 
                                      self.data.num_batch_train, 
                                      self.data.next_train)
        
        self.train_stats.append([loss, in_top])
        
    def one_epoch_valid(self):
        loss, in_top = self.one_epoch("valid", 
                                      self.data.num_batch_valid, 
                                      self.data.next_valid)
        self.valid_stats.append([loss, in_top])
        self.best_valid_loss = min(self.best_valid_loss, np.mean(loss))
        self.best_valid_in_top = max(self.best_valid_in_top, np.mean(in_top))

    def one_epoch_test(self):
        loss, in_top = self.one_epoch("test", 
                                      self.data.num_batch_test,
                                      self.data.next_test)
        self.test_stats.append([loss, in_top])
    
    def early_stop(self):
        loss_improve = self.best_valid_loss == np.mean(self.valid_stats[-1][0])
        in_top_improve = self.best_valid_in_top == np.mean(self.valid_stats[-1][1])
        if loss_improve or in_top_improve:
            return False
        else:
            if self.epoch < self.option.min_epoch:
                return False
            else:
                return True

    def train(self):
        while (self.epoch < self.option.max_epoch and not self.early_stopped):
            self.one_epoch_train()
            self.one_epoch_valid()
            self.one_epoch_test()
            self.epoch += 1
            model_path = self.saver.save(self.sess, 
                                         self.option.model_path,
                                         global_step=self.epoch)
            print("Model saved at %s" % model_path)
            
            if self.early_stop():
                self.early_stopped = True
                print("Early stopped at epoch %d" % (self.epoch))
        
        all_test_in_top = [np.mean(x[1]) for x in self.test_stats]
        best_test_epoch = np.argmax(all_test_in_top)
        best_test = all_test_in_top[best_test_epoch]
        
        msg = "Best test in top: %0.4f at epoch %d." % (best_test, best_test_epoch + 1)       
        print(msg)
        self.log_file.write(msg + "\n")
        pickle.dump([self.train_stats, self.valid_stats, self.test_stats],
                    open(os.path.join(self.option.this_expsdir, "results.pckl"), "w"))

    def get_predictions(self):
        if self.option.query_is_language:
            all_accu = []
            all_num_preds = []
            all_num_preds_no_mistake = []

        f = open(os.path.join(self.option.this_expsdir, "test_predictions.txt"), "w")
        if self.option.get_phead:
            f_p = open(os.path.join(self.option.this_expsdir, "test_preds_and_probs.txt"), "w")
        all_in_top = []
        for batch in xrange(self.data.num_batch_test):
            if (batch+1) % max(1, (self.data.num_batch_test / self.option.print_per_batch)) == 0:
                sys.stdout.write("%d/%d\t" % (batch+1, self.data.num_batch_test))
                sys.stdout.flush()
            (qq, hh, tt), mdb = self.data.next_test()
            in_top, predictions_this_batch \
                    = self.learner.get_predictions_given_queries(self.sess, qq, hh, tt, mdb)
            all_in_top += list(in_top)

            for i, (q, h, t) in enumerate(zip(qq, hh, tt)):
                p_head = predictions_this_batch[i, h]
                if self.option.adv_rank:
                    eval_fn = lambda (j, p): p >= p_head and (j != h)
                elif self.option.rand_break:
                    eval_fn = lambda (j, p): (p > p_head) or ((p == p_head) and (j != h) and (np.random.uniform() < 0.5))
                else:
                    eval_fn = lambda (j, p): (p > p_head)
                this_predictions = filter(eval_fn, enumerate(predictions_this_batch[i, :]))
                this_predictions = sorted(this_predictions, key=lambda x: x[1], reverse=True)
                if self.option.query_is_language:
                    all_num_preds.append(len(this_predictions))
                    mistake = False
                    for k, _ in this_predictions:
                        assert(k != h)
                        if not self.data.is_true(q, k, t):
                            mistake = True
                            break
                    all_accu.append(not mistake)
                    if not mistake:
                        all_num_preds_no_mistake.append(len(this_predictions))
                else:
                    this_predictions.append((h, p_head))
                    this_predictions = [self.data.number_to_entity[j] for j, _ in this_predictions]
                    q_string = self.data.parser["query"][q]
                    h_string = self.data.number_to_entity[h]
                    t_string = self.data.number_to_entity[t]
                    to_write = [q_string, h_string, t_string] + this_predictions
                    f.write(",".join(to_write) + "\n")
                    if self.option.get_phead:
                        f_p.write(",".join(to_write + [str(p_head)]) + "\n")
        f.close()
        if self.option.get_phead:
            f_p.close()
        
        if self.option.query_is_language:
            print("Averaged num of preds", np.mean(all_num_preds))
            print("Averaged num of preds for no mistake", np.mean(all_num_preds_no_mistake))
            msg = "Accuracy %0.4f" % np.mean(all_accu)
            print(msg)
            self.log_file.write(msg + "\n")

        msg = "Test in top %0.4f" % np.mean(all_in_top)
        msg += self.msg_with_time("\nTest predictions written.")
        print(msg)
        self.log_file.write(msg + "\n")
        
    def get_attentions(self):
        if self.option.query_is_language:
            num_batch = int(np.ceil(1.0*len(self.data.query_for_rules)/self.option.batch_size))
            query_batches = np.array_split(self.data.query_for_rules, num_batch)
        else:   
            #print(self.data.query_for_rules)
            if not self.option.type_check:
                num_batch = int(np.ceil(1.*len(self.data.query_for_rules)/self.option.batch_size))
                query_batches = np.array_split(self.data.query_for_rules, num_batch)       
            else:
                query_batches = [[i] for i in self.data.query_for_rules]

        all_attention_operators = {}
        all_attention_memories = {}

        for queries in query_batches:
            attention_operators, attention_memories \
            = self.learner.get_attentions_given_queries(self.sess, queries)
            
            # Tuple-ize in order to be used as dict keys
            if self.option.query_is_language:
                queries = [tuple(q) for q in queries]

            for i in xrange(len(queries)):
                all_attention_operators[queries[i]] \
                                        = [[attn[i] 
                                        for attn in attn_step] 
                                        for attn_step in attention_operators]
                all_attention_memories[queries[i]] = \
                                        [attn_step[i, :] 
                                        for attn_step in attention_memories]
        pickle.dump([all_attention_operators, all_attention_memories], 
                    open(os.path.join(self.option.this_expsdir, "attentions.pckl"), "w"))
               
        msg = self.msg_with_time("Attentions collected.")
        print(msg)
        self.log_file.write(msg + "\n")

        all_queries = reduce(lambda x,y: list(x) + list(y), query_batches, [])
        return all_attention_operators, all_attention_memories, all_queries

    def get_rules(self):
        all_attention_operators, all_attention_memories, queries = self.get_attentions()

        all_listed_rules = {}
        all_printed_rules = []
        for i, q in enumerate(queries):
            if not self.option.query_is_language:
                if (i+1) % max(1, (len(queries) / 5)) == 0:
                    sys.stdout.write("%d/%d\t" % (i, len(queries)))
                    sys.stdout.flush()
            else: 
                # Tuple-ize in order to be used as dict keys
                q = tuple(q)
            all_listed_rules[q] = list_rules(all_attention_operators[q], 
                                             all_attention_memories[q],
                                             self.option.rule_thr,)
            all_printed_rules += print_rules(q, 
                                             all_listed_rules[q], 
                                             self.data.parser,
                                             self.option.query_is_language)

        pickle.dump(all_listed_rules, 
                    open(os.path.join(self.option.this_expsdir, "rules.pckl"), "w"))
        with open(os.path.join(self.option.this_expsdir, "rules.txt"), "w") as f:
            for line in all_printed_rules:
                f.write(line + "\n")
        msg = self.msg_with_time("\nRules listed and printed.")
        print(msg)
        self.log_file.write(msg + "\n")

    def get_vocab_embedding(self):
        vocab_embedding = self.learner.get_vocab_embedding(self.sess)
        msg = self.msg_with_time("Vocabulary embedding retrieved.")
        print(msg)
        self.log_file.write(msg + "\n")
        
        vocab_embed_file = os.path.join(self.option.this_expsdir, "vocab_embed.pckl")
        pickle.dump({"embedding": vocab_embedding, "labels": self.data.query_vocab_to_number}, open(vocab_embed_file, "w"))
        msg = self.msg_with_time("Vocabulary embedding stored.")
        print(msg)
        self.log_file.write(msg + "\n")

    def close_log_file(self):
        self.log_file.close()