from keras.models import Sequential
from keras.layers import Dense, LSTM, GRU, Bidirectional
from keras.regularizers import l1, l2, l1_l2
from keras.optimizers import Adam, SGD
from .useful import *
from .Configs import get_all

class RecurrentModel():
	"""Recurrent neural network class with model generator from dictionary of parameters"""
	def __init__(self, x_train, y_train,
		parameters):
		self.current_params = parameters
		self.x_train = x_train
		self.output_dim = len(y_train[0])
		for key in get_all():
			if key not in parameters:
				raise KeyError(key, "parameter missing")


	def get_model(self):
		"""takes shape of x and y data as tuples, returns model"""
		dims = tuple(list(self.x_train.shape)[1:]) # tuple to list, omit dataset size
		model = Sequential()
		#input
		return_sequences = True
		if self.current_params["depth"] == 1:
			return_sequences = False

		model.add(self.__input_layer(dims, return_sequences)) 
		#hidden
		if self.current_params["depth"] > 2:
			for i, layer in enumerate(list(range(self.current_params["depth"] - 2))):
				model.add(self.__hidden_layer(True))
		if self.current_params["depth"] >= 2:
			model.add(self.__hidden_layer(False))
		#output
		model.add(self.__output_layer(self.output_dim)) 

		if self.current_params["optimiser"] == "adam":
			opt = Adam() #use default learning rate for adam
		elif self.current_params["optimiser"] == "sgd":
			opt = SGD(lr=self.current_params["learning_rate"])

		model.compile(
			optimizer=self.current_params["optimiser"],
			  loss=self.current_params["loss"],
			  metrics=['acc']
		)
		
		return model

	def __generate_regulariser(self, l1_value, l2_value):
		""" Returns keras l1/l2 regulariser"""
		if l1_value and l2_value:
			return l1_l2(l1=l1_value, l2=l2_value)
		elif l1_value and not l2_value:
			return l1(l1_value)
		elif l2_value:
			return l2(l2_value)
		else:
			return None


	def __input_layer(self, dims, return_sequences):
		""" Returns GRU or LSTM input layer """	
		if self.current_params["bidirectional"] == True:
			return Bidirectional(self.__middle_hidden_layer(return_sequences), input_shape=dims)

		else:	
			if self.current_params["layer_type"]  == "GRU":
				return GRU(self.current_params["hidden_neurons"], 
					input_shape=dims,
					return_sequences=return_sequences, 
					kernel_initializer=self.current_params["kernel_initializer"], 
					recurrent_initializer=self.current_params["recurrent_initializer"], 
					recurrent_regularizer=self.__generate_regulariser(self.current_params["r_l1_reg"], self.current_params["r_l2_reg"]), 
					bias_regularizer=self.__generate_regulariser(self.current_params["b_l1_reg"], self.current_params["b_l2_reg"]),
					dropout=self.current_params["dropout"], 
					recurrent_dropout=self.current_params["recurrent_dropout"]
				)

			return LSTM(self.current_params["hidden_neurons"], 
				input_shape=dims,
				return_sequences=return_sequences, 
				kernel_initializer=self.current_params["kernel_initializer"], 
				recurrent_initializer=self.current_params["recurrent_initializer"], 
				recurrent_regularizer=self.__generate_regulariser(self.current_params["r_l1_reg"], self.current_params["r_l2_reg"]), 
				bias_regularizer=self.__generate_regulariser(self.current_params["b_l1_reg"], self.current_params["b_l2_reg"]),
				dropout=self.current_params["dropout"], 
				recurrent_dropout=self.current_params["recurrent_dropout"] 
			)

	def __hidden_layer(self, return_sequences):
		""" reurns GRU or LSTM hidden layer """
		layer = self.__middle_hidden_layer(return_sequences)

		if self.current_params["bidirectional"] == True:
			return Bidirectional(layer)
		return layer

	def __middle_hidden_layer(self, return_sequences):

		if self.current_params["layer_type"]  == "GRU":
			layer = GRU(self.current_params["hidden_neurons"], 
				return_sequences=return_sequences, 
				kernel_initializer=self.current_params["kernel_initializer"], 
				recurrent_initializer=self.current_params["recurrent_initializer"], 
				recurrent_regularizer=self.__generate_regulariser(self.current_params["r_l1_reg"], self.current_params["r_l2_reg"]), 
				bias_regularizer=self.__generate_regulariser(self.current_params["b_l1_reg"], self.current_params["b_l2_reg"]),
				dropout=self.current_params["dropout"], 
				recurrent_dropout=self.current_params["recurrent_dropout"]
			)
		else:
			layer = LSTM(self.current_params["hidden_neurons"], 
				return_sequences=return_sequences, 
				kernel_initializer=self.current_params["kernel_initializer"], 
				recurrent_initializer=self.current_params["recurrent_initializer"], 
				recurrent_regularizer=self.__generate_regulariser(self.current_params["r_l1_reg"], self.current_params["r_l2_reg"]), 
				bias_regularizer=self.__generate_regulariser(self.current_params["b_l1_reg"], self.current_params["b_l2_reg"]),
				dropout=self.current_params["dropout"], 
				recurrent_dropout=self.current_params["recurrent_dropout"]
			)

		return layer 

	def __output_layer(self, possible_classes):
		""" Returns output layer of feed-forward neurons """

		return Dense(
			possible_classes,
			activation=self.current_params["activation"],
		)

class FFNN(RecurrentModel):
	def __hidden_layer(self, return_sequences):
		return self.__middle_hidden_layer()

	def __middle_hidden_layer(self):
		layer = Dense(self.current_params["hidden_neurons"], 
			activation=self.current_params["activation"],
			kernel_initializer=self.current_params["kernel_initializer"], 
			dropout=self.current_params["dropout"], 
		)
		return layer 

	def __output_layer(self, possible_classes):
		""" Returns output layer of feed-forward neurons """
		return Dense(
			possible_classes,
			activation=self.current_params["activation"],
		)


def generate_model(x,y,params,model_type="rnn"):
	"""return a new model based on 
	shape of training data x, shape of training data y 
	and parameters"""
	if model_type == "rnn":
		model_gen = RecurrentModel(x,y,params)
	elif model_type =="ffnn":
		model_gen = FFNN(x,y,params)
	else:
		raise KeyError("model parameter must be rnn or ffnn not {}".format(model))
	return model_gen.get_model()