# -*- coding: utf-8 -*-
"""
Created on Sun Apr 28 18:32:15 2019

@author: wmy
"""

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
from keras import backend as K
from keras.losses import mean_absolute_error, mean_squared_error
from keras.models import load_model
from keras.optimizers import Adam
import random
from model import wdsr_a, wdsr_b
from optimizer import AdamWithWeightsNormalization
from utils import DataLoader

class SuperResolution(object):
    
    def __init__(self, scale=4, num_res_blocks=32, pretrained_weights=None, name=None):
        self.scale = scale
        self.num_res_blocks = num_res_blocks
        self.model = wdsr_b(scale=scale, num_res_blocks=num_res_blocks)
        self.model.compile(optimizer=AdamWithWeightsNormalization(lr=0.001), \
                           loss=self.mae, metrics=[self.psnr])
        if pretrained_weights != None:
            self.model.load_weights(pretrained_weights)
            print("[OK] weights loaded.")
            pass
        self.data_loader = DataLoader(scale=scale, crop_size=256)
        self.pretrained_weights = pretrained_weights
        self.default_weights_save_path = 'weights/wdsr-b-' + \
        str(self.num_res_blocks) + '-x' + str(self.scale) + '.h5'
        self.name = name
        pass

    def mae(self, hr, sr):
        margin = (tf.shape(hr)[1] - tf.shape(sr)[1]) // 2
        hr_crop = tf.cond(tf.equal(margin, 0), lambda: hr, lambda: hr[:, margin:-margin, margin:-margin, :])
        hr = K.in_train_phase(hr_crop, hr)
        hr.uses_learning_phase = True
        return mean_absolute_error(hr, sr)

    def psnr(self, hr, sr):
        margin = (tf.shape(hr)[1] - tf.shape(sr)[1]) // 2
        hr_crop = tf.cond(tf.equal(margin, 0), lambda: hr, lambda: hr[:, margin:-margin, margin:-margin, :])
        hr = K.in_train_phase(hr_crop, hr)
        hr.uses_learning_phase = True
        return tf.image.psnr(hr, sr, max_val=255)
    
    def train(self, epoches=10000, batch_size=8, weights_save_path=None):
        if weights_save_path == None:
            weights_save_path = self.default_weights_save_path
            pass
        for epoch in range(epoches):
            for batch_i, (lrs, hrs) in enumerate(self.data_loader.batches(batch_size=batch_size)):
                temp_loss, temp_psnr = self.model.train_on_batch(lrs, hrs)
                print("[epoch: {}/{}][batch: {}/{}][loss: {}][psnr: {}]".format(epoch+1, epoches, \
                      batch_i+1, self.data_loader.n_batches, temp_loss, temp_psnr))
                if (batch_i+1) % 25 == 0:
                    self.sample(epoch=epoch+1, batch=batch_i+1)
                    pass
                pass
            self.model.save_weights(weights_save_path)
            print("[OK] weights saved.")
            pass
        pass
    
    def sample(self, setpath='datasets/train', save_folder='samples', epoch=1, batch=1):
        images = self.data_loader.search(setpath)
        image = random.choice(images)
        hr = self.data_loader.imread(image)
        lr = self.data_loader.downsampling(hr)
        lr_resize = lr.resize(hr.size)
        lr =  np.asarray(lr)
        sr = self.model.predict(np.array([lr]))[0]
        sr = np.clip(sr, 0, 255)
        sr = sr.astype('uint8')
        lr = Image.fromarray(lr)
        sr = Image.fromarray(sr)
        lr_resize.save(save_folder + "/" + "epoch_" + str(epoch) + "_batch_" + str(batch) + "_lr.jpg")
        sr.save(save_folder + "/" + "epoch_" + str(epoch) + "_batch_" + str(batch) + "_sr.jpg")
        hr.save(save_folder + "/" + "epoch_" + str(epoch) + "_batch_" + str(batch) + "_hr.jpg")
        pass
    
    pass


sr = SuperResolution(pretrained_weights='./weights/wdsr-b-32-x4.h5')
sr.train()