from tkinter import *
from tkinter import messagebox
import math

CELL_SIZE = 80
CELL_PADDING = 10
ARROW_LENGTH = CELL_SIZE / 2

class GridWorldWindow(object):
    """Manages all of the UI
    """
    def __init__(self, metadata):
        self.window = Tk()
        self.window.title('Gridworld')
        self.window.geometry('{}x{}'.format(1080, 720))

        # extract data from the JSON
        self.grid_width = metadata['width']
        self.grid_height = metadata['height']
        self.obstacles = [tuple(obstacle) for obstacle in metadata['obstacles']]
        self.terminals = [tuple(terminal['state']) for terminal in metadata['terminals']]

        self.canvas_width = metadata['width'] * CELL_SIZE
        self.canvas_height = metadata['height'] * CELL_SIZE

        # create the tkinder IDs for all of the modifiable UI
        self.ids_text = [[0 for col in range(self.grid_width)] for row in range(self.grid_height)]
        self.ids_rect = [[0 for col in range(self.grid_width)] for row in range(self.grid_height)]
        self.ids_arrow = [[0 for col in range(self.grid_width)] for row in range(self.grid_height)]

        self._create_buttons()

        self.canvas = Canvas(self.window, width=self.canvas_width, height=self.canvas_height, bg='black')
        self.canvas.pack(padx=10, pady=10)

        self._create_grid()

    def _create_buttons(self):
        self.frame_value_buttons = Frame(self.window)
        self.frame_value_buttons.pack(padx=5, pady=5)

        self.frame_policy_buttons = Frame(self.window)
        self.frame_policy_buttons.pack(padx=5, pady=5)

        self.frame_reset_buttons = Frame(self.window)
        self.frame_reset_buttons.pack(padx=5, pady=5)
        
        self.btn_value_iteration_1_step = Button(self.frame_value_buttons, text='1-Step Value Iteration', anchor=W)
        self.btn_value_iteration_1_step.pack(side=LEFT)

        self.btn_value_iteration_100_steps = Button(self.frame_value_buttons, text='100-Step Value Iteration', anchor=E)
        self.btn_value_iteration_100_steps.pack(side=LEFT)

        self.btn_value_iteration_slow = Button(self.frame_value_buttons, text='Slow Value Iteration', anchor=E)
        self.btn_value_iteration_slow.pack(side=LEFT)

        self.btn_policy_iteration_1_step = Button(self.frame_policy_buttons, text='1-Step Policy Iteration', anchor=E)
        self.btn_policy_iteration_1_step.pack(side=LEFT)

        self.btn_policy_iteration_100_steps = Button(self.frame_policy_buttons, text='100-Step Policy Iteration', anchor=E)
        self.btn_policy_iteration_100_steps.pack(side=LEFT)

        self.btn_policy_iteration_slow = Button(self.frame_policy_buttons, text='Slow Policy Iteration', anchor=E)
        self.btn_policy_iteration_slow.pack(side=LEFT)

        self.btn_reset = Button(self.frame_reset_buttons, text='Reset', anchor=E)
        self.btn_reset.pack(side=LEFT)

    def _create_grid(self):
        for row in range(self.grid_height):
            for col in range(self.grid_width):
                if (row, col) in self.obstacles:
                    fill = 'grey'
                    text = None
                else:
                    fill = None
                    text = '0.00'

                self.ids_rect[row][col] = self.canvas.create_rectangle(col * CELL_SIZE, row * CELL_SIZE, (col+1) * CELL_SIZE, (row+1) * CELL_SIZE, fill=fill, outline='white')
                if (row, col) in self.terminals:
                    self.canvas.create_rectangle(col * CELL_SIZE + CELL_PADDING, row * CELL_SIZE + CELL_PADDING, (col+1) * CELL_SIZE - CELL_PADDING, (row+1) * CELL_SIZE - CELL_PADDING, fill=fill, outline='white')

                self.ids_text[row][col] = self.canvas.create_text(col * CELL_SIZE + CELL_SIZE/2, row * CELL_SIZE + CELL_SIZE/2,  text=text, fill='white')
                self.ids_arrow[row][col] = self.canvas.create_line(0, 0, 0, 0, width=2, arrow=LAST, fill='white')

    def _compute_color(self, value):
        # negative values are redder while positive values are greener
        if value == 0:
            return '#000000'
        elif value > 0:
            g = math.floor(255 if value >= 1.0 else value * 256)
            return '#{:02x}{:02x}{:02x}'.format(0, g, 0)
        elif value < 0:
            r = math.floor(255 if -value >= 1.0 else -value * 256)
            return '#{:02x}{:02x}{:02x}'.format(r, 0, 0)
    
    def show_dialog(self, text):
        messagebox.showinfo('Info', text)

    def update_grid(self, values, policy):
        for state, value in values.items():
            rect_id = self.ids_rect[state[0]][state[1]]
            text_id = self.ids_text[state[0]][state[1]]
            arrow_id = self.ids_arrow[state[0]][state[1]]

            self.canvas.itemconfig(rect_id, fill=self._compute_color(value))
            self.canvas.itemconfig(text_id, text='{:.2f}'.format(value))

            if state not in self.terminals:
                self.canvas.coords(arrow_id,
                    state[1] * CELL_SIZE +  CELL_SIZE/2 + policy[state][1] * ARROW_LENGTH - policy[state][1],
                    state[0] * CELL_SIZE +  CELL_SIZE/2 + policy[state][0] * ARROW_LENGTH - policy[state][0],
                    state[1] * CELL_SIZE +  CELL_SIZE/2 + policy[state][1] * ARROW_LENGTH,
                    state[0] * CELL_SIZE +  CELL_SIZE/2 + policy[state][0] * ARROW_LENGTH)
    
    def clear(self):
        for row in range(self.grid_height):
            for col in range(self.grid_width):
                rect_id = self.ids_rect[row][col]
                text_id = self.ids_text[row][col]
                arrow_id = self.ids_arrow[row][col]

                if (row, col) in self.obstacles:
                    fill = 'grey'
                    text = None
                else:
                    fill = self._compute_color(0)
                    text = '0.00'
                self.canvas.itemconfig(rect_id, fill=fill)
                self.canvas.itemconfig(text_id, text=text)
                self.canvas.coords(arrow_id, 0, 0, 0, 0)

    def run(self):
        # run the UI loop
        mainloop()