# -*- coding: utf-8 -*- """ @File : rest.py @Time : 2020/4/6 上午9:39 @Author : yizuotian @Description : restful服务 """ import argparse import base64 import itertools import sys import cv2 import numpy as np import torch import tornado.httpserver import tornado.wsgi from flask import Flask, request import crnn from config import cfg app = Flask(__name__) app.config["JSON_AS_ASCII"] = False def pre_process_image(image, h, w): """ :param image: [H,W] :param h: 图像高度 :param w: 图像宽度 :return: """ if h != 32 and h < w: new_w = int(w * 32 / h) image = cv2.resize(image, (new_w, 32)) if w != 32 and w < h: new_h = int(h * 32 / w) image = cv2.resize(image, (32, new_h)) if h < w: image = np.array(image).T # [W,H] image = image.astype(np.float32) / 255. image -= 0.5 image /= 0.5 image = image[np.newaxis, np.newaxis, :, :] # [B,C,W,H] return image def inference(image, h, w): """ 预测图像 :param image: [H,W] :param h: 图像高度 :param w: 图像宽度 :return: text """ image = torch.FloatTensor(image) if h > w: predict = v_net(image)[0].detach().cpu().numpy() # [W,num_classes] else: predict = h_net(image)[0].detach().cpu().numpy() # [W,num_classes] image.to(device) label = np.argmax(predict[:], axis=1) label = [alpha[class_id] for class_id in label] label = [k for k, g in itertools.groupby(list(label))] # label = ''.join(label).replace(' ', '') return label @app.route('/crnn', methods=['POST']) def ocr_rest(): """ :return: """ img = base64.decodebytes(request.form.get('img').encode()) img = np.frombuffer(img, dtype=np.uint8) h, w = request.form.getlist('shape', type=int) img = img.reshape((h, w)) # 预处理 img = pre_process_image(img, h, w) # 预测 text = inference(img, h, w) text = ''.join(text) print("text:{}".format(text)) return {'text': text} def start_tornado(app, port=5000): http_server = tornado.httpserver.HTTPServer( tornado.wsgi.WSGIContainer(app)) http_server.listen(port) print("Tornado server starting on port {}".format(port)) tornado.ioloop.IOLoop.instance().start() if __name__ == '__main__': """ Usage: export KMP_DUPLICATE_LIB_OK=TRUE python rest.py -l output/crnn.horizontal.061.pth -v output/crnn.vertical.090.pth -d cuda """ parse = argparse.ArgumentParser() parse.add_argument('-l', "--weight-path-horizontal", type=str, default=None, help="weight path") parse.add_argument('-v', "--weight-path-vertical", type=str, default=None, help="weight path") parse.add_argument('-d', "--device", type=str, default='cpu', help="cpu or cuda") args = parse.parse_args(sys.argv[1:]) alpha = cfg.word.get_all_words() device = torch.device('cuda' if args.device == 'cuda' and torch.cuda.is_available() else 'cpu') # 加载权重,水平方向 h_net = crnn.CRNN(num_classes=len(alpha)) h_net.load_state_dict(torch.load(args.weight_path_horizontal, map_location='cpu')['model']) h_net.eval() h_net.to(device) # 垂直方向 v_net = crnn.CRNNV(num_classes=len(alpha)) v_net.load_state_dict(torch.load(args.weight_path_vertical, map_location='cpu')['model']) v_net.eval() v_net.to(device) # 启动restful服务 start_tornado(app, 5000)