import json
import os
import random
from os import path
from string import ascii_uppercase, digits, punctuation

import colorama
import numpy
import regex
import torch
from colorama import Fore
from torch.utils import data

from my_classes import TextBox, TextLine
from my_utils import robust_padding

VOCAB = ascii_uppercase + digits + punctuation + " \t\n"


class MyDataset(data.Dataset):
    def __init__(
        self, dict_path="data/data_dict.pth", device="cpu", val_size=76, test_path=None
    ):
        if dict_path is None:
            self.val_dict = {}
            self.train_dict = {}
        else:
            data_items = list(torch.load(dict_path).items())
            random.shuffle(data_items)

            self.val_dict = dict(data_items[:val_size])
            self.train_dict = dict(data_items[val_size:])

        if test_path is None:
            self.test_dict = {}
        else:
            self.test_dict = torch.load(test_path)

        self.device = device

    def get_test_data(self, key):
        text = self.test_dict[key]
        text_tensor = torch.zeros(len(text), 1, dtype=torch.long)
        text_tensor[:, 0] = torch.LongTensor([VOCAB.find(c) for c in text])

        return text_tensor.to(self.device)

    def get_train_data(self, batch_size=8):
        samples = random.sample(self.train_dict.keys(), batch_size)

        texts = [self.train_dict[k][0] for k in samples]
        labels = [self.train_dict[k][1] for k in samples]

        robust_padding(texts, labels)

        maxlen = max(len(t) for t in texts)

        text_tensor = torch.zeros(maxlen, batch_size, dtype=torch.long)
        for i, text in enumerate(texts):
            text_tensor[:, i] = torch.LongTensor([VOCAB.find(c) for c in text])

        truth_tensor = torch.zeros(maxlen, batch_size, dtype=torch.long)
        for i, label in enumerate(labels):
            truth_tensor[:, i] = torch.LongTensor(label)

        return text_tensor.to(self.device), truth_tensor.to(self.device)

    def get_val_data(self, batch_size=8, device="cpu"):
        keys = random.sample(self.val_dict.keys(), batch_size)

        texts = [self.val_dict[k][0] for k in keys]
        labels = [self.val_dict[k][1] for k in keys]

        maxlen = max(len(s) for s in texts)
        texts = [s.ljust(maxlen, " ") for s in texts]
        labels = [
            numpy.pad(a, (0, maxlen - len(a)), mode="constant", constant_values=0)
            for a in labels
        ]

        text_tensor = torch.zeros(maxlen, batch_size, dtype=torch.long)
        for i, text in enumerate(texts):
            text_tensor[:, i] = torch.LongTensor([VOCAB.find(c) for c in text])

        truth_tensor = torch.zeros(maxlen, batch_size, dtype=torch.long)
        for i, label in enumerate(labels):
            truth_tensor[:, i] = torch.LongTensor(label)

        return keys, text_tensor.to(self.device), truth_tensor.to(self.device)


def get_files(data_path="data/"):
    json_files = sorted(
        (f for f in os.scandir(data_path) if f.name.endswith(".json")),
        key=lambda f: f.path,
    )
    txt_files = sorted(
        (f for f in os.scandir(data_path) if f.name.endswith(".txt")),
        key=lambda f: f.path,
    )

    assert len(json_files) == len(txt_files)
    for f1, f2 in zip(json_files, txt_files):
        assert path.splitext(f1)[0] == path.splitext(f2)[0]

    return json_files, txt_files


def sort_text(txt_file):
    with open(txt_file, "r") as txt_opened:
        content = sorted([TextBox(line) for line in txt_opened], key=lambda box: box.y)

    text_lines = [TextLine(content[0])]
    for box in content[1:]:
        try:
            text_lines[-1].insert(box)
        except ValueError:
            text_lines.append(TextLine(box))

    return "\n".join([str(text_line) for text_line in text_lines])


def create_test_data():
    keys = sorted(
        path.splitext(f.name)[0]
        for f in os.scandir("tmp/task3-test(347p)")
        if f.name.endswith(".jpg")
    )

    files = ["tmp/text.task1&2-test(361p)/" + s + ".txt" for s in keys]

    test_dict = {}
    for k, f in zip(keys, files):
        test_dict[k] = sort_text(f)

    torch.save(test_dict, "data/test_dict.pth")


def create_data(data_path="tmp/data/"):

    json_files, txt_files = get_files(data_path)
    keys = [path.splitext(f.name)[0] for f in json_files]

    data_dict = {}

    for key, json_file, txt_file in zip(keys, json_files, txt_files):
        with open(json_file, "r", encoding="utf-8") as json_opend:
            key_info = json.load(json_opend)

        text = sort_text(txt_file)
        text_space = regex.sub(r"[\t\n]", " ", text)

        text_class = numpy.zeros(len(text), dtype=int)

        print()
        print(json_file.path, txt_file.path)
        for i, k in enumerate(iter(key_info)):
            v = key_info[k]
            if k == "total":
                s = regex.search(
                    r"(\bTOTAL[^C]*ROUND[^C]*)(" + v + r")(\b)", text_space
                )
                if s is None:
                    s = regex.search(r"(\bTOTAL[^C]*)(" + v + r")(\b)", text_space)
                    if s is None:
                        s = regex.search(r"(\b)(" + v + r")(\b)", text_space)
                        if s is None:
                            s = regex.search(r"()(" + v + r")()", text_space)
                v = s[2]
                text_class[range(*s.span(2))] = i + 1
            else:
                if not v in text_space:
                    s = None
                    e = 0
                    while s is None and e < 3:
                        e += 1
                        s = regex.search(
                            r"(\b" + v + r"\b){e<=" + str(e) + r"}", text_space
                        )
                    v = s[0]

                pos = text_space.find(v)
                text_class[pos : pos + len(v)] = i + 1

        data_dict[key] = (text, text_class)

        # print(txt_file.path)
        # color_print(text, text_class)

    return keys, data_dict


def color_print(text, text_class):
    colorama.init()
    for c, n in zip(text, text_class):
        if n == 1:
            print(Fore.RED + c, end="")
        elif n == 2:
            print(Fore.GREEN + c, end="")
        elif n == 3:
            print(Fore.BLUE + c, end="")
        elif n == 4:
            print(Fore.YELLOW + c, end="")
        else:
            print(Fore.WHITE + c, end="")
    print(Fore.RESET)
    print()


if __name__ == "__main__":
    create_test_data()

    # dataset = MyDataset("data/data_dict2.pth")
    # text, truth = dataset.get_train_data()
    # print(text)
    # print(truth)
    # dict3 = torch.load("data/data_dict3.pth")
    # for k in dict3.keys():
    #     text, text_class = dict3[k]
    #     color_print(text, text_class)

    # keys, data_dict = create_data()
    # torch.save(data_dict, "data/data_dict4.pth")

    # s = "START 0 TOTAL:1.00, START TOTAL: 1.00 END"
    # rs = regex.search(r"(\sTOTAL.*)(1.00)(\s)", s)
    # for i in range(len(rs)):
    #     print(repr(rs[i]), rs.span(i))