#-*-coding:utf-8-*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import re

from absl import app as absl_app
from absl import flags
import tensorflow as tf  # pylint: disable=g-bad-import-order

from resnet import resnet_model
from resnet import resnet_run_loop

_NUM_CHANNELS = 3
_NUM_CLASSES = 5

# The record is the image plus a one-byte label

_NUM_IMAGES = {
    'train': 230944,
    'validation': 19448,
}

DATASET_NAME = 'nsfw'

_IMAGE_SIZE = 224
_HEIGHT = 224
_WIDTH = 224
###############################################################################
# Data processing
###############################################################################

def get_filenames(is_training, data_dir):
    file_names = []
    if is_training:
        pattern = 'nsfw_train_.*.tfrecord'
    else:
        pattern = 'nsfw_validation_.*.tfrecord'
    for top, dis, files in os.walk(data_dir):
        for name in files:
            if re.match(pattern, name):
                file_names.append(os.path.join(top, name))
    return file_names

def preprocess_image(image, is_training):
  """Preprocess a single image of layout [height, width, depth]."""
  if is_training:
    # Resize the image to add four extra pixels on each side.
    image = tf.image.resize_image_with_crop_or_pad(
        image, _IMAGE_SIZE + 8, _IMAGE_SIZE + 8)

    # Randomly crop a [_HEIGHT, _WIDTH] section of the image.
    image = tf.random_crop(image, [_HEIGHT, _WIDTH, _NUM_CHANNELS])

    # Randomly flip the image horizontally.
    image = tf.image.random_flip_left_right(image)

  # Subtract off the mean and divide by the variance of the pixels.
  image = tf.image.per_image_standardization(image)
  return image


def parse_record(raw_record, is_training, dtype ):
    print(dtype)
    image_feature_description = {
        'image/height': tf.FixedLenFeature([], tf.int64),
        'image/width': tf.FixedLenFeature([], tf.int64),
        'image/format': tf.FixedLenFeature([], tf.string),
        'image/class/label': tf.FixedLenFeature([], tf.int64),
        'image/encoded': tf.FixedLenFeature([], tf.string),
    }

    parsed = tf.parse_single_example(raw_record, image_feature_description)
    image = parsed['image/encoded']
    image = tf.image.decode_image(image, channels=3)
    image = tf.image.convert_image_dtype(image, dtype)
    image.set_shape([None, None, 3])
    image = tf.image.resize_images(image, [_IMAGE_SIZE, _IMAGE_SIZE])
    image  = preprocess_image(image, is_training)
    label = parsed['image/class/label']
    label = tf.cast(label, tf.int32)
    image = tf.cast(image, dtype)
    return image, label


def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
                           parse_record_fn, num_epochs=1, num_gpus=None,
                           examples_per_epoch=None, dtype=tf.float32):
    dataset = dataset.prefetch(buffer_size=batch_size)
    if is_training:
        dataset = dataset.shuffle(buffer_size=shuffle_buffer)

    dataset = dataset.repeat(num_epochs)

    if is_training and num_gpus and examples_per_epoch:
        total_examples = num_epochs * examples_per_epoch
        # Force the number of batches to be divisible by the number of devices.
        # This prevents some devices from receiving batches while others do not,
        # which can lead to a lockup. This case will soon be handled directly by
        # distribution strategies, at which point this .take() operation will no
        # longer be needed.
        total_batches = total_examples // batch_size // num_gpus * num_gpus
        dataset.take(total_batches * batch_size)

    # Parse the raw records into images and labels. Testing has shown that setting
    # num_parallel_batches > 1 produces no improvement in throughput, since
    # batch_size is almost always much greater than the number of CPU cores.
    dataset = dataset.apply(
        tf.contrib.data.map_and_batch(
            lambda value: parse_record_fn(value, is_training, dtype),
            batch_size=batch_size,
            num_parallel_batches=1,
            drop_remainder=False))

    # Operations between the final prefetch and the get_next call to the iterator
    # will happen synchronously during run time. We prefetch here again to
    # background all of the above processing work and keep it out of the
    # critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE
    # allows DistributionStrategies to adjust how many batches to fetch based
    # on how many devices are present.
    dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)

    return dataset



def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None, dtype=tf.float32):
    filenames = get_filenames(is_training, data_dir)
    print(filenames)

    dataset = tf.data.TFRecordDataset(filenames=filenames)
    dataset = process_record_dataset(
        dataset=dataset,
        is_training=is_training,
        batch_size=batch_size,
        shuffle_buffer=500,
        parse_record_fn=parse_record,
        num_epochs=num_epochs,
        num_gpus=num_gpus,
        examples_per_epoch=_NUM_IMAGES['train'] if is_training else None,
        dtype=dtype
    )
    return dataset

def _get_block_sizes(resnet_size):
  """Retrieve the size of each block_layer in the ResNet model.

  The number of block layers used for the Resnet model varies according
  to the size of the model. This helper grabs the layer set we want, throwing
  an error if a non-standard size has been selected.

  Args:
    resnet_size: The number of convolutional layers needed in the model.

  Returns:
    A list of block sizes to use in building the model.

  Raises:
    KeyError: if invalid resnet_size is received.
  """
  choices = {
      18: [2, 2, 2, 2],
      34: [3, 4, 6, 3],
      50: [3, 4, 6, 3],
      101: [3, 4, 23, 3],
      152: [3, 8, 36, 3],
      200: [3, 24, 36, 3]
  }

  try:
    return choices[resnet_size]
  except KeyError:
    err = ('Could not find layers for selected Resnet size.\n'
           'Size received: {}; sizes allowed: {}.'.format(
               resnet_size, choices.keys()))
    raise ValueError(err)



###############################################################################
# Running the model
###############################################################################
class Model(resnet_model.Model):
  """Model class with appropriate defaults for CIFAR-10 data."""

  def __init__(self,
               resnet_size,
               data_format=None,
               num_classes=_NUM_CLASSES,
               resnet_version=resnet_model.DEFAULT_VERSION,
               dtype=resnet_model.DEFAULT_DTYPE):
    """These are the parameters that work for CIFAR-10 data.

    Args:
      resnet_size: The number of convolutional layers needed in the model.
      data_format: Either 'channels_first' or 'channels_last', specifying which
        data format to use when setting up the model.
      num_classes: The number of output classes needed from the model. This
        enables users to extend the same model to their own datasets.
      resnet_version: Integer representing which version of the ResNet network
      to use. See README for details. Valid values: [1, 2]
      dtype: The TensorFlow dtype to use for calculations.

    Raises:
      ValueError: if invalid resnet_size is chosen
    """
    if resnet_size % 6 != 2:
      raise ValueError('resnet_size must be 6n + 2:', resnet_size)

    if resnet_size < 50:
      bottleneck = False
      final_size = 512
    else:
      bottleneck = True
      final_size = 2048

    super(Model, self).__init__(
        resnet_size=resnet_size,
        bottleneck=bottleneck,
        num_classes=num_classes,
        num_filters=64,
        kernel_size=7,
        conv_stride=1,
        first_pool_size=None,
        first_pool_stride=None,
        block_sizes=_get_block_sizes(resnet_size),
        block_strides=[1, 2, 2, 2],
        final_size=final_size,
        resnet_version=resnet_version,
        data_format=data_format,
        dtype=dtype
    )


def nsfw_model_fn(features, labels, mode, params):
  """Model function for CIFAR-10."""
  features = tf.reshape(features, [-1, _IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS])

  if params['fine_tune']:
    warmup = False
    base_lr = .1
  else:
    warmup = True
    base_lr = .128

  learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
      batch_size=params['batch_size'],
      batch_denom=256,
      num_images=_NUM_IMAGES['train'],
      boundary_epochs=[30, 60, 80, 90],
      decay_rates=[1, 0.1, 0.01, 0.001, 1e-4],
      warmup=warmup,
      base_lr=base_lr)

  # We use a weight decay of 0.0002, which performs better
  # than the 0.0001 that was originally suggested.
  weight_decay = 2e-4

  # Empirical testing showed that including batch_normalization variables
  # in the calculation of regularized loss helped validation accuracy
  # for the CIFAR-10 dataset, perhaps because the regularization prevents
  # overfitting on the small data set. We therefore include all vars when
  # regularizing and computing loss during training.
  def loss_filter_fn(_):
    return True

  return resnet_run_loop.resnet_model_fn(
      features=features,
      labels=labels,
      mode=mode,
      model_class=Model,
      resnet_size=params['resnet_size'],
      weight_decay=weight_decay,
      learning_rate_fn=learning_rate_fn,
      momentum=0.9,
      data_format=params['data_format'],
      resnet_version=params['resnet_version'],
      loss_scale=params['loss_scale'],
      loss_filter_fn=loss_filter_fn,
      dtype=params['dtype'],
      fine_tune=params['fine_tune']
  )


def set_defaults(**kwargs):
  for key, value in kwargs.items():
    flags.FLAGS.set_default(name=key, value=value)


def define_nsfw_flags():
  resnet_run_loop.define_resnet_flags()
  flags.adopt_module_key_flags(resnet_run_loop)

  set_defaults(
      data_dir='',
      model_dir='',
      resnet_size='50',
      train_epochs=50,
      epochs_between_evals=50,
      batch_size=128)


def run_nsfw(flags_obj):
  input_function = input_fn
  resnet_run_loop.resnet_main(
      flags_obj,
      nsfw_model_fn,
      input_function,
      DATASET_NAME,
      shape=[_IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS])


def main(_):
  run_nsfw(flags.FLAGS)


if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
  define_nsfw_flags()
  absl_app.run(main)