#-*-coding:utf-8-*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import re

from absl import app as absl_app
from absl import flags
import tensorflow as tf  # pylint: disable=g-bad-import-order

from resnet import resnet_model
from resnet import resnet_run_loop

_NUM_CHANNELS = 3
_NUM_CLASSES = 5

# The record is the image plus a one-byte label

_NUM_IMAGES = {
    'train': 230944,
    'validation': 19448,
}

DATASET_NAME = 'nsfw'

_IMAGE_SIZE = 64

###############################################################################
# Data processing
###############################################################################

def get_filenames(is_training, data_dir):
    file_names = []
    if is_training:
        pattern = 'nsfw_train_.*.tfrecord'
    else:
        pattern = 'nsfw_validation_.*.tfrecord'
    for top, dis, files in os.walk(data_dir):
        for name in files:
            if re.match(pattern, name):
                file_names.append(os.path.join(top, name))
    return file_names

def preprocess_image(image, is_training):
  """Preprocess a single image of layout [height, width, depth]."""
  if is_training:
    # Resize the image to add four extra pixels on each side.
    image = tf.image.resize_image_with_crop_or_pad(
        image, _IMAGE_SIZE + 8, _IMAGE_SIZE + 8)

    # Randomly crop a [_HEIGHT, _WIDTH] section of the image.
    image = tf.random_crop(image, [_IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS])

    # Randomly flip the image horizontally.
    image = tf.image.random_flip_left_right(image)

  # Subtract off the mean and divide by the variance of the pixels.
  image = tf.image.per_image_standardization(image)
  return image


def parse_record(raw_record, is_training ):
    image_feature_description = {
        'image/height': tf.FixedLenFeature([], tf.int64),
        'image/width': tf.FixedLenFeature([], tf.int64),
        'image/format': tf.FixedLenFeature([], tf.string),
        'image/class/label': tf.FixedLenFeature([], tf.int64),
        'image/encoded': tf.FixedLenFeature([], tf.string),
    }

    parsed = tf.parse_single_example(raw_record, image_feature_description)
    image = parsed['image/encoded']
    image = tf.image.decode_image(image, channels=3)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image.set_shape([None, None, 3])
    image = tf.image.resize_images(image, [_IMAGE_SIZE, _IMAGE_SIZE])
    image  = preprocess_image(image, is_training)
    label = parsed['image/class/label']
    return image, label


def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
                           parse_record_fn, num_epochs=1, num_gpus=None,
                           examples_per_epoch=None, dtype=tf.float32):
    dataset = dataset.prefetch(buffer_size=batch_size)
    if is_training:
        dataset = dataset.shuffle(buffer_size=shuffle_buffer)

    dataset = dataset.repeat(num_epochs)

    if is_training and num_gpus and examples_per_epoch:
        total_examples = num_epochs * examples_per_epoch
        # Force the number of batches to be divisible by the number of devices.
        # This prevents some devices from receiving batches while others do not,
        # which can lead to a lockup. This case will soon be handled directly by
        # distribution strategies, at which point this .take() operation will no
        # longer be needed.
        total_batches = total_examples // batch_size // num_gpus * num_gpus
        dataset.take(total_batches * batch_size)

    # Parse the raw records into images and labels. Testing has shown that setting
    # num_parallel_batches > 1 produces no improvement in throughput, since
    # batch_size is almost always much greater than the number of CPU cores.
    dataset = dataset.apply(
        tf.contrib.data.map_and_batch(
            lambda value: parse_record_fn(value, is_training),
            batch_size=batch_size,
            num_parallel_batches=1,
            drop_remainder=False))

    # Operations between the final prefetch and the get_next call to the iterator
    # will happen synchronously during run time. We prefetch here again to
    # background all of the above processing work and keep it out of the
    # critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE
    # allows DistributionStrategies to adjust how many batches to fetch based
    # on how many devices are present.
    dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)

    return dataset



def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None,
             dtype=tf.float32):
    filenames = get_filenames(is_training, data_dir)
    print(filenames)

    dataset = tf.data.TFRecordDataset(filenames=filenames)
    dataset = process_record_dataset(
        dataset=dataset,
        is_training=is_training,
        batch_size=batch_size,
        shuffle_buffer=500,
        parse_record_fn=parse_record,
        num_epochs=num_epochs,
        num_gpus=num_gpus,
        examples_per_epoch=_NUM_IMAGES['train'] if is_training else None,
        dtype=dtype
    )
    return dataset



###############################################################################
# Running the model
###############################################################################
class Model(resnet_model.Model):
  """Model class with appropriate defaults for CIFAR-10 data."""

  def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
               resnet_version=resnet_model.DEFAULT_VERSION,
               dtype=resnet_model.DEFAULT_DTYPE):
    """These are the parameters that work for CIFAR-10 data.

    Args:
      resnet_size: The number of convolutional layers needed in the model.
      data_format: Either 'channels_first' or 'channels_last', specifying which
        data format to use when setting up the model.
      num_classes: The number of output classes needed from the model. This
        enables users to extend the same model to their own datasets.
      resnet_version: Integer representing which version of the ResNet network
      to use. See README for details. Valid values: [1, 2]
      dtype: The TensorFlow dtype to use for calculations.

    Raises:
      ValueError: if invalid resnet_size is chosen
    """
    if resnet_size % 6 != 2:
      raise ValueError('resnet_size must be 6n + 2:', resnet_size)

    num_blocks = (resnet_size - 2) // 6

    super(Model, self).__init__(
        resnet_size=resnet_size,
        bottleneck=False,
        num_classes=num_classes,
        num_filters=16,
        kernel_size=3,
        conv_stride=1,
        first_pool_size=None,
        first_pool_stride=None,
        block_sizes=[num_blocks] * 3,
        block_strides=[1, 2, 2],
        final_size=64,
        resnet_version=resnet_version,
        data_format=data_format,
        dtype=dtype
    )


def cifar10_model_fn(features, labels, mode, params):
  """Model function for CIFAR-10."""
  features = tf.reshape(features, [-1, _IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS])

  learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
      batch_size=params['batch_size'], batch_denom=128,
      num_images=_NUM_IMAGES['train'], boundary_epochs=[10, 20, 30],
      decay_rates=[1, 0.1, 0.01, 0.001])

  # We use a weight decay of 0.0002, which performs better
  # than the 0.0001 that was originally suggested.
  weight_decay = 2e-4

  # Empirical testing showed that including batch_normalization variables
  # in the calculation of regularized loss helped validation accuracy
  # for the CIFAR-10 dataset, perhaps because the regularization prevents
  # overfitting on the small data set. We therefore include all vars when
  # regularizing and computing loss during training.
  def loss_filter_fn(_):
    return True

  return resnet_run_loop.resnet_model_fn(
      features=features,
      labels=labels,
      mode=mode,
      model_class=Model,
      resnet_size=params['resnet_size'],
      weight_decay=weight_decay,
      learning_rate_fn=learning_rate_fn,
      momentum=0.9,
      data_format=params['data_format'],
      resnet_version=params['resnet_version'],
      loss_scale=params['loss_scale'],
      loss_filter_fn=loss_filter_fn,
      dtype=params['dtype'],
      fine_tune=params['fine_tune']
  )


def set_defaults(**kwargs):
  for key, value in kwargs.items():
    flags.FLAGS.set_default(name=key, value=value)


def define_flower_flags():
  resnet_run_loop.define_resnet_flags()
  flags.adopt_module_key_flags(resnet_run_loop)

  set_defaults(
      data_dir='',
      model_dir='',
      resnet_size='32',
      train_epochs=50,
      epochs_between_evals=50,
      batch_size=128)


def run_flower(flags_obj):
  """Run ResNet CIFAR-10 training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
  """
  input_function = input_fn
  resnet_run_loop.resnet_main(
      flags_obj,
      cifar10_model_fn,
      input_function,
      DATASET_NAME,
      shape=[_IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS])


def main(_):
  run_flower(flags.FLAGS)


if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
  define_flower_flags()
  absl_app.run(main)