"""Trains the FRRN A architecture on the CityScapes dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import collections
import functools
import pickle
import lasagne
import logging as pylogging
from dltools import hybrid_training, utility, architectures, logging, optimizer, hooks, losses
import theano
import theano.tensor as T
import numpy as np

# Try to import pychianti and if it doesn't work, import the python adapter.
# This allows us to use the code without installing the C++ library.
try:
    import pychianti
except:
    from dltools import pychianti_adapter as pychianti


NUM_CLASSES = 19
IMAGE_CHANNELS = 3
IMAGE_ROWS = 1024
IMAGE_COLS = 2048
BASE_CHANNELS = 48
FR_CHANNELS = 32
MULTIPLIER = 2
BOOTSTRAP_MULTIPLIER = 64
REDUCE_LR_INTERVAL = 5000


def define_network(arch, batch_size, sample_factor, crop_size=None):
    """Creates the network architecture.

    Args:
        arch: The architecture type. "frrn_a" or "frrn_b"
        batch_size: The batch size.
        sample_factor: The subsampling factor.
        crop_size: The size of the image crops. None will result in full-frame
            training.

    Returns:
        The newly created network instance.

    Raises:
        ValueError: If `arch` is not a valid architecture identifier.
    """
    if arch == "frrn_a":
        builder = architectures.FRRNABuilder
    elif arch == "frrn_b":
        builder = architectures.FRRNBBuilder
    else:
        raise ValueError("Invalid network architecture {}.".format(arch))

    # Define the theano variables
    input_var = T.ftensor4()

    builder = builder(
        base_channels=BASE_CHANNELS,
        lanes=FR_CHANNELS,
        multiplier=MULTIPLIER,
        num_classes=NUM_CLASSES
    )

    if crop_size is None:
        network = builder.build(
            input_var=input_var,
            input_shape=(batch_size,
                         IMAGE_CHANNELS,
                         IMAGE_ROWS // sample_factor,
                         IMAGE_COLS // sample_factor))
    else:
        network = builder.build(
            input_var=input_var,
            input_shape=(batch_size,
                         IMAGE_CHANNELS,
                         crop_size,
                         crop_size))

    return network


def compile_train_function(network, batch_size, learning_rate):
    """Compiles the training function.

    Args:
        network: The network instance.
        batch_size: The training batch size.
        learning_rate: The learning rate.
    Returns:
    The update function that takes a batch of images and targets and updates the
    network weights.
    """
    learning_rate = np.float32(learning_rate)

    input_var = network.input_layers[0].input_var
    target_var = T.ftensor4()

    # Loss function
    loss_fn = functools.partial(
        losses.bootstrapped_xentropy,
        targets=target_var,
        batch_size=batch_size,
        multiplier=BOOTSTRAP_MULTIPLIER
    )

    # Update function
    lr = theano.shared(learning_rate)

    update_fn = functools.partial(lasagne.updates.adam, learning_rate=lr)

    pylogging.info("Compile SGD updates")
    gd_step = hybrid_training.compile_gd_step(
        network, loss_fn, [input_var, target_var], update_fn)

    reduce_lr = theano.function(
        inputs=[],
        updates=collections.OrderedDict([
            (lr, T.maximum(np.float32(5e-5), lr / np.float32(1.25)))
        ])
    )

    def _compute_update(imgs, targets, update_counter):
        if (update_counter + 1) % REDUCE_LR_INTERVAL == 0:
            reduce_lr()
        loss = gd_step(imgs, targets)

        return loss

    return _compute_update


def compile_validation_function(network, batch_size):
    """Compiles the validation function.

    Args:
        network: The network instance.
        batch_size: The batch size.

    Returns:
    A function that takes in a batch of images and targets and returns the
    predicted segmentation mask and the loss.
    """
    input_var = network.input_layers[0].input_var
    target_var = T.ftensor4()

    predictions = lasagne.layers.get_output(
        network.output_layers, deterministic=True)[0]

    loss = losses.bootstrapped_xentropy(
        predictions=predictions,
        targets=target_var,
        batch_size=batch_size,
        multiplier=BOOTSTRAP_MULTIPLIER
    )

    pylogging.info("Compile validation function")
    return theano.function(
        inputs=[input_var, target_var],
        outputs=[T.argmax(predictions, axis=1), loss]
    )


def get_training_provider(cityscapes_folder,
                          sample_factor,
                          batch_size,
                          iterator_type,
                          crop_size):
    """Creates the training data provider.

    Args:
        cityscapes_folder: The folder in which the Cityscapes Dataset is
            located.
        sample_factor: The image sampling factor.
        batch_size: The batch size.
        iterator_type: The iterator type. "uniform" or "weighted".
        crop_size: The size of the image crops. None will result in full-frame
            training.

    Returns:
    A chianti data provider.
    """
    augmentors = [
        pychianti.Augmentor.Translation(120),
    ]

    if sample_factor > 1:
        augmentors.append(pychianti.Augmentor.Subsample(sample_factor))

    if crop_size is not None:
        augmentors.append(pychianti.Augmentor.Crop(crop_size, NUM_CLASSES))

    augmentors.extend([
        pychianti.Augmentor.Gamma(0.05),
        pychianti.Augmentor.Rotation(10),
        pychianti.Augmentor.Saturation(0.5, 1.5),
        pychianti.Augmentor.Hue(-30, 30),
    ])

    images = utility.get_image_label_pairs(cityscapes_folder, "train")

    if iterator_type == "uniform":
        iterator = pychianti.Iterator.Random(images)
    elif iterator_type == "weighted":
        # Load the image weights
        with open("data_weights.pkl", "rb") as f:
            w = pickle.load(f)

        weights = []
        for img in images:
            image_name = img[0].split("/")[-1]
            weights.append(w[image_name])
        iterator = pychianti.Iterator.WeightedRandom(images, weights)
    else:
        raise ValueError("Invalid iterator type {}.".format(iterator_type))

    provider = pychianti.DataProvider(
        pychianti.Augmentor.Combined(augmentors),
        pychianti.Loader.RGB(),
        pychianti.Loader.ValueMapper(utility.cityscapes_value_map),
        iterator,
        batch_size,
        NUM_CLASSES)

    return provider


def get_validation_provider(cityscapes_folder,
                            sample_factor,
                            batch_size,
                            crop_size=None):
    """Creates the validation data provider.

    Args:
        cityscapes_folder: The folder in which the Cityscapes Dataset is
            located.
        sample_factor: The image sampling factor.
        batch_size: The batch size.
        crop_size: The size of the image crops. None will result in full-frame
            training.

    Returns:
    A chianti data provider.
    """
    augmentors = []

    if sample_factor > 1:
        augmentors.append(pychianti.Augmentor.Subsample(sample_factor))

    if crop_size is not None:
        augmentors.append(pychianti.Augmentor.Crop(crop_size, NUM_CLASSES))

    validation_images = utility.get_image_label_pairs(cityscapes_folder, "val")

    return pychianti.DataProvider(
        pychianti.Augmentor.Combined(augmentors),
        pychianti.Loader.RGB(),
        pychianti.Loader.ValueMapper(utility.cityscapes_value_map),
        pychianti.Iterator.Sequential(validation_images),
        batch_size,
        NUM_CLASSES)


def main():
    """Trains a FRRN architecture on the Cityscapes Dataset."""
    parser = argparse.ArgumentParser(
        description="Trains a Full-Resolution Residual"
                    " Network on the Cityscapes"
                    " Dataset.")

    parser.add_argument("--architecture",
                        type=str,
                        choices=["frrn_a", "frrn_b"],
                        required=True,
                        help="The network architecture type.")

    parser.add_argument("--model_file",
                        type=str,
                        required=True,
                        help="The model filename. Weights are initialized to "
                             "the given values if the file exists. Snapshots "
                             "are stored using a _snapshot_[iteration] "
                             "post-fix.")

    parser.add_argument("--log_file",
                        type=str,
                        required=True,
                        help="The log filename. Use log_monitor.py in order to "
                             "monitor training progress in the terminal.")

    parser.add_argument("--cs_folder",
                        type=str,
                        required=True,
                        help="The folder that contains the Cityscapes Dataset.")

    parser.add_argument("--batch_size",
                        type=int,
                        default=3,
                        help="The batch size.")

    parser.add_argument("--validation_interval",
                        type=int,
                        default=500,
                        help="The validation interval.")

    parser.add_argument("--iterator",
                        type=str,
                        default="uniform",
                        choices=["uniform", "weighted"],
                        help="The dataset iterator type.")

    parser.add_argument("--crop_size",
                        type=int,
                        default=0,
                        help="The size of crops to extract from the "
                             "full-resolution images. If 0, then no crops "
                             "will be extracted.")

    parser.add_argument("--learning_rate",
                        type=float,
                        default=1e-3,
                        help="The learning rate to use.")

    parser.add_argument("--sample_factor",
                        type=int,
                        default=0,
                        help="The sampling factor.")

    args = parser.parse_args()

    # Determine the sampling factor based on the network architecture
    if args.architecture == "frrn_a":
        sample_factor = 4
    else:
        sample_factor = 2

    if args.sample_factor != 0:
        sample_factor = args.sample_factor

    if args.crop_size > 0:
        crop_size = args.crop_size
        sample_factor = 1
    else:
        crop_size = None

    pylogging.info("Sample factor: {}".format(sample_factor))

    # Define the network lasagne graph and try to load the model file
    network = define_network(args.architecture,
                             args.batch_size,
                             sample_factor,
                             crop_size)
    pylogging.info("Try to load weights from {}".format(args.model_file))
    network.load_model(args.model_file)

    # Get the logger
    logger = logging.FileLogWriter(args.log_file)

    # Create the optimizer
    opt = optimizer.MiniBatchOptimizer(
        compile_train_function(network, args.batch_size, args.learning_rate),
        get_training_provider(args.cs_folder,
                              sample_factor,
                              args.batch_size,
                              args.iterator,
                              crop_size),
        [
            hooks.SnapshotHook(
                args.model_file,
                network,
                interval=args.validation_interval),
            hooks.LoggingHook(logger),
            hooks.SegmentationValidationHook(
                compile_validation_function(network, args.batch_size),
                get_validation_provider(args.cs_folder,
                                        sample_factor,
                                        args.batch_size,
                                        crop_size),
                logger,
                interval=args.validation_interval)
        ])

    pylogging.info("Start training")
    opt.optimize()


if __name__ == "__main__":
    pylogging.basicConfig(format="%(asctime)s %(levelname)s %(message)s",
                          level=pylogging.DEBUG)
    main()