#
# Copyright 2018-2019 IBM Corp. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from PIL import Image
import io
import numpy as np
import logging

import tensorflow
from maxfw.model import MAXModelWrapper

from config import DEFAULT_MODEL_PATH, MODEL_NAME
from train_mitoses import normalize

logger = logging.getLogger()


class ModelWrapper(MAXModelWrapper):

    MODEL_META_DATA = {
        'id': '{}-keras-model'.format(MODEL_NAME.lower()),
        'name': '{} Keras Model'.format(MODEL_NAME),
        'description': '{} Keras model trained on TUPAC16 data to detect mitosis'.format(MODEL_NAME),
        'type': 'image_classification',
        'license': 'Custom',
        'source': 'https://developer.ibm.com/exchanges/models/all/max-breast-cancer-mitosis-detector/'
    }

    def __init__(self, path=DEFAULT_MODEL_PATH):
        logger.info('Loading model from: {}...'.format(path))
        self.sess = tensorflow.keras.backend.get_session()
        base_model = tensorflow.keras.models.load_model(path, compile=False)
        probs = tensorflow.keras.layers.Activation('sigmoid', name="sigmoid")(base_model.output)
        self.model = tensorflow.keras.models.Model(inputs=base_model.input, outputs=probs)
        self.input_tensor = self.model.input
        self.output_tensor = self.model.output

    def _read_image(self, image_data):
        image = Image.open(io.BytesIO(image_data))
        if image.size != (64, 64):
            raise ValueError(
                "The input file must be a PNG image of size (64, 64). Got %s %s." % (image.format, image.size))
        image = np.array(image)
        return image

    def _pre_process(self, image):
        image = np.expand_dims(image, 0)
        return image

    def _post_process(self, preds):
        return preds[0][0]

    def _predict(self, x):
        norm_patch_batch = normalize((np.array(x) / 255).astype(np.float32), "resnet_custom")
        out_batch = self.output_tensor.eval(feed_dict={self.input_tensor: norm_patch_batch}, session=self.sess)
        return out_batch