import os
import re
import math
import sys
import shutil
import json
import traceback
import PIL.Image as PilImage
import threading
import tkinter as tk
from tkinter import messagebox
from tkinter import ttk
from tkinter import filedialog
from constants import *
from config import ModelConfig, OUTPUT_SHAPE1_MAP, NETWORK_MAP, DataAugmentationEntity, PretreatmentEntity, CORE_VERSION
from make_dataset import DataSets
from predict_testing import Predict
from trains import Trains
from category import category_extract, SIMPLE_CATEGORY_MODEL
from gui.utils import LayoutGUI
from gui.data_augmentation import DataAugmentationDialog
from gui.pretreatment import PretreatmentDialog

NOT_EDITABLE_MSG = "ONLY SUPPORT MODIFICATION FROM FILE"


class Wizard:

    job: threading.Thread
    current_task: Trains = None
    is_task_running: bool = False
    data_augmentation_entity = DataAugmentationEntity()
    pretreatment_entity = PretreatmentEntity()
    extract_regex = ".*?(?=_)"
    label_split = ""
    model_conf: ModelConfig = None

    def __init__(self, parent: tk.Tk):
        self.layout = {
            'global': {
                'start': {'x': 15, 'y': 20},
                'space': {'x': 15, 'y': 25},
                'tiny_space': {'x': 5, 'y': 10}
            }
        }
        self.parent = parent
        self.parent.iconbitmap(Wizard.resource_path("resource/icon.ico"))
        self.current_project: str = ""
        self.project_root_path = "./projects"
        if not os.path.exists(self.project_root_path):
            os.makedirs(self.project_root_path)
        self.parent.title('Image Classification Wizard Tool based on Deep Learning')
        self.parent.resizable(width=False, height=False)
        self.window_width = 815
        self.window_height = 700
        self.layout_utils = LayoutGUI(self.layout, self.window_width)
        screenwidth = self.parent.winfo_screenwidth()
        screenheight = self.parent.winfo_screenheight()
        size = '%dx%d+%d+%d' % (
            self.window_width,
            self.window_height,
            (screenwidth - self.window_width) / 2,
            (screenheight - self.window_height) / 2
        )

        self.parent.bind('<Button-1>', lambda x: self.blank_click(x))

        # ============================= Menu 1 =====================================
        self.menubar = tk.Menu(self.parent)
        self.data_menu = tk.Menu(self.menubar, tearoff=False)
        self.help_menu = tk.Menu(self.menubar, tearoff=False)
        self.system_menu = tk.Menu(self.menubar, tearoff=False)
        self.edit_var = tk.DoubleVar()
        self.label_from_var = tk.StringVar()

        self.memory_usage_menu = tk.Menu(self.menubar, tearoff=False)
        self.memory_usage_menu.add_radiobutton(label="50%", variable=self.edit_var, value=0.5)
        self.memory_usage_menu.add_radiobutton(label="60%", variable=self.edit_var, value=0.6)
        self.memory_usage_menu.add_radiobutton(label="70%", variable=self.edit_var, value=0.7)
        self.memory_usage_menu.add_radiobutton(label="80%", variable=self.edit_var, value=0.8)

        self.label_from_menu = tk.Menu(self.menubar, tearoff=False)
        self.label_from_menu.add_radiobutton(label="FileName", variable=self.label_from_var, value='FileName')
        self.label_from_menu.add_radiobutton(label="TXT", variable=self.label_from_var, value='TXT')

        self.menubar.add_cascade(label="System", menu=self.system_menu)
        self.system_menu.add_cascade(label="Memory Usage", menu=self.memory_usage_menu)

        self.data_menu.add_command(label="Data Augmentation", command=lambda: self.popup_data_augmentation())
        self.data_menu.add_command(label="Pretreatment", command=lambda: self.popup_pretreatment())
        self.data_menu.add_separator()
        self.data_menu.add_command(label="Clear Dataset", command=lambda: self.clear_dataset())
        self.data_menu.add_separator()
        self.data_menu.add_cascade(label="Label From", menu=self.label_from_menu)
        self.menubar.add_cascade(label="Data", menu=self.data_menu)

        self.help_menu.add_command(label="About", command=lambda: self.popup_about())
        self.menubar.add_cascade(label="Help", menu=self.help_menu)

        self.parent.config(menu=self.menubar)

        # ============================= Group 1 =====================================
        self.label_frame_source = ttk.Labelframe(self.parent, text='Sample Source')
        self.label_frame_source.place(
            x=self.layout['global']['start']['x'],
            y=self.layout['global']['start']['y'],
            width=790,
            height=150
        )

        # 训练集源路径 - 标签
        self.dataset_train_path_text = ttk.Label(self.parent, text='Training Path', anchor=tk.W)
        self.layout_utils.inside_widget(
            src=self.dataset_train_path_text,
            target=self.label_frame_source,
            width=90,
            height=20
        )

        # 训练集源路径 - 输入控件
        self.source_train_path_listbox = tk.Listbox(self.parent, font=('微软雅黑', 9))
        self.layout_utils.next_to_widget(
            src=self.source_train_path_listbox,
            target=self.dataset_train_path_text,
            width=600,
            height=50,
            tiny_space=True
        )
        self.source_train_path_listbox.bind(
            sequence="<Delete>",
            func=lambda x: self.listbox_delete_item_callback(x, self.source_train_path_listbox)
        )
        self.listbox_scrollbar(self.source_train_path_listbox)

        # 训练集源路径 - 按钮
        self.btn_browse_train = ttk.Button(
            self.parent, text='Browse', command=lambda: self.browse_dataset(DatasetType.Directory, RunMode.Trains)
        )
        self.layout_utils.next_to_widget(
            src=self.btn_browse_train,
            target=self.source_train_path_listbox,
            width=60,
            height=24,
            tiny_space=True
        )

        # 验证集源路径 - 标签
        label_edge = self.layout_utils.object_edge_info(self.dataset_train_path_text)
        widget_edge = self.layout_utils.object_edge_info(self.source_train_path_listbox)
        self.dataset_validation_path_text = ttk.Label(self.parent, text='Validation Path', anchor=tk.W)
        self.dataset_validation_path_text.place(
            x=label_edge['x'],
            y=widget_edge['edge_y'] + self.layout['global']['space']['y'] / 2,
            width=90,
            height=20
        )

        # 验证集源路径 - 输入控件
        self.source_validation_path_listbox = tk.Listbox(self.parent, font=('微软雅黑', 9))
        self.layout_utils.next_to_widget(
            src=self.source_validation_path_listbox,
            target=self.dataset_validation_path_text,
            width=600,
            height=50,
            tiny_space=True
        )
        self.source_validation_path_listbox.bind(
            sequence="<Delete>",
            func=lambda x: self.listbox_delete_item_callback(x, self.source_validation_path_listbox)
        )
        self.listbox_scrollbar(self.source_validation_path_listbox)

        # 训练集源路径 - 按钮
        self.btn_browse_validation = ttk.Button(
            self.parent, text='Browse', command=lambda: self.browse_dataset(DatasetType.Directory, RunMode.Validation)
        )
        self.layout_utils.next_to_widget(
            src=self.btn_browse_validation,
            target=self.source_validation_path_listbox,
            width=60,
            height=24,
            tiny_space=True
        )

        # ============================= Group 2 =====================================
        self.label_frame_neu = ttk.Labelframe(self.parent, text='Neural Network')
        self.layout_utils.below_widget(
            src=self.label_frame_neu,
            target=self.label_frame_source,
            width=790,
            height=120,
            tiny_space=False
        )

        # 最大标签数目 - 标签
        self.label_num_text = ttk.Label(self.parent, text='Label Num', anchor=tk.W)
        self.layout_utils.inside_widget(
            src=self.label_num_text,
            target=self.label_frame_neu,
            width=65,
            height=20,
        )

        # 最大标签数目 - 滚动框
        self.label_num_spin = ttk.Spinbox(self.parent, from_=1, to=12)
        self.label_num_spin.set(1)
        self.layout_utils.next_to_widget(
            src=self.label_num_spin,
            target=self.label_num_text,
            width=50,
            height=20,
            tiny_space=True
        )

        # 图像通道 - 标签
        self.channel_text = ttk.Label(self.parent, text='Channel', anchor=tk.W)
        self.layout_utils.next_to_widget(
            src=self.channel_text,
            target=self.label_num_spin,
            width=50,
            height=20,
            tiny_space=False
        )

        # 图像通道 - 下拉框
        self.comb_channel = ttk.Combobox(self.parent, values=(3, 1), state='readonly')
        self.comb_channel.current(1)
        self.layout_utils.next_to_widget(
            src=self.comb_channel,
            target=self.channel_text,
            width=38,
            height=20,
            tiny_space=True
        )

        # 卷积层 - 标签
        self.neu_cnn_text = ttk.Label(self.parent, text='CNN Layer', anchor=tk.W)
        self.layout_utils.next_to_widget(
            src=self.neu_cnn_text,
            target=self.comb_channel,
            width=65,
            height=20,
            tiny_space=False
        )

        # 卷积层 - 下拉框
        self.comb_neu_cnn = ttk.Combobox(self.parent, values=[_.name for _ in CNNNetwork], state='readonly')
        self.comb_neu_cnn.current(0)
        self.layout_utils.next_to_widget(
            src=self.comb_neu_cnn,
            target=self.neu_cnn_text,
            width=80,
            height=20,
            tiny_space=True
        )

        # 循环层 - 标签
        self.neu_recurrent_text = ttk.Label(self.parent, text='Recurrent Layer', anchor=tk.W)
        self.layout_utils.next_to_widget(
            src=self.neu_recurrent_text,
            target=self.comb_neu_cnn,
            width=95,
            height=20,
            tiny_space=False
        )

        # 循环层 - 下拉框
        self.comb_recurrent = ttk.Combobox(self.parent, values=[_.name for _ in RecurrentNetwork], state='readonly')
        self.comb_recurrent.current(1)
        self.layout_utils.next_to_widget(
            src=self.comb_recurrent,
            target=self.neu_recurrent_text,
            width=112,
            height=20,
            tiny_space=True
        )
        self.comb_recurrent.bind("<<ComboboxSelected>>", lambda x: self.auto_loss(x))

        # 循环层单元数 - 标签
        self.units_num_text = ttk.Label(self.parent, text='UnitsNum', anchor=tk.W)
        self.layout_utils.next_to_widget(
            src=self.units_num_text,
            target=self.comb_recurrent,
            width=60,
            height=20,
            tiny_space=False
        )

        # 循环层单元数 - 下拉框
        self.units_num_spin = ttk.Spinbox(self.parent, from_=16, to=512, increment=16, wrap=True)
        self.units_num_spin.set(64)
        self.layout_utils.next_to_widget(
            src=self.units_num_spin,
            target=self.units_num_text,
            width=55,
            height=20,
            tiny_space=True
        )

        # 损失函数 - 标签
        self.loss_func_text = ttk.Label(self.parent, text='Loss Function', anchor=tk.W)
        self.layout_utils.below_widget(
            src=self.loss_func_text,
            target=self.label_num_text,
            width=85,
            height=20,
            tiny_space=True
        )

        # 损失函数 - 下拉框
        self.comb_loss = ttk.Combobox(self.parent, values=[_.name for _ in LossFunction], state='readonly')
        self.comb_loss.current(1)
        self.layout_utils.next_to_widget(
            src=self.comb_loss,
            target=self.loss_func_text,
            width=101,
            height=20,
            tiny_space=True
        )

        # 优化器 - 标签
        self.optimizer_text = ttk.Label(self.parent, text='Optimizer', anchor=tk.W)
        self.layout_utils.next_to_widget(
            src=self.optimizer_text,
            target=self.comb_loss,
            width=60,
            height=20,
            tiny_space=False
        )

        # 优化器 - 下拉框
        self.comb_optimizer = ttk.Combobox(self.parent, values=[_.name for _ in Optimizer], state='readonly')
        self.comb_optimizer.current(0)
        self.layout_utils.next_to_widget(
            src=self.comb_optimizer,
            target=self.optimizer_text,
            width=88,
            height=20,
            tiny_space=True
        )

        # 学习率 - 标签
        self.learning_rate_text = ttk.Label(self.parent, text='Learning Rate', anchor=tk.W)
        self.layout_utils.next_to_widget(
            src=self.learning_rate_text,
            target=self.comb_optimizer,
            width=85,
            height=20,
            tiny_space=False
        )

        # 学习率 - 滚动框
        self.learning_rate_spin = ttk.Spinbox(self.parent, from_=0.00001, to=0.1, increment='0.0001')
        self.learning_rate_spin.set(0.001)
        self.layout_utils.next_to_widget(
            src=self.learning_rate_spin,
            target=self.learning_rate_text,
            width=67,
            height=20,
            tiny_space=True
        )

        # Resize - 标签
        self.resize_text = ttk.Label(self.parent, text='Resize', anchor=tk.W)
        self.layout_utils.next_to_widget(
            src=self.resize_text,
            target=self.learning_rate_spin,
            width=36,
            height=20,
            tiny_space=False
        )

        # Resize - 输入框
        self.resize_val = tk.StringVar()
        self.resize_val.set('[150, 50]')
        self.resize_entry = ttk.Entry(self.parent, textvariable=self.resize_val, justify=tk.LEFT)
        self.layout_utils.next_to_widget(
            src=self.resize_entry,
            target=self.resize_text,
            width=60,
            height=20,
            tiny_space=True
        )

        # Size - 标签
        self.size_text = ttk.Label(self.parent, text='Size', anchor=tk.W)
        self.layout_utils.next_to_widget(
            src=self.size_text,
            target=self.resize_entry,
            width=30,
            height=20,
            tiny_space=False
        )

        # Size - 输入框
        self.size_val = tk.StringVar()
        self.size_val.set('[-1, -1]')
        self.size_entry = ttk.Entry(self.parent, textvariable=self.size_val, justify=tk.LEFT)
        self.layout_utils.next_to_widget(
            src=self.size_entry,
            target=self.size_text,
            width=60,
            height=20,
            tiny_space=True
        )

        # 类别 - 标签
        self.category_text = ttk.Label(self.parent, text='Category', anchor=tk.W)
        self.layout_utils.below_widget(
            src=self.category_text,
            target=self.loss_func_text,
            width=72,
            height=20,
            tiny_space=True
        )

        # 类别 - 下拉框
        self.comb_category = ttk.Combobox(self.parent, values=(
            'CUSTOMIZED',
            'NUMERIC',
            'ALPHANUMERIC',
            'ALPHANUMERIC_LOWER',
            'ALPHANUMERIC_UPPER',
            'ALPHABET_LOWER',
            'ALPHABET_UPPER',
            'ALPHABET',
            'ARITHMETIC',
            'FLOAT',
            'CHS_3500',
            'ALPHANUMERIC_CHS_3500_LOWER',
            'DOCUMENT_OCR'
        ), state='readonly')
        self.comb_category.current(1)
        self.comb_category.bind("<<ComboboxSelected>>", lambda x: self.comb_category_callback(x))
        self.layout_utils.next_to_widget(
            src=self.comb_category,
            target=self.category_text,
            width=225,
            height=20,
            tiny_space=True
        )

        # 类别 - 自定义输入框
        self.category_val = tk.StringVar()
        self.category_val.set('')
        self.category_entry = ttk.Entry(self.parent, textvariable=self.category_val, justify=tk.LEFT, state=tk.DISABLED)
        self.layout_utils.next_to_widget(
            src=self.category_entry,
            target=self.comb_category,
            width=440,
            height=20,
            tiny_space=False
        )

        # ============================= Group 3 =====================================
        self.label_frame_train = ttk.Labelframe(self.parent, text='Training Configuration')
        self.layout_utils.below_widget(
            src=self.label_frame_train,
            target=self.label_frame_neu,
            width=790,
            height=60,
            tiny_space=True
        )

        # 任务完成标准 - 准确率 - 标签
        self.end_acc_text = ttk.Label(self.parent, text='End Accuracy', anchor=tk.W)
        self.layout_utils.inside_widget(
            src=self.end_acc_text,
            target=self.label_frame_train,
            width=85,
            height=20,
        )

        # 任务完成标准 - 准确率 - 输入框
        self.end_acc_val = tk.DoubleVar()
        self.end_acc_val.set(0.95)
        self.end_acc_entry = ttk.Entry(self.parent, textvariable=self.end_acc_val, justify=tk.LEFT)
        self.layout_utils.next_to_widget(
            src=self.end_acc_entry,
            target=self.end_acc_text,
            width=56,
            height=20,
            tiny_space=True
        )

        # 任务完成标准 - 平均损失 - 标签
        self.end_cost_text = ttk.Label(self.parent, text='End Cost', anchor=tk.W)
        self.layout_utils.next_to_widget(
            src=self.end_cost_text,
            target=self.end_acc_entry,
            width=60,
            height=20,
            tiny_space=False
        )

        # 任务完成标准 - 平均损失 - 输入框
        self.end_cost_val = tk.DoubleVar()
        self.end_cost_val.set(0.5)
        self.end_cost_entry = ttk.Entry(self.parent, textvariable=self.end_cost_val, justify=tk.LEFT)
        self.layout_utils.next_to_widget(
            src=self.end_cost_entry,
            target=self.end_cost_text,
            width=58,
            height=20,
            tiny_space=True
        )

        # 任务完成标准 - 循环轮次 - 标签
        self.end_epochs_text = ttk.Label(self.parent, text='End Epochs', anchor=tk.W)
        self.layout_utils.next_to_widget(
            src=self.end_epochs_text,
            target=self.end_cost_entry,
            width=72,
            height=20,
            tiny_space=False
        )

        # 任务完成标准 - 循环轮次 - 输入框
        self.end_epochs_spin = ttk.Spinbox(self.parent, from_=0, to=10000)
        self.end_epochs_spin.set(2)
        self.layout_utils.next_to_widget(
            src=self.end_epochs_spin,
            target=self.end_epochs_text,
            width=50,
            height=20,
            tiny_space=True
        )

        # 训练批次大小 - 标签
        self.batch_size_text = ttk.Label(self.parent, text='Train BatchSize', anchor=tk.W)
        self.layout_utils.next_to_widget(
            src=self.batch_size_text,
            target=self.end_epochs_spin,
            width=90,
            height=20,
            tiny_space=False
        )

        # 训练批次大小 - 输入框
        self.batch_size_val = tk.IntVar()
        self.batch_size_val.set(64)
        self.batch_size_entry = ttk.Entry(self.parent, textvariable=self.batch_size_val, justify=tk.LEFT)
        self.layout_utils.next_to_widget(
            src=self.batch_size_entry,
            target=self.batch_size_text,
            width=40,
            height=20,
            tiny_space=True
        )

        # 验证批次大小 - 标签
        self.validation_batch_size_text = ttk.Label(self.parent, text='Validation BatchSize', anchor=tk.W)
        self.layout_utils.next_to_widget(
            src=self.validation_batch_size_text,
            target=self.batch_size_entry,
            width=120,
            height=20,
            tiny_space=False
        )

        # 验证批次大小 - 输入框
        self.validation_batch_size_val = tk.IntVar()
        self.validation_batch_size_val.set(300)
        self.validation_batch_size_entry = ttk.Entry(self.parent, textvariable=self.validation_batch_size_val, justify=tk.LEFT)
        self.layout_utils.next_to_widget(
            src=self.validation_batch_size_entry,
            target=self.validation_batch_size_text,
            width=40,
            height=20,
            tiny_space=True
        )

        # ============================= Group 5 =====================================
        self.label_frame_project = ttk.Labelframe(self.parent, text='Project Configuration')
        self.layout_utils.below_widget(
            src=self.label_frame_project,
            target=self.label_frame_train,
            width=790,
            height=60,
            tiny_space=True
        )

        # 项目名 - 标签
        self.project_name_text = ttk.Label(self.parent, text='Project Name', anchor=tk.W)
        self.layout_utils.inside_widget(
            src=self.project_name_text,
            target=self.label_frame_project,
            width=90,
            height=20
        )

        # 项目名 - 下拉输入框
        self.comb_project_name = ttk.Combobox(self.parent)
        self.layout_utils.next_to_widget(
            src=self.comb_project_name,
            target=self.project_name_text,
            width=430,
            height=20,
            tiny_space=True
        )
        self.comb_project_name.bind(
            sequence="<Return>",
            func=lambda x: self.project_name_fill_callback(x)
        )
        self.comb_project_name.bind(
            sequence="<Button-1>",
            func=lambda x: self.fetch_projects()
        )

        def read_conf(event):
            threading.Thread(target=self.read_conf).start()

        self.comb_project_name.bind("<<ComboboxSelected>>", read_conf)

        # 保存配置 - 按钮
        self.btn_save_conf = ttk.Button(
            self.parent, text='Save Configuration', command=lambda: self.save_conf()
        )
        self.layout_utils.next_to_widget(
            src=self.btn_save_conf,
            target=self.comb_project_name,
            width=130,
            height=24,
            tiny_space=False,
            offset_y=-2
        )

        # 删除项目 - 按钮
        self.btn_delete = ttk.Button(
            self.parent, text='Delete', command=lambda: self.delete_project()
        )
        self.layout_utils.next_to_widget(
            src=self.btn_delete,
            target=self.btn_save_conf,
            width=80,
            height=24,
            tiny_space=False,
        )

        # ============================= Group 6 =====================================
        self.label_frame_dataset = ttk.Labelframe(
            self.parent, text='Sample Dataset'
        )
        self.layout_utils.below_widget(
            src=self.label_frame_dataset,
            target=self.label_frame_project,
            width=790,
            height=170,
            tiny_space=True
        )

        # 附加训练集 - 按钮
        self.btn_attach_dataset = ttk.Button(
            self.parent,
            text='Attach Dataset',
            command=lambda: self.attach_dataset()
        )
        self.layout_utils.inside_widget(
            src=self.btn_attach_dataset,
            target=self.label_frame_dataset,
            width=120,
            height=24,
        )

        # 附加训练集 - 显示框
        self.attach_dataset_val = tk.StringVar()
        self.attach_dataset_val.set('')
        self.attach_dataset_entry = ttk.Entry(
            self.parent, textvariable=self.attach_dataset_val, justify=tk.LEFT, state=tk.DISABLED
        )
        self.layout_utils.next_to_widget(
            src=self.attach_dataset_entry,
            target=self.btn_attach_dataset,
            width=420,
            height=24,
            tiny_space=True
        )

        # 验证集数目 - 标签
        self.validation_num_text = ttk.Label(self.parent, text='Validation Set Num', anchor=tk.W)
        self.layout_utils.next_to_widget(
            src=self.validation_num_text,
            target=self.attach_dataset_entry,
            width=120,
            height=20,
            tiny_space=False,
            offset_y=2
        )

        # 验证集数目 - 输入框
        self.validation_num_val = tk.IntVar()
        self.validation_num_val.set(300)
        self.validation_num_entry = ttk.Entry(self.parent, textvariable=self.validation_num_val, justify=tk.LEFT)
        self.layout_utils.next_to_widget(
            src=self.validation_num_entry,
            target=self.validation_num_text,
            width=71,
            height=20,
            tiny_space=True
        )

        # 训练集路径 - 标签
        self.dataset_train_path_text = ttk.Label(self.parent, text='Training Dataset', anchor=tk.W)
        self.layout_utils.below_widget(
            src=self.dataset_train_path_text,
            target=self.btn_attach_dataset,
            width=100,
            height=20,
            tiny_space=False
        )

        # 训练集路径 - 列表框
        self.dataset_train_listbox = tk.Listbox(self.parent, font=('微软雅黑', 9))
        self.layout_utils.next_to_widget(
            src=self.dataset_train_listbox,
            target=self.dataset_train_path_text,
            width=640,
            height=36,
            tiny_space=False
        )
        self.dataset_train_listbox.bind(
            sequence="<Delete>",
            func=lambda x: self.listbox_delete_item_callback(x, self.dataset_train_listbox)
        )
        self.listbox_scrollbar(self.dataset_train_listbox)

        # 验证集路径 - 标签
        label_edge = self.layout_utils.object_edge_info(self.dataset_train_path_text)
        widget_edge = self.layout_utils.object_edge_info(self.dataset_train_listbox)
        self.dataset_validation_path_text = ttk.Label(self.parent, text='Validation Dataset', anchor=tk.W)
        self.dataset_validation_path_text.place(
            x=label_edge['x'],
            y=widget_edge['edge_y'] + self.layout['global']['space']['y'] / 2,
            width=100,
            height=20
        )

        # 验证集路径 - 下拉输入框
        self.dataset_validation_listbox = tk.Listbox(self.parent, font=('微软雅黑', 9))
        self.layout_utils.next_to_widget(
            src=self.dataset_validation_listbox,
            target=self.dataset_validation_path_text,
            width=640,
            height=36,
            tiny_space=False
        )
        self.dataset_validation_listbox.bind(
            sequence="<Delete>",
            func=lambda x: self.listbox_delete_item_callback(x, self.dataset_validation_listbox)
        )
        self.listbox_scrollbar(self.dataset_validation_listbox)

        self.sample_map = {
            DatasetType.Directory: {
                RunMode.Trains: self.source_train_path_listbox,
                RunMode.Validation: self.source_validation_path_listbox
            },
            DatasetType.TFRecords: {
                RunMode.Trains: self.dataset_train_listbox,
                RunMode.Validation: self.dataset_validation_listbox
            }
        }

        # 开始训练 - 按钮
        self.btn_training = ttk.Button(self.parent, text='Start Training', command=lambda: self.start_training())
        self.layout_utils.widget_from_right(
            src=self.btn_training,
            target=self.label_frame_dataset,
            width=120,
            height=24,
            tiny_space=True
        )

        # 终止训练 - 按钮
        self.btn_stop = ttk.Button(self.parent, text='Stop', command=lambda: self.stop_training())
        self.button_state(self.btn_stop, tk.DISABLED)
        self.layout_utils.before_widget(
            src=self.btn_stop,
            target=self.btn_training,
            width=60,
            height=24,
            tiny_space=True
        )

        # 编译模型 - 按钮
        self.btn_compile = ttk.Button(self.parent, text='Compile', command=lambda: self.compile())
        self.layout_utils.before_widget(
            src=self.btn_compile,
            target=self.btn_stop,
            width=80,
            height=24,
            tiny_space=True
        )

        # 打包训练集 - 按钮
        self.btn_make_dataset = ttk.Button(self.parent, text='Make Dataset', command=lambda: self.make_dataset())
        self.layout_utils.before_widget(
            src=self.btn_make_dataset,
            target=self.btn_compile,
            width=120,
            height=24,
            tiny_space=True
        )

        # 清除训练记录 - 按钮
        self.btn_reset_history = ttk.Button(
            self.parent, text='Reset History', command=lambda: self.reset_history()
        )
        self.layout_utils.before_widget(
            src=self.btn_reset_history,
            target=self.btn_make_dataset,
            width=120,
            height=24,
            tiny_space=True
        )

        # 预测 - 按钮
        self.btn_testing = ttk.Button(
            self.parent, text='Testing', command=lambda: self.testing_model()
        )
        self.layout_utils.before_widget(
            src=self.btn_testing,
            target=self.btn_reset_history,
            width=80,
            height=24,
            tiny_space=True
        )

        self.parent.geometry(size)

    @staticmethod
    def threading_exec(func, *args) -> threading.Thread:
        th = threading.Thread(target=func, args=args)
        th.setDaemon(True)
        th.start()
        return th

    def popup_data_augmentation(self):
        if not self.current_project:
            messagebox.showerror(
                "Error!", "Please set the project name first."
            )
            return
        data_augmentation = DataAugmentationDialog()
        data_augmentation.read_conf(self.data_augmentation_entity)

    def popup_pretreatment(self):
        if not self.current_project:
            messagebox.showerror(
                "Error!", "Please set the project name first."
            )
            return
        pretreatment = PretreatmentDialog()
        pretreatment.read_conf(self.pretreatment_entity)

    @staticmethod
    def listbox_scrollbar(listbox: tk.Listbox):
        y_scrollbar = tk.Scrollbar(
            listbox, command=listbox.yview
        )
        y_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
        listbox.config(yscrollcommand=y_scrollbar.set)

    def blank_click(self, event):
        if self.current_project != self.comb_project_name.get():
            self.project_name_fill_callback(event)

    def project_name_fill_callback(self, event):
        suffix = '-{}-{}-H{}-{}-C{}'.format(
            self.comb_neu_cnn.get(),
            self.comb_recurrent.get(),
            self.units_num_spin.get(),
            self.comb_loss.get(),
            self.comb_channel.get(),
        )
        current_project_name = self.comb_project_name.get()
        if len(current_project_name) > 0 and current_project_name not in self.project_names:
            self.extract_regex = ".*?(?=_)"
            self.label_from_var.set('FileName')
            self.sample_map[DatasetType.Directory][RunMode.Trains].delete(0, tk.END)
            self.sample_map[DatasetType.Directory][RunMode.Validation].delete(0, tk.END)
            self.category_val.set("")
            if not current_project_name.endswith(suffix):
                self.comb_project_name.insert(tk.END, suffix)
            self.current_project = self.comb_project_name.get()
            self.update_dataset_files_path(mode=RunMode.Trains)
            self.update_dataset_files_path(mode=RunMode.Validation)
            self.data_augmentation_entity = DataAugmentationEntity()
            self.pretreatment_entity = PretreatmentEntity()

    @property
    def project_path(self):
        if not self.current_project:
            return None
        project_path = "{}/{}".format(self.project_root_path, self.current_project)
        if not os.path.exists(project_path):
            os.makedirs(project_path)
        return project_path

    def update_dataset_files_path(self, mode: RunMode):
        dataset_name = "dataset/{}.0.tfrecords".format(mode.value)
        dataset_path = os.path.join(self.project_path, dataset_name)
        dataset_path = dataset_path.replace("\\", '/')
        self.sample_map[DatasetType.TFRecords][mode].delete(0, tk.END)
        self.sample_map[DatasetType.TFRecords][mode].insert(tk.END, dataset_path)
        self.save_conf()

    def attach_dataset(self):
        if self.is_task_running:
            messagebox.showerror(
                "Error!", "Please terminate the current training first or wait for the training to end."
            )
            return
        if not self.current_project:
            messagebox.showerror(
                "Error!", "Please set the project name first."
            )
            return
        filename = filedialog.askdirectory()
        if not filename:
            return
        model_conf = ModelConfig(self.current_project)

        if not self.check_dataset(model_conf):
            return

        self.attach_dataset_val.set(filename)
        self.sample_map[DatasetType.Directory][RunMode.Trains].insert(tk.END, filename)
        self.button_state(self.btn_attach_dataset, tk.DISABLED)

        for mode in [RunMode.Trains, RunMode.Validation]:
            attached_dataset_name = model_conf.dataset_increasing_name(mode)
            attached_dataset_name = "dataset/{}".format(attached_dataset_name)
            attached_dataset_path = os.path.join(self.project_path, attached_dataset_name)
            attached_dataset_path = attached_dataset_path.replace("\\", '/')
            if mode == RunMode.Validation and self.validation_num_val.get() == 0:
                continue
            self.sample_map[DatasetType.TFRecords][mode].insert(tk.END, attached_dataset_path)
        self.save_conf()
        model_conf = ModelConfig(self.current_project)
        self.threading_exec(
            lambda: DataSets(model_conf).make_dataset(
                trains_path=filename,
                is_add=True,
                callback=lambda: self.button_state(self.btn_attach_dataset, tk.NORMAL),
                msg=lambda x: tk.messagebox.showinfo('Attach Dataset Status', x)
            )
        )
        pass

    @staticmethod
    def button_state(btn: ttk.Button, state: str):
        btn['state'] = state

    def delete_project(self):
        if not self.current_project:
            messagebox.showerror(
                "Error!", "Please select a project to delete."
            )
            return
        if self.is_task_running:
            messagebox.showerror(
                "Error!", "Please terminate the current training first or wait for the training to end."
            )
            return
        project_path = "./projects/{}".format(self.current_project)
        try:
            shutil.rmtree(project_path)
        except Exception as e:
            messagebox.showerror(
                "Error!", json.dumps(e.args, ensure_ascii=False)
            )
        messagebox.showinfo(
            "Error!", "Delete successful!"
        )
        self.comb_project_name.delete(0, tk.END)

    def reset_history(self):
        if not self.current_project:
            messagebox.showerror(
                "Error!", "Please select a project first."
            )
            return
        if self.is_task_running:
            messagebox.showerror(
                "Error!", "Please terminate the current training first or wait for the training to end."
            )
            return
        project_history_path = "./projects/{}/model".format(self.current_project)
        try:
            shutil.rmtree(project_history_path)
        except Exception as e:
            messagebox.showerror(
                "Error!", json.dumps(e.args, ensure_ascii=False)
            )
        messagebox.showinfo(
            "Error!", "Delete history successful!"
        )

    def testing_model(self):
        filename = filedialog.askdirectory()
        if not filename:
            return
        filename = filename.replace("\\", "/")
        predict = Predict(project_name=self.current_project)
        predict.testing(image_dir=filename, limit=self.validation_batch_size)

    def clear_dataset(self):
        if not self.current_project:
            messagebox.showerror(
                "Error!", "Please select a project first."
            )
            return
        if self.is_task_running:
            messagebox.showerror(
                "Error!", "Please terminate the current training first or wait for the training to end."
            )
            return
        project_history_path = "./projects/{}/dataset".format(self.current_project)
        try:
            shutil.rmtree(project_history_path)
            self.dataset_train_listbox.delete(1, tk.END)
            self.dataset_validation_listbox.delete(1, tk.END)
        except Exception as e:
            messagebox.showerror(
                "Error!", json.dumps(e.args, ensure_ascii=False)
            )
        messagebox.showinfo(
            "Error!", "Clear dataset successful!"
        )

    @staticmethod
    def popup_about():
        messagebox.showinfo("About", "Image Classification Wizard Tool based on Deep Learning 1.0 CORE_VERSION({})\n\nAuthor's mailbox: kerlomz@gmail.com\n\nQQ Group: 857149419".format(CORE_VERSION))

    def auto_loss(self, event):
        if self.comb_recurrent.get() == 'NoRecurrent':
            self.comb_loss.set("CrossEntropy")

    @staticmethod
    def get_param(src: dict, key, default=None):
        result = src.get(key)
        return result if result else default

    def read_conf(self):
        print('Reading configuration...')
        selected = self.comb_project_name.get()
        self.current_project = selected
        model_conf = ModelConfig(selected)
        self.edit_var.set(model_conf.memory_usage)
        self.size_val.set("[{}, {}]".format(model_conf.image_width, model_conf.image_height))
        self.resize_val.set(json.dumps(model_conf.resize))
        self.source_train_path_listbox.delete(0, tk.END)
        self.source_validation_path_listbox.delete(0, tk.END)
        self.dataset_validation_listbox.delete(0, tk.END)
        self.dataset_train_listbox.delete(0, tk.END)
        for source_train in self.get_param(model_conf.trains_path, DatasetType.Directory, default=[]):
            self.source_train_path_listbox.insert(tk.END, source_train)
        for source_validation in self.get_param(model_conf.validation_path, DatasetType.Directory, default=[]):
            self.source_validation_path_listbox.insert(tk.END, source_validation)
        self.label_num_spin.set(model_conf.max_label_num)
        self.comb_channel.set(model_conf.image_channel)
        self.comb_neu_cnn.set(model_conf.neu_cnn_param)
        self.comb_recurrent.set(model_conf.neu_recurrent_param)
        self.units_num_spin.set(model_conf.units_num)
        self.comb_loss.set(model_conf.loss_func_param)

        self.extract_regex = model_conf.extract_regex
        self.label_split = model_conf.label_split
        self.label_from_var.set(model_conf.label_from.value)

        self.comb_optimizer.set(model_conf.neu_optimizer_param)
        self.learning_rate_spin.set(model_conf.trains_learning_rate)
        self.end_acc_val.set(model_conf.trains_end_acc)
        self.end_cost_val.set(model_conf.trains_end_cost)
        self.end_epochs_spin.set(model_conf.trains_end_epochs)
        self.batch_size_val.set(model_conf.batch_size)
        self.validation_batch_size_val.set(model_conf.validation_batch_size)
        self.validation_num_val.set(model_conf.validation_set_num)

        self.data_augmentation_entity.binaryzation = model_conf.da_binaryzation
        self.data_augmentation_entity.median_blur = model_conf.da_median_blur
        self.data_augmentation_entity.gaussian_blur = model_conf.da_gaussian_blur
        self.data_augmentation_entity.equalize_hist = model_conf.da_equalize_hist
        self.data_augmentation_entity.laplace = model_conf.da_laplace
        self.data_augmentation_entity.warp_perspective = model_conf.da_warp_perspective
        self.data_augmentation_entity.rotate = model_conf.da_rotate
        self.data_augmentation_entity.sp_noise = model_conf.da_sp_noise
        self.data_augmentation_entity.brightness = model_conf.da_brightness
        self.data_augmentation_entity.hue = model_conf.da_hue
        self.data_augmentation_entity.saturation = model_conf.da_saturation
        self.data_augmentation_entity.gamma = model_conf.da_gamma
        self.data_augmentation_entity.channel_swap = model_conf.da_channel_swap
        self.data_augmentation_entity.random_blank = model_conf.da_random_blank
        self.data_augmentation_entity.random_transition = model_conf.da_random_transition
        self.data_augmentation_entity.random_captcha = model_conf.da_random_captcha

        self.pretreatment_entity.binaryzation = model_conf.pre_binaryzation
        self.pretreatment_entity.replace_transparent = model_conf.pre_replace_transparent
        self.pretreatment_entity.horizontal_stitching = model_conf.pre_horizontal_stitching
        self.pretreatment_entity.concat_frames = model_conf.pre_concat_frames
        self.pretreatment_entity.blend_frames = model_conf.pre_blend_frames
        self.pretreatment_entity.exec_map = model_conf.pre_exec_map

        for dataset_validation in self.get_param(model_conf.validation_path, DatasetType.TFRecords, default=[]):
            self.dataset_validation_listbox.insert(tk.END, dataset_validation)
        for dataset_train in self.get_param(model_conf.trains_path, DatasetType.TFRecords, default=[]):
            self.dataset_train_listbox.insert(tk.END, dataset_train)

        # print('Loading category configuration...')
        if isinstance(model_conf.category_param, list):
            self.category_entry['state'] = tk.DISABLED
            self.comb_category.set('CUSTOMIZED')
            if len(model_conf.category_param) > 1000:
                self.category_val.set(NOT_EDITABLE_MSG)
            else:
                self.category_val.set(model_conf.category_param_text)
                self.category_entry['state'] = tk.NORMAL
        else:
            self.category_val.set("")
            self.category_entry['state'] = tk.DISABLED
            self.comb_category.set(model_conf.category_param)
        # print('Loading configuration is completed.')
        self.model_conf = model_conf
        return self.model_conf

    @property
    def validation_batch_size(self):
        # if self.dataset_validation_listbox.size() > 1:
        return self.validation_batch_size_val.get()
        # else:
        #     return min(self.validation_batch_size_val.get(), self.validation_num_val.get())

    @property
    def device_usage(self):
        return self.edit_var.get()

    def save_conf(self):
        if not self.current_project:
            messagebox.showerror(
                "Error!", "Please set the project name first."
            )
            return
        model_conf = ModelConfig(
            project_name=self.current_project,
            MemoryUsage=self.device_usage,
            CNNNetwork=self.neu_cnn,
            RecurrentNetwork=self.neu_recurrent,
            UnitsNum=self.units_num_spin.get(),
            Optimizer=self.optimizer,
            LossFunction=self.loss_func,
            Decoder=self.comb_loss.get(),
            ModelName=self.current_project,
            ModelField=ModelField.Image.value,
            ModelScene=ModelScene.Classification.value,
            Category=self.category,
            Resize=self.resize,
            ImageChannel=self.comb_channel.get(),
            ImageWidth=self.image_width,
            ImageHeight=self.image_height,
            MaxLabelNum=self.label_num_spin.get(),
            AutoPadding=True,
            ReplaceTransparent=False,
            HorizontalStitching=False,
            OutputSplit='',
            LabelFrom=self.label_from_var.get(),
            ExtractRegex=self.extract_regex,
            LabelSplit=self.label_split,
            DatasetTrainsPath=self.dataset_value(
                dataset_type=DatasetType.TFRecords, mode=RunMode.Trains
            ),
            DatasetValidationPath=self.dataset_value(
                dataset_type=DatasetType.TFRecords, mode=RunMode.Validation
            ),
            SourceTrainPath=self.dataset_value(
                dataset_type=DatasetType.Directory, mode=RunMode.Trains
            ),
            SourceValidationPath=self.dataset_value(
                dataset_type=DatasetType.Directory, mode=RunMode.Validation
            ),
            ValidationSetNum=self.validation_num_val.get(),
            SavedSteps=100,
            ValidationSteps=500,
            EndAcc=self.end_acc_val.get(),
            EndCost=self.end_cost_val.get(),
            EndEpochs=self.end_epochs_spin.get(),
            BatchSize=self.batch_size_val.get(),
            ValidationBatchSize=self.validation_batch_size,
            LearningRate=self.learning_rate_spin.get(),
            DA_Binaryzation=self.data_augmentation_entity.binaryzation,
            DA_MedianBlur=self.data_augmentation_entity.median_blur,
            DA_GaussianBlur=self.data_augmentation_entity.gaussian_blur,
            DA_EqualizeHist=self.data_augmentation_entity.equalize_hist,
            DA_Laplace=self.data_augmentation_entity.laplace,
            DA_WarpPerspective=self.data_augmentation_entity.warp_perspective,
            DA_Rotate=self.data_augmentation_entity.rotate,
            DA_PepperNoise=self.data_augmentation_entity.sp_noise,
            DA_Brightness=self.data_augmentation_entity.brightness,
            DA_Saturation=self.data_augmentation_entity.saturation,
            DA_Hue=self.data_augmentation_entity.hue,
            DA_Gamma=self.data_augmentation_entity.gamma,
            DA_ChannelSwap=self.data_augmentation_entity.channel_swap,
            DA_RandomBlank=self.data_augmentation_entity.random_blank,
            DA_RandomTransition=self.data_augmentation_entity.random_transition,
            DA_RandomCaptcha=self.data_augmentation_entity.random_captcha,
            Pre_Binaryzation=self.pretreatment_entity.binaryzation,
            Pre_ReplaceTransparent=self.pretreatment_entity.replace_transparent,
            Pre_HorizontalStitching=self.pretreatment_entity.horizontal_stitching,
            Pre_ConcatFrames=self.pretreatment_entity.concat_frames,
            Pre_BlendFrames=self.pretreatment_entity.blend_frames,
            Pre_ExecuteMap=self.pretreatment_entity.exec_map
        )
        model_conf.update()
        return model_conf

    def make_dataset(self):
        if not self.current_project:
            messagebox.showerror(
                "Error!", "Please set the project name first."
            )
            return
        if self.is_task_running:
            messagebox.showerror(
                "Error!", "Please terminate the current training first or wait for the training to end."
            )
            return
        self.save_conf()
        self.button_state(self.btn_make_dataset, tk.DISABLED)
        model_conf = ModelConfig(self.current_project)
        train_path = self.dataset_value(DatasetType.Directory, RunMode.Trains)
        validation_path = self.dataset_value(DatasetType.Directory, RunMode.Validation)
        if len(train_path) < 1:
            messagebox.showerror(
                "Error!", "{} Sample set has not been added.".format(RunMode.Trains.value)
            )
            self.button_state(self.btn_make_dataset, tk.NORMAL)
            return
        self.threading_exec(
            lambda: DataSets(model_conf).make_dataset(
                trains_path=train_path,
                validation_path=validation_path,
                is_add=False,
                callback=lambda: self.button_state(self.btn_make_dataset, tk.NORMAL),
                msg=lambda x: tk.messagebox.showinfo('Make Dataset Status', x)
            )
        )

    @property
    def size(self):
        return self.json_filter(self.size_val.get(), int)

    @property
    def image_height(self):
        return self.size[1]

    @property
    def image_width(self):
        return self.size[0]

    @property
    def resize(self):
        return self.json_filter(self.resize_val.get(), int)

    @property
    def neu_cnn(self):
        return self.comb_neu_cnn.get()

    @property
    def neu_recurrent(self):
        return self.comb_recurrent.get()

    @property
    def loss_func(self):
        return self.comb_loss.get()

    @property
    def optimizer(self):
        return self.comb_optimizer.get()

    @staticmethod
    def json_filter(content, item_type):
        if not content:
            messagebox.showerror(
                "Error!", "To select a customized category, you must specify the category set manually."
            )
            return None
        try:
            content = json.loads(content)
        except ValueError as e:
            messagebox.showerror(
                "Error!", "Input must be of type JSON."
            )
            return None
        content = [item_type(i) for i in content]
        return content

    @property
    def category(self):
        comb_selected = self.comb_category.get()
        if not comb_selected:
            messagebox.showerror(
                "Error!", "Please select built-in category or custom category first"
            )
            return None
        if comb_selected == 'CUSTOMIZED':
            category_value = self.category_entry.get()
            if category_value == NOT_EDITABLE_MSG:
                return self.model_conf.category_param_text
            category_value = category_value.replace("'", '"') if "'" in category_value else category_value
            category_value = self.json_filter(category_value, str)
        else:
            category_value = comb_selected
        return category_value

    def dataset_value(self, dataset_type: DatasetType, mode: RunMode):
        listbox = self.sample_map[dataset_type][mode]
        value = list(listbox.get(0, listbox.size() - 1))
        return value

    def compile_task(self):
        if not self.current_project:
            messagebox.showerror(
                "Error!", "Please set the project name first."
            )
            return
        model_conf = ModelConfig(project_name=self.current_project)
        if not os.path.exists(model_conf.model_root_path):
            messagebox.showerror(
                "Error", "Model storage folder does not exist."
            )
            return
        if len(os.listdir(model_conf.model_root_path)) < 3:
            messagebox.showerror(
                "Error", "There is no training model record, please train before compiling."
            )
            return
        try:
            if not self.current_task:
                self.current_task = Trains(model_conf)

            self.current_task.compile_graph(0)
            status = 'Compile completed'
        except Exception as e:
            messagebox.showerror(
                e.__class__.__name__, json.dumps(e.args, ensure_ascii=False)
            )
            status = 'Compile failure'
        tk.messagebox.showinfo('Compile Status', status)

    def compile(self):
        self.job = self.threading_exec(
            lambda: self.compile_task()
        )

    def training_task(self):
        model_conf = ModelConfig(project_name=self.current_project)

        self.current_task = Trains(model_conf)
        try:
            self.button_state(self.btn_training, tk.DISABLED)
            self.button_state(self.btn_stop, tk.NORMAL)
            self.is_task_running = True
            self.current_task.train_process()
            status = 'Training completed'
        except Exception as e:
            traceback.print_exc()
            messagebox.showerror(
                e.__class__.__name__, json.dumps(e.args, ensure_ascii=False)
            )
            status = 'Training failure'
        self.button_state(self.btn_training, tk.NORMAL)
        self.button_state(self.btn_stop, tk.DISABLED)
        self.comb_project_name['state'] = tk.NORMAL
        self.is_task_running = False
        tk.messagebox.showinfo('Training Status', status)

    @staticmethod
    def check_dataset(model_conf):
        trains_path = model_conf.trains_path[DatasetType.TFRecords]
        validation_path = model_conf.validation_path[DatasetType.TFRecords]
        if not trains_path or not validation_path:
            messagebox.showerror(
                "Error!", "Training set or validation set not defined."
            )
            return False
        for tp in trains_path:
            if not os.path.exists(tp):
                messagebox.showerror(
                    "Error!", "Training set path does not exist, please make dataset first"
                )
                return False
        for vp in validation_path:
            if not os.path.exists(vp):
                messagebox.showerror(
                    "Error!", "Validation set path does not exist, please make dataset first"
                )
                return False
        return True

    def start_training(self):
        if not self.check_resize():
            return
        if not self.current_project:
            messagebox.showerror(
                "Error!", "Please set the project name first."
            )
            return
        model_conf = self.save_conf()
        if not self.check_dataset(model_conf):
            return
        self.comb_project_name['state'] = tk.DISABLED
        self.job = self.threading_exec(
            lambda: self.training_task()
        )

    def stop_training(self):
        self.current_task.stop_flag = True

    @property
    def project_names(self):
        return [i.name for i in os.scandir(self.project_root_path) if i.is_dir()]

    def fetch_projects(self):
        self.comb_project_name['values'] = self.project_names

    def browse_dataset(self, dataset_type: DatasetType, mode: RunMode):
        if not self.current_project:
            messagebox.showerror(
                "Error!", "Please define the project name first."
            )
            return
        filename = filedialog.askdirectory()
        if not filename:
            return
        is_sub = False
        for i, item in enumerate(os.scandir(filename)):
            if item.is_dir():
                path = item.path.replace("\\", "/")
                if self.sample_map[dataset_type][mode].size() == 0:
                    self.fetch_sample([path])
                self.sample_map[dataset_type][mode].insert(tk.END, path)
                if i > 0:
                    continue
                is_sub = True
            else:
                break
        if not is_sub:
            filename = filename.replace("\\", "/")
            if self.sample_map[dataset_type][mode].size() == 0:
                self.fetch_sample([filename])
            self.sample_map[dataset_type][mode].insert(tk.END, filename)

    @staticmethod
    def closest_category(category):
        category = set(category)
        category_group = dict()
        for key in SIMPLE_CATEGORY_MODEL.keys():
            category_set = set(category_extract(key))
            if category <= category_set:
                category_group[key] = len(category_set) - len(category)
        if not category_group:
            return None
        min_index = min(category_group.values())
        for k, v in category_group.items():
            if v == min_index:
                return k

    def fetch_sample(self, dataset_path):
        file_names = os.listdir(dataset_path[0])[0:100]
        category = list()
        len_label = -1

        for file_name in file_names:
            if "_" in file_name:
                label = file_name.split("_")[0]
                label = [i for i in label]
                len_label = len(label)
                category.extend(label)

        size = PilImage.open(os.path.join(dataset_path[0], file_names[0])).size
        self.size_val.set(json.dumps(size))
        self.resize_val.set(json.dumps(size))
        self.label_num_spin.set(len_label)
        if not self.category_val.get() or self.category_val.get() != NOT_EDITABLE_MSG:
            category_pram = self.closest_category(category)
            if not category_pram:
                return
            self.comb_category.set(category_pram)

    def listbox_delete_item_callback(self, event, listbox: tk.Listbox):
        i = listbox.curselection()[0]
        listbox.delete(i)
        self.save_conf()

    def comb_category_callback(self, event):
        comb_selected = self.comb_category.get()
        if comb_selected == 'CUSTOMIZED':
            self.category_entry['state'] = tk.NORMAL
        else:
            self.category_entry.delete(0, tk.END)
            self.category_entry['state'] = tk.DISABLED

    def check_resize(self):
        if self.loss_func == 'CTC':
            return True
        param = OUTPUT_SHAPE1_MAP[NETWORK_MAP[self.neu_cnn]]
        shape1w = math.ceil(1.0*self.resize[0]/param[0])
        shape1h = math.ceil(1.0*self.resize[1]/param[0])
        input_s1 = shape1w * shape1h * param[1]
        label_num = int(self.label_num_spin.get())
        if input_s1 % label_num != 0:
            messagebox.showerror(
                "Error!", "Shape[1] = {} must divide the label_num = {}.".format(input_s1, label_num)
            )
            return False
        return True

    @staticmethod
    def resource_path(relative_path):
        try:
            # PyInstaller creates a temp folder and stores path in _MEIPASS
            base_path = sys._MEIPASS
        except AttributeError:
            base_path = os.path.abspath(".")
        return os.path.join(base_path, relative_path)


if __name__ == '__main__':
    root = tk.Tk()
    app = Wizard(root)
    root.mainloop()