# -*- coding: utf-8 -*-
"""
 @File    : generator.py
 @Time    : 2019/12/22 下午8:22
 @Author  : yizuotian
 @Description    :  中文数据生成器
"""
import random

import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from torch.utils.data.dataset import Dataset

from fontutils import FONT_CHARS_DICT


def random_color(lower_val, upper_val):
    return [random.randint(lower_val, upper_val),
            random.randint(lower_val, upper_val),
            random.randint(lower_val, upper_val)]


def put_text(image, x, y, text, font, color=None):
    """
    写中文字
    :param image:
    :param x:
    :param y:
    :param text:
    :param font:
    :param color:
    :return:
    """

    im = Image.fromarray(image)
    draw = ImageDraw.Draw(im)
    color = (255, 0, 0) if color is None else color
    draw.text((x, y), text, color, font=font)
    return np.array(im)


class Generator(Dataset):
    def __init__(self, alpha, direction='horizontal'):
        """

        :param alpha: 所有字符
        :param direction: 文字方向:horizontal|vertical
        """
        super(Generator, self).__init__()
        self.alpha = alpha
        self.direction = direction
        self.alpha_list = list(alpha)
        self.min_len = 5
        self.max_len_list = [16, 19, 24, 26]
        self.max_len = max(self.max_len_list)
        self.font_size_list = [30, 25, 20, 18]
        self.font_path_list = list(FONT_CHARS_DICT.keys())
        self.font_list = []  # 二位列表[size,font]
        for size in self.font_size_list:
            self.font_list.append([ImageFont.truetype(font_path, size=size)
                                   for font_path in self.font_path_list])
        if self.direction == 'horizontal':
            self.im_h = 32
            self.im_w = 512
        else:
            self.im_h = 512
            self.im_w = 32

    def gen_background(self):
        """
        生成背景;随机背景|纯色背景|合成背景
        :return:
        """
        a = random.random()
        pure_bg = np.ones((self.im_h, self.im_w, 3)) * np.array(random_color(0, 100))
        random_bg = np.random.rand(self.im_h, self.im_w, 3) * 100
        if a < 0.1:
            return random_bg
        elif a < 0.8:
            return pure_bg
        else:
            b = random.random()
            mix_bg = b * pure_bg + (1 - b) * random_bg
            return mix_bg

    def horizontal_draw(self, draw, text, font, color, char_w, char_h):
        """
        水平方向文字合成
        :param draw:
        :param text:
        :param font:
        :param color:
        :param char_w:
        :param char_h:
        :return:
        """
        text_w = len(text) * char_w
        h_margin = max(self.im_h - char_h, 1)
        w_margin = max(self.im_w - text_w, 1)
        x_shift = np.random.randint(0, w_margin)
        y_shift = np.random.randint(0, h_margin)
        i = 0
        while i < len(text):
            draw.text((x_shift, y_shift), text[i], color, font=font)
            i += 1
            x_shift += char_w
            y_shift = np.random.randint(0, h_margin)
            # 如果下个字符超出图像,则退出
            if x_shift + char_w > self.im_w:
                break
        return text[:i]

    def vertical_draw(self, draw, text, font, color, char_w, char_h):
        """
        锤子方向文字生成
        :param draw:
        :param text:
        :param font:
        :param color:
        :param char_w:
        :param char_h:
        :return:
        """
        text_h = len(text) * char_h
        h_margin = max(self.im_h - text_h, 1)
        w_margin = max(self.im_w - char_w, 1)
        x_shift = np.random.randint(0, w_margin)
        y_shift = np.random.randint(0, h_margin)
        i = 0
        while i < len(text):
            draw.text((x_shift, y_shift), text[i], color, font=font)
            i += 1
            x_shift = np.random.randint(0, w_margin)
            y_shift += char_h
            # 如果下个字符超出图像,则退出
            if y_shift + char_h > self.im_h:
                break
        return text[:i]

    def draw_text(self, draw, text, font, color, char_w, char_h):
        if self.direction == 'horizontal':
            return self.horizontal_draw(draw, text, font, color, char_w, char_h)
        return self.vertical_draw(draw, text, font, color, char_w, char_h)

    def gen_image(self):
        idx = np.random.randint(len(self.max_len_list))
        image = self.gen_background()
        image = image.astype(np.uint8)
        target_len = int(np.random.uniform(self.min_len, self.max_len_list[idx], size=1))

        # 随机选择size,font
        size_idx = np.random.randint(len(self.font_size_list))
        font_idx = np.random.randint(len(self.font_path_list))
        font = self.font_list[size_idx][font_idx]
        font_path = self.font_path_list[font_idx]
        # 在选中font字体的可见字符中随机选择target_len个字符
        text = np.random.choice(FONT_CHARS_DICT[font_path], target_len)
        text = ''.join(text)
        # 计算字体的w和h
        w, char_h = font.getsize(text)
        char_w = int(w / len(text))

        # 写文字,生成图像
        im = Image.fromarray(image)
        draw = ImageDraw.Draw(im)
        color = tuple(random_color(105, 255))
        text = self.draw_text(draw, text, font, color, char_w, char_h)
        target_len = len(text)  # target_len可能变小了
        # 对应的类别
        indices = np.array([self.alpha.index(c) for c in text])
        # 转为灰度图
        image = np.array(im)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        # 亮度反转
        if random.random() > 0.5:
            image = 255 - image
        return image, indices, target_len

    def __getitem__(self, item):
        image, indices, target_len = self.gen_image()
        if self.direction == 'horizontal':
            image = np.transpose(image[:, :, np.newaxis], axes=(2, 1, 0))  # [H,W,C]=>[C,W,H]
        else:
            image = np.transpose(image[:, :, np.newaxis], axes=(2, 0, 1))  # [H,W,C]=>[C,H,W]
        # 标准化
        image = image.astype(np.float32) / 255.
        image -= 0.5
        image /= 0.5

        target = np.zeros(shape=(self.max_len,), dtype=np.long)
        target[:target_len] = indices
        input_len = self.im_w // 4 - 3
        return image, target, input_len, target_len

    def __len__(self):
        return len(self.alpha) * 100


def test_image_gen(direction='vertical'):
    from config import cfg
    gen = Generator(cfg.word.get_all_words()[:10], direction=direction)
    for i in range(10):
        im, indices, target_len = gen.gen_image()
        # cv2.imwrite('output/{}-{:03d}.jpg'.format(direction, i + 1), im)
        print(''.join([gen.alpha[j] for j in indices]))


def test_gen():
    from data.words import Word
    gen = Generator(Word().get_all_words())
    for x in gen:
        print(x[1])


def test_font_size():
    font = ImageFont.truetype('fonts/simsun.ttc')
    print(font.size)
    font.size = 20
    print(font.size)


if __name__ == '__main__':
    test_image_gen('horizontal')
    # test_image_gen('vertical')
    # test_gen()
    # test_font_size()