"""dense net"""
import keras
from keras.applications import densenet
from keras.utils import get_file

from . import retinanet
from . import Backbone
from ..utils.image import preprocess_image

allowed_backbones = {
    'densenet121': ([6, 12, 24, 16], densenet.DenseNet121),
    'densenet169': ([6, 12, 32, 32], densenet.DenseNet169),
    'densenet201': ([6, 12, 48, 32], densenet.DenseNet201),

class DenseNetBackbone(Backbone):
    """ Describes backbone information and provides utility functions.

    def retinanet(self, *args, **kwargs):
        """ Returns a retinanet model using the correct backbone.
        return densenet_retinanet(*args, backbone=self.backbone, **kwargs)

    def download_imagenet(self):
        """ Download pre-trained weights for the specified backbone name.
        This name is in the format {backbone}_weights_tf_dim_ordering_tf_
        where backbone is the densenet + number of layers (e.g. densenet121).
        For more info check the explanation from the keras densenet script 
        origin = 'https://github.com/fchollet/deep-learning-models/releases/'\
        file_name = '{}_weights_tf_dim_ordering_tf_kernels_notop.h5'

        # load weights
        if keras.backend.image_data_format() == 'channels_first':
            raise ValueError(
                'Weights for "channels_first" format are not available.')

        weights_url = origin + file_name.format(self.backbone)
        return get_file(file_name.format(self.backbone),
                        weights_url, cache_subdir='models')

    def validate(self):
        """ Checks whether the backbone string is correct.
        backbone = self.backbone.split('_')[0]

        if backbone not in allowed_backbones:
            raise ValueError(
                'Backbone (\'{}\') not in allowed backbones ({}).'.format(
                    backbone, allowed_backbones.keys()))

    def preprocess_image(self, inputs):
        """ Takes as input an image and prepares it for being passed through 
        the network.
        return preprocess_image(inputs, mode='tf')

def densenet_retinanet(num_classes, backbone='densenet121',
                       inputs=None, modifier=None, **kwargs):
    """ Constructs a retinanet model using a densenet backbone.

    num_classes: int
        Number of classes to predict.
    backbone: str
        Which backbone to use (one of ('densenet121', 'densenet169', 
    inputs: tensor
        The inputs to the network (defaults to a Tensor of shape 
        (None, None, 3)).
    modifier: function
        A function handler which can modify the backbone before using it in 
        retinanet (this can be used to freeze backbone layers for example).

        RetinaNet model with a DenseNet backbone.
    # choose default input
    if inputs is None:
        inputs = keras.layers.Input((None, None, 3))

    blocks, creator = allowed_backbones[backbone]
    model = creator(

    # get last conv layer from the end of each dense block
    layer_outputs = [
                idx + 2,
                block_num)).output for idx,
        block_num in enumerate(blocks)]

    # create the densenet backbone
    model = keras.models.Model(
        inputs=inputs, outputs=layer_outputs[1:], name=model.name)

    # invoke modifier if given
    if modifier:
        model = modifier(model)

    # create the full model
    model = retinanet.retinanet(

    return model