#-*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import time
import sys
import glob
import cv2
import six
import gc
import paddle
import paddle.fluid as fluid
from collections import namedtuple
import paddle.dataset as dataset

from data_augmentor import DataAugmentor
import data_augmentor

# 路径相关
RootPath = os.path.abspath("./")
sys.path.append(RootPath)
print ("项目根目录路径为: ",RootPath)

# 标注数据类别
Label = namedtuple( 'Label' , [

    'name'        ,
    'id'          ,
    'trainId'     ,
    'category'    ,
    'categoryId'  ,
    'hasInstances',
    'ignoreInEval',
    'color'       ,
    ] )

# 标注定义
labels = [
    #           name     id trainId      category  catId hasInstances ignoreInEval            color
    Label(     'void' ,   0 ,     0,        'void' ,   0 ,      False ,      False , (  0,   0,   0) ),
    Label(    's_w_d' , 200 ,     1 ,   'dividing' ,   1 ,      False ,      False , ( 70, 130, 180) ),
    Label(    's_y_d' , 204 ,     1 ,   'dividing' ,   1 ,      False ,      False , (220,  20,  60) ),
    Label(  'ds_w_dn' , 213 ,     1 ,   'dividing' ,   1 ,      False ,       True , (128,   0, 128) ),
    Label(  'ds_y_dn' , 209 ,     1 ,   'dividing' ,   1 ,      False ,      False , (255, 0,   0) ),
    Label(  'sb_w_do' , 206 ,     1 ,   'dividing' ,   1 ,      False ,       True , (  0,   0,  60) ),
    Label(  'sb_y_do' , 207 ,     1 ,   'dividing' ,   1 ,      False ,       True , (  0,  60, 100) ),
    Label(    'b_w_g' , 201 ,     2 ,    'guiding' ,   2 ,      False ,      False , (  0,   0, 142) ),
    Label(    'b_y_g' , 203 ,     2 ,    'guiding' ,   2 ,      False ,      False , (119,  11,  32) ),
    Label(   'db_w_g' , 211 ,     2 ,    'guiding' ,   2 ,      False ,       True , (244,  35, 232) ),
    Label(   'db_y_g' , 208 ,     2 ,    'guiding' ,   2 ,      False ,       True , (  0,   0, 160) ),
    Label(   'db_w_s' , 216 ,     3 ,   'stopping' ,   3 ,      False ,       True , (153, 153, 153) ),
    Label(    's_w_s' , 217 ,     3 ,   'stopping' ,   3 ,      False ,      False , (220, 220,   0) ),
    Label(   'ds_w_s' , 215 ,     3 ,   'stopping' ,   3 ,      False ,       True , (250, 170,  30) ),
    Label(    's_w_c' , 218 ,     4 ,    'chevron' ,   4 ,      False ,       True , (102, 102, 156) ),
    Label(    's_y_c' , 219 ,     4 ,    'chevron' ,   4 ,      False ,       True , (128,   0,   0) ),
    Label(    's_w_p' , 210 ,     5 ,    'parking' ,   5 ,      False ,      False , (128,  64, 128) ),
    Label(    's_n_p' , 232 ,     5 ,    'parking' ,   5 ,      False ,       True , (238, 232, 170) ),
    Label(   'c_wy_z' , 214 ,     6 ,      'zebra' ,   6 ,      False ,      False , (190, 153, 153) ),
    Label(    'a_w_u' , 202 ,     7 ,  'thru/turn' ,   7 ,      False ,       True , (  0,   0, 230) ),
    Label(    'a_w_t' , 220 ,     7 ,  'thru/turn' ,   7 ,      False ,      False , (128, 128,   0) ),
    Label(   'a_w_tl' , 221 ,     7 ,  'thru/turn' ,   7 ,      False ,      False , (128,  78, 160) ),
    Label(   'a_w_tr' , 222 ,     7 ,  'thru/turn' ,   7 ,      False ,      False , (150, 100, 100) ),
    Label(  'a_w_tlr' , 231 ,     7 ,  'thru/turn' ,   7 ,      False ,       True , (255, 165,   0) ),
    Label(    'a_w_l' , 224 ,     7 ,  'thru/turn' ,   7 ,      False ,      False , (180, 165, 180) ),
    Label(    'a_w_r' , 225 ,     7 ,  'thru/turn' ,   7 ,      False ,      False , (107, 142,  35) ),
    Label(   'a_w_lr' , 226 ,     7 ,  'thru/turn' ,   7 ,      False ,      False , (201, 255, 229) ),
    Label(   'a_n_lu' , 230 ,     7 ,  'thru/turn' ,   7 ,      False ,       True , (0,   191, 255) ),
    Label(   'a_w_tu' , 228 ,     7 ,  'thru/turn' ,   7 ,      False ,       True , ( 51, 255,  51) ),
    Label(    'a_w_m' , 229 ,     7 ,  'thru/turn' ,   7 ,      False ,       True , (250, 128, 114) ),
    Label(    'a_y_t' , 233 ,     7 ,  'thru/turn' ,   7 ,      False ,       True , (127, 255,   0) ),
    Label(   'b_n_sr' , 205 ,     8 ,  'reduction' ,   8 ,      False ,      False , (255, 128,   0) ),
    Label(  'd_wy_za' , 212 ,     8 ,  'attention' ,   8 ,      False ,       True , (  0, 255, 255) ),
    Label(  'r_wy_np' , 227 ,     8 , 'no parking' ,   8 ,      False ,      False , (178, 132, 190) ),
    Label( 'vom_wy_n' , 223 ,     8 ,     'others' ,   8 ,      False ,       True , (128, 128,  64) ),
    Label(   'om_n_n' , 250 ,     8 ,     'others' ,   8 ,      False ,      False , (102,   0, 204) ),
    Label(    'noise' , 249 ,     0 ,    'ignored' ,   0 ,      False ,       True , (  0, 153, 153) ),
    Label(  'ignored' , 255 ,     0 ,    'ignored' ,   0 ,      False ,       True , (255, 255, 255) ),
]

# 名字转标注
name2label      = { label.name    : label for label in labels           }

# id转标注
id2label        = { label.id      : label for label in labels           }

# 训练id转标注
trainId2label   = { label.trainId : label for label in reversed(labels) }

print ("标准转换检测 200 --->  1: ",id2label[200].trainId)

# 数据预处理
augmentor = DataAugmentor()
class TrainDataReader:
    def __init__(self, dataset_dir, subset='train',rows=2000, cols=1354, shuffle=True, birdeye=True):
        label_dirname = dataset_dir + subset
        print (label_dirname)

        if six.PY2:
            import commands
            label_files = commands.getoutput(
                "find %s -type f | grep _bin.png | sort" %
                label_dirname).splitlines()
        else:
            import subprocess
            label_files = subprocess.getstatusoutput(
                "find %s -type f | grep _bin.png | sort" %
                label_dirname)[-1].splitlines()

        print ('---')
        print (label_files[0])
        self.label_files = label_files
        self.label_dirname = label_dirname
        self.rows = rows
        self.cols = cols
        self.index = 0
        self.subset = subset
        self.dataset_dir = dataset_dir
        self.shuffle = shuffle
        self.M = 0
        self.Minv = 0
        self.reset()
        self.get_M_Minv()
        self.augmentor = 0
        self.birdeye = birdeye
        print("images total number", len(label_files))

    # 标签转分类 255 ignore ?
    def label2classes(self, label,row,col):
        x = np.zeros([row,col,9]).astype(np.int64)
        for i in range(row):
            for j in range(col):
                    try:
                        trainId = id2label[int(label[i][j])].trainId
                        x[i, j ,trainId] = 1 # 属于第m类,第三维m处值为1
                    except Exception as err:
                        #print('像素级标签值异常',err)
                        pass
        return x


    def get_M_Minv(self):
        # 左上、右上、左下、右下
        src = np.float32([[800, 730], [2583, 730], [0, 1709], [3383, 1709]])
        dst = np.float32([[0, 0], [3999,0], [1300, 3999], [2700, 3999]])
        self.M = cv2.getPerspectiveTransform(src, dst)
        self.Minv = cv2.getPerspectiveTransform(dst,src)

    def reset(self, shuffle=False):
        self.index = 0
        if self.shuffle:
            np.random.shuffle(self.label_files)

    def next_img(self):
        self.index += 1
        if self.index >= len(self.label_files):
            self.reset()

    def prev_img(self):
        if self.index >= 1:
            self.index -= 1

    def get_img(self):
        #if self.augmentor != 0 and self.augmentor < 2:
        #    self.prev_img()
        while True:
            label_name = self.label_files[self.index]
            img_name = label_name.replace('_bin.png', '.jpg')
            img_name = img_name.replace('Label', 'ColorImage')
            label = cv2.imread(label_name,cv2.IMREAD_GRAYSCALE)
            img = cv2.imread(img_name)
            if img is None:
                print("load img failed:", img_name)
                self.next_img()
            else:
                break
        try:
            if self.birdeye ==True:
                warped_img = cv2.warpPerspective(img, self.M, (4000, 4000),flags=cv2.INTER_CUBIC)
                warped_label = cv2.warpPerspective(label, self.M, (4000, 4000),flags=cv2.INTER_NEAREST)
                label = cv2.resize(warped_label, (self.cols, self.rows), interpolation=cv2.INTER_NEAREST)
                img   = cv2.resize(warped_img, (self.cols, self.rows), interpolation=cv2.INTER_CUBIC)
            else:
                label = cv2.resize(label, (self.cols, self.rows), interpolation=cv2.INTER_NEAREST)
                img   = cv2.resize(img, (self.cols, self.rows), interpolation=cv2.INTER_CUBIC)
        except Exception as err:
            print('warped_error: ',err)
            img = np.zeros([self.cols,self.rows,3]).astype(np.uint8)
            label = np.zeros([self.cols,self.rows]).astype(np.uint8)
        # 数据增广
        if self.augmentor != 0:
            if self.augmentor < 2:
                img,label = augmentor.disturb(img, label)
            else :
                self.augmentor = 0
        img   = img.transpose((2,0,1))
        label = self.label2classes(label,self.rows, self.cols) # 转换为 9 分类
        return img, label, label_name

    def get_batch(self, batch_size=1):
        imgs = []
        labels = []
        names = []
        while len(imgs) < batch_size:
            img, label, label_name = self.get_img()
            imgs.append(img)
            labels.append(label)
            names.append(label_name)
            self.next_img()
            self.augmentor += 1
        return np.array(imgs), np.array(labels), names

    def get_batch_generator(self, batch_size, total_step):
        def do_get_batch():
            for i in range(total_step):
                gc.collect() 
                try:
                    imgs, labels, names = self.get_batch(batch_size)
                except Exception as err:
                    imgs, labels, names = self.get_batch(batch_size)
                    print('Generator 异常',err)

                imgs   = imgs.astype(np.float32)
                labels = labels.astype(np.float32)
                imgs  /= 255
                yield i, imgs, labels, names

        batches = do_get_batch()
        try:
            from prefetch_generator import BackgroundGenerator
            batches = BackgroundGenerator(batches, 10)
        except:
            print(
                "You can install 'prefetch_generator' for acceleration of data reading."
            )
        return batches

class TestDataReader:
    def __init__(self, dataset_dir, subset='test',rows=880, cols=596, shuffle=False, birdeye=True):
        image_dirname = os.path.join(dataset_dir,subset)
        print (image_dirname)
        image_files = sorted(glob.glob(image_dirname+"/image/*."+"jpg"))
        print ('---')
        print (image_files[0])
        self.image_files = image_files
        self.image_dirname = image_dirname
        self.rows = rows
        self.cols = cols
        self.index = 0
        self.subset = subset
        self.dataset_dir = dataset_dir
        self.shuffle = shuffle
        self.M = 0
        self.Minv = 0
        self.reset()
        self.get_M_Minv()
        self.birdeye = birdeye
        print("images total number", len(image_files))

    def get_M_Minv(self):
        # 左上、右上、左下、右下
        src = np.float32([[800, 730], [2583, 730], [0, 1709], [3383, 1709]])
        dst = np.float32([[0, 0], [3999,0], [1300, 3999], [2700, 3999]])
        self.M = cv2.getPerspectiveTransform(src, dst)
        self.Minv = cv2.getPerspectiveTransform(dst,src)

    def reset(self, shuffle=False):
        self.index = 0
        if self.shuffle:
            np.random.shuffle(self.image_files)

    def next_img(self):
        self.index += 1
        if self.index >= len(self.image_files):
            self.reset()

    def get_img(self):
        while True:
            img_name = self.image_files[self.index]
            label_name = img_name.replace('.jpg', '.png')
            img = cv2.imread(img_name)
            if img is None:
                print("load img failed:", img_name)
                self.next_img()
            else:
                break
        if self.birdeye == True:
            warped_img = cv2.warpPerspective(img, self.M, (4000, 4000),flags=cv2.INTER_CUBIC)
            img   = cv2.resize(warped_img, (self.cols, self.rows), interpolation=cv2.INTER_CUBIC)
        else:
            img   = cv2.resize(img, (self.cols, self.rows), interpolation=cv2.INTER_CUBIC)
        img   = img.transpose((2,0,1))
        return img, label_name

    def get_batch(self, batch_size=1):
        imgs = []
        labels = []
        names = []
        while len(imgs) < batch_size:
            img, label_name = self.get_img()
            imgs.append(img)
            names.append(label_name)
            self.next_img()
        return np.array(imgs), names

    def get_batch_generator(self, batch_size, total_step):
        def do_get_batch():
            for i in range(total_step):
                imgs   = []
                names  = []
                try:
                    imgs, names = self.get_batch(batch_size)
                except Exception as err:
                    imgs, names = self.get_batch(batch_size)
                    print('Generator 异常',err)
                    
                imgs   = imgs.astype(np.float32)
                imgs  /= 255

                yield i, imgs, names

        batches = do_get_batch()
        try:
            from prefetch_generator import BackgroundGenerator
            batches = BackgroundGenerator(batches,10)
        except:
            print(
                "You can install 'prefetch_generator' for acceleration of data reading."
            )
        return batches

class EvalDataReader:
    def __init__(self, dataset_dir, subset='val',rows=512, cols=1024, shuffle=True, birdeye=True):
        label_dirname = os.path.join(dataset_dir,subset)
        print (label_dirname)
        label_files = sorted(glob.glob(label_dirname+"/label/*."+"png"))
        print ('---')
        print (label_files[0])
        self.label_files = label_files
        self.label_dirname = label_dirname
        self.rows = rows
        self.cols = cols
        self.index = 0
        self.subset = subset
        self.dataset_dir = dataset_dir
        self.shuffle = shuffle
        self.reset()
        self.augmentor = 0
        self.M = 0
        self.Minv = 0
        self.get_M_Minv()
        self.birdeye = birdeye
        print("images total number", len(label_files))

    # 标签转分类 255 ignore ?
    def label2classes(self, label,row,col):
        x = np.zeros([row,col,9]).astype(np.int64)
        for i in range(row):
            for j in range(col):
                    try:
                        trainId = id2label[int(label[i][j])].trainId
                        x[i, j ,trainId] = 1 # 属于第m类,第三维m处值为1
                    except Exception as err:
                        print('像素级标签值异常',err)
                        pass
        return x

    def get_M_Minv(self):
        # 左上、右上、左下、右下
        src = np.float32([[800, 730], [2583, 730], [0, 1709], [3383, 1709]])
        dst = np.float32([[0, 0], [3999,0], [1300, 3999], [2700, 3999]])
        self.M = cv2.getPerspectiveTransform(src, dst)
        self.Minv = cv2.getPerspectiveTransform(dst,src)

    def reset(self, shuffle=False):
        self.index = 0
        if self.shuffle:
            np.random.shuffle(self.label_files)

    def next_img(self):
        self.index += 1
        if self.index >= len(self.label_files):
            self.reset()

    def prev_img(self):
        if self.index >= 1:
            self.index -= 1

    def get_img(self):
        #if self.augmentor != 0 and self.augmentor < 6:
        #    self.prev_img()
        while True:
            label_name = self.label_files[self.index]
            img_name = label_name.replace('label', 'image')
            img_name = img_name.replace('_bin.png', '.jpg')
            label = cv2.imread(label_name,cv2.IMREAD_GRAYSCALE)
            img = cv2.imread(img_name)
            if img is None:
                print("load img failed:", img_name)
                self.next_img()
            else:
                break

        warped_img = cv2.warpPerspective(img, self.M, (4000, 4000),flags=cv2.INTER_CUBIC)
        warped_label = cv2.warpPerspective(label, self.M, (4000, 4000),flags=cv2.INTER_NEAREST)

        if self.birdeye == True:
            label = cv2.resize(warped_label, (self.cols, self.rows), interpolation=cv2.INTER_NEAREST)
            img   = cv2.resize(warped_img, (self.cols, self.rows), interpolation=cv2.INTER_CUBIC)
        else:
            label = cv2.resize(label, (self.cols, self.rows), interpolation=cv2.INTER_NEAREST)
            img   = cv2.resize(img, (self.cols, self.rows), interpolation=cv2.INTER_CUBIC)
            
        img   = img.transpose((2,0,1))
        label = self.label2classes(label,self.rows, self.cols) # 转换为 9 分类
        return img, label, label_name

    def get_batch(self, batch_size=1):
        imgs = []
        labels = []
        names = []
        while len(imgs) < batch_size:
            img, label, label_name = self.get_img()
            imgs.append(img)
            labels.append(label)
            names.append(label_name)
            self.next_img()
            self.augmentor += 1
        return np.array(imgs), np.array(labels), names

    def get_batch_generator(self, batch_size, total_step):
        def do_get_batch():
            for i in range(total_step):
                gc.collect() 
                try:
                    imgs, labels, names = self.get_batch(batch_size)
                except Exception as err:
                    imgs, labels, names = self.get_batch(batch_size)
                    print('Generator 异常',err)

                imgs   = imgs.astype(np.float32)
                labels = labels.astype(np.float32)
                imgs  /= 255

                yield i, imgs, labels, names

        batches = do_get_batch()
        try:
            from prefetch_generator import BackgroundGenerator
            batches = BackgroundGenerator(batches, 10)
        except:
            print(
                "You can install 'prefetch_generator' for acceleration of data reading."
            )
        return batches