from keras import Model
from keras.layers import Conv2D, Conv2DTranspose, concatenate
from keras.optimizers import Adam


class AutoEncoderModel(object):
    def __init__(self, left_input, right_input, lr=1e-4, rows=128, cols=512):
        self.rows = rows
        self.cols = cols
        self.left = left_input
        self.right = right_input
        self.left_est = None
        self.right_est = None
        self.output = None
        self.model = None
        self.lr = lr
        self.build_architecture()
        self.build_outputs()
        self.build_model()

    @staticmethod
    def conv(input, channels, kernel_size, strides, activation='elu'):

        return Conv2D(channels, kernel_size=kernel_size, strides=strides, padding='same', activation=activation)(input)

    @staticmethod
    def deconv(input, channels, kernel_size, scale):

        return Conv2DTranspose(channels, kernel_size=kernel_size, strides=scale, padding='same')(input)

    def conv_block(self, input, channels, kernel_size):
        conv1 = self.conv(input, channels, kernel_size, 1)

        conv2 = self.conv(conv1, channels, kernel_size, 2)

        return conv2

    def deconv_block(self, input, channels, kernel_size, skip):
        deconv1 = self.deconv(input, channels, kernel_size, 2)

        if skip is not None:
            concat1 = concatenate([deconv1, skip], 3)
        else:
            concat1 = deconv1

        iconv1 = self.conv(concat1, channels, kernel_size, 1)

        return iconv1

    def get_output(self, deconv):
        return self.conv(deconv, 3, 3, 1, 'sigmoid')

    def build_architecture(self):
        # encoder
        conv1 = self.conv_block(self.left, 32, 7)
        conv2 = self.conv_block(conv1, 64, 5)
        conv3 = self.conv_block(conv2, 128, 3)
        conv4 = self.conv_block(conv3, 256, 3)
        conv5 = self.conv_block(conv4, 512, 3)
        conv6 = self.conv_block(conv5, 512, 3)
        conv7 = self.conv_block(conv6, 512, 3)

        # skips
        skip1 = conv1
        skip2 = conv2
        skip3 = conv3
        skip4 = conv4
        skip5 = conv5
        skip6 = conv6

        deconv7 = self.deconv_block(conv7, 512, 3, skip6)
        deconv6 = self.deconv_block(deconv7, 512, 3, skip5)
        deconv5 = self.deconv_block(deconv6, 256, 3, skip4)
        deconv4 = self.deconv_block(deconv5, 128, 3, skip3)
        deconv3 = self.deconv_block(deconv4, 64, 3, skip2)
        deconv2 = self.deconv_block(deconv3, 32, 3, skip1)
        deconv1 = self.deconv_block(deconv2, 16, 3, None)

        self.output = self.get_output(deconv1)

    def build_outputs(self):
        self.left_est = self.output
        # self.right_est = expand_dims(self.output, 1, 'right_estimate')

    def build_model(self):
        self.model = Model(inputs=[self.left], outputs=[self.left_est])
        self.model.compile(loss=['mae'],
                           optimizer='adadelta',
                           metrics=['mse'])