from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os.path
import sys

from PIL import Image

import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
import config

IMAGE_HEIGHT = config.IMAGE_HEIGHT
IMAGE_WIDTH = config.IMAGE_WIDTH
CHAR_SETS = config.CHAR_SETS
CLASSES_NUM = config.CLASSES_NUM
CHARS_NUM = config.CHARS_NUM

RECORD_DIR = config.RECORD_DIR
TRAIN_FILE = config.TRAIN_FILE
VALID_FILE = config.VALID_FILE

FLAGS = None

def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

  
def label_to_one_hot(label):
  one_hot_label = np.zeros([CHARS_NUM, CLASSES_NUM])
  offset = []
  index = []
  for i, c in enumerate(label):
    offset.append(i)
    index.append(CHAR_SETS.index(c))
  one_hot_index = [offset, index]
  one_hot_label[one_hot_index] = 1.0
  return one_hot_label.astype(np.uint8)


def conver_to_tfrecords(data_set, name):
  """Converts a dataset to tfrecords."""
  if not os.path.exists(RECORD_DIR):
      os.makedirs(RECORD_DIR)
  filename = os.path.join(RECORD_DIR, name)
  print('>> Writing', filename)
  writer = tf.python_io.TFRecordWriter(filename)
  num_examples = len(data_set)
  for index in range(num_examples):
    image = data_set[index][0]
    height = image.shape[0]
    width = image.shape[1]
    image_raw = image.tostring()
    label = data_set[index][1]
    label_raw = label_to_one_hot(label).tostring()
    example = tf.train.Example(features=tf.train.Features(feature={
        'height': _int64_feature(height),
        'width': _int64_feature(width),
        'label_raw': _bytes_feature(label_raw),
        'image_raw': _bytes_feature(image_raw)}))
    writer.write(example.SerializeToString())
  writer.close()
  print('>> Writing Done!')
  

def create_data_list(image_dir):
  if not gfile.Exists(image_dir):
    print("Image director '" + image_dir + "' not found.")
    return None
  extensions = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG']
  print("Looking for images in '" + image_dir + "'")
  file_list = []
  for extension in extensions:
    file_glob = os.path.join(image_dir, '*.' + extension)
    file_list.extend(gfile.Glob(file_glob))
  if not file_list:
    print("No files found in '" + image_dir + "'")
    return None
  images = []
  labels = []
  for file_name in file_list:
    image = Image.open(file_name)
    image_gray = image.convert('L')
    image_resize = image_gray.resize(size=(IMAGE_WIDTH,IMAGE_HEIGHT))
    input_img = np.array(image_resize, dtype='int16')
    image.close()
    label_name = os.path.basename(file_name).split('_')[0]
    images.append(input_img)
    labels.append(label_name)
  return zip(images, labels)


def main(_):
  training_data = create_data_list(FLAGS.train_dir)
  conver_to_tfrecords(training_data, TRAIN_FILE)
    
  validation_data = create_data_list(FLAGS.valid_dir)
  conver_to_tfrecords(validation_data, VALID_FILE)


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--train_dir',
      type=str,
      default='./data/train_data',
      help='Directory training to get captcha data files and write the converted result.'
  )
  parser.add_argument(
      '--valid_dir',
      type=str,
      default='./data/valid_data',
      help='Directory validation to get captcha data files and write the converted result.'
  )
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)