# -*- coding: utf-8 -*- import cv2 import numpy as np import h5py import requests import data import deploy from .decorators import timeout from . import exceptions as exc from .shortcuts import at_random @timeout(30) def fetch_cvimage_from_url(url, maxsize=10 * 1024 * 1024): req = requests.get(url, timeout=5, stream=True) content = '' for chunk in req.iter_content(2048): content += chunk if len(content) > maxsize: req.close() raise ValueError('Response too large') img_array = np.asarray(bytearray(content), dtype=np.uint8) cv2_img_flag = cv2.CV_LOAD_IMAGE_COLOR image = cv2.imdecode(img_array, cv2_img_flag) return image class MockClassifier(object): def classify(self, *args, **kwargs): message = at_random( "I hope a mock message like this won't get caught by Twitter's spam filter", "But I must explain to you how all this mistaken idea was born", "At vero eos et accusamus et iusto odio dignissimos ducimus qui blanditiis", "Excepteur sint occaecat cupidatat non proident", ) return [deploy.Prediction(1, message, 100)] class ImageClassifier(object): def __init__(self, dataset_path, input_shape, model_name='model'): catname_to_categories = data.get_categories() self.category_to_catnames = {v: k for k, v in catname_to_categories.items()} self.model = deploy.load_model( input_shape=input_shape, n_outputs=len(catname_to_categories), model_name=model_name) self.input_shape = input_shape self.average_image = data.get_mean(dataset_path) def classify(self, cvimage): normalized = deploy.normalize_cvimage(cvimage, size=self.input_shape, mean=self.average_image) return deploy.apply_model(normalized, self.model, self.category_to_catnames) class URLClassifier(object): def __init__(self, image_classifier): self._image_classifier = image_classifier def classify(self, url=None): cvimage = fetch_cvimage_from_url(url) if cvimage is None: raise exc.NotImage(url) return self._image_classifier.classify(cvimage) class RemoteClassifier(object): def __init__(self, base_url): self._base_url = base_url def classify(self, **params): try: r = requests.get(self._base_url, params=params, timeout=60).json() if 'error' in r: raise exc.RemoteError(r['error']) return map(lambda guess: deploy.Prediction(**guess), r['y']) except requests.exceptions.Timeout: raise exc.TimeoutError