import os
import sys
import re
import collections
import numpy as np
import scipy
import json
import itertools
import pickle
import gc
import gzip
import argparse

def tokenize(sent):
    '''Return the tokens of a sentence including punctuation.
    >>> tokenize('Bob dropped the apple. Where is the apple?')
    ['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?']
    '''
    return re.findall('(?:\w+)|\S',sent)

def list_to_map(l):
    '''Convert a list of values to a map from values to indices'''
    return {val:i for i,val in enumerate(l)}

def parse_stories(lines):
    '''
    Parse stories provided in the bAbi tasks format, with knowledge graph.
    '''
    data = []
    story = []
    for line in lines:
        if line[-1] == "\n":
            line = line[:-1]
        nid, line = line.split(' ', 1)
        nid = int(nid)
        if nid == 1:
            story = []
            questions = []
        if '\t' in line:
            q, apre = line.split('\t')[:2]
            a = apre.split(',')
            q = tokenize(q)
            substory = [x for x in story if x]
            data.append((substory, q, a))
            story.append('')
        else:
            line, graph = line.split('=', 1)
            sent = tokenize(line)
            graph_parsed = json.loads(graph)
            story.append((sent, graph_parsed))
    return data

def get_stories(taskname):
    with open(taskname, 'r') as f:
        lines = f.readlines()
    return parse_stories(lines)

def get_max_sentence_length(stories):
    return max((max((len(sentence) for (sentence, graph) in sents_graphs)) for (sents_graphs, query, answer) in stories))

def get_max_query_length(stories):
    return max((len(query) for (sents_graphs, query, answer) in stories))

def get_max_num_queries(stories):
    return max((len(queries) for (sents_graphs, query, answer) in stories))

def get_max_nodes_per_iter(stories):
    result = 0
    for (sents_graphs, query, answer) in stories:
        prev_nodes = set()
        for (sentence, graph) in sents_graphs:
            cur_nodes = set(graph["nodes"])
            new_nodes = len(cur_nodes - prev_nodes)
            if new_nodes > result:
                result = new_nodes
            prev_nodes = cur_nodes
    return result

def get_buckets(stories, max_ignore_unbatched=100, max_pad_amount=25):
    sentencecounts = [len(sents_graphs) for (sents_graphs, query, answer) in stories]
    countpairs = sorted(collections.Counter(sentencecounts).items())

    buckets = []
    smallest_left_val = 0
    num_unbatched = max_ignore_unbatched
    for val,ct in countpairs:
        num_unbatched += ct
        if val - smallest_left_val > max_pad_amount or num_unbatched > max_ignore_unbatched:
            buckets.append(val)
            smallest_left_val = val
            num_unbatched = 0
    if buckets[-1] != countpairs[-1][0]:
        buckets.append(countpairs[-1][0])

    return buckets

PAD_WORD = "<PAD>"

def get_wordlist(stories):
    words = [PAD_WORD] + sorted(list(set((word
        for (sents_graphs, query, answer) in stories
        for wordbag in itertools.chain((s for s,g in sents_graphs), [query])
        for word in wordbag ))))
    wordmap = list_to_map(words)
    return words, wordmap

def get_answer_list(stories):
    words = sorted(list(set(word for (sents_graphs, query, answer) in stories for word in answer)))
    wordmap = list_to_map(words)
    return words, wordmap

def pad_story(story, num_sentences, sentence_length):
    def pad(lst,dlen,pad):
        return lst + [pad]*(dlen - len(lst))
    
    sents_graphs, query, answer = story
    padded_sents_graphs = [(pad(s,sentence_length,PAD_WORD), g) for s,g in sents_graphs]
    padded_query = pad(query,sentence_length,PAD_WORD)

    sentgraph_padding = (pad([],sentence_length,PAD_WORD), padded_sents_graphs[-1][1])
    return (pad(padded_sents_graphs, num_sentences, sentgraph_padding), padded_query, answer)

def get_unqualified_id(s):
    return s.split("#")[0]

def get_graph_lists(stories):
    node_words = sorted(list(set(get_unqualified_id(node)
        for (sents_graphs, query, answer) in stories
        for sent,graph in sents_graphs
        for node in graph["nodes"])))
    nodemap = list_to_map(node_words)
    edge_words = sorted(list(set(get_unqualified_id(edge["type"])
        for (sents_graphs, query, answer) in stories
        for sent,graph in sents_graphs
        for edge in graph["edges"])))
    edgemap = list_to_map(edge_words)
    return node_words, nodemap, edge_words, edgemap

def convert_graph(graphs, nodemap, edgemap, new_nodes_per_iter, dynamic=True):
    num_node_ids = len(nodemap)
    num_edge_types = len(edgemap)

    full_size = len(graphs)*new_nodes_per_iter + 1

    prev_size = 1
    processed_nodes = []
    index_map = {}
    all_num_nodes = []
    all_node_ids = []
    all_node_strengths = []
    all_edges = []
    if not dynamic:
        processed_nodes = list(nodemap.keys())
        index_map = nodemap.copy()
        prev_size = num_node_ids
        full_size = prev_size
        new_nodes_per_iter = 0
    for g in graphs:
        active_nodes = g["nodes"]
        active_edges = g["edges"]

        new_nodes = [e for e in active_nodes if e not in processed_nodes]

        num_new_nodes = len(new_nodes)
        if not dynamic:
            assert num_new_nodes == 0, "Cannot create more nodes in non-dynamic mode!\n{}".format(graphs)
        
        new_node_strengths = np.zeros([new_nodes_per_iter], np.float32)
        new_node_strengths[:num_new_nodes] = 1.0

        new_node_ids = np.zeros([new_nodes_per_iter, num_node_ids], np.float32)
        for i, node in enumerate(new_nodes):
            new_node_ids[i,nodemap[get_unqualified_id(node)]] = 1.0
            index_map[node] = prev_size + i

        next_edges = np.zeros([full_size, full_size, num_edge_types])
        for edge in active_edges:
            next_edges[index_map[edge["from"]],
                       index_map[edge["to"]],
                       edgemap[get_unqualified_id(edge["type"])]] = 1.0

        processed_nodes.extend(new_nodes)
        prev_size += new_nodes_per_iter

        all_num_nodes.append(num_new_nodes)
        all_node_ids.append(new_node_ids)
        all_edges.append(next_edges)
        all_node_strengths.append(new_node_strengths)

    return np.stack(all_num_nodes), np.stack(all_node_strengths), np.stack(all_node_ids), np.stack(all_edges)

def convert_story(story, wordmap, answer_map, graph_node_map, graph_edge_map, new_nodes_per_iter, dynamic=True):
    """
    Converts a story in format
        ([(sentence, graph)], [(index, question_arr, answer)])
    to a consolidated story in format
        (sentence_arr, [graph_arr_dict], [(index, question_arr, answer)])
    and also replaces words according to the input maps
    """
    sents_graphs, query, answer = story

    sentence_arr = [[wordmap[w] for w in s] for s,g in sents_graphs]
    graphs = convert_graph([g for s,g in sents_graphs], graph_node_map, graph_edge_map, new_nodes_per_iter, dynamic)
    query_arr = [wordmap[w] for w in query]
    answer_arr = [answer_map[w] for w in answer]
    return (sentence_arr, graphs, query_arr, answer_arr)

def process_story(s,bucket_len):
    return convert_story(pad_story(s, bucket_len, sentence_length), wordmap, answer_map, graph_node_map, graph_edge_map, new_nodes_per_iter, dynamic)

def bucket_stories(stories, buckets, wordmap, answer_map, graph_node_map, graph_edge_map, sentence_length, new_nodes_per_iter, dynamic=True):
    return [ [process_story(story,bmax) for story in stories if bstart < len(story[0]) <= bmax]
                for bstart, bmax in zip([0]+buckets,buckets)]

def prepare_stories(stories, dynamic=True):
    sentence_length = max(get_max_sentence_length(stories), get_max_query_length(stories))
    buckets = get_buckets(stories)
    wordlist, wordmap = get_wordlist(stories)
    anslist, ansmap = get_answer_list(stories)
    new_nodes_per_iter = get_max_nodes_per_iter(stories)

    graph_node_list, graph_node_map, graph_edge_list, graph_edge_map = get_graph_lists(stories)
    bucketed = bucket_stories(stories, buckets, wordmap, ansmap, graph_node_map, graph_edge_map, sentence_length, new_nodes_per_iter, dynamic)
    return sentence_length, new_nodes_per_iter, buckets, wordlist, anslist, graph_node_list, graph_edge_list, bucketed

def print_batch(story, wordlist, anslist, file=sys.stdout):
    sents, query, answer = story
    for batch,(s,q,a) in enumerate(zip(sents,query,answer)):
        file.write("Story {}\n".format(batch))
        for sent in s:
            file.write(" ".join([wordlist[word] for word in sent]) + "\n")
        file.write(" ".join(wordlist[word] for word in q) + "\n")
        file.write(" ".join(anslist[word] for word in a.nonzero()[1]) + "\n")

MetadataList = collections.namedtuple("MetadataList", ["sentence_length", "new_nodes_per_iter", "buckets", "wordlist", "anslist", "graph_node_list", "graph_edge_list"])
PreppedStory = collections.namedtuple("PreppedStory", ["converted", "sentences", "query", "answer"])
def generate_metadata(stories, dynamic=True):
    sentence_length = max(get_max_sentence_length(stories), get_max_query_length(stories))
    buckets = get_buckets(stories)
    wordlist, wordmap = get_wordlist(stories)
    anslist, ansmap = get_answer_list(stories)
    new_nodes_per_iter = get_max_nodes_per_iter(stories)
    graph_node_list, graph_node_map, graph_edge_list, graph_edge_map = get_graph_lists(stories)
    metadata = MetadataList(sentence_length, new_nodes_per_iter, buckets, wordlist, anslist, graph_node_list, graph_edge_list)
    return metadata

def preprocess_stories(stories, savedir, dynamic=True, metadata_file=None):
    if metadata_file is None:
        metadata = generate_metadata(stories, dynamic)
    else:
        with open(metadata_file,'rb') as f:
            metadata = pickle.load(f)

    buckets = get_buckets(stories)
    sentence_length, new_nodes_per_iter, old_buckets, wordlist, anslist, graph_node_list, graph_edge_list = metadata
    metadata = metadata._replace(buckets=buckets)

    if not os.path.exists(savedir):
        os.makedirs(savedir)

    with open(os.path.join(savedir,'metadata.p'),'wb') as f:
        pickle.dump(metadata, f)

    bucketed_files = [[] for _ in buckets] 

    for i,story in enumerate(stories):
        bucket_idx, cur_bucket = next(((i,bmax) for (i,(bstart, bmax)) in enumerate(zip([0]+buckets,buckets))
                                        if bstart < len(story[0]) <= bmax), (None,None))
        assert cur_bucket is not None, "Couldn't put story of length {} into buckets {}".format(len(story[0]), buckets)
        bucket_dir = os.path.join(savedir, "bucket_{}".format(cur_bucket))
        if not os.path.exists(bucket_dir):
            os.makedirs(bucket_dir)
        story_fn = os.path.join(bucket_dir, "story_{}.pz".format(i))

        sents_graphs, query, answer = story
        sents = [s for s,g in sents_graphs]
        cvtd = convert_story(pad_story(story, cur_bucket, sentence_length), list_to_map(wordlist), list_to_map(anslist), list_to_map(graph_node_list), list_to_map(graph_edge_list), new_nodes_per_iter, dynamic)

        prepped = PreppedStory(cvtd, sents, query, answer)

        with gzip.open(story_fn, 'wb') as zf:
            pickle.dump(prepped, zf)

        bucketed_files[bucket_idx].append(os.path.relpath(story_fn, savedir))
        gc.collect() # we don't want to use too much memory, so try to clean it up

    with open(os.path.join(savedir,'file_list.p'),'wb') as f:
        pickle.dump(bucketed_files, f)

def main(file, dynamic, metadata_file=None):
    stories = get_stories(file)
    dirname, ext = os.path.splitext(file)
    preprocess_stories(stories, dirname, dynamic, metadata_file)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parse a graph file')
    parser.add_argument("file", help="Graph file to parse")
    parser.add_argument("--static", dest="dynamic", action="store_false", help="Don't use dynamic nodes")
    parser.add_argument("--metadata-file", default=None, help="Use this particular metadata file instead of building it from scratch")
    args = vars(parser.parse_args())
    main(**args)