import cv2 import numpy as np import os import torch from torch.nn.functional import upsample from dataloaders import utils import networks.deeplab_resnet as resnet from glob import glob from copy import deepcopy drawing = False start = False # TODO: import test image path image_list = glob(os.path.join('ims/', '*.'+'jpg')) image = cv2.imread(image_list[0]) img_shape = (450, 450) image = utils.fixed_resize(image, img_shape).astype(np.uint8) w = img_shape[0] h = img_shape[1] output = np.zeros((w, h, 3), np.uint8) thres = 0.8 left = 0xFFF right = 0 up = 0xFFF down = 0 gpu_id = 0 device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu") # Create the network and load the weights net = resnet.resnet101(1, nInputChannels=4, classifier='psp') print("Initializing weights from: {}".format(os.path.join('models/', 'deepgc_pascal_epoch-99.pth'))) state_dict_checkpoint = torch.load(os.path.join('models/', 'deepgc_pascal_epoch-99.pth'), map_location=lambda storage, loc: storage) net.load_state_dict(state_dict_checkpoint) net.eval() net.to(device) def interactive_drawing(event, x, y, flag, params): global xs, ys, ix, iy, drawing, image, output, left, right, up, down if event == cv2.EVENT_LBUTTONDOWN: drawing = True ix, iy = x, y xs, ys = x, y left = min(left, x) right = max(right, x) up = min(up, y) down = max(down, y) elif event == cv2.EVENT_MOUSEMOVE: if drawing is True: cv2.line(image, (ix, iy), (x, y), (0, 0, 255), 2) cv2.line(output, (ix, iy), (x, y), (255, 255, 255), 1) ix = x iy = y left = min(left, x) right = max(right, x) up = min(up, y) down = max(down, y) elif event == cv2.EVENT_LBUTTONUP: drawing = False cv2.line(image, (ix, iy), (x, y), (0, 0, 255), 2) cv2.line(output, (ix, iy), (x, y), (255, 255, 255), 1) ix = x iy = y cv2.line(image, (ix, iy), (xs, ys), (0, 0, 255), 2) cv2.line(output, (ix, iy), (xs, ys), (255, 255, 255), 1) return x, y def main(): global image, output cv2.namedWindow('draw', flags=cv2.WINDOW_NORMAL) cv2.setMouseCallback('draw', interactive_drawing) image_idx = 0 while(1): cv2.imshow('draw', image) k = cv2.waitKey(1) & 0xFF if k != 255: # print(k) pass if k == 100: # D cv2.imwrite('./' + str(image_idx) + 'out.png', image) if k == 115: global left, right, up, down left = 0xFFF right = 0 up = 0xFFF down = 0 drawing = False # true if mouse is pressed image = cv2.imread(image_list[image_idx]) image = utils.fixed_resize(image, (450, 450)).astype(np.uint8) sp = image.shape w = sp[0] h = sp[1] output = np.zeros((w, h, 3), np.uint8) while (1): cv2.imshow('draw', image) k = cv2.waitKey(1) & 0xFF if k == 32: break if k == 27: image = cv2.imread(image_list[image_idx]) image = utils.fixed_resize(image, (450, 450)).astype(np.uint8) output = np.zeros((w, h, 3), np.uint8) tmp = (output[:, :, 0] > 0).astype(np.uint8) tmp_ = deepcopy(tmp) fill_mask = np.ones((tmp.shape[0] + 2, tmp.shape[1] + 2)) fill_mask[1:-1, 1:-1] = tmp_ fill_mask = fill_mask.astype(np.uint8) cv2.floodFill(tmp_, fill_mask, (int((left + right) / 2), int((up + down) / 2)), 5) tmp_ = tmp_.astype(np.int8) output = cv2.resize(output, img_shape) tmp_ = tmp_.astype(np.int8) tmp_[tmp_ == 5] = -1 # pixel inside bounding box tmp_[tmp_ == 0] = 1 # pixel on and outside bounding box tmp = (tmp == 0).astype(np.uint8) dismap = cv2.distanceTransform(tmp, cv2.DIST_L2, cv2.DIST_MASK_PRECISE) # compute distance inside and outside bounding box dismap = tmp_ * dismap + 128 dismap[dismap > 255] = 255 dismap[dismap < 0] = 0 dismap = dismap dismap = utils.fixed_resize(dismap, (450, 450)).astype(np.uint8) dismap = np.expand_dims(dismap, axis=-1) image = image[:, :, ::-1] # change to rgb merge_input = np.concatenate((image, dismap), axis=2).astype(np.float32) inputs = torch.from_numpy(merge_input.transpose((2, 0, 1))[np.newaxis, ...]) # Run a forward pass inputs = inputs.to(device) outputs = net.forward(inputs) outputs = upsample(outputs, size=(450, 450), mode='bilinear', align_corners=True) outputs = outputs.to(torch.device('cpu')) prediction = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0)) prediction = 1 / (1 + np.exp(-prediction)) prediction = np.squeeze(prediction) prediction[prediction>thres] = 255 prediction[prediction<=thres] = 0 prediction = np.expand_dims(prediction, axis=-1).astype(np.uint8) image = image[:, :, ::-1] # change to bgr display_mask = np.concatenate([np.zeros_like(prediction), np.zeros_like(prediction), prediction], axis=-1) image = cv2.addWeighted(image, 0.9, display_mask, 0.5, 0.1) if k == 99: break if k == 110: image_idx += 1 if image_idx >= len(image_list): print('Already the last image. Starting from the beginning.') image_idx = 0 image = cv2.imread(image_list[image_idx]) image = utils.fixed_resize(image, (450, 450)).astype(np.uint8) sp = image.shape w = sp[0] h = sp[1] output = np.zeros((w, h, 3), np.uint8) if k == 112: image_idx -= 1 if image_idx < 0: print('Reached the first image. Starting from the end.') image_idx = len(image_list)-1 image = cv2.imread(image_list[image_idx]) image = utils.fixed_resize(image, (450, 450)).astype(np.uint8) sp = image.shape w = sp[0] h = sp[1] output = np.zeros((w, h, 3), np.uint8) cv2.destroyAllWindows() if __name__ == "__main__": main()