'''
Skip-thought vectors
https://github.com/ryankiros/skip-thoughts
'''
import os

import theano
import theano.tensor as tensor

import _pickle as pkl
import numpy
import copy
import nltk

from collections import OrderedDict, defaultdict
from scipy.linalg import norm
from nltk.tokenize import word_tokenize

profile = False

#-----------------------------------------------------------------------------#
# Specify model and table locations here
#-----------------------------------------------------------------------------#
path_to_models = 'Data/skipthoughts/'
path_to_tables = 'Data/skipthoughts/'
#-----------------------------------------------------------------------------#

path_to_umodel = path_to_models + 'uni_skip.npz'
path_to_bmodel = path_to_models + 'bi_skip.npz'


def load_model():
	"""
	Load the model with saved tables
	"""
	# Load model options
	print('Loading model parameters...')
	with open('%s.pkl'%path_to_umodel, 'rb') as f:
		uoptions = pkl.load(f)
	with open('%s.pkl'%path_to_bmodel, 'rb') as f:
		boptions = pkl.load(f)

	# Load parameters
	uparams = init_params(uoptions)
	uparams = load_params(path_to_umodel, uparams)
	utparams = init_tparams(uparams)
	bparams = init_params_bi(boptions)
	bparams = load_params(path_to_bmodel, bparams)
	btparams = init_tparams(bparams)

	# Extractor functions
	print('Compiling encoders...')
	embedding, x_mask, ctxw2v = build_encoder(utparams, uoptions)
	f_w2v = theano.function([embedding, x_mask], ctxw2v, name='f_w2v')
	embedding, x_mask, ctxw2v = build_encoder_bi(btparams, boptions)
	f_w2v2 = theano.function([embedding, x_mask], ctxw2v, name='f_w2v2')

	# Tables
	print('Loading tables...')
	utable, btable = load_tables()

	# Store everything we need in a dictionary
	print('Packing up...')
	model = {}
	model['uoptions'] = uoptions
	model['boptions'] = boptions
	model['utable'] = utable
	model['btable'] = btable
	model['f_w2v'] = f_w2v
	model['f_w2v2'] = f_w2v2

	return model


def load_tables():
	"""
	Load the tables
	"""
	words = []
	utable = numpy.load(path_to_tables + 'utable.npy', encoding='bytes')
	btable = numpy.load(path_to_tables + 'btable.npy', encoding='bytes')
	f = open(path_to_tables + 'dictionary.txt', 'rb')
	for line in f:
		words.append(line.decode('utf-8').strip())
	f.close()
	utable = OrderedDict(zip(words, utable))
	btable = OrderedDict(zip(words, btable))
	return utable, btable


def encode(model, X, use_norm=True, verbose=True, batch_size=128, use_eos=False):
	"""
	Encode sentences in the list X. Each entry will return a vector
	"""
	# first, do preprocessing
	X = preprocess(X)

	# word dictionary and init
	d = defaultdict(lambda : 0)
	for w in model['utable'].keys():
		d[w] = 1
	ufeatures = numpy.zeros((len(X), model['uoptions']['dim']), dtype='float32')
	bfeatures = numpy.zeros((len(X), 2 * model['boptions']['dim']), dtype='float32')

	# length dictionary
	ds = defaultdict(list)
	captions = [s.split() for s in X]
	for i,s in enumerate(captions):
		ds[len(s)].append(i)

	# Get features. This encodes by length, in order to avoid wasting computation
	for k in ds.keys():
		if verbose:
			print(k)
		numbatches = len(ds[k]) / batch_size + 1
		for minibatch in range(int(numbatches)):
			caps = ds[k][int(minibatch)::int(numbatches)]

			if use_eos:
				uembedding = numpy.zeros((k+1, len(caps), model['uoptions']['dim_word']), dtype='float32')
				bembedding = numpy.zeros((k+1, len(caps), model['boptions']['dim_word']), dtype='float32')
			else:
				uembedding = numpy.zeros((k, len(caps), model['uoptions']['dim_word']), dtype='float32')
				bembedding = numpy.zeros((k, len(caps), model['boptions']['dim_word']), dtype='float32')
			for ind, c in enumerate(caps):
				caption = captions[c]
				for j in range(len(caption)):
					if d[caption[j]] > 0:
						uembedding[j,ind] = model['utable'][caption[j]]
						bembedding[j,ind] = model['btable'][caption[j]]
					else:
						uembedding[j,ind] = model['utable']['UNK']
						bembedding[j,ind] = model['btable']['UNK']
				if use_eos:
					uembedding[-1,ind] = model['utable']['<eos>']
					bembedding[-1,ind] = model['btable']['<eos>']
			if use_eos:
				uff = model['f_w2v'](uembedding, numpy.ones((len(caption)+1,len(caps)), dtype='float32'))
				bff = model['f_w2v2'](bembedding, numpy.ones((len(caption)+1,len(caps)), dtype='float32'))
			else:
				uff = model['f_w2v'](uembedding, numpy.ones((len(caption),len(caps)), dtype='float32'))
				bff = model['f_w2v2'](bembedding, numpy.ones((len(caption),len(caps)), dtype='float32'))
			if use_norm:
				for j in range(len(uff)):
					uff[j] /= norm(uff[j])
					bff[j] /= norm(bff[j])
			for ind, c in enumerate(caps):
				ufeatures[c] = uff[ind]
				bfeatures[c] = bff[ind]
	
	features = numpy.c_[ufeatures, bfeatures]
	return features


def preprocess(text):
	"""
	Preprocess text for encoder
	"""
	X = []
	sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
	for t in text:
		sents = sent_detector.tokenize(t)
		result = ''
		for s in sents:
			tokens = word_tokenize(s)
			result += ' ' + ' '.join(tokens)
		X.append(result)
	return X


def nn(model, text, vectors, query, k=5):
	"""
	Return the nearest neighbour sentences to query
	text: list of sentences
	vectors: the corresponding representations for text
	query: a string to search
	"""
	qf = encode(model, [query])
	qf /= norm(qf)
	scores = numpy.dot(qf, vectors.T).flatten()
	sorted_args = numpy.argsort(scores)[::-1]
	sentences = [text[a] for a in sorted_args[:k]]
	print('QUERY: ' + query)
	print('NEAREST: ')
	for i, s in enumerate(sentences):
		print(s, sorted_args[i])


def word_features(table):
	"""
	Extract word features into a normalized matrix
	"""
	features = numpy.zeros((len(table), 620), dtype='float32')
	keys = table.keys()
	for i in range(len(table)):
		f = table[keys[i]]
		features[i] = f / norm(f)
	return features


def nn_words(table, wordvecs, query, k=10):
	"""
	Get the nearest neighbour words
	"""
	keys = table.keys()
	qf = table[query]
	scores = numpy.dot(qf, wordvecs.T).flatten()
	sorted_args = numpy.argsort(scores)[::-1]
	words = [keys[a] for a in sorted_args[:k]]
	print('QUERY: ' + query)
	print('NEAREST: ')
	for i, w in enumerate(words):
		print(w)


def _p(pp, name):
	"""
	make prefix-appended name
	"""
	return '%s_%s'%(pp, name)


def init_tparams(params):
	"""
	initialize Theano shared variables according to the initial parameters
	"""
	tparams = OrderedDict()
	for kk, pp in params.items():
		tparams[kk] = theano.shared(params[kk], name=kk)
	return tparams


def load_params(path, params):
	"""
	load parameters
	"""
	pp = numpy.load(path)
	for kk, vv in params.items():
		if kk not in pp:
			warnings.warn('%s is not in the archive'%kk)
			continue
		params[kk] = pp[kk]
	return params


# layers: 'name': ('parameter initializer', 'feedforward')
layers = {'gru': ('param_init_gru', 'gru_layer')}

def get_layer(name):
	fns = layers[name]
	return (eval(fns[0]), eval(fns[1]))


def init_params(options):
	"""
	initialize all parameters needed for the encoder
	"""
	params = OrderedDict()

	# embedding
	params['Wemb'] = norm_weight(options['n_words_src'], options['dim_word'])

	# encoder: GRU
	params = get_layer(options['encoder'])[0](options, params, prefix='encoder',
											  nin=options['dim_word'], dim=options['dim'])
	return params


def init_params_bi(options):
	"""
	initialize all paramters needed for bidirectional encoder
	"""
	params = OrderedDict()

	# embedding
	params['Wemb'] = norm_weight(options['n_words_src'], options['dim_word'])

	# encoder: GRU
	params = get_layer(options['encoder'])[0](options, params, prefix='encoder',
											  nin=options['dim_word'], dim=options['dim'])
	params = get_layer(options['encoder'])[0](options, params, prefix='encoder_r',
											  nin=options['dim_word'], dim=options['dim'])
	return params


def build_encoder(tparams, options):
	"""
	build an encoder, given pre-computed word embeddings
	"""
	# word embedding (source)
	embedding = tensor.tensor3('embedding', dtype='float32')
	x_mask = tensor.matrix('x_mask', dtype='float32')

	# encoder
	proj = get_layer(options['encoder'])[1](tparams, embedding, options,
											prefix='encoder',
											mask=x_mask)
	ctx = proj[0][-1]

	return embedding, x_mask, ctx


def build_encoder_bi(tparams, options):
	"""
	build bidirectional encoder, given pre-computed word embeddings
	"""
	# word embedding (source)
	embedding = tensor.tensor3('embedding', dtype='float32')
	embeddingr = embedding[::-1]
	x_mask = tensor.matrix('x_mask', dtype='float32')
	xr_mask = x_mask[::-1]

	# encoder
	proj = get_layer(options['encoder'])[1](tparams, embedding, options,
											prefix='encoder',
											mask=x_mask)
	projr = get_layer(options['encoder'])[1](tparams, embeddingr, options,
											 prefix='encoder_r',
											 mask=xr_mask)

	ctx = tensor.concatenate([proj[0][-1], projr[0][-1]], axis=1)

	return embedding, x_mask, ctx


# some utilities
def ortho_weight(ndim):
	W = numpy.random.randn(ndim, ndim)
	u, s, v = numpy.linalg.svd(W)
	return u.astype('float32')


def norm_weight(nin,nout=None, scale=0.1, ortho=True):
	if nout == None:
		nout = nin
	if nout == nin and ortho:
		W = ortho_weight(nin)
	else:
		W = numpy.random.uniform(low=-scale, high=scale, size=(nin, nout))
	return W.astype('float32')


def param_init_gru(options, params, prefix='gru', nin=None, dim=None):
	"""
	parameter init for GRU
	"""
	if nin == None:
		nin = options['dim_proj']
	if dim == None:
		dim = options['dim_proj']
	W = numpy.concatenate([norm_weight(nin,dim),
						   norm_weight(nin,dim)], axis=1)
	params[_p(prefix,'W')] = W
	params[_p(prefix,'b')] = numpy.zeros((2 * dim,)).astype('float32')
	U = numpy.concatenate([ortho_weight(dim),
						   ortho_weight(dim)], axis=1)
	params[_p(prefix,'U')] = U

	Wx = norm_weight(nin, dim)
	params[_p(prefix,'Wx')] = Wx
	Ux = ortho_weight(dim)
	params[_p(prefix,'Ux')] = Ux
	params[_p(prefix,'bx')] = numpy.zeros((dim,)).astype('float32')

	return params


def gru_layer(tparams, state_below, options, prefix='gru', mask=None, **kwargs):
	"""
	Forward pass through GRU layer
	"""
	nsteps = state_below.shape[0]
	if state_below.ndim == 3:
		n_samples = state_below.shape[1]
	else:
		n_samples = 1

	dim = tparams[_p(prefix,'Ux')].shape[1]

	if mask == None:
		mask = tensor.alloc(1., state_below.shape[0], 1)

	def _slice(_x, n, dim):
		if _x.ndim == 3:
			return _x[:, :, n*dim:(n+1)*dim]
		return _x[:, n*dim:(n+1)*dim]

	state_below_ = tensor.dot(state_below, tparams[_p(prefix, 'W')]) + tparams[_p(prefix, 'b')]
	state_belowx = tensor.dot(state_below, tparams[_p(prefix, 'Wx')]) + tparams[_p(prefix, 'bx')]
	U = tparams[_p(prefix, 'U')]
	Ux = tparams[_p(prefix, 'Ux')]

	def _step_slice(m_, x_, xx_, h_, U, Ux):
		preact = tensor.dot(h_, U)
		preact += x_

		r = tensor.nnet.sigmoid(_slice(preact, 0, dim))
		u = tensor.nnet.sigmoid(_slice(preact, 1, dim))

		preactx = tensor.dot(h_, Ux)
		preactx = preactx * r
		preactx = preactx + xx_

		h = tensor.tanh(preactx)

		h = u * h_ + (1. - u) * h
		h = m_[:,None] * h + (1. - m_)[:,None] * h_

		return h

	seqs = [mask, state_below_, state_belowx]
	_step = _step_slice

	rval, updates = theano.scan(_step,
								sequences=seqs,
								outputs_info = [tensor.alloc(0., n_samples, dim)],
								non_sequences = [tparams[_p(prefix, 'U')],
												 tparams[_p(prefix, 'Ux')]],
								name=_p(prefix, '_layers'),
								n_steps=nsteps,
								profile=profile,
								strict=True)
	rval = [rval]
	return rval