# TODO: 1. makes PAD and split the big string into sentences.
#       2. make batch first and then split (or maybe not, given that there is no need for validation set and test set to have batches)

from nltk.tokenize import word_tokenize
import collections
import itertools
import logging
import re
import os
import numpy as np
import time
import pickle
from utils import set_logger

class Data(object):
	def __init__(self, dataPath, savePath, paramSavePath, logPath, debug, split_percent, batch_size, timestr, timestep, window):
		'''
			* dataPath is way to find the data. We have two data files.
				One is the real size as described in the paper.
				Another is a much smaller dataset with 100 sentences
				from both arXiv and book dataset used for early code test.
			* debug is the indicator whether we are testing our code or real training.
				default: debug = True, testing code mode.
			# split_percent: training set : validation set : testing set
		'''
		self.debug = debug
		self.savePath = savePath
		self.dataPath = dataPath if not self.debug else '../data/data_pre.txt'
		self.paramSavePath = paramSavePath
		self.logger = set_logger(logPath, timestr, os.path.basename(__file__))
		self.split_percent = split_percent
		self.timestep = timestep
		self.window = window
		self.load_data()
	#   self.data is the list containing all the contents in data file
	#   self.sentSize: how many sentences.
		self.clean_str()
		self.word2num()
	#   self.dataArr: an np.ndarray version of self.data
	#   self.mapToNum is the word - index map. A word's index can be visited by self.mapToNum['word'].
	#   self.dataNum maps words in self.dataStr into number. (np.ndarray)
	#   self.vocabSize is vocabulary size
		self.split_tvt()
	#   self.train training set
	#   self.validation validation set
	#   self.test testing set
	#   self.shift() Shift first 10% of self.dataNum and split tvt sets again.
		self.batch_size = batch_size if not self.debug else 10
	def load_data(self):
		'''	Load data from self.dataPath into one string.
		'''
		try:
			with open(self.dataPath) as f:
				self.data = f.read().splitlines()
			self.sentSize = len(self.data)
			if self.debug:
				self.logger.info('load_data finished')
		except:
			msg = 'File does not exist.\n' + \
				'Sorry the dataset used here is protected under copyright.\n\n' + \
				'If you would like to use the dateset,' + \
				'please kindly read the README.md under data folder.'
			self.logger.info(msg)
			raise Exception(msg)
	
	def clean_str(self):
		"""	Tokenization/string cleaning for all datasets except for SST.
			https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
			It seems that they usually use this function to do some cleaning,
			though I really don't know why. Hope I can figure it out later.
		"""
		for i in range(self.sentSize):
			string = self.data[i]
			string = re.sub(r"[^A-Za-z0-9(),!?\'\`_]", " ", string)
			string = re.sub(r"\'s", " \'s", string)
			string = re.sub(r"\'ve", " \'ve", string)
			string = re.sub(r"n\'t", " n\'t", string)
			string = re.sub(r"\'re", " \'re", string)
			string = re.sub(r"\'d", " \'d", string)
			string = re.sub(r"\'ll", " \'ll", string)
			string = re.sub(r",", " , ", string)
			string = re.sub(r"!", " ! ", string)
			string = re.sub(r"\(", " \( ", string)
			string = re.sub(r"\)", " \) ", string)
			string = re.sub(r"\?", " \? ", string)
			string = re.sub(r"\s{2,}", " ", string)
			string = re.sub("__LaTex__", "", string) # consider how to deal with latex later.
			# prevent some special words from converting to lowercase.
			specialwords = ['EOS', 'GOO', '__LaTex__']
			toLower = lambda x: " ".join( a if a in specialwords else a.lower() \
				for a in x.split() )
			self.data[i] = toLower(string.strip())
		if self.debug:
			self.logger.info('clean_str finished')
	
	def word2num(self):
		''' Index each word to a unique index.
			Here we use its frequency rank to index, though it does
			not really matter how to index.
		'''
		# remove sentences with more than self.timestep words.
		# or else, there would be too many PADs in the end.
		sentenceSplit = [word_tokenize(x) for x in self.data]
		#sentenceLength = np.asarray([len(x) for x in sentenceSplit])
		# add 'PAD' to the end if a sentence is shorter than self.timestep words.
		# Then add 'PAD' before and after the whole sentence.
		sentenceSplitPAD = []
		for i in range(len(sentenceSplit)):
			if self.timestep - len(sentenceSplit[i]) >= 0:
				sentenceSplitPAD.append(['PAD'] * (max(self.window) - 1) + sentenceSplit[i] + \
										['PAD'] * (self.timestep - len(sentenceSplit[i]) + max(self.window) - 1))
		#self.dataArr = np.asarray(sentenceSplitPAD)
		# There won't be significant difference if I just remove some sentences.
		# So, when counting how many words in the whole dataset
		# I just use the original one.
		words = list(itertools.chain(*sentenceSplitPAD))
		a = collections.Counter(words)
		# add some symbols. see url below for details
		# https://github.com/nicolas-ivanov/tf_seq2seq_chatbot/issues/15
		a['UNK'] = 1
		# a['PAD'] = 1
		self.vocabSize = len(a.keys())
		b = a.most_common(self.vocabSize)
		self.mapToNum = collections.defaultdict(list)
		self.mapToWord = collections.defaultdict(list)
		i = 0
		for k, _ in b:
			self.mapToNum[k].append(i)
			self.mapToWord[i].append(k)
			i += 1
		self.dataNum = []
		for sentence in sentenceSplitPAD:
			sentenceNum = []
			for word in sentence:
				if word in self.mapToNum:
					sentenceNum.extend(self.mapToNum[word])
				else:
					sentenceNum.extend(self.mapToNum['UNK'])
			self.dataNum.append(sentenceNum)
		self.dataNum = np.asarray(self.dataNum)
		self.sentSize = len(self.dataNum)
		# save dataNum, in case we may use it later.
		fileName = 'dataNum_' + time.strftime("%Y%m%d_%H%M%S")
		np.save(self.savePath + fileName, self.dataNum)
		self.logger.info("'dataNum' save to " + self.savePath + fileName)
	
	def split_tvt(self, shift=False):
		'''	split data into Training set, Validation Set, Testing set.
			* shift: if True, shift first 10% of self.dataNum and split tvt sets again.
		'''
		if shift:
			self.dataNum = np.concatenate((self.dataNum[int(self.sentSize * 0.1):],
											self.dataNum[:int(self.sentSize * 0.1)]))
		self.train = self.dataNum[:int(self.sentSize * self.split_percent[0])]
		self.validation = self.dataNum[int(self.sentSize * self.split_percent[0]):
				int(self.sentSize * (self.split_percent[0] + self.split_percent[1]))]
		self.test = self.dataNum[int(self.sentSize * self.split_percent[1]):]
		if self.debug:
			self.logger.info('split_tvt finished')

	def shift(self):
		'''	A dump version for split_tvt, with argument shift=True
			Shift first 10% of self.dataNum and split tvt sets again.
		'''
		self.split_tvt(shift=True)

	def get_first_batch(self, whichSet='train'):
		'''	Used for get the first batch of t/v/t set.
			whichSet: choose from 'train', 'validation', 'test'
		'''
		if whichSet == 'train':
			self.trainBatchCnt = 0
			self.trainMaxBatch = int(len(self.train) / self.batch_size)
			return self.train[:self.batch_size]
		elif whichSet == 'validation':
			self.validationBatchCnt = 0
			self.validationMaxBatch = int(len(self.validation) / self.batch_size)
			return self.validation[:self.batch_size]
		elif whichSet == 'test':
			self.testBatchCnt = 0
			self.testMaxBatch = int(len(self.test) / self.batch_size)
			return self.test[:self.batch_size]
		else:
			msg = 'Wrong set name!\n'+ \
				  'Should be train / validation / test.'
			raise Exception(msg)

	def next_batch(self, whichSet='train'):
		if whichSet == 'train':
			self.trainBatchCnt += 1
			assert self.trainBatchCnt < self.trainMaxBatch
			return self.train[self.trainBatchCnt * self.batch_size: (self.trainBatchCnt + 1) * self.batch_size]
		elif whichSet == 'validation':
			self.validationBatchCnt += 1
			assert self.validationBatchCnt < self.validationMaxBatch
			return self.validation[self.validationBatchCnt * self.batch_size: (self.validationBatchCnt + 1) * self.batch_size]
		elif whichSet == 'test':
			self.testBatchCnt += 1
			assert self.testBatchCnt < self.testMaxBatch
			return self.test[self.testBatchCnt * self.batch_size: (self.testBatchCnt + 1) * self.batch_size]
		else:
			msg = 'Wrong set name!\n'+ \
				  'Should be train / validation / test.'
			raise Exception(msg)
	# Following code copied here:
	# https://stackoverflow.com/questions/17219481/save-to-file-and-load-an-instance-of-a-python-class-with-its-attributes
	def save(self, fileName):
		assert fileName is not None

		with open(self.paramSavePath + fileName, 'wb') as f:
			pickle.dump(self, f, pickle.HIGHEST_PROTOCOL)
	
	@staticmethod
	def load(self, fileName):
		assert fileName is not None

		with open(self.paramSavePath + fileName, 'rb') as f:
			return pickle.load(f)