#-*- coding: utf8
from __future__ import division, print_function

from prme import mrr

import pandas as pd
import plac
import numpy as np

def main(model, out_fpath):
    store = pd.HDFStore(model)
    
    from_ = store['from_'][0][0]
    to = store['to'][0][0]
    assert from_ == 0
    
    trace_fpath = store['trace_fpath'][0][0]

    XP_hk = store['XP_hk'].values
    XP_ok = store['XP_ok'].values
    XG_ok = store['XG_ok'].values
    alpha = store['alpha'].values[0][0]
    tau = store['tau'].values[0][0]

    hyper2id = dict(store['hyper2id'].values)
    obj2id = dict(store['obj2id'].values)
    
    HSDs = []
    dts = []

    with open(trace_fpath) as trace_file:
        for i, l in enumerate(trace_file): 
            if i < to:
                continue

            dt, h, s, d = l.strip().split('\t')
            if h in hyper2id and s in obj2id and d in obj2id:
                dts.append(float(dt))
                HSDs.append([hyper2id[h], obj2id[s], obj2id[d]])
    
    num_queries = min(10000, len(HSDs))
    queries = np.random.choice(len(HSDs), size=num_queries)
    
    dts = np.array(dts, order='C', dtype='d')
    HSDs = np.array(HSDs, order='C', dtype='i4')
    rrs = mrr.compute(dts, HSDs, XP_hk, XP_ok, XG_ok, alpha, tau)
    
    np.savetxt(out_fpath, rrs)
    store.close()
    
plac.call(main)