#!/usr/bin/env python

"""
    rsna_train.py
    - Train models using keras and tensorflow backend.

    @author Alesson Scapinello
	@author Bernardo Henz
    @author Daniel Souza
    @author Felipe Kitamura
    @author Igor Santos
    @author José Venson
"""

import argparse
import os
import sys
import warnings

import keras
import keras.preprocessing.image
import tensorflow as tf

# Change these to absolute imports if you copy this script outside the keras_retinanet package.

from keras_retinanet.keras_retinanet  import layers  # noqa: F401
from keras_retinanet.keras_retinanet  import losses
from keras_retinanet.keras_retinanet  import models
from keras_retinanet.keras_retinanet.callbacks import RedirectModel
from keras_retinanet.keras_retinanet.callbacks.eval import Evaluate
from keras_retinanet.keras_retinanet.models.retinanet import retinanet_bbox
from keras_retinanet.keras_retinanet.preprocessing.csv_generator import CSVGenerator
from keras_retinanet.keras_retinanet.preprocessing.kitti import KittiGenerator
from keras_retinanet.keras_retinanet.preprocessing.open_images import OpenImagesGenerator
from keras_retinanet.keras_retinanet.preprocessing.pascal_voc import PascalVocGenerator
from keras_retinanet.keras_retinanet.utils.anchors import make_shapes_callback
from keras_retinanet.keras_retinanet.utils.keras_version import check_keras_version
from keras_retinanet.keras_retinanet.utils.model import freeze as freeze_model
from keras_retinanet.keras_retinanet.utils.model import freeze_first_N_layers
from keras_retinanet.keras_retinanet.utils.transform import random_transform_generator
from keras_retinanet.keras_retinanet.models.retinanet import AnchorParameters
from keras.callbacks import TensorBoard


from rsna_generator import RsnaGenerator, ImageOnlyTransformations

from keras import backend as K

class TrainValTensorBoard(TensorBoard):
    """ Tensorboard callback for train and validation.
    """
    def __init__(self, log_dir='./logs', **kwargs):
        # Make the original `TensorBoard` log to a subdirectory 'training'
        training_log_dir = os.path.join(log_dir, 'training')
        super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs)

        # Log the validation metrics to a separate subdirectory
        self.val_log_dir = os.path.join(log_dir, 'validation')

    def set_model(self, model):
        # Setup writer for validation metrics
        self.val_writer = tf.summary.FileWriter(self.val_log_dir)
        super(TrainValTensorBoard, self).set_model(model)

    def on_epoch_end(self, epoch, logs=None):
        # Pop the validation logs and handle them separately with
        # `self.val_writer`. Also rename the keys so that they can
        # be plotted on the same figure with the training metrics
        logs = logs or {}
        val_logs = {k.replace('val_', ''): v for k, v in logs.items() if k.startswith('val_')}
        for name, value in val_logs.items():
            summary = tf.Summary()
            summary_value = summary.value.add()
            summary_value.simple_value = value.item()
            summary_value.tag = name
            self.val_writer.add_summary(summary, epoch)
        self.val_writer.flush()

        # Pass the remaining logs to `TensorBoard.on_epoch_end`
        logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
        super(TrainValTensorBoard, self).on_epoch_end(epoch, logs)

    def on_train_end(self, logs=None):
        super(TrainValTensorBoard, self).on_train_end(logs)
        self.val_writer.close()


def set_keras_backend(backend):
    """ Configure tensorflow as backend
    """
    if K.backend() != backend:
        os.environ['KERAS_BACKEND'] = backend
        importlib.reload(K)
        assert K.backend() == backend
    if backend == "tensorflow":
        K.get_session().close()
        cfg = K.tf.ConfigProto()
        cfg.gpu_options.allow_growth = True
        K.set_session(K.tf.Session(config=cfg))
        K.clear_session()



def makedirs(path):
    """ Try to create the directory,
        pass if the directory exists already, fails otherwise.
    """
    try:
        os.makedirs(path)
    except OSError:
        if not os.path.isdir(path):
            raise


def get_session():
    """ Construct a modified tf session.
    """
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    return tf.Session(config=config)


def model_with_weights(model, weights, skip_mismatch):
    """ Load weights for model.

    Args
        model         : The model to load weights for.
        weights       : The weights to load.
        skip_mismatch : If True, skips layers whose shape of weights doesn't match with the model.
    """
    if weights is not None:
        model.load_weights(weights, by_name=True, skip_mismatch=skip_mismatch)
    return model


def create_models(backbone_retinanet, num_classes, weights, multi_gpu=0, freeze_backbone=False,anchors_ratios=None,anchors_scales=None,noise_aug_std=None,dropout_rate=None):
    """ Creates three models (model, training_model, prediction_model).

    Args
        backbone_retinanet : A function to call to create a retinanet model with a given backbone.
        num_classes        : The number of classes to train.
        weights            : The weights to load into the model.
        multi_gpu          : The number of GPUs to use for training.
        freeze_backbone    : If True, disables learning for the backbone.

    Returns
        model            : The base model. This is also the model that is saved in snapshots.
        training_model   : The training model. If multi_gpu=0, this is identical to model.
        prediction_model : The model wrapped with utility functions to perform object detection (applies regression values and performs NMS).
    """
    modifier = freeze_model if freeze_backbone else None
    # Keras recommends initialising a multi-gpu model on the CPU to ease weight sharing, and to prevent OOM errors.
    # optionally wrap in a parallel model
    anchor_params = AnchorParameters.default
    anchor_params.ratios = anchors_ratios
    anchor_params.scales = anchors_scales
    if multi_gpu > 1:
        from keras.utils import multi_gpu_model
        with tf.device('/cpu:0'):
            model = model_with_weights(backbone_retinanet(num_classes, modifier=modifier,num_anchors=anchor_params.num_anchors(),noise_aug_std=noise_aug_std,dropout_rate=dropout_rate), weights=weights, skip_mismatch=True)
        training_model = multi_gpu_model(model, gpus=multi_gpu)
    else:
        model          = model_with_weights(backbone_retinanet(num_classes, modifier=modifier,num_anchors=anchor_params.num_anchors(),noise_aug_std=noise_aug_std,dropout_rate=dropout_rate), weights=weights, skip_mismatch=True)
        training_model = model

    # make prediction model
    prediction_model = retinanet_bbox(model=model,anchors_ratios=anchors_ratios,anchors_scales=anchors_scales)

    # compile model
    training_model.compile(
        loss={
            'regression'    : losses.smooth_l1(),
            'classification': losses.focal()
        },
        optimizer=keras.optimizers.adam(lr=1e-4, clipnorm=0.001)
    )

    return model, training_model, prediction_model


def create_callbacks(model, training_model, prediction_model, validation_generator, args):
    """ Creates the callbacks to use during training.

    Args
        model: The base model.
        training_model: The model that is used for training.
        prediction_model: The model that should be used for validation.
        validation_generator: The generator for creating validation data.
        args: parseargs args object.

    Returns:
        A list of callbacks used for training.
    """
    callbacks = []

    tensorboard_callback = None

    if args.tensorboard_dir:
        tensorboard_callback = keras.callbacks.TensorBoard(
            log_dir                = args.tensorboard_dir,
            histogram_freq         = 0,
            batch_size             = args.batch_size,
            write_graph            = True,
            write_grads            = False,
            write_images           = False,
            embeddings_freq        = 0,
            embeddings_layer_names = None,
            embeddings_metadata    = None
        )
        # callbacks.append(tensorboard_callback)
        callbacks.append(TrainValTensorBoard(write_graph=False, log_dir=args.tensorboard_dir))

    # save the model
    if args.snapshots:
        # ensure directory created first; otherwise h5py will error after epoch.
        makedirs(args.snapshot_path)
        checkpoint = keras.callbacks.ModelCheckpoint(
            os.path.join(
                args.snapshot_path,
                '{backbone}_{dataset_type}_{{epoch:02d}}.h5'.format(backbone=args.backbone, dataset_type=args.dataset_type)
            ),
            verbose=1
        )
        checkpoint = RedirectModel(checkpoint, model)
        callbacks.append(checkpoint)

    # Append Reduce LR on plateau callback (see keras documentation)
    callbacks.append(keras.callbacks.ReduceLROnPlateau(
        monitor  = 'loss',
        factor   = 0.1,
        patience = 2,
        verbose  = 1,
        mode     = 'auto',
        epsilon  = 0.0001,
        cooldown = 0,
        min_lr   = 0
    ))

    return callbacks


def create_generators(args, preprocess_image):
    """ Create generators for training and validation.

    Args
        args             : parseargs object containing configuration for generators.
        preprocess_image : Function that preprocesses an image for the network.
    """
    common_args = {
        'batch_size'       : args.batch_size,
        'image_min_side'   : args.image_min_side,
        'image_max_side'   : args.image_max_side,
        'preprocess_image' : preprocess_image,
    }

    # create random transform generator for augmenting training data
    if args.random_transform:
        transform_generator = random_transform_generator(
            min_rotation=-0.1,
            max_rotation=0.1,
            min_translation=(-0.1, -0.1),
            max_translation=(0.1, 0.1),
            min_shear=-0.1,
            max_shear=0.1,
            min_scaling=(0.9, 0.9),
            max_scaling=(1.1, 1.1),
            flip_x_chance=0.5,
            flip_y_chance=0.0,
        )
    else:
        transform_generator = random_transform_generator(flip_x_chance=0.0)
    
    # Add proper data augmentation
    if args.image_only_transformations:
        image_only_transformations = ImageOnlyTransformations(noise_std = 0.02, 
                                                              contrast_level= 0.3,
                                                              brightness_level = 0.1)
    else:
        image_only_transformations = None

    #Create train and validation generator
    train_generator = RsnaGenerator(
        args.rsna_train_json,
        args.rsna_path,
        transform_generator=transform_generator,
        image_only_transformations = image_only_transformations,
        bbox_aug_std = args.bbox_aug_std,
        anchors_ratios = args.anchor_boxes,
        anchors_scales = args.anchor_scales,
        dicom_load_mode=args.dicom_load_mode,
        hist_eq=args.hist_eq,
        **common_args
    )
    
    train_generator.transform_parameters.fill_mode = 'constant'
    
    validation_generator = RsnaGenerator(
        args.rsna_val_json,
        args.rsna_path,
        image_only_transformations=None,
        anchors_ratios = args.anchor_boxes,
        anchors_scales = args.anchor_scales,
        dicom_load_mode=args.dicom_load_mode,
        hist_eq=args.hist_eq,
        **common_args
    )

    return train_generator, validation_generator


def check_args(parsed_args):
    """ Function to check for inherent contradictions within parsed arguments.
    For example, batch_size < num_gpus
    Intended to raise errors prior to backend initialisation.

    Args
        parsed_args: parser.parse_args()

    Returns
        parsed_args
    """

    if parsed_args.multi_gpu > 1 and parsed_args.batch_size < parsed_args.multi_gpu:
        raise ValueError(
            "Batch size ({}) must be equal to or higher than the number of GPUs ({})".format(parsed_args.batch_size,
                                                                                             parsed_args.multi_gpu))
    if parsed_args.multi_gpu > 1 and parsed_args.snapshot:
        raise ValueError(
            "Multi GPU training ({}) and resuming from snapshots ({}) is not supported.".format(parsed_args.multi_gpu,
                                                                                                parsed_args.snapshot))

    if parsed_args.multi_gpu > 1 and not parsed_args.multi_gpu_force:
        raise ValueError("Multi-GPU support is experimental, use at own risk! Run with --multi-gpu-force if you wish to continue.")

    if 'resnet' not in parsed_args.backbone:
        warnings.warn('Using experimental backbone {}. Only resnet50 has been properly tested.'.format(parsed_args.backbone))

    return parsed_args


def parse_args(args):
    """ Parse the arguments.
    """
    parser     = argparse.ArgumentParser(description='Simple training script for training a RetinaNet network.')
    subparsers = parser.add_subparsers(help='Arguments for specific dataset types.', dest='dataset_type')
    subparsers.required = True

    rsna_parser = subparsers.add_parser('rsna')
    rsna_parser.add_argument('rsna_path', help='Path to dataset directory (ie. /tmp/COCO).')
    rsna_parser.add_argument('rsna_train_json', help='Path to training json.')
    rsna_parser.add_argument('rsna_val_json', help='Path to validation json.')

    group = parser.add_mutually_exclusive_group()
    group.add_argument('--snapshot',          help='Resume training from a snapshot.')
    group.add_argument('--imagenet-weights',  help='Initialize the model with pretrained imagenet weights. This is the default behaviour.', action='store_const', const=True, default=True)
    group.add_argument('--weights',           help='Initialize the model with weights from a file.')
    group.add_argument('--no-weights',        help='Don\'t initialize the model with any weights.', dest='imagenet_weights', action='store_const', const=False)

    parser.add_argument('--backbone',        help='Backbone model used by retinanet.', default='resnet50', type=str)
    parser.add_argument('--batch-size',      help='Size of the batches.', default=1, type=int)
    parser.add_argument('--gpu',             help='Id of the GPU to use (as reported by nvidia-smi).')
    parser.add_argument('--multi-gpu',       help='Number of GPUs to use for parallel processing.', type=int, default=0)
    parser.add_argument('--multi-gpu-force', help='Extra flag needed to enable (experimental) multi-gpu support.', action='store_true')
    parser.add_argument('--epochs',          help='Number of epochs to train.', type=int, default=50)
    parser.add_argument('--steps',           help='Number of steps per epoch.', type=int, default=10000)
    parser.add_argument('--val_steps',           help='Number of steps per epoch.', type=int, default=400)
    parser.add_argument('--snapshot-path',   help='Path to store snapshots of models during training (defaults to \'./snapshots\')', default='./snapshots')
    parser.add_argument('--tensorboard_dir', help='Log directory for Tensorboard output', default='./logs')
    parser.add_argument('--no-snapshots',    help='Disable saving snapshots.', dest='snapshots', action='store_false')
    parser.add_argument('--no-evaluation',   help='Disable per epoch evaluation.', dest='evaluation', action='store_false')
    parser.add_argument('--freeze-backbone', help='Freeze training of backbone layers.', action='store_true')
    parser.add_argument('--data-aug', help='Enables random-transforms and image-only-transforms.', action='store_true')
    parser.add_argument('--random-transform', help='Randomly transform image and annotations.', action='store_true')
    parser.add_argument('--image_only_transformations', help='Randomly perform image-only transformations.', action='store_true')
    parser.add_argument('--noise_aug_std', help='Defines de std of the random noise added during training. If noise_aug_std=None, no noise is added.', type=float,default=None)
    parser.add_argument('--bbox_aug_std', help='Defines the std of the bounding box augs (none aug with not set).', type=float,default=None)
    parser.add_argument('--dropout_rate', help='Defines the dropout rate.', type=float,default=None)

    parser.add_argument('--image-min-side', help='Rescale the image so the smallest side is min_side.', type=int, default=800)
    parser.add_argument('--image-max-side', help='Rescale the image if the largest side is larger than max_side.', type=int, default=1333)
    parser.add_argument('--dicom_load_mode', help='Decide to load only image (image) or sex and view position as well (image_sex_view).', type=str, default='image')
    parser.add_argument('--hist_eq', help='Perform histogram equalization', action='store_true')
    
    parser.add_argument('--anchor_boxes', help='List of anchor boxes', type=str, default='0.5,1,2')
    parser.add_argument('--anchor_scales', help='List of anchor scales', type=str, default='1, 1.25992105, 1.58740105')
    parser.add_argument('--score_threshold', help='Threshold on score to filter detections with (defaults to 0.05).', default=0.2, type=float)
    parser.add_argument('--nms_threshold',   help='Non maximum suppression threshold',type=float, default=0.1)

    return check_args(parser.parse_args(args))

def main(args=None):
    set_keras_backend("tensorflow")  
    # parse arguments
    if args is None:
        args = sys.argv[1:]

    #Parse anchor boxes/scales
    args = parse_args(args)
    args.anchor_boxes = [float(item) for item in args.anchor_boxes.split(',')]
    args.anchor_scales = [float(item) for item in args.anchor_scales.split(',')]
    print('Using anchors: {}'.format(args.anchor_boxes))
    
    if (args.data_aug):
        args.random_transform=True
        args.image_only_transformations=True

    # create object that stores backbone information
    backbone = models.backbone(args.backbone,args.noise_aug_std)

    # make sure keras is the minimum required version
    check_keras_version()

    # optionally choose specific GPU
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    keras.backend.tensorflow_backend.set_session(get_session())

    # create the generators
    train_generator, validation_generator = create_generators(args, backbone.preprocess_image)
    
    # create the model
    if args.snapshot is not None:
        print('Loading model, this may take a second...')
        model            = models.load_model(args.snapshot, backbone_name=args.backbone, anchors_ratios=args.anchor_boxes,anchors_scales=args.anchor_scales)
        training_model   = model
        prediction_model = retinanet_bbox(model=model, anchors_ratios=args.anchor_boxes,anchors_scales=args.anchor_scales)
    else:
        weights = args.weights
        # default to imagenet if nothing else is specified
        if weights is None and args.imagenet_weights:
            weights = backbone.download_imagenet()

        print('Creating model, this may take a second...')
        model, training_model, prediction_model = create_models(
            backbone_retinanet=backbone.retinanet,
            num_classes=1,
            weights=weights,
            multi_gpu=args.multi_gpu,
            freeze_backbone=args.freeze_backbone,
            anchors_ratios=args.anchor_boxes,
            anchors_scales=args.anchor_scales,
            noise_aug_std=args.noise_aug_std,
            dropout_rate=args.dropout_rate
        )

    # print model summary
    print(model.summary())

    # this lets the generator compute backbone layer shapes using the actual backbone model
    if 'vgg' in args.backbone or 'densenet' in args.backbone:
        train_generator.compute_shapes = make_shapes_callback(model)
        if validation_generator:
            validation_generator.compute_shapes = train_generator.compute_shapes

    # create the callbacks
    callbacks = create_callbacks(
        model,
        training_model,
        prediction_model,
        validation_generator, 
        args,
    )

    # start training
    training_model.fit_generator(
        generator=train_generator,
        steps_per_epoch=args.steps,
        epochs=args.epochs,
        verbose=1,
        callbacks=callbacks,
        validation_data=validation_generator,
        validation_steps=args.val_steps
    )

    
if __name__ == '__main__':
    main()