# # Copyright 2018-2019 IBM Corp. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import io import logging from PIL import Image from keras.backend import clear_session from keras import models from keras.preprocessing.image import img_to_array from keras.applications import imagenet_utils import numpy as np from maxfw.model import MAXModelWrapper from config import DEFAULT_MODEL_PATH from flask import abort import json import os logger = logging.getLogger() class ModelWrapper(MAXModelWrapper): """Model wrapper for Keras models""" MODEL_NAME = 'resnet50' MODEL_INPUT_IMG_SIZE = (224, 224) MODEL_LICENSE = 'MIT' MODEL_MODE = 'caffe' MODEL_META_DATA = { 'id': '{}-keras-imagenet'.format(MODEL_NAME.lower()), 'name': '{} Keras Model'.format(MODEL_NAME), 'description': '{} Keras model trained on ImageNet'.format(MODEL_NAME), 'type': 'image_classification', 'license': '{}'.format(MODEL_LICENSE), 'source': 'https://developer.ibm.com/exchanges/models/all/max-resnet-50/' } def __init__(self, path=DEFAULT_MODEL_PATH): logger.info('Loading model from: {}...'.format(path)) clear_session() self.model = models.load_model( os.path.join(path, 'resnet50.h5')) # this seems to be required to make Keras models play nicely with threads self.model._make_predict_function() logger.info('Loaded model: {}'.format(self.model.name)) with open(os.path.join(DEFAULT_MODEL_PATH, 'class_index.json')) as class_file: self.classes = json.load(class_file) def read_image(self, image_data): try: image = Image.open(io.BytesIO(image_data)) if image.mode != 'RGB': image = image.convert('RGB') return image except IOError: abort(400, 'Invalid file type/extension. Please provide a valid image (supported formats: JPEG, PNG, TIFF).') def _pre_process(self, image): image = image.resize(self.MODEL_INPUT_IMG_SIZE) image = img_to_array(image) image = np.expand_dims(image, axis=0) return imagenet_utils.preprocess_input(image, mode=self.MODEL_MODE) def _post_process(self, preds): preds_sorted_index = preds[0].argsort()[-5:][::-1] top_preds_prob = preds[0][preds_sorted_index] return [[self.classes[str(preds_sorted_index[i])][0], self.classes[str(preds_sorted_index[i])][1], top_preds_prob[i]] for i in range(len(preds_sorted_index))] def _predict(self, x): return self.model.predict(x)