from pprint import pprint, pformat

import os
import shutil
import pickle

import logging
from pprint import pprint, pformat
logging.basicConfig(format="%(levelname)-8s:%(filename)s.%(funcName)20s >>   %(message)s")
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

import torch
from torch import nn
from torch.autograd import Variable

from collections import namedtuple, defaultdict

"""
    Local Utilities, Helper Functions

"""
def mkdir_if_exist_not(name):
    if not os.path.isdir(name):
        return os.mkdir(name)
    
def initialize_task(hpconfig = 'hpconfig.py', prefix='run00'):
    log.info('loading hyperparameters from {}'.format(hpconfig))
    root_dir = hpconfig.replace('.py', '') + '__' + hash_file(hpconfig)[-6:]
    mkdir_if_exist_not(prefix)
    root_dir = '{}/{}'.format(prefix, root_dir)
    mkdir_if_exist_not(root_dir)
    mkdir_if_exist_not('{}/results'.format(root_dir))
    mkdir_if_exist_not('{}/results/metrics'.format(root_dir))
    mkdir_if_exist_not('{}/weights'.format(root_dir))
    mkdir_if_exist_not('{}/plots'.format(root_dir))

    shutil.copy(hpconfig, root_dir)
    shutil.copy('config.py', root_dir)

    return root_dir

"""
Logging utils
"""
def logger(func, dlevel=logging.INFO):
    def wrapper(*args, **kwargs):
        level = log.getEffectiveLevel()
        log.setLevel(level)
        ret = func(*args, **kwargs)
        log.setLevel(level)
        return ret
    
    return wrapper


from pprint import pprint, pformat
from tqdm import tqdm as _tqdm

def tqdm(a, *args, **kwargs):
    return _tqdm(a, ncols=100,  *args, **kwargs) # if config.CONFIG.tqdm else a


def squeeze(lol):
    """
    List of lists to List

    Args:
        lol : List of lists

    Returns:
       List 

    """
    return [ i for l in lol for i in l ]

"""
    util functions to enable pretty print on namedtuple

"""
def _namedtuple_repr_(self):
    return pformat(self.___asdict())

def ___asdict(self):
    d = self._asdict()
    for k, v in d.items():
        if hasattr(v, '_asdict'):
            d[k] = ___asdict(v)

    return dict(d)


"""
# Batching utils   
"""
import numpy as np
def seq_maxlen(seqs):
    return max([len(seq) for seq in seqs])

PAD = 0
def pad_seq(seqs, maxlen=0, PAD=PAD):
    def pad_seq_(seq):
        return seq[:maxlen] + [PAD]*(maxlen-len(seq))

    if len(seqs) == 0:
        return seqs
    
    if type(seqs[0]) == type([]):
        maxlen = maxlen if maxlen else seq_maxlen(seqs)
        seqs = [ pad_seq_(seq) for seq in seqs ]
    else:
        seqs = pad_seq_(seqs)
        
    return seqs


class ListTable(list):
    """ Overridden list class which takes a 2-dimensional list of 
    the form [[1,2,3],[4,5,6]], and renders an HTML Table in 
    IPython Notebook. 
    Taken from http://calebmadrigal.com/display-list-as-table-in-ipython-notebook/"""
    
    def _repr_html_(self):
        html = ["<table>"]
        for row in self:
            html.append("<tr>")
            
            for col in row:
                html.append("<td>{0}</td>".format(col))
            
            html.append("</tr>")
        html.append("</table>")
        return ''.join(html)

    def __repr__(self):
        lines = []
        for i in self:
            lines.append('|'.join(i))
        log.debug('number of lines: {}'.format(len(lines)))
        return '\n'.join(lines + ['\n'])

"""
torch utils
"""
def are_weights_same(model1, model2):
    m1dict = model1.state_dict()
    m2dict = model2.state_dict()
    
    if m1dict.keys() != m2dict.keys():
        log.error('models don\'t match')
        log.error(pformat(m1dict.keys()))
        log.error(pformat(m2dict.keys()))
        return False
    
    for p in m1dict.keys():
        ne = m1dict[p].data.ne(m2dict[p].data)
        if ne.sum() > 0:
            print('===== {} ===='.format(p))
            print(ne.cpu().numpy())
            print('sum = ', ne.sum().cpu().numpy())
    
            return False
        
    return True

def LongVar(config, array, requires_grad=False):
    return Var(config, array, requires_grad).long()

def Var(config, array, requires_grad=False):
    ret =  Variable(torch.Tensor(array), requires_grad=requires_grad)
    if config.CONFIG.cuda:
        ret = ret.cuda()

    return ret

def init_hidden(config, batch_size, cell):
    layers = 1
    if isinstance(cell, (nn.LSTM, nn.GRU)):
        layers = cell.num_layers
        if cell.bidirectional:
            layers = layers * 2

    if isinstance(cell, (nn.LSTM, nn.LSTMCell)):
        hidden  = Variable(torch.zeros(layers, batch_size, cell.hidden_size))
        context = Variable(torch.zeros(layers, batch_size, cell.hidden_size))
    
        if config.CONFIG.cuda:
            hidden  = hidden.cuda()
            context = context.cuda()
        return hidden, context

    if isinstance(cell, (nn.GRU, nn.GRUCell)):
        hidden  = Variable(torch.zeros(layers, batch_size, cell.hidden_size))
        if config.CONFIG.cuda:
            hidden  = hidden.cuda()
        return hidden

class FLAGS:
    CONTINUE_TRAINING = 0
    STOP_TRAINING = 1
    
class Averager(list):
    def __init__(self, config, filename=None, ylim=None, *args, **kwargs):
        super(Averager, self).__init__(*args, **kwargs)
        self.config = config
        self.filename = filename
        self.ylim = ylim
        if filename:
            try:
                f = '{}.pkl'.format(filename)
                if os.path.isfile(f):
                    log.debug('loading {}'.format(f))
                    self.extend(pickle.load(open(f, 'rb')))
            except:
                open(filename, 'w').close()

    @property
    def avg(self):
        if len(self):
            return sum(self)/len(self)
        else:
            return 0

    def __str__(self):
        if len(self) > 0:
            #return 'min/max/avg/latest: {:0.5f}/{:0.5f}/{:0.5f}/{:0.5f}'.format(min(self), max(self), self.avg, self[-1])
            return '{:0.4f}/{:0.4f}/{:0.4f}/{:0.4f}'.format(min(self), max(self), self.avg, self[-1])
        
        return '<empty>'

    def append(self, a):
        super(Averager, self).append(a)
            
    def empty(self):
        del self[:]

    def write_to_file(self):
        
        if self.filename:
            if self.config.CONFIG.plot_metrics:
                import matplotlib.pyplot as plt
                plt.plot(self)
                plt.title(os.path.basename(self.filename), fontsize=20)
                plt.xlabel('epoch')
                if self.ylim:
                    plt.ylim(*self.ylim)

                plt.savefig('{}.{}'.format(self.filename, 'png'))
                plt.close()

            pickle.dump(list(self), open('{}.pkl'.format(self.filename), 'wb'))
            with open(self.filename, 'a') as f:
                f.write(self.__str__() + '\n')
                f.flush()

    

class EpochAverager(Averager):
    def __init__(self, config, filename=None, *args, **kwargs):
        super(EpochAverager, self).__init__(config, filename, *args, **kwargs)
        self.config = config
        self.epoch_cache = Averager(config, filename, *args, *kwargs)

    def cache(self, a):
        self.epoch_cache.append(a)

    def clear_cache(self):
        super(EpochAverager, self).append(self.epoch_cache.avg)
        self.epoch_cache.empty();
                

# Python program to find SHA256 hash string of a file
#https://www.quickprogrammingtips.com/python/how-to-calculate-sha256-hash-of-a-file-in-python.html
import hashlib

def hash_file(filename):
    sha256_hash = hashlib.sha256()
    with open(filename,"rb") as f:
        # Read and update hash string value in blocks of 4K
        for byte_block in iter(lambda: f.read(4096),b""):
            sha256_hash.update(byte_block)
            
    return sha256_hash.hexdigest()


def dump_vocab_tsv(config, vocab, embedding, filepath):
    assert embedding.shape[0] == len(vocab)

    vector_filepath = filepath.replace('.tsv', '.vector.tsv')
    token_filepath  = filepath.replace('.tsv', '.token.tsv')

    vector_file = open(vector_filepath, 'w')
    token_file  = open(token_filepath,  'w')
    
    for i, vector in enumerate(embedding):
        vector_file.write('\t'.join([str(v) for v in vector]) + '\n')
        token_file.write(vocab[i] + '\n')

    vector_file.close()
    token_file.close()


def dump_cosine_similarity_tsv(config, vocab, embedding, filepath, count=100):
    assert embedding.shape[0] == len(vocab)

    matrix_filepath = filepath.replace('.tsv', '.matrix.pkl')
    similar_filepath = filepath.replace('.tsv', '.similar.tsv')
    dissimilar_filepath  = filepath.replace('.tsv', '.dissimilar.tsv')

    e_norm = embedding / embedding.norm(dim=1)[:, None]
    scores = torch.mm(e_norm, e_norm.t())

    pickle.dump(scores.cpu().numpy(), open(matrix_filepath, 'wb'))

    similars = scores.topk(count, dim=1)[1]
    dissimilars = (1 - scores).topk(count, dim=1)[1]


    similar_file = open(similar_filepath, 'w')
    dissimilar_file  = open(dissimilar_filepath,  'w')
    
    for i in range(len(vocab)):
        similar_file.write('|'.join(vocab.index2word[j] for j in similars[i]) + '\n')
        dissimilar_file.write('|'.join(vocab.index2word[j] for j in dissimilars[i]) + '\n')
    
    similar_file.close()
    dissimilar_file.close()

def conv2d_output_size(W, H, F=3, S=1, P=1):
    W2 = (W - F + 2*P)//S + 1
    H2 = (H - F + 2*P)//S + 1
    return (W2, H2)