import os
import cv2
import sys
import time
import collections
import torch
import argparse
import numpy as np
#import torch.nn as nn
#import torch.nn.functional as F
#import pytesseract
import params
import torchvision.transforms as transforms

from torch.autograd import Variable
from torch.utils import data

from Detection.PSEnet import models
from Detection.PSEnet import util
from Detection.PSEnet.pypse import pse as pypse
from PIL import Image
from Recognition.crnn.models import crnn
from Recognition.crnn import utils
#from Recognition.crnn import params

def scaleimg(img, long_size=2240):
    h, w = img.shape[0:2]
    scale = long_size * 1.0 / max(h, w)
    img = cv2.resize(img, dsize=None, fx=scale, fy=scale)
    return img

def crop(img,bbox):
    #img = cv2.imread(imgpath)
    bbox = bbox.reshape(4,2)
    topleft_x = np.min(bbox[:,0])
    topleft_y = np.min(bbox[:,1])
    bot_right_x = np.max(bbox[:,0])
    bot_right_y = np.max(bbox[:,1])
    cropped_img = img[topleft_y:bot_right_y, topleft_x:bot_right_x]
    cropped_img = cv2.resize(cropped_img,(100,32))
    cropped_img = cv2.cvtColor(cropped_img,cv2.COLOR_BGR2GRAY)
    cropped_img = Image.fromarray(cropped_img)
    #cropped_img = cropped_img.convert('RGB')
    cropped_img = transforms.ToTensor()(cropped_img)
    return cropped_img

def drawBBox(bboxs,img):
    for bbox in bboxs:
        bbox = np.reshape(bbox,(4,2))
        cv2.drawContours(img, [bbox],-1, (0, 255, 0), 2)
    cv2.imwrite('result.jpg',img)

def detect(org_img):
    if params.arch == "resnet50":
        model = models.resnet50(pretrained=True, num_classes=7, scale=params.scale)
    elif params.arch == "resnet101":
        model = models.resnet101(pretrained=True, num_classes=7, scale=params.scale)
    elif params.arch == "resnet152":
        model = models.resnet152(pretrained=True, num_classes=7, scale=params.scale)
    for param in model.parameters():
        param.requires_grad = False

    model = model.cuda()

    if params.PSEnet_path is not None:                                         
        if os.path.isfile(params.PSEnet_path):
            print("Loading model and optimizer from checkpoint '{}'".format(params.PSEnet_path))
            checkpoint = torch.load(params.PSEnet_path)
            
            # model.load_state_dict(checkpoint['state_dict'])
            d = collections.OrderedDict()
            for key, value in checkpoint['state_dict'].items():
                tmp = key[7:]
                d[tmp] = value
            model.load_state_dict(d)

            print("Loaded checkpoint '{}' (epoch {})"
                  .format(params.PSEnet_path, checkpoint['epoch']))
            sys.stdout.flush()
        else:
            print("No checkpoint found at '{}'".format(params.PSEnet_path))
            sys.stdout.flush()

    model.eval()
    scaled_img = scaleimg(org_img[:,:,[2,1,0]])
    #scaled_img = np.expand_dims(scaled_img,axis=0)
    scaled_img = Image.fromarray(scaled_img)
    scaled_img = scaled_img.convert('RGB')
    scaled_img = transforms.ToTensor()(scaled_img)
    scaled_img = transforms.Normalize(mean=[0.0618, 0.1206, 0.2677], std=[1.0214, 1.0212, 1.0242])(scaled_img)
    scaled_img = torch.unsqueeze(scaled_img,0)
    #img = scaleimg(org_img)
    #img = img[:,:,[2,1,0]]
    #img = np.expand_dims(img,axis=0)
    #img = Image.fromarray(img)
    #img = img.convert('RGB')
    #img = torch.Tensor(img)
    #img = img.permute(0,3,1,2)
    scaled_img = Variable(scaled_img.cuda())

    outputs = model(scaled_img)

    score = torch.sigmoid(outputs[:, 0, :, :])
    outputs = (torch.sign(outputs - params.binary_th) + 1) / 2

    text = outputs[:, 0, :, :]
    kernels = outputs[:, 0:params.kernel_num, :, :] * text

    score = score.data.cpu().numpy()[0].astype(np.float32)
    text = text.data.cpu().numpy()[0].astype(np.uint8)
    kernels = kernels.data.cpu().numpy()[0].astype(np.uint8)
    pred = pypse(kernels, params.min_kernel_area / (params.scale * params.scale))

    scale = (org_img.shape[1] * 1.0 / pred.shape[1], org_img.shape[0] * 1.0 / pred.shape[0])
    label = pred
    label_num = np.max(label) + 1
    bboxes = []
    for i in range(1, label_num):
        points = np.array(np.where(label == i)).transpose((1, 0))[:, ::-1]

        if points.shape[0] < params.min_area / (params.scale * params.scale):
            continue

        score_i = np.mean(score[label == i])
        if score_i < params.min_score:
            continue

        rect = cv2.minAreaRect(points)
        bbox = cv2.boxPoints(rect) * scale
        bbox = bbox.astype('int32')
        bboxes.append(bbox.reshape(-1))
    drawBBox(bboxes,org_img)
    return bboxes

def recognise(bboxes,org_img):
    nclass = len(params.alphabet) + 1
    model = crnn.CRNN(params.imgH, params.nc, nclass, params.nh)
    if torch.cuda.is_available():
        model = model.cuda()

    # load model
    print('loading pretrained model from %s' % params.crnn_path)
    if params.multi_gpu:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(torch.load(params.crnn_path))
    converter = utils.strLabelConverter(params.alphabet)
    print('PREDICTION:')
    for bbox in bboxes:
        cropped_img = crop(org_img,bbox)
        if torch.cuda.is_available():
            image = cropped_img.cuda()
        image = image.view(1, *image.size())
        image = Variable(image)
        model.eval()
        preds = model(image)

        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)

        preds_size = Variable(torch.IntTensor([preds.size(0)]))
        raw_pred = converter.decode(preds.data, preds_size.data, raw=True)
        sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
        print('%-20s => %-20s' % (raw_pred, sim_pred))


def main(args):
    print ('reading image..')
    image = cv2.imread(args.image)
    print ('detecting text')
    bboxes = detect(image)
    print ('recognizing text')
    recognise(bboxes,image)

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='image path')
    parser.add_argument('--img', nargs='?', type=str, default='demo/tr_img_09961.jpg',    
                        help='Path to test image')
    args = parser.parse_args()
    main(args)