'''
Skip-thought vectors
https://github.com/ryankiros/skip-thoughts
'''
import os

import theano
import theano.tensor as tensor

import cPickle as pkl
import numpy
import copy
import nltk

from collections import OrderedDict, defaultdict
from scipy.linalg import norm
from nltk.tokenize import word_tokenize

profile = False

#-----------------------------------------------------------------------------#
# Specify model and table locations here
#-----------------------------------------------------------------------------#
path_to_models = 'Data/skipthoughts/'
path_to_tables = 'Data/skipthoughts/'
#-----------------------------------------------------------------------------#

path_to_umodel = path_to_models + 'uni_skip.npz'
path_to_bmodel = path_to_models + 'bi_skip.npz'


def load_model():
	"""
	Load the model with saved tables
	"""
	# Load model options
	print 'Loading model parameters...'
	with open('%s.pkl'%path_to_umodel, 'rb') as f:
		uoptions = pkl.load(f)
	with open('%s.pkl'%path_to_bmodel, 'rb') as f:
		boptions = pkl.load(f)

	# Load parameters
	uparams = init_params(uoptions)
	uparams = load_params(path_to_umodel, uparams)
	utparams = init_tparams(uparams)
	bparams = init_params_bi(boptions)
	bparams = load_params(path_to_bmodel, bparams)
	btparams = init_tparams(bparams)

	# Extractor functions
	print 'Compiling encoders...'
	embedding, x_mask, ctxw2v = build_encoder(utparams, uoptions)
	f_w2v = theano.function([embedding, x_mask], ctxw2v, name='f_w2v')
	embedding, x_mask, ctxw2v = build_encoder_bi(btparams, boptions)
	f_w2v2 = theano.function([embedding, x_mask], ctxw2v, name='f_w2v2')

	# Tables
	print 'Loading tables...'
	utable, btable = load_tables()

	# Store everything we need in a dictionary
	print 'Packing up...'
	model = {}
	model['uoptions'] = uoptions
	model['boptions'] = boptions
	model['utable'] = utable
	model['btable'] = btable
	model['f_w2v'] = f_w2v
	model['f_w2v2'] = f_w2v2

	return model


def load_tables():
	"""
	Load the tables
	"""
	words = []
	utable = numpy.load(path_to_tables + 'utable.npy')
	btable = numpy.load(path_to_tables + 'btable.npy')
	f = open(path_to_tables + 'dictionary.txt', 'rb')
	for line in f:
		words.append(line.decode('utf-8').strip())
	f.close()
	utable = OrderedDict(zip(words, utable))
	btable = OrderedDict(zip(words, btable))
	return utable, btable


def encode(model, X, use_norm=True, verbose=True, batch_size=128, use_eos=False):
	"""
	Encode sentences in the list X. Each entry will return a vector
	"""
	# first, do preprocessing
	X = preprocess(X)

	# word dictionary and init
	d = defaultdict(lambda : 0)
	for w in model['utable'].keys():
		d[w] = 1
	ufeatures = numpy.zeros((len(X), model['uoptions']['dim']), dtype='float32')
	bfeatures = numpy.zeros((len(X), 2 * model['boptions']['dim']), dtype='float32')

	# length dictionary
	ds = defaultdict(list)
	captions = [s.split() for s in X]
	for i,s in enumerate(captions):
		ds[len(s)].append(i)

	# Get features. This encodes by length, in order to avoid wasting computation
	for k in ds.keys():
		if verbose:
			print k
		numbatches = len(ds[k]) / batch_size + 1
		for minibatch in range(numbatches):
			caps = ds[k][minibatch::numbatches]

			if use_eos:
				uembedding = numpy.zeros((k+1, len(caps), model['uoptions']['dim_word']), dtype='float32')
				bembedding = numpy.zeros((k+1, len(caps), model['boptions']['dim_word']), dtype='float32')
			else:
				uembedding = numpy.zeros((k, len(caps), model['uoptions']['dim_word']), dtype='float32')
				bembedding = numpy.zeros((k, len(caps), model['boptions']['dim_word']), dtype='float32')
			for ind, c in enumerate(caps):
				caption = captions[c]
				for j in range(len(caption)):
					if d[caption[j]] > 0:
						uembedding[j,ind] = model['utable'][caption[j]]
						bembedding[j,ind] = model['btable'][caption[j]]
					else:
						uembedding[j,ind] = model['utable']['UNK']
						bembedding[j,ind] = model['btable']['UNK']
				if use_eos:
					uembedding[-1,ind] = model['utable']['<eos>']
					bembedding[-1,ind] = model['btable']['<eos>']
			if use_eos:
				uff = model['f_w2v'](uembedding, numpy.ones((len(caption)+1,len(caps)), dtype='float32'))
				bff = model['f_w2v2'](bembedding, numpy.ones((len(caption)+1,len(caps)), dtype='float32'))
			else:
				uff = model['f_w2v'](uembedding, numpy.ones((len(caption),len(caps)), dtype='float32'))
				bff = model['f_w2v2'](bembedding, numpy.ones((len(caption),len(caps)), dtype='float32'))
			if use_norm:
				for j in range(len(uff)):
					uff[j] /= norm(uff[j])
					bff[j] /= norm(bff[j])
			for ind, c in enumerate(caps):
				ufeatures[c] = uff[ind]
				bfeatures[c] = bff[ind]

	features = numpy.c_[ufeatures, bfeatures]
	return features


def preprocess(text):
	"""
	Preprocess text for encoder
	"""
	X = []
	sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
	for t in text:
		sents = sent_detector.tokenize(t)
		result = ''
		for s in sents:
			tokens = word_tokenize(s)
			result += ' ' + ' '.join(tokens)
		X.append(result)
	return X


def nn(model, text, vectors, query, k=5):
	"""
	Return the nearest neighbour sentences to query
	text: list of sentences
	vectors: the corresponding representations for text
	query: a string to search
	"""
	qf = encode(model, [query])
	qf /= norm(qf)
	scores = numpy.dot(qf, vectors.T).flatten()
	sorted_args = numpy.argsort(scores)[::-1]
	sentences = [text[a] for a in sorted_args[:k]]
	print 'QUERY: ' + query
	print 'NEAREST: '
	for i, s in enumerate(sentences):
		print s, sorted_args[i]


def word_features(table):
	"""
	Extract word features into a normalized matrix
	"""
	features = numpy.zeros((len(table), 620), dtype='float32')
	keys = table.keys()
	for i in range(len(table)):
		f = table[keys[i]]
		features[i] = f / norm(f)
	return features


def nn_words(table, wordvecs, query, k=10):
	"""
	Get the nearest neighbour words
	"""
	keys = table.keys()
	qf = table[query]
	scores = numpy.dot(qf, wordvecs.T).flatten()
	sorted_args = numpy.argsort(scores)[::-1]
	words = [keys[a] for a in sorted_args[:k]]
	print 'QUERY: ' + query
	print 'NEAREST: '
	for i, w in enumerate(words):
		print w


def _p(pp, name):
	"""
	make prefix-appended name
	"""
	return '%s_%s'%(pp, name)


def init_tparams(params):
	"""
	initialize Theano shared variables according to the initial parameters
	"""
	tparams = OrderedDict()
	for kk, pp in params.iteritems():
		tparams[kk] = theano.shared(params[kk], name=kk)
	return tparams


def load_params(path, params):
	"""
	load parameters
	"""
	pp = numpy.load(path)
	for kk, vv in params.iteritems():
		if kk not in pp:
			warnings.warn('%s is not in the archive'%kk)
			continue
		params[kk] = pp[kk]
	return params


# layers: 'name': ('parameter initializer', 'feedforward')
layers = {'gru': ('param_init_gru', 'gru_layer')}

def get_layer(name):
	fns = layers[name]
	return (eval(fns[0]), eval(fns[1]))


def init_params(options):
	"""
	initialize all parameters needed for the encoder
	"""
	params = OrderedDict()

	# embedding
	params['Wemb'] = norm_weight(options['n_words_src'], options['dim_word'])

	# encoder: GRU
	params = get_layer(options['encoder'])[0](options, params, prefix='encoder',
											  nin=options['dim_word'], dim=options['dim'])
	return params


def init_params_bi(options):
	"""
	initialize all paramters needed for bidirectional encoder
	"""
	params = OrderedDict()

	# embedding
	params['Wemb'] = norm_weight(options['n_words_src'], options['dim_word'])

	# encoder: GRU
	params = get_layer(options['encoder'])[0](options, params, prefix='encoder',
											  nin=options['dim_word'], dim=options['dim'])
	params = get_layer(options['encoder'])[0](options, params, prefix='encoder_r',
											  nin=options['dim_word'], dim=options['dim'])
	return params


def build_encoder(tparams, options):
	"""
	build an encoder, given pre-computed word embeddings
	"""
	# word embedding (source)
	embedding = tensor.tensor3('embedding', dtype='float32')
	x_mask = tensor.matrix('x_mask', dtype='float32')

	# encoder
	proj = get_layer(options['encoder'])[1](tparams, embedding, options,
											prefix='encoder',
											mask=x_mask)
	ctx = proj[0][-1]

	return embedding, x_mask, ctx


def build_encoder_bi(tparams, options):
	"""
	build bidirectional encoder, given pre-computed word embeddings
	"""
	# word embedding (source)
	embedding = tensor.tensor3('embedding', dtype='float32')
	embeddingr = embedding[::-1]
	x_mask = tensor.matrix('x_mask', dtype='float32')
	xr_mask = x_mask[::-1]

	# encoder
	proj = get_layer(options['encoder'])[1](tparams, embedding, options,
											prefix='encoder',
											mask=x_mask)
	projr = get_layer(options['encoder'])[1](tparams, embeddingr, options,
											 prefix='encoder_r',
											 mask=xr_mask)

	ctx = tensor.concatenate([proj[0][-1], projr[0][-1]], axis=1)

	return embedding, x_mask, ctx


# some utilities
def ortho_weight(ndim):
	W = numpy.random.randn(ndim, ndim)
	u, s, v = numpy.linalg.svd(W)
	return u.astype('float32')


def norm_weight(nin,nout=None, scale=0.1, ortho=True):
	if nout == None:
		nout = nin
	if nout == nin and ortho:
		W = ortho_weight(nin)
	else:
		W = numpy.random.uniform(low=-scale, high=scale, size=(nin, nout))
	return W.astype('float32')


def param_init_gru(options, params, prefix='gru', nin=None, dim=None):
	"""
	parameter init for GRU
	"""
	if nin == None:
		nin = options['dim_proj']
	if dim == None:
		dim = options['dim_proj']
	W = numpy.concatenate([norm_weight(nin,dim),
						   norm_weight(nin,dim)], axis=1)
	params[_p(prefix,'W')] = W
	params[_p(prefix,'b')] = numpy.zeros((2 * dim,)).astype('float32')
	U = numpy.concatenate([ortho_weight(dim),
						   ortho_weight(dim)], axis=1)
	params[_p(prefix,'U')] = U

	Wx = norm_weight(nin, dim)
	params[_p(prefix,'Wx')] = Wx
	Ux = ortho_weight(dim)
	params[_p(prefix,'Ux')] = Ux
	params[_p(prefix,'bx')] = numpy.zeros((dim,)).astype('float32')

	return params


def gru_layer(tparams, state_below, options, prefix='gru', mask=None, **kwargs):
	"""
	Forward pass through GRU layer
	"""
	nsteps = state_below.shape[0]
	if state_below.ndim == 3:
		n_samples = state_below.shape[1]
	else:
		n_samples = 1

	dim = tparams[_p(prefix,'Ux')].shape[1]

	if mask == None:
		mask = tensor.alloc(1., state_below.shape[0], 1)

	def _slice(_x, n, dim):
		if _x.ndim == 3:
			return _x[:, :, n*dim:(n+1)*dim]
		return _x[:, n*dim:(n+1)*dim]

	state_below_ = tensor.dot(state_below, tparams[_p(prefix, 'W')]) + tparams[_p(prefix, 'b')]
	state_belowx = tensor.dot(state_below, tparams[_p(prefix, 'Wx')]) + tparams[_p(prefix, 'bx')]
	U = tparams[_p(prefix, 'U')]
	Ux = tparams[_p(prefix, 'Ux')]

	def _step_slice(m_, x_, xx_, h_, U, Ux):
		preact = tensor.dot(h_, U)
		preact += x_

		r = tensor.nnet.sigmoid(_slice(preact, 0, dim))
		u = tensor.nnet.sigmoid(_slice(preact, 1, dim))

		preactx = tensor.dot(h_, Ux)
		preactx = preactx * r
		preactx = preactx + xx_

		h = tensor.tanh(preactx)

		h = u * h_ + (1. - u) * h
		h = m_[:,None] * h + (1. - m_)[:,None] * h_

		return h

	seqs = [mask, state_below_, state_belowx]
	_step = _step_slice

	rval, updates = theano.scan(_step,
								sequences=seqs,
								outputs_info = [tensor.alloc(0., n_samples, dim)],
								non_sequences = [tparams[_p(prefix, 'U')],
												 tparams[_p(prefix, 'Ux')]],
								name=_p(prefix, '_layers'),
								n_steps=nsteps,
								profile=profile,
								strict=True)
	rval = [rval]
	return rval