from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import os
import glob
import re
import sys
import cv2
import h5py
import torch
import numpy as np
import argparse
import json

from threading import Thread, Lock
if sys.version_info[0] == 2:
	import Queue as queue
else:
	import queue


folder_map = {
	"train": ["train2014"],
	"val": ["val2014"],
	"trainval": ["train2014", "val2014"],
	"test": ["test2015"],
}


def save_images(image_path, image_type, data_path, data_name, num_workers):
	"""
	Process all of the image to a numpy array, then store them to a file.
	--------------------
	Arguments:
		image_path (str): path points to images.
		image_type (str): "train", "val", "trainval", or "test".
		data_path (str): path points to the location which stores images.
		data_name (str): name of stored file.
		num_workers (int): number of threads used to load images.
	"""
	dataset = h5py.File(os.path.join(data_path, "%s_%s.h5" % (data_name, image_type)), "w")

	q = queue.Queue()
	images_idx = {}
	images_path = []
	lock = Lock()
	for data in folder_map[image_type]:
		folder = os.path.join(image_path, data)
		images_path.extend(glob.glob(folder+"/*"))
	pattern = re.compile(r"_([0-9]+).jpg")

	for i, img_path in enumerate(images_path):
		assert len(pattern.findall(img_path)) == 1, "More than one index found in an image path!"
		idx = int(pattern.findall(img_path)[0])
		images_idx[idx] = i
		q.put((i, img_path))
	assert len(images_idx) == len(images_path), "Duplicated indices are found!"
	images = dataset.create_dataset("images", (len(images_path), 448, 448, 3), dtype=np.uint8)

	def _worker():
		while True:
			i, img_path = q.get()
			if i is None:
				break
			img = cv2.cvtColor((cv2.resize(cv2.imread(img_path, cv2.CV_LOAD_IMAGE_COLOR), (448, 448))), 
				cv2.COLOR_BGR2RGB)

			with lock:
				if i % 1000 == 0:
					print("processing %i/%i" % (i, len(images_path)))
				images[i] = img
			q.task_done()

	for _ in range(num_workers):
		thread = Thread(target=_worker)
		thread.daemon = True
		thread.start()
	q.join()

	print("Terminating threads...")
	for _ in range(2*num_workers):
		q.put((None, None))

	torch.save(images_idx, os.path.join(data_path, "%s_%s.pt" % (data_name, image_type)))
	dataset.close()
	print("Finish saving images...")


def main(opt):
	"""
	Create file that stores images in "train", "val", "trainval", and "test" datasets.
	"""
	# transform = transforms.Compose([
	# 		transforms.Scale(opt.size_scale),
	# 		transforms.ToTensor(),
	# 		transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
	# 	])

	# Process train images
	print("Create train images dataset...")
	save_images(opt.img_path, "train", opt.data_path, opt.data_name, opt.num_workers)

	# Process val images
	print("Create val images dataset...")
	save_images(opt.img_path, "val", opt.data_path, opt.data_name, opt.num_workers)

	# # Process trainval images
	# print("Create trainval images dataset...")
	# save_images(opt.img_path, "trainval", opt.data_path, opt.data_name, opt.num_workers)

	# # Process test images
	# print("Create test images dataset...")
	# save_images(opt.img_path, "test", opt.data_path, opt.data_name, opt.num_workers)

	print("Done!")


if __name__ == "__main__":
	parser = argparse.ArgumentParser()
	parser.add_argument("--img_path", default="/ceph/kien/data2.0")
	parser.add_argument("--data_name", default="cocoimages")
	parser.add_argument("--data_path", default="/ceph/kien/VQA/dataset")
	parser.add_argument("--num_workers", type=int, default=8)

	args = parser.parse_args()
	params = vars(args)
	print("Parsed input parameters:")
	print(json.dumps(params, indent=2))

	main(args)