import torch
import torch.nn as nn
import rdkit.Chem as Chem
import torch.nn.functional as F
from hgraph.nnutils import *
from hgraph.encoder import IncHierMPNEncoder
from hgraph.mol_graph import MolGraph
from hgraph.inc_graph import IncTree, IncGraph

class HTuple():
    def __init__(self, node=None, mess=None, vmask=None, emask=None):
        self.node, self.mess = node, mess
        self.vmask, self.emask = vmask, emask

class HierMPNDecoder(nn.Module):

    def __init__(self, vocab, avocab, rnn_type, embed_size, hidden_size, latent_size, depthT, depthG, dropout, attention=False):
        super(HierMPNDecoder, self).__init__()
        self.vocab = vocab
        self.avocab = avocab
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        self.latent_size = latent_size
        self.use_attention = attention
        self.itensor = torch.LongTensor([]).cuda()

        self.hmpn = IncHierMPNEncoder(vocab, avocab, rnn_type, embed_size, hidden_size, depthT, depthG, dropout)
        self.rnn_cell = self.hmpn.tree_encoder.rnn
        self.E_assm = self.hmpn.E_i 
        self.E_order = torch.eye(MolGraph.MAX_POS).cuda()

        self.topoNN = nn.Sequential(
                nn.Linear(hidden_size + latent_size, hidden_size),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_size, 1)
        )
        self.clsNN = nn.Sequential(
                nn.Linear(hidden_size + latent_size, hidden_size),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_size, vocab.size()[0])
        )
        self.iclsNN = nn.Sequential(
                nn.Linear(hidden_size + latent_size, hidden_size),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_size, vocab.size()[1])
        )
        self.matchNN = nn.Sequential(
                nn.Linear(hidden_size + embed_size + MolGraph.MAX_POS, hidden_size),
                nn.ReLU(),
        )
        self.W_assm = nn.Linear(hidden_size, latent_size)

        if latent_size != hidden_size:
            self.W_root = nn.Linear(latent_size, hidden_size)

        if self.use_attention:
            self.A_topo = nn.Linear(hidden_size, latent_size)
            self.A_cls = nn.Linear(hidden_size, latent_size)
            self.A_assm = nn.Linear(hidden_size, latent_size)

        self.topo_loss = nn.BCEWithLogitsLoss(size_average=False)
        self.cls_loss = nn.CrossEntropyLoss(size_average=False)
        self.icls_loss = nn.CrossEntropyLoss(size_average=False)
        self.assm_loss = nn.CrossEntropyLoss(size_average=False)
        
    def apply_tree_mask(self, tensors, cur, prev):
        fnode, fmess, agraph, bgraph, cgraph, scope = tensors
        agraph = agraph * index_select_ND(cur.emask, 0, agraph)
        bgraph = bgraph * index_select_ND(cur.emask, 0, bgraph)
        cgraph = cgraph * index_select_ND(prev.vmask, 0, cgraph)
        return fnode, fmess, agraph, bgraph, cgraph, scope

    def apply_graph_mask(self, tensors, hgraph):
        fnode, fmess, agraph, bgraph, scope = tensors
        agraph = agraph * index_select_ND(hgraph.emask, 0, agraph)
        bgraph = bgraph * index_select_ND(hgraph.emask, 0, bgraph)
        return fnode, fmess, agraph, bgraph, scope

    def update_graph_mask(self, graph_batch, new_atoms, hgraph):
        new_atom_index = hgraph.vmask.new_tensor(new_atoms)
        hgraph.vmask.scatter_(0, new_atom_index, 1)

        new_atom_set = set(new_atoms)
        new_bonds = [] #new bonds are the subgraph induced by new_atoms
        for zid in new_atoms:
            for nid in graph_batch[zid]:
                if nid not in new_atom_set: continue
                new_bonds.append( graph_batch[zid][nid]['mess_idx'] )

        new_bond_index = hgraph.emask.new_tensor(new_bonds)
        if len(new_bonds) > 0:
            hgraph.emask.scatter_(0, new_bond_index, 1)
        return new_atom_index, new_bond_index

    def init_decoder_state(self, tree_batch, tree_tensors, src_root_vecs):
        batch_size = len(src_root_vecs)
        num_mess = len(tree_tensors[1])
        agraph = tree_tensors[2].clone()
        bgraph = tree_tensors[3].clone()

        for i,tup in enumerate(tree_tensors[-1]):
            root = tup[0]
            assert agraph[root,-1].item() == 0
            agraph[root,-1] = num_mess + i
            for v in tree_batch.successors(root):
                mess_idx = tree_batch[root][v]['mess_idx'] 
                assert bgraph[mess_idx,-1].item() == 0
                bgraph[mess_idx,-1] = num_mess + i

        new_tree_tensors = tree_tensors[:2] + [agraph, bgraph] + tree_tensors[4:]
        htree = HTuple()
        htree.mess = self.rnn_cell.get_init_state(tree_tensors[1], src_root_vecs)
        htree.emask = torch.cat( [bgraph.new_zeros(num_mess), bgraph.new_ones(batch_size)], dim=0 )

        return htree, new_tree_tensors

    def attention(self, src_vecs, batch_idx, queries, W_att):
        size = batch_idx.size()
        if batch_idx.dim() > 1:
            batch_idx = batch_idx.view(-1)
            queries = queries.view(-1, queries.size(-1))

        src_vecs = src_vecs.index_select(0, batch_idx)
        att_score = torch.bmm( src_vecs, W_att(queries).unsqueeze(-1) )
        att_vecs = F.softmax(att_score, dim=1) * src_vecs
        att_vecs = att_vecs.sum(dim=1)
        return att_vecs if len(size) == 1 else att_vecs.view(size[0], size[1], -1)

    def get_topo_score(self, src_tree_vecs, batch_idx, topo_vecs):
        if self.use_attention:
            topo_cxt = self.attention(src_tree_vecs, batch_idx, topo_vecs, self.A_topo)
        else:
            topo_cxt = src_tree_vecs.index_select(index=batch_idx, dim=0)
        return self.topoNN( torch.cat([topo_vecs, topo_cxt], dim=-1) ).squeeze(-1)

    def get_cls_score(self, src_tree_vecs, batch_idx, cls_vecs, cls_labs):
        if self.use_attention:
            cls_cxt = self.attention(src_tree_vecs, batch_idx, cls_vecs, self.A_cls)
        else:
            cls_cxt = src_tree_vecs.index_select(index=batch_idx, dim=0)

        cls_vecs = torch.cat([cls_vecs, cls_cxt], dim=-1)
        cls_scores = self.clsNN(cls_vecs)

        if cls_labs is None: #inference mode
            icls_scores = self.iclsNN(cls_vecs) #no masking
        else:
            vocab_masks = self.vocab.get_mask(cls_labs)
            icls_scores = self.iclsNN(cls_vecs) + vocab_masks #apply mask by log(x + mask): mask=0 or -INF
        return cls_scores, icls_scores

    def get_assm_score(self, src_graph_vecs, batch_idx, assm_vecs):
        if self.use_attention:
            assm_cxt = self.attention(src_graph_vecs, batch_idx, assm_vecs, self.A_assm)
        else:
            assm_cxt = index_select_ND(src_graph_vecs, 0, batch_idx)
        return (self.W_assm(assm_vecs) * assm_cxt).sum(dim=-1)

    def forward(self, src_mol_vecs, graphs, tensors, orders):
        batch_size = len(orders)
        tree_batch, graph_batch = graphs
        tree_tensors, graph_tensors = tensors
        inter_tensors = tree_tensors

        src_root_vecs, src_tree_vecs, src_graph_vecs = src_mol_vecs
        init_vecs = src_root_vecs if self.latent_size == self.hidden_size else self.W_root(src_root_vecs)

        htree, tree_tensors = self.init_decoder_state(tree_batch, tree_tensors, init_vecs)
        hinter = HTuple(
            mess = self.rnn_cell.get_init_state(inter_tensors[1]),
            emask = self.itensor.new_zeros(inter_tensors[1].size(0))
        )
        hgraph = HTuple(
            mess = self.rnn_cell.get_init_state(graph_tensors[1]),
            vmask = self.itensor.new_zeros(graph_tensors[0].size(0)),
            emask = self.itensor.new_zeros(graph_tensors[1].size(0))
        )
        
        all_topo_preds, all_cls_preds, all_assm_preds = [], [], []
        new_atoms = []
        tree_scope = tree_tensors[-1]
        for i in range(batch_size):
            root = tree_batch.nodes[ tree_scope[i][0] ]
            clab, ilab = self.vocab[ root['label'] ]
            all_cls_preds.append( (init_vecs[i], i, clab, ilab) ) #cluster prediction
            new_atoms.extend(root['cluster'])

        subgraph = self.update_graph_mask(graph_batch, new_atoms, hgraph)
        graph_tensors = self.hmpn.embed_graph(graph_tensors) + (graph_tensors[-1],) #preprocess graph tensors

        maxt = max([len(x) for x in orders])
        max_cls_size = max( [len(attr) * 2 for node,attr in tree_batch.nodes(data='cluster')] )

        for t in range(maxt):
            batch_list = [i for i in range(batch_size) if t < len(orders[i])]
            assert htree.emask[0].item() == 0 and hinter.emask[0].item() == 0 and hgraph.vmask[0].item() == 0 and hgraph.emask[0].item() == 0

            subtree = [], []
            for i in batch_list:
                xid, yid, tlab = orders[i][t]
                subtree[0].append(xid)
                if yid is not None:
                    mess_idx = tree_batch[xid][yid]['mess_idx']
                    subtree[1].append(mess_idx)

            subtree = htree.emask.new_tensor(subtree[0]), htree.emask.new_tensor(subtree[1]) 
            htree.emask.scatter_(0, subtree[1], 1)
            hinter.emask.scatter_(0, subtree[1], 1)

            cur_tree_tensors = self.apply_tree_mask(tree_tensors, htree, hgraph)
            cur_inter_tensors = self.apply_tree_mask(inter_tensors, hinter, hgraph)
            cur_graph_tensors = self.apply_graph_mask(graph_tensors, hgraph)
            htree, hinter, hgraph = self.hmpn(cur_tree_tensors, cur_inter_tensors, cur_graph_tensors, htree, hinter, hgraph, subtree, subgraph)

            new_atoms = []
            for i in batch_list:
                xid, yid, tlab = orders[i][t]
                all_topo_preds.append( (htree.node[xid], i, tlab) ) #topology prediction
                if yid is not None:
                    mess_idx = tree_batch[xid][yid]['mess_idx']
                    new_atoms.extend( tree_batch.nodes[yid]['cluster'] ) #NOTE: regardless of tlab = 0 or 1

                if tlab == 0: continue

                cls = tree_batch.nodes[yid]['smiles']
                clab, ilab = self.vocab[ tree_batch.nodes[yid]['label'] ]
                mess_idx = tree_batch[xid][yid]['mess_idx']
                hmess = self.rnn_cell.get_hidden_state(htree.mess)
                all_cls_preds.append( (hmess[mess_idx], i, clab, ilab) ) #cluster prediction using message
                
                inter_label = tree_batch.nodes[yid]['inter_label']
                inter_label = [ (pos, self.vocab[(cls, icls)][1]) for pos,icls in inter_label ]
                inter_size = self.vocab.get_inter_size(ilab)

                if len(tree_batch.nodes[xid]['cluster']) > 2: #uncertainty occurs only when previous cluster is a ring
                    nth_child = tree_batch[yid][xid]['label'] #must be yid -> xid (graph order labeling is different from tree)
                    cands = tree_batch.nodes[yid]['assm_cands']
                    icls = list(zip(*inter_label))[1]
                    cand_vecs = self.enum_attach(hgraph, cands, icls, nth_child)

                    if len(cand_vecs) < max_cls_size:
                        pad_len = max_cls_size - len(cand_vecs)
                        cand_vecs = F.pad(cand_vecs, (0,0,0,pad_len))

                    batch_idx = hgraph.emask.new_tensor( [i] * max_cls_size )
                    all_assm_preds.append( (cand_vecs, batch_idx, 0) ) #the label is always the first of assm_cands

            subgraph = self.update_graph_mask(graph_batch, new_atoms, hgraph)

        topo_vecs, batch_idx, topo_labels = zip_tensors(all_topo_preds)
        topo_scores = self.get_topo_score(src_tree_vecs, batch_idx, topo_vecs)
        topo_loss = self.topo_loss(topo_scores, topo_labels.float())
        topo_acc = get_accuracy_bin(topo_scores, topo_labels)

        cls_vecs, batch_idx, cls_labs, icls_labs = zip_tensors(all_cls_preds)
        cls_scores, icls_scores = self.get_cls_score(src_tree_vecs, batch_idx, cls_vecs, cls_labs)
        cls_loss = self.cls_loss(cls_scores, cls_labs) + self.icls_loss(icls_scores, icls_labs)
        cls_acc = get_accuracy(cls_scores, cls_labs)
        icls_acc = get_accuracy(icls_scores, icls_labs)

        if len(all_assm_preds) > 0:
            assm_vecs, batch_idx, assm_labels = zip_tensors(all_assm_preds)
            assm_scores = self.get_assm_score(src_graph_vecs, batch_idx, assm_vecs)
            assm_loss = self.assm_loss(assm_scores, assm_labels)
            assm_acc = get_accuracy_sym(assm_scores, assm_labels)
        else:
            assm_loss, assm_acc = 0, 1
        
        loss = (topo_loss + cls_loss + assm_loss) / batch_size
        return loss, cls_acc, icls_acc, topo_acc, assm_acc

    def enum_attach(self, hgraph, cands, icls, nth_child):
        cands = self.itensor.new_tensor(cands)
        icls_vecs = self.itensor.new_tensor(icls * len(cands))
        icls_vecs = self.E_assm( icls_vecs )

        nth_child = self.itensor.new_tensor([nth_child] * len(cands.view(-1)))
        order_vecs = self.E_order.index_select(0, nth_child)

        cand_vecs = hgraph.node.index_select(0, cands.view(-1))
        cand_vecs = torch.cat( [cand_vecs, icls_vecs, order_vecs], dim=-1 )
        cand_vecs = self.matchNN(cand_vecs)

        if len(icls) == 2:
            cand_vecs = cand_vecs.view(-1, 2, self.hidden_size).sum(dim=1)
        return cand_vecs

    def decode(self, src_mol_vecs, greedy=True, max_decode_step=100, beam=5):
        src_root_vecs, src_tree_vecs, src_graph_vecs = src_mol_vecs
        batch_size = len(src_root_vecs)

        tree_batch = IncTree(batch_size, node_fdim=2, edge_fdim=3)
        graph_batch = IncGraph(self.avocab, batch_size, node_fdim=self.hmpn.atom_size, edge_fdim=self.hmpn.atom_size + self.hmpn.bond_size)
        stack = [[] for i in range(batch_size)]

        init_vecs = src_root_vecs if self.latent_size == self.hidden_size else self.W_root(src_root_vecs)
        batch_idx = self.itensor.new_tensor(range(batch_size))
        cls_scores, icls_scores = self.get_cls_score(src_tree_vecs, batch_idx, init_vecs, None)
        root_cls = cls_scores.max(dim=-1)[1]
        icls_scores = icls_scores + self.vocab.get_mask(root_cls)
        root_cls, root_icls = root_cls.tolist(), icls_scores.max(dim=-1)[1].tolist()

        super_root = tree_batch.add_node() 
        for bid in range(batch_size):
            clab, ilab = root_cls[bid], root_icls[bid]
            root_idx = tree_batch.add_node( batch_idx.new_tensor([clab, ilab]) )
            tree_batch.add_edge(super_root, root_idx) 
            stack[bid].append(root_idx)

            root_smiles = self.vocab.get_ismiles(ilab)
            new_atoms, new_bonds, attached = graph_batch.add_mol(bid, root_smiles, [], 0)
            tree_batch.register_cgraph(root_idx, new_atoms, new_bonds, attached)
        
        #invariance: tree_tensors is equal to inter_tensors (but inter_tensor's init_vec is 0)
        tree_tensors = tree_batch.get_tensors()
        graph_tensors = graph_batch.get_tensors()

        htree = HTuple( mess = self.rnn_cell.get_init_state(tree_tensors[1]) )
        hinter = HTuple( mess = self.rnn_cell.get_init_state(tree_tensors[1]) )
        hgraph = HTuple( mess = self.rnn_cell.get_init_state(graph_tensors[1]) )
        h = self.rnn_cell.get_hidden_state(htree.mess)
        h[1 : batch_size + 1] = init_vecs #wiring root (only for tree, not inter)
        
        for t in range(max_decode_step):
            batch_list = [ bid for bid in range(batch_size) if len(stack[bid]) > 0 ]
            if len(batch_list) == 0: break

            batch_idx = batch_idx.new_tensor(batch_list)
            cur_tree_nodes = [stack[bid][-1] for bid in batch_list]
            subtree = batch_idx.new_tensor(cur_tree_nodes), batch_idx.new_tensor([])
            subgraph = batch_idx.new_tensor( tree_batch.get_cluster_nodes(cur_tree_nodes) ), batch_idx.new_tensor( tree_batch.get_cluster_edges(cur_tree_nodes) )

            htree, hinter, hgraph = self.hmpn(tree_tensors, tree_tensors, graph_tensors, htree, hinter, hgraph, subtree, subgraph)
            topo_scores = self.get_topo_score(src_tree_vecs, batch_idx, htree.node.index_select(0, subtree[0]))
            topo_scores = torch.sigmoid(topo_scores)
            if greedy:
                topo_preds = topo_scores.tolist()
            else:
                topo_preds = torch.bernoulli(topo_scores).tolist()

            new_mess = []
            expand_list = []
            for i,bid in enumerate(batch_list):
                if topo_preds[i] > 0.5 and tree_batch.can_expand(stack[bid][-1]):
                    expand_list.append( (len(new_mess), bid) )
                    new_node = tree_batch.add_node() #new node label is yet to be predicted
                    edge_feature = batch_idx.new_tensor( [stack[bid][-1], new_node, 0] ) #parent to child is 0
                    new_edge = tree_batch.add_edge(stack[bid][-1], new_node, edge_feature) 
                    stack[bid].append(new_node)
                    new_mess.append(new_edge)
                else:
                    child = stack[bid].pop()
                    if len(stack[bid]) > 0:
                        nth_child = tree_batch.graph.in_degree(stack[bid][-1]) #edge child -> father has not established
                        edge_feature = batch_idx.new_tensor( [child, stack[bid][-1], nth_child] )
                        new_edge = tree_batch.add_edge(child, stack[bid][-1], edge_feature)
                        new_mess.append(new_edge)

            subtree = subtree[0], batch_idx.new_tensor(new_mess)
            subgraph = [], []
            htree, hinter, hgraph = self.hmpn(tree_tensors, tree_tensors, graph_tensors, htree, hinter, hgraph, subtree, subgraph)
            cur_mess = self.rnn_cell.get_hidden_state(htree.mess).index_select(0, subtree[1])

            if len(expand_list) > 0:
                idx_in_mess, expand_list = zip(*expand_list)
                idx_in_mess = batch_idx.new_tensor( idx_in_mess )
                expand_idx = batch_idx.new_tensor( expand_list )
                forward_mess = cur_mess.index_select(0, idx_in_mess)
                cls_scores, icls_scores = self.get_cls_score(src_tree_vecs, expand_idx, forward_mess, None)
                scores, cls_topk, icls_topk = hier_topk(cls_scores, icls_scores, self.vocab, beam)
                if not greedy:
                    scores = torch.exp(scores) #score is output of log_softmax
                    shuf_idx = torch.multinomial(scores, beam, replacement=True).tolist()

            for i,bid in enumerate(expand_list):
                new_node, fa_node = stack[bid][-1], stack[bid][-2]
                success = False
                cls_beam = range(beam) if greedy else shuf_idx[i]
                for kk in cls_beam: #try until one is chemically valid
                    if success: break
                    clab, ilab = cls_topk[i][kk], icls_topk[i][kk]
                    node_feature = batch_idx.new_tensor( [clab, ilab] )
                    tree_batch.set_node_feature(new_node, node_feature)
                    smiles, ismiles = self.vocab.get_smiles(clab), self.vocab.get_ismiles(ilab)
                    fa_cluster, _, fa_used = tree_batch.get_cluster(fa_node)
                    inter_cands, anchor_smiles, attach_points = graph_batch.get_assm_cands(fa_cluster, fa_used, ismiles)

                    if len(inter_cands) == 0:
                        continue
                    elif len(inter_cands) == 1:
                        sorted_cands = [(inter_cands[0], 0)]
                        nth_child = 0
                    else:
                        nth_child = tree_batch.graph.in_degree(fa_node)
                        icls = [self.vocab[ (smiles,x) ][1] for x in anchor_smiles]
                        cands = inter_cands if len(attach_points) <= 2 else [ (x[0],x[-1]) for x in inter_cands ]
                        cand_vecs = self.enum_attach(hgraph, cands, icls, nth_child)

                        batch_idx = batch_idx.new_tensor( [bid] * len(inter_cands) )
                        assm_scores = self.get_assm_score(src_graph_vecs, batch_idx, cand_vecs).tolist()
                        sorted_cands = sorted( list(zip(inter_cands, assm_scores)), key = lambda x:x[1], reverse=True )

                    for inter_label,_ in sorted_cands:
                        inter_label = list(zip(inter_label, attach_points))
                        if graph_batch.try_add_mol(bid, ismiles, inter_label):
                            new_atoms, new_bonds, attached = graph_batch.add_mol(bid, ismiles, inter_label, nth_child)
                            tree_batch.register_cgraph(new_node, new_atoms, new_bonds, attached)
                            tree_batch.update_attached(fa_node, inter_label)
                            success = True
                            break

                if not success: #force backtrack
                    child = stack[bid].pop() #pop the dummy new_node which can't be added
                    nth_child = tree_batch.graph.in_degree(stack[bid][-1]) 
                    edge_feature = batch_idx.new_tensor( [child, stack[bid][-1], nth_child] )
                    new_edge = tree_batch.add_edge(child, stack[bid][-1], edge_feature)

                    child = stack[bid].pop() 
                    if len(stack[bid]) > 0:
                        nth_child = tree_batch.graph.in_degree(stack[bid][-1]) 
                        edge_feature = batch_idx.new_tensor( [child, stack[bid][-1], nth_child] )
                        new_edge = tree_batch.add_edge(child, stack[bid][-1], edge_feature)

        return graph_batch.get_mol()