import h5py
import cv2
import numpy as np
import copy
import os
import imageio
import io
from multiprocessing import Pool, Manager
from visual_mpc.utils.sync import ManagedSyncCounter
import random
import functools
from tqdm import tqdm


MANDATORY_KEYS = ['camera_configuration', 'policy_desc', 'bin_type', 'bin_insert', 'contains_annotation',
                                'robot', 'gripper', 'background', 'action_space', 'object_classes', 'primitives', 'camera_type']


def serialize_image(img):
    assert img.dtype == np.uint8, "Must be uint8!"
    return cv2.imencode('.jpg', img)[1]


def serialize_video(imgs, temp_name_append):
    mp4_name = './temp{}.mp4'.format(temp_name_append)
    try:
        assert imgs.dtype == np.uint8, "Must be uint8 array!"
        assert not os.path.exists(mp4_name), "file {} exists!".format(mp4_name)
        # this is a hack to ensure imageio succesfully saves as a mp4 (instead of getting encoding confused)
        writer = imageio.get_writer(mp4_name)
        [writer.append_data(i[:, :, ::-1]) for i in imgs]
        writer.close()

        f = open(mp4_name, 'rb')
        buf = f.read()
        f.close()
    finally:
        if os.path.exists(mp4_name):
            os.remove(mp4_name)

    return np.frombuffer(buf, dtype=np.uint8)


def save_dict(data_container, dict_group, video_encoding, t_index):
    for k, d in data_container.items():
                    if 'images' == k:
                        T, n_cams = d.shape[:2]
                        dict_group.attrs['n_cams'] = n_cams

                        for n in range(n_cams):
                            dict_group.attrs['cam_encoding'] = video_encoding
                            cam_group = dict_group.create_group("cam{}_video".format(n))
                            if video_encoding == 'mp4':
                                data = cam_group.create_dataset("frames", data=serialize_video(d[:, n], t_index))
                                data.attrs['shape'] = d[0, n].shape
                                data.attrs['T'] = d.shape[0]
                                data.attrs['image_format'] = 'RGB'
                            elif video_encoding == 'jpeg':
                                for t in range(T):
                                    data = cam_group.create_dataset("frame{}".format(t), data=serialize_image(d[t, n]))
                                    data.attrs['shape'] = d[t, n].shape
                                    data.attrs['image_format'] = 'RGB'
                            else:
                                raise ValueError
                    elif 'image' in k:
                        data = dict_group.create_dataset(k, data=serialize_image(d))
                        data.attrs['shape'] = d.shape
                    else:
                        dict_group.create_dataset(k, data=d)


def save_hdf5(filename, env_obs, policy_out, agent_data, meta_data, video_encoding='mp4', t_index=None):
    if t_index is None:
        t_index = random.randint(0, 9999999)
    # meta-data includes calibration "number", policy "type" descriptor, environment bounds
    with h5py.File(filename, 'w') as f:
        f.create_dataset('file_version', data='0.1.0')
        [save_dict(data_container, f.create_group(name), video_encoding, t_index) for data_container, name in zip([env_obs, agent_data], ['env', 'misc'])]

        policy_dict = {}
        first_keys = list(policy_out[0].keys())
        for k in first_keys:
            assert all([k in p for p in policy_out[1:]]), "hdf5 format requires keys must be uniform across time!"
            policy_dict[k] = np.concatenate([p[k][None] for p in policy_out], axis=0)
        save_dict(policy_dict, f.create_group('policy'), video_encoding, t_index)

        meta_data_group = f.create_group('metadata')
        for mandatory_key in MANDATORY_KEYS:
            meta_data_group.attrs[mandatory_key] = meta_data.pop(mandatory_key)
        
        for k in meta_data.keys():
            meta_data_group.attrs[k] = meta_data[k]
            

def save_worker(traj_data, cntr, group_name=''):
    t_index = random.randint(0, 9999999)
    t, meta_data = traj_data

    try:
        env_obs = pkl.load(open('{}/obs_dict.pkl'.format(t), 'rb'), encoding='latin1')
        if meta_data['contains_annotation']:
            env_obs['bbox_annotations'] = pkl.load(open('{}/annotation_array.pkl'.format(t), 'rb'), encoding='latin1')
        n_cams = len(glob.glob('{}/images*'.format(t)))
        if n_cams:
            T = min([len(glob.glob('{}/images{}/*.jpg'.format(t, i))) for i in range(n_cams)])
            height, width = cv2.imread('{}/images0/im_0.jpg'.format(t)).shape[:2]
            env_obs['images'] = np.zeros((T, n_cams, height, width, 3), dtype=np.uint8)

            for n in range(n_cams):
                for time in range(T):
                    env_obs['images'][time, n] = cv2.imread('{}/images{}/im_{}.jpg'.format(t, n, time))

        policy_out = pkl.load(open('{}/policy_out.pkl'.format(t), 'rb'), encoding='latin1')
        agent_data = pkl.load(open('{}/agent_data.pkl'.format(t), 'rb'), encoding='latin1')

        def store_in_metadata_if_exists(key):  
            if key in agent_data:
                meta_data[key] = agent_data.pop(key)
        [store_in_metadata_if_exists(k) for k in ['goal_reached', 'term_t']]

        c = cntr.ret_increment
        save_hdf5('{}/{}traj{}.hdf5'.format(args.output_folder, group_name, c), env_obs, policy_out, agent_data, meta_data, video_encoding, t_index)
        return True
    except (FileNotFoundError, NotADirectoryError):
        return False


if __name__ == '__main__':
    import argparse
    import glob
    import json
    import random
    import sys
    import os
    import shutil
    import math
    if sys.version_info[0] == 2:
        import cPickle as pkl
        input_fn = raw_input
    else:
        import pickle as pkl
        input_fn = input

    parser = argparse.ArgumentParser(description="converts dataset from pkl format to hdf5")
    parser.add_argument('input_folder', type=str, help='where raw files are stored')
    parser.add_argument('output_folder', type=str, help='where to save')
    parser.add_argument('--output_group_name', type=str, default='', help='name to prepend in front of trajs')
    parser.add_argument('--video_jpeg_encoding', action='store_true', default=False, help='uses jpeg encoding for video frames instead of mp4')
    parser.add_argument('--counter', type=int, help='where to start counter', default=0)
    parser.add_argument('--n_workers', type=int, help='number of multi-threaded workers', default=1)
    args = parser.parse_args()

    assert args.n_workers >= 1, "can't have less than 1 worker thread!"
    args.input_folder, args.output_folder = [os.path.expanduser(x) for x in (args.input_folder, args.output_folder)]
    if not os.path.exists(args.output_folder):
        os.makedirs(args.output_folder)
    elif input_fn('path {} exists, should folder be deleted? (y/n): '.format(args.output_folder)) == 'y':
        shutil.rmtree(args.output_folder)
        os.makedirs(args.output_folder)   
    
    if args.video_jpeg_encoding:
        video_encoding = 'jpeg'
    else:
        video_encoding = 'mp4'
        if len(glob.glob('./temp*.mp4')) != 0:
            print("Please delete all temp*.mp4 files! (needed for saving)")
            raise EnvironmentError
    
    traj_groups = glob.glob(args.input_folder + "/*")
    print('found {} traj groups!'.format(len(traj_groups)))

    trajs, annotations_loaded = [], 0
    for group in traj_groups:
        meta_data_dict = json.load(open('{}/hparams.json'.format(group), 'r'))
        group_trajs = glob.glob('{}/*'.format(group))
        for t in group_trajs:
            traj_meta_data = copy.deepcopy(meta_data_dict)
            traj_meta_data['object_batch'] = group
            if os.path.exists('{}/annotation_array.pkl'.format(t)):
                traj_meta_data['contains_annotation'] = True
                annotations_loaded += 1
            else:
                traj_meta_data['contains_annotation'] = False
            
            if isinstance(traj_meta_data['object_classes'], str):
                traj_meta_data['object_classes'] = traj_meta_data['object_classes'].split("+")
            
            assert all([k in traj_meta_data for k in MANDATORY_KEYS]), 'metadata for {} is missing keys!'.format(t)
            assert isinstance(traj_meta_data['object_classes'], list), "did not split object classes!"
            assert all([isinstance(x, str) for x in traj_meta_data['object_classes']]), 'object classes is not a string!'

            trajs.append((t, traj_meta_data))
    random.shuffle(trajs)
    
    print('Loaded {} trajectories with {} annotations!'.format(len(trajs), annotations_loaded))

    cntr = ManagedSyncCounter(Manager(), args.counter)
    if args.n_workers == 1:
        saved = 0
        for t in tqdm(trajs):
            saved += save_worker(t, cntr, args.output_group_name)
        
        print('saved {} total trajs'.format(saved))
    else:
        map_fn = functools.partial(save_worker, cntr=cntr, group_name=args.output_group_name)
        p = Pool(args.n_workers)
        print('saved {} total trajs'.format(sum(tqdm(p.imap_unordered(map_fn, trajs), total=len(trajs)))))