try: import cPickle as pickle
except: import pickle
from time import time
from argparse import ArgumentParser
import importlib
import json
import cPickle
import networkx as nx
import itertools
import pdb
import sys
sys.path.insert(0, './')

from gem.utils      import graph_util, plot_util
from gem.evaluation import visualize_embedding as viz
from gem.evaluation.evaluate_graph_reconstruction import expGR
from gem.evaluation.evaluate_link_prediction import expLP
from gem.evaluation.evaluate_node_classification import expNC
from gem.evaluation.visualize_embedding import expVis

methClassMap = {"gf": "GraphFactorization",
                "hope": "HOPE",
                "lap": "LaplacianEigenmaps",
                "lle": "LocallyLinearEmbedding",
                "node2vec": "node2vec",
                "sdne": "SDNE"}


def learn_emb(MethObj, di_graph, params, res_pre, m_summ):
    if params["experiments"] == ["lp"]:
        X = None
    else:
        print 'Learning Embedding: %s' % m_summ
        if not bool(int(params["load_emb"])):
            X, learn_t = MethObj.learn_embedding(graph=di_graph,
                                                 edge_f=None,
                                                 no_python=True)
            print '\tTime to learn embedding: %f sec' % learn_t
            pickle.dump(X, open('%s_%s.emb' % (res_pre, m_summ), 'wb'))
            pickle.dump(learn_t,
                        open('%s_%s.learnT' % (res_pre, m_summ), 'wb'))
        else:
            X = pickle.load(open('%s_%s.emb' % (res_pre, m_summ),
                                 'rb'))
            try:
                learn_t = pickle.load(open('%s_%s.learnT' % (res_pre, m_summ),
                                           'rb'))
                print '\tTime to learn emb.: %f sec' % learn_t
            except IOError:
                print '\tTime info not found'
    return X


def run_exps(MethObj, di_graph, data_set, node_labels, params):
    m_summ = MethObj.get_method_summary()
    res_pre = "gem/results/%s" % data_set
    X = learn_emb(MethObj, di_graph, params, res_pre, m_summ)
    if "gr" in params["experiments"]:
        expGR(di_graph, MethObj,
              X, params["n_sample_nodes"],
              params["rounds"], res_pre,
              m_summ, is_undirected=params["is_undirected"])
    if "lp" in params["experiments"]:
        expLP(di_graph, MethObj,
              params["n_sample_nodes"],
              params["rounds"], res_pre,
              m_summ, is_undirected=params["is_undirected"])
    if "nc" in params["experiments"]:
        if "nc_test_ratio_arr" not in params:
            print('NC test ratio not provided')
        else:
            expNC(X, node_labels, params["nc_test_ratio_arr"],
                  params["rounds"], res_pre,
                  m_summ)
    if "viz" in params["experiments"]:
        if MethObj.get_method_name() == 'hope_gsvd':
            d = X.shape[1] / 2
            expVis(X[:, :d], res_pre, m_summ,
                   node_labels=node_labels, di_graph=di_graph)
        else:
            expVis(X, res_pre, m_summ,
                   node_labels=node_labels, di_graph=di_graph)


def call_exps(params, data_set):
    print('Dataset: %s' % data_set)
    model_hyp = json.load(
        open('gem/experiments/config/%s.conf' % data_set, 'r')
    )
    if bool(params["node_labels"]):
        node_labels = cPickle.load(
            open('gem/data/%s/node_labels.pickle' % data_set, 'rb')
        )
    else:
        node_labels = None
    di_graph = nx.read_gpickle('gem/data/%s/graph.gpickle' % data_set)
    for d, meth in itertools.product(params["dimensions"], params["methods"]):
        dim = int(d)
        MethClass = getattr(
            importlib.import_module("gem.embedding.%s" % meth),
            methClassMap[meth]
        )
        hyp = {"d": dim}
        hyp.update(model_hyp[meth])
        MethObj = MethClass(hyp)
        run_exps(MethObj, di_graph, data_set, node_labels, params)


if __name__ == '__main__':
    ''' Sample usage
    python experiments/exp.py -data sbm -dim 128 -meth sdne -exp gr,lp
    '''
    t1 = time()
    parser = ArgumentParser(description='Graph Embedding Experiments')
    parser.add_argument('-data', '--data_sets',
                        help='dataset names (default: sbm)')
    parser.add_argument('-dim', '--dimensions',
                        help='embedding dimensions list(default: 2^1 to 2^8)')
    parser.add_argument('-meth', '--methods',
                        help='method list (default: all methods)')
    parser.add_argument('-exp', '--experiments',
                        help='exp list (default: gr,lp,viz,nc)')
    parser.add_argument('-lemb', '--load_emb',
                        help='load saved embeddings (default: False)')
    parser.add_argument('-lexp', '--load_exp',
                        help='load saved experiment results (default: False)')
    parser.add_argument('-rounds', '--rounds',
                        help='number of rounds (default: 5)')
    parser.add_argument('-plot', '--plot',
                        help='plot the results (default: False)')
    parser.add_argument('-saveMAP', '--save_MAP',
                        help='save MAP in a latex table (default: False)')

    params = json.load(open('gem/experiments/config/params.conf', 'r'))
    args = vars(parser.parse_args())
    for k, v in args.iteritems():
        if v is not None:
            params[k] = v
    params["experiments"] = params["experiments"].split(',')
    params["data_sets"] = params["data_sets"].split(',')
    params["rounds"] = int(params["rounds"])
    params["n_sample_nodes"] = int(params["n_sample_nodes"])
    params["is_undirected"] = bool(int(params["is_undirected"]))
    if params["methods"] == "all":
        params["methods"] = methClassMap.keys()
    else:
        params["methods"] = params["methods"].split(',')
    params["dimensions"] = params["dimensions"].split(',')
    if "nc_test_ratio_arr" in params:
        params["nc_test_ratio_arr"] = params["nc_test_ratio_arr"].split(',')
        params["nc_test_ratio_arr"] = \
            [float(ratio) for ratio in params["nc_test_ratio_arr"]]
    for data_set in params["data_sets"]:
        call_exps(params, data_set)