from PyQt5.QtWidgets import * #QWidget, QApplication
from PyQt5.QtGui import * #QPainter, QPainterPath
from PyQt5.QtCore import * #Qt
import sys
from PIL import Image
import numpy as np
import time
import cv2
import torch
from ui_shadow_draw.ui_sketch import UISketch
from ui_shadow_draw.ui_recorder import UIRecorder
import qdarkstyle
from ui_shadow_draw.gangate_draw import GANGATEDraw
from ui_shadow_draw.gangate_vis import GANGATEVis
from data.base_dataset import get_transform


import time
import os
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from util import html
from util import util

opt = TestOptions().parse()
opt.nThreads = 1   # test code only supports nThreads = 1
opt.batchSize = 1  # test code only supports batchSize = 1
opt.serial_batches = True  # no shuffle
opt.no_flip = True  # no flip
opt.loadSize=256

transform = get_transform(opt)
model = create_model(opt)

class GANGATEGui(QWidget):
    def __init__(self,win_size= 384 ,img_size = 384):
        QWidget.__init__(self)

        self.win_size = win_size
        self.img_size = img_size

        self.drawWidget = GANGATEDraw(win_size=self.win_size,img_size=self.img_size)
        self.drawWidget.setFixedSize(win_size,win_size)

        self.visWidget = GANGATEVis(win_size=self.win_size,img_size=self.img_size)
        self.visWidget.setFixedSize(win_size,win_size)

        vbox = QVBoxLayout()

        self.drawWidgetBox = QGroupBox()
        self.drawWidgetBox.setTitle('Drawing Pad')
        vbox_t = QVBoxLayout()
        vbox_t.addWidget(self.drawWidget)
        self.drawWidgetBox.setLayout(vbox_t)
        vbox.addWidget(self.drawWidgetBox)

        self.labelId=0

        self.bBicycle = QRadioButton("Bicycle")
        self.bBicycle.setToolTip("This button enables generation of a Bicycle")

        self.bCat = QRadioButton("Cat")
        self.bCat.setToolTip("This button enables generation of a Cat")

        self.bChair = QRadioButton("Chair")
        self.bChair.setToolTip("This button enables generation of a Chair")

        self.bHamburger = QRadioButton("Hamburger")
        self.bHamburger.setToolTip("This button enables generation of a Hamburger")

        self.bPizza = QRadioButton("Pizza")
        self.bPizza.setToolTip("This button enables generation of a Pizza")

        self.bTeddy = QRadioButton("Teddy")
        self.bTeddy.setToolTip("This button enables generation of a Teddy")


        bhbox = QGridLayout()

        bhbox.addWidget(self.bBicycle,0,0)
        bhbox.addWidget(self.bCat,1,0)
        bhbox.addWidget(self.bChair,2,0)
        bhbox.addWidget(self.bHamburger,0,1)
        bhbox.addWidget(self.bPizza,1,1)
        bhbox.addWidget(self.bTeddy,2,1)



        self.bGenerate = QPushButton('Generate !')
        self.bGenerate.setToolTip("This button generates the final image to render")


        self.bReset = QPushButton('Reset !')
        self.bReset.setToolTip("This button resets the drawing pad !")



        self.bRandomize = QPushButton('Dice')
        self.bRandomize.setToolTip("This button generates new set of generations the drawing pad !")


        self.bMoveStroke = QRadioButton('Move Stroke')
        self.bMoveStroke.setToolTip("This button moves the selected stroke on the drawing pad !")

        self.bWarpStroke = QRadioButton('Warp Stroke')
        self.bWarpStroke.setToolTip("This button warps the selected stroke on the drawing pad !")

        self.bDrawStroke = QRadioButton('Draw Stroke')
        self.bDrawStroke.setToolTip("This button reverts back to the drawing mode on the drawing pad !")

        self.bSelectPatch = QRadioButton('Select Patch')
        self.bSelectPatch.setToolTip("This button selects patches from the shadows!")

        self.bEnableShadows = QCheckBox('Enable Shadows')
        self.bEnableShadows.toggle()

        hbox = QHBoxLayout()
        hbox.addLayout(vbox)

        vbox3 = QVBoxLayout()
        self.visWidgetBox = QGroupBox()
        self.visWidgetBox.setTitle('Generations')

        vbox_t3 = QVBoxLayout()
        vbox_t3.addWidget(self.visWidget)
        self.visWidgetBox.setLayout(vbox_t3)
        vbox3.addWidget(self.visWidgetBox)



        bhbox_controls = QGridLayout()

        bhbox_controls.addWidget(self.bReset,0,0)
        bhbox_controls.addWidget(self.bRandomize,0,1)
        bhbox_controls.addWidget(self.bDrawStroke,0,2)
        bhbox_controls.addWidget(self.bMoveStroke,0,3)
        bhbox_controls.addWidget(self.bWarpStroke,0,4)
        bhbox_controls.addWidget(self.bSelectPatch,0,5)
        bhbox_controls.addWidget(self.bEnableShadows,0,6)


        hbox.addLayout(vbox3)
        hbox.addLayout(bhbox)

        controlBox = QGroupBox()
        controlBox.setTitle('Controls')

        controlBox.setLayout(bhbox_controls)

        vbox_final = QVBoxLayout()
        vbox_final.addLayout(hbox)
        vbox_final.addWidget(controlBox)
        self.setLayout(vbox_final)

        self.bTeddy.setChecked(True)
        self.labelId=5
        self.bDrawStroke.setChecked(True)

        self.enable_shadow = True
        self.which_shadow_img = 0

        self.bBicycle.clicked.connect(self.Bicycle)
        self.bCat.clicked.connect(self.Cat)
        self.bChair.clicked.connect(self.Chair)
        self.bHamburger.clicked.connect(self.Hamburger)
        self.bPizza.clicked.connect(self.Pizza)
        self.bTeddy.clicked.connect(self.Teddy)

        self.bGenerate.clicked.connect(self.generate)
        self.bReset.clicked.connect(self.reset)
        self.bRandomize.clicked.connect(self.randomize)
        self.bMoveStroke.clicked.connect(self.move_stroke)
        self.bWarpStroke.clicked.connect(self.warp_stroke)
        self.bDrawStroke.clicked.connect(self.draw_stroke)
        self.bSelectPatch.clicked.connect(self.select_patch)
        self.bEnableShadows.stateChanged.connect(self.toggle_shadow)


    def toggle_shadow(self,state):
        if state == Qt.Checked:
            self.enable_shadow=True
        else:
            self.enable_shadow=False
            self.drawWidget.cycleShadows()
        self.generate()


    def Bicycle(self):
        self.labelId = 0

    def Cat(self):
        self.labelId = 1

    def Chair(self):
        self.labelId = 2

    def Hamburger(self):
        self.labelId = 3

    def Pizza(self):
        self.labelId = 4

    def Teddy(self):
        self.labelId = 5

    def get_network_input(self):
        cv2_scribble = self.drawWidget.getDrawImage()
        cv2_scribble = cv2.cvtColor(cv2_scribble,cv2.COLOR_BGR2RGB)
        cv2.imwrite('./imgs/current_scribble.jpg',cv2_scribble)
        pil_scribble = Image.fromarray(cv2_scribble)

        A = transform(pil_scribble)
        A=A.resize_(1,opt.input_nc,128,128)
        A=A.expand(opt.num_interpolate,opt.input_nc,128,128)
        B = A
        label = torch.LongTensor([self.labelId])
        label = label.expand(opt.num_interpolate)
        data = {'A': A,'A_sparse':A,'A_mask':A,
                'B': B,'A_paths': '', 'B_paths': '', 'label': label }
        return data

    def browse(self,pos_y,pos_x):
        num_rows = int(opt.num_interpolate/2)
        num_cols = 2
        div_rows = int(self.img_size/num_rows)
        div_cols = int(self.img_size/num_cols)

        which_row = int(pos_x / div_rows)
        which_col = int(pos_y / div_cols)


        cv2_gallery = cv2.imread('imgs/fake_B_gallery.png')
        cv2_gallery = cv2.resize(cv2_gallery,(self.img_size,self.img_size))

        cv2_gallery = cv2.rectangle(cv2_gallery, ( which_col * div_cols , which_row * div_rows  ) , ( (which_col + 1) * div_cols , (which_row + 1) * div_rows  ) , (0,255,0) , 8)
        self.visWidget.update_vis_cv2(cv2_gallery)

        cv2_img = cv2.imread('imgs/test_fake_B_shadow.png')
        cv2_img = cv2.resize(cv2_img,(self.img_size,self.img_size))

        which_highlight = which_row * 2 + which_col
        img_gray = cv2.imread('imgs/test_%d_L_fake_B_inter.png' % (which_highlight),cv2.IMREAD_GRAYSCALE)
        img_gray = cv2.resize(img_gray,(self.img_size,self.img_size))
        (thresh,im_bw)=cv2.threshold(img_gray,128,255,cv2.THRESH_BINARY | cv2.THRESH_OTSU)
        cv2_img[np.where(im_bw==[0])] = [0,255,0]


        self.drawWidget.setShadowImage(cv2_img)

    def generate(self):
        #pass
        data=self.get_network_input()
        model.set_input(data)
        visuals=model.get_latent_noise_visualization()

        image_dir='./imgs'
        for label,image_numpy in visuals.items():
            image_name='test_%s.png' % (label)
            save_path=os.path.join(image_dir,image_name)
            util.save_image(image_numpy,save_path)
        ## convert back from pil image to cv2 image


        if self.enable_shadow:
            cv2_img = cv2.imread('imgs/test_fake_B_shadow.png')
        else:
            cv2_img = cv2.imread('imgs/test_%d_L_fake_B_inter.png'%(self.which_shadow_img))

        self.drawWidget.setShadowImage(cv2_img)

        self.visWidget.update_vis('imgs/fake_B_gallery.png')

    def cycle_shadow_image(self):
        self.which_shadow_img+=1
        self.which_shadow_img = self.which_shadow_img%opt.num_interpolate
        self.generate()
    def get_paste_image(self):
        cv2_img = cv2.imread('imgs/test_%d_L_fake_B_inter.png'%(self.which_shadow_img))
        cv2_img = cv2.resize(cv2_img,(self.img_size,self.img_size))

        return cv2_img

    def reset(self):
        self.drawWidget.reset()

    def move_stroke(self):
        self.drawWidget.move_stroke()

    def warp_stroke(self):
        self.drawWidget.warp_stroke()

    def draw_stroke(self):
        self.drawWidget.draw_stroke()

    def select_patch(self):
        self.drawWidget.select_patch()

    def randomize(self):
        model.randomize_noise()
        self.generate()
    def scribble(self):
        self.drawWidget.scribble()

    def erase(self):
        self.drawWidget.erase()


if __name__ == '__main__':
    app = QApplication(sys.argv)
    window = GANGATEGui()
    window.setWindowTitle('iSketchNFill')
    window.setWindowFlags(window.windowFlags() & ~Qt.WindowMaximizeButtonHint)   # fix window siz
    window.show()
    sys.exit(app.exec_())