from __future__ import print_function
import os, sys, time
import argparse
import cv2 as cv
import numpy as np
import random # import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision.transforms import transforms
import torchgeometry as tgm

from model import DexiNet
from losses import weighted_cross_entropy_loss
from dexi_utils import cv_imshow, dataset_info


class testDataset(Dataset):
    def __init__(self, data_root, arg = None):
        self.data_root = data_root
        self.arg = arg
        self.transforms = transforms
        self.mean_bgr = arg.mean_pixel_values[0:3] if len(arg.mean_pixel_values)==4\
            else arg.mean_pixel_values

        self.data_index = self._build_index()


    def _build_index(self):
        sample_indices = []
        if not self.arg.test_data == "CLASSIC":
            list_name = os.path.join(self.data_root,self.arg.test_list)#os.path.abspath(self.data_root)
            with open(list_name,'r') as f:
                files = f.readlines()

            files = [line.strip() for line in files]

            pairs = [line.split() for line in files]
            images_path = [line[0] for line in pairs]
            labels_path = [line[1] for line in pairs]
            sample_indices = [images_path,labels_path]
        else:

            # for single image testing
            images_path = os.listdir(self.data_root)
            labels_path = None
            sample_indices = [images_path, labels_path]
        return sample_indices

    def __len__(self):
        return len(self.data_index[0])

    def __getitem__(self, idx):
        # get data sample
        # image_path, label_path = self.data_index[idx]
        image_path = self.data_index[0][idx]
        label_path = self.data_index[1][idx] if not self.arg.test_data=="CLASSIC" else None
        img_name = os.path.basename(image_path)
        file_name = img_name[:-3]+"png"

        # base dir
        if self.arg.test_data.upper() == 'BIPED':
            img_dir = os.path.join(self.arg.input_val_dir,'imgs','test')
            gt_dir = os.path.join(self.arg.input_val_dir,'edge_maps','test')
        elif self.arg.test_data.upper() == 'CLASSIC':
            img_dir = self.arg.input_val_dir
            gt_dir = None
        else:
            img_dir = self.arg.input_val_dir
            gt_dir = self.arg.input_val_dir

        # load data
        image = cv.imread(os.path.join(img_dir,image_path), cv.IMREAD_COLOR)
        if not self.arg.test_data == "CLASSIC":
            label = cv.imread(os.path.join(gt_dir,label_path), cv.IMREAD_COLOR)
        else:
            label=None

        im_shape =[image.shape[0],image.shape[1]]
        image, label = self.transform(img=image, gt=label)

        return dict(images=image, labels=label, file_names=file_name,image_shape=im_shape)

    def transform(self, img, gt):

        # gt[gt< 51] = 0 # test without gt discrimination
        if self.arg.test_data=="CLASSIC":
            img_height = img.shape[0] if img.shape[0] % 16 == 0 else ((img.shape[0] // 16) + 1) * 16
            img_width = img.shape[1] if img.shape[1] % 16 == 0 else ((img.shape[1] // 16) + 1) * 16
            print('Real-size:',img.shape, "Ideal size:",[img_height,img_width])
            img = cv.resize(img, (self.arg.test_im_width,self.arg.test_im_height))
            gt = None
        elif img.shape[0]<512 or img.shape[1]<512:
            img = cv.resize(img, (512, 512))
            gt = cv.resize(gt, (512, 512))
        elif img.shape[0]%16!=0 or img.shape[1]%16!=0:
            img_width = ((img.shape[1] // 16) + 1) * 16
            img_height = ((img.shape[0] // 16) + 1) * 16
            img = cv.resize(img, (img_width, img_height))
            gt = cv.resize(gt, (img_width, img_height))


        # if self.yita is not None:
        #     gt[gt >= self.yita] = 1
        img = np.array(img, dtype=np.float32)
        # if self.rgb:
        #     img = img[:, :, ::-1]  # RGB->BGR
        if not self.arg.test_data=="CLASSIC":
            gt = np.array(gt, dtype=np.float32)
            if len(gt.shape) == 3:
                gt = gt[:, :, 0]
            gt /= 255.
            gt = torch.from_numpy(np.array([gt])).float()
        else:
            gt = np.zeros((img.shape[:2]))
            gt=torch.from_numpy(np.array([gt])).float()
        img -= self.mean_bgr
        img = img.transpose((2, 0, 1))
        img = torch.from_numpy(img.copy()).float()

        return img, gt


class BipedMyDataset(Dataset):
    train_modes = ['train', 'test',]
    dataset_types = ['rgbr',]
    data_types = ['aug',]
    def __init__(self, data_root, train_mode='train', dataset_type='rgbr',
                 is_scaling=None, arg=None):
        self.data_root = data_root
        self.train_mode = train_mode
        self.dataset_type = dataset_type
        self.data_type = 'aug' # be aware that this might change in the future
        self.scale = is_scaling
        self.arg =arg
        self.mean_bgr = arg.mean_pixel_values[0:3] if len(arg.mean_pixel_values) == 4 \
            else arg.mean_pixel_values

        self.data_index = self._build_index()

    def _build_index(self):
        assert self.train_mode in self.train_modes, self.train_mode
        assert self.dataset_type in self.dataset_types, self.dataset_type
        assert self.data_type in self.data_types, self.data_type
        sample_indices = []
        data_root = os.path.abspath(self.data_root)
        images_path = os.path.join(data_root, 'imgs', self.train_mode,
                                   self.dataset_type, self.data_type)
        labels_path = os.path.join(data_root, 'edge_maps', self.train_mode,
                                   self.dataset_type, self.data_type)
        for directory_name in os.listdir(images_path):
             image_directories = os.path.join(images_path, directory_name)
             for file_name_ext in os.listdir(image_directories):
                 file_name = file_name_ext[:-4]
                 sample_indices.append(
                     (os.path.join(images_path, directory_name, file_name + '.jpg'),
                      os.path.join(labels_path, directory_name, file_name + '.png'),)
                 )
        return sample_indices

    def __len__(self):
        return len(self.data_index)

    def __getitem__(self, idx):
        # get data sample
        image_path, label_path = self.data_index[idx]
        
        # load data
        image = cv.imread(image_path, cv.IMREAD_COLOR)
        label = cv.imread(label_path, cv.IMREAD_GRAYSCALE)
        image, label = self.transform(img=image, gt=label)
        return dict(images=image, labels=label)

    def transform(self, img, gt):

        gt = np.array(gt, dtype=np.float32)
        if len(gt.shape) == 3:
            gt = gt[:, :, 0]
        # gt[gt< 51] = 0 # test without gt discrimination
        gt /= 255.
        # if self.yita is not None:
        #     gt[gt >= self.yita] = 1
        img = np.array(img, dtype=np.float32)
        # if self.rgb:
        #     img = img[:, :, ::-1]  # RGB->BGR
        img -= self.mean_bgr
        # data = []
        # if self.scale is not None:
        #     for scl in self.scale:
        #         img_scale = cv.resize(img, None, fx=scl, fy=scl, interpolation=cv.INTER_LINEAR)
        #         data.append(torch.from_numpy(img_scale.transpose((2, 0, 1))).float())
        #     return data, gt
        crop_size = self.arg.img_height if self.arg.img_height == self.arg.img_width else 400

        if self.arg.crop_img:
            _, h, w = gt.size()
            assert (crop_size < h and crop_size < w)
            i = random.randint(0, h - crop_size)
            j = random.randint(0, w - crop_size)
            img = img[:, i:i + crop_size, j:j + crop_size]
            gt = gt[:, i:i + crop_size, j:j + crop_size]
        else:
            img = cv.resize(img, dsize=(self.arg.img_width, self.arg.img_height ))
            gt = cv.resize(gt, dsize=(self.arg.img_width, self.arg.img_height ))
        img = img.transpose((2, 0, 1))
        img = torch.from_numpy(img.copy()).float()
        gt = torch.from_numpy(np.array([gt])).float()
        return img, gt

def image_normalization(img, img_min=0, img_max=255):
    """This is a typical image normalization function
    where the minimum and maximum of the image is needed
    source: https://en.wikipedia.org/wiki/Normalization_(image_processing)
    :param img: an image could be gray scale or color
    :param img_min:  for default is 0
    :param img_max: for default is 255
    :return: a normalized image, if max is 255 the dtype is uint8
    """
    img = np.float32(img)
    epsilon=1e-12 # whenever an inconsistent image
    img = (img-np.min(img))*(img_max-img_min)/((np.max(img)-np.min(img))+epsilon)+img_min
    return img

def restore_rgb(config,I, restore_rgb=False):
    """
    :param config: [args.channel_swap, args.mean_pixel_value]
    :param I: and image or a set of images
    :return: an image or a set of images restored
    """

    if  len(I)>3 and not type(I)==np.ndarray:
        I =np.array(I)
        I = I[:,:,:,0:3]
        n = I.shape[0]
        for i in range(n):
            x = I[i,...]
            x = np.array(x, dtype=np.float32)
            x += config[1]
            if restore_rgb:
                x = x[:, :, config[0]]
            x = image_normalization(x)
            I[i,:,:,:]=x
    elif len(I.shape)==3 and I.shape[-1]==3:
        I = np.array(I, dtype=np.float32)
        I += config[1]
        if restore_rgb:
            I = I[:, :, config[0]]
        I = image_normalization(I)
    else:
        print("Sorry the input data size is out of our configuration")
    # print("The enterely I data {} restored".format(I.shape))
    return I

def visualize_result(imgs_list, arg):
    """
    data 2 image in one matrix
    :param imgs_list: a list of prediction, gt and input data
    :param arg:
    :return: one image with the whole of imgs_list data
    """
    n_imgs = len(imgs_list)
    data_list =[]
    for i in range(n_imgs):
        tmp = imgs_list[i]
        if tmp.shape[1]==3:
            tmp = np.transpose(np.squeeze(tmp[1]),[1,2,0])
            tmp=restore_rgb([arg.channel_swap,arg.mean_pixel_values[:3]],tmp)
            tmp = np.uint8(image_normalization(tmp))
        else:
            tmp= np.squeeze(tmp[1])
            if len(tmp.shape) == 2:
                tmp = np.uint8(image_normalization(tmp))
                tmp = cv.bitwise_not(tmp)
                tmp = cv.cvtColor(tmp, cv.COLOR_GRAY2BGR)
            else:
                tmp = np.uint8(image_normalization(tmp))
        data_list.append(tmp)
    img = data_list[0]
    if n_imgs % 2 == 0:
        imgs = np.zeros((img.shape[0] * 2 + 10, img.shape[1] * (n_imgs // 2) + ((n_imgs // 2 - 1) * 5), 3))
    else:
        imgs = np.zeros((img.shape[0] * 2 + 10, img.shape[1] * ((1 + n_imgs) // 2) + ((n_imgs // 2) * 5), 3))
        n_imgs += 1

    k=0
    imgs = np.uint8(imgs)
    i_step = img.shape[0]+10
    j_step = img.shape[1]+5
    for i in range(2):
        for j in range(n_imgs//2):
            if k<len(data_list):
                imgs[i*i_step:i*i_step+img.shape[0],j*j_step:j*j_step+img.shape[1],:]=data_list[k]
                k+=1
            else:
                pass
    return imgs

def create_directory(dir_path):
    """Creates an empty directory.
    Args:
        dir_path (str): the absolute path to the directory to create.
    """
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

def train(epoch, dataloader, model, criterion, optimizer, device,
          log_interval_vis, tb_writer, args=None):
    imgs_res_folder =os.path.join(args.output_dir,'current_res')
    create_directory(imgs_res_folder)
    model.train()
    for batch_id, sample_batched in enumerate(dataloader):
        images = sample_batched['images'].to(device)  # BxCxHxW
        labels = sample_batched['labels'].to(device)  # BxHxW
        # labels = labels[:, None]  # Bx1xHxW

        preds_list = model(images)
        loss = sum([criterion(preds, labels) for preds in preds_list])
        loss /= images.shape[0]  # the batch size

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch_id%5==0:
            print(time.ctime(),'Epoch: {0} Sample {1}/{2} Loss: {3}' \
              .format(epoch, batch_id, len(dataloader), loss.item()))

        if tb_writer is not None:
            tb_writer.add_scalar('data/loss', loss.detach(), (len(dataloader)*epoch+batch_id))

        if batch_id % log_interval_vis == 0:
            res_data = []
            img = images.cpu().numpy()
            res_data.append(img)
            ed_gt = labels.cpu().numpy()
            res_data.append(ed_gt)
            for i in range(len(preds_list)):
                tmp = preds_list[i]
                tmp = torch.sigmoid(tmp)
                tmp = tmp.cpu().detach().numpy()
                res_data.append(tmp)
            vis_imgs = visualize_result(res_data, arg=args)
            del tmp, res_data
            vis_imgs = cv.resize(vis_imgs,(int(vis_imgs.shape[1]*0.8),int(vis_imgs.shape[0]*0.8)))
            img_test = 'Epoch: {0} Sample {1}/{2} Loss: {3}' \
              .format(epoch, batch_id, len(dataloader), loss.item())
            BLACK = (0, 0, 255)
            font = cv.FONT_HERSHEY_SIMPLEX
            font_size = 1.1
            font_color = BLACK
            font_thickness = 2
            x, y = 30, 30
            vis_imgs = cv.putText(vis_imgs, img_test, (x, y), font, font_size, font_color, font_thickness, cv.LINE_AA)
            cv.imwrite(os.path.join(imgs_res_folder,'results.png'),vis_imgs)


def save_image_batch_to_disk(tensor, output_dir, file_names, img_shape=None,arg=None):

    os.makedirs(output_dir,exist_ok=True)
    if not arg.is_testing:
        assert len(tensor.shape) == 4, tensor.shape
        for tensor_image, file_name in zip(tensor, file_names):
            image_vis = tgm.utils.tensor_to_image(torch.sigmoid(tensor_image))[..., 0]
            image_vis = (255.0*(1.0- image_vis)).astype(np.uint8) #
            output_file_name = os.path.join(output_dir, file_name)
            assert cv.imwrite(output_file_name, image_vis)
    else:
        output_dir_f = os.path.join(output_dir,'f')
        output_dir_a = os.path.join(output_dir,'a')
        os.makedirs(output_dir_f, exist_ok=True)
        os.makedirs(output_dir_a,exist_ok=True)
        # 255.0 * (1.0 - em_a)
        edge_maps = []
        for i in tensor:
            tmp = torch.sigmoid(i).cpu().detach().numpy()
            edge_maps.append(tmp)
        # edge_maps.append(tmp)
        tensor = np.array(edge_maps)
        idx =0
        image_shape = [x.cpu().detach().numpy() for x in img_shape]
        image_shape = [[y, x] for x, y in zip(image_shape[0], image_shape[1])]
        for i_shape, file_name in zip(image_shape,file_names):
            tmp = tensor[:,idx,...]
            tmp = np.transpose(np.squeeze(tmp),[0,1,2])
            preds = []
            for i in range(tmp.shape[0]):
                tmp_img = tmp[i]
                tmp_img[tmp_img<0.0] = 0.0
                tmp_img =255.0 * (1.0 - tmp_img)
                if not tmp_img.shape[1]==i_shape[0] or not tmp_img.shape[0]==i_shape[1]:
                    tmp_img = cv.resize(tmp_img,(i_shape[0],i_shape[1]))
                preds.append(tmp_img)
                if i==6:
                    fuse = tmp_img
            average = np.array(preds,dtype=np.float32)
            average = np.uint8(np.mean(average,axis=0))
            output_file_name_f = os.path.join(output_dir_f, file_name)
            output_file_name_a = os.path.join(output_dir_a, file_name)
            assert cv.imwrite(output_file_name_f, fuse)
            assert cv.imwrite(output_file_name_a, np.uint8(average))
            idx+=1
        

def validation(epoch, dataloader, model, device, output_dir, arg=None):
    model.eval()
    total_losses = []

    for batch_id, sample_batched in enumerate(dataloader):
        images = sample_batched['images'].to(device)
        labels = sample_batched['labels'].to(device)
        file_names = sample_batched['file_names']
        output = model(images)
        save_image_batch_to_disk(output[-1], output_dir, file_names, arg=arg)


def weight_init(m):
    if isinstance(m, (nn.Conv2d, )):

        torch.nn.init.normal_(m.weight,mean=0, std=0.01)
        if m.weight.data.shape[1]==torch.Size([1]):
            torch.nn.init.normal_(m.weight, mean=0.0,)
        if m.weight.data.shape==torch.Size([1,6,1,1]):
            torch.nn.init.constant_(m.weight,0.2)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)
    # for fusion layer
    if isinstance(m, (nn.ConvTranspose2d,)):

        torch.nn.init.normal_(m.weight,mean=0, std=0.01)
        if m.weight.data.shape[1] == torch.Size([1]):
            torch.nn.init.normal_(m.weight, std=0.1)

        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

def main():
    # Testing settings
    DATASET_NAME= ['BIPED','BSDS','BSDS300','CID','DCD','MULTICUE',
                    'PASCAL','NYUD','CLASSIC'] # 8
    TEST_DATA = DATASET_NAME[8]
    data_inf = dataset_info(TEST_DATA)
    # training settings
    parser = argparse.ArgumentParser(description='Training application.')
    # Data parameters
    parser.add_argument('--input-dir', type=str,default='/opt/dataset/BIPED/edges',
                        help='the path to the directory with the input data.')
    parser.add_argument('--input-val-dir', type=str,default=data_inf['data_dir'],
                        help='the path to the directory with the input data for validation.')
    parser.add_argument('--output_dir', type=str, default='checkpoints',
                        help='the path to output the results.')
    parser.add_argument('--test_data', type=str, default=TEST_DATA,
                        help='Name of the dataset.')
    parser.add_argument('--test_list', type=str, default=data_inf['file_name'],
                        help='Name of the dataset.')
    parser.add_argument('--is_testing', type=bool, default=True,
                        help='Just for testing')
    parser.add_argument('--use_prev_trained', type=bool, default=True,
                        help='use previous trained data') # Just for test
    parser.add_argument('--checkpoint_data', type=str, default='24/24_model.pth',
                        help='Just for testing') #  '19/19_*.pht'
    parser.add_argument('--test_im_width', type=int, default=data_inf['img_width'],
                        help='image height for testing')
    parser.add_argument('--test_im_height', type=int, default=data_inf['img_height'],
                        help=' image height for testing')
    parser.add_argument('--res_dir', type=str, default='result',
                        help='Result directory')
    parser.add_argument('--log-interval-vis', type=int, default=50,
                        help='how many batches to wait before logging training status')
    # Optimization parameters
    parser.add_argument('--optimizer', type=str, choices=['adam', 'sgd'], default='adam',
                        help='the optimization solver to use (default: adam)')
    parser.add_argument('--num-epochs', type=int, default=25, metavar='N',
                        help='number of training epochs (default: 100)')
    # parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
    #                     help='learning rate (default: 1e-3)')
    parser.add_argument('--wd', type=float, default=1e-5, metavar='WD',
                        help='weight decay (default: 1e-5)')
    parser.add_argument('--lr', default=1e-4, type=float,
                        help='Initial learning rate.')
    parser.add_argument('--lr_stepsize', default=1e4, type=int,
                        help='Learning rate step size.')
    parser.add_argument('--batch-size', type=int, default=8, metavar='B',
                        help='the mini-batch size (default: 2)')
    parser.add_argument('--num-workers', default=8, type=int,
                        help='the number of workers for the dataloader.')
    parser.add_argument('--tensorboard', action='store_true', default=True,
                        help='use tensorboard for logging purposes'),
    parser.add_argument('--gpu', type=str, default='1',
                        help='select GPU'),
    parser.add_argument('--img_width', type = int, default = 400, help='image size for training')
    parser.add_argument('--img_height', type = int, default = 400, help='image size for training')
    parser.add_argument('--channel_swap', default=[2, 1, 0], type=int)
    parser.add_argument('--crop_img', default=False, type=bool,
                        help='If true crop training images, other ways resizing')
    parser.add_argument('--mean_pixel_values', default=[104.00699, 116.66877, 122.67892, 137.86],
                        type=float)  # [103.939,116.779,123.68] [104.00699, 116.66877, 122.67892]
    args = parser.parse_args()

    tb_writer = None
    if args.tensorboard and not args.is_testing:
        from tensorboardX import SummaryWriter # previous torch version
        # from torch.utils.tensorboard import SummaryWriter # for torch 1.4 or greather
        tb_writer = SummaryWriter(log_dir=args.output_dir)
    print(" **** You have available ", torch.cuda.device_count(), "GPUs!")
    print("Pytorch version: ", torch.__version__)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    device = torch.device('cpu' if torch.cuda.device_count() == 0 else 'cuda')
    model = DexiNet().to(device)
    # model = nn.DataParallel(model)
    model.apply(weight_init)

    if not args.is_testing:

        dataset_train = BipedMyDataset(args.input_dir, train_mode='train',
                                      arg=args)

        dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size,
                                      shuffle=True, num_workers=args.num_workers)
    dataset_val = testDataset(args.input_val_dir, arg=args)
    dataloader_val = DataLoader(dataset_val, batch_size=args.batch_size,
                                shuffle=False, num_workers=args.num_workers)
    # for testing
    if args.is_testing:
        model.load_state_dict(torch.load(os.path.join(args.output_dir,args.checkpoint_data), map_location=device))

        model.eval()

        output_dir = os.path.join(args.res_dir, "BIPED2" + args.test_data)
        with torch.no_grad():
            for batch_id, sample_batched in enumerate(dataloader_val):
                images = sample_batched['images'].to(device)
                if not args.test_data == "CLASSIC":
                    labels = sample_batched['labels'].to(device)
                file_names = sample_batched['file_names']
                image_shape = sample_batched['image_shape']
                print("input image size: ",images.shape)
                output = model(images)
                save_image_batch_to_disk(output, output_dir, file_names,image_shape, arg=args)

        print("Testing ended in ",args.test_data, "dataset")
        sys.exit()

    criterion = weighted_cross_entropy_loss
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)

    # Learning rate scheduler.
    # lr_schd = lr_scheduler.StepLR(optimizer, step_size=args.lr_stepsize,
    #                               gamma=args.lr_gamma)

    for epoch in range(args.num_epochs):
        # Create output directory
        output_dir_epoch = os.path.join(args.output_dir, str(epoch))
        img_test_dir = os.path.join(output_dir_epoch,args.test_data+'_res')
        create_directory(output_dir_epoch)
        create_directory(img_test_dir)
        # with torch.no_grad():
        #     validation(epoch, dataloader_val, model, device, img_test_dir,arg=args)
        train(epoch, dataloader_train, model, criterion, optimizer, device,
              args.log_interval_vis, tb_writer, args=args)

        # lr_schd.step() # decay lr at the end of the epoch.
    
        with torch.no_grad():
            validation(epoch, dataloader_val, model, device, img_test_dir,arg=args)

        try:
            net_state_dict = model.module.state_dict()
        except:
            net_state_dict = model.state_dict()

        torch.save(net_state_dict, os.path.join(
                   output_dir_epoch, '{0}_model.pth'.format(epoch)))

        
if __name__ == '__main__':
    main()