# -*- coding: utf-8 -*-
'''MalwaResNet.  An attempt at a deep architecture (34 layers) for grokking malicious/benign of bytes.  Inspired by ResNet: https://arxiv.org/pdf/1512.03385.pdf
   Designed to allows you to slurp the whole file into GPU memory by breaking it apart into chunks.
'''
from keras.models import Sequential, Model
from keras.layers import Dense, BatchNormalization, Dropout, Activation, TimeDistributed, Embedding, AveragePooling1D, GlobalAveragePooling2D, Conv1D, Flatten, ELU, Reshape
from keras.regularizers import l2

from keras.layers import Input, Add

charset = set(range(257))  # ord(256) = "EOF" (different from ord(0))

def ResidualBlock1D_helper(layers, kernel_size, filters, final_stride=1):
    def f(_input):
        basic = _input
        for ln in range(layers):
            #basic = BatchNormalization()( basic ) # triggers known keras bug w/ TimeDistributed: https://github.com/fchollet/keras/issues/5221
            basic = ELU()(basic)  
            basic = Conv1D(filters, kernel_size, kernel_initializer='he_normal',
                           kernel_regularizer=l2(1.e-4), padding='same')(basic)

        # note that this strides without averaging
        return AveragePooling1D(pool_size=1, strides=final_stride)(Add()([_input, basic]))

    return f

def ResidualBlock1D(input_layer, layers, kernel_size, final_stride=1):
    # create a resblock using Keras' Model class API
    model_shape = input_layer.output_shape
    _inp = Input(shape=model_shape[-2:])
    _out = ResidualBlock1D_helper(layers=layers, kernel_size=kernel_size,
                                  filters=input_layer.output_shape[-1], final_stride=final_stride)(_inp)
    model = Model(inputs=_inp, outputs=_out)
    return model


def create_model(input_shape, byte_embedding_size=2, lite=False):
    '''End to end deep learning.

    Note that to fit in memory of most graphics cards, the input file must be split up into multiple chunks, which will take
    up contiguous memory on the GPU card.  Then we process each chunk separately and combine within the model.  Note
    that this isn't exactly the same as having operated on the whole file, because of boundary artifacts of each chunk.
    But (shrug). It works.

    Args:
        input_shape (tuple) : input shape to the model. For this model, should be of shape (file_chunks, chunk_size)
        byte_embedding_size (int): each byte is embedded into a space of this dimension
        n_mlp_layers (int): number of hidden layers for the final multilayer perceptron
        lite (bool): set this to True if your GPU is complaining about memory. If it still borks, (shrug), sorry.

    Returns:
        keras.models.Sequential : a model to train
    '''
    file_chunks, chunk_size = input_shape

    model = Sequential()

    # first, we'll represent bytes in some embedding space.
    # if byte_embedding_size=2, then each byte (e.g., 'A') will be mapped to a point in 2d space
    # the mapping is learned end-to-end to minimize overall loss

    # TimeDistributed is operating on each chunk
    model.add(TimeDistributed(Embedding(len(charset), byte_embedding_size,
                                        input_length=chunk_size, name="embedding"), input_shape=(file_chunks, chunk_size)))
    # output shape: (nb_batch, file_chunks, chunk_size, byte_embedding_size)

    kernel_size = 3
    filters = 64
    # an initial convolutional layer, followed by strides=2
    model.add(TimeDistributed(Conv1D(filters, kernel_size * 2 + 1,
                                     kernel_initializer='he_normal', kernel_regularizer=l2(1.e-4), padding='same')))
    model.add(BatchNormalization())
    model.add(ELU())  # Activation('relu') )

    # pool and downsample
    model.add(TimeDistributed(AveragePooling1D(pool_size=2)))

    # add residual blocks
    # filters=64
    model.add(TimeDistributed(ResidualBlock1D( model.layers[-1], layers=2, kernel_size=kernel_size)) )
    if not lite: model.add(TimeDistributed(ResidualBlock1D( model.layers[-1], layers=2, kernel_size=kernel_size)) )
    model.add(TimeDistributed( ResidualBlock1D( model.layers[-1], layers=2, kernel_size=kernel_size, final_stride=2 )))

    model.add(TimeDistributed(Conv1D(filters * 2, 1, kernel_initializer='he_normal')))
    filters *= 2  # 128

    model.add(TimeDistributed( ResidualBlock1D( model.layers[-1], layers=2, kernel_size=kernel_size)) )
    if not lite: model.add(TimeDistributed( ResidualBlock1D( model.layers[-1], layers=2, kernel_size=kernel_size)) )
    if not lite: model.add(TimeDistributed( ResidualBlock1D( model.layers[-1], layers=2, kernel_size=kernel_size)) )
    model.add(TimeDistributed( ResidualBlock1D( model.layers[-1], layers=2, kernel_size=kernel_size, final_stride=2 ) ) )

    model.add(TimeDistributed(
        Conv1D(filters * 2, 1, kernel_initializer='he_normal')))
    filters *= 2  # 256

    model.add( TimeDistributed( ResidualBlock1D( model.layers[-1], layers=2, kernel_size=kernel_size)) )
    if not lite: model.add( TimeDistributed( ResidualBlock1D( model.layers[-1], layers=2, kernel_size=kernel_size)) )
    if not lite: model.add( TimeDistributed( ResidualBlock1D( model.layers[-1], layers=2, kernel_size=kernel_size)) )
    if not lite: model.add( TimeDistributed( ResidualBlock1D( model.layers[-1], layers=2, kernel_size=kernel_size)) )
    if not lite: model.add( TimeDistributed( ResidualBlock1D( model.layers[-1], layers=2, kernel_size=kernel_size)) )
    model.add( TimeDistributed( ResidualBlock1D( model.layers[-1], layers=2, kernel_size=kernel_size, final_stride=2 ) ) )

    model.add(TimeDistributed(Conv1D(filters * 2, 1, kernel_initializer='he_normal')))
    filters *= 2  # 512

    model.add(TimeDistributed( ResidualBlock1D(model.layers[-1], layers=2, kernel_size=kernel_size)) )
    if not lite: model.add( TimeDistributed( ResidualBlock1D( model.layers[-1], layers=2, kernel_size=kernel_size ) ) )
    model.add( TimeDistributed( ResidualBlock1D( model.layers[-1], layers=2, kernel_size=kernel_size, final_stride=1 ) ) )

    # output shape: (nb_batch, file_chunks, downsampled_chunk_size, filters )

    # average over chunks and within each chunk
    model.add(GlobalAveragePooling2D())
    # (nb_batch, filters)

    # add fully-connected layers
    model.add(Dense(1000, kernel_initializer='he_normal'))
    model.add(BatchNormalization())
    model.add(ELU())  # Activation('relu') )

    # output layer
    model.add(Dense(1, activation='sigmoid', kernel_initializer='he_normal'))

    # we'll optimize with plain old sgd
    model.compile(loss='binary_crossentropy',
                  optimizer='sgd', metrics=['accuracy'])

    return model