# # Copyright 2018-2019 IBM Corp. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import io import logging import numpy as np import skimage.color import skimage.io import skimage.transform from flask import abort from maxfw.model import MAXModelWrapper from config import DEFAULT_MODEL_PATH, MODEL_META_DATA as model_meta from core.srgan_controller import SRGAN_controller logger = logging.getLogger() class ModelWrapper(MAXModelWrapper): MODEL_META_DATA = model_meta def __init__(self, path=DEFAULT_MODEL_PATH): logger.info('Loading model from: {}...'.format(path)) # Initialize the SRGAN controller self.SRGAN = SRGAN_controller(checkpoint=DEFAULT_MODEL_PATH) logger.info('Loaded model') def _read_image(self, image_data): '''Read the image from a Bytestream.''' image = skimage.io.imread(io.BytesIO(image_data), plugin='imageio') return image def _pre_process(self, image): ''' Preprocess the image. 1. Verify the dimensions 2. Resize if we exceed the maximum dimensions permitted by our model 3. Normalize the image 4. Convert to standardized input format ''' # Standardize input dtype of image image = image.astype('uint8') # If grayscale. Convert to RGB for consistency. if image.ndim != 3: image = skimage.color.gray2rgb(image) # If has an alpha channel, remove it for consistency if image.shape[-1] == 4: image = image[..., :3] # Resize dimensions that are too large to 500px (instead of raising an error) logger.info(f'image input dim: {image.shape[0]}x{image.shape[1]}') # a. find factor factor = np.ceil(max(image.shape[0], image.shape[1]) / 500) # b. resize if factor > 1: # if at least one image dimension is bigger than 500px if factor > 4: message = "The dimensions of the image are too big (>2000px). The image would have been downscaled instead." logger.error(message) abort(400, message) image = skimage.transform.resize(image, (np.floor(image.shape[0] / factor), np.floor(image.shape[1] / factor)), anti_aliasing=True) logger.info(f'image resized to: {image.shape[0]}x{image.shape[1]}') # Normalize image image = image / np.max(image) # Convert the image to numpy array with dtype float32 as required by the SRGAN # (1, H, W, C) image = np.array([image]).astype(np.float32) return image def _predict(self, image): '''Call the model''' return self.SRGAN.upscale(image) def write_image(self, image): '''Return the generated image as output.''' logger.info(f'image output dim: {image.shape[0]}x{image.shape[1]}') stream = io.BytesIO() skimage.io.imsave(stream, image) stream.seek(0) return stream def _post_process(self, result): '''Post-processing.''' return result