from __future__ import print_function, division import sys import os sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from lib import replay_memory from common import load_annotations, numpy_to_tk_image import Tkinter from collections import OrderedDict import imgaug as ia import cPickle as pickle class AttributeGroup(object): def __init__(self, name, name_shown, attributes=None, default_attribute=None): self.name = name self.name_shown = name_shown self.attributes = attributes if attributes is not None else [] self.default_attribute = default_attribute def append(self, att): self.attributes.append(att) def get_by_name(self, name): atts = [att for att in self.attributes if att.name == name] return atts[0] if len(atts) > 0 else None def set_default_by_name(self, name): self.default_attribute = self.get_by_name(name) class Attribute(object): def __init__(self, name, name_shown): self.name = name self.name_shown = name_shown """ ATTRIBUTE_GROUPS = [ ( "Road Type", "highway", OrderedDict([ ("country_road", "Country Road"), ("highway", "Highway"), ("highway_entry_exit", "Highway entry/exit"), ("open_area", "Open Area / Parking Lot"), ("fuel_station", "Fuel Station"), ("hotel", "Hotel"), ("rest_area", "Rest Area"), ("city_road", "City Road"), ("toll_booth", "Toll Booth") ]) ) ] """ ATTRIBUTE_GROUP_ROAD_TYPE = AttributeGroup("road_type", "Road Type") ATTRIBUTE_GROUP_ROAD_TYPE.append(Attribute("country_road", "Country Road")) ATTRIBUTE_GROUP_ROAD_TYPE.append(Attribute("highway", "Highway")) ATTRIBUTE_GROUP_ROAD_TYPE.append(Attribute("highway_entry_exit", "Highway entry/exit")) ATTRIBUTE_GROUP_ROAD_TYPE.append(Attribute("open_area", "Open Area / Parking Lot")) ATTRIBUTE_GROUP_ROAD_TYPE.append(Attribute("fuel_station", "Fuel Station")) ATTRIBUTE_GROUP_ROAD_TYPE.append(Attribute("hotel", "Hotel")) ATTRIBUTE_GROUP_ROAD_TYPE.append(Attribute("rest_area", "Rest Area")) ATTRIBUTE_GROUP_ROAD_TYPE.append(Attribute("city_road", "City Road")) ATTRIBUTE_GROUP_ROAD_TYPE.append(Attribute("toll_booth", "Toll Booth")) ATTRIBUTE_GROUP_ROAD_TYPE.append(Attribute("other", "Other")) ATTRIBUTE_GROUP_ROAD_TYPE.set_default_by_name("highway") ATTRIBUTE_GROUP_INTERSECTION = AttributeGroup("intersection", "Intersection") ATTRIBUTE_GROUP_INTERSECTION.append(Attribute("none", "None")) ATTRIBUTE_GROUP_INTERSECTION.append(Attribute("t-left", "T (left -|)")) ATTRIBUTE_GROUP_INTERSECTION.append(Attribute("t-right", "T (right |-)")) ATTRIBUTE_GROUP_INTERSECTION.append(Attribute("t-frontal", "T (frontal -.-)")) ATTRIBUTE_GROUP_INTERSECTION.append(Attribute("cross", "Cross")) ATTRIBUTE_GROUP_INTERSECTION.append(Attribute("roundabout", "Roundabout")) ATTRIBUTE_GROUP_INTERSECTION.append(Attribute("other", "Other")) ATTRIBUTE_GROUP_INTERSECTION.set_default_by_name("none") ATTRIBUTE_GROUP_DIRECTION = AttributeGroup("direction", "Direction") ATTRIBUTE_GROUP_DIRECTION.append(Attribute("unidirection", "Unidirection")) ATTRIBUTE_GROUP_DIRECTION.append(Attribute("bidirection", "Bidirection")) ATTRIBUTE_GROUP_DIRECTION.append(Attribute("other", "Other")) ATTRIBUTE_GROUP_DIRECTION.set_default_by_name("bidirection") ATTRIBUTE_GROUP_LANE_COUNT = AttributeGroup("lane-count", "Lane Count (current dir.)") ATTRIBUTE_GROUP_LANE_COUNT.append(Attribute("1", "1")) ATTRIBUTE_GROUP_LANE_COUNT.append(Attribute("2", "2")) ATTRIBUTE_GROUP_LANE_COUNT.append(Attribute("3", "3")) ATTRIBUTE_GROUP_LANE_COUNT.append(Attribute("4+", "4+")) ATTRIBUTE_GROUP_LANE_COUNT.append(Attribute("other", "Other")) ATTRIBUTE_GROUP_LANE_COUNT.set_default_by_name("2") ATTRIBUTE_GROUP_CURVE = AttributeGroup("curve", "Curve (current lane)") ATTRIBUTE_GROUP_CURVE.append(Attribute("straight", "Straight")) ATTRIBUTE_GROUP_CURVE.append(Attribute("left-slight", "Left (slight)")) ATTRIBUTE_GROUP_CURVE.append(Attribute("left-medium", "Left (medium)")) ATTRIBUTE_GROUP_CURVE.append(Attribute("left-strong", "Left (strong)")) ATTRIBUTE_GROUP_CURVE.append(Attribute("right-slight", "Right (slight)")) ATTRIBUTE_GROUP_CURVE.append(Attribute("right-medium", "Right (medium)")) ATTRIBUTE_GROUP_CURVE.append(Attribute("right-strong", "Right (strong)")) ATTRIBUTE_GROUP_CURVE.append(Attribute("other", "Other")) ATTRIBUTE_GROUP_CURVE.set_default_by_name("straight") ATTRIBUTE_GROUP_SPACE_FRONT = AttributeGroup("space-front", "Space (Front)") ATTRIBUTE_GROUP_SPACE_FRONT.append(Attribute("plenty", "plenty (>3s)")) ATTRIBUTE_GROUP_SPACE_FRONT.append(Attribute("some", "some (1-3s)")) ATTRIBUTE_GROUP_SPACE_FRONT.append(Attribute("minimal", "minimal (<1s)")) ATTRIBUTE_GROUP_SPACE_FRONT.append(Attribute("none", "none (crashing)")) ATTRIBUTE_GROUP_SPACE_FRONT.set_default_by_name("plenty") ATTRIBUTE_GROUP_SPACE_LEFT = AttributeGroup("space-left", "Space (Left)") ATTRIBUTE_GROUP_SPACE_LEFT.append(Attribute("plenty", "plenty (good)")) ATTRIBUTE_GROUP_SPACE_LEFT.append(Attribute("some", "some (meh)")) ATTRIBUTE_GROUP_SPACE_LEFT.append(Attribute("minimal", "minimal (bad)")) ATTRIBUTE_GROUP_SPACE_LEFT.append(Attribute("none", "none (crashing)")) ATTRIBUTE_GROUP_SPACE_LEFT.set_default_by_name("plenty") ATTRIBUTE_GROUP_SPACE_RIGHT = AttributeGroup("space-right", "Space (Right)") ATTRIBUTE_GROUP_SPACE_RIGHT.append(Attribute("plenty", "plenty (good)")) ATTRIBUTE_GROUP_SPACE_RIGHT.append(Attribute("some", "some (meh)")) ATTRIBUTE_GROUP_SPACE_RIGHT.append(Attribute("minimal", "minimal (bad)")) ATTRIBUTE_GROUP_SPACE_RIGHT.append(Attribute("none", "none (crashing)")) ATTRIBUTE_GROUP_SPACE_RIGHT.set_default_by_name("plenty") ATTRIBUTE_GROUP_OFFROAD = AttributeGroup("offroad", "Offroad") ATTRIBUTE_GROUP_OFFROAD.append(Attribute("onroad", "Onroad")) ATTRIBUTE_GROUP_OFFROAD.append(Attribute("slightly", "Slightly")) ATTRIBUTE_GROUP_OFFROAD.append(Attribute("significantly", "Significantly")) ATTRIBUTE_GROUP_OFFROAD.set_default_by_name("onroad") ATTRIBUTE_GROUPS = [ ATTRIBUTE_GROUP_ROAD_TYPE, ATTRIBUTE_GROUP_INTERSECTION, ATTRIBUTE_GROUP_DIRECTION, ATTRIBUTE_GROUP_LANE_COUNT, ATTRIBUTE_GROUP_CURVE, ATTRIBUTE_GROUP_SPACE_FRONT, ATTRIBUTE_GROUP_SPACE_LEFT, ATTRIBUTE_GROUP_SPACE_RIGHT, ATTRIBUTE_GROUP_OFFROAD ] def main(): print("Loading replay memory...") memory = replay_memory.ReplayMemory.create_instance_supervised() win = AttributesAnnotationWindow.create( memory, save_to_fp="annotations_attributes.pickle", every_nth_example=25 ) win.autosave_every_nth = 100 win.master.wm_title("Annotate attributes") Tkinter.mainloop() class AttributesAnnotationWindow(object): def __init__(self, master, canvas, memory, current_state_idx, annotations, save_to_fp, every_nth_example=10, zoom_factor=4): self.master = master self.canvas = canvas self.memory = memory self.current_state_idx = current_state_idx self.annotations = annotations if annotations is not None else dict() self.current_annotation = None self.background_label = None self.dirty = False self.last_autosave = 0 self.every_nth_example = every_nth_example self.zoom_factor = zoom_factor self.autosave_every_nth = 20 self.save_to_fp = save_to_fp self.is_showing_directly_previous_state = False self.directly_previous_state = None self.current_state = None self.att_group_to_variable = dict() #self.switch_to_state(self.current_state_idx, autosave=False) #self.current_state = memory.get_state_by_id(current_state_idx) @staticmethod def create(memory, save_to_fp, every_nth_example=10, zoom_factor=2): colcount = max([len(att_group.attributes) for att_group in ATTRIBUTE_GROUPS]) print("Loading previous annotations...") annotations = load_annotations(save_to_fp) #is_annotated = dict([(str(annotation.idx), True) for annotation in annotations]) current_state_idx = memory.id_min if annotations is not None: while current_state_idx < memory.id_max: key = str(current_state_idx) if key not in annotations: break current_state_idx += every_nth_example print("ID of first unannotated state: %d" % (current_state_idx,)) master = Tkinter.Tk() master.grid() state = memory.get_state_by_id(current_state_idx) canvas_height = state.screenshot_rs.shape[0] * zoom_factor canvas_width = state.screenshot_rs.shape[1] * zoom_factor print("canvas height, width:", canvas_height, canvas_width) canvas = Tkinter.Canvas(master, width=canvas_width, height=canvas_height) #canvas.pack() canvas.grid(row=0, column=0, columnspan=colcount) canvas.focus_set() #y = int(canvas_height / 2) #w.create_line(0, y, canvas_width, y, fill="#476042") message = Tkinter.Label(master, text="Press S to save.") #message.pack(side=Tkinter.BOTTOM) message.grid(row=1, column=0, columnspan=colcount) window_state = AttributesAnnotationWindow( master, canvas, memory, current_state_idx, annotations, save_to_fp, every_nth_example, zoom_factor ) def build_lambda(att_group, att): return lambda: window_state.on_radio_click(att_group, att) for row_idx, att_group in enumerate(ATTRIBUTE_GROUPS): print(row_idx) var = Tkinter.StringVar() window_state.att_group_to_variable[att_group.name] = var var.set(att_group.default_attribute.name) lab = Tkinter.Label(master, text=att_group.name_shown) lab.grid(row=2+row_idx, column=0, sticky=Tkinter.W) #lab = Tkinter.Label(master, text=att_group.name_shown).pack(side=Tkinter.LEFT) #lab = Tkinter.Label(master, text=att_group.name_shown).pack(anchor=Tkinter.S) print("default:", att_group.default_attribute.name) for col_idx, att in enumerate(att_group.attributes): print("ns/n", att.name_shown, att.name) c = Tkinter.Radiobutton( master, text=att.name_shown, variable=var, value=att.name, command=build_lambda(att_group, att) ) #c.pack(side=Tkinter.LEFT) c.grid(row=2+row_idx, column=col_idx+1, sticky=Tkinter.W) print(row_idx, col_idx) canvas.bind("<s>", lambda event: window_state.save_annotations(force=True)) canvas.bind("<p>", lambda event: window_state.toggle_previous_screenshot()) canvas.bind("<Left>", lambda event: window_state.previous_state(autosave=True)) canvas.bind("<Right>", lambda event: window_state.next_state(autosave=True)) window_state.switch_to_state(window_state.current_state_idx, autosave=False) return window_state def on_radio_click(self, att_group, att): print("radio click", att_group.name, att.name) var = self.att_group_to_variable[att_group.name] var.set(att.name) self.current_annotation["attributes"][att_group.name] = att.name self.dirty = True #def update_annotations(self): def toggle_previous_screenshot(self): if self.directly_previous_state is not None: if self.is_showing_directly_previous_state: self.set_canvas_background(self._generate_heatmap()) else: self.set_canvas_background(self.directly_previous_state.screenshot_rs) self.is_showing_directly_previous_state = not self.is_showing_directly_previous_state def previous_state(self, autosave): print("Switching to previous state...") self.current_state_idx -= self.every_nth_example assert self.current_state_idx >= self.memory.id_min, "Start of memory reached (%d vs %d)" % (self.current_state_idx, self.memory.id_min) self.switch_to_state(self.current_state_idx, autosave=autosave) def next_state(self, autosave): print("Switching to next state...") self.current_state_idx += self.every_nth_example assert self.current_state_idx <= self.memory.id_max, "End of memory reached (%d vs %d)" % (self.current_state_idx, self.memory.id_max) self.switch_to_state(self.current_state_idx, autosave=autosave) def switch_to_state(self, idx, autosave): print("Switching to state %d (autosave=%s)..." % (idx, str(autosave))) self.directly_previous_state = self.memory.get_state_by_id(idx-1) self.current_state = self.memory.get_state_by_id(idx) assert self.current_state is not None self.current_state_idx = idx if autosave: if (self.last_autosave+1) % self.autosave_every_nth == 0: # only autosaves if dirty flag is true, ie any example was changed self.save_annotations() self.last_autosave = 0 else: self.last_autosave += 1 print("last_autosave=", self.last_autosave) key = str(self.current_state_idx) if key in self.annotations: self.current_annotation = self.annotations[key] print("Annotation for state ", key, " available.") print("Attributes: ", self.annotations[key]["attributes"]) else: print("No annotation yet for state ", key) last_annotation = self.current_annotation self.current_annotation = { "idx": self.current_state_idx, "from_datetime": self.current_state.from_datetime, "screenshot_rs": self.current_state.screenshot_rs, "attributes": dict() } for att_group in ATTRIBUTE_GROUPS: if last_annotation is not None: self.current_annotation["attributes"][att_group.name] = last_annotation["attributes"][att_group.name] else: self.current_annotation["attributes"][att_group.name] = att_group.default_attribute.name self.annotations[key] = self.current_annotation print("Initializing attributes to ", self.current_annotation["attributes"]) # set variables to new annotation state for att_group_name in self.current_annotation["attributes"]: var = self.att_group_to_variable[att_group_name] val = self.current_annotation["attributes"][att_group_name] var.set(val) self.is_showing_directly_previous_state = False self.set_canvas_background(self._generate_heatmap()) #self.update_annotation_grid(self.grid, initial=True) def save_annotations(self, force=False): #print(self.annotations) if self.dirty or force: print("Saving to %s..." % (self.save_to_fp,)) with open(self.save_to_fp, "w") as f: pickle.dump(self.annotations, f, protocol=-1) self.dirty = False print("Finished saving.") else: print("Not saved (not marked dirty)") def set_canvas_background(self, image): if self.background_label is None: # initialize background image label (first call) #img = self.current_state.screenshot_rs #bg_img_tk = numpy_to_tk_image(np.zeros(img.shape)) img_heatmap = self._generate_heatmap() img_heatmap_rs = ia.imresize_single_image(img_heatmap, (img_heatmap.shape[0]*self.zoom_factor, img_heatmap.shape[1]*self.zoom_factor), interpolation="nearest") bg_img_tk = numpy_to_tk_image(img_heatmap_rs) self.background_label = Tkinter.Label(self.canvas, image=bg_img_tk) self.background_label.place(x=0, y=0, relwidth=1, relheight=1, anchor=Tkinter.NW) self.background_label.image = bg_img_tk #print("image size", image.shape) #print("image height, width", image.to_array().shape) image_rs = ia.imresize_single_image(image, (image.shape[0]*self.zoom_factor, image.shape[1]*self.zoom_factor), interpolation="nearest") image_tk = numpy_to_tk_image(image_rs) self.background_label.configure(image=image_tk) self.background_label.image = image_tk def _generate_heatmap(self): #return util.draw_heatmap_overlay(self.current_state.screenshot_rs, self.grid, alpha=self.heatmap_alpha) return self.current_state.screenshot_rs if __name__ == "__main__": main()