import os
import sys
import random
import traceback
from tensorflow.keras.optimizers import RMSprop, Adam
from scipy.stats import rankdata
import math
import numpy as np
from tqdm import tqdm
import argparse
random.seed(42)
import threading 
import configs
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

from utils import normalize, pad, convert, revert
import models, configs, data_loader

class SearchEngine:
    def __init__(self, args, conf=None):
        self.data_path = args.data_path + args.dataset+'/' 
        self.train_params = conf.get('training_params', dict())
        self.data_params = conf.get('data_params',dict())
        self.model_params = conf.get('model_params',dict())
        
        self._eval_sets = None
        
        self._code_reprs = None
        self._codebase = None
        self._codebase_chunksize = 2000000

    ##### Model Loading / saving #####
    def save_model(self, model, epoch):
        model_path = f"./output/{model.__class__.__name__}/models/"
        os.makedirs(model_path, exist_ok=True)
        model.save(model_path + f"epo{epoch}_code.h5", model_path + f"epo{epoch}_desc.h5", overwrite=True)
        
    def load_model(self, model, epoch):
        model_path = f"./output/{model.__class__.__name__}/models/"
        assert os.path.exists(model_path + f"epo{epoch}_code.h5"),f"Weights at epoch {epoch} not found"
        assert os.path.exists(model_path + f"epo{epoch}_desc.h5"),f"Weights at epoch {epoch} not found"
        model.load(model_path + f"epo{epoch}_code.h5", model_path + f"epo{epoch}_desc.h5")


    ##### Training #####
    def train(self, model):
        if self.train_params['reload']>0:
            self.load_model(model, self.train_params['reload'])
        valid_every = self.train_params.get('valid_every', None)
        save_every = self.train_params.get('save_every', None)
        batch_size = self.train_params.get('batch_size', 128)
        nb_epoch = self.train_params.get('nb_epoch', 10)
        split = self.train_params.get('validation_split', 0)
        
        val_loss = {'loss': 1., 'epoch': 0}
        chunk_size = self.train_params.get('chunk_size', 100000)
        
        for i in range(self.train_params['reload']+1, nb_epoch):
            print('Epoch %d :: \n' % i, end='')  
            
            logger.debug('loading data chunk..')
            offset = (i-1)*self.train_params.get('chunk_size', 100000)
            
            names = data_loader.load_hdf5(self.data_path+self.data_params['train_methname'], offset, chunk_size)
            apis = data_loader.load_hdf5(self.data_path+self.data_params['train_apiseq'], offset, chunk_size)
            tokens = data_loader.load_hdf5(self.data_path+self.data_params['train_tokens'], offset, chunk_size)
            descs = data_loader.load_hdf5(self.data_path+self.data_params['train_desc'], offset, chunk_size)
            
            logger.debug('padding data..')
            methnames = pad(names, self.data_params['methname_len'])
            apiseqs = pad(apis, self.data_params['apiseq_len'])
            tokens = pad(tokens, self.data_params['tokens_len'])
            good_descs = pad(descs,self.data_params['desc_len'])
            bad_descs=[desc for desc in descs]
            random.shuffle(bad_descs)
            bad_descs = pad(bad_descs, self.data_params['desc_len'])

            hist = model.fit([methnames, apiseqs, tokens, good_descs, bad_descs], epochs=1, batch_size=batch_size, validation_split=split)

            if hist.history['val_loss'][0] < val_loss['loss']:
                val_loss = {'loss': hist.history['val_loss'][0], 'epoch': i}
            print('Best: Loss = {}, Epoch = {}'.format(val_loss['loss'], val_loss['epoch']))
            
            if save_every is not None and i % save_every == 0:
                self.save_model(model, i)

            if valid_every is not None and i % valid_every == 0:                
                acc, mrr, map, ndcg = self.valid(model, 1000, 1)             

    ##### Evaluation in the develop set #####
    def valid(self, model, poolsize, K):
        """
        validate in a code pool. 
        param: poolsize - size of the code pool, if -1, load the whole test set
        """
        def ACC(real,predict):
            sum=0.0
            for val in real:
                try: index=predict.index(val)
                except ValueError: index=-1
                if index!=-1: sum=sum+1  
            return sum/float(len(real))
        def MAP(real,predict):
            sum=0.0
            for id,val in enumerate(real):
                try: index=predict.index(val)
                except ValueError: index=-1
                if index!=-1: sum=sum+(id+1)/float(index+1)
            return sum/float(len(real))
        def MRR(real,predict):
            sum=0.0
            for val in real:
                try: index=predict.index(val)
                except ValueError: index=-1
                if index!=-1: sum=sum+1.0/float(index+1)
            return sum/float(len(real))
        def NDCG(real,predict):
            dcg=0.0
            idcg=IDCG(len(real))
            for i,predictItem in enumerate(predict):
                if predictItem in real:
                    itemRelevance=1
                    rank = i+1
                    dcg+=(math.pow(2,itemRelevance)-1.0)*(math.log(2)/math.log(rank+1))
            return dcg/float(idcg)
        def IDCG(n):
            idcg=0
            itemRelevance=1
            for i in range(n):
                idcg+=(math.pow(2, itemRelevance)-1.0)*(math.log(2)/math.log(i+2))
            return idcg

        #load valid dataset
        if self._eval_sets is None:
            methnames = data_loader.load_hdf5(self.data_path+self.data_params['valid_methname'], 0, poolsize)
            apiseqs= data_loader.load_hdf5(self.data_path+self.data_params['valid_apiseq'], 0, poolsize)
            tokens = data_loader.load_hdf5(self.data_path+self.data_params['valid_tokens'], 0, poolsize)
            descs = data_loader.load_hdf5(self.data_path+self.data_params['valid_desc'], 0, poolsize) 
            self._eval_sets={'methnames':methnames, 'apiseqs':apiseqs, 'tokens':tokens, 'descs':descs}
            
        accs,mrrs,maps,ndcgs = [], [], [], []
        data_len = len(self._eval_sets['descs'])
        for i in tqdm(range(data_len)):
            desc=self._eval_sets['descs'][i]#good desc
            descs = pad([desc]*data_len,self.data_params['desc_len'])
            methnames = pad(self._eval_sets['methnames'],self.data_params['methname_len'])
            apiseqs= pad(self._eval_sets['apiseqs'],self.data_params['apiseq_len'])
            tokens= pad(self._eval_sets['tokens'],self.data_params['tokens_len'])
            n_results = K          
            sims = model.predict([methnames, apiseqs,tokens, descs], batch_size=data_len).flatten()
            negsims= np.negative(sims)
            predict = np.argpartition(negsims, kth=n_results-1)
            predict = predict[:n_results]   
            predict = [int(k) for k in predict]
            real=[i]
            accs.append(ACC(real,predict))
            mrrs.append(MRR(real,predict))
            maps.append(MAP(real,predict))
            ndcgs.append(NDCG(real,predict))                          
        logger.info(f'ACC={np.mean(accs)}, MRR={np.mean(mrrs)}, MAP={np.mean(maps)}, nDCG={np.mean(ndcgs)}')        
        return acc,mrr,map,ndcg
    
    
    ##### Compute Representation #####
    def repr_code(self, model):
        logger.info('Loading the use data ..')
        methnames = data_loader.load_hdf5(self.data_path+self.data_params['use_methname'],0,-1)
        apiseqs = data_loader.load_hdf5(self.data_path+self.data_params['use_apiseq'],0,-1)
        tokens = data_loader.load_hdf5(self.data_path+self.data_params['use_tokens'],0,-1) 
        methnames = pad(methnames, self.data_params['methname_len'])
        apiseqs = pad(apiseqs, self.data_params['apiseq_len'])
        tokens = pad(tokens, self.data_params['tokens_len'])
        
        logger.info('Representing code ..')
        vecs= model.repr_code([methnames, apiseqs, tokens], batch_size=10000)
        vecs= vecs.astype(np.float)
        vecs= normalize(vecs)
        return vecs
            
    
    def search(self, model, vocab, query, n_results=10):
        desc=[convert(vocab, query)]#convert desc sentence to word indices
        padded_desc = pad(desc, self.data_params['desc_len'])
        desc_repr=model.repr_desc([padded_desc])
        desc_repr=desc_repr.astype(np.float32)
        desc_repr = normalize(desc_repr).T # [dim x 1]
        codes, sims = [], []
        threads=[]
        for i,code_reprs_chunk in enumerate(self._code_reprs):
            t = threading.Thread(target=self.search_thread, args = (codes,sims,desc_repr,code_reprs_chunk,i,n_results))
            threads.append(t)
        for t in threads:
            t.start()
        for t in threads:#wait until all sub-threads finish
            t.join()
        return codes,sims
                 
    def search_thread(self, codes, sims, desc_repr, code_reprs, i, n_results):        
    #1. compute similarity
        chunk_sims=np.dot(code_reprs, desc_repr) # [pool_size x 1] 
        chunk_sims = np.squeeze(chunk_sims, axis=1)
    #2. choose top results
        negsims=np.negative(chunk_sims)
        maxinds = np.argpartition(negsims, kth=n_results-1)
        maxinds = maxinds[:n_results]        
        chunk_codes = [self._codebase[i][k] for k in maxinds]
        chunk_sims = chunk_sims[maxinds]
        codes.extend(chunk_codes)
        sims.extend(chunk_sims)
        
    def postproc(self,codes_sims):
        codes_, sims_ = zip(*codes_sims)
        codes= [code for code in codes_]
        sims= [sim for sim in sims_]
        final_codes=[]
        final_sims=[]
        n=len(codes_sims)        
        for i in range(n):
            is_dup=False
            for j in range(i):
                if codes[i][:80]==codes[j][:80] and abs(sims[i]-sims[j])<0.01:
                    is_dup=True
            if not is_dup:
                final_codes.append(codes[i])
                final_sims.append(sims[i])
        return zip(final_codes,final_sims)

    
def parse_args():
    parser = argparse.ArgumentParser("Train and Test Code Search(Embedding) Model")
    parser.add_argument("--data_path", type=str, default='./data/', help="working directory")
    parser.add_argument("--model", type=str, default="JointEmbeddingModel", help="model name")
    parser.add_argument("--dataset", type=str, default="github", help="dataset name")
    parser.add_argument("--mode", choices=["train","eval","repr_code","search"], default='train',
                        help="The mode to run. The `train` mode trains a model;"
                        " the `eval` mode evaluat models in a test set "
                        " The `repr_code/repr_desc` mode computes vectors"
                        " for a code snippet or a natural language description with a trained model.")
    parser.add_argument("--verbose",action="store_true", default=True, help="Be verbose")
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    config=getattr(configs, 'config_'+args.model)()
    engine = SearchEngine(args, config)

    ##### Define model ######
    logger.info('Build Model')
    model = getattr(models, args.model)(config)#initialize the model
    model.build()
    model.summary(export_path = f"./output/{args.model}/")
    
    optimizer = config.get('training_params', dict()).get('optimizer', 'adam')
    model.compile(optimizer=optimizer)  

    data_path = args.data_path+args.dataset+'/'
    
    if args.mode=='train':  
        engine.train(model)
        
    elif args.mode=='eval': # evaluate for a specific epoch
        if config['training_params']['reload']>0:
            engine.load_model(model, config['training_params']['reload'])
        engine.eval(model, -1, 10)
        
    elif args.mode=='repr_code':
        if config['training_params']['reload']>0:
            engine.load_model(model, config['training_params']['reload'])
        vecs = engine.repr_code(model)
        data_loader.save_code_reprs(vecs, data_path+config['data_params']['use_codevecs'])
        
    elif args.mode=='search':
        #search code based on a desc
        if config['training_params']['reload']>0:
            engine.load_model(model, config['training_params']['reload'])
        engine._code_reprs = data_loader.load_code_reprs(data_path+config['data_params']['use_codevecs'], engine._codebase_chunksize)
        engine._codebase = data_loader.load_codebase(data_path+config['data_params']['use_codebase'], engine._codebase_chunksize)
        vocab = data_loader.load_pickle(data_path+config['data_params']['vocab_desc'])
        while True:
            try:
                query = input('Input Query: ')
                n_results = int(input('How many results? '))
            except Exception:
                print("Exception while parsing your input:")
                traceback.print_exc()
                break
            query = query.lower().replace('how to ', '').replace('how do i ', '').replace('how can i ', '').replace('?', '').strip()
            codes,sims=engine.search(model, vocab, query, n_results)
            zipped=zip(codes,sims)
            zipped=sorted(zipped, reverse=True, key=lambda x:x[1])
            zipped=engine.postproc(zipped)
            zipped = list(zipped)[:n_results]
            results = '\n\n'.join(map(str,zipped)) #combine the result into a returning string
            print(results)