from sklearn.metrics import mean_squared_error
from .models import make_model
import numpy as np
import math
from keras.optimizers import Adam
from keras.models import clone_model
import keras
import keras.backend as K

def fine_tuning_loss(y_true,y_pred):    #
    return K.mean(K.square(y_true[:,:,:,1]-(y_pred[:,:,:,0]*y_true[:,:,:,1]+y_pred[:,:,:,1])) + 2*y_pred[:,:,:,0]*K.square(y_true[:,:,:,2]) - K.square(y_true[:,:,:,2]))

sigma_arr = np.array([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19])
sigma_arr = sigma_arr*5

class Fine_tuning:

    def __init__(self, noisy_image, ep):
        
        self.noisy_img = np.float32(noisy_image)
        self.noisy_img /= 255.
        
        self.img_x = self.noisy_img.shape[0]
        self.img_y = self.noisy_img.shape[1]
        
        self.ep = ep
        self.mini_batch_size = 1
        
        self.model_copy = make_model(self.img_x, self.img_y)
        #self.model_copy.save_weights('./weights/sigma_estimation_model.hdf5')
    
    def preprocessing(self, sigma_hat):
        
        self.X_data = (self.noisy_img - 0.5) / 0.2
        self.X_data = self.X_data.reshape(1,self.img_x, self.img_y, 1)
        
        self.Y_data = np.zeros((1,self.img_x, self.img_y,3))
        #self.Y_data[:,:,:,0] = self.clean_img
        self.Y_data[:,:,:,1] = self.noisy_img
        self.Y_data[:,:,:,2] = sigma_hat/255.
    
    def get_model(self):
        
        model = clone_model(self.model_copy)
        model.load_weights('./weights/sigma_estimation_model.hdf5')
        adam=Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0)
        model.compile(loss=fine_tuning_loss, optimizer=adam)
        
        return model

    def estimation(self):
        
        min_sig_index = 0
        max_sig_index = 20
        
        while(True):

            save_result = Save_result()
            
            sigma_hat_index = (min_sig_index + max_sig_index)//2
            
            sigma_hat = sigma_arr[sigma_hat_index]
            self.preprocessing(sigma_hat)
            
            print ('current sigma_hat : ', sigma_hat)
            print ('')

            model = self.get_model()
            model.fit(self.X_data, self.Y_data, verbose=0, batch_size = self.mini_batch_size, epochs = self.ep,callbacks=[save_result])
            status = save_result.get_result()

            del model
            
            if status == True:
                max_sig_index = sigma_hat_index
            else:
                min_sig_index = sigma_hat_index

            if sigma_hat_index == (min_sig_index + max_sig_index)//2:
                sigma_hat = sigma_arr[sigma_hat_index]
                break
        

        return sigma_hat

class Save_result(keras.callbacks.Callback):
    def __init__(self):    
        self.loss_loss_than_zero = False
        return
    
    def on_train_begin(self, logs={}):
        return
        
    def on_train_end(self, logs={}):
        return
        
    def on_epoch_end(self, epoch, logs={}):
        current_loss = logs.get('loss')
        if current_loss < 0:
            print("Epoch %05d: early stopping, Loss < 0" % epoch)
            self.model.stop_training = True
            self.loss_loss_than_zero = True
        return
    
    def get_result(self):
        return self.loss_loss_than_zero