import re
import json
import numpy as np
from tqdm import tqdm
import pdb
import os
import pickle
import cPickle
import string
import sys

def unpickle(p):
	return cPickle.load(open(p,'r'))

def load_json(p):
	return json.load(open(p,'r'))

def clean_words(data):
	dict = {}
	freq = {}
	# start with 1
	idx = 1
	sentence_count = 0
	eliminate = 0
	max_w = 30
	for k in tqdm(range(len(data['caption_entity']))):
		sen = data['caption_entity'][k]
		filename = data['file_name'][k]
		# skip the no image description
		words = re.split(' ', sen)
		# pop the last u'.'
		n = len(words)
		if "" in words:
		    words.remove("")
		if n <= max_w:
			sentence_count += 1
			for word in words:
				if "\n" in word:
                                        word = word.replace("\n", "")
				for p in string.punctuation:
					if p in word:
						word = word.replace(p,'')
				word = word.lower()
				if word not in dict.keys():
					dict[word] = idx
					idx += 1
					freq[word] = 1
				else:
					freq[word] += 1
		else:
			eliminate += 1
	print 'Threshold(max_words) =', max_w
	print 'Eliminate =', eliminate 
	print 'Total sentence_count =', sentence_count
	print 'Number of different words =', len(dict.keys())
	print 'Saving....'
	np.savez('K_cleaned_words', dict=dict, freq=freq)
	return dict, freq

phase = sys.argv[1]
data_path = './mscoco_data/K_annotation_'+phase+'2014.pkl'
data = unpickle(data_path)
thres = 5
if not os.path.isfile('./mscoco_data/dictionary_'+str(thres)+'.npz'):
	# clean the words through the frequency
	if not os.path.isfile('K_cleaned_words.npz'):
		dict, freq = clean_words(data)
	else:
		words = np.load('K_cleaned_words.npz')
		dict = words['dict'].item(0)
		freq = words['freq'].item(0)
	idx2word = {}
	word2idx = {}
	idx = 1
	for k in tqdm(dict.keys()):
		if freq[k] >= thres and k != "":
			word2idx[k] = idx
			idx2word[str(idx)] = k
			idx += 1

	word2idx[u'<BOS>'] = 0
	idx2word["0"] = u'<BOS>'
	word2idx[u'<EOS>'] = len(word2idx.keys())
	idx2word[str(len(idx2word.keys()))] = u'<EOS>'
	word2idx[u'<UNK>'] = len(word2idx.keys())
	idx2word[str(len(idx2word.keys()))] = u'<UNK>'
	word2idx[u'<NOT>'] = len(word2idx.keys())
        idx2word[str(len(idx2word.keys()))] = u'<NOT>'
	print 'Threshold of word fequency =', thres
	print 'Total words in the dictionary =', len(word2idx.keys())
	np.savez('./mscoco_data/dictionary_'+str(thres), word2idx=word2idx, idx2word=idx2word)
else:
	tem = np.load('./mscoco_data/dictionary_'+str(thres)+'.npz')
	word2idx = tem['word2idx'].item(0)
	idx2word = tem['idx2word'].item(0)

num_sentence = 0
eliminate = 0
tokenized_caption_list = []
caption_list = []
filename_list = []
caption_length = []
for k in tqdm(range(len(data['caption_entity']))):
	sen = data['caption_entity'][k]
	filename = data['file_name'][k]
	# skip the no image description
	words = re.split(' ', sen)
	# pop the last u'.'
	tokenized_sent = np.zeros([30+1], dtype=int)
	tokenized_sent.fill(int(word2idx[u'<NOT>']))
	#tokenized_sent[0] = int(word2idx[u'<BOS>'])
	valid = True
	count = 0
	caption = []
	
	if len(words) <= 30:
		for word in words:
			try:
				word = word.lower()
				for p in string.punctuation:
                                        if p in word:
                                                word = word.replace(p,'')
				if word != "":
					idx = int(word2idx[word])
					tokenized_sent[count] = idx
					caption.append(word)
					count += 1
			except KeyError:
				# if contain <UNK> then drop the sentence
				if phase == 'train':
					valid = False
					break
				else:
					tokenized_sent[count] = int(word2idx[u'<UNK>'])
					count += 1
		if valid:
			tokenized_sent[count] = (word2idx["<EOS>"])
			caption_list.append(caption)
			length = np.sum((tokenized_sent!=0)+0)
			tokenized_caption_list.append(tokenized_sent)
			filename_list.append(filename)
			caption_length.append(length)
			num_sentence += 1
		else:
			if phase == 'val':
				pdb.set_trace()
			eliminate += 1	
tokenized_caption_info = {}
tokenized_caption_info['caption_length'] = np.asarray(caption_length)
tokenized_caption_info['tokenized_caption_list'] = np.asarray(tokenized_caption_list)
tokenized_caption_info['caption_list'] = np.asarray(caption_list)
tokenized_caption_info['filename_list'] = np.asarray(filename_list)
print 'Number of sentence =', num_sentence
with open('./mscoco_data/tokenized_'+phase+'_caption.pkl', 'w') as outfile:
	pickle.dump(tokenized_caption_info, outfile)