"""
This file is used to generate TFRecords using the AI Challenger dataset.

@ Author: Yu Sun. vxallset@outlook.com

@ Date created: Jun 04, 2019

@ Last modified: Jun 27, 2019

"""
import numpy as np
import os
import time
from skimage import io, draw
from skimage.transform import resize
from random import shuffle
import tensorflow as tf
import json

"""
# The image which contains a person is collected from the AI Challenger dataset in the following steps:
    1. Get the coordinate of the bounding box in the original image.
    2. Adjust the ratio of the bounding box to be 4:3 (height : width)

    Note that the coordinates of keypoints are also re-calculated when the foreground parts are clipped from the 
    original images.

"""


def draw_points_on_img(img, point_ver, point_hor, point_class):
    for i in range(len(point_class)):
        if point_class[i] != 3:
            rr, cc = draw.circle(point_ver[i], point_hor[i], 10, (256, 192))
            #draw.set_color(img, [rr, cc], [0., 0., 0.], alpha=5)
            img[rr, cc, :] = 0
    #io.imshow(img)
    #io.show()

    return img


def draw_lines_on_img(img, point_ver, point_hor, point_class):
    line_list = [[0, 1], [1, 2], [3, 4], [4, 5], [6, 7], [7, 8], [9, 10],
               [10, 11], [12, 13], [13, 6], [13, 9], [13, 0], [13, 3]]

    # key point class: 1:visible, 2: not visible, 3: not marked
    for start_point_id in range(len(point_class)):
        if point_class[start_point_id] == 3:
            continue
        for end_point_id in range(len(point_class)):
            if point_class[end_point_id] == 3:
                continue

            if [start_point_id, end_point_id] in line_list:
                rr, cc = draw.line(int(point_ver[start_point_id]), int(point_hor[start_point_id]),
                                   int(point_ver[end_point_id]), int(point_hor[end_point_id]))
                draw.set_color(img, [rr, cc], [255, 0, 0])

    return img


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def extract_people_from_dataset(dataset_root_path='../../../dataset/ai_challenger/', image_save_path='../dataset/imgs/',
                                tfrecords_path='../dataset/', is_shuffle=True):
    """
    This function is used to extract people from the AI Challenger dataset. The extract image will contain only one
    person each and will be saved as a single .jpg file. At last, the image and the the corresponding annotation will
    be saved into a .tfrecord file.

    :param dataset_root_path: the root path of the AI Challenger dataset.
    :param image_save_path: the path used to save the clipped images.
    :param tfrecord_path: the path used to save the .tfrecords file.
    :param is_shuffle: is shuffle.
    :return: None.
    """
    annotation_file = os.path.join(dataset_root_path, 'keypoint_train_annotations_20170909.json')
    image_read_path = os.path.join(dataset_root_path, 'train_images')
    tfrecords_file = os.path.join(tfrecords_path, 'train.tfrecords')

    if not os.path.exists(tfrecords_path):
        os.mkdir(tfrecords_path)
    if os.path.exists(tfrecords_file):
        os.remove(tfrecords_file)
    if os.path.exists(image_save_path):
        useless = os.listdir(image_save_path)
        for onefile in useless:
            os.remove(os.path.join(image_save_path, onefile))
    else:
        os.mkdir(image_save_path)

    saved_number = 0
    image_number = 0
    start_time = time.time()
    with tf.python_io.TFRecordWriter(tfrecords_file) as tfwriter:

        with open(annotation_file, 'r') as jsfile:
            data = json.load(jsfile)

            for one_item in data:
                img_id = one_item['image_id']
                image_number += 1
                if image_number % 100 == 0:
                    print('Processed {} images, extracted {} people from the dataset. '
                          'time = {}'.format(image_number, saved_number, time.time() - start_time))

                kps = one_item['keypoint_annotations']
                boxes = one_item['human_annotations']

                # read image
                img_filename = os.path.join(image_read_path, img_id + '.jpg')
                img = io.imread(img_filename)

                for i in range(len(boxes)):
                    # construct the name of a human in the dictionary,
                    # for example, the first one (when i = 0) is 'human1'
                    human_name = 'human' + str(i+1)

                    kp = kps[human_name]
                    box = boxes[human_name]
                    p1_hor, p1_ver, p2_hor, p2_ver = box
                    foreground = img[p1_ver:p2_ver, p1_hor:p2_hor, :]

                    try:
                        foreground = resize(foreground, (256, 192, 3))
                    except ValueError:
                        print('ValueError at image {} and {}'.format(image_number, human_name))
                        continue

                    foreground = foreground * 255.0
                    foreground_uint8 = np.uint8(foreground)

                    kp_hor = (np.array(kp[0::3]) - p1_hor) / (p2_hor - p1_hor) * 192
                    kp_ver = (np.array(kp[1::3]) - p1_ver) / (p2_ver - p1_ver) * 256
                    kp_class = np.array(kp[2::3])

                    img_name = img_id + '_' + human_name + '.jpg'

                    io.imsave(os.path.join(image_save_path, img_id + '_' + human_name + '.jpg'), foreground_uint8)

                    example = tf.train.Example(
                        features=tf.train.Features(
                            feature={
                                'image_name': _bytes_feature(img_name.encode()),
                                'image_raw': _bytes_feature(foreground_uint8.tobytes()),
                                'keypoints_ver': _bytes_feature(np.uint8(kp_ver).tobytes()),
                                'keypoints_hor': _bytes_feature(np.uint8(kp_hor).tobytes()),
                                'keypoints_class': _bytes_feature(np.uint8(kp_class).tobytes())
                            }))
                    tfwriter.write(example.SerializeToString())

                    saved_number += 1
    print('Extracted {} people from the dataset in total.'.format(saved_number))


def decode_proto(proto):
    features = tf.parse_single_example(proto,
                                       features={
                                           'image_name': tf.FixedLenFeature([], tf.string),
                                           'image_raw': tf.FixedLenFeature([], tf.string),
                                           'keypoints_ver': tf.FixedLenFeature([], tf.string),
                                           'keypoints_hor': tf.FixedLenFeature([], tf.string),
                                           'keypoints_class': tf.FixedLenFeature([], tf.string),
                                       })
    image_name = features['image_name']

    image_raw = tf.decode_raw(features['image_raw'], out_type=np.uint8)
    image = tf.reshape(image_raw, [256, 192, 3])

    keypoints_ver = tf.decode_raw(features['keypoints_ver'], out_type=np.uint8)
    keypoints_hor = tf.decode_raw(features['keypoints_hor'], out_type=np.uint8)
    keypoints_class = tf.decode_raw(features['keypoints_class'], out_type=np.uint8)
    return image_name, image, keypoints_ver, keypoints_hor, keypoints_class


def decode_tfrecord(filename_queue):
    tfreader = tf.TFRecordReader()
    _, proto = tfreader.read(filename_queue)
    image_name, image, keypoints_ver, keypoints_hor, keypoints_class = decode_proto(proto)

    return image_name, image, keypoints_ver, keypoints_hor, keypoints_class


def input_batch(datasetname, batch_size, num_epochs):
    """
    This function is used to decode the TFrecord and return a batch of images as well as their information
    :param datasetname: the name of the TFrecord file.
    :param batch_size: the number of images in a batch
    :param num_epochs: the number of epochs
    :return: a batch of images as well as their information
    """
    with tf.name_scope('input_batch'):
        # The shuffle transformation uses a finite-sized buffer to shuffle elements
        # in memory. The parameter is the number of elements in the buffer. For
        # completely uniform shuffling, set the parameter to be the same as the
        # number of elements in the dataset.
        mydataset = tf.data.TFRecordDataset(datasetname)
        mydataset = mydataset.map(decode_proto)

        # have no idea why I can't set the parameter of mydataset.shuffle to be the number of the dataset......
        # mydataset = mydataset.shuffle(200)
        mydataset = mydataset.repeat(num_epochs * 2)
        # drop all the data that can't be used to make up a batch
        mydataset = mydataset.batch(batch_size, drop_remainder=True)
        iterator = mydataset.make_one_shot_iterator()

        nextelement = iterator.get_next()
        return nextelement


def mytest():
    tfrecord_file = '../dataset/train.tfrecords'

    filename_queue = tf.train.string_input_producer([tfrecord_file], num_epochs=None)
    image_name, image, keypoints_ver, keypoints_hor, keypoints_class = decode_tfrecord(filename_queue)

    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            # while not coord.should_stop():
            for i in range(10):
                img_name, img, point_ver, point_hor, point_class = sess.run([image_name, image, keypoints_ver,
                                                                             keypoints_hor, keypoints_class])

                print(img_name, point_hor, point_ver, point_class)

                for i in range(len(point_class)):
                    if point_class[i] > 0:
                        rr, cc = draw.circle(point_ver[i], point_hor[i], 10, (256, 192))
                        img[rr, cc, :] = 0

                io.imshow(img)
                io.show()

        except tf.errors.OutOfRangeError:
            print('Done reading')
        finally:
            coord.request_stop()


if __name__ == '__main__':
    extract_people_from_dataset()
    #mytest()