"""
Class for loading imitation learning TFRecords, and applying data augmentation during training.

imgaug code from Carla authors:
https://github.com/carla-simulator/imitation-learning/issues/1#issuecomment-355747357

Dosovitskiy et al. CARLA: An Open Urban Driving Simulator
http://proceedings.mlr.press/v78/dosovitskiy17a/dosovitskiy17a.pdf:
"To further reduce overfitting, we performed extensive data augmentation by adding Gaussian blur,
additive Gaussian noise, pixel dropout, additive and multiplicative brightness variation, contrast variation,
and saturation variation"

Codevilla et al. End-to-end Driving via Conditional Imitation Learning
http://vladlen.info/papers/conditional-imitation.pdf:
"Transformations include change in contrast, brightness, and tone, as well as addition of Gaussian blur, Gaussian noise,
salt-and-pepper noise, and region dropout (masking out a random set of rectangles in the image, each rectangle taking
roughly 1% of image area)"
"""
from __future__ import unicode_literals

import functools

import tensorflow as tf

from common.train import inputs
from common.util import img_aug
import constants as ilc


def convert_image_tf(img_str):
    rgb_image = tf.reshape(tf.decode_raw(img_str, tf.uint8), shape=(ilc.IMG_HEIGHT, ilc.IMG_WIDTH, 3))
    return rgb_image


def get_feat_schema():
    schema = {
        ilc.FEATKEY_KEY: tf.FixedLenFeature(dtype=tf.string, shape=[]),
        ilc.FEATKEY_IMG: tf.FixedLenFeature(dtype=tf.string, shape=[]),
    }
    for key in ilc.TGT_KEYS:
        schema[key] = tf.FixedLenFeature(dtype=tf.float32, shape=[])
    return schema


class Preprocessor(object):
    """Base class for preprocessing steps that run at training time, shortly before the data enters the model_fn.

    Raises:
        AssertionError if mode is note TRAIN
    """

    def preprocess(self, dataset, mode):
        """Applies transformation to dataset

        Args:
            dataset: a tf.data.Dataset
            mode: a tf.estimator.ModeKeys

        Returns:
            a tf.data.Dataset
        """
        raise NotImplementedError


class FilterValidIntention(Preprocessor):

    def preprocess(self, dataset, mode):
        dataset = dataset.filter(self.filter_valid_intentions)
        return dataset

    @staticmethod
    def filter_valid_intentions(tf_example):
        """Return True if high-level command is in {2, 3, 4, 5}.

        Args:
            tf_example: Dict[str, tf.Tensor]

        Returns:
            tf.Tensor (type=bool)
        """
        high_level_command = tf_example[ilc.TGT_HIGH_LVL_CMD]
        return tf.logical_and(
            tf.greater_equal(high_level_command, 2),
            tf.less_equal(high_level_command, 5))


class CarlaPreprocessor(Preprocessor):
    """Base class for preprocessing steps that run at training time, shortly before the data enters the model_fn."""

    def preprocess(self, dataset, mode):
        """Applies transformation to dataset

        Args:
            dataset: a tf.data.Dataset
            mode: a tf.estimator.ModeKeys

        Returns:
            a tf.data.Dataset
        """
        dataset = dataset.map(self.read_fn, num_parallel_calls=16)
        return dataset

    @staticmethod
    def read_fn(tf_example):
        """Given a tf_example dict, separates into feature_dict and target_dict"""
        flat_img = tf_example[ilc.FEATKEY_IMG]
        img = convert_image_tf(flat_img)
        img = tf.cast(img, tf.float32)
        img = tf.squeeze(img)
        img = tf.div(img, 255.0)

        feats = {
            ilc.FEATKEY_IMG: img,
            ilc.TGT_SPEED: tf_example[ilc.TGT_SPEED],
        }
        tgts = {key: tf_example[key] for key in ilc.TGT_KEYS}

        return feats, tgts


class ProbabilisticImageAugmentor(Preprocessor):
    """Preprocessor that applies `image_transform` to input image with `augmentation_prob` probability.

    Raises:
        AssertionError if mode is not TRAIN
    """

    def __init__(self, augmentation_prob, image_transform):
        self.augmentation_prob = augmentation_prob
        self.image_transform = image_transform
        super(ProbabilisticImageAugmentor, self).__init__()

    def apply_to_image(self, feats, tgts):
        orig_img = feats[ilc.FEATKEY_IMG]
        weighted_coin = tf.less(tf.random_uniform([], 0, 1.0), self.augmentation_prob)
        img = tf.cond(
            weighted_coin,
            lambda: self.image_transform(orig_img),
            lambda: orig_img,
        )
        feats[ilc.FEATKEY_IMG] = img
        return feats, tgts

    def preprocess(self, dataset, mode):
        assert mode == tf.estimator.ModeKeys.TRAIN, 'Should not augment eval / inference'
        return dataset.map(self.apply_to_image, num_parallel_calls=16)


def _rand_gauss_blur(img):
    stddev = tf.random_uniform([], 0, 1.5)
    return img_aug.gauss_blur(img, stddev=stddev)


def _rand_gauss_noise(img):
    sigma = tf.random_uniform([], 0, 0.05)
    coin = tf.less(tf.random_uniform([], 0, 1.0), 0.5)
    new_img = tf.cond(
        coin,
        lambda: img_aug.gauss_noise(img, True, stddev=sigma),
        lambda: img_aug.gauss_noise(img, False, stddev=sigma),
    )
    return new_img


def _rand_pixelwise_dropout(img):
    coin = tf.less(tf.random_uniform([], 0.0, 1.0), 0.5)
    p_pixel_drop = tf.random_uniform([], 0, 0.1)
    new_img = tf.cond(
        coin,
        lambda: img_aug.pixelwise_dropout(img, p_pixel_drop, True),
        lambda: img_aug.pixelwise_dropout(img, p_pixel_drop, False),
    )
    return new_img


def _rand_coarse_pixelwise_dropout(img):
    coin = tf.less(tf.random_uniform([], 0.0, 1.0), 0.5)
    p_pixel_drop = tf.random_uniform([], 0, 0.1)
    p_height = tf.random_uniform([], 0.08, 0.2)
    p_width = tf.random_uniform([], 0.08, 0.2)
    new_img = tf.cond(
        coin,
        lambda: img_aug.coarse_pixelwise_dropout(img, p_height, p_width, p_pixel_drop, True),
        lambda: img_aug.coarse_pixelwise_dropout(img, p_height, p_width, p_pixel_drop, False),
    )
    return new_img


def train_input_fn(tfrecord_fpaths, batch_size, shuffle_buffer_size):
    input_fn = inputs.input_fn_factory(
        tfrecord_fpaths=tfrecord_fpaths,
        feature_schema=get_feat_schema(),
        batch_size=batch_size,
        mode=tf.estimator.ModeKeys.TRAIN,
        num_epochs=None,
        model_preprocessors=[
            FilterValidIntention(),
            CarlaPreprocessor(),
            ProbabilisticImageAugmentor(0.09, _rand_gauss_blur),
            ProbabilisticImageAugmentor(0.09, _rand_gauss_noise),
            ProbabilisticImageAugmentor(0.30, _rand_pixelwise_dropout),
            ProbabilisticImageAugmentor(0.30, _rand_coarse_pixelwise_dropout),
            ProbabilisticImageAugmentor(0.30, functools.partial(tf.image.random_brightness, max_delta=32 / 255)),
            ProbabilisticImageAugmentor(0.30, functools.partial(tf.image.random_saturation, lower=0.5, upper=1.5)),
            ProbabilisticImageAugmentor(0.09, functools.partial(tf.image.random_contrast, lower=0.5, upper=1.2)),
        ],
        shuffle=True,
        shuffle_buffer_size=shuffle_buffer_size,
    )
    return input_fn


def evaluation_input_fn(tfrecord_fpaths, batch_size):
    input_fn = inputs.input_fn_factory(
        tfrecord_fpaths=tfrecord_fpaths,
        feature_schema=get_feat_schema(),
        batch_size=batch_size,
        mode=tf.estimator.ModeKeys.EVAL,
        num_epochs=1,
        model_preprocessors=[FilterValidIntention(), CarlaPreprocessor()],
        num_parallel_calls=16,
    )
    return input_fn