import time
from multiprocessing import cpu_count
from multiprocessing.dummy import Pool as ThreadPool

import cv2
import numpy as np
import torch
from PIL import Image
from torch.nn import functional as F
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import save_image

from Dataloader import EvaluateSet
from models.text_segmentation import TextSegament, XceptionTextSegment


def draw_bounding_box(img, mask, area_threshold=100):
    b, c = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    for cnt in b:
        area = cv2.contourArea(cnt)
        if area > area_threshold:
            hull = cv2.convexHull(cnt)
            cv2.drawContours(img, [hull], 0, (50, 128, 30), -1)

    return img


def process(eval_img, device='cpu'):
    (img, origin, unpadder), file_name = eval_img
    with torch.no_grad():
        out = model(img.to(device))

    prob = F.sigmoid(out)
    mask = prob > 0.5
    mask = torch.nn.MaxPool2d(kernel_size=(3, 3), padding=(1, 1), stride=1)(mask.float()).byte()
    mask = unpadder(mask)
    mask = mask.float().cpu()

    save_image(mask, file_name + ' _mask.jpg')
    origin_np = np.array(to_pil_image(origin[0]))
    mask_np = to_pil_image(mask[0]).convert("L")
    mask_np = np.array(mask_np, dtype='uint8')
    mask_np = draw_bounding_box(origin_np, mask_np, 500)
    mask_ = Image.fromarray(mask_np)
    mask_.save(file_name + "_contour.jpg")
    # ret, mask_np = cv2.threshold(mask_np, 127, 255, 0)
    # dst = cv2.inpaint(origin_np, mask_np, 1, cv2.INPAINT_NS)
    # out = Image.fromarray(dst)
    # out.save(file_name + ' _box.jpg')


model = XceptionTextSegment()
model.total_parameters()

# the model is trained with in-place batch norm, but the weights are compatible with torch's batch norm
model.load_state_dict(torch.load("checkpoints/text_seg_model_681epos.pt", map_location='cpu'))
model = model.cuda()

evalset = EvaluateSet(mean=[0.4935, 0.4563, 0.4544],
                      std=[0.3769, 0.3615, 0.3566],
                      img_folder='test_data',
                      resize=600)

a = time.time()
for i in evalset:
    process(i, 'cuda')

# with ThreadPool(cpu_count() - 1) as p:
#     p.map(process, evalset)
print("Runtime :{}".format(time.time() - a))