from __future__ import print_function, division import sys import os sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from lib import replay_memory from lib import util import Tkinter from PIL import Image, ImageTk import numpy as np import cPickle as pickle import time import imgaug as ia def numpy_to_tk_image(image): image_pil = Image.fromarray(image) image_tk = ImageTk.PhotoImage(image_pil) return image_tk def load_annotations(fp): if os.path.isfile(fp): return pickle.load(open(fp, "r")) else: return None def draw_normal_distribution(height, width, x, y, size): if 0 <= y < height and 0 <= x < width: pad_by = size * 10 img = np.zeros((pad_by+height+pad_by, pad_by+width+pad_by), dtype=np.float32) #img = img.pad(img, ((20, 20), (20, 20))) #normal = util.create_2d_gaussian(size=size*2, fwhm=size) normal = util.create_2d_gaussian(size=size*4, sigma=size) #print(normal) normal_h, normal_w = normal.shape normal_hh, normal_wh = normal_h//2, normal_w//2 #print("normal size", normal.shape) #print("img.shape", img.shape) #print("img[y-normal_hh:y+normal_hh, x-normal_wh:x+normal_wh]", img[y-normal_hh:y+normal_hh, x-normal_wh:x+normal_wh].shape) y1 = np.clip(y-normal_hh+pad_by, 0, img.shape[0]-1) #-(2*pad_by)) y2 = np.clip(y+normal_hh+pad_by, 0, img.shape[0]-1) #-(2*pad_by)) x1 = np.clip(x-normal_wh+pad_by, 0, img.shape[1]-1) #-(2*pad_by)) x2 = np.clip(x+normal_wh+pad_by, 0, img.shape[1]-1) #-(2*pad_by)) if x2 - x1 > 0 and y2 - y1 > 0: img[y1:y2, x1:x2] = normal return img[pad_by:-pad_by, pad_by:-pad_by] else: return np.zeros((height, width), dtype=np.float32) class GridAnnotationWindow(object): def __init__(self, master, canvas, memory, current_state_idx, annotations, current_anno_attribute_name, save_to_fp, every_nth_example=10, zoom_factor=4): self.master = master self.canvas = canvas self.memory = memory self.current_state_idx = current_state_idx self.annotations = annotations if annotations is not None else dict() self.current_annotation = None self.background_label = None self.eraser = False self.dirty = False self.brush_size = 3 self.last_autosave = 0 self.heatmap_alpha = 0.3 self.heatmap_alphas = (0.1, 0.3) self.current_anno_attribute_name = current_anno_attribute_name self.every_nth_example = every_nth_example self.zoom_factor = zoom_factor self.autosave_every_nth = 20 self.save_to_fp = save_to_fp self.is_showing_directly_previous_state = False self.directly_previous_state = None self.current_state = None self.switch_to_state(self.current_state_idx, autosave=False) #self.current_state = memory.get_state_by_id(current_state_idx) @staticmethod def create(memory, current_anno_attribute_name, save_to_fp, every_nth_example=10, zoom_factor=4): print("Loading previous annotations...") annotations = load_annotations(save_to_fp) #is_annotated = dict([(str(annotation.idx), True) for annotation in annotations]) current_state_idx = memory.id_min if annotations is not None: while current_state_idx < memory.id_max: key = str(current_state_idx) if key not in annotations or current_anno_attribute_name not in annotations[key]: break current_state_idx += every_nth_example print("ID of first unannotated state: %d" % (current_state_idx,)) master = Tkinter.Tk() state = memory.get_state_by_id(current_state_idx) canvas_height = state.screenshot_rs.shape[0] * zoom_factor canvas_width = state.screenshot_rs.shape[1] * zoom_factor print("canvas height, width:", canvas_height, canvas_width) canvas = Tkinter.Canvas(master, width=canvas_width, height=canvas_height) canvas.pack() canvas.focus_set() #y = int(canvas_height / 2) #w.create_line(0, y, canvas_width, y, fill="#476042") message = Tkinter.Label(master, text="Click to draw annotation. Press E to switch to eraser mode. Press S to save. Use Numpad +/- for brush size.") message.pack(side=Tkinter.BOTTOM) window_state = GridAnnotationWindow( master, canvas, memory, current_state_idx, annotations, current_anno_attribute_name, save_to_fp, every_nth_example, zoom_factor ) #canvas.bind("<Button-1>", OnPaint(window_state)) #master.bind("<Button-1>", lambda event: print(event)) #master.bind("<Button-3>", lambda event: print("right", event)) #master.bind("<ButtonPress-1>", lambda event: print("press", event)) master.bind("<B1-Motion>", window_state.on_left_mouse_button) #master.bind("<ButtonRelease-1>", lambda event: print("release", event)) master.bind("<B3-Motion>", window_state.on_right_mouse_button) canvas.bind("<e>", lambda event: window_state.toggle_eraser()) canvas.bind("<s>", lambda event: window_state.save_annotations(force=True)) canvas.bind("<w>", lambda event: window_state.toggle_heatmap()) canvas.bind("<p>", lambda event: window_state.toggle_previous_screenshot()) canvas.bind("<Left>", lambda event: window_state.previous_state(autosave=True)) canvas.bind("<Right>", lambda event: window_state.next_state(autosave=True)) canvas.bind("<KP_Add>", lambda event: window_state.increase_brush_size()) canvas.bind("<KP_Subtract>", lambda event: window_state.decrease_brush_size()) return window_state @property def grid(self): return self.current_annotation[self.current_anno_attribute_name] def toggle_eraser(self): self.eraser = not self.eraser print("Eraser set to %s" % (self.eraser,)) def toggle_heatmap(self): pos = self.heatmap_alphas.index(self.heatmap_alpha) self.heatmap_alpha = self.heatmap_alphas[(pos+1) % len(self.heatmap_alphas)] self.set_canvas_background(self._generate_heatmap()) def toggle_previous_screenshot(self): if self.directly_previous_state is not None: if self.is_showing_directly_previous_state: self.set_canvas_background(self._generate_heatmap()) else: self.set_canvas_background(self.directly_previous_state.screenshot_rs) self.is_showing_directly_previous_state = not self.is_showing_directly_previous_state def increase_brush_size(self): self.brush_size = np.clip(self.brush_size+1, 1, 100) print("Increased brush size to %d" % (self.brush_size,)) def decrease_brush_size(self): self.brush_size = np.clip(self.brush_size-1, 1, 100) print("Decreased brush size to %d" % (self.brush_size,)) def previous_state(self, autosave): print("Switching to previous state...") self.current_state_idx -= self.every_nth_example assert self.current_state_idx >= self.memory.id_min, "Start of memory reached (%d vs %d)" % (self.current_state_idx, self.memory.id_min) self.switch_to_state(self.current_state_idx, autosave=autosave) def next_state(self, autosave): print("Switching to next state...") self.current_state_idx += self.every_nth_example assert self.current_state_idx <= self.memory.id_max, "End of memory reached (%d vs %d)" % (self.current_state_idx, self.memory.id_max) self.switch_to_state(self.current_state_idx, autosave=autosave) def switch_to_state(self, idx, autosave): print("Switching to state %d (autosave=%s)..." % (idx, str(autosave))) self.directly_previous_state = self.memory.get_state_by_id(idx-1) self.current_state = self.memory.get_state_by_id(idx) assert self.current_state is not None self.current_state_idx = idx if autosave: if (self.last_autosave+1) % self.autosave_every_nth == 0: # only autosaves if dirty flag is true, ie any example was changed self.save_annotations() self.last_autosave = 0 else: self.last_autosave += 1 print("last_autosave=", self.last_autosave) key = str(self.current_state_idx) if key in self.annotations: self.current_annotation = self.annotations[key] else: self.current_annotation = { "idx": self.current_state_idx, "from_datetime": self.current_state.from_datetime, "screenshot_rs": self.current_state.screenshot_rs, } self.annotations[key] = self.current_annotation annos_done = [key for key in self.current_annotation.keys() if key not in ["idx", "from_datetime", "screenshot_rs"]] print("Annotations added to this state: %s" % (", ".join(annos_done))) if self.current_anno_attribute_name not in annos_done: print("This state has not yet been annotated with '%s'." % (self.current_anno_attribute_name,)) img = self.current_state.screenshot_rs empty_grid = np.zeros((img.shape[0], img.shape[1]), dtype=np.float32) self.current_annotation[self.current_anno_attribute_name] = empty_grid self.is_showing_directly_previous_state = False self.update_annotation_grid(self.grid, initial=True) def save_annotations(self, force=False): #print(self.annotations) if self.dirty or force: print("Saving...") with open(self.save_to_fp, "w") as f: pickle.dump(self.annotations, f, protocol=-1) self.dirty = False print("Finished saving.") else: print("Not saved (not marked dirty)") """ def redraw_canvas(self): img = generate_canvas_image(self.current_state.screenshot_rs, self.grid) self.canvas.delete(Tkinter.ALL) self.set_canvas_background(self.canvas, img) """ def update_annotation_grid(self, annotation_grid, initial=False): self.current_annotation[self.current_anno_attribute_name] = annotation_grid #self.redraw_canvas() #img = generate_canvas_image(self.current_state.screenshot_rs, annotation_grid) img_heatmap = self._generate_heatmap() self.set_canvas_background(img_heatmap) if not initial: self.dirty = True def set_canvas_background(self, image): if self.background_label is None: # initialize background image label (first call) #img = self.current_state.screenshot_rs #bg_img_tk = numpy_to_tk_image(np.zeros(img.shape)) img_heatmap = self._generate_heatmap() img_heatmap_rs = ia.imresize_single_image(img_heatmap, (img_heatmap.shape[0]*self.zoom_factor, img_heatmap.shape[1]*self.zoom_factor), interpolation="nearest") bg_img_tk = numpy_to_tk_image(img_heatmap_rs) self.background_label = Tkinter.Label(self.canvas, image=bg_img_tk) self.background_label.place(x=0, y=0, relwidth=1, relheight=1, anchor=Tkinter.NW) self.background_label.image = bg_img_tk #print("image size", image.shape) #print("image height, width", image.to_array().shape) image_rs = ia.imresize_single_image(image, (image.shape[0]*self.zoom_factor, image.shape[1]*self.zoom_factor), interpolation="nearest") image_tk = numpy_to_tk_image(image_rs) self.background_label.configure(image=image_tk) self.background_label.image = image_tk def _generate_heatmap(self): return util.draw_heatmap_overlay(self.current_state.screenshot_rs, self.grid, alpha=self.heatmap_alpha) def on_left_mouse_button(self, event): #canvas = event.widget x = self.canvas.canvasx(event.x) / self.zoom_factor y = self.canvas.canvasy(event.y) / self.zoom_factor height, width = self.current_state.screenshot_rs.shape[0:2] #x = event.x #y = event.y #canvas.delete(Tkinter.ALL) grid = self.grid normal = draw_normal_distribution(height, width, int(x), int(y), self.brush_size) #normal = np.zeros_like(grid) #normal[int(y)-2:int(y)+2, int(x)-2:int(x)+2] = 1.0 if not self.eraser: #grid = np.clip(grid + normal, 0, 1) grid = np.maximum(grid, normal) else: grid = grid - normal grid = np.clip(grid, 0, 1) self.update_annotation_grid(grid) #time.sleep(0.1) def on_right_mouse_button(self, event): x = self.canvas.canvasx(event.x) / self.zoom_factor y = self.canvas.canvasy(event.y) / self.zoom_factor height, width = self.current_state.screenshot_rs.shape[0:2] grid = self.grid normal = draw_normal_distribution(height, width, int(x), int(y), self.brush_size) grid = grid - normal grid = np.clip(grid, 0, 1) self.update_annotation_grid(grid)