This file is used to generate TFRecords using the AI Challenger dataset.

@ Author: Yu Sun. vxallset@outlook.com

@ Date created: Jun 04, 2019

@ Last modified: Jun 27, 2019

import numpy as np
import os
import time
from skimage import io, draw
from skimage.transform import resize
from random import shuffle
import tensorflow as tf
import json

# The image which contains a person is collected from the AI Challenger dataset in the following steps:
    1. Get the coordinate of the bounding box in the original image.
    2. Adjust the ratio of the bounding box to be 4:3 (height : width)

    Note that the coordinates of keypoints are also re-calculated when the foreground parts are clipped from the 
    original images.


def draw_points_on_img(img, point_ver, point_hor, point_class):
    for i in range(len(point_class)):
        if point_class[i] != 3:
            rr, cc = draw.circle(point_ver[i], point_hor[i], 10, (256, 192))
            #draw.set_color(img, [rr, cc], [0., 0., 0.], alpha=5)
            img[rr, cc, :] = 0

    return img

def draw_lines_on_img(img, point_ver, point_hor, point_class):
    line_list = [[0, 1], [1, 2], [3, 4], [4, 5], [6, 7], [7, 8], [9, 10],
               [10, 11], [12, 13], [13, 6], [13, 9], [13, 0], [13, 3]]

    # key point class: 1:visible, 2: not visible, 3: not marked
    for start_point_id in range(len(point_class)):
        if point_class[start_point_id] == 3:
        for end_point_id in range(len(point_class)):
            if point_class[end_point_id] == 3:

            if [start_point_id, end_point_id] in line_list:
                rr, cc = draw.line(int(point_ver[start_point_id]), int(point_hor[start_point_id]),
                                   int(point_ver[end_point_id]), int(point_hor[end_point_id]))
                draw.set_color(img, [rr, cc], [255, 0, 0])

    return img

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def extract_people_from_dataset(dataset_root_path='../../../dataset/ai_challenger/', image_save_path='../dataset/imgs/',
                                tfrecords_path='../dataset/', is_shuffle=True):
    This function is used to extract people from the AI Challenger dataset. The extract image will contain only one
    person each and will be saved as a single .jpg file. At last, the image and the the corresponding annotation will
    be saved into a .tfrecord file.

    :param dataset_root_path: the root path of the AI Challenger dataset.
    :param image_save_path: the path used to save the clipped images.
    :param tfrecord_path: the path used to save the .tfrecords file.
    :param is_shuffle: is shuffle.
    :return: None.
    annotation_file = os.path.join(dataset_root_path, 'keypoint_train_annotations_20170909.json')
    image_read_path = os.path.join(dataset_root_path, 'train_images')
    tfrecords_file = os.path.join(tfrecords_path, 'train.tfrecords')

    if not os.path.exists(tfrecords_path):
    if os.path.exists(tfrecords_file):
    if os.path.exists(image_save_path):
        useless = os.listdir(image_save_path)
        for onefile in useless:
            os.remove(os.path.join(image_save_path, onefile))

    saved_number = 0
    image_number = 0
    start_time = time.time()
    with tf.python_io.TFRecordWriter(tfrecords_file) as tfwriter:

        with open(annotation_file, 'r') as jsfile:
            data = json.load(jsfile)

            for one_item in data:
                img_id = one_item['image_id']
                image_number += 1
                if image_number % 100 == 0:
                    print('Processed {} images, extracted {} people from the dataset. '
                          'time = {}'.format(image_number, saved_number, time.time() - start_time))

                kps = one_item['keypoint_annotations']
                boxes = one_item['human_annotations']

                # read image
                img_filename = os.path.join(image_read_path, img_id + '.jpg')
                img = io.imread(img_filename)

                for i in range(len(boxes)):
                    # construct the name of a human in the dictionary,
                    # for example, the first one (when i = 0) is 'human1'
                    human_name = 'human' + str(i+1)

                    kp = kps[human_name]
                    box = boxes[human_name]
                    p1_hor, p1_ver, p2_hor, p2_ver = box
                    foreground = img[p1_ver:p2_ver, p1_hor:p2_hor, :]

                        foreground = resize(foreground, (256, 192, 3))
                    except ValueError:
                        print('ValueError at image {} and {}'.format(image_number, human_name))

                    foreground = foreground * 255.0
                    foreground_uint8 = np.uint8(foreground)

                    kp_hor = (np.array(kp[0::3]) - p1_hor) / (p2_hor - p1_hor) * 192
                    kp_ver = (np.array(kp[1::3]) - p1_ver) / (p2_ver - p1_ver) * 256
                    kp_class = np.array(kp[2::3])

                    img_name = img_id + '_' + human_name + '.jpg'

                    io.imsave(os.path.join(image_save_path, img_id + '_' + human_name + '.jpg'), foreground_uint8)

                    example = tf.train.Example(
                                'image_name': _bytes_feature(img_name.encode()),
                                'image_raw': _bytes_feature(foreground_uint8.tobytes()),
                                'keypoints_ver': _bytes_feature(np.uint8(kp_ver).tobytes()),
                                'keypoints_hor': _bytes_feature(np.uint8(kp_hor).tobytes()),
                                'keypoints_class': _bytes_feature(np.uint8(kp_class).tobytes())

                    saved_number += 1
    print('Extracted {} people from the dataset in total.'.format(saved_number))

def decode_proto(proto):
    features = tf.parse_single_example(proto,
                                           'image_name': tf.FixedLenFeature([], tf.string),
                                           'image_raw': tf.FixedLenFeature([], tf.string),
                                           'keypoints_ver': tf.FixedLenFeature([], tf.string),
                                           'keypoints_hor': tf.FixedLenFeature([], tf.string),
                                           'keypoints_class': tf.FixedLenFeature([], tf.string),
    image_name = features['image_name']

    image_raw = tf.decode_raw(features['image_raw'], out_type=np.uint8)
    image = tf.reshape(image_raw, [256, 192, 3])

    keypoints_ver = tf.decode_raw(features['keypoints_ver'], out_type=np.uint8)
    keypoints_hor = tf.decode_raw(features['keypoints_hor'], out_type=np.uint8)
    keypoints_class = tf.decode_raw(features['keypoints_class'], out_type=np.uint8)
    return image_name, image, keypoints_ver, keypoints_hor, keypoints_class

def decode_tfrecord(filename_queue):
    tfreader = tf.TFRecordReader()
    _, proto = tfreader.read(filename_queue)
    image_name, image, keypoints_ver, keypoints_hor, keypoints_class = decode_proto(proto)

    return image_name, image, keypoints_ver, keypoints_hor, keypoints_class

def input_batch(datasetname, batch_size, num_epochs):
    This function is used to decode the TFrecord and return a batch of images as well as their information
    :param datasetname: the name of the TFrecord file.
    :param batch_size: the number of images in a batch
    :param num_epochs: the number of epochs
    :return: a batch of images as well as their information
    with tf.name_scope('input_batch'):
        # The shuffle transformation uses a finite-sized buffer to shuffle elements
        # in memory. The parameter is the number of elements in the buffer. For
        # completely uniform shuffling, set the parameter to be the same as the
        # number of elements in the dataset.
        mydataset = tf.data.TFRecordDataset(datasetname)
        mydataset = mydataset.map(decode_proto)

        # have no idea why I can't set the parameter of mydataset.shuffle to be the number of the dataset......
        # mydataset = mydataset.shuffle(200)
        mydataset = mydataset.repeat(num_epochs * 2)
        # drop all the data that can't be used to make up a batch
        mydataset = mydataset.batch(batch_size, drop_remainder=True)
        iterator = mydataset.make_one_shot_iterator()

        nextelement = iterator.get_next()
        return nextelement

def mytest():
    tfrecord_file = '../dataset/train.tfrecords'

    filename_queue = tf.train.string_input_producer([tfrecord_file], num_epochs=None)
    image_name, image, keypoints_ver, keypoints_hor, keypoints_class = decode_tfrecord(filename_queue)

    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
            # while not coord.should_stop():
            for i in range(10):
                img_name, img, point_ver, point_hor, point_class = sess.run([image_name, image, keypoints_ver,
                                                                             keypoints_hor, keypoints_class])

                print(img_name, point_hor, point_ver, point_class)

                for i in range(len(point_class)):
                    if point_class[i] > 0:
                        rr, cc = draw.circle(point_ver[i], point_hor[i], 10, (256, 192))
                        img[rr, cc, :] = 0


        except tf.errors.OutOfRangeError:
            print('Done reading')

if __name__ == '__main__':