import os
import sys

sys.path.append(os.getcwd())

from tkinter import Frame, Label, BOTH, Tk, LEFT, HORIZONTAL, Scale, Button, GROOVE, filedialog, PhotoImage, messagebox

import PIL.Image
import PIL.ImageTk
import numpy
import torch

from poser.morph_rotate_combine_poser import MorphRotateCombinePoser256Param6
from poser.poser import Poser
from tha.combiner import CombinerSpec
from tha.face_morpher import FaceMorpherSpec
from tha.two_algo_face_rotator import TwoAlgoFaceRotatorSpec
from util import extract_pytorch_image_from_filelike, rgba_to_numpy_image


class ManualPoserApp:
    def __init__(self,
                 master,
                 poser: Poser,
                 torch_device: torch.device):
        super().__init__()
        self.master = master
        self.poser = poser
        self.torch_device = torch_device

        self.master.title("Manual Poser")

        source_image_frame = Frame(self.master, width=256, height=256)
        source_image_frame.pack_propagate(0)
        source_image_frame.pack(side=LEFT)

        self.source_image_label = Label(source_image_frame, text="Nothing yet!")
        self.source_image_label.pack(fill=BOTH, expand=True)

        control_frame = Frame(self.master, borderwidth=2, relief=GROOVE)
        control_frame.pack(side=LEFT, fill='y')

        self.param_sliders = []
        for param in self.poser.pose_parameters():
            slider = Scale(control_frame,
                           from_=param.lower_bound,
                           to=param.upper_bound,
                           length=256,
                           resolution=0.001,
                           orient=HORIZONTAL)
            slider.set(param.default_value)
            slider.pack(fill='x')
            self.param_sliders.append(slider)

            label = Label(control_frame, text=param.display_name)
            label.pack()

        posed_image_frame = Frame(self.master, width=256, height=256)
        posed_image_frame.pack_propagate(0)
        posed_image_frame.pack(side=LEFT)

        self.posed_image_label = Label(posed_image_frame, text="Nothing yet!")
        self.posed_image_label.pack(fill=BOTH, expand=True)

        self.load_source_image_button = Button(control_frame, text="Load Image ...", relief=GROOVE,
                                               command=self.load_image)
        self.load_source_image_button.pack(fill='x')

        self.pose_size = len(self.poser.pose_parameters())
        self.source_image = None
        self.posed_image = None
        self.current_pose = None
        self.last_pose = None
        self.needs_update = False

        self.master.after(1000 // 30, self.update_image)

    def load_image(self):
        file_name = filedialog.askopenfilename(
            filetypes=[("PNG", '*.png')],
            initialdir="data/illust")
        if len(file_name) > 0:
            image = PhotoImage(file=file_name)
            if image.width() != self.poser.image_size() or image.height() != self.poser.image_size():
                message = "The loaded image has size %dx%d, but we require %dx%d." \
                          % (image.width(), image.height(), self.poser.image_size(), self.poser.image_size())
                messagebox.showerror("Wrong image size!", message)
            self.source_image_label.configure(image=image, text="")
            self.source_image_label.image = image
            self.source_image_label.pack()

            self.source_image = extract_pytorch_image_from_filelike(file_name).to(self.torch_device).unsqueeze(dim=0)
            self.needs_update = True

    def update_pose(self):
        self.current_pose = torch.zeros(self.pose_size, device=self.torch_device)
        for i in range(self.pose_size):
            self.current_pose[i] = self.param_sliders[i].get()
        self.current_pose = self.current_pose.unsqueeze(dim=0)

    def update_image(self):
        self.update_pose()
        if (not self.needs_update) and self.last_pose is not None and (
                (self.last_pose - self.current_pose).abs().sum().item() < 1e-5):
            self.master.after(1000 // 30, self.update_image)
            return
        if self.source_image is None:
            self.master.after(1000 // 30, self.update_image)
            return
        self.last_pose = self.current_pose

        posed_image = self.poser.pose(self.source_image, self.current_pose).detach().cpu()
        numpy_image = rgba_to_numpy_image(posed_image[0])
        pil_image = PIL.Image.fromarray(numpy.uint8(numpy.rint(numpy_image * 255.0)), mode='RGBA')
        photo_image = PIL.ImageTk.PhotoImage(image=pil_image)

        self.posed_image_label.configure(image=photo_image, text="")
        self.posed_image_label.image = photo_image
        self.posed_image_label.pack()
        self.needs_update = False

        self.master.after(1000 // 30, self.update_image)


if __name__ == "__main__":
    cuda = torch.device('cuda')
    poser = MorphRotateCombinePoser256Param6(
        morph_module_spec=FaceMorpherSpec(),
        morph_module_file_name="data/face_morpher.pt",
        rotate_module_spec=TwoAlgoFaceRotatorSpec(),
        rotate_module_file_name="data/two_algo_face_rotator.pt",
        combine_module_spec=CombinerSpec(),
        combine_module_file_name="data/combiner.pt",
        device=cuda)
    root = Tk()
    app = ManualPoserApp(master=root, poser=poser, torch_device=cuda)
    root.mainloop()