from __future__ import print_function, division
import os,sys,copy;import numpy as np
from collections import defaultdict
sys.path.append('./')
from treetime import TreeAnc
from treetime import GTR
from treetime import seq_utils
from Bio import Phylo, AlignIO
from sf_miscellaneous import write_json, write_pickle
from sf_geneCluster_align_makeTree import load_sorted_clusters

def infer_gene_gain_loss(path, rates = [1.0, 1.0]):
    # initialize GTR model with default parameters
    mu = np.sum(rates)
    gene_pi = np.array(rates)/mu
    gain_loss_model = GTR.custom(pi = gene_pi, mu=mu,
                           W=np.ones((2,2)),
                           alphabet = np.array(['0','1']))
    # add "unknown" state to profile
    gain_loss_model.profile_map['-'] = np.ones(2)
    root_dir = os.path.dirname(os.path.realpath(__file__))

    # define file names for pseudo alignment of presence/absence patterns as in 001001010110
    sep='/'
    fasta = sep.join([path.rstrip(sep), 'geneCluster', 'genePresence.aln'])
    # strain tree based on core gene SNPs
    nwk =  sep.join([path.rstrip(sep), 'geneCluster', 'strain_tree.nwk'])

    # instantiate treetime with custom GTR
    t = TreeAnc(nwk, gtr =gain_loss_model, verbose=2)
    # fix leaves names since Bio.Phylo interprets numeric leaf names as confidence
    for leaf in t.tree.get_terminals():
        if leaf.name is None:
            leaf.name = str(leaf.confidence)
    t.aln = fasta
    t.tree.root.branch_length=0.0001
    t.reconstruct_anc(method='ml')

    for n in t.tree.find_clades():
        n.genepresence = n.sequence

    return t


def export_gain_loss(tree, path, merged_gain_loss_output):
    '''
    '''
    # write final tree with internal node names as assigned by treetime
    sep='/'
    output_path= sep.join([path.rstrip(sep), 'geneCluster/'])
    events_dict_path= sep.join([ output_path, 'dt_geneEvents.cpk'])
    gene_pattern_dict_path= sep.join([ output_path, 'dt_genePattern.cpk'])

    tree_fname = sep.join([output_path, 'strain_tree.nwk'])
    Phylo.write(tree.tree, tree_fname, 'newick')


    gene_gain_loss_dict=defaultdict(str)
    preorder_strain_list= [] #store the preorder nodes as strain list
    for node in tree.tree.find_clades(order='preorder'):# order does not matter much here
            if node.up is None: continue
            #print(node.name ,len(node.geneevents),node.geneevents)
            gain_loss = [ str(int(ancestral)*2+int(derived))
                        for ancestral,derived in zip(node.up.genepresence, node.genepresence)]
            gene_gain_loss_dict[node.name]="".join(gain_loss)
            preorder_strain_list.append(node.name)

    gain_loss_array = np.array([[i for i in gain_loss_str]
                                for gain_loss_str in gene_gain_loss_dict.values()], dtype=int)
    # 1 and 2 are codes for gain/loss events
    events_array = ((gain_loss_array == 1) | (gain_loss_array == 2)).sum(axis=0)
    events_dict =  { index:event for index, event in enumerate(events_array) }

    write_pickle(events_dict_path, events_dict)

    if merged_gain_loss_output:
        ## export gene loss dict to json for visualization
        #gene_loss_fname = sep.join([ output_path, 'geneGainLossEvent.json'])
        #write_json(gene_gain_loss_dict, gene_loss_fname, indent=1)
        write_pickle(gene_pattern_dict_path, gene_gain_loss_dict)
    else:
        ## strainID as key, presence pattern as value (converted into np.array)
        sorted_genelist = load_sorted_clusters(path)
        strainID_keymap= {ind:k for ind, k in enumerate(preorder_strain_list)}
        #presence_arr= np.array([ np.fromstring(gene_gain_loss_dict[k], np.int8)-48 for k in preorder_strain_list])
        presence_arr= np.array([ np.array(gene_gain_loss_dict[k],'c') for k in preorder_strain_list])
        ## if true, write pattern dict instead of pattern string in a json file
        pattern_json_flag=False
        for ind, (clusterID, gene) in enumerate(sorted_genelist):
            pattern_fname='%s%s_patterns.json'%(output_path,clusterID)
            if pattern_json_flag:
                pattern_dt= { strainID_keymap[strain_ind]:str(patt) for strain_ind, patt in enumerate(presence_arr[:, ind])}
                write_json(pattern_dt, pattern_fname, indent=1)
            #print(preorder_strain_list,clusterID)
            #print(''.join([ str(patt) for patt in presence_arr[:, ind]]))
            with open(pattern_fname,'w') as write_pattern:
                write_pattern.write('{"patterns":"'+''.join([ str(patt) for patt in presence_arr[:, ind]])+'"}')


def process_gain_loss(path, merged_gain_loss_output):
    ##  infer gain/loss event
    tree = infer_gene_gain_loss(path)
    create_visible_pattern_dictionary(tree)
    set_seq_to_patternseq(tree)
    set_visible_pattern_to_ignore(tree,p=-1,mergeequalstrains=True)

    def myminimizer(c):
        return compute_totallh(tree,c)

    from scipy.optimize import minimize
    with np.errstate(divide='ignore'):
        try:
            res1 = minimize(myminimizer,[0.5,1.],method='L-BFGS-B',bounds = [(0.01,0.99),(0.1,100.)])
            success1 = res1.success
        except:
            res1 = type('', (), {})()
            res1.fun = np.inf
            success1 = False
        try:
            res2 = minimize(myminimizer,[0.2,1.],method='L-BFGS-B',bounds = [(0.01,0.99),(0.1,100.)])
            success2 = res2.success
        except:
            res2 = type('', (), {})()
            res2.fun = np.inf
            success2 = False
        try:
            res3 = minimize(myminimizer,[0.8,1.],method='L-BFGS-B',bounds = [(0.01,0.99),(0.1,100.)])
            success3 = res3.success
        except:
            res3 = type('', (), {})()
            res3.fun = np.inf
            success3 = False
    if (success1 or success2 or success3) == True:
        print('successfully estimated the gtr parameters. Reconstructing ancestral states...')
        #get the best of the three numerical estimates
        minimalindex = (res1.fun,res2.fun,res3.fun).index(min(res1.fun,res2.fun,res3.fun))
        res = (res1,res2,res3)[minimalindex]
        change_gtr_parameters_forgainloss(tree,res.x[0],res.x[1])
        print('estimated gain/loss rate: ',res.x[0],res.x[1])
        tree.reconstruct_anc(method='ml')
        export_gain_loss(tree,path,merged_gain_loss_output)
    else:
        print('Warning: failed to estimated the gtr parameters by ML.')
        #import ipdb;ipdb.set_trace()
        change_gtr_parameters_forgainloss(tree,0.5,1.0)
        tree.reconstruct_anc(method='ml')
        export_gain_loss(tree,path,merged_gain_loss_output)


def create_visible_pattern_dictionary(tree):
    """
    create a sequence in all leaves such that each presence absence pattern occurs only once
    """
    #create a pattern dictionary
    #patterndict = {pattern_tuple: [first position in pseudoalignment with pattern, number of genes with this pattern,indicator to include this pattern in the estimation]}
    #clusterdict = {first position with pattern: [number of genes with pattern,indicator to include gene in gtr inference]}
    #initialize dictionaries
    tree.tree.patterndict = {}
    numstrains = len(tree.tree.get_terminals())
    corepattern = ('1',)*numstrains
    nullpattern = ('0',)*numstrains
    tree.tree.clusterdict = {}
    #create dictionaries
    numgenes = tree.tree.get_terminals()[0].genepresence.shape[0]
    for genenumber in range(numgenes):
        pattern=()
        for leaf in tree.tree.get_terminals():
            pattern = pattern + (leaf.genepresence[genenumber],)
        if pattern == nullpattern:
            print("Warning: There seems to be a nullpattern in the data! Check your presence absence pseudoalignment at pos", genenumber+1)
        if pattern in tree.tree.patterndict:
            tree.tree.patterndict[pattern][1] = tree.tree.patterndict[pattern][1]+1
            tree.tree.clusterdict[tree.tree.patterndict[pattern][0]] = [tree.tree.patterndict[pattern][1],1]
        else:
            tree.tree.patterndict[pattern] = [genenumber,1,1]
            tree.tree.clusterdict[tree.tree.patterndict[pattern][0]] = [tree.tree.patterndict[pattern][1],1]

    #thin sequence to unique pattern and save result to node.patternseq
    for node in tree.tree.find_clades():
        if hasattr(node, 'sequence'):
            if len(node.sequence) != numgenes:
                print ("Warning: Nonmatching number of genes in sequence")
            node.patternseq = node.sequence[sorted(tree.tree.clusterdict.keys())]
            # add the all zero pattern at the end of all pattern
            node.patternseq = np.append(node.patternseq,['0',])

    # add an artificial pattern of all zero (nullpattern)
    tree.tree.patterndict[nullpattern] = [numgenes,0,0]
    tree.tree.clusterdict[tree.tree.patterndict[nullpattern][0]] = [tree.tree.patterndict[nullpattern][1],0]
    #create lists for abundance of pattern and inclusion_flag, resp..
    tree.tree.pattern_abundance = [tree.tree.clusterdict[key][0] for key in sorted(tree.tree.clusterdict.keys())]
    tree.tree.pattern_include = [tree.tree.clusterdict[key][1] for key in sorted(tree.tree.clusterdict.keys())]
    #save the index of the first core pattern
    # check whether there is a corepattern (there should always be a corepattern, unless you are using single cell sequencing data.)
    if corepattern in tree.tree.patterndict:
        tree.tree.corepattern_index = sorted(tree.tree.clusterdict.keys()).index(tree.tree.patterndict[corepattern][0])

def index2pattern(index,numstrains):
    """
    transforms a set of indices to the pattern where only these are 1, all others are 0
    """
    pattern = [0] * numstrains
    for ind in index:
        pattern[ind] = 1
    return tuple(pattern)


def index2pattern_reverse(index,numstrains):
    """
    transforms a set of indices to the pattern where only these are 0, all others are 1
    """
    pattern = [1] * numstrains
    for ind in index:
        pattern[ind] = 0
    return tuple(pattern)

def create_ignoring_pattern_dictionary(tree,p = 0):
    """
    create a dictionary of pattern that correspond to extended core genes and extended unique genes
    these pattern will be ignored in the inference of gene gain/loss rates
    """
    #create a pattern dictionary
    #unpatterndict = {pattern_tuple: [first position in pseudoalignment with pattern, number of genes with this pattern]}
    #initialize dictionaries
    import itertools
    tree.tree.unpatterndict = {}
    numstrains = len(tree.tree.get_terminals())
    if p == 0:
        p = int(numstrains/10)
    corepattern = ('1',)*numstrains
    nullpattern = ('0',)*numstrains

    #all sets of indices for p or less of numstrains individuals
    myindices = iter(())
    for i in range(p):
        myindices = itertools.chain(myindices, itertools.combinations(range(numstrains),i+1))

    for indices in myindices:
        tree.tree.unpatterndict[index2pattern(indices,numstrains)] = [-1,0,0]
        tree.tree.unpatterndict[index2pattern_reverse(indices,numstrains)] = [-1,0,0]


def create_distance_matrix(tree):
    numstrains = len(tree.tree.get_terminals())
    tree.tree.distance_matrix = np.zeros([numstrains,numstrains])
    i = 0
    for leaf1 in tree.tree.get_terminals():
        j = 0
        for leaf2 in tree.tree.get_terminals():
            tree.tree.distance_matrix[i,j] = tree.tree.distance(leaf1,leaf2)
            j += 1
        i += 1

def merge_strains(distances,indices,mindist = 0.0):
    remain = set(indices)
    final = set()
    tempdelset = set()
    while len(remain) >0:
        i = remain.pop()
        final.add(i)
        for j in remain:
            if distances[i,j] <= mindist:
                tempdelset.add(j)
        remain.difference_update(tempdelset)
        tempdelset.clear
    return len(final)

def set_visible_pattern_to_ignore(tree,p = -1,mergeequalstrains = False,lowfreq = True, highfreq = True):
    """
    sets all pattern with at most p strains or at least numstrains-p strains to ignore
    """
    numstrains = len(tree.tree.get_terminals())
    if mergeequalstrains:
        if not hasattr(tree.tree,'distance_matrix'):
            create_distance_matrix(tree)
        numstrains = merge_strains(tree.tree.distance_matrix,np.array(range(numstrains)) )
    if p == -1:
        p = int(numstrains/10)
    for pattern in tree.tree.patterndict.keys():
        #freq = sum([int(i) for i in pattern])
        freq = pattern.count('1')
        if lowfreq:
            if mergeequalstrains:
                # indices of individuals in the pattern
                strainindices = np.where(np.array(pattern) == '1')[0]
                #no_observedstrains = len(strainindices)
                lowfreq = merge_strains(tree.tree.distance_matrix,strainindices)
            else:
                lowfreq = freq
            if lowfreq <= p:
                tree.tree.patterndict[pattern][2] = 0
                tree.tree.clusterdict[tree.tree.patterndict[pattern][0]][1] = 0
        if highfreq:
            if mergeequalstrains:
                # indices of individuals in the pattern
                strainindices = np.where(np.array(pattern) == '0')[0]
                #no_observedstrains = len(strainindices)
                highfreq = merge_strains(tree.tree.distance_matrix,strainindices)
            else:
                highfreq = numstrains -freq
            if highfreq <= p:
                tree.tree.patterndict[pattern][2] = 0
                tree.tree.clusterdict[tree.tree.patterndict[pattern][0]][1] = 0

    tree.tree.pattern_include = [tree.tree.clusterdict[key][1] for key in sorted(tree.tree.clusterdict.keys())]
    if sum(tree.tree.pattern_include) == 0:
        print('WARNING all pattern have been excluded, estimation of parameters is thus impossible')


def _check_seq_and_patternseq(tree):
    for leaf in tree.tree.get_terminals():
        if all(leaf.sequence != leaf.patternseq):
            print('WARNING: wrong pattern in ',leaf.name)
        else:
            print(leaf.name, ' is ok')

def compute_lh(tree,verbose=0):
    """
    compute the likelihood for each gene presence pattern in the sequence given the gtr parameters
    """

    min_branch_length = 1e-10
    L = tree.tree.get_terminals()[0].sequence.shape[0]
    n_states = tree.gtr.alphabet.shape[0]
    if verbose > 2:
        print ("Walking up the tree, computing likelihoods for the pattern in the leaves...")
    for leaf in tree.tree.get_terminals():
        # in any case, set the profile
        leaf.profile = seq_utils.seq2prof(leaf.sequence, tree.gtr.profile_map)
        leaf.lh_prefactor = np.zeros(L)
    for node in tree.tree.get_nonterminals(order='postorder'): #leaves -> root
        # regardless of what was before, set the profile to ones
        node.lh_prefactor = np.zeros(L)
        node.profile = np.ones((L, n_states)) # this has to be ones in each entry -> we will multiply it
        for ch in node.clades:
            ch.seq_msg_to_parent = tree.gtr.propagate_profile(ch.profile,
                max(ch.branch_length, min_branch_length),
                return_log=False) # raw prob to transfer prob up
            node.profile *= ch.seq_msg_to_parent
            node.lh_prefactor += ch.lh_prefactor
        pre = node.profile.sum(axis=1) #sum over nucleotide states

        node.profile = (node.profile.T/pre).T # normalize so that the sum is 1
        node.lh_prefactor += np.log(pre) # and store log-prefactor

    tree.tree.root.pattern_profile_lh = (np.log(tree.tree.root.profile).transpose() + tree.tree.root.lh_prefactor).transpose()

def change_gtr_parameters_forgainloss(tree,pi_present,mu):
    genepi = np.array([1.0-pi_present,pi_present])
    genepi /= genepi.sum()
    tree.gtr.Pi = genepi
    # change speed
    tree.gtr.mu = mu
    # flow matrix
    tree.gtr.W = np.ones((2,2))
    np.fill_diagonal(tree.gtr.W, - ((tree.gtr.W).sum(axis=0) - 1.))
    tree.gtr._check_fix_Q()
    # meanwhile tree.gtr._check_fix_Q() keeps mu
    tree.gtr._eig()


def compute_totallh(tree,params,adjustcore = True,verbose = 0):
    """
    compute the total likelihood for all genes with presence pattern set to include
    conditioned on not observing pattern set to not include (e.g. the nullpattern)
    be careful: this function changes the gtr model
    """
    # change the relation of genegain and geneloss rate and the speed
    pi_present = params[0]
    mymu = params[1]
    change_gtr_parameters_forgainloss(tree,pi_present,mymu)

    # this gives the log likelihoods for each pattern
    compute_lh(tree)
    tree.tree.root.pattern_lh =  np.log(np.sum(np.exp(tree.tree.root.pattern_profile_lh)*tree.gtr.Pi,axis=1))
    #compute the likelihood of all genes with included pattern
    tree.tree.root.total_llh =  np.sum(tree.tree.root.pattern_lh * np.array(tree.tree.pattern_abundance) * np.array(tree.tree.pattern_include))
    #adjust for pattern that should not be included
    ll_forsumofignored = np.sum(np.exp( tree.tree.root.pattern_lh) * np.subtract(1,tree.tree.pattern_include))
    if verbose > 2:
        print("adjusting for all pattern that have been set to pattern_include == 0")
    tree.tree.root.total_llh = tree.tree.root.total_llh - ( np.log(1.- ll_forsumofignored) * np.sum(np.array(tree.tree.pattern_abundance) * np.array(tree.tree.pattern_include)) )
    if verbose > 3:
        print("totalLH:", pi_present, mymu, tree.tree.root.total_llh)
    if np.isnan(tree.tree.root.total_llh) or np.isinf(tree.tree.root.total_llh):
        return 1e50
    else:
        return tree.tree.root.total_llh * -1.

def set_seq_to_patternseq(tree):
    for node in tree.tree.find_clades():
        if hasattr(node, 'patternseq'):
            node.sequence = node.patternseq
        else:
            delattr(node,sequence)

def set_seq_to_genepresence(tree):
    for node in tree.tree.find_clades():
        if hasattr(node, 'genepresence'):
            node.sequence = node.genepresence
        else:
            delattr(node,sequence)

def plot_ll(filename,tree,mu =1.0):
    import matplotlib.pyplot as plt
    xaxis = np.linspace(0.0001, 0.999, num=500)
    graph = [compute_totallh(tree,[x,mu]) for x in xaxis]
    plt.plot(xaxis, graph)
    plt.savefig(filename)
    plt.close()

def plot_ll_mu(filename,tree,pi_present =0.5,mu_max = 10):
    import matplotlib.pyplot as plt
    xaxis = np.linspace(0.01, mu_max, num=500)
    graph = [compute_totallh(tree,[pi_present,x]) for x in xaxis]
    plt.plot(xaxis, graph)
    plt.savefig(filename)
    plt.close()