#!/usr/bin/env python # -------------------------------------------------------- # Tensorflow Faster R-CNN # Licensed under The MIT License [see LICENSE for details] # Written by Xinlei Chen, based on code from Ross Girshick # -------------------------------------------------------- """ Demo script showing detections in sample images. See README.md for installation instructions before running. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import glob import time import _init_paths from model.config import cfg from model.test import im_detect from model.nms_wrapper import nms from text_connector import TextDetector from utils.timer import Timer import tensorflow as tf import matplotlib.pyplot as plt import numpy as np import os import cv2 import argparse from nets.vgg16 import vgg16 from nets.resnet_v1 import Resnetv1 from nets.squeezenet import SqueezeNet from nets.mobilenet_v2 import MobileNetV2 from utils import helper CLASSES = ('__background__', 'text') def vis_detections(im, class_name, dets, thresh=0.5, text=False): """Draw detected bounding boxes.""" inds = np.where(dets[:, -1] >= thresh)[0] if len(inds) == 0: return im = im[:, :, (2, 1, 0)] fig, ax = plt.subplots(figsize=(12, 12)) ax.imshow(im, aspect='equal') for i in inds: bbox = dets[i, :8] score = dets[i, -1] ax.add_line( plt.Line2D([bbox[0], bbox[2], bbox[6], bbox[4], bbox[0]], [bbox[1], bbox[3], bbox[7], bbox[5], bbox[1]], color='red', linewidth=3) ) if text: ax.text(bbox[0], bbox[1] - 2, '{:s} {:.3f}'.format(class_name, score), bbox=dict(facecolor='blue', alpha=0.5), fontsize=14, color='white') ax.set_title(('{} detections with ' 'p({} | box) >= {:.1f}').format(class_name, class_name, thresh), fontsize=14) plt.axis('off') plt.tight_layout() plt.draw() plt.show() def save_result(img, img_name, text_lines, result_dir): dst = img color = (0, 150, 0) for bbox in text_lines: bbox = [int(x) for x in bbox] p1 = (bbox[0], bbox[1]) p2 = (bbox[2], bbox[3]) p3 = (bbox[6], bbox[7]) p4 = (bbox[4], bbox[5]) dst = cv2.line(dst, p1, p2, color, 2) dst = cv2.line(dst, p2, p3, color, 2) dst = cv2.line(dst, p3, p4, color, 2) dst = cv2.line(dst, p4, p1, color, 2) img_path = os.path.join(result_dir, img_name[0:-4] + '.jpg') cv2.imwrite(img_path, dst) def recover_scale(boxes, scale): """ :param boxes: [(x1, y1, x2, y2)] :param scale: image scale :return: """ tmp_boxes = [] for b in boxes: tmp_boxes.append([int(x / scale) for x in b]) return np.asarray(tmp_boxes).astype(np.float32) def draw_rpn_boxes(img, img_name, boxes, scores, im_scale, nms, save_dir): """ :param boxes: [(x1, y1, x2, y2)] """ boxes = recover_scale(boxes, im_scale) base_name = img_name.split('/')[-1] color = (0, 255, 0) out = img.copy() if nms: boxes, scores = TextDetector.pre_process(boxes, scores) file_name = "%s_rpn_nms.jpg" % base_name else: file_name = "%s_rpn.jpg" % base_name for i, box in enumerate(boxes): cv2.rectangle(out, (box[0], box[1]), (box[2], box[3]), color, 2) cx = int((box[0] + box[2]) / 2) cy = int((box[1] + box[3]) / 2) cv2.putText(out, "%.01f" % scores[i], (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.2, (255, 0, 0)) cv2.imwrite(os.path.join(save_dir, file_name), out) def demo(sess, net, im_file, result_dir, viz=False, oriented=False): """Detect object classes in an image using pre-computed object proposals.""" # Load the demo image im = helper.read_rgb_img(im_file) # Detect all object classes and regress object bounds timer = Timer() timer.tic() scores, boxes, resized_im_shape, im_scale = im_detect(sess, net, im) timer.toc() im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) img_name = im_file.split('/')[-1] draw_rpn_boxes(im, img_name, boxes, scores[:, np.newaxis], im_scale, True, result_dir) draw_rpn_boxes(im, img_name, boxes, scores[:, np.newaxis], im_scale, False, result_dir) # Run TextDetector to merge small box line_detector = TextDetector(oriented) # line_detector 的输入必须是在 scale 之后的图片上!!, # 如果还原了以后再进行行构建,原图可能太大,导致每个 anchor 的 width 很大,导致 MAX_HORIZONTAL_GAP 太小 # text_lines point order: left-top, right-top, left-bottom, right-bottom text_lines = line_detector.detect(boxes, scores[:, np.newaxis], resized_im_shape) print("Image %s, detect %d text lines in %.3fs" % (im_file, len(text_lines), timer.diff)) if len(text_lines) != 0: text_lines = recover_scale(text_lines, im_scale) save_result(im, img_name, text_lines, result_dir) # Visualize detections if viz: vis_detections(im, CLASSES[1], text_lines) def parse_args(): """Parse input arguments.""" parser = argparse.ArgumentParser(description='Tensorflow CTPN demo') parser.add_argument('--net', dest='net', choices=['vgg16', 'squeeze', 'mobile'], default='vgg16') parser.add_argument('--img_dir', default='/home/cwq/data/ICDAR13/123') parser.add_argument('--dataset', dest='dataset', help='model tag', default='voc_2007_trainval') parser.add_argument('--tag', dest='tag', help='model tag', default='vgg_latin_chn_newdata') parser.add_argument('--viz', action='store_true', default=False, help='show result') parser.add_argument('-o', '--oriented', action='store_true', default=False, help='output rotated detect box') args = parser.parse_args() if not os.path.exists(args.img_dir): print("img dir not exists.") exit(-1) args.result_dir = os.path.join('./data/result', args.tag) if not os.path.exists(args.result_dir): os.makedirs(args.result_dir) return args if __name__ == '__main__': args = parse_args() # model path netname = args.net dataset = args.dataset ckpt_dir = os.path.join('output', netname, dataset, args.tag) ckpt = tf.train.get_checkpoint_state(ckpt_dir) # set config tfconfig = tf.ConfigProto(allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True # init session sess = tf.Session(config=tfconfig) # load network if netname == 'vgg16': net = vgg16() elif netname == 'res101': net = Resnetv1(num_layers=101) elif netname == 'mobile': net = MobileNetV2() elif args.net == 'squeeze': net = SqueezeNet() else: raise NotImplementedError net.create_architecture("TEST", num_classes=len(CLASSES), tag=args.tag, anchor_width=cfg.CTPN.ANCHOR_WIDTH, anchor_h_ratio_step=cfg.CTPN.H_RADIO_STEP, num_anchors=cfg.CTPN.NUM_ANCHORS) saver = tf.train.Saver() saver.restore(sess, ckpt.model_checkpoint_path) print('Loaded network {:s}'.format(ckpt.model_checkpoint_path)) im_files = glob.glob(args.img_dir + "/*.*") for im_file in im_files: demo(sess, net, im_file, args.result_dir, args.viz, args.oriented)