#################################################################
# Code written by Edward Choi (mp2893@gatech.edu)
# For bug report, please contact author using the email address
#################################################################

import sys, random
import numpy as np
import cPickle as pickle
from collections import OrderedDict
import argparse

import theano
import theano.tensor as T
from theano import config

def sigmoid(x):
  return 1. / (1. + np.exp(-x))

def numpy_floatX(data):
	return np.asarray(data, dtype=config.floatX)

def load_embedding(infile):
	Wemb = np.array(pickle.load(open(infile, 'rb'))).astype(config.floatX)
	return Wemb

def load_params(options):
	params = OrderedDict()
	weights = np.load(options['modelFile'])
	for k,v in weights.iteritems():
		params[k] = v
	if len(options['embFile']) > 0: params['W_emb'] = np.array(pickle.load(open(options['embFile'], 'rb'))).astype(config.floatX)
	return params

def init_tparams(params, options):
	tparams = OrderedDict()
	for key, value in params.iteritems():
		tparams[key] = theano.shared(value, name=key)
	return tparams

def _slice(_x, n, dim):
	if _x.ndim == 3:
		return _x[:, :, n*dim:(n+1)*dim]
	return _x[:, n*dim:(n+1)*dim]

def gru_layer(tparams, emb, name, hiddenDimSize):
	timesteps = emb.shape[0]
	if emb.ndim == 3: n_samples = emb.shape[1]
	else: n_samples = 1

	def stepFn(wx, h, U_gru):
		uh = T.dot(h, U_gru)
		r = T.nnet.sigmoid(_slice(wx, 0, hiddenDimSize) + _slice(uh, 0, hiddenDimSize))
		z = T.nnet.sigmoid(_slice(wx, 1, hiddenDimSize) + _slice(uh, 1, hiddenDimSize))
		h_tilde = T.tanh(_slice(wx, 2, hiddenDimSize) + r * _slice(uh, 2, hiddenDimSize))
		h_new = z * h + ((1. - z) * h_tilde)
		return h_new

	Wx = T.dot(emb, tparams['W_gru_'+name]) + tparams['b_gru_'+name]
	results, updates = theano.scan(fn=stepFn, sequences=[Wx], outputs_info=T.alloc(numpy_floatX(0.0), n_samples, hiddenDimSize), non_sequences=[tparams['U_gru_'+name]], name='gru_layer', n_steps=timesteps)

	return results
	
def build_model(tparams, options):
	alphaHiddenDimSize = options['alphaHiddenDimSize']
	betaHiddenDimSize = options['betaHiddenDimSize']

	x = T.tensor3('x', dtype=config.floatX)

	reverse_emb_t = x[::-1]
	reverse_h_a = gru_layer(tparams, reverse_emb_t, 'a', alphaHiddenDimSize)[::-1] * 0.5
	reverse_h_b = gru_layer(tparams, reverse_emb_t, 'b', betaHiddenDimSize)[::-1] * 0.5

	preAlpha = T.dot(reverse_h_a, tparams['w_alpha']) + tparams['b_alpha']
	preAlpha = preAlpha.reshape((preAlpha.shape[0], preAlpha.shape[1]))
	alpha = (T.nnet.softmax(preAlpha.T)).T

	beta = T.tanh(T.dot(reverse_h_b, tparams['W_beta']) + tparams['b_beta'])
	
	return x, alpha, beta

def padMatrixWithTime(seqs, times, options):
	lengths = np.array([len(seq) for seq in seqs]).astype('int32')
	n_samples = len(seqs)
	maxlen = np.max(lengths)

	x = np.zeros((maxlen, n_samples, options['inputDimSize'])).astype(config.floatX)
	t = np.zeros((maxlen, n_samples)).astype(config.floatX)
	for idx, (seq,time) in enumerate(zip(seqs,times)):
		for xvec, subseq in zip(x[:,idx,:], seq):
			xvec[subseq] = 1.
		t[:lengths[idx], idx] = time

	if options['useLogTime']: t = np.log(t + 1.)

	return x, t, lengths

def padMatrixWithoutTime(seqs, options):
	lengths = np.array([len(seq) for seq in seqs]).astype('int32')
	n_samples = len(seqs)
	maxlen = np.max(lengths)

	x = np.zeros((maxlen, n_samples, options['inputDimSize'])).astype(config.floatX)
	for idx, seq in enumerate(seqs):
		for xvec, subseq in zip(x[:,idx,:], seq):
			xvec[subseq] = 1.

	return x, lengths

def load_data_debug(seqFile, labelFile, timeFile=''):
	sequences = np.array(pickle.load(open(seqFile, 'rb')))
	labels = np.array(pickle.load(open(labelFile, 'rb')))
	if len(timeFile) > 0:
		times = np.array(pickle.load(open(timeFile, 'rb')))

	dataSize = len(labels)
	np.random.seed(0)
	ind = np.random.permutation(dataSize)
	nTest = int(0.15 * dataSize)
	nValid = int(0.10 * dataSize)

	test_indices = ind[:nTest]
	valid_indices = ind[nTest:nTest+nValid]
	train_indices = ind[nTest+nValid:]

	train_set_x = sequences[train_indices]
	train_set_y = labels[train_indices]
	test_set_x = sequences[test_indices]
	test_set_y = labels[test_indices]
	valid_set_x = sequences[valid_indices]
	valid_set_y = labels[valid_indices]
	train_set_t = None
	test_set_t = None
	valid_set_t = None

	if len(timeFile) > 0:
		train_set_t = times[train_indices]
		test_set_t = times[test_indices]
		valid_set_t = times[valid_indices]

	def len_argsort(seq):
		return sorted(range(len(seq)), key=lambda x: len(seq[x]))

	train_sorted_index = len_argsort(train_set_x)
	train_set_x = [train_set_x[i] for i in train_sorted_index]
	train_set_y = [train_set_y[i] for i in train_sorted_index]

	valid_sorted_index = len_argsort(valid_set_x)
	valid_set_x = [valid_set_x[i] for i in valid_sorted_index]
	valid_set_y = [valid_set_y[i] for i in valid_sorted_index]

	test_sorted_index = len_argsort(test_set_x)
	test_set_x = [test_set_x[i] for i in test_sorted_index]
	test_set_y = [test_set_y[i] for i in test_sorted_index]

	if len(timeFile) > 0:
		train_set_t = [train_set_t[i] for i in train_sorted_index]
		valid_set_t = [valid_set_t[i] for i in valid_sorted_index]
		test_set_t = [test_set_t[i] for i in test_sorted_index]

	train_set = (train_set_x, train_set_y, train_set_t)
	valid_set = (valid_set_x, valid_set_y, valid_set_t)
	test_set = (test_set_x, test_set_y, test_set_t)

	return train_set, valid_set, test_set

def load_data(dataFile, labelFile, timeFile):
	test_set_x = np.array(pickle.load(open(dataFile, 'rb')))
	test_set_y = np.array(pickle.load(open(labelFile, 'rb')))
	test_set_t = None
	if len(timeFile) > 0:
		test_set_t = np.array(pickle.load(open(timeFile, 'rb')))

	def len_argsort(seq):
		return sorted(range(len(seq)), key=lambda x: len(seq[x]))

	sorted_index = len_argsort(test_set_x)
	test_set_x = [test_set_x[i] for i in sorted_index]
	test_set_y = [test_set_y[i] for i in sorted_index]
	if len(timeFile) > 0:
		test_set_t = [test_set_t[i] for i in sorted_index]
	
	test_set = (test_set_x, test_set_y, test_set_t)

	return test_set

def print2file(buf, outFile):
	outfd = open(outFile, 'a')
	outfd.write(buf + '\n')
	outfd.close()

def train_RETAIN(
	modelFile='model.npz',
	seqFile='seqFile.txt',
	labelFile='labelFile.txt',
	outFile='outFile.txt',
	timeFile='timeFile.txt',
	typeFile='types.txt',
	useLogTime=True,
	embFile='embFile.txt',
	logEps=1e-8
):
	options = locals().copy()

	if len(timeFile) > 0: useTime = True
	else: useTime = False
	options['useTime'] = useTime

	if len(embFile) > 0: useFixedEmb = True
	else: useFixedEmb = False
	options['useFixedEmb'] = useFixedEmb
	
	print 'Loading the parameters ... ',
	params = load_params(options)
	tparams = init_tparams(params, options)

	options['alphaHiddenDimSize'] = params['w_alpha'].shape[0]
	options['betaHiddenDimSize'] = params['W_beta'].shape[0]
	options['inputDimSize'] = params['W_emb'].shape[0]

	print 'Building the model ... ',
	x, alpha, beta =  build_model(tparams, options)
	get_result = theano.function(inputs=[x], outputs=[alpha, beta], name='get_result')

	print 'Loading data ... ',
	testSet = load_data(seqFile, labelFile, timeFile)
	print 'done'

	types = pickle.load(open(typeFile, 'rb'))
	rtypes = dict([(v,k) for k,v in types.iteritems()])

	print 'Contribution calculation start!!'
	count = 0
	outfd = open(outFile, 'w')
	for index in range(len(testSet[0])):
		if count % 100 == 0: print 'processed %d patients' % count
		count += 1
		batchX = [testSet[0][index]]
		label = testSet[1][index]

		if useTime: 
			batchT = [testSet[2][index]]
			x, t, lengths = padMatrixWithTime(batchX, batchT, options)
		else:
			x, lengths = padMatrixWithoutTime(batchX, options)

		n_timesteps = x.shape[0]
		n_samples = x.shape[1]
		emb = np.dot(x, params['W_emb'])

		if useTime:
			temb = np.concatenate([emb, t.reshape((n_timesteps,n_samples,1))], axis=2)
		else:
			temb = emb

		alpha, beta = get_result(temb)
		alpha = alpha[:,0]
		beta = beta[:,0,:]

		ct = (alpha[:,None] * beta * emb[:,0,:]).sum(axis=0)
		y_t = sigmoid(np.dot(ct, params['w_output']) + params['b_output'])

		buf = ''
		patient = batchX[0]
		for i in range(len(patient)):
			visit = patient[i]
			buf += '-------------- visit_index:%d ---------------\n' % i
			for j in range(len(visit)):
				code = visit[j]
				contribution = np.dot(params['w_output'].flatten(), alpha[i] * beta[i] * params['W_emb'][code])
				buf += '%s:%f  ' % (rtypes[code], contribution)
			buf += '\n------------------------------------\n'
		buf += 'patient_index:%d, label:%d, score:%f\n\n' % (index, label, y_t)
		outfd.write(buf + '\n')
	outfd.close()
	
def parse_arguments(parser):
	parser.add_argument('model_file', type=str, metavar='<model_file>', help='The path to the Numpy-compressed file containing the model parameters.')
	parser.add_argument('seq_file', type=str, metavar='<visit_file>', help='The path to the cPickled file containing visit information of patients')
	parser.add_argument('label_file', type=str, metavar='<label_file>', help='The path to the cPickled file containing label information of patients')
	parser.add_argument('type_file', type=str, metavar='<type_file>', help='The path to the cPickled dictionary for mapping medical code strings to integers')
	parser.add_argument('out_file', metavar='<out_file>', help='The path to the output models. The models will be saved after every epoch')
	parser.add_argument('--time_file', type=str, default='', help='The path to the cPickled file containing durations between visits of patients. If you are not using duration information, do not use this option')
	parser.add_argument('--use_log_time', type=int, default=1, choices=[0,1], help='Use logarithm of time duration to dampen the impact of the outliers (0 for false, 1 for true) (default value: 1)')
	parser.add_argument('--embed_file', type=str, default='', help='The path to the cPickled file containing the representation vectors of medical codes. If you are not using medical code representations, do not use this option')
	args = parser.parse_args()
	return args

if __name__ == '__main__':
	parser = argparse.ArgumentParser()
	args = parse_arguments(parser)

	train_RETAIN(
		modelFile=args.model_file,
		seqFile=args.seq_file, 
		labelFile=args.label_file, 
		typeFile=args.type_file, 
		outFile=args.out_file, 
		timeFile=args.time_file, 
		useLogTime=args.use_log_time,
		embFile=args.embed_file
	)