import networkx as nx
import numpy as np
import openie
from fuzzywuzzy import fuzz
from jericho.util import clean


class StateAction(object):

    def __init__(self, spm, vocab, vocab_rev, tsv_file, max_word_len):
        self.graph_state = nx.DiGraph()
        self.max_word_len = max_word_len
        self.graph_state_rep = []
        self.visible_state = ""
        self.drqa_input = ""
        self.vis_pruned_actions = []
        self.pruned_actions_rep = []
        self.sp = spm
        self.vocab_act = vocab
        self.vocab_act_rev = vocab_rev
        self.vocab_kge = self.load_vocab_kge(tsv_file)
        self.adj_matrix = np.zeros((len(self.vocab_kge['entity']), len(self.vocab_kge['entity'])))
        self.room = ""

    def visualize(self):
        # import matplotlib.pyplot as plt
        pos = nx.spring_layout(self.graph_state)
        edge_labels = {e: self.graph_state.edges[e]['rel'] for e in self.graph_state.edges}
        print(edge_labels)
        nx.draw_networkx_edge_labels(self.graph_state, pos, edge_labels)
        nx.draw(self.graph_state, pos=pos, with_labels=True, node_size=200, font_size=10)
        #plt.show()

    def load_vocab_kge(self, tsv_file):
        ent = {}
        with open(tsv_file, 'r') as f:
            for line in f:
                e, eid = line.split('\t')
                ent[e.strip()] = int(eid.strip())
        rel = {}
        with open(tsv_file, 'r') as f:
            for line in f:
                r, rid = line.split('\t')
                rel[r.strip()] = int(rid.strip())

        return {'entity': ent, 'relation': rel}

    def update_state(self, visible_state, inventory_state, objs, prev_action=None, cache=None):

        prev_room = self.room

        graph_copy = self.graph_state.copy()
        con_cs = [graph_copy.subgraph(c) for c in nx.weakly_connected_components(graph_copy)]

        prev_room_subgraph = None
        prev_you_subgraph = None

        for con_c in con_cs:
            for node in con_c.nodes:
                node = set(str(node).split())
                if set(prev_room.split()).issubset(node):
                    prev_room_subgraph = nx.induced_subgraph(graph_copy, con_c.nodes)

        for edge in self.graph_state.edges:
            if 'you' in edge[0]:
                graph_copy.remove_edge(*edge)

        self.graph_state = graph_copy

        visible_state = visible_state.split('\n')
        room = visible_state[0]
        visible_state = clean(' '.join(visible_state[1:]))

        dirs = ['north', 'south', 'east', 'west', 'southeast', 'southwest', 'northeast', 'northwest', 'up', 'down']

        self.visible_state = str(visible_state)
        rules = []

        if cache is None:
            sents = openie.call_stanford_openie(self.visible_state)['sentences']
        else:
            sents = cache

        if sents == "":
            return []

        in_aliases = ['are in', 'are facing', 'are standing', 'are behind', 'are above', 'are below', 'are in front']

        in_rl = []
        in_flag = False

        for i, ov in enumerate(sents):
            sent = ' '.join([a['word'] for a in ov['tokens']])
            triple = ov['openie']
            for d in dirs:
                if d in sent and i != 0:
                    rules.append((room, 'has', 'exit to ' + d))

            for tr in triple:
                h, r, t = tr['subject'].lower(), tr['relation'].lower(), tr['object'].lower()

                if h == 'you':
                    for rp in in_aliases:
                        if fuzz.token_set_ratio(r, rp) > 80:
                            r = "in"
                            in_rl.append((h, r, t))
                            in_flag = True
                            break

                if h == 'it':
                    break
                if not in_flag:
                    rules.append((h, r, t))

        if in_flag:
            cur_t = in_rl[0]
            for h, r, t in in_rl:
                if set(cur_t[2].split()).issubset(set(t.split())):
                    cur_t = h, r, t
            rules.append(cur_t)
            room = cur_t[2]

        try:
            items = inventory_state.split(':')[1].split('\n')[1:]
            for item in items:
                rules.append(('you', 'have', str(' ' .join(item.split()[1:]))))
        except:
            pass

        if prev_action is not None:
            for d in dirs:
                if d in prev_action and self.room != "":
                    rules.append((prev_room, d + ' of', room))
                    if prev_room_subgraph is not None:
                        for ed in prev_room_subgraph.edges:
                            rules.append((ed[0], prev_room_subgraph[ed]['rel'], ed[1]))
                    break

        for o in objs:
            #if o != 'all':
            rules.append((str(o), 'in', room))

        add_rules = rules

        for rule in add_rules:
            u = '_'.join(str(rule[0]).split())
            v = '_'.join(str(rule[2]).split())
            if u in self.vocab_kge['entity'].keys() and v in self.vocab_kge['entity'].keys():
                if u != 'it' and v != 'it':
                    self.graph_state.add_edge(rule[0], rule[2], rel=rule[1])

        return add_rules, sents

    def get_state_rep_kge(self):
        ret = []
        self.adj_matrix = np.zeros((len(self.vocab_kge['entity']), len(self.vocab_kge['entity'])))

        for u, v in self.graph_state.edges:
            u = '_'.join(str(u).split())
            v = '_'.join(str(v).split())

            if u not in self.vocab_kge['entity'].keys() or v not in self.vocab_kge['entity'].keys():
                break

            u_idx = self.vocab_kge['entity'][u]
            v_idx = self.vocab_kge['entity'][v]
            self.adj_matrix[u_idx][v_idx] = 1

            ret.append(self.vocab_kge['entity'][u])
            ret.append(self.vocab_kge['entity'][v])

        return list(set(ret))

    def get_state_kge(self):
        ret = []
        self.adj_matrix = np.zeros((len(self.vocab_kge['entity']), len(self.vocab_kge['entity'])))

        for u, v in self.graph_state.edges:
            u = '_'.join(str(u).split())
            v = '_'.join(str(v).split())

            if u not in self.vocab_kge['entity'].keys() or v not in self.vocab_kge['entity'].keys():
                break

            u_idx = self.vocab_kge['entity'][u]
            v_idx = self.vocab_kge['entity'][v]
            self.adj_matrix[u_idx][v_idx] = 1

            ret.append(u)
            ret.append(v)

        return list(set(ret))

    def get_obs_rep(self, *args):
        ret = [self.get_visible_state_rep_drqa(ob) for ob in args]
        return pad_sequences(ret, maxlen=300)

    def get_visible_state_rep_drqa(self, state_description):
        remove = ['=', '-', '\'', ':', '[', ']', 'eos', 'EOS', 'SOS', 'UNK', 'unk', 'sos', '<', '>']

        for rm in remove:
            state_description = state_description.replace(rm, '')

        return self.sp.encode_as_ids(state_description)

    def get_action_rep_drqa(self, action):

        action_desc_num = 20 * [0]
        action = str(action)

        for i, token in enumerate(action.split()[:20]):
            short_tok = token[:self.max_word_len]
            action_desc_num[i] = self.vocab_act_rev[short_tok] if short_tok in self.vocab_act_rev else 0

        return action_desc_num

    def step(self, visible_state, inventory_state, objs, prev_action=None, cache=None, gat=True):
        ret, ret_cache = self.update_state(visible_state, inventory_state, objs, prev_action, cache)

        self.pruned_actions_rep = [self.get_action_rep_drqa(a) for a in self.vis_pruned_actions]

        inter = self.visible_state #+ "The actions are:" + ",".join(self.vis_pruned_actions) + "."
        self.drqa_input = self.get_visible_state_rep_drqa(inter)

        self.graph_state_rep = self.get_state_rep_kge(), self.adj_matrix

        return ret, ret_cache


def pad_sequences(sequences, maxlen=None, dtype='int32', value=0.):
    '''
    Partially borrowed from Keras
    # Arguments
        sequences: list of lists where each element is a sequence
        maxlen: int, maximum length
        dtype: type to cast the resulting sequence.
        value: float, value to pad the sequences to the desired value.
    # Returns
        x: numpy array with dimensions (number_of_sequences, maxlen)
    '''
    lengths = [len(s) for s in sequences]
    nb_samples = len(sequences)
    if maxlen is None:
        maxlen = np.max(lengths)
    # take the sample shape from the first non empty sequence
    # checking for consistency in the main loop below.
    sample_shape = tuple()
    for s in sequences:
        if len(s) > 0:
            sample_shape = np.asarray(s).shape[1:]
            break
    x = (np.ones((nb_samples, maxlen) + sample_shape) * value).astype(dtype)
    for idx, s in enumerate(sequences):
        if len(s) == 0:
            continue  # empty list was found
        # pre truncating
        trunc = s[-maxlen:]
        # check `trunc` has expected shape
        trunc = np.asarray(trunc, dtype=dtype)
        if trunc.shape[1:] != sample_shape:
            raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' %
                             (trunc.shape[1:], idx, sample_shape))
        # post padding
        x[idx, :len(trunc)] = trunc
    return x